Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 78 additions & 64 deletions aai_cli/commands/account.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from collections.abc import Mapping
from datetime import UTC, date, datetime, timedelta
from typing import Annotated

import typer
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field
from rich.markup import escape
from rich.text import Text

Expand Down Expand Up @@ -43,67 +44,85 @@ def _format_usage_number(value: object) -> str:
return f"{number:,.6f}".rstrip("0").rstrip(".")


def _usage_items(data: Mapping[str, object]) -> list[dict[str, object]]:
return jsonshape.mapping_list(data.get("usage_items"))


def _format_dollars(cents: float) -> str:
return f"${cents / 100:,.2f}"


def _window_total_cents(item: Mapping[str, object]) -> float:
"""Sum a window's spend (cents) from its ``line_items``.
# The product/feature label keys a usage line item may carry, in preference order.
_LABEL_KEYS = ("name", "product", "service", "feature", "model", "type", "description")

The AMS usage endpoint returns ``total: 0.0`` on every window; the real
spend lives in each window's ``line_items[].price`` (cents, like
``balance_in_cents``), so the window total is derived from them rather than
the dead top-level ``total``.
"""
return sum(
jsonshape.as_float(line_item.get("price"))
for line_item in jsonshape.mapping_list(item.get("line_items"))
)


def _window_label(item: Mapping[str, object]) -> str:
start = timeparse.parse_iso_utc(item.get("start_timestamp"))
end = timeparse.parse_iso_utc(item.get("end_timestamp"))
if start is None or end is None:
return timeparse.format_utc_day(item.get("start_timestamp"))
if end.date() == start.date() + timedelta(days=1):
return start.date().isoformat()
return f"{start.date().isoformat()} to {end.date().isoformat()}"


def _line_item_name(line_item: Mapping[str, object]) -> str:
"""The product/feature label for a usage line item, or ``""`` if it carries none."""
return next(
(
str(value)
for key in ("name", "product", "service", "feature", "model", "type", "description")
if (value := line_item.get(key))
),
"",
)


def _line_items_summary(item: Mapping[str, object]) -> str:
"""Per-product spend for a window, in dollars, aggregated by product and ordered
biggest-first.

Both this and the window total derive from ``line_items[].price`` (cents), so the
breakdown is shown in the same unit as the ``total`` column and the products sum to
that total — they reconcile, instead of mixing dollars with raw quantities. Products
are aggregated by name (the AMS endpoint can return several rows for one product),
a row with no recognizable product is grouped under ``other``, and zero-dollar
products are dropped as noise (they don't affect the reconciliation).
# AMS payload shapes drift, so the usage models are deliberately tolerant: unknown
# fields are ignored, a junk price falls back to 0.0, and a non-list/non-object
# never raises — `assembly usage` must degrade gracefully, not crash. They are
# parse-side only: `--json` passes the raw AMS dict through untouched.
_MappingList = BeforeValidator(jsonshape.mapping_list)


class _LineItem(BaseModel):
"""One usage line item: a ``price`` in cents plus whichever label key AMS used."""

model_config = ConfigDict(extra="allow")

price: Annotated[float, BeforeValidator(jsonshape.as_float)] = 0.0

@property
def label(self) -> str:
"""The product/feature label for the item, or ``""`` if it carries none."""
extra = self.model_extra or {}
return next((str(value) for key in _LABEL_KEYS if (value := extra.get(key))), "")


class _Window(BaseModel):
"""One usage window.

