Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions sdk/rt/speechmatics/rt/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._exceptions import TimeoutError
from ._exceptions import TranscriptionError
from ._logging import get_logger
from ._models import AudioEncoding
from ._models import AudioEventsConfig
from ._models import AudioFormat
from ._models import ClientMessageType
Expand Down Expand Up @@ -97,6 +98,8 @@ def __init__(
self.on(ServerMessageType.WARNING, self._on_warning)
self.on(ServerMessageType.AUDIO_ADDED, self._on_audio_added)

self._audio_format = AudioFormat(encoding=AudioEncoding.PCM_S16LE, sample_rate=44100, chunk_size=4096)

self._logger.debug("AsyncClient initialized (request_id=%s)", self._session.request_id)

async def start_session(
Expand Down Expand Up @@ -133,6 +136,9 @@ async def start_session(
... await client.start_session()
... await client.send_audio(frame)
"""
if audio_format is not None:
self._audio_format = audio_format

await self._start_recognition_session(
transcription_config=transcription_config,
audio_format=audio_format,
Expand Down Expand Up @@ -161,16 +167,24 @@ async def stop_session(self) -> None:
await self._session_done_evt.wait() # Wait for end of transcript event to indicate we can stop listening
await self.close()

async def force_end_of_utterance(self) -> None:
async def force_end_of_utterance(self, timestamp: Optional[float] = None) -> None:
"""
This method sends a ForceEndOfUtterance message to the server to signal
the end of an utterance. Forcing end of utterance will cause the final
transcript to be sent to the client early.

Takes an optional timestamp parameter to specify a marker for the engine
to use for timing of the end of the utterance. If not provided, the timestamp
will be calculated based on the cumulative audio sent to the server.

Args:
timestamp: Optional timestamp for the request.

Raises:
ConnectionError: If the WebSocket connection fails.
TranscriptionError: If the server reports an error during teardown.
TimeoutError: If the connection or teardown times out.
ValueError: If the audio format does not have an encoding set.

Examples:
Basic streaming:
Expand All @@ -179,7 +193,19 @@ async def force_end_of_utterance(self) -> None:
... await client.send_audio(frame)
... await client.force_end_of_utterance()
"""
await self.send_message({"message": ClientMessageType.FORCE_END_OF_UTTERANCE})
if timestamp is None:
timestamp = self.audio_seconds_sent

await self.send_message({"message": ClientMessageType.FORCE_END_OF_UTTERANCE, "timestamp": timestamp})

@property
def audio_seconds_sent(self) -> float:
"""Number of audio seconds sent to the server.

Raises:
ValueError: If the audio format does not have an encoding set.
"""
return self._audio_bytes_sent / (self._audio_format.sample_rate * self._audio_format.bytes_per_sample)

async def transcribe(
self,
Expand Down
7 changes: 7 additions & 0 deletions sdk/rt/speechmatics/rt/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, transport: Transport) -> None:
self._recv_task: Optional[asyncio.Task[None]] = None
self._closed_evt = asyncio.Event()
self._eos_sent = False
self._audio_bytes_sent = 0
self._seq_no = 0

self._logger = get_logger("speechmatics.rt.base_client")
Expand Down Expand Up @@ -122,11 +123,17 @@ async def send_audio(self, payload: bytes) -> None:

try:
await self._transport.send_message(payload)
self._audio_bytes_sent += len(payload)
self._seq_no += 1
except Exception:
self._closed_evt.set()
raise

@property
def audio_bytes_sent(self) -> int:
"""Number of audio bytes sent to the server."""
return self._audio_bytes_sent

async def send_message(self, message: dict[str, Any]) -> None:
"""
Send a message through the WebSocket.
Expand Down
23 changes: 23 additions & 0 deletions sdk/rt/speechmatics/rt/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,29 @@ class AudioFormat:
sample_rate: int = 44100
chunk_size: int = 4096

_BYTES_PER_SAMPLE = {
AudioEncoding.PCM_F32LE: 4,
AudioEncoding.PCM_S16LE: 2,
AudioEncoding.MULAW: 1,
}

@property
def bytes_per_sample(self) -> int:
"""Number of bytes per audio sample based on encoding.

Raises:
ValueError: If encoding is None (file type) or unrecognized.
"""
if self.encoding is None:
raise ValueError(
"Cannot determine bytes per sample for file-type audio format. "
"Set an explicit encoding on AudioFormat."
)
try:
return self._BYTES_PER_SAMPLE[self.encoding]
except KeyError:
raise ValueError(f"Unknown encoding: {self.encoding}")

def to_dict(self) -> dict[str, Any]:
"""
Convert audio format to dictionary.
Expand Down
18 changes: 13 additions & 5 deletions sdk/voice/speechmatics/voice/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def __init__(
)

# Uses ForceEndOfUtterance message
# Todo - fix this logic as use FEOU isn't the same as just not using fixed.
self._uses_forced_eou: bool = not self._uses_fixed_eou
self._forced_eou_active: bool = False
self._last_forced_eou_latency: float = 0.0
Expand Down Expand Up @@ -472,12 +473,15 @@ def _prepare_config(
# LIFECYCLE METHODS
# ============================================================================

async def connect(self) -> None:
async def connect(self, ws_headers: Optional[dict] = None) -> None:
"""Connect to the Speechmatics API.

Establishes WebSocket connection and starts the transcription session.
This must be called before sending audio.

Args:
ws_headers: Optional headers to pass to the WebSocket connection.

Raises:
Exception: If connection fails.

Expand Down Expand Up @@ -521,6 +525,7 @@ async def connect(self) -> None:
await self.start_session(
transcription_config=self._transcription_config,
audio_format=self._audio_format,
ws_headers=ws_headers,
)
self._is_connected = True
self._start_metrics_task()
Expand Down Expand Up @@ -717,14 +722,11 @@ def update_diarization_config(self, config: SpeakerFocusConfig) -> None:
# PUBLIC UTTERANCE / TURN MANAGEMENT
# ============================================================================

def finalize(self, end_of_turn: bool = False) -> None:
def finalize(self) -> None:
"""Finalize segments.

This function will emit segments in the buffer without any further checks
on the contents of the segments.

Args:
end_of_turn: Whether to emit an end of turn message.
"""

# Clear smart turn cutoff
Expand Down Expand Up @@ -1526,6 +1528,12 @@ async def _calculate_finalize_delay(
# Smart Turn enabled
if self._smart_turn_detector:
annotation.add(AnnotationFlags.SMART_TURN_ACTIVE)
# If Smart Turn hasn't returned a result yet but is enabled, add NO_SIGNAL annotation.
# This covers the case where the TTL fires before VAD triggers Smart Turn inference.
if not annotation.has(AnnotationFlags.SMART_TURN_TRUE) and not annotation.has(
AnnotationFlags.SMART_TURN_FALSE
):
annotation.add(AnnotationFlags.SMART_TURN_NO_SIGNAL)
else:
annotation.add(AnnotationFlags.SMART_TURN_INACTIVE)

Expand Down
33 changes: 28 additions & 5 deletions sdk/voice/speechmatics/voice/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class AnnotationFlags(str, Enum):
SMART_TURN_INACTIVE = "smart_turn_inactive"
SMART_TURN_TRUE = "smart_turn_true"
SMART_TURN_FALSE = "smart_turn_false"
SMART_TURN_NO_SIGNAL = "smart_turn_no_signal"


# ==============================================================================
Expand Down Expand Up @@ -417,24 +418,46 @@ class EndOfTurnConfig(BaseModel):
min_end_of_turn_delay: float = 0.01
penalties: list[EndOfTurnPenaltyItem] = Field(
default_factory=lambda: [
# Increase delay
#
# Speaker rate increases expected TTL
EndOfTurnPenaltyItem(penalty=3.0, annotation=[AnnotationFlags.VERY_SLOW_SPEAKER]),
EndOfTurnPenaltyItem(penalty=2.0, annotation=[AnnotationFlags.SLOW_SPEAKER]),
#
# High / low rate of disfluencies
EndOfTurnPenaltyItem(penalty=2.5, annotation=[AnnotationFlags.ENDS_WITH_DISFLUENCY]),
EndOfTurnPenaltyItem(penalty=1.1, annotation=[AnnotationFlags.HAS_DISFLUENCY]),
#
# We do NOT have an end of sentence character
EndOfTurnPenaltyItem(
penalty=2.0,
annotation=[AnnotationFlags.ENDS_WITH_EOS],
is_not=True,
),
# Decrease delay
#
# We have finals and end of sentence
EndOfTurnPenaltyItem(
penalty=0.5, annotation=[AnnotationFlags.ENDS_WITH_FINAL, AnnotationFlags.ENDS_WITH_EOS]
),
# Smart Turn + VAD
EndOfTurnPenaltyItem(penalty=0.2, annotation=[AnnotationFlags.SMART_TURN_TRUE]),
#
# Smart Turn - when false, wait longer to prevent premature end of turn
EndOfTurnPenaltyItem(
penalty=0.2, annotation=[AnnotationFlags.VAD_STOPPED, AnnotationFlags.SMART_TURN_INACTIVE]
penalty=0.2, annotation=[AnnotationFlags.SMART_TURN_TRUE, AnnotationFlags.SMART_TURN_ACTIVE]
),
EndOfTurnPenaltyItem(
penalty=2.0, annotation=[AnnotationFlags.SMART_TURN_FALSE, AnnotationFlags.SMART_TURN_ACTIVE]
),
EndOfTurnPenaltyItem(
penalty=1.5, annotation=[AnnotationFlags.SMART_TURN_NO_SIGNAL, AnnotationFlags.SMART_TURN_ACTIVE]
),
#
# VAD - only applied when smart turn is not in use and on the speaker stopping
EndOfTurnPenaltyItem(
penalty=0.2,
annotation=[
AnnotationFlags.VAD_STOPPED,
AnnotationFlags.VAD_ACTIVE,
AnnotationFlags.SMART_TURN_INACTIVE,
],
),
]
)
Expand Down
98 changes: 53 additions & 45 deletions tests/voice/test_17_eou_feou.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,41 +48,41 @@ class TranscriptionTests(BaseModel):
SAMPLES: TranscriptionTests = TranscriptionTests.from_dict(
{
"samples": [
# {
# "id": "07b",
# "path": "./assets/audio_07b_16kHz.wav",
# "sample_rate": 16000,
# "language": "en",
# "segments": [
# {"text": "Hello.", "start_time": 1.05, "end_time": 1.67},
# {"text": "Tomorrow.", "start_time": 3.5, "end_time": 4.1},
# {"text": "Wednesday.", "start_time": 6.05, "end_time": 6.73},
# {"text": "Of course. That's fine.", "start_time": 8.8, "end_time": 9.96},
# {"text": "Behind.", "start_time": 12.03, "end_time": 12.73},
# {"text": "In front.", "start_time": 14.84, "end_time": 15.52},
# {"text": "Do you think so?", "start_time": 17.54, "end_time": 18.32},
# {"text": "Brilliant.", "start_time": 20.55, "end_time": 21.08},
# {"text": "Banana.", "start_time": 22.98, "end_time": 23.53},
# {"text": "When?", "start_time": 25.49, "end_time": 25.96},
# {"text": "Today.", "start_time": 27.66, "end_time": 28.15},
# {"text": "This morning.", "start_time": 29.91, "end_time": 30.47},
# {"text": "Goodbye.", "start_time": 32.21, "end_time": 32.68},
# ],
# },
# {
# "id": "08",
# "path": "./assets/audio_08_16kHz.wav",
# "sample_rate": 16000,
# "language": "en",
# "segments": [
# {"text": "Hello.", "start_time": 0.4, "end_time": 0.75},
# {"text": "Goodbye.", "start_time": 2.12, "end_time": 2.5},
# {"text": "Banana.", "start_time": 3.84, "end_time": 4.27},
# {"text": "Breakaway.", "start_time": 5.62, "end_time": 6.42},
# {"text": "Before.", "start_time": 7.76, "end_time": 8.16},
# {"text": "After.", "start_time": 9.56, "end_time": 10.05},
# ],
# },
{
"id": "07b",
"path": "./assets/audio_07b_16kHz.wav",
"sample_rate": 16000,
"language": "en",
"segments": [
{"text": "Hello.", "start_time": 1.05, "end_time": 1.67},
{"text": "Tomorrow.", "start_time": 3.5, "end_time": 4.1},
{"text": "Wednesday.", "start_time": 6.05, "end_time": 6.73},
{"text": "Of course. That's fine.", "start_time": 8.8, "end_time": 9.96},
{"text": "Behind.", "start_time": 12.03, "end_time": 12.73},
{"text": "In front.", "start_time": 14.84, "end_time": 15.52},
{"text": "Do you think so?", "start_time": 17.54, "end_time": 18.32},
{"text": "Brilliant.", "start_time": 20.55, "end_time": 21.08},
{"text": "Banana.", "start_time": 22.98, "end_time": 23.53},
{"text": "When?", "start_time": 25.49, "end_time": 25.96},
{"text": "Today.", "start_time": 27.66, "end_time": 28.15},
{"text": "This morning.", "start_time": 29.91, "end_time": 30.47},
{"text": "Goodbye.", "start_time": 32.21, "end_time": 32.68},
],
},
{
"id": "08",
"path": "./assets/audio_08_16kHz.wav",
"sample_rate": 16000,
"language": "en",
"segments": [
{"text": "Hello.", "start_time": 0.4, "end_time": 0.75},
{"text": "Goodbye.", "start_time": 2.12, "end_time": 2.5},
{"text": "Banana.", "start_time": 3.84, "end_time": 4.27},
{"text": "Breakaway.", "start_time": 5.62, "end_time": 6.42},
{"text": "Before.", "start_time": 7.76, "end_time": 8.16},
{"text": "After.", "start_time": 9.56, "end_time": 10.05},
],
},
{
"id": "09",
"path": "./assets/audio_09_16kHz.wav",
Expand All @@ -97,12 +97,12 @@ class TranscriptionTests(BaseModel):
)

# VAD delay
VAD_DELAY_S: list[float] = [0.18, 0.22]
VAD_DELAY_S: list[float] = [0.18] # , 0.22]

# Endpoints
ENDPOINTS: list[str] = [
# "wss://eu-west-2-research.speechmatics.cloud/v2",
"wss://eu.rt.speechmatics.com/v2",
"wss://eu-west-2-research.speechmatics.cloud/v2",
# "wss://eu.rt.speechmatics.com/v2",
# "wss://us.rt.speechmatics.com/v2",
]

Expand Down Expand Up @@ -177,6 +177,11 @@ async def run_test(endpoint: str, sample: TranscriptionTest, config: VoiceAgentC
# Start time
start_time = datetime.datetime.now()

# Zero time
def zero_time(message):
global start_time
start_time = datetime.datetime.now()

# Finalized segment
def add_segments(message):
segments = message["segments"]
Expand Down Expand Up @@ -213,19 +218,20 @@ def log_message(message):
log = json.dumps({"ts": round(ts, 3), "payload": message})
print(log)

# Custom listeners
client.on(AgentServerMessageType.RECOGNITION_STARTED, zero_time)
client.on(AgentServerMessageType.END_OF_TURN, eot_detected)
client.on(AgentServerMessageType.ADD_SEGMENT, add_segments)
client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial)
client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial)

# Add listeners
if SHOW_LOG:
message_types = [m for m in AgentServerMessageType if m != AgentServerMessageType.AUDIO_ADDED]
# message_types = [AgentServerMessageType.ADD_SEGMENT]
for message_type in message_types:
client.on(message_type, log_message)

# Custom listeners
client.on(AgentServerMessageType.END_OF_TURN, eot_detected)
client.on(AgentServerMessageType.ADD_SEGMENT, add_segments)
client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial)
client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial)

# HEADER
if SHOW_LOG:
print()
Expand Down Expand Up @@ -326,7 +332,9 @@ def log_message(message):
# Calculate the CER
cer = TextUtils.cer(normalized_expected, normalized_received)

print(f"[{idx}] `{normalized_expected}` -> `{normalized_received}` (CER: {cer:.1%})")
# Debug metrics
if SHOW_LOG:
print(f"[{idx}] `{normalized_expected}` -> `{normalized_received}` (CER: {cer:.1%})")

# Check CER
if cer > CER_THRESHOLD:
Expand Down
Loading