Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ response = requests.post(
- m4a
- pcm

Implementation note: Opus output is encoded on a 48 kHz codec clock. When streaming to an in-memory buffer, the final Ogg/Opus pages are only available after the muxer is closed, so the last response chunk must be read after finalization to avoid truncating the tail of the clip.

<p align="center">
<img src="assets/format_comparison.png" width="80%" alt="Audio Format Comparison" style="border: 2px solid #333; padding: 10px;">
</p>
Expand Down
48 changes: 37 additions & 11 deletions api/src/services/streaming_audio_writer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Audio conversion service with proper streaming support"""

import struct
from io import BytesIO
from typing import Optional

import av
import numpy as np
import soundfile as sf
from loguru import logger
from pydub import AudioSegment


class StreamingAudioWriter:
Expand All @@ -20,6 +17,8 @@ def __init__(self, format: str, sample_rate: int, channels: int = 1):
self.channels = channels
self.bytes_written = 0
self.pts = 0
# Opus is muxed on a 48 kHz clock even when the source PCM is 24 kHz.
self.codec_sample_rate = 48000 if self.format == "opus" else self.sample_rate

codec_map = {
"wav": "pcm_s16le",
Expand Down Expand Up @@ -47,21 +46,31 @@ def __init__(self, format: str, sample_rate: int, channels: int = 1):
)
self.stream = self.container.add_stream(
codec_map[self.format],
rate=self.sample_rate,
rate=self.codec_sample_rate,
layout="mono" if self.channels == 1 else "stereo",
)
# Set bit_rate only for codecs where it's applicable and useful
if self.format in ['mp3', 'aac', 'opus']:
self.stream.bit_rate = 128000

if self.format == "opus":
# Resample the model's 24 kHz PCM into the codec clock expected by Opus.
self.resampler = av.AudioResampler(
format="s16",
layout="mono" if self.channels == 1 else "stereo",
rate=self.codec_sample_rate,
)
else:
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here

def close(self):
if hasattr(self, "container"):
self.container.close()
del self.container

if hasattr(self, "output_buffer"):
self.output_buffer.close()
del self.output_buffer

def write_chunk(
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
Expand All @@ -84,9 +93,15 @@ def write_chunk(
# No explicit flush method is available or needed here.
logger.debug("Muxed final packets.")

# The Opus/Ogg muxer keeps the final pages in memory until close().
# Reading the buffer before close truncates the tail of the stream.
self.container.close()
del self.container

# Get the final bytes from the buffer *before* closing it
data = self.output_buffer.getvalue()
self.close() # Close container and buffer
self.output_buffer.close()
del self.output_buffer
return data

if audio_data is None or len(audio_data) == 0:
Expand All @@ -103,12 +118,23 @@ def write_chunk(
)
frame.sample_rate = self.sample_rate

frame.pts = self.pts
self.pts += frame.samples

packets = self.stream.encode(frame)
for packet in packets:
self.container.mux(packet)
frames_to_encode = [frame]
if self.format == "opus":
resampled = self.resampler.resample(frame)
if resampled is None:
frames_to_encode = []
elif isinstance(resampled, list):
frames_to_encode = resampled
else:
frames_to_encode = [resampled]

for encode_frame in frames_to_encode:
encode_frame.pts = self.pts
self.pts += encode_frame.samples

packets = self.stream.encode(encode_frame)
for packet in packets:
self.container.mux(packet)

data = self.output_buffer.getvalue()
self.output_buffer.seek(0)
Expand Down
31 changes: 26 additions & 5 deletions api/tests/test_audio_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Tests for AudioService"""

from io import BytesIO
from unittest.mock import patch

import av
import numpy as np
import pytest

Expand Down Expand Up @@ -80,16 +82,35 @@ async def test_convert_to_opus(sample_audio):
writer = StreamingAudioWriter("opus", sample_rate=24000)

audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "opus", writer
AudioChunk(audio_data), "opus", writer, is_last_chunk=False, trim_audio=False
)
final_chunk = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
"opus",
writer,
is_last_chunk=True,
trim_audio=False,
)

writer.close()
encoded = audio_chunk.output + final_chunk.output

assert isinstance(audio_chunk.output, bytes)
assert isinstance(encoded, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
assert len(encoded) > 0
# Check OGG header
assert audio_chunk.output.startswith(b"OggS")
assert encoded.startswith(b"OggS")

with av.open(BytesIO(encoded), mode="r", format="ogg") as container:
decoded_samples = 0
decoded_rate = None
for frame in container.decode(audio=0):
decoded_samples += frame.samples
decoded_rate = frame.sample_rate

assert decoded_rate == 48000
assert decoded_samples / decoded_rate == pytest.approx(
len(audio_data) / sample_rate, abs=0.03
)


@pytest.mark.asyncio
Expand Down