Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 88 additions & 55 deletions py/src/braintrust/integrations/anthropic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any

from braintrust.util import is_numeric


class Wrapper:
"""Base wrapper class with __getattr__ delegation to preserve original types."""
Expand All @@ -13,73 +15,104 @@ def __getattr__(self, name: str) -> Any:
return getattr(self.__wrapped, name)


def extract_anthropic_usage(usage: Any) -> dict[str, float]:
"""Extract and normalize usage metrics from Anthropic usage object or dict.
_ANTHROPIC_USAGE_METRIC_FIELDS = (
("input_tokens", "prompt_tokens"),
("output_tokens", "completion_tokens"),
("cache_read_input_tokens", "prompt_cached_tokens"),
("cache_creation_input_tokens", "prompt_cache_creation_tokens"),
)

Converts Anthropic's usage format to Braintrust's standard token metric names.
Handles both object attributes and dictionary keys.
_ANTHROPIC_CACHE_CREATION_METRIC_FIELDS = (
("ephemeral_5m_input_tokens", "prompt_cache_creation_ephemeral_5m_tokens"),
("ephemeral_1h_input_tokens", "prompt_cache_creation_ephemeral_1h_tokens"),
)

Args:
usage: Anthropic usage object (from Message.usage) or dict
_ANTHROPIC_SERVER_TOOL_USE_METRIC_FIELDS = (
("web_search_requests", "server_tool_use_web_search_requests"),
("web_fetch_requests", "server_tool_use_web_fetch_requests"),
)

Returns:
Dictionary with normalized metric names:
- prompt_tokens (from input_tokens)
- completion_tokens (from output_tokens)
- prompt_cached_tokens (from cache_read_input_tokens)
- prompt_cache_creation_tokens (from cache_creation_input_tokens)
"""
metrics: dict[str, float] = {}
_ANTHROPIC_USAGE_METADATA_FIELDS = frozenset(
{
"service_tier",
"inference_geo",
}
)

if not usage:
return metrics

def get_value(key: str) -> Any:
if isinstance(usage, dict):
return usage.get(key)
return getattr(usage, key, None)
def _try_to_dict(obj: Any) -> dict[str, Any] | None:
if isinstance(obj, dict):
return obj

input_tokens = get_value("input_tokens")
if input_tokens is not None:
if hasattr(obj, "model_dump"):
try:
metrics["prompt_tokens"] = float(input_tokens)
except (ValueError, TypeError):
pass
candidate = obj.model_dump(mode="python")
except TypeError:
candidate = obj.model_dump()
return candidate if isinstance(candidate, dict) else None

output_tokens = get_value("output_tokens")
if output_tokens is not None:
try:
metrics["completion_tokens"] = float(output_tokens)
except (ValueError, TypeError):
pass
if hasattr(obj, "to_dict"):
candidate = obj.to_dict()
return candidate if isinstance(candidate, dict) else None

cache_read_tokens = get_value("cache_read_input_tokens")
if cache_read_tokens is not None:
try:
metrics["prompt_cached_tokens"] = float(cache_read_tokens)
except (ValueError, TypeError):
pass
if hasattr(obj, "dict"):
candidate = obj.dict()
return candidate if isinstance(candidate, dict) else None

if hasattr(obj, "__dict__"):
return vars(obj)

return None

cache_creation_tokens = get_value("cache_creation_input_tokens")
if cache_creation_tokens is not None:
try:
metrics["prompt_cache_creation_tokens"] = float(cache_creation_tokens)
except (ValueError, TypeError):
pass

return metrics
def _set_numeric_metric(metrics: dict[str, float], name: str, value: Any) -> None:
if is_numeric(value):
metrics[name] = float(value)


def finalize_anthropic_tokens(metrics: dict[str, float]) -> dict[str, float]:
"""Finalize Anthropic token calculations."""
total_prompt_tokens = (
metrics.get("prompt_tokens", 0)
+ metrics.get("prompt_cached_tokens", 0)
+ metrics.get("prompt_cache_creation_tokens", 0)
)
def extract_anthropic_usage(usage: Any) -> tuple[dict[str, float], dict[str, Any]]:
"""Extract normalized metrics and allowlisted metadata from Anthropic usage.

return {
**metrics,
"prompt_tokens": total_prompt_tokens,
"tokens": total_prompt_tokens + metrics.get("completion_tokens", 0),
Numeric usage fields are converted into Braintrust metrics. Allowlisted
non-numeric fields are attached as span metadata with a ``usage_`` prefix.
"""
usage = _try_to_dict(usage)
if usage is None:
return {}, {}

metrics: dict[str, float] = {}
for source_name, metric_name in _ANTHROPIC_USAGE_METRIC_FIELDS:
_set_numeric_metric(metrics, metric_name, usage.get(source_name))

cache_creation = _try_to_dict(usage.get("cache_creation"))
cache_creation_breakdown: list[float] = []
if cache_creation is not None:
for source_name, metric_name in _ANTHROPIC_CACHE_CREATION_METRIC_FIELDS:
value = cache_creation.get(source_name)
_set_numeric_metric(metrics, metric_name, value)
if is_numeric(value):
cache_creation_breakdown.append(float(value))

server_tool_use = _try_to_dict(usage.get("server_tool_use"))
if server_tool_use is not None:
for source_name, metric_name in _ANTHROPIC_SERVER_TOOL_USE_METRIC_FIELDS:
_set_numeric_metric(metrics, metric_name, server_tool_use.get(source_name))

if "prompt_cache_creation_tokens" not in metrics and cache_creation_breakdown:
metrics["prompt_cache_creation_tokens"] = sum(cache_creation_breakdown)

if metrics:
total_prompt_tokens = (
metrics.get("prompt_tokens", 0)
+ metrics.get("prompt_cached_tokens", 0)
+ metrics.get("prompt_cache_creation_tokens", 0)
)
metrics["prompt_tokens"] = total_prompt_tokens
metrics["tokens"] = total_prompt_tokens + metrics.get("completion_tokens", 0)

metadata = {
f"usage_{name}": value
for name, value in usage.items()
if name in _ANTHROPIC_USAGE_METADATA_FIELDS and value is not None
}
return metrics, metadata
68 changes: 67 additions & 1 deletion py/src/braintrust/integrations/anthropic/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
from braintrust import logger
from braintrust.integrations.anthropic import AnthropicIntegration, wrap_anthropic
from braintrust.integrations.anthropic._utils import extract_anthropic_usage
from braintrust.integrations.anthropic.tracing import _log_message_to_span
from braintrust.test_helpers import init_test_logger

Expand Down Expand Up @@ -73,6 +74,7 @@ def test_log_message_to_span_includes_stop_reason_and_stop_sequence():
"tokens": 18.0,
"time_to_first_token": 0.123,
},
metadata={},
)


Expand Down Expand Up @@ -539,18 +541,82 @@ def test_setup_creates_spans(memory_logger):
AnthropicIntegration.setup()

client = anthropic.Anthropic()
client.messages.create(
message = client.messages.create(
model=MODEL,
max_tokens=100,
messages=[{"role": "user", "content": "hi"}],
)

usage = message.usage

spans = memory_logger.pop()
assert len(spans) == 1
span = spans[0]
assert span["metadata"]["model"] == MODEL
assert span["metadata"]["provider"] == "anthropic"

cache_creation = getattr(usage, "cache_creation", None)
if cache_creation is None:
pytest.skip("Anthropic SDK version does not expose nested cache_creation usage fields")

if isinstance(cache_creation, dict):
ephemeral_5m = cache_creation["ephemeral_5m_input_tokens"]
ephemeral_1h = cache_creation["ephemeral_1h_input_tokens"]
else:
ephemeral_5m = cache_creation.ephemeral_5m_input_tokens
ephemeral_1h = cache_creation.ephemeral_1h_input_tokens

assert span["metadata"]["usage_service_tier"] == usage.service_tier
assert span["metadata"]["usage_inference_geo"] == usage.inference_geo
metrics = span["metrics"]
assert metrics["prompt_tokens"] == (
usage.input_tokens + usage.cache_read_input_tokens + usage.cache_creation_input_tokens
)
assert metrics["completion_tokens"] == usage.output_tokens
assert metrics["prompt_cache_creation_tokens"] == usage.cache_creation_input_tokens
assert metrics["prompt_cache_creation_ephemeral_5m_tokens"] == ephemeral_5m
assert metrics["prompt_cache_creation_ephemeral_1h_tokens"] == ephemeral_1h
assert "service_tier" not in metrics


def test_extract_anthropic_usage_preserves_nested_numeric_fields():
usage = {
"input_tokens": 8,
"output_tokens": 12,
"cache_creation": {
"ephemeral_5m_input_tokens": 3,
"ephemeral_1h_input_tokens": 4,
},
"server_tool_use": {
"web_search_requests": 2,
"web_fetch_requests": 1,
},
"service_tier": "standard",
"inference_geo": "not_available",
}
metrics, metadata = extract_anthropic_usage(usage)

assert metrics["prompt_tokens"] == 15
assert metrics["completion_tokens"] == 12
assert metrics["tokens"] == 27
assert metrics["prompt_cache_creation_tokens"] == 7
assert metrics["prompt_cache_creation_ephemeral_5m_tokens"] == 3
assert metrics["prompt_cache_creation_ephemeral_1h_tokens"] == 4
assert metrics["server_tool_use_web_search_requests"] == 2
assert metrics["server_tool_use_web_fetch_requests"] == 1
assert "service_tier" not in metrics
assert metadata == {
"usage_service_tier": "standard",
"usage_inference_geo": "not_available",
}


def test_extract_anthropic_usage_skips_empty_usage():
metrics, metadata = extract_anthropic_usage(SimpleNamespace())

assert metrics == {}
assert metadata == {}


def _make_batch_requests():
return [
Expand Down
6 changes: 3 additions & 3 deletions py/src/braintrust/integrations/anthropic/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from contextlib import contextmanager

from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage
from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span


Expand Down Expand Up @@ -445,7 +445,7 @@ def _start_span(name, kwargs):
def _log_message_to_span(message, span, time_to_first_token: float | None = None):
with _catch_exceptions():
usage = getattr(message, "usage", {})
metrics = finalize_anthropic_tokens(extract_anthropic_usage(usage))
metrics, metadata = extract_anthropic_usage(usage)

if time_to_first_token is not None:
metrics["time_to_first_token"] = time_to_first_token
Expand All @@ -462,7 +462,7 @@ def _log_message_to_span(message, span, time_to_first_token: float | None = None
if v is not None
} or None

span.log(output=output, metrics=metrics)
span.log(output=output, metrics=metrics, metadata=metadata)


@contextmanager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
print("Claude Agent SDK not installed, skipping integration tests")

from braintrust import logger
from braintrust.integrations.anthropic._utils import extract_anthropic_usage
from braintrust.integrations.claude_agent_sdk import setup_claude_agent_sdk
from braintrust.integrations.claude_agent_sdk._test_transport import make_cassette_transport
from braintrust.integrations.claude_agent_sdk.tracing import (
ToolSpanTracker,
_build_llm_input,
_create_client_wrapper_class,
_create_tool_wrapper_class,
_extract_usage_from_result_message,
_parse_tool_name,
_serialize_content_blocks,
_serialize_system_message,
Expand Down Expand Up @@ -184,6 +184,8 @@ async def calculator_handler(args):
for metric_name in ("prompt_tokens", "completion_tokens", "tokens"):
if metric_name in llm_span.get("metrics", {}):
assert llm_span["metrics"][metric_name] > 0
assert any(llm_span.get("metadata", {}).get("usage_service_tier") == "standard" for llm_span in llm_spans)
assert any("usage_inference_geo" in llm_span.get("metadata", {}) for llm_span in llm_spans)
tool_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TOOL]
for tool_span in tool_spans:
assert tool_span["span_attributes"]["name"] == "calculator"
Expand Down Expand Up @@ -1828,9 +1830,9 @@ def test_serialize_system_message_extracts_known_fields(message, expected):
assert _serialize_system_message(message) == expected


def test_extract_usage_from_result_message_normalizes_anthropic_tokens():
metrics = _extract_usage_from_result_message(
ResultMessage(input_tokens=5, output_tokens=3, cache_creation_input_tokens=2)
def test_extract_anthropic_usage_normalizes_claude_result_message_usage():
metrics, metadata = extract_anthropic_usage(
ResultMessage(input_tokens=5, output_tokens=3, cache_creation_input_tokens=2).usage
)

assert metrics == {
Expand All @@ -1839,6 +1841,7 @@ def test_extract_usage_from_result_message_normalizes_anthropic_tokens():
"prompt_cache_creation_tokens": 2.0,
"tokens": 10.0,
}
assert metadata == {}


@pytest.mark.parametrize(
Expand Down
27 changes: 4 additions & 23 deletions py/src/braintrust/integrations/claude_agent_sdk/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import AsyncGenerator, AsyncIterable
from typing import Any

from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
from braintrust.integrations.anthropic._utils import Wrapper, extract_anthropic_usage
from braintrust.integrations.claude_agent_sdk._constants import (
ANTHROPIC_MESSAGES_CREATE_SPAN_NAME,
CLAUDE_AGENT_TASK_SPAN_NAME,
Expand Down Expand Up @@ -711,10 +711,10 @@ def _handle_user(self, message: Any) -> None:
def _handle_result(self, message: Any) -> None:
self._active_key = None
if hasattr(message, "usage"):
usage_metrics = _extract_usage_from_result_message(message)
usage_metrics, usage_metadata = extract_anthropic_usage(message.usage)
ctx = self._get_context(None)
if ctx.llm_span and usage_metrics:
ctx.llm_span.log(metrics=usage_metrics)
if ctx.llm_span and (usage_metrics or usage_metadata):
ctx.llm_span.log(metrics=usage_metrics or None, metadata=usage_metadata or None)
result_metadata = {
k: v
for k, v in {
Expand Down Expand Up @@ -1203,25 +1203,6 @@ def _serialize_content_blocks(content: Any) -> Any:
return content


def _extract_usage_from_result_message(result_message: Any) -> dict[str, float]:
"""Extracts and normalizes usage metrics from a ResultMessage.

Uses shared Anthropic utilities for consistent metric extraction.
"""
if not hasattr(result_message, "usage"):
return {}

usage = result_message.usage
if not usage:
return {}

metrics = extract_anthropic_usage(usage)
if metrics:
metrics = finalize_anthropic_tokens(metrics)

return metrics


def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Builds the input array for an LLM span from the initial prompt and conversation history.

Expand Down
Loading