Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions packages/kaos/src/kaos/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from asyncio.subprocess import Process as AsyncioProcess
from collections.abc import AsyncGenerator
from pathlib import Path, PurePath
from stat import S_ISDIR
from typing import TYPE_CHECKING, Literal

if os.name == "nt":
Expand Down Expand Up @@ -176,5 +177,127 @@ async def exec(self, *args: str, env: Mapping[str, str] | None = None) -> KaosPr
return self.Process(process)


class ScopedLocalKaos(LocalKaos):
"""Local KAOS backend with an instance-local working directory."""

def __init__(self, cwd: StrOrKaosPath | None = None) -> None:
base_cwd = Path.cwd()
self._cwd = self._normalize_local_path(
self._coerce_local_path(cwd) if cwd is not None else base_cwd,
base=base_cwd,
)

@staticmethod
def _coerce_local_path(path: StrOrKaosPath) -> Path:
return path.unsafe_to_local_path() if isinstance(path, KaosPath) else Path(path)

def _normalize_local_path(self, path: Path, *, base: Path | None = None) -> Path:
if not path.is_absolute():
path = (base or self._cwd) / path
return Path(pathmodule.normpath(str(path)))

def _resolve_local_path(self, path: StrOrKaosPath) -> Path:
return self._normalize_local_path(self._coerce_local_path(path))

def _resolve_kaos_path(self, path: StrOrKaosPath) -> KaosPath:
return KaosPath.unsafe_from_local_path(self._resolve_local_path(path))

def getcwd(self) -> KaosPath:
return KaosPath.unsafe_from_local_path(self._cwd)

async def chdir(self, path: StrOrKaosPath) -> None:
local_path = self._resolve_local_path(path)
st = await aiofiles.os.stat(local_path)
if not S_ISDIR(st.st_mode):
raise NotADirectoryError(str(local_path))
self._cwd = local_path

async def stat(self, path: StrOrKaosPath, *, follow_symlinks: bool = True) -> StatResult:
return await super().stat(self._resolve_kaos_path(path), follow_symlinks=follow_symlinks)

async def iterdir(self, path: StrOrKaosPath) -> AsyncGenerator[KaosPath]:
async for entry in super().iterdir(self._resolve_kaos_path(path)):
yield entry

async def glob(
self, path: StrOrKaosPath, pattern: str, *, case_sensitive: bool = True
) -> AsyncGenerator[KaosPath]:
async for entry in super().glob(
self._resolve_kaos_path(path),
pattern,
case_sensitive=case_sensitive,
):
yield entry

async def readbytes(self, path: StrOrKaosPath, n: int | None = None) -> bytes:
return await super().readbytes(self._resolve_kaos_path(path), n=n)

async def readtext(
self,
path: str | KaosPath,
*,
encoding: str = "utf-8",
errors: Literal["strict", "ignore", "replace"] = "strict",
) -> str:
return await super().readtext(
self._resolve_kaos_path(path),
encoding=encoding,
errors=errors,
)

async def readlines(
self,
path: str | KaosPath,
*,
encoding: str = "utf-8",
errors: Literal["strict", "ignore", "replace"] = "strict",
) -> AsyncGenerator[str]:
async for line in super().readlines(
self._resolve_kaos_path(path),
encoding=encoding,
errors=errors,
):
yield line

async def writebytes(self, path: StrOrKaosPath, data: bytes) -> int:
return await super().writebytes(self._resolve_kaos_path(path), data)

async def writetext(
self,
path: str | KaosPath,
data: str,
*,
mode: Literal["w"] | Literal["a"] = "w",
encoding: str = "utf-8",
errors: Literal["strict", "ignore", "replace"] = "strict",
) -> int:
return await super().writetext(
self._resolve_kaos_path(path),
data,
mode=mode,
encoding=encoding,
errors=errors,
)

async def mkdir(
self, path: StrOrKaosPath, parents: bool = False, exist_ok: bool = False
) -> None:
await super().mkdir(self._resolve_kaos_path(path), parents=parents, exist_ok=exist_ok)

async def exec(self, *args: str, env: Mapping[str, str] | None = None) -> KaosProcess:
if not args:
raise ValueError("At least one argument (the program to execute) is required.")

process = await asyncio.create_subprocess_exec(
*args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(self._cwd),
env=env,
)
return self.Process(process)


local_kaos = LocalKaos()
"""The default local KAOS instance."""
44 changes: 43 additions & 1 deletion packages/kaos/tests/test_local_kaos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

from kaos import reset_current_kaos, set_current_kaos
from kaos.local import LocalKaos
from kaos.local import LocalKaos, ScopedLocalKaos
from kaos.path import KaosPath


