diff --git a/src/kimi_cli/wire/server.py b/src/kimi_cli/wire/server.py index fecc48c83..17275c955 100644 --- a/src/kimi_cli/wire/server.py +++ b/src/kimi_cli/wire/server.py @@ -79,6 +79,8 @@ def __init__(self, soul: Soul): # soul running stuffs self._soul = soul self._cancel_event: asyncio.Event | None = None + self._turn_active: bool = False + """True only while the soul is actively running a turn (between run_soul start and return).""" self._pending_requests: dict[str, Request] = {} """Maps JSON RPC message IDs to pending `Request`s.""" self._client_supports_question: bool = False @@ -540,6 +542,7 @@ async def _handle_prompt( ) self._cancel_event = asyncio.Event() + self._turn_active = True try: runtime = self._soul.runtime if isinstance(self._soul, KimiSoul) else None await run_soul( @@ -580,6 +583,9 @@ async def _handle_prompt( result={"status": Statuses.CANCELLED}, ) finally: + # Mark the turn as inactive immediately so concurrent steer/cancel + # handlers see the correct state before _cancel_event is cleared. + self._turn_active = False # Clean up any remaining pending requests from this turn. # After run_soul() returns, the soul and all subagents are done, # so any unresolved requests are stale. @@ -609,7 +615,7 @@ async def _handle_prompt( async def _handle_steer( self, msg: JSONRPCSteerMessage ) -> JSONRPCSuccessResponse | JSONRPCErrorResponse: - if not isinstance(self._soul, KimiSoul) or not self._is_streaming: + if not isinstance(self._soul, KimiSoul) or not self._turn_active: return JSONRPCErrorResponse( id=msg.id, error=JSONRPCErrorObject( diff --git a/tests/core/test_wire_server_steer.py b/tests/core/test_wire_server_steer.py index 15a35edce..02ee1497d 100644 --- a/tests/core/test_wire_server_steer.py +++ b/tests/core/test_wire_server_steer.py @@ -66,6 +66,7 @@ async def test_handle_steer_queues_input_when_streaming( monkeypatch.setattr(soul, "steer", lambda user_input: queued.append(user_input)) server._cancel_event = asyncio.Event() + server._turn_active = True response = await server._handle_steer( JSONRPCSteerMessage( @@ -79,6 +80,31 @@ async def test_handle_steer_queues_input_when_streaming( assert queued == [[TextPart(text="follow-up")]] +@pytest.mark.asyncio +async def test_handle_steer_rejects_after_turn_ends( + runtime: Runtime, + tmp_path: Path, +) -> None: + """Steer sent after run_soul() returns but before _cancel_event is cleared should be rejected.""" + soul = _make_soul(runtime, tmp_path) + server = WireServer(soul) + + # Simulate the state after run_soul() returns but before finally cleanup: + # _cancel_event is still set (not yet None), but _turn_active is False. + server._cancel_event = asyncio.Event() + server._turn_active = False + + response = await server._handle_steer( + JSONRPCSteerMessage( + id="1", + params=JSONRPCSteerMessage.Params(user_input=[TextPart(text="too late")]), + ) + ) + + assert isinstance(response, JSONRPCErrorResponse) + assert response.error.code == ErrorCodes.INVALID_STATE + + @pytest.mark.asyncio async def test_shutdown_rejects_foreground_approval_in_runtime( runtime: Runtime,