Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 75 additions & 19 deletions aai_cli/commands/evaluate/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from __future__ import annotations

import math
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from enum import StrEnum
Expand Down Expand Up @@ -54,56 +56,97 @@ def _pct(value: object) -> str:
return f"{jsonshape.as_float(value):.2%}"


def _secs(value: object) -> str:
"""A latency in seconds, formatted for display."""
return f"{jsonshape.as_float(value):.2f}s"


def _percentile(values: list[float], q: float) -> float:
"""The q-quantile (q in [0, 1]) of ``values``, linearly interpolated between
the two closest ranks (numpy's default method). ``values`` must be non-empty."""
ordered = sorted(values)
pos = q * (len(ordered) - 1)
low = math.floor(pos)
high = math.ceil(pos)
if low == high:
return ordered[low]
return ordered[low] + (ordered[high] - ordered[low]) * (pos - low)


@dataclass(frozen=True)
class _ItemResult:
"""One scored row: the emitted dict plus the score kept for pooling."""
"""One scored row: the emitted dict plus the score and latency kept for pooling."""

row: dict[str, object]
words: wer.Score | None
latency: float


def _failed_result(item: eval_data.EvalItem, err: CLIError) -> _ItemResult:
"""A row whose transcription failed: the error rides along, no scores pooled."""
return _ItemResult(row={"item": item.item_id, "error": err.message}, words=None)
def _failed_result(item: eval_data.EvalItem, err: CLIError, latency: float) -> _ItemResult:
"""A row whose transcription failed: the error and latency ride along, no scores pooled."""
return _ItemResult(
row={"item": item.item_id, "error": err.message, "latency": latency},
words=None,
latency=latency,
)


def _score_item(item: eval_data.EvalItem, transcript: aai.Transcript) -> _ItemResult:
def _score_item(
item: eval_data.EvalItem, transcript: aai.Transcript, latency: float
) -> _ItemResult:
words = wer.score(item.reference, str(transcript.text or ""))
row: dict[str, object] = {
"item": item.item_id,
"words": words.words,
"errors": words.errors,
"wer": words.wer,
"latency": latency,
}
return _ItemResult(row=row, words=words)
return _ItemResult(row=row, words=words, latency=latency)


def _pooled_metrics(results: list[_ItemResult]) -> dict[str, object]:
"""The summary scores pooled over the scored rows (failed rows carry none)."""
"""The summary metrics: WER pooled over the scored rows (failed rows carry none),
and the latency distribution over every row that ran a transcription."""
metrics: dict[str, object] = {}
word_scores = [result.words for result in results if result.words is not None]
if word_scores:
total = wer.pooled(word_scores)
metrics.update({"words": total.words, "errors": total.errors, "wer": total.wer})
latencies = [result.latency for result in results]
if latencies:
metrics["latency_p50"] = _percentile(latencies, 0.5)
metrics["latency_p90"] = _percentile(latencies, 0.9)
return metrics


@dataclass(frozen=True)
class _Timed:
"""One transcription's outcome paired with its wall-clock latency in seconds."""

outcome: aai.Transcript | CLIError
latency: float


def _transcribe_one(
api_key: str, item: eval_data.EvalItem, config: aai.TranscriptionConfig
) -> aai.Transcript | CLIError:
"""One item's outcome: its transcript, or the CLIError it failed with.
) -> _Timed:
"""One item's timed outcome: its transcript (or the CLIError it failed with) and
the wall-clock latency of the request.