Expand Down Expand Up @@ -196,3 +196,45 @@ async def test_exec_wait_timeout(local_kaos: LocalKaos):
if process.returncode is None:
await process.kill()
await process.wait()


@pytest.fixture
def scoped_local_kaos(tmp_path: Path) -> Generator[ScopedLocalKaos]:
"""Set a scoped local Kaos as current without mutating process cwd."""
scoped = ScopedLocalKaos(str(tmp_path))
token = set_current_kaos(scoped)
try:
yield scoped
finally:
reset_current_kaos(token)


async def test_scoped_local_kaos_tracks_cwd_without_process_chdir(
scoped_local_kaos: ScopedLocalKaos,
):
original_cwd = Path.cwd()
initial_scoped_cwd = str(scoped_local_kaos.getcwd())
await scoped_local_kaos.mkdir("nested")

await scoped_local_kaos.chdir("nested")

assert Path.cwd() == original_cwd
assert str(scoped_local_kaos.getcwd()) == str(Path(initial_scoped_cwd) / "nested")


async def test_scoped_local_kaos_exec_respects_scoped_cwd(scoped_local_kaos: ScopedLocalKaos):
await scoped_local_kaos.mkdir("nested")
await scoped_local_kaos.chdir("nested")

process = await scoped_local_kaos.exec(
*(
sys.executable,
"-c",
"from pathlib import Path; print(Path.cwd()); Path('scoped.txt').write_text('ok')",
)
)

stdout_data = await process.stdout.read()
assert await process.wait() == 0
assert stdout_data.decode("utf-8").strip() == str(scoped_local_kaos.getcwd())
assert await (scoped_local_kaos.getcwd() / "scoped.txt").is_file()
15 changes: 13 additions & 2 deletions src/kimi_cli/acp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, NamedTuple

import acp
from kaos.local import ScopedLocalKaos
from kaos.path import KaosPath

from kimi_cli.acp.kaos import ACPKaos
Expand Down Expand Up @@ -157,7 +158,12 @@ async def new_session(
mcp_configs=[mcp_config],
)
config = cli_instance.soul.runtime.config
acp_kaos = ACPKaos(self.conn, session.id, self.client_capabilities)
acp_kaos = ACPKaos(
self.conn,
session.id,
self.client_capabilities,
fallback=ScopedLocalKaos(session.work_dir),
)
acp_session = ACPSession(session.id, cli_instance, self.conn, kaos=acp_kaos)
model_id_conv = _ModelIDConv(config.default_model, config.default_thinking)
self.sessions[session.id] = (acp_session, model_id_conv)
Expand Down Expand Up @@ -227,7 +233,12 @@ async def _setup_session(
resumed=True, # _setup_session loads existing sessions
)
config = cli_instance.soul.runtime.config
acp_kaos = ACPKaos(self.conn, session.id, self.client_capabilities)
acp_kaos = ACPKaos(
self.conn,
session.id,
self.client_capabilities,
fallback=ScopedLocalKaos(session.work_dir),
)
acp_session = ACPSession(session.id, cli_instance, self.conn, kaos=acp_kaos)
model_id_conv = _ModelIDConv(config.default_model, config.default_thinking)
self.sessions[session.id] = (acp_session, model_id_conv)
Expand Down
31 changes: 21 additions & 10 deletions src/kimi_cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

import kaos
from kaos import get_current_kaos, reset_current_kaos, set_current_kaos
from kaos.local import ScopedLocalKaos, local_kaos
from kaos.path import KaosPath
from pydantic import SecretStr

