diff --git a/src/kimi_cli/ui/shell/__init__.py b/src/kimi_cli/ui/shell/__init__.py index 1f1905b0e..f1bc0e492 100644 --- a/src/kimi_cli/ui/shell/__init__.py +++ b/src/kimi_cli/ui/shell/__init__.py @@ -2,7 +2,9 @@ import asyncio import contextlib +import os import shlex +import sys import time from collections import deque from collections.abc import Awaitable, Callable, Coroutine @@ -22,6 +24,12 @@ from kimi_cli.soul import LLMNotSet, LLMNotSupported, MaxStepsReached, RunCancelled, Soul, run_soul from kimi_cli.soul.kimisoul import KimiSoul from kimi_cli.ui.shell import update as _update_mod +from kimi_cli.ui.shell.capture import inject_to_context + +if sys.platform != "win32": + from kimi_cli.ui.shell.capture import execute_with_pty_capture as _execute_capture +else: + from kimi_cli.ui.shell.capture import execute_with_pipe_capture as _execute_capture from kimi_cli.ui.shell.console import console from kimi_cli.ui.shell.echo import render_user_echo_text from kimi_cli.ui.shell.mcp_status import render_mcp_prompt @@ -40,7 +48,6 @@ visualize, ) from kimi_cli.utils.envvar import get_env_bool -from kimi_cli.utils.logging import open_original_stderr from kimi_cli.utils.signals import install_sigint_handler from kimi_cli.utils.slashcmd import SlashCommand, SlashCommandCall, parse_slash_command_call from kimi_cli.utils.subprocess_env import get_clean_env @@ -487,7 +494,7 @@ async def _invalidate_after_mcp_loading() -> None: return shell_ok async def _run_shell_command(self, command: str) -> None: - """Run a shell command in foreground.""" + """Run a shell command in foreground, capturing output for context injection.""" if not command.strip(): return @@ -503,45 +510,91 @@ async def _run_shell_command(self, command: str) -> None: ) return - # Check if user is trying to use 'cd' command + # Handle bare `cd` / `cd ` — resolve and persist globally. + # Compound commands like `cd /tmp && ls` are left to the shell. stripped_cmd = command.strip() split_cmd: list[str] | None = None try: split_cmd = shlex.split(stripped_cmd) except ValueError as exc: logger.debug("Failed to parse shell command for cd check: {error}", error=exc) - if split_cmd and len(split_cmd) == 2 and split_cmd[0] == "cd": - console.print( - "[yellow]Warning: Directory changes are not preserved across command executions." - "[/yellow]" - ) + if split_cmd and split_cmd[0] == "cd" and len(split_cmd) <= 2: + await self._handle_cd(split_cmd) return logger.info("Running shell command: {cmd}", cmd=command) - proc: asyncio.subprocess.Process | None = None - - def _handler(): - logger.debug("SIGINT received.") - if proc: - proc.terminate() - - loop = asyncio.get_running_loop() - remove_sigint = install_sigint_handler(loop, _handler) + exit_code: int | None = None + raw_output: str | None = None try: - # TODO: For the sake of simplicity, we now use `create_subprocess_shell`. - # Later we should consider making this behave like a real shell. - with open_original_stderr() as stderr: - kwargs: dict[str, Any] = {} - if stderr is not None: - kwargs["stderr"] = stderr - proc = await asyncio.create_subprocess_shell(command, env=get_clean_env(), **kwargs) - await proc.wait() + exit_code, raw_output = await _execute_capture( + command, env=get_clean_env(), cwd=os.getcwd() + ) except Exception as e: logger.exception("Failed to run shell command:") console.print(f"[red]Failed to run shell command: {e}[/red]") - finally: - remove_sigint() + + # Inject captured output into conversation context + if raw_output is not None and isinstance(self.soul, KimiSoul): + try: + await inject_to_context( + self.soul.context.append_message, command, raw_output, exit_code + ) + except Exception: + logger.debug("Failed to inject shell output to context", exc_info=True) + + async def _handle_cd(self, args: list[str]) -> None: + """Resolve ``cd`` via a real shell and persist the directory change. + + Only called for bare ``cd`` / ``cd `` (at most 2 tokens). + """ + target = args[1] if len(args) > 1 else "~" + + # Provide OLDPWD so `cd -` works across invocations. + env = get_clean_env() + old_cwd = os.getcwd() + if "OLDPWD" not in env: + env["OLDPWD"] = old_cwd + + # Use shlex.quote for safety, but NOT for targets that need shell + # expansion: ~, -, and $VAR references. For those we let the real + # shell handle expansion directly. + needs_shell_expansion = target.startswith("~") or target.startswith("$") or target == "-" + quoted_target = target if needs_shell_expansion else shlex.quote(target) + + # Let the shell resolve ~, -, $HOME, CDPATH, etc. + probe = await asyncio.create_subprocess_shell( + f"cd {quoted_target} && pwd", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + stdout, stderr = await probe.communicate() + + if probe.returncode != 0: + err = stderr.decode("utf-8", errors="replace").strip() + console.print(f"[red]cd: {err or 'failed'}[/red]") + return + + # `cd -` prints the destination before `pwd`; take the last non-empty line. + lines = [ln for ln in stdout.decode("utf-8", errors="replace").splitlines() if ln.strip()] + new_cwd = lines[-1].strip() if lines else "" + if not new_cwd or not os.path.isdir(new_cwd): + console.print(f"[red]cd: not a directory: {target}[/red]") + return + + os.chdir(new_cwd) + # Set OLDPWD for the next `cd -` invocation. + os.environ["OLDPWD"] = old_cwd + + # Keep the session's work_dir in sync so agent background tasks + # (which use session.work_dir as cwd) also see the new directory. + if isinstance(self.soul, KimiSoul): + from kaos.path import KaosPath + + self.soul.runtime.session.work_dir = KaosPath.unsafe_from_local_path( + __import__("pathlib").Path(new_cwd) + ) async def _run_slash_command(self, command_call: SlashCommandCall) -> None: from kimi_cli.cli import Reload, SwitchToWeb diff --git a/src/kimi_cli/ui/shell/capture.py b/src/kimi_cli/ui/shell/capture.py new file mode 100644 index 000000000..54ca65f05 --- /dev/null +++ b/src/kimi_cli/ui/shell/capture.py @@ -0,0 +1,284 @@ +"""Shared utilities for shell output capture, cleaning, and context injection. + +Both Ctrl+X shell mode and the ``!`` prefix route through these functions +so there is a single implementation for output cleaning, truncation, and +context injection. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import os +import re +import sys +from collections.abc import Callable, Coroutine +from typing import Any + +# POSIX-only modules — guarded so the module stays importable on Windows. +if sys.platform != "win32": + import fcntl + import termios + +from kosong.message import Message +from rich.text import Text + +from kimi_cli import logger +from kimi_cli.soul.message import system_reminder + +SHELL_OUTPUT_MAX_BYTES = 50_000 +"""Maximum bytes of shell output to inject into context.""" + +_CAPTURE_HARD_LIMIT = 2 * SHELL_OUTPUT_MAX_BYTES +"""Hard cap on in-memory capture to prevent unbounded growth from binary / infinite output.""" + + +# --------------------------------------------------------------------------- +# Output cleaning +# --------------------------------------------------------------------------- + + +def clean_output(raw: str) -> str: + """Strip ANSI escapes and resolve carriage-return overwrites. + + Processing order: + 1. Normalise ``\\r\\n`` → ``\\n`` (terminal / ``script`` line endings). + 2. Strip C0 control characters (``\\x00``–``\\x08``, ``\\x0e``–``\\x1f``) + except ``\\n``, ``\\r``, ``\\t`` which carry meaning. + 3. Resolve standalone ``\\r`` overwrites (``"50%\\r100%"`` → ``"100%"``). + 4. Let ``rich.text.Text.from_ansi`` strip SGR / CSI / OSC sequences. + """ + # 1. Normalise line endings + text = raw.replace("\r\n", "\n") + + # 2. Strip troublesome C0 control chars (e.g. \x08 backspace). + # Preserve \x1b (ESC) so ANSI sequences survive until step 4. + text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1a\x1c-\x1f]", "", text) + + # 3. Resolve \r overwrites + lines: list[str] = [] + for line in text.split("\n"): + parts = line.split("\r") + if len(parts) > 1: + line = parts[-1] + lines.append(line) + text = "\n".join(lines) + + # 4. Strip ANSI via rich + return Text.from_ansi(text).plain + + +# --------------------------------------------------------------------------- +# PTY-based execution with capture +# --------------------------------------------------------------------------- + + +async def execute_with_pty_capture( + command: str, + env: dict[str, str] | None = None, + cwd: str | None = None, +) -> tuple[int | None, str]: + """Run *command* in a pseudo-terminal, teeing output to the real terminal. + + The subprocess's **stdin** is inherited (the real terminal) so interactive + input works. **stdout** and **stderr** are routed through a PTY so that + programs see ``isatty() == True`` and produce coloured / formatted output. + We read from the PTY master and write every chunk to the real stdout *and* + an in-memory buffer. + + Returns ``(exit_code, raw_output)``. + """ + logger.debug("PTY capture: running {command!r}", command=command) + master_fd, slave_fd = os.openpty() + + # Match PTY size to real terminal so columnar output renders correctly. + try: + ws = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"\x00" * 8) + fcntl.ioctl(slave_fd, termios.TIOCSWINSZ, ws) + except OSError: + pass + + try: + proc = await asyncio.create_subprocess_shell( + command, + stdout=slave_fd, + stderr=slave_fd, + env=env, + cwd=cwd, + ) + except Exception: + os.close(master_fd) + raise + finally: + # Parent no longer needs slave fd after subprocess inherits it. + # On success the child owns a dup; on failure the except branch + # already closed master_fd so we only need to close slave_fd here. + os.close(slave_fd) + + # Save terminal state so we can restore it if the child corrupts it. + stdin_fd = sys.stdin.fileno() + saved_termios: list[Any] | None = None + with contextlib.suppress(OSError, termios.error): + saved_termios = termios.tcgetattr(stdin_fd) + + loop = asyncio.get_running_loop() + captured: list[bytes] = [] + captured_bytes = 0 + eof_event = asyncio.Event() + + try: + stdout_fd = sys.stdout.fileno() + except (AttributeError, OSError): + stdout_fd = 1 # fallback + + def _on_master_readable() -> None: + nonlocal captured_bytes + try: + data = os.read(master_fd, 4096) + except OSError: + loop.remove_reader(master_fd) + eof_event.set() + return + if data: + with contextlib.suppress(OSError): + os.write(stdout_fd, data) + # Always append; we trim to _CAPTURE_HARD_LIMIT (tail) after the + # process exits so the final context injection gets the most recent + # output rather than the head. + captured.append(data) + captured_bytes += len(data) + # Evict old chunks once we exceed twice the hard limit to bound memory. + if captured_bytes > _CAPTURE_HARD_LIMIT * 2: + while captured and captured_bytes > _CAPTURE_HARD_LIMIT: + evicted = captured.pop(0) + captured_bytes -= len(evicted) + else: + loop.remove_reader(master_fd) + eof_event.set() + + loop.add_reader(master_fd, _on_master_readable) + + try: + await proc.wait() + # Drain any remaining buffered output after process exits. + with contextlib.suppress(TimeoutError): + await asyncio.wait_for(eof_event.wait(), timeout=0.5) + except (KeyboardInterrupt, asyncio.CancelledError): + # Forward SIGINT to child if it is still running. + logger.debug("PTY capture: interrupted, forwarding SIGINT to child") + if proc.returncode is None: + proc.send_signal(2) # SIGINT + try: + await asyncio.wait_for(proc.wait(), timeout=3.0) + except (TimeoutError, ProcessLookupError, asyncio.CancelledError): + logger.debug("PTY capture: child did not exit after SIGINT, killing") + proc.kill() + finally: + loop.remove_reader(master_fd) + os.close(master_fd) + # Restore terminal state unconditionally. + if saved_termios is not None: + with contextlib.suppress(OSError, termios.error): + termios.tcsetattr(stdin_fd, termios.TCSAFLUSH, saved_termios) + + raw = b"".join(captured).decode("utf-8", errors="replace") + return proc.returncode, raw + + +# --------------------------------------------------------------------------- +# PIPE-based execution with tee (for ``!`` prefix / non-TTY contexts) +# --------------------------------------------------------------------------- + + +async def execute_with_pipe_capture( + command: str, + env: dict[str, str] | None = None, + cwd: str | None = None, +) -> tuple[int | None, str]: + """Run *command* with PIPE capture, printing output as it arrives. + + Simpler than PTY — programs won't see a TTY, so colours are lost. + Suitable for the ``!`` prefix where TTY fidelity is less important. + + Returns ``(exit_code, raw_output)``. + """ + proc = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + env=env, + cwd=cwd, + ) + assert proc.stdout is not None + chunks: list[bytes] = [] + captured_bytes = 0 + while True: + chunk = await proc.stdout.read(4096) + if not chunk: + break + try: + sys.stdout.buffer.write(chunk) + sys.stdout.buffer.flush() + except (OSError, AttributeError): + pass + chunks.append(chunk) + captured_bytes += len(chunk) + if captured_bytes > _CAPTURE_HARD_LIMIT * 2: + while chunks and captured_bytes > _CAPTURE_HARD_LIMIT: + evicted = chunks.pop(0) + captured_bytes -= len(evicted) + await proc.wait() + raw = b"".join(chunks).decode("utf-8", errors="replace") + return proc.returncode, raw + + +# --------------------------------------------------------------------------- +# Context injection +# --------------------------------------------------------------------------- + + +async def inject_to_context( + append_message: Callable[[Message], Coroutine[Any, Any, None]], + command: str, + raw_output: str, + exit_code: int | None, +) -> None: + """Build a context message from shell output and append it. + + The caller provides an ``append_message`` callback (e.g. + ``soul.context.append_message``) so this module stays decoupled from + any specific soul / context implementation. + + The output is cleaned, truncated, wrapped in ````, and + appended as a user message. + """ + output = clean_output(raw_output) + + # Neutralise any system-reminder tags in the output to prevent injection. + output = output.replace("", "<system-reminder>") + output = output.replace("", "</system-reminder>") + + # Truncate, keeping the tail (usually more informative). + truncated = False + encoded = output.encode("utf-8") + if len(encoded) > SHELL_OUTPUT_MAX_BYTES: + output = encoded[-SHELL_OUTPUT_MAX_BYTES:].decode("utf-8", errors="replace") + truncated = True + + status = f"exit code {exit_code}" if exit_code is not None else "unknown exit code" + parts = [ + f"The user ran a shell command in shell mode ({status}).", + "This is output from a command the user ran directly, not from the AI.", + "The output may be useful context for subsequent requests.", + ] + if truncated: + parts.append(f"(Output truncated to last {SHELL_OUTPUT_MAX_BYTES} bytes)") + + header = " ".join(parts) + body = f"$ {command}\n{output}" + + message = Message( + role="user", + content=[system_reminder(f"{header}\n\n{body}")], + ) + await append_message(message)