A bad item must not discard the other (paid) items, so per-item failures
are recorded rather than raised — except ``NotAuthenticated`` (one rejected
key fails every row identically) and non-CLIError bugs, which propagate and
abort the run.
"""
start = time.perf_counter()
try:
return client.transcribe(api_key, item.audio, config=config)
outcome: aai.Transcript | CLIError = client.transcribe(api_key, item.audio, config=config)
except NotAuthenticated:
raise
except CLIError as err:
return err
outcome = err
return _Timed(outcome=outcome, latency=time.perf_counter() - start)


def _concurrent_transcripts(
Expand All @@ -112,7 +155,7 @@ def _concurrent_transcripts(
*,
transcription_config: aai.TranscriptionConfig,
concurrency: int,
) -> list[aai.Transcript | CLIError]:
) -> list[_Timed]:
with ThreadPoolExecutor(max_workers=concurrency) as pool:
futures = [
pool.submit(_transcribe_one, api_key, item, transcription_config) for item in items
Expand All @@ -134,15 +177,15 @@ def _transcripts(
concurrency: int,
json_mode: bool,
quiet: bool,
) -> list[aai.Transcript | CLIError]:
"""Each item's transcript — or the CLIError it failed with — in dataset order.
) -> list[_Timed]:
"""Each item's timed transcript — or the CLIError it failed with — in dataset order.

Sequential by default, with a per-item spinner; ``--concurrency`` fans the
API calls out across a thread pool (see ``_transcribe_one`` for which
failures are per-item outcomes and which abort the run).
"""
if concurrency == 1:
outcomes: list[aai.Transcript | CLIError] = []
outcomes: list[_Timed] = []
for index, item in enumerate(items, start=1):
with output.status(
f"[{index}/{len(items)}] Transcribing {item.item_id}…",
Expand Down Expand Up @@ -185,6 +228,11 @@ def _summary(payload: dict[str, object]) -> str:
parts.append(
f"WER {_pct(payload.get('wer'))} ({errors} {noun} / {payload.get('words')} words)"
)
if "latency_p50" in payload:
parts.append(
f"latency p50 {_secs(payload.get('latency_p50'))}"
f" · p90 {_secs(payload.get('latency_p90'))}"
)
return output.heading(" ".join(parts))


Expand All @@ -197,19 +245,27 @@ def _pct_cell(row: dict[str, object], key: str) -> str:
return _pct(row[key]) if key in row else ""


def _secs_cell(row: dict[str, object], key: str) -> str:
return _secs(row[key]) if key in row else ""


def _render(payload: dict[str, object]) -> RenderableType:
has_wer = "wer" in payload
has_failed = "failed" in payload
has_latency = "latency_p50" in payload
columns = [
"ITEM",
*(["WORDS", "ERRORS", "WER"] if has_wer else []),
*(["LATENCY"] if has_latency else []),
*(["ERROR"] if has_failed else []),
]
table = output.data_table(*columns)
for row in jsonshape.mapping_list(payload.get("rows")):
cells = [str(row.get("item"))]
if has_wer:
cells += [_cell(row, "words"), _cell(row, "errors"), _pct_cell(row, "wer")]
if has_latency:
cells.append(_secs_cell(row, "latency"))
if has_failed:
cells.append(_cell(row, "error"))
table.add_row(*cells)
Expand Down Expand Up @@ -245,10 +301,10 @@ def run_evaluate(opts: EvalOptions, state: AppState, *, json_mode: bool) -> None
quiet=state.quiet,
)
results = [
_failed_result(item, outcome)
if isinstance(outcome, CLIError)
else _score_item(item, outcome)
for item, outcome in zip(
_failed_result(item, timed.outcome, timed.latency)
if isinstance(timed.outcome, CLIError)
else _score_item(item, timed.outcome, timed.latency)
for item, timed in zip(
data.items,
outcomes,
strict=True, # pragma: no mutate (defensive invariant; _transcripts returns one outcome per item)
Expand Down
97 changes: 94 additions & 3 deletions tests/test_eval_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@ def _payload_of(result):
)


def _without_latency(row):
return {key: value for key, value in row.items() if key != "latency"}


def _fake_perf_counter(mocker, ticks):
"""Pin the eval timer so each row's latency is a known constant.

The sequential path reads perf_counter as start/end per item, so ``ticks``
is consumed two at a time: (start, end) for item 1, then item 2, …
"""
return mocker.patch(
"aai_cli.commands.evaluate._exec.time.perf_counter",
autospec=True,
side_effect=list(ticks),
)


def test_wer_table_with_per_file_and_pooled_scores(tmp_path, mocker):
_auth()
_write_wer_manifest(tmp_path)
Expand Down Expand Up @@ -94,8 +111,13 @@ def test_json_payload_shape(tmp_path, mocker):
assert payload["words"] == 4
assert payload["errors"] == 1
assert payload["wer"] == 0.25
assert payload["rows"][0] == {"item": "a.wav", "words": 2, "errors": 0, "wer": 0.0}
assert payload["rows"][1] == {"item": "b.wav", "words": 2, "errors": 1, "wer": 0.5}
assert _without_latency(payload["rows"][0]) == {
"item": "a.wav", "words": 2, "errors": 0, "wer": 0.0
} # fmt: skip
assert _without_latency(payload["rows"][1]) == {
"item": "b.wav", "words": 2, "errors": 1, "wer": 0.5
} # fmt: skip
assert all(isinstance(row["latency"], float) for row in payload["rows"])
assert "failed" not in payload # only present when a row failed


Expand Down Expand Up @@ -144,11 +166,20 @@ def _assign(obj, attribute, value):
def test_item_results_are_immutable():
from aai_cli.commands.evaluate._exec import _ItemResult

result = _ItemResult(row={}, words=None)
result = _ItemResult(row={}, words=None, latency=0.0)
with pytest.raises(dataclasses.FrozenInstanceError):
_assign(result, "words", None)


def test_timed_outcome_is_immutable():
from aai_cli.commands.evaluate._exec import _Timed
from aai_cli.core.errors import APIError

timed = _Timed(outcome=APIError("boom"), latency=1.0)
with pytest.raises(dataclasses.FrozenInstanceError):
_assign(timed, "latency", 2.0)


def test_missing_transcript_text_scores_as_all_deletions(tmp_path, mocker):
_auth()
(tmp_path / "a.wav").write_bytes(b"fake-audio")
Expand Down Expand Up @@ -247,3 +278,63 @@ def test_unauthenticated_exits_with_auth_code(tmp_path):
(tmp_path / "m.csv").write_text("audio,text\na.wav,hello\n", encoding="utf-8")
result = runner.invoke(app, ["eval", "m.csv"])
assert result.exit_code == 4


def test_per_row_latency_and_percentiles_in_json(tmp_path, mocker):
_auth()
_write_wer_manifest(tmp_path)
_mock_transcribe(mocker, [_transcript("hello there"), _transcript("goodbye now")])
# (start, end) per item: row a takes 1.5s, row b takes 0.5s (starts nonzero so a
# mutated `end + start` would diverge from `end - start`).
_fake_perf_counter(mocker, [10.0, 11.5, 20.0, 20.5])
payload = _payload_of(runner.invoke(app, ["eval", "manifest.csv", "--json"]))
assert payload["rows"][0]["latency"] == 1.5
assert payload["rows"][1]["latency"] == 0.5
# Pooled over [0.5, 1.5]: p50 = 1.0, p90 = 0.5 + 1.0*0.9 = 1.4.
assert payload["latency_p50"] == pytest.approx(1.0)
assert payload["latency_p90"] == pytest.approx(1.4)


def test_human_output_shows_latency_column_and_summary(tmp_path, mocker):
_auth()
_write_wer_manifest(tmp_path)
_mock_transcribe(mocker, [_transcript("hello there"), _transcript("goodbye now")])
_fake_perf_counter(mocker, [10.0, 11.5, 20.0, 20.5])
result = runner.invoke(app, ["eval", "manifest.csv"])
assert result.exit_code == 0
assert "LATENCY" in result.output # the per-row column header
assert "1.50s" in result.output # row a's latency, seconds-formatted
assert "0.50s" in result.output # row b's latency
assert "latency p50 1.00s · p90 1.40s" in result.output # the pooled summary


def test_failed_row_still_carries_latency(tmp_path, mocker):
from aai_cli.core.errors import APIError

_auth()
_write_wer_manifest(tmp_path)
_mock_transcribe(mocker, [_transcript("hello there"), APIError("rate limited")])
_fake_perf_counter(mocker, [10.0, 11.0, 20.0, 20.25])
payload = _payload_of(runner.invoke(app, ["eval", "manifest.csv", "--json"]))
failed_row = next(row for row in payload["rows"] if "error" in row)
assert failed_row["latency"] == 0.25 # the timer wraps the failing call too
# The latency distribution pools the failed row alongside the scored one.
assert payload["latency_p50"] == pytest.approx(0.625)


@pytest.mark.parametrize(
("values", "q", "expected"),
[
([5.0], 0.5, 5.0), # single value: every quantile is that value
([1.0, 2.0, 3.0], 0.5, 2.0), # odd count, exact rank -> the median element
([1.0, 2.0, 3.0, 4.0], 0.5, 2.5), # even count -> interpolated midpoint
([0.0, 10.0], 0.9, 9.0), # interpolation between the two ranks
([1.0, 2.0, 3.0, 4.0], 0.0, 1.0), # q=0 -> minimum
([1.0, 2.0, 3.0, 4.0], 1.0, 4.0), # q=1 -> maximum
],
)
def test_percentile_interpolates_between_ranks(values, q, expected):
from aai_cli.commands.evaluate._exec import _percentile

# Pass values out of order to prove _percentile sorts before interpolating.
assert _percentile(list(reversed(values)), q) == pytest.approx(expected)
9 changes: 6 additions & 3 deletions tests/test_eval_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_mock_transcribe,
_payload_of,
_transcript,
_without_latency,
_write_wer_manifest,
)

Expand Down Expand Up @@ -172,9 +173,11 @@ def test_failed_row_keeps_completed_rows_and_summary_pools_scored_only(tmp_path,
payload = _payload_of(result)
assert payload["items"] == 3
assert payload["failed"] == 1
assert payload["rows"][0] == {"item": "a.wav", "words": 2, "errors": 0, "wer": 0.0}
assert payload["rows"][1] == {"item": "b.wav", "error": "rate limited"}
assert payload["rows"][2] == {"item": "c.wav", "words": 2, "errors": 0, "wer": 0.0}
rows = [_without_latency(row) for row in payload["rows"]]
assert rows[0] == {"item": "a.wav", "words": 2, "errors": 0, "wer": 0.0}
assert rows[1] == {"item": "b.wav", "error": "rate limited"}
assert rows[2] == {"item": "c.wav", "words": 2, "errors": 0, "wer": 0.0}
assert all(isinstance(row["latency"], float) for row in payload["rows"])
# Pooled over the two scored rows only — the failed row contributes no words.
assert payload["words"] == 4
assert payload["errors"] == 0
Expand Down
Loading