@@ -28,7 +28,9 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
2828be instantiated directly by users of the MCP framework.
2929"""
3030
31+ import logging
3132from enum import Enum
33+ from types import TracebackType
3234from typing import Any , TypeVar , overload
3335
3436import anyio
@@ -51,11 +53,59 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
5153)
5254from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
5355
56+ logger = logging .getLogger (__name__ )
57+
5458
5559class InitializationState (Enum ):
60+ """Represents the lifecycle states of a server session.
61+
62+ State transitions:
63+ NotInitialized -> Initializing -> Initialized -> Closing -> Closed
64+ Stateless -> Closing -> Closed
65+ Any state -> Error (on unrecoverable failure)
66+ """
67+
5668 NotInitialized = 1
5769 Initializing = 2
5870 Initialized = 3
71+ Stateless = 4
72+ Closing = 5
73+ Closed = 6
74+ Error = 7
75+
76+
77+ # Valid state transitions: maps each state to the set of states it can transition to.
78+ _VALID_TRANSITIONS : dict [InitializationState , set [InitializationState ]] = {
79+ InitializationState .NotInitialized : {
80+ InitializationState .Initializing ,
81+ InitializationState .Initialized , # client may send notification without prior request
82+ InitializationState .Closing ,
83+ InitializationState .Error ,
84+ },
85+ InitializationState .Initializing : {
86+ InitializationState .Initialized ,
87+ InitializationState .Closing ,
88+ InitializationState .Error ,
89+ },
90+ InitializationState .Initialized : {
91+ InitializationState .Initializing , # re-initialization
92+ InitializationState .Closing ,
93+ InitializationState .Error ,
94+ },
95+ InitializationState .Stateless : {
96+ InitializationState .Closing ,
97+ InitializationState .Error ,
98+ },
99+ InitializationState .Closing : {
100+ InitializationState .Closed ,
101+ InitializationState .Error ,
102+ },
103+ InitializationState .Closed : set (),
104+ InitializationState .Error : {
105+ InitializationState .Closing ,
106+ InitializationState .Closed ,
107+ },
108+ }
59109
60110
61111ServerSessionT = TypeVar ("ServerSessionT" , bound = "ServerSession" )
@@ -74,7 +124,7 @@ class ServerSession(
74124 types .ClientNotification ,
75125 ]
76126):
77- _initialized : InitializationState = InitializationState .NotInitialized
127+ _initialization_state : InitializationState = InitializationState .NotInitialized
78128 _client_params : types .InitializeRequestParams | None = None
79129 _experimental_features : ExperimentalServerSessionFeatures | None = None
80130
@@ -87,9 +137,7 @@ def __init__(
87137 ) -> None :
88138 super ().__init__ (read_stream , write_stream )
89139 self ._stateless = stateless
90- self ._initialization_state = (
91- InitializationState .Initialized if stateless else InitializationState .NotInitialized
92- )
140+ self ._initialization_state = InitializationState .Stateless if stateless else InitializationState .NotInitialized
93141
94142 self ._init_options = init_options
95143 self ._incoming_message_stream_writer , self ._incoming_message_stream_reader = anyio .create_memory_object_stream [
@@ -105,6 +153,39 @@ def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]:
105153 def _receive_notification_adapter (self ) -> TypeAdapter [types .ClientNotification ]:
106154 return types .client_notification_adapter
107155
156+ @property
157+ def initialization_state (self ) -> InitializationState :
158+ """Return the current initialization state of the session."""
159+ return self ._initialization_state
160+
161+ @property
162+ def is_initialized (self ) -> bool :
163+ """Check whether the session is ready to process requests.
164+
165+ Returns True when the session has completed initialization handshake
166+ (Initialized) or is operating in stateless mode (Stateless).
167+ """
168+ return self ._initialization_state in (
169+ InitializationState .Initialized ,
170+ InitializationState .Stateless ,
171+ )
172+
173+ def _transition_state (self , new_state : InitializationState ) -> None :
174+ """Transition the session to a new state, validating the transition.
175+
176+ Args:
177+ new_state: The target state to transition to.
178+
179+ Raises:
180+ RuntimeError: If the transition is not valid from the current state.
181+ """
182+ current = self ._initialization_state
183+ valid_targets = _VALID_TRANSITIONS .get (current , set ())
184+ if new_state not in valid_targets :
185+ raise RuntimeError (f"Invalid session state transition: { current .name } -> { new_state .name } " )
186+ logger .debug ("Session state transition: %s -> %s" , current .name , new_state .name )
187+ self ._initialization_state = new_state
188+
108189 @property
109190 def client_params (self ) -> types .InitializeRequestParams | None :
110191 return self ._client_params
@@ -162,11 +243,34 @@ async def _receive_loop(self) -> None:
162243 async with self ._incoming_message_stream_writer :
163244 await super ()._receive_loop ()
164245
246+ async def __aexit__ (
247+ self ,
248+ exc_type : type [BaseException ] | None ,
249+ exc_val : BaseException | None ,
250+ exc_tb : TracebackType | None ,
251+ ) -> bool | None :
252+ """Clean up the session with proper state transitions."""
253+ try :
254+ if self ._initialization_state not in (
255+ InitializationState .Closed ,
256+ InitializationState .Closing ,
257+ ):
258+ self ._transition_state (InitializationState .Closing )
259+ except RuntimeError :
260+ logger .debug ("Could not transition to Closing from %s" , self ._initialization_state .name )
261+ try :
262+ return await super ().__aexit__ (exc_type , exc_val , exc_tb )
263+ finally :
264+ try :
265+ self ._transition_state (InitializationState .Closed )
266+ except RuntimeError :
267+ logger .debug ("Could not transition to Closed from %s" , self ._initialization_state .name )
268+
165269 async def _received_request (self , responder : RequestResponder [types .ClientRequest , types .ServerResult ]):
166270 match responder .request :
167271 case types .InitializeRequest (params = params ):
168272 requested_version = params .protocol_version
169- self ._initialization_state = InitializationState .Initializing
273+ self ._transition_state ( InitializationState .Initializing )
170274 self ._client_params = params
171275 with responder :
172276 await responder .respond (
@@ -186,22 +290,27 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques
186290 instructions = self ._init_options .instructions ,
187291 )
188292 )
189- self ._initialization_state = InitializationState .Initialized
293+ self ._transition_state ( InitializationState .Initialized )
190294 case types .PingRequest ():
191295 # Ping requests are allowed at any time
192296 pass
193297 case _:
194- if self ._initialization_state != InitializationState . Initialized :
298+ if not self .is_initialized :
195299 raise RuntimeError ("Received request before initialization was complete" )
196300
197301 async def _received_notification (self , notification : types .ClientNotification ) -> None :
198302 # Need this to avoid ASYNC910
199303 await anyio .lowlevel .checkpoint ()
200304 match notification :
201305 case types .InitializedNotification ():
202- self ._initialization_state = InitializationState .Initialized
306+ # Transition to Initialized if not already there (e.g. stateless mode)
307+ if self ._initialization_state in (
308+ InitializationState .NotInitialized ,
309+ InitializationState .Initializing ,
310+ ):
311+ self ._transition_state (InitializationState .Initialized )
203312 case _:
204- if self ._initialization_state != InitializationState . Initialized : # pragma: no cover
313+ if not self .is_initialized : # pragma: no cover
205314 raise RuntimeError ("Received notification before initialization was complete" )
206315
207316 async def send_log_message (
0 commit comments