@@ -429,6 +429,37 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se
429429 return False
430430 return True
431431
432+ def _validate_init_session (self , request : Request ) -> bool :
433+ """Check if an initialization request has a valid session ID."""
434+ if not self .mcp_session_id :
435+ return True
436+ request_session_id = self ._get_session_id (request )
437+ if request_session_id and request_session_id != self .mcp_session_id :
438+ return False
439+ return True
440+
441+ async def _parse_jsonrpc_body (
442+ self , body : bytes , scope : Scope , receive : Receive , send : Send
443+ ) -> JSONRPCMessage | None :
444+ """Parse request body into a JSON-RPC message, sending error responses on failure."""
445+ try :
446+ raw_message = pydantic_core .from_json (body )
447+ except ValueError as e :
448+ response = self ._create_error_response (f"Parse error: { str (e )} " , HTTPStatus .BAD_REQUEST , PARSE_ERROR )
449+ await response (scope , receive , send )
450+ return None
451+
452+ try :
453+ return jsonrpc_message_adapter .validate_python (raw_message , by_name = False )
454+ except ValidationError as e : # pragma: no cover
455+ response = self ._create_error_response (
456+ f"Validation error: { str (e )} " ,
457+ HTTPStatus .BAD_REQUEST ,
458+ INVALID_PARAMS ,
459+ )
460+ await response (scope , receive , send )
461+ return None
462+
432463 async def _handle_post_request (self , scope : Scope , request : Request , receive : Receive , send : Send ) -> None :
433464 """Handle POST requests containing JSON-RPC messages."""
434465 writer = self ._read_stream_writer
@@ -451,41 +482,21 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
451482 # Parse the body - only read it once
452483 body = await request .body ()
453484
454- try :
455- raw_message = pydantic_core .from_json (body )
456- except ValueError as e :
457- response = self ._create_error_response (f"Parse error: { str (e )} " , HTTPStatus .BAD_REQUEST , PARSE_ERROR )
458- await response (scope , receive , send )
459- return
460-
461- try :
462- message = jsonrpc_message_adapter .validate_python (raw_message , by_name = False )
463- except ValidationError as e : # pragma: no cover
464- response = self ._create_error_response (
465- f"Validation error: { str (e )} " ,
466- HTTPStatus .BAD_REQUEST ,
467- INVALID_PARAMS ,
468- )
469- await response (scope , receive , send )
485+ message = await self ._parse_jsonrpc_body (body , scope , receive , send )
486+ if message is None :
470487 return
471488
472489 # Check if this is an initialization request
473490 is_initialization_request = isinstance (message , JSONRPCRequest ) and message .method == "initialize"
474491
475492 if is_initialization_request :
476- # Check if the server already has an established session
477- if self .mcp_session_id :
478- # Check if request has a session ID
479- request_session_id = self ._get_session_id (request )
480-
481- # If request has a session ID but doesn't match, return 404
482- if request_session_id and request_session_id != self .mcp_session_id : # pragma: no cover
483- response = self ._create_error_response (
484- "Not Found: Invalid or expired session ID" ,
485- HTTPStatus .NOT_FOUND ,
486- )
487- await response (scope , receive , send )
488- return
493+ if not self ._validate_init_session (request ): # pragma: no cover
494+ response = self ._create_error_response (
495+ "Not Found: Invalid or expired session ID" ,
496+ HTTPStatus .NOT_FOUND ,
497+ )
498+ await response (scope , receive , send )
499+ return
489500 elif not await self ._validate_request_headers (request , send ): # pragma: no cover
490501 return
491502
0 commit comments