From 9143998569c8da76d0f6e7f857c429a81b65e28e Mon Sep 17 00:00:00 2001 From: smartchoice Date: Thu, 26 Mar 2026 15:02:53 +0800 Subject: [PATCH] fix(wire): reject steer messages after turn ends to prevent silent message loss MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The WireServer used `_is_streaming` (based on `_cancel_event is not None`) to gate steer acceptance. Since `_cancel_event` is only cleared in the `finally` block of `_handle_prompt()`, there was a window after `run_soul()` returns but before cleanup where steers would be accepted with `status=steered` but never consumed — silently lost. This adds a `_turn_active` flag that tracks the actual soul execution lifecycle, set to True at turn start and False immediately when `run_soul()` returns. `_handle_steer()` now checks `_turn_active` instead of `_is_streaming`. --- src/kimi_cli/wire/server.py | 8 +++++++- tests/core/test_wire_server_steer.py | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/kimi_cli/wire/server.py b/src/kimi_cli/wire/server.py index fecc48c83..17275c955 100644 --- a/src/kimi_cli/wire/server.py +++ b/src/kimi_cli/wire/server.py @@ -79,6 +79,8 @@ def __init__(self, soul: Soul): # soul running stuffs self._soul = soul self._cancel_event: asyncio.Event | None = None + self._turn_active: bool = False + """True only while the soul is actively running a turn (between run_soul start and return).""" self._pending_requests: dict[str, Request] = {} """Maps JSON RPC message IDs to pending `Request`s.""" self._client_supports_question: bool = False @@ -540,6 +542,7 @@ async def _handle_prompt( ) self._cancel_event = asyncio.Event() + self._turn_active = True try: runtime = self._soul.runtime if isinstance(self._soul, KimiSoul) else None await run_soul( @@ -580,6 +583,9 @@ async def _handle_prompt( result={"status": Statuses.CANCELLED}, ) finally: + # Mark the turn as inactive immediately so concurrent steer/cancel + # handlers see the correct state before _cancel_event is cleared. + self._turn_active = False # Clean up any remaining pending requests from this turn. # After run_soul() returns, the soul and all subagents are done, # so any unresolved requests are stale. @@ -609,7 +615,7 @@ async def _handle_prompt( async def _handle_steer( self, msg: JSONRPCSteerMessage ) -> JSONRPCSuccessResponse | JSONRPCErrorResponse: - if not isinstance(self._soul, KimiSoul) or not self._is_streaming: + if not isinstance(self._soul, KimiSoul) or not self._turn_active: return JSONRPCErrorResponse( id=msg.id, error=JSONRPCErrorObject( diff --git a/tests/core/test_wire_server_steer.py b/tests/core/test_wire_server_steer.py index 15a35edce..02ee1497d 100644 --- a/tests/core/test_wire_server_steer.py +++ b/tests/core/test_wire_server_steer.py @@ -66,6 +66,7 @@ async def test_handle_steer_queues_input_when_streaming( monkeypatch.setattr(soul, "steer", lambda user_input: queued.append(user_input)) server._cancel_event = asyncio.Event() + server._turn_active = True response = await server._handle_steer( JSONRPCSteerMessage( @@ -79,6 +80,31 @@ async def test_handle_steer_queues_input_when_streaming( assert queued == [[TextPart(text="follow-up")]] +@pytest.mark.asyncio +async def test_handle_steer_rejects_after_turn_ends( + runtime: Runtime, + tmp_path: Path, +) -> None: + """Steer sent after run_soul() returns but before _cancel_event is cleared should be rejected.""" + soul = _make_soul(runtime, tmp_path) + server = WireServer(soul) + + # Simulate the state after run_soul() returns but before finally cleanup: + # _cancel_event is still set (not yet None), but _turn_active is False. + server._cancel_event = asyncio.Event() + server._turn_active = False + + response = await server._handle_steer( + JSONRPCSteerMessage( + id="1", + params=JSONRPCSteerMessage.Params(user_input=[TextPart(text="too late")]), + ) + ) + + assert isinstance(response, JSONRPCErrorResponse) + assert response.error.code == ErrorCodes.INVALID_STATE + + @pytest.mark.asyncio async def test_shutdown_rejects_foreground_approval_in_runtime( runtime: Runtime,