Expand Down Expand Up @@ -256,17 +257,19 @@ async def create(
soul.set_hook_engine(hook_engine)
runtime.hook_engine = hook_engine

return KimiCLI(soul, runtime, env_overrides)
return KimiCLI(soul, runtime, env_overrides, ScopedLocalKaos(session.work_dir))

def __init__(
self,
_soul: KimiSoul,
_runtime: Runtime,
_env_overrides: dict[str, str],
_kaos_backend: ScopedLocalKaos | None = None,
) -> None:
self._soul = _soul
self._runtime = _runtime
self._env_overrides = _env_overrides
self._kaos_backend = _kaos_backend or ScopedLocalKaos(_runtime.session.work_dir)

@property
def soul(self) -> KimiSoul:
Expand All @@ -288,15 +291,23 @@ def shutdown_background_tasks(self) -> None:

@contextlib.asynccontextmanager
async def _env(self) -> AsyncGenerator[None]:
original_cwd = KaosPath.cwd()
await kaos.chdir(self._runtime.session.work_dir)
kaos_token = None
if get_current_kaos() is local_kaos:
kaos_token = set_current_kaos(self._kaos_backend)
try:
# to ignore possible warnings from dateparser
warnings.filterwarnings("ignore", category=DeprecationWarning)
async with self._runtime.oauth.refreshing(self._runtime):
yield
finally:
await kaos.chdir(original_cwd)
if kaos_token is not None:
reset_current_kaos(kaos_token)

@contextlib.asynccontextmanager
async def env(self) -> AsyncGenerator[None]:
"""Run with the session working directory and refreshed runtime auth."""
async with self._env():
yield

async def run(
self,
Expand All @@ -322,7 +333,7 @@ async def run(
MaxStepsReached: When the maximum number of steps is reached.
RunCancelled: When the run is cancelled by the cancel event.
"""
async with self._env():
async with self.env():
wire_future = asyncio.Future[WireUISide]()
stop_ui_loop = asyncio.Event()
approval_bridge_tasks: dict[str, asyncio.Task[None]] = {}
Expand Down Expand Up @@ -513,7 +524,7 @@ async def run_shell(
level=WelcomeInfoItem.Level.INFO,
)
)
async with self._env():
async with self.env():
shell = Shell(self._soul, welcome_info=welcome_info, prefill_text=prefill_text)
return await shell.run(command)

Expand All @@ -528,7 +539,7 @@ async def run_print(
"""Run the Kimi Code CLI instance with print UI."""
from kimi_cli.ui.print import Print

async with self._env():
async with self.env():
print_ = Print(
self._soul,
input_format,
Expand All @@ -542,14 +553,14 @@ async def run_acp(self) -> None:
"""Run the Kimi Code CLI instance as ACP server."""
from kimi_cli.ui.acp import ACP

async with self._env():
async with self.env():
acp = ACP(self._soul)
await acp.run()

async def run_wire_stdio(self) -> None:
"""Run the Kimi Code CLI instance as Wire server over stdio."""
from kimi_cli.wire.server import WireServer

async with self._env():
async with self.env():
server = WireServer(self._soul)
await server.serve()
13 changes: 7 additions & 6 deletions src/kimi_cli/soul/kimisoul.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import kosong
import tenacity
from kaos.path import KaosPath
from kosong import StepResult
from kosong.chat_provider import (
APIConnectionError,
Expand Down Expand Up @@ -482,7 +483,7 @@ async def run(self, user_input: str | list[ContentPart]):
matcher_value=text_input_for_hook,
input_data=events.user_prompt_submit(
session_id=self._runtime.session.id,
cwd=str(Path.cwd()),
cwd=str(KaosPath.cwd()),
prompt=text_input_for_hook,
),
)
Expand Down Expand Up @@ -521,7 +522,7 @@ async def run(self, user_input: str | list[ContentPart]):
"Stop",
input_data=events.stop(
session_id=self._runtime.session.id,
cwd=str(Path.cwd()),
cwd=str(KaosPath.cwd()),
stop_hook_active=False,
),
)
Expand Down Expand Up @@ -714,7 +715,7 @@ async def _agent_loop(self) -> TurnOutcome:
matcher_value=type(e).__name__,
input_data=_hook_events.stop_failure(
session_id=self._runtime.session.id,
cwd=str(Path.cwd()),
cwd=str(KaosPath.cwd()),
error_type=type(e).__name__,
error_message=str(e),
),
Expand Down Expand Up @@ -766,7 +767,7 @@ async def _append_notification(view: NotificationView) -> None:
matcher_value=view.event.type,
input_data=events.notification(
session_id=self._runtime.session.id,
cwd=str(Path.cwd()),
cwd=str(KaosPath.cwd()),
sink="llm",
notification_type=view.event.type,
title=view.event.title,
Expand Down Expand Up @@ -959,7 +960,7 @@ async def _compact_with_retry() -> CompactionResult:
matcher_value=trigger_reason,
input_data=events.pre_compact(
session_id=self._runtime.session.id,
cwd=str(Path.cwd()),
cwd=str(KaosPath.cwd()),
trigger=trigger_reason,
token_count=self._context.token_count,
),
Expand Down Expand Up @@ -1000,7 +1001,7 @@ async def _compact_with_retry() -> CompactionResult:
matcher_value=trigger_reason,
input_data=events.post_compact(
session_id=self._runtime.session.id,
cwd=str(Path.cwd()),
cwd=str(KaosPath.cwd()),
trigger=trigger_reason,
estimated_token_count=estimated_token_count,
),
Expand Down
Loading
Loading