The AMS usage endpoint returns ``total: 0.0`` on every window; the real spend
lives in each window's ``line_items[].price`` (cents, like ``balance_in_cents``),
so the window total is derived from them rather than the dead top-level ``total``.
"""
totals: dict[str, float] = {}
for line_item in jsonshape.mapping_list(item.get("line_items")):
name = _line_item_name(line_item) or "other"
totals[name] = totals.get(name, 0.0) + jsonshape.as_float(line_item.get("price"))
ordered = sorted(((n, c) for n, c in totals.items() if c), key=lambda nc: (-nc[1], nc[0]))
return ", ".join(f"{name}: {_format_dollars(cents)}" for name, cents in ordered)

start_timestamp: object = None
end_timestamp: object = None
line_items: Annotated[list[_LineItem], _MappingList] = Field(default_factory=list[_LineItem])

@property
def total_cents(self) -> float:
return sum(item.price for item in self.line_items)

@property
def label(self) -> str:
start = timeparse.parse_iso_utc(self.start_timestamp)
end = timeparse.parse_iso_utc(self.end_timestamp)
if start is None or end is None:
return timeparse.format_utc_day(self.start_timestamp)
if end.date() == start.date() + timedelta(days=1):
return start.date().isoformat()
return f"{start.date().isoformat()} to {end.date().isoformat()}"

@property
def breakdown(self) -> str:
"""Per-product spend for the window, in dollars, aggregated by product and
ordered biggest-first.

Both this and ``total_cents`` derive from ``line_items[].price`` (cents), so
the breakdown is shown in the same unit as the ``total`` column and the
products sum to that total — they reconcile, instead of mixing dollars with
raw quantities. Products are aggregated by label (the AMS endpoint can return
several rows for one product), a row with no recognizable product is grouped
under ``other``, and zero-dollar products are dropped as noise (they don't
affect the reconciliation).
"""
totals: dict[str, float] = {}
for item in self.line_items:
name = item.label or "other"
totals[name] = totals.get(name, 0.0) + item.price
ordered = sorted(((n, c) for n, c in totals.items() if c), key=lambda nc: (-nc[1], nc[0]))
return ", ".join(f"{name}: {_format_dollars(cents)}" for name, cents in ordered)


class _Usage(BaseModel):
"""The AMS usage response: just the windows; everything else is passthrough."""

usage_items: Annotated[list[_Window], _MappingList] = Field(default_factory=list[_Window])


app = typer.Typer(help="Account billing, usage, and limits.")
Expand Down Expand Up @@ -192,7 +211,7 @@ def body(state: AppState, json_mode: bool) -> None:
data = ams.get_usage(jwt, start_date, end_date, window)

def render(d: dict[str, object]) -> object:
windows = [(item, _window_total_cents(item)) for item in _usage_items(d)]
windows = [(item, item.total_cents) for item in _Usage.model_validate(d).usage_items]
shown = windows if include_zero else [w for w in windows if w[1]]
total = sum(cents for _, cents in windows)
range_label = (
Expand All @@ -211,9 +230,7 @@ def render(d: dict[str, object]) -> object:
)
return output.stack(summary, output.muted(message))

shown_with_breakdown = [
(item, cents, _line_items_summary(item)) for item, cents in shown
]
shown_with_breakdown = [(item, cents, item.breakdown) for item, cents in shown]
show_breakdown = any(breakdown for _, _, breakdown in shown_with_breakdown)
table = (
output.data_table("period", "total", "breakdown")
Expand All @@ -222,10 +239,7 @@ def render(d: dict[str, object]) -> object:
)
hidden_count = len(windows) - len(shown)
for item, cents, breakdown in shown_with_breakdown:
row = [
escape(_window_label(item)),
_format_dollars(cents),
]
row = [escape(item.label), _format_dollars(cents)]
if show_breakdown:
row.append(escape(breakdown))
table.add_row(*row)
Expand Down
67 changes: 20 additions & 47 deletions aai_cli/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import assemblyai as aai
from assemblyai.streaming.v3 import SpeechModel, StreamingParameters
from pydantic import JsonValue, TypeAdapter, ValidationError

from aai_cli import jsonshape
from aai_cli.errors import UsageError
Expand Down Expand Up @@ -183,59 +184,31 @@ def _coerce_table(model_cls: type, names: tuple[str, ...]) -> dict[str, str]:
TRANSCRIBE_FIELDS = TRANSCRIBE_COERCE
STREAM_FIELDS = STREAM_COERCE

_TRUE = {"1", "true", "yes", "on"}
_FALSE = {"0", "false", "no", "off"}


def _coerce_bool(field: str, raw: str) -> object:
low = raw.strip().lower()
if low in _TRUE:
return True
if low in _FALSE:
return False
raise UsageError(f"{field} expects a boolean (true/false), got {raw!r}.")


def _coerce_int(field: str, raw: str) -> object:
try:
return int(raw)
except ValueError as exc:
raise UsageError(f"{field} expects an integer, got {raw!r}.") from exc


def _coerce_float(field: str, raw: str) -> object:
try:
return float(raw)
except ValueError as exc:
raise UsageError(f"{field} expects a number, got {raw!r}.") from exc


def _coerce_list(_field: str, raw: str) -> object:
return [part.strip() for part in raw.split(",") if part.strip()]


def _coerce_json(field: str, raw: str) -> object:
try:
return json.loads(raw)
except json.JSONDecodeError as exc:
raise UsageError(f"{field} expects a JSON value, got {raw!r}.") from exc


# Coercion kind -> coercer. Kinds absent here ("str", and any unknown) pass through raw.
_COERCERS: dict[str, Callable[[str, str], object]] = {
"bool": _coerce_bool,
"int": _coerce_int,
"float": _coerce_float,
"list": _coerce_list,
"json": _coerce_json,
# Coercion kind -> (lax pydantic parser, expectation named in the error). Pydantic
# parses the CLI's string inputs (bool spellings, int/float, raw JSON values);
# "list" (CSV split) and "str"/unknown (passthrough) are handled inline in
# `coerce_value`.
_VALIDATORS: dict[str, tuple[Callable[[str], object], str]] = {
"bool": (TypeAdapter[object](bool).validate_python, "a boolean (true/false)"),
"int": (TypeAdapter[object](int).validate_python, "an integer"),
"float": (TypeAdapter[object](float).validate_python, "a number"),
"json": (TypeAdapter[object](JsonValue).validate_json, "a JSON value"),
}


def coerce_value(field: str, raw: str) -> object:
"""Coerce a string --config value to the type expected by `field`."""
kind = TRANSCRIBE_COERCE.get(field) or STREAM_COERCE.get(field, "str")
coercer = _COERCERS.get(kind)
return coercer(field, raw) if coercer is not None else raw
if kind == "list":
return [part.strip() for part in raw.split(",") if part.strip()]
entry = _VALIDATORS.get(kind)
if entry is None: # "str" and any unknown kind pass through raw
return raw
validate, expected = entry
try:
return validate(raw)
except ValidationError as exc:
raise UsageError(f"{field} expects {expected}, got {raw!r}.") from exc


def parse_config_overrides(
Expand Down
64 changes: 39 additions & 25 deletions tests/test_account_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,75 +109,89 @@ def test_usage_renders_table_human(monkeypatch, mocker):
assert "2026-05-01" in result.output and "$12.50" in result.output


def test_usage_helpers_format_windows_and_line_items():
assert account._usage_items({"usage_items": "bad"}) == []
assert account._usage_items({"usage_items": [{"total": 1}, "bad"]}) == [{"total": 1}]
def _window(payload):
return account._Window.model_validate(payload)


def _line_item(payload):
return account._LineItem.model_validate(payload)


def test_usage_models_tolerate_junk_shapes():
# A non-list `usage_items` and non-object windows degrade to nothing, not a crash.
assert account._Usage.model_validate({"usage_items": "bad"}).usage_items == []
kept = account._Usage.model_validate({"usage_items": [{"total": 1}, "bad"]}).usage_items
assert kept == [account._Window()] # the object survives; the junk item is dropped
assert _window({"line_items": "bad"}).breakdown == ""
# A line item without a price counts as zero spend (pins the 0.0 field default).
assert _window({"line_items": [{"name": "x"}]}).total_cents == 0.0
assert _line_item({"price": "junk"}).price == 0.0


def test_usage_models_format_windows_and_line_items():
# Window total is the sum of line-item `price` (cents); the dead top-level
# `total` field the AMS endpoint returns is ignored.
assert (
account._window_total_cents(
{"total": 0.0, "line_items": [{"price": 1250.0}, {"price": 0.5}]}
)
_window({"total": 0.0, "line_items": [{"price": 1250.0}, {"price": 0.5}]}).total_cents
== 1250.5
)
assert account._window_total_cents({"total": 99.0, "line_items": []}) == 0.0
assert account._window_label({"start_timestamp": "bad"}) == "bad"
assert _window({"total": 99.0, "line_items": []}).total_cents == 0.0
assert _window({"start_timestamp": "bad"}).label == "bad"
assert (
account._window_label(
_window(
{
"start_timestamp": "2026-01-01T00:00:00Z",
"end_timestamp": "2026-01-03T00:00:00Z",
}
)
).label
== "2026-01-01 to 2026-01-03"
)
# Exactly one parseable bound falls back to the single start-day label (pins the
# `start is None or end is None` guard; an `and` would dereference the None end).
assert account._window_label({"start_timestamp": "2026-01-01T00:00:00Z"}) == "2026-01-01"
assert _window({"start_timestamp": "2026-01-01T00:00:00Z"}).label == "2026-01-01"
# A one-day window (end == start + 1 day) collapses to a single day, not a range
# (pins the `start.date() + timedelta(days=1)`).
assert (
account._window_label(
_window(
{
"start_timestamp": "2026-01-01T00:00:00Z",
"end_timestamp": "2026-01-02T00:00:00Z",
}
)
).label
== "2026-01-01"
)
# Every recognized label key resolves (pins each entry in the lookup tuple).
for key in ("name", "product", "service", "feature", "model", "type", "description"):
assert account._line_item_name({key: "X"}) == "X"
assert account._line_item_name({"name": "minutes", "total": "12.500"}) == "minutes"
assert account._line_item_name({"product": "streaming"}) == "streaming"
assert account._line_item_name({"quantity": 3}) == ""
assert account._line_item_name({}) == ""
assert _line_item({key: "X"}).label == "X"
assert _line_item({"name": "minutes", "total": "12.500"}).label == "minutes"
assert _line_item({"product": "streaming"}).label == "streaming"
assert _line_item({"quantity": 3}).label == ""
assert _line_item({}).label == ""
# Breakdown aggregates by product and shows dollars (from `price` cents), biggest
# first, so the line items sum to the window total and reconcile with it.
assert (
account._line_items_summary(
_window(
{
"line_items": [
{"name": "minutes", "price": 1000.0},
{"name": "streaming", "price": 2500.0},
{"name": "minutes", "price": 250.0},
]
}
)
).breakdown
== "streaming: $25.00, minutes: $12.50"
)
# Equal-dollar products break the tie by name (pins the nc[0] secondary sort key).
assert (
account._line_items_summary(
_window(
{"line_items": [{"name": "zeta", "price": 500.0}, {"name": "alpha", "price": 500.0}]}
)
).breakdown
== "alpha: $5.00, zeta: $5.00"
)
# A line item with no recognizable product label is grouped under "other".
assert account._line_items_summary({"line_items": [{"price": 500.0}]}) == "other: $5.00"
assert _window({"line_items": [{"price": 500.0}]}).breakdown == "other: $5.00"
# Zero-dollar products are dropped (they only add noise and still reconcile to 0).
assert account._line_items_summary({"line_items": [{"name": "free", "price": 0.0}]}) == ""
assert account._line_items_summary({"line_items": "bad"}) == ""
assert _window({"line_items": [{"name": "free", "price": 0.0}]}).breakdown == ""


def test_usage_human_renders_breakdown(monkeypatch, mocker):
Expand Down
Loading