From cba93e07f195cdee976b9094f297c889efe134e1 Mon Sep 17 00:00:00 2001 From: Nik-Reddy Date: Wed, 15 Apr 2026 21:48:23 -0700 Subject: [PATCH] feat: expose session context to SDK MCP tool handlers via get_tool_context() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tools registered with @tool can now call get_tool_context() to access session_id, transcript_path, and conversation history during execution. Context is best-effort — depends on hook callbacks having fired before the tool call to provide session metadata. Fixes #823 --- src/claude_agent_sdk/__init__.py | 30 +++ .../_internal/_tool_context.py | 13 + src/claude_agent_sdk/_internal/query.py | 60 ++++- src/claude_agent_sdk/types.py | 53 ++++ tests/test_tool_context.py | 250 ++++++++++++++++++ 5 files changed, 405 insertions(+), 1 deletion(-) create mode 100644 src/claude_agent_sdk/_internal/_tool_context.py create mode 100644 tests/test_tool_context.py diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 6a414c25..21e35019 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -106,6 +106,7 @@ ThinkingConfigAdaptive, ThinkingConfigDisabled, ThinkingConfigEnabled, + ToolContext, ToolPermissionContext, ToolResultBlock, ToolUseBlock, @@ -120,6 +121,32 @@ T = TypeVar("T") +def get_tool_context() -> ToolContext | None: + """Get the current tool execution context, if available. + + Returns the session context (session ID, transcript path, working + directory, etc.) when called from within an SDK MCP tool handler + registered via :func:`tool` / :func:`create_sdk_mcp_server`. + + Returns ``None`` if called outside a tool execution or if session + info hasn't been received yet (e.g. no hooks have fired before the + first tool call). + + Example:: + + @tool("my_tool", "Does something", {"query": str}) + async def my_tool(args): + ctx = get_tool_context() + if ctx: + history = ctx.get_conversation_history() + # ... use history ... + return {"content": [{"type": "text", "text": "done"}]} + """ + from ._internal._tool_context import _current_tool_context + + return _current_tool_context.get() + + @dataclass class SdkMcpTool(Generic[T]): """Definition for an SDK MCP tool.""" @@ -590,6 +617,9 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "tool", "SdkMcpTool", "ToolAnnotations", + # Tool context + "get_tool_context", + "ToolContext", # Errors "ClaudeSDKError", "CLIConnectionError", diff --git a/src/claude_agent_sdk/_internal/_tool_context.py b/src/claude_agent_sdk/_internal/_tool_context.py new file mode 100644 index 00000000..ea78ade1 --- /dev/null +++ b/src/claude_agent_sdk/_internal/_tool_context.py @@ -0,0 +1,13 @@ +"""Internal contextvar for propagating session context to SDK MCP tool handlers.""" + +from __future__ import annotations + +import contextvars +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..types import ToolContext + +_current_tool_context: contextvars.ContextVar[ToolContext | None] = ( + contextvars.ContextVar("_current_tool_context", default=None) +) diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 80b6d93c..317808c1 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -25,6 +25,7 @@ SDKHookCallbackRequest, ToolPermissionContext, ) +from ._tool_context import _current_tool_context from .transport import Transport if TYPE_CHECKING: @@ -121,6 +122,13 @@ def __init__( # Track first result for proper stream closure with SDK MCP servers self._first_result_event = anyio.Event() + # Session context captured from hook callbacks (best-effort) + self._session_id: str | None = None + self._transcript_path: str | None = None + self._cwd: str | None = None + self._agent_id: str | None = None + self._agent_type: str | None = None + async def initialize(self) -> dict[str, Any] | None: """Initialize control protocol if in streaming mode. @@ -261,6 +269,26 @@ async def _read_messages(self) -> None: # Always signal end of stream await self._message_send.send({"type": "end"}) + def _capture_session_context(self, input_data: Any) -> None: + """Extract session metadata from hook/permission input data. + + The CLI sends ``session_id``, ``transcript_path``, ``cwd``, and + optionally ``agent_id``/``agent_type`` on every hook callback and + permission request. We store the latest values so they can be + propagated to SDK MCP tool handlers via the contextvar. + """ + if not isinstance(input_data, dict): + return + if "session_id" in input_data: + self._session_id = input_data["session_id"] + if "transcript_path" in input_data: + self._transcript_path = input_data["transcript_path"] + if "cwd" in input_data: + self._cwd = input_data["cwd"] + # Optional sub-agent fields + self._agent_id = input_data.get("agent_id") or self._agent_id + self._agent_type = input_data.get("agent_type") or self._agent_type + async def _handle_control_request(self, request: SDKControlRequest) -> None: """Handle incoming control request from CLI.""" request_id = request["request_id"] @@ -273,6 +301,9 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: if subtype == "can_use_tool": permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment] original_input = permission_request["input"] + + # Capture session metadata from the permission request + self._capture_session_context(request_data.get("input")) # Handle tool permission request if not self.can_use_tool: raise Exception("canUseTool callback is not provided") @@ -317,6 +348,10 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: elif subtype == "hook_callback": hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment] + + # Capture session metadata from hook input + self._capture_session_context(request_data.get("input")) + # Handle hook callback callback_id = hook_callback_request["callback_id"] callback = self.hook_callbacks.get(callback_id) @@ -520,7 +555,30 @@ async def _handle_sdk_mcp_request( ) handler = server.request_handlers.get(CallToolRequest) if handler: - result = await handler(call_request) + # Set tool context for the duration of the handler call + token = None + session_id = getattr(self, "_session_id", None) + transcript_path = getattr(self, "_transcript_path", None) + cwd = getattr(self, "_cwd", None) + if session_id and transcript_path and cwd: + from ..types import ToolContext + + ctx = ToolContext( + session_id=session_id, + transcript_path=transcript_path, + cwd=cwd, + agent_id=getattr(self, "_agent_id", None), + agent_type=getattr(self, "_agent_type", None), + tool_use_id=params.get("_meta", {}).get("tool_use_id") + if isinstance(params.get("_meta"), dict) + else None, + ) + token = _current_tool_context.set(ctx) + try: + result = await handler(call_request) + finally: + if token is not None: + _current_tool_context.reset(token) # Convert MCP result to JSONRPC response content = [] for item in result.root.content: # type: ignore[union-attr] diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index a82a8b9b..3eb103f2 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -1155,6 +1155,59 @@ class SessionMessage: parent_tool_use_id: None = None +@dataclass +class ToolContext: + """Execution context available to SDK MCP tool handlers. + + Populated automatically when the CLI sends hook callbacks containing + session metadata *before* a ``tools/call`` request. Call + :func:`get_tool_context` from inside an ``@tool`` handler to retrieve + this. + + .. note:: + + This is a **best-effort** API. The context will be ``None`` if no + hook callbacks have fired prior to the tool call (e.g. the session + just started and no hooks were registered). + + Attributes: + session_id: UUID of the active CLI session. + transcript_path: Filesystem path to the JSONL transcript file. + cwd: Working directory of the CLI process. + agent_id: Sub-agent identifier, present only inside a Task-spawned + sub-agent. + agent_type: Agent type name (e.g. ``"general-purpose"``). + tool_use_id: The ``tool_use_id`` of the current tool invocation, + if available. + """ + + session_id: str + transcript_path: str + cwd: str + agent_id: str | None = None + agent_type: str | None = None + tool_use_id: str | None = None + + def get_conversation_history(self) -> list["SessionMessage"]: + """Read conversation history from the session transcript. + + This performs file I/O to parse the JSONL transcript — it is an + explicit method (not a property) to make the cost visible to callers. + + Returns: + List of :class:`SessionMessage` objects in chronological order. + Returns an empty list if the transcript cannot be found. + """ + from ._internal.sessions import get_session_messages + + directory = ( + str(Path(self.transcript_path).parent.parent) + if self.transcript_path + else None + ) + return get_session_messages(self.session_id, directory=directory) + + class ThinkingConfigAdaptive(TypedDict): type: Literal["adaptive"] diff --git a/tests/test_tool_context.py b/tests/test_tool_context.py new file mode 100644 index 00000000..3ab572b1 --- /dev/null +++ b/tests/test_tool_context.py @@ -0,0 +1,250 @@ +"""Tests for ToolContext and get_tool_context().""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from claude_agent_sdk import ToolContext, get_tool_context +from claude_agent_sdk._internal._tool_context import _current_tool_context +from claude_agent_sdk.types import SessionMessage + + +class TestToolContextDataclass: + """Test ToolContext creation and field access.""" + + def test_required_fields(self) -> None: + ctx = ToolContext( + session_id="abc-123", + transcript_path="/home/user/.claude/projects/proj/abc-123.jsonl", + cwd="/home/user/project", + ) + assert ctx.session_id == "abc-123" + assert ctx.transcript_path == "/home/user/.claude/projects/proj/abc-123.jsonl" + assert ctx.cwd == "/home/user/project" + + def test_optional_fields_default_none(self) -> None: + ctx = ToolContext( + session_id="abc-123", + transcript_path="/path/to/transcript.jsonl", + cwd="/cwd", + ) + assert ctx.agent_id is None + assert ctx.agent_type is None + assert ctx.tool_use_id is None + + def test_all_fields(self) -> None: + ctx = ToolContext( + session_id="abc-123", + transcript_path="/path/to/transcript.jsonl", + cwd="/cwd", + agent_id="agent-456", + agent_type="general-purpose", + tool_use_id="toolu_01ABC", + ) + assert ctx.agent_id == "agent-456" + assert ctx.agent_type == "general-purpose" + assert ctx.tool_use_id == "toolu_01ABC" + + +class TestGetToolContext: + """Test get_tool_context() function.""" + + def test_returns_none_by_default(self) -> None: + assert get_tool_context() is None + + def test_returns_context_when_set(self) -> None: + ctx = ToolContext( + session_id="sess-1", + transcript_path="/p/t.jsonl", + cwd="/cwd", + ) + token = _current_tool_context.set(ctx) + try: + result = get_tool_context() + assert result is ctx + assert result is not None + assert result.session_id == "sess-1" + finally: + _current_tool_context.reset(token) + + def test_returns_none_after_reset(self) -> None: + ctx = ToolContext( + session_id="sess-2", + transcript_path="/p/t.jsonl", + cwd="/cwd", + ) + token = _current_tool_context.set(ctx) + _current_tool_context.reset(token) + assert get_tool_context() is None + + +class TestContextVarIsolation: + """Test that the contextvar is properly scoped.""" + + def test_set_and_reset(self) -> None: + ctx = ToolContext( + session_id="s1", + transcript_path="/t.jsonl", + cwd="/c", + ) + # Should start as None + assert _current_tool_context.get() is None + + token = _current_tool_context.set(ctx) + assert _current_tool_context.get() is ctx + + _current_tool_context.reset(token) + assert _current_tool_context.get() is None + + def test_nested_set_reset(self) -> None: + ctx1 = ToolContext(session_id="s1", transcript_path="/t1.jsonl", cwd="/c1") + ctx2 = ToolContext(session_id="s2", transcript_path="/t2.jsonl", cwd="/c2") + + token1 = _current_tool_context.set(ctx1) + assert _current_tool_context.get() is ctx1 + + token2 = _current_tool_context.set(ctx2) + assert _current_tool_context.get() is ctx2 + + _current_tool_context.reset(token2) + assert _current_tool_context.get() is ctx1 + + _current_tool_context.reset(token1) + assert _current_tool_context.get() is None + + +class TestGetConversationHistory: + """Test ToolContext.get_conversation_history().""" + + def test_calls_get_session_messages_with_correct_args(self) -> None: + ctx = ToolContext( + session_id="550e8400-e29b-41d4-a716-446655440000", + transcript_path="/home/user/.claude/projects/proj/sessions/550e8400.jsonl", + cwd="/home/user/project", + ) + expected_dir = str( + Path("/home/user/.claude/projects/proj/sessions/550e8400.jsonl") + .parent.parent + ) + + mock_messages = [ + SessionMessage( + type="user", + uuid="msg-1", + session_id="550e8400-e29b-41d4-a716-446655440000", + message={"role": "user", "content": "hello"}, + ), + ] + + with patch( + "claude_agent_sdk._internal.sessions.get_session_messages", + return_value=mock_messages, + ) as mock_fn: + result = ctx.get_conversation_history() + + mock_fn.assert_called_once_with( + "550e8400-e29b-41d4-a716-446655440000", + directory=expected_dir, + ) + assert result == mock_messages + + def test_returns_empty_list_on_missing_session(self) -> None: + ctx = ToolContext( + session_id="nonexistent", + transcript_path="/fake/path/sessions/nonexistent.jsonl", + cwd="/cwd", + ) + with patch( + "claude_agent_sdk._internal.sessions.get_session_messages", + return_value=[], + ): + result = ctx.get_conversation_history() + assert result == [] + + +class TestCaptureSessionContext: + """Test that Query._capture_session_context works correctly.""" + + def test_capture_from_hook_input(self) -> None: + from claude_agent_sdk._internal.query import Query + + # Create a minimal Query without going through __init__ + q = object.__new__(Query) + q._session_id = None + q._transcript_path = None + q._cwd = None + q._agent_id = None + q._agent_type = None + + q._capture_session_context({ + "session_id": "s-100", + "transcript_path": "/t/s-100.jsonl", + "cwd": "/project", + "agent_id": "a-1", + "agent_type": "code-reviewer", + }) + + assert q._session_id == "s-100" + assert q._transcript_path == "/t/s-100.jsonl" + assert q._cwd == "/project" + assert q._agent_id == "a-1" + assert q._agent_type == "code-reviewer" + + def test_capture_ignores_non_dict(self) -> None: + from claude_agent_sdk._internal.query import Query + + q = object.__new__(Query) + q._session_id = None + q._transcript_path = None + q._cwd = None + q._agent_id = None + q._agent_type = None + + q._capture_session_context(None) + q._capture_session_context("not a dict") + + assert q._session_id is None + + def test_capture_preserves_agent_fields(self) -> None: + """agent_id/agent_type should not be overwritten with None.""" + from claude_agent_sdk._internal.query import Query + + q = object.__new__(Query) + q._session_id = None + q._transcript_path = None + q._cwd = None + q._agent_id = "existing-agent" + q._agent_type = "general-purpose" + + # Input without agent fields should preserve existing values + q._capture_session_context({ + "session_id": "s-200", + "transcript_path": "/t.jsonl", + "cwd": "/c", + }) + + assert q._agent_id == "existing-agent" + assert q._agent_type == "general-purpose" + + +class TestExports: + """Test that ToolContext and get_tool_context are properly exported.""" + + def test_tool_context_in_all(self) -> None: + import claude_agent_sdk + + assert "ToolContext" in claude_agent_sdk.__all__ + + def test_get_tool_context_in_all(self) -> None: + import claude_agent_sdk + + assert "get_tool_context" in claude_agent_sdk.__all__ + + def test_importable(self) -> None: + from claude_agent_sdk import ToolContext, get_tool_context + + assert ToolContext is not None + assert callable(get_tool_context)