From d470a0c4f4c0de8e0b55945e1e4f4b86ec3dd85c Mon Sep 17 00:00:00 2001 From: nic Date: Mon, 30 Mar 2026 16:16:43 +0800 Subject: [PATCH 1/6] feat(web): add embedded session runtime --- src/kimi_cli/web/app.py | 19 ++- src/kimi_cli/web/runner/embedded_process.py | 147 +++++++++++++++++++ src/kimi_cli/web/runner/embedded_worker.py | 150 ++++++++++++++++++++ src/kimi_cli/web/runner/process.py | 60 ++++---- tests/web/test_runner.py | 63 ++++++++ 5 files changed, 410 insertions(+), 29 deletions(-) create mode 100644 src/kimi_cli/web/runner/embedded_process.py create mode 100644 src/kimi_cli/web/runner/embedded_worker.py create mode 100644 tests/web/test_runner.py diff --git a/src/kimi_cli/web/app.py b/src/kimi_cli/web/app.py index 013ddfa6a..8adc2c042 100644 --- a/src/kimi_cli/web/app.py +++ b/src/kimi_cli/web/app.py @@ -7,7 +7,7 @@ from collections.abc import Callable from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, cast +from typing import Any, Literal, cast from urllib.parse import quote import scalar_fastapi @@ -63,6 +63,8 @@ ENV_ENFORCE_ORIGIN = "KIMI_WEB_ENFORCE_ORIGIN" ENV_RESTRICT_SENSITIVE_APIS = "KIMI_WEB_RESTRICT_SENSITIVE_APIS" ENV_MAX_PUBLIC_PATH_DEPTH = "KIMI_WEB_MAX_PUBLIC_PATH_DEPTH" +ENV_RUNTIME = "KIMI_WEB_RUNTIME" +RuntimeMode = Literal["process", "embedded"] # Cache durations _IMMUTABLE_MAX_AGE = 365 * 24 * 3600 # 1 year for content-hashed assets @@ -109,6 +111,18 @@ def _load_env_flag(key: str) -> bool: return os.environ.get(key, "").strip().lower() in {"1", "true", "yes", "on"} +def _load_runtime_mode() -> RuntimeMode: + runtime = os.environ.get(ENV_RUNTIME, "process").strip().lower() or "process" + if runtime not in {"process", "embedded"}: + logger.warning( + "Invalid {env}={value}, falling back to process", + env=ENV_RUNTIME, + value=runtime, + ) + return "process" + return cast(RuntimeMode, runtime) + + ENV_LAN_ONLY = "KIMI_WEB_LAN_ONLY" @@ -131,6 +145,7 @@ def create_app( int(env_max_depth_str) if env_max_depth_str and env_max_depth_str.isdigit() else None ) env_lan_only = _load_env_flag(ENV_LAN_ONLY) + env_runtime = _load_runtime_mode() session_token = session_token if session_token is not None else env_token allowed_origins = allowed_origins if allowed_origins is not None else env_origins @@ -154,7 +169,7 @@ async def lifespan(app: FastAPI): app.state.lan_only = lan_only # Start KimiCLI runner - runner = KimiCLIRunner() + runner = KimiCLIRunner(runtime_mode=env_runtime) app.state.runner = runner runner.start() diff --git a/src/kimi_cli/web/runner/embedded_process.py b/src/kimi_cli/web/runner/embedded_process.py new file mode 100644 index 000000000..957782fef --- /dev/null +++ b/src/kimi_cli/web/runner/embedded_process.py @@ -0,0 +1,147 @@ +"""Embedded session process for the Kimi CLI web interface.""" + +from __future__ import annotations + +import json +import time +from typing import Any +from uuid import UUID, uuid4 + +from kimi_cli import logger +from kimi_cli.app import KimiCLI +from kimi_cli.cli.mcp import get_global_mcp_config_file +from kimi_cli.exception import MCPConfigError +from kimi_cli.web.runner.embedded_worker import EmbeddedWireWorker +from kimi_cli.web.runner.process import SessionProcess +from kimi_cli.web.store.sessions import load_session_by_id +from kimi_cli.wire.jsonrpc import ( + JSONRPCCancelMessage, + JSONRPCInMessageAdapter, + JSONRPCPromptMessage, + JSONRPCSuccessResponse, +) + + +async def _create_kimi_cli_for_session(session_id: UUID) -> KimiCLI: + joint_session = load_session_by_id(session_id) + if joint_session is None: + raise ValueError(f"Session not found: {session_id}") + + session = joint_session.kimi_cli_session + + default_mcp_file = get_global_mcp_config_file() + mcp_configs: list[dict[str, Any]] = [] + if default_mcp_file.exists(): + raw = default_mcp_file.read_text(encoding="utf-8") + try: + mcp_configs = [json.loads(raw)] + except json.JSONDecodeError: + logger.warning( + "Invalid JSON in MCP config file: {path}", + path=default_mcp_file, + ) + + try: + return await KimiCLI.create(session, mcp_configs=mcp_configs or None) + except MCPConfigError as exc: + logger.warning( + "Invalid MCP config in {path}: {error}. Starting without MCP.", + path=default_mcp_file, + error=exc, + ) + return await KimiCLI.create(session, mcp_configs=None) + + +class EmbeddedSessionProcess(SessionProcess): + """Manage one session using an in-process wire worker.""" + + def __init__(self, session_id: UUID) -> None: + super().__init__(session_id) + self._worker: EmbeddedWireWorker | None = None + + @property + def is_alive(self) -> bool: + return self._worker is not None + + async def start( + self, + *, + reason: str | None = None, + detail: str | None = None, + restart_started_at: float | None = None, + ) -> None: + """Start the embedded worker.""" + async with self._lock: + if self.is_alive: + return + + self._in_flight_prompt_ids.clear() + self._worker_id = str(uuid4()) + + try: + kimi_cli = await _create_kimi_cli_for_session(self.session_id) + worker = EmbeddedWireWorker( + kimi_cli, + emit_json=self._process_worker_output_line, + ) + await worker.start() + except Exception: + self._worker_id = None + raise + + self._worker = worker + + if restart_started_at is not None: + elapsed_ms = int((time.perf_counter() - restart_started_at) * 1000) + detail = f"restart_ms={elapsed_ms}" + await self._emit_status("idle", reason=reason or "start", detail=detail) + await self._emit_restart_notice(reason=reason, restart_ms=elapsed_ms) + else: + await self._emit_status("idle", reason=reason or "start", detail=None) + + async def stop_worker( + self, + *, + reason: str | None = None, + emit_status: bool = True, + ) -> None: + """Stop only the embedded worker, keeping WebSockets connected.""" + async with self._lock: + worker = self._worker + self._worker = None + + if worker is not None: + await worker.stop() + + self._in_flight_prompt_ids.clear() + self._worker_id = None + if emit_status: + await self._emit_status("stopped", reason=reason or "stop") + + async def send_message(self, message: str) -> None: + """Send a message to the embedded worker.""" + await self.start() + worker = self._worker + assert worker is not None + + try: + in_message = JSONRPCInMessageAdapter.validate_json(message) + if isinstance(in_message, JSONRPCPromptMessage): + was_busy = self.is_busy + self._in_flight_prompt_ids.add(in_message.id) + if not was_busy: + await self._emit_status("busy", reason="prompt") + elif isinstance(in_message, JSONRPCCancelMessage) and not self.is_busy: + await self._broadcast( + JSONRPCSuccessResponse(id=in_message.id, result={}).model_dump_json() + ) + return + + new_message = await self._handle_in_message(in_message) + if new_message is not None: + message = new_message + except ValueError as e: + logger.error(f"{e.__class__.__name__} {e}: Invalid JSONRPC in message: {message}") + return + + await worker.handle_message(message) diff --git a/src/kimi_cli/web/runner/embedded_worker.py b/src/kimi_cli/web/runner/embedded_worker.py new file mode 100644 index 000000000..31645eadf --- /dev/null +++ b/src/kimi_cli/web/runner/embedded_worker.py @@ -0,0 +1,150 @@ +"""In-process wire worker for the Kimi CLI web interface.""" + +from __future__ import annotations + +import asyncio +import json +from collections.abc import Awaitable, Callable +from typing import Any + +import pydantic + +from kimi_cli.app import KimiCLI +from kimi_cli.soul.kimisoul import KimiSoul +from kimi_cli.wire.jsonrpc import ( + ErrorCodes, + JSONRPCErrorObject, + JSONRPCErrorResponse, + JSONRPCErrorResponseNullableID, + JSONRPCInMessageAdapter, + JSONRPCMessage, + JSONRPCOutMessage, +) +from kimi_cli.wire.server import WireServer + +_PROMPT_ENV_LOCK = asyncio.Lock() + + +class EmbeddedWireWorker(WireServer): + """Run the wire protocol against a local ``KimiCLI`` instance.""" + + def __init__( + self, + kimi_cli: KimiCLI, + *, + emit_json: Callable[[str], Awaitable[None]], + ) -> None: + super().__init__(kimi_cli.soul) + self._kimi_cli = kimi_cli + self._emit_json = emit_json + self._emit_lock = asyncio.Lock() + + async def start(self) -> None: + """Start the in-process worker.""" + if isinstance(self._soul, KimiSoul) and self._root_hub_task is None: + runtime = self._kimi_cli.soul.runtime + if runtime.root_wire_hub is not None: + self._root_hub_queue = runtime.root_wire_hub.subscribe() + self._root_hub_task = asyncio.create_task(self._root_hub_loop()) + + async def stop(self) -> None: + """Stop the in-process worker.""" + await self._shutdown() + + async def handle_message(self, message: str) -> None: + """Handle a JSON-RPC message from the web client.""" + try: + msg_json = json.loads(message) + except ValueError: + await self._emit_out_message( + JSONRPCErrorResponseNullableID( + id=None, + error=JSONRPCErrorObject( + code=ErrorCodes.PARSE_ERROR, + message="Invalid JSON format", + ), + ) + ) + return + + try: + generic_msg = JSONRPCMessage.model_validate(msg_json) + except pydantic.ValidationError: + await self._emit_out_message( + JSONRPCErrorResponseNullableID( + id=None, + error=JSONRPCErrorObject( + code=ErrorCodes.INVALID_REQUEST, + message="Invalid request", + ), + ) + ) + return + + if generic_msg.is_response(): + try: + msg = JSONRPCInMessageAdapter.validate_python(msg_json) + except pydantic.ValidationError: + await self._emit_out_message( + JSONRPCErrorResponseNullableID( + id=None, + error=JSONRPCErrorObject( + code=ErrorCodes.INVALID_REQUEST, + message="Invalid response", + ), + ) + ) + return + self._dispatch(msg) + return + + if not generic_msg.method_is_inbound(): + if generic_msg.id is not None: + await self._emit_out_message( + JSONRPCErrorResponse( + id=generic_msg.id, + error=JSONRPCErrorObject( + code=ErrorCodes.METHOD_NOT_FOUND, + message=f"Unexpected method received: {generic_msg.method}", + ), + ) + ) + return + + try: + msg = JSONRPCInMessageAdapter.validate_python(msg_json) + except pydantic.ValidationError: + if generic_msg.id is not None: + await self._emit_out_message( + JSONRPCErrorResponse( + id=generic_msg.id, + error=JSONRPCErrorObject( + code=ErrorCodes.INVALID_PARAMS, + message=f"Invalid parameters for method `{generic_msg.method}`", + ), + ) + ) + return + + self._dispatch(msg) + + def _dispatch(self, msg: Any) -> None: + task = asyncio.create_task(self._dispatch_msg(msg)) + task.add_done_callback(self._dispatch_tasks.discard) + self._dispatch_tasks.add(task) + + async def _send_msg(self, msg: JSONRPCOutMessage) -> None: + await self._emit_out_message(msg) + + async def _emit_out_message(self, msg: Any) -> None: + payload = msg.model_dump_json() + async with self._emit_lock: + await self._emit_json(payload) + + async def _handle_prompt(self, msg): # type: ignore[override] + # ``KimiCLI.run_wire_stdio()`` normally keeps the entire wire server inside + # ``KimiCLI._env()``. For the embedded worker we scope that environment to + # each foreground turn and serialize embedded prompts to avoid cross-session + # cwd races from ``kaos.chdir()``. + async with _PROMPT_ENV_LOCK, self._kimi_cli._env(): + return await super()._handle_prompt(msg) diff --git a/src/kimi_cli/web/runner/process.py b/src/kimi_cli/web/runner/process.py index 91bfe7eb0..805bd100d 100644 --- a/src/kimi_cli/web/runner/process.py +++ b/src/kimi_cli/web/runner/process.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path +from typing import Literal from uuid import UUID, uuid4 from kosong.message import ContentPart, ImageURLPart, TextPart @@ -325,37 +326,34 @@ async def _read_loop(self) -> None: else: continue - await self._broadcast(line.decode("utf-8").rstrip("\n")) - - # Handle out message - try: - msg = json.loads(line) - match msg.get("method"): - case "event": - msg["params"] = deserialize_wire_message(msg["params"]) - await self._handle_out_message(JSONRPCEventMessage.model_validate(msg)) - case "request": - msg["params"] = deserialize_wire_message(msg["params"]) - await self._handle_out_message( - JSONRPCRequestMessage.model_validate(msg) - ) - case _: - if msg.get("error"): - await self._handle_out_message( - JSONRPCErrorResponse.model_validate(msg) - ) - else: - await self._handle_out_message( - JSONRPCSuccessResponse.model_validate(msg) - ) - except json.JSONDecodeError: - logger.error(f"Invalid JSONRPC out message: {line}") + await self._process_worker_output_line(line.decode("utf-8").rstrip("\n")) except asyncio.CancelledError: raise except Exception as e: logger.warning(f"Unexpected error in read loop: {e.__class__.__name__} {e}") + async def _process_worker_output_line(self, message: str) -> None: + """Broadcast and process one JSON-RPC line from a worker.""" + await self._broadcast(message) + + try: + msg = json.loads(message) + match msg.get("method"): + case "event": + msg["params"] = deserialize_wire_message(msg["params"]) + await self._handle_out_message(JSONRPCEventMessage.model_validate(msg)) + case "request": + msg["params"] = deserialize_wire_message(msg["params"]) + await self._handle_out_message(JSONRPCRequestMessage.model_validate(msg)) + case _: + if msg.get("error"): + await self._handle_out_message(JSONRPCErrorResponse.model_validate(msg)) + else: + await self._handle_out_message(JSONRPCSuccessResponse.model_validate(msg)) + except json.JSONDecodeError: + logger.error(f"Invalid JSONRPC out message: {message}") + async def _handle_out_message(self, message: JSONRPCOutMessage) -> None: """Handle outbound message from worker.""" match message: @@ -659,8 +657,9 @@ async def send_message(self, message: str) -> None: class KimiCLIRunner: """Manages multiple session processes.""" - def __init__(self) -> None: + def __init__(self, *, runtime_mode: Literal["process", "embedded"] = "process") -> None: """Initialize the runner.""" + self._runtime_mode = runtime_mode self._sessions: dict[UUID, SessionProcess] = {} self._lock = asyncio.Lock() @@ -685,9 +684,16 @@ async def get_or_create_session(self, session_id: UUID) -> SessionProcess: """Get or create a session process.""" async with self._lock: if session_id not in self._sessions: - self._sessions[session_id] = SessionProcess(session_id) + self._sessions[session_id] = self._create_session_process(session_id) return self._sessions[session_id] + def _create_session_process(self, session_id: UUID) -> SessionProcess: + if self._runtime_mode == "embedded": + from kimi_cli.web.runner.embedded_process import EmbeddedSessionProcess + + return EmbeddedSessionProcess(session_id) + return SessionProcess(session_id) + def get_session(self, session_id: UUID) -> SessionProcess | None: """Get a session process if it exists.""" return self._sessions.get(session_id) diff --git a/tests/web/test_runner.py b/tests/web/test_runner.py new file mode 100644 index 000000000..ab9956d98 --- /dev/null +++ b/tests/web/test_runner.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from uuid import uuid4 + +from fastapi.testclient import TestClient + +from kimi_cli.web import app as web_app +from kimi_cli.web.runner.embedded_process import EmbeddedSessionProcess +from kimi_cli.web.runner.process import KimiCLIRunner + + +def test_runner_creates_embedded_session_process() -> None: + runner = KimiCLIRunner(runtime_mode="embedded") + + session_process = runner._create_session_process(uuid4()) + + assert isinstance(session_process, EmbeddedSessionProcess) + + +def test_create_app_uses_embedded_runtime_from_env(monkeypatch) -> None: + captured: dict[str, str] = {} + + class FakeRunner: + def __init__(self, *, runtime_mode: str) -> None: + captured["runtime_mode"] = runtime_mode + + def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + monkeypatch.setenv(web_app.ENV_RUNTIME, "embedded") + monkeypatch.setattr(web_app, "KimiCLIRunner", FakeRunner) + + with TestClient(web_app.create_app()) as client: + response = client.get("/healthz") + + assert response.status_code == 200 + assert captured["runtime_mode"] == "embedded" + + +def test_create_app_invalid_runtime_falls_back_to_process(monkeypatch) -> None: + captured: dict[str, str] = {} + + class FakeRunner: + def __init__(self, *, runtime_mode: str) -> None: + captured["runtime_mode"] = runtime_mode + + def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + monkeypatch.setenv(web_app.ENV_RUNTIME, "invalid") + monkeypatch.setattr(web_app, "KimiCLIRunner", FakeRunner) + + with TestClient(web_app.create_app()) as client: + response = client.get("/healthz") + + assert response.status_code == 200 + assert captured["runtime_mode"] == "process" From 47bc2f699d3f83740bdedca728503dd8e99d02ab Mon Sep 17 00:00:00 2001 From: nic Date: Mon, 30 Mar 2026 17:05:40 +0800 Subject: [PATCH 2/6] feat(web): default web runtime to embedded --- src/kimi_cli/web/app.py | 6 +++--- src/kimi_cli/web/runner/process.py | 2 +- tests/web/test_runner.py | 30 ++++++++++++++++++++++++++---- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/src/kimi_cli/web/app.py b/src/kimi_cli/web/app.py index 8adc2c042..61daf818a 100644 --- a/src/kimi_cli/web/app.py +++ b/src/kimi_cli/web/app.py @@ -112,14 +112,14 @@ def _load_env_flag(key: str) -> bool: def _load_runtime_mode() -> RuntimeMode: - runtime = os.environ.get(ENV_RUNTIME, "process").strip().lower() or "process" + runtime = os.environ.get(ENV_RUNTIME, "embedded").strip().lower() or "embedded" if runtime not in {"process", "embedded"}: logger.warning( - "Invalid {env}={value}, falling back to process", + "Invalid {env}={value}, falling back to embedded", env=ENV_RUNTIME, value=runtime, ) - return "process" + return "embedded" return cast(RuntimeMode, runtime) diff --git a/src/kimi_cli/web/runner/process.py b/src/kimi_cli/web/runner/process.py index 805bd100d..0011e8da0 100644 --- a/src/kimi_cli/web/runner/process.py +++ b/src/kimi_cli/web/runner/process.py @@ -657,7 +657,7 @@ async def send_message(self, message: str) -> None: class KimiCLIRunner: """Manages multiple session processes.""" - def __init__(self, *, runtime_mode: Literal["process", "embedded"] = "process") -> None: + def __init__(self, *, runtime_mode: Literal["process", "embedded"] = "embedded") -> None: """Initialize the runner.""" self._runtime_mode = runtime_mode self._sessions: dict[UUID, SessionProcess] = {} diff --git a/tests/web/test_runner.py b/tests/web/test_runner.py index ab9956d98..c87cfa770 100644 --- a/tests/web/test_runner.py +++ b/tests/web/test_runner.py @@ -17,7 +17,7 @@ def test_runner_creates_embedded_session_process() -> None: assert isinstance(session_process, EmbeddedSessionProcess) -def test_create_app_uses_embedded_runtime_from_env(monkeypatch) -> None: +def test_create_app_defaults_to_embedded_runtime(monkeypatch) -> None: captured: dict[str, str] = {} class FakeRunner: @@ -30,7 +30,6 @@ def start(self) -> None: async def stop(self) -> None: pass - monkeypatch.setenv(web_app.ENV_RUNTIME, "embedded") monkeypatch.setattr(web_app, "KimiCLIRunner", FakeRunner) with TestClient(web_app.create_app()) as client: @@ -40,7 +39,7 @@ async def stop(self) -> None: assert captured["runtime_mode"] == "embedded" -def test_create_app_invalid_runtime_falls_back_to_process(monkeypatch) -> None: +def test_create_app_explicit_process_runtime(monkeypatch) -> None: captured: dict[str, str] = {} class FakeRunner: @@ -53,7 +52,7 @@ def start(self) -> None: async def stop(self) -> None: pass - monkeypatch.setenv(web_app.ENV_RUNTIME, "invalid") + monkeypatch.setenv(web_app.ENV_RUNTIME, "process") monkeypatch.setattr(web_app, "KimiCLIRunner", FakeRunner) with TestClient(web_app.create_app()) as client: @@ -61,3 +60,26 @@ async def stop(self) -> None: assert response.status_code == 200 assert captured["runtime_mode"] == "process" + + +def test_create_app_invalid_runtime_falls_back_to_embedded(monkeypatch) -> None: + captured: dict[str, str] = {} + + class FakeRunner: + def __init__(self, *, runtime_mode: str) -> None: + captured["runtime_mode"] = runtime_mode + + def start(self) -> None: + pass + + async def stop(self) -> None: + pass + + monkeypatch.setenv(web_app.ENV_RUNTIME, "invalid") + monkeypatch.setattr(web_app, "KimiCLIRunner", FakeRunner) + + with TestClient(web_app.create_app()) as client: + response = client.get("/healthz") + + assert response.status_code == 200 + assert captured["runtime_mode"] == "embedded" From dd248cdf91c336d4bcf25817ca9b7f1ec4570d9e Mon Sep 17 00:00:00 2001 From: nic Date: Mon, 30 Mar 2026 17:37:11 +0800 Subject: [PATCH 3/6] fix(web): avoid protected env usage in embedded worker --- src/kimi_cli/app.py | 16 +++++++++++----- src/kimi_cli/web/runner/embedded_worker.py | 4 ++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/kimi_cli/app.py b/src/kimi_cli/app.py index 40465f191..28b9596eb 100644 --- a/src/kimi_cli/app.py +++ b/src/kimi_cli/app.py @@ -298,6 +298,12 @@ async def _env(self) -> AsyncGenerator[None]: finally: await kaos.chdir(original_cwd) + @contextlib.asynccontextmanager + async def env(self) -> AsyncGenerator[None]: + """Run with the session working directory and refreshed runtime auth.""" + async with self._env(): + yield + async def run( self, user_input: str | list[ContentPart], @@ -322,7 +328,7 @@ async def run( MaxStepsReached: When the maximum number of steps is reached. RunCancelled: When the run is cancelled by the cancel event. """ - async with self._env(): + async with self.env(): wire_future = asyncio.Future[WireUISide]() stop_ui_loop = asyncio.Event() approval_bridge_tasks: dict[str, asyncio.Task[None]] = {} @@ -513,7 +519,7 @@ async def run_shell( level=WelcomeInfoItem.Level.INFO, ) ) - async with self._env(): + async with self.env(): shell = Shell(self._soul, welcome_info=welcome_info, prefill_text=prefill_text) return await shell.run(command) @@ -528,7 +534,7 @@ async def run_print( """Run the Kimi Code CLI instance with print UI.""" from kimi_cli.ui.print import Print - async with self._env(): + async with self.env(): print_ = Print( self._soul, input_format, @@ -542,7 +548,7 @@ async def run_acp(self) -> None: """Run the Kimi Code CLI instance as ACP server.""" from kimi_cli.ui.acp import ACP - async with self._env(): + async with self.env(): acp = ACP(self._soul) await acp.run() @@ -550,6 +556,6 @@ async def run_wire_stdio(self) -> None: """Run the Kimi Code CLI instance as Wire server over stdio.""" from kimi_cli.wire.server import WireServer - async with self._env(): + async with self.env(): server = WireServer(self._soul) await server.serve() diff --git a/src/kimi_cli/web/runner/embedded_worker.py b/src/kimi_cli/web/runner/embedded_worker.py index 31645eadf..7bdf34425 100644 --- a/src/kimi_cli/web/runner/embedded_worker.py +++ b/src/kimi_cli/web/runner/embedded_worker.py @@ -143,8 +143,8 @@ async def _emit_out_message(self, msg: Any) -> None: async def _handle_prompt(self, msg): # type: ignore[override] # ``KimiCLI.run_wire_stdio()`` normally keeps the entire wire server inside - # ``KimiCLI._env()``. For the embedded worker we scope that environment to + # the session environment context. For the embedded worker we scope that environment to # each foreground turn and serialize embedded prompts to avoid cross-session # cwd races from ``kaos.chdir()``. - async with _PROMPT_ENV_LOCK, self._kimi_cli._env(): + async with _PROMPT_ENV_LOCK, self._kimi_cli.env(): return await super()._handle_prompt(msg) From b85cef11153a0c3cbfcc4ae203c61460770180f9 Mon Sep 17 00:00:00 2001 From: nic Date: Mon, 30 Mar 2026 18:27:09 +0800 Subject: [PATCH 4/6] fix(web): scope kaos cwd per session --- packages/kaos/src/kaos/local.py | 123 +++++++++++++++++++++ packages/kaos/tests/test_local_kaos.py | 44 +++++++- src/kimi_cli/app.py | 15 ++- src/kimi_cli/soul/kimisoul.py | 13 ++- src/kimi_cli/soul/toolset.py | 8 +- src/kimi_cli/subagents/runner.py | 5 +- src/kimi_cli/web/runner/embedded_worker.py | 8 +- 7 files changed, 191 insertions(+), 25 deletions(-) diff --git a/packages/kaos/src/kaos/local.py b/packages/kaos/src/kaos/local.py index 7ab2f7cd0..d727defca 100644 --- a/packages/kaos/src/kaos/local.py +++ b/packages/kaos/src/kaos/local.py @@ -5,6 +5,7 @@ from asyncio.subprocess import Process as AsyncioProcess from collections.abc import AsyncGenerator from pathlib import Path, PurePath +from stat import S_ISDIR from typing import TYPE_CHECKING, Literal if os.name == "nt": @@ -176,5 +177,127 @@ async def exec(self, *args: str, env: Mapping[str, str] | None = None) -> KaosPr return self.Process(process) +class ScopedLocalKaos(LocalKaos): + """Local KAOS backend with an instance-local working directory.""" + + def __init__(self, cwd: StrOrKaosPath | None = None) -> None: + base_cwd = Path.cwd() + self._cwd = self._normalize_local_path( + self._coerce_local_path(cwd) if cwd is not None else base_cwd, + base=base_cwd, + ) + + @staticmethod + def _coerce_local_path(path: StrOrKaosPath) -> Path: + return path.unsafe_to_local_path() if isinstance(path, KaosPath) else Path(path) + + def _normalize_local_path(self, path: Path, *, base: Path | None = None) -> Path: + if not path.is_absolute(): + path = (base or self._cwd) / path + return Path(pathmodule.normpath(str(path))) + + def _resolve_local_path(self, path: StrOrKaosPath) -> Path: + return self._normalize_local_path(self._coerce_local_path(path)) + + def _resolve_kaos_path(self, path: StrOrKaosPath) -> KaosPath: + return KaosPath.unsafe_from_local_path(self._resolve_local_path(path)) + + def getcwd(self) -> KaosPath: + return KaosPath.unsafe_from_local_path(self._cwd) + + async def chdir(self, path: StrOrKaosPath) -> None: + local_path = self._resolve_local_path(path) + st = await aiofiles.os.stat(local_path) + if not S_ISDIR(st.st_mode): + raise NotADirectoryError(str(local_path)) + self._cwd = local_path + + async def stat(self, path: StrOrKaosPath, *, follow_symlinks: bool = True) -> StatResult: + return await super().stat(self._resolve_kaos_path(path), follow_symlinks=follow_symlinks) + + async def iterdir(self, path: StrOrKaosPath) -> AsyncGenerator[KaosPath]: + async for entry in super().iterdir(self._resolve_kaos_path(path)): + yield entry + + async def glob( + self, path: StrOrKaosPath, pattern: str, *, case_sensitive: bool = True + ) -> AsyncGenerator[KaosPath]: + async for entry in super().glob( + self._resolve_kaos_path(path), + pattern, + case_sensitive=case_sensitive, + ): + yield entry + + async def readbytes(self, path: StrOrKaosPath, n: int | None = None) -> bytes: + return await super().readbytes(self._resolve_kaos_path(path), n=n) + + async def readtext( + self, + path: str | KaosPath, + *, + encoding: str = "utf-8", + errors: Literal["strict", "ignore", "replace"] = "strict", + ) -> str: + return await super().readtext( + self._resolve_kaos_path(path), + encoding=encoding, + errors=errors, + ) + + async def readlines( + self, + path: str | KaosPath, + *, + encoding: str = "utf-8", + errors: Literal["strict", "ignore", "replace"] = "strict", + ) -> AsyncGenerator[str]: + async for line in super().readlines( + self._resolve_kaos_path(path), + encoding=encoding, + errors=errors, + ): + yield line + + async def writebytes(self, path: StrOrKaosPath, data: bytes) -> int: + return await super().writebytes(self._resolve_kaos_path(path), data) + + async def writetext( + self, + path: str | KaosPath, + data: str, + *, + mode: Literal["w"] | Literal["a"] = "w", + encoding: str = "utf-8", + errors: Literal["strict", "ignore", "replace"] = "strict", + ) -> int: + return await super().writetext( + self._resolve_kaos_path(path), + data, + mode=mode, + encoding=encoding, + errors=errors, + ) + + async def mkdir( + self, path: StrOrKaosPath, parents: bool = False, exist_ok: bool = False + ) -> None: + await super().mkdir(self._resolve_kaos_path(path), parents=parents, exist_ok=exist_ok) + + async def exec(self, *args: str, env: Mapping[str, str] | None = None) -> KaosProcess: + if not args: + raise ValueError("At least one argument (the program to execute) is required.") + + process = await asyncio.create_subprocess_exec( + *args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=str(self._cwd), + env=env, + ) + return self.Process(process) + + local_kaos = LocalKaos() """The default local KAOS instance.""" diff --git a/packages/kaos/tests/test_local_kaos.py b/packages/kaos/tests/test_local_kaos.py index 80709f4f5..5f04b4f33 100644 --- a/packages/kaos/tests/test_local_kaos.py +++ b/packages/kaos/tests/test_local_kaos.py @@ -9,7 +9,7 @@ import pytest from kaos import reset_current_kaos, set_current_kaos -from kaos.local import LocalKaos +from kaos.local import LocalKaos, ScopedLocalKaos from kaos.path import KaosPath @@ -196,3 +196,45 @@ async def test_exec_wait_timeout(local_kaos: LocalKaos): if process.returncode is None: await process.kill() await process.wait() + + +@pytest.fixture +def scoped_local_kaos(tmp_path: Path) -> Generator[ScopedLocalKaos]: + """Set a scoped local Kaos as current without mutating process cwd.""" + scoped = ScopedLocalKaos(tmp_path) + token = set_current_kaos(scoped) + try: + yield scoped + finally: + reset_current_kaos(token) + + +async def test_scoped_local_kaos_tracks_cwd_without_process_chdir( + scoped_local_kaos: ScopedLocalKaos, +): + original_cwd = Path.cwd() + initial_scoped_cwd = str(scoped_local_kaos.getcwd()) + await scoped_local_kaos.mkdir("nested") + + await scoped_local_kaos.chdir("nested") + + assert Path.cwd() == original_cwd + assert str(scoped_local_kaos.getcwd()) == str(Path(initial_scoped_cwd) / "nested") + + +async def test_scoped_local_kaos_exec_respects_scoped_cwd(scoped_local_kaos: ScopedLocalKaos): + await scoped_local_kaos.mkdir("nested") + await scoped_local_kaos.chdir("nested") + + process = await scoped_local_kaos.exec( + *( + sys.executable, + "-c", + "from pathlib import Path; print(Path.cwd()); Path('scoped.txt').write_text('ok')", + ) + ) + + stdout_data = await process.stdout.read() + assert await process.wait() == 0 + assert stdout_data.decode("utf-8").strip() == str(scoped_local_kaos.getcwd()) + assert await (scoped_local_kaos.getcwd() / "scoped.txt").is_file() diff --git a/src/kimi_cli/app.py b/src/kimi_cli/app.py index 28b9596eb..8c39c1cbd 100644 --- a/src/kimi_cli/app.py +++ b/src/kimi_cli/app.py @@ -8,7 +8,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -import kaos +from kaos import get_current_kaos, reset_current_kaos, set_current_kaos +from kaos.local import ScopedLocalKaos, local_kaos from kaos.path import KaosPath from pydantic import SecretStr @@ -256,17 +257,19 @@ async def create( soul.set_hook_engine(hook_engine) runtime.hook_engine = hook_engine - return KimiCLI(soul, runtime, env_overrides) + return KimiCLI(soul, runtime, env_overrides, ScopedLocalKaos(session.work_dir)) def __init__( self, _soul: KimiSoul, _runtime: Runtime, _env_overrides: dict[str, str], + _kaos_backend: ScopedLocalKaos | None = None, ) -> None: self._soul = _soul self._runtime = _runtime self._env_overrides = _env_overrides + self._kaos_backend = _kaos_backend or ScopedLocalKaos(_runtime.session.work_dir) @property def soul(self) -> KimiSoul: @@ -288,15 +291,17 @@ def shutdown_background_tasks(self) -> None: @contextlib.asynccontextmanager async def _env(self) -> AsyncGenerator[None]: - original_cwd = KaosPath.cwd() - await kaos.chdir(self._runtime.session.work_dir) + kaos_token = None + if get_current_kaos() is local_kaos: + kaos_token = set_current_kaos(self._kaos_backend) try: # to ignore possible warnings from dateparser warnings.filterwarnings("ignore", category=DeprecationWarning) async with self._runtime.oauth.refreshing(self._runtime): yield finally: - await kaos.chdir(original_cwd) + if kaos_token is not None: + reset_current_kaos(kaos_token) @contextlib.asynccontextmanager async def env(self) -> AsyncGenerator[None]: diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py index 697c1c9cd..63bf3104b 100644 --- a/src/kimi_cli/soul/kimisoul.py +++ b/src/kimi_cli/soul/kimisoul.py @@ -10,6 +10,7 @@ import kosong import tenacity +from kaos.path import KaosPath from kosong import StepResult from kosong.chat_provider import ( APIConnectionError, @@ -482,7 +483,7 @@ async def run(self, user_input: str | list[ContentPart]): matcher_value=text_input_for_hook, input_data=events.user_prompt_submit( session_id=self._runtime.session.id, - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), prompt=text_input_for_hook, ), ) @@ -521,7 +522,7 @@ async def run(self, user_input: str | list[ContentPart]): "Stop", input_data=events.stop( session_id=self._runtime.session.id, - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), stop_hook_active=False, ), ) @@ -714,7 +715,7 @@ async def _agent_loop(self) -> TurnOutcome: matcher_value=type(e).__name__, input_data=_hook_events.stop_failure( session_id=self._runtime.session.id, - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), error_type=type(e).__name__, error_message=str(e), ), @@ -766,7 +767,7 @@ async def _append_notification(view: NotificationView) -> None: matcher_value=view.event.type, input_data=events.notification( session_id=self._runtime.session.id, - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), sink="llm", notification_type=view.event.type, title=view.event.title, @@ -959,7 +960,7 @@ async def _compact_with_retry() -> CompactionResult: matcher_value=trigger_reason, input_data=events.pre_compact( session_id=self._runtime.session.id, - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), trigger=trigger_reason, token_count=self._context.token_count, ), @@ -1000,7 +1001,7 @@ async def _compact_with_retry() -> CompactionResult: matcher_value=trigger_reason, input_data=events.post_compact( session_id=self._runtime.session.id, - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), trigger=trigger_reason, estimated_token_count=estimated_token_count, ), diff --git a/src/kimi_cli/soul/toolset.py b/src/kimi_cli/soul/toolset.py index 3722147ef..541dfb2c9 100644 --- a/src/kimi_cli/soul/toolset.py +++ b/src/kimi_cli/soul/toolset.py @@ -8,9 +8,9 @@ from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta -from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, overload +from kaos.path import KaosPath from kosong.tooling import ( CallableTool, CallableTool2, @@ -153,7 +153,7 @@ async def _call(): matcher_value=tool_call.function.name, input_data=events.pre_tool_use( session_id=_get_session_id(), - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), tool_name=tool_call.function.name, tool_input=tool_input_dict, tool_call_id=tool_call.id, @@ -180,7 +180,7 @@ async def _call(): matcher_value=tool_call.function.name, input_data=events.post_tool_use_failure( session_id=_get_session_id(), - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), tool_name=tool_call.function.name, tool_input=tool_input_dict, error=str(e), @@ -203,7 +203,7 @@ async def _call(): matcher_value=tool_call.function.name, input_data=events.post_tool_use( session_id=_get_session_id(), - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), tool_name=tool_call.function.name, tool_input=tool_input_dict, tool_output=str(ret)[:2000], diff --git a/src/kimi_cli/subagents/runner.py b/src/kimi_cli/subagents/runner.py index c609d0637..9ec0872b4 100644 --- a/src/kimi_cli/subagents/runner.py +++ b/src/kimi_cli/subagents/runner.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import TYPE_CHECKING +from kaos.path import KaosPath from kosong.chat_provider import APIStatusError, ChatProviderError from kosong.tooling import ToolError, ToolOk, ToolReturnValue @@ -257,7 +258,7 @@ async def run(self, req: ForegroundRunRequest) -> ToolReturnValue: matcher_value=actual_type, input_data=hook_events.subagent_start( session_id=self._runtime.session.id, - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), agent_name=actual_type, prompt=req.prompt[:500], ), @@ -283,7 +284,7 @@ async def run(self, req: ForegroundRunRequest) -> ToolReturnValue: matcher_value=actual_type, input_data=hook_events.subagent_stop( session_id=self._runtime.session.id, - cwd=str(Path.cwd()), + cwd=str(KaosPath.cwd()), agent_name=actual_type, response=(final_response or "")[:500], ), diff --git a/src/kimi_cli/web/runner/embedded_worker.py b/src/kimi_cli/web/runner/embedded_worker.py index 7bdf34425..66dab7f65 100644 --- a/src/kimi_cli/web/runner/embedded_worker.py +++ b/src/kimi_cli/web/runner/embedded_worker.py @@ -22,8 +22,6 @@ ) from kimi_cli.wire.server import WireServer -_PROMPT_ENV_LOCK = asyncio.Lock() - class EmbeddedWireWorker(WireServer): """Run the wire protocol against a local ``KimiCLI`` instance.""" @@ -142,9 +140,5 @@ async def _emit_out_message(self, msg: Any) -> None: await self._emit_json(payload) async def _handle_prompt(self, msg): # type: ignore[override] - # ``KimiCLI.run_wire_stdio()`` normally keeps the entire wire server inside - # the session environment context. For the embedded worker we scope that environment to - # each foreground turn and serialize embedded prompts to avoid cross-session - # cwd races from ``kaos.chdir()``. - async with _PROMPT_ENV_LOCK, self._kimi_cli.env(): + async with self._kimi_cli.env(): return await super()._handle_prompt(msg) From f709b0321aae99b57355ad1dc01a8f1fe9ac260d Mon Sep 17 00:00:00 2001 From: nic Date: Mon, 30 Mar 2026 18:32:47 +0800 Subject: [PATCH 5/6] fix(pykaos): align scoped kaos test input type --- packages/kaos/tests/test_local_kaos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/kaos/tests/test_local_kaos.py b/packages/kaos/tests/test_local_kaos.py index 5f04b4f33..1c4d9054b 100644 --- a/packages/kaos/tests/test_local_kaos.py +++ b/packages/kaos/tests/test_local_kaos.py @@ -201,7 +201,7 @@ async def test_exec_wait_timeout(local_kaos: LocalKaos): @pytest.fixture def scoped_local_kaos(tmp_path: Path) -> Generator[ScopedLocalKaos]: """Set a scoped local Kaos as current without mutating process cwd.""" - scoped = ScopedLocalKaos(tmp_path) + scoped = ScopedLocalKaos(str(tmp_path)) token = set_current_kaos(scoped) try: yield scoped From 7eb85d80fa783cfce710841a0bb8e4299bd79869 Mon Sep 17 00:00:00 2001 From: nic Date: Mon, 30 Mar 2026 19:14:30 +0800 Subject: [PATCH 6/6] fix(acp): scope kaos cwd per session --- src/kimi_cli/acp/server.py | 15 +++++++++++++-- tests/acp/test_kaos.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 tests/acp/test_kaos.py diff --git a/src/kimi_cli/acp/server.py b/src/kimi_cli/acp/server.py index 8461d2286..7eb495de5 100644 --- a/src/kimi_cli/acp/server.py +++ b/src/kimi_cli/acp/server.py @@ -8,6 +8,7 @@ from typing import Any, NamedTuple import acp +from kaos.local import ScopedLocalKaos from kaos.path import KaosPath from kimi_cli.acp.kaos import ACPKaos @@ -157,7 +158,12 @@ async def new_session( mcp_configs=[mcp_config], ) config = cli_instance.soul.runtime.config - acp_kaos = ACPKaos(self.conn, session.id, self.client_capabilities) + acp_kaos = ACPKaos( + self.conn, + session.id, + self.client_capabilities, + fallback=ScopedLocalKaos(session.work_dir), + ) acp_session = ACPSession(session.id, cli_instance, self.conn, kaos=acp_kaos) model_id_conv = _ModelIDConv(config.default_model, config.default_thinking) self.sessions[session.id] = (acp_session, model_id_conv) @@ -227,7 +233,12 @@ async def _setup_session( resumed=True, # _setup_session loads existing sessions ) config = cli_instance.soul.runtime.config - acp_kaos = ACPKaos(self.conn, session.id, self.client_capabilities) + acp_kaos = ACPKaos( + self.conn, + session.id, + self.client_capabilities, + fallback=ScopedLocalKaos(session.work_dir), + ) acp_session = ACPSession(session.id, cli_instance, self.conn, kaos=acp_kaos) model_id_conv = _ModelIDConv(config.default_model, config.default_thinking) self.sessions[session.id] = (acp_session, model_id_conv) diff --git a/tests/acp/test_kaos.py b/tests/acp/test_kaos.py new file mode 100644 index 000000000..9f6179d0f --- /dev/null +++ b/tests/acp/test_kaos.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from pathlib import Path + +from kaos import reset_current_kaos, set_current_kaos +from kaos.local import ScopedLocalKaos +from kaos.path import KaosPath + +from kimi_cli.acp.kaos import ACPKaos + + +class _FakeACPClient: + pass + + +async def test_acp_kaos_resolves_relative_paths_from_session_work_dir(tmp_path: Path) -> None: + work_dir = tmp_path / "project" + work_dir.mkdir() + (work_dir / "note.txt").write_text("hello", encoding="utf-8") + + acp_kaos = ACPKaos( + _FakeACPClient(), # type: ignore[arg-type] + "session-1", + None, + fallback=ScopedLocalKaos(str(work_dir)), + ) + token = set_current_kaos(acp_kaos) + try: + assert str(KaosPath("note.txt").canonical()) == str(work_dir / "note.txt") + assert await acp_kaos.readtext("note.txt") == "hello" + finally: + reset_current_kaos(token)