Skip to content

Commit 4f7b898

Browse files
committed
feat: add public setter methods for ClientSession callbacks
Add set_sampling_callback(), set_elicitation_callback(), and set_list_roots_callback() methods to ClientSession, allowing callbacks to be updated at runtime after initialization without mutating private attributes directly. Also removes the # pragma: no cover from _default_elicitation_callback and adds coverage via the new test for set_elicitation_callback(None). Reported-by: dgenio Github-Issue: #2379
1 parent cf4e435 commit 4f7b898

File tree

4 files changed

+177
-1
lines changed

4 files changed

+177
-1
lines changed

src/mcp/client/session.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def _default_elicitation_callback(
7474
context: RequestContext[ClientSession],
7575
params: types.ElicitRequestParams,
7676
) -> types.ElicitResult | types.ErrorData:
77-
return types.ErrorData( # pragma: no cover
77+
return types.ErrorData(
7878
code=types.INVALID_REQUEST,
7979
message="Elicitation not supported",
8080
)
@@ -216,6 +216,48 @@ def experimental(self) -> ExperimentalClientFeatures:
216216
self._experimental_features = ExperimentalClientFeatures(self)
217217
return self._experimental_features
218218

219+
def set_sampling_callback(self, callback: SamplingFnT | None) -> None:
220+
"""Update the sampling callback.
221+
222+
Note: Client capabilities are advertised to the server during :meth:`initialize`
223+
and will not be re-negotiated when this setter is called. If a sampling
224+
callback is set after initialization, the server may not be aware of the
225+
capability.
226+
227+
Args:
228+
callback: The new sampling callback, or ``None`` to restore the default
229+
(which rejects all sampling requests with an error).
230+
"""
231+
self._sampling_callback = callback or _default_sampling_callback
232+
233+
def set_elicitation_callback(self, callback: ElicitationFnT | None) -> None:
234+
"""Update the elicitation callback.
235+
236+
Note: Client capabilities are advertised to the server during :meth:`initialize`
237+
and will not be re-negotiated when this setter is called. If an elicitation
238+
callback is set after initialization, the server may not be aware of the
239+
capability.
240+
241+
Args:
242+
callback: The new elicitation callback, or ``None`` to restore the default
243+
(which rejects all elicitation requests with an error).
244+
"""
245+
self._elicitation_callback = callback or _default_elicitation_callback
246+
247+
def set_list_roots_callback(self, callback: ListRootsFnT | None) -> None:
248+
"""Update the list roots callback.
249+
250+
Note: Client capabilities are advertised to the server during :meth:`initialize`
251+
and will not be re-negotiated when this setter is called. If a list-roots
252+
callback is set after initialization, the server may not be aware of the
253+
capability.
254+
255+
Args:
256+
callback: The new list roots callback, or ``None`` to restore the default
257+
(which rejects all list-roots requests with an error).
258+
"""
259+
self._list_roots_callback = callback or _default_list_roots_callback
260+
219261
async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
220262
"""Send a ping request."""
221263
return await self.send_request(types.PingRequest(params=types.RequestParams(_meta=meta)), types.EmptyResult)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
from pydantic import BaseModel, Field
5+
6+
from mcp import Client
7+
from mcp.client.session import ClientSession
8+
from mcp.server.mcpserver import Context, MCPServer
9+
from mcp.shared._context import RequestContext
10+
from mcp.types import ElicitRequestParams, ElicitResult, TextContent
11+
12+
13+
class AnswerSchema(BaseModel):
14+
answer: str = Field(description="The user's answer")
15+
16+
17+
@pytest.mark.anyio
18+
async def test_set_elicitation_callback():
19+
server = MCPServer("test")
20+
21+
updated_answer = "Updated answer"
22+
23+
async def updated_callback(
24+
context: RequestContext[ClientSession],
25+
params: ElicitRequestParams,
26+
) -> ElicitResult:
27+
return ElicitResult(action="accept", content={"answer": updated_answer})
28+
29+
@server.tool("ask")
30+
async def ask(prompt: str, ctx: Context) -> str:
31+
result = await ctx.elicit(message=prompt, schema=AnswerSchema)
32+
if result.action == "accept" and result.data:
33+
return result.data.answer
34+
return "no answer" # pragma: no cover
35+
36+
async with Client(server) as client:
37+
# Before setting callback — default rejects with error
38+
result = await client.call_tool("ask", {"prompt": "question?"})
39+
assert result.is_error is True
40+
41+
# Set new callback — should succeed
42+
client.session.set_elicitation_callback(updated_callback)
43+
result = await client.call_tool("ask", {"prompt": "question?"})
44+
assert result.is_error is False
45+
assert isinstance(result.content[0], TextContent)
46+
assert result.content[0].text == updated_answer
47+
48+
# Reset to None — back to default error
49+
client.session.set_elicitation_callback(None)
50+
result = await client.call_tool("ask", {"prompt": "question?"})
51+
assert result.is_error is True

tests/client/test_list_roots_callback.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,42 @@ async def test_list_roots(context: Context, message: str):
4545
assert result.is_error is True
4646
assert isinstance(result.content[0], TextContent)
4747
assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported"
48+
49+
50+
@pytest.mark.anyio
51+
async def test_set_list_roots_callback():
52+
server = MCPServer("test")
53+
54+
updated_result = ListRootsResult(
55+
roots=[
56+
Root(uri=FileUrl("file://users/fake/updated"), name="Updated Root"),
57+
]
58+
)
59+
60+
async def updated_callback(
61+
context: RequestContext[ClientSession],
62+
) -> ListRootsResult:
63+
return updated_result
64+
65+
@server.tool("get_roots")
66+
async def get_roots(context: Context, param: str) -> bool:
67+
roots = await context.session.list_roots()
68+
assert roots == updated_result
69+
return True
70+
71+
async with Client(server) as client:
72+
# Before setting callback — default rejects with error
73+
result = await client.call_tool("get_roots", {"param": "x"})
74+
assert result.is_error is True
75+
76+
# Set new callback — should succeed
77+
client.session.set_list_roots_callback(updated_callback)
78+
result = await client.call_tool("get_roots", {"param": "x"})
79+
assert result.is_error is False
80+
assert isinstance(result.content[0], TextContent)
81+
assert result.content[0].text == "true"
82+
83+
# Reset to None — back to default error
84+
client.session.set_list_roots_callback(None)
85+
result = await client.call_tool("get_roots", {"param": "x"})
86+
assert result.is_error is True

tests/client/test_sampling_callback.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,50 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool:
5757
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
5858

5959

60+
@pytest.mark.anyio
61+
async def test_set_sampling_callback():
62+
server = MCPServer("test")
63+
64+
updated_return = CreateMessageResult(
65+
role="assistant",
66+
content=TextContent(type="text", text="Updated response"),
67+
model="updated-model",
68+
stop_reason="endTurn",
69+
)
70+
71+
async def updated_callback(
72+
context: RequestContext[ClientSession],
73+
params: CreateMessageRequestParams,
74+
) -> CreateMessageResult:
75+
return updated_return
76+
77+
@server.tool("do_sample")
78+
async def do_sample(message: str, ctx: Context) -> bool:
79+
value = await ctx.session.create_message(
80+
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
81+
max_tokens=100,
82+
)
83+
assert value == updated_return
84+
return True
85+
86+
async with Client(server) as client:
87+
# Before setting callback — default rejects with error
88+
result = await client.call_tool("do_sample", {"message": "test"})
89+
assert result.is_error is True
90+
91+
# Set new callback — should succeed
92+
client.session.set_sampling_callback(updated_callback)
93+
result = await client.call_tool("do_sample", {"message": "test"})
94+
assert result.is_error is False
95+
assert isinstance(result.content[0], TextContent)
96+
assert result.content[0].text == "true"
97+
98+
# Reset to None — back to default error
99+
client.session.set_sampling_callback(None)
100+
result = await client.call_tool("do_sample", {"message": "test"})
101+
assert result.is_error is True
102+
103+
60104
@pytest.mark.anyio
61105
async def test_create_message_backwards_compat_single_content():
62106
"""Test backwards compatibility: create_message without tools returns single content."""

0 commit comments

Comments
 (0)