Skip to content

Commit c3bf48e

Browse files
committed
feat: Improve compatibility with argo 2.5.0
1 parent f60a2d3 commit c3bf48e

6 files changed

Lines changed: 718 additions & 58 deletions

File tree

py/noxfile.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@
6161
LITELLM_VERSIONS = (LATEST, "1.74.0")
6262
# CLI bundling started in 0.1.10 - older versions require external Claude Code installation
6363
CLAUDE_AGENT_SDK_VERSIONS = (LATEST, "0.1.10")
64-
AGNO_VERSIONS = (LATEST, "2.1.0")
64+
# Keep LATEST for newest API coverage, and pin 2.4.0 to cover the 2.4 -> 2.5 breaking change
65+
# to internals we leverage for instrumentation.
66+
AGNO_VERSIONS = (LATEST, "2.4.0", "2.1.0")
6567
# pydantic_ai 1.x requires Python >= 3.10
6668
# Two test suites with different version requirements:
6769
# 1. wrap_openai approach: works with older versions (0.1.9+)

py/src/braintrust/wrappers/agno/agent.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from braintrust.span_types import SpanTypeAttribute
66
from wrapt import wrap_function_wrapper
77

8+
from .run_helpers import arun_public_dispatch_wrapper, run_public_dispatch_wrapper
89
from .utils import (
910
_aggregate_agent_chunks,
1011
extract_metadata,
@@ -46,19 +47,18 @@ def _run_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any):
4647
return _create_run_span(wrapped, instance, args, kwargs, input_data)
4748

4849
def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
49-
"""Entry point for public run(input)."""
50-
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
51-
input_data = {"input": input_arg}
52-
return _create_run_span(wrapped, instance, args, kwargs, input_data)
50+
return run_public_dispatch_wrapper(
51+
wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent"
52+
)
5353

5454
# Wrap private method if it exists, otherwise wrap public method
5555
if hasattr(Agent, "_run"):
5656
wrap_function_wrapper(Agent, "_run", _run_wrapper_private)
5757
elif hasattr(Agent, "run"):
5858
wrap_function_wrapper(Agent, "run", _run_wrapper_public)
5959

60-
async def _create_arun_span(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict):
61-
"""Shared logic to create span and execute arun method."""
60+
async def _create_arun_span_private(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict):
61+
"""Shared logic to create span and execute async private _arun method."""
6262
agent_name = getattr(instance, "name", None) or "Agent"
6363
span_name = f"{agent_name}.arun"
6464

@@ -80,19 +80,16 @@ async def _arun_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs:
8080
run_response = args[0] if len(args) > 0 else kwargs.get("run_response")
8181
input_arg = args[1] if len(args) > 1 else kwargs.get("input")
8282
input_data = {"run_response": run_response, "input": input_arg}
83-
return await _create_arun_span(wrapped, instance, args, kwargs, input_data)
83+
return await _create_arun_span_private(wrapped, instance, args, kwargs, input_data)
8484

85-
async def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
86-
"""Entry point for public arun(input)."""
87-
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
88-
input_data = {"input": input_arg}
89-
return await _create_arun_span(wrapped, instance, args, kwargs, input_data)
85+
def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
86+
return arun_public_dispatch_wrapper(
87+
wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent"
88+
)
9089

9190
# Wrap private method if it exists, otherwise wrap public method
9291
if hasattr(Agent, "_arun"):
9392
wrap_function_wrapper(Agent, "_arun", _arun_wrapper_private)
94-
elif hasattr(Agent, "arun"):
95-
wrap_function_wrapper(Agent, "arun", _arun_wrapper_public)
9693

9794
def run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
9895
agent_name = getattr(instance, "name", None) or "Agent"
@@ -211,6 +208,9 @@ async def _trace_stream():
211208

212209
if hasattr(Agent, "_arun_stream"):
213210
wrap_function_wrapper(Agent, "_arun_stream", arun_stream_wrapper)
211+
elif not hasattr(Agent, "_arun") and hasattr(Agent, "arun"):
212+
# Agno >= 2.5 routes through public arun(..., stream=...)
213+
wrap_function_wrapper(Agent, "arun", _arun_wrapper_public)
214214

215215
mark_patched(Agent)
216216
return Agent
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import time
2+
from inspect import isawaitable
3+
from typing import Any
4+
5+
from braintrust.logger import start_span
6+
from braintrust.span_types import SpanTypeAttribute
7+
8+
from .utils import (
9+
extract_metadata,
10+
extract_metrics,
11+
is_async_iterator,
12+
is_sync_iterator,
13+
omit,
14+
trace_async_stream_result,
15+
trace_sync_stream_result,
16+
)
17+
18+
19+
def run_public_dispatch_wrapper(
20+
wrapped: Any,
21+
instance: Any,
22+
args: Any,
23+
kwargs: Any,
24+
*,
25+
default_name: str,
26+
metadata_component: str,
27+
) -> Any:
28+
"""Trace a public synchronous `run(...)` dispatch method.
29+
30+
Handles both non-streaming return values and synchronous streaming iterators.
31+
For iterator results, span lifecycle is delegated to `trace_sync_stream_result`.
32+
"""
33+
component_name = getattr(instance, "name", None) or default_name
34+
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
35+
input_data = {"input": input_arg}
36+
metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)}
37+
38+
span = start_span(
39+
name=f"{component_name}.run",
40+
type=SpanTypeAttribute.TASK,
41+
input=input_data,
42+
metadata=metadata,
43+
)
44+
span.set_current()
45+
start = time.time()
46+
try:
47+
result = wrapped(*args, **kwargs)
48+
if is_sync_iterator(result):
49+
return trace_sync_stream_result(result, span, start)
50+
span.log(
51+
output=result,
52+
metrics=extract_metrics(result),
53+
)
54+
span.unset_current()
55+
span.end()
56+
return result
57+
except Exception as e:
58+
span.log(error=str(e))
59+
span.unset_current()
60+
span.end()
61+
raise
62+
63+
64+
def arun_public_dispatch_wrapper(
65+
wrapped: Any,
66+
instance: Any,
67+
args: Any,
68+
kwargs: Any,
69+
*,
70+
default_name: str,
71+
metadata_component: str,
72+
) -> Any:
73+
"""Trace a public `arun(...)` dispatch method across async return contracts.
74+
75+
Supports all observed `arun` dispatcher behaviors:
76+
- immediate return value
77+
- awaitable returning a value
78+
- direct async iterator
79+
- awaitable returning an async iterator
80+
81+
If an async iterator is returned (directly or after await), span lifecycle is
82+
delegated to `trace_async_stream_result` so the span remains open until stream
83+
consumption completes.
84+
"""
85+
component_name = getattr(instance, "name", None) or default_name
86+
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
87+
input_data = {"input": input_arg}
88+
metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)}
89+
90+
span = start_span(
91+
name=f"{component_name}.arun",
92+
type=SpanTypeAttribute.TASK,
93+
input=input_data,
94+
metadata=metadata,
95+
)
96+
span.set_current()
97+
start = time.time()
98+
try:
99+
result = wrapped(*args, **kwargs)
100+
101+
if isawaitable(result):
102+
103+
async def _trace_awaitable():
104+
should_end_span = True
105+
try:
106+
awaited = await result
107+
if is_async_iterator(awaited):
108+
should_end_span = False
109+
return trace_async_stream_result(awaited, span, start)
110+
span.log(
111+
output=awaited,
112+
metrics=extract_metrics(awaited),
113+
)
114+
return awaited
115+
except Exception as e:
116+
span.log(error=str(e))
117+
raise
118+
finally:
119+
if should_end_span:
120+
span.unset_current()
121+
span.end()
122+
123+
return _trace_awaitable()
124+
125+
if is_async_iterator(result):
126+
return trace_async_stream_result(result, span, start)
127+
128+
span.log(
129+
output=result,
130+
metrics=extract_metrics(result),
131+
)
132+
span.unset_current()
133+
span.end()
134+
return result
135+
except Exception as e:
136+
span.log(error=str(e))
137+
span.unset_current()
138+
span.end()
139+
raise

py/src/braintrust/wrappers/agno/team.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from braintrust.span_types import SpanTypeAttribute
66
from wrapt import wrap_function_wrapper
77

8+
from .run_helpers import arun_public_dispatch_wrapper, run_public_dispatch_wrapper
89
from .utils import (
910
_aggregate_agent_chunks,
1011
extract_metadata,
@@ -46,19 +47,18 @@ def _run_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any):
4647
return _create_run_span(wrapped, instance, args, kwargs, input_data)
4748

4849
def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
49-
"""Entry point for public run(input)."""
50-
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
51-
input_data = {"input": input_arg}
52-
return _create_run_span(wrapped, instance, args, kwargs, input_data)
50+
return run_public_dispatch_wrapper(
51+
wrapped, instance, args, kwargs, default_name="Team", metadata_component="team"
52+
)
5353

5454
# Wrap private method if it exists, otherwise wrap public method
5555
if hasattr(Team, "_run"):
5656
wrap_function_wrapper(Team, "_run", _run_wrapper_private)
5757
elif hasattr(Team, "run"):
5858
wrap_function_wrapper(Team, "run", _run_wrapper_public)
5959

60-
async def _create_arun_span(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict):
61-
"""Shared logic to create span and execute arun method."""
60+
async def _create_arun_span_private(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict):
61+
"""Shared logic to create span and execute async private _arun method."""
6262
agent_name = getattr(instance, "name", None) or "Team"
6363
span_name = f"{agent_name}.arun"
6464

@@ -80,19 +80,16 @@ async def _arun_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs:
8080
run_response = args[0] if len(args) > 0 else kwargs.get("run_response")
8181
input_arg = args[1] if len(args) > 1 else kwargs.get("input")
8282
input_data = {"run_response": run_response, "input": input_arg}
83-
return await _create_arun_span(wrapped, instance, args, kwargs, input_data)
83+
return await _create_arun_span_private(wrapped, instance, args, kwargs, input_data)
8484

85-
async def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
86-
"""Entry point for public arun(input)."""
87-
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
88-
input_data = {"input": input_arg}
89-
return await _create_arun_span(wrapped, instance, args, kwargs, input_data)
85+
def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
86+
return arun_public_dispatch_wrapper(
87+
wrapped, instance, args, kwargs, default_name="Team", metadata_component="team"
88+
)
9089

9190
# Wrap private method if it exists, otherwise wrap public method
9291
if hasattr(Team, "_arun"):
9392
wrap_function_wrapper(Team, "_arun", _arun_wrapper_private)
94-
elif hasattr(Team, "arun"):
95-
wrap_function_wrapper(Team, "arun", _arun_wrapper_public)
9693

9794
def run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
9895
agent_name = getattr(instance, "name", None) or "Team"
@@ -211,6 +208,9 @@ async def _trace_stream():
211208

212209
if hasattr(Team, "_arun_stream"):
213210
wrap_function_wrapper(Team, "_arun_stream", arun_stream_wrapper)
211+
elif not hasattr(Team, "_arun") and hasattr(Team, "arun"):
212+
# Agno >= 2.5 routes through public arun(..., stream=...)
213+
wrap_function_wrapper(Team, "arun", _arun_wrapper_public)
214214

215215
mark_patched(Team)
216216
return Team

0 commit comments

Comments
 (0)