|
5 | 5 | the actual Claude Agent SDK. |
6 | 6 | """ |
7 | 7 |
|
| 8 | +import asyncio |
| 9 | +import gc |
| 10 | +import sys |
| 11 | +import types |
| 12 | +from typing import Type |
| 13 | + |
8 | 14 | import pytest |
9 | 15 |
|
10 | 16 | # Try to import the Claude Agent SDK - skip tests if not available |
|
19 | 25 | from braintrust import logger |
20 | 26 | from braintrust.span_types import SpanTypeAttribute |
21 | 27 | from braintrust.test_helpers import init_test_logger |
| 28 | +from braintrust.wrappers.claude_agent_sdk import setup_claude_agent_sdk |
22 | 29 | from braintrust.wrappers.claude_agent_sdk._wrapper import ( |
23 | 30 | _create_client_wrapper_class, |
24 | 31 | _create_tool_wrapper_class, |
@@ -292,3 +299,110 @@ class TestAutoInstrumentClaudeAgentSDK: |
292 | 299 | def test_auto_instrument_claude_agent_sdk(self): |
293 | 300 | """Test auto_instrument patches Claude Agent SDK and creates spans.""" |
294 | 301 | verify_autoinstrument_script("test_auto_claude_agent_sdk.py") |
| 302 | + |
| 303 | + |
| 304 | +class _FakeClaudeAgentOptions: |
| 305 | + def __init__(self, model, permission_mode=None): |
| 306 | + self.model = model |
| 307 | + self.permission_mode = permission_mode |
| 308 | + |
| 309 | + |
| 310 | +class _FakeMessage: |
| 311 | + def __init__(self, content): |
| 312 | + self.content = content |
| 313 | + |
| 314 | + |
| 315 | +class _FakeResultMessage: |
| 316 | + def __init__(self): |
| 317 | + self.usage = types.SimpleNamespace(input_tokens=1, output_tokens=1, cache_creation_input_tokens=0) |
| 318 | + self.num_turns = 1 |
| 319 | + self.session_id = "session-123" |
| 320 | + |
| 321 | + |
| 322 | +class _FakeClaudeSDKClient: |
| 323 | + def __init__(self, options): |
| 324 | + self.options = options |
| 325 | + self._prompt = None |
| 326 | + |
| 327 | + async def __aenter__(self): |
| 328 | + return self |
| 329 | + |
| 330 | + async def __aexit__(self, *args): |
| 331 | + return None |
| 332 | + |
| 333 | + async def query(self, prompt): |
| 334 | + self._prompt = prompt |
| 335 | + |
| 336 | + async def receive_response(self): |
| 337 | + yield _FakeMessage("Hello") |
| 338 | + await asyncio.sleep(0) |
| 339 | + yield _FakeResultMessage() |
| 340 | + |
| 341 | + |
| 342 | +class _FakeClaudeSdkModule(types.ModuleType): |
| 343 | + ClaudeSDKClient: Type[_FakeClaudeSDKClient] |
| 344 | + ClaudeAgentOptions: Type[_FakeClaudeAgentOptions] |
| 345 | + SdkMcpTool = None |
| 346 | + tool = None |
| 347 | + |
| 348 | + |
| 349 | +class _FakeConsumerModule(types.ModuleType): |
| 350 | + ClaudeSDKClient: Type[_FakeClaudeSDKClient] |
| 351 | + ClaudeAgentOptions: Type[_FakeClaudeAgentOptions] |
| 352 | + |
| 353 | + |
| 354 | +def _install_fake_claude_sdk(monkeypatch): |
| 355 | + fake_module = _FakeClaudeSdkModule("claude_agent_sdk") |
| 356 | + fake_module.ClaudeSDKClient = _FakeClaudeSDKClient |
| 357 | + fake_module.ClaudeAgentOptions = _FakeClaudeAgentOptions |
| 358 | + monkeypatch.setitem(sys.modules, "claude_agent_sdk", fake_module) |
| 359 | + return fake_module |
| 360 | + |
| 361 | + |
| 362 | +@pytest.mark.asyncio |
| 363 | +async def test_setup_claude_agent_sdk_repro_import_before_setup(memory_logger, monkeypatch): |
| 364 | + """Regression test for https://github.com/braintrustdata/braintrust-sdk-python/issues/7.""" |
| 365 | + assert not memory_logger.pop() |
| 366 | + |
| 367 | + fake_sdk = _install_fake_claude_sdk(monkeypatch) |
| 368 | + consumer_module_name = "test_issue7_repro_module" |
| 369 | + consumer_module = _FakeConsumerModule(consumer_module_name) |
| 370 | + consumer_module.ClaudeSDKClient = fake_sdk.ClaudeSDKClient |
| 371 | + consumer_module.ClaudeAgentOptions = fake_sdk.ClaudeAgentOptions |
| 372 | + monkeypatch.setitem(sys.modules, consumer_module_name, consumer_module) |
| 373 | + |
| 374 | + # Mirror the reported import pattern: |
| 375 | + # from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions |
| 376 | + assert setup_claude_agent_sdk(project=PROJECT_NAME, api_key=logger.TEST_API_KEY) |
| 377 | + assert consumer_module.ClaudeSDKClient is not _FakeClaudeSDKClient |
| 378 | + |
| 379 | + loop_errors = [] |
| 380 | + received_types = [] |
| 381 | + |
| 382 | + async def main(): |
| 383 | + loop = asyncio.get_running_loop() |
| 384 | + loop.set_exception_handler(lambda loop, ctx: loop_errors.append(ctx.get("exception") or ctx.get("message"))) |
| 385 | + |
| 386 | + options = consumer_module.ClaudeAgentOptions( |
| 387 | + model="claude-sonnet-4-20250514", |
| 388 | + permission_mode="bypassPermissions", |
| 389 | + ) |
| 390 | + async with consumer_module.ClaudeSDKClient(options=options) as client: |
| 391 | + await client.query("Hello") |
| 392 | + async for message in client.receive_response(): |
| 393 | + received_types.append(type(message).__name__) |
| 394 | + |
| 395 | + await asyncio.sleep(0) |
| 396 | + gc.collect() |
| 397 | + await asyncio.sleep(0.01) |
| 398 | + |
| 399 | + await main() |
| 400 | + |
| 401 | + assert loop_errors == [] |
| 402 | + assert received_types == ["_FakeMessage", "_FakeResultMessage"] |
| 403 | + |
| 404 | + spans = memory_logger.pop() |
| 405 | + task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK] |
| 406 | + assert len(task_spans) == 1 |
| 407 | + assert task_spans[0]["span_attributes"]["name"] == "Claude Agent" |
| 408 | + assert task_spans[0]["input"] == "Hello" |
0 commit comments