66
77"""
88
9+ from __future__ import annotations as _annotations
10+
911import asyncio
1012import os
13+ import socketserver
1114import threading
1215import time
1316import webbrowser
1417from http .server import BaseHTTPRequestHandler , HTTPServer
15- from typing import Any
18+ from typing import Any , Callable
1619from urllib .parse import parse_qs , urlparse
1720
1821import httpx
22+ from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
1923from mcp .client .auth import OAuthClientProvider , TokenStorage
2024from mcp .client .session import ClientSession
2125from mcp .client .sse import sse_client
2226from mcp .client .streamable_http import streamable_http_client
2327from mcp .shared .auth import OAuthClientInformationFull , OAuthClientMetadata , OAuthToken
28+ from mcp .shared .message import SessionMessage
2429
2530
2631class InMemoryTokenStorage (TokenStorage ):
@@ -46,7 +51,13 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
4651class CallbackHandler (BaseHTTPRequestHandler ):
4752 """Simple HTTP handler to capture OAuth callback."""
4853
49- def __init__ (self , request , client_address , server , callback_data ):
54+ def __init__ (
55+ self ,
56+ request : Any ,
57+ client_address : tuple [str , int ],
58+ server : socketserver .BaseServer ,
59+ callback_data : dict [str , Any ],
60+ ):
5061 """Initialize with callback data storage."""
5162 self .callback_data = callback_data
5263 super ().__init__ (request , client_address , server )
@@ -91,15 +102,14 @@ def do_GET(self):
91102 self .send_response (404 )
92103 self .end_headers ()
93104
94- def log_message (self , format , * args ):
105+ def log_message (self , format : str , * args : Any ):
95106 """Suppress default logging."""
96- pass
97107
98108
99109class CallbackServer :
100110 """Simple server to handle OAuth callbacks."""
101111
102- def __init__ (self , port = 3000 ):
112+ def __init__ (self , port : int = 3000 ):
103113 self .port = port
104114 self .server = None
105115 self .thread = None
@@ -110,7 +120,12 @@ def _create_handler_with_data(self):
110120 callback_data = self .callback_data
111121
112122 class DataCallbackHandler (CallbackHandler ):
113- def __init__ (self , request , client_address , server ):
123+ def __init__ (
124+ self ,
125+ request : BaseHTTPRequestHandler ,
126+ client_address : tuple [str , int ],
127+ server : socketserver .BaseServer ,
128+ ):
114129 super ().__init__ (request , client_address , server , callback_data )
115130
116131 return DataCallbackHandler
@@ -131,7 +146,7 @@ def stop(self):
131146 if self .thread :
132147 self .thread .join (timeout = 1 )
133148
134- def wait_for_callback (self , timeout = 300 ):
149+ def wait_for_callback (self , timeout : int = 300 ):
135150 """Wait for OAuth callback with timeout."""
136151 start_time = time .time ()
137152 while time .time () - start_time < timeout :
@@ -225,7 +240,12 @@ async def _default_redirect_handler(authorization_url: str) -> None:
225240
226241 traceback .print_exc ()
227242
228- async def _run_session (self , read_stream , write_stream , get_session_id ):
243+ async def _run_session (
244+ self ,
245+ read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ],
246+ write_stream : MemoryObjectSendStream [SessionMessage ],
247+ get_session_id : Callable [[], str | None ] | None = None ,
248+ ):
229249 """Run the MCP session with the given streams."""
230250 print ("🤝 Initializing MCP session..." )
231251 async with ClientSession (read_stream , write_stream ) as session :
@@ -314,7 +334,7 @@ async def interactive_loop(self):
314334 continue
315335
316336 # Parse arguments (simple JSON-like format)
317- arguments = {}
337+ arguments : dict [ str , Any ] = {}
318338 if len (parts ) > 2 :
319339 import json
320340
0 commit comments