Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/claude_agent_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
ThinkingConfigAdaptive,
ThinkingConfigDisabled,
ThinkingConfigEnabled,
ToolContext,
ToolPermissionContext,
ToolResultBlock,
ToolUseBlock,
Expand All @@ -120,6 +121,32 @@
T = TypeVar("T")


def get_tool_context() -> ToolContext | None:
"""Get the current tool execution context, if available.

Returns the session context (session ID, transcript path, working
directory, etc.) when called from within an SDK MCP tool handler
registered via :func:`tool` / :func:`create_sdk_mcp_server`.

Returns ``None`` if called outside a tool execution or if session
info hasn't been received yet (e.g. no hooks have fired before the
first tool call).

Example::

@tool("my_tool", "Does something", {"query": str})
async def my_tool(args):
ctx = get_tool_context()
if ctx:
history = ctx.get_conversation_history()
# ... use history ...
return {"content": [{"type": "text", "text": "done"}]}
"""
from ._internal._tool_context import _current_tool_context

return _current_tool_context.get()


@dataclass
class SdkMcpTool(Generic[T]):
"""Definition for an SDK MCP tool."""
Expand Down Expand Up @@ -590,6 +617,9 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any:
"tool",
"SdkMcpTool",
"ToolAnnotations",
# Tool context
"get_tool_context",
"ToolContext",
# Errors
"ClaudeSDKError",
"CLIConnectionError",
Expand Down
13 changes: 13 additions & 0 deletions src/claude_agent_sdk/_internal/_tool_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Internal contextvar for propagating session context to SDK MCP tool handlers."""

from __future__ import annotations

import contextvars
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ..types import ToolContext

_current_tool_context: contextvars.ContextVar[ToolContext | None] = (
contextvars.ContextVar("_current_tool_context", default=None)
)
60 changes: 59 additions & 1 deletion src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SDKHookCallbackRequest,
ToolPermissionContext,
)
from ._tool_context import _current_tool_context
from .transport import Transport

if TYPE_CHECKING:
Expand Down Expand Up @@ -121,6 +122,13 @@ def __init__(
# Track first result for proper stream closure with SDK MCP servers
self._first_result_event = anyio.Event()

# Session context captured from hook callbacks (best-effort)
self._session_id: str | None = None
self._transcript_path: str | None = None
self._cwd: str | None = None
self._agent_id: str | None = None
self._agent_type: str | None = None

async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.

Expand Down Expand Up @@ -261,6 +269,26 @@ async def _read_messages(self) -> None:
# Always signal end of stream
await self._message_send.send({"type": "end"})

def _capture_session_context(self, input_data: Any) -> None:
"""Extract session metadata from hook/permission input data.

The CLI sends ``session_id``, ``transcript_path``, ``cwd``, and
optionally ``agent_id``/``agent_type`` on every hook callback and
permission request. We store the latest values so they can be
propagated to SDK MCP tool handlers via the contextvar.
"""
if not isinstance(input_data, dict):
return
if "session_id" in input_data:
self._session_id = input_data["session_id"]
if "transcript_path" in input_data:
self._transcript_path = input_data["transcript_path"]
if "cwd" in input_data:
self._cwd = input_data["cwd"]
# Optional sub-agent fields
self._agent_id = input_data.get("agent_id") or self._agent_id
self._agent_type = input_data.get("agent_type") or self._agent_type

async def _handle_control_request(self, request: SDKControlRequest) -> None:
"""Handle incoming control request from CLI."""
request_id = request["request_id"]
Expand All @@ -273,6 +301,9 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None:
if subtype == "can_use_tool":
permission_request: SDKControlPermissionRequest = request_data # type: ignore[assignment]
original_input = permission_request["input"]

# Capture session metadata from the permission request
self._capture_session_context(request_data.get("input"))
# Handle tool permission request
if not self.can_use_tool:
raise Exception("canUseTool callback is not provided")
Expand Down Expand Up @@ -317,6 +348,10 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None:

elif subtype == "hook_callback":
hook_callback_request: SDKHookCallbackRequest = request_data # type: ignore[assignment]

# Capture session metadata from hook input
self._capture_session_context(request_data.get("input"))

# Handle hook callback
callback_id = hook_callback_request["callback_id"]
callback = self.hook_callbacks.get(callback_id)
Expand Down Expand Up @@ -520,7 +555,30 @@ async def _handle_sdk_mcp_request(
)
handler = server.request_handlers.get(CallToolRequest)
if handler:
result = await handler(call_request)
# Set tool context for the duration of the handler call
token = None
session_id = getattr(self, "_session_id", None)
transcript_path = getattr(self, "_transcript_path", None)
cwd = getattr(self, "_cwd", None)
if session_id and transcript_path and cwd:
from ..types import ToolContext

ctx = ToolContext(
session_id=session_id,
transcript_path=transcript_path,
cwd=cwd,
agent_id=getattr(self, "_agent_id", None),
agent_type=getattr(self, "_agent_type", None),
tool_use_id=params.get("_meta", {}).get("tool_use_id")
if isinstance(params.get("_meta"), dict)
else None,
)
token = _current_tool_context.set(ctx)
try:
result = await handler(call_request)
finally:
if token is not None:
_current_tool_context.reset(token)
# Convert MCP result to JSONRPC response
content = []
for item in result.root.content: # type: ignore[union-attr]
Expand Down
53 changes: 53 additions & 0 deletions src/claude_agent_sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,59 @@ class SessionMessage:
parent_tool_use_id: None = None


@dataclass
class ToolContext:
"""Execution context available to SDK MCP tool handlers.

Populated automatically when the CLI sends hook callbacks containing
session metadata *before* a ``tools/call`` request. Call
:func:`get_tool_context` from inside an ``@tool`` handler to retrieve
this.

.. note::

This is a **best-effort** API. The context will be ``None`` if no
hook callbacks have fired prior to the tool call (e.g. the session
just started and no hooks were registered).

Attributes:
session_id: UUID of the active CLI session.
transcript_path: Filesystem path to the JSONL transcript file.
cwd: Working directory of the CLI process.
agent_id: Sub-agent identifier, present only inside a Task-spawned
sub-agent.
agent_type: Agent type name (e.g. ``"general-purpose"``).
tool_use_id: The ``tool_use_id`` of the current tool invocation,
if available.
"""

session_id: str
transcript_path: str
cwd: str
agent_id: str | None = None
agent_type: str | None = None
tool_use_id: str | None = None

def get_conversation_history(self) -> list["SessionMessage"]:
"""Read conversation history from the session transcript.

This performs file I/O to parse the JSONL transcript — it is an
explicit method (not a property) to make the cost visible to callers.

Returns:
List of :class:`SessionMessage` objects in chronological order.
Returns an empty list if the transcript cannot be found.
"""
from ._internal.sessions import get_session_messages

directory = (
str(Path(self.transcript_path).parent.parent)
if self.transcript_path
else None
)
return get_session_messages(self.session_id, directory=directory)


class ThinkingConfigAdaptive(TypedDict):
type: Literal["adaptive"]

Expand Down
Loading