@@ -83,9 +83,7 @@ async def _agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any
8383 _maybe_create_tool_spans_from_messages (result )
8484
8585 output = _serialize_result_output (result )
86- metrics = _extract_usage_metrics (result , start_time , end_time )
87-
88- agent_span .log (output = output , metrics = metrics )
86+ agent_span .log (output = output , metrics = _wrapper_span_metrics (start_time , end_time ))
8987 return result
9088 finally :
9189 _reset_tool_trace_capture (tool_trace_token )
@@ -109,9 +107,7 @@ def _agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any)
109107 _maybe_create_tool_spans_from_messages (result )
110108
111109 output = _serialize_result_output (result )
112- metrics = _extract_usage_metrics (result , start_time , end_time )
113-
114- agent_span .log (output = output , metrics = metrics )
110+ agent_span .log (output = output , metrics = _wrapper_span_metrics (start_time , end_time ))
115111 return result
116112 finally :
117113 _reset_tool_trace_capture (tool_trace_token )
@@ -131,7 +127,7 @@ def _agent_to_cli_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: A
131127 start_time = time .time ()
132128 result = wrapped (* args , ** kwargs )
133129 end_time = time .time ()
134- agent_span .log (metrics = { "start" : start_time , "end" : end_time , "duration" : end_time - start_time } )
130+ agent_span .log (metrics = _wrapper_span_metrics ( start_time , end_time ) )
135131 return result
136132
137133
@@ -211,17 +207,13 @@ async def _agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: An
211207 _maybe_create_tool_spans_from_messages (final_result )
212208
213209 output = None
214- metrics = {
215- "start" : start_time ,
216- "end" : end_time ,
217- "duration" : end_time - start_time ,
210+ metrics : dict [str , float ] = {
211+ ** _wrapper_span_metrics (start_time , end_time ),
218212 "event_count" : event_count ,
219213 }
220214
221215 if final_result :
222216 output = _serialize_result_output (final_result )
223- usage_metrics = _extract_usage_metrics (final_result , start_time , end_time )
224- metrics .update (usage_metrics )
225217
226218 agent_span .log (output = output , metrics = metrics )
227219 finally :
@@ -245,9 +237,7 @@ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
245237 end_time = time .time ()
246238
247239 output = _serialize_model_response (result )
248- metrics = _extract_response_metrics (result , start_time , end_time )
249-
250- span .log (output = output , metrics = metrics )
240+ span .log (output = output , metrics = _wrapper_span_metrics (start_time , end_time ))
251241 return result
252242
253243 return wrapper
@@ -270,9 +260,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
270260 end_time = time .time ()
271261
272262 output = _serialize_model_response (result )
273- metrics = _extract_response_metrics (result , start_time , end_time )
274-
275- span .log (output = output , metrics = metrics )
263+ span .log (output = output , metrics = _wrapper_span_metrics (start_time , end_time ))
276264 return result
277265
278266 return wrapper
@@ -326,9 +314,7 @@ async def wrapper(*args, **kwargs):
326314 end_time = time .time ()
327315
328316 output = _serialize_model_response (result )
329- metrics = _extract_response_metrics (result , start_time , end_time )
330-
331- span .log (output = output , metrics = metrics )
317+ span .log (output = output , metrics = _wrapper_span_metrics (start_time , end_time ))
332318 return result
333319
334320 return wrapper
@@ -349,9 +335,7 @@ def wrapper(*args, **kwargs):
349335 end_time = time .time ()
350336
351337 output = _serialize_model_response (result )
352- metrics = _extract_response_metrics (result , start_time , end_time )
353-
354- span .log (output = output , metrics = metrics )
338+ span .log (output = output , metrics = _wrapper_span_metrics (start_time , end_time ))
355339 return result
356340
357341 return wrapper
@@ -492,10 +476,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
492476 _maybe_create_tool_spans_from_messages (self .stream_result )
493477
494478 output = _serialize_stream_output (self .stream_result )
495- metrics = _extract_stream_usage_metrics (
496- self .stream_result , self .start_time , end_time , self ._first_token_time
479+ self .span_cm .log (
480+ output = output ,
481+ metrics = _wrapper_span_metrics (self .start_time , end_time , self ._first_token_time ),
497482 )
498- self .span_cm .log (output = output , metrics = metrics )
499483
500484 # Clean up span context
501485 if self .span_cm :
@@ -593,9 +577,12 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
593577 try :
594578 final_response = self .stream .get ()
595579 output = _serialize_model_response (final_response )
596- metrics = _extract_response_metrics (
597- final_response , self .start_time , end_time , self ._first_token_time
598- )
580+ if self .span_type == SpanTypeAttribute .LLM :
581+ metrics = _extract_response_metrics (
582+ final_response , self .start_time , end_time , self ._first_token_time
583+ )
584+ else :
585+ metrics = _wrapper_span_metrics (self .start_time , end_time , self ._first_token_time )
599586 self .span_cm .log (output = output , metrics = metrics )
600587 except Exception as e :
601588 logger .debug (f"Failed to extract stream output/metrics: { e } " )
@@ -700,10 +687,10 @@ def _finalize(self):
700687 _maybe_create_tool_spans_from_messages (self ._stream_result )
701688
702689 output = _serialize_stream_output (self ._stream_result )
703- metrics = _extract_stream_usage_metrics (
704- self ._stream_result , self ._start_time , end_time , self ._first_token_time
690+ self ._span .log (
691+ output = output ,
692+ metrics = _wrapper_span_metrics (self ._start_time , end_time , self ._first_token_time ),
705693 )
706- self ._span .log (output = output , metrics = metrics )
707694 self ._logged = True
708695 finally :
709696 try :
@@ -761,10 +748,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
761748 try :
762749 final_response = self .stream .get ()
763750 output = _serialize_model_response (final_response )
764- metrics = _extract_response_metrics (
765- final_response , self .start_time , end_time , self ._first_token_time
751+ self .span_cm .log (
752+ output = output ,
753+ metrics = _wrapper_span_metrics (self .start_time , end_time , self ._first_token_time ),
766754 )
767- self .span_cm .log (output = output , metrics = metrics )
768755 except Exception as e :
769756 logger .debug (f"Failed to extract stream output/metrics: { e } " )
770757
@@ -1155,105 +1142,20 @@ def _parse_model_string(model: Any) -> tuple[str | None, str | None]:
11551142 return model_str , None
11561143
11571144
1158- def _extract_usage_metrics (result : Any , start_time : float , end_time : float ) -> dict [str , float ] | None :
1159- """Extract usage metrics from agent run result."""
1160- metrics : dict [str , float ] = {}
1161-
1162- metrics ["start" ] = start_time
1163- metrics ["end" ] = end_time
1164- metrics ["duration" ] = end_time - start_time
1165-
1166- usage = None
1167- if hasattr (result , "response" ):
1168- try :
1169- response = result .response
1170- if hasattr (response , "usage" ):
1171- usage = response .usage
1172- except (AttributeError , ValueError ):
1173- pass
1174-
1175- if usage is None and hasattr (result , "usage" ):
1176- usage = result .usage
1177-
1178- if usage is None :
1179- return metrics
1180-
1181- if hasattr (usage , "input_tokens" ):
1182- input_tokens = usage .input_tokens
1183- if input_tokens is not None :
1184- metrics ["prompt_tokens" ] = float (input_tokens )
1185-
1186- if hasattr (usage , "output_tokens" ):
1187- output_tokens = usage .output_tokens
1188- if output_tokens is not None :
1189- metrics ["completion_tokens" ] = float (output_tokens )
1190-
1191- if hasattr (usage , "total_tokens" ):
1192- total_tokens = usage .total_tokens
1193- if total_tokens is not None :
1194- metrics ["tokens" ] = float (total_tokens )
1195-
1196- if hasattr (usage , "cache_read_tokens" ) and usage .cache_read_tokens is not None :
1197- metrics ["prompt_cached_tokens" ] = float (usage .cache_read_tokens )
1198-
1199- if hasattr (usage , "cache_write_tokens" ) and usage .cache_write_tokens is not None :
1200- metrics ["prompt_cache_creation_tokens" ] = float (usage .cache_write_tokens )
1201-
1202- if hasattr (usage , "input_audio_tokens" ) and usage .input_audio_tokens is not None :
1203- metrics ["prompt_audio_tokens" ] = float (usage .input_audio_tokens )
1204-
1205- if hasattr (usage , "output_audio_tokens" ) and usage .output_audio_tokens is not None :
1206- metrics ["completion_audio_tokens" ] = float (usage .output_audio_tokens )
1207-
1208- if hasattr (usage , "details" ) and isinstance (usage .details , dict ):
1209- details = usage .details
1210-
1211- if "reasoning_tokens" in details :
1212- metrics ["completion_reasoning_tokens" ] = float (details ["reasoning_tokens" ])
1213-
1214- if "cached_tokens" in details :
1215- metrics ["prompt_cached_tokens" ] = float (details ["cached_tokens" ])
1216-
1217- return metrics if metrics else None
1218-
1219-
1220- def _extract_stream_usage_metrics (
1221- stream_result : Any , start_time : float , end_time : float , first_token_time : float | None
1222- ) -> dict [str , float ] | None :
1223- """Extract usage metrics from stream result."""
1224- metrics : dict [str , float ] = {}
1225-
1226- metrics ["start" ] = start_time
1227- metrics ["end" ] = end_time
1228- metrics ["duration" ] = end_time - start_time
1229-
1230- if first_token_time :
1145+ def _wrapper_span_metrics (
1146+ start_time : float , end_time : float , first_token_time : float | None = None
1147+ ) -> dict [str , float ]:
1148+ # Wrapper spans (agent_run, model_request, streaming wrappers) must NOT log token or
1149+ # cost metrics. The leaf `chat <model>` span already logs them, and trace-tree rollup
1150+ # (self + descendants) would then double-count tokens/cost at every wrapper ancestor.
1151+ metrics : dict [str , float ] = {
1152+ "start" : start_time ,
1153+ "end" : end_time ,
1154+ "duration" : end_time - start_time ,
1155+ }
1156+ if first_token_time is not None :
12311157 metrics ["time_to_first_token" ] = first_token_time - start_time
1232-
1233- if hasattr (stream_result , "usage" ):
1234- usage_func = stream_result .usage
1235- if callable (usage_func ):
1236- usage = usage_func ()
1237- else :
1238- usage = usage_func
1239-
1240- if usage :
1241- if hasattr (usage , "input_tokens" ) and usage .input_tokens is not None :
1242- metrics ["prompt_tokens" ] = float (usage .input_tokens )
1243-
1244- if hasattr (usage , "output_tokens" ) and usage .output_tokens is not None :
1245- metrics ["completion_tokens" ] = float (usage .output_tokens )
1246-
1247- if hasattr (usage , "total_tokens" ) and usage .total_tokens is not None :
1248- metrics ["tokens" ] = float (usage .total_tokens )
1249-
1250- if hasattr (usage , "cache_read_tokens" ) and usage .cache_read_tokens is not None :
1251- metrics ["prompt_cached_tokens" ] = float (usage .cache_read_tokens )
1252-
1253- if hasattr (usage , "cache_write_tokens" ) and usage .cache_write_tokens is not None :
1254- metrics ["prompt_cache_creation_tokens" ] = float (usage .cache_write_tokens )
1255-
1256- return metrics if metrics else None
1158+ return metrics
12571159
12581160
12591161def _extract_response_metrics (
0 commit comments