Skip to content

Commit e00f990

Browse files
committed
feat: expand InitializationState with explicit lifecycle state machine
Expand the InitializationState enum with new states (Stateless, Closing, Closed, Error) and add centralized transition validation via a _VALID_TRANSITIONS table and _transition_state() method. Key changes: - Add Stateless state for sessions that skip the initialization handshake - Add Closing/Closed states for orderly shutdown tracking - Add Error state for unrecoverable failures with recovery paths - Add _transition_state() method that validates transitions against a table - Add initialization_state property (read-only) and is_initialized property - Override __aexit__ to transition through Closing -> Closed on exit - Update _received_request and _received_notification to use new APIs - Add comprehensive test suite (20 tests) covering all state transitions Github-Issue: #1691
1 parent cf4e435 commit e00f990

2 files changed

Lines changed: 421 additions & 9 deletions

File tree

src/mcp/server/session.py

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
2828
be instantiated directly by users of the MCP framework.
2929
"""
3030

31+
import logging
3132
from enum import Enum
33+
from types import TracebackType
3234
from typing import Any, TypeVar, overload
3335

3436
import anyio
@@ -51,11 +53,59 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
5153
)
5254
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
5355

56+
logger = logging.getLogger(__name__)
57+
5458

5559
class InitializationState(Enum):
60+
"""Represents the lifecycle states of a server session.
61+
62+
State transitions:
63+
NotInitialized -> Initializing -> Initialized -> Closing -> Closed
64+
Stateless -> Closing -> Closed
65+
Any state -> Error (on unrecoverable failure)
66+
"""
67+
5668
NotInitialized = 1
5769
Initializing = 2
5870
Initialized = 3
71+
Stateless = 4
72+
Closing = 5
73+
Closed = 6
74+
Error = 7
75+
76+
77+
# Valid state transitions: maps each state to the set of states it can transition to.
78+
_VALID_TRANSITIONS: dict[InitializationState, set[InitializationState]] = {
79+
InitializationState.NotInitialized: {
80+
InitializationState.Initializing,
81+
InitializationState.Initialized, # client may send notification without prior request
82+
InitializationState.Closing,
83+
InitializationState.Error,
84+
},
85+
InitializationState.Initializing: {
86+
InitializationState.Initialized,
87+
InitializationState.Closing,
88+
InitializationState.Error,
89+
},
90+
InitializationState.Initialized: {
91+
InitializationState.Initializing, # re-initialization
92+
InitializationState.Closing,
93+
InitializationState.Error,
94+
},
95+
InitializationState.Stateless: {
96+
InitializationState.Closing,
97+
InitializationState.Error,
98+
},
99+
InitializationState.Closing: {
100+
InitializationState.Closed,
101+
InitializationState.Error,
102+
},
103+
InitializationState.Closed: set(),
104+
InitializationState.Error: {
105+
InitializationState.Closing,
106+
InitializationState.Closed,
107+
},
108+
}
59109

60110

61111
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
@@ -74,7 +124,7 @@ class ServerSession(
74124
types.ClientNotification,
75125
]
76126
):
77-
_initialized: InitializationState = InitializationState.NotInitialized
127+
_initialization_state: InitializationState = InitializationState.NotInitialized
78128
_client_params: types.InitializeRequestParams | None = None
79129
_experimental_features: ExperimentalServerSessionFeatures | None = None
80130

@@ -87,9 +137,7 @@ def __init__(
87137
) -> None:
88138
super().__init__(read_stream, write_stream)
89139
self._stateless = stateless
90-
self._initialization_state = (
91-
InitializationState.Initialized if stateless else InitializationState.NotInitialized
92-
)
140+
self._initialization_state = InitializationState.Stateless if stateless else InitializationState.NotInitialized
93141

94142
self._init_options = init_options
95143
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
@@ -105,6 +153,39 @@ def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]:
105153
def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification]:
106154
return types.client_notification_adapter
107155

156+
@property
157+
def initialization_state(self) -> InitializationState:
158+
"""Return the current initialization state of the session."""
159+
return self._initialization_state
160+
161+
@property
162+
def is_initialized(self) -> bool:
163+
"""Check whether the session is ready to process requests.
164+
165+
Returns True when the session has completed initialization handshake
166+
(Initialized) or is operating in stateless mode (Stateless).
167+
"""
168+
return self._initialization_state in (
169+
InitializationState.Initialized,
170+
InitializationState.Stateless,
171+
)
172+
173+
def _transition_state(self, new_state: InitializationState) -> None:
174+
"""Transition the session to a new state, validating the transition.
175+
176+
Args:
177+
new_state: The target state to transition to.
178+
179+
Raises:
180+
RuntimeError: If the transition is not valid from the current state.
181+
"""
182+
current = self._initialization_state
183+
valid_targets = _VALID_TRANSITIONS.get(current, set())
184+
if new_state not in valid_targets:
185+
raise RuntimeError(f"Invalid session state transition: {current.name} -> {new_state.name}")
186+
logger.debug("Session state transition: %s -> %s", current.name, new_state.name)
187+
self._initialization_state = new_state
188+
108189
@property
109190
def client_params(self) -> types.InitializeRequestParams | None:
110191
return self._client_params
@@ -162,11 +243,34 @@ async def _receive_loop(self) -> None:
162243
async with self._incoming_message_stream_writer:
163244
await super()._receive_loop()
164245

246+
async def __aexit__(
247+
self,
248+
exc_type: type[BaseException] | None,
249+
exc_val: BaseException | None,
250+
exc_tb: TracebackType | None,
251+
) -> bool | None:
252+
"""Clean up the session with proper state transitions."""
253+
try:
254+
if self._initialization_state not in (
255+
InitializationState.Closed,
256+
InitializationState.Closing,
257+
):
258+
self._transition_state(InitializationState.Closing)
259+
except RuntimeError:
260+
logger.debug("Could not transition to Closing from %s", self._initialization_state.name)
261+
try:
262+
return await super().__aexit__(exc_type, exc_val, exc_tb)
263+
finally:
264+
try:
265+
self._transition_state(InitializationState.Closed)
266+
except RuntimeError:
267+
logger.debug("Could not transition to Closed from %s", self._initialization_state.name)
268+
165269
async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
166270
match responder.request:
167271
case types.InitializeRequest(params=params):
168272
requested_version = params.protocol_version
169-
self._initialization_state = InitializationState.Initializing
273+
self._transition_state(InitializationState.Initializing)
170274
self._client_params = params
171275
with responder:
172276
await responder.respond(
@@ -186,22 +290,27 @@ async def _received_request(self, responder: RequestResponder[types.ClientReques
186290
instructions=self._init_options.instructions,
187291
)
188292
)
189-
self._initialization_state = InitializationState.Initialized
293+
self._transition_state(InitializationState.Initialized)
190294
case types.PingRequest():
191295
# Ping requests are allowed at any time
192296
pass
193297
case _:
194-
if self._initialization_state != InitializationState.Initialized:
298+
if not self.is_initialized:
195299
raise RuntimeError("Received request before initialization was complete")
196300

197301
async def _received_notification(self, notification: types.ClientNotification) -> None:
198302
# Need this to avoid ASYNC910
199303
await anyio.lowlevel.checkpoint()
200304
match notification:
201305
case types.InitializedNotification():
202-
self._initialization_state = InitializationState.Initialized
306+
# Transition to Initialized if not already there (e.g. stateless mode)
307+
if self._initialization_state in (
308+
InitializationState.NotInitialized,
309+
InitializationState.Initializing,
310+
):
311+
self._transition_state(InitializationState.Initialized)
203312
case _:
204-
if self._initialization_state != InitializationState.Initialized: # pragma: no cover
313+
if not self.is_initialized: # pragma: no cover
205314
raise RuntimeError("Received notification before initialization was complete")
206315

207316
async def send_log_message(

0 commit comments

Comments
 (0)