diff --git a/py/src/braintrust/integrations/anthropic/_utils.py b/py/src/braintrust/integrations/anthropic/_utils.py index e117451c..df5a412d 100644 --- a/py/src/braintrust/integrations/anthropic/_utils.py +++ b/py/src/braintrust/integrations/anthropic/_utils.py @@ -2,6 +2,8 @@ from typing import Any +from braintrust.util import is_numeric + class Wrapper: """Base wrapper class with __getattr__ delegation to preserve original types.""" @@ -13,73 +15,104 @@ def __getattr__(self, name: str) -> Any: return getattr(self.__wrapped, name) -def extract_anthropic_usage(usage: Any) -> dict[str, float]: - """Extract and normalize usage metrics from Anthropic usage object or dict. +_ANTHROPIC_USAGE_METRIC_FIELDS = ( + ("input_tokens", "prompt_tokens"), + ("output_tokens", "completion_tokens"), + ("cache_read_input_tokens", "prompt_cached_tokens"), + ("cache_creation_input_tokens", "prompt_cache_creation_tokens"), +) - Converts Anthropic's usage format to Braintrust's standard token metric names. - Handles both object attributes and dictionary keys. +_ANTHROPIC_CACHE_CREATION_METRIC_FIELDS = ( + ("ephemeral_5m_input_tokens", "prompt_cache_creation_ephemeral_5m_tokens"), + ("ephemeral_1h_input_tokens", "prompt_cache_creation_ephemeral_1h_tokens"), +) - Args: - usage: Anthropic usage object (from Message.usage) or dict +_ANTHROPIC_SERVER_TOOL_USE_METRIC_FIELDS = ( + ("web_search_requests", "server_tool_use_web_search_requests"), + ("web_fetch_requests", "server_tool_use_web_fetch_requests"), +) - Returns: - Dictionary with normalized metric names: - - prompt_tokens (from input_tokens) - - completion_tokens (from output_tokens) - - prompt_cached_tokens (from cache_read_input_tokens) - - prompt_cache_creation_tokens (from cache_creation_input_tokens) - """ - metrics: dict[str, float] = {} +_ANTHROPIC_USAGE_METADATA_FIELDS = frozenset( + { + "service_tier", + "inference_geo", + } +) - if not usage: - return metrics - def get_value(key: str) -> Any: - if isinstance(usage, dict): - return usage.get(key) - return getattr(usage, key, None) +def _try_to_dict(obj: Any) -> dict[str, Any] | None: + if isinstance(obj, dict): + return obj - input_tokens = get_value("input_tokens") - if input_tokens is not None: + if hasattr(obj, "model_dump"): try: - metrics["prompt_tokens"] = float(input_tokens) - except (ValueError, TypeError): - pass + candidate = obj.model_dump(mode="python") + except TypeError: + candidate = obj.model_dump() + return candidate if isinstance(candidate, dict) else None - output_tokens = get_value("output_tokens") - if output_tokens is not None: - try: - metrics["completion_tokens"] = float(output_tokens) - except (ValueError, TypeError): - pass + if hasattr(obj, "to_dict"): + candidate = obj.to_dict() + return candidate if isinstance(candidate, dict) else None - cache_read_tokens = get_value("cache_read_input_tokens") - if cache_read_tokens is not None: - try: - metrics["prompt_cached_tokens"] = float(cache_read_tokens) - except (ValueError, TypeError): - pass + if hasattr(obj, "dict"): + candidate = obj.dict() + return candidate if isinstance(candidate, dict) else None + + if hasattr(obj, "__dict__"): + return vars(obj) + + return None - cache_creation_tokens = get_value("cache_creation_input_tokens") - if cache_creation_tokens is not None: - try: - metrics["prompt_cache_creation_tokens"] = float(cache_creation_tokens) - except (ValueError, TypeError): - pass - return metrics +def _set_numeric_metric(metrics: dict[str, float], name: str, value: Any) -> None: + if is_numeric(value): + metrics[name] = float(value) -def finalize_anthropic_tokens(metrics: dict[str, float]) -> dict[str, float]: - """Finalize Anthropic token calculations.""" - total_prompt_tokens = ( - metrics.get("prompt_tokens", 0) - + metrics.get("prompt_cached_tokens", 0) - + metrics.get("prompt_cache_creation_tokens", 0) - ) +def extract_anthropic_usage(usage: Any) -> tuple[dict[str, float], dict[str, Any]]: + """Extract normalized metrics and allowlisted metadata from Anthropic usage. - return { - **metrics, - "prompt_tokens": total_prompt_tokens, - "tokens": total_prompt_tokens + metrics.get("completion_tokens", 0), + Numeric usage fields are converted into Braintrust metrics. Allowlisted + non-numeric fields are attached as span metadata with a ``usage_`` prefix. + """ + usage = _try_to_dict(usage) + if usage is None: + return {}, {} + + metrics: dict[str, float] = {} + for source_name, metric_name in _ANTHROPIC_USAGE_METRIC_FIELDS: + _set_numeric_metric(metrics, metric_name, usage.get(source_name)) + + cache_creation = _try_to_dict(usage.get("cache_creation")) + cache_creation_breakdown: list[float] = [] + if cache_creation is not None: + for source_name, metric_name in _ANTHROPIC_CACHE_CREATION_METRIC_FIELDS: + value = cache_creation.get(source_name) + _set_numeric_metric(metrics, metric_name, value) + if is_numeric(value): + cache_creation_breakdown.append(float(value)) + + server_tool_use = _try_to_dict(usage.get("server_tool_use")) + if server_tool_use is not None: + for source_name, metric_name in _ANTHROPIC_SERVER_TOOL_USE_METRIC_FIELDS: + _set_numeric_metric(metrics, metric_name, server_tool_use.get(source_name)) + + if "prompt_cache_creation_tokens" not in metrics and cache_creation_breakdown: + metrics["prompt_cache_creation_tokens"] = sum(cache_creation_breakdown) + + if metrics: + total_prompt_tokens = ( + metrics.get("prompt_tokens", 0) + + metrics.get("prompt_cached_tokens", 0) + + metrics.get("prompt_cache_creation_tokens", 0) + ) + metrics["prompt_tokens"] = total_prompt_tokens + metrics["tokens"] = total_prompt_tokens + metrics.get("completion_tokens", 0) + + metadata = { + f"usage_{name}": value + for name, value in usage.items() + if name in _ANTHROPIC_USAGE_METADATA_FIELDS and value is not None } + return metrics, metadata diff --git a/py/src/braintrust/integrations/anthropic/test_anthropic.py b/py/src/braintrust/integrations/anthropic/test_anthropic.py index b363fb1e..e187365a 100644 --- a/py/src/braintrust/integrations/anthropic/test_anthropic.py +++ b/py/src/braintrust/integrations/anthropic/test_anthropic.py @@ -11,6 +11,7 @@ import pytest from braintrust import logger from braintrust.integrations.anthropic import AnthropicIntegration, wrap_anthropic +from braintrust.integrations.anthropic._utils import extract_anthropic_usage from braintrust.integrations.anthropic.tracing import _log_message_to_span from braintrust.test_helpers import init_test_logger @@ -73,6 +74,7 @@ def test_log_message_to_span_includes_stop_reason_and_stop_sequence(): "tokens": 18.0, "time_to_first_token": 0.123, }, + metadata={}, ) @@ -539,18 +541,82 @@ def test_setup_creates_spans(memory_logger): AnthropicIntegration.setup() client = anthropic.Anthropic() - client.messages.create( + message = client.messages.create( model=MODEL, max_tokens=100, messages=[{"role": "user", "content": "hi"}], ) + usage = message.usage + spans = memory_logger.pop() assert len(spans) == 1 span = spans[0] assert span["metadata"]["model"] == MODEL assert span["metadata"]["provider"] == "anthropic" + cache_creation = getattr(usage, "cache_creation", None) + if cache_creation is None: + pytest.skip("Anthropic SDK version does not expose nested cache_creation usage fields") + + if isinstance(cache_creation, dict): + ephemeral_5m = cache_creation["ephemeral_5m_input_tokens"] + ephemeral_1h = cache_creation["ephemeral_1h_input_tokens"] + else: + ephemeral_5m = cache_creation.ephemeral_5m_input_tokens + ephemeral_1h = cache_creation.ephemeral_1h_input_tokens + + assert span["metadata"]["usage_service_tier"] == usage.service_tier + assert span["metadata"]["usage_inference_geo"] == usage.inference_geo + metrics = span["metrics"] + assert metrics["prompt_tokens"] == ( + usage.input_tokens + usage.cache_read_input_tokens + usage.cache_creation_input_tokens + ) + assert metrics["completion_tokens"] == usage.output_tokens + assert metrics["prompt_cache_creation_tokens"] == usage.cache_creation_input_tokens + assert metrics["prompt_cache_creation_ephemeral_5m_tokens"] == ephemeral_5m + assert metrics["prompt_cache_creation_ephemeral_1h_tokens"] == ephemeral_1h + assert "service_tier" not in metrics + + +def test_extract_anthropic_usage_preserves_nested_numeric_fields(): + usage = { + "input_tokens": 8, + "output_tokens": 12, + "cache_creation": { + "ephemeral_5m_input_tokens": 3, + "ephemeral_1h_input_tokens": 4, + }, + "server_tool_use": { + "web_search_requests": 2, + "web_fetch_requests": 1, + }, + "service_tier": "standard", + "inference_geo": "not_available", + } + metrics, metadata = extract_anthropic_usage(usage) + + assert metrics["prompt_tokens"] == 15 + assert metrics["completion_tokens"] == 12 + assert metrics["tokens"] == 27 + assert metrics["prompt_cache_creation_tokens"] == 7 + assert metrics["prompt_cache_creation_ephemeral_5m_tokens"] == 3 + assert metrics["prompt_cache_creation_ephemeral_1h_tokens"] == 4 + assert metrics["server_tool_use_web_search_requests"] == 2 + assert metrics["server_tool_use_web_fetch_requests"] == 1 + assert "service_tier" not in metrics + assert metadata == { + "usage_service_tier": "standard", + "usage_inference_geo": "not_available", + } + + +def test_extract_anthropic_usage_skips_empty_usage(): + metrics, metadata = extract_anthropic_usage(SimpleNamespace()) + + assert metrics == {} + assert metadata == {} + def _make_batch_requests(): return [ diff --git a/py/src/braintrust/integrations/anthropic/tracing.py b/py/src/braintrust/integrations/anthropic/tracing.py index 15974cf4..3e07b6ba 100644 --- a/py/src/braintrust/integrations/anthropic/tracing.py +++ b/py/src/braintrust/integrations/anthropic/tracing.py @@ -3,7 +3,7 @@ import warnings from contextlib import contextmanager -from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens +from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span @@ -445,7 +445,7 @@ def _start_span(name, kwargs): def _log_message_to_span(message, span, time_to_first_token: float | None = None): with _catch_exceptions(): usage = getattr(message, "usage", {}) - metrics = finalize_anthropic_tokens(extract_anthropic_usage(usage)) + metrics, metadata = extract_anthropic_usage(usage) if time_to_first_token is not None: metrics["time_to_first_token"] = time_to_first_token @@ -462,7 +462,7 @@ def _log_message_to_span(message, span, time_to_first_token: float | None = None if v is not None } or None - span.log(output=output, metrics=metrics) + span.log(output=output, metrics=metrics, metadata=metadata) @contextmanager diff --git a/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py b/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py index b9a3202d..ff71a13a 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/test_claude_agent_sdk.py @@ -24,6 +24,7 @@ print("Claude Agent SDK not installed, skipping integration tests") from braintrust import logger +from braintrust.integrations.anthropic._utils import extract_anthropic_usage from braintrust.integrations.claude_agent_sdk import setup_claude_agent_sdk from braintrust.integrations.claude_agent_sdk._test_transport import make_cassette_transport from braintrust.integrations.claude_agent_sdk.tracing import ( @@ -31,7 +32,6 @@ _build_llm_input, _create_client_wrapper_class, _create_tool_wrapper_class, - _extract_usage_from_result_message, _parse_tool_name, _serialize_content_blocks, _serialize_system_message, @@ -184,6 +184,8 @@ async def calculator_handler(args): for metric_name in ("prompt_tokens", "completion_tokens", "tokens"): if metric_name in llm_span.get("metrics", {}): assert llm_span["metrics"][metric_name] > 0 + assert any(llm_span.get("metadata", {}).get("usage_service_tier") == "standard" for llm_span in llm_spans) + assert any("usage_inference_geo" in llm_span.get("metadata", {}) for llm_span in llm_spans) tool_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TOOL] for tool_span in tool_spans: assert tool_span["span_attributes"]["name"] == "calculator" @@ -1828,9 +1830,9 @@ def test_serialize_system_message_extracts_known_fields(message, expected): assert _serialize_system_message(message) == expected -def test_extract_usage_from_result_message_normalizes_anthropic_tokens(): - metrics = _extract_usage_from_result_message( - ResultMessage(input_tokens=5, output_tokens=3, cache_creation_input_tokens=2) +def test_extract_anthropic_usage_normalizes_claude_result_message_usage(): + metrics, metadata = extract_anthropic_usage( + ResultMessage(input_tokens=5, output_tokens=3, cache_creation_input_tokens=2).usage ) assert metrics == { @@ -1839,6 +1841,7 @@ def test_extract_usage_from_result_message_normalizes_anthropic_tokens(): "prompt_cache_creation_tokens": 2.0, "tokens": 10.0, } + assert metadata == {} @pytest.mark.parametrize( diff --git a/py/src/braintrust/integrations/claude_agent_sdk/tracing.py b/py/src/braintrust/integrations/claude_agent_sdk/tracing.py index 98726802..d31c3551 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/tracing.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/tracing.py @@ -7,7 +7,7 @@ from collections.abc import AsyncGenerator, AsyncIterable from typing import Any -from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens +from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage from braintrust.integrations.claude_agent_sdk._constants import ( ANTHROPIC_MESSAGES_CREATE_SPAN_NAME, CLAUDE_AGENT_TASK_SPAN_NAME, @@ -711,10 +711,10 @@ def _handle_user(self, message: Any) -> None: def _handle_result(self, message: Any) -> None: self._active_key = None if hasattr(message, "usage"): - usage_metrics = _extract_usage_from_result_message(message) + usage_metrics, usage_metadata = extract_anthropic_usage(message.usage) ctx = self._get_context(None) - if ctx.llm_span and usage_metrics: - ctx.llm_span.log(metrics=usage_metrics) + if ctx.llm_span and (usage_metrics or usage_metadata): + ctx.llm_span.log(metrics=usage_metrics or None, metadata=usage_metadata or None) result_metadata = { k: v for k, v in { @@ -1203,25 +1203,6 @@ def _serialize_content_blocks(content: Any) -> Any: return content -def _extract_usage_from_result_message(result_message: Any) -> dict[str, float]: - """Extracts and normalizes usage metrics from a ResultMessage. - - Uses shared Anthropic utilities for consistent metric extraction. - """ - if not hasattr(result_message, "usage"): - return {} - - usage = result_message.usage - if not usage: - return {} - - metrics = extract_anthropic_usage(usage) - if metrics: - metrics = finalize_anthropic_tokens(metrics) - - return metrics - - def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]] | None: """Builds the input array for an LLM span from the initial prompt and conversation history.