From 5896a32da7a8bc9e2df4ba117c0a04edf71665ee Mon Sep 17 00:00:00 2001 From: Slime Turing Date: Sat, 7 Mar 2026 14:43:12 +0800 Subject: [PATCH] Fix truncated Opus tail in streaming writer --- README.md | 2 + api/src/services/streaming_audio_writer.py | 48 +++++++++++++++++----- api/tests/test_audio_service.py | 31 +++++++++++--- 3 files changed, 65 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 60a30c40..64e5a908 100644 --- a/README.md +++ b/README.md @@ -256,6 +256,8 @@ response = requests.post( - m4a - pcm +Implementation note: Opus output is encoded on a 48 kHz codec clock. When streaming to an in-memory buffer, the final Ogg/Opus pages are only available after the muxer is closed, so the last response chunk must be read after finalization to avoid truncating the tail of the clip. +

Audio Format Comparison

diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index de9c84e3..b60893da 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -1,14 +1,11 @@ """Audio conversion service with proper streaming support""" -import struct from io import BytesIO from typing import Optional import av import numpy as np -import soundfile as sf from loguru import logger -from pydub import AudioSegment class StreamingAudioWriter: @@ -20,6 +17,8 @@ def __init__(self, format: str, sample_rate: int, channels: int = 1): self.channels = channels self.bytes_written = 0 self.pts = 0 + # Opus is muxed on a 48 kHz clock even when the source PCM is 24 kHz. + self.codec_sample_rate = 48000 if self.format == "opus" else self.sample_rate codec_map = { "wav": "pcm_s16le", @@ -47,21 +46,31 @@ def __init__(self, format: str, sample_rate: int, channels: int = 1): ) self.stream = self.container.add_stream( codec_map[self.format], - rate=self.sample_rate, + rate=self.codec_sample_rate, layout="mono" if self.channels == 1 else "stereo", ) # Set bit_rate only for codecs where it's applicable and useful if self.format in ['mp3', 'aac', 'opus']: self.stream.bit_rate = 128000 + + if self.format == "opus": + # Resample the model's 24 kHz PCM into the codec clock expected by Opus. + self.resampler = av.AudioResampler( + format="s16", + layout="mono" if self.channels == 1 else "stereo", + rate=self.codec_sample_rate, + ) else: raise ValueError(f"Unsupported format: {self.format}") # Use self.format here def close(self): if hasattr(self, "container"): self.container.close() + del self.container if hasattr(self, "output_buffer"): self.output_buffer.close() + del self.output_buffer def write_chunk( self, audio_data: Optional[np.ndarray] = None, finalize: bool = False @@ -84,9 +93,15 @@ def write_chunk( # No explicit flush method is available or needed here. logger.debug("Muxed final packets.") + # The Opus/Ogg muxer keeps the final pages in memory until close(). + # Reading the buffer before close truncates the tail of the stream. + self.container.close() + del self.container + # Get the final bytes from the buffer *before* closing it data = self.output_buffer.getvalue() - self.close() # Close container and buffer + self.output_buffer.close() + del self.output_buffer return data if audio_data is None or len(audio_data) == 0: @@ -103,12 +118,23 @@ def write_chunk( ) frame.sample_rate = self.sample_rate - frame.pts = self.pts - self.pts += frame.samples - - packets = self.stream.encode(frame) - for packet in packets: - self.container.mux(packet) + frames_to_encode = [frame] + if self.format == "opus": + resampled = self.resampler.resample(frame) + if resampled is None: + frames_to_encode = [] + elif isinstance(resampled, list): + frames_to_encode = resampled + else: + frames_to_encode = [resampled] + + for encode_frame in frames_to_encode: + encode_frame.pts = self.pts + self.pts += encode_frame.samples + + packets = self.stream.encode(encode_frame) + for packet in packets: + self.container.mux(packet) data = self.output_buffer.getvalue() self.output_buffer.seek(0) diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index 5ba53928..6236fe2e 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -1,7 +1,9 @@ """Tests for AudioService""" +from io import BytesIO from unittest.mock import patch +import av import numpy as np import pytest @@ -80,16 +82,35 @@ async def test_convert_to_opus(sample_audio): writer = StreamingAudioWriter("opus", sample_rate=24000) audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), "opus", writer + AudioChunk(audio_data), "opus", writer, is_last_chunk=False, trim_audio=False + ) + final_chunk = await AudioService.convert_audio( + AudioChunk(np.array([], dtype=np.int16)), + "opus", + writer, + is_last_chunk=True, + trim_audio=False, ) - writer.close() + encoded = audio_chunk.output + final_chunk.output - assert isinstance(audio_chunk.output, bytes) + assert isinstance(encoded, bytes) assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 + assert len(encoded) > 0 # Check OGG header - assert audio_chunk.output.startswith(b"OggS") + assert encoded.startswith(b"OggS") + + with av.open(BytesIO(encoded), mode="r", format="ogg") as container: + decoded_samples = 0 + decoded_rate = None + for frame in container.decode(audio=0): + decoded_samples += frame.samples + decoded_rate = frame.sample_rate + + assert decoded_rate == 48000 + assert decoded_samples / decoded_rate == pytest.approx( + len(audio_data) / sample_rate, abs=0.03 + ) @pytest.mark.asyncio