Skip to content

Commit 6a6cb76

Browse files
fix(pydantic_ai): strip token metrics from wrapper spans
#312 retyped wrapper spans (agent_run, model_request, streaming wrappers) from LLM to TASK, but they kept logging the same prompt_tokens/completion_tokens/tokens as their nested leaf `chat <model>` span. The server derives estimated_cost per-span from tokens+metadata.model (brainstore estimated_cost.rs), rolls up trace totals with `coalesce_add` over every non-scorer span regardless of type (summary.rs accumulate_metrics), and sums experiment-level token/cost over all non-scorer spans without filtering on span_type='llm' (summary.ts experimentScanSpanSummary). Retyping to TASK therefore did not stop double-counting on any of those three axes. Route every wrapper log site through a new `_wrapper_span_metrics` helper that emits only {start, end, duration, optional time_to_first_token}. Leaf `chat <model>` spans (from _wrap_concrete_model_class and _DirectStreamWrapper when span_type=LLM) keep full _extract_response_metrics. `_DirectStreamWrapper` now branches on span_type since it serves as both leaf and wrapper. Delete now-dead `_extract_usage_metrics` and `_extract_stream_usage_metrics`. Flip existing cassette-backed assertions (test_agent_run_async, test_agent_run_sync, test_agent_run_stream_async, test_agent_with_tools, test_agent_run_stream_sync) to assert prompt_tokens / completion_tokens / tokens / prompt_cached_tokens are absent from wrapper spans and present only on the leaf. No cassette re-recording needed -- the change is purely in post-processing.
1 parent 31d7c07 commit 6a6cb76

2 files changed

Lines changed: 69 additions & 161 deletions

File tree

py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -140,23 +140,21 @@ async def test_agent_run_async(memory_logger):
140140
assert chat_span["metadata"]["provider"] == "openai"
141141
_assert_metrics_are_valid(chat_span["metrics"], start, end)
142142

143-
# Agent spans should have token metrics
144-
assert "prompt_tokens" in agent_span["metrics"]
145-
assert "completion_tokens" in agent_span["metrics"]
146-
assert agent_span["metrics"]["prompt_tokens"] > 0
147-
assert agent_span["metrics"]["completion_tokens"] > 0
148-
149-
# Regression: no double-counting of cost/tokens. Experiment-level aggregations
150-
# sum metrics across type='llm' spans, so a single agent turn must contribute
151-
# its tokens exactly once. The wrapper agent_run span logs the same usage as
152-
# the leaf chat span; only the leaf should be type=LLM.
143+
# Regression: wrapper agent_run span must NOT log token metrics. The leaf chat
144+
# span already logs them, and trace-tree rollup (self + descendants) plus any
145+
# unfiltered sum over metrics would otherwise double-count tokens/cost at the
146+
# parent regardless of span type.
147+
for token_key in ("prompt_tokens", "completion_tokens", "tokens", "prompt_cached_tokens"):
148+
assert token_key not in agent_span["metrics"], (
149+
f"wrapper span must not log {token_key}; it duplicates the leaf chat span"
150+
)
151+
152+
# Only the leaf chat span should be type=LLM, and it must own the token totals.
153153
llm_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.LLM]
154154
assert len(llm_spans) == 1, f"expected exactly one LLM-typed span, got {len(llm_spans)}"
155155
assert llm_spans[0]["span_id"] == chat_span["span_id"]
156-
llm_prompt_tokens_sum = sum(s["metrics"].get("prompt_tokens", 0) for s in llm_spans)
157-
llm_completion_tokens_sum = sum(s["metrics"].get("completion_tokens", 0) for s in llm_spans)
158-
assert llm_prompt_tokens_sum == chat_span["metrics"]["prompt_tokens"]
159-
assert llm_completion_tokens_sum == chat_span["metrics"]["completion_tokens"]
156+
assert chat_span["metrics"]["prompt_tokens"] > 0
157+
assert chat_span["metrics"]["completion_tokens"] > 0
160158

161159

162160
@pytest.mark.vcr
@@ -245,9 +243,12 @@ def is_descendant(child_span, ancestor_id):
245243
assert chat_span["metadata"]["provider"] == "openai"
246244
_assert_metrics_are_valid(chat_span["metrics"], start, end)
247245

248-
# Agent spans should have token metrics
249-
assert "prompt_tokens" in agent_sync_span["metrics"]
250-
assert "completion_tokens" in agent_sync_span["metrics"]
246+
# Wrapper agent_run_sync span must not log token metrics (would double-count at rollup).
247+
assert "prompt_tokens" not in agent_sync_span["metrics"]
248+
assert "completion_tokens" not in agent_sync_span["metrics"]
249+
# Tokens live on the leaf chat span only.
250+
assert chat_span["metrics"]["prompt_tokens"] > 0
251+
assert chat_span["metrics"]["completion_tokens"] > 0
251252

252253

253254
def test_agent_to_cli_sync(memory_logger, monkeypatch):
@@ -544,9 +545,12 @@ async def test_agent_run_stream(memory_logger):
544545
print(f"span_parents: {chat_span['span_parents']}")
545546
print(f"metrics: {chat_span['metrics']}")
546547

547-
# Agent spans should have token metrics
548-
assert "prompt_tokens" in agent_span["metrics"]
549-
assert "completion_tokens" in agent_span["metrics"]
548+
# Wrapper stream span must not log token metrics (would double-count at rollup).
549+
# time_to_first_token is asserted above; it's a non-summable timing metric and stays.
550+
assert "prompt_tokens" not in agent_span["metrics"]
551+
assert "completion_tokens" not in agent_span["metrics"]
552+
assert chat_span["metrics"]["prompt_tokens"] > 0
553+
assert chat_span["metrics"]["completion_tokens"] > 0
550554

551555

552556
@pytest.mark.vcr
@@ -842,9 +846,11 @@ def is_descendant(child_span, ancestor_id):
842846
assert chat_span["metadata"]["provider"] == "openai"
843847
_assert_metrics_are_valid(chat_span["metrics"], start, end)
844848

845-
# Agent spans should have token metrics
846-
assert "prompt_tokens" in agent_span["metrics"]
847-
assert "completion_tokens" in agent_span["metrics"]
849+
# Wrapper agent_run span must not log token metrics (would double-count at rollup).
850+
assert "prompt_tokens" not in agent_span["metrics"]
851+
assert "completion_tokens" not in agent_span["metrics"]
852+
assert chat_span["metrics"]["prompt_tokens"] > 0
853+
assert chat_span["metrics"]["completion_tokens"] > 0
848854

849855

850856
@pytest.mark.vcr
@@ -1143,9 +1149,9 @@ def is_descendant(child_span, ancestor_id):
11431149
# Chat span may not have complete metrics since it's an intermediate span
11441150
assert "start" in chat_span["metrics"]
11451151

1146-
# Agent spans should have token metrics
1147-
assert "prompt_tokens" in agent_span["metrics"]
1148-
assert "completion_tokens" in agent_span["metrics"]
1152+
# Wrapper agent_run_stream_sync span must not log token metrics.
1153+
assert "prompt_tokens" not in agent_span["metrics"]
1154+
assert "completion_tokens" not in agent_span["metrics"]
11491155

11501156

11511157
@pytest.mark.vcr

py/src/braintrust/integrations/pydantic_ai/tracing.py

Lines changed: 37 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,7 @@ async def _agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any
8383
_maybe_create_tool_spans_from_messages(result)
8484

8585
output = _serialize_result_output(result)
86-
metrics = _extract_usage_metrics(result, start_time, end_time)
87-
88-
agent_span.log(output=output, metrics=metrics)
86+
agent_span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time))
8987
return result
9088
finally:
9189
_reset_tool_trace_capture(tool_trace_token)
@@ -109,9 +107,7 @@ def _agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any)
109107
_maybe_create_tool_spans_from_messages(result)
110108

111109
output = _serialize_result_output(result)
112-
metrics = _extract_usage_metrics(result, start_time, end_time)
113-
114-
agent_span.log(output=output, metrics=metrics)
110+
agent_span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time))
115111
return result
116112
finally:
117113
_reset_tool_trace_capture(tool_trace_token)
@@ -131,7 +127,7 @@ def _agent_to_cli_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: A
131127
start_time = time.time()
132128
result = wrapped(*args, **kwargs)
133129
end_time = time.time()
134-
agent_span.log(metrics={"start": start_time, "end": end_time, "duration": end_time - start_time})
130+
agent_span.log(metrics=_wrapper_span_metrics(start_time, end_time))
135131
return result
136132

137133

@@ -211,17 +207,13 @@ async def _agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: An
211207
_maybe_create_tool_spans_from_messages(final_result)
212208

213209
output = None
214-
metrics = {
215-
"start": start_time,
216-
"end": end_time,
217-
"duration": end_time - start_time,
210+
metrics: dict[str, float] = {
211+
**_wrapper_span_metrics(start_time, end_time),
218212
"event_count": event_count,
219213
}
220214

221215
if final_result:
222216
output = _serialize_result_output(final_result)
223-
usage_metrics = _extract_usage_metrics(final_result, start_time, end_time)
224-
metrics.update(usage_metrics)
225217

226218
agent_span.log(output=output, metrics=metrics)
227219
finally:
@@ -245,9 +237,7 @@ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
245237
end_time = time.time()
246238

247239
output = _serialize_model_response(result)
248-
metrics = _extract_response_metrics(result, start_time, end_time)
249-
250-
span.log(output=output, metrics=metrics)
240+
span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time))
251241
return result
252242

253243
return wrapper
@@ -270,9 +260,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
270260
end_time = time.time()
271261

272262
output = _serialize_model_response(result)
273-
metrics = _extract_response_metrics(result, start_time, end_time)
274-
275-
span.log(output=output, metrics=metrics)
263+
span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time))
276264
return result
277265

278266
return wrapper
@@ -326,9 +314,7 @@ async def wrapper(*args, **kwargs):
326314
end_time = time.time()
327315

328316
output = _serialize_model_response(result)
329-
metrics = _extract_response_metrics(result, start_time, end_time)
330-
331-
span.log(output=output, metrics=metrics)
317+
span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time))
332318
return result
333319

334320
return wrapper
@@ -349,9 +335,7 @@ def wrapper(*args, **kwargs):
349335
end_time = time.time()
350336

351337
output = _serialize_model_response(result)
352-
metrics = _extract_response_metrics(result, start_time, end_time)
353-
354-
span.log(output=output, metrics=metrics)
338+
span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time))
355339
return result
356340

357341
return wrapper
@@ -492,10 +476,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
492476
_maybe_create_tool_spans_from_messages(self.stream_result)
493477

494478
output = _serialize_stream_output(self.stream_result)
495-
metrics = _extract_stream_usage_metrics(
496-
self.stream_result, self.start_time, end_time, self._first_token_time
479+
self.span_cm.log(
480+
output=output,
481+
metrics=_wrapper_span_metrics(self.start_time, end_time, self._first_token_time),
497482
)
498-
self.span_cm.log(output=output, metrics=metrics)
499483

500484
# Clean up span context
501485
if self.span_cm:
@@ -593,9 +577,12 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
593577
try:
594578
final_response = self.stream.get()
595579
output = _serialize_model_response(final_response)
596-
metrics = _extract_response_metrics(
597-
final_response, self.start_time, end_time, self._first_token_time
598-
)
580+
if self.span_type == SpanTypeAttribute.LLM:
581+
metrics = _extract_response_metrics(
582+
final_response, self.start_time, end_time, self._first_token_time
583+
)
584+
else:
585+
metrics = _wrapper_span_metrics(self.start_time, end_time, self._first_token_time)
599586
self.span_cm.log(output=output, metrics=metrics)
600587
except Exception as e:
601588
logger.debug(f"Failed to extract stream output/metrics: {e}")
@@ -700,10 +687,10 @@ def _finalize(self):
700687
_maybe_create_tool_spans_from_messages(self._stream_result)
701688

702689
output = _serialize_stream_output(self._stream_result)
703-
metrics = _extract_stream_usage_metrics(
704-
self._stream_result, self._start_time, end_time, self._first_token_time
690+
self._span.log(
691+
output=output,
692+
metrics=_wrapper_span_metrics(self._start_time, end_time, self._first_token_time),
705693
)
706-
self._span.log(output=output, metrics=metrics)
707694
self._logged = True
708695
finally:
709696
try:
@@ -761,10 +748,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
761748
try:
762749
final_response = self.stream.get()
763750
output = _serialize_model_response(final_response)
764-
metrics = _extract_response_metrics(
765-
final_response, self.start_time, end_time, self._first_token_time
751+
self.span_cm.log(
752+
output=output,
753+
metrics=_wrapper_span_metrics(self.start_time, end_time, self._first_token_time),
766754
)
767-
self.span_cm.log(output=output, metrics=metrics)
768755
except Exception as e:
769756
logger.debug(f"Failed to extract stream output/metrics: {e}")
770757

@@ -1155,105 +1142,20 @@ def _parse_model_string(model: Any) -> tuple[str | None, str | None]:
11551142
return model_str, None
11561143

11571144

1158-
def _extract_usage_metrics(result: Any, start_time: float, end_time: float) -> dict[str, float] | None:
1159-
"""Extract usage metrics from agent run result."""
1160-
metrics: dict[str, float] = {}
1161-
1162-
metrics["start"] = start_time
1163-
metrics["end"] = end_time
1164-
metrics["duration"] = end_time - start_time
1165-
1166-
usage = None
1167-
if hasattr(result, "response"):
1168-
try:
1169-
response = result.response
1170-
if hasattr(response, "usage"):
1171-
usage = response.usage
1172-
except (AttributeError, ValueError):
1173-
pass
1174-
1175-
if usage is None and hasattr(result, "usage"):
1176-
usage = result.usage
1177-
1178-
if usage is None:
1179-
return metrics
1180-
1181-
if hasattr(usage, "input_tokens"):
1182-
input_tokens = usage.input_tokens
1183-
if input_tokens is not None:
1184-
metrics["prompt_tokens"] = float(input_tokens)
1185-
1186-
if hasattr(usage, "output_tokens"):
1187-
output_tokens = usage.output_tokens
1188-
if output_tokens is not None:
1189-
metrics["completion_tokens"] = float(output_tokens)
1190-
1191-
if hasattr(usage, "total_tokens"):
1192-
total_tokens = usage.total_tokens
1193-
if total_tokens is not None:
1194-
metrics["tokens"] = float(total_tokens)
1195-
1196-
if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
1197-
metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
1198-
1199-
if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
1200-
metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
1201-
1202-
if hasattr(usage, "input_audio_tokens") and usage.input_audio_tokens is not None:
1203-
metrics["prompt_audio_tokens"] = float(usage.input_audio_tokens)
1204-
1205-
if hasattr(usage, "output_audio_tokens") and usage.output_audio_tokens is not None:
1206-
metrics["completion_audio_tokens"] = float(usage.output_audio_tokens)
1207-
1208-
if hasattr(usage, "details") and isinstance(usage.details, dict):
1209-
details = usage.details
1210-
1211-
if "reasoning_tokens" in details:
1212-
metrics["completion_reasoning_tokens"] = float(details["reasoning_tokens"])
1213-
1214-
if "cached_tokens" in details:
1215-
metrics["prompt_cached_tokens"] = float(details["cached_tokens"])
1216-
1217-
return metrics if metrics else None
1218-
1219-
1220-
def _extract_stream_usage_metrics(
1221-
stream_result: Any, start_time: float, end_time: float, first_token_time: float | None
1222-
) -> dict[str, float] | None:
1223-
"""Extract usage metrics from stream result."""
1224-
metrics: dict[str, float] = {}
1225-
1226-
metrics["start"] = start_time
1227-
metrics["end"] = end_time
1228-
metrics["duration"] = end_time - start_time
1229-
1230-
if first_token_time:
1145+
def _wrapper_span_metrics(
1146+
start_time: float, end_time: float, first_token_time: float | None = None
1147+
) -> dict[str, float]:
1148+
# Wrapper spans (agent_run, model_request, streaming wrappers) must NOT log token or
1149+
# cost metrics. The leaf `chat <model>` span already logs them, and trace-tree rollup
1150+
# (self + descendants) would then double-count tokens/cost at every wrapper ancestor.
1151+
metrics: dict[str, float] = {
1152+
"start": start_time,
1153+
"end": end_time,
1154+
"duration": end_time - start_time,
1155+
}
1156+
if first_token_time is not None:
12311157
metrics["time_to_first_token"] = first_token_time - start_time
1232-
1233-
if hasattr(stream_result, "usage"):
1234-
usage_func = stream_result.usage
1235-
if callable(usage_func):
1236-
usage = usage_func()
1237-
else:
1238-
usage = usage_func
1239-
1240-
if usage:
1241-
if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
1242-
metrics["prompt_tokens"] = float(usage.input_tokens)
1243-
1244-
if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
1245-
metrics["completion_tokens"] = float(usage.output_tokens)
1246-
1247-
if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
1248-
metrics["tokens"] = float(usage.total_tokens)
1249-
1250-
if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
1251-
metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
1252-
1253-
if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
1254-
metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
1255-
1256-
return metrics if metrics else None
1158+
return metrics
12571159

12581160

12591161
def _extract_response_metrics(

0 commit comments

Comments
 (0)