diff --git a/.github/workflows/ci-kimi-cli-ts.yml b/.github/workflows/ci-kimi-cli-ts.yml new file mode 100644 index 000000000..9eb39de93 --- /dev/null +++ b/.github/workflows/ci-kimi-cli-ts.yml @@ -0,0 +1,134 @@ +name: CI - Kimi CLI (TypeScript) + +on: + push: + branches: [main] + pull_request: + branches: [main] + +permissions: + contents: write + +jobs: + typecheck-and-test: + name: Typecheck & Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: oven-sh/setup-bun@v2 + with: + bun-version: latest + + - run: bun install --frozen-lockfile + + - name: Typecheck + run: ./node_modules/.bin/tsc --noEmit --skipLibCheck + + - name: Test + run: bun test + + build-binaries: + name: Build Binary (${{ matrix.os }}-${{ matrix.arch }}) + needs: typecheck-and-test + runs-on: ${{ matrix.runner }} + strategy: + fail-fast: false + matrix: + include: + - os: linux + arch: x64 + runner: ubuntu-latest + target: bun-linux-x64 + artifact: kimi-linux-x64 + - os: linux + arch: arm64 + runner: ubuntu-latest + target: bun-linux-arm64 + artifact: kimi-linux-arm64 + - os: darwin + arch: x64 + runner: macos-13 + target: bun-darwin-x64 + artifact: kimi-darwin-x64 + - os: darwin + arch: arm64 + runner: macos-14 + target: bun-darwin-arm64 + artifact: kimi-darwin-arm64 + - os: windows + arch: x64 + runner: windows-latest + target: bun-windows-x64 + artifact: kimi-windows-x64 + + steps: + - uses: actions/checkout@v4 + + - uses: oven-sh/setup-bun@v2 + with: + bun-version: latest + + - run: bun install --frozen-lockfile + + - name: Build standalone binary + run: bun build src/kimi_cli/index.ts --compile --outfile dist/kimi --target=${{ matrix.target }} + + - name: Verify binary (unix) + if: matrix.os != 'windows' + run: | + chmod +x dist/kimi + dist/kimi --version + dist/kimi --help + + - name: Verify binary (windows) + if: matrix.os == 'windows' + run: | + dist\kimi.exe --version + dist\kimi.exe --help + + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: ${{ matrix.artifact }} + path: dist/kimi* + if-no-files-found: error + + release: + name: Create Release + needs: build-binaries + if: github.event_name == 'push' && github.ref == 'refs/heads/main' && startsWith(github.event.head_commit.message, 'release:') + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Get version + id: version + run: echo "version=$(node -p "require('./package.json').version")" >> $GITHUB_OUTPUT + + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts/ + + - name: Prepare release assets + run: | + mkdir -p release + for dir in artifacts/kimi-*; do + name=$(basename "$dir") + if [[ "$name" == *windows* ]]; then + (cd "$dir" && zip -r "../../release/${name}.zip" .) + else + chmod +x "$dir/kimi" + tar -czf "release/${name}.tar.gz" -C "$dir" kimi + fi + done + ls -la release/ + + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 + with: + tag_name: v${{ steps.version.outputs.version }} + name: v${{ steps.version.outputs.version }} + files: release/* + generate_release_notes: true diff --git a/.gitignore b/.gitignore index f00b02169..38208fdbe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,47 +1,35 @@ -# Python-generated files -__pycache__/ -*.py[oc] -build/ -dist/ -wheels/ -*.egg-info +# dependencies (bun install) +node_modules -# Virtual environments -.venv +# output +out +dist +*.tgz -# Project files -.vscode -.env -.env.local -/tests_local -uv.toml -.idea/* +# code coverage +coverage +*.lcov -# Build dependencies -src/kimi_cli/deps/bin -src/kimi_cli/deps/tmp +# logs +logs +_.log +report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json -# Web build artifacts -src/kimi_cli/web/static/assets/ - -# Vis build artifacts -src/kimi_cli/vis/static/ +# dotenv environment variable files +.env +.env.development.local +.env.test.local +.env.production.local +.env.local -# Generated reports -tests_ai/report.json +# caches +.eslintcache +.cache +*.tsbuildinfo -# nix build result -result -result-* +# IntelliJ based IDEs +.idea -# macOS files +# Finder (MacOS) folder config .DS_Store - -# Rust files -target/ - -node_modules/ -static/ -.memo/ -.entire -.claude \ No newline at end of file +.ftp diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..b8100b77e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,111 @@ +--- +description: Use Bun instead of Node.js, npm, pnpm, or vite. +globs: "*.ts, *.tsx, *.html, *.css, *.js, *.jsx, package.json" +alwaysApply: false +--- + +Default to using Bun instead of Node.js. + +- Use `bun ` instead of `node ` or `ts-node ` +- Use `bun test` instead of `jest` or `vitest` +- Use `bun build ` instead of `webpack` or `esbuild` +- Use `bun install` instead of `npm install` or `yarn install` or `pnpm install` +- Use `bun run + + +``` + +With the following `frontend.tsx`: + +```tsx#frontend.tsx +import React from "react"; + +// import .css files directly and it works +import './index.css'; + +import { createRoot } from "react-dom/client"; + +const root = createRoot(document.body); + +export default function Frontend() { + return

Hello, world!

; +} + +root.render(); +``` + +Then, run index.ts + +```sh +bun --hot ./index.ts +``` + +For more information, read the Bun API docs in `node_modules/bun-types/docs/**.md`. diff --git a/README.md b/README.md index 544eac559..05820f161 100644 --- a/README.md +++ b/README.md @@ -1,174 +1,15 @@ -# Kimi Code CLI +# ts -[![Commit Activity](https://img.shields.io/github/commit-activity/w/MoonshotAI/kimi-cli)](https://github.com/MoonshotAI/kimi-cli/graphs/commit-activity) -[![Checks](https://img.shields.io/github/check-runs/MoonshotAI/kimi-cli/main)](https://github.com/MoonshotAI/kimi-cli/actions) -[![Version](https://img.shields.io/pypi/v/kimi-cli)](https://pypi.org/project/kimi-cli/) -[![Downloads](https://img.shields.io/pypi/dw/kimi-cli)](https://pypistats.org/packages/kimi-cli) -[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/MoonshotAI/kimi-cli) +To install dependencies: -[Kimi Code](https://www.kimi.com/code/) | [Documentation](https://moonshotai.github.io/kimi-cli/en/) | [文档](https://moonshotai.github.io/kimi-cli/zh/) - -Kimi Code CLI is an AI agent that runs in the terminal, helping you complete software development tasks and terminal operations. It can read and edit code, execute shell commands, search and fetch web pages, and autonomously plan and adjust actions during execution. - -## Getting Started - -See [Getting Started](https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html) for how to install and start using Kimi Code CLI. - -## Key Features - -### Shell command mode - -Kimi Code CLI is not only a coding agent, but also a shell. You can switch the shell command mode by pressing `Ctrl-X`. In this mode, you can directly run shell commands without leaving Kimi Code CLI. - -![](./docs/media/shell-mode.gif) - -> [!NOTE] -> Built-in shell commands like `cd` are not supported yet. - -### VS Code extension - -Kimi Code CLI can be integrated with [Visual Studio Code](https://code.visualstudio.com/) via the [Kimi Code VS Code Extension](https://marketplace.visualstudio.com/items?itemName=moonshot-ai.kimi-code). - -![VS Code Extension](./docs/media/vscode.png) - -### IDE integration via ACP - -Kimi Code CLI supports [Agent Client Protocol] out of the box. You can use it together with any ACP-compatible editor or IDE. - -[Agent Client Protocol]: https://github.com/agentclientprotocol/agent-client-protocol - -To use Kimi Code CLI with ACP clients, make sure to run Kimi Code CLI in the terminal and send `/login` to complete the login first. Then, you can configure your ACP client to start Kimi Code CLI as an ACP agent server with command `kimi acp`. - -For example, to use Kimi Code CLI with [Zed](https://zed.dev/) or [JetBrains](https://blog.jetbrains.com/ai/2025/12/bring-your-own-ai-agent-to-jetbrains-ides/), add the following configuration to your `~/.config/zed/settings.json` or `~/.jetbrains/acp.json` file: - -```json -{ - "agent_servers": { - "Kimi Code CLI": { - "type": "custom", - "command": "kimi", - "args": ["acp"], - "env": {} - } - } -} -``` - -Then you can create Kimi Code CLI threads in IDE's agent panel. - -![](./docs/media/acp-integration.gif) - -### Zsh integration - -You can use Kimi Code CLI together with Zsh, to empower your shell experience with AI agent capabilities. - -Install the [zsh-kimi-cli](https://github.com/MoonshotAI/zsh-kimi-cli) plugin via: - -```sh -git clone https://github.com/MoonshotAI/zsh-kimi-cli.git \ - ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/kimi-cli +```bash +bun install ``` -> [!NOTE] -> If you are using a plugin manager other than Oh My Zsh, you may need to refer to the plugin's README for installation instructions. - -Then add `kimi-cli` to your Zsh plugin list in `~/.zshrc`: - -```sh -plugins=(... kimi-cli) -``` - -After restarting Zsh, you can switch to agent mode by pressing `Ctrl-X`. - -### MCP support - -Kimi Code CLI supports MCP (Model Context Protocol) tools. - -**`kimi mcp` sub-command group** - -You can manage MCP servers with `kimi mcp` sub-command group. For example: - -```sh -# Add streamable HTTP server: -kimi mcp add --transport http context7 https://mcp.context7.com/mcp --header "CONTEXT7_API_KEY: ctx7sk-your-key" - -# Add streamable HTTP server with OAuth authorization: -kimi mcp add --transport http --auth oauth linear https://mcp.linear.app/mcp - -# Add stdio server: -kimi mcp add --transport stdio chrome-devtools -- npx chrome-devtools-mcp@latest - -# List added MCP servers: -kimi mcp list - -# Remove an MCP server: -kimi mcp remove chrome-devtools - -# Authorize an MCP server: -kimi mcp auth linear -``` - -**Ad-hoc MCP configuration** - -Kimi Code CLI also supports ad-hoc MCP server configuration via CLI option. - -Given an MCP config file in the well-known MCP config format like the following: - -```json -{ - "mcpServers": { - "context7": { - "url": "https://mcp.context7.com/mcp", - "headers": { - "CONTEXT7_API_KEY": "YOUR_API_KEY" - } - }, - "chrome-devtools": { - "command": "npx", - "args": ["-y", "chrome-devtools-mcp@latest"] - } - } -} -``` - -Run `kimi` with `--mcp-config-file` option to connect to the specified MCP servers: - -```sh -kimi --mcp-config-file /path/to/mcp.json -``` - -### More - -See more features in the [Documentation](https://moonshotai.github.io/kimi-cli/en/). - -## Development - -To develop Kimi Code CLI, run: - -```sh -git clone https://github.com/MoonshotAI/kimi-cli.git -cd kimi-cli - -make prepare # prepare the development environment -``` - -Then you can start working on Kimi Code CLI. - -Refer to the following commands after you make changes: - -```sh -uv run kimi # run Kimi Code CLI +To run: -make format # format code -make check # run linting and type checking -make test # run tests -make test-kimi-cli # run Kimi Code CLI tests only -make test-kosong # run kosong tests only -make test-pykaos # run pykaos tests only -make build-web # build the web UI and sync it into the package (requires Node.js/npm) -make build # build python packages -make build-bin # build standalone binary -make help # show all make targets +```bash +bun run index.ts ``` -Note: `make build` and `make build-bin` automatically run `make build-web` to embed the web UI. +This project was created using `bun init` in bun v1.3.3. [Bun](https://bun.com) is a fast all-in-one JavaScript runtime. diff --git a/bun.lock b/bun.lock new file mode 100644 index 000000000..2b62cb02b --- /dev/null +++ b/bun.lock @@ -0,0 +1,320 @@ +{ + "lockfileVersion": 1, + "configVersion": 1, + "workspaces": { + "": { + "name": "ts", + "dependencies": { + "@anthropic-ai/sdk": "^0.81.0", + "@google/genai": "^1.48.0", + "@iarna/toml": "^2.2.5", + "chalk": "^5.6.2", + "commander": "^14.0.3", + "globby": "^16.2.0", + "ink": "^6.8.0", + "ink-spinner": "^5.0.0", + "ink-text-input": "^6.0.0", + "micromatch": "^4.0.8", + "nanoid": "^5.1.7", + "openai": "^6.33.0", + "react": "^19.2.4", + "smol-toml": "^1.6.1", + "zod": "^4.3.6", + "zod-to-json-schema": "^3.25.2", + }, + "devDependencies": { + "@biomejs/biome": "^2.4.10", + "@types/bun": "latest", + "@types/micromatch": "^4.0.10", + "@types/react": "^19.2.14", + "react-devtools-core": "^7.0.1", + }, + "peerDependencies": { + "typescript": "^5", + }, + }, + }, + "packages": { + "@alcalzone/ansi-tokenize": ["@alcalzone/ansi-tokenize@0.2.5", "https://mirrors.tencent.com/npm/@alcalzone/ansi-tokenize/-/ansi-tokenize-0.2.5.tgz", { "dependencies": { "ansi-styles": "^6.2.1", "is-fullwidth-code-point": "^5.0.0" } }, "sha512-3NX/MpTdroi0aKz134A6RC2Gb2iXVECN4QaAXnvCIxxIm3C3AVB1mkUe8NaaiyvOpDfsrqWhYtj+Q6a62RrTsw=="], + + "@anthropic-ai/sdk": ["@anthropic-ai/sdk@0.81.0", "https://mirrors.tencent.com/npm/@anthropic-ai/sdk/-/sdk-0.81.0.tgz", { "dependencies": { "json-schema-to-ts": "^3.1.1" }, "peerDependencies": { "zod": "^3.25.0 || ^4.0.0" }, "optionalPeers": ["zod"], "bin": { "anthropic-ai-sdk": "bin/cli" } }, "sha512-D4K5PvEV6wPiRtVlVsJHIUhHAmOZ6IT/I9rKlTf84gR7GyyAurPJK7z9BOf/AZqC5d1DhYQGJNKRmV+q8dGhgw=="], + + "@babel/runtime": ["@babel/runtime@7.29.2", "https://mirrors.tencent.com/npm/@babel/runtime/-/runtime-7.29.2.tgz", {}, "sha512-JiDShH45zKHWyGe4ZNVRrCjBz8Nh9TMmZG1kh4QTK8hCBTWBi8Da+i7s1fJw7/lYpM4ccepSNfqzZ/QvABBi5g=="], + + "@biomejs/biome": ["@biomejs/biome@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/biome/-/biome-2.4.10.tgz", { "optionalDependencies": { "@biomejs/cli-darwin-arm64": "2.4.10", "@biomejs/cli-darwin-x64": "2.4.10", "@biomejs/cli-linux-arm64": "2.4.10", "@biomejs/cli-linux-arm64-musl": "2.4.10", "@biomejs/cli-linux-x64": "2.4.10", "@biomejs/cli-linux-x64-musl": "2.4.10", "@biomejs/cli-win32-arm64": "2.4.10", "@biomejs/cli-win32-x64": "2.4.10" }, "bin": { "biome": "bin/biome" } }, "sha512-xxA3AphFQ1geij4JTHXv4EeSTda1IFn22ye9LdyVPoJU19fNVl0uzfEuhsfQ4Yue/0FaLs2/ccVi4UDiE7R30w=="], + + "@biomejs/cli-darwin-arm64": ["@biomejs/cli-darwin-arm64@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/cli-darwin-arm64/-/cli-darwin-arm64-2.4.10.tgz", { "os": "darwin", "cpu": "arm64" }, "sha512-vuzzI1cWqDVzOMIkYyHbKqp+AkQq4K7k+UCXWpkYcY/HDn1UxdsbsfgtVpa40shem8Kax4TLDLlx8kMAecgqiw=="], + + "@biomejs/cli-darwin-x64": ["@biomejs/cli-darwin-x64@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/cli-darwin-x64/-/cli-darwin-x64-2.4.10.tgz", { "os": "darwin", "cpu": "x64" }, "sha512-14fzASRo+BPotwp7nWULy2W5xeUyFnTaq1V13Etrrxkrih+ez/2QfgFm5Ehtf5vSjtgx/IJycMMpn5kPd5ZNaA=="], + + "@biomejs/cli-linux-arm64": ["@biomejs/cli-linux-arm64@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/cli-linux-arm64/-/cli-linux-arm64-2.4.10.tgz", { "os": "linux", "cpu": "arm64" }, "sha512-7MH1CMW5uuxQ/s7FLST63qF8B3Hgu2HRdZ7tA1X1+mk+St4JOuIrqdhIBnnyqeyWJNI+Bww7Es5QZ0wIc1Cmkw=="], + + "@biomejs/cli-linux-arm64-musl": ["@biomejs/cli-linux-arm64-musl@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/cli-linux-arm64-musl/-/cli-linux-arm64-musl-2.4.10.tgz", { "os": "linux", "cpu": "arm64" }, "sha512-WrJY6UuiSD/Dh+nwK2qOTu8kdMDlLV3dLMmychIghHPAysWFq1/DGC1pVZx8POE3ZkzKR3PUUnVrtZfMfaJjyQ=="], + + "@biomejs/cli-linux-x64": ["@biomejs/cli-linux-x64@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/cli-linux-x64/-/cli-linux-x64-2.4.10.tgz", { "os": "linux", "cpu": "x64" }, "sha512-tZLvEEi2u9Xu1zAqRjTcpIDGVtldigVvzug2fTuPG0ME/g8/mXpRPcNgLB22bGn6FvLJpHHnqLnwliOu8xjYrg=="], + + "@biomejs/cli-linux-x64-musl": ["@biomejs/cli-linux-x64-musl@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/cli-linux-x64-musl/-/cli-linux-x64-musl-2.4.10.tgz", { "os": "linux", "cpu": "x64" }, "sha512-kDTi3pI6PBN6CiczsWYOyP2zk0IJI08EWEQyDMQWW221rPaaEz6FvjLhnU07KMzLv8q3qSuoB93ua6inSQ55Tw=="], + + "@biomejs/cli-win32-arm64": ["@biomejs/cli-win32-arm64@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/cli-win32-arm64/-/cli-win32-arm64-2.4.10.tgz", { "os": "win32", "cpu": "arm64" }, "sha512-umwQU6qPzH+ISTf/eHyJ/QoQnJs3V9Vpjz2OjZXe9MVBZ7prgGafMy7yYeRGnlmDAn87AKTF3Q6weLoMGpeqdQ=="], + + "@biomejs/cli-win32-x64": ["@biomejs/cli-win32-x64@2.4.10", "https://mirrors.tencent.com/npm/@biomejs/cli-win32-x64/-/cli-win32-x64-2.4.10.tgz", { "os": "win32", "cpu": "x64" }, "sha512-aW/JU5GuyH4uxMrNYpoC2kjaHlyJGLgIa3XkhPEZI0uKhZhJZU8BuEyJmvgzSPQNGozBwWjC972RaNdcJ9KyJg=="], + + "@google/genai": ["@google/genai@1.48.0", "https://mirrors.tencent.com/npm/@google/genai/-/genai-1.48.0.tgz", { "dependencies": { "google-auth-library": "^10.3.0", "p-retry": "^4.6.2", "protobufjs": "^7.5.4", "ws": "^8.18.0" }, "peerDependencies": { "@modelcontextprotocol/sdk": "^1.25.2" }, "optionalPeers": ["@modelcontextprotocol/sdk"] }, "sha512-plonYK4ML2PrxsRD9SeqmFt76eREWkQdPCglOA6aYDzL1AAbE+7PUnT54SvpWGfws13L0AZEqGSpL7+1IPnTxQ=="], + + "@iarna/toml": ["@iarna/toml@2.2.5", "https://mirrors.tencent.com/npm/@iarna/toml/-/toml-2.2.5.tgz", {}, "sha512-trnsAYxU3xnS1gPHPyU961coFyLkh4gAD/0zQ5mymY4yOZ+CYvsPqUbOFSw0aDM4y0tV7tiFxL/1XfXPNC6IPg=="], + + "@nodelib/fs.scandir": ["@nodelib/fs.scandir@2.1.5", "https://mirrors.tencent.com/npm/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", { "dependencies": { "@nodelib/fs.stat": "2.0.5", "run-parallel": "^1.1.9" } }, "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g=="], + + "@nodelib/fs.stat": ["@nodelib/fs.stat@2.0.5", "https://mirrors.tencent.com/npm/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", {}, "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A=="], + + "@nodelib/fs.walk": ["@nodelib/fs.walk@1.2.8", "https://mirrors.tencent.com/npm/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", { "dependencies": { "@nodelib/fs.scandir": "2.1.5", "fastq": "^1.6.0" } }, "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg=="], + + "@protobufjs/aspromise": ["@protobufjs/aspromise@1.1.2", "https://mirrors.tencent.com/npm/@protobufjs/aspromise/-/aspromise-1.1.2.tgz", {}, "sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ=="], + + "@protobufjs/base64": ["@protobufjs/base64@1.1.2", "https://mirrors.tencent.com/npm/@protobufjs/base64/-/base64-1.1.2.tgz", {}, "sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg=="], + + "@protobufjs/codegen": ["@protobufjs/codegen@2.0.4", "https://mirrors.tencent.com/npm/@protobufjs/codegen/-/codegen-2.0.4.tgz", {}, "sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg=="], + + "@protobufjs/eventemitter": ["@protobufjs/eventemitter@1.1.0", "https://mirrors.tencent.com/npm/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz", {}, "sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q=="], + + "@protobufjs/fetch": ["@protobufjs/fetch@1.1.0", "https://mirrors.tencent.com/npm/@protobufjs/fetch/-/fetch-1.1.0.tgz", { "dependencies": { "@protobufjs/aspromise": "^1.1.1", "@protobufjs/inquire": "^1.1.0" } }, "sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ=="], + + "@protobufjs/float": ["@protobufjs/float@1.0.2", "https://mirrors.tencent.com/npm/@protobufjs/float/-/float-1.0.2.tgz", {}, "sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ=="], + + "@protobufjs/inquire": ["@protobufjs/inquire@1.1.0", "https://mirrors.tencent.com/npm/@protobufjs/inquire/-/inquire-1.1.0.tgz", {}, "sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q=="], + + "@protobufjs/path": ["@protobufjs/path@1.1.2", "https://mirrors.tencent.com/npm/@protobufjs/path/-/path-1.1.2.tgz", {}, "sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA=="], + + "@protobufjs/pool": ["@protobufjs/pool@1.1.0", "https://mirrors.tencent.com/npm/@protobufjs/pool/-/pool-1.1.0.tgz", {}, "sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw=="], + + "@protobufjs/utf8": ["@protobufjs/utf8@1.1.0", "https://mirrors.tencent.com/npm/@protobufjs/utf8/-/utf8-1.1.0.tgz", {}, "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw=="], + + "@sindresorhus/merge-streams": ["@sindresorhus/merge-streams@4.0.0", "https://mirrors.tencent.com/npm/@sindresorhus/merge-streams/-/merge-streams-4.0.0.tgz", {}, "sha512-tlqY9xq5ukxTUZBmoOp+m61cqwQD5pHJtFY3Mn8CA8ps6yghLH/Hw8UPdqg4OLmFW3IFlcXnQNmo/dh8HzXYIQ=="], + + "@types/braces": ["@types/braces@3.0.5", "https://mirrors.tencent.com/npm/@types/braces/-/braces-3.0.5.tgz", {}, "sha512-SQFof9H+LXeWNz8wDe7oN5zu7ket0qwMu5vZubW4GCJ8Kkeh6nBWUz87+KTz/G3Kqsrp0j/W253XJb3KMEeg3w=="], + + "@types/bun": ["@types/bun@1.3.11", "https://mirrors.tencent.com/npm/@types/bun/-/bun-1.3.11.tgz", { "dependencies": { "bun-types": "1.3.11" } }, "sha512-5vPne5QvtpjGpsGYXiFyycfpDF2ECyPcTSsFBMa0fraoxiQyMJ3SmuQIGhzPg2WJuWxVBoxWJ2kClYTcw/4fAg=="], + + "@types/micromatch": ["@types/micromatch@4.0.10", "https://mirrors.tencent.com/npm/@types/micromatch/-/micromatch-4.0.10.tgz", { "dependencies": { "@types/braces": "*" } }, "sha512-5jOhFDElqr4DKTrTEbnW8DZ4Hz5LRUEmyrGpCMrD/NphYv3nUnaF08xmSLx1rGGnyEs/kFnhiw6dCgcDqMr5PQ=="], + + "@types/node": ["@types/node@25.5.0", "https://mirrors.tencent.com/npm/@types/node/-/node-25.5.0.tgz", { "dependencies": { "undici-types": "~7.18.0" } }, "sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw=="], + + "@types/react": ["@types/react@19.2.14", "https://mirrors.tencent.com/npm/@types/react/-/react-19.2.14.tgz", { "dependencies": { "csstype": "^3.2.2" } }, "sha512-ilcTH/UniCkMdtexkoCN0bI7pMcJDvmQFPvuPvmEaYA/NSfFTAgdUSLAoVjaRJm7+6PvcM+q1zYOwS4wTYMF9w=="], + + "@types/retry": ["@types/retry@0.12.0", "https://mirrors.tencent.com/npm/@types/retry/-/retry-0.12.0.tgz", {}, "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA=="], + + "agent-base": ["agent-base@7.1.4", "https://mirrors.tencent.com/npm/agent-base/-/agent-base-7.1.4.tgz", {}, "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ=="], + + "ansi-escapes": ["ansi-escapes@7.3.0", "https://mirrors.tencent.com/npm/ansi-escapes/-/ansi-escapes-7.3.0.tgz", { "dependencies": { "environment": "^1.0.0" } }, "sha512-BvU8nYgGQBxcmMuEeUEmNTvrMVjJNSH7RgW24vXexN4Ven6qCvy4TntnvlnwnMLTVlcRQQdbRY8NKnaIoeWDNg=="], + + "ansi-regex": ["ansi-regex@6.2.2", "https://mirrors.tencent.com/npm/ansi-regex/-/ansi-regex-6.2.2.tgz", {}, "sha512-Bq3SmSpyFHaWjPk8If9yc6svM8c56dB5BAtW4Qbw5jHTwwXXcTLoRMkpDJp6VL0XzlWaCHTXrkFURMYmD0sLqg=="], + + "ansi-styles": ["ansi-styles@6.2.3", "https://mirrors.tencent.com/npm/ansi-styles/-/ansi-styles-6.2.3.tgz", {}, "sha512-4Dj6M28JB+oAH8kFkTLUo+a2jwOFkuqb3yucU0CANcRRUbxS0cP0nZYCGjcc3BNXwRIsUVmDGgzawme7zvJHvg=="], + + "auto-bind": ["auto-bind@5.0.1", "https://mirrors.tencent.com/npm/auto-bind/-/auto-bind-5.0.1.tgz", {}, "sha512-ooviqdwwgfIfNmDwo94wlshcdzfO64XV0Cg6oDsDYBJfITDz1EngD2z7DkbvCWn+XIMsIqW27sEVF6qcpJrRcg=="], + + "base64-js": ["base64-js@1.5.1", "https://mirrors.tencent.com/npm/base64-js/-/base64-js-1.5.1.tgz", {}, "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA=="], + + "bignumber.js": ["bignumber.js@9.3.1", "https://mirrors.tencent.com/npm/bignumber.js/-/bignumber.js-9.3.1.tgz", {}, "sha512-Ko0uX15oIUS7wJ3Rb30Fs6SkVbLmPBAKdlm7q9+ak9bbIeFf0MwuBsQV6z7+X768/cHsfg+WlysDWJcmthjsjQ=="], + + "braces": ["braces@3.0.3", "https://mirrors.tencent.com/npm/braces/-/braces-3.0.3.tgz", { "dependencies": { "fill-range": "^7.1.1" } }, "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA=="], + + "buffer-equal-constant-time": ["buffer-equal-constant-time@1.0.1", "https://mirrors.tencent.com/npm/buffer-equal-constant-time/-/buffer-equal-constant-time-1.0.1.tgz", {}, "sha512-zRpUiDwd/xk6ADqPMATG8vc9VPrkck7T07OIx0gnjmJAnHnTVXNQG3vfvWNuiZIkwu9KrKdA1iJKfsfTVxE6NA=="], + + "bun-types": ["bun-types@1.3.11", "https://mirrors.tencent.com/npm/bun-types/-/bun-types-1.3.11.tgz", { "dependencies": { "@types/node": "*" } }, "sha512-1KGPpoxQWl9f6wcZh57LvrPIInQMn2TQ7jsgxqpRzg+l0QPOFvJVH7HmvHo/AiPgwXy+/Thf6Ov3EdVn1vOabg=="], + + "chalk": ["chalk@5.6.2", "https://mirrors.tencent.com/npm/chalk/-/chalk-5.6.2.tgz", {}, "sha512-7NzBL0rN6fMUW+f7A6Io4h40qQlG+xGmtMxfbnH/K7TAtt8JQWVQK+6g0UXKMeVJoyV5EkkNsErQ8pVD3bLHbA=="], + + "cli-boxes": ["cli-boxes@3.0.0", "https://mirrors.tencent.com/npm/cli-boxes/-/cli-boxes-3.0.0.tgz", {}, "sha512-/lzGpEWL/8PfI0BmBOPRwp0c/wFNX1RdUML3jK/RcSBA9T8mZDdQpqYBKtCFTOfQbwPqWEOpjqW+Fnayc0969g=="], + + "cli-cursor": ["cli-cursor@4.0.0", "https://mirrors.tencent.com/npm/cli-cursor/-/cli-cursor-4.0.0.tgz", { "dependencies": { "restore-cursor": "^4.0.0" } }, "sha512-VGtlMu3x/4DOtIUwEkRezxUZ2lBacNJCHash0N0WeZDBS+7Ux1dm3XWAgWYxLJFMMdOeXMHXorshEFhbMSGelg=="], + + "cli-spinners": ["cli-spinners@2.9.2", "https://mirrors.tencent.com/npm/cli-spinners/-/cli-spinners-2.9.2.tgz", {}, "sha512-ywqV+5MmyL4E7ybXgKys4DugZbX0FC6LnwrhjuykIjnK9k8OQacQ7axGKnjDXWNhns0xot3bZI5h55H8yo9cJg=="], + + "cli-truncate": ["cli-truncate@5.2.0", "https://mirrors.tencent.com/npm/cli-truncate/-/cli-truncate-5.2.0.tgz", { "dependencies": { "slice-ansi": "^8.0.0", "string-width": "^8.2.0" } }, "sha512-xRwvIOMGrfOAnM1JYtqQImuaNtDEv9v6oIYAs4LIHwTiKee8uwvIi363igssOC0O5U04i4AlENs79LQLu9tEMw=="], + + "code-excerpt": ["code-excerpt@4.0.0", "https://mirrors.tencent.com/npm/code-excerpt/-/code-excerpt-4.0.0.tgz", { "dependencies": { "convert-to-spaces": "^2.0.1" } }, "sha512-xxodCmBen3iy2i0WtAK8FlFNrRzjUqjRsMfho58xT/wvZU1YTM3fCnRjcy1gJPMepaRlgm/0e6w8SpWHpn3/cA=="], + + "commander": ["commander@14.0.3", "https://mirrors.tencent.com/npm/commander/-/commander-14.0.3.tgz", {}, "sha512-H+y0Jo/T1RZ9qPP4Eh1pkcQcLRglraJaSLoyOtHxu6AapkjWVCy2Sit1QQ4x3Dng8qDlSsZEet7g5Pq06MvTgw=="], + + "convert-to-spaces": ["convert-to-spaces@2.0.1", "https://mirrors.tencent.com/npm/convert-to-spaces/-/convert-to-spaces-2.0.1.tgz", {}, "sha512-rcQ1bsQO9799wq24uE5AM2tAILy4gXGIK/njFWcVQkGNZ96edlpY+A7bjwvzjYvLDyzmG1MmMLZhpcsb+klNMQ=="], + + "csstype": ["csstype@3.2.3", "https://mirrors.tencent.com/npm/csstype/-/csstype-3.2.3.tgz", {}, "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ=="], + + "data-uri-to-buffer": ["data-uri-to-buffer@4.0.1", "https://mirrors.tencent.com/npm/data-uri-to-buffer/-/data-uri-to-buffer-4.0.1.tgz", {}, "sha512-0R9ikRb668HB7QDxT1vkpuUBtqc53YyAwMwGeUFKRojY/NWKvdZ+9UYtRfGmhqNbRkTSVpMbmyhXipFFv2cb/A=="], + + "debug": ["debug@4.4.3", "https://mirrors.tencent.com/npm/debug/-/debug-4.4.3.tgz", { "dependencies": { "ms": "^2.1.3" } }, "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA=="], + + "ecdsa-sig-formatter": ["ecdsa-sig-formatter@1.0.11", "https://mirrors.tencent.com/npm/ecdsa-sig-formatter/-/ecdsa-sig-formatter-1.0.11.tgz", { "dependencies": { "safe-buffer": "^5.0.1" } }, "sha512-nagl3RYrbNv6kQkeJIpt6NJZy8twLB/2vtz6yN9Z4vRKHN4/QZJIEbqohALSgwKdnksuY3k5Addp5lg8sVoVcQ=="], + + "emoji-regex": ["emoji-regex@10.6.0", "https://mirrors.tencent.com/npm/emoji-regex/-/emoji-regex-10.6.0.tgz", {}, "sha512-toUI84YS5YmxW219erniWD0CIVOo46xGKColeNQRgOzDorgBi1v4D71/OFzgD9GO2UGKIv1C3Sp8DAn0+j5w7A=="], + + "environment": ["environment@1.1.0", "https://mirrors.tencent.com/npm/environment/-/environment-1.1.0.tgz", {}, "sha512-xUtoPkMggbz0MPyPiIWr1Kp4aeWJjDZ6SMvURhimjdZgsRuDplF5/s9hcgGhyXMhs+6vpnuoiZ2kFiu3FMnS8Q=="], + + "es-toolkit": ["es-toolkit@1.45.1", "https://mirrors.tencent.com/npm/es-toolkit/-/es-toolkit-1.45.1.tgz", {}, "sha512-/jhoOj/Fx+A+IIyDNOvO3TItGmlMKhtX8ISAHKE90c4b/k1tqaqEZ+uUqfpU8DMnW5cgNJv606zS55jGvza0Xw=="], + + "escape-string-regexp": ["escape-string-regexp@2.0.0", "https://mirrors.tencent.com/npm/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz", {}, "sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w=="], + + "extend": ["extend@3.0.2", "https://mirrors.tencent.com/npm/extend/-/extend-3.0.2.tgz", {}, "sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g=="], + + "fast-glob": ["fast-glob@3.3.3", "https://mirrors.tencent.com/npm/fast-glob/-/fast-glob-3.3.3.tgz", { "dependencies": { "@nodelib/fs.stat": "^2.0.2", "@nodelib/fs.walk": "^1.2.3", "glob-parent": "^5.1.2", "merge2": "^1.3.0", "micromatch": "^4.0.8" } }, "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg=="], + + "fastq": ["fastq@1.20.1", "https://mirrors.tencent.com/npm/fastq/-/fastq-1.20.1.tgz", { "dependencies": { "reusify": "^1.0.4" } }, "sha512-GGToxJ/w1x32s/D2EKND7kTil4n8OVk/9mycTc4VDza13lOvpUZTGX3mFSCtV9ksdGBVzvsyAVLM6mHFThxXxw=="], + + "fetch-blob": ["fetch-blob@3.2.0", "https://mirrors.tencent.com/npm/fetch-blob/-/fetch-blob-3.2.0.tgz", { "dependencies": { "node-domexception": "^1.0.0", "web-streams-polyfill": "^3.0.3" } }, "sha512-7yAQpD2UMJzLi1Dqv7qFYnPbaPx7ZfFK6PiIxQ4PfkGPyNyl2Ugx+a/umUonmKqjhM4DnfbMvdX6otXq83soQQ=="], + + "fill-range": ["fill-range@7.1.1", "https://mirrors.tencent.com/npm/fill-range/-/fill-range-7.1.1.tgz", { "dependencies": { "to-regex-range": "^5.0.1" } }, "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg=="], + + "formdata-polyfill": ["formdata-polyfill@4.0.10", "https://mirrors.tencent.com/npm/formdata-polyfill/-/formdata-polyfill-4.0.10.tgz", { "dependencies": { "fetch-blob": "^3.1.2" } }, "sha512-buewHzMvYL29jdeQTVILecSaZKnt/RJWjoZCF5OW60Z67/GmSLBkOFM7qh1PI3zFNtJbaZL5eQu1vLfazOwj4g=="], + + "gaxios": ["gaxios@7.1.4", "https://mirrors.tencent.com/npm/gaxios/-/gaxios-7.1.4.tgz", { "dependencies": { "extend": "^3.0.2", "https-proxy-agent": "^7.0.1", "node-fetch": "^3.3.2" } }, "sha512-bTIgTsM2bWn3XklZISBTQX7ZSddGW+IO3bMdGaemHZ3tbqExMENHLx6kKZ/KlejgrMtj8q7wBItt51yegqalrA=="], + + "gcp-metadata": ["gcp-metadata@8.1.2", "https://mirrors.tencent.com/npm/gcp-metadata/-/gcp-metadata-8.1.2.tgz", { "dependencies": { "gaxios": "^7.0.0", "google-logging-utils": "^1.0.0", "json-bigint": "^1.0.0" } }, "sha512-zV/5HKTfCeKWnxG0Dmrw51hEWFGfcF2xiXqcA3+J90WDuP0SvoiSO5ORvcBsifmx/FoIjgQN3oNOGaQ5PhLFkg=="], + + "get-east-asian-width": ["get-east-asian-width@1.5.0", "https://mirrors.tencent.com/npm/get-east-asian-width/-/get-east-asian-width-1.5.0.tgz", {}, "sha512-CQ+bEO+Tva/qlmw24dCejulK5pMzVnUOFOijVogd3KQs07HnRIgp8TGipvCCRT06xeYEbpbgwaCxglFyiuIcmA=="], + + "glob-parent": ["glob-parent@5.1.2", "https://mirrors.tencent.com/npm/glob-parent/-/glob-parent-5.1.2.tgz", { "dependencies": { "is-glob": "^4.0.1" } }, "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow=="], + + "globby": ["globby@16.2.0", "https://mirrors.tencent.com/npm/globby/-/globby-16.2.0.tgz", { "dependencies": { "@sindresorhus/merge-streams": "^4.0.0", "fast-glob": "^3.3.3", "ignore": "^7.0.5", "is-path-inside": "^4.0.0", "slash": "^5.1.0", "unicorn-magic": "^0.4.0" } }, "sha512-QrJia2qDf5BB/V6HYlDTs0I0lBahyjLzpGQg3KT7FnCdTonAyPy2RtY802m2k4ALx6Dp752f82WsOczEVr3l6Q=="], + + "google-auth-library": ["google-auth-library@10.6.2", "https://mirrors.tencent.com/npm/google-auth-library/-/google-auth-library-10.6.2.tgz", { "dependencies": { "base64-js": "^1.3.0", "ecdsa-sig-formatter": "^1.0.11", "gaxios": "^7.1.4", "gcp-metadata": "8.1.2", "google-logging-utils": "1.1.3", "jws": "^4.0.0" } }, "sha512-e27Z6EThmVNNvtYASwQxose/G57rkRuaRbQyxM2bvYLLX/GqWZ5chWq2EBoUchJbCc57eC9ArzO5wMsEmWftCw=="], + + "google-logging-utils": ["google-logging-utils@1.1.3", "https://mirrors.tencent.com/npm/google-logging-utils/-/google-logging-utils-1.1.3.tgz", {}, "sha512-eAmLkjDjAFCVXg7A1unxHsLf961m6y17QFqXqAXGj/gVkKFrEICfStRfwUlGNfeCEjNRa32JEWOUTlYXPyyKvA=="], + + "https-proxy-agent": ["https-proxy-agent@7.0.6", "https://mirrors.tencent.com/npm/https-proxy-agent/-/https-proxy-agent-7.0.6.tgz", { "dependencies": { "agent-base": "^7.1.2", "debug": "4" } }, "sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw=="], + + "ignore": ["ignore@7.0.5", "https://mirrors.tencent.com/npm/ignore/-/ignore-7.0.5.tgz", {}, "sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg=="], + + "indent-string": ["indent-string@5.0.0", "https://mirrors.tencent.com/npm/indent-string/-/indent-string-5.0.0.tgz", {}, "sha512-m6FAo/spmsW2Ab2fU35JTYwtOKa2yAwXSwgjSv1TJzh4Mh7mC3lzAOVLBprb72XsTrgkEIsl7YrFNAiDiRhIGg=="], + + "ink": ["ink@6.8.0", "https://mirrors.tencent.com/npm/ink/-/ink-6.8.0.tgz", { "dependencies": { "@alcalzone/ansi-tokenize": "^0.2.4", "ansi-escapes": "^7.3.0", "ansi-styles": "^6.2.1", "auto-bind": "^5.0.1", "chalk": "^5.6.0", "cli-boxes": "^3.0.0", "cli-cursor": "^4.0.0", "cli-truncate": "^5.1.1", "code-excerpt": "^4.0.0", "es-toolkit": "^1.39.10", "indent-string": "^5.0.0", "is-in-ci": "^2.0.0", "patch-console": "^2.0.0", "react-reconciler": "^0.33.0", "scheduler": "^0.27.0", "signal-exit": "^3.0.7", "slice-ansi": "^8.0.0", "stack-utils": "^2.0.6", "string-width": "^8.1.1", "terminal-size": "^4.0.1", "type-fest": "^5.4.1", "widest-line": "^6.0.0", "wrap-ansi": "^9.0.0", "ws": "^8.18.0", "yoga-layout": "~3.2.1" }, "peerDependencies": { "@types/react": ">=19.0.0", "react": ">=19.0.0", "react-devtools-core": ">=6.1.2" }, "optionalPeers": ["@types/react", "react-devtools-core"] }, "sha512-sbl1RdLOgkO9isK42WCZlJCFN9hb++sX9dsklOvfd1YQ3bQ2AiFu12Q6tFlr0HvEUvzraJntQCCpfEoUe9DSzA=="], + + "ink-spinner": ["ink-spinner@5.0.0", "https://mirrors.tencent.com/npm/ink-spinner/-/ink-spinner-5.0.0.tgz", { "dependencies": { "cli-spinners": "^2.7.0" }, "peerDependencies": { "ink": ">=4.0.0", "react": ">=18.0.0" } }, "sha512-EYEasbEjkqLGyPOUc8hBJZNuC5GvXGMLu0w5gdTNskPc7Izc5vO3tdQEYnzvshucyGCBXc86ig0ujXPMWaQCdA=="], + + "ink-text-input": ["ink-text-input@6.0.0", "https://mirrors.tencent.com/npm/ink-text-input/-/ink-text-input-6.0.0.tgz", { "dependencies": { "chalk": "^5.3.0", "type-fest": "^4.18.2" }, "peerDependencies": { "ink": ">=5", "react": ">=18" } }, "sha512-Fw64n7Yha5deb1rHY137zHTAbSTNelUKuB5Kkk2HACXEtwIHBCf9OH2tP/LQ9fRYTl1F0dZgbW0zPnZk6FA9Lw=="], + + "is-extglob": ["is-extglob@2.1.1", "https://mirrors.tencent.com/npm/is-extglob/-/is-extglob-2.1.1.tgz", {}, "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ=="], + + "is-fullwidth-code-point": ["is-fullwidth-code-point@5.1.0", "https://mirrors.tencent.com/npm/is-fullwidth-code-point/-/is-fullwidth-code-point-5.1.0.tgz", { "dependencies": { "get-east-asian-width": "^1.3.1" } }, "sha512-5XHYaSyiqADb4RnZ1Bdad6cPp8Toise4TzEjcOYDHZkTCbKgiUl7WTUCpNWHuxmDt91wnsZBc9xinNzopv3JMQ=="], + + "is-glob": ["is-glob@4.0.3", "https://mirrors.tencent.com/npm/is-glob/-/is-glob-4.0.3.tgz", { "dependencies": { "is-extglob": "^2.1.1" } }, "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg=="], + + "is-in-ci": ["is-in-ci@2.0.0", "https://mirrors.tencent.com/npm/is-in-ci/-/is-in-ci-2.0.0.tgz", { "bin": { "is-in-ci": "cli.js" } }, "sha512-cFeerHriAnhrQSbpAxL37W1wcJKUUX07HyLWZCW1URJT/ra3GyUTzBgUnh24TMVfNTV2Hij2HLxkPHFZfOZy5w=="], + + "is-number": ["is-number@7.0.0", "https://mirrors.tencent.com/npm/is-number/-/is-number-7.0.0.tgz", {}, "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng=="], + + "is-path-inside": ["is-path-inside@4.0.0", "https://mirrors.tencent.com/npm/is-path-inside/-/is-path-inside-4.0.0.tgz", {}, "sha512-lJJV/5dYS+RcL8uQdBDW9c9uWFLLBNRyFhnAKXw5tVqLlKZ4RMGZKv+YQ/IA3OhD+RpbJa1LLFM1FQPGyIXvOA=="], + + "json-bigint": ["json-bigint@1.0.0", "https://mirrors.tencent.com/npm/json-bigint/-/json-bigint-1.0.0.tgz", { "dependencies": { "bignumber.js": "^9.0.0" } }, "sha512-SiPv/8VpZuWbvLSMtTDU8hEfrZWg/mH/nV/b4o0CYbSxu1UIQPLdwKOCIyLQX+VIPO5vrLX3i8qtqFyhdPSUSQ=="], + + "json-schema-to-ts": ["json-schema-to-ts@3.1.1", "https://mirrors.tencent.com/npm/json-schema-to-ts/-/json-schema-to-ts-3.1.1.tgz", { "dependencies": { "@babel/runtime": "^7.18.3", "ts-algebra": "^2.0.0" } }, "sha512-+DWg8jCJG2TEnpy7kOm/7/AxaYoaRbjVB4LFZLySZlWn8exGs3A4OLJR966cVvU26N7X9TWxl+Jsw7dzAqKT6g=="], + + "jwa": ["jwa@2.0.1", "https://mirrors.tencent.com/npm/jwa/-/jwa-2.0.1.tgz", { "dependencies": { "buffer-equal-constant-time": "^1.0.1", "ecdsa-sig-formatter": "1.0.11", "safe-buffer": "^5.0.1" } }, "sha512-hRF04fqJIP8Abbkq5NKGN0Bbr3JxlQ+qhZufXVr0DvujKy93ZCbXZMHDL4EOtodSbCWxOqR8MS1tXA5hwqCXDg=="], + + "jws": ["jws@4.0.1", "https://mirrors.tencent.com/npm/jws/-/jws-4.0.1.tgz", { "dependencies": { "jwa": "^2.0.1", "safe-buffer": "^5.0.1" } }, "sha512-EKI/M/yqPncGUUh44xz0PxSidXFr/+r0pA70+gIYhjv+et7yxM+s29Y+VGDkovRofQem0fs7Uvf4+YmAdyRduA=="], + + "long": ["long@5.3.2", "https://mirrors.tencent.com/npm/long/-/long-5.3.2.tgz", {}, "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA=="], + + "merge2": ["merge2@1.4.1", "https://mirrors.tencent.com/npm/merge2/-/merge2-1.4.1.tgz", {}, "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg=="], + + "micromatch": ["micromatch@4.0.8", "https://mirrors.tencent.com/npm/micromatch/-/micromatch-4.0.8.tgz", { "dependencies": { "braces": "^3.0.3", "picomatch": "^2.3.1" } }, "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA=="], + + "mimic-fn": ["mimic-fn@2.1.0", "https://mirrors.tencent.com/npm/mimic-fn/-/mimic-fn-2.1.0.tgz", {}, "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg=="], + + "ms": ["ms@2.1.3", "https://mirrors.tencent.com/npm/ms/-/ms-2.1.3.tgz", {}, "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA=="], + + "nanoid": ["nanoid@5.1.7", "https://mirrors.tencent.com/npm/nanoid/-/nanoid-5.1.7.tgz", { "bin": { "nanoid": "bin/nanoid.js" } }, "sha512-ua3NDgISf6jdwezAheMOk4mbE1LXjm1DfMUDMuJf4AqxLFK3ccGpgWizwa5YV7Yz9EpXwEaWoRXSb/BnV0t5dQ=="], + + "node-domexception": ["node-domexception@1.0.0", "https://mirrors.tencent.com/npm/node-domexception/-/node-domexception-1.0.0.tgz", {}, "sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ=="], + + "node-fetch": ["node-fetch@3.3.2", "https://mirrors.tencent.com/npm/node-fetch/-/node-fetch-3.3.2.tgz", { "dependencies": { "data-uri-to-buffer": "^4.0.0", "fetch-blob": "^3.1.4", "formdata-polyfill": "^4.0.10" } }, "sha512-dRB78srN/l6gqWulah9SrxeYnxeddIG30+GOqK/9OlLVyLg3HPnr6SqOWTWOXKRwC2eGYCkZ59NNuSgvSrpgOA=="], + + "onetime": ["onetime@5.1.2", "https://mirrors.tencent.com/npm/onetime/-/onetime-5.1.2.tgz", { "dependencies": { "mimic-fn": "^2.1.0" } }, "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg=="], + + "openai": ["openai@6.33.0", "https://mirrors.tencent.com/npm/openai/-/openai-6.33.0.tgz", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.25 || ^4.0" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-xAYN1W3YsDXJWA5F277135YfkEk6H7D3D6vWwRhJ3OEkzRgcyK8z/P5P9Gyi/wB4N8kK9kM5ZjprfvyHagKmpw=="], + + "p-retry": ["p-retry@4.6.2", "https://mirrors.tencent.com/npm/p-retry/-/p-retry-4.6.2.tgz", { "dependencies": { "@types/retry": "0.12.0", "retry": "^0.13.1" } }, "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ=="], + + "patch-console": ["patch-console@2.0.0", "https://mirrors.tencent.com/npm/patch-console/-/patch-console-2.0.0.tgz", {}, "sha512-0YNdUceMdaQwoKce1gatDScmMo5pu/tfABfnzEqeG0gtTmd7mh/WcwgUjtAeOU7N8nFFlbQBnFK2gXW5fGvmMA=="], + + "picomatch": ["picomatch@2.3.2", "https://mirrors.tencent.com/npm/picomatch/-/picomatch-2.3.2.tgz", {}, "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA=="], + + "protobufjs": ["protobufjs@7.5.4", "https://mirrors.tencent.com/npm/protobufjs/-/protobufjs-7.5.4.tgz", { "dependencies": { "@protobufjs/aspromise": "^1.1.2", "@protobufjs/base64": "^1.1.2", "@protobufjs/codegen": "^2.0.4", "@protobufjs/eventemitter": "^1.1.0", "@protobufjs/fetch": "^1.1.0", "@protobufjs/float": "^1.0.2", "@protobufjs/inquire": "^1.1.0", "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", "@types/node": ">=13.7.0", "long": "^5.0.0" } }, "sha512-CvexbZtbov6jW2eXAvLukXjXUW1TzFaivC46BpWc/3BpcCysb5Vffu+B3XHMm8lVEuy2Mm4XGex8hBSg1yapPg=="], + + "queue-microtask": ["queue-microtask@1.2.3", "https://mirrors.tencent.com/npm/queue-microtask/-/queue-microtask-1.2.3.tgz", {}, "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A=="], + + "react": ["react@19.2.4", "https://mirrors.tencent.com/npm/react/-/react-19.2.4.tgz", {}, "sha512-9nfp2hYpCwOjAN+8TZFGhtWEwgvWHXqESH8qT89AT/lWklpLON22Lc8pEtnpsZz7VmawabSU0gCjnj8aC0euHQ=="], + + "react-devtools-core": ["react-devtools-core@7.0.1", "https://mirrors.tencent.com/npm/react-devtools-core/-/react-devtools-core-7.0.1.tgz", { "dependencies": { "shell-quote": "^1.6.1", "ws": "^7" } }, "sha512-C3yNvRHaizlpiASzy7b9vbnBGLrhvdhl1CbdU6EnZgxPNbai60szdLtl+VL76UNOt5bOoVTOz5rNWZxgGt+Gsw=="], + + "react-reconciler": ["react-reconciler@0.33.0", "https://mirrors.tencent.com/npm/react-reconciler/-/react-reconciler-0.33.0.tgz", { "dependencies": { "scheduler": "^0.27.0" }, "peerDependencies": { "react": "^19.2.0" } }, "sha512-KetWRytFv1epdpJc3J4G75I4WrplZE5jOL7Yq0p34+OVOKF4Se7WrdIdVC45XsSSmUTlht2FM/fM1FZb1mfQeA=="], + + "restore-cursor": ["restore-cursor@4.0.0", "https://mirrors.tencent.com/npm/restore-cursor/-/restore-cursor-4.0.0.tgz", { "dependencies": { "onetime": "^5.1.0", "signal-exit": "^3.0.2" } }, "sha512-I9fPXU9geO9bHOt9pHHOhOkYerIMsmVaWB0rA2AI9ERh/+x/i7MV5HKBNrg+ljO5eoPVgCcnFuRjJ9uH6I/3eg=="], + + "retry": ["retry@0.13.1", "https://mirrors.tencent.com/npm/retry/-/retry-0.13.1.tgz", {}, "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg=="], + + "reusify": ["reusify@1.1.0", "https://mirrors.tencent.com/npm/reusify/-/reusify-1.1.0.tgz", {}, "sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw=="], + + "run-parallel": ["run-parallel@1.2.0", "https://mirrors.tencent.com/npm/run-parallel/-/run-parallel-1.2.0.tgz", { "dependencies": { "queue-microtask": "^1.2.2" } }, "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA=="], + + "safe-buffer": ["safe-buffer@5.2.1", "https://mirrors.tencent.com/npm/safe-buffer/-/safe-buffer-5.2.1.tgz", {}, "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ=="], + + "scheduler": ["scheduler@0.27.0", "https://mirrors.tencent.com/npm/scheduler/-/scheduler-0.27.0.tgz", {}, "sha512-eNv+WrVbKu1f3vbYJT/xtiF5syA5HPIMtf9IgY/nKg0sWqzAUEvqY/xm7OcZc/qafLx/iO9FgOmeSAp4v5ti/Q=="], + + "shell-quote": ["shell-quote@1.8.3", "https://mirrors.tencent.com/npm/shell-quote/-/shell-quote-1.8.3.tgz", {}, "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw=="], + + "signal-exit": ["signal-exit@3.0.7", "https://mirrors.tencent.com/npm/signal-exit/-/signal-exit-3.0.7.tgz", {}, "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ=="], + + "slash": ["slash@5.1.0", "https://mirrors.tencent.com/npm/slash/-/slash-5.1.0.tgz", {}, "sha512-ZA6oR3T/pEyuqwMgAKT0/hAv8oAXckzbkmR0UkUosQ+Mc4RxGoJkRmwHgHufaenlyAgE1Mxgpdcrf75y6XcnDg=="], + + "slice-ansi": ["slice-ansi@8.0.0", "https://mirrors.tencent.com/npm/slice-ansi/-/slice-ansi-8.0.0.tgz", { "dependencies": { "ansi-styles": "^6.2.3", "is-fullwidth-code-point": "^5.1.0" } }, "sha512-stxByr12oeeOyY2BlviTNQlYV5xOj47GirPr4yA1hE9JCtxfQN0+tVbkxwCtYDQWhEKWFHsEK48ORg5jrouCAg=="], + + "smol-toml": ["smol-toml@1.6.1", "https://mirrors.tencent.com/npm/smol-toml/-/smol-toml-1.6.1.tgz", {}, "sha512-dWUG8F5sIIARXih1DTaQAX4SsiTXhInKf1buxdY9DIg4ZYPZK5nGM1VRIYmEbDbsHt7USo99xSLFu5Q1IqTmsg=="], + + "stack-utils": ["stack-utils@2.0.6", "https://mirrors.tencent.com/npm/stack-utils/-/stack-utils-2.0.6.tgz", { "dependencies": { "escape-string-regexp": "^2.0.0" } }, "sha512-XlkWvfIm6RmsWtNJx+uqtKLS8eqFbxUg0ZzLXqY0caEy9l7hruX8IpiDnjsLavoBgqCCR71TqWO8MaXYheJ3RQ=="], + + "string-width": ["string-width@8.2.0", "https://mirrors.tencent.com/npm/string-width/-/string-width-8.2.0.tgz", { "dependencies": { "get-east-asian-width": "^1.5.0", "strip-ansi": "^7.1.2" } }, "sha512-6hJPQ8N0V0P3SNmP6h2J99RLuzrWz2gvT7VnK5tKvrNqJoyS9W4/Fb8mo31UiPvy00z7DQXkP2hnKBVav76thw=="], + + "strip-ansi": ["strip-ansi@7.2.0", "https://mirrors.tencent.com/npm/strip-ansi/-/strip-ansi-7.2.0.tgz", { "dependencies": { "ansi-regex": "^6.2.2" } }, "sha512-yDPMNjp4WyfYBkHnjIRLfca1i6KMyGCtsVgoKe/z1+6vukgaENdgGBZt+ZmKPc4gavvEZ5OgHfHdrazhgNyG7w=="], + + "tagged-tag": ["tagged-tag@1.0.0", "https://mirrors.tencent.com/npm/tagged-tag/-/tagged-tag-1.0.0.tgz", {}, "sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng=="], + + "terminal-size": ["terminal-size@4.0.1", "https://mirrors.tencent.com/npm/terminal-size/-/terminal-size-4.0.1.tgz", {}, "sha512-avMLDQpUI9I5XFrklECw1ZEUPJhqzcwSWsyyI8blhRLT+8N1jLJWLWWYQpB2q2xthq8xDvjZPISVh53T/+CLYQ=="], + + "to-regex-range": ["to-regex-range@5.0.1", "https://mirrors.tencent.com/npm/to-regex-range/-/to-regex-range-5.0.1.tgz", { "dependencies": { "is-number": "^7.0.0" } }, "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ=="], + + "ts-algebra": ["ts-algebra@2.0.0", "https://mirrors.tencent.com/npm/ts-algebra/-/ts-algebra-2.0.0.tgz", {}, "sha512-FPAhNPFMrkwz76P7cdjdmiShwMynZYN6SgOujD1urY4oNm80Ou9oMdmbR45LotcKOXoy7wSmHkRFE6Mxbrhefw=="], + + "type-fest": ["type-fest@5.5.0", "https://mirrors.tencent.com/npm/type-fest/-/type-fest-5.5.0.tgz", { "dependencies": { "tagged-tag": "^1.0.0" } }, "sha512-PlBfpQwiUvGViBNX84Yxwjsdhd1TUlXr6zjX7eoirtCPIr08NAmxwa+fcYBTeRQxHo9YC9wwF3m9i700sHma8g=="], + + "typescript": ["typescript@5.9.3", "https://mirrors.tencent.com/npm/typescript/-/typescript-5.9.3.tgz", { "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" } }, "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw=="], + + "undici-types": ["undici-types@7.18.2", "https://mirrors.tencent.com/npm/undici-types/-/undici-types-7.18.2.tgz", {}, "sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w=="], + + "unicorn-magic": ["unicorn-magic@0.4.0", "https://mirrors.tencent.com/npm/unicorn-magic/-/unicorn-magic-0.4.0.tgz", {}, "sha512-wH590V9VNgYH9g3lH9wWjTrUoKsjLF6sGLjhR4sH1LWpLmCOH0Zf7PukhDA8BiS7KHe4oPNkcTHqYkj7SOGUOw=="], + + "web-streams-polyfill": ["web-streams-polyfill@3.3.3", "https://mirrors.tencent.com/npm/web-streams-polyfill/-/web-streams-polyfill-3.3.3.tgz", {}, "sha512-d2JWLCivmZYTSIoge9MsgFCZrt571BikcWGYkjC1khllbTeDlGqZ2D8vD8E/lJa8WGWbb7Plm8/XJYV7IJHZZw=="], + + "widest-line": ["widest-line@6.0.0", "https://mirrors.tencent.com/npm/widest-line/-/widest-line-6.0.0.tgz", { "dependencies": { "string-width": "^8.1.0" } }, "sha512-U89AsyEeAsyoF0zVJBkG9zBgekjgjK7yk9sje3F4IQpXBJ10TF6ByLlIfjMhcmHMJgHZI4KHt4rdNfktzxIAMA=="], + + "wrap-ansi": ["wrap-ansi@9.0.2", "https://mirrors.tencent.com/npm/wrap-ansi/-/wrap-ansi-9.0.2.tgz", { "dependencies": { "ansi-styles": "^6.2.1", "string-width": "^7.0.0", "strip-ansi": "^7.1.0" } }, "sha512-42AtmgqjV+X1VpdOfyTGOYRi0/zsoLqtXQckTmqTeybT+BDIbM/Guxo7x3pE2vtpr1ok6xRqM9OpBe+Jyoqyww=="], + + "ws": ["ws@7.5.10", "https://mirrors.tencent.com/npm/ws/-/ws-7.5.10.tgz", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": "^5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-+dbF1tHwZpXcbOJdVOkzLDxZP1ailvSxM6ZweXTegylPny803bFhA+vqBYw4s31NSAk4S2Qz+AKXK9a4wkdjcQ=="], + + "yoga-layout": ["yoga-layout@3.2.1", "https://mirrors.tencent.com/npm/yoga-layout/-/yoga-layout-3.2.1.tgz", {}, "sha512-0LPOt3AxKqMdFBZA3HBAt/t/8vIKq7VaQYbuA8WxCgung+p9TVyKRYdpvCb80HcdTN2NkbIKbhNwKUfm3tQywQ=="], + + "zod": ["zod@4.3.6", "https://mirrors.tencent.com/npm/zod/-/zod-4.3.6.tgz", {}, "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg=="], + + "zod-to-json-schema": ["zod-to-json-schema@3.25.2", "https://mirrors.tencent.com/npm/zod-to-json-schema/-/zod-to-json-schema-3.25.2.tgz", { "peerDependencies": { "zod": "^3.25.28 || ^4" } }, "sha512-O/PgfnpT1xKSDeQYSCfRI5Gy3hPf91mKVDuYLUHZJMiDFptvP41MSnWofm8dnCm0256ZNfZIM7DSzuSMAFnjHA=="], + + "@google/genai/ws": ["ws@8.20.0", "https://mirrors.tencent.com/npm/ws/-/ws-8.20.0.tgz", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA=="], + + "ink/ws": ["ws@8.20.0", "https://mirrors.tencent.com/npm/ws/-/ws-8.20.0.tgz", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA=="], + + "ink-text-input/type-fest": ["type-fest@4.41.0", "https://mirrors.tencent.com/npm/type-fest/-/type-fest-4.41.0.tgz", {}, "sha512-TeTSQ6H5YHvpqVwBRcnLDCBnDOHWYu7IvGbHT6N8AOymcr9PJGjc1GTtiWZTYg0NCgYwvnYWEkVChQAr9bjfwA=="], + + "openai/ws": ["ws@8.20.0", "https://mirrors.tencent.com/npm/ws/-/ws-8.20.0.tgz", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA=="], + + "wrap-ansi/string-width": ["string-width@7.2.0", "https://mirrors.tencent.com/npm/string-width/-/string-width-7.2.0.tgz", { "dependencies": { "emoji-regex": "^10.3.0", "get-east-asian-width": "^1.0.0", "strip-ansi": "^7.1.0" } }, "sha512-tsaTIkKW9b4N+AEj+SVA+WhJzV7/zMhcSu78mLKWSk7cXMOSHsBKFWUs0fWwq8QyK3MgJBQRX6Gbi4kYbdvGkQ=="], + } +} diff --git a/index.ts b/index.ts new file mode 100644 index 000000000..f67b2c645 --- /dev/null +++ b/index.ts @@ -0,0 +1 @@ +console.log("Hello via Bun!"); \ No newline at end of file diff --git a/package.json b/package.json new file mode 100644 index 000000000..701112368 --- /dev/null +++ b/package.json @@ -0,0 +1,52 @@ +{ + "name": "kimi-cli", + "version": "2.0.0", + "module": "src/kimi_cli/index.ts", + "type": "module", + "private": true, + "bin": { + "kimi": "src/kimi_cli/index.ts" + }, + "scripts": { + "start": "bun run src/kimi_cli/index.ts", + "dev": "bun --watch run src/kimi_cli/index.ts", + "test": "bun test", + "build": "bun build src/kimi_cli/index.ts --compile --outfile dist/kimi", + "build:linux-x64": "bun build src/kimi_cli/index.ts --compile --outfile dist/kimi --target=bun-linux-x64", + "build:linux-arm64": "bun build src/kimi_cli/index.ts --compile --outfile dist/kimi --target=bun-linux-arm64", + "build:darwin-x64": "bun build src/kimi_cli/index.ts --compile --outfile dist/kimi --target=bun-darwin-x64", + "build:darwin-arm64": "bun build src/kimi_cli/index.ts --compile --outfile dist/kimi --target=bun-darwin-arm64", + "build:windows-x64": "bun build src/kimi_cli/index.ts --compile --outfile dist/kimi.exe --target=bun-windows-x64", + "lint": "biome check src/", + "format": "biome format --write src/", + "typecheck": "tsc --noEmit" + }, + "devDependencies": { + "@biomejs/biome": "^2.4.10", + "@types/bun": "latest", + "@types/micromatch": "^4.0.10", + "@types/react": "^19.2.14", + "react-devtools-core": "^7.0.1" + }, + "peerDependencies": { + "typescript": "^5" + }, + "dependencies": { + "@anthropic-ai/sdk": "^0.81.0", + "@google/genai": "^1.48.0", + "@iarna/toml": "^2.2.5", + "chalk": "^5.6.2", + "commander": "^14.0.3", + "globby": "^16.2.0", + "ink": "^6.8.0", + "ink-spinner": "^5.0.0", + "ink-text-input": "^6.0.0", + "micromatch": "^4.0.8", + "nanoid": "^5.1.7", + "openai": "^6.33.0", + "react": "^19.2.4", + "smol-toml": "^1.6.1", + "zod": "^4.3.6", + "zod-to-json-schema": "^3.25.2" + } +} diff --git a/src/kimi_cli/CHANGELOG.md b/src/kimi_cli/CHANGELOG.md deleted file mode 120000 index 699cc9e7b..000000000 --- a/src/kimi_cli/CHANGELOG.md +++ /dev/null @@ -1 +0,0 @@ -../../CHANGELOG.md \ No newline at end of file diff --git a/src/kimi_cli/__init__.py b/src/kimi_cli/__init__.py deleted file mode 100644 index 7c1453fc7..000000000 --- a/src/kimi_cli/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -from typing import Any, cast - - -class _LazyLogger: - """Import loguru only when logging is actually used.""" - - def __init__(self) -> None: - self._logger: Any | None = None - - def _get(self) -> Any: - if self._logger is None: - from loguru import logger as real_logger - - # Disable logging by default for library usage. - # Application entry points (e.g., kimi_cli.cli) should call logger.enable("kimi_cli") - # to enable logging. - real_logger.disable("kimi_cli") - self._logger = real_logger - return self._logger - - def __getattr__(self, name: str) -> Any: - return getattr(self._get(), name) - - -logger = cast(Any, _LazyLogger()) - -__all__ = ["logger"] diff --git a/src/kimi_cli/__main__.py b/src/kimi_cli/__main__.py deleted file mode 100644 index 9f5cd4f38..000000000 --- a/src/kimi_cli/__main__.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -import sys -from collections.abc import Sequence -from pathlib import Path - - -def _prog_name() -> str: - return Path(sys.argv[0]).name or "kimi" - - -def main(argv: Sequence[str] | None = None) -> int | str | None: - from kimi_cli.utils.proxy import normalize_proxy_env - - normalize_proxy_env() - - args = list(sys.argv[1:] if argv is None else argv) - - if len(args) == 1 and args[0] in {"--version", "-V"}: - from kimi_cli.constant import get_version - - print(f"kimi, version {get_version()}") - return 0 - - from kimi_cli.cli import cli - - try: - return cli(args=args, prog_name=_prog_name()) - except SystemExit as exc: - return exc.code - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/kimi_cli/acp/AGENTS.md b/src/kimi_cli/acp/AGENTS.md deleted file mode 100644 index 24efeacaa..000000000 --- a/src/kimi_cli/acp/AGENTS.md +++ /dev/null @@ -1,92 +0,0 @@ -# ACP Integration Notes (kimi-cli) - -## Protocol summary (ACP overview) -- ACP is JSON-RPC 2.0 with request/response methods plus one-way notifications. -- Typical flow: `initialize` -> optional `authenticate` -> `session/new` or `session/load` - -> `session/prompt` - with `session/update` notifications and optional `session/cancel`. -- Clients provide `session/request_permission` and optional terminal/filesystem methods. -- All ACP file paths must be absolute; line numbers are 1-based. - -## Entry points and server modes -- **Single-session server**: `KimiCLI.run_acp()` uses `ACP` -> `ACPServerSingleSession`. - - Code: `src/kimi_cli/app.py`, `src/kimi_cli/ui/acp/__init__.py`. - - Used when running CLI with `--acp` UI mode. -- **Multi-session server**: `acp_main()` runs `ACPServer` with `use_unstable_protocol=True`. - - Code: `src/kimi_cli/acp/__init__.py`, `src/kimi_cli/acp/server.py`. - - Exposed via the `kimi acp` command in `src/kimi_cli/cli/__init__.py`. - -## Capabilities advertised -- `prompt_capabilities`: `embedded_context=False`, `image=True`, `audio=False`. -- `mcp_capabilities`: `http=True`, `sse=False`. -- Single-session: `load_session=False`, no session list capabilities. -- Multi-session: `load_session=True`, `session_capabilities.list` supported. -- `auth_methods=[]` (no authentication methods advertised). - -## Session lifecycle (implemented behavior) -- `session/new` - - Multi-session: creates a persisted `Session`, builds `KimiCLI`, stores `ACPSession`. - - Single-session: wraps the existing `Soul` into a `Wire` loop and creates `ACPSession`. - - Both send `AvailableCommandsUpdate` for slash commands on session creation. - - MCP servers passed by ACP are converted via `acp_mcp_servers_to_mcp_config`. -- `session/load` - - Multi-session only: loads by `Session.find`, then builds `KimiCLI` and `ACPSession`. - - No history replay yet (TODO). - - Single-session: not implemented. -- `session/list` - - Multi-session only: lists sessions via `Session.list`, no pagination. - - Single-session: not implemented. -- `session/prompt` - - Uses `ACPSession.prompt()` to stream updates and produce a `stop_reason`. - - Stop reasons: `end_turn`, `max_turn_requests`, `cancelled`. -- `session/cancel` - - Sets the per-turn cancel event to stop the prompt. - -## Streaming updates and content mapping -- Text chunks -> `AgentMessageChunk`. -- Think chunks -> `AgentThoughtChunk`. -- Tool calls: - - Start -> `ToolCallStart` with JSON args as text content. - - Streaming args -> `ToolCallProgress` with updated title/args. - - Results -> `ToolCallProgress` with `completed` or `failed`. - - Tool call IDs are prefixed with turn ID to avoid collisions across turns. -- Plan updates: - - `TodoDisplayBlock` is converted into `AgentPlanUpdate`. -- Available commands: - - `AvailableCommandsUpdate` is sent right after session creation. - -## Prompt/content conversion -- Incoming prompt blocks: - - Supported: `TextContentBlock`, `ImageContentBlock` (converted to data URL). - - Unsupported types are logged and ignored. -- Tool result display blocks: - - `DiffDisplayBlock` -> `FileEditToolCallContent`. - - `HideOutputDisplayBlock` suppresses tool output in ACP (used by terminal tool). - -## Tool integration and permission flow -- ACP sessions use `ACPKaos` to route filesystem reads/writes through ACP clients. -- If the client advertises `terminal` capability, the `Shell` tool is replaced by an - ACP-backed `Terminal` tool. - - Uses ACP `terminal/create`, waits for exit, streams `TerminalToolCallContent`, - then releases the terminal handle. -- Approval requests in the core tool system are bridged to ACP - `session/request_permission` with allow-once/allow-always/reject options. - -## Current gaps / not implemented -- `authenticate` method (not used by current Zed ACP client). -- `session/set_mode` and `session/set_model` (no multi-mode/model switching in kimi-cli). -- `ext_method` / `ext_notification` for custom ACP extensions are stubbed. -- Single-session server does not implement `session/load` or `session/list`. - -## Filesystem (ACP client-backed) -- When the client advertises `fs.readTextFile` / `fs.writeTextFile`, `ACPKaos` routes - reads and writes through ACP `fs/*` methods. -- `ReadFile` uses `KaosPath.read_lines`, which `ACPKaos` implements via ACP reads. -- `ReadMediaFile` uses `KaosPath.read_bytes` to load image/video payloads through ACP reads. -- `WriteFile` uses `KaosPath.read_text/write_text/append_text` and still generates diffs - and approvals in the tool layer. - -## Zed-specific notes (as of current integration) -- Zed does not currently call `authenticate`. -- Zed’s external agent server session management is not yet available, so - `session/load` is not exercised in practice. diff --git a/src/kimi_cli/acp/__init__.py b/src/kimi_cli/acp/__init__.py deleted file mode 100644 index daba368d7..000000000 --- a/src/kimi_cli/acp/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -def acp_main() -> None: - """Entry point for the multi-session ACP server.""" - import asyncio - - import acp - - from kimi_cli.acp.server import ACPServer - from kimi_cli.app import enable_logging - from kimi_cli.utils.logging import logger - - enable_logging() - logger.info("Starting ACP server on stdio") - asyncio.run(acp.run_agent(ACPServer(), use_unstable_protocol=True)) diff --git a/src/kimi_cli/acp/convert.py b/src/kimi_cli/acp/convert.py deleted file mode 100644 index 84f8c62ad..000000000 --- a/src/kimi_cli/acp/convert.py +++ /dev/null @@ -1,128 +0,0 @@ -from __future__ import annotations - -import acp - -from kimi_cli.acp.types import ACPContentBlock -from kimi_cli.utils.logging import logger -from kimi_cli.wire.types import ( - ContentPart, - DiffDisplayBlock, - DisplayBlock, - ImageURLPart, - TextPart, - ToolReturnValue, -) - - -def acp_blocks_to_content_parts(prompt: list[ACPContentBlock]) -> list[ContentPart]: - content: list[ContentPart] = [] - for block in prompt: - match block: - case acp.schema.TextContentBlock(): - content.append(TextPart(text=block.text)) - case acp.schema.ImageContentBlock(): - content.append( - ImageURLPart( - image_url=ImageURLPart.ImageURL( - url=f"data:{block.mime_type};base64,{block.data}" - ) - ) - ) - case acp.schema.EmbeddedResourceContentBlock(): - resource = block.resource - if isinstance(resource, acp.schema.TextResourceContents): - uri = resource.uri - text = resource.text - content.append(TextPart(text=f"\n{text}\n")) - else: - logger.warning( - "Unsupported embedded resource type: {type}", - type=type(resource).__name__, - ) - case acp.schema.ResourceContentBlock(): - # ResourceContentBlock is a link reference without inline content; - # include the URI so the model is at least aware of the reference. - content.append( - TextPart(text=f"") - ) - case _: - logger.warning("Unsupported prompt content block: {block}", block=block) - return content - - -def display_block_to_acp_content( - block: DisplayBlock, -) -> acp.schema.FileEditToolCallContent | None: - if isinstance(block, DiffDisplayBlock): - return acp.schema.FileEditToolCallContent( - type="diff", - path=block.path, - old_text=block.old_text, - new_text=block.new_text, - ) - - return None - - -def tool_result_to_acp_content( - tool_ret: ToolReturnValue, -) -> list[ - acp.schema.ContentToolCallContent - | acp.schema.FileEditToolCallContent - | acp.schema.TerminalToolCallContent -]: - from kimi_cli.acp.tools import HideOutputDisplayBlock - - def _to_acp_content( - part: ContentPart, - ) -> ( - acp.schema.ContentToolCallContent - | acp.schema.FileEditToolCallContent - | acp.schema.TerminalToolCallContent - ): - if isinstance(part, TextPart): - return acp.schema.ContentToolCallContent( - type="content", content=acp.schema.TextContentBlock(type="text", text=part.text) - ) - logger.warning("Unsupported content part in tool result: {part}", part=part) - return acp.schema.ContentToolCallContent( - type="content", - content=acp.schema.TextContentBlock(type="text", text=f"[{part.__class__.__name__}]"), - ) - - def _to_text_block(text: str) -> acp.schema.ContentToolCallContent: - return acp.schema.ContentToolCallContent( - type="content", content=acp.schema.TextContentBlock(type="text", text=text) - ) - - contents: list[ - acp.schema.ContentToolCallContent - | acp.schema.FileEditToolCallContent - | acp.schema.TerminalToolCallContent - ] = [] - - for block in tool_ret.display: - if isinstance(block, HideOutputDisplayBlock): - # return early to indicate no output should be shown - return [] - - content = display_block_to_acp_content(block) - if content is not None: - contents.append(content) - # TODO: better concatenation of `display` blocks and `output`? - - output = tool_ret.output - if isinstance(output, str): - if output: - contents.append(_to_text_block(output)) - else: - # NOTE: At the moment, ToolReturnValue.output is either a string or a - # list of ContentPart. We avoid an unnecessary isinstance() check here - # to keep pyright happy while still handling list outputs. - contents.extend(_to_acp_content(part) for part in output) - - if not contents and tool_ret.message: - # Fallback to the `message` for LLM if there's no other content - contents.append(_to_text_block(tool_ret.message)) - - return contents diff --git a/src/kimi_cli/acp/kaos.py b/src/kimi_cli/acp/kaos.py deleted file mode 100644 index 50319031f..000000000 --- a/src/kimi_cli/acp/kaos.py +++ /dev/null @@ -1,291 +0,0 @@ -from __future__ import annotations - -import asyncio -from collections.abc import AsyncGenerator, Iterable, Mapping -from contextlib import suppress -from typing import Literal - -import acp -from kaos import AsyncReadable, AsyncWritable, Kaos, KaosProcess, StatResult, StrOrKaosPath -from kaos.local import local_kaos -from kaos.path import KaosPath - -_DEFAULT_TERMINAL_OUTPUT_LIMIT = 50_000 -_DEFAULT_POLL_INTERVAL = 0.2 -_TRUNCATION_NOTICE = "[acp output truncated]\n" - - -class _NullWritable: - def can_write_eof(self) -> bool: - return False - - def close(self) -> None: - return None - - async def drain(self) -> None: - return None - - def is_closing(self) -> bool: - return False - - async def wait_closed(self) -> None: - return None - - def write(self, data: bytes) -> None: - return None - - def writelines(self, data: Iterable[bytes], /) -> None: - return None - - def write_eof(self) -> None: - return None - - -class ACPProcess: - """KAOS process adapter for ACP terminal execution.""" - - def __init__( - self, - client: acp.Client, - session_id: str, - terminal_id: str, - *, - poll_interval: float = _DEFAULT_POLL_INTERVAL, - ) -> None: - self._client = client - self._session_id = session_id - self._terminal_id = terminal_id - self._poll_interval = poll_interval - self._stdin = _NullWritable() - self._stdout = asyncio.StreamReader() - self._stderr = asyncio.StreamReader() - self.stdin: AsyncWritable = self._stdin - self.stdout: AsyncReadable = self._stdout - # ACP does not expose stderr separately; keep stderr empty. - self.stderr: AsyncReadable = self._stderr - self._returncode: int | None = None - self._last_output = "" - self._truncation_noted = False - self._exit_future: asyncio.Future[int] = asyncio.get_running_loop().create_future() - self._poll_task = asyncio.create_task(self._poll_output()) - - @property - def pid(self) -> int: - return -1 - - @property - def returncode(self) -> int | None: - return self._returncode - - async def wait(self) -> int: - return await self._exit_future - - async def kill(self) -> None: - await self._client.kill_terminal( - session_id=self._session_id, - terminal_id=self._terminal_id, - ) - - def _feed_output(self, output_response: acp.schema.TerminalOutputResponse) -> None: - output = output_response.output - reset = output_response.truncated or ( - self._last_output and not output.startswith(self._last_output) - ) - if reset and self._last_output and not self._truncation_noted: - self._stdout.feed_data(_TRUNCATION_NOTICE.encode("utf-8")) - self._truncation_noted = True - - delta = output if reset else output[len(self._last_output) :] - if delta: - self._stdout.feed_data(delta.encode("utf-8", "replace")) - self._last_output = output - - @staticmethod - def _normalize_exit_code(exit_code: int | None) -> int: - return 1 if exit_code is None else exit_code - - async def _poll_output(self) -> None: - exit_task = asyncio.create_task( - self._client.wait_for_terminal_exit( - session_id=self._session_id, - terminal_id=self._terminal_id, - ) - ) - exit_code: int | None = None - try: - while True: - if exit_task.done(): - exit_response = exit_task.result() - exit_code = exit_response.exit_code - break - - output_response = await self._client.terminal_output( - session_id=self._session_id, - terminal_id=self._terminal_id, - ) - self._feed_output(output_response) - if output_response.exit_status: - exit_code = output_response.exit_status.exit_code - try: - exit_response = await exit_task - exit_code = exit_response.exit_code or exit_code - except Exception: - pass - break - - await asyncio.sleep(self._poll_interval) - - final_output = await self._client.terminal_output( - session_id=self._session_id, - terminal_id=self._terminal_id, - ) - self._feed_output(final_output) - except Exception as exc: - error_note = f"[acp terminal error] {exc}\n" - self._stdout.feed_data(error_note.encode("utf-8", "replace")) - if exit_code is None: - exit_code = 1 - finally: - if not exit_task.done(): - exit_task.cancel() - with suppress(Exception): - await exit_task - self._returncode = self._normalize_exit_code(exit_code) - self._stdout.feed_eof() - self._stderr.feed_eof() - if not self._exit_future.done(): - self._exit_future.set_result(self._returncode) - with suppress(Exception): - await self._client.release_terminal( - session_id=self._session_id, - terminal_id=self._terminal_id, - ) - - -class ACPKaos: - """KAOS backend that routes supported operations through ACP.""" - - name: str = "acp" - - def __init__( - self, - client: acp.Client, - session_id: str, - client_capabilities: acp.schema.ClientCapabilities | None, - fallback: Kaos | None = None, - *, - output_byte_limit: int | None = _DEFAULT_TERMINAL_OUTPUT_LIMIT, - poll_interval: float = _DEFAULT_POLL_INTERVAL, - ) -> None: - self._client = client - self._session_id = session_id - self._fallback = fallback or local_kaos - fs = client_capabilities.fs if client_capabilities else None - self._supports_read = bool(fs and fs.read_text_file) - self._supports_write = bool(fs and fs.write_text_file) - self._supports_terminal = bool(client_capabilities and client_capabilities.terminal) - self._output_byte_limit = output_byte_limit - self._poll_interval = poll_interval - - def pathclass(self): - return self._fallback.pathclass() - - def normpath(self, path: StrOrKaosPath) -> KaosPath: - return self._fallback.normpath(path) - - def gethome(self) -> KaosPath: - return self._fallback.gethome() - - def getcwd(self) -> KaosPath: - return self._fallback.getcwd() - - async def chdir(self, path: StrOrKaosPath) -> None: - await self._fallback.chdir(path) - - async def stat(self, path: StrOrKaosPath, *, follow_symlinks: bool = True) -> StatResult: - return await self._fallback.stat(path, follow_symlinks=follow_symlinks) - - def iterdir(self, path: StrOrKaosPath) -> AsyncGenerator[KaosPath]: - return self._fallback.iterdir(path) - - def glob( - self, path: StrOrKaosPath, pattern: str, *, case_sensitive: bool = True - ) -> AsyncGenerator[KaosPath]: - return self._fallback.glob(path, pattern, case_sensitive=case_sensitive) - - async def readbytes(self, path: StrOrKaosPath, n: int | None = None) -> bytes: - return await self._fallback.readbytes(path, n=n) - - async def readtext( - self, - path: StrOrKaosPath, - *, - encoding: str = "utf-8", - errors: Literal["strict", "ignore", "replace"] = "strict", - ) -> str: - abs_path = self._abs_path(path) - if not self._supports_read: - return await self._fallback.readtext(abs_path, encoding=encoding, errors=errors) - response = await self._client.read_text_file(path=abs_path, session_id=self._session_id) - return response.content - - async def readlines( - self, - path: StrOrKaosPath, - *, - encoding: str = "utf-8", - errors: Literal["strict", "ignore", "replace"] = "strict", - ) -> AsyncGenerator[str]: - text = await self.readtext(path, encoding=encoding, errors=errors) - for line in text.splitlines(keepends=True): - yield line - - async def writebytes(self, path: StrOrKaosPath, data: bytes) -> int: - return await self._fallback.writebytes(path, data) - - async def writetext( - self, - path: StrOrKaosPath, - data: str, - *, - mode: Literal["w", "a"] = "w", - encoding: str = "utf-8", - errors: Literal["strict", "ignore", "replace"] = "strict", - ) -> int: - abs_path = self._abs_path(path) - if mode == "a": - if self._supports_read and self._supports_write: - existing = await self.readtext(abs_path, encoding=encoding, errors=errors) - await self._client.write_text_file( - path=abs_path, - content=existing + data, - session_id=self._session_id, - ) - return len(data) - return await self._fallback.writetext( - abs_path, data, mode="a", encoding=encoding, errors=errors - ) - - if not self._supports_write: - return await self._fallback.writetext( - abs_path, data, mode=mode, encoding=encoding, errors=errors - ) - - await self._client.write_text_file( - path=abs_path, - content=data, - session_id=self._session_id, - ) - return len(data) - - async def mkdir( - self, path: StrOrKaosPath, parents: bool = False, exist_ok: bool = False - ) -> None: - await self._fallback.mkdir(path, parents=parents, exist_ok=exist_ok) - - async def exec(self, *args: str, env: Mapping[str, str] | None = None) -> KaosProcess: - return await self._fallback.exec(*args, env=env) - - def _abs_path(self, path: StrOrKaosPath) -> str: - kaos_path = path if isinstance(path, KaosPath) else KaosPath(path) - return str(kaos_path.canonical()) diff --git a/src/kimi_cli/acp/mcp.py b/src/kimi_cli/acp/mcp.py deleted file mode 100644 index b3acc6df0..000000000 --- a/src/kimi_cli/acp/mcp.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import acp.schema -from fastmcp.mcp_config import MCPConfig -from pydantic import ValidationError - -from kimi_cli.acp.types import MCPServer -from kimi_cli.exception import MCPConfigError - - -def acp_mcp_servers_to_mcp_config(mcp_servers: list[MCPServer]) -> MCPConfig: - if not mcp_servers: - return MCPConfig() - - try: - return MCPConfig.model_validate( - {"mcpServers": {server.name: _convert_acp_mcp_server(server) for server in mcp_servers}} - ) - except ValidationError as exc: - raise MCPConfigError(f"Invalid MCP config from ACP client: {exc}") from exc - - -def _convert_acp_mcp_server(server: MCPServer) -> dict[str, Any]: - """Convert an ACP MCP server to a dictionary representation.""" - match server: - case acp.schema.HttpMcpServer(): - return { - "url": server.url, - "transport": "http", - "headers": {header.name: header.value for header in server.headers}, - } - case acp.schema.SseMcpServer(): - return { - "url": server.url, - "transport": "sse", - "headers": {header.name: header.value for header in server.headers}, - } - case acp.schema.McpServerStdio(): - return { - "command": server.command, - "args": server.args, - "env": {item.name: item.value for item in server.env}, - "transport": "stdio", - } diff --git a/src/kimi_cli/acp/server.py b/src/kimi_cli/acp/server.py deleted file mode 100644 index 31852d036..000000000 --- a/src/kimi_cli/acp/server.py +++ /dev/null @@ -1,457 +0,0 @@ -from __future__ import annotations - -import asyncio -import sys -import time -from datetime import datetime -from pathlib import Path -from typing import Any, NamedTuple - -import acp -from kaos.path import KaosPath - -from kimi_cli.acp.kaos import ACPKaos -from kimi_cli.acp.mcp import acp_mcp_servers_to_mcp_config -from kimi_cli.acp.session import ACPSession -from kimi_cli.acp.tools import replace_tools -from kimi_cli.acp.types import ACPContentBlock, MCPServer -from kimi_cli.acp.version import ACPVersionSpec, negotiate_version -from kimi_cli.app import KimiCLI -from kimi_cli.auth.oauth import KIMI_CODE_OAUTH_KEY, load_tokens -from kimi_cli.config import LLMModel, OAuthRef, load_config, save_config -from kimi_cli.constant import NAME, VERSION -from kimi_cli.llm import create_llm, derive_model_capabilities -from kimi_cli.session import Session -from kimi_cli.soul.slash import registry as soul_slash_registry -from kimi_cli.soul.toolset import KimiToolset -from kimi_cli.utils.logging import logger - - -class ACPServer: - def __init__(self) -> None: - self.client_capabilities: acp.schema.ClientCapabilities | None = None - self.conn: acp.Client | None = None - self.sessions: dict[str, tuple[ACPSession, _ModelIDConv]] = {} - self.negotiated_version: ACPVersionSpec | None = None - self._auth_methods: list[acp.schema.AuthMethod] = [] - - def on_connect(self, conn: acp.Client) -> None: - logger.info("ACP client connected") - self.conn = conn - - async def initialize( - self, - protocol_version: int, - client_capabilities: acp.schema.ClientCapabilities | None = None, - client_info: acp.schema.Implementation | None = None, - **kwargs: Any, - ) -> acp.InitializeResponse: - self.negotiated_version = negotiate_version(protocol_version) - logger.info( - "ACP server initialized with client protocol version: {version}, " - "negotiated version: {negotiated}, " - "client capabilities: {capabilities}, client info: {info}", - version=protocol_version, - negotiated=self.negotiated_version, - capabilities=client_capabilities, - info=client_info, - ) - self.client_capabilities = client_capabilities - - # get command and args of current process for terminal-auth - command = sys.argv[0] - args: list[str] = [] - - # Build terminal auth data for error response - terminal_args = args + ["login"] - - # Build and cache auth methods for reuse in AUTH_REQUIRED errors - self._auth_methods = [ - acp.schema.AuthMethod( - id="login", - name="Login with Kimi account", - description=( - "Run `kimi login` command in the terminal, " - "then follow the instructions to finish login." - ), - # Store auth data in field_meta for building AUTH_REQUIRED error - field_meta={ - "terminal-auth": { - "command": command, - "args": terminal_args, - "label": "Kimi Code Login", - "env": {}, - "type": "terminal", - } - }, - ), - ] - - return acp.InitializeResponse( - protocol_version=self.negotiated_version.protocol_version, - agent_capabilities=acp.schema.AgentCapabilities( - load_session=True, - prompt_capabilities=acp.schema.PromptCapabilities( - embedded_context=True, image=True, audio=False - ), - mcp_capabilities=acp.schema.McpCapabilities(http=True, sse=False), - session_capabilities=acp.schema.SessionCapabilities( - list=acp.schema.SessionListCapabilities(), - resume=acp.schema.SessionResumeCapabilities(), - ), - ), - auth_methods=self._auth_methods, - agent_info=acp.schema.Implementation(name=NAME, version=VERSION), - ) - - @staticmethod - def _check_token_usable() -> str | None: - """Return ``None`` if the persisted OAuth token is usable, else a reason string.""" - ref = OAuthRef(storage="file", key=KIMI_CODE_OAUTH_KEY) - token = load_tokens(ref) - - if token is None or not token.access_token: - return "no valid token found" - if token.expires_at and token.expires_at < time.time() and not token.refresh_token: - # Token expired and no refresh token — background refresh cannot help. - return "token expired and no refresh token available" - return None - - def _check_auth(self) -> None: - """Check if Kimi Code authentication is complete. Raise AUTH_REQUIRED if not.""" - reason = self._check_token_usable() - if reason: - auth_methods_data: list[dict[str, Any]] = [] - for m in self._auth_methods: - if m.field_meta and "terminal-auth" in m.field_meta: - terminal_auth = m.field_meta["terminal-auth"] - auth_methods_data.append( - { - "id": m.id, - "name": m.name, - "description": m.description, - "type": terminal_auth.get("type", "terminal"), - "args": terminal_auth.get("args", []), - "env": terminal_auth.get("env", {}), - } - ) - - logger.warning("Authentication required, {reason}", reason=reason) - raise acp.RequestError.auth_required({"authMethods": auth_methods_data}) - - async def new_session( - self, cwd: str, mcp_servers: list[MCPServer] | None = None, **kwargs: Any - ) -> acp.NewSessionResponse: - logger.info("Creating new session for working directory: {cwd}", cwd=cwd) - assert self.conn is not None, "ACP client not connected" - assert self.client_capabilities is not None, "ACP connection not initialized" - - # Check authentication before creating session - self._check_auth() - - session = await Session.create(KaosPath.unsafe_from_local_path(Path(cwd))) - - mcp_config = acp_mcp_servers_to_mcp_config(mcp_servers or []) - cli_instance = await KimiCLI.create( - session, - mcp_configs=[mcp_config], - ) - config = cli_instance.soul.runtime.config - acp_kaos = ACPKaos(self.conn, session.id, self.client_capabilities) - acp_session = ACPSession(session.id, cli_instance, self.conn, kaos=acp_kaos) - model_id_conv = _ModelIDConv(config.default_model, config.default_thinking) - self.sessions[session.id] = (acp_session, model_id_conv) - - if isinstance(cli_instance.soul.agent.toolset, KimiToolset): - replace_tools( - self.client_capabilities, - self.conn, - session.id, - cli_instance.soul.agent.toolset, - cli_instance.soul.runtime, - ) - - available_commands = [ - acp.schema.AvailableCommand(name=cmd.name, description=cmd.description) - for cmd in soul_slash_registry.list_commands() - ] - asyncio.create_task( - self.conn.session_update( - session_id=session.id, - update=acp.schema.AvailableCommandsUpdate( - session_update="available_commands_update", - available_commands=available_commands, - ), - ) - ) - return acp.NewSessionResponse( - session_id=session.id, - modes=acp.schema.SessionModeState( - available_modes=[ - acp.schema.SessionMode( - id="default", - name="Default", - description="The default mode.", - ), - ], - current_mode_id="default", - ), - models=acp.schema.SessionModelState( - available_models=_expand_llm_models(config.models), - current_model_id=model_id_conv.to_acp_model_id(), - ), - ) - - async def _setup_session( - self, - cwd: str, - session_id: str, - mcp_servers: list[MCPServer] | None = None, - ) -> tuple[ACPSession, _ModelIDConv]: - """Load or resume a session. Shared by load_session and resume_session.""" - assert self.conn is not None, "ACP client not connected" - assert self.client_capabilities is not None, "ACP connection not initialized" - - work_dir = KaosPath.unsafe_from_local_path(Path(cwd)) - session = await Session.find(work_dir, session_id) - if session is None: - logger.error( - "Session not found: {id} for working directory: {cwd}", id=session_id, cwd=cwd - ) - raise acp.RequestError.invalid_params({"session_id": "Session not found"}) - - mcp_config = acp_mcp_servers_to_mcp_config(mcp_servers or []) - cli_instance = await KimiCLI.create( - session, - mcp_configs=[mcp_config], - ) - config = cli_instance.soul.runtime.config - acp_kaos = ACPKaos(self.conn, session.id, self.client_capabilities) - acp_session = ACPSession(session.id, cli_instance, self.conn, kaos=acp_kaos) - model_id_conv = _ModelIDConv(config.default_model, config.default_thinking) - self.sessions[session.id] = (acp_session, model_id_conv) - - if isinstance(cli_instance.soul.agent.toolset, KimiToolset): - replace_tools( - self.client_capabilities, - self.conn, - session.id, - cli_instance.soul.agent.toolset, - cli_instance.soul.runtime, - ) - - return acp_session, model_id_conv - - async def load_session( - self, cwd: str, session_id: str, mcp_servers: list[MCPServer] | None = None, **kwargs: Any - ) -> None: - logger.info("Loading session: {id} for working directory: {cwd}", id=session_id, cwd=cwd) - - if session_id in self.sessions: - logger.warning("Session already loaded: {id}", id=session_id) - return - - # Check authentication before loading session - self._check_auth() - - await self._setup_session(cwd, session_id, mcp_servers) - # TODO: replay session history? - - async def resume_session( - self, cwd: str, session_id: str, mcp_servers: list[MCPServer] | None = None, **kwargs: Any - ) -> acp.schema.ResumeSessionResponse: - logger.info("Resuming session: {id} for working directory: {cwd}", id=session_id, cwd=cwd) - - if session_id not in self.sessions: - await self._setup_session(cwd, session_id, mcp_servers) - - acp_session, model_id_conv = self.sessions[session_id] - config = acp_session.cli.soul.runtime.config - return acp.schema.ResumeSessionResponse( - modes=acp.schema.SessionModeState( - available_modes=[ - acp.schema.SessionMode( - id="default", - name="Default", - description="The default mode.", - ), - ], - current_mode_id="default", - ), - models=acp.schema.SessionModelState( - available_models=_expand_llm_models(config.models), - current_model_id=model_id_conv.to_acp_model_id(), - ), - ) - - async def fork_session( - self, cwd: str, session_id: str, mcp_servers: list[MCPServer] | None = None, **kwargs: Any - ) -> acp.schema.ForkSessionResponse: - raise NotImplementedError - - async def list_sessions( - self, cursor: str | None = None, cwd: str | None = None, **kwargs: Any - ) -> acp.schema.ListSessionsResponse: - logger.info("Listing sessions for working directory: {cwd}", cwd=cwd) - if cwd is None: - return acp.schema.ListSessionsResponse(sessions=[], next_cursor=None) - work_dir = KaosPath.unsafe_from_local_path(Path(cwd)) - sessions = await Session.list(work_dir) - return acp.schema.ListSessionsResponse( - sessions=[ - acp.schema.SessionInfo( - cwd=cwd, - session_id=s.id, - title=s.title, - updated_at=datetime.fromtimestamp(s.updated_at).astimezone().isoformat(), - ) - for s in sessions - ], - next_cursor=None, - ) - - async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> None: - assert mode_id == "default", "Only default mode is supported" - - async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any) -> None: - logger.info( - "Setting session model to {model_id} for session: {id}", - model_id=model_id, - id=session_id, - ) - if session_id not in self.sessions: - logger.error("Session not found: {id}", id=session_id) - raise acp.RequestError.invalid_params({"session_id": "Session not found"}) - - acp_session, current_model_id = self.sessions[session_id] - cli_instance = acp_session.cli - model_id_conv = _ModelIDConv.from_acp_model_id(model_id) - if model_id_conv == current_model_id: - return - - config = cli_instance.soul.runtime.config - new_model = config.models.get(model_id_conv.model_key) - if new_model is None: - logger.error("Model not found: {model_key}", model_key=model_id_conv.model_key) - raise acp.RequestError.invalid_params({"model_id": "Model not found"}) - new_provider = config.providers.get(new_model.provider) - if new_provider is None: - logger.error( - "Provider not found: {provider} for model: {model_key}", - provider=new_model.provider, - model_key=model_id_conv.model_key, - ) - raise acp.RequestError.invalid_params({"model_id": "Model's provider not found"}) - - new_llm = create_llm( - new_provider, - new_model, - session_id=acp_session.id, - thinking=model_id_conv.thinking, - oauth=cli_instance.soul.runtime.oauth, - ) - cli_instance.soul.runtime.llm = new_llm - - config.default_model = model_id_conv.model_key - config.default_thinking = model_id_conv.thinking - assert config.is_from_default_location, "`kimi acp` must use the default config location" - config_for_save = load_config() - config_for_save.default_model = model_id_conv.model_key - config_for_save.default_thinking = model_id_conv.thinking - save_config(config_for_save) - - async def authenticate(self, method_id: str, **kwargs: Any) -> acp.AuthenticateResponse | None: - """ - For Terminal Auth, this method is typically not called directly - (user completes auth in terminal). Implement for completeness. - """ - if method_id == "login": - reason = self._check_token_usable() - if reason is None: - logger.info("Authentication successful for method: {id}", id=method_id) - return acp.AuthenticateResponse() - else: - logger.warning( - "Authentication not complete for method: {id} ({reason})", - id=method_id, - reason=reason, - ) - raise acp.RequestError.auth_required( - { - "message": "Please complete login in terminal first", - "authMethods": self._auth_methods, - } - ) - - logger.error("Unknown auth method: {method_id}", method_id=method_id) - raise acp.RequestError.invalid_params({"method_id": "Unknown auth method"}) - - async def prompt( - self, prompt: list[ACPContentBlock], session_id: str, **kwargs: Any - ) -> acp.PromptResponse: - logger.info("Received prompt request for session: {id}", id=session_id) - if session_id not in self.sessions: - logger.error("Session not found: {id}", id=session_id) - raise acp.RequestError.invalid_params({"session_id": "Session not found"}) - acp_session, *_ = self.sessions[session_id] - return await acp_session.prompt(prompt) - - async def cancel(self, session_id: str, **kwargs: Any) -> None: - logger.info("Received cancel request for session: {id}", id=session_id) - if session_id not in self.sessions: - logger.error("Session not found: {id}", id=session_id) - raise acp.RequestError.invalid_params({"session_id": "Session not found"}) - acp_session, *_ = self.sessions[session_id] - await acp_session.cancel() - - async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: - raise NotImplementedError - - async def ext_notification(self, method: str, params: dict[str, Any]) -> None: - raise NotImplementedError - - -class _ModelIDConv(NamedTuple): - model_key: str - thinking: bool - - @classmethod - def from_acp_model_id(cls, model_id: str) -> _ModelIDConv: - if model_id.endswith(",thinking"): - return _ModelIDConv(model_id[: -len(",thinking")], True) - return _ModelIDConv(model_id, False) - - def to_acp_model_id(self) -> str: - if self.thinking: - return f"{self.model_key},thinking" - return self.model_key - - -def _expand_llm_models(models: dict[str, LLMModel]) -> list[acp.schema.ModelInfo]: - expanded_models: list[acp.schema.ModelInfo] = [] - for model_key, model in models.items(): - capabilities = derive_model_capabilities(model) - if "thinking" in model.model or "reason" in model.model: - # always-thinking models - expanded_models.append( - acp.schema.ModelInfo( - model_id=_ModelIDConv(model_key, True).to_acp_model_id(), - name=f"{model.model}", - ) - ) - else: - expanded_models.append( - acp.schema.ModelInfo( - model_id=model_key, - name=model.model, - ) - ) - if "thinking" in capabilities: - # add thinking variant - expanded_models.append( - acp.schema.ModelInfo( - model_id=_ModelIDConv(model_key, True).to_acp_model_id(), - name=f"{model.model} (thinking)", - ) - ) - return expanded_models diff --git a/src/kimi_cli/acp/session.py b/src/kimi_cli/acp/session.py deleted file mode 100644 index 1d0939dcb..000000000 --- a/src/kimi_cli/acp/session.py +++ /dev/null @@ -1,496 +0,0 @@ -from __future__ import annotations - -import asyncio -import uuid -from contextvars import ContextVar - -import acp -import streamingjson # type: ignore[reportMissingTypeStubs] -from kaos import Kaos, reset_current_kaos, set_current_kaos -from kosong.chat_provider import APIStatusError, ChatProviderError - -from kimi_cli.acp.convert import ( - acp_blocks_to_content_parts, - display_block_to_acp_content, - tool_result_to_acp_content, -) -from kimi_cli.acp.types import ACPContentBlock -from kimi_cli.app import KimiCLI -from kimi_cli.soul import LLMNotSet, LLMNotSupported, MaxStepsReached, RunCancelled -from kimi_cli.tools import extract_key_argument -from kimi_cli.utils.logging import logger -from kimi_cli.wire.types import ( - ApprovalRequest, - ApprovalResponse, - CompactionBegin, - CompactionEnd, - ContentPart, - MCPLoadingBegin, - MCPLoadingEnd, - Notification, - PlanDisplay, - QuestionRequest, - StatusUpdate, - SteerInput, - StepBegin, - StepInterrupted, - SubagentEvent, - TextPart, - ThinkPart, - TodoDisplayBlock, - ToolCall, - ToolCallPart, - ToolCallRequest, - ToolResult, - TurnBegin, - TurnEnd, -) - -_current_turn_id = ContextVar[str | None]("current_turn_id", default=None) -_terminal_tool_call_ids = ContextVar[set[str] | None]("terminal_tool_call_ids", default=None) - - -def get_current_acp_tool_call_id_or_none() -> str | None: - """See `_ToolCallState.acp_tool_call_id`.""" - from kimi_cli.soul.toolset import get_current_tool_call_or_none - - turn_id = _current_turn_id.get() - if turn_id is None: - return None - tool_call = get_current_tool_call_or_none() - if tool_call is None: - return None - return f"{turn_id}/{tool_call.id}" - - -def register_terminal_tool_call_id(tool_call_id: str) -> None: - calls = _terminal_tool_call_ids.get() - if calls is not None: - calls.add(tool_call_id) - - -def should_hide_terminal_output(tool_call_id: str) -> bool: - calls = _terminal_tool_call_ids.get() - return calls is not None and tool_call_id in calls - - -class _ToolCallState: - """Manages the state of a single tool call for streaming updates.""" - - def __init__(self, tool_call: ToolCall): - self.tool_call = tool_call - self.args = tool_call.function.arguments or "" - self.lexer = streamingjson.Lexer() - if tool_call.function.arguments is not None: - self.lexer.append_string(tool_call.function.arguments) - - @property - def acp_tool_call_id(self) -> str: - # When the user rejected or cancelled a tool call, the step result may not - # be appended to the context. In this case, future step may emit tool call - # with the same tool call ID (on the LLM side). To avoid confusion of the - # ACP client, we ensure the uniqueness by prefixing with the turn ID. - turn_id = _current_turn_id.get() - assert turn_id is not None - return f"{turn_id}/{self.tool_call.id}" - - def append_args_part(self, args_part: str) -> None: - """Append a new arguments part to the accumulated args and lexer.""" - self.args += args_part - self.lexer.append_string(args_part) - - def get_title(self) -> str: - """Get the current title with subtitle if available.""" - tool_name = self.tool_call.function.name - subtitle = extract_key_argument(self.lexer, tool_name) - if subtitle: - return f"{tool_name}: {subtitle}" - return tool_name - - -class _TurnState: - def __init__(self): - self.id = str(uuid.uuid4()) - """Unique ID for the turn.""" - self.tool_calls: dict[str, _ToolCallState] = {} - """Map of tool call ID (LLM-side ID) to tool call state.""" - self.last_tool_call: _ToolCallState | None = None - self.cancel_event = asyncio.Event() - - -class ACPSession: - def __init__( - self, - id: str, - cli: KimiCLI, - acp_conn: acp.Client, - kaos: Kaos | None = None, - ) -> None: - self._id = id - self._cli = cli - self._conn = acp_conn - self._kaos = kaos - self._turn_state: _TurnState | None = None - - @property - def id(self) -> str: - """The ID of the ACP session.""" - return self._id - - @property - def cli(self) -> KimiCLI: - """The Kimi Code CLI instance bound to this ACP session.""" - return self._cli - - def _is_oauth_session(self) -> bool: - """Return True if the current session uses OAuth-based authentication.""" - try: - llm = self._cli.soul.runtime.llm - return llm is not None and getattr(llm.provider_config, "oauth", None) is not None - except AttributeError: - return False - - async def prompt(self, prompt: list[ACPContentBlock]) -> acp.PromptResponse: - user_input = acp_blocks_to_content_parts(prompt) - self._turn_state = _TurnState() - token = _current_turn_id.set(self._turn_state.id) - kaos_token = set_current_kaos(self._kaos) if self._kaos is not None else None - terminal_tool_calls_token = _terminal_tool_call_ids.set(set()) - try: - async for msg in self._cli.run(user_input, self._turn_state.cancel_event): - match msg: - case TurnBegin(): - pass - case SteerInput(): - pass - case TurnEnd(): - pass - case StepBegin(): - pass - case StepInterrupted(): - break - case CompactionBegin(): - pass - case CompactionEnd(): - pass - case MCPLoadingBegin(): - pass - case MCPLoadingEnd(): - pass - case StatusUpdate(): - pass - case Notification(): - await self._send_notification(msg) - case ThinkPart(think=think): - await self._send_thinking(think) - case TextPart(text=text): - await self._send_text(text) - case ContentPart(): - logger.warning("Unsupported content part: {part}", part=msg) - await self._send_text(f"[{msg.__class__.__name__}]") - case ToolCall(): - await self._send_tool_call(msg) - case ToolCallPart(): - await self._send_tool_call_part(msg) - case ToolResult(): - await self._send_tool_result(msg) - case ApprovalResponse(): - pass - case SubagentEvent(): - pass - case PlanDisplay(): - pass - case ApprovalRequest(): - await self._handle_approval_request(msg) - case ToolCallRequest(): - logger.warning("Unexpected ToolCallRequest in ACP session: {msg}", msg=msg) - case QuestionRequest(): - logger.warning( - "QuestionRequest is unsupported in ACP session; resolving empty answer." - ) - msg.resolve({}) - case _: - pass - except LLMNotSet as e: - logger.exception("LLM not set:") - raise acp.RequestError.auth_required() from e - except LLMNotSupported as e: - logger.exception("LLM not supported:") - raise acp.RequestError.internal_error({"error": str(e)}) from e - except APIStatusError as e: - if e.status_code == 401 and self._is_oauth_session(): - logger.warning("Authentication failed (401), prompting re-login") - raise acp.RequestError.auth_required() from e - logger.exception("LLM API status error:") - raise acp.RequestError.internal_error({"error": str(e)}) from e - except ChatProviderError as e: - logger.exception("LLM provider error:") - raise acp.RequestError.internal_error({"error": str(e)}) from e - except MaxStepsReached as e: - logger.warning("Max steps reached: {n_steps}", n_steps=e.n_steps) - return acp.PromptResponse(stop_reason="max_turn_requests") - except RunCancelled: - logger.info("Prompt cancelled by user") - return acp.PromptResponse(stop_reason="cancelled") - except Exception as e: - logger.exception("Unexpected error during prompt:") - raise acp.RequestError.internal_error({"error": str(e)}) from e - finally: - self._turn_state = None - if kaos_token is not None: - reset_current_kaos(kaos_token) - _terminal_tool_call_ids.reset(terminal_tool_calls_token) - _current_turn_id.reset(token) - return acp.PromptResponse(stop_reason="end_turn") - - async def cancel(self) -> None: - if self._turn_state is None: - logger.warning("Cancel requested but no prompt is running") - return - - self._turn_state.cancel_event.set() - - async def _send_thinking(self, think: str): - """Send thinking content to client.""" - if not self._id or not self._conn: - return - - await self._conn.session_update( - self._id, - acp.schema.AgentThoughtChunk( - content=acp.schema.TextContentBlock(type="text", text=think), - session_update="agent_thought_chunk", - ), - ) - - async def _send_text(self, text: str): - """Send text chunk to client.""" - if not self._id or not self._conn: - return - - await self._conn.session_update( - session_id=self._id, - update=acp.schema.AgentMessageChunk( - content=acp.schema.TextContentBlock(type="text", text=text), - session_update="agent_message_chunk", - ), - ) - - async def _send_notification(self, notification: Notification): - """Send a system notification to the client as a text chunk.""" - body = notification.body.strip() - text = f"[Notification] {notification.title}" - if body: - text = f"{text}\n{body}" - await self._send_text(text) - - async def _send_tool_call(self, tool_call: ToolCall): - """Send tool call to client.""" - assert self._turn_state is not None - if not self._id or not self._conn: - return - - # Create and store tool call state - state = _ToolCallState(tool_call) - self._turn_state.tool_calls[tool_call.id] = state - self._turn_state.last_tool_call = state - - await self._conn.session_update( - session_id=self._id, - update=acp.schema.ToolCallStart( - session_update="tool_call", - tool_call_id=state.acp_tool_call_id, - title=state.get_title(), - status="in_progress", - content=[ - acp.schema.ContentToolCallContent( - type="content", - content=acp.schema.TextContentBlock(type="text", text=state.args), - ) - ], - ), - ) - logger.debug("Sent tool call: {name}", name=tool_call.function.name) - - async def _send_tool_call_part(self, part: ToolCallPart): - """Send tool call part (streaming arguments).""" - assert self._turn_state is not None - if ( - not self._id - or not self._conn - or not part.arguments_part - or self._turn_state.last_tool_call is None - ): - return - - # Append new arguments part to the last tool call - self._turn_state.last_tool_call.append_args_part(part.arguments_part) - - # Update the tool call with new content and title - update = acp.schema.ToolCallProgress( - session_update="tool_call_update", - tool_call_id=self._turn_state.last_tool_call.acp_tool_call_id, - title=self._turn_state.last_tool_call.get_title(), - status="in_progress", - content=[ - acp.schema.ContentToolCallContent( - type="content", - content=acp.schema.TextContentBlock( - type="text", text=self._turn_state.last_tool_call.args - ), - ) - ], - ) - - await self._conn.session_update(session_id=self._id, update=update) - logger.debug("Sent tool call update: {delta}", delta=part.arguments_part[:50]) - - async def _send_tool_result(self, result: ToolResult): - """Send tool result to client.""" - assert self._turn_state is not None - if not self._id or not self._conn: - return - - tool_ret = result.return_value - - state = self._turn_state.tool_calls.pop(result.tool_call_id, None) - if state is None: - logger.warning("Tool call not found: {id}", id=result.tool_call_id) - return - - update = acp.schema.ToolCallProgress( - session_update="tool_call_update", - tool_call_id=state.acp_tool_call_id, - status="failed" if tool_ret.is_error else "completed", - ) - - contents = ( - [] - if should_hide_terminal_output(state.acp_tool_call_id) - else tool_result_to_acp_content(tool_ret) - ) - if contents: - update.content = contents - - await self._conn.session_update(session_id=self._id, update=update) - logger.debug("Sent tool result: {id}", id=result.tool_call_id) - - for block in tool_ret.display: - if isinstance(block, TodoDisplayBlock): - await self._send_plan_update(block) - - async def _handle_approval_request(self, request: ApprovalRequest): - """Handle approval request by sending permission request to client.""" - assert self._turn_state is not None - if not self._id or not self._conn: - logger.warning("No session ID, auto-rejecting approval request") - request.resolve("reject") - return - - state = self._turn_state.tool_calls.get(request.tool_call_id, None) - if state is None: - logger.warning("Tool call not found: {id}", id=request.tool_call_id) - request.resolve("reject") - return - - try: - content: list[ - acp.schema.ContentToolCallContent - | acp.schema.FileEditToolCallContent - | acp.schema.TerminalToolCallContent - ] = [] - if request.display: - for block in request.display: - diff_content = display_block_to_acp_content(block) - if diff_content is not None: - content.append(diff_content) - if not content: - content.append( - acp.schema.ContentToolCallContent( - type="content", - content=acp.schema.TextContentBlock( - type="text", - text=f"Requesting approval to perform: {request.description}", - ), - ) - ) - - # Send permission request and wait for response - logger.debug("Requesting permission for action: {action}", action=request.action) - response = await self._conn.request_permission( - [ - acp.schema.PermissionOption( - option_id="approve", - name="Approve once", - kind="allow_once", - ), - acp.schema.PermissionOption( - option_id="approve_for_session", - name="Approve for this session", - kind="allow_always", - ), - acp.schema.PermissionOption( - option_id="reject", - name="Reject", - kind="reject_once", - ), - ], - self._id, - acp.schema.ToolCallUpdate( - tool_call_id=state.acp_tool_call_id, - title=state.get_title(), - content=content, - ), - ) - logger.debug("Received permission response: {response}", response=response) - - # Process the outcome - if isinstance(response.outcome, acp.schema.AllowedOutcome): - # selected - option_id = response.outcome.option_id - if option_id == "approve": - logger.debug("Permission granted for: {action}", action=request.action) - request.resolve("approve") - elif option_id == "approve_for_session": - logger.debug("Permission granted for session: {action}", action=request.action) - request.resolve("approve_for_session") - else: - logger.debug("Permission denied for: {action}", action=request.action) - request.resolve("reject") - else: - # cancelled - logger.debug("Permission request cancelled for: {action}", action=request.action) - request.resolve("reject") - except Exception: - logger.exception("Error handling approval request:") - # On error, reject the request - request.resolve("reject") - - async def _send_plan_update(self, block: TodoDisplayBlock) -> None: - """Send todo list updates as ACP agent plan updates.""" - - status_map: dict[str, acp.schema.PlanEntryStatus] = { - "pending": "pending", - "in progress": "in_progress", - "in_progress": "in_progress", - "done": "completed", - "completed": "completed", - } - entries: list[acp.schema.PlanEntry] = [ - acp.schema.PlanEntry( - content=todo.title, - priority="medium", - status=status_map.get(todo.status.lower(), "pending"), - ) - for todo in block.items - if todo.title - ] - - if not entries: - logger.warning("No valid todo items to send in plan update: {todos}", todos=block.items) - return - - await self._conn.session_update( - session_id=self._id, - update=acp.schema.AgentPlanUpdate(session_update="plan", entries=entries), - ) diff --git a/src/kimi_cli/acp/tools.py b/src/kimi_cli/acp/tools.py deleted file mode 100644 index 055c9edb3..000000000 --- a/src/kimi_cli/acp/tools.py +++ /dev/null @@ -1,167 +0,0 @@ -import asyncio -from contextlib import suppress - -import acp -from kaos import get_current_kaos -from kaos.local import local_kaos -from kosong.tooling import CallableTool2, ToolReturnValue - -from kimi_cli.soul.agent import Runtime -from kimi_cli.soul.approval import Approval -from kimi_cli.soul.toolset import KimiToolset -from kimi_cli.tools.shell import Params as ShellParams -from kimi_cli.tools.shell import Shell -from kimi_cli.tools.utils import ToolResultBuilder -from kimi_cli.wire.types import DisplayBlock - - -def replace_tools( - client_capabilities: acp.schema.ClientCapabilities, - acp_conn: acp.Client, - acp_session_id: str, - toolset: KimiToolset, - runtime: Runtime, -) -> None: - current_kaos = get_current_kaos().name - if current_kaos not in (local_kaos.name, "acp"): - # Only replace tools when running locally or under ACPKaos. - return - - if client_capabilities.terminal and (shell_tool := toolset.find(Shell)): - # Replace the Shell tool with the ACP Terminal tool if supported. - toolset.add( - Terminal( - shell_tool, - acp_conn, - acp_session_id, - runtime.approval, - ) - ) - - -class HideOutputDisplayBlock(DisplayBlock): - """A special DisplayBlock that indicates output should be hidden in ACP clients.""" - - type: str = "acp/hide_output" - - -class Terminal(CallableTool2[ShellParams]): - def __init__( - self, - shell_tool: Shell, - acp_conn: acp.Client, - acp_session_id: str, - approval: Approval, - ) -> None: - # Use the `name`, `description`, and `params` from the existing Shell tool, - # so that when this is added to the toolset, it replaces the original Shell tool. - super().__init__(shell_tool.name, shell_tool.description, shell_tool.params) - self._acp_conn = acp_conn - self._acp_session_id = acp_session_id - self._approval = approval - - async def __call__(self, params: ShellParams) -> ToolReturnValue: - from kimi_cli.acp.session import get_current_acp_tool_call_id_or_none - - builder = ToolResultBuilder() - # Hide tool output because we use `TerminalToolCallContent` which already streams output - # directly to the user. - builder.display(HideOutputDisplayBlock()) - - if not params.command: - return builder.error("Command cannot be empty.", brief="Empty command") - - approval_result = await self._approval.request( - self.name, - "run shell command", - f"Run command `{params.command}`", - ) - if not approval_result: - return approval_result.rejection_error() - - timeout_seconds = float(params.timeout) - timeout_label = f"{timeout_seconds:g}s" - terminal_id: str | None = None - exit_status: ( - acp.schema.WaitForTerminalExitResponse | acp.schema.TerminalExitStatus | None - ) = None - timed_out = False - - try: - resp = await self._acp_conn.create_terminal( - command=params.command, - session_id=self._acp_session_id, - output_byte_limit=builder.max_chars, - ) - terminal_id = resp.terminal_id - - acp_tool_call_id = get_current_acp_tool_call_id_or_none() - assert acp_tool_call_id, "Expected to have an ACP tool call ID in context" - await self._acp_conn.session_update( - session_id=self._acp_session_id, - update=acp.schema.ToolCallProgress( - session_update="tool_call_update", - tool_call_id=acp_tool_call_id, - status="in_progress", - content=[ - acp.schema.TerminalToolCallContent( - type="terminal", - terminal_id=terminal_id, - ) - ], - ), - ) - - try: - async with asyncio.timeout(timeout_seconds): - exit_status = await self._acp_conn.wait_for_terminal_exit( - session_id=self._acp_session_id, - terminal_id=terminal_id, - ) - except TimeoutError: - timed_out = True - await self._acp_conn.kill_terminal( - session_id=self._acp_session_id, - terminal_id=terminal_id, - ) - - output_response = await self._acp_conn.terminal_output( - session_id=self._acp_session_id, - terminal_id=terminal_id, - ) - builder.write(output_response.output) - if output_response.exit_status: - exit_status = output_response.exit_status - - exit_code = exit_status.exit_code if exit_status else None - exit_signal = exit_status.signal if exit_status else None - - truncated_note = ( - " Output was truncated by the client output limit." - if output_response.truncated - else "" - ) - - if timed_out: - return builder.error( - f"Command killed by timeout ({timeout_label}){truncated_note}", - brief=f"Killed by timeout ({timeout_label})", - ) - if exit_signal: - return builder.error( - f"Command terminated by signal: {exit_signal}.{truncated_note}", - brief=f"Signal: {exit_signal}", - ) - if exit_code not in (None, 0): - return builder.error( - f"Command failed with exit code: {exit_code}.{truncated_note}", - brief=f"Failed with exit code: {exit_code}", - ) - return builder.ok(f"Command executed successfully.{truncated_note}") - finally: - if terminal_id is not None: - with suppress(Exception): - await self._acp_conn.release_terminal( - session_id=self._acp_session_id, - terminal_id=terminal_id, - ) diff --git a/src/kimi_cli/acp/types.py b/src/kimi_cli/acp/types.py deleted file mode 100644 index 288e83dc2..000000000 --- a/src/kimi_cli/acp/types.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -import acp - -MCPServer = acp.schema.HttpMcpServer | acp.schema.SseMcpServer | acp.schema.McpServerStdio - -ACPContentBlock = ( - acp.schema.TextContentBlock - | acp.schema.ImageContentBlock - | acp.schema.AudioContentBlock - | acp.schema.ResourceContentBlock - | acp.schema.EmbeddedResourceContentBlock -) diff --git a/src/kimi_cli/acp/version.py b/src/kimi_cli/acp/version.py deleted file mode 100644 index 6c51cd0d3..000000000 --- a/src/kimi_cli/acp/version.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True) -class ACPVersionSpec: - """Describes one supported ACP protocol version.""" - - protocol_version: int # negotiation integer (currently 1) - spec_tag: str # ACP spec tag (e.g. "v0.10.8") - sdk_version: str # corresponding SDK version (e.g. "0.8.0") - - -CURRENT_VERSION = ACPVersionSpec( - protocol_version=1, - spec_tag="v0.10.8", - sdk_version="0.8.0", -) - -SUPPORTED_VERSIONS: dict[int, ACPVersionSpec] = { - 1: CURRENT_VERSION, -} - -MIN_PROTOCOL_VERSION = 1 - - -def negotiate_version(client_protocol_version: int) -> ACPVersionSpec: - """Negotiate the protocol version with the client. - - Returns the highest server-supported version that does not exceed the - client's requested version. If the client version is lower than - ``MIN_PROTOCOL_VERSION`` the server still returns its own current - version so the client can decide whether to disconnect. - """ - if client_protocol_version < MIN_PROTOCOL_VERSION: - return CURRENT_VERSION - - # Find the highest supported version <= client version - best: ACPVersionSpec | None = None - for ver, spec in SUPPORTED_VERSIONS.items(): - if ver <= client_protocol_version and (best is None or ver > best.protocol_version): - best = spec - - return best if best is not None else CURRENT_VERSION diff --git a/src/kimi_cli/agents/okabe/agent.yaml b/src/kimi_cli/agents/okabe/agent.yaml deleted file mode 100644 index 06848a492..000000000 --- a/src/kimi_cli/agents/okabe/agent.yaml +++ /dev/null @@ -1,22 +0,0 @@ -version: 1 -agent: - extend: default - tools: - - "kimi_cli.tools.agent:Agent" - - "kimi_cli.tools.dmail:SendDMail" - - "kimi_cli.tools.ask_user:AskUserQuestion" - - "kimi_cli.tools.todo:SetTodoList" - - "kimi_cli.tools.shell:Shell" - - "kimi_cli.tools.background:TaskList" - - "kimi_cli.tools.background:TaskOutput" - - "kimi_cli.tools.background:TaskStop" - - "kimi_cli.tools.file:ReadFile" - - "kimi_cli.tools.file:ReadMediaFile" - - "kimi_cli.tools.file:Glob" - - "kimi_cli.tools.file:Grep" - - "kimi_cli.tools.file:WriteFile" - - "kimi_cli.tools.file:StrReplaceFile" - - "kimi_cli.tools.web:SearchWeb" - - "kimi_cli.tools.web:FetchURL" - - "kimi_cli.tools.plan:ExitPlanMode" - - "kimi_cli.tools.plan.enter:EnterPlanMode" diff --git a/src/kimi_cli/agentspec.py b/src/kimi_cli/agentspec.py deleted file mode 100644 index b2808db01..000000000 --- a/src/kimi_cli/agentspec.py +++ /dev/null @@ -1,160 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from pathlib import Path -from typing import Any, NamedTuple - -import yaml -from pydantic import BaseModel, Field - -from kimi_cli.exception import AgentSpecError - -DEFAULT_AGENT_SPEC_VERSION = "1" -SUPPORTED_AGENT_SPEC_VERSIONS = (DEFAULT_AGENT_SPEC_VERSION,) - - -def get_agents_dir() -> Path: - return Path(__file__).parent / "agents" - - -DEFAULT_AGENT_FILE = get_agents_dir() / "default" / "agent.yaml" -OKABE_AGENT_FILE = get_agents_dir() / "okabe" / "agent.yaml" - - -class Inherit(NamedTuple): - """Marker class for inheritance in agent spec.""" - - -inherit = Inherit() - - -class AgentSpec(BaseModel): - """Agent specification.""" - - extend: str | None = Field(default=None, description="Agent file to extend") - name: str | Inherit = Field(default=inherit, description="Agent name") # required - system_prompt_path: Path | Inherit = Field( - default=inherit, description="System prompt path" - ) # required - system_prompt_args: dict[str, str] = Field( - default_factory=dict, description="System prompt arguments" - ) - model: str | None = Field(default=None, description="Default model alias") - when_to_use: str | None = Field(default=None, description="Usage guidance") - tools: list[str] | None | Inherit = Field(default=inherit, description="Tools") # required - allowed_tools: list[str] | None | Inherit = Field(default=inherit, description="Allowed tools") - exclude_tools: list[str] | None | Inherit = Field( - default=inherit, description="Tools to exclude" - ) - subagents: dict[str, SubagentSpec] | None | Inherit = Field( - default=inherit, description="Subagents" - ) - - -class SubagentSpec(BaseModel): - """Subagent specification.""" - - path: Path = Field(description="Subagent file path") - description: str = Field(description="Subagent description") - - -@dataclass(frozen=True, slots=True, kw_only=True) -class ResolvedAgentSpec: - """Resolved agent specification.""" - - name: str - system_prompt_path: Path - system_prompt_args: dict[str, str] - model: str | None - when_to_use: str - tools: list[str] - allowed_tools: list[str] | None - exclude_tools: list[str] - subagents: dict[str, SubagentSpec] - - -def load_agent_spec(agent_file: Path) -> ResolvedAgentSpec: - """ - Load agent specification from file. - - Raises: - FileNotFoundError: If the agent spec file is not found. - AgentSpecError: If the agent spec is not valid. - """ - agent_spec = _load_agent_spec(agent_file) - assert agent_spec.extend is None, "agent extension should be recursively resolved" - if isinstance(agent_spec.name, Inherit): - raise AgentSpecError("Agent name is required") - if isinstance(agent_spec.system_prompt_path, Inherit): - raise AgentSpecError("System prompt path is required") - if isinstance(agent_spec.tools, Inherit): - raise AgentSpecError("Tools are required") - if isinstance(agent_spec.allowed_tools, Inherit): - agent_spec.allowed_tools = None - if isinstance(agent_spec.exclude_tools, Inherit): - agent_spec.exclude_tools = [] - if isinstance(agent_spec.subagents, Inherit): - agent_spec.subagents = {} - return ResolvedAgentSpec( - name=agent_spec.name, - system_prompt_path=agent_spec.system_prompt_path, - system_prompt_args=agent_spec.system_prompt_args, - model=agent_spec.model, - when_to_use=agent_spec.when_to_use or "", - tools=agent_spec.tools or [], - allowed_tools=agent_spec.allowed_tools, - exclude_tools=agent_spec.exclude_tools or [], - subagents=agent_spec.subagents or {}, - ) - - -def _load_agent_spec(agent_file: Path) -> AgentSpec: - if not agent_file.exists(): - raise AgentSpecError(f"Agent spec file not found: {agent_file}") - if not agent_file.is_file(): - raise AgentSpecError(f"Agent spec path is not a file: {agent_file}") - try: - with open(agent_file, encoding="utf-8") as f: - data: dict[str, Any] = yaml.safe_load(f) - except yaml.YAMLError as e: - raise AgentSpecError(f"Invalid YAML in agent spec file: {e}") from e - - version = str(data.get("version", DEFAULT_AGENT_SPEC_VERSION)) - if version not in SUPPORTED_AGENT_SPEC_VERSIONS: - raise AgentSpecError(f"Unsupported agent spec version: {version}") - - agent_spec = AgentSpec(**data.get("agent", {})) - if isinstance(agent_spec.system_prompt_path, Path): - agent_spec.system_prompt_path = ( - agent_file.parent / agent_spec.system_prompt_path - ).absolute() - if isinstance(agent_spec.subagents, dict): - for v in agent_spec.subagents.values(): - v.path = (agent_file.parent / v.path).absolute() - if agent_spec.extend: - if agent_spec.extend == "default": - base_agent_file = DEFAULT_AGENT_FILE - else: - base_agent_file = (agent_file.parent / agent_spec.extend).absolute() - base_agent_spec = _load_agent_spec(base_agent_file) - if not isinstance(agent_spec.name, Inherit): - base_agent_spec.name = agent_spec.name - if not isinstance(agent_spec.system_prompt_path, Inherit): - base_agent_spec.system_prompt_path = agent_spec.system_prompt_path - for k, v in agent_spec.system_prompt_args.items(): - # system prompt args should be merged instead of overwritten - base_agent_spec.system_prompt_args[k] = v - if agent_spec.model is not None: - base_agent_spec.model = agent_spec.model - if agent_spec.when_to_use is not None: - base_agent_spec.when_to_use = agent_spec.when_to_use - if not isinstance(agent_spec.tools, Inherit): - base_agent_spec.tools = agent_spec.tools - if not isinstance(agent_spec.allowed_tools, Inherit): - base_agent_spec.allowed_tools = agent_spec.allowed_tools - if not isinstance(agent_spec.exclude_tools, Inherit): - base_agent_spec.exclude_tools = agent_spec.exclude_tools - if not isinstance(agent_spec.subagents, Inherit): - base_agent_spec.subagents = agent_spec.subagents - agent_spec = base_agent_spec - return agent_spec diff --git a/src/kimi_cli/agentspec.ts b/src/kimi_cli/agentspec.ts new file mode 100644 index 000000000..2fa6331a2 --- /dev/null +++ b/src/kimi_cli/agentspec.ts @@ -0,0 +1,167 @@ +/** + * Agent spec loader — corresponds to Python agentspec.py + * Loads agent YAML specifications with inheritance support. + */ + +import { join, dirname, resolve } from "node:path"; +import { z } from "zod/v4"; +import { parse as parseYaml } from "./utils/yaml.ts"; + +// ── Constants ─────────────────────────────────────────── + +const DEFAULT_AGENT_SPEC_VERSION = "1"; +const SUPPORTED_VERSIONS = new Set([DEFAULT_AGENT_SPEC_VERSION]); + +export function getAgentsDir(): string { + return join(dirname(import.meta.dir), "kimi_cli", "agents"); +} + +const INHERIT = Symbol("inherit"); +type Inherit = typeof INHERIT; + +// ── Types ─────────────────────────────────────────────── + +export interface SubagentSpec { + path: string; + description: string; +} + +export interface AgentSpec { + extend?: string; + name: string | Inherit; + systemPromptPath: string | Inherit; + systemPromptArgs: Record; + model?: string; + whenToUse?: string; + tools: string[] | null | Inherit; + allowedTools: string[] | null | Inherit; + excludeTools: string[] | null | Inherit; + subagents: Record | null | Inherit; +} + +export interface ResolvedAgentSpec { + name: string; + systemPromptPath: string; + systemPromptArgs: Record; + model: string | null; + whenToUse: string; + tools: string[]; + allowedTools: string[] | null; + excludeTools: string[]; + subagents: Record; +} + +// ── Errors ────────────────────────────────────────────── + +export class AgentSpecError extends Error { + constructor(message: string) { + super(message); + this.name = "AgentSpecError"; + } +} + +// ── Loader ────────────────────────────────────────────── + +function parseAgentData(data: Record, agentFileDir: string): AgentSpec { + const agent = (data.agent ?? {}) as Record; + + const spec: AgentSpec = { + extend: agent.extend as string | undefined, + name: agent.name != null ? String(agent.name) : INHERIT, + systemPromptPath: agent.system_prompt_path != null + ? resolve(agentFileDir, String(agent.system_prompt_path)) + : INHERIT, + systemPromptArgs: (agent.system_prompt_args as Record) ?? {}, + model: agent.model != null ? String(agent.model) : undefined, + whenToUse: agent.when_to_use != null ? String(agent.when_to_use) : undefined, + tools: agent.tools !== undefined ? (agent.tools as string[] | null) : INHERIT, + allowedTools: agent.allowed_tools !== undefined ? (agent.allowed_tools as string[] | null) : INHERIT, + excludeTools: agent.exclude_tools !== undefined ? (agent.exclude_tools as string[] | null) : INHERIT, + subagents: agent.subagents !== undefined + ? parseSubagents(agent.subagents as Record, agentFileDir) + : INHERIT, + }; + + return spec; +} + +function parseSubagents( + raw: Record | null, + baseDir: string, +): Record | null { + if (!raw) return null; + const result: Record = {}; + for (const [key, val] of Object.entries(raw)) { + const v = val as Record; + result[key] = { + path: resolve(baseDir, String(v.path)), + description: String(v.description ?? ""), + }; + } + return result; +} + +async function loadAgentSpecRaw(agentFile: string): Promise { + const file = Bun.file(agentFile); + if (!(await file.exists())) { + throw new AgentSpecError(`Agent spec file not found: ${agentFile}`); + } + + const text = await file.text(); + const data = parseYaml(text) as Record; + + const version = String(data.version ?? DEFAULT_AGENT_SPEC_VERSION); + if (!SUPPORTED_VERSIONS.has(version)) { + throw new AgentSpecError(`Unsupported agent spec version: ${version}`); + } + + const agentFileDir = dirname(agentFile); + const spec = parseAgentData(data, agentFileDir); + + // Handle inheritance + if (spec.extend) { + const baseFile = spec.extend === "default" + ? join(getAgentsDir(), "default", "agent.yaml") + : resolve(agentFileDir, spec.extend); + + const base = await loadAgentSpecRaw(baseFile); + + if (spec.name !== INHERIT) base.name = spec.name; + if (spec.systemPromptPath !== INHERIT) base.systemPromptPath = spec.systemPromptPath; + // Merge system prompt args + for (const [k, v] of Object.entries(spec.systemPromptArgs)) { + base.systemPromptArgs[k] = v; + } + if (spec.model != null) base.model = spec.model; + if (spec.whenToUse != null) base.whenToUse = spec.whenToUse; + if (spec.tools !== INHERIT) base.tools = spec.tools; + if (spec.allowedTools !== INHERIT) base.allowedTools = spec.allowedTools; + if (spec.excludeTools !== INHERIT) base.excludeTools = spec.excludeTools; + if (spec.subagents !== INHERIT) base.subagents = spec.subagents; + + base.extend = undefined; + return base; + } + + return spec; +} + +export async function loadAgentSpec(agentFile: string): Promise { + const spec = await loadAgentSpecRaw(agentFile); + + if (spec.name === INHERIT) throw new AgentSpecError("Agent name is required"); + if (spec.systemPromptPath === INHERIT) throw new AgentSpecError("System prompt path is required"); + if (spec.tools === INHERIT) throw new AgentSpecError("Tools are required"); + + return { + name: spec.name as string, + systemPromptPath: spec.systemPromptPath as string, + systemPromptArgs: spec.systemPromptArgs, + model: spec.model ?? null, + whenToUse: spec.whenToUse ?? "", + tools: (spec.tools as string[]) ?? [], + allowedTools: spec.allowedTools === INHERIT ? null : (spec.allowedTools as string[] | null), + excludeTools: spec.excludeTools === INHERIT ? [] : ((spec.excludeTools as string[]) ?? []), + subagents: spec.subagents === INHERIT ? {} : ((spec.subagents as Record) ?? {}), + }; +} diff --git a/src/kimi_cli/app.py b/src/kimi_cli/app.py deleted file mode 100644 index ad2e2747a..000000000 --- a/src/kimi_cli/app.py +++ /dev/null @@ -1,540 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import dataclasses -import warnings -from collections.abc import AsyncGenerator, Callable -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import kaos -from kaos.path import KaosPath -from pydantic import SecretStr - -from kimi_cli.agentspec import DEFAULT_AGENT_FILE -from kimi_cli.auth.oauth import OAuthManager -from kimi_cli.cli import InputFormat, OutputFormat -from kimi_cli.config import Config, LLMModel, LLMProvider, load_config -from kimi_cli.llm import augment_provider_with_env_vars, create_llm, model_display_name -from kimi_cli.session import Session -from kimi_cli.share import get_share_dir -from kimi_cli.soul import run_soul -from kimi_cli.soul.agent import Runtime, load_agent -from kimi_cli.soul.context import Context -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.utils.aioqueue import QueueShutDown -from kimi_cli.utils.logging import logger, redirect_stderr_to_logger -from kimi_cli.utils.path import shorten_home -from kimi_cli.wire import Wire, WireUISide -from kimi_cli.wire.types import ApprovalRequest, ApprovalResponse, ContentPart, WireMessage - -if TYPE_CHECKING: - from fastmcp.mcp_config import MCPConfig - - -def enable_logging(debug: bool = False, *, redirect_stderr: bool = True) -> None: - # NOTE: stderr redirection is implemented by swapping the process-level fd=2 (dup2). - # That can hide Click/Typer error output during CLI startup, so some entrypoints delay - # installing it until after critical initialization succeeds. - logger.remove() # Remove default stderr handler - logger.enable("kimi_cli") - if debug: - logger.enable("kosong") - logger.add( - get_share_dir() / "logs" / "kimi.log", - # FIXME: configure level for different modules - level="TRACE" if debug else "INFO", - rotation="06:00", - retention="10 days", - ) - if redirect_stderr: - redirect_stderr_to_logger() - - -def _cleanup_stale_foreground_subagents(runtime: Runtime) -> None: - subagent_store = getattr(runtime, "subagent_store", None) - if subagent_store is None: - return - - stale_agent_ids = [ - record.agent_id - for record in subagent_store.list_instances() - if record.status == "running_foreground" - ] - for agent_id in stale_agent_ids: - logger.warning( - "Marking stale foreground subagent instance as failed during startup: {agent_id}", - agent_id=agent_id, - ) - subagent_store.update_instance(agent_id, status="failed") - - -class KimiCLI: - @staticmethod - async def create( - session: Session, - *, - # Basic configuration - config: Config | Path | None = None, - model_name: str | None = None, - thinking: bool | None = None, - # Run mode - yolo: bool = False, - # Extensions - agent_file: Path | None = None, - mcp_configs: list[MCPConfig] | list[dict[str, Any]] | None = None, - skills_dirs: list[KaosPath] | None = None, - # Loop control - max_steps_per_turn: int | None = None, - max_retries_per_step: int | None = None, - max_ralph_iterations: int | None = None, - startup_progress: Callable[[str], None] | None = None, - defer_mcp_loading: bool = False, - ) -> KimiCLI: - """ - Create a KimiCLI instance. - - Args: - session (Session): A session created by `Session.create` or `Session.continue_`. - config (Config | Path | None, optional): Configuration to use, or path to config file. - Defaults to None. - model_name (str | None, optional): Name of the model to use. Defaults to None. - thinking (bool | None, optional): Whether to enable thinking mode. Defaults to None. - yolo (bool, optional): Approve all actions without confirmation. Defaults to False. - agent_file (Path | None, optional): Path to the agent file. Defaults to None. - mcp_configs (list[MCPConfig | dict[str, Any]] | None, optional): MCP configs to load - MCP tools from. Defaults to None. - skills_dirs (list[KaosPath] | None, optional): Custom skills directories that - override default user/project discovery. Defaults to None. - max_steps_per_turn (int | None, optional): Maximum number of steps in one turn. - Defaults to None. - max_retries_per_step (int | None, optional): Maximum number of retries in one step. - Defaults to None. - max_ralph_iterations (int | None, optional): Extra iterations after the first turn in - Ralph mode. Defaults to None. - startup_progress (Callable[[str], None] | None, optional): Progress callback used by - interactive startup UI. Defaults to None. - defer_mcp_loading (bool, optional): Defer MCP startup until the interactive shell is - ready. Defaults to False. - - Raises: - FileNotFoundError: When the agent file is not found. - ConfigError(KimiCLIException, ValueError): When the configuration is invalid. - AgentSpecError(KimiCLIException, ValueError): When the agent specification is invalid. - SystemPromptTemplateError(KimiCLIException, ValueError): When the system prompt - template is invalid. - InvalidToolError(KimiCLIException, ValueError): When any tool cannot be loaded. - MCPConfigError(KimiCLIException, ValueError): When any MCP configuration is invalid. - MCPRuntimeError(KimiCLIException, RuntimeError): When any MCP server cannot be - connected. - """ - if startup_progress is not None: - startup_progress("Loading configuration...") - - config = config if isinstance(config, Config) else load_config(config) - if max_steps_per_turn is not None: - config.loop_control.max_steps_per_turn = max_steps_per_turn - if max_retries_per_step is not None: - config.loop_control.max_retries_per_step = max_retries_per_step - if max_ralph_iterations is not None: - config.loop_control.max_ralph_iterations = max_ralph_iterations - logger.info("Loaded config: {config}", config=config) - - oauth = OAuthManager(config) - - model: LLMModel | None = None - provider: LLMProvider | None = None - - # try to use config file - if not model_name and config.default_model: - # no --model specified && default model is set in config - model = config.models[config.default_model] - provider = config.providers[model.provider] - if model_name and model_name in config.models: - # --model specified && model is set in config - model = config.models[model_name] - provider = config.providers[model.provider] - - if not model: - model = LLMModel(provider="", model="", max_context_size=100_000) - provider = LLMProvider(type="kimi", base_url="", api_key=SecretStr("")) - - # try overwrite with environment variables - assert provider is not None - assert model is not None - env_overrides = augment_provider_with_env_vars(provider, model) - - # determine thinking mode - thinking = config.default_thinking if thinking is None else thinking - - # determine yolo mode - yolo = yolo if yolo else config.default_yolo - - llm = create_llm( - provider, - model, - thinking=thinking, - session_id=session.id, - oauth=oauth, - ) - if llm is not None: - logger.info("Using LLM provider: {provider}", provider=provider) - logger.info("Using LLM model: {model}", model=model) - logger.info("Thinking mode: {thinking}", thinking=thinking) - - if startup_progress is not None: - startup_progress("Scanning workspace...") - - runtime = await Runtime.create( - config, - oauth, - llm, - session, - yolo, - skills_dirs=skills_dirs, - ) - runtime.notifications.recover() - runtime.background_tasks.reconcile() - _cleanup_stale_foreground_subagents(runtime) - - # Refresh plugin configs with fresh credentials (e.g. OAuth tokens) - try: - from kimi_cli.plugin.manager import ( - collect_host_values, - get_plugins_dir, - refresh_plugin_configs, - ) - - host_values = collect_host_values(config, oauth) - if host_values.get("api_key"): - refresh_plugin_configs(get_plugins_dir(), host_values) - except Exception: - logger.debug("Failed to refresh plugin configs, skipping") - - if agent_file is None: - agent_file = DEFAULT_AGENT_FILE - if startup_progress is not None: - startup_progress("Loading agent...") - - agent = await load_agent( - agent_file, - runtime, - mcp_configs=mcp_configs or [], - start_mcp_loading=not defer_mcp_loading, - ) - - if startup_progress is not None: - startup_progress("Restoring conversation...") - context = Context(session.context_file) - await context.restore() - - if context.system_prompt is not None: - agent = dataclasses.replace(agent, system_prompt=context.system_prompt) - else: - await context.write_system_prompt(agent.system_prompt) - - soul = KimiSoul(agent, context=context) - - # Create and inject hook engine - from kimi_cli.hooks.engine import HookEngine - - hook_engine = HookEngine(config.hooks, cwd=str(session.work_dir)) - soul.set_hook_engine(hook_engine) - runtime.hook_engine = hook_engine - - return KimiCLI(soul, runtime, env_overrides) - - def __init__( - self, - _soul: KimiSoul, - _runtime: Runtime, - _env_overrides: dict[str, str], - ) -> None: - self._soul = _soul - self._runtime = _runtime - self._env_overrides = _env_overrides - - @property - def soul(self) -> KimiSoul: - """Get the KimiSoul instance.""" - return self._soul - - @property - def session(self) -> Session: - """Get the Session instance.""" - return self._runtime.session - - def shutdown_background_tasks(self) -> None: - """Kill active background tasks on exit, unless keep_alive_on_exit is configured.""" - if self._runtime.config.background.keep_alive_on_exit: - return - killed = self._runtime.background_tasks.kill_all_active(reason="CLI session ended") - if killed: - logger.info("Stopped {n} background task(s) on exit: {ids}", n=len(killed), ids=killed) - - @contextlib.asynccontextmanager - async def _env(self) -> AsyncGenerator[None]: - original_cwd = KaosPath.cwd() - await kaos.chdir(self._runtime.session.work_dir) - try: - # to ignore possible warnings from dateparser - warnings.filterwarnings("ignore", category=DeprecationWarning) - async with self._runtime.oauth.refreshing(self._runtime): - yield - finally: - await kaos.chdir(original_cwd) - - async def run( - self, - user_input: str | list[ContentPart], - cancel_event: asyncio.Event, - merge_wire_messages: bool = False, - ) -> AsyncGenerator[WireMessage]: - """ - Run the Kimi Code CLI instance without any UI and yield Wire messages directly. - - Args: - user_input (str | list[ContentPart]): The user input to the agent. - cancel_event (asyncio.Event): An event to cancel the run. - merge_wire_messages (bool): Whether to merge Wire messages as much as possible. - - Yields: - WireMessage: The Wire messages from the `KimiSoul`. - - Raises: - LLMNotSet: When the LLM is not set. - LLMNotSupported: When the LLM does not have required capabilities. - ChatProviderError: When the LLM provider returns an error. - MaxStepsReached: When the maximum number of steps is reached. - RunCancelled: When the run is cancelled by the cancel event. - """ - async with self._env(): - wire_future = asyncio.Future[WireUISide]() - stop_ui_loop = asyncio.Event() - approval_bridge_tasks: dict[str, asyncio.Task[None]] = {} - forwarded_approval_requests: dict[str, ApprovalRequest] = {} - - async def _bridge_approval_request(request: ApprovalRequest) -> None: - try: - response = await request.wait() - assert self._runtime.approval_runtime is not None - self._runtime.approval_runtime.resolve( - request.id, response, feedback=request.feedback - ) - finally: - approval_bridge_tasks.pop(request.id, None) - forwarded_approval_requests.pop(request.id, None) - - def _forward_approval_request(wire: Wire, request: ApprovalRequest) -> None: - if request.id in forwarded_approval_requests: - return - forwarded_approval_requests[request.id] = request - if request.id not in approval_bridge_tasks: - approval_bridge_tasks[request.id] = asyncio.create_task( - _bridge_approval_request(request) - ) - wire.soul_side.send(request) - - async def _ui_loop_fn(wire: Wire) -> None: - wire_future.set_result(wire.ui_side(merge=merge_wire_messages)) - assert self._runtime.root_wire_hub is not None - assert self._runtime.approval_runtime is not None - root_hub_queue = self._runtime.root_wire_hub.subscribe() - stop_task = asyncio.create_task(stop_ui_loop.wait()) - queue_task = asyncio.create_task(root_hub_queue.get()) - try: - for pending in self._runtime.approval_runtime.list_pending(): - _forward_approval_request( - wire, - ApprovalRequest( - id=pending.id, - tool_call_id=pending.tool_call_id, - sender=pending.sender, - action=pending.action, - description=pending.description, - display=pending.display, - source_kind=pending.source.kind, - source_id=pending.source.id, - agent_id=pending.source.agent_id, - subagent_type=pending.source.subagent_type, - ), - ) - while True: - done, _ = await asyncio.wait( - [stop_task, queue_task], - return_when=asyncio.FIRST_COMPLETED, - ) - if stop_task in done: - break - try: - msg = queue_task.result() - except QueueShutDown: - break - match msg: - case ApprovalRequest() as request: - _forward_approval_request(wire, request) - queue_task = asyncio.create_task(root_hub_queue.get()) - continue - case ApprovalResponse() as response: - if ( - request := forwarded_approval_requests.get(response.request_id) - ) and not request.resolved: - request.resolve(response.response, response.feedback) - case _: - pass - wire.soul_side.send(msg) - queue_task = asyncio.create_task(root_hub_queue.get()) - finally: - stop_task.cancel() - queue_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await stop_task - with contextlib.suppress(asyncio.CancelledError): - await queue_task - for task in list(approval_bridge_tasks.values()): - task.cancel() - for task in list(approval_bridge_tasks.values()): - with contextlib.suppress(asyncio.CancelledError): - await task - approval_bridge_tasks.clear() - forwarded_approval_requests.clear() - assert self._runtime.root_wire_hub is not None - self._runtime.root_wire_hub.unsubscribe(root_hub_queue) - - soul_task = asyncio.create_task( - run_soul( - self.soul, - user_input, - _ui_loop_fn, - cancel_event, - runtime=self._runtime, - ) - ) - - try: - wire_ui = await wire_future - while True: - msg = await wire_ui.receive() - yield msg - except QueueShutDown: - pass - finally: - # stop consuming Wire messages - stop_ui_loop.set() - # wait for the soul task to finish, or raise - await soul_task - - async def run_shell(self, command: str | None = None) -> bool: - """Run the Kimi Code CLI instance with shell UI.""" - from kimi_cli.ui.shell import Shell, WelcomeInfoItem - - welcome_info = [ - WelcomeInfoItem( - name="Directory", value=str(shorten_home(self._runtime.session.work_dir)) - ), - WelcomeInfoItem(name="Session", value=self._runtime.session.id), - ] - if base_url := self._env_overrides.get("KIMI_BASE_URL"): - welcome_info.append( - WelcomeInfoItem( - name="API URL", - value=f"{base_url} (from KIMI_BASE_URL)", - level=WelcomeInfoItem.Level.WARN, - ) - ) - if self._env_overrides.get("KIMI_API_KEY"): - welcome_info.append( - WelcomeInfoItem( - name="API Key", - value="****** (from KIMI_API_KEY)", - level=WelcomeInfoItem.Level.WARN, - ) - ) - if not self._runtime.llm: - welcome_info.append( - WelcomeInfoItem( - name="Model", - value="not set, send /login to login", - level=WelcomeInfoItem.Level.WARN, - ) - ) - elif "KIMI_MODEL_NAME" in self._env_overrides: - welcome_info.append( - WelcomeInfoItem( - name="Model", - value=f"{self._soul.model_name} (from KIMI_MODEL_NAME)", - level=WelcomeInfoItem.Level.WARN, - ) - ) - else: - welcome_info.append( - WelcomeInfoItem( - name="Model", - value=model_display_name(self._soul.model_name), - level=WelcomeInfoItem.Level.INFO, - ) - ) - if self._soul.model_name not in ( - "kimi-for-coding", - "kimi-code", - "kimi-k2.5", - "kimi-k2-5", - ): - welcome_info.append( - WelcomeInfoItem( - name="Tip", - value="send /login to use our latest kimi-k2.5 model", - level=WelcomeInfoItem.Level.WARN, - ) - ) - welcome_info.append( - WelcomeInfoItem( - name="\nTip", - value=( - "Spot a bug or have feedback? Type /feedback right in this session" - " — every report makes Kimi better." - ), - level=WelcomeInfoItem.Level.INFO, - ) - ) - async with self._env(): - shell = Shell(self._soul, welcome_info=welcome_info) - return await shell.run(command) - - async def run_print( - self, - input_format: InputFormat, - output_format: OutputFormat, - command: str | None = None, - *, - final_only: bool = False, - ) -> int: - """Run the Kimi Code CLI instance with print UI.""" - from kimi_cli.ui.print import Print - - async with self._env(): - print_ = Print( - self._soul, - input_format, - output_format, - self._runtime.session.context_file, - final_only=final_only, - ) - return await print_.run(command) - - async def run_acp(self) -> None: - """Run the Kimi Code CLI instance as ACP server.""" - from kimi_cli.ui.acp import ACP - - async with self._env(): - acp = ACP(self._soul) - await acp.run() - - async def run_wire_stdio(self) -> None: - """Run the Kimi Code CLI instance as Wire server over stdio.""" - from kimi_cli.wire.server import WireServer - - async with self._env(): - server = WireServer(self._soul) - await server.serve() diff --git a/src/kimi_cli/app.ts b/src/kimi_cli/app.ts new file mode 100644 index 000000000..fab468767 --- /dev/null +++ b/src/kimi_cli/app.ts @@ -0,0 +1,283 @@ +/** + * KimiCLI app orchestrator — corresponds to Python app.py + * Creates and wires together all components. + */ + +import { loadConfig, type Config, type ConfigMeta } from "./config.ts"; +import { createLLM, augmentProviderWithEnvVars, type LLM } from "./llm.ts"; +import { OAuthManager, loadTokens, commonHeaders } from "./auth/oauth.ts"; +import { Session } from "./session.ts"; +import { HookEngine } from "./hooks/engine.ts"; +import { Context } from "./soul/context.ts"; +import { Runtime, Agent, loadAgent } from "./soul/agent.ts"; +import { KimiSoul, type SoulCallbacks } from "./soul/kimisoul.ts"; +import { logger } from "./utils/logging.ts"; + +// ── KimiCLI ───────────────────────────────────────── + +export class KimiCLI { + readonly soul: KimiSoul; + readonly agent: Agent; + readonly session: Session; + readonly config: Config; + readonly configMeta: ConfigMeta; + readonly context: Context; + + private constructor(opts: { + soul: KimiSoul; + agent: Agent; + session: Session; + config: Config; + configMeta: ConfigMeta; + context: Context; + }) { + this.soul = opts.soul; + this.agent = opts.agent; + this.session = opts.session; + this.config = opts.config; + this.configMeta = opts.configMeta; + this.context = opts.context; + } + + // ── Factory ────────────────────────────────────── + + static async create(opts: { + workDir?: string; + additionalDirs?: string[]; + configFile?: string; + modelName?: string; + thinking?: boolean; + yolo?: boolean; + sessionId?: string; + continueSession?: boolean; + maxStepsPerTurn?: number; + callbacks?: SoulCallbacks; + }): Promise { + const workDir = opts.workDir ?? process.cwd(); + + // 1. Load config + const { config, meta: configMeta } = await loadConfig(opts.configFile); + + // Override settings from CLI flags + if (opts.maxStepsPerTurn) { + config.loop_control.max_steps_per_turn = opts.maxStepsPerTurn; + } + if (opts.yolo) { + config.default_yolo = true; + } + + // 2. Determine model + const modelName = opts.modelName ?? config.default_model; + let llm: LLM | null = null; + + if (modelName && config.models[modelName]) { + const modelConfig = config.models[modelName]!; + const providerName = modelConfig.provider; + const providerConfig = config.providers[providerName]; + + if (providerConfig) { + // Resolve API key: if OAuth is configured, load the access token + let apiKey = providerConfig.api_key; + if (providerConfig.oauth) { + const token = await loadTokens(providerConfig.oauth); + if (token) { + apiKey = token.access_token; + } + } + + // Build platform identification headers (matches Python _kimi_default_headers) + const platformHeaders = await commonHeaders(); + const mergedCustomHeaders: Record = { + "User-Agent": `KimiCLI/2.0.0`, + ...platformHeaders, + ...(providerConfig.custom_headers ?? {}), + }; + + // Convert snake_case config to camelCase LLM interface + const llmProvider = { + type: providerConfig.type as any, + baseUrl: providerConfig.base_url, + apiKey, + customHeaders: mergedCustomHeaders, + env: providerConfig.env, + oauth: providerConfig.oauth?.key ?? null, + }; + const llmModel = { + model: modelConfig.model, + provider: modelConfig.provider, + maxContextSize: modelConfig.max_context_size, + capabilities: modelConfig.capabilities, + }; + + // Apply env var overrides + augmentProviderWithEnvVars(llmProvider, llmModel); + + llm = createLLM(llmProvider, llmModel, { + thinking: opts.thinking ?? config.default_thinking, + }); + } + } + + // Fallback: create LLM from environment variables directly + // Supports: KIMI_BASE_URL, KIMI_API_KEY, KIMI_MODEL_NAME + if (!llm) { + const envBaseUrl = process.env.KIMI_BASE_URL; + const envApiKey = process.env.KIMI_API_KEY; + const envModel = process.env.KIMI_MODEL_NAME; + + if (envBaseUrl && envApiKey && envModel) { + const envPlatformHeaders = await commonHeaders(); + const llmProvider = { + type: "kimi" as const, + baseUrl: envBaseUrl, + apiKey: envApiKey, + customHeaders: { + "User-Agent": `KimiCLI/2.0.0`, + ...envPlatformHeaders, + } as Record, + }; + const llmModel = { + model: envModel, + provider: "env", + maxContextSize: parseInt( + process.env.KIMI_MODEL_MAX_CONTEXT_SIZE ?? "131072", + 10, + ), + capabilities: undefined as any, + }; + + llm = createLLM(llmProvider, llmModel, { + thinking: opts.thinking ?? config.default_thinking, + }); + + if (llm) { + logger.info( + `LLM from env: ${envModel} @ ${envBaseUrl}`, + ); + } + } + } + + if (!llm) { + logger.warn( + `No LLM configured for model "${modelName}". ` + + "Set up a model in ~/.kimi/config.toml", + ); + } + + // 3. Create/restore session + let session: Session; + if (opts.sessionId) { + const found = await Session.find(workDir, opts.sessionId); + session = found ?? (await Session.create(workDir)); + } else if (opts.continueSession) { + const continued = await Session.continue_(workDir); + if (continued) { + session = continued; + logger.info(`Continuing session ${session.id}`); + } else { + session = await Session.create(workDir); + logger.info("No previous session found, starting new session"); + } + } else { + session = await Session.create(workDir); + } + + // Store additional dirs in session state + if (opts.additionalDirs && opts.additionalDirs.length > 0) { + session.state.additional_dirs = opts.additionalDirs.map((d) => + d.startsWith("/") ? d : `${workDir}/${d}`, + ); + } + + // 4. Create hook engine + const hookEngine = new HookEngine({ + hooks: config.hooks, + cwd: workDir, + }); + + // 5. Create runtime + const runtime = await Runtime.create({ + config, + llm, + session, + hookEngine, + }); + + // 6. Load agent + const agent = await loadAgent({ runtime }); + + // 7. Create/restore context + const context = new Context(session.contextFile); + await context.restore(); + + // 8. Write system prompt if new context; otherwise use restored prompt + if (!context.systemPrompt) { + await context.writeSystemPrompt(agent.systemPrompt); + } else { + // On session continuation, use the system prompt from the restored context + // to ensure consistency (the prompt may have changed between versions) + (agent as any).systemPrompt = context.systemPrompt; + } + + // 9. Create KimiSoul + const soul = new KimiSoul({ + agent, + context, + callbacks: opts.callbacks ?? {}, + }); + + // Wire slash commands + soul.wireSlashCommands(); + + // Wire tool context (plan mode, ask user, etc.) + soul.wireToolContext(); + + return new KimiCLI({ + soul, + agent, + session, + config, + configMeta, + context, + }); + } + + // ── Run modes ──────────────────────────────────── + + /** + * Run in interactive shell mode (React Ink TUI). + */ + async runShell(initialCommand?: string): Promise { + // This will be called from cli/index.ts with Ink rendering + // The shell component will call soul.run() directly + if (initialCommand) { + await this.soul.run(initialCommand); + } + return true; // continue + } + + /** + * Run in print mode (non-interactive). + */ + async runPrint(input: string): Promise { + await this.soul.run(input); + } + + // ── Lifecycle ────────────────────────────────── + + async shutdown(): Promise { + this.soul.abort(); + await this.agent.toolset.cleanup(); + + // Clean up empty sessions (no real messages exchanged) + if (await this.session.isEmpty()) { + await this.session.delete(); + logger.debug("Deleted empty session"); + } else { + await this.session.saveState(); + } + + logger.info("KimiCLI shutdown complete"); + } +} diff --git a/src/kimi_cli/approval_runtime/__init__.py b/src/kimi_cli/approval_runtime/__init__.py deleted file mode 100644 index dfcbf0f57..000000000 --- a/src/kimi_cli/approval_runtime/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -from kimi_cli.approval_runtime.models import ( - ApprovalRequestRecord, - ApprovalResponseKind, - ApprovalRuntimeEvent, - ApprovalSource, - ApprovalSourceKind, - ApprovalStatus, -) -from kimi_cli.approval_runtime.runtime import ( - ApprovalCancelledError, - ApprovalRuntime, - get_current_approval_source_or_none, - reset_current_approval_source, - set_current_approval_source, -) - -__all__ = [ - "ApprovalCancelledError", - "ApprovalRequestRecord", - "ApprovalResponseKind", - "ApprovalRuntime", - "ApprovalRuntimeEvent", - "ApprovalSource", - "ApprovalSourceKind", - "ApprovalStatus", - "get_current_approval_source_or_none", - "reset_current_approval_source", - "set_current_approval_source", -] diff --git a/src/kimi_cli/approval_runtime/index.ts b/src/kimi_cli/approval_runtime/index.ts new file mode 100644 index 000000000..4f147c724 --- /dev/null +++ b/src/kimi_cli/approval_runtime/index.ts @@ -0,0 +1,194 @@ +/** + * Approval runtime — corresponds to Python approval_runtime/ + * Manages approval requests lifecycle: create, wait, resolve, cancel. + */ + +import { randomUUID } from "node:crypto"; +import { logger } from "../utils/logging.ts"; + +// ── Types ─────────────────────────────────────────────── + +export type ApprovalResponseKind = "approve" | "approve_for_session" | "reject"; +export type ApprovalSourceKind = "foreground_turn" | "background_agent"; +export type ApprovalStatus = "pending" | "resolved" | "cancelled"; +export type ApprovalRuntimeEventKind = "request_created" | "request_resolved"; + +export interface ApprovalSource { + kind: ApprovalSourceKind; + id: string; + agentId?: string; + subagentType?: string; +} + +export interface ApprovalRequestRecord { + id: string; + toolCallId: string; + sender: string; + action: string; + description: string; + display: unknown[]; + source: ApprovalSource; + createdAt: number; + status: ApprovalStatus; + resolvedAt: number | null; + response: ApprovalResponseKind | null; + feedback: string; +} + +export interface ApprovalRuntimeEvent { + kind: ApprovalRuntimeEventKind; + request: ApprovalRequestRecord; +} + +// ── Errors ────────────────────────────────────────────── + +export class ApprovalCancelledError extends Error { + constructor(requestId: string) { + super(`Approval cancelled: ${requestId}`); + this.name = "ApprovalCancelledError"; + } +} + +// ── Waiter (promise-based future) ─────────────────────── + +interface Waiter { + resolve: (value: [ApprovalResponseKind, string]) => void; + reject: (reason: Error) => void; + promise: Promise<[ApprovalResponseKind, string]>; +} + +function createWaiter(): Waiter { + let resolve!: Waiter["resolve"]; + let reject!: Waiter["reject"]; + const promise = new Promise<[ApprovalResponseKind, string]>((res, rej) => { + resolve = res; + reject = rej; + }); + return { resolve, reject, promise }; +} + +// ── Runtime ───────────────────────────────────────────── + +export type EventSubscriber = (event: ApprovalRuntimeEvent) => void; + +export class ApprovalRuntime { + private requests = new Map(); + private waiters = new Map(); + private subscribers = new Map(); + + createRequest(opts: { + requestId?: string; + toolCallId: string; + sender: string; + action: string; + description: string; + display?: unknown[]; + source: ApprovalSource; + }): ApprovalRequestRecord { + const request: ApprovalRequestRecord = { + id: opts.requestId ?? randomUUID(), + toolCallId: opts.toolCallId, + sender: opts.sender, + action: opts.action, + description: opts.description, + display: opts.display ?? [], + source: opts.source, + createdAt: Date.now() / 1000, + status: "pending", + resolvedAt: null, + response: null, + feedback: "", + }; + this.requests.set(request.id, request); + this.publishEvent({ kind: "request_created", request }); + return request; + } + + async waitForResponse(requestId: string): Promise<[ApprovalResponseKind, string]> { + const request = this.requests.get(requestId); + if (!request) throw new Error(`Approval request not found: ${requestId}`); + + if (request.status === "cancelled") { + throw new ApprovalCancelledError(requestId); + } + if (request.status === "resolved" && request.response) { + return [request.response, request.feedback]; + } + + let waiter = this.waiters.get(requestId); + if (!waiter) { + waiter = createWaiter(); + this.waiters.set(requestId, waiter); + } + return waiter.promise; + } + + resolve(requestId: string, response: ApprovalResponseKind, feedback = ""): boolean { + const request = this.requests.get(requestId); + if (!request || request.status !== "pending") return false; + + request.status = "resolved"; + request.response = response; + request.feedback = feedback; + request.resolvedAt = Date.now() / 1000; + + const waiter = this.waiters.get(requestId); + if (waiter) { + waiter.resolve([response, feedback]); + this.waiters.delete(requestId); + } + this.publishEvent({ kind: "request_resolved", request }); + return true; + } + + cancelBySource(sourceKind: ApprovalSourceKind, sourceId: string): number { + let cancelled = 0; + for (const [requestId, request] of this.requests) { + if (request.status !== "pending") continue; + if (request.source.kind !== sourceKind || request.source.id !== sourceId) continue; + + request.status = "cancelled"; + request.response = "reject"; + request.resolvedAt = Date.now() / 1000; + + const waiter = this.waiters.get(requestId); + if (waiter) { + waiter.reject(new ApprovalCancelledError(requestId)); + this.waiters.delete(requestId); + } + this.publishEvent({ kind: "request_resolved", request }); + cancelled++; + } + return cancelled; + } + + listPending(): ApprovalRequestRecord[] { + return [...this.requests.values()] + .filter((r) => r.status === "pending") + .sort((a, b) => a.createdAt - b.createdAt); + } + + getRequest(requestId: string): ApprovalRequestRecord | undefined { + return this.requests.get(requestId); + } + + subscribe(callback: EventSubscriber): string { + const token = randomUUID(); + this.subscribers.set(token, callback); + return token; + } + + unsubscribe(token: string): void { + this.subscribers.delete(token); + } + + private publishEvent(event: ApprovalRuntimeEvent): void { + for (const cb of this.subscribers.values()) { + try { + cb(event); + } catch (err) { + logger.error("Approval runtime event subscriber failed", err); + } + } + } +} diff --git a/src/kimi_cli/approval_runtime/models.py b/src/kimi_cli/approval_runtime/models.py deleted file mode 100644 index 4409c1571..000000000 --- a/src/kimi_cli/approval_runtime/models.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -import time -from dataclasses import dataclass, field -from typing import Literal - -from kimi_cli.wire.types import DisplayBlock - -type ApprovalResponseKind = Literal["approve", "approve_for_session", "reject"] -type ApprovalSourceKind = Literal["foreground_turn", "background_agent"] -type ApprovalStatus = Literal["pending", "resolved", "cancelled"] -type ApprovalRuntimeEventKind = Literal["request_created", "request_resolved"] - - -@dataclass(frozen=True, slots=True, kw_only=True) -class ApprovalSource: - kind: ApprovalSourceKind - id: str - agent_id: str | None = None - subagent_type: str | None = None - - -@dataclass(slots=True, kw_only=True) -class ApprovalRequestRecord: - id: str - tool_call_id: str - sender: str - action: str - description: str - display: list[DisplayBlock] - source: ApprovalSource - created_at: float = field(default_factory=time.time) - status: ApprovalStatus = "pending" - resolved_at: float | None = None - response: ApprovalResponseKind | None = None - feedback: str = "" - - -@dataclass(frozen=True, slots=True, kw_only=True) -class ApprovalRuntimeEvent: - kind: ApprovalRuntimeEventKind - request: ApprovalRequestRecord diff --git a/src/kimi_cli/approval_runtime/runtime.py b/src/kimi_cli/approval_runtime/runtime.py deleted file mode 100644 index b310e0fc9..000000000 --- a/src/kimi_cli/approval_runtime/runtime.py +++ /dev/null @@ -1,189 +0,0 @@ -from __future__ import annotations - -import asyncio -import uuid -from contextvars import ContextVar, Token -from typing import TYPE_CHECKING - -from kimi_cli.utils.logging import logger -from kimi_cli.wire.types import ApprovalRequest, ApprovalResponse - -from .models import ( - ApprovalRequestRecord, - ApprovalResponseKind, - ApprovalRuntimeEvent, - ApprovalSource, - ApprovalSourceKind, -) - -if TYPE_CHECKING: - from collections.abc import Callable - - from kimi_cli.wire.root_hub import RootWireHub - from kimi_cli.wire.types import DisplayBlock - - -class ApprovalCancelledError(Exception): - """Raised when a pending approval is cancelled by its source lifecycle.""" - - -_current_approval_source = ContextVar[ApprovalSource | None]( - "current_approval_source", - default=None, -) - - -def get_current_approval_source_or_none() -> ApprovalSource | None: - return _current_approval_source.get() - - -def set_current_approval_source(source: ApprovalSource) -> Token[ApprovalSource | None]: - return _current_approval_source.set(source) - - -def reset_current_approval_source(token: Token[ApprovalSource | None]) -> None: - _current_approval_source.reset(token) - - -class ApprovalRuntime: - def __init__(self) -> None: - self._requests: dict[str, ApprovalRequestRecord] = {} - self._waiters: dict[str, asyncio.Future[tuple[ApprovalResponseKind, str]]] = {} - self._subscribers: dict[str, Callable[[ApprovalRuntimeEvent], None]] = {} - self._root_wire_hub: RootWireHub | None = None - - def bind_root_wire_hub(self, root_wire_hub: RootWireHub) -> None: - if self._root_wire_hub is root_wire_hub: - return - self._root_wire_hub = root_wire_hub - - def create_request( - self, - *, - sender: str, - action: str, - description: str, - tool_call_id: str, - display: list[DisplayBlock], - source: ApprovalSource, - request_id: str | None = None, - ) -> ApprovalRequestRecord: - request = ApprovalRequestRecord( - id=request_id or str(uuid.uuid4()), - tool_call_id=tool_call_id, - sender=sender, - action=action, - description=description, - display=display, - source=source, - ) - self._requests[request.id] = request - self._publish_event(ApprovalRuntimeEvent(kind="request_created", request=request)) - self._publish_wire_request(request) - return request - - async def wait_for_response(self, request_id: str) -> tuple[ApprovalResponseKind, str]: - waiter = self._waiters.get(request_id) - request = self._requests.get(request_id) - if request is None: - raise KeyError(f"Approval request not found: {request_id}") - if waiter is None: - if request.status == "cancelled": - raise ApprovalCancelledError(request_id) - if request.status == "resolved": - assert request.response is not None - return request.response, request.feedback - waiter = asyncio.get_running_loop().create_future() - self._waiters[request_id] = waiter - return await waiter - - def resolve(self, request_id: str, response: ApprovalResponseKind, feedback: str = "") -> bool: - request = self._requests.get(request_id) - if request is None or request.status != "pending": - return False - request.status = "resolved" - request.response = response - request.feedback = feedback - import time - - request.resolved_at = time.time() - waiter = self._waiters.pop(request_id, None) - if waiter is not None and not waiter.done(): - waiter.set_result((response, feedback)) - self._publish_event(ApprovalRuntimeEvent(kind="request_resolved", request=request)) - self._publish_wire_response(request_id, response, feedback) - return True - - def cancel_by_source(self, source_kind: ApprovalSourceKind, source_id: str) -> int: - cancelled = 0 - import time - - for request_id, request in self._requests.items(): - if request.status != "pending": - continue - if request.source.kind != source_kind or request.source.id != source_id: - continue - request.status = "cancelled" - request.response = "reject" - request.resolved_at = time.time() - waiter = self._waiters.pop(request_id, None) - if waiter is not None and not waiter.done(): - waiter.set_exception(ApprovalCancelledError(request_id)) - self._publish_event(ApprovalRuntimeEvent(kind="request_resolved", request=request)) - self._publish_wire_response(request_id, "reject") - cancelled += 1 - return cancelled - - def list_pending(self) -> list[ApprovalRequestRecord]: - pending = [request for request in self._requests.values() if request.status == "pending"] - pending.sort(key=lambda request: request.created_at) - return pending - - def get_request(self, request_id: str) -> ApprovalRequestRecord | None: - return self._requests.get(request_id) - - def subscribe(self, callback: Callable[[ApprovalRuntimeEvent], None]) -> str: - token = uuid.uuid4().hex - self._subscribers[token] = callback - return token - - def unsubscribe(self, token: str) -> None: - self._subscribers.pop(token, None) - - def _publish_event(self, event: ApprovalRuntimeEvent) -> None: - for callback in list(self._subscribers.values()): - try: - callback(event) - except Exception: - logger.exception("Approval runtime event subscriber failed") - - def _publish_wire_request(self, request: ApprovalRequestRecord) -> None: - if self._root_wire_hub is None: - return - self._root_wire_hub.publish_nowait( - ApprovalRequest( - id=request.id, - tool_call_id=request.tool_call_id, - sender=request.sender, - action=request.action, - description=request.description, - display=request.display, - source_kind=request.source.kind, - source_id=request.source.id, - agent_id=request.source.agent_id, - subagent_type=request.source.subagent_type, - ) - ) - - def _publish_wire_response( - self, request_id: str, response: ApprovalResponseKind, feedback: str = "" - ) -> None: - if self._root_wire_hub is None: - return - self._root_wire_hub.publish_nowait( - ApprovalResponse( - request_id=request_id, - response=response, - feedback=feedback, - ) - ) diff --git a/src/kimi_cli/auth/__init__.py b/src/kimi_cli/auth/__init__.py deleted file mode 100644 index aadb45556..000000000 --- a/src/kimi_cli/auth/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -KIMI_CODE_PLATFORM_ID = "kimi-code" - -__all__ = ["KIMI_CODE_PLATFORM_ID"] diff --git a/src/kimi_cli/auth/index.ts b/src/kimi_cli/auth/index.ts new file mode 100644 index 000000000..b59d75d78 --- /dev/null +++ b/src/kimi_cli/auth/index.ts @@ -0,0 +1,5 @@ +/** + * Auth module — corresponds to Python auth/__init__.py + */ + +export { KIMI_CODE_PLATFORM_ID } from "./platforms.ts"; diff --git a/src/kimi_cli/auth/oauth.py b/src/kimi_cli/auth/oauth.py deleted file mode 100644 index 5c8de336b..000000000 --- a/src/kimi_cli/auth/oauth.py +++ /dev/null @@ -1,804 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import os -import platform -import socket -import sys -import time -import uuid -import webbrowser -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager, suppress -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, cast - -import aiohttp -import keyring -from pydantic import SecretStr - -from kimi_cli.auth import KIMI_CODE_PLATFORM_ID -from kimi_cli.auth.platforms import ( - ModelInfo, - get_platform_by_id, - list_models, - managed_model_key, - managed_provider_key, -) -from kimi_cli.config import ( - Config, - LLMModel, - LLMProvider, - MoonshotFetchConfig, - MoonshotSearchConfig, - OAuthRef, - save_config, -) -from kimi_cli.constant import VERSION -from kimi_cli.share import get_share_dir -from kimi_cli.utils.aiohttp import new_client_session -from kimi_cli.utils.logging import logger - -if TYPE_CHECKING: - from kimi_cli.soul.agent import Runtime - - -KIMI_CODE_CLIENT_ID = "17e5f671-d194-4dfb-9706-5516cb48c098" -KIMI_CODE_OAUTH_KEY = "oauth/kimi-code" -DEFAULT_OAUTH_HOST = "https://auth.kimi.com" -KEYRING_SERVICE = "kimi-code" -REFRESH_INTERVAL_SECONDS = 60 -REFRESH_THRESHOLD_SECONDS = 300 - - -class OAuthError(RuntimeError): - """OAuth flow error.""" - - -class OAuthUnauthorized(OAuthError): - """OAuth credentials rejected.""" - - -class OAuthDeviceExpired(OAuthError): - """Device authorization expired.""" - - -OAuthEventKind = Literal["info", "error", "waiting", "verification_url", "success"] - - -@dataclass(slots=True, frozen=True) -class OAuthEvent: - type: OAuthEventKind - message: str - data: dict[str, Any] | None = None - - def __str__(self) -> str: - return self.message - - @property - def json(self) -> str: - payload: dict[str, Any] = {"type": self.type, "message": self.message} - if self.data is not None: - payload["data"] = self.data - return json.dumps(payload, ensure_ascii=False) - - -@dataclass(slots=True) -class OAuthToken: - access_token: str - refresh_token: str - expires_at: float - scope: str - token_type: str - - @classmethod - def from_response(cls, payload: dict[str, Any]) -> OAuthToken: - expires_in = float(payload["expires_in"]) - return cls( - access_token=str(payload["access_token"]), - refresh_token=str(payload["refresh_token"]), - expires_at=time.time() + expires_in, - scope=str(payload["scope"]), - token_type=str(payload["token_type"]), - ) - - def to_dict(self) -> dict[str, Any]: - return { - "access_token": self.access_token, - "refresh_token": self.refresh_token, - "expires_at": self.expires_at, - "scope": self.scope, - "token_type": self.token_type, - } - - @classmethod - def from_dict(cls, payload: dict[str, Any]) -> OAuthToken: - expires_at_value = payload.get("expires_at") - return cls( - access_token=str(payload.get("access_token") or ""), - refresh_token=str(payload.get("refresh_token") or ""), - expires_at=float(expires_at_value) if expires_at_value is not None else 0.0, - scope=str(payload.get("scope") or ""), - token_type=str(payload.get("token_type") or ""), - ) - - -@dataclass(slots=True) -class DeviceAuthorization: - user_code: str - device_code: str - verification_uri: str - verification_uri_complete: str - expires_in: int | None - interval: int - - -def _oauth_host() -> str: - return os.getenv("KIMI_CODE_OAUTH_HOST") or os.getenv("KIMI_OAUTH_HOST") or DEFAULT_OAUTH_HOST - - -def _device_id_path() -> Path: - return get_share_dir() / "device_id" - - -def _ensure_private_file(path: Path) -> None: - with suppress(OSError): - os.chmod(path, 0o600) - - -def _device_model() -> str: - system = platform.system() - arch = platform.machine() or "" - if system == "Darwin": - version = platform.mac_ver()[0] or platform.release() - if version and arch: - return f"macOS {version} {arch}" - if version: - return f"macOS {version}" - return f"macOS {arch}".strip() - if system == "Windows": - release = platform.release() - if release == "10": - try: - build = sys.getwindowsversion().build # type: ignore[attr-defined] - except Exception: - build = None - if build and build >= 22000: - release = "11" - if release and arch: - return f"Windows {release} {arch}" - if release: - return f"Windows {release}" - return f"Windows {arch}".strip() - if system: - version = platform.release() - if version and arch: - return f"{system} {version} {arch}" - if version: - return f"{system} {version}" - return f"{system} {arch}".strip() - return "Unknown" - - -def get_device_id() -> str: - path = _device_id_path() - if path.exists(): - return path.read_text(encoding="utf-8").strip() - device_id = uuid.uuid4().hex - path.write_text(device_id, encoding="utf-8") - _ensure_private_file(path) - return device_id - - -def _ascii_header_value(value: str, *, fallback: str = "unknown") -> str: - try: - value.encode("ascii") - return value.strip() - except UnicodeEncodeError: - sanitized = value.encode("ascii", errors="ignore").decode("ascii").strip() - return sanitized or fallback - - -def _common_headers() -> dict[str, str]: - device_name = platform.node() or socket.gethostname() - device_model = _device_model() - headers = { - "X-Msh-Platform": "kimi_cli", - "X-Msh-Version": VERSION, - "X-Msh-Device-Name": device_name, - "X-Msh-Device-Model": device_model, - "X-Msh-Os-Version": platform.version(), - "X-Msh-Device-Id": get_device_id(), - } - return {key: _ascii_header_value(value) for key, value in headers.items()} - - -def _credentials_dir() -> Path: - path = get_share_dir() / "credentials" - path.mkdir(parents=True, exist_ok=True) - return path - - -def _credentials_path(key: str) -> Path: - name = key.removeprefix("oauth/").split("/")[-1] or key - return _credentials_dir() / f"{name}.json" - - -def _load_from_keyring(key: str) -> OAuthToken | None: - try: - raw = keyring.get_password(KEYRING_SERVICE, key) - except Exception as exc: - logger.warning("Failed to read token from keyring: {error}", error=exc) - return None - if not raw: - return None - try: - payload = json.loads(raw) - except json.JSONDecodeError: - return None - if not isinstance(payload, dict): - return None - payload = cast(dict[str, Any], payload) - return OAuthToken.from_dict(payload) - - -def _delete_from_keyring(key: str) -> None: - try: - keyring.delete_password(KEYRING_SERVICE, key) - except Exception: - return - - -def _load_from_file(key: str) -> OAuthToken | None: - path = _credentials_path(key) - if not path.exists(): - return None - try: - payload = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError: - return None - if not isinstance(payload, dict): - return None - payload = cast(dict[str, Any], payload) - return OAuthToken.from_dict(payload) - - -def _save_to_file(key: str, token: OAuthToken) -> None: - path = _credentials_path(key) - path.write_text(json.dumps(token.to_dict(), ensure_ascii=False), encoding="utf-8") - _ensure_private_file(path) - - -def _delete_from_file(key: str) -> None: - path = _credentials_path(key) - if path.exists(): - path.unlink() - - -def load_tokens(ref: OAuthRef) -> OAuthToken | None: - file_token = _load_from_file(ref.key) - if file_token is not None: - return file_token - if ref.storage != "keyring": - return None - token = _load_from_keyring(ref.key) - if token is None: - return None - try: - _save_to_file(ref.key, token) - except OSError as exc: - logger.warning("Failed to migrate token from keyring to file: {error}", error=exc) - else: - with suppress(Exception): - _delete_from_keyring(ref.key) - return token - - -def save_tokens(ref: OAuthRef, token: OAuthToken) -> OAuthRef: - if ref.storage == "keyring": - logger.warning("Keyring storage is deprecated; saving OAuth tokens to file.") - ref = OAuthRef(storage="file", key=ref.key) - _save_to_file(ref.key, token) - return ref - - -def delete_tokens(ref: OAuthRef) -> None: - if ref.storage == "keyring": - _delete_from_keyring(ref.key) - _delete_from_file(ref.key) - - -async def request_device_authorization() -> DeviceAuthorization: - async with ( - new_client_session() as session, - session.post( - f"{_oauth_host().rstrip('/')}/api/oauth/device_authorization", - data={"client_id": KIMI_CODE_CLIENT_ID}, - headers=_common_headers(), - ) as response, - ): - data = await response.json(content_type=None) - status = response.status - if status != 200: - raise OAuthError(f"Device authorization failed: {data}") - return DeviceAuthorization( - user_code=str(data["user_code"]), - device_code=str(data["device_code"]), - verification_uri=str(data.get("verification_uri") or ""), - verification_uri_complete=str(data["verification_uri_complete"]), - expires_in=int(data.get("expires_in") or 0) or None, - interval=int(data.get("interval") or 5), - ) - - -async def _request_device_token(auth: DeviceAuthorization) -> tuple[int, dict[str, Any]]: - try: - async with ( - new_client_session() as session, - session.post( - f"{_oauth_host().rstrip('/')}/api/oauth/token", - data={ - "client_id": KIMI_CODE_CLIENT_ID, - "device_code": auth.device_code, - "grant_type": "urn:ietf:params:oauth:grant-type:device_code", - }, - headers=_common_headers(), - ) as response, - ): - data_any: Any = await response.json(content_type=None) - status = response.status - except aiohttp.ClientError as exc: - raise OAuthError("Token polling request failed.") from exc - if not isinstance(data_any, dict): - raise OAuthError("Unexpected token polling response.") - data = cast(dict[str, Any], data_any) - if status >= 500: - raise OAuthError(f"Token polling server error: {status}.") - return status, data - - -async def refresh_token(refresh_token: str) -> OAuthToken: - async with ( - new_client_session() as session, - session.post( - f"{_oauth_host().rstrip('/')}/api/oauth/token", - data={ - "client_id": KIMI_CODE_CLIENT_ID, - "grant_type": "refresh_token", - "refresh_token": refresh_token, - }, - headers=_common_headers(), - ) as response, - ): - data = await response.json(content_type=None) - status = response.status - if status in (401, 403): - raise OAuthUnauthorized(data.get("error_description") or "Token refresh unauthorized.") - if status != 200: - raise OAuthError(data.get("error_description") or "Token refresh failed.") - return OAuthToken.from_response(data) - - -def _select_default_model_and_thinking(models: list[ModelInfo]) -> tuple[ModelInfo, bool] | None: - if not models: - return None - selected_model = models[0] - capabilities = selected_model.capabilities - thinking = "thinking" in capabilities or "always_thinking" in capabilities - return selected_model, thinking - - -def _apply_kimi_code_config( - config: Config, - *, - models: list[ModelInfo], - selected_model: ModelInfo, - thinking: bool, - oauth_ref: OAuthRef, -) -> None: - platform = get_platform_by_id(KIMI_CODE_PLATFORM_ID) - if platform is None: - raise OAuthError("Kimi Code platform not found.") - - provider_key = managed_provider_key(platform.id) - config.providers[provider_key] = LLMProvider( - type="kimi", - base_url=platform.base_url, - api_key=SecretStr(""), - oauth=oauth_ref, - ) - - for key, model in list(config.models.items()): - if model.provider == provider_key: - del config.models[key] - - for model_info in models: - capabilities = model_info.capabilities or None - config.models[managed_model_key(platform.id, model_info.id)] = LLMModel( - provider=provider_key, - model=model_info.id, - max_context_size=model_info.context_length, - capabilities=capabilities, - ) - - config.default_model = managed_model_key(platform.id, selected_model.id) - config.default_thinking = thinking - - if platform.search_url: - config.services.moonshot_search = MoonshotSearchConfig( - base_url=platform.search_url, - api_key=SecretStr(""), - oauth=oauth_ref, - ) - - if platform.fetch_url: - config.services.moonshot_fetch = MoonshotFetchConfig( - base_url=platform.fetch_url, - api_key=SecretStr(""), - oauth=oauth_ref, - ) - - -async def login_kimi_code( - config: Config, *, open_browser: bool = True -) -> AsyncIterator[OAuthEvent]: - if not config.is_from_default_location: - yield OAuthEvent( - "error", - "Login requires the default config file; restart without --config/--config-file.", - ) - return - - platform = get_platform_by_id(KIMI_CODE_PLATFORM_ID) - if platform is None: - yield OAuthEvent("error", "Kimi Code platform is unavailable.") - return - - auth: DeviceAuthorization - token: OAuthToken | None = None - while True: - try: - auth = await request_device_authorization() - except Exception as exc: - yield OAuthEvent("error", f"Login failed: {exc}") - return - - yield OAuthEvent( - "info", - "Please visit the following URL to finish authorization.", - ) - yield OAuthEvent( - "verification_url", - f"Verification URL: {auth.verification_uri_complete}", - data={ - "verification_url": auth.verification_uri_complete, - "user_code": auth.user_code, - }, - ) - if open_browser: - try: - webbrowser.open(auth.verification_uri_complete) - except Exception as exc: - logger.warning("Failed to open browser: {error}", error=exc) - - interval = max(auth.interval, 1) - printed_wait = False - try: - while True: - status, data = await _request_device_token(auth) - if status == 200 and "access_token" in data: - token = OAuthToken.from_response(data) - break - error_code = str(data.get("error") or "unknown_error") - if error_code == "expired_token": - raise OAuthDeviceExpired("Device code expired.") - error_description = str(data.get("error_description") or "") - if not printed_wait: - yield OAuthEvent( - "waiting", - f"Waiting for user authorization...: {error_description.strip()}", - data={ - "error": error_code, - "error_description": error_description, - }, - ) - printed_wait = True - await asyncio.sleep(interval) - except OAuthDeviceExpired: - yield OAuthEvent("info", "Device code expired, restarting login...") - continue - except Exception as exc: - yield OAuthEvent("error", f"Login failed: {exc}") - return - break - - assert token is not None - - oauth_ref = OAuthRef(storage="file", key=KIMI_CODE_OAUTH_KEY) - oauth_ref = save_tokens(oauth_ref, token) - - try: - models = await list_models(platform, token.access_token) - except Exception as exc: - logger.error("Failed to get models: {error}", error=exc) - yield OAuthEvent("error", f"Failed to get models: {exc}") - return - - if not models: - yield OAuthEvent("error", "No models available for the selected platform.") - return - - selection = _select_default_model_and_thinking(models) - if selection is None: - return - selected_model, thinking = selection - - _apply_kimi_code_config( - config, - models=models, - selected_model=selected_model, - thinking=thinking, - oauth_ref=oauth_ref, - ) - save_config(config) - yield OAuthEvent("success", "Logged in successfully.") - return - - -async def logout_kimi_code(config: Config) -> AsyncIterator[OAuthEvent]: - if not config.is_from_default_location: - yield OAuthEvent( - "error", - "Logout requires the default config file; restart without --config/--config-file.", - ) - return - - delete_tokens(OAuthRef(storage="keyring", key=KIMI_CODE_OAUTH_KEY)) - delete_tokens(OAuthRef(storage="file", key=KIMI_CODE_OAUTH_KEY)) - - provider_key = managed_provider_key(KIMI_CODE_PLATFORM_ID) - if provider_key in config.providers: - del config.providers[provider_key] - - removed_default = False - for key, model in list(config.models.items()): - if model.provider != provider_key: - continue - del config.models[key] - if config.default_model == key: - removed_default = True - - if removed_default: - config.default_model = "" - - config.services.moonshot_search = None - config.services.moonshot_fetch = None - - save_config(config) - yield OAuthEvent("success", "Logged out successfully.") - return - - -class OAuthManager: - def __init__(self, config: Config) -> None: - self._config = config - # Cache access tokens only; refresh tokens are always read from persisted storage. - self._access_tokens: dict[str, str] = {} - self._refresh_lock = asyncio.Lock() - self._migrate_oauth_storage() - self._load_initial_tokens() - - def _iter_oauth_refs(self) -> list[OAuthRef]: - refs: list[OAuthRef] = [] - for provider in self._config.providers.values(): - if provider.oauth: - refs.append(provider.oauth) - for service in ( - self._config.services.moonshot_search, - self._config.services.moonshot_fetch, - ): - if service and service.oauth: - refs.append(service.oauth) - return refs - - def _migrate_oauth_storage(self) -> None: - migrated_keys: set[str] = set() - changed = False - - def _migrate_ref(ref: OAuthRef) -> OAuthRef: - nonlocal changed - if ref.storage != "keyring": - return ref - if ref.key not in migrated_keys: - load_tokens(ref) - migrated_keys.add(ref.key) - changed = True - return OAuthRef(storage="file", key=ref.key) - - for provider in self._config.providers.values(): - if provider.oauth: - provider.oauth = _migrate_ref(provider.oauth) - - for service in ( - self._config.services.moonshot_search, - self._config.services.moonshot_fetch, - ): - if service and service.oauth: - service.oauth = _migrate_ref(service.oauth) - - if changed and self._config.is_from_default_location: - save_config(self._config) - - def _load_initial_tokens(self) -> None: - for ref in self._iter_oauth_refs(): - token = load_tokens(ref) - if token: - self._cache_access_token(ref, token) - - def _cache_access_token(self, ref: OAuthRef, token: OAuthToken) -> None: - if not token.access_token: - self._access_tokens.pop(ref.key, None) - return - self._access_tokens[ref.key] = token.access_token - - def common_headers(self) -> dict[str, str]: - return _common_headers() - - def resolve_api_key(self, api_key: SecretStr, oauth: OAuthRef | None) -> str: - if oauth: - token = self._access_tokens.get(oauth.key) - if token is None: - persisted = load_tokens(oauth) - if persisted: - self._cache_access_token(oauth, persisted) - token = self._access_tokens.get(oauth.key) - if token: - return token - logger.warning( - "OAuth ref present (key={key}) but no access token resolved; " - "falling back to configured api_key", - key=oauth.key, - ) - return api_key.get_secret_value() - - def _kimi_code_ref(self) -> OAuthRef | None: - provider_key = managed_provider_key(KIMI_CODE_PLATFORM_ID) - provider = self._config.providers.get(provider_key) - if provider and provider.oauth: - return provider.oauth - for service in ( - self._config.services.moonshot_search, - self._config.services.moonshot_fetch, - ): - if service and service.oauth and service.oauth.key == KIMI_CODE_OAUTH_KEY: - return service.oauth - return None - - async def ensure_fresh(self, runtime: Runtime | None = None) -> None: - """Load persisted tokens, cache them, and refresh if close to expiry. - - Args: - runtime: When provided the live LLM client's API key is updated - in-place. Pass ``None`` for lightweight callers (e.g. title - generation) that only need the internal cache to be current. - """ - ref = self._kimi_code_ref() - if ref is None: - return - token = load_tokens(ref) - if token is None: - return - self._cache_access_token(ref, token) - self._apply_access_token(runtime, token.access_token) - await self._refresh_tokens(ref, token, runtime) - - @asynccontextmanager - async def refreshing(self, runtime: Runtime) -> AsyncIterator[None]: - stop_event = asyncio.Event() - - async def _runner() -> None: - try: - while True: - try: - await asyncio.wait_for( - stop_event.wait(), - timeout=REFRESH_INTERVAL_SECONDS, - ) - return - except TimeoutError: - pass - try: - await self.ensure_fresh(runtime) - except Exception as exc: - logger.warning( - "Failed to refresh OAuth token in background: {error}", - error=exc, - ) - except asyncio.CancelledError: - pass - - await self.ensure_fresh(runtime) - refresh_task = asyncio.create_task(_runner()) - try: - yield - finally: - stop_event.set() - refresh_task.cancel() - with suppress(asyncio.CancelledError): - await refresh_task - - async def _refresh_tokens( - self, - ref: OAuthRef, - token: OAuthToken, - runtime: Runtime | None, - ) -> None: - # Always prefer persisted tokens before refresh to avoid stale cache - # when multiple sessions might have already rotated the refresh token. - persisted = load_tokens(ref) - if persisted: - self._cache_access_token(ref, persisted) - current_token = persisted or token - if not current_token.refresh_token: - return - async with self._refresh_lock: - # Re-check persisted token inside the lock to reduce races. - persisted = load_tokens(ref) - if persisted: - self._cache_access_token(ref, persisted) - current = persisted or current_token - now = time.time() - if ( - current.expires_at - and current.expires_at > now - and current.expires_at - now >= REFRESH_THRESHOLD_SECONDS - ): - return - refresh_token_value = current.refresh_token - if not refresh_token_value: - return - try: - refreshed = await refresh_token(refresh_token_value) - except OAuthUnauthorized as exc: - # If another session refreshed and persisted a new token, - # do not delete it. Just sync memory and exit. - latest = load_tokens(ref) - if latest and latest.refresh_token != refresh_token_value: - self._cache_access_token(ref, latest) - self._apply_access_token(runtime, latest.access_token) - return - logger.warning( - "OAuth credentials rejected, deleting stored tokens: {error}", - error=exc, - ) - self._access_tokens.pop(ref.key, None) - delete_tokens(ref) - self._apply_access_token(runtime, "") - return - except Exception as exc: - logger.warning("Failed to refresh OAuth token: {error}", error=exc) - return - save_tokens(ref, refreshed) - self._cache_access_token(ref, refreshed) - self._apply_access_token(runtime, refreshed.access_token) - - def _apply_access_token(self, runtime: Runtime | None, access_token: str) -> None: - if runtime is None: - return - provider_key = managed_provider_key(KIMI_CODE_PLATFORM_ID) - if runtime.llm is None or runtime.llm.model_config is None: - return - if runtime.llm.model_config.provider != provider_key: - return - from kosong.chat_provider.kimi import Kimi - - assert isinstance(runtime.llm.chat_provider, Kimi), "Expected Kimi chat provider" - runtime.llm.chat_provider.client.api_key = access_token - - -if __name__ == "__main__": - from rich import print - - print(_common_headers()) diff --git a/src/kimi_cli/auth/oauth.ts b/src/kimi_cli/auth/oauth.ts new file mode 100644 index 000000000..17405bfcd --- /dev/null +++ b/src/kimi_cli/auth/oauth.ts @@ -0,0 +1,560 @@ +/** + * OAuth module — corresponds to Python auth/oauth.py + * Device-code OAuth flow, token storage & refresh for Kimi Code. + */ + +import { join } from "node:path"; +import { randomUUID } from "node:crypto"; +import { hostname, platform, arch, release } from "node:os"; +import { getShareDir, type Config, saveConfig } from "../config.ts"; +import type { OAuthRef } from "../config.ts"; +import { getVersion } from "../constant.ts"; +import { + KIMI_CODE_PLATFORM_ID, + getPlatformById, + listModels, + managedProviderKey, + managedModelKey, + deriveModelCapabilities, + type ModelInfo, +} from "./platforms.ts"; +import { logger } from "../utils/logging.ts"; + +// ── Constants ─────────────────────────────────────────── + +const KIMI_CODE_CLIENT_ID = "17e5f671-d194-4dfb-9706-5516cb48c098"; +export const KIMI_CODE_OAUTH_KEY = "oauth/kimi-code"; +const DEFAULT_OAUTH_HOST = "https://auth.kimi.com"; +const KEYRING_SERVICE = "kimi-code"; +export const REFRESH_INTERVAL_SECONDS = 60; +export const REFRESH_THRESHOLD_SECONDS = 300; + +// ── Errors ────────────────────────────────────────────── + +export class OAuthError extends Error { + constructor(message: string) { + super(message); + this.name = "OAuthError"; + } +} + +export class OAuthUnauthorized extends OAuthError { + constructor(message = "OAuth credentials rejected.") { + super(message); + this.name = "OAuthUnauthorized"; + } +} + +export class OAuthDeviceExpired extends OAuthError { + constructor(message = "Device authorization expired.") { + super(message); + this.name = "OAuthDeviceExpired"; + } +} + +// ── Event / Token types ───────────────────────────────── + +export type OAuthEventKind = "info" | "error" | "waiting" | "verification_url" | "success"; + +export interface OAuthEvent { + type: OAuthEventKind; + message: string; + data?: Record; +} + +export interface OAuthToken { + access_token: string; + refresh_token: string; + expires_at: number; + scope: string; + token_type: string; +} + +export interface DeviceAuthorization { + user_code: string; + device_code: string; + verification_uri: string; + verification_uri_complete: string; + expires_in: number | null; + interval: number; +} + +// ── Helpers ───────────────────────────────────────────── + +function oauthHost(): string { + return process.env.KIMI_CODE_OAUTH_HOST ?? process.env.KIMI_OAUTH_HOST ?? DEFAULT_OAUTH_HOST; +} + +function credentialsDir(): string { + return join(getShareDir(), "credentials"); +} + +function credentialsPath(key: string): string { + const name = key.replace(/^oauth\//, "").split("/").pop() ?? key; + return join(credentialsDir(), `${name}.json`); +} + +function deviceIdPath(): string { + return join(getShareDir(), "device_id"); +} + +export async function getDeviceId(): Promise { + const path = deviceIdPath(); + const file = Bun.file(path); + if (await file.exists()) { + return (await file.text()).trim(); + } + const deviceId = randomUUID().replace(/-/g, ""); + await Bun.$`mkdir -p ${getShareDir()}`.quiet(); + await Bun.write(path, deviceId); + return deviceId; +} + +function deviceModel(): string { + const sys = platform(); + const a = arch(); + if (sys === "darwin") return `macOS ${a}`; + if (sys === "win32") return `Windows ${a}`; + if (sys === "linux") return `Linux ${a}`; + return `${sys} ${a}`; +} + +export async function commonHeaders(): Promise> { + return { + "X-Msh-Platform": "kimi_cli", + "X-Msh-Version": getVersion(), + "X-Msh-Device-Name": hostname(), + "X-Msh-Device-Model": deviceModel(), + "X-Msh-Os-Version": release(), + "X-Msh-Device-Id": await getDeviceId(), + }; +} + +// ── Token persistence (file-based) ───────────────────── + +export async function loadTokens(ref: OAuthRef): Promise { + const path = credentialsPath(ref.key); + const file = Bun.file(path); + if (!(await file.exists())) return null; + try { + return (await file.json()) as OAuthToken; + } catch { + return null; + } +} + +export async function saveTokens(ref: OAuthRef, token: OAuthToken): Promise { + const path = credentialsPath(ref.key); + await Bun.$`mkdir -p ${credentialsDir()}`.quiet(); + await Bun.write(path, JSON.stringify(token)); + return { storage: "file", key: ref.key }; +} + +export async function deleteTokens(ref: OAuthRef): Promise { + const path = credentialsPath(ref.key); + const file = Bun.file(path); + if (await file.exists()) { + await Bun.$`rm -f ${path}`.quiet(); + } +} + +// ── Device authorization flow ─────────────────────────── + +export async function requestDeviceAuthorization(): Promise { + const host = oauthHost().replace(/\/+$/, ""); + const headers = await commonHeaders(); + const res = await fetch(`${host}/api/oauth/device_authorization`, { + method: "POST", + headers: { ...headers, "Content-Type": "application/x-www-form-urlencoded" }, + body: new URLSearchParams({ client_id: KIMI_CODE_CLIENT_ID }), + }); + const data = (await res.json()) as Record; + if (res.status !== 200) throw new OAuthError(`Device authorization failed: ${JSON.stringify(data)}`); + return { + user_code: String(data.user_code), + device_code: String(data.device_code), + verification_uri: String(data.verification_uri ?? ""), + verification_uri_complete: String(data.verification_uri_complete), + expires_in: data.expires_in ? Number(data.expires_in) : null, + interval: Number(data.interval ?? 5), + }; +} + +/** Poll the token endpoint once. Corresponds to Python _request_device_token. */ +async function requestDeviceToken(auth: DeviceAuthorization): Promise<{ status: number; data: Record }> { + const host = oauthHost().replace(/\/+$/, ""); + const headers = await commonHeaders(); + try { + const res = await fetch(`${host}/api/oauth/token`, { + method: "POST", + headers: { ...headers, "Content-Type": "application/x-www-form-urlencoded" }, + body: new URLSearchParams({ + client_id: KIMI_CODE_CLIENT_ID, + device_code: auth.device_code, + grant_type: "urn:ietf:params:oauth:grant-type:device_code", + }), + }); + const data = (await res.json()) as Record; + if (res.status >= 500) throw new OAuthError(`Token polling server error: ${res.status}.`); + return { status: res.status, data }; + } catch (err) { + if (err instanceof OAuthError) throw err; + throw new OAuthError("Token polling request failed."); + } +} + +export async function refreshToken(refreshTokenValue: string): Promise { + const host = oauthHost().replace(/\/+$/, ""); + const headers = await commonHeaders(); + const res = await fetch(`${host}/api/oauth/token`, { + method: "POST", + headers: { ...headers, "Content-Type": "application/x-www-form-urlencoded" }, + body: new URLSearchParams({ + client_id: KIMI_CODE_CLIENT_ID, + grant_type: "refresh_token", + refresh_token: refreshTokenValue, + }), + }); + const data = (await res.json()) as Record; + if (res.status === 401 || res.status === 403) { + throw new OAuthUnauthorized((data.error_description as string) ?? "Token refresh unauthorized."); + } + if (res.status !== 200) { + throw new OAuthError((data.error_description as string) ?? "Token refresh failed."); + } + return { + access_token: String(data.access_token), + refresh_token: String(data.refresh_token), + expires_at: Date.now() / 1000 + Number(data.expires_in), + scope: String(data.scope), + token_type: String(data.token_type), + }; +} + +// ── Kimi Code login/logout (async generator) ──────────── + +function selectDefaultModelAndThinking(models: ModelInfo[]): { model: ModelInfo; thinking: boolean } | null { + if (!models.length) return null; + const model = models[0]!; + const caps = deriveModelCapabilities(model); + const thinking = caps.has("thinking") || caps.has("always_thinking"); + return { model, thinking }; +} + +function applyKimiCodeConfig( + config: Config, + opts: { + models: ModelInfo[]; + selectedModel: ModelInfo; + thinking: boolean; + oauthRef: OAuthRef; + }, +): void { + const plat = getPlatformById(KIMI_CODE_PLATFORM_ID); + if (!plat) throw new OAuthError("Kimi Code platform not found."); + + const providerKey = managedProviderKey(plat.id); + config.providers[providerKey] = { + type: "kimi", + base_url: plat.baseUrl, + api_key: "", + oauth: opts.oauthRef, + }; + + // Remove old models for this provider + for (const [key, model] of Object.entries(config.models)) { + if (model.provider === providerKey) delete config.models[key]; + } + + // Add fresh models + for (const modelInfo of opts.models) { + const caps = deriveModelCapabilities(modelInfo); + config.models[managedModelKey(plat.id, modelInfo.id)] = { + provider: providerKey, + model: modelInfo.id, + max_context_size: modelInfo.contextLength, + capabilities: caps.size > 0 ? ([...caps] as any) : undefined, + }; + } + + config.default_model = managedModelKey(plat.id, opts.selectedModel.id); + config.default_thinking = opts.thinking; + + if (plat.searchUrl) { + config.services = config.services ?? {}; + (config.services as any).moonshot_search = { + base_url: plat.searchUrl, + api_key: "", + oauth: opts.oauthRef, + }; + } + if (plat.fetchUrl) { + config.services = config.services ?? {}; + (config.services as any).moonshot_fetch = { + base_url: plat.fetchUrl, + api_key: "", + oauth: opts.oauthRef, + }; + } +} + +/** + * Run the Kimi Code OAuth device-code login flow. + * Yields OAuthEvent objects for UI display. + * Corresponds to Python login_kimi_code(). + */ +export async function* loginKimiCode( + config: Config, + opts: { openBrowser?: boolean } = {}, +): AsyncGenerator { + const plat = getPlatformById(KIMI_CODE_PLATFORM_ID); + if (!plat) { + yield { type: "error", message: "Kimi Code platform is unavailable." }; + return; + } + + let token: OAuthToken | null = null; + + // Retry loop — device codes can expire + while (true) { + let auth: DeviceAuthorization; + try { + auth = await requestDeviceAuthorization(); + } catch (err) { + yield { type: "error", message: `Login failed: ${err}` }; + return; + } + + yield { type: "info", message: "Please visit the following URL to finish authorization." }; + yield { + type: "verification_url", + message: `Verification URL: ${auth.verification_uri_complete}`, + data: { verification_url: auth.verification_uri_complete, user_code: auth.user_code }, + }; + + if (opts.openBrowser !== false) { + try { + // Use Bun.spawn to open URL in default browser + const proc = Bun.spawn( + process.platform === "darwin" + ? ["open", auth.verification_uri_complete] + : process.platform === "win32" + ? ["cmd", "/c", "start", auth.verification_uri_complete] + : ["xdg-open", auth.verification_uri_complete], + { stdout: "ignore", stderr: "ignore" }, + ); + await proc.exited; + } catch { + // Ignore browser open failures + } + } + + let interval = Math.max(auth.interval, 1); + let printedWait = false; + + try { + while (true) { + const { status, data } = await requestDeviceToken(auth); + if (status === 200 && data.access_token) { + token = { + access_token: String(data.access_token), + refresh_token: String(data.refresh_token), + expires_at: Date.now() / 1000 + Number(data.expires_in), + scope: String(data.scope ?? ""), + token_type: String(data.token_type ?? "bearer"), + }; + break; + } + const errorCode = String(data.error ?? "unknown_error"); + if (errorCode === "expired_token") throw new OAuthDeviceExpired(); + if (!printedWait) { + const desc = String(data.error_description ?? ""); + yield { + type: "waiting", + message: `Waiting for user authorization...${desc ? ": " + desc.trim() : ""}`, + data: { error: errorCode, error_description: desc }, + }; + printedWait = true; + } + await new Promise((r) => setTimeout(r, interval * 1000)); + } + } catch (err) { + if (err instanceof OAuthDeviceExpired) { + yield { type: "info", message: "Device code expired, restarting login..." }; + continue; // Retry outer loop + } + yield { type: "error", message: `Login failed: ${err}` }; + return; + } + break; // Got token, exit retry loop + } + + if (!token) return; + + // Save token + const oauthRef: OAuthRef = { storage: "file", key: KIMI_CODE_OAUTH_KEY }; + await saveTokens(oauthRef, token); + + // Fetch models + let models: ModelInfo[]; + try { + models = await listModels(plat, token.access_token); + } catch (err) { + logger.error(`Failed to get models: ${err}`); + yield { type: "error", message: `Failed to get models: ${err}` }; + return; + } + + if (!models.length) { + yield { type: "error", message: "No models available for the selected platform." }; + return; + } + + const selection = selectDefaultModelAndThinking(models); + if (!selection) return; + + applyKimiCodeConfig(config, { + models, + selectedModel: selection.model, + thinking: selection.thinking, + oauthRef, + }); + await saveConfig(config); + yield { type: "success", message: "Logged in successfully." }; +} + +/** + * Logout from Kimi Code — delete tokens and clean up config. + * Corresponds to Python logout_kimi_code(). + */ +export async function* logoutKimiCode(config: Config): AsyncGenerator { + // Delete stored tokens (both keyring and file) + await deleteTokens({ storage: "keyring", key: KIMI_CODE_OAUTH_KEY }); + await deleteTokens({ storage: "file", key: KIMI_CODE_OAUTH_KEY }); + + const providerKey = managedProviderKey(KIMI_CODE_PLATFORM_ID); + if (config.providers[providerKey]) { + delete config.providers[providerKey]; + } + + let removedDefault = false; + for (const [key, model] of Object.entries(config.models)) { + if (model.provider !== providerKey) continue; + delete config.models[key]; + if (config.default_model === key) removedDefault = true; + } + if (removedDefault) config.default_model = ""; + + if (config.services) { + (config.services as any).moonshot_search = undefined; + (config.services as any).moonshot_fetch = undefined; + } + + await saveConfig(config); + yield { type: "success", message: "Logged out successfully." }; +} + +// ── OAuthManager ──────────────────────────────────────── + +export class OAuthManager { + private config: { providers: Record }; + private accessTokens = new Map(); + + constructor(config: { providers: Record }) { + this.config = config; + } + + async initialize(): Promise { + for (const provider of Object.values(this.config.providers)) { + if (provider.oauth) { + const token = await loadTokens(provider.oauth); + if (token) this.accessTokens.set(provider.oauth.key, token.access_token); + } + } + } + + async resolveApiKey(apiKey: string, oauth?: OAuthRef): Promise { + if (oauth) { + const cached = this.accessTokens.get(oauth.key); + if (cached) return cached; + const persisted = await loadTokens(oauth); + if (persisted) { + this.accessTokens.set(oauth.key, persisted.access_token); + return persisted.access_token; + } + logger.warn(`OAuth ref present (key=${oauth.key}) but no access token; falling back to api_key`); + } + return apiKey; + } + + async ensureFresh(): Promise { + for (const provider of Object.values(this.config.providers)) { + if (!provider.oauth) continue; + const token = await loadTokens(provider.oauth); + if (!token || !token.refresh_token) continue; + + this.accessTokens.set(provider.oauth.key, token.access_token); + + const now = Date.now() / 1000; + if (token.expires_at && token.expires_at > now && token.expires_at - now >= REFRESH_THRESHOLD_SECONDS) { + continue; + } + try { + const refreshed = await refreshToken(token.refresh_token); + await saveTokens(provider.oauth, refreshed); + this.accessTokens.set(provider.oauth.key, refreshed.access_token); + } catch (err) { + if (err instanceof OAuthUnauthorized) { + this.accessTokens.delete(provider.oauth.key); + await deleteTokens(provider.oauth); + } else { + logger.warn("Failed to refresh OAuth token", err); + } + } + } + } + + /** + * Background refresh loop — corresponds to Python OAuthManager.refreshing(). + * Periodically calls ensureFresh() until the returned abort function is called. + * Returns an AbortController; call abort() to stop the background loop. + */ + refreshing(): AbortController { + const controller = new AbortController(); + const signal = controller.signal; + + const run = async () => { + // Initial ensure fresh + try { + await this.ensureFresh(); + } catch (err) { + logger.warn(`Failed initial OAuth token refresh: ${err}`); + } + + while (!signal.aborted) { + try { + await new Promise((resolve, reject) => { + const timer = setTimeout(resolve, REFRESH_INTERVAL_SECONDS * 1000); + signal.addEventListener("abort", () => { + clearTimeout(timer); + reject(new Error("aborted")); + }, { once: true }); + }); + } catch { + break; // Aborted + } + + try { + await this.ensureFresh(); + } catch (err) { + logger.warn(`Failed to refresh OAuth token in background: ${err}`); + } + } + }; + + // Fire-and-forget background loop + run().catch(() => {}); + + return controller; + } +} diff --git a/src/kimi_cli/auth/platforms.py b/src/kimi_cli/auth/platforms.py deleted file mode 100644 index a474dd5a0..000000000 --- a/src/kimi_cli/auth/platforms.py +++ /dev/null @@ -1,293 +0,0 @@ -from __future__ import annotations - -import os -from typing import Any, NamedTuple, cast - -import aiohttp -from pydantic import BaseModel - -from kimi_cli.auth import KIMI_CODE_PLATFORM_ID -from kimi_cli.config import Config, LLMModel, load_config, save_config -from kimi_cli.llm import ModelCapability -from kimi_cli.utils.aiohttp import new_client_session -from kimi_cli.utils.logging import logger - - -class ModelInfo(BaseModel): - """Model information returned from the API.""" - - id: str - context_length: int - supports_reasoning: bool - supports_image_in: bool - supports_video_in: bool - - @property - def capabilities(self) -> set[ModelCapability]: - """Derive capabilities from model info.""" - caps: set[ModelCapability] = set() - if self.supports_reasoning: - caps.add("thinking") - # Models with "thinking" in name are always-thinking - if "thinking" in self.id.lower(): - caps.update(("thinking", "always_thinking")) - if self.supports_image_in: - caps.add("image_in") - if self.supports_video_in: - caps.add("video_in") - if "kimi-k2.5" in self.id.lower(): - caps.update(("thinking", "image_in", "video_in")) - return caps - - -class Platform(NamedTuple): - id: str - name: str - base_url: str - search_url: str | None = None - fetch_url: str | None = None - allowed_prefixes: list[str] | None = None - - -def _kimi_code_base_url() -> str: - if base_url := os.getenv("KIMI_CODE_BASE_URL"): - return base_url - return "https://api.kimi.com/coding/v1" - - -PLATFORMS: list[Platform] = [ - Platform( - id=KIMI_CODE_PLATFORM_ID, - name="Kimi Code", - base_url=_kimi_code_base_url(), - search_url=f"{_kimi_code_base_url()}/search", - fetch_url=f"{_kimi_code_base_url()}/fetch", - ), - Platform( - id="moonshot-cn", - name="Moonshot AI Open Platform (moonshot.cn)", - base_url="https://api.moonshot.cn/v1", - allowed_prefixes=["kimi-k"], - ), - Platform( - id="moonshot-ai", - name="Moonshot AI Open Platform (moonshot.ai)", - base_url="https://api.moonshot.ai/v1", - allowed_prefixes=["kimi-k"], - ), -] - -_PLATFORM_BY_ID = {platform.id: platform for platform in PLATFORMS} -_PLATFORM_BY_NAME = {platform.name: platform for platform in PLATFORMS} - - -def get_platform_by_id(platform_id: str) -> Platform | None: - return _PLATFORM_BY_ID.get(platform_id) - - -def get_platform_by_name(name: str) -> Platform | None: - return _PLATFORM_BY_NAME.get(name) - - -MANAGED_PROVIDER_PREFIX = "managed:" - - -def managed_provider_key(platform_id: str) -> str: - return f"{MANAGED_PROVIDER_PREFIX}{platform_id}" - - -def managed_model_key(platform_id: str, model_id: str) -> str: - return f"{platform_id}/{model_id}" - - -def parse_managed_provider_key(provider_key: str) -> str | None: - if not provider_key.startswith(MANAGED_PROVIDER_PREFIX): - return None - return provider_key.removeprefix(MANAGED_PROVIDER_PREFIX) - - -def is_managed_provider_key(provider_key: str) -> bool: - return provider_key.startswith(MANAGED_PROVIDER_PREFIX) - - -def get_platform_name_for_provider(provider_key: str) -> str | None: - platform_id = parse_managed_provider_key(provider_key) - if not platform_id: - return None - platform = get_platform_by_id(platform_id) - return platform.name if platform else None - - -async def refresh_managed_models(config: Config) -> bool: - if not config.is_from_default_location: - return False - - managed_providers = { - key: provider for key, provider in config.providers.items() if is_managed_provider_key(key) - } - if not managed_providers: - return False - - changed = False - updates: list[tuple[str, str, list[ModelInfo]]] = [] - for provider_key, provider in managed_providers.items(): - platform_id = parse_managed_provider_key(provider_key) - if not platform_id: - continue - platform = get_platform_by_id(platform_id) - if platform is None: - logger.warning("Managed platform not found: {platform}", platform=platform_id) - continue - - api_key = provider.api_key.get_secret_value() - if not api_key and provider.oauth: - from kimi_cli.auth.oauth import load_tokens - - token = load_tokens(provider.oauth) - if token: - api_key = token.access_token - if not api_key: - logger.warning( - "Missing API key for managed provider: {provider}", - provider=provider_key, - ) - continue - try: - models = await list_models(platform, api_key) - except Exception as exc: - logger.error( - "Failed to refresh models for {platform}: {error}", - platform=platform_id, - error=exc, - ) - continue - - updates.append((provider_key, platform_id, models)) - if _apply_models(config, provider_key, platform_id, models): - changed = True - - if changed: - config_for_save = load_config() - save_changed = False - for provider_key, platform_id, models in updates: - if _apply_models(config_for_save, provider_key, platform_id, models): - save_changed = True - if save_changed: - save_config(config_for_save) - return changed - - -async def list_models(platform: Platform, api_key: str) -> list[ModelInfo]: - async with new_client_session() as session: - models = await _list_models( - session, - base_url=platform.base_url, - api_key=api_key, - ) - if platform.allowed_prefixes is None: - return models - prefixes = tuple(platform.allowed_prefixes) - return [model for model in models if model.id.startswith(prefixes)] - - -async def _list_models( - session: aiohttp.ClientSession, - *, - base_url: str, - api_key: str, -) -> list[ModelInfo]: - models_url = f"{base_url.rstrip('/')}/models" - try: - async with session.get( - models_url, - headers={"Authorization": f"Bearer {api_key}"}, - raise_for_status=True, - ) as response: - resp_json = await response.json() - except aiohttp.ClientError: - raise - - data = resp_json.get("data") - if not isinstance(data, list): - raise ValueError(f"Unexpected models response for {base_url}") - - result: list[ModelInfo] = [] - for item in cast(list[dict[str, Any]], data): - model_id = item.get("id") - if not model_id: - continue - result.append( - ModelInfo( - id=str(model_id), - context_length=int(item.get("context_length") or 0), - supports_reasoning=bool(item.get("supports_reasoning")), - supports_image_in=bool(item.get("supports_image_in")), - supports_video_in=bool(item.get("supports_video_in")), - ) - ) - return result - - -def _apply_models( - config: Config, - provider_key: str, - platform_id: str, - models: list[ModelInfo], -) -> bool: - changed = False - model_keys: list[str] = [] - - for model in models: - model_key = managed_model_key(platform_id, model.id) - model_keys.append(model_key) - - existing = config.models.get(model_key) - capabilities = model.capabilities or None # empty set -> None - - if existing is None: - config.models[model_key] = LLMModel( - provider=provider_key, - model=model.id, - max_context_size=model.context_length, - capabilities=capabilities, - ) - changed = True - continue - - if existing.provider != provider_key: - existing.provider = provider_key - changed = True - if existing.model != model.id: - existing.model = model.id - changed = True - if existing.max_context_size != model.context_length: - existing.max_context_size = model.context_length - changed = True - if existing.capabilities != capabilities: - existing.capabilities = capabilities - changed = True - - removed_default = False - model_keys_set = set(model_keys) - for key, model in list(config.models.items()): - if model.provider != provider_key: - continue - if key in model_keys_set: - continue - del config.models[key] - if config.default_model == key: - removed_default = True - changed = True - - if removed_default: - if model_keys: - config.default_model = model_keys[0] - else: - config.default_model = next(iter(config.models), "") - changed = True - - if config.default_model and config.default_model not in config.models: - config.default_model = next(iter(config.models), "") - changed = True - - return changed diff --git a/src/kimi_cli/auth/platforms.ts b/src/kimi_cli/auth/platforms.ts new file mode 100644 index 000000000..8e00e9f4a --- /dev/null +++ b/src/kimi_cli/auth/platforms.ts @@ -0,0 +1,230 @@ +/** + * Platform definitions and model management. + * Corresponds to Python's auth/platforms.py. + */ + +import type { Config, OAuthRef } from "../config.ts"; +import type { ModelCapability } from "../types.ts"; +import { logger } from "../utils/logging.ts"; + +// ── Constants ──────────────────────────────────────────── + +export const KIMI_CODE_PLATFORM_ID = "kimi-code"; +export const MANAGED_PROVIDER_PREFIX = "managed:"; + +// ── Types ──────────────────────────────────────────────── + +export interface ModelInfo { + id: string; + contextLength: number; + supportsReasoning: boolean; + supportsImageIn: boolean; + supportsVideoIn: boolean; +} + +export function deriveModelCapabilities(model: ModelInfo): Set { + const caps = new Set(); + if (model.supportsReasoning) caps.add("thinking"); + if (model.id.toLowerCase().includes("thinking")) { + caps.add("thinking"); + caps.add("always_thinking"); + } + if (model.supportsImageIn) caps.add("image_in"); + if (model.supportsVideoIn) caps.add("video_in"); + if (model.id.toLowerCase().includes("kimi-k2.5")) { + caps.add("thinking"); + caps.add("image_in"); + caps.add("video_in"); + } + return caps; +} + +export interface Platform { + id: string; + name: string; + baseUrl: string; + searchUrl?: string; + fetchUrl?: string; + allowedPrefixes?: string[]; +} + +// ── Platform registry ──────────────────────────────────── + +function kimiCodeBaseUrl(): string { + return process.env.KIMI_CODE_BASE_URL ?? "https://api.kimi.com/coding/v1"; +} + +export const PLATFORMS: Platform[] = [ + { + id: KIMI_CODE_PLATFORM_ID, + name: "Kimi Code", + baseUrl: kimiCodeBaseUrl(), + searchUrl: `${kimiCodeBaseUrl()}/search`, + fetchUrl: `${kimiCodeBaseUrl()}/fetch`, + }, + { + id: "moonshot-cn", + name: "Moonshot AI Open Platform (moonshot.cn)", + baseUrl: "https://api.moonshot.cn/v1", + allowedPrefixes: ["kimi-k"], + }, + { + id: "moonshot-ai", + name: "Moonshot AI Open Platform (moonshot.ai)", + baseUrl: "https://api.moonshot.ai/v1", + allowedPrefixes: ["kimi-k"], + }, +]; + +const _platformById = new Map(PLATFORMS.map((p) => [p.id, p])); +const _platformByName = new Map(PLATFORMS.map((p) => [p.name, p])); + +export function getPlatformById(platformId: string): Platform | undefined { + return _platformById.get(platformId); +} + +export function getPlatformByName(name: string): Platform | undefined { + return _platformByName.get(name); +} + +// ── Key helpers ────────────────────────────────────────── + +export function managedProviderKey(platformId: string): string { + return `${MANAGED_PROVIDER_PREFIX}${platformId}`; +} + +export function managedModelKey(platformId: string, modelId: string): string { + return `${platformId}/${modelId}`; +} + +export function parseManagedProviderKey(providerKey: string): string | null { + if (!providerKey.startsWith(MANAGED_PROVIDER_PREFIX)) return null; + return providerKey.slice(MANAGED_PROVIDER_PREFIX.length); +} + +export function isManagedProviderKey(providerKey: string): boolean { + return providerKey.startsWith(MANAGED_PROVIDER_PREFIX); +} + +export function getPlatformNameForProvider(providerKey: string): string | null { + const platformId = parseManagedProviderKey(providerKey); + if (!platformId) return null; + const platform = getPlatformById(platformId); + return platform?.name ?? null; +} + +// ── Model listing ──────────────────────────────────────── + +export async function listModels(platform: Platform, apiKey: string): Promise { + const modelsUrl = `${platform.baseUrl.replace(/\/+$/, "")}/models`; + const res = await fetch(modelsUrl, { + headers: { Authorization: `Bearer ${apiKey}` }, + }); + if (!res.ok) { + throw new Error(`Failed to list models (HTTP ${res.status})`); + } + const json = (await res.json()) as { data?: unknown[] }; + const data = json.data; + if (!Array.isArray(data)) { + throw new Error(`Unexpected models response for ${platform.baseUrl}`); + } + + const models: ModelInfo[] = []; + for (const item of data as Record[]) { + const modelId = item.id; + if (!modelId) continue; + models.push({ + id: String(modelId), + contextLength: Number(item.context_length ?? 0), + supportsReasoning: Boolean(item.supports_reasoning), + supportsImageIn: Boolean(item.supports_image_in), + supportsVideoIn: Boolean(item.supports_video_in), + }); + } + + if (platform.allowedPrefixes) { + const prefixes = platform.allowedPrefixes; + return models.filter((m) => prefixes.some((p) => m.id.startsWith(p))); + } + return models; +} + +// ── Refresh managed models ─────────────────────────────── + +export async function refreshManagedModels(config: Config): Promise { + // Lazy import to avoid circular dependency (oauth.ts imports from platforms.ts) + const { loadTokens } = await import("./oauth.ts"); + + let changed = false; + for (const [providerKey, provider] of Object.entries(config.providers)) { + const platformId = parseManagedProviderKey(providerKey); + if (!platformId) continue; + const platform = getPlatformById(platformId); + if (!platform) continue; + + let apiKey = provider.api_key; + if (!apiKey && provider.oauth) { + const token = await loadTokens(provider.oauth); + if (token) apiKey = token.access_token; + } + if (!apiKey) continue; + + try { + const models = await listModels(platform, apiKey); + if (applyModels(config, providerKey, platformId, models)) { + changed = true; + } + } catch (err) { + logger.error(`Failed to refresh models for ${platformId}: ${err}`); + } + } + return changed; +} + +function applyModels( + config: Config, + providerKey: string, + platformId: string, + models: ModelInfo[], +): boolean { + let changed = false; + const modelKeys = new Set(); + + for (const model of models) { + const modelKey = managedModelKey(platformId, model.id); + modelKeys.add(modelKey); + + const existing = config.models[modelKey]; + const capabilities = deriveModelCapabilities(model); + const capsArray = capabilities.size > 0 ? [...capabilities] as ModelCapability[] : undefined; + + if (!existing) { + config.models[modelKey] = { + provider: providerKey, + model: model.id, + max_context_size: model.contextLength, + capabilities: capsArray, + }; + changed = true; + continue; + } + + if (existing.provider !== providerKey) { + existing.provider = providerKey; + changed = true; + } + } + + // Remove stale models + for (const [key, model] of Object.entries(config.models)) { + if (model.provider !== providerKey) continue; + if (modelKeys.has(key)) continue; + delete config.models[key]; + if (config.default_model === key) { + config.default_model = ""; + } + changed = true; + } + + return changed; +} diff --git a/src/kimi_cli/background/__init__.py b/src/kimi_cli/background/__init__.py deleted file mode 100644 index 3d5088101..000000000 --- a/src/kimi_cli/background/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -from .ids import generate_task_id -from .manager import BackgroundTaskManager -from .models import ( - TaskConsumerState, - TaskControl, - TaskKind, - TaskOutputChunk, - TaskRuntime, - TaskSpec, - TaskStatus, - TaskView, - is_terminal_status, -) -from .store import BackgroundTaskStore -from .summary import build_active_task_snapshot, format_task, format_task_list, list_task_views -from .worker import run_background_task_worker - -__all__ = [ - "BackgroundTaskManager", - "BackgroundTaskStore", - "TaskConsumerState", - "TaskControl", - "TaskKind", - "TaskOutputChunk", - "TaskRuntime", - "TaskSpec", - "TaskStatus", - "TaskView", - "build_active_task_snapshot", - "format_task", - "format_task_list", - "generate_task_id", - "is_terminal_status", - "list_task_views", - "run_background_task_worker", -] diff --git a/src/kimi_cli/background/agent_runner.py b/src/kimi_cli/background/agent_runner.py deleted file mode 100644 index 1625e0642..000000000 --- a/src/kimi_cli/background/agent_runner.py +++ /dev/null @@ -1,209 +0,0 @@ -# pyright: reportPrivateUsage=false -from __future__ import annotations - -import asyncio -import contextlib -from dataclasses import replace -from typing import TYPE_CHECKING - -from kimi_cli.approval_runtime import ( - ApprovalSource, - reset_current_approval_source, - set_current_approval_source, -) -from kimi_cli.subagents.builder import SubagentBuilder -from kimi_cli.subagents.core import SubagentRunSpec, prepare_soul -from kimi_cli.subagents.output import SubagentOutputWriter -from kimi_cli.subagents.runner import run_with_summary_continuation -from kimi_cli.utils.logging import logger -from kimi_cli.wire import Wire - -if TYPE_CHECKING: - from kimi_cli.approval_runtime.models import ApprovalRuntimeEvent - from kimi_cli.background.manager import BackgroundTaskManager - from kimi_cli.soul.agent import Runtime - - -class BackgroundAgentRunner: - def __init__( - self, - *, - runtime: Runtime, - manager: BackgroundTaskManager, - task_id: str, - agent_id: str, - subagent_type: str, - prompt: str, - model_override: str | None, - timeout_s: int | None = None, - resumed: bool = False, - ) -> None: - self._runtime = runtime - self._manager = manager - self._task_id = task_id - self._agent_id = agent_id - self._subagent_type = subagent_type - self._prompt = prompt - self._model_override = model_override - self._timeout_s = timeout_s - self._resumed = resumed - self._builder = SubagentBuilder(runtime) - self._approval_update_tasks: set[asyncio.Task[None]] = set() - - async def run(self) -> None: - assert self._runtime.approval_runtime is not None - assert self._runtime.subagent_store is not None - token = set_current_approval_source( - ApprovalSource( - kind="background_agent", - id=self._task_id, - agent_id=self._agent_id, - subagent_type=self._subagent_type, - ) - ) - approval_subscription = self._runtime.approval_runtime.subscribe( - self._on_approval_runtime_event - ) - task_output_path = self._manager.store.output_path(self._task_id) - output = SubagentOutputWriter( - self._runtime.subagent_store.output_path(self._agent_id), - extra_paths=[task_output_path], - ) - - try: - if self._timeout_s is not None: - await asyncio.wait_for(self._run_core(output), timeout=self._timeout_s) - else: - await self._run_core(output) - except TimeoutError as exc: - if isinstance(exc.__cause__, asyncio.CancelledError): - # Task-level timeout from wait_for (it raises TimeoutError from CancelledError) - logger.warning( - "Background agent task {id} timed out after {t}s", - id=self._task_id, - t=self._timeout_s, - ) - self._runtime.subagent_store.update_instance(self._agent_id, status="failed") - self._manager._mark_task_timed_out( - self._task_id, f"Agent task timed out after {self._timeout_s}s" - ) - output.error(f"Agent task timed out after {self._timeout_s}s") - else: - # Internal timeout (e.g. aiohttp request) — treat as generic failure - logger.exception("Background agent runner failed") - self._runtime.subagent_store.update_instance(self._agent_id, status="failed") - self._manager._mark_task_failed(self._task_id, str(exc)) - output.error(str(exc)) - except asyncio.CancelledError: - self._runtime.subagent_store.update_instance(self._agent_id, status="killed") - self._manager._mark_task_killed(self._task_id, "Stopped by TaskStop") - output.stage("cancelled") - raise - except Exception as exc: - logger.exception("Background agent runner failed") - self._runtime.subagent_store.update_instance(self._agent_id, status="failed") - self._manager._mark_task_failed(self._task_id, str(exc)) - output.error(str(exc)) - finally: - for task in list(self._approval_update_tasks): - task.cancel() - for task in list(self._approval_update_tasks): - with contextlib.suppress(asyncio.CancelledError): - await task - self._runtime.approval_runtime.unsubscribe(approval_subscription) - self._runtime.approval_runtime.cancel_by_source("background_agent", self._task_id) - reset_current_approval_source(token) - self._manager._live_agent_tasks.pop(self._task_id, None) - - async def _run_core(self, output: SubagentOutputWriter) -> None: - assert self._runtime.subagent_store is not None - self._manager._mark_task_running(self._task_id) - output.stage("runner_started") - - type_def = self._runtime.labor_market.require_builtin_type(self._subagent_type) - record = self._runtime.subagent_store.require_instance(self._agent_id) - launch_spec = record.launch_spec - if self._model_override is not None: - launch_spec = replace( - launch_spec, - model_override=self._model_override, - effective_model=self._model_override, - ) - - spec = SubagentRunSpec( - agent_id=self._agent_id, - type_def=type_def, - launch_spec=launch_spec, - prompt=self._prompt, - resumed=self._resumed, - ) - soul, prompt = await prepare_soul( - spec, - self._runtime, - self._builder, - self._runtime.subagent_store, - on_stage=output.stage, - ) - - async def _ui_loop_fn(wire: Wire) -> None: - wire_ui = wire.ui_side(merge=True) - while True: - msg = await wire_ui.receive() - output.write_wire_message(msg) - - output.stage("run_soul_start") - final_response, failure = await run_with_summary_continuation( - soul, - prompt, - _ui_loop_fn, - self._runtime.subagent_store.wire_path(self._agent_id), - ) - if failure is not None: - self._manager._mark_task_failed(self._task_id, failure.message) - self._runtime.subagent_store.update_instance(self._agent_id, status="failed") - output.stage(f"failed: {failure.brief}") - return - output.stage("run_soul_finished") - - assert final_response is not None - output.summary(final_response) - self._runtime.subagent_store.update_instance(self._agent_id, status="idle") - self._manager._mark_task_completed(self._task_id) - - def _on_approval_runtime_event(self, event: ApprovalRuntimeEvent) -> None: - request = event.request - if request.source.kind != "background_agent" or request.source.id != self._task_id: - return - task = asyncio.create_task(self._apply_approval_runtime_event(event)) - self._approval_update_tasks.add(task) - task.add_done_callback(self._approval_update_tasks.discard) - task.add_done_callback(self._log_approval_update_failure) - - async def _apply_approval_runtime_event(self, event: ApprovalRuntimeEvent) -> None: - request = event.request - if event.kind == "request_created": - await asyncio.to_thread( - self._manager._mark_task_awaiting_approval, - self._task_id, - request.description, - ) - elif event.kind == "request_resolved": - assert self._runtime.approval_runtime is not None - pending_for_task = [ - pending - for pending in self._runtime.approval_runtime.list_pending() - if pending.source.kind == "background_agent" and pending.source.id == self._task_id - ] - if pending_for_task: - return - await asyncio.to_thread( - self._manager._mark_task_running, - self._task_id, - ) - - @staticmethod - def _log_approval_update_failure(task: asyncio.Task[None]) -> None: - with contextlib.suppress(asyncio.CancelledError): - exc = task.exception() - if exc is not None: - logger.opt(exception=exc).error("Failed to apply background approval state update") diff --git a/src/kimi_cli/background/ids.py b/src/kimi_cli/background/ids.py deleted file mode 100644 index 282ac0d08..000000000 --- a/src/kimi_cli/background/ids.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -import secrets - -from .models import TaskKind - -_ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyz" - - -_TASK_ID_PREFIXES: dict[TaskKind, str] = { - "bash": "bash", - "agent": "agent", -} - - -def generate_task_id(kind: TaskKind) -> str: - prefix = _TASK_ID_PREFIXES[kind] - suffix = "".join(secrets.choice(_ALPHABET) for _ in range(8)) - return f"{prefix}-{suffix}" diff --git a/src/kimi_cli/background/ids.ts b/src/kimi_cli/background/ids.ts new file mode 100644 index 000000000..185518556 --- /dev/null +++ b/src/kimi_cli/background/ids.ts @@ -0,0 +1,21 @@ +/** + * Background task ID generation — corresponds to Python background/ids.py + */ + +import type { TaskKind } from "./models.ts"; + +const ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyz"; + +const TASK_ID_PREFIXES: Record = { + bash: "bash", + agent: "agent", +}; + +export function generateTaskId(kind: TaskKind): string { + const prefix = TASK_ID_PREFIXES[kind]; + let suffix = ""; + for (let i = 0; i < 8; i++) { + suffix += ALPHABET[Math.floor(Math.random() * ALPHABET.length)]; + } + return `${prefix}-${suffix}`; +} diff --git a/src/kimi_cli/background/manager.py b/src/kimi_cli/background/manager.py deleted file mode 100644 index fc6399f15..000000000 --- a/src/kimi_cli/background/manager.py +++ /dev/null @@ -1,580 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import signal -import subprocess -import sys -import time -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from kaos.local import local_kaos - -from kimi_cli.config import BackgroundConfig -from kimi_cli.notifications import NotificationEvent, NotificationManager -from kimi_cli.session import Session -from kimi_cli.utils.logging import logger - -if TYPE_CHECKING: - from kimi_cli.soul.agent import Runtime - -from .ids import generate_task_id -from .models import ( - TaskOutputChunk, - TaskRuntime, - TaskSpec, - TaskStatus, - TaskView, - is_terminal_status, -) -from .store import BackgroundTaskStore - - -class BackgroundTaskManager: - def __init__( - self, - session: Session, - config: BackgroundConfig, - *, - notifications: NotificationManager, - owner_role: str = "root", - ) -> None: - self._session = session - self._config = config - self._notifications = notifications - self._owner_role = owner_role - self._store = BackgroundTaskStore(session.context_file.parent / "tasks") - self._runtime: Runtime | None = None - self._live_agent_tasks: dict[str, asyncio.Task[None]] = {} - self._completion_event: asyncio.Event = asyncio.Event() - - @property - def completion_event(self) -> asyncio.Event: - """Event set when a new terminal notification is published. - - Not set immediately when a task becomes terminal — only after - ``reconcile()`` / ``publish_terminal_notifications()`` runs. - Deduplicated notifications do not trigger a repeat signal. - """ - return self._completion_event - - @property - def store(self) -> BackgroundTaskStore: - return self._store - - @property - def role(self) -> str: - return self._owner_role - - def copy_for_role(self, role: str) -> BackgroundTaskManager: - manager = BackgroundTaskManager( - self._session, - self._config, - notifications=self._notifications, - owner_role=role, - ) - manager._runtime = self._runtime - return manager - - def bind_runtime(self, runtime: Runtime) -> None: - self._runtime = runtime - - def _ensure_root(self) -> None: - if self._owner_role != "root": - raise RuntimeError("Background tasks are only supported from the root agent.") - - def _ensure_local_backend(self) -> None: - if self._session.work_dir_meta.kaos != local_kaos.name: - raise RuntimeError("Background tasks are only supported on local sessions.") - - def _active_task_count(self) -> int: - return sum( - 1 for view in self._store.list_views() if not is_terminal_status(view.runtime.status) - ) - - def _worker_command(self, task_dir: Path) -> list[str]: - if getattr(sys, "frozen", False): - return [ - sys.executable, - "__background-task-worker", - "--task-dir", - str(task_dir), - "--heartbeat-interval-ms", - str(self._config.worker_heartbeat_interval_ms), - "--control-poll-interval-ms", - str(self._config.wait_poll_interval_ms), - "--kill-grace-period-ms", - str(self._config.kill_grace_period_ms), - ] - return [ - sys.executable, - "-m", - "kimi_cli.cli", - "__background-task-worker", - "--task-dir", - str(task_dir), - "--heartbeat-interval-ms", - str(self._config.worker_heartbeat_interval_ms), - "--control-poll-interval-ms", - str(self._config.wait_poll_interval_ms), - "--kill-grace-period-ms", - str(self._config.kill_grace_period_ms), - ] - - def _launch_worker(self, task_dir: Path) -> int: - kwargs: dict[str, Any] = { - "stdin": subprocess.DEVNULL, - "stdout": subprocess.DEVNULL, - "stderr": subprocess.DEVNULL, - "cwd": str(task_dir), - } - if os.name == "nt": - kwargs["creationflags"] = getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0) - else: - kwargs["start_new_session"] = True - - process = subprocess.Popen(self._worker_command(task_dir), **kwargs) - return process.pid - - def create_bash_task( - self, - *, - command: str, - description: str, - timeout_s: int, - tool_call_id: str, - shell_name: str, - shell_path: str, - cwd: str, - ) -> TaskView: - self._ensure_root() - self._ensure_local_backend() - - if self._active_task_count() >= self._config.max_running_tasks: - raise RuntimeError("Too many background tasks are already running.") - - task_id = generate_task_id("bash") - spec = TaskSpec( - id=task_id, - kind="bash", - session_id=self._session.id, - description=description, - tool_call_id=tool_call_id, - owner_role="root", - command=command, - shell_name=shell_name, - shell_path=shell_path, - cwd=cwd, - timeout_s=timeout_s, - ) - self._store.create_task(spec) - - runtime = self._store.read_runtime(task_id) - task_dir = self._store.task_dir(task_id) - try: - worker_pid = self._launch_worker(task_dir) - except Exception as exc: - runtime.status = "failed" - runtime.failure_reason = f"Failed to launch worker: {exc}" - runtime.finished_at = time.time() - runtime.updated_at = runtime.finished_at - self._store.write_runtime(task_id, runtime) - raise - - runtime = self._store.read_runtime(task_id) - if runtime.finished_at is None and ( - runtime.status == "created" - or (runtime.status == "starting" and runtime.worker_pid is None) - ): - runtime.status = "starting" - runtime.worker_pid = worker_pid - runtime.updated_at = time.time() - self._store.write_runtime(task_id, runtime) - return self._store.merged_view(task_id) - - def create_agent_task( - self, - *, - agent_id: str, - subagent_type: str, - prompt: str, - description: str, - tool_call_id: str, - model_override: str | None, - timeout_s: int | None = None, - resumed: bool = False, - ) -> TaskView: - from .agent_runner import BackgroundAgentRunner - - self._ensure_root() - self._ensure_local_backend() - if self._runtime is None: - raise RuntimeError("Background task manager is not bound to a runtime.") - if self._active_task_count() >= self._config.max_running_tasks: - raise RuntimeError("Too many background tasks are already running.") - - task_id = generate_task_id("agent") - spec = TaskSpec( - id=task_id, - kind="agent", - session_id=self._session.id, - description=description, - tool_call_id=tool_call_id, - owner_role="root", - kind_payload={ - "agent_id": agent_id, - "subagent_type": subagent_type, - "prompt": prompt, - "model_override": model_override, - "launch_mode": "background", - }, - ) - self._store.create_task(spec) - runtime = self._store.read_runtime(task_id) - runtime.status = "starting" - runtime.updated_at = time.time() - self._store.write_runtime(task_id, runtime) - effective_timeout = timeout_s or self._config.agent_task_timeout_s - task = asyncio.create_task( - BackgroundAgentRunner( - runtime=self._runtime, - manager=self, - task_id=task_id, - agent_id=agent_id, - subagent_type=subagent_type, - prompt=prompt, - model_override=model_override, - timeout_s=effective_timeout, - resumed=resumed, - ).run() - ) - self._live_agent_tasks[task_id] = task - return self._store.merged_view(task_id) - - def list_tasks( - self, - *, - status: TaskStatus | None = None, - limit: int | None = 20, - ) -> list[TaskView]: - tasks = self._store.list_views() - if status is not None: - tasks = [task for task in tasks if task.runtime.status == status] - if limit is None: - return tasks - return tasks[:limit] - - def get_task(self, task_id: str) -> TaskView | None: - try: - return self._store.merged_view(task_id) - except (FileNotFoundError, ValueError): - return None - - def resolve_output_path(self, task_id: str) -> Path: - """Return the canonical output path for *task_id*.""" - return self._store.output_path(task_id) - - def read_output( - self, - task_id: str, - *, - offset: int = 0, - max_bytes: int | None = None, - ) -> TaskOutputChunk: - view = self._store.merged_view(task_id) - return self._store.read_output( - task_id, - offset, - max_bytes or self._config.read_max_bytes, - status=view.runtime.status, - ) - - def tail_output( - self, - task_id: str, - *, - max_bytes: int | None = None, - max_lines: int | None = None, - ) -> str: - self._store.merged_view(task_id) - return self._store.tail_output( - task_id, - max_bytes=max_bytes or self._config.read_max_bytes, - max_lines=max_lines or self._config.notification_tail_lines, - ) - - async def wait(self, task_id: str, *, timeout_s: int = 30) -> TaskView: - end_time = time.monotonic() + timeout_s - while True: - view = self._store.merged_view(task_id) - if is_terminal_status(view.runtime.status): - return view - if time.monotonic() >= end_time: - return view - await asyncio.sleep(self._config.wait_poll_interval_ms / 1000) - - def _best_effort_kill(self, runtime: TaskRuntime) -> None: - try: - if os.name == "nt": - pid = runtime.child_pid or runtime.worker_pid - if pid is None: - return - subprocess.run( - ["taskkill", "/PID", str(pid), "/T", "/F"], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - check=False, - ) - return - - if runtime.child_pgid is not None: - os.killpg(runtime.child_pgid, signal.SIGTERM) - return - if runtime.child_pid is not None: - os.kill(runtime.child_pid, signal.SIGTERM) - except ProcessLookupError: - pass - except Exception: - logger.exception("Failed to send best-effort kill signal") - - def kill(self, task_id: str, *, reason: str = "Killed by user") -> TaskView: - self._ensure_root() - view = self._store.merged_view(task_id) - if is_terminal_status(view.runtime.status): - return view - - if view.spec.kind == "agent": - self._mark_task_killed(task_id, reason) - if self._runtime is not None and self._runtime.approval_runtime is not None: - self._runtime.approval_runtime.cancel_by_source("background_agent", task_id) - task = self._live_agent_tasks.pop(task_id, None) - if task is not None: - task.cancel() - return self._store.merged_view(task_id) - - control = view.control.model_copy( - update={ - "kill_requested_at": time.time(), - "kill_reason": reason, - "force": False, - } - ) - self._store.write_control(task_id, control) - self._best_effort_kill(view.runtime) - return self._store.merged_view(task_id) - - def kill_all_active(self, *, reason: str = "CLI session ended") -> list[str]: - """Kill all non-terminal background tasks. Used during CLI shutdown.""" - killed: list[str] = [] - for view in self._store.list_views(): - if is_terminal_status(view.runtime.status): - continue - try: - self.kill(view.spec.id, reason=reason) - killed.append(view.spec.id) - except Exception: - logger.exception( - "Failed to kill task {task_id} during shutdown", - task_id=view.spec.id, - ) - return killed - - def recover(self) -> None: - now = time.time() - stale_after = self._config.worker_stale_after_ms / 1000 - for view in self._store.list_views(): - if is_terminal_status(view.runtime.status): - continue - if view.spec.kind == "agent": - if view.spec.id in self._live_agent_tasks: - continue - runtime = view.runtime.model_copy() - runtime.finished_at = now - runtime.updated_at = now - runtime.status = "lost" - runtime.failure_reason = "In-process background agent is no longer running" - self._store.write_runtime(view.spec.id, runtime) - agent_id = (view.spec.kind_payload or {}).get("agent_id") - if ( - isinstance(agent_id, str) - and self._runtime is not None - and self._runtime.subagent_store is not None - ): - record = self._runtime.subagent_store.get_instance(agent_id) - if record is not None and record.status == "running_background": - self._runtime.subagent_store.update_instance(agent_id, status="failed") - continue - last_progress_at = ( - view.runtime.heartbeat_at - or view.runtime.started_at - or view.runtime.updated_at - or view.spec.created_at - ) - if now - last_progress_at <= stale_after: - continue - - # Re-read runtime to narrow the race window with the worker process. - fresh_runtime = self._store.read_runtime(view.spec.id) - if is_terminal_status(fresh_runtime.status): - continue - fresh_progress = ( - fresh_runtime.heartbeat_at - or fresh_runtime.started_at - or fresh_runtime.updated_at - or view.spec.created_at - ) - if now - fresh_progress <= stale_after: - continue - - runtime = fresh_runtime.model_copy() - runtime.finished_at = now - runtime.updated_at = now - if view.control.kill_requested_at is not None: - runtime.status = "killed" - runtime.interrupted = True - runtime.failure_reason = view.control.kill_reason or "Killed during recovery" - else: - runtime.status = "lost" - runtime.failure_reason = ( - "Background worker never heartbeat after startup" - if fresh_runtime.heartbeat_at is None - else "Background worker heartbeat expired" - ) - self._store.write_runtime(view.spec.id, runtime) - - def reconcile(self, *, limit: int | None = None) -> list[str]: - self.recover() - return self.publish_terminal_notifications(limit=limit) - - def publish_terminal_notifications(self, *, limit: int | None = None) -> list[str]: - published: list[str] = [] - for view in self._store.list_views(): - if not is_terminal_status(view.runtime.status): - continue - - status = view.runtime.status - terminal_reason = "timed_out" if view.runtime.timed_out else status - match terminal_reason: - case "completed": - severity = "success" - title = f"Background task completed: {view.spec.description}" - case "timed_out": - severity = "error" - title = f"Background task timed out: {view.spec.description}" - case "failed": - severity = "error" - title = f"Background task failed: {view.spec.description}" - case "killed": - severity = "warning" - title = f"Background task stopped: {view.spec.description}" - case "lost": - severity = "warning" - title = f"Background task lost: {view.spec.description}" - case _: - severity = "info" - title = f"Background task updated: {view.spec.description}" - - body_lines = [ - f"Task ID: {view.spec.id}", - f"Status: {status}", - f"Description: {view.spec.description}", - ] - if terminal_reason != status: - body_lines.append(f"Terminal reason: {terminal_reason}") - if view.runtime.exit_code is not None: - body_lines.append(f"Exit code: {view.runtime.exit_code}") - if view.runtime.failure_reason: - body_lines.append(f"Failure reason: {view.runtime.failure_reason}") - - event = NotificationEvent( - id=self._notifications.new_id(), - category="task", - type=f"task.{terminal_reason}", - source_kind="background_task", - source_id=view.spec.id, - title=title, - body="\n".join(body_lines), - severity=severity, - payload={ - "task_id": view.spec.id, - "task_kind": view.spec.kind, - "status": status, - "description": view.spec.description, - "exit_code": view.runtime.exit_code, - "interrupted": view.runtime.interrupted, - "timed_out": view.runtime.timed_out, - "terminal_reason": terminal_reason, - "failure_reason": view.runtime.failure_reason, - }, - dedupe_key=f"background_task:{view.spec.id}:{terminal_reason}", - ) - notification = self._notifications.publish(event) - if notification.event.id == event.id: - published.append(notification.event.id) - self._completion_event.set() - if limit is not None and len(published) >= limit: - break - return published - - def _mark_task_running(self, task_id: str) -> None: - runtime = self._store.read_runtime(task_id) - if is_terminal_status(runtime.status): - return - runtime.status = "running" - runtime.updated_at = time.time() - runtime.heartbeat_at = runtime.updated_at - runtime.failure_reason = None - self._store.write_runtime(task_id, runtime) - - def _mark_task_awaiting_approval(self, task_id: str, reason: str) -> None: - runtime = self._store.read_runtime(task_id) - if is_terminal_status(runtime.status): - return - runtime.status = "awaiting_approval" - runtime.updated_at = time.time() - runtime.failure_reason = reason - self._store.write_runtime(task_id, runtime) - - def _mark_task_completed(self, task_id: str) -> None: - runtime = self._store.read_runtime(task_id) - if is_terminal_status(runtime.status): - return - runtime.status = "completed" - runtime.updated_at = time.time() - runtime.finished_at = runtime.updated_at - runtime.failure_reason = None - self._store.write_runtime(task_id, runtime) - - def _mark_task_failed(self, task_id: str, reason: str) -> None: - runtime = self._store.read_runtime(task_id) - if is_terminal_status(runtime.status): - return - runtime.status = "failed" - runtime.updated_at = time.time() - runtime.finished_at = runtime.updated_at - runtime.failure_reason = reason - self._store.write_runtime(task_id, runtime) - - def _mark_task_timed_out(self, task_id: str, reason: str) -> None: - runtime = self._store.read_runtime(task_id) - if is_terminal_status(runtime.status): - return - runtime.status = "failed" - runtime.updated_at = time.time() - runtime.finished_at = runtime.updated_at - runtime.interrupted = True - runtime.timed_out = True - runtime.failure_reason = reason - self._store.write_runtime(task_id, runtime) - - def _mark_task_killed(self, task_id: str, reason: str) -> None: - runtime = self._store.read_runtime(task_id) - if is_terminal_status(runtime.status): - return - runtime.status = "killed" - runtime.updated_at = time.time() - runtime.finished_at = runtime.updated_at - runtime.interrupted = True - runtime.failure_reason = reason - self._store.write_runtime(task_id, runtime) diff --git a/src/kimi_cli/background/manager.ts b/src/kimi_cli/background/manager.ts new file mode 100644 index 000000000..70f144bf1 --- /dev/null +++ b/src/kimi_cli/background/manager.ts @@ -0,0 +1,367 @@ +/** + * Background task manager — corresponds to Python background/manager.py + * Manages task lifecycle: create, list, stop, recover. + */ + +import { join } from "node:path"; +import { logger } from "../utils/logging.ts"; +import type { BackgroundConfig } from "../config.ts"; +import type { Session } from "../session.ts"; +import { generateTaskId } from "./ids.ts"; +import { + type TaskSpec, + type TaskRuntime, + type TaskView, + type TaskOutputChunk, + type TaskStatus, + isTerminalStatus, +} from "./models.ts"; +import { BackgroundTaskStore } from "./store.ts"; + +export class BackgroundTaskManager { + private _session: Session; + private _config: BackgroundConfig; + private _ownerRole: string; + private _store: BackgroundTaskStore; + + constructor( + session: Session, + config: BackgroundConfig, + opts?: { ownerRole?: string }, + ) { + this._session = session; + this._config = config; + this._ownerRole = opts?.ownerRole ?? "root"; + // Store tasks dir next to context file + const tasksDir = join(session.dir, "tasks"); + this._store = new BackgroundTaskStore(tasksDir); + } + + get store(): BackgroundTaskStore { + return this._store; + } + + get role(): string { + return this._ownerRole; + } + + private ensureRoot(): void { + if (this._ownerRole !== "root") { + throw new Error("Background tasks are only supported from the root agent."); + } + } + + private activeTaskCount(): number { + return this._store.listViews().filter((v) => !isTerminalStatus(v.runtime.status)).length; + } + + createBashTask(opts: { + command: string; + description: string; + timeoutS: number; + toolCallId: string; + shellName: string; + shellPath: string; + cwd: string; + }): TaskView { + this.ensureRoot(); + + if (this.activeTaskCount() >= this._config.max_running_tasks) { + throw new Error("Too many background tasks are already running."); + } + + const taskId = generateTaskId("bash"); + const now = Date.now() / 1000; + const spec: TaskSpec = { + version: 1, + id: taskId, + kind: "bash", + sessionId: this._session.id, + description: opts.description, + toolCallId: opts.toolCallId, + ownerRole: "root", + createdAt: now, + command: opts.command, + shellName: opts.shellName, + shellPath: opts.shellPath, + cwd: opts.cwd, + timeoutS: opts.timeoutS, + }; + this._store.createTask(spec); + + // Launch worker subprocess + const taskDir = this._store.taskDir(taskId); + let runtime = this._store.readRuntime(taskId); + try { + const workerPid = this.launchWorker(taskDir); + runtime = this._store.readRuntime(taskId); + if ( + runtime.finishedAt == null && + (runtime.status === "created" || + (runtime.status === "starting" && runtime.workerPid == null)) + ) { + runtime.status = "starting"; + runtime.workerPid = workerPid; + runtime.updatedAt = Date.now() / 1000; + this._store.writeRuntime(taskId, runtime); + } + } catch (err) { + runtime.status = "failed"; + runtime.failureReason = `Failed to launch worker: ${err}`; + runtime.finishedAt = Date.now() / 1000; + runtime.updatedAt = runtime.finishedAt; + this._store.writeRuntime(taskId, runtime); + throw err; + } + + return this._store.mergedView(taskId); + } + + private launchWorker(taskDir: string): number { + const proc = Bun.spawn( + [process.execPath, "--run", "background-worker", "--task-dir", taskDir], + { + stdin: "ignore", + stdout: "ignore", + stderr: "ignore", + cwd: taskDir, + }, + ); + return proc.pid; + } + + createAgentTask(opts: { + agentId: string; + subagentType: string; + prompt: string; + description: string; + toolCallId: string; + modelOverride?: string; + timeoutS?: number; + resumed?: boolean; + }): TaskView { + this.ensureRoot(); + + if (this.activeTaskCount() >= this._config.max_running_tasks) { + throw new Error("Too many background tasks are already running."); + } + + const taskId = generateTaskId("agent"); + const now = Date.now() / 1000; + const spec: TaskSpec = { + version: 1, + id: taskId, + kind: "agent", + sessionId: this._session.id, + description: opts.description, + toolCallId: opts.toolCallId, + ownerRole: "root", + createdAt: now, + kindPayload: { + agent_id: opts.agentId, + subagent_type: opts.subagentType, + prompt: opts.prompt, + model_override: opts.modelOverride, + launch_mode: "background", + }, + }; + this._store.createTask(spec); + + const runtime = this._store.readRuntime(taskId); + runtime.status = "starting"; + runtime.updatedAt = Date.now() / 1000; + this._store.writeRuntime(taskId, runtime); + + return this._store.mergedView(taskId); + } + + listTasks(opts?: { status?: TaskStatus; limit?: number }): TaskView[] { + let tasks = this._store.listViews(); + if (opts?.status != null) { + tasks = tasks.filter((t) => t.runtime.status === opts.status); + } + const limit = opts?.limit ?? 20; + return tasks.slice(0, limit); + } + + getTask(taskId: string): TaskView | undefined { + try { + return this._store.mergedView(taskId); + } catch { + return undefined; + } + } + + readOutput(taskId: string, opts?: { offset?: number; maxBytes?: number }): TaskOutputChunk { + const view = this._store.mergedView(taskId); + return this._store.readOutput( + taskId, + opts?.offset ?? 0, + opts?.maxBytes ?? this._config.read_max_bytes, + view.runtime.status, + ); + } + + tailOutput(taskId: string, opts?: { maxBytes?: number; maxLines?: number }): string { + this._store.mergedView(taskId); // validate existence + return this._store.tailOutput( + taskId, + opts?.maxBytes ?? this._config.read_max_bytes, + opts?.maxLines ?? this._config.notification_tail_lines, + ); + } + + async wait(taskId: string, timeoutS = 30): Promise { + const endTime = performance.now() + timeoutS * 1000; + while (true) { + const view = this._store.mergedView(taskId); + if (isTerminalStatus(view.runtime.status)) return view; + if (performance.now() >= endTime) return view; + await Bun.sleep(this._config.wait_poll_interval_ms); + } + } + + kill(taskId: string, reason = "Killed by user"): TaskView { + this.ensureRoot(); + const view = this._store.mergedView(taskId); + if (isTerminalStatus(view.runtime.status)) return view; + + if (view.spec.kind === "agent") { + this.markTaskKilled(taskId, reason); + return this._store.mergedView(taskId); + } + + // Bash: write control file, best-effort signal + const control = { ...view.control }; + control.killRequestedAt = Date.now() / 1000; + control.killReason = reason; + control.force = false; + this._store.writeControl(taskId, control); + this.bestEffortKill(view.runtime); + return this._store.mergedView(taskId); + } + + killAllActive(reason = "CLI session ended"): string[] { + const killed: string[] = []; + for (const view of this._store.listViews()) { + if (isTerminalStatus(view.runtime.status)) continue; + try { + this.kill(view.spec.id, reason); + killed.push(view.spec.id); + } catch { + logger.error(`Failed to kill task ${view.spec.id} during shutdown`); + } + } + return killed; + } + + recover(): void { + const now = Date.now() / 1000; + const staleAfter = this._config.worker_stale_after_ms / 1000; + + for (const view of this._store.listViews()) { + if (isTerminalStatus(view.runtime.status)) continue; + + if (view.spec.kind === "agent") { + // Agent tasks without live runner are lost + const runtime = { ...view.runtime }; + runtime.finishedAt = now; + runtime.updatedAt = now; + runtime.status = "lost"; + runtime.failureReason = "In-process background agent is no longer running"; + this._store.writeRuntime(view.spec.id, runtime); + continue; + } + + const lastProgressAt = + view.runtime.heartbeatAt ?? + view.runtime.startedAt ?? + view.runtime.updatedAt ?? + view.spec.createdAt; + if (now - lastProgressAt <= staleAfter) continue; + + // Re-read to narrow race window + const freshRuntime = this._store.readRuntime(view.spec.id); + if (isTerminalStatus(freshRuntime.status)) continue; + const freshProgress = + freshRuntime.heartbeatAt ?? + freshRuntime.startedAt ?? + freshRuntime.updatedAt ?? + view.spec.createdAt; + if (now - freshProgress <= staleAfter) continue; + + const runtime = { ...freshRuntime }; + runtime.finishedAt = now; + runtime.updatedAt = now; + if (view.control.killRequestedAt != null) { + runtime.status = "killed"; + runtime.interrupted = true; + runtime.failureReason = view.control.killReason ?? "Killed during recovery"; + } else { + runtime.status = "lost"; + runtime.failureReason = + freshRuntime.heartbeatAt == null + ? "Background worker never heartbeat after startup" + : "Background worker heartbeat expired"; + } + this._store.writeRuntime(view.spec.id, runtime); + } + } + + reconcile(): void { + this.recover(); + } + + // ── Internal status helpers ── + + markTaskRunning(taskId: string): void { + const runtime = this._store.readRuntime(taskId); + if (isTerminalStatus(runtime.status)) return; + runtime.status = "running"; + runtime.updatedAt = Date.now() / 1000; + runtime.heartbeatAt = runtime.updatedAt; + runtime.failureReason = undefined; + this._store.writeRuntime(taskId, runtime); + } + + markTaskCompleted(taskId: string): void { + const runtime = this._store.readRuntime(taskId); + if (isTerminalStatus(runtime.status)) return; + runtime.status = "completed"; + runtime.updatedAt = Date.now() / 1000; + runtime.finishedAt = runtime.updatedAt; + runtime.failureReason = undefined; + this._store.writeRuntime(taskId, runtime); + } + + markTaskFailed(taskId: string, reason: string): void { + const runtime = this._store.readRuntime(taskId); + if (isTerminalStatus(runtime.status)) return; + runtime.status = "failed"; + runtime.updatedAt = Date.now() / 1000; + runtime.finishedAt = runtime.updatedAt; + runtime.failureReason = reason; + this._store.writeRuntime(taskId, runtime); + } + + markTaskKilled(taskId: string, reason: string): void { + const runtime = this._store.readRuntime(taskId); + if (isTerminalStatus(runtime.status)) return; + runtime.status = "killed"; + runtime.updatedAt = Date.now() / 1000; + runtime.finishedAt = runtime.updatedAt; + runtime.interrupted = true; + runtime.failureReason = reason; + this._store.writeRuntime(taskId, runtime); + } + + private bestEffortKill(runtime: TaskRuntime): void { + try { + const pid = runtime.childPgid ?? runtime.childPid ?? runtime.workerPid; + if (pid == null) return; + process.kill(pid, "SIGTERM"); + } catch { + // Process may already be gone + } + } +} diff --git a/src/kimi_cli/background/models.py b/src/kimi_cli/background/models.py deleted file mode 100644 index a3cb033ba..000000000 --- a/src/kimi_cli/background/models.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -import time -from typing import Any, Literal - -from pydantic import BaseModel, ConfigDict, Field, field_validator - -type TaskKind = Literal["bash", "agent"] -type TaskStatus = Literal[ - "created", - "starting", - "running", - "awaiting_approval", - "completed", - "failed", - "killed", - "lost", -] -type TaskOwnerRole = Literal["root", "subagent"] - -TERMINAL_TASK_STATUSES: tuple[TaskStatus, ...] = ("completed", "failed", "killed", "lost") - - -def is_terminal_status(status: TaskStatus) -> bool: - return status in TERMINAL_TASK_STATUSES - - -class TaskSpec(BaseModel): - model_config = ConfigDict(extra="ignore") - - version: int = 1 - id: str - kind: TaskKind - session_id: str - description: str - tool_call_id: str - owner_role: TaskOwnerRole = "root" - created_at: float = Field(default_factory=time.time) - - @field_validator("owner_role", mode="before") - @classmethod - def _normalize_owner_role(cls, v: str) -> str: - if v in ("fixed_subagent", "dynamic_subagent"): - return "subagent" - return v - - # Bash-specific fields for V1. Future task types can use kind_payload. - command: str | None = None - shell_name: str | None = None - shell_path: str | None = None - cwd: str | None = None - timeout_s: int | None = None - kind_payload: dict[str, Any] | None = None - - -class TaskRuntime(BaseModel): - model_config = ConfigDict(extra="ignore") - - status: TaskStatus = "created" - worker_pid: int | None = None - child_pid: int | None = None - child_pgid: int | None = None - started_at: float | None = None - heartbeat_at: float | None = None - updated_at: float = Field(default_factory=time.time) - finished_at: float | None = None - exit_code: int | None = None - interrupted: bool = False - timed_out: bool = False - failure_reason: str | None = None - - -class TaskControl(BaseModel): - model_config = ConfigDict(extra="ignore") - - kill_requested_at: float | None = None - kill_reason: str | None = None - force: bool = False - - -class TaskConsumerState(BaseModel): - model_config = ConfigDict(extra="ignore") - - last_seen_output_size: int = 0 - last_viewed_at: float | None = None - - -class TaskView(BaseModel): - model_config = ConfigDict(extra="ignore") - - spec: TaskSpec - runtime: TaskRuntime - control: TaskControl - consumer: TaskConsumerState - - -class TaskOutputChunk(BaseModel): - model_config = ConfigDict(extra="ignore") - - task_id: str - offset: int - next_offset: int - text: str - eof: bool - status: TaskStatus diff --git a/src/kimi_cli/background/models.ts b/src/kimi_cli/background/models.ts new file mode 100644 index 000000000..5c9a4744b --- /dev/null +++ b/src/kimi_cli/background/models.ts @@ -0,0 +1,212 @@ +/** + * Background task models — corresponds to Python background/models.py + */ + +export type TaskKind = "bash" | "agent"; +export type TaskStatus = + | "created" + | "starting" + | "running" + | "awaiting_approval" + | "completed" + | "failed" + | "killed" + | "lost"; +export type TaskOwnerRole = "root" | "subagent"; + +export const TERMINAL_TASK_STATUSES: readonly TaskStatus[] = [ + "completed", + "failed", + "killed", + "lost", +] as const; + +export function isTerminalStatus(status: TaskStatus): boolean { + return (TERMINAL_TASK_STATUSES as readonly string[]).includes(status); +} + +export interface TaskSpec { + version: number; + id: string; + kind: TaskKind; + sessionId: string; + description: string; + toolCallId: string; + ownerRole: TaskOwnerRole; + createdAt: number; + // Bash-specific + command?: string; + shellName?: string; + shellPath?: string; + cwd?: string; + timeoutS?: number; + // Generic payload for other task types + kindPayload?: Record; +} + +export interface TaskRuntime { + status: TaskStatus; + workerPid?: number; + childPid?: number; + childPgid?: number; + startedAt?: number; + heartbeatAt?: number; + updatedAt: number; + finishedAt?: number; + exitCode?: number; + interrupted: boolean; + timedOut: boolean; + failureReason?: string; +} + +export interface TaskControl { + killRequestedAt?: number; + killReason?: string; + force: boolean; +} + +export interface TaskConsumerState { + lastSeenOutputSize: number; + lastViewedAt?: number; +} + +export interface TaskView { + spec: TaskSpec; + runtime: TaskRuntime; + control: TaskControl; + consumer: TaskConsumerState; +} + +export interface TaskOutputChunk { + taskId: string; + offset: number; + nextOffset: number; + text: string; + eof: boolean; + status: TaskStatus; +} + +// ── JSON serialization helpers (snake_case ↔ camelCase) ── + +export function taskSpecToJson(spec: TaskSpec): Record { + return { + version: spec.version, + id: spec.id, + kind: spec.kind, + session_id: spec.sessionId, + description: spec.description, + tool_call_id: spec.toolCallId, + owner_role: spec.ownerRole, + created_at: spec.createdAt, + command: spec.command, + shell_name: spec.shellName, + shell_path: spec.shellPath, + cwd: spec.cwd, + timeout_s: spec.timeoutS, + kind_payload: spec.kindPayload, + }; +} + +export function taskSpecFromJson(data: Record): TaskSpec { + let ownerRole = String(data.owner_role ?? "root"); + if (ownerRole === "fixed_subagent" || ownerRole === "dynamic_subagent") { + ownerRole = "subagent"; + } + return { + version: Number(data.version ?? 1), + id: String(data.id), + kind: String(data.kind) as TaskKind, + sessionId: String(data.session_id), + description: String(data.description ?? ""), + toolCallId: String(data.tool_call_id ?? ""), + ownerRole: ownerRole as TaskOwnerRole, + createdAt: Number(data.created_at ?? Date.now() / 1000), + command: data.command != null ? String(data.command) : undefined, + shellName: data.shell_name != null ? String(data.shell_name) : undefined, + shellPath: data.shell_path != null ? String(data.shell_path) : undefined, + cwd: data.cwd != null ? String(data.cwd) : undefined, + timeoutS: data.timeout_s != null ? Number(data.timeout_s) : undefined, + kindPayload: data.kind_payload as Record | undefined, + }; +} + +export function taskRuntimeToJson(rt: TaskRuntime): Record { + return { + status: rt.status, + worker_pid: rt.workerPid, + child_pid: rt.childPid, + child_pgid: rt.childPgid, + started_at: rt.startedAt, + heartbeat_at: rt.heartbeatAt, + updated_at: rt.updatedAt, + finished_at: rt.finishedAt, + exit_code: rt.exitCode, + interrupted: rt.interrupted, + timed_out: rt.timedOut, + failure_reason: rt.failureReason, + }; +} + +export function taskRuntimeFromJson(data: Record): TaskRuntime { + return { + status: (String(data.status ?? "created")) as TaskStatus, + workerPid: data.worker_pid != null ? Number(data.worker_pid) : undefined, + childPid: data.child_pid != null ? Number(data.child_pid) : undefined, + childPgid: data.child_pgid != null ? Number(data.child_pgid) : undefined, + startedAt: data.started_at != null ? Number(data.started_at) : undefined, + heartbeatAt: data.heartbeat_at != null ? Number(data.heartbeat_at) : undefined, + updatedAt: Number(data.updated_at ?? Date.now() / 1000), + finishedAt: data.finished_at != null ? Number(data.finished_at) : undefined, + exitCode: data.exit_code != null ? Number(data.exit_code) : undefined, + interrupted: Boolean(data.interrupted ?? false), + timedOut: Boolean(data.timed_out ?? false), + failureReason: data.failure_reason != null ? String(data.failure_reason) : undefined, + }; +} + +export function taskControlToJson(ctrl: TaskControl): Record { + return { + kill_requested_at: ctrl.killRequestedAt, + kill_reason: ctrl.killReason, + force: ctrl.force, + }; +} + +export function taskControlFromJson(data: Record): TaskControl { + return { + killRequestedAt: data.kill_requested_at != null ? Number(data.kill_requested_at) : undefined, + killReason: data.kill_reason != null ? String(data.kill_reason) : undefined, + force: Boolean(data.force ?? false), + }; +} + +export function taskConsumerToJson(cs: TaskConsumerState): Record { + return { + last_seen_output_size: cs.lastSeenOutputSize, + last_viewed_at: cs.lastViewedAt, + }; +} + +export function taskConsumerFromJson(data: Record): TaskConsumerState { + return { + lastSeenOutputSize: Number(data.last_seen_output_size ?? 0), + lastViewedAt: data.last_viewed_at != null ? Number(data.last_viewed_at) : undefined, + }; +} + +export function newTaskRuntime(): TaskRuntime { + return { + status: "created", + updatedAt: Date.now() / 1000, + interrupted: false, + timedOut: false, + }; +} + +export function newTaskControl(): TaskControl { + return { force: false }; +} + +export function newTaskConsumerState(): TaskConsumerState { + return { lastSeenOutputSize: 0 }; +} diff --git a/src/kimi_cli/background/store.py b/src/kimi_cli/background/store.py deleted file mode 100644 index ff3552793..000000000 --- a/src/kimi_cli/background/store.py +++ /dev/null @@ -1,196 +0,0 @@ -from __future__ import annotations - -import os -import re -from pathlib import Path - -from kimi_cli.utils.io import atomic_json_write - -from .models import ( - TaskConsumerState, - TaskControl, - TaskOutputChunk, - TaskRuntime, - TaskSpec, - TaskStatus, - TaskView, -) - -_VALID_TASK_ID = re.compile(r"^[a-z0-9][a-z0-9\-]{1,24}$") - - -def _validate_task_id(task_id: str) -> None: - if not _VALID_TASK_ID.match(task_id): - raise ValueError(f"Invalid task_id: {task_id!r}") - - -class BackgroundTaskStore: - SPEC_FILE = "spec.json" - RUNTIME_FILE = "runtime.json" - CONTROL_FILE = "control.json" - CONSUMER_FILE = "consumer.json" - OUTPUT_FILE = "output.log" - - def __init__(self, root: Path): - self._root = root - - @property - def root(self) -> Path: - return self._root - - def _ensure_root(self) -> Path: - """Return the root directory, creating it if it does not exist.""" - self._root.mkdir(parents=True, exist_ok=True) - return self._root - - def task_dir(self, task_id: str) -> Path: - _validate_task_id(task_id) - path = self._ensure_root() / task_id - path.mkdir(parents=True, exist_ok=True) - return path - - def task_path(self, task_id: str) -> Path: - _validate_task_id(task_id) - return self.root / task_id - - def spec_path(self, task_id: str) -> Path: - return self.task_path(task_id) / self.SPEC_FILE - - def runtime_path(self, task_id: str) -> Path: - return self.task_path(task_id) / self.RUNTIME_FILE - - def control_path(self, task_id: str) -> Path: - return self.task_path(task_id) / self.CONTROL_FILE - - def consumer_path(self, task_id: str) -> Path: - return self.task_path(task_id) / self.CONSUMER_FILE - - def output_path(self, task_id: str) -> Path: - return self.task_path(task_id) / self.OUTPUT_FILE - - def create_task(self, spec: TaskSpec) -> None: - task_dir = self.task_dir(spec.id) - atomic_json_write(spec.model_dump(mode="json"), task_dir / self.SPEC_FILE) - atomic_json_write(TaskRuntime().model_dump(mode="json"), task_dir / self.RUNTIME_FILE) - atomic_json_write(TaskControl().model_dump(mode="json"), task_dir / self.CONTROL_FILE) - atomic_json_write( - TaskConsumerState().model_dump(mode="json"), - task_dir / self.CONSUMER_FILE, - ) - self.output_path(spec.id).touch(exist_ok=True) - - def list_task_ids(self) -> list[str]: - if not self.root.exists(): - return [] - task_ids: list[str] = [] - for path in sorted(self.root.iterdir()): - if not path.is_dir(): - continue - if not (path / self.SPEC_FILE).exists(): - continue - task_ids.append(path.name) - return task_ids - - def write_spec(self, spec: TaskSpec) -> None: - atomic_json_write(spec.model_dump(mode="json"), self.spec_path(spec.id)) - - def read_spec(self, task_id: str) -> TaskSpec: - return TaskSpec.model_validate_json(self.spec_path(task_id).read_text(encoding="utf-8")) - - def write_runtime(self, task_id: str, runtime: TaskRuntime) -> None: - atomic_json_write(runtime.model_dump(mode="json"), self.runtime_path(task_id)) - - def read_runtime(self, task_id: str) -> TaskRuntime: - path = self.runtime_path(task_id) - if not path.exists(): - return TaskRuntime() - return TaskRuntime.model_validate_json(path.read_text(encoding="utf-8")) - - def write_control(self, task_id: str, control: TaskControl) -> None: - atomic_json_write(control.model_dump(mode="json"), self.control_path(task_id)) - - def read_control(self, task_id: str) -> TaskControl: - path = self.control_path(task_id) - if not path.exists(): - return TaskControl() - return TaskControl.model_validate_json(path.read_text(encoding="utf-8")) - - def write_consumer(self, task_id: str, consumer: TaskConsumerState) -> None: - atomic_json_write(consumer.model_dump(mode="json"), self.consumer_path(task_id)) - - def read_consumer(self, task_id: str) -> TaskConsumerState: - path = self.consumer_path(task_id) - if not path.exists(): - return TaskConsumerState() - return TaskConsumerState.model_validate_json(path.read_text(encoding="utf-8")) - - def merged_view(self, task_id: str) -> TaskView: - return TaskView( - spec=self.read_spec(task_id), - runtime=self.read_runtime(task_id), - control=self.read_control(task_id), - consumer=self.read_consumer(task_id), - ) - - def list_views(self) -> list[TaskView]: - views = [self.merged_view(task_id) for task_id in self.list_task_ids()] - views.sort( - key=lambda view: view.runtime.updated_at or view.spec.created_at, - reverse=True, - ) - return views - - def read_output( - self, - task_id: str, - offset: int, - max_bytes: int, - *, - status: TaskStatus, - path_override: Path | None = None, - ) -> TaskOutputChunk: - path = path_override if path_override is not None else self.output_path(task_id) - if not path.exists(): - return TaskOutputChunk( - task_id=task_id, - offset=offset, - next_offset=offset, - text="", - eof=True, - status=status, - ) - - with path.open("rb") as f: - f.seek(0, os.SEEK_END) - total_size = f.tell() - bounded_offset = min(max(offset, 0), total_size) - f.seek(bounded_offset) - content = f.read(max_bytes) - - next_offset = bounded_offset + len(content) - return TaskOutputChunk( - task_id=task_id, - offset=bounded_offset, - next_offset=next_offset, - text=content.decode("utf-8", errors="replace"), - eof=next_offset >= total_size, - status=status, - ) - - def tail_output(self, task_id: str, max_bytes: int, max_lines: int) -> str: - path = self.output_path(task_id) - if not path.exists(): - return "" - - with path.open("rb") as f: - f.seek(0, os.SEEK_END) - total_size = f.tell() - start = max(0, total_size - max_bytes) - f.seek(start) - content = f.read() - - text = content.decode("utf-8", errors="replace") - lines = text.splitlines() - if len(lines) > max_lines: - lines = lines[-max_lines:] - return "\n".join(lines) diff --git a/src/kimi_cli/background/store.ts b/src/kimi_cli/background/store.ts new file mode 100644 index 000000000..6ba12f8fc --- /dev/null +++ b/src/kimi_cli/background/store.ts @@ -0,0 +1,228 @@ +/** + * Background task store — corresponds to Python background/store.py + * File-based persistence: per-task directory with spec.json, runtime.json, etc. + */ + +import { join } from "node:path"; +import { mkdirSync, existsSync, readdirSync, readFileSync, writeFileSync, statSync } from "node:fs"; +import { + type TaskSpec, + type TaskRuntime, + type TaskControl, + type TaskConsumerState, + type TaskView, + type TaskOutputChunk, + type TaskStatus, + taskSpecToJson, + taskSpecFromJson, + taskRuntimeToJson, + taskRuntimeFromJson, + taskControlToJson, + taskControlFromJson, + taskConsumerToJson, + taskConsumerFromJson, + newTaskRuntime, + newTaskControl, + newTaskConsumerState, +} from "./models.ts"; + +const VALID_TASK_ID = /^[a-z0-9][a-z0-9\-]{1,24}$/; + +function validateTaskId(taskId: string): void { + if (!VALID_TASK_ID.test(taskId)) { + throw new Error(`Invalid task_id: ${taskId}`); + } +} + +function atomicJsonWrite(data: Record, filePath: string): void { + const tmpPath = filePath + ".tmp"; + writeFileSync(tmpPath, JSON.stringify(data, null, 2), "utf-8"); + // Bun.fs.renameSync is atomic on same filesystem + const { renameSync } = require("node:fs"); + renameSync(tmpPath, filePath); +} + +export class BackgroundTaskStore { + static readonly SPEC_FILE = "spec.json"; + static readonly RUNTIME_FILE = "runtime.json"; + static readonly CONTROL_FILE = "control.json"; + static readonly CONSUMER_FILE = "consumer.json"; + static readonly OUTPUT_FILE = "output.log"; + + private _root: string; + + constructor(root: string) { + this._root = root; + } + + get root(): string { + return this._root; + } + + private ensureRoot(): string { + if (!existsSync(this._root)) { + mkdirSync(this._root, { recursive: true }); + } + return this._root; + } + + taskDir(taskId: string): string { + validateTaskId(taskId); + const path = join(this.ensureRoot(), taskId); + if (!existsSync(path)) { + mkdirSync(path, { recursive: true }); + } + return path; + } + + taskPath(taskId: string): string { + validateTaskId(taskId); + return join(this._root, taskId); + } + + specPath(taskId: string): string { + return join(this.taskPath(taskId), BackgroundTaskStore.SPEC_FILE); + } + + runtimePath(taskId: string): string { + return join(this.taskPath(taskId), BackgroundTaskStore.RUNTIME_FILE); + } + + controlPath(taskId: string): string { + return join(this.taskPath(taskId), BackgroundTaskStore.CONTROL_FILE); + } + + consumerPath(taskId: string): string { + return join(this.taskPath(taskId), BackgroundTaskStore.CONSUMER_FILE); + } + + outputPath(taskId: string): string { + return join(this.taskPath(taskId), BackgroundTaskStore.OUTPUT_FILE); + } + + createTask(spec: TaskSpec): void { + const dir = this.taskDir(spec.id); + atomicJsonWrite(taskSpecToJson(spec), join(dir, BackgroundTaskStore.SPEC_FILE)); + atomicJsonWrite(taskRuntimeToJson(newTaskRuntime()), join(dir, BackgroundTaskStore.RUNTIME_FILE)); + atomicJsonWrite(taskControlToJson(newTaskControl()), join(dir, BackgroundTaskStore.CONTROL_FILE)); + atomicJsonWrite(taskConsumerToJson(newTaskConsumerState()), join(dir, BackgroundTaskStore.CONSUMER_FILE)); + // Touch output file + writeFileSync(join(dir, BackgroundTaskStore.OUTPUT_FILE), "", "utf-8"); + } + + listTaskIds(): string[] { + if (!existsSync(this._root)) return []; + const taskIds: string[] = []; + for (const entry of readdirSync(this._root).sort()) { + const dirPath = join(this._root, entry); + try { + if (!statSync(dirPath).isDirectory()) continue; + } catch { + continue; + } + if (!existsSync(join(dirPath, BackgroundTaskStore.SPEC_FILE))) continue; + taskIds.push(entry); + } + return taskIds; + } + + writeSpec(spec: TaskSpec): void { + atomicJsonWrite(taskSpecToJson(spec), this.specPath(spec.id)); + } + + readSpec(taskId: string): TaskSpec { + const data = JSON.parse(readFileSync(this.specPath(taskId), "utf-8")); + return taskSpecFromJson(data); + } + + writeRuntime(taskId: string, runtime: TaskRuntime): void { + atomicJsonWrite(taskRuntimeToJson(runtime), this.runtimePath(taskId)); + } + + readRuntime(taskId: string): TaskRuntime { + const path = this.runtimePath(taskId); + if (!existsSync(path)) return newTaskRuntime(); + const data = JSON.parse(readFileSync(path, "utf-8")); + return taskRuntimeFromJson(data); + } + + writeControl(taskId: string, control: TaskControl): void { + atomicJsonWrite(taskControlToJson(control), this.controlPath(taskId)); + } + + readControl(taskId: string): TaskControl { + const path = this.controlPath(taskId); + if (!existsSync(path)) return newTaskControl(); + const data = JSON.parse(readFileSync(path, "utf-8")); + return taskControlFromJson(data); + } + + writeConsumer(taskId: string, consumer: TaskConsumerState): void { + atomicJsonWrite(taskConsumerToJson(consumer), this.consumerPath(taskId)); + } + + readConsumer(taskId: string): TaskConsumerState { + const path = this.consumerPath(taskId); + if (!existsSync(path)) return newTaskConsumerState(); + const data = JSON.parse(readFileSync(path, "utf-8")); + return taskConsumerFromJson(data); + } + + mergedView(taskId: string): TaskView { + return { + spec: this.readSpec(taskId), + runtime: this.readRuntime(taskId), + control: this.readControl(taskId), + consumer: this.readConsumer(taskId), + }; + } + + listViews(): TaskView[] { + const views = this.listTaskIds().map((id) => this.mergedView(id)); + views.sort((a, b) => (b.runtime.updatedAt || b.spec.createdAt) - (a.runtime.updatedAt || a.spec.createdAt)); + return views; + } + + readOutput( + taskId: string, + offset: number, + maxBytes: number, + status: TaskStatus, + ): TaskOutputChunk { + const path = this.outputPath(taskId); + if (!existsSync(path)) { + return { taskId, offset, nextOffset: offset, text: "", eof: true, status }; + } + + const buf = readFileSync(path); + const totalSize = buf.length; + const boundedOffset = Math.min(Math.max(offset, 0), totalSize); + const content = buf.subarray(boundedOffset, boundedOffset + maxBytes); + const nextOffset = boundedOffset + content.length; + + return { + taskId, + offset: boundedOffset, + nextOffset, + text: content.toString("utf-8"), + eof: nextOffset >= totalSize, + status, + }; + } + + tailOutput(taskId: string, maxBytes: number, maxLines: number): string { + const path = this.outputPath(taskId); + if (!existsSync(path)) return ""; + + const buf = readFileSync(path); + const totalSize = buf.length; + const start = Math.max(0, totalSize - maxBytes); + const content = buf.subarray(start); + const text = content.toString("utf-8"); + let lines = text.split("\n"); + if (lines.length > maxLines) { + lines = lines.slice(-maxLines); + } + return lines.join("\n"); + } +} diff --git a/src/kimi_cli/background/summary.py b/src/kimi_cli/background/summary.py deleted file mode 100644 index 1c2802270..000000000 --- a/src/kimi_cli/background/summary.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from .manager import BackgroundTaskManager -from .models import TaskView, is_terminal_status - - -def list_task_views( - manager: BackgroundTaskManager, - *, - active_only: bool = True, - limit: int = 20, -) -> list[TaskView]: - views = manager.list_tasks(limit=None) - if active_only: - views = [view for view in views if not is_terminal_status(view.runtime.status)] - return views[:limit] - - -def format_task(view: TaskView, *, include_command: bool = False) -> str: - lines = [ - f"task_id: {view.spec.id}", - f"kind: {view.spec.kind}", - f"status: {view.runtime.status}", - f"description: {view.spec.description}", - ] - if view.spec.kind == "agent" and view.spec.kind_payload: - if agent_id := view.spec.kind_payload.get("agent_id"): - lines.append(f"agent_id: {agent_id}") - if subagent_type := view.spec.kind_payload.get("subagent_type"): - lines.append(f"subagent_type: {subagent_type}") - if include_command and view.spec.command: - lines.append(f"command: {view.spec.command}") - if view.runtime.exit_code is not None: - lines.append(f"exit_code: {view.runtime.exit_code}") - if view.runtime.failure_reason: - lines.append(f"reason: {view.runtime.failure_reason}") - return "\n".join(lines) - - -def format_task_list( - views: list[TaskView], - *, - active_only: bool = True, - include_command: bool = True, -) -> str: - header = "active_background_tasks" if active_only else "background_tasks" - if not views: - return f"{header}: 0\n[no tasks]" - - lines = [f"{header}: {len(views)}", ""] - for index, view in enumerate(views, start=1): - lines.extend([f"[{index}]", format_task(view, include_command=include_command), ""]) - return "\n".join(lines).rstrip() - - -def build_active_task_snapshot(manager: BackgroundTaskManager, *, limit: int = 20) -> str | None: - views = list_task_views(manager, active_only=True, limit=limit) - if not views: - return None - return "\n".join( - [ - "", - format_task_list(views, active_only=True, include_command=False), - "", - ] - ) diff --git a/src/kimi_cli/background/summary.ts b/src/kimi_cli/background/summary.ts new file mode 100644 index 000000000..525a5c209 --- /dev/null +++ b/src/kimi_cli/background/summary.ts @@ -0,0 +1,79 @@ +/** + * Background task summary/formatting — corresponds to Python background/summary.py + */ + +import type { TaskView } from "./models.ts"; +import { isTerminalStatus } from "./models.ts"; +import type { BackgroundTaskManager } from "./manager.ts"; + +export function listTaskViews( + manager: BackgroundTaskManager, + opts?: { activeOnly?: boolean; limit?: number }, +): TaskView[] { + const activeOnly = opts?.activeOnly ?? true; + const limit = opts?.limit ?? 20; + let views = manager.listTasks({ limit: undefined }); + if (activeOnly) { + views = views.filter((v) => !isTerminalStatus(v.runtime.status)); + } + return views.slice(0, limit); +} + +export function formatTask(view: TaskView, opts?: { includeCommand?: boolean }): string { + const lines = [ + `task_id: ${view.spec.id}`, + `kind: ${view.spec.kind}`, + `status: ${view.runtime.status}`, + `description: ${view.spec.description}`, + ]; + + if (view.spec.kind === "agent" && view.spec.kindPayload) { + const agentId = view.spec.kindPayload.agent_id; + if (agentId) lines.push(`agent_id: ${agentId}`); + const subagentType = view.spec.kindPayload.subagent_type; + if (subagentType) lines.push(`subagent_type: ${subagentType}`); + } + + if (opts?.includeCommand && view.spec.command) { + lines.push(`command: ${view.spec.command}`); + } + if (view.runtime.exitCode != null) { + lines.push(`exit_code: ${view.runtime.exitCode}`); + } + if (view.runtime.failureReason) { + lines.push(`reason: ${view.runtime.failureReason}`); + } + return lines.join("\n"); +} + +export function formatTaskList( + views: TaskView[], + opts?: { activeOnly?: boolean; includeCommand?: boolean }, +): string { + const activeOnly = opts?.activeOnly ?? true; + const includeCommand = opts?.includeCommand ?? true; + const header = activeOnly ? "active_background_tasks" : "background_tasks"; + + if (views.length === 0) { + return `${header}: 0\n[no tasks]`; + } + + const lines = [`${header}: ${views.length}`, ""]; + for (let i = 0; i < views.length; i++) { + lines.push(`[${i + 1}]`, formatTask(views[i]!, { includeCommand }), ""); + } + return lines.join("\n").trimEnd(); +} + +export function buildActiveTaskSnapshot( + manager: BackgroundTaskManager, + opts?: { limit?: number }, +): string | undefined { + const views = listTaskViews(manager, { activeOnly: true, limit: opts?.limit ?? 20 }); + if (views.length === 0) return undefined; + return [ + "", + formatTaskList(views, { activeOnly: true, includeCommand: false }), + "", + ].join("\n"); +} diff --git a/src/kimi_cli/background/worker.py b/src/kimi_cli/background/worker.py deleted file mode 100644 index 9ccb7ad06..000000000 --- a/src/kimi_cli/background/worker.py +++ /dev/null @@ -1,209 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import os -import signal -import subprocess -import time -from pathlib import Path -from typing import Any - -from kimi_cli.utils.logging import logger -from kimi_cli.utils.subprocess_env import get_clean_env - -from .models import TaskControl -from .store import BackgroundTaskStore - - -def terminate_process_tree_windows(pid: int, *, force: bool) -> None: - args = ["taskkill", "/PID", str(pid), "/T"] - if force: - args.append("/F") - subprocess.run( - args, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - check=False, - ) - - -async def run_background_task_worker( - task_dir: Path, - *, - heartbeat_interval_ms: int = 5000, - control_poll_interval_ms: int = 500, - kill_grace_period_ms: int = 2000, -) -> None: - task_dir = task_dir.expanduser().resolve() - task_id = task_dir.name - store = BackgroundTaskStore(task_dir.parent) - spec = store.read_spec(task_id) - runtime = store.read_runtime(task_id) - - runtime.status = "starting" - runtime.worker_pid = os.getpid() - runtime.started_at = time.time() - runtime.heartbeat_at = runtime.started_at - runtime.updated_at = runtime.started_at - store.write_runtime(task_id, runtime) - - control = store.read_control(task_id) - if control.kill_requested_at is not None: - runtime.status = "killed" - runtime.interrupted = True - runtime.finished_at = time.time() - runtime.updated_at = runtime.finished_at - runtime.failure_reason = control.kill_reason or "Killed before command start" - store.write_runtime(task_id, runtime) - return - - if spec.command is None or spec.shell_path is None or spec.cwd is None: - runtime.status = "failed" - runtime.finished_at = time.time() - runtime.updated_at = runtime.finished_at - runtime.failure_reason = "Task spec is incomplete for bash worker" - store.write_runtime(task_id, runtime) - return - - process: asyncio.subprocess.Process | None = None - control_task: asyncio.Task[None] | None = None - heartbeat_task: asyncio.Task[None] | None = None - stop_event = asyncio.Event() - kill_sent_at: float | None = None - timed_out = False - timeout_reason: str | None = None - - async def _heartbeat_loop() -> None: - while not stop_event.is_set(): - await asyncio.sleep(heartbeat_interval_ms / 1000) - current = store.read_runtime(task_id) - if current.finished_at is not None: - return - current.heartbeat_at = time.time() - current.updated_at = current.heartbeat_at - store.write_runtime(task_id, current) - - async def _terminate_process(force: bool = False) -> None: - nonlocal kill_sent_at - if process is None or process.returncode is not None: - return - kill_sent_at = kill_sent_at or time.time() - - try: - if os.name == "nt": - terminate_process_tree_windows(process.pid, force=force) - return - - target_pgid = process.pid - if force: - os.killpg(target_pgid, signal.SIGKILL) - else: - os.killpg(target_pgid, signal.SIGTERM) - except ProcessLookupError: - pass - - async def _control_loop() -> None: - nonlocal kill_sent_at - while not stop_event.is_set(): - await asyncio.sleep(control_poll_interval_ms / 1000) - current_control: TaskControl = store.read_control(task_id) - if current_control.kill_requested_at is not None: - await _terminate_process(force=current_control.force) - if ( - kill_sent_at is not None - and process is not None - and process.returncode is None - and time.time() - kill_sent_at >= kill_grace_period_ms / 1000 - ): - await _terminate_process(force=True) - - try: - output_path = store.output_path(task_id) - with output_path.open("ab") as output_file: - spawn_kwargs: dict[str, Any] = { - "stdin": subprocess.DEVNULL, - "stdout": output_file, - "stderr": output_file, - "cwd": spec.cwd, - "env": get_clean_env(), - } - if os.name == "nt": - spawn_kwargs["creationflags"] = getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0) - else: - spawn_kwargs["start_new_session"] = True - - args = ( - (spec.shell_path, "-command", spec.command) - if spec.shell_name == "Windows PowerShell" - else (spec.shell_path, "-c", spec.command) - ) - process = await asyncio.create_subprocess_exec(*args, **spawn_kwargs) - - runtime = store.read_runtime(task_id) - runtime.status = "running" - runtime.child_pid = process.pid - runtime.child_pgid = process.pid if os.name != "nt" else None - runtime.updated_at = time.time() - runtime.heartbeat_at = runtime.updated_at - store.write_runtime(task_id, runtime) - last_known_runtime = runtime - - heartbeat_task = asyncio.create_task(_heartbeat_loop()) - control_task = asyncio.create_task(_control_loop()) - if spec.timeout_s is None: - returncode = await process.wait() - else: - try: - returncode = await asyncio.wait_for(process.wait(), timeout=spec.timeout_s) - except TimeoutError: - timed_out = True - timeout_reason = f"Command timed out after {spec.timeout_s}s" - await _terminate_process(force=False) - try: - returncode = await asyncio.wait_for( - process.wait(), - timeout=kill_grace_period_ms / 1000, - ) - except TimeoutError: - await _terminate_process(force=True) - returncode = await process.wait() - except Exception as exc: - logger.exception("Background task worker failed") - runtime = store.read_runtime(task_id) - runtime.status = "failed" - runtime.finished_at = time.time() - runtime.updated_at = runtime.finished_at - runtime.failure_reason = str(exc) - store.write_runtime(task_id, runtime) - return - finally: - stop_event.set() - for task in (heartbeat_task, control_task): - if task is not None: - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await task - - runtime = last_known_runtime.model_copy() - control = store.read_control(task_id) - runtime.finished_at = time.time() - runtime.updated_at = runtime.finished_at - runtime.exit_code = returncode - runtime.heartbeat_at = runtime.finished_at - if timed_out: - runtime.status = "failed" - runtime.interrupted = True - runtime.timed_out = True - runtime.failure_reason = timeout_reason - elif control.kill_requested_at is not None: - runtime.status = "killed" - runtime.interrupted = True - runtime.failure_reason = control.kill_reason or "Killed" - elif returncode == 0: - runtime.status = "completed" - runtime.failure_reason = None - else: - runtime.status = "failed" - runtime.failure_reason = f"Command failed with exit code {returncode}" - store.write_runtime(task_id, runtime) diff --git a/src/kimi_cli/background/worker.ts b/src/kimi_cli/background/worker.ts new file mode 100644 index 000000000..42693509b --- /dev/null +++ b/src/kimi_cli/background/worker.ts @@ -0,0 +1,193 @@ +/** + * Background task worker — corresponds to Python background/worker.py + * Runs a bash command in a subprocess with heartbeat and kill polling. + */ + +import { join } from "node:path"; +import { existsSync, openSync, closeSync } from "node:fs"; +import { BackgroundTaskStore } from "./store.ts"; +import { isTerminalStatus } from "./models.ts"; + +export async function runBackgroundTaskWorker( + taskDir: string, + opts?: { + heartbeatIntervalMs?: number; + controlPollIntervalMs?: number; + killGracePeriodMs?: number; + }, +): Promise { + const heartbeatIntervalMs = opts?.heartbeatIntervalMs ?? 5000; + const controlPollIntervalMs = opts?.controlPollIntervalMs ?? 500; + const killGracePeriodMs = opts?.killGracePeriodMs ?? 2000; + + const taskId = taskDir.split("/").pop()!; + const storeRoot = join(taskDir, ".."); + const store = new BackgroundTaskStore(storeRoot); + const spec = store.readSpec(taskId); + let runtime = store.readRuntime(taskId); + + const now = Date.now() / 1000; + runtime.status = "starting"; + runtime.workerPid = process.pid; + runtime.startedAt = now; + runtime.heartbeatAt = now; + runtime.updatedAt = now; + store.writeRuntime(taskId, runtime); + + // Check if already killed before launch + const control = store.readControl(taskId); + if (control.killRequestedAt != null) { + runtime.status = "killed"; + runtime.interrupted = true; + runtime.finishedAt = Date.now() / 1000; + runtime.updatedAt = runtime.finishedAt; + runtime.failureReason = control.killReason ?? "Killed before command start"; + store.writeRuntime(taskId, runtime); + return; + } + + if (!spec.command || !spec.shellPath || !spec.cwd) { + runtime.status = "failed"; + runtime.finishedAt = Date.now() / 1000; + runtime.updatedAt = runtime.finishedAt; + runtime.failureReason = "Task spec is incomplete for bash worker"; + store.writeRuntime(taskId, runtime); + return; + } + + let timedOut = false; + let timeoutReason: string | undefined; + let killSentAt: number | undefined; + + const outputPath = store.outputPath(taskId); + const outputFd = openSync(outputPath, "a"); + + let proc: ReturnType | undefined; + let heartbeatTimer: ReturnType | undefined; + let controlTimer: ReturnType | undefined; + + try { + const args = + spec.shellName === "Windows PowerShell" + ? [spec.shellPath, "-command", spec.command] + : [spec.shellPath, "-c", spec.command]; + + proc = Bun.spawn(args, { + stdin: "ignore", + stdout: outputFd, + stderr: outputFd, + cwd: spec.cwd, + }); + + runtime = store.readRuntime(taskId); + runtime.status = "running"; + runtime.childPid = proc.pid; + runtime.childPgid = proc.pid; + runtime.updatedAt = Date.now() / 1000; + runtime.heartbeatAt = runtime.updatedAt; + store.writeRuntime(taskId, runtime); + + // Heartbeat loop + heartbeatTimer = setInterval(() => { + try { + const current = store.readRuntime(taskId); + if (current.finishedAt != null) return; + current.heartbeatAt = Date.now() / 1000; + current.updatedAt = current.heartbeatAt; + store.writeRuntime(taskId, current); + } catch { + // Ignore + } + }, heartbeatIntervalMs); + + // Control poll loop + controlTimer = setInterval(() => { + try { + const ctrl = store.readControl(taskId); + if (ctrl.killRequestedAt != null && proc) { + try { + proc.kill(); + } catch { + // Process may be gone + } + if ( + killSentAt != null && + proc.exitCode == null && + Date.now() / 1000 - killSentAt >= killGracePeriodMs / 1000 + ) { + try { + proc.kill(9); // SIGKILL + } catch { + // Ignore + } + } + if (killSentAt == null) { + killSentAt = Date.now() / 1000; + } + } + } catch { + // Ignore + } + }, controlPollIntervalMs); + + // Wait for process with optional timeout + let exitCode: number; + if (spec.timeoutS != null) { + const timeoutPromise = new Promise<"timeout">((resolve) => + setTimeout(() => resolve("timeout"), spec.timeoutS! * 1000), + ); + const result = await Promise.race([proc.exited, timeoutPromise]); + if (result === "timeout") { + timedOut = true; + timeoutReason = `Command timed out after ${spec.timeoutS}s`; + proc.kill(); + try { + exitCode = await proc.exited; + } catch { + exitCode = -1; + } + } else { + exitCode = result; + } + } else { + exitCode = await proc.exited; + } + + // Write final runtime + const finalControl = store.readControl(taskId); + const finalRuntime = store.readRuntime(taskId); + finalRuntime.finishedAt = Date.now() / 1000; + finalRuntime.updatedAt = finalRuntime.finishedAt; + finalRuntime.exitCode = exitCode; + finalRuntime.heartbeatAt = finalRuntime.finishedAt; + + if (timedOut) { + finalRuntime.status = "failed"; + finalRuntime.interrupted = true; + finalRuntime.timedOut = true; + finalRuntime.failureReason = timeoutReason; + } else if (finalControl.killRequestedAt != null) { + finalRuntime.status = "killed"; + finalRuntime.interrupted = true; + finalRuntime.failureReason = finalControl.killReason ?? "Killed"; + } else if (exitCode === 0) { + finalRuntime.status = "completed"; + finalRuntime.failureReason = undefined; + } else { + finalRuntime.status = "failed"; + finalRuntime.failureReason = `Command failed with exit code ${exitCode}`; + } + store.writeRuntime(taskId, finalRuntime); + } catch (err) { + runtime = store.readRuntime(taskId); + runtime.status = "failed"; + runtime.finishedAt = Date.now() / 1000; + runtime.updatedAt = runtime.finishedAt; + runtime.failureReason = String(err); + store.writeRuntime(taskId, runtime); + } finally { + if (heartbeatTimer) clearInterval(heartbeatTimer); + if (controlTimer) clearInterval(controlTimer); + closeSync(outputFd); + } +} diff --git a/src/kimi_cli/cli/__init__.py b/src/kimi_cli/cli/__init__.py deleted file mode 100644 index 0fb044220..000000000 --- a/src/kimi_cli/cli/__init__.py +++ /dev/null @@ -1,969 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Literal - -import typer - -if TYPE_CHECKING: - from kimi_cli.session import Session - -from ._lazy_group import LazySubcommandGroup - - -class Reload(Exception): - """Reload configuration.""" - - def __init__(self, session_id: str | None = None): - super().__init__("reload") - self.session_id = session_id - self.source_session: Session | None = None - - -class SwitchToWeb(Exception): - """Switch to web interface.""" - - def __init__(self, session_id: str | None = None): - super().__init__("switch_to_web") - self.session_id = session_id - - -class SwitchToVis(Exception): - """Switch to vis (tracing visualizer) interface.""" - - def __init__(self, session_id: str | None = None): - super().__init__("switch_to_vis") - self.session_id = session_id - - -cli = typer.Typer( - cls=LazySubcommandGroup, - epilog="""\b\ -Documentation: https://moonshotai.github.io/kimi-cli/\n -LLM friendly version: https://moonshotai.github.io/kimi-cli/llms.txt""", - add_completion=False, - context_settings={"help_option_names": ["-h", "--help"]}, - help="Kimi, your next CLI agent.", -) - -UIMode = Literal["shell", "print", "acp", "wire"] - - -class ExitCode: - SUCCESS = 0 - FAILURE = 1 - RETRYABLE = 75 # EX_TEMPFAIL from sysexits.h - - -InputFormat = Literal["text", "stream-json"] -OutputFormat = Literal["text", "stream-json"] - - -def _version_callback(value: bool) -> None: - if value: - from kimi_cli.constant import get_version - - typer.echo(f"kimi, version {get_version()}") - raise typer.Exit() - - -@cli.callback(invoke_without_command=True) -def kimi( - ctx: typer.Context, - # Meta - version: Annotated[ - bool, - typer.Option( - "--version", - "-V", - help="Show version and exit.", - callback=_version_callback, - is_eager=True, - ), - ] = False, - verbose: Annotated[ - bool, - typer.Option( - "--verbose", - help="Print verbose information. Default: no.", - ), - ] = False, - debug: Annotated[ - bool, - typer.Option( - "--debug", - help="Log debug information. Default: no.", - ), - ] = False, - # Basic configuration - local_work_dir: Annotated[ - Path | None, - typer.Option( - "--work-dir", - "-w", - exists=True, - file_okay=False, - dir_okay=True, - readable=True, - writable=True, - help="Working directory for the agent. Default: current directory.", - ), - ] = None, - local_add_dirs: Annotated[ - list[Path] | None, - typer.Option( - "--add-dir", - exists=True, - file_okay=False, - dir_okay=True, - readable=True, - help=( - "Add an additional directory to the workspace scope. " - "Can be specified multiple times." - ), - ), - ] = None, - session_id: Annotated[ - str | None, - typer.Option( - "--session", - "-S", - help="Session ID to resume for the working directory. Default: new session.", - ), - ] = None, - continue_: Annotated[ - bool, - typer.Option( - "--continue", - "-C", - help="Continue the previous session for the working directory. Default: no.", - ), - ] = False, - config_string: Annotated[ - str | None, - typer.Option( - "--config", - help="Config TOML/JSON string to load. Default: none.", - ), - ] = None, - config_file: Annotated[ - Path | None, - typer.Option( - "--config-file", - exists=True, - file_okay=True, - dir_okay=False, - readable=True, - help="Config TOML/JSON file to load. Default: ~/.kimi/config.toml.", - ), - ] = None, - model_name: Annotated[ - str | None, - typer.Option( - "--model", - "-m", - help="LLM model to use. Default: default model set in config file.", - ), - ] = None, - thinking: Annotated[ - bool | None, - typer.Option( - "--thinking/--no-thinking", - help="Enable thinking mode. Default: default thinking mode set in config file.", - ), - ] = None, - # Run mode - yolo: Annotated[ - bool, - typer.Option( - "--yolo", - "--yes", - "-y", - "--auto-approve", - help="Automatically approve all actions. Default: no.", - ), - ] = False, - prompt: Annotated[ - str | None, - typer.Option( - "--prompt", - "-p", - "--command", - "-c", - help="User prompt to the agent. Default: prompt interactively.", - ), - ] = None, - print_mode: Annotated[ - bool, - typer.Option( - "--print", - help=( - "Run in print mode (non-interactive). Note: print mode implicitly adds `--yolo`." - ), - ), - ] = False, - acp_mode: Annotated[ - bool, - typer.Option( - "--acp", - help="(Deprecated, use `kimi acp` instead) Run as ACP server.", - ), - ] = False, - wire_mode: Annotated[ - bool, - typer.Option( - "--wire", - help="Run as Wire server (experimental).", - ), - ] = False, - input_format: Annotated[ - InputFormat | None, - typer.Option( - "--input-format", - help=( - "Input format to use. Must be used with `--print` " - "and the input must be piped in via stdin. " - "Default: text." - ), - ), - ] = None, - output_format: Annotated[ - OutputFormat | None, - typer.Option( - "--output-format", - help="Output format to use. Must be used with `--print`. Default: text.", - ), - ] = None, - final_message_only: Annotated[ - bool, - typer.Option( - "--final-message-only", - help="Only print the final assistant message (print UI).", - ), - ] = False, - quiet: Annotated[ - bool, - typer.Option( - "--quiet", - help="Alias for `--print --output-format text --final-message-only`.", - ), - ] = False, - # Customization - agent: Annotated[ - Literal["default", "okabe"] | None, - typer.Option( - "--agent", - help="Builtin agent specification to use. Default: builtin default agent.", - ), - ] = None, - agent_file: Annotated[ - Path | None, - typer.Option( - "--agent-file", - exists=True, - file_okay=True, - dir_okay=False, - readable=True, - help="Custom agent specification file. Default: builtin default agent.", - ), - ] = None, - mcp_config_file: Annotated[ - list[Path] | None, - typer.Option( - "--mcp-config-file", - exists=True, - file_okay=True, - dir_okay=False, - readable=True, - help=( - "MCP config file to load. Add this option multiple times to specify multiple MCP " - "configs. Default: none." - ), - ), - ] = None, - mcp_config: Annotated[ - list[str] | None, - typer.Option( - "--mcp-config", - help=( - "MCP config JSON to load. Add this option multiple times to specify multiple MCP " - "configs. Default: none." - ), - ), - ] = None, - local_skills_dir: Annotated[ - list[Path] | None, - typer.Option( - "--skills-dir", - exists=True, - file_okay=False, - dir_okay=True, - readable=True, - help="Custom skills directories (repeatable). Overrides default discovery.", - ), - ] = None, - # Loop control - max_steps_per_turn: Annotated[ - int | None, - typer.Option( - "--max-steps-per-turn", - min=1, - help="Maximum number of steps in one turn. Default: from config.", - ), - ] = None, - max_retries_per_step: Annotated[ - int | None, - typer.Option( - "--max-retries-per-step", - min=1, - help="Maximum number of retries in one step. Default: from config.", - ), - ] = None, - max_ralph_iterations: Annotated[ - int | None, - typer.Option( - "--max-ralph-iterations", - min=-1, - help=( - "Extra iterations after the first turn in Ralph mode. Use -1 for unlimited. " - "Default: from config." - ), - ), - ] = None, -): - """Kimi, your next CLI agent.""" - import asyncio - import contextlib - import json - - from kimi_cli.utils.proctitle import init_process_name - - init_process_name("Kimi Code") - - if ctx.invoked_subcommand is not None: - return # skip rest if a subcommand is invoked - - del version # handled in the callback - - from kaos.path import KaosPath - - from kimi_cli.agentspec import DEFAULT_AGENT_FILE, OKABE_AGENT_FILE - from kimi_cli.app import KimiCLI, enable_logging - from kimi_cli.config import Config, load_config_from_string - from kimi_cli.exception import ConfigError - from kimi_cli.hooks import events as hook_events - from kimi_cli.metadata import load_metadata, save_metadata - from kimi_cli.session import Session - from kimi_cli.ui.shell.startup import ShellStartupProgress - from kimi_cli.utils.logging import logger, open_original_stderr, redirect_stderr_to_logger - - from .mcp import get_global_mcp_config_file - - # Don't redirect stderr during argument parsing. Our stderr redirector - # replaces fd=2 with a pipe, which would swallow Click/Typer startup errors. - # Redirection is installed later, right before KimiCLI.create(), so that - # MCP server stderr noise is captured into logs from the start. - enable_logging(debug, redirect_stderr=False) - - def _emit_fatal_error(message: str) -> None: - # Prefer writing to the original stderr fd even if we later redirect fd=2. - # This ensures fatal errors are visible to the user. - with open_original_stderr() as stream: - if stream is not None: - stream.write((message.rstrip() + "\n").encode("utf-8", errors="replace")) - stream.flush() - return - typer.echo(message, err=True) - - if session_id is not None: - session_id = session_id.strip() - if not session_id: - raise typer.BadParameter("Session ID cannot be empty", param_hint="--session") - - if quiet: - if acp_mode or wire_mode: - raise typer.BadParameter( - "Quiet mode cannot be combined with ACP or Wire UI", - param_hint="--quiet", - ) - if output_format not in (None, "text"): - raise typer.BadParameter( - "Quiet mode implies `--output-format text`", - param_hint="--quiet", - ) - print_mode = True - output_format = "text" - final_message_only = True - - conflict_option_sets = [ - { - "--print": print_mode, - "--acp": acp_mode, - "--wire": wire_mode, - }, - { - "--agent": agent is not None, - "--agent-file": agent_file is not None, - }, - { - "--continue": continue_, - "--session": session_id is not None, - }, - { - "--config": config_string is not None, - "--config-file": config_file is not None, - }, - ] - for option_set in conflict_option_sets: - active_options = [flag for flag, active in option_set.items() if active] - if len(active_options) > 1: - raise typer.BadParameter( - f"Cannot combine {', '.join(active_options)}.", - param_hint=active_options[0], - ) - - if agent is not None: - match agent: - case "default": - agent_file = DEFAULT_AGENT_FILE - case "okabe": - agent_file = OKABE_AGENT_FILE - - ui: UIMode = "shell" - if print_mode: - ui = "print" - elif acp_mode: - ui = "acp" - elif wire_mode: - ui = "wire" - - if prompt is not None: - prompt = prompt.strip() - if not prompt: - raise typer.BadParameter("Prompt cannot be empty", param_hint="--prompt") - - if input_format is not None and ui != "print": - raise typer.BadParameter( - "Input format is only supported for print UI", - param_hint="--input-format", - ) - if output_format is not None and ui != "print": - raise typer.BadParameter( - "Output format is only supported for print UI", - param_hint="--output-format", - ) - if final_message_only and ui != "print": - raise typer.BadParameter( - "Final-message-only output is only supported for print UI", - param_hint="--final-message-only", - ) - - config: Config | Path | None = None - if config_string is not None: - config_string = config_string.strip() - if not config_string: - raise typer.BadParameter("Config cannot be empty", param_hint="--config") - try: - config = load_config_from_string(config_string) - except ConfigError as e: - raise typer.BadParameter(str(e), param_hint="--config") from e - elif config_file is not None: - config = config_file - - file_configs = list(mcp_config_file or []) - raw_mcp_config = list(mcp_config or []) - - # Use default MCP config file if no MCP config is provided - if not file_configs: - default_mcp_file = get_global_mcp_config_file() - if default_mcp_file.exists(): - file_configs.append(default_mcp_file) - - try: - mcp_configs = [json.loads(conf.read_text(encoding="utf-8")) for conf in file_configs] - except json.JSONDecodeError as e: - raise typer.BadParameter(f"Invalid JSON: {e}", param_hint="--mcp-config-file") from e - - try: - mcp_configs += [json.loads(conf) for conf in raw_mcp_config] - except json.JSONDecodeError as e: - raise typer.BadParameter(f"Invalid JSON: {e}", param_hint="--mcp-config") from e - - skills_dirs: list[KaosPath] | None = None - if local_skills_dir: - skills_dirs = [KaosPath.unsafe_from_local_path(p) for p in local_skills_dir] - - work_dir = KaosPath.unsafe_from_local_path(local_work_dir) if local_work_dir else KaosPath.cwd() - - # Tracks the most recently created/loaded session so that _reload_loop's - # exception handler can clean it up even when _run() fails before returning. - _latest_created_session: Session | None = None - - async def _run(session_id: str | None) -> tuple[Session, int]: - """ - Create/load session and run the CLI instance. - - Returns: - The session and the exit code (0 = success, 1 = failure, 75 = retryable). - """ - startup_progress = ShellStartupProgress(enabled=ui == "shell") - try: - startup_progress.update("Preparing session...") - - if session_id is not None: - session = await Session.find(work_dir, session_id) - if session is None: - logger.info( - "Session {session_id} not found, creating new session", - session_id=session_id, - ) - session = await Session.create(work_dir, session_id) - logger.info("Switching to session: {session_id}", session_id=session.id) - elif continue_: - session = await Session.continue_(work_dir) - if session is None: - raise typer.BadParameter( - "No previous session found for the working directory", - param_hint="--continue", - ) - logger.info("Continuing previous session: {session_id}", session_id=session.id) - else: - session = await Session.create(work_dir) - logger.info("Created new session: {session_id}", session_id=session.id) - - nonlocal _latest_created_session - _latest_created_session = session - - # Add CLI-provided additional directories to session state - if local_add_dirs: - from kimi_cli.utils.path import is_within_directory - - canonical_work_dir = work_dir.canonical() - changed = False - for d in local_add_dirs: - dir_path = KaosPath.unsafe_from_local_path(d).canonical() - dir_str = str(dir_path) - # Skip dirs within work_dir (already accessible) - if is_within_directory(dir_path, canonical_work_dir): - logger.info( - "Skipping --add-dir {dir}: already within working directory", - dir=dir_str, - ) - continue - if dir_str not in session.state.additional_dirs: - session.state.additional_dirs.append(dir_str) - changed = True - if changed: - session.save_state() - - # Redirect stderr *before* KimiCLI.create() so that MCP server - # subprocesses (e.g. mcp-remote OAuth debug logs) write to the log - # file instead of polluting the user's terminal. CLI argument - # parsing has already succeeded at this point, so Typer/Click - # startup errors are no longer a concern. Fatal errors from - # create() are still visible because _emit_fatal_error() writes to - # the saved original stderr fd. - redirect_stderr_to_logger() - - instance = await KimiCLI.create( - session, - config=config, - model_name=model_name, - thinking=thinking, - yolo=yolo or (ui == "print"), # print mode implies yolo - agent_file=agent_file, - mcp_configs=mcp_configs, - skills_dirs=skills_dirs, - max_steps_per_turn=max_steps_per_turn, - max_retries_per_step=max_retries_per_step, - max_ralph_iterations=max_ralph_iterations, - startup_progress=startup_progress.update if ui == "shell" else None, - defer_mcp_loading=ui == "shell" and prompt is None, - ) - startup_progress.stop() - - # --- SessionStart hook --- - _session_source = "resume" if continue_ else "startup" - await instance.soul.hook_engine.trigger( - "SessionStart", - matcher_value=_session_source, - input_data=hook_events.session_start( - session_id=session.id, - cwd=str(work_dir), - source=_session_source, - ), - ) - - # Install stderr redirection only after initialization succeeded, so runtime - # stderr noise is captured into logs without hiding startup failures. - redirect_stderr_to_logger() - preserve_background_tasks = False - try: - match ui: - case "shell": - shell_ok = await instance.run_shell(prompt) - exit_code = ExitCode.SUCCESS if shell_ok else ExitCode.FAILURE - case "print": - exit_code = await instance.run_print( - input_format or "text", - output_format or "text", - prompt, - final_only=final_message_only, - ) - case "acp": - if prompt is not None: - logger.warning("ACP server ignores prompt argument") - await instance.run_acp() - exit_code = ExitCode.SUCCESS - case "wire": - if prompt is not None: - logger.warning("Wire server ignores prompt argument") - await instance.run_wire_stdio() - exit_code = ExitCode.SUCCESS - except Reload as e: - preserve_background_tasks = True - if e.session_id is None: - r = Reload(session_id=session.id) - r.source_session = session - raise r from e - e.source_session = session - raise - except SwitchToWeb: - preserve_background_tasks = True - raise - except SwitchToVis: - preserve_background_tasks = True - raise - finally: - # --- SessionEnd hook --- - with contextlib.suppress(Exception): - await asyncio.wait_for( - instance.soul.hook_engine.trigger( - "SessionEnd", - matcher_value="exit", - input_data=hook_events.session_end( - session_id=session.id, - cwd=str(work_dir), - reason="exit", - ), - ), - timeout=5, - ) - - if not preserve_background_tasks: - instance.shutdown_background_tasks() - - return session, exit_code - finally: - startup_progress.stop() - - async def _delete_empty_session(session: Session) -> None: - """Delete an empty session directory and clear last_session_id if it pointed to it.""" - logger.info( - "Session {session_id} has empty context, removing it", - session_id=session.id, - ) - await session.delete() - meta = load_metadata() - wdm = meta.get_work_dir_meta(session.work_dir) - if wdm is not None and wdm.last_session_id == session.id: - wdm.last_session_id = None - save_metadata(meta) - - async def _post_run(last_session: Session, exit_code: int) -> None: - if last_session.is_empty(): - # Always clean up empty sessions regardless of exit code - await _delete_empty_session(last_session) - elif exit_code == ExitCode.SUCCESS: - metadata = load_metadata() - work_dir_meta = metadata.get_work_dir_meta(last_session.work_dir) - if work_dir_meta is None: - logger.warning( - "Work dir metadata missing when marking last session, recreating: {work_dir}", - work_dir=last_session.work_dir, - ) - work_dir_meta = metadata.new_work_dir_meta(last_session.work_dir) - work_dir_meta.last_session_id = last_session.id - save_metadata(metadata) - - async def _reload_loop(session_id: str | None) -> tuple[str | None, int]: - """Run the main loop, handling Reload/SwitchToWeb/SwitchToVis. - - Returns: - (switch_target, exit_code) where switch_target is "web", "vis", - or None if the session ended normally. - """ - last_session: Session | None = None - try: - while True: - try: - last_session, exit_code = await _run(session_id) - break - except Reload as e: - # Clean up old empty session when switching to a different session - old = e.source_session - if old is not None and old.id != e.session_id and old.is_empty(): - await _delete_empty_session(old) - last_session = None - else: - last_session = e.source_session - session_id = e.session_id - continue - except SwitchToWeb as e: - if e.session_id is not None: - session = await Session.find(work_dir, e.session_id) - if session is not None: - await _post_run(session, ExitCode.SUCCESS) - return "web", ExitCode.SUCCESS - except SwitchToVis as e: - if e.session_id is not None: - session = await Session.find(work_dir, e.session_id) - if session is not None: - await _post_run(session, ExitCode.SUCCESS) - return "vis", ExitCode.SUCCESS - assert last_session is not None - await _post_run(last_session, exit_code) - return None, exit_code - except (SwitchToWeb, SwitchToVis): - # Currently handled inside the loop (return), but re-raise explicitly - # so the generic except below never treats them as unexpected errors. - raise - except Exception: - # Best-effort cleanup: _latest_created_session is the session from - # the most recent _run() call, which may have failed before returning. - # last_session is from a *previous* iteration and must not be touched. - if _latest_created_session is not None and _latest_created_session.is_empty(): - with contextlib.suppress(Exception): - await _delete_empty_session(_latest_created_session) - raise - - try: - switch_target, exit_code = asyncio.run(_reload_loop(session_id)) - except (typer.BadParameter, typer.Exit): - # Let Typer/Click format these errors (rich panel + correct exit code). - raise - except Exception as exc: - import click - - if isinstance(exc, click.ClickException): - # ClickException includes the errors Typer knows how to render; don't - # wrap them, or we'd lose the standard error UI and exit codes. - raise - logger.exception("Fatal error when running CLI") - if debug: - import traceback - - # In debug mode, show full traceback for quick diagnosis. - _emit_fatal_error(traceback.format_exc()) - else: - from kimi_cli.share import get_share_dir - - log_path = get_share_dir() / "logs" / "kimi.log" - # In non-debug mode, print a concise error and point users to logs. - _emit_fatal_error(f"{exc}\nSee logs: {log_path}") - raise typer.Exit(code=1) from exc - if switch_target in ("web", "vis"): - from kimi_cli.utils.logging import restore_stderr - - restore_stderr() - - # Restore default SIGINT handler and terminal state after the shell's - # asyncio.run() to ensure Ctrl+C works in the uvicorn web server. - import signal - - signal.signal(signal.SIGINT, signal.default_int_handler) - - from kimi_cli.utils.term import ensure_tty_sane - - ensure_tty_sane() - - if switch_target == "web": - from kimi_cli.web.app import run_web_server - - run_web_server(open_browser=True) - else: - from kimi_cli.vis.app import run_vis_server - - run_vis_server(open_browser=True) - elif exit_code != ExitCode.SUCCESS: - raise typer.Exit(code=exit_code) - - -@cli.command() -def login( - json: bool = typer.Option( - False, - "--json", - help="Emit OAuth events as JSON lines.", - ), -) -> None: - """Login to your Kimi account.""" - import asyncio - - from rich.console import Console - from rich.status import Status - - from kimi_cli.auth.oauth import login_kimi_code - from kimi_cli.config import load_config - - async def _run() -> bool: - if json: - ok = True - async for event in login_kimi_code(load_config()): - typer.echo(event.json) - if event.type == "error": - ok = False - return ok - - console = Console() - ok = True - status: Status | None = None - try: - async for event in login_kimi_code(load_config()): - if event.type == "waiting": - if status is None: - status = console.status("Waiting for user authorization...") - status.start() - continue - if status is not None: - status.stop() - status = None - match event.type: - case "error": - style = "red" - case "success": - style = "green" - case _: - style = None - console.print(event.message, markup=False, style=style) - if event.type == "error": - ok = False - finally: - if status is not None: - status.stop() - return ok - - ok = asyncio.run(_run()) - if not ok: - raise typer.Exit(code=1) - - -@cli.command() -def logout( - json: bool = typer.Option( - False, - "--json", - help="Emit OAuth events as JSON lines.", - ), -) -> None: - """Logout from your Kimi account.""" - import asyncio - - from rich.console import Console - - from kimi_cli.auth.oauth import logout_kimi_code - from kimi_cli.config import load_config - - async def _run() -> bool: - ok = True - if json: - async for event in logout_kimi_code(load_config()): - typer.echo(event.json) - if event.type == "error": - ok = False - return ok - - console = Console() - async for event in logout_kimi_code(load_config()): - match event.type: - case "error": - style = "red" - case "success": - style = "green" - case _: - style = None - console.print(event.message, markup=False, style=style) - if event.type == "error": - ok = False - return ok - - ok = asyncio.run(_run()) - if not ok: - raise typer.Exit(code=1) - - -@cli.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) -def term( - ctx: typer.Context, -) -> None: - """Run Toad TUI backed by Kimi Code CLI ACP server.""" - from .toad import run_term - - run_term(ctx) - - -@cli.command() -def acp(): - """Run Kimi Code CLI ACP server.""" - from kimi_cli.acp import acp_main - - acp_main() - - -@cli.command(name="__background-task-worker", hidden=True) -def background_task_worker( - task_dir: Annotated[Path, typer.Option("--task-dir")], - heartbeat_interval_ms: Annotated[int, typer.Option("--heartbeat-interval-ms")] = 5000, - control_poll_interval_ms: Annotated[int, typer.Option("--control-poll-interval-ms")] = 500, - kill_grace_period_ms: Annotated[int, typer.Option("--kill-grace-period-ms")] = 2000, -) -> None: - """Run background task worker subprocess (internal).""" - import asyncio - - from kimi_cli.background import run_background_task_worker - from kimi_cli.utils.proctitle import set_process_title - - set_process_title("kimi-code-bg-worker") - - from kimi_cli.app import enable_logging - - enable_logging(debug=False) - asyncio.run( - run_background_task_worker( - task_dir, - heartbeat_interval_ms=heartbeat_interval_ms, - control_poll_interval_ms=control_poll_interval_ms, - kill_grace_period_ms=kill_grace_period_ms, - ) - ) - - -@cli.command(name="__web-worker", hidden=True) -def web_worker(session_id: str) -> None: - """Run web worker subprocess (internal).""" - import asyncio - from uuid import UUID - - from kimi_cli.utils.proctitle import set_process_title - - set_process_title("kimi-code-worker") - - from kimi_cli.app import enable_logging - from kimi_cli.web.runner.worker import run_worker - - try: - parsed_session_id = UUID(session_id) - except ValueError as exc: - raise typer.BadParameter(f"Invalid session ID: {session_id}") from exc - - enable_logging(debug=False) - asyncio.run(run_worker(parsed_session_id)) - - -if __name__ == "__main__": - import sys - - if "kimi_cli.cli" not in sys.modules: - sys.modules["kimi_cli.cli"] = sys.modules[__name__] - - sys.exit(cli()) diff --git a/src/kimi_cli/cli/__main__.py b/src/kimi_cli/cli/__main__.py deleted file mode 100644 index 3ad011483..000000000 --- a/src/kimi_cli/cli/__main__.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -import sys - -from kimi_cli.cli import cli - -if __name__ == "__main__": - from kimi_cli.utils.proxy import normalize_proxy_env - - normalize_proxy_env() - sys.exit(cli()) diff --git a/src/kimi_cli/cli/_lazy_group.py b/src/kimi_cli/cli/_lazy_group.py deleted file mode 100644 index f3fa701b9..000000000 --- a/src/kimi_cli/cli/_lazy_group.py +++ /dev/null @@ -1,222 +0,0 @@ -# pyright: reportAttributeAccessIssue=false, reportMissingParameterType=false, reportPrivateImportUsage=false, reportPrivateUsage=false, reportUnknownArgumentType=false, reportUnknownMemberType=false, reportUnknownParameterType=false, reportUnknownVariableType=false, reportUntypedBaseClass=false -from __future__ import annotations - -from importlib import import_module -from typing import Any, cast - -import click -import typer -from click.core import HelpFormatter -from typer.main import get_command - - -class LazySubcommandGroup(typer.core.TyperGroup): - """Load heavyweight subcommands only when they are actually invoked.""" - - lazy_subcommands: dict[str, tuple[str, str, str]] = { - "info": ("kimi_cli.cli.info", "cli", "Show version and protocol information."), - "export": ("kimi_cli.cli.export", "cli", "Export session data."), - "mcp": ("kimi_cli.cli.mcp", "cli", "Manage MCP server configurations."), - "plugin": ("kimi_cli.cli.plugin", "cli", "Manage plugins."), - "vis": ("kimi_cli.cli.vis", "cli", "Run Kimi Agent Tracing Visualizer."), - "web": ("kimi_cli.cli.web", "cli", "Run Kimi Code CLI web interface."), - } - lazy_command_order: tuple[str, ...] = ( - "info", - "export", - "mcp", - "plugin", - "vis", - "web", - ) - - def list_commands(self, ctx: click.Context) -> list[str]: - commands = list(super().list_commands(ctx)) - for name in self.lazy_command_order: - if name not in commands: - commands.append(name) - return commands - - def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: - command = super().get_command(ctx, cmd_name) - if command is not None: - return command - - lazy_spec = self.lazy_subcommands.get(cmd_name) - if lazy_spec is None: - return None - - module_name, attribute_name, _ = lazy_spec - command = get_command(getattr(import_module(module_name), attribute_name)) - command.name = cmd_name - self.commands[cmd_name] = command - return command - - def format_help(self, ctx: click.Context, formatter: HelpFormatter) -> None: - if not typer.core.HAS_RICH or self.rich_markup_mode is None: - return super().format_help(ctx, formatter) - - from typer import rich_utils - - rich_utils_any = cast(Any, rich_utils) - console = rich_utils_any._get_rich_console() - console.print( - rich_utils_any.Padding( - rich_utils_any.highlighter(self.get_usage(ctx)), - 1, - ), - style=rich_utils_any.STYLE_USAGE_COMMAND, - ) - - if self.help: - console.print( - rich_utils_any.Padding( - rich_utils_any.Align( - rich_utils_any._get_help_text( - obj=self, - markup_mode=self.rich_markup_mode, - ), - pad=False, - ), - (0, 1, 1, 1), - ) - ) - - panel_to_arguments: dict[str, list[click.Argument]] = {} - panel_to_options: dict[str, list[click.Option]] = {} - for param in self.get_params(ctx): - if getattr(param, "hidden", False): - continue - if isinstance(param, click.Argument): - panel_name = ( - getattr(param, rich_utils_any._RICH_HELP_PANEL_NAME, None) - or rich_utils_any.ARGUMENTS_PANEL_TITLE - ) - panel_to_arguments.setdefault(panel_name, []).append(param) - elif isinstance(param, click.Option): - panel_name = ( - getattr(param, rich_utils_any._RICH_HELP_PANEL_NAME, None) - or rich_utils_any.OPTIONS_PANEL_TITLE - ) - panel_to_options.setdefault(panel_name, []).append(param) - - default_arguments = panel_to_arguments.get(rich_utils_any.ARGUMENTS_PANEL_TITLE, []) - rich_utils_any._print_options_panel( - name=rich_utils_any.ARGUMENTS_PANEL_TITLE, - params=default_arguments, - ctx=ctx, - markup_mode=self.rich_markup_mode, - console=console, - ) - for panel_name, arguments in panel_to_arguments.items(): - if panel_name == rich_utils_any.ARGUMENTS_PANEL_TITLE: - continue - rich_utils_any._print_options_panel( - name=panel_name, - params=arguments, - ctx=ctx, - markup_mode=self.rich_markup_mode, - console=console, - ) - - default_options = panel_to_options.get(rich_utils_any.OPTIONS_PANEL_TITLE, []) - rich_utils_any._print_options_panel( - name=rich_utils_any.OPTIONS_PANEL_TITLE, - params=default_options, - ctx=ctx, - markup_mode=self.rich_markup_mode, - console=console, - ) - for panel_name, options in panel_to_options.items(): - if panel_name == rich_utils_any.OPTIONS_PANEL_TITLE: - continue - rich_utils_any._print_options_panel( - name=panel_name, - params=options, - ctx=ctx, - markup_mode=self.rich_markup_mode, - console=console, - ) - - panel_to_commands: dict[str, list[click.Command]] = {} - for command_name in self.list_commands(ctx): - command = self.commands.get(command_name) - if command is None: - lazy_spec = self.lazy_subcommands.get(command_name) - if lazy_spec is None: - continue - command = click.Command(command_name, help=lazy_spec[2]) - if command.hidden: - continue - panel_name = ( - getattr(command, rich_utils_any._RICH_HELP_PANEL_NAME, None) - or rich_utils_any.COMMANDS_PANEL_TITLE - ) - panel_to_commands.setdefault(panel_name, []).append(command) - - max_cmd_len = max( - ( - len(command.name or "") - for commands in panel_to_commands.values() - for command in commands - ), - default=0, - ) - default_commands = panel_to_commands.get(rich_utils_any.COMMANDS_PANEL_TITLE, []) - rich_utils_any._print_commands_panel( - name=rich_utils_any.COMMANDS_PANEL_TITLE, - commands=default_commands, - markup_mode=self.rich_markup_mode, - console=console, - cmd_len=max_cmd_len, - ) - for panel_name, commands in panel_to_commands.items(): - if panel_name == rich_utils_any.COMMANDS_PANEL_TITLE: - continue - rich_utils_any._print_commands_panel( - name=panel_name, - commands=commands, - markup_mode=self.rich_markup_mode, - console=console, - cmd_len=max_cmd_len, - ) - - if self.epilog: - lines = self.epilog.split("\n\n") - epilogue = "\n".join(x.replace("\n", " ").strip() for x in lines) - epilogue_text = rich_utils_any._make_rich_text( - text=epilogue, - markup_mode=self.rich_markup_mode, - ) - console.print(rich_utils_any.Padding(rich_utils_any.Align(epilogue_text, pad=False), 1)) - - def format_commands(self, ctx: click.Context, formatter: HelpFormatter) -> None: - entries: list[tuple[str, str | None]] = [] - for subcommand in self.list_commands(ctx): - command = self.commands.get(subcommand) - if command is not None: - if command.hidden: - continue - entries.append((subcommand, None)) - continue - - lazy_spec = self.lazy_subcommands.get(subcommand) - if lazy_spec is None: - continue - entries.append((subcommand, lazy_spec[2])) - - if not entries: - return - - limit = formatter.width - 6 - max(len(name) for name, _ in entries) - rows: list[tuple[str, str]] = [] - for subcommand, short_help in entries: - command = self.commands.get(subcommand) - if command is not None: - rows.append((subcommand, command.get_short_help_str(limit))) - continue - rows.append((subcommand, short_help or "")) - - if rows: - with formatter.section("Commands"): - formatter.write_dl(rows) diff --git a/src/kimi_cli/cli/export.py b/src/kimi_cli/cli/export.py deleted file mode 100644 index 2b7b37e99..000000000 --- a/src/kimi_cli/cli/export.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Export command for packaging session data.""" - -from __future__ import annotations - -import io -import zipfile -from pathlib import Path -from typing import Annotated - -import typer - -cli = typer.Typer(help="Export session data.") - - -def _find_session_by_id(session_id: str) -> Path | None: - """Find a session directory by session ID across all work directories.""" - from kimi_cli.share import get_share_dir - - sessions_root = get_share_dir() / "sessions" - if not sessions_root.exists(): - return None - - for work_dir_hash_dir in sessions_root.iterdir(): - if not work_dir_hash_dir.is_dir(): - continue - candidate = work_dir_hash_dir / session_id - if candidate.is_dir(): - return candidate - - return None - - -@cli.callback(invoke_without_command=True) -def export( - session_id: Annotated[ - str, - typer.Argument(help="Session ID to export."), - ], - output: Annotated[ - Path | None, - typer.Option( - "--output", - "-o", - help="Output ZIP file path. Default: session-{id}.zip in current directory.", - ), - ] = None, -) -> None: - """Export a session as a ZIP archive.""" - session_dir = _find_session_by_id(session_id) - if session_dir is None: - typer.echo(f"Error: session '{session_id}' not found.", err=True) - raise typer.Exit(code=1) - - # Collect files - files = sorted(f for f in session_dir.iterdir() if f.is_file()) - if not files: - typer.echo(f"Error: session '{session_id}' has no files.", err=True) - raise typer.Exit(code=1) - - # Determine output path - if output is None: - output = Path.cwd() / f"session-{session_id}.zip" - - # Create ZIP - buf = io.BytesIO() - with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: - for file_path in files: - zf.write(file_path, arcname=file_path.name) - buf.seek(0) - - output.parent.mkdir(parents=True, exist_ok=True) - output.write_bytes(buf.getvalue()) - - typer.echo(str(output)) diff --git a/src/kimi_cli/cli/export.ts b/src/kimi_cli/cli/export.ts new file mode 100644 index 000000000..8f25d84c1 --- /dev/null +++ b/src/kimi_cli/cli/export.ts @@ -0,0 +1,86 @@ +/** + * CLI export command — corresponds to Python cli/export.py + * Exports a session as a ZIP archive. + */ + +import { Command } from "commander"; +import { join, resolve } from "node:path"; +import { readdirSync, statSync, readFileSync, writeFileSync, mkdirSync, existsSync } from "node:fs"; + +function findSessionById(sessionId: string): string | null { + const { getShareDir } = require("../config.ts"); + const sessionsRoot = join(getShareDir(), "sessions"); + if (!existsSync(sessionsRoot)) return null; + + try { + for (const workDirHash of readdirSync(sessionsRoot)) { + const workDirHashDir = join(sessionsRoot, workDirHash); + try { + if (!statSync(workDirHashDir).isDirectory()) continue; + } catch { + continue; + } + const candidate = join(workDirHashDir, sessionId); + try { + if (statSync(candidate).isDirectory()) return candidate; + } catch { + continue; + } + } + } catch { + // ignore + } + return null; +} + +export const exportCommand = new Command("export") + .description("Export a session as a ZIP archive.") + .argument("", "Session ID to export.") + .option("-o, --output ", "Output ZIP file path. Default: session-{id}.zip in current directory.") + .action(async (sessionId: string, options: { output?: string }) => { + const sessionDir = findSessionById(sessionId); + if (!sessionDir) { + console.error(`Error: session '${sessionId}' not found.`); + process.exit(1); + } + + // Collect files + let files: string[]; + try { + files = readdirSync(sessionDir) + .filter((f) => { + try { + return statSync(join(sessionDir, f)).isFile(); + } catch { + return false; + } + }) + .sort(); + } catch { + files = []; + } + + if (files.length === 0) { + console.error(`Error: session '${sessionId}' has no files.`); + process.exit(1); + } + + // Determine output path + const outputPath = options.output + ? resolve(options.output) + : resolve(process.cwd(), `session-${sessionId}.zip`); + + // Use Bun's built-in zip capabilities or fall back to shell + try { + const outputDir = outputPath.substring(0, outputPath.lastIndexOf("/")); + mkdirSync(outputDir, { recursive: true }); + + // Use shell zip command + const fileArgs = files.map((f) => join(sessionDir, f)); + await Bun.$`zip -j ${outputPath} ${fileArgs}`.quiet(); + console.log(outputPath); + } catch (err) { + console.error(`Error creating ZIP: ${err}`); + process.exit(1); + } + }); diff --git a/src/kimi_cli/cli/index.ts b/src/kimi_cli/cli/index.ts new file mode 100644 index 000000000..69f3c085d --- /dev/null +++ b/src/kimi_cli/cli/index.ts @@ -0,0 +1,342 @@ +/** + * CLI router — corresponds to Python cli/__init__.py + * Uses Commander.js (replaces Typer) + */ + +import { Command } from "commander"; +import React from "react"; +import { render } from "ink"; +import { KimiCLI } from "../app.ts"; +import type { SoulCallbacks } from "../soul/kimisoul.ts"; +import { Shell } from "../ui/shell/Shell.tsx"; +import type { WireUIEvent } from "../ui/shell/events.ts"; +import chalk from "chalk"; + +// ── Re-exports from Python cli/__init__.py ────────────── + +export class Reload extends Error { + sessionId: string | null; + constructor(sessionId: string | null = null) { + super("reload"); + this.name = "Reload"; + this.sessionId = sessionId; + } +} + +export class SwitchToWeb extends Error { + sessionId: string | null; + constructor(sessionId: string | null = null) { + super("switch_to_web"); + this.name = "SwitchToWeb"; + this.sessionId = sessionId; + } +} + +export class SwitchToVis extends Error { + sessionId: string | null; + constructor(sessionId: string | null = null) { + super("switch_to_vis"); + this.name = "SwitchToVis"; + this.sessionId = sessionId; + } +} + +export type UIMode = "shell" | "print" | "acp" | "wire"; +export type InputFormat = "text" | "stream-json"; +export type OutputFormat = "text" | "stream-json"; + +export const ExitCode = { + SUCCESS: 0, + FAILURE: 1, + RETRYABLE: 75, // EX_TEMPFAIL from sysexits.h +} as const; + +// ── Subcommands ────────────────────────────────────────── + +import { loginCommand } from "./login.ts"; +import { logoutCommand } from "./logout.ts"; +import { infoCommand } from "./info.ts"; +import { exportCommand } from "./export.ts"; + +// ── Version callback ───────────────────────────────────── + +function getVersionString(): string { + try { + const { getVersion } = require("../constant.ts"); + return getVersion(); + } catch { + return "0.0.0"; + } +} + +// ── Program ────────────────────────────────────────────── + +const program = new Command() + .name("kimi") + .description("Kimi, your next CLI agent.") + .version(getVersionString(), "-V, --version") + .addCommand(loginCommand) + .addCommand(logoutCommand) + .addCommand(infoCommand) + .addCommand(exportCommand); + +// Main chat command (default) +program + .argument("[prompt...]", "Initial prompt to send") + .option("-m, --model ", "Model to use") + .option("--thinking", "Enable thinking mode") + .option("--no-thinking", "Disable thinking mode") + .option("--yolo", "Auto-approve all tool calls") + .option("--print", "Print mode (non-interactive)") + .option("-w, --work-dir ", "Working directory") + .option("--add-dir ", "Add additional directories to the workspace") + .option("--max-steps-per-turn ", "Max steps per turn", parseInt) + .option("--max-retries-per-step ", "Max retries per step", parseInt) + .option("--config-file ", "Config TOML/JSON file to load") + .option("--config ", "Config TOML/JSON string to load") + .option("--session ", "Resume session by ID") + .option("-C, --continue", "Continue the most recent session") + .option("--input-format ", "Input format (text, stream-json). Print mode only.") + .option("--output-format ", "Output format (text, stream-json). Print mode only.") + .option("--quiet", "Alias for --print --output-format text --final-message-only") + .option("--final-message-only", "Only print the final assistant message (print UI)") + .option("-p, --prompt ", "User prompt to the agent") + .option("--verbose", "Verbose output") + .option("--debug", "Debug mode") + .option("--wire", "Run as Wire server (experimental)") + .option("--agent ", "Builtin agent specification to use") + .option("--agent-file ", "Custom agent specification file") + .option("--mcp-config-file ", "MCP config file(s) to load") + .option("--mcp-config ", "MCP config JSON to load") + .action( + async ( + promptParts: string[], + options: { + model?: string; + thinking?: boolean; + yolo?: boolean; + print?: boolean; + workDir?: string; + addDir?: string[]; + maxStepsPerTurn?: number; + maxRetriesPerStep?: number; + configFile?: string; + config?: string; + session?: string; + continue?: boolean; + inputFormat?: string; + outputFormat?: string; + quiet?: boolean; + finalMessageOnly?: boolean; + prompt?: string; + verbose?: boolean; + debug?: boolean; + wire?: boolean; + agent?: string; + agentFile?: string; + mcpConfigFile?: string[]; + mcpConfig?: string[]; + }, + ) => { + // Handle --quiet alias + if (options.quiet) { + options.print = true; + options.outputFormat = "text"; + options.finalMessageOnly = true; + } + + // Resolve prompt from either positional args or --prompt option + const prompt = + promptParts.length > 0 + ? promptParts.join(" ") + : options.prompt ?? undefined; + + // Determine config source: --config-file takes precedence over legacy --config as path + const configFile = options.configFile ?? undefined; + + try { + if (options.print) { + // ── Print mode: callbacks write directly to stdout/stderr ── + const callbacks: SoulCallbacks = { + onTextDelta: (text) => process.stdout.write(text), + onThinkDelta: (text) => process.stderr.write(chalk.dim(text)), + onError: (err) => + process.stderr.write(chalk.red(`[ERROR] ${err.message}\n`)), + onTurnEnd: () => process.stdout.write("\n"), + onStatusUpdate: (status) => { + if (options.verbose && status.tokenUsage) { + process.stderr.write( + chalk.dim( + `[tokens] in=${status.tokenUsage.inputTokens} out=${status.tokenUsage.outputTokens}\n`, + ), + ); + } + }, + }; + + const app = await KimiCLI.create({ + workDir: options.workDir, + additionalDirs: options.addDir, + configFile, + modelName: options.model, + thinking: options.thinking, + yolo: options.yolo ?? true, // print mode implies yolo + sessionId: options.session, + continueSession: options.continue, + maxStepsPerTurn: options.maxStepsPerTurn ?? options.maxRetriesPerStep, + callbacks, + }); + + if (prompt) await app.runPrint(prompt); + await app.shutdown(); + } else if (options.wire) { + // ── Wire mode ── + const app = await KimiCLI.create({ + workDir: options.workDir, + additionalDirs: options.addDir, + configFile, + modelName: options.model, + thinking: options.thinking, + yolo: options.yolo, + sessionId: options.session, + continueSession: options.continue, + maxStepsPerTurn: options.maxStepsPerTurn ?? options.maxRetriesPerStep, + callbacks: {}, + }); + // Wire mode not yet implemented + console.error("Wire mode is not yet implemented."); + await app.shutdown(); + process.exit(1); + } else { + // ── Interactive mode: callbacks push wire events to React Ink UI ── + + // pushEvent will be set by Shell's onWireReady callback + let pushEvent: ((event: WireUIEvent) => void) | null = null; + + const callbacks: SoulCallbacks = { + onTurnBegin: (userInput) => { + const text = + typeof userInput === "string" + ? userInput + : "[complex input]"; + pushEvent?.({ type: "turn_begin", userInput: text }); + }, + onTurnEnd: () => { + pushEvent?.({ type: "turn_end" }); + }, + onStepBegin: (n) => { + pushEvent?.({ type: "step_begin", n }); + }, + onTextDelta: (text) => { + pushEvent?.({ type: "text_delta", text }); + }, + onThinkDelta: (text) => { + pushEvent?.({ type: "think_delta", text }); + }, + onToolCall: (tc) => { + pushEvent?.({ + type: "tool_call", + id: tc.id, + name: tc.name, + arguments: tc.arguments, + }); + }, + onToolResult: (toolCallId, result) => { + pushEvent?.({ + type: "tool_result", + toolCallId, + result: { + tool_call_id: toolCallId, + return_value: { + isError: result.isError, + output: result.output, + message: result.message, + }, + display: [], + }, + }); + }, + onStatusUpdate: (status) => { + pushEvent?.({ + type: "status_update", + status: { + context_usage: status.contextUsage ?? null, + context_tokens: status.contextTokens ?? null, + max_context_tokens: status.maxContextTokens ?? null, + token_usage: status.tokenUsage ?? null, + message_id: null, + plan_mode: status.planMode ?? null, + mcp_status: null, + }, + }); + }, + onCompactionBegin: () => { + pushEvent?.({ type: "compaction_begin" }); + }, + onCompactionEnd: () => { + pushEvent?.({ type: "compaction_end" }); + }, + onError: (err) => { + pushEvent?.({ type: "error", message: err.message }); + }, + onNotification: (title, body) => { + pushEvent?.({ type: "notification", title, body }); + }, + }; + + const app = await KimiCLI.create({ + workDir: options.workDir, + additionalDirs: options.addDir, + configFile, + modelName: options.model, + thinking: options.thinking, + yolo: options.yolo, + sessionId: options.session, + continueSession: options.continue, + maxStepsPerTurn: options.maxStepsPerTurn ?? options.maxRetriesPerStep, + callbacks, + }); + + const { waitUntilExit } = render( + React.createElement(Shell, { + modelName: app.soul.modelName, + workDir: options.workDir ?? process.cwd(), + sessionId: app.session.id, + thinking: app.soul.thinking, + onSubmit: (input: string) => { + app.soul.run(input); + }, + onInterrupt: () => { + app.soul.abort(); + }, + onWireReady: (push) => { + pushEvent = push; + }, + extraSlashCommands: app.soul.availableSlashCommands, + }), + ); + + // Run initial prompt if provided + if (prompt) { + app.soul.run(prompt); + } + + await waitUntilExit(); + await app.shutdown(); + } + } catch (err) { + console.error("Error:", err); + process.exit(ExitCode.FAILURE); + } + }, + ); + +export async function cli(argv: string[]): Promise { + try { + await program.parseAsync(argv); + return ExitCode.SUCCESS; + } catch (error) { + console.error("Fatal error:", error); + return ExitCode.FAILURE; + } +} diff --git a/src/kimi_cli/cli/info.py b/src/kimi_cli/cli/info.py deleted file mode 100644 index b2678b889..000000000 --- a/src/kimi_cli/cli/info.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -import json -import platform -from typing import Annotated, TypedDict - -import typer - - -class InfoData(TypedDict): - kimi_cli_version: str - agent_spec_versions: list[str] - wire_protocol_version: str - python_version: str - - -def _collect_info() -> InfoData: - from kimi_cli.agentspec import SUPPORTED_AGENT_SPEC_VERSIONS - from kimi_cli.constant import get_version - from kimi_cli.wire.protocol import WIRE_PROTOCOL_VERSION - - return { - "kimi_cli_version": get_version(), - "agent_spec_versions": [str(version) for version in SUPPORTED_AGENT_SPEC_VERSIONS], - "wire_protocol_version": WIRE_PROTOCOL_VERSION, - "python_version": platform.python_version(), - } - - -def _emit_info(json_output: bool) -> None: - info = _collect_info() - if json_output: - typer.echo(json.dumps(info, ensure_ascii=False)) - return - - agent_versions_text = ", ".join(str(version) for version in info["agent_spec_versions"]) - - lines = [ - f"kimi-cli version: {info['kimi_cli_version']}", - f"agent spec versions: {agent_versions_text}", - f"wire protocol: {info['wire_protocol_version']}", - f"python version: {info['python_version']}", - ] - for line in lines: - typer.echo(line) - - -cli = typer.Typer(help="Show version and protocol information.") - - -@cli.callback(invoke_without_command=True) -def info( - json_output: Annotated[ - bool, - typer.Option( - "--json", - help="Output information as JSON.", - ), - ] = False, -): - """Show version and protocol information.""" - _emit_info(json_output) diff --git a/src/kimi_cli/cli/info.ts b/src/kimi_cli/cli/info.ts new file mode 100644 index 000000000..c4447168f --- /dev/null +++ b/src/kimi_cli/cli/info.ts @@ -0,0 +1,48 @@ +/** + * CLI info command — corresponds to Python cli/info.py + * Displays version, agent spec, wire protocol, and runtime info. + */ + +import { Command } from "commander"; + +interface InfoData { + kimi_cli_version: string; + wire_protocol_version: string; + runtime: string; + runtime_version: string; +} + +function collectInfo(): InfoData { + // Lazy imports to avoid circular dependencies at module load time + const { getVersion } = require("../constant.ts"); + let wireProtocolVersion = "unknown"; + try { + const wire = require("../wire/protocol.ts"); + wireProtocolVersion = wire.WIRE_PROTOCOL_VERSION ?? "unknown"; + } catch { + // wire module may not be available + } + + return { + kimi_cli_version: getVersion(), + wire_protocol_version: wireProtocolVersion, + runtime: "bun", + runtime_version: typeof Bun !== "undefined" ? Bun.version : process.version, + }; +} + +export const infoCommand = new Command("info") + .description("Show version and protocol information.") + .option("--json", "Output information as JSON.", false) + .action((options: { json?: boolean }) => { + const info = collectInfo(); + + if (options.json) { + console.log(JSON.stringify(info, null, 2)); + return; + } + + console.log(`kimi-cli version: ${info.kimi_cli_version}`); + console.log(`wire protocol: ${info.wire_protocol_version}`); + console.log(`runtime: ${info.runtime} ${info.runtime_version}`); + }); diff --git a/src/kimi_cli/cli/login.ts b/src/kimi_cli/cli/login.ts new file mode 100644 index 000000000..49bbb1178 --- /dev/null +++ b/src/kimi_cli/cli/login.ts @@ -0,0 +1,53 @@ +/** + * CLI login command — corresponds to Python cli login() command + * Device-code OAuth flow with console/JSON output. + */ + +import { Command } from "commander"; + +export const loginCommand = new Command("login") + .description("Login to your Kimi account.") + .option("--json", "Emit OAuth events as JSON lines.", false) + .action(async (options: { json?: boolean }) => { + const { loginKimiCode } = await import("../auth/oauth.ts"); + const { loadConfig } = await import("../config.ts"); + + const { config } = await loadConfig(); + let ok = true; + + if (options.json) { + for await (const event of loginKimiCode(config)) { + console.log( + JSON.stringify({ + type: event.type, + message: event.message, + ...(event.data ? { data: event.data } : {}), + }), + ); + if (event.type === "error") ok = false; + } + } else { + let waiting = false; + for await (const event of loginKimiCode(config)) { + if (event.type === "waiting") { + if (!waiting) { + process.stderr.write("Waiting for user authorization...\n"); + waiting = true; + } + continue; + } + waiting = false; + const color = + event.type === "error" + ? "\x1b[31m" + : event.type === "success" + ? "\x1b[32m" + : ""; + const reset = color ? "\x1b[0m" : ""; + console.log(`${color}${event.message}${reset}`); + if (event.type === "error") ok = false; + } + } + + if (!ok) process.exit(1); + }); diff --git a/src/kimi_cli/cli/logout.ts b/src/kimi_cli/cli/logout.ts new file mode 100644 index 000000000..0023aa888 --- /dev/null +++ b/src/kimi_cli/cli/logout.ts @@ -0,0 +1,44 @@ +/** + * CLI logout command — corresponds to Python cli logout() command + * Clears OAuth tokens and config. + */ + +import { Command } from "commander"; + +export const logoutCommand = new Command("logout") + .description("Logout from your Kimi account.") + .option("--json", "Emit OAuth events as JSON lines.", false) + .action(async (options: { json?: boolean }) => { + const { logoutKimiCode } = await import("../auth/oauth.ts"); + const { loadConfig } = await import("../config.ts"); + + const { config } = await loadConfig(); + let ok = true; + + if (options.json) { + for await (const event of logoutKimiCode(config)) { + console.log( + JSON.stringify({ + type: event.type, + message: event.message, + ...(event.data ? { data: event.data } : {}), + }), + ); + if (event.type === "error") ok = false; + } + } else { + for await (const event of logoutKimiCode(config)) { + const color = + event.type === "error" + ? "\x1b[31m" + : event.type === "success" + ? "\x1b[32m" + : ""; + const reset = color ? "\x1b[0m" : ""; + console.log(`${color}${event.message}${reset}`); + if (event.type === "error") ok = false; + } + } + + if (!ok) process.exit(1); + }); diff --git a/src/kimi_cli/cli/mcp.py b/src/kimi_cli/cli/mcp.py deleted file mode 100644 index 835774eb3..000000000 --- a/src/kimi_cli/cli/mcp.py +++ /dev/null @@ -1,349 +0,0 @@ -import json -from pathlib import Path -from typing import Annotated, Any, Literal - -import typer - -cli = typer.Typer(help="Manage MCP server configurations.") - - -def get_global_mcp_config_file() -> Path: - """Get the global MCP config file path.""" - from kimi_cli.share import get_share_dir - - return get_share_dir() / "mcp.json" - - -def _load_mcp_config() -> dict[str, Any]: - """Load MCP config from global mcp config file.""" - from fastmcp.mcp_config import MCPConfig - from pydantic import ValidationError - - mcp_file = get_global_mcp_config_file() - if not mcp_file.exists(): - return {"mcpServers": {}} - try: - config = json.loads(mcp_file.read_text(encoding="utf-8")) - except json.JSONDecodeError as e: - raise typer.BadParameter(f"Invalid JSON in MCP config file '{mcp_file}': {e}") from e - - try: - MCPConfig.model_validate(config) - except ValidationError as e: - raise typer.BadParameter(f"Invalid MCP config in '{mcp_file}': {e}") from e - - return config - - -def _save_mcp_config(config: dict[str, Any]) -> None: - """Save MCP config to default file.""" - mcp_file = get_global_mcp_config_file() - mcp_file.write_text(json.dumps(config, indent=2, ensure_ascii=False), encoding="utf-8") - - -def _get_mcp_server(name: str, *, require_remote: bool = False) -> dict[str, Any]: - """Get MCP server config by name.""" - config = _load_mcp_config() - servers = config.get("mcpServers", {}) - if name not in servers: - typer.echo(f"MCP server '{name}' not found.", err=True) - raise typer.Exit(code=1) - server = servers[name] - if require_remote and "url" not in server: - typer.echo(f"MCP server '{name}' is not a remote server.", err=True) - raise typer.Exit(code=1) - return server - - -def _parse_key_value_pairs( - items: list[str], option_name: str, *, separator: str = "=", strip_whitespace: bool = False -) -> dict[str, str]: - """Parse key/value pairs from CLI options.""" - parsed: dict[str, str] = {} - for item in items: - if separator not in item: - typer.echo( - f"Invalid {option_name} format: {item} (expected KEY{separator}VALUE).", - err=True, - ) - raise typer.Exit(code=1) - key, value = item.split(separator, 1) - if strip_whitespace: - key, value = key.strip(), value.strip() - if not key: - typer.echo(f"Invalid {option_name} format: {item} (empty key).", err=True) - raise typer.Exit(code=1) - parsed[key] = value - return parsed - - -Transport = Literal["stdio", "http"] - - -@cli.command( - "add", - epilog=""" - Examples:\n - \n - # Add streamable HTTP server:\n - kimi mcp add --transport http context7 https://mcp.context7.com/mcp --header \"CONTEXT7_API_KEY: ctx7sk-your-key\"\n - \n - # Add streamable HTTP server with OAuth authorization:\n - kimi mcp add --transport http --auth oauth linear https://mcp.linear.app/mcp\n - \n - # Add stdio server:\n - kimi mcp add --transport stdio chrome-devtools -- npx chrome-devtools-mcp@latest - """.strip(), # noqa: E501 -) -def mcp_add( - name: Annotated[ - str, - typer.Argument(help="Name of the MCP server to add."), - ], - server_args: Annotated[ - list[str] | None, - typer.Argument( - metavar="TARGET_OR_COMMAND...", - help="For http: server URL. For stdio: command to run (prefix with `--`).", - ), - ] = None, - transport: Annotated[ - Transport, - typer.Option( - "--transport", - "-t", - help="Transport type for the MCP server. Default: stdio.", - ), - ] = "stdio", - env: Annotated[ - list[str] | None, - typer.Option( - "--env", - "-e", - help="Environment variables in KEY=VALUE format. Can be specified multiple times.", - ), - ] = None, - header: Annotated[ - list[str] | None, - typer.Option( - "--header", - "-H", - help="HTTP headers in KEY:VALUE format. Can be specified multiple times.", - ), - ] = None, - auth: Annotated[ - str | None, - typer.Option( - "--auth", - "-a", - help="Authorization type (e.g., 'oauth').", - ), - ] = None, -): - """Add an MCP server.""" - config = _load_mcp_config() - server_args = server_args or [] - - if transport not in {"stdio", "http"}: - typer.echo(f"Unsupported transport: {transport}.", err=True) - raise typer.Exit(code=1) - - if transport == "stdio": - if not server_args: - typer.echo( - "For stdio transport, provide the command to start the MCP server after `--`.", - err=True, - ) - raise typer.Exit(code=1) - if header: - typer.echo("--header is only valid for http transport.", err=True) - raise typer.Exit(code=1) - if auth: - typer.echo("--auth is only valid for http transport.", err=True) - raise typer.Exit(code=1) - command, *command_args = server_args - server_config: dict[str, Any] = {"command": command, "args": command_args} - if env: - server_config["env"] = _parse_key_value_pairs(env, "env") - else: - if env: - typer.echo("--env is only supported for stdio transport.", err=True) - raise typer.Exit(code=1) - if not server_args: - typer.echo("URL is required for http transport.", err=True) - raise typer.Exit(code=1) - if len(server_args) > 1: - typer.echo( - "Multiple targets provided. Supply a single URL for http transport.", - err=True, - ) - raise typer.Exit(code=1) - server_config = {"url": server_args[0], "transport": "http"} - if header: - server_config["headers"] = _parse_key_value_pairs( - header, "header", separator=":", strip_whitespace=True - ) - if auth: - server_config["auth"] = auth - - if "mcpServers" not in config: - config["mcpServers"] = {} - config["mcpServers"][name] = server_config - _save_mcp_config(config) - typer.echo(f"Added MCP server '{name}' to {get_global_mcp_config_file()}.") - - -@cli.command("remove") -def mcp_remove( - name: Annotated[ - str, - typer.Argument(help="Name of the MCP server to remove."), - ], -): - """Remove an MCP server.""" - _get_mcp_server(name) - config = _load_mcp_config() - del config["mcpServers"][name] - _save_mcp_config(config) - typer.echo(f"Removed MCP server '{name}' from {get_global_mcp_config_file()}.") - - -def _has_oauth_tokens(server_url: str) -> bool: - """Check if OAuth tokens exist for the server.""" - import asyncio - - async def _check() -> bool: - try: - from fastmcp.client.auth.oauth import FileTokenStorage - - storage = FileTokenStorage(server_url=server_url) - tokens = await storage.get_tokens() - return tokens is not None - except Exception: - return False - - return asyncio.run(_check()) - - -@cli.command("list") -def mcp_list(): - """List all MCP servers.""" - config_file = get_global_mcp_config_file() - config = _load_mcp_config() - servers: dict[str, Any] = config.get("mcpServers", {}) - - typer.echo(f"MCP config file: {config_file}") - if not servers: - typer.echo("No MCP servers configured.") - return - - for name, server in servers.items(): - if "command" in server: - cmd = server["command"] - cmd_args = " ".join(server.get("args", [])) - line = f"{name} (stdio): {cmd} {cmd_args}".rstrip() - elif "url" in server: - transport = server.get("transport") or "http" - if transport == "streamable-http": - transport = "http" - line = f"{name} ({transport}): {server['url']}" - if server.get("auth") == "oauth" and not _has_oauth_tokens(server["url"]): - line += " [authorization required - run: kimi mcp auth " + name + "]" - else: - line = f"{name}: {server}" - typer.echo(f" {line}") - - -@cli.command("auth") -def mcp_auth( - name: Annotated[ - str, - typer.Argument(help="Name of the MCP server to authorize."), - ], -): - """Authorize with an OAuth-enabled MCP server.""" - import asyncio - - server = _get_mcp_server(name, require_remote=True) - if server.get("auth") != "oauth": - typer.echo(f"MCP server '{name}' does not use OAuth. Add with --auth oauth.", err=True) - raise typer.Exit(code=1) - - async def _auth() -> None: - import fastmcp - - typer.echo(f"Authorizing with '{name}'...") - typer.echo("A browser window will open for authorization.") - - client = fastmcp.Client({"mcpServers": {name: server}}) - try: - async with client: - tools = await client.list_tools() - typer.echo(f"Successfully authorized with '{name}'.") - typer.echo(f"Available tools: {len(tools)}") - except Exception as e: - typer.echo(f"Authorization failed: {type(e).__name__}: {e}", err=True) - raise typer.Exit(code=1) from None - - asyncio.run(_auth()) - - -@cli.command("reset-auth") -def mcp_reset_auth( - name: Annotated[ - str, - typer.Argument(help="Name of the MCP server to reset authorization."), - ], -): - """Reset OAuth authorization for an MCP server (clear cached tokens).""" - server = _get_mcp_server(name, require_remote=True) - - try: - from fastmcp.client.auth.oauth import FileTokenStorage - - storage = FileTokenStorage(server_url=server["url"]) - storage.clear() - typer.echo(f"OAuth tokens cleared for '{name}'.") - except ImportError: - typer.echo("OAuth support not available.", err=True) - raise typer.Exit(code=1) from None - except Exception as e: - typer.echo(f"Failed to clear tokens: {type(e).__name__}: {e}", err=True) - raise typer.Exit(code=1) from None - - -@cli.command("test") -def mcp_test( - name: Annotated[ - str, - typer.Argument(help="Name of the MCP server to test."), - ], -): - """Test connection to an MCP server and list available tools.""" - import asyncio - - server = _get_mcp_server(name) - - async def _test() -> None: - import fastmcp - - typer.echo(f"Testing connection to '{name}'...") - client = fastmcp.Client({"mcpServers": {name: server}}) - - try: - async with client: - tools = await client.list_tools() - typer.echo(f"✓ Connected to '{name}'") - typer.echo(f" Available tools: {len(tools)}") - if tools: - typer.echo(" Tools:") - for tool in tools: - desc = tool.description or "" - if len(desc) > 50: - desc = desc[:47] + "..." - typer.echo(f" - {tool.name}: {desc}") - except Exception as e: - typer.echo(f"✗ Connection failed: {type(e).__name__}: {e}", err=True) - raise typer.Exit(code=1) from None - - asyncio.run(_test()) diff --git a/src/kimi_cli/cli/plugin.py b/src/kimi_cli/cli/plugin.py deleted file mode 100644 index 4a3830ceb..000000000 --- a/src/kimi_cli/cli/plugin.py +++ /dev/null @@ -1,302 +0,0 @@ -"""CLI commands for plugin management.""" - -from __future__ import annotations - -from pathlib import Path -from typing import Annotated - -import typer - -from kimi_cli.plugin import PluginError - -cli = typer.Typer(help="Manage plugins.") - - -def _parse_git_url(target: str) -> tuple[str, str | None, str | None]: - """Parse a git URL into (clone_url, subpath, branch). - - Splits .git URLs at the .git boundary. For GitHub/GitLab short URLs, - treats the first two path segments as owner/repo and the rest as subpath. - Strips ``tree/{branch}/`` or ``-/tree/{branch}/`` prefixes from - browser-copied URLs and returns the branch name. - """ - # Path 1: URL contains .git followed by / or end-of-string - idx = target.find(".git/") - if idx == -1 and target.endswith(".git"): - return target, None, None - if idx != -1: - clone_url = target[: idx + 4] # up to and including ".git" - rest = target[idx + 5 :] # after ".git/" - subpath = rest.strip("/") or None - return clone_url, subpath, None - - # Path 2: GitHub/GitLab short URL (no .git) - from urllib.parse import urlparse - - parsed = urlparse(target) - segments = [s for s in parsed.path.split("/") if s] - if len(segments) < 2: - return target, None, None - - owner_repo = "/".join(segments[:2]) - clone_url = f"{parsed.scheme}://{parsed.netloc}/{owner_repo}" - rest_segments = segments[2:] - - # GitLab uses /-/tree/{branch}/, strip leading "-" - if rest_segments and rest_segments[0] == "-": - rest_segments = rest_segments[1:] - - # Strip tree/{branch}/ prefix and extract branch - branch: str | None = None - if len(rest_segments) >= 2 and rest_segments[0] == "tree": - branch = rest_segments[1] - rest_segments = rest_segments[2:] - - subpath = "/".join(rest_segments) or None - return clone_url, subpath, branch - - -def _resolve_source(target: str) -> tuple[Path, Path | None]: - """Resolve plugin source to (local_dir, tmp_to_cleanup). - - Returns the source directory and an optional temp directory that - the caller must clean up after use. - """ - import shutil - import tempfile - - # Git URL - if target.startswith(("https://", "git@", "http://")) and ( - ".git/" in target - or target.endswith(".git") - or "github.com/" in target - or "gitlab.com/" in target - ): - import subprocess - - clone_url, subpath, branch = _parse_git_url(target) - - tmp = Path(tempfile.mkdtemp(prefix="kimi-plugin-")) - typer.echo(f"Cloning {clone_url}...") - clone_cmd = ["git", "clone", "--depth", "1"] - if branch: - clone_cmd += ["--branch", branch] - clone_cmd += [clone_url, str(tmp / "repo")] - result = subprocess.run( - clone_cmd, - capture_output=True, - text=True, - ) - if result.returncode != 0: - shutil.rmtree(tmp, ignore_errors=True) - typer.echo( - f"Error: git clone failed: {result.stderr.strip()}", - err=True, - ) - raise typer.Exit(1) - - repo_root = tmp / "repo" - - if subpath: - source = (repo_root / subpath).resolve() - if not source.is_relative_to(repo_root.resolve()): - shutil.rmtree(tmp, ignore_errors=True) - typer.echo( - f"Error: subpath escapes repository: {subpath}", - err=True, - ) - raise typer.Exit(1) - if not source.is_dir(): - shutil.rmtree(tmp, ignore_errors=True) - typer.echo( - f"Error: subpath '{subpath}' not found in repository", - err=True, - ) - raise typer.Exit(1) - if not (source / "plugin.json").exists(): - shutil.rmtree(tmp, ignore_errors=True) - typer.echo( - f"Error: no plugin.json in '{subpath}'", - err=True, - ) - raise typer.Exit(1) - return source, tmp - - # No subpath — check root first - if (repo_root / "plugin.json").exists(): - return repo_root, tmp - - # Scan one level for available plugins - available = sorted( - d.name for d in repo_root.iterdir() if d.is_dir() and (d / "plugin.json").exists() - ) - if available: - names = "\n".join(f" - {n}" for n in available) - typer.echo( - f"Error: No plugin.json at repository root. " - f"Available plugins:\n{names}\n" - f"Use: kimi plugin install /", - err=True, - ) - else: - typer.echo( - "Error: No plugin.json found in repository", - err=True, - ) - shutil.rmtree(tmp, ignore_errors=True) - raise typer.Exit(1) - - p = Path(target).expanduser().resolve() - - # Zip file - if p.is_file() and p.suffix == ".zip": - import zipfile - - tmp = Path(tempfile.mkdtemp(prefix="kimi-plugin-")) - typer.echo(f"Extracting {p.name}...") - with zipfile.ZipFile(p, "r") as zf: - # Reject zip members that escape the extraction directory - for member in zf.namelist(): - member_path = (tmp / member).resolve() - if not member_path.is_relative_to(tmp.resolve()): - shutil.rmtree(tmp, ignore_errors=True) - typer.echo(f"Error: zip contains unsafe path: {member}", err=True) - raise typer.Exit(1) - zf.extractall(tmp) - # Find the directory containing plugin.json (may be nested one level) - for candidate in [tmp] + sorted(tmp.iterdir()): - if candidate.is_dir() and (candidate / "plugin.json").exists(): - return candidate, tmp - # Check for __MACOSX and similar artifacts - dirs = [d for d in tmp.iterdir() if d.is_dir() and not d.name.startswith("_")] - if len(dirs) == 1 and (dirs[0] / "plugin.json").exists(): - return dirs[0], tmp - shutil.rmtree(tmp, ignore_errors=True) - typer.echo("Error: No plugin.json found in zip", err=True) - raise typer.Exit(1) - - # Local directory - if p.is_dir(): - return p, None - - typer.echo(f"Error: {target} is not a directory, zip file, or git URL", err=True) - raise typer.Exit(1) - - -@cli.command("install") -def install_cmd( - target: Annotated[str, typer.Argument(help="Plugin source: directory, .zip, or git URL")], -) -> None: - """Install a plugin and inject host configuration.""" - import shutil - - from kimi_cli.config import load_config - from kimi_cli.constant import VERSION - from kimi_cli.plugin.manager import get_plugins_dir, install_plugin - - source, tmp_dir = _resolve_source(target) - - try: - config = load_config() - - from kimi_cli.auth.oauth import OAuthManager - from kimi_cli.llm import augment_provider_with_env_vars - from kimi_cli.plugin.manager import collect_host_values - - # Apply env var overrides (install runs outside normal startup) - if config.default_model and config.default_model in config.models: - model = config.models[config.default_model] - if model.provider in config.providers: - augment_provider_with_env_vars(config.providers[model.provider], model) - - oauth = OAuthManager(config) - host_values = collect_host_values(config, oauth) - - if not host_values.get("api_key"): - typer.echo( - "Warning: No LLM provider configured. " - "Plugins requiring API key injection will fail. " - "Run 'kimi login' or configure a provider first.", - err=True, - ) - - spec = install_plugin( - source=source, - plugins_dir=get_plugins_dir(), - host_values=host_values, - host_name="kimi-code", - host_version=VERSION, - ) - except PluginError as exc: - typer.echo(f"Error: {exc}", err=True) - raise typer.Exit(1) from exc - finally: - # Clean up temp directory from zip/git extraction - if tmp_dir is not None: - shutil.rmtree(tmp_dir, ignore_errors=True) - - typer.echo(f"Installed plugin '{spec.name}' v{spec.version}") - if spec.runtime: - typer.echo(f" runtime: host={spec.runtime.host}, version={spec.runtime.host_version}") - - -@cli.command("list") -def list_cmd() -> None: - """List installed plugins.""" - from kimi_cli.plugin.manager import get_plugins_dir, list_plugins - - plugins = list_plugins(get_plugins_dir()) - if not plugins: - typer.echo("No plugins installed.") - return - - for p in plugins: - status = "installed" if p.runtime else "not configured" - typer.echo(f" {p.name} v{p.version} ({status})") - - -@cli.command("remove") -def remove_cmd( - name: Annotated[str, typer.Argument(help="Plugin name to remove")], -) -> None: - """Remove an installed plugin.""" - from kimi_cli.plugin.manager import get_plugins_dir, remove_plugin - - try: - remove_plugin(name, get_plugins_dir()) - except PluginError as exc: - typer.echo(f"Error: {exc}", err=True) - raise typer.Exit(1) from exc - - typer.echo(f"Removed plugin '{name}'") - - -@cli.command("info") -def info_cmd( - name: Annotated[str, typer.Argument(help="Plugin name")], -) -> None: - """Show plugin details.""" - from kimi_cli.plugin import parse_plugin_json - from kimi_cli.plugin.manager import get_plugins_dir - - plugin_json = get_plugins_dir() / name / "plugin.json" - if not plugin_json.exists(): - typer.echo(f"Error: Plugin '{name}' not found", err=True) - raise typer.Exit(1) - - try: - spec = parse_plugin_json(plugin_json) - except PluginError as exc: - typer.echo(f"Error: {exc}", err=True) - raise typer.Exit(1) from exc - - typer.echo(f"Name: {spec.name}") - typer.echo(f"Version: {spec.version}") - typer.echo(f"Description: {spec.description or '(none)'}") - typer.echo(f"Config file: {spec.config_file or '(none)'}") - if spec.inject: - typer.echo(f"Inject: {', '.join(f'{k} <- {v}' for k, v in spec.inject.items())}") - if spec.runtime: - typer.echo(f"Runtime: host={spec.runtime.host}, version={spec.runtime.host_version}") - else: - typer.echo("Runtime: (not installed via host)") diff --git a/src/kimi_cli/cli/toad.py b/src/kimi_cli/cli/toad.py deleted file mode 100644 index b57766f04..000000000 --- a/src/kimi_cli/cli/toad.py +++ /dev/null @@ -1,73 +0,0 @@ -import importlib.util -import shlex -import shutil -import subprocess -import sys -from pathlib import Path - -import typer - - -def _default_acp_command() -> list[str]: - argv0 = sys.argv[0] - if argv0: - resolved = shutil.which(argv0) - resolved_path = Path(resolved).expanduser() if resolved else Path(argv0).expanduser() - if ( - resolved_path.exists() - and resolved_path.suffix != ".py" - and not resolved_path.name.startswith(("python", "pypy")) - ): - return [str(resolved_path), "acp"] - - return [sys.executable, "-m", "kimi_cli.cli", "acp"] - - -def _default_toad_command() -> list[str]: - if sys.version_info < (3, 14): - typer.echo("`kimi term` requires Python 3.14+ because Toad requires it.", err=True) - raise typer.Exit(code=1) - if importlib.util.find_spec("toad") is None: - typer.echo( - "Toad dependency is missing. Install kimi-cli with Python 3.14+ to use `kimi term`.", - err=True, - ) - raise typer.Exit(code=1) - return [sys.executable, "-m", "toad.cli"] - - -def _extract_project_dir(extra_args: list[str]) -> Path | None: - work_dir: str | None = None - idx = 0 - while idx < len(extra_args): - arg = extra_args[idx] - if arg in ("--work-dir", "-w"): - if idx + 1 < len(extra_args): - work_dir = extra_args[idx + 1] - idx += 2 - continue - elif arg.startswith("--work-dir=") or arg.startswith("-w="): - work_dir = arg.split("=", 1)[1] - elif arg.startswith("-w") and len(arg) > 2: - work_dir = arg[2:] - idx += 1 - - if not work_dir: - return None - - return Path(work_dir).expanduser().resolve() - - -def run_term(ctx: typer.Context) -> None: - extra_args = list(ctx.args) - acp_args = _default_acp_command() - acp_command = shlex.join(acp_args) - toad_parts = _default_toad_command() - args = [*toad_parts, "acp", acp_command] - project_dir = _extract_project_dir(extra_args) - if project_dir is not None: - args.append(str(project_dir)) - - result = subprocess.run(args) - if result.returncode != 0: - raise typer.Exit(code=result.returncode) diff --git a/src/kimi_cli/cli/vis.py b/src/kimi_cli/cli/vis.py deleted file mode 100644 index 4ce95c485..000000000 --- a/src/kimi_cli/cli/vis.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Vis command for Kimi Agent Tracing Visualizer.""" - -from typing import Annotated - -import typer - -cli = typer.Typer(help="Run Kimi Agent Tracing Visualizer.") - - -@cli.callback(invoke_without_command=True) -def vis( - ctx: typer.Context, - host: Annotated[ - str | None, - typer.Option("--host", "-h", help="Bind to specific IP address"), - ] = None, - network: Annotated[ - bool, - typer.Option("--network", "-n", help="Enable network access (bind to 0.0.0.0)"), - ] = False, - port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 5495, - open_browser: Annotated[ - bool, typer.Option("--open/--no-open", help="Open browser automatically") - ] = True, - reload: Annotated[bool, typer.Option("--reload", help="Enable auto-reload")] = False, -): - """Launch the agent tracing visualizer.""" - from kimi_cli.vis.app import run_vis_server - - # Determine bind address (same logic as kimi web) - if host: - bind_host = host - elif network: - bind_host = "0.0.0.0" - else: - bind_host = "127.0.0.1" - - run_vis_server(host=bind_host, port=port, open_browser=open_browser, reload=reload) diff --git a/src/kimi_cli/cli/web.py b/src/kimi_cli/cli/web.py deleted file mode 100644 index d81731944..000000000 --- a/src/kimi_cli/cli/web.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Web UI command for Kimi Code CLI.""" - -from typing import Annotated - -import typer - -cli = typer.Typer(help="Run Kimi Code CLI web interface.") - - -@cli.callback(invoke_without_command=True) -def web( - ctx: typer.Context, - host: Annotated[ - str | None, - typer.Option("--host", "-h", help="Bind to specific IP address"), - ] = None, - network: Annotated[ - bool, - typer.Option("--network", "-n", help="Enable network access (bind to 0.0.0.0)"), - ] = False, - port: Annotated[int, typer.Option("--port", "-p", help="Port to bind to")] = 5494, - reload: Annotated[bool, typer.Option("--reload", help="Enable auto-reload")] = False, - open_browser: Annotated[ - bool, typer.Option("--open/--no-open", help="Open browser automatically") - ] = True, - auth_token: Annotated[ - str | None, - typer.Option("--auth-token", help="Bearer token for API authentication."), - ] = None, - allowed_origins: Annotated[ - str | None, - typer.Option( - "--allowed-origins", - help="Comma-separated list of allowed Origin values.", - ), - ] = None, - dangerously_omit_auth: Annotated[ - bool, - typer.Option( - "--dangerously-omit-auth", - help="Disable auth checks (dangerous in public networks).", - ), - ] = False, - restrict_sensitive_apis: Annotated[ - bool | None, - typer.Option( - "--restrict-sensitive-apis/--no-restrict-sensitive-apis", - help="Disable sensitive APIs (config write, open-in, file access limits).", - ), - ] = None, - lan_only: Annotated[ - bool, - typer.Option( - "--lan-only/--public", - help="Only allow access from local network (default) or allow public access.", - ), - ] = True, -): - """Run Kimi Code CLI web interface.""" - from kimi_cli.web.app import run_web_server - - # Determine bind address - if host: - bind_host = host - elif network: - bind_host = "0.0.0.0" - else: - bind_host = "127.0.0.1" - - run_web_server( - host=bind_host, - port=port, - reload=reload, - open_browser=open_browser, - auth_token=auth_token, - allowed_origins=allowed_origins, - dangerously_omit_auth=dangerously_omit_auth, - restrict_sensitive_apis=restrict_sensitive_apis, - lan_only=lan_only, - ) diff --git a/src/kimi_cli/config.py b/src/kimi_cli/config.py deleted file mode 100644 index 50e8a50f4..000000000 --- a/src/kimi_cli/config.py +++ /dev/null @@ -1,377 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path -from typing import Literal, Self - -import tomlkit -from pydantic import ( - AliasChoices, - BaseModel, - Field, - SecretStr, - ValidationError, - field_serializer, - model_validator, -) -from tomlkit.exceptions import TOMLKitError - -from kimi_cli.exception import ConfigError -from kimi_cli.hooks.config import HookDef -from kimi_cli.llm import ModelCapability, ProviderType -from kimi_cli.share import get_share_dir -from kimi_cli.utils.logging import logger - - -class OAuthRef(BaseModel): - """Reference to OAuth credentials stored outside the config file.""" - - storage: Literal["keyring", "file"] = "file" - """Credential storage backend.""" - key: str - """Storage key to locate OAuth credentials.""" - - -class LLMProvider(BaseModel): - """LLM provider configuration.""" - - type: ProviderType - """Provider type""" - base_url: str - """API base URL""" - api_key: SecretStr - """API key""" - env: dict[str, str] | None = None - """Environment variables to set before creating the provider instance""" - custom_headers: dict[str, str] | None = None - """Custom headers to include in API requests""" - oauth: OAuthRef | None = None - """OAuth credential reference (do not store tokens here).""" - - @field_serializer("api_key", when_used="json") - def dump_secret(self, v: SecretStr): - return v.get_secret_value() - - -class LLMModel(BaseModel): - """LLM model configuration.""" - - provider: str - """Provider name""" - model: str - """Model name""" - max_context_size: int - """Maximum context size (unit: tokens)""" - capabilities: set[ModelCapability] | None = None - """Model capabilities""" - - -class LoopControl(BaseModel): - """Agent loop control configuration.""" - - max_steps_per_turn: int = Field( - default=100, - ge=1, - validation_alias=AliasChoices("max_steps_per_turn", "max_steps_per_run"), - ) - """Maximum number of steps in one turn""" - max_retries_per_step: int = Field(default=3, ge=1) - """Maximum number of retries in one step""" - max_ralph_iterations: int = Field(default=0, ge=-1) - """Extra iterations after the first turn in Ralph mode. Use -1 for unlimited.""" - reserved_context_size: int = Field(default=50_000, ge=1000) - """Reserved token count for LLM response generation. Auto-compaction triggers when - either context_tokens + reserved_context_size >= max_context_size or - context_tokens >= max_context_size * compaction_trigger_ratio. Default is 50000.""" - compaction_trigger_ratio: float = Field(default=0.85, ge=0.5, le=0.99) - """Context usage ratio threshold for auto-compaction. Default is 0.85 (85%). - Auto-compaction triggers when context_tokens >= max_context_size * compaction_trigger_ratio - or when context_tokens + reserved_context_size >= max_context_size.""" - - -class BackgroundConfig(BaseModel): - """Background task runtime configuration.""" - - max_running_tasks: int = Field(default=4, ge=1) - read_max_bytes: int = Field(default=30_000, ge=1024) - notification_tail_lines: int = Field(default=20, ge=1) - notification_tail_chars: int = Field(default=3_000, ge=256) - wait_poll_interval_ms: int = Field(default=500, ge=50) - worker_heartbeat_interval_ms: int = Field(default=5_000, ge=100) - worker_stale_after_ms: int = Field(default=15_000, ge=1000) - kill_grace_period_ms: int = Field(default=2_000, ge=100) - keep_alive_on_exit: bool = Field( - default=False, - description="Keep background tasks alive when CLI exits. Default: kill on exit.", - ) - agent_task_timeout_s: int = Field(default=900, ge=60) - """Maximum runtime in seconds for a background agent task. Default: 900 (15 min).""" - - -class NotificationConfig(BaseModel): - """Notification runtime configuration.""" - - claim_stale_after_ms: int = Field(default=15_000, ge=1000) - - -class MoonshotSearchConfig(BaseModel): - """Moonshot Search configuration.""" - - base_url: str - """Base URL for Moonshot Search service.""" - api_key: SecretStr - """API key for Moonshot Search service.""" - custom_headers: dict[str, str] | None = None - """Custom headers to include in API requests.""" - oauth: OAuthRef | None = None - """OAuth credential reference (do not store tokens here).""" - - @field_serializer("api_key", when_used="json") - def dump_secret(self, v: SecretStr): - return v.get_secret_value() - - -class MoonshotFetchConfig(BaseModel): - """Moonshot Fetch configuration.""" - - base_url: str - """Base URL for Moonshot Fetch service.""" - api_key: SecretStr - """API key for Moonshot Fetch service.""" - custom_headers: dict[str, str] | None = None - """Custom headers to include in API requests.""" - oauth: OAuthRef | None = None - """OAuth credential reference (do not store tokens here).""" - - @field_serializer("api_key", when_used="json") - def dump_secret(self, v: SecretStr): - return v.get_secret_value() - - -class Services(BaseModel): - """Services configuration.""" - - moonshot_search: MoonshotSearchConfig | None = None - """Moonshot Search configuration.""" - moonshot_fetch: MoonshotFetchConfig | None = None - """Moonshot Fetch configuration.""" - - -class MCPClientConfig(BaseModel): - """MCP client configuration.""" - - tool_call_timeout_ms: int = 60000 - """Timeout for tool calls in milliseconds.""" - - -class MCPConfig(BaseModel): - """MCP configuration.""" - - client: MCPClientConfig = Field( - default_factory=MCPClientConfig, description="MCP client configuration" - ) - - -class Config(BaseModel): - """Main configuration structure.""" - - is_from_default_location: bool = Field( - default=False, - description="Whether the config was loaded from the default location", - exclude=True, - ) - source_file: Path | None = Field( - default=None, - description="Path to the loaded config file. None when loaded from --config text.", - exclude=True, - ) - default_model: str = Field(default="", description="Default model to use") - default_thinking: bool = Field(default=False, description="Default thinking mode") - default_yolo: bool = Field(default=False, description="Default yolo (auto-approve) mode") - default_editor: str = Field( - default="", - description="Default external editor command (e.g. 'vim', 'code --wait')", - ) - theme: Literal["dark", "light"] = Field( - default="dark", - description="Terminal color theme. Use 'light' for light terminal backgrounds.", - ) - models: dict[str, LLMModel] = Field(default_factory=dict, description="List of LLM models") - providers: dict[str, LLMProvider] = Field( - default_factory=dict, description="List of LLM providers" - ) - loop_control: LoopControl = Field(default_factory=LoopControl, description="Agent loop control") - background: BackgroundConfig = Field( - default_factory=BackgroundConfig, description="Background task configuration" - ) - notifications: NotificationConfig = Field( - default_factory=NotificationConfig, description="Notification configuration" - ) - services: Services = Field(default_factory=Services, description="Services configuration") - mcp: MCPConfig = Field(default_factory=MCPConfig, description="MCP configuration") - hooks: list[HookDef] = Field(default_factory=list, description="Hook definitions") # pyright: ignore[reportUnknownVariableType] - - @model_validator(mode="after") - def validate_model(self) -> Self: - if self.default_model and self.default_model not in self.models: - raise ValueError(f"Default model {self.default_model} not found in models") - for model in self.models.values(): - if model.provider not in self.providers: - raise ValueError(f"Provider {model.provider} not found in providers") - return self - - -def get_config_file() -> Path: - """Get the configuration file path.""" - return get_share_dir() / "config.toml" - - -def get_default_config() -> Config: - """Get the default configuration.""" - return Config( - default_model="", - models={}, - providers={}, - services=Services(), - ) - - -def load_config(config_file: Path | None = None) -> Config: - """ - Load configuration from config file. - If the config file does not exist, create it with default configuration. - - Args: - config_file (Path | None): Path to the configuration file. If None, use default path. - - Returns: - Validated Config object. - - Raises: - ConfigError: If the configuration file is invalid. - """ - default_config_file = get_config_file().expanduser().resolve(strict=False) - if config_file is None: - config_file = default_config_file - config_file = config_file.expanduser().resolve(strict=False) - is_default_config_file = config_file == default_config_file - logger.debug("Loading config from file: {file}", file=config_file) - - # If the user hasn't provided an explicit config path, migrate legacy JSON config once. - if is_default_config_file and not config_file.exists(): - _migrate_json_config_to_toml() - - if not config_file.exists(): - config = get_default_config() - logger.debug("No config file found, creating default config: {config}", config=config) - save_config(config, config_file) - config.is_from_default_location = is_default_config_file - config.source_file = config_file - return config - - try: - config_text = config_file.read_text(encoding="utf-8") - if config_file.suffix.lower() == ".json": - data = json.loads(config_text) - else: - data = tomlkit.loads(config_text) - config = Config.model_validate(data) - except json.JSONDecodeError as e: - raise ConfigError(f"Invalid JSON in configuration file {config_file}: {e}") from e - except TOMLKitError as e: - raise ConfigError(f"Invalid TOML in configuration file {config_file}: {e}") from e - except ValidationError as e: - raise ConfigError(f"Invalid configuration file {config_file}: {e}") from e - config.is_from_default_location = is_default_config_file - config.source_file = config_file - return config - - -def load_config_from_string(config_string: str) -> Config: - """ - Load configuration from a TOML or JSON string. - - Args: - config_string (str): TOML or JSON configuration text. - - Returns: - Validated Config object. - - Raises: - ConfigError: If the configuration text is invalid. - """ - if not config_string.strip(): - raise ConfigError("Configuration text cannot be empty") - - json_error: json.JSONDecodeError | None = None - try: - data = json.loads(config_string) - except json.JSONDecodeError as exc: - json_error = exc - data = None - - if data is None: - try: - data = tomlkit.loads(config_string) - except TOMLKitError as toml_error: - raise ConfigError( - f"Invalid configuration text: {json_error}; {toml_error}" - ) from toml_error - - try: - config = Config.model_validate(data) - except ValidationError as e: - raise ConfigError(f"Invalid configuration text: {e}") from e - config.is_from_default_location = False - config.source_file = None - return config - - -def save_config(config: Config, config_file: Path | None = None): - """ - Save configuration to config file. - - Args: - config (Config): Config object to save. - config_file (Path | None): Path to the configuration file. If None, use default path. - """ - config_file = config_file or get_config_file() - logger.debug("Saving config to file: {file}", file=config_file) - config_file.parent.mkdir(parents=True, exist_ok=True) - config_data = config.model_dump(mode="json", exclude_none=True) - with open(config_file, "w", encoding="utf-8") as f: - if config_file.suffix.lower() == ".json": - f.write(json.dumps(config_data, ensure_ascii=False, indent=2)) - else: - f.write(tomlkit.dumps(config_data)) # type: ignore[reportUnknownMemberType] - - -def _migrate_json_config_to_toml() -> None: - old_json_config_file = get_share_dir() / "config.json" - new_toml_config_file = get_share_dir() / "config.toml" - - if not old_json_config_file.exists(): - return - if new_toml_config_file.exists(): - return - - logger.info( - "Migrating legacy config file from {old} to {new}", - old=old_json_config_file, - new=new_toml_config_file, - ) - - try: - with open(old_json_config_file, encoding="utf-8") as f: - data = json.load(f) - config = Config.model_validate(data) - except json.JSONDecodeError as e: - raise ConfigError(f"Invalid JSON in legacy configuration file: {e}") from e - except ValidationError as e: - raise ConfigError(f"Invalid legacy configuration file: {e}") from e - - # Write new TOML config, then keep a backup of the original JSON file. - save_config(config, new_toml_config_file) - backup_path = old_json_config_file.with_name("config.json.bak") - old_json_config_file.replace(backup_path) - logger.info("Legacy config backed up to {file}", file=backup_path) diff --git a/src/kimi_cli/config.ts b/src/kimi_cli/config.ts new file mode 100644 index 000000000..d4c98b457 --- /dev/null +++ b/src/kimi_cli/config.ts @@ -0,0 +1,303 @@ +/** + * Configuration module — corresponds to Python config.py + * Loads/saves TOML config with Zod validation. + */ + +import { z } from "zod/v4"; +import TOML from "@iarna/toml"; +import { ModelCapability } from "./types.ts"; + +// ── Sub-schemas ───────────────────────────────────────── + +export const OAuthRef = z.object({ + storage: z.enum(["keyring", "file"]).default("file"), + key: z.string(), +}); +export type OAuthRef = z.infer; + +export const ProviderType = z.enum([ + "kimi", + "openai_legacy", + "openai_responses", + "anthropic", + "google_genai", + "gemini", + "vertexai", + "_echo", + "_scripted_echo", + "_chaos", +]); +export type ProviderType = z.infer; + +export const LLMProvider = z.object({ + type: ProviderType, + base_url: z.string(), + api_key: z.string(), + env: z.record(z.string(), z.string()).optional(), + custom_headers: z.record(z.string(), z.string()).optional(), + oauth: OAuthRef.optional(), +}); +export type LLMProvider = z.infer; + +export const LLMModel = z.object({ + provider: z.string(), + model: z.string(), + max_context_size: z.number().int(), + capabilities: z.array(ModelCapability).optional(), +}); +export type LLMModel = z.infer; + +export const LoopControl = z.object({ + max_steps_per_turn: z.number().int().min(1).default(100), + max_retries_per_step: z.number().int().min(1).default(3), + max_ralph_iterations: z.number().int().min(-1).default(0), + reserved_context_size: z.number().int().min(1000).default(50_000), + compaction_trigger_ratio: z.number().min(0.5).max(0.99).default(0.85), +}); +export type LoopControl = z.infer; + +export const BackgroundConfig = z.object({ + max_running_tasks: z.number().int().min(1).default(4), + read_max_bytes: z.number().int().min(1024).default(30_000), + notification_tail_lines: z.number().int().min(1).default(20), + notification_tail_chars: z.number().int().min(256).default(3_000), + wait_poll_interval_ms: z.number().int().min(50).default(500), + worker_heartbeat_interval_ms: z.number().int().min(100).default(5_000), + worker_stale_after_ms: z.number().int().min(1000).default(15_000), + kill_grace_period_ms: z.number().int().min(100).default(2_000), + keep_alive_on_exit: z.boolean().default(false), + agent_task_timeout_s: z.number().int().min(60).default(900), +}); +export type BackgroundConfig = z.infer; + +export const NotificationConfig = z.object({ + claim_stale_after_ms: z.number().int().min(1000).default(15_000), +}); +export type NotificationConfig = z.infer; + +export const MoonshotSearchConfig = z.object({ + base_url: z.string(), + api_key: z.string(), + custom_headers: z.record(z.string(), z.string()).optional(), + oauth: OAuthRef.optional(), +}); +export type MoonshotSearchConfig = z.infer; + +export const MoonshotFetchConfig = z.object({ + base_url: z.string(), + api_key: z.string(), + custom_headers: z.record(z.string(), z.string()).optional(), + oauth: OAuthRef.optional(), +}); +export type MoonshotFetchConfig = z.infer; + +export const Services = z.object({ + moonshot_search: MoonshotSearchConfig.optional(), + moonshot_fetch: MoonshotFetchConfig.optional(), +}); +export type Services = z.infer; + +export const MCPClientConfig = z.object({ + tool_call_timeout_ms: z.number().int().default(60000), +}); +export type MCPClientConfig = z.infer; + +export const MCPConfig = z.object({ + client: MCPClientConfig.default({} as any), +}); +export type MCPConfig = z.infer; + +export const HookEventType = z.enum([ + "PreToolUse", + "PostToolUse", + "PostToolUseFailure", + "UserPromptSubmit", + "Stop", + "StopFailure", + "SessionStart", + "SessionEnd", + "SubagentStart", + "SubagentStop", + "PreCompact", + "PostCompact", + "Notification", +]); +export type HookEventType = z.infer; + +export const HookDef = z.object({ + event: HookEventType, + command: z.string(), + matcher: z.string().default(""), + timeout: z.number().int().min(1).max(600).default(30), +}); +export type HookDef = z.infer; + +export const Config = z + .object({ + default_model: z.string().default(""), + default_thinking: z.boolean().default(false), + default_yolo: z.boolean().default(false), + default_editor: z.string().default(""), + theme: z.enum(["dark", "light"]).default("dark"), + models: z.record(z.string(), LLMModel).default({}), + providers: z.record(z.string(), LLMProvider).default({}), + loop_control: LoopControl.default({} as any), + background: BackgroundConfig.default({} as any), + notifications: NotificationConfig.default({} as any), + services: Services.default({}), + mcp: MCPConfig.default({} as any), + hooks: z.array(HookDef).default([]), + }) + .refine( + (cfg) => { + if (cfg.default_model && !(cfg.default_model in cfg.models)) return false; + for (const m of Object.values(cfg.models) as LLMModel[]) { + if (!(m.provider in cfg.providers)) return false; + } + return true; + }, + { message: "default_model must exist in models, and all model providers must exist in providers" }, + ); + +export type Config = z.infer; + +/** Runtime metadata attached after loading (not persisted). */ +export interface ConfigMeta { + isFromDefaultLocation: boolean; + sourceFile: string | null; +} + +// ── Paths ─────────────────────────────────────────────── + +import { homedir } from "node:os"; +import { join, resolve } from "node:path"; +import { ConfigError } from "./exception.ts"; + +export { ConfigError }; + +export function getShareDir(): string { + return process.env.KIMI_SHARE_DIR ?? join(homedir(), ".kimi"); +} + +export function getConfigFile(): string { + return join(getShareDir(), "config.toml"); +} + +// ── Secret masking helper ──────────────────────────────── + +/** Mask a secret string for safe logging (shows first 4 chars + ***). */ +export function maskSecret(value: string): string { + if (!value || value.length <= 4) return "***"; + return value.slice(0, 4) + "***"; +} + +// ── JSON → TOML migration ─────────────────────────────── + +async function migrateJsonConfigToToml(): Promise { + const oldJsonConfigFile = join(getShareDir(), "config.json"); + const newTomlConfigFile = join(getShareDir(), "config.toml"); + + const oldFile = Bun.file(oldJsonConfigFile); + const newFile = Bun.file(newTomlConfigFile); + if (!(await oldFile.exists())) return; + if (await newFile.exists()) return; + + try { + const data = await oldFile.json(); + const config = Config.parse(data); + await saveConfig(config, newTomlConfigFile); + // Backup old file + const backupPath = oldJsonConfigFile.replace(/\.json$/, ".json.bak"); + await Bun.$`mv ${oldJsonConfigFile} ${backupPath}`.quiet(); + } catch (err) { + // If migration fails, continue with default config + } +} + +export function getDefaultConfig(): Config { + return Config.parse({}); +} + +export async function loadConfig( + configFile?: string, +): Promise<{ config: Config; meta: ConfigMeta }> { + const defaultConfigFile = resolve(getConfigFile()); + const resolvedPath = configFile ? resolve(configFile) : defaultConfigFile; + const isDefault = resolvedPath === defaultConfigFile; + + // If using default config and it doesn't exist, try migrating from JSON + if (isDefault) { + const file = Bun.file(resolvedPath); + if (!(await file.exists())) { + await migrateJsonConfigToToml(); + } + } + + const file = Bun.file(resolvedPath); + if (!(await file.exists())) { + const config = getDefaultConfig(); + await saveConfig(config, resolvedPath); + return { config, meta: { isFromDefaultLocation: isDefault, sourceFile: resolvedPath } }; + } + + try { + const text = await file.text(); + let data: unknown; + if (resolvedPath.toLowerCase().endsWith(".json")) { + data = JSON.parse(text); + } else { + const rawData = TOML.parse(text); + // @iarna/toml adds Symbol properties that break Zod validation — strip them via JSON roundtrip + data = JSON.parse(JSON.stringify(rawData)); + } + const config = Config.parse(data); + + // Environment variable overrides + if (process.env.KIMI_MODEL_NAME) config.default_model = process.env.KIMI_MODEL_NAME; + + return { config, meta: { isFromDefaultLocation: isDefault, sourceFile: resolvedPath } }; + } catch (err) { + if (err instanceof z.ZodError) { + throw new ConfigError(`Invalid configuration file ${resolvedPath}: ${err.message}`); + } + throw new ConfigError(`Failed to parse configuration file ${resolvedPath}: ${err}`); + } +} + +export async function loadConfigFromString(text: string): Promise<{ config: Config; meta: ConfigMeta }> { + if (!text.trim()) throw new ConfigError("Configuration text cannot be empty"); + + let data: unknown; + try { + data = JSON.parse(text); + } catch { + try { + data = TOML.parse(text); + } catch (tomlErr) { + throw new ConfigError(`Invalid configuration text: ${tomlErr}`); + } + } + + try { + const config = Config.parse(data); + return { config, meta: { isFromDefaultLocation: false, sourceFile: null } }; + } catch (err) { + throw new ConfigError(`Invalid configuration text: ${err}`); + } +} + +export async function saveConfig(config: Config, configFile?: string): Promise { + const filePath = configFile ?? getConfigFile(); + const dir = filePath.substring(0, filePath.lastIndexOf("/")); + await Bun.$`mkdir -p ${dir}`.quiet(); + + // Strip undefined/null values for clean output + const data = JSON.parse(JSON.stringify(config)); + + if (filePath.toLowerCase().endsWith(".json")) { + await Bun.write(filePath, JSON.stringify(data, null, 2)); + } else { + const tomlStr = TOML.stringify(data as any); + await Bun.write(filePath, tomlStr); + } +} diff --git a/src/kimi_cli/constant.py b/src/kimi_cli/constant.py deleted file mode 100644 index 84458fe67..000000000 --- a/src/kimi_cli/constant.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from functools import cache -from typing import TYPE_CHECKING - -NAME = "Kimi Code CLI" - -if TYPE_CHECKING: - VERSION: str - USER_AGENT: str - - -@cache -def get_version() -> str: - from importlib import metadata - - return metadata.version("kimi-cli") - - -@cache -def get_user_agent() -> str: - return f"KimiCLI/{get_version()}" - - -def __getattr__(name: str) -> str: - if name == "VERSION": - return get_version() - if name == "USER_AGENT": - return get_user_agent() - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = ["NAME", "VERSION", "USER_AGENT", "get_version", "get_user_agent"] diff --git a/src/kimi_cli/constant.ts b/src/kimi_cli/constant.ts new file mode 100644 index 000000000..57847f2c9 --- /dev/null +++ b/src/kimi_cli/constant.ts @@ -0,0 +1,30 @@ +/** + * Constants module — corresponds to Python constant.py + * Exports NAME, VERSION, USER_AGENT, and helper functions. + */ + +import { join } from "node:path"; + +export const NAME = "Kimi Code CLI"; + +let _version: string | null = null; + +export function getVersion(): string { + if (_version) return _version; + try { + // Read version from package.json at build/runtime + const pkgPath = join(import.meta.dir, "../../../package.json"); + const pkg = require(pkgPath); + _version = String(pkg.version ?? "0.0.0"); + } catch { + _version = "0.0.0"; + } + return _version; +} + +export function getUserAgent(): string { + return `KimiCLI/${getVersion()}`; +} + +export const VERSION: string = getVersion(); +export const USER_AGENT: string = getUserAgent(); diff --git a/src/kimi_cli/deps/Makefile b/src/kimi_cli/deps/Makefile deleted file mode 100644 index ed750181c..000000000 --- a/src/kimi_cli/deps/Makefile +++ /dev/null @@ -1,84 +0,0 @@ -THIS_DIR := $(patsubst %/,%,$(dir $(lastword $(MAKEFILE_LIST)))) -BIN_DIR := $(THIS_DIR)/bin -TMP_DIR := $(THIS_DIR)/tmp - -# Allow override via environment: RG_VERSION=15.0.0 make download-ripgrep -RG_VERSION ?= 15.0.0 -OS := $(shell uname -s) -ARCH := $(shell uname -m | tr '[:upper:]' '[:lower:]') -RG_ARCHIVE_EXT := tar.gz -RG_ARCHIVE_BIN := rg -RG_BIN_SUFFIX := - -# Map OS/ARCH to ripgrep TARGET name -# See: https://github.com/BurntSushi/ripgrep/releases -ifeq ($(OS),Darwin) - ifeq ($(ARCH),arm64) - RG_TARGET := aarch64-apple-darwin - else ifeq ($(ARCH),x86_64) - RG_TARGET := x86_64-apple-darwin - else - $(error Unsupported macOS architecture: $(ARCH)) - endif -else ifeq ($(OS),Linux) - ifeq ($(ARCH),x86_64) - RG_TARGET := x86_64-unknown-linux-musl - else ifeq ($(ARCH),aarch64) - RG_TARGET := aarch64-unknown-linux-gnu - else ifeq ($(ARCH),armv7l) - RG_TARGET := arm-unknown-linux-gnueabihf - else - $(error Unsupported Linux architecture: $(ARCH)) - endif -else ifneq (,$(filter MSYS% MINGW%,$(OS))) - ifeq ($(ARCH),x86_64) - RG_TARGET := x86_64-pc-windows-msvc - else ifeq ($(ARCH),aarch64) - RG_TARGET := aarch64-pc-windows-msvc - else - $(error Unsupported Windows architecture: $(ARCH)) - endif - RG_ARCHIVE_EXT := zip - RG_ARCHIVE_BIN := rg.exe - RG_BIN_SUFFIX := .exe -else - $(error Unsupported OS: $(OS)) -endif - -RG_ARCHIVE := ripgrep-$(RG_VERSION)-$(RG_TARGET).$(RG_ARCHIVE_EXT) -RG_URL := https://github.com/BurntSushi/ripgrep/releases/download/$(RG_VERSION)/$(RG_ARCHIVE) - - -.PHONY: download-ripgrep -download-ripgrep: - @echo "==> Ensuring ripgrep is installed" - @if [ -f "$(BIN_DIR)/rg$(RG_BIN_SUFFIX)" ]; then \ - echo "rg already installed at $(BIN_DIR)/rg$(RG_BIN_SUFFIX)"; \ - else \ - echo "Downloading ripgrep $(RG_VERSION) from: $(RG_URL)"; \ - mkdir -p "$(BIN_DIR)" "$(TMP_DIR)"; \ - ARCHIVE_PATH="$(TMP_DIR)/$(RG_ARCHIVE)"; \ - if command -v curl >/dev/null 2>&1; then \ - curl -L --fail -o "$$ARCHIVE_PATH" "$(RG_URL)"; \ - else \ - if command -v wget >/dev/null 2>&1; then \ - wget -O "$$ARCHIVE_PATH" "$(RG_URL)"; \ - else \ - echo "Error: neither curl nor wget is available" && exit 1; \ - fi; \ - fi; \ - if [ "$(RG_ARCHIVE_EXT)" = "zip" ]; then \ - ARCHIVE_PATH="$$ARCHIVE_PATH" TMP_DIR="$(TMP_DIR)" python -c "import os, zipfile; zipfile.ZipFile(os.environ['ARCHIVE_PATH']).extractall(os.environ['TMP_DIR'])"; \ - else \ - tar -xzf "$$ARCHIVE_PATH" -C "$(TMP_DIR)"; \ - fi; \ - SRC_PATH="$(TMP_DIR)/ripgrep-$(RG_VERSION)-$(RG_TARGET)/$(RG_ARCHIVE_BIN)"; \ - DST_PATH="$(BIN_DIR)/rg$(RG_BIN_SUFFIX)"; \ - cp "$$SRC_PATH" "$$DST_PATH"; \ - chmod +x "$$DST_PATH"; \ - echo "rg installed at $$DST_PATH"; \ - fi - - -.PHONY: download-deps -download-deps: download-ripgrep diff --git a/src/kimi_cli/exception.py b/src/kimi_cli/exception.py deleted file mode 100644 index d8c21c75e..000000000 --- a/src/kimi_cli/exception.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - - -class KimiCLIException(Exception): - """Base exception class for Kimi Code CLI.""" - - pass - - -class ConfigError(KimiCLIException, ValueError): - """Configuration error.""" - - pass - - -class AgentSpecError(KimiCLIException, ValueError): - """Agent specification error.""" - - pass - - -class InvalidToolError(KimiCLIException, ValueError): - """Invalid tool error.""" - - pass - - -class SystemPromptTemplateError(KimiCLIException, ValueError): - """System prompt template error.""" - - pass - - -class MCPConfigError(KimiCLIException, ValueError): - """MCP config error.""" - - pass - - -class MCPRuntimeError(KimiCLIException, RuntimeError): - """MCP runtime error.""" - - pass diff --git a/src/kimi_cli/exception.ts b/src/kimi_cli/exception.ts new file mode 100644 index 000000000..a158b9d15 --- /dev/null +++ b/src/kimi_cli/exception.ts @@ -0,0 +1,60 @@ +/** + * Exception hierarchy — corresponds to Python exception.py + * All custom error classes for kimi-cli. + */ + +/** Base exception class for Kimi Code CLI. */ +export class KimiCLIException extends Error { + constructor(message: string) { + super(message); + this.name = "KimiCLIException"; + } +} + +/** Configuration error. */ +export class ConfigError extends KimiCLIException { + constructor(message: string) { + super(message); + this.name = "ConfigError"; + } +} + +/** Agent specification error. */ +export class AgentSpecError extends KimiCLIException { + constructor(message: string) { + super(message); + this.name = "AgentSpecError"; + } +} + +/** Invalid tool error. */ +export class InvalidToolError extends KimiCLIException { + constructor(message: string) { + super(message); + this.name = "InvalidToolError"; + } +} + +/** System prompt template error. */ +export class SystemPromptTemplateError extends KimiCLIException { + constructor(message: string) { + super(message); + this.name = "SystemPromptTemplateError"; + } +} + +/** MCP config error. */ +export class MCPConfigError extends KimiCLIException { + constructor(message: string) { + super(message); + this.name = "MCPConfigError"; + } +} + +/** MCP runtime error. */ +export class MCPRuntimeError extends KimiCLIException { + constructor(message: string) { + super(message); + this.name = "MCPRuntimeError"; + } +} diff --git a/src/kimi_cli/hooks/__init__.py b/src/kimi_cli/hooks/__init__.py deleted file mode 100644 index 1bb365e2d..000000000 --- a/src/kimi_cli/hooks/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from kimi_cli.hooks.config import HOOK_EVENT_TYPES, HookDef, HookEventType -from kimi_cli.hooks.engine import HookEngine - -__all__ = ["HookDef", "HookEventType", "HOOK_EVENT_TYPES", "HookEngine"] diff --git a/src/kimi_cli/hooks/config.py b/src/kimi_cli/hooks/config.py deleted file mode 100644 index ea6963371..000000000 --- a/src/kimi_cli/hooks/config.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - -HookEventType = Literal[ - "PreToolUse", - "PostToolUse", - "PostToolUseFailure", - "UserPromptSubmit", - "Stop", - "StopFailure", - "SessionStart", - "SessionEnd", - "SubagentStart", - "SubagentStop", - "PreCompact", - "PostCompact", - "Notification", -] - -HOOK_EVENT_TYPES: list[str] = list(HookEventType.__args__) # type: ignore[attr-defined] - - -class HookDef(BaseModel): - """A single hook definition in config.toml.""" - - event: HookEventType - """Which lifecycle event triggers this hook.""" - command: str - """Shell command to execute. Receives JSON on stdin.""" - matcher: str = "" - """Regex pattern to filter. Empty matches everything.""" - timeout: int = Field(default=30, ge=1, le=600) - """Timeout in seconds. Fail-open on timeout.""" diff --git a/src/kimi_cli/hooks/config.ts b/src/kimi_cli/hooks/config.ts new file mode 100644 index 000000000..54b103728 --- /dev/null +++ b/src/kimi_cli/hooks/config.ts @@ -0,0 +1,23 @@ +/** + * Hook configuration — corresponds to Python hooks/config.py + * HookDef and HookEventType are already defined in config.ts, + * re-exported here for convenience. + */ + +export { HookDef, HookEventType } from "../config.ts"; + +export const HOOK_EVENT_TYPES: string[] = [ + "PreToolUse", + "PostToolUse", + "PostToolUseFailure", + "UserPromptSubmit", + "Stop", + "StopFailure", + "SessionStart", + "SessionEnd", + "SubagentStart", + "SubagentStop", + "PreCompact", + "PostCompact", + "Notification", +]; diff --git a/src/kimi_cli/hooks/engine.py b/src/kimi_cli/hooks/engine.py deleted file mode 100644 index 20202485c..000000000 --- a/src/kimi_cli/hooks/engine.py +++ /dev/null @@ -1,310 +0,0 @@ -from __future__ import annotations - -import asyncio -import re -import time -import uuid -from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field -from typing import Any - -from kimi_cli import logger -from kimi_cli.hooks.config import HookDef, HookEventType -from kimi_cli.hooks.runner import HookResult, run_hook - -# Callback signatures for wire integration -type OnTriggered = Callable[[str, str, int], None] -"""(event, target, hook_count) -> None""" - -type OnResolved = Callable[[str, str, str, str, int], None] -"""(event, target, action, reason, duration_ms) -> None""" - -type OnWireHookRequest = Callable[[WireHookHandle], Awaitable[None]] -"""Called when a wire hook needs client handling. The callback should send -the request over the wire and resolve the handle when the client responds.""" - - -@dataclass -class WireHookSubscription: - """A client-side hook subscription registered via wire initialize.""" - - id: str - event: str - matcher: str = "" - timeout: int = 30 - - -@dataclass -class WireHookHandle: - """A pending wire hook request waiting for client response.""" - - id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) - subscription_id: str = "" - event: str = "" - target: str = "" - input_data: dict[str, Any] = field(default_factory=lambda: {}) - _future: asyncio.Future[HookResult] | None = field(default=None, repr=False) - - def _get_future(self) -> asyncio.Future[HookResult]: - if self._future is None: - self._future = asyncio.get_event_loop().create_future() - return self._future - - async def wait(self) -> HookResult: - """Wait for the client to respond.""" - return await self._get_future() - - def resolve(self, action: str = "allow", reason: str = "") -> None: - """Resolve with client's decision.""" - result = HookResult(action=action, reason=reason) # type: ignore[arg-type] - future = self._get_future() - if not future.done(): - future.set_result(result) - - -class HookEngine: - """Loads hook definitions and executes matching hooks in parallel. - - Supports two hook sources: - - Server-side (config.toml): shell commands executed locally - - Client-side (wire subscriptions): forwarded to client via HookRequest - """ - - def __init__( - self, - hooks: list[HookDef] | None = None, - cwd: str | None = None, - *, - on_triggered: OnTriggered | None = None, - on_resolved: OnResolved | None = None, - on_wire_hook: OnWireHookRequest | None = None, - ): - self._hooks: list[HookDef] = list(hooks) if hooks else [] - self._wire_subs: list[WireHookSubscription] = [] - self._cwd = cwd - self._on_triggered = on_triggered - self._on_resolved = on_resolved - self._on_wire_hook = on_wire_hook - self._by_event: dict[str, list[HookDef]] = {} - self._wire_by_event: dict[str, list[WireHookSubscription]] = {} - self._rebuild_index() - - def _rebuild_index(self) -> None: - self._by_event.clear() - for h in self._hooks: - self._by_event.setdefault(h.event, []).append(h) - self._wire_by_event.clear() - for s in self._wire_subs: - self._wire_by_event.setdefault(s.event, []).append(s) - - def add_hooks(self, hooks: list[HookDef]) -> None: - """Add server-side hooks at runtime. Rebuilds index.""" - self._hooks.extend(hooks) - self._rebuild_index() - - def add_wire_subscriptions(self, subs: list[WireHookSubscription]) -> None: - """Register client-side hook subscriptions from wire initialize.""" - self._wire_subs.extend(subs) - self._rebuild_index() - - def set_callbacks( - self, - on_triggered: OnTriggered | None = None, - on_resolved: OnResolved | None = None, - on_wire_hook: OnWireHookRequest | None = None, - ) -> None: - """Set wire event callbacks.""" - self._on_triggered = on_triggered - self._on_resolved = on_resolved - self._on_wire_hook = on_wire_hook - - @property - def has_hooks(self) -> bool: - return bool(self._hooks) or bool(self._wire_subs) - - def has_hooks_for(self, event: HookEventType) -> bool: - return bool(self._by_event.get(event)) or bool(self._wire_by_event.get(event)) - - @property - def summary(self) -> dict[str, int]: - """Event -> total count of hooks (server + wire).""" - counts: dict[str, int] = {} - for event, hooks in self._by_event.items(): - counts[event] = counts.get(event, 0) + len(hooks) - for event, subs in self._wire_by_event.items(): - counts[event] = counts.get(event, 0) + len(subs) - return counts - - def details(self) -> dict[str, list[dict[str, str]]]: - """Event -> list of {matcher, command/type} for display.""" - result: dict[str, list[dict[str, str]]] = {} - for event, hooks in self._by_event.items(): - entries = result.setdefault(event, []) - for h in hooks: - entries.append( - { - "matcher": h.matcher or "(all)", - "source": "server", - "command": h.command, - } - ) - for event, subs in self._wire_by_event.items(): - entries = result.setdefault(event, []) - for s in subs: - entries.append( - { - "matcher": s.matcher or "(all)", - "source": "wire", - "command": "(client-side)", - } - ) - return result - - def _match_regex(self, pattern: str, value: str) -> bool: - if not pattern: - return True - try: - return bool(re.search(pattern, value)) - except re.error: - logger.warning("Invalid regex in hook matcher: {}", pattern) - return False - - async def trigger( - self, - event: HookEventType, - *, - matcher_value: str = "", - input_data: dict[str, Any], - ) -> list[HookResult]: - """Run all matching hooks (server + wire) in parallel.""" - - # --- Match server-side hooks --- - seen_commands: set[str] = set() - server_matched: list[HookDef] = [] - for h in self._by_event.get(event, []): - if not self._match_regex(h.matcher, matcher_value): - continue - if h.command in seen_commands: - continue - seen_commands.add(h.command) - server_matched.append(h) - - # --- Match wire subscriptions --- - wire_matched: list[WireHookSubscription] = [] - for s in self._wire_by_event.get(event, []): - if not self._match_regex(s.matcher, matcher_value): - continue - wire_matched.append(s) - - total = len(server_matched) + len(wire_matched) - if total == 0: - return [] - - try: - return await self._execute_hooks( - event, matcher_value, server_matched, wire_matched, input_data - ) - except Exception: - logger.warning("Hook engine error for {}, failing open", event) - return [] - - async def _execute_hooks( - self, - event: str, - matcher_value: str, - server_matched: list[HookDef], - wire_matched: list[WireHookSubscription], - input_data: dict[str, Any], - ) -> list[HookResult]: - """Run matched hooks and emit wire events. Separated for fail-open wrapping.""" - total = len(server_matched) + len(wire_matched) - logger.debug( - "Triggering {} hooks for {} ({} server, {} wire)", - total, - event, - len(server_matched), - len(wire_matched), - ) - - # --- HookTriggered --- - if self._on_triggered: - try: - self._on_triggered(event, matcher_value, total) - except Exception: - logger.debug("HookTriggered callback failed, continuing") - - t0 = time.monotonic() - tasks: list[asyncio.Task[HookResult]] = [] - - # Server-side: run shell commands - for h in server_matched: - tasks.append( - asyncio.create_task( - run_hook(h.command, input_data, timeout=h.timeout, cwd=self._cwd) - ) - ) - - # Wire-side: send request to client, wait for response - for s in wire_matched: - tasks.append( - asyncio.create_task( - self._dispatch_wire_hook( - s.id, event, matcher_value, input_data, timeout=s.timeout - ) - ) - ) - - results = list(await asyncio.gather(*tasks)) - duration_ms = int((time.monotonic() - t0) * 1000) - - # Aggregate: block if any hook blocked - action = "allow" - reason = "" - for r in results: - if r.action == "block": - action = "block" - reason = r.reason - break - - # --- HookResolved --- - if self._on_resolved: - try: - self._on_resolved(event, matcher_value, action, reason, duration_ms) - except Exception: - logger.debug("HookResolved callback failed, continuing") - - return results - - async def _dispatch_wire_hook( - self, - subscription_id: str, - event: str, - target: str, - input_data: dict[str, Any], - *, - timeout: int = 30, - ) -> HookResult: - """Send a hook request to the wire client and wait for response.""" - if not self._on_wire_hook: - return HookResult(action="allow") - - handle = WireHookHandle( - subscription_id=subscription_id, - event=event, - target=target, - input_data=input_data, - ) - # Run the callback in background so timeout applies to the - # full client round-trip, not just handle.wait(). - hook_task: asyncio.Task[None] = asyncio.ensure_future(self._on_wire_hook(handle)) - hook_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) - try: - return await asyncio.wait_for(handle.wait(), timeout=timeout) - except TimeoutError: - hook_task.cancel() - logger.warning("Wire hook timed out: {} {}", event, target) - return HookResult(action="allow", timed_out=True) - except Exception as e: - hook_task.cancel() - logger.warning("Wire hook failed: {} {}: {}", event, target, e) - return HookResult(action="allow") diff --git a/src/kimi_cli/hooks/engine.ts b/src/kimi_cli/hooks/engine.ts new file mode 100644 index 000000000..826327292 --- /dev/null +++ b/src/kimi_cli/hooks/engine.ts @@ -0,0 +1,352 @@ +/** + * Hook engine — corresponds to Python hooks/engine.py + * Runs matching hooks (shell commands) in parallel on lifecycle events. + */ + +import type { HookDef, HookEventType } from "./config.ts"; +import { logger } from "../utils/logging.ts"; + +// ── Types ─────────────────────────────────────────────── + +export interface HookResult { + action: "allow" | "block"; + reason: string; + stdout?: string; + stderr?: string; + exitCode?: number; + timedOut?: boolean; +} + +export interface WireHookSubscription { + id: string; + event: string; + matcher: string; + timeout: number; +} + +export type OnTriggered = (event: string, target: string, hookCount: number) => void; +export type OnResolved = (event: string, target: string, action: string, reason: string, durationMs: number) => void; +export type OnWireHookRequest = (handle: WireHookHandle) => Promise; + +// ── Wire hook handle ──────────────────────────────────── + +let _handleIdCounter = 0; + +export class WireHookHandle { + readonly id: string; + readonly subscriptionId: string; + readonly event: string; + readonly target: string; + readonly inputData: Record; + + private _resolve?: (result: HookResult) => void; + private _promise: Promise; + + constructor(opts: { + subscriptionId: string; + event: string; + target: string; + inputData: Record; + }) { + this.id = `wh${(++_handleIdCounter).toString(36)}`; + this.subscriptionId = opts.subscriptionId; + this.event = opts.event; + this.target = opts.target; + this.inputData = opts.inputData; + this._promise = new Promise((resolve) => { + this._resolve = resolve; + }); + } + + wait(): Promise { + return this._promise; + } + + resolve(action: "allow" | "block" = "allow", reason = ""): void { + this._resolve?.({ action, reason }); + } +} + +// ── Hook runner ───────────────────────────────────────── + +async function runHook( + command: string, + inputData: Record, + opts?: { timeout?: number; cwd?: string }, +): Promise { + const timeout = (opts?.timeout ?? 30) * 1000; + try { + const proc = Bun.spawn(["sh", "-c", command], { + stdin: new Blob([JSON.stringify(inputData)]), + stdout: "pipe", + stderr: "pipe", + cwd: opts?.cwd, + }); + + const timer = setTimeout(() => proc.kill(), timeout); + + const exitCode = await proc.exited; + clearTimeout(timer); + + const stdout = await new Response(proc.stdout).text(); + const stderr = await new Response(proc.stderr).text(); + + // Exit 2 = block + if (exitCode === 2) { + return { + action: "block", + reason: stderr.trim(), + stdout, + stderr, + exitCode: 2, + }; + } + + // Exit 0 + JSON stdout = structured decision + if (exitCode === 0 && stdout.trim()) { + try { + const parsed = JSON.parse(stdout.trim()); + if (parsed && typeof parsed === "object") { + // Direct action field (e.g. {"action":"block","reason":"..."}) + if (parsed.action === "block") { + return { + action: "block", + reason: String(parsed.reason ?? ""), + stdout, + stderr, + exitCode: 0, + }; + } + // Claude Code-style hookSpecificOutput + const hookOutput = parsed.hookSpecificOutput; + if (hookOutput?.permissionDecision === "deny") { + return { + action: "block", + reason: String(hookOutput.permissionDecisionReason ?? ""), + stdout, + stderr, + exitCode: 0, + }; + } + } + } catch { + // Not JSON — that's fine + } + } + + return { action: "allow", reason: "", stdout, stderr, exitCode: exitCode ?? 0 }; + } catch { + return { action: "allow", reason: "" }; + } +} + +// ── Engine ────────────────────────────────────────────── + +export class HookEngine { + private hooks: HookDef[]; + private wireSubs: WireHookSubscription[] = []; + private cwd?: string; + private onTriggered?: OnTriggered; + private onResolved?: OnResolved; + private onWireHook?: OnWireHookRequest; + private byEvent = new Map(); + private wireByEvent = new Map(); + + constructor(opts?: { + hooks?: HookDef[]; + cwd?: string; + onTriggered?: OnTriggered; + onResolved?: OnResolved; + onWireHook?: OnWireHookRequest; + }) { + this.hooks = opts?.hooks ? [...opts.hooks] : []; + this.cwd = opts?.cwd; + this.onTriggered = opts?.onTriggered; + this.onResolved = opts?.onResolved; + this.onWireHook = opts?.onWireHook; + this.rebuildIndex(); + } + + private rebuildIndex(): void { + this.byEvent.clear(); + for (const h of this.hooks) { + const list = this.byEvent.get(h.event) ?? []; + list.push(h); + this.byEvent.set(h.event, list); + } + this.wireByEvent.clear(); + for (const s of this.wireSubs) { + const list = this.wireByEvent.get(s.event) ?? []; + list.push(s); + this.wireByEvent.set(s.event, list); + } + } + + addHooks(hooks: HookDef[]): void { + this.hooks.push(...hooks); + this.rebuildIndex(); + } + + addWireSubscriptions(subs: WireHookSubscription[]): void { + this.wireSubs.push(...subs); + this.rebuildIndex(); + } + + setCallbacks(opts: { onTriggered?: OnTriggered; onResolved?: OnResolved; onWireHook?: OnWireHookRequest }): void { + this.onTriggered = opts.onTriggered; + this.onResolved = opts.onResolved; + this.onWireHook = opts.onWireHook; + } + + get hasHooks(): boolean { + return this.hooks.length > 0 || this.wireSubs.length > 0; + } + + hasHooksFor(event: HookEventType): boolean { + return (this.byEvent.get(event)?.length ?? 0) > 0 || (this.wireByEvent.get(event)?.length ?? 0) > 0; + } + + get summary(): Record { + const counts: Record = {}; + for (const [event, hooks] of this.byEvent) { + counts[event] = (counts[event] ?? 0) + hooks.length; + } + for (const [event, subs] of this.wireByEvent) { + counts[event] = (counts[event] ?? 0) + subs.length; + } + return counts; + } + + private matchRegex(pattern: string, value: string): boolean { + if (!pattern) return true; + try { + return new RegExp(pattern).test(value); + } catch { + logger.warn(`Invalid regex in hook matcher: ${pattern}`); + return false; + } + } + + async trigger( + event: HookEventType, + opts: { matcherValue?: string; inputData: Record }, + ): Promise { + const matcherValue = opts.matcherValue ?? ""; + + // Match server-side hooks + const seenCommands = new Set(); + const serverMatched: HookDef[] = []; + for (const h of this.byEvent.get(event) ?? []) { + if (!this.matchRegex(h.matcher, matcherValue)) continue; + if (seenCommands.has(h.command)) continue; + seenCommands.add(h.command); + serverMatched.push(h); + } + + // Match wire subscriptions + const wireMatched: WireHookSubscription[] = []; + for (const s of this.wireByEvent.get(event) ?? []) { + if (!this.matchRegex(s.matcher, matcherValue)) continue; + wireMatched.push(s); + } + + const total = serverMatched.length + wireMatched.length; + if (total === 0) return []; + + try { + return await this.executeHooks(event, matcherValue, serverMatched, wireMatched, opts.inputData); + } catch { + logger.warn(`Hook engine error for ${event}, failing open`); + return []; + } + } + + private async executeHooks( + event: string, + matcherValue: string, + serverMatched: HookDef[], + wireMatched: WireHookSubscription[], + inputData: Record, + ): Promise { + const total = serverMatched.length + wireMatched.length; + + if (this.onTriggered) { + try { + this.onTriggered(event, matcherValue, total); + } catch { + // ignore + } + } + + const t0 = performance.now(); + + // Server-side: run shell commands + const tasks: Promise[] = serverMatched.map((h) => + runHook(h.command, inputData, { timeout: h.timeout, cwd: this.cwd }), + ); + + // Wire-side: dispatch to client + for (const s of wireMatched) { + tasks.push(this.dispatchWireHook(s.id, event, matcherValue, inputData, s.timeout)); + } + + const results = await Promise.all(tasks); + const durationMs = Math.round(performance.now() - t0); + + let action = "allow"; + let reason = ""; + for (const r of results) { + if (r.action === "block") { + action = "block"; + reason = r.reason; + break; + } + } + + if (this.onResolved) { + try { + this.onResolved(event, matcherValue, action, reason, durationMs); + } catch { + // ignore + } + } + + return results; + } + + private async dispatchWireHook( + subscriptionId: string, + event: string, + target: string, + inputData: Record, + timeout: number = 30, + ): Promise { + if (!this.onWireHook) { + return { action: "allow", reason: "" }; + } + + const handle = new WireHookHandle({ + subscriptionId, + event, + target, + inputData, + }); + + const hookPromise = this.onWireHook(handle); + hookPromise.catch(() => {}); // Suppress unhandled rejection + + try { + const timeoutMs = timeout * 1000; + const result = await Promise.race([ + handle.wait(), + new Promise((_, reject) => + setTimeout(() => reject(new Error("timeout")), timeoutMs), + ), + ]); + return result; + } catch { + logger.warn(`Wire hook timed out: ${event} ${target}`); + return { action: "allow", reason: "", timedOut: true }; + } + } +} diff --git a/src/kimi_cli/hooks/events.py b/src/kimi_cli/hooks/events.py deleted file mode 100644 index 66e447a41..000000000 --- a/src/kimi_cli/hooks/events.py +++ /dev/null @@ -1,190 +0,0 @@ -"""Input payload builders for each hook event type.""" - -from __future__ import annotations - -from typing import Any - - -def _base(event: str, session_id: str, cwd: str) -> dict[str, Any]: - return {"hook_event_name": event, "session_id": session_id, "cwd": cwd} - - -def pre_tool_use( - *, - session_id: str, - cwd: str, - tool_name: str, - tool_input: dict[str, Any], - tool_call_id: str = "", -) -> dict[str, Any]: - return { - **_base("PreToolUse", session_id, cwd), - "tool_name": tool_name, - "tool_input": tool_input, - "tool_call_id": tool_call_id, - } - - -def post_tool_use( - *, - session_id: str, - cwd: str, - tool_name: str, - tool_input: dict[str, Any], - tool_output: str = "", - tool_call_id: str = "", -) -> dict[str, Any]: - return { - **_base("PostToolUse", session_id, cwd), - "tool_name": tool_name, - "tool_input": tool_input, - "tool_output": tool_output, - "tool_call_id": tool_call_id, - } - - -def post_tool_use_failure( - *, - session_id: str, - cwd: str, - tool_name: str, - tool_input: dict[str, Any], - error: str, - tool_call_id: str = "", -) -> dict[str, Any]: - return { - **_base("PostToolUseFailure", session_id, cwd), - "tool_name": tool_name, - "tool_input": tool_input, - "error": error, - "tool_call_id": tool_call_id, - } - - -def user_prompt_submit( - *, - session_id: str, - cwd: str, - prompt: str, -) -> dict[str, Any]: - return {**_base("UserPromptSubmit", session_id, cwd), "prompt": prompt} - - -def stop( - *, - session_id: str, - cwd: str, - stop_hook_active: bool = False, -) -> dict[str, Any]: - return { - **_base("Stop", session_id, cwd), - "stop_hook_active": stop_hook_active, - } - - -def stop_failure( - *, - session_id: str, - cwd: str, - error_type: str, - error_message: str, -) -> dict[str, Any]: - return { - **_base("StopFailure", session_id, cwd), - "error_type": error_type, - "error_message": error_message, - } - - -def session_start( - *, - session_id: str, - cwd: str, - source: str, -) -> dict[str, Any]: - return {**_base("SessionStart", session_id, cwd), "source": source} - - -def session_end( - *, - session_id: str, - cwd: str, - reason: str, -) -> dict[str, Any]: - return {**_base("SessionEnd", session_id, cwd), "reason": reason} - - -def subagent_start( - *, - session_id: str, - cwd: str, - agent_name: str, - prompt: str, -) -> dict[str, Any]: - return { - **_base("SubagentStart", session_id, cwd), - "agent_name": agent_name, - "prompt": prompt, - } - - -def subagent_stop( - *, - session_id: str, - cwd: str, - agent_name: str, - response: str = "", -) -> dict[str, Any]: - return { - **_base("SubagentStop", session_id, cwd), - "agent_name": agent_name, - "response": response, - } - - -def pre_compact( - *, - session_id: str, - cwd: str, - trigger: str, - token_count: int, -) -> dict[str, Any]: - return { - **_base("PreCompact", session_id, cwd), - "trigger": trigger, - "token_count": token_count, - } - - -def post_compact( - *, - session_id: str, - cwd: str, - trigger: str, - estimated_token_count: int, -) -> dict[str, Any]: - return { - **_base("PostCompact", session_id, cwd), - "trigger": trigger, - "estimated_token_count": estimated_token_count, - } - - -def notification( - *, - session_id: str, - cwd: str, - sink: str, - notification_type: str, - title: str = "", - body: str = "", - severity: str = "info", -) -> dict[str, Any]: - return { - **_base("Notification", session_id, cwd), - "sink": sink, - "notification_type": notification_type, - "title": title, - "body": body, - "severity": severity, - } diff --git a/src/kimi_cli/hooks/events.ts b/src/kimi_cli/hooks/events.ts new file mode 100644 index 000000000..c836c2655 --- /dev/null +++ b/src/kimi_cli/hooks/events.ts @@ -0,0 +1,176 @@ +/** + * Hook event payload builders — corresponds to Python hooks/events.py + * Each function returns a payload dict for a specific hook event. + */ + +function _base(event: string, sessionId: string, cwd: string): Record { + return { hook_event_name: event, session_id: sessionId, cwd }; +} + +export function preToolUse(opts: { + sessionId: string; + cwd: string; + toolName: string; + toolInput: Record; + toolCallId?: string; +}): Record { + return { + ..._base("PreToolUse", opts.sessionId, opts.cwd), + tool_name: opts.toolName, + tool_input: opts.toolInput, + tool_call_id: opts.toolCallId ?? "", + }; +} + +export function postToolUse(opts: { + sessionId: string; + cwd: string; + toolName: string; + toolInput: Record; + toolOutput?: string; + toolCallId?: string; +}): Record { + return { + ..._base("PostToolUse", opts.sessionId, opts.cwd), + tool_name: opts.toolName, + tool_input: opts.toolInput, + tool_output: opts.toolOutput ?? "", + tool_call_id: opts.toolCallId ?? "", + }; +} + +export function postToolUseFailure(opts: { + sessionId: string; + cwd: string; + toolName: string; + toolInput: Record; + error: string; + toolCallId?: string; +}): Record { + return { + ..._base("PostToolUseFailure", opts.sessionId, opts.cwd), + tool_name: opts.toolName, + tool_input: opts.toolInput, + error: opts.error, + tool_call_id: opts.toolCallId ?? "", + }; +} + +export function userPromptSubmit(opts: { + sessionId: string; + cwd: string; + prompt: string; +}): Record { + return { ..._base("UserPromptSubmit", opts.sessionId, opts.cwd), prompt: opts.prompt }; +} + +export function stop(opts: { + sessionId: string; + cwd: string; + stopHookActive?: boolean; +}): Record { + return { + ..._base("Stop", opts.sessionId, opts.cwd), + stop_hook_active: opts.stopHookActive ?? false, + }; +} + +export function stopFailure(opts: { + sessionId: string; + cwd: string; + errorType: string; + errorMessage: string; +}): Record { + return { + ..._base("StopFailure", opts.sessionId, opts.cwd), + error_type: opts.errorType, + error_message: opts.errorMessage, + }; +} + +export function sessionStart(opts: { + sessionId: string; + cwd: string; + source: string; +}): Record { + return { ..._base("SessionStart", opts.sessionId, opts.cwd), source: opts.source }; +} + +export function sessionEnd(opts: { + sessionId: string; + cwd: string; + reason: string; +}): Record { + return { ..._base("SessionEnd", opts.sessionId, opts.cwd), reason: opts.reason }; +} + +export function subagentStart(opts: { + sessionId: string; + cwd: string; + agentName: string; + prompt: string; +}): Record { + return { + ..._base("SubagentStart", opts.sessionId, opts.cwd), + agent_name: opts.agentName, + prompt: opts.prompt, + }; +} + +export function subagentStop(opts: { + sessionId: string; + cwd: string; + agentName: string; + response?: string; +}): Record { + return { + ..._base("SubagentStop", opts.sessionId, opts.cwd), + agent_name: opts.agentName, + response: opts.response ?? "", + }; +} + +export function preCompact(opts: { + sessionId: string; + cwd: string; + trigger: string; + tokenCount: number; +}): Record { + return { + ..._base("PreCompact", opts.sessionId, opts.cwd), + trigger: opts.trigger, + token_count: opts.tokenCount, + }; +} + +export function postCompact(opts: { + sessionId: string; + cwd: string; + trigger: string; + estimatedTokenCount: number; +}): Record { + return { + ..._base("PostCompact", opts.sessionId, opts.cwd), + trigger: opts.trigger, + estimated_token_count: opts.estimatedTokenCount, + }; +} + +export function notification(opts: { + sessionId: string; + cwd: string; + sink: string; + notificationType: string; + title?: string; + body?: string; + severity?: string; +}): Record { + return { + ..._base("Notification", opts.sessionId, opts.cwd), + sink: opts.sink, + notification_type: opts.notificationType, + title: opts.title ?? "", + body: opts.body ?? "", + severity: opts.severity ?? "info", + }; +} diff --git a/src/kimi_cli/hooks/runner.py b/src/kimi_cli/hooks/runner.py deleted file mode 100644 index 3c81d6d27..000000000 --- a/src/kimi_cli/hooks/runner.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -from dataclasses import dataclass -from typing import Any, Literal, cast - -from kimi_cli import logger - - -@dataclass -class HookResult: - """Result of a single hook execution.""" - - action: Literal["allow", "block"] = "allow" - reason: str = "" - stdout: str = "" - stderr: str = "" - exit_code: int = 0 - timed_out: bool = False - - -async def run_hook( - command: str, - input_data: dict[str, Any], - *, - timeout: int = 30, - cwd: str | None = None, -) -> HookResult: - """Execute a single hook command. Fail-open: errors/timeouts -> allow.""" - try: - proc = await asyncio.create_subprocess_shell( - command, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - cwd=cwd, - ) - try: - stdout_bytes, stderr_bytes = await asyncio.wait_for( - proc.communicate(input=json.dumps(input_data).encode()), - timeout=timeout, - ) - except TimeoutError: - proc.kill() - await proc.wait() - logger.warning("Hook timed out after {}s: {}", timeout, command) - return HookResult(action="allow", timed_out=True) - except asyncio.CancelledError: - proc.kill() - await proc.wait() - raise - except Exception as e: - logger.warning("Hook failed: {}: {}", command, e) - return HookResult(action="allow", stderr=str(e)) - - stdout = stdout_bytes.decode(errors="replace") - stderr = stderr_bytes.decode(errors="replace") - exit_code = proc.returncode or 0 - - # Exit 2 = block - if exit_code == 2: - return HookResult( - action="block", - reason=stderr.strip(), - stdout=stdout, - stderr=stderr, - exit_code=2, - ) - - # Exit 0 + JSON stdout = structured decision - if exit_code == 0 and stdout.strip(): - try: - raw = json.loads(stdout) - if isinstance(raw, dict): - parsed = cast(dict[str, Any], raw) - hook_output = cast(dict[str, Any], parsed.get("hookSpecificOutput", {})) - if hook_output.get("permissionDecision") == "deny": - return HookResult( - action="block", - reason=str(hook_output.get("permissionDecisionReason", "")), - stdout=stdout, - stderr=stderr, - exit_code=0, - ) - except (json.JSONDecodeError, TypeError): - pass - - return HookResult(action="allow", stdout=stdout, stderr=stderr, exit_code=exit_code) diff --git a/src/kimi_cli/index.ts b/src/kimi_cli/index.ts new file mode 100644 index 000000000..c5bfcbe9b --- /dev/null +++ b/src/kimi_cli/index.ts @@ -0,0 +1,10 @@ +#!/usr/bin/env bun +/** + * Kimi CLI - AI Agent for Terminal + * Entry point (corresponds to Python __main__.py) + */ + +import { cli } from "./cli/index.ts"; + +const exitCode = await cli(process.argv); +process.exit(exitCode); diff --git a/src/kimi_cli/llm.py b/src/kimi_cli/llm.py deleted file mode 100644 index fba2f10e9..000000000 --- a/src/kimi_cli/llm.py +++ /dev/null @@ -1,301 +0,0 @@ -from __future__ import annotations - -import json -import os -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Literal, cast, get_args - -from kosong.chat_provider import ChatProvider -from pydantic import SecretStr - -from kimi_cli.constant import USER_AGENT - -if TYPE_CHECKING: - from kimi_cli.auth.oauth import OAuthManager - from kimi_cli.config import Config, LLMModel, LLMProvider - -type ProviderType = Literal[ - "kimi", - "openai_legacy", - "openai_responses", - "anthropic", - "google_genai", # for backward-compatibility, equals to `gemini` - "gemini", - "vertexai", - "_echo", - "_scripted_echo", - "_chaos", -] - -type ModelCapability = Literal["image_in", "video_in", "thinking", "always_thinking"] -ALL_MODEL_CAPABILITIES: set[ModelCapability] = set(get_args(ModelCapability.__value__)) - - -@dataclass(slots=True) -class LLM: - chat_provider: ChatProvider - max_context_size: int - capabilities: set[ModelCapability] - model_config: LLMModel | None = None - provider_config: LLMProvider | None = None - - @property - def model_name(self) -> str: - return self.chat_provider.model_name - - -def model_display_name(model_name: str | None) -> str: - if not model_name: - return "" - if model_name in ("kimi-for-coding", "kimi-code"): - return f"{model_name} (powered by kimi-k2.5)" - return model_name - - -def augment_provider_with_env_vars(provider: LLMProvider, model: LLMModel) -> dict[str, str]: - """Override provider/model settings from environment variables. - - Returns: - Mapping of environment variables that were applied. - """ - applied: dict[str, str] = {} - - match provider.type: - case "kimi": - if base_url := os.getenv("KIMI_BASE_URL"): - provider.base_url = base_url - applied["KIMI_BASE_URL"] = base_url - if api_key := os.getenv("KIMI_API_KEY"): - provider.api_key = SecretStr(api_key) - applied["KIMI_API_KEY"] = "******" - if model_name := os.getenv("KIMI_MODEL_NAME"): - model.model = model_name - applied["KIMI_MODEL_NAME"] = model_name - if max_context_size := os.getenv("KIMI_MODEL_MAX_CONTEXT_SIZE"): - model.max_context_size = int(max_context_size) - applied["KIMI_MODEL_MAX_CONTEXT_SIZE"] = max_context_size - if capabilities := os.getenv("KIMI_MODEL_CAPABILITIES"): - caps_lower = (cap.strip().lower() for cap in capabilities.split(",") if cap.strip()) - model.capabilities = set( - cast(ModelCapability, cap) - for cap in caps_lower - if cap in get_args(ModelCapability.__value__) - ) - applied["KIMI_MODEL_CAPABILITIES"] = capabilities - case "openai_legacy" | "openai_responses": - if base_url := os.getenv("OPENAI_BASE_URL"): - provider.base_url = base_url - if api_key := os.getenv("OPENAI_API_KEY"): - provider.api_key = SecretStr(api_key) - case _: - pass - - return applied - - -def _kimi_default_headers(provider: LLMProvider, oauth: OAuthManager | None) -> dict[str, str]: - headers = {"User-Agent": USER_AGENT} - if oauth: - headers.update(oauth.common_headers()) - if provider.custom_headers: - headers.update(provider.custom_headers) - return headers - - -def create_llm( - provider: LLMProvider, - model: LLMModel, - *, - thinking: bool | None = None, - session_id: str | None = None, - oauth: OAuthManager | None = None, -) -> LLM | None: - if provider.type not in {"_echo", "_scripted_echo"} and ( - not provider.base_url or not model.model - ): - return None - - resolved_api_key = ( - oauth.resolve_api_key(provider.api_key, provider.oauth) - if oauth and provider.oauth - else provider.api_key.get_secret_value() - ) - - match provider.type: - case "kimi": - from kosong.chat_provider.kimi import Kimi - - chat_provider = Kimi( - model=model.model, - base_url=provider.base_url, - api_key=resolved_api_key, - default_headers=_kimi_default_headers(provider, oauth), - ) - - gen_kwargs: Kimi.GenerationKwargs = {} - if session_id: - gen_kwargs["prompt_cache_key"] = session_id - if temperature := os.getenv("KIMI_MODEL_TEMPERATURE"): - gen_kwargs["temperature"] = float(temperature) - if top_p := os.getenv("KIMI_MODEL_TOP_P"): - gen_kwargs["top_p"] = float(top_p) - if max_tokens := os.getenv("KIMI_MODEL_MAX_TOKENS"): - gen_kwargs["max_tokens"] = int(max_tokens) - - if gen_kwargs: - chat_provider = chat_provider.with_generation_kwargs(**gen_kwargs) - case "openai_legacy": - from kosong.contrib.chat_provider.openai_legacy import OpenAILegacy - - chat_provider = OpenAILegacy( - model=model.model, - base_url=provider.base_url, - api_key=resolved_api_key, - ) - case "openai_responses": - from kosong.contrib.chat_provider.openai_responses import OpenAIResponses - - chat_provider = OpenAIResponses( - model=model.model, - base_url=provider.base_url, - api_key=resolved_api_key, - ) - case "anthropic": - from kosong.contrib.chat_provider.anthropic import Anthropic - - chat_provider = Anthropic( - model=model.model, - base_url=provider.base_url, - api_key=resolved_api_key, - default_max_tokens=50000, - metadata={"user_id": session_id} if session_id else None, - ) - case "google_genai" | "gemini": - from kosong.contrib.chat_provider.google_genai import GoogleGenAI - - chat_provider = GoogleGenAI( - model=model.model, - base_url=provider.base_url, - api_key=resolved_api_key, - ) - case "vertexai": - from kosong.contrib.chat_provider.google_genai import GoogleGenAI - - os.environ.update(provider.env or {}) - chat_provider = GoogleGenAI( - model=model.model, - base_url=provider.base_url, - api_key=resolved_api_key, - vertexai=True, - ) - case "_echo": - from kosong.chat_provider.echo import EchoChatProvider - - chat_provider = EchoChatProvider() - case "_scripted_echo": - from kosong.chat_provider.echo import ScriptedEchoChatProvider - - if provider.env: - os.environ.update(provider.env) - scripts = _load_scripted_echo_scripts() - trace_value = os.getenv("KIMI_SCRIPTED_ECHO_TRACE", "") - trace = trace_value.strip().lower() in {"1", "true", "yes", "on"} - chat_provider = ScriptedEchoChatProvider(scripts, trace=trace) - case "_chaos": - from kosong.chat_provider.chaos import ChaosChatProvider, ChaosConfig - from kosong.chat_provider.kimi import Kimi - - chat_provider = ChaosChatProvider( - provider=Kimi( - model=model.model, - base_url=provider.base_url, - api_key=resolved_api_key, - default_headers=_kimi_default_headers(provider, oauth), - ), - chaos_config=ChaosConfig( - error_probability=0.8, - error_types=[429, 500, 503], - ), - ) - - capabilities = derive_model_capabilities(model) - - # Apply thinking if specified or if model always requires thinking - if "always_thinking" in capabilities or (thinking is True and "thinking" in capabilities): - chat_provider = chat_provider.with_thinking("high") - elif thinking is False: - chat_provider = chat_provider.with_thinking("off") - # If thinking is None and model doesn't always think, leave as-is (default behavior) - - return LLM( - chat_provider=chat_provider, - max_context_size=model.max_context_size, - capabilities=capabilities, - model_config=model, - provider_config=provider, - ) - - -def clone_llm_with_model_alias( - llm: LLM | None, - config: Config, - model_alias: str | None, - *, - session_id: str, - oauth: OAuthManager | None, -) -> LLM | None: - if model_alias is None: - return llm - if model_alias not in config.models: - raise KeyError(f"Unknown model alias: {model_alias}") - model = config.models[model_alias] - provider = config.providers[model.provider] - thinking: bool | None = None - if llm is not None: - effort = getattr(llm.chat_provider, "thinking_effort", None) - if effort is not None: - thinking = effort != "off" - return create_llm( - provider, - model, - thinking=thinking, - session_id=session_id, - oauth=oauth, - ) - - -def derive_model_capabilities(model: LLMModel) -> set[ModelCapability]: - capabilities = set(model.capabilities or ()) - # Models with "thinking" in their name are always-thinking models - if "thinking" in model.model.lower() or "reason" in model.model.lower(): - capabilities.update(("thinking", "always_thinking")) - # These models support thinking but can be toggled on/off - elif model.model in {"kimi-for-coding", "kimi-code"}: - capabilities.update(("thinking", "image_in", "video_in")) - return capabilities - - -def _load_scripted_echo_scripts() -> list[str]: - script_path = os.getenv("KIMI_SCRIPTED_ECHO_SCRIPTS") - if not script_path: - raise ValueError("KIMI_SCRIPTED_ECHO_SCRIPTS is required for _scripted_echo.") - path = Path(script_path).expanduser() - if not path.exists(): - raise ValueError(f"Scripted echo file not found: {path}") - text = path.read_text(encoding="utf-8") - try: - data: object = json.loads(text) - except json.JSONDecodeError: - scripts = [chunk.strip() for chunk in text.split("\n---\n") if chunk.strip()] - if scripts: - return scripts - raise ValueError( - "Scripted echo file must be a JSON array of strings or a text file " - "split by '\\n---\\n'." - ) from None - if isinstance(data, list): - data_list = cast(list[object], data) - if all(isinstance(item, str) for item in data_list): - return cast(list[str], data_list) - raise ValueError("Scripted echo JSON must be an array of strings.") diff --git a/src/kimi_cli/llm.ts b/src/kimi_cli/llm.ts new file mode 100644 index 000000000..b75460aa7 --- /dev/null +++ b/src/kimi_cli/llm.ts @@ -0,0 +1,722 @@ +/** + * LLM abstraction layer — corresponds to Python's llm.py + * Provides a unified interface for multiple LLM providers. + */ + +import type { Message, ModelCapability, TokenUsage } from "./types"; + +// ── Provider Types ───────────────────────────────────────── + +export type ProviderType = + | "kimi" + | "openai_legacy" + | "openai_responses" + | "anthropic" + | "google_genai" + | "gemini" + | "vertexai" + | "_echo" + | "_scripted_echo" + | "_chaos"; + +// ── Stream Chunk Types ───────────────────────────────────── + +export interface TextChunk { + type: "text"; + text: string; +} + +export interface ThinkChunk { + type: "think"; + text: string; +} + +export interface ToolCallChunk { + type: "tool_call"; + id: string; + name: string; + arguments: string; +} + +export interface UsageChunk { + type: "usage"; + usage: TokenUsage; +} + +export interface DoneChunk { + type: "done"; + messageId?: string; +} + +export type StreamChunk = + | TextChunk + | ThinkChunk + | ToolCallChunk + | UsageChunk + | DoneChunk; + +// ── LLM Provider Interface ──────────────────────────────── + +export interface LLMProviderConfig { + type: ProviderType; + baseUrl: string; + apiKey: string; + customHeaders?: Record; + env?: Record; + oauth?: string | null; +} + +export interface LLMModelConfig { + model: string; + provider: string; + maxContextSize: number; + capabilities?: ModelCapability[]; +} + +export interface ChatOptions { + /** System prompt */ + system?: string; + /** Generation temperature */ + temperature?: number; + /** Top-p nucleus sampling */ + topP?: number; + /** Maximum output tokens */ + maxTokens?: number; + /** Enable/disable thinking */ + thinking?: "high" | "low" | "off"; + /** Tool definitions for the model */ + tools?: ToolDefinition[]; + /** Abort signal for cancellation */ + signal?: AbortSignal; +} + +export interface ToolDefinition { + name: string; + description: string; + parameters: Record; +} + +/** + * Abstract interface for LLM providers. + * Each provider (Anthropic, OpenAI, Kimi, etc.) implements this. + */ +export interface LLMProvider { + readonly modelName: string; + + /** + * Send a chat completion request and return a stream of chunks. + */ + chat( + messages: Message[], + options?: ChatOptions + ): AsyncIterable; +} + +// ── LLM Class ────────────────────────────────────────────── + +/** + * Wraps an LLM provider with model capabilities and context limits. + */ +export class LLM { + readonly provider: LLMProvider; + readonly maxContextSize: number; + readonly capabilities: Set; + readonly modelConfig: LLMModelConfig | null; + readonly providerConfig: LLMProviderConfig | null; + + constructor(opts: { + provider: LLMProvider; + maxContextSize: number; + capabilities: Set; + modelConfig?: LLMModelConfig | null; + providerConfig?: LLMProviderConfig | null; + }) { + this.provider = opts.provider; + this.maxContextSize = opts.maxContextSize; + this.capabilities = opts.capabilities; + this.modelConfig = opts.modelConfig ?? null; + this.providerConfig = opts.providerConfig ?? null; + } + + get modelName(): string { + return this.provider.modelName; + } + + /** + * Check if the model has a specific capability. + */ + hasCapability(cap: ModelCapability): boolean { + return this.capabilities.has(cap); + } + + /** + * Stream a chat completion. + */ + chat( + messages: Message[], + options?: ChatOptions + ): AsyncIterable { + return this.provider.chat(messages, options); + } +} + +// ── Model Display Name ───────────────────────────────────── + +export function modelDisplayName(modelName: string | null): string { + if (!modelName) return ""; + if (modelName === "kimi-for-coding" || modelName === "kimi-code") { + return `${modelName} (powered by kimi-k2.5)`; + } + return modelName; +} + +// ── Capability Detection ─────────────────────────────────── + +const ALL_MODEL_CAPABILITIES: Set = new Set([ + "image_in", + "video_in", + "thinking", + "always_thinking", +]); + +/** + * Derive model capabilities from model config. + */ +export function deriveModelCapabilities( + model: LLMModelConfig +): Set { + const capabilities = new Set(model.capabilities ?? []); + const lowerName = model.model.toLowerCase(); + + if (lowerName.includes("thinking") || lowerName.includes("reason")) { + capabilities.add("thinking"); + capabilities.add("always_thinking"); + } else if ( + model.model === "kimi-for-coding" || + model.model === "kimi-code" + ) { + capabilities.add("thinking"); + capabilities.add("image_in"); + capabilities.add("video_in"); + } + + return capabilities; +} + +// ── Environment Variable Overrides ───────────────────────── + +/** + * Override provider/model settings from environment variables. + * Returns a mapping of env vars that were applied. + */ +export function augmentProviderWithEnvVars( + provider: LLMProviderConfig, + model: LLMModelConfig +): Record { + const applied: Record = {}; + + switch (provider.type) { + case "kimi": { + const baseUrl = Bun.env.KIMI_BASE_URL; + if (baseUrl) { + provider.baseUrl = baseUrl; + applied["KIMI_BASE_URL"] = baseUrl; + } + const apiKey = Bun.env.KIMI_API_KEY; + if (apiKey) { + provider.apiKey = apiKey; + applied["KIMI_API_KEY"] = "******"; + } + const modelName = Bun.env.KIMI_MODEL_NAME; + if (modelName) { + model.model = modelName; + applied["KIMI_MODEL_NAME"] = modelName; + } + const maxCtx = Bun.env.KIMI_MODEL_MAX_CONTEXT_SIZE; + if (maxCtx) { + model.maxContextSize = parseInt(maxCtx, 10); + applied["KIMI_MODEL_MAX_CONTEXT_SIZE"] = maxCtx; + } + const caps = Bun.env.KIMI_MODEL_CAPABILITIES; + if (caps) { + const parsed = caps + .split(",") + .map((c) => c.trim().toLowerCase()) + .filter((c): c is ModelCapability => ALL_MODEL_CAPABILITIES.has(c as ModelCapability)); + model.capabilities = parsed; + applied["KIMI_MODEL_CAPABILITIES"] = caps; + } + break; + } + case "openai_legacy": + case "openai_responses": { + const baseUrl = Bun.env.OPENAI_BASE_URL; + if (baseUrl) provider.baseUrl = baseUrl; + const apiKey = Bun.env.OPENAI_API_KEY; + if (apiKey) provider.apiKey = apiKey; + break; + } + default: + break; + } + + return applied; +} + +// ── Token Estimation ─────────────────────────────────────── + +/** + * Simple token count estimation (~4 chars per token). + */ +export function estimateTokenCount(text: string): number { + return Math.ceil(text.length / 4); +} + +/** + * Estimate tokens for an array of messages. + */ +export function estimateMessagesTokenCount(messages: Message[]): number { + let total = 0; + for (const msg of messages) { + if (typeof msg.content === "string") { + total += estimateTokenCount(msg.content); + } else { + for (const part of msg.content) { + if ("text" in part) { + total += estimateTokenCount((part as { text: string }).text); + } + } + } + // Overhead per message (role, separators) + total += 4; + } + return total; +} + +// ── Factory (placeholder providers) ──────────────────────── + +/** + * Create an LLM instance from provider and model config. + */ +export function createLLM( + provider: LLMProviderConfig, + model: LLMModelConfig, + options?: { + thinking?: boolean | null; + sessionId?: string | null; + } +): LLM | null { + if ( + provider.type !== "_echo" && + provider.type !== "_scripted_echo" && + (!provider.baseUrl || !model.model) + ) { + return null; + } + + const capabilities = deriveModelCapabilities(model); + + // Determine thinking mode + let thinkingMode: "high" | "off" | undefined; + if ( + capabilities.has("always_thinking") || + (options?.thinking === true && capabilities.has("thinking")) + ) { + thinkingMode = "high"; + } else if (options?.thinking === false) { + thinkingMode = "off"; + } + + // Create real provider based on type + let llmProvider: LLMProvider; + + switch (provider.type) { + case "kimi": + case "openai_legacy": + case "openai_responses": + llmProvider = new OpenAICompatibleProvider({ + baseUrl: provider.baseUrl, + apiKey: provider.apiKey, + modelName: model.model, + customHeaders: provider.customHeaders, + thinkingMode, + }); + break; + + case "_echo": + llmProvider = { + modelName: model.model, + async *chat(messages: Message[]) { + const lastMsg = messages[messages.length - 1]; + const text = lastMsg + ? typeof lastMsg.content === "string" + ? lastMsg.content + : "[echo]" + : "[empty]"; + yield { type: "text" as const, text }; + yield { + type: "usage" as const, + usage: { inputTokens: 10, outputTokens: text.length }, + }; + yield { type: "done" as const }; + }, + }; + break; + + default: + llmProvider = { + modelName: model.model, + async *chat() { + throw new Error( + `LLM provider "${provider.type}" is not yet implemented. Model: ${model.model}` + ); + }, + }; + break; + } + + return new LLM({ + provider: llmProvider, + maxContextSize: model.maxContextSize, + capabilities, + modelConfig: model, + providerConfig: provider, + }); +} + +// ── OpenAI-Compatible Provider ────────────────────────────── +// Works with Kimi API, OpenAI API, and any OpenAI-compatible endpoint. + +interface OpenAICompatibleProviderConfig { + baseUrl: string; + apiKey: string; + modelName: string; + customHeaders?: Record; + thinkingMode?: "high" | "low" | "off"; +} + +interface OpenAIMessage { + role: "system" | "user" | "assistant" | "tool"; + content?: string | OpenAIContentPart[] | null; + tool_calls?: OpenAIToolCall[]; + tool_call_id?: string; +} + +interface OpenAIContentPart { + type: string; + text?: string; + image_url?: { url: string }; +} + +interface OpenAIToolCall { + id: string; + type: "function"; + function: { name: string; arguments: string }; +} + +interface OpenAITool { + type: "function"; + function: { + name: string; + description: string; + parameters: Record; + }; +} + +class OpenAICompatibleProvider implements LLMProvider { + readonly modelName: string; + private baseUrl: string; + private apiKey: string; + private customHeaders: Record; + private thinkingMode?: "high" | "low" | "off"; + + constructor(config: OpenAICompatibleProviderConfig) { + this.modelName = config.modelName; + this.baseUrl = config.baseUrl.replace(/\/+$/, ""); + this.apiKey = config.apiKey; + this.customHeaders = config.customHeaders ?? {}; + this.thinkingMode = config.thinkingMode; + } + + async *chat( + messages: Message[], + options?: ChatOptions, + ): AsyncIterable { + // Convert messages to OpenAI format + const openaiMessages = this.convertMessages(messages, options?.system); + + // Build request body + const body: Record = { + model: this.modelName, + messages: openaiMessages, + stream: true, + stream_options: { include_usage: true }, + }; + + // Default max_tokens if not provided (Python Kimi provider defaults to 32000) + body.max_tokens = options?.maxTokens ?? 32000; + if (options?.temperature != null) body.temperature = options.temperature; + if (options?.topP != null) body.top_p = options.topP; + + // Thinking mode configuration (Kimi-specific) + const thinking = options?.thinking ?? this.thinkingMode; + if (thinking && thinking !== "off") { + body.reasoning_effort = thinking; // "high" | "low" + body.thinking = { type: "enabled" }; + } else if (thinking === "off") { + body.thinking = { type: "disabled" }; + } + + // Tools + if (options?.tools && options.tools.length > 0) { + body.tools = options.tools.map( + (t): OpenAITool => ({ + type: "function", + function: { + name: t.name, + description: t.description, + parameters: t.parameters, + }, + }), + ); + } + + // Fetch streaming response + const url = `${this.baseUrl}/chat/completions`; + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + ...this.customHeaders, + }; + + const response = await fetch(url, { + method: "POST", + headers, + body: JSON.stringify(body), + signal: options?.signal, + }); + + if (!response.ok) { + const text = await response.text(); + throw new Error( + `LLM API error ${response.status}: ${text.slice(0, 500)}`, + ); + } + + if (!response.body) { + throw new Error("LLM API returned no body"); + } + + // Parse SSE stream + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + let totalInputTokens = 0; + let totalOutputTokens = 0; + let cacheReadTokens = 0; + const pendingToolCalls = new Map< + number, + { id: string; name: string; arguments: string } + >(); + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed || trimmed === "data: [DONE]" || trimmed === "data:[DONE]") continue; + // Support both "data: {...}" and "data:{...}" (Kimi API omits the space) + if (!trimmed.startsWith("data:")) continue; + const jsonStr = trimmed.startsWith("data: ") ? trimmed.slice(6) : trimmed.slice(5); + let data: any; + try { + data = JSON.parse(jsonStr); + } catch { + continue; + } + + // Extract usage if present (handle both standard and Kimi-specific formats) + if (data.usage) { + const u = data.usage; + totalInputTokens = u.prompt_tokens ?? u.input_tokens ?? 0; + totalOutputTokens = u.completion_tokens ?? u.output_tokens ?? 0; + // Kimi-specific: cached_tokens at root level + cacheReadTokens = u.cached_tokens ?? u.prompt_tokens_details?.cached_tokens ?? 0; + } + // Kimi may also embed usage in choice + if (data.choices?.[0]?.usage) { + const cu = data.choices[0].usage; + totalInputTokens = cu.prompt_tokens ?? totalInputTokens; + totalOutputTokens = cu.completion_tokens ?? totalOutputTokens; + cacheReadTokens = cu.cached_tokens ?? cacheReadTokens; + } + + const choices = data.choices; + if (!choices || choices.length === 0) continue; + + const delta = choices[0].delta; + if (!delta) continue; + + // Text content + if (delta.content) { + yield { type: "text", text: delta.content }; + } + + // Reasoning/thinking content (Kimi k2.5 specific) + if (delta.reasoning_content) { + yield { type: "think", text: delta.reasoning_content }; + } + + // Tool calls + if (delta.tool_calls) { + for (const tc of delta.tool_calls) { + const idx = tc.index ?? 0; + if (tc.id) { + // New tool call + pendingToolCalls.set(idx, { + id: tc.id, + name: tc.function?.name ?? "", + arguments: tc.function?.arguments ?? "", + }); + } else if (pendingToolCalls.has(idx)) { + // Append to existing + const existing = pendingToolCalls.get(idx)!; + if (tc.function?.name) existing.name += tc.function.name; + if (tc.function?.arguments) + existing.arguments += tc.function.arguments; + } + } + } + + // Check for finish reason + if (choices[0].finish_reason) { + // Emit any pending tool calls + for (const [, tc] of pendingToolCalls) { + yield { + type: "tool_call", + id: tc.id, + name: tc.name, + arguments: tc.arguments, + }; + } + pendingToolCalls.clear(); + } + } + } + } finally { + reader.releaseLock(); + } + + // Emit usage + if (totalInputTokens > 0 || totalOutputTokens > 0) { + yield { + type: "usage", + usage: { + inputTokens: totalInputTokens, + outputTokens: totalOutputTokens, + ...(cacheReadTokens > 0 ? { cacheReadTokens } : {}), + }, + }; + } + + yield { type: "done" }; + } + + private convertMessages( + messages: Message[], + system?: string, + ): OpenAIMessage[] { + const result: OpenAIMessage[] = []; + + // System prompt + if (system) { + result.push({ role: "system", content: system }); + } + + for (const msg of messages) { + if (typeof msg.content === "string") { + result.push({ + role: msg.role as "user" | "assistant" | "system", + content: msg.content, + }); + } else { + // Complex content with parts + const textParts: string[] = []; + const toolUseParts: OpenAIToolCall[] = []; + const toolResultParts: { toolCallId: string; content: string }[] = []; + + for (const part of msg.content) { + switch (part.type) { + case "text": + textParts.push(part.text); + break; + case "tool_use": + toolUseParts.push({ + id: part.id, + type: "function", + function: { + name: part.name, + arguments: JSON.stringify(part.input), + }, + }); + break; + case "tool_result": + toolResultParts.push({ + toolCallId: part.toolUseId, + content: part.content, + }); + break; + case "image": + // Skip images for now + break; + } + } + + if (msg.role === "assistant" && toolUseParts.length > 0) { + const assistantMsg: OpenAIMessage = { + role: "assistant", + content: textParts.join("\n") || null, + tool_calls: toolUseParts, + }; + // Preserve reasoning_content for multi-turn thinking + if ((msg as any).reasoning_content) { + (assistantMsg as any).reasoning_content = (msg as any).reasoning_content; + } + result.push(assistantMsg); + } else if (msg.role === "assistant") { + const assistantMsg: OpenAIMessage = { + role: "assistant", + content: textParts.join("\n") || null, + }; + // Preserve reasoning_content for multi-turn thinking + if ((msg as any).reasoning_content) { + (assistantMsg as any).reasoning_content = (msg as any).reasoning_content; + } + result.push(assistantMsg); + } else if (toolResultParts.length > 0) { + // Tool results become individual tool messages + for (const tr of toolResultParts) { + result.push({ + role: "tool", + tool_call_id: tr.toolCallId, + content: tr.content, + }); + } + } else { + result.push({ + role: msg.role as "user" | "assistant" | "system", + content: textParts.join("\n"), + }); + } + } + } + + return result; + } +} diff --git a/src/kimi_cli/metadata.py b/src/kimi_cli/metadata.py deleted file mode 100644 index 88242d89e..000000000 --- a/src/kimi_cli/metadata.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import json -from hashlib import md5 -from pathlib import Path - -from kaos import get_current_kaos -from kaos.local import local_kaos -from kaos.path import KaosPath -from pydantic import BaseModel, ConfigDict, Field - -from kimi_cli.share import get_share_dir -from kimi_cli.utils.io import atomic_json_write -from kimi_cli.utils.logging import logger - - -def get_metadata_file() -> Path: - return get_share_dir() / "kimi.json" - - -class WorkDirMeta(BaseModel): - """Metadata for a work directory.""" - - path: str - """The full path of the work directory.""" - - kaos: str = local_kaos.name - """The name of the KAOS where the work directory is located.""" - - last_session_id: str | None = None - """Last session ID of this work directory.""" - - @property - def sessions_dir(self) -> Path: - """The directory to store sessions for this work directory.""" - path_md5 = md5(self.path.encode(encoding="utf-8")).hexdigest() - dir_basename = path_md5 if self.kaos == local_kaos.name else f"{self.kaos}_{path_md5}" - session_dir = get_share_dir() / "sessions" / dir_basename - session_dir.mkdir(parents=True, exist_ok=True) - return session_dir - - -class Metadata(BaseModel): - """Kimi metadata structure.""" - - model_config = ConfigDict(extra="ignore") - - work_dirs: list[WorkDirMeta] = Field(default_factory=list[WorkDirMeta]) - """Work directory list.""" - - def get_work_dir_meta(self, path: KaosPath) -> WorkDirMeta | None: - """Get the metadata for a work directory.""" - for wd in self.work_dirs: - if wd.path == str(path) and wd.kaos == get_current_kaos().name: - return wd - return None - - def new_work_dir_meta(self, path: KaosPath) -> WorkDirMeta: - """Create a new work directory metadata.""" - wd_meta = WorkDirMeta(path=str(path), kaos=get_current_kaos().name) - self.work_dirs.append(wd_meta) - return wd_meta - - -def load_metadata() -> Metadata: - metadata_file = get_metadata_file() - logger.debug("Loading metadata from file: {file}", file=metadata_file) - if not metadata_file.exists(): - logger.debug("No metadata file found, creating empty metadata") - return Metadata() - with open(metadata_file, encoding="utf-8") as f: - data = json.load(f) - return Metadata(**data) - - -def save_metadata(metadata: Metadata): - metadata_file = get_metadata_file() - logger.debug("Saving metadata to file: {file}", file=metadata_file) - atomic_json_write(metadata.model_dump(), metadata_file) diff --git a/src/kimi_cli/metadata.ts b/src/kimi_cli/metadata.ts new file mode 100644 index 000000000..2eebf02a1 --- /dev/null +++ b/src/kimi_cli/metadata.ts @@ -0,0 +1,118 @@ +/** + * Metadata module — corresponds to Python metadata.py + * Tracks work directories and their sessions using MD5 hashes for directory names. + */ + +import { createHash } from "node:crypto"; +import { join } from "node:path"; +import { getShareDir } from "./config.ts"; +import { logger } from "./utils/logging.ts"; + +// ── Metadata file ──────────────────────────────────────── + +export function getMetadataFile(): string { + return join(getShareDir(), "kimi.json"); +} + +// ── WorkDirMeta ────────────────────────────────────────── + +export interface WorkDirMeta { + /** The full path of the work directory. */ + path: string; + /** The name of the KAOS where the work directory is located. */ + kaos: string; + /** Last session ID of this work directory. */ + lastSessionId: string | null; +} + +/** Compute the sessions directory for a work directory using MD5 hash (compatible with Python). */ +export function getSessionsDir(workDirMeta: WorkDirMeta): string { + const pathMd5 = createHash("md5") + .update(workDirMeta.path, "utf-8") + .digest("hex"); + // For local kaos, just use the MD5 hash; otherwise prefix with kaos name + const localKaos = "local"; + const dirBasename = + workDirMeta.kaos === localKaos + ? pathMd5 + : `${workDirMeta.kaos}_${pathMd5}`; + return join(getShareDir(), "sessions", dirBasename); +} + +// ── Metadata ───────────────────────────────────────────── + +export interface Metadata { + workDirs: WorkDirMeta[]; +} + +/** Get the metadata for a work directory. */ +export function getWorkDirMeta( + metadata: Metadata, + path: string, + kaos = "local", +): WorkDirMeta | null { + for (const wd of metadata.workDirs) { + if (wd.path === path && wd.kaos === kaos) { + return wd; + } + } + return null; +} + +/** Create a new work directory metadata entry. */ +export function newWorkDirMeta( + metadata: Metadata, + path: string, + kaos = "local", +): WorkDirMeta { + const wdMeta: WorkDirMeta = { + path, + kaos, + lastSessionId: null, + }; + metadata.workDirs.push(wdMeta); + return wdMeta; +} + +// ── Load / Save ────────────────────────────────────────── + +export async function loadMetadata(): Promise { + const metadataFile = getMetadataFile(); + logger.debug(`Loading metadata from file: ${metadataFile}`); + const file = Bun.file(metadataFile); + if (!(await file.exists())) { + logger.debug("No metadata file found, creating empty metadata"); + return { workDirs: [] }; + } + try { + const data = await file.json(); + // Map Python-style snake_case to camelCase + const workDirs: WorkDirMeta[] = (data.work_dirs ?? data.workDirs ?? []).map( + (wd: any) => ({ + path: wd.path ?? "", + kaos: wd.kaos ?? "local", + lastSessionId: wd.last_session_id ?? wd.lastSessionId ?? null, + }), + ); + return { workDirs }; + } catch (err) { + logger.warn(`Failed to load metadata: ${err}`); + return { workDirs: [] }; + } +} + +export async function saveMetadata(metadata: Metadata): Promise { + const metadataFile = getMetadataFile(); + logger.debug(`Saving metadata to file: ${metadataFile}`); + const dir = metadataFile.substring(0, metadataFile.lastIndexOf("/")); + await Bun.$`mkdir -p ${dir}`.quiet(); + // Save in Python-compatible snake_case format + const data = { + work_dirs: metadata.workDirs.map((wd) => ({ + path: wd.path, + kaos: wd.kaos, + last_session_id: wd.lastSessionId, + })), + }; + await Bun.write(metadataFile, JSON.stringify(data, null, 2)); +} diff --git a/src/kimi_cli/notifications/__init__.py b/src/kimi_cli/notifications/__init__.py deleted file mode 100644 index db0d3cd2d..000000000 --- a/src/kimi_cli/notifications/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -from .llm import build_notification_message, extract_notification_ids, is_notification_message -from .manager import NotificationManager -from .models import ( - NotificationCategory, - NotificationDelivery, - NotificationDeliveryStatus, - NotificationEvent, - NotificationSeverity, - NotificationSink, - NotificationSinkState, - NotificationView, -) -from .notifier import NotificationWatcher -from .store import NotificationStore -from .wire import to_wire_notification - -__all__ = [ - "NotificationCategory", - "NotificationDelivery", - "NotificationDeliveryStatus", - "NotificationEvent", - "NotificationManager", - "NotificationSeverity", - "NotificationSink", - "NotificationSinkState", - "NotificationStore", - "NotificationView", - "NotificationWatcher", - "build_notification_message", - "extract_notification_ids", - "is_notification_message", - "to_wire_notification", -] diff --git a/src/kimi_cli/notifications/index.ts b/src/kimi_cli/notifications/index.ts new file mode 100644 index 000000000..1eda2b2c3 --- /dev/null +++ b/src/kimi_cli/notifications/index.ts @@ -0,0 +1,21 @@ +/** + * Notification system — corresponds to Python notifications/ + */ + +export type { + NotificationCategory, + NotificationSeverity, + NotificationSink, + NotificationDeliveryStatus, + NotificationEvent, + NotificationSinkState, + NotificationDelivery, + NotificationView, +} from "./models.ts"; +export { newNotificationEvent, eventToJson, eventFromJson, deliveryToJson, deliveryFromJson } from "./models.ts"; +export { NotificationStore } from "./store.ts"; +export { NotificationManager } from "./manager.ts"; +export type { NotificationConfig } from "./manager.ts"; +export { NotificationWatcher } from "./notifier.ts"; +export { toWireNotification } from "./wire.ts"; +export type { WireNotification } from "./wire.ts"; diff --git a/src/kimi_cli/notifications/llm.py b/src/kimi_cli/notifications/llm.py deleted file mode 100644 index c922bef20..000000000 --- a/src/kimi_cli/notifications/llm.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -import re -from collections.abc import Sequence -from typing import TYPE_CHECKING - -from kosong.message import Message - -from kimi_cli.wire.types import TextPart - -from .models import NotificationView - -if TYPE_CHECKING: - from kimi_cli.soul.agent import Runtime - -_NOTIFICATION_ID_RE = re.compile(r' Message: - event = view.event - lines = [ - ( - f'' - ), - f"Title: {event.title}", - f"Severity: {event.severity}", - event.body, - ] - - if event.category == "task" and event.source_kind == "background_task": - task_view = runtime.background_tasks.get_task(event.source_id) - if task_view is not None: - tail = runtime.background_tasks.tail_output( - task_view.spec.id, - max_bytes=runtime.config.background.notification_tail_chars, - max_lines=runtime.config.background.notification_tail_lines, - ) - lines.extend( - [ - "", - f"Task ID: {task_view.spec.id}", - f"Task Type: {task_view.spec.kind}", - f"Description: {task_view.spec.description}", - f"Status: {task_view.runtime.status}", - ] - ) - if task_view.runtime.exit_code is not None: - lines.append(f"Exit code: {task_view.runtime.exit_code}") - if task_view.runtime.failure_reason: - lines.append(f"Failure reason: {task_view.runtime.failure_reason}") - if tail: - lines.extend(["Output tail:", tail]) - lines.append("") - - lines.append("") - return Message(role="user", content=[TextPart(text="\n".join(lines))]) - - -def extract_notification_ids(history: Sequence[Message]) -> set[str]: - ids: set[str] = set() - for message in history: - if message.role != "user": - continue - for part in message.content: - if not isinstance(part, TextPart): - continue - for match in _NOTIFICATION_ID_RE.finditer(part.text): - ids.add(match.group(1)) - return ids - - -def is_notification_message(message: Message) -> bool: - if message.role != "user" or len(message.content) != 1: - return False - part = message.content[0] - return isinstance(part, TextPart) and part.text.lstrip().startswith(" None: - self._config = config - self._store = NotificationStore(root) - - @property - def store(self) -> NotificationStore: - return self._store - - def new_id(self) -> str: - return f"n{uuid.uuid4().hex[:8]}" - - def _initial_delivery(self, event: NotificationEvent) -> NotificationDelivery: - return NotificationDelivery(sinks={sink: NotificationSinkState() for sink in event.targets}) - - def find_by_dedupe_key(self, dedupe_key: str) -> NotificationView | None: - for view in self._store.list_views(): - if view.event.dedupe_key == dedupe_key: - return view - return None - - def publish(self, event: NotificationEvent) -> NotificationView: - if event.dedupe_key: - existing = self.find_by_dedupe_key(event.dedupe_key) - if existing is not None: - return existing - delivery = self._initial_delivery(event) - self._store.create_notification(event, delivery) - return NotificationView(event=event, delivery=delivery) - - def recover(self) -> None: - now = time.time() - stale_after = self._config.claim_stale_after_ms / 1000 - for view in self._store.list_views(): - updated = False - delivery = view.delivery.model_copy(deep=True) - for sink_state in delivery.sinks.values(): - if sink_state.status != "claimed" or sink_state.claimed_at is None: - continue - if now - sink_state.claimed_at <= stale_after: - continue - sink_state.status = "pending" - sink_state.claimed_at = None - updated = True - if updated: - self._store.write_delivery(view.event.id, delivery) - - def has_pending_for_sink(self, sink: str) -> bool: - """Check whether any notification has a pending delivery for *sink*.""" - for view in self._store.list_views(): - sink_state = view.delivery.sinks.get(sink) - if sink_state is not None and sink_state.status == "pending": - return True - return False - - def claim_for_sink(self, sink: str, *, limit: int = 8) -> list[NotificationView]: - self.recover() - claimed: list[NotificationView] = [] - now = time.time() - for view in reversed(self._store.list_views()): - sink_state = view.delivery.sinks.get(sink) - if sink_state is None or sink_state.status == "acked": - continue - if sink_state.status == "claimed": - continue - delivery = view.delivery.model_copy(deep=True) - target_state = delivery.sinks[sink] - target_state.status = "claimed" - target_state.claimed_at = now - self._store.write_delivery(view.event.id, delivery) - claimed.append(NotificationView(event=view.event, delivery=delivery)) - if len(claimed) >= limit: - break - return claimed - - async def deliver_pending( - self, - sink: str, - *, - on_notification: Callable[[NotificationView], Awaitable[None] | None], - limit: int = 8, - before_claim: Callable[[], object] | None = None, - ) -> list[NotificationView]: - """Deliver pending notifications for one sink using a shared claim/ack flow. - - If the handler raises for a notification, the error is logged and that - notification stays in ``claimed`` state (will be recovered later). - Delivery continues for remaining notifications. - """ - if before_claim is not None: - before_claim() - - delivered: list[NotificationView] = [] - for view in self.claim_for_sink(sink, limit=limit): - try: - result = on_notification(view) - if result is not None: - await result - except Exception: - logger.exception( - "Notification handler failed for {sink}/{id}, leaving claimed for recovery", - sink=sink, - id=view.event.id, - ) - continue - delivered.append(self.ack(sink, view.event.id)) - return delivered - - def ack(self, sink: str, notification_id: str) -> NotificationView: - view = self._store.merged_view(notification_id) - delivery = view.delivery.model_copy(deep=True) - sink_state = delivery.sinks.get(sink) - if sink_state is None: - return view - sink_state.status = "acked" - sink_state.acked_at = time.time() - sink_state.claimed_at = None - self._store.write_delivery(notification_id, delivery) - return NotificationView(event=view.event, delivery=delivery) - - def ack_ids(self, sink: str, notification_ids: set[str]) -> None: - for notification_id in notification_ids: - try: - self.ack(sink, notification_id) - except (FileNotFoundError, ValueError): - continue diff --git a/src/kimi_cli/notifications/manager.ts b/src/kimi_cli/notifications/manager.ts new file mode 100644 index 000000000..7655d0944 --- /dev/null +++ b/src/kimi_cli/notifications/manager.ts @@ -0,0 +1,156 @@ +/** + * Notification manager — corresponds to Python notifications/manager.py + * Publishes notifications with deduplication and claim/ack delivery flow. + */ + +import { randomUUID } from "node:crypto"; +import { logger } from "../utils/logging.ts"; +import type { + NotificationEvent, + NotificationDelivery, + NotificationSinkState, + NotificationView, +} from "./models.ts"; +import { NotificationStore } from "./store.ts"; + +export interface NotificationConfig { + claimStaleAfterMs: number; +} + +export class NotificationManager { + private _config: NotificationConfig; + private _store: NotificationStore; + + constructor(root: string, config?: Partial) { + this._config = { claimStaleAfterMs: config?.claimStaleAfterMs ?? 15_000 }; + this._store = new NotificationStore(root); + } + + get store(): NotificationStore { + return this._store; + } + + newId(): string { + return `n${randomUUID().replace(/-/g, "").slice(0, 8)}`; + } + + private initialDelivery(event: NotificationEvent): NotificationDelivery { + const sinks: Record = {}; + for (const sink of event.targets) { + sinks[sink] = { status: "pending" }; + } + return { sinks }; + } + + findByDedupeKey(dedupeKey: string): NotificationView | undefined { + for (const view of this._store.listViews()) { + if (view.event.dedupeKey === dedupeKey) { + return view; + } + } + return undefined; + } + + publish(event: NotificationEvent): NotificationView { + if (event.dedupeKey) { + const existing = this.findByDedupeKey(event.dedupeKey); + if (existing) return existing; + } + const delivery = this.initialDelivery(event); + this._store.createNotification(event, delivery); + return { event, delivery }; + } + + recover(): void { + const now = Date.now() / 1000; + const staleAfter = this._config.claimStaleAfterMs / 1000; + for (const view of this._store.listViews()) { + let updated = false; + const delivery = structuredClone(view.delivery); + for (const sinkState of Object.values(delivery.sinks)) { + if (sinkState.status !== "claimed" || sinkState.claimedAt == null) continue; + if (now - sinkState.claimedAt <= staleAfter) continue; + sinkState.status = "pending"; + sinkState.claimedAt = undefined; + updated = true; + } + if (updated) { + this._store.writeDelivery(view.event.id, delivery); + } + } + } + + hasPendingForSink(sink: string): boolean { + for (const view of this._store.listViews()) { + const sinkState = view.delivery.sinks[sink]; + if (sinkState && sinkState.status === "pending") return true; + } + return false; + } + + claimForSink(sink: string, limit = 8): NotificationView[] { + this.recover(); + const claimed: NotificationView[] = []; + const now = Date.now() / 1000; + const views = this._store.listViews(); + // Process in reverse (oldest first) + for (let i = views.length - 1; i >= 0; i--) { + const view = views[i]!; + const sinkState = view.delivery.sinks[sink]; + if (!sinkState || sinkState.status === "acked" || sinkState.status === "claimed") continue; + const delivery = structuredClone(view.delivery); + const targetState = delivery.sinks[sink]!; + targetState.status = "claimed"; + targetState.claimedAt = now; + this._store.writeDelivery(view.event.id, delivery); + claimed.push({ event: view.event, delivery }); + if (claimed.length >= limit) break; + } + return claimed; + } + + async deliverPending( + sink: string, + opts: { + onNotification: (view: NotificationView) => Promise | void; + limit?: number; + beforeClaim?: () => void; + }, + ): Promise { + if (opts.beforeClaim) opts.beforeClaim(); + const delivered: NotificationView[] = []; + for (const view of this.claimForSink(sink, opts.limit ?? 8)) { + try { + const result = opts.onNotification(view); + if (result instanceof Promise) await result; + } catch { + logger.warn(`Notification handler failed for ${sink}/${view.event.id}, leaving claimed`); + continue; + } + delivered.push(this.ack(sink, view.event.id)); + } + return delivered; + } + + ack(sink: string, notificationId: string): NotificationView { + const view = this._store.mergedView(notificationId); + const delivery = structuredClone(view.delivery); + const sinkState = delivery.sinks[sink]; + if (!sinkState) return view; + sinkState.status = "acked"; + sinkState.ackedAt = Date.now() / 1000; + sinkState.claimedAt = undefined; + this._store.writeDelivery(notificationId, delivery); + return { event: view.event, delivery }; + } + + ackIds(sink: string, notificationIds: Set): void { + for (const id of notificationIds) { + try { + this.ack(sink, id); + } catch { + // ignore missing + } + } + } +} diff --git a/src/kimi_cli/notifications/models.py b/src/kimi_cli/notifications/models.py deleted file mode 100644 index 33918c1f5..000000000 --- a/src/kimi_cli/notifications/models.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -import time -from typing import Any, Literal - -from pydantic import BaseModel, ConfigDict, Field - -type NotificationCategory = Literal["task", "agent", "system"] -type NotificationSeverity = Literal["info", "success", "warning", "error"] -type NotificationSink = Literal["llm", "wire", "shell"] -type NotificationDeliveryStatus = Literal["pending", "claimed", "acked"] - - -class NotificationEvent(BaseModel): - model_config = ConfigDict(extra="ignore") - - version: int = 1 - id: str - category: NotificationCategory - type: str - source_kind: str - source_id: str - title: str - body: str - severity: NotificationSeverity = "info" - created_at: float = Field(default_factory=time.time) - payload: dict[str, Any] = Field(default_factory=dict) - targets: list[NotificationSink] = Field(default_factory=lambda: ["llm", "wire", "shell"]) - dedupe_key: str | None = None - - -class NotificationSinkState(BaseModel): - model_config = ConfigDict(extra="ignore") - - status: NotificationDeliveryStatus = "pending" - claimed_at: float | None = None - acked_at: float | None = None - - -class NotificationDelivery(BaseModel): - model_config = ConfigDict(extra="ignore") - - sinks: dict[str, NotificationSinkState] = Field(default_factory=dict) - - -class NotificationView(BaseModel): - model_config = ConfigDict(extra="ignore") - - event: NotificationEvent - delivery: NotificationDelivery diff --git a/src/kimi_cli/notifications/models.ts b/src/kimi_cli/notifications/models.ts new file mode 100644 index 000000000..43b45b94f --- /dev/null +++ b/src/kimi_cli/notifications/models.ts @@ -0,0 +1,132 @@ +/** + * Notification models — corresponds to Python notifications/models.py + */ + +export type NotificationCategory = "task" | "agent" | "system"; +export type NotificationSeverity = "info" | "success" | "warning" | "error"; +export type NotificationSink = "llm" | "wire" | "shell"; +export type NotificationDeliveryStatus = "pending" | "claimed" | "acked"; + +export interface NotificationEvent { + version: number; + id: string; + category: NotificationCategory; + type: string; + sourceKind: string; + sourceId: string; + title: string; + body: string; + severity: NotificationSeverity; + createdAt: number; + payload: Record; + targets: NotificationSink[]; + dedupeKey?: string; +} + +export interface NotificationSinkState { + status: NotificationDeliveryStatus; + claimedAt?: number; + ackedAt?: number; +} + +export interface NotificationDelivery { + sinks: Record; +} + +export interface NotificationView { + event: NotificationEvent; + delivery: NotificationDelivery; +} + +// ── JSON serialization helpers (snake_case ↔ camelCase) ── + +export function eventToJson(e: NotificationEvent): Record { + return { + version: e.version, + id: e.id, + category: e.category, + type: e.type, + source_kind: e.sourceKind, + source_id: e.sourceId, + title: e.title, + body: e.body, + severity: e.severity, + created_at: e.createdAt, + payload: e.payload, + targets: e.targets, + dedupe_key: e.dedupeKey, + }; +} + +export function eventFromJson(data: Record): NotificationEvent { + return { + version: Number(data.version ?? 1), + id: String(data.id), + category: String(data.category ?? "system") as NotificationCategory, + type: String(data.type ?? ""), + sourceKind: String(data.source_kind ?? data.sourceKind ?? ""), + sourceId: String(data.source_id ?? data.sourceId ?? ""), + title: String(data.title ?? ""), + body: String(data.body ?? ""), + severity: String(data.severity ?? "info") as NotificationSeverity, + createdAt: Number(data.created_at ?? data.createdAt ?? Date.now() / 1000), + payload: (data.payload as Record) ?? {}, + targets: (data.targets as NotificationSink[]) ?? ["llm", "wire", "shell"], + dedupeKey: (data.dedupe_key ?? data.dedupeKey) as string | undefined, + }; +} + +export function deliveryToJson(d: NotificationDelivery): Record { + const sinks: Record = {}; + for (const [key, state] of Object.entries(d.sinks)) { + sinks[key] = { + status: state.status, + claimed_at: state.claimedAt, + acked_at: state.ackedAt, + }; + } + return { sinks }; +} + +export function deliveryFromJson(data: Record): NotificationDelivery { + const rawSinks = (data.sinks as Record>) ?? {}; + const sinks: Record = {}; + for (const [key, raw] of Object.entries(rawSinks)) { + sinks[key] = { + status: String(raw.status ?? "pending") as NotificationDeliveryStatus, + claimedAt: raw.claimed_at != null ? Number(raw.claimed_at) : undefined, + ackedAt: raw.acked_at != null ? Number(raw.acked_at) : undefined, + }; + } + return { sinks }; +} + +export function newNotificationEvent(opts: { + id: string; + category: NotificationCategory; + type: string; + sourceKind: string; + sourceId: string; + title: string; + body: string; + severity?: NotificationSeverity; + payload?: Record; + targets?: NotificationSink[]; + dedupeKey?: string; +}): NotificationEvent { + return { + version: 1, + id: opts.id, + category: opts.category, + type: opts.type, + sourceKind: opts.sourceKind, + sourceId: opts.sourceId, + title: opts.title, + body: opts.body, + severity: opts.severity ?? "info", + createdAt: Date.now() / 1000, + payload: opts.payload ?? {}, + targets: opts.targets ?? ["llm", "wire", "shell"], + dedupeKey: opts.dedupeKey, + }; +} diff --git a/src/kimi_cli/notifications/notifier.py b/src/kimi_cli/notifications/notifier.py deleted file mode 100644 index 161780f80..000000000 --- a/src/kimi_cli/notifications/notifier.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -from collections.abc import Awaitable, Callable - -from kimi_cli.utils.logging import logger - -from .manager import NotificationManager -from .models import NotificationSink, NotificationView - - -class NotificationWatcher: - def __init__( - self, - manager: NotificationManager, - *, - sink: NotificationSink, - on_notification: Callable[[NotificationView], Awaitable[None] | None], - before_poll: Callable[[], object] | None = None, - interval_s: float = 1.0, - ) -> None: - self._manager = manager - self._sink = sink - self._on_notification = on_notification - self._before_poll = before_poll - self._interval_s = interval_s - - async def poll_once(self) -> list[NotificationView]: - return await self._manager.deliver_pending( - self._sink, - on_notification=self._on_notification, - before_claim=self._before_poll, - ) - - async def run_forever(self) -> None: - while True: - try: - await self.poll_once() - except asyncio.CancelledError: - raise - except Exception: - logger.exception("NotificationWatcher poll failed") - await asyncio.sleep(self._interval_s) diff --git a/src/kimi_cli/notifications/notifier.ts b/src/kimi_cli/notifications/notifier.ts new file mode 100644 index 000000000..9798a7eb3 --- /dev/null +++ b/src/kimi_cli/notifications/notifier.ts @@ -0,0 +1,49 @@ +/** + * Notification watcher — corresponds to Python notifications/notifier.py + * Polls the manager for pending notifications on a given sink. + */ + +import { logger } from "../utils/logging.ts"; +import type { NotificationManager } from "./manager.ts"; +import type { NotificationSink, NotificationView } from "./models.ts"; + +export class NotificationWatcher { + private _manager: NotificationManager; + private _sink: NotificationSink; + private _onNotification: (view: NotificationView) => Promise | void; + private _beforePoll?: () => void; + private _intervalS: number; + + constructor(opts: { + manager: NotificationManager; + sink: NotificationSink; + onNotification: (view: NotificationView) => Promise | void; + beforePoll?: () => void; + intervalS?: number; + }) { + this._manager = opts.manager; + this._sink = opts.sink; + this._onNotification = opts.onNotification; + this._beforePoll = opts.beforePoll; + this._intervalS = opts.intervalS ?? 1.0; + } + + async pollOnce(): Promise { + return this._manager.deliverPending(this._sink, { + onNotification: this._onNotification, + beforeClaim: this._beforePoll, + }); + } + + async runForever(signal?: AbortSignal): Promise { + while (!signal?.aborted) { + try { + await this.pollOnce(); + } catch (err) { + if (signal?.aborted) break; + logger.error("NotificationWatcher poll failed"); + } + await Bun.sleep(this._intervalS * 1000); + } + } +} diff --git a/src/kimi_cli/notifications/store.py b/src/kimi_cli/notifications/store.py deleted file mode 100644 index f1f0601ed..000000000 --- a/src/kimi_cli/notifications/store.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -import re -from pathlib import Path - -from kimi_cli.utils.io import atomic_json_write - -from .models import NotificationDelivery, NotificationEvent, NotificationView - -_VALID_NOTIFICATION_ID = re.compile(r"^[a-z0-9]{2,20}$") - - -def _validate_notification_id(notification_id: str) -> None: - if not _VALID_NOTIFICATION_ID.match(notification_id): - raise ValueError(f"Invalid notification_id: {notification_id!r}") - - -class NotificationStore: - EVENT_FILE = "event.json" - DELIVERY_FILE = "delivery.json" - - def __init__(self, root: Path): - self._root = root - - @property - def root(self) -> Path: - return self._root - - def _ensure_root(self) -> Path: - """Return the root directory, creating it if it does not exist.""" - self._root.mkdir(parents=True, exist_ok=True) - return self._root - - def notification_dir(self, notification_id: str) -> Path: - _validate_notification_id(notification_id) - path = self._ensure_root() / notification_id - path.mkdir(parents=True, exist_ok=True) - return path - - def notification_path(self, notification_id: str) -> Path: - _validate_notification_id(notification_id) - return self.root / notification_id - - def event_path(self, notification_id: str) -> Path: - return self.notification_path(notification_id) / self.EVENT_FILE - - def delivery_path(self, notification_id: str) -> Path: - return self.notification_path(notification_id) / self.DELIVERY_FILE - - def create_notification( - self, - event: NotificationEvent, - delivery: NotificationDelivery, - ) -> None: - notification_dir = self.notification_dir(event.id) - atomic_json_write(event.model_dump(mode="json"), notification_dir / self.EVENT_FILE) - atomic_json_write(delivery.model_dump(mode="json"), notification_dir / self.DELIVERY_FILE) - - def list_notification_ids(self) -> list[str]: - if not self.root.exists(): - return [] - notification_ids: list[str] = [] - for path in sorted(self.root.iterdir()): - if not path.is_dir(): - continue - if not (path / self.EVENT_FILE).exists(): - continue - notification_ids.append(path.name) - return notification_ids - - def read_event(self, notification_id: str) -> NotificationEvent: - return NotificationEvent.model_validate_json( - self.event_path(notification_id).read_text(encoding="utf-8") - ) - - def write_event(self, event: NotificationEvent) -> None: - atomic_json_write(event.model_dump(mode="json"), self.event_path(event.id)) - - def read_delivery(self, notification_id: str) -> NotificationDelivery: - path = self.delivery_path(notification_id) - if not path.exists(): - return NotificationDelivery() - return NotificationDelivery.model_validate_json(path.read_text(encoding="utf-8")) - - def write_delivery(self, notification_id: str, delivery: NotificationDelivery) -> None: - atomic_json_write(delivery.model_dump(mode="json"), self.delivery_path(notification_id)) - - def merged_view(self, notification_id: str) -> NotificationView: - return NotificationView( - event=self.read_event(notification_id), - delivery=self.read_delivery(notification_id), - ) - - def list_views(self) -> list[NotificationView]: - views = [ - self.merged_view(notification_id) for notification_id in self.list_notification_ids() - ] - views.sort(key=lambda view: view.event.created_at, reverse=True) - return views diff --git a/src/kimi_cli/notifications/store.ts b/src/kimi_cli/notifications/store.ts new file mode 100644 index 000000000..af340f4d6 --- /dev/null +++ b/src/kimi_cli/notifications/store.ts @@ -0,0 +1,138 @@ +/** + * Notification store — corresponds to Python notifications/store.py + * File-based persistence for notification events and delivery state. + */ + +import { join } from "node:path"; +import { + existsSync, + mkdirSync, + readFileSync, + writeFileSync, + readdirSync, + statSync, +} from "node:fs"; +import { + type NotificationEvent, + type NotificationDelivery, + type NotificationView, + type NotificationSinkState, + eventToJson, + eventFromJson, + deliveryToJson, + deliveryFromJson, +} from "./models.ts"; + +const VALID_NOTIFICATION_ID = /^[a-z0-9]{2,20}$/; + +function validateNotificationId(id: string): void { + if (!VALID_NOTIFICATION_ID.test(id)) { + throw new Error(`Invalid notification_id: ${id}`); + } +} + +function atomicJsonWrite(data: Record, filePath: string): void { + const tmpPath = filePath + ".tmp"; + writeFileSync(tmpPath, JSON.stringify(data, null, 2), "utf-8"); + const { renameSync } = require("node:fs"); + renameSync(tmpPath, filePath); +} + +export class NotificationStore { + static readonly EVENT_FILE = "event.json"; + static readonly DELIVERY_FILE = "delivery.json"; + + private _root: string; + + constructor(root: string) { + this._root = root; + } + + get root(): string { + return this._root; + } + + private ensureRoot(): string { + if (!existsSync(this._root)) { + mkdirSync(this._root, { recursive: true }); + } + return this._root; + } + + notificationDir(notificationId: string): string { + validateNotificationId(notificationId); + const path = join(this.ensureRoot(), notificationId); + if (!existsSync(path)) { + mkdirSync(path, { recursive: true }); + } + return path; + } + + notificationPath(notificationId: string): string { + validateNotificationId(notificationId); + return join(this._root, notificationId); + } + + eventPath(notificationId: string): string { + return join(this.notificationPath(notificationId), NotificationStore.EVENT_FILE); + } + + deliveryPath(notificationId: string): string { + return join(this.notificationPath(notificationId), NotificationStore.DELIVERY_FILE); + } + + createNotification(event: NotificationEvent, delivery: NotificationDelivery): void { + const dir = this.notificationDir(event.id); + atomicJsonWrite(eventToJson(event), join(dir, NotificationStore.EVENT_FILE)); + atomicJsonWrite(deliveryToJson(delivery), join(dir, NotificationStore.DELIVERY_FILE)); + } + + listNotificationIds(): string[] { + if (!existsSync(this._root)) return []; + const ids: string[] = []; + for (const entry of readdirSync(this._root).sort()) { + const dirPath = join(this._root, entry); + try { + if (!statSync(dirPath).isDirectory()) continue; + } catch { + continue; + } + if (!existsSync(join(dirPath, NotificationStore.EVENT_FILE))) continue; + ids.push(entry); + } + return ids; + } + + readEvent(notificationId: string): NotificationEvent { + const data = JSON.parse(readFileSync(this.eventPath(notificationId), "utf-8")); + return eventFromJson(data); + } + + writeEvent(event: NotificationEvent): void { + atomicJsonWrite(eventToJson(event), this.eventPath(event.id)); + } + + readDelivery(notificationId: string): NotificationDelivery { + const path = this.deliveryPath(notificationId); + if (!existsSync(path)) return { sinks: {} }; + const data = JSON.parse(readFileSync(path, "utf-8")); + return deliveryFromJson(data); + } + + writeDelivery(notificationId: string, delivery: NotificationDelivery): void { + atomicJsonWrite(deliveryToJson(delivery), this.deliveryPath(notificationId)); + } + + mergedView(notificationId: string): NotificationView { + return { + event: this.readEvent(notificationId), + delivery: this.readDelivery(notificationId), + }; + } + + listViews(): NotificationView[] { + const views = this.listNotificationIds().map((id) => this.mergedView(id)); + views.sort((a, b) => b.event.createdAt - a.event.createdAt); + return views; + } +} diff --git a/src/kimi_cli/notifications/wire.py b/src/kimi_cli/notifications/wire.py deleted file mode 100644 index 341e31846..000000000 --- a/src/kimi_cli/notifications/wire.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from kimi_cli.wire.types import Notification - -from .models import NotificationView - - -def to_wire_notification(view: NotificationView) -> Notification: - event = view.event - return Notification( - id=event.id, - category=event.category, - type=event.type, - source_kind=event.source_kind, - source_id=event.source_id, - title=event.title, - body=event.body, - severity=event.severity, - created_at=event.created_at, - payload=event.payload, - ) diff --git a/src/kimi_cli/notifications/wire.ts b/src/kimi_cli/notifications/wire.ts new file mode 100644 index 000000000..5de3d24aa --- /dev/null +++ b/src/kimi_cli/notifications/wire.ts @@ -0,0 +1,35 @@ +/** + * Wire notification bridge — corresponds to Python notifications/wire.py + * Converts NotificationView to wire protocol Notification. + */ + +import type { NotificationView } from "./models.ts"; + +export interface WireNotification { + id: string; + category: string; + type: string; + source_kind: string; + source_id: string; + title: string; + body: string; + severity: string; + created_at: number; + payload: Record; +} + +export function toWireNotification(view: NotificationView): WireNotification { + const e = view.event; + return { + id: e.id, + category: e.category, + type: e.type, + source_kind: e.sourceKind, + source_id: e.sourceId, + title: e.title, + body: e.body, + severity: e.severity, + created_at: e.createdAt, + payload: e.payload, + }; +} diff --git a/src/kimi_cli/plugin/__init__.py b/src/kimi_cli/plugin/__init__.py deleted file mode 100644 index 7359c3c43..000000000 --- a/src/kimi_cli/plugin/__init__.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Plugin specification parsing and config injection.""" - -import json -from pathlib import Path -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - - -class PluginError(Exception): - """Raised when plugin.json is invalid or an operation fails.""" - - -class PluginRuntime(BaseModel): - """Runtime information written by the host after installation.""" - - host: str - host_version: str - - -class PluginToolSpec(BaseModel): - """A tool declared by a plugin.""" - - name: str - description: str - command: list[str] - parameters: dict[str, object] = Field(default_factory=dict) - - -class PluginSpec(BaseModel): - """Parsed representation of a plugin.json file.""" - - model_config = ConfigDict(extra="ignore") - - name: str - version: str - description: str = "" - config_file: str | None = None - inject: dict[str, str] = Field(default_factory=dict) - tools: list[PluginToolSpec] = Field(default_factory=list) # pyright: ignore[reportUnknownVariableType] - runtime: PluginRuntime | None = None - - -PLUGIN_JSON = "plugin.json" - - -def parse_plugin_json(path: Path) -> PluginSpec: - """Parse a plugin.json file and return a validated PluginSpec.""" - try: - data = json.loads(path.read_text(encoding="utf-8")) - except (OSError, json.JSONDecodeError) as exc: - raise PluginError(f"Failed to read {path}: {exc}") from exc - - if "name" not in data: - raise PluginError(f"Missing required field 'name' in {path}") - if "version" not in data: - raise PluginError(f"Missing required field 'version' in {path}") - if data.get("inject") and not data.get("config_file"): - raise PluginError(f"'inject' requires 'config_file' in {path}") - - try: - return PluginSpec.model_validate(data) - except Exception as exc: - raise PluginError(f"Invalid plugin.json schema in {path}: {exc}") from exc - - -def inject_config(plugin_dir: Path, spec: PluginSpec, values: dict[str, str]) -> None: - """Inject host values into the plugin's config file. - - Args: - plugin_dir: Root directory of the installed plugin. - spec: Parsed plugin spec. - values: Map of standard inject keys to actual values (e.g. {"api_key": "sk-xxx"}). - """ - if not spec.inject or not spec.config_file: - return - - config_path = (plugin_dir / spec.config_file).resolve() - if not config_path.is_relative_to(plugin_dir.resolve()): - raise PluginError(f"config_file escapes plugin directory: {spec.config_file}") - if not config_path.exists(): - raise PluginError(f"Config file not found: {config_path}") - - try: - config = json.loads(config_path.read_text(encoding="utf-8")) - except (OSError, json.JSONDecodeError) as exc: - raise PluginError(f"Failed to read config file {config_path}: {exc}") from exc - - for target_path, source_key in spec.inject.items(): - if source_key not in values: - raise PluginError(f"Host does not provide required inject key '{source_key}'") - _set_nested(config, target_path, values[source_key]) - - config_path.write_text( - json.dumps(config, ensure_ascii=False, indent=2), - encoding="utf-8", - ) - - -def write_runtime(plugin_dir: Path, runtime: PluginRuntime) -> None: - """Write runtime info into plugin.json.""" - plugin_json_path = plugin_dir / PLUGIN_JSON - try: - data = json.loads(plugin_json_path.read_text(encoding="utf-8")) - except (OSError, json.JSONDecodeError) as exc: - raise PluginError(f"Failed to read {plugin_json_path}: {exc}") from exc - data["runtime"] = runtime.model_dump() - plugin_json_path.write_text( - json.dumps(data, ensure_ascii=False, indent=2), - encoding="utf-8", - ) - - -def _set_nested(obj: dict[str, Any], dotted_path: str, value: object) -> None: - """Set a value in a nested dict using dot-separated path. - - Creates intermediate dicts if they don't exist. - """ - keys = dotted_path.split(".") - for key in keys[:-1]: - if key not in obj or not isinstance(obj[key], dict): - obj[key] = {} - obj = obj[key] - obj[keys[-1]] = value diff --git a/src/kimi_cli/plugin/manager.py b/src/kimi_cli/plugin/manager.py deleted file mode 100644 index c58e0ffc0..000000000 --- a/src/kimi_cli/plugin/manager.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Plugin installation, removal, and listing.""" - -from __future__ import annotations - -import shutil -import tempfile -from pathlib import Path -from typing import TYPE_CHECKING - -from kimi_cli.plugin import ( - PLUGIN_JSON, - PluginError, - PluginRuntime, - PluginSpec, - inject_config, - parse_plugin_json, - write_runtime, -) -from kimi_cli.share import get_share_dir - -if TYPE_CHECKING: - from kimi_cli.auth.oauth import OAuthManager - from kimi_cli.config import Config - - -def get_plugins_dir() -> Path: - """Return the plugins installation directory (~/.kimi/plugins/).""" - return get_share_dir() / "plugins" - - -def collect_host_values(config: Config, oauth: OAuthManager) -> dict[str, str]: - """Collect host values (api_key, base_url) for plugin injection. - - Resolves credentials from the default provider, handling OAuth tokens - and static API keys. Callers that run outside the normal startup flow - (e.g. ``install_cmd``) should apply environment-variable overrides - (``augment_provider_with_env_vars``) to the provider **before** calling - this function; the main app startup already does that. - """ - values: dict[str, str] = {} - if not config.default_model or config.default_model not in config.models: - return values - model = config.models[config.default_model] - if model.provider not in config.providers: - return values - provider = config.providers[model.provider] - api_key = oauth.resolve_api_key(provider.api_key, provider.oauth) - if api_key: - values["api_key"] = api_key - values["base_url"] = provider.base_url - return values - - -def _validate_name(name: str, plugins_dir: Path) -> Path: - """Resolve and validate plugin name, returning the safe destination path.""" - dest = (plugins_dir / name).resolve() - if not dest.is_relative_to(plugins_dir.resolve()): - raise PluginError(f"Invalid plugin name: {name}") - return dest - - -def install_plugin( - *, - source: Path, - plugins_dir: Path, - host_values: dict[str, str], - host_name: str, - host_version: str, -) -> PluginSpec: - """Install a plugin from a source directory. - - Stages the new copy to a temp dir first, so a failed upgrade - does not destroy the previous installation. - """ - source_plugin_json = source / PLUGIN_JSON - if not source_plugin_json.exists(): - raise PluginError(f"No plugin.json found in {source}") - - spec = parse_plugin_json(source_plugin_json) - dest = _validate_name(spec.name, plugins_dir) - - # Stage to a temp dir inside plugins_dir so rename is atomic on same fs - plugins_dir.mkdir(parents=True, exist_ok=True) - staging = Path(tempfile.mkdtemp(prefix=f".{spec.name}-", dir=plugins_dir)) - try: - # Copy source into staging - staging_plugin = staging / spec.name - shutil.copytree(source, staging_plugin) - - # Apply inject + runtime on the staged copy - inject_config(staging_plugin, spec, host_values) - runtime = PluginRuntime(host=host_name, host_version=host_version) - write_runtime(staging_plugin, runtime) - - # Swap: remove old, move staged into place - if dest.exists(): - shutil.rmtree(dest) - staging_plugin.rename(dest) - except Exception: - # On any failure, clean up staging but leave existing install intact - shutil.rmtree(staging, ignore_errors=True) - raise - finally: - # Clean up staging dir shell (may be empty after successful rename) - shutil.rmtree(staging, ignore_errors=True) - - # Re-read to return the installed spec (with runtime) - return parse_plugin_json(dest / PLUGIN_JSON) - - -def refresh_plugin_configs(plugins_dir: Path, host_values: dict[str, str]) -> None: - """Re-inject host values into all installed plugin config files. - - Called at startup so that OAuth tokens and other credentials - stay fresh even after the initial install. - """ - if not plugins_dir.is_dir(): - return - - for child in sorted(plugins_dir.iterdir()): - plugin_json = child / PLUGIN_JSON - if not child.is_dir() or not plugin_json.is_file(): - continue - try: - spec = parse_plugin_json(plugin_json) - if spec.inject and spec.config_file: - inject_config(child, spec, host_values) - except Exception: - continue - - -def list_plugins(plugins_dir: Path) -> list[PluginSpec]: - """List all installed plugins.""" - if not plugins_dir.is_dir(): - return [] - - plugins: list[PluginSpec] = [] - for child in sorted(plugins_dir.iterdir()): - plugin_json = child / PLUGIN_JSON - if child.is_dir() and plugin_json.is_file(): - try: - plugins.append(parse_plugin_json(plugin_json)) - except PluginError: - continue - return plugins - - -def remove_plugin(name: str, plugins_dir: Path) -> None: - """Remove an installed plugin.""" - dest = _validate_name(name, plugins_dir) - if not dest.exists(): - raise PluginError(f"Plugin '{name}' not found in {plugins_dir}") - shutil.rmtree(dest) diff --git a/src/kimi_cli/plugin/manager.ts b/src/kimi_cli/plugin/manager.ts new file mode 100644 index 000000000..fdebfa83c --- /dev/null +++ b/src/kimi_cli/plugin/manager.ts @@ -0,0 +1,267 @@ +/** + * Plugin manager — corresponds to Python plugin/manager.py + * Plugin installation, removal, and listing. + */ + +import { join, resolve } from "node:path"; +import { + existsSync, + mkdirSync, + readdirSync, + readFileSync, + writeFileSync, + statSync, + rmSync, + cpSync, + renameSync, + mkdtempSync, +} from "node:fs"; +import { homedir, tmpdir } from "node:os"; +import { logger } from "../utils/logging.ts"; + +// ── Types ── + +export class PluginError extends Error { + constructor(message: string) { + super(message); + this.name = "PluginError"; + } +} + +export interface PluginRuntime { + host: string; + hostVersion: string; +} + +export interface PluginToolSpec { + name: string; + description: string; + command: string[]; + parameters: Record; +} + +export interface PluginSpec { + name: string; + version: string; + description: string; + configFile?: string; + inject: Record; + tools: PluginToolSpec[]; + runtime?: PluginRuntime; +} + +export const PLUGIN_JSON = "plugin.json"; + +// ── Parsing ── + +export function parsePluginJson(path: string): PluginSpec { + let data: Record; + try { + data = JSON.parse(readFileSync(path, "utf-8")); + } catch (err) { + throw new PluginError(`Failed to read ${path}: ${err}`); + } + + if (!data.name) throw new PluginError(`Missing required field 'name' in ${path}`); + if (!data.version) throw new PluginError(`Missing required field 'version' in ${path}`); + if (data.inject && !data.config_file) { + throw new PluginError(`'inject' requires 'config_file' in ${path}`); + } + + const tools: PluginToolSpec[] = []; + if (Array.isArray(data.tools)) { + for (const t of data.tools) { + tools.push({ + name: String(t.name ?? ""), + description: String(t.description ?? ""), + command: Array.isArray(t.command) ? t.command.map(String) : [], + parameters: (t.parameters as Record) ?? {}, + }); + } + } + + const runtime = data.runtime as Record | undefined; + + return { + name: String(data.name), + version: String(data.version), + description: String(data.description ?? ""), + configFile: data.config_file ? String(data.config_file) : undefined, + inject: (data.inject as Record) ?? {}, + tools, + runtime: runtime + ? { host: String(runtime.host ?? ""), hostVersion: String(runtime.host_version ?? "") } + : undefined, + }; +} + +// ── Directory helpers ── + +export function getPluginsDir(): string { + const shareDir = join(homedir(), ".kimi"); + return join(shareDir, "plugins"); +} + +// ── Config injection ── + +function setNested(obj: Record, dottedPath: string, value: unknown): void { + const keys = dottedPath.split("."); + let current = obj; + for (let i = 0; i < keys.length - 1; i++) { + const key = keys[i]!; + if (!(key in current) || typeof current[key] !== "object" || current[key] === null) { + current[key] = {}; + } + current = current[key] as Record; + } + current[keys[keys.length - 1]!] = value; +} + +export function injectConfig( + pluginDir: string, + spec: PluginSpec, + values: Record, +): void { + if (!spec.inject || !spec.configFile) return; + + const configPath = resolve(join(pluginDir, spec.configFile)); + if (!configPath.startsWith(resolve(pluginDir))) { + throw new PluginError(`config_file escapes plugin directory: ${spec.configFile}`); + } + if (!existsSync(configPath)) { + throw new PluginError(`Config file not found: ${configPath}`); + } + + let config: Record; + try { + config = JSON.parse(readFileSync(configPath, "utf-8")); + } catch (err) { + throw new PluginError(`Failed to read config file ${configPath}: ${err}`); + } + + for (const [targetPath, sourceKey] of Object.entries(spec.inject)) { + if (!(sourceKey in values)) { + throw new PluginError(`Host does not provide required inject key '${sourceKey}'`); + } + setNested(config, targetPath, values[sourceKey]!); + } + + writeFileSync(configPath, JSON.stringify(config, null, 2), "utf-8"); +} + +export function writeRuntime(pluginDir: string, runtime: PluginRuntime): void { + const pluginJsonPath = join(pluginDir, PLUGIN_JSON); + let data: Record; + try { + data = JSON.parse(readFileSync(pluginJsonPath, "utf-8")); + } catch (err) { + throw new PluginError(`Failed to read ${pluginJsonPath}: ${err}`); + } + data.runtime = { host: runtime.host, host_version: runtime.hostVersion }; + writeFileSync(pluginJsonPath, JSON.stringify(data, null, 2), "utf-8"); +} + +// ── Installation ── + +function validateName(name: string, pluginsDir: string): string { + const dest = resolve(join(pluginsDir, name)); + if (!dest.startsWith(resolve(pluginsDir))) { + throw new PluginError(`Invalid plugin name: ${name}`); + } + return dest; +} + +export function installPlugin(opts: { + source: string; + pluginsDir: string; + hostValues: Record; + hostName: string; + hostVersion: string; +}): PluginSpec { + const sourcePluginJson = join(opts.source, PLUGIN_JSON); + if (!existsSync(sourcePluginJson)) { + throw new PluginError(`No plugin.json found in ${opts.source}`); + } + + const spec = parsePluginJson(sourcePluginJson); + const dest = validateName(spec.name, opts.pluginsDir); + + mkdirSync(opts.pluginsDir, { recursive: true }); + const staging = mkdtempSync(join(opts.pluginsDir, `.${spec.name}-`)); + + try { + const stagingPlugin = join(staging, spec.name); + cpSync(opts.source, stagingPlugin, { recursive: true }); + + injectConfig(stagingPlugin, spec, opts.hostValues); + writeRuntime(stagingPlugin, { host: opts.hostName, hostVersion: opts.hostVersion }); + + if (existsSync(dest)) rmSync(dest, { recursive: true, force: true }); + renameSync(stagingPlugin, dest); + } catch (err) { + rmSync(staging, { recursive: true, force: true }); + throw err; + } finally { + try { + rmSync(staging, { recursive: true, force: true }); + } catch { + // Ignore + } + } + + return parsePluginJson(join(dest, PLUGIN_JSON)); +} + +export function refreshPluginConfigs(pluginsDir: string, hostValues: Record): void { + if (!existsSync(pluginsDir)) return; + try { + if (!statSync(pluginsDir).isDirectory()) return; + } catch { + return; + } + + for (const child of readdirSync(pluginsDir).sort()) { + const childPath = join(pluginsDir, child); + const pluginJson = join(childPath, PLUGIN_JSON); + try { + if (!statSync(childPath).isDirectory() || !existsSync(pluginJson)) continue; + const spec = parsePluginJson(pluginJson); + if (spec.inject && spec.configFile) { + injectConfig(childPath, spec, hostValues); + } + } catch { + continue; + } + } +} + +export function listPlugins(pluginsDir: string): PluginSpec[] { + if (!existsSync(pluginsDir)) return []; + try { + if (!statSync(pluginsDir).isDirectory()) return []; + } catch { + return []; + } + + const plugins: PluginSpec[] = []; + for (const child of readdirSync(pluginsDir).sort()) { + const childPath = join(pluginsDir, child); + const pluginJson = join(childPath, PLUGIN_JSON); + try { + if (statSync(childPath).isDirectory() && existsSync(pluginJson)) { + plugins.push(parsePluginJson(pluginJson)); + } + } catch { + continue; + } + } + return plugins; +} + +export function removePlugin(name: string, pluginsDir: string): void { + const dest = validateName(name, pluginsDir); + if (!existsSync(dest)) { + throw new PluginError(`Plugin '${name}' not found in ${pluginsDir}`); + } + rmSync(dest, { recursive: true, force: true }); +} diff --git a/src/kimi_cli/plugin/tool.py b/src/kimi_cli/plugin/tool.py deleted file mode 100644 index abd08bace..000000000 --- a/src/kimi_cli/plugin/tool.py +++ /dev/null @@ -1,173 +0,0 @@ -"""Plugin tool wrapper — runs plugin-declared tools as subprocesses.""" - -from __future__ import annotations - -import asyncio -import json -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from kosong.tooling import CallableTool, ToolError, ToolOk -from kosong.tooling.error import ToolRuntimeError -from loguru import logger - -from kimi_cli.plugin import PluginToolSpec -from kimi_cli.tools.utils import ToolRejectedError -from kimi_cli.utils.subprocess_env import get_clean_env -from kimi_cli.wire.types import ToolReturnValue - -if TYPE_CHECKING: - from kimi_cli.config import Config - from kimi_cli.soul.approval import Approval - - -def _get_host_values(config: Config) -> dict[str, str]: - """Extract current host values (api_key, base_url) from config. - - Reads the latest provider credentials, which may have been - refreshed by OAuth since plugin install time. - """ - from kimi_cli.auth.oauth import OAuthManager - from kimi_cli.plugin.manager import collect_host_values - - oauth = OAuthManager(config) - return collect_host_values(config, oauth) - - -class PluginTool(CallableTool): - """A tool that executes a plugin command in a subprocess. - - Parameters are passed via stdin as JSON. - stdout is captured as the tool result. - Host credentials are injected as environment variables at runtime - (not baked into config files) to handle OAuth token refresh. - """ - - def __init__( - self, - tool_spec: PluginToolSpec, - plugin_dir: Path, - *, - inject: dict[str, str], - config: Config, - approval: Approval | None = None, - **kwargs: Any, - ): - super().__init__( - name=tool_spec.name, - description=tool_spec.description, - parameters=tool_spec.parameters or {"type": "object", "properties": {}}, - **kwargs, - ) - self._command = tool_spec.command - self._plugin_dir = plugin_dir - self._inject = inject # e.g. {"kimiCodeAPIKey": "api_key"} - self._config = config - self._approval = approval - - def _build_env(self) -> dict[str, str]: - """Build env vars with fresh host credentials for the subprocess.""" - env = get_clean_env() - if self._inject: - host_values = _get_host_values(self._config) - for target_key, source_key in self._inject.items(): - if source_key in host_values: - # Inject as env var using the plugin's config key name - # e.g. kimiCodeAPIKey= - env[target_key] = host_values[source_key] - return env - - async def __call__(self, *args: Any, **kwargs: Any) -> ToolReturnValue: - if self._approval is not None: - description = f"Run plugin tool `{self.name}`." - if not await self._approval.request(self.name, f"plugin:{self.name}", description): - return ToolRejectedError() - - params_json = json.dumps(kwargs, ensure_ascii=False) - - try: - proc = await asyncio.create_subprocess_exec( - *self._command, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - cwd=str(self._plugin_dir), - env=self._build_env(), - ) - except Exception as exc: - return ToolRuntimeError(str(exc)) - - try: - stdout, stderr = await asyncio.wait_for( - proc.communicate(input=params_json.encode("utf-8")), - timeout=120, - ) - except asyncio.CancelledError: - proc.kill() - await proc.wait() - raise - except TimeoutError: - proc.kill() - await proc.wait() - return ToolError( - message=f"Plugin tool '{self.name}' timed out after 120s.", - brief="Timeout", - ) - - output = stdout.decode("utf-8", errors="replace").strip() - err_output = stderr.decode("utf-8", errors="replace").strip() - - if proc.returncode != 0: - error_msg = err_output or output or f"Exit code {proc.returncode}" - return ToolError( - message=f"Plugin tool '{self.name}' failed: {error_msg}", - brief=f"Exit {proc.returncode}", - ) - - if err_output: - logger.debug("Plugin tool {name} stderr: {err}", name=self.name, err=err_output) - - return ToolOk(output=output) - - -def load_plugin_tools( - plugins_dir: Path, config: Config, *, approval: Approval | None = None -) -> list[PluginTool]: - """Scan installed plugins and create PluginTool instances for declared tools.""" - from kimi_cli.plugin import PLUGIN_JSON, PluginError, parse_plugin_json - - if not plugins_dir.is_dir(): - return [] - - tools: list[PluginTool] = [] - for child in sorted(plugins_dir.iterdir()): - plugin_json = child / PLUGIN_JSON - if not child.is_dir() or not plugin_json.is_file(): - continue - try: - spec = parse_plugin_json(plugin_json) - except PluginError: - continue - for tool_spec in spec.tools: - try: - tool = PluginTool( - tool_spec, - plugin_dir=child, - inject=spec.inject, - config=config, - approval=approval, - ) - except Exception: - logger.warning( - "Skipping invalid plugin tool: {name} (from {plugin})", - name=tool_spec.name, - plugin=spec.name, - ) - continue - tools.append(tool) - logger.info( - "Loaded plugin tool: {name} (from {plugin})", - name=tool_spec.name, - plugin=spec.name, - ) - return tools diff --git a/src/kimi_cli/plugin/tool.ts b/src/kimi_cli/plugin/tool.ts new file mode 100644 index 000000000..d18a8b9ce --- /dev/null +++ b/src/kimi_cli/plugin/tool.ts @@ -0,0 +1,142 @@ +/** + * Plugin tool wrapper — corresponds to Python plugin/tool.py + * Runs plugin-declared tools as subprocesses. + */ + +import { join } from "node:path"; +import { existsSync, readdirSync, statSync } from "node:fs"; +import { logger } from "../utils/logging.ts"; +import { + type PluginToolSpec, + type PluginSpec, + PluginError, + PLUGIN_JSON, + parsePluginJson, +} from "./manager.ts"; + +export interface PluginToolResult { + ok: boolean; + output: string; + brief?: string; +} + +export class PluginTool { + readonly name: string; + readonly description: string; + readonly parameters: Record; + private _command: string[]; + private _pluginDir: string; + private _inject: Record; + + constructor(opts: { + toolSpec: PluginToolSpec; + pluginDir: string; + inject: Record; + }) { + this.name = opts.toolSpec.name; + this.description = opts.toolSpec.description; + this.parameters = opts.toolSpec.parameters || { type: "object", properties: {} }; + this._command = opts.toolSpec.command; + this._pluginDir = opts.pluginDir; + this._inject = opts.inject; + } + + private buildEnv(): Record { + const env: Record = { ...process.env } as Record; + // Inject values are directly set as env vars + for (const [targetKey, sourceKey] of Object.entries(this._inject)) { + // The values should be resolved by the caller + if (sourceKey) { + env[targetKey] = sourceKey; + } + } + return env; + } + + async execute(params: Record): Promise { + const paramsJson = JSON.stringify(params); + + try { + const proc = Bun.spawn(this._command, { + stdin: new Blob([paramsJson]), + stdout: "pipe", + stderr: "pipe", + cwd: this._pluginDir, + env: this.buildEnv(), + }); + + const timer = setTimeout(() => proc.kill(), 120_000); + const exitCode = await proc.exited; + clearTimeout(timer); + + const stdout = await new Response(proc.stdout).text(); + const stderr = await new Response(proc.stderr).text(); + const output = stdout.trim(); + const errOutput = stderr.trim(); + + if (exitCode !== 0) { + const errorMsg = errOutput || output || `Exit code ${exitCode}`; + return { + ok: false, + output: `Plugin tool '${this.name}' failed: ${errorMsg}`, + brief: `Exit ${exitCode}`, + }; + } + + if (errOutput) { + logger.debug(`Plugin tool ${this.name} stderr: ${errOutput}`); + } + + return { ok: true, output }; + } catch (err) { + return { ok: false, output: String(err), brief: "Runtime error" }; + } + } +} + +/** + * Scan installed plugins and create PluginTool instances for declared tools. + */ +export function loadPluginTools(pluginsDir: string): PluginTool[] { + if (!existsSync(pluginsDir)) return []; + try { + if (!statSync(pluginsDir).isDirectory()) return []; + } catch { + return []; + } + + const tools: PluginTool[] = []; + for (const child of readdirSync(pluginsDir).sort()) { + const childPath = join(pluginsDir, child); + const pluginJson = join(childPath, PLUGIN_JSON); + if (!existsSync(pluginJson)) continue; + try { + if (!statSync(childPath).isDirectory()) continue; + } catch { + continue; + } + + let spec: PluginSpec; + try { + spec = parsePluginJson(pluginJson); + } catch { + continue; + } + + for (const toolSpec of spec.tools) { + try { + tools.push( + new PluginTool({ + toolSpec, + pluginDir: childPath, + inject: spec.inject, + }), + ); + logger.info(`Loaded plugin tool: ${toolSpec.name} (from ${spec.name})`); + } catch { + logger.warn(`Skipping invalid plugin tool: ${toolSpec.name} (from ${spec.name})`); + } + } + } + return tools; +} diff --git a/src/kimi_cli/prompts/__init__.py b/src/kimi_cli/prompts/__init__.py deleted file mode 100644 index fe9e992fb..000000000 --- a/src/kimi_cli/prompts/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -INIT = (Path(__file__).parent / "init.md").read_text(encoding="utf-8") -COMPACT = (Path(__file__).parent / "compact.md").read_text(encoding="utf-8") diff --git a/src/kimi_cli/py.typed b/src/kimi_cli/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/kimi_cli/session.py b/src/kimi_cli/session.py deleted file mode 100644 index ba56741b4..000000000 --- a/src/kimi_cli/session.py +++ /dev/null @@ -1,309 +0,0 @@ -from __future__ import annotations - -import asyncio -import builtins -import json -import shutil -import uuid -from dataclasses import dataclass -from pathlib import Path -from textwrap import shorten - -from kaos.path import KaosPath -from kosong.message import Message - -from kimi_cli.metadata import WorkDirMeta, load_metadata, save_metadata -from kimi_cli.session_state import SessionState, load_session_state, save_session_state -from kimi_cli.utils.logging import logger -from kimi_cli.wire.file import WireFile -from kimi_cli.wire.types import TurnBegin - - -@dataclass(slots=True, kw_only=True) -class Session: - """A session of a work directory.""" - - # static metadata - id: str - """The session ID.""" - work_dir: KaosPath - """The absolute path of the work directory.""" - work_dir_meta: WorkDirMeta - """The metadata of the work directory.""" - context_file: Path - """The absolute path to the file storing the message history.""" - wire_file: WireFile - """The wire message log file wrapper.""" - - # session state - state: SessionState - """Persisted session state (approval settings, plan mode, workspace scope, etc.).""" - - # refreshable metadata - title: str - """The title of the session.""" - updated_at: float - """The timestamp of the last update to the session.""" - - @property - def dir(self) -> Path: - """The absolute path of the session directory.""" - path = self.work_dir_meta.sessions_dir / self.id - path.mkdir(parents=True, exist_ok=True) - return path - - @property - def subagents_dir(self) -> Path: - """The absolute path of the subagent instances directory.""" - path = self.dir / "subagents" - path.mkdir(parents=True, exist_ok=True) - return path - - def is_empty(self) -> bool: - """Whether the session has any context history or a custom title.""" - if self.state.custom_title: - return False - if not self.wire_file.is_empty(): - return False - try: - with self.context_file.open(encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - role = json.loads(line, strict=False).get("role") - if isinstance(role, str) and not role.startswith("_"): - return False - except FileNotFoundError: - return True - except (OSError, ValueError, TypeError): - logger.exception("Failed to read context file {file}:", file=self.context_file) - return False - return True - - def save_state(self) -> None: - """Persist the session state to disk. - - Reloads externally-mutable fields (title, archive) from disk first - to avoid overwriting concurrent changes made by the web API. - """ - fresh = load_session_state(self.dir) - self.state.custom_title = fresh.custom_title - self.state.title_generated = fresh.title_generated - self.state.title_generate_attempts = fresh.title_generate_attempts - self.state.archived = fresh.archived - self.state.archived_at = fresh.archived_at - self.state.auto_archive_exempt = fresh.auto_archive_exempt - save_session_state(self.state, self.dir) - - async def delete(self) -> None: - """Delete the session directory.""" - session_dir = self.work_dir_meta.sessions_dir / self.id - if not session_dir.exists(): - return - await asyncio.to_thread(shutil.rmtree, session_dir, True) - - async def refresh(self) -> None: - self.title = "Untitled" - self.updated_at = self.context_file.stat().st_mtime if self.context_file.exists() else 0.0 - - if self.state.custom_title: - self.title = self.state.custom_title - return - - try: - async for record in self.wire_file.iter_records(): - wire_msg = record.to_wire_message() - if isinstance(wire_msg, TurnBegin): - self.title = shorten( - Message(role="user", content=wire_msg.user_input).extract_text(" "), - width=50, - ) - return - except Exception: - logger.exception( - "Failed to derive session title from wire file {file}:", - file=self.wire_file.path, - ) - - @staticmethod - async def create( - work_dir: KaosPath, - session_id: str | None = None, - _context_file: Path | None = None, - ) -> Session: - """Create a new session for a work directory.""" - work_dir = work_dir.canonical() - logger.debug("Creating new session for work directory: {work_dir}", work_dir=work_dir) - - metadata = load_metadata() - work_dir_meta = metadata.get_work_dir_meta(work_dir) - if work_dir_meta is None: - work_dir_meta = metadata.new_work_dir_meta(work_dir) - - if session_id is None: - session_id = str(uuid.uuid4()) - session_dir = work_dir_meta.sessions_dir / session_id - session_dir.mkdir(parents=True, exist_ok=True) - - if _context_file is None: - context_file = session_dir / "context.jsonl" - else: - logger.warning( - "Using provided context file: {context_file}", context_file=_context_file - ) - _context_file.parent.mkdir(parents=True, exist_ok=True) - if _context_file.exists(): - assert _context_file.is_file() - context_file = _context_file - - if context_file.exists(): - # truncate if exists - logger.warning( - "Context file already exists, truncating: {context_file}", context_file=context_file - ) - context_file.unlink() - context_file.touch() - - save_metadata(metadata) - - session = Session( - id=session_id, - work_dir=work_dir, - work_dir_meta=work_dir_meta, - context_file=context_file, - wire_file=WireFile(path=session_dir / "wire.jsonl"), - state=SessionState(), - title="", - updated_at=0.0, - ) - await session.refresh() - return session - - @staticmethod - async def find(work_dir: KaosPath, session_id: str) -> Session | None: - """Find a session by work directory and session ID.""" - work_dir = work_dir.canonical() - logger.debug( - "Finding session for work directory: {work_dir}, session ID: {session_id}", - work_dir=work_dir, - session_id=session_id, - ) - - metadata = load_metadata() - work_dir_meta = metadata.get_work_dir_meta(work_dir) - if work_dir_meta is None: - logger.debug("Work directory never been used") - return None - - _migrate_session_context_file(work_dir_meta, session_id) - - session_dir = work_dir_meta.sessions_dir / session_id - if not session_dir.is_dir(): - logger.debug("Session directory not found: {session_dir}", session_dir=session_dir) - return None - - context_file = session_dir / "context.jsonl" - if not context_file.exists(): - logger.debug( - "Session context file not found: {context_file}", context_file=context_file - ) - return None - - session = Session( - id=session_id, - work_dir=work_dir, - work_dir_meta=work_dir_meta, - context_file=context_file, - wire_file=WireFile(path=session_dir / "wire.jsonl"), - state=load_session_state(session_dir), - title="", - updated_at=0.0, - ) - await session.refresh() - return session - - @staticmethod - async def list(work_dir: KaosPath) -> builtins.list[Session]: - """List all sessions for a work directory.""" - work_dir = work_dir.canonical() - logger.debug("Listing sessions for work directory: {work_dir}", work_dir=work_dir) - - metadata = load_metadata() - work_dir_meta = metadata.get_work_dir_meta(work_dir) - if work_dir_meta is None: - logger.debug("Work directory never been used") - return [] - - session_ids = { - path.name if path.is_dir() else path.stem - for path in work_dir_meta.sessions_dir.iterdir() - if path.is_dir() or path.suffix == ".jsonl" - } - - sessions: list[Session] = [] - for session_id in session_ids: - _migrate_session_context_file(work_dir_meta, session_id) - session_dir = work_dir_meta.sessions_dir / session_id - if not session_dir.is_dir(): - logger.debug("Session directory not found: {session_dir}", session_dir=session_dir) - continue - context_file = session_dir / "context.jsonl" - if not context_file.exists(): - logger.debug( - "Session context file not found: {context_file}", context_file=context_file - ) - continue - session = Session( - id=session_id, - work_dir=work_dir, - work_dir_meta=work_dir_meta, - context_file=context_file, - wire_file=WireFile(path=session_dir / "wire.jsonl"), - state=load_session_state(session_dir), - title="", - updated_at=0.0, - ) - if session.is_empty(): - logger.debug( - "Session context file is empty: {context_file}", context_file=context_file - ) - continue - await session.refresh() - sessions.append(session) - sessions.sort(key=lambda session: session.updated_at, reverse=True) - return sessions - - @staticmethod - async def continue_(work_dir: KaosPath) -> Session | None: - """Get the last session for a work directory.""" - work_dir = work_dir.canonical() - logger.debug("Continuing session for work directory: {work_dir}", work_dir=work_dir) - - metadata = load_metadata() - work_dir_meta = metadata.get_work_dir_meta(work_dir) - if work_dir_meta is None: - logger.debug("Work directory never been used") - return None - if work_dir_meta.last_session_id is None: - logger.debug("Work directory never had a session") - return None - - logger.debug( - "Found last session for work directory: {session_id}", - session_id=work_dir_meta.last_session_id, - ) - return await Session.find(work_dir, work_dir_meta.last_session_id) - - -def _migrate_session_context_file(work_dir_meta: WorkDirMeta, session_id: str) -> None: - old_context_file = work_dir_meta.sessions_dir / f"{session_id}.jsonl" - new_context_file = work_dir_meta.sessions_dir / session_id / "context.jsonl" - if old_context_file.exists() and not new_context_file.exists(): - new_context_file.parent.mkdir(parents=True, exist_ok=True) - old_context_file.rename(new_context_file) - logger.info( - "Migrated session context file from {old} to {new}", - old=old_context_file, - new=new_context_file, - ) diff --git a/src/kimi_cli/session.ts b/src/kimi_cli/session.ts new file mode 100644 index 000000000..2e085effe --- /dev/null +++ b/src/kimi_cli/session.ts @@ -0,0 +1,337 @@ +/** + * Session module — corresponds to Python session.py + session_state.py + * Manages per-workdir sessions with context files and state persistence. + */ + +import { z } from "zod/v4"; +import { join, resolve } from "node:path"; +import { createHash, randomUUID } from "node:crypto"; +import { getShareDir } from "./config.ts"; +import { logger } from "./utils/logging.ts"; +import { + loadMetadata, + saveMetadata, + getWorkDirMeta, + newWorkDirMeta, + getSessionsDir, + type Metadata, + type WorkDirMeta, +} from "./metadata.ts"; + +// ── Session State ─────────────────────────────────────── + +export const ApprovalStateData = z.object({ + yolo: z.boolean().default(false), + auto_approve_actions: z.array(z.string()).default([]), +}); +export type ApprovalStateData = z.infer; + +export const SessionState = z.object({ + version: z.number().int().default(1), + approval: ApprovalStateData.default({} as any), + additional_dirs: z.array(z.string()).default([]), + custom_title: z.string().nullable().default(null), + title_generated: z.boolean().default(false), + title_generate_attempts: z.number().int().default(0), + plan_mode: z.boolean().default(false), + plan_session_id: z.string().nullable().default(null), + plan_slug: z.string().nullable().default(null), + wire_mtime: z.number().nullable().default(null), + archived: z.boolean().default(false), + archived_at: z.number().nullable().default(null), + auto_archive_exempt: z.boolean().default(false), +}); +export type SessionState = z.infer; + +const STATE_FILE_NAME = "state.json"; + +export async function loadSessionState(sessionDir: string): Promise { + const stateFile = join(sessionDir, STATE_FILE_NAME); + const file = Bun.file(stateFile); + if (!(await file.exists())) { + return SessionState.parse({}); + } + try { + const data = await file.json(); + return SessionState.parse(data); + } catch { + logger.warn(`Corrupted state file, using defaults: ${stateFile}`); + return SessionState.parse({}); + } +} + +export async function saveSessionState(state: SessionState, sessionDir: string): Promise { + const stateFile = join(sessionDir, STATE_FILE_NAME); + await Bun.write(stateFile, JSON.stringify(state, null, 2)); +} + +// ── WorkDir Metadata (uses metadata.ts for Python-compatible MD5 hashing) ── + +function getSessionsBaseDir(workDir: string): string { + // Use MD5 hash of the work directory path, compatible with Python metadata.py + const pathMd5 = createHash("md5").update(workDir, "utf-8").digest("hex"); + return join(getShareDir(), "sessions", pathMd5); +} + +// ── Session class ─────────────────────────────────────── + +export class Session { + readonly id: string; + readonly workDir: string; + readonly sessionsDir: string; + readonly contextFile: string; + readonly wireFile: string; + state: SessionState; + title: string; + updatedAt: number; + + constructor(opts: { + id: string; + workDir: string; + sessionsDir: string; + contextFile: string; + wireFile: string; + state: SessionState; + title?: string; + updatedAt?: number; + }) { + this.id = opts.id; + this.workDir = opts.workDir; + this.sessionsDir = opts.sessionsDir; + this.contextFile = opts.contextFile; + this.wireFile = opts.wireFile; + this.state = opts.state; + this.title = opts.title ?? "Untitled"; + this.updatedAt = opts.updatedAt ?? 0; + } + + get dir(): string { + const path = join(this.sessionsDir, this.id); + // Note: directory creation is handled by save operations (saveState, create) + return path; + } + + get subagentsDir(): string { + return join(this.dir, "subagents"); + } + + /** Ensure the session directory exists (call before writing). */ + async ensureDir(): Promise { + const path = this.dir; + await Bun.$`mkdir -p ${path}`.quiet(); + return path; + } + + async isEmpty(): Promise { + if (this.state.custom_title) return false; + + const contextBunFile = Bun.file(this.contextFile); + if (!(await contextBunFile.exists())) return true; + + try { + const text = await contextBunFile.text(); + for (const line of text.split("\n")) { + const trimmed = line.trim(); + if (!trimmed) continue; + try { + const parsed = JSON.parse(trimmed); + if (typeof parsed.role === "string" && !parsed.role.startsWith("_")) { + return false; + } + } catch { + continue; + } + } + } catch { + return false; + } + return true; + } + + async saveState(): Promise { + await Bun.$`mkdir -p ${this.dir}`.quiet(); + + // Reload externally-mutable fields from disk first to avoid + // overwriting concurrent changes made by the web API (matches Python behavior). + const fresh = await loadSessionState(this.dir); + this.state.custom_title = fresh.custom_title; + this.state.title_generated = fresh.title_generated; + this.state.title_generate_attempts = fresh.title_generate_attempts; + this.state.archived = fresh.archived; + this.state.archived_at = fresh.archived_at; + this.state.auto_archive_exempt = fresh.auto_archive_exempt; + + await saveSessionState(this.state, this.dir); + } + + async delete(): Promise { + const sessionDir = join(this.sessionsDir, this.id); + const file = Bun.file(sessionDir); + if (await file.exists()) { + await Bun.$`rm -rf ${sessionDir}`.quiet(); + } + } + + async refresh(): Promise { + this.title = "Untitled"; + const contextBunFile = Bun.file(this.contextFile); + if (await contextBunFile.exists()) { + const stat = await Bun.$`stat -f %m ${this.contextFile} 2>/dev/null || stat -c %Y ${this.contextFile} 2>/dev/null`.quiet().text(); + this.updatedAt = Number.parseFloat(stat.trim()) || 0; + } else { + this.updatedAt = 0; + } + + if (this.state.custom_title) { + this.title = this.state.custom_title; + return; + } + + // Try to derive title from wire file first turn + const wireBunFile = Bun.file(this.wireFile); + if (await wireBunFile.exists()) { + try { + const text = await wireBunFile.text(); + for (const line of text.split("\n")) { + const trimmed = line.trim(); + if (!trimmed) continue; + try { + const record = JSON.parse(trimmed); + if (record.type === "turn_begin" && record.user_input) { + const raw = typeof record.user_input === "string" + ? record.user_input + : JSON.stringify(record.user_input); + this.title = raw.slice(0, 50); + return; + } + } catch { + continue; + } + } + } catch { + // ignore + } + } + } + + // ── Static factories ─────────────────────────────── + + static async create(workDir: string, sessionId?: string): Promise { + workDir = resolve(workDir); + + // Ensure work dir is tracked in global metadata + const metadata = await loadMetadata(); + let wdMeta = getWorkDirMeta(metadata, workDir); + if (!wdMeta) { + wdMeta = newWorkDirMeta(metadata, workDir); + } + + const sessionsDir = getSessionsBaseDir(workDir); + const id = sessionId ?? randomUUID(); + const sessionDir = join(sessionsDir, id); + await Bun.$`mkdir -p ${sessionDir}`.quiet(); + + const contextFile = join(sessionDir, "context.jsonl"); + // Truncate if exists + await Bun.write(contextFile, ""); + + await saveMetadata(metadata); + + const session = new Session({ + id, + workDir, + sessionsDir, + contextFile, + wireFile: join(sessionDir, "wire.jsonl"), + state: SessionState.parse({}), + }); + await session.refresh(); + return session; + } + + static async find(workDir: string, sessionId: string): Promise { + workDir = resolve(workDir); + const sessionsDir = getSessionsBaseDir(workDir); + const sessionDir = join(sessionsDir, sessionId); + + const dirFile = Bun.file(join(sessionDir, "context.jsonl")); + if (!(await dirFile.exists())) return null; + + const state = await loadSessionState(sessionDir); + const session = new Session({ + id: sessionId, + workDir, + sessionsDir, + contextFile: join(sessionDir, "context.jsonl"), + wireFile: join(sessionDir, "wire.jsonl"), + state, + }); + await session.refresh(); + return session; + } + + static async list(workDir: string): Promise { + workDir = resolve(workDir); + const sessionsDir = getSessionsBaseDir(workDir); + + const dirFile = Bun.file(sessionsDir); + if (!(await dirFile.exists())) return []; + + let entries: string[]; + try { + const output = await Bun.$`ls ${sessionsDir}`.quiet().text(); + entries = output.trim().split("\n").filter(Boolean); + } catch { + return []; + } + + const sessions: Session[] = []; + for (const entry of entries) { + const sessionDir = join(sessionsDir, entry); + const contextFile = join(sessionDir, "context.jsonl"); + const ctxFile = Bun.file(contextFile); + if (!(await ctxFile.exists())) continue; + + const state = await loadSessionState(sessionDir); + const session = new Session({ + id: entry, + workDir, + sessionsDir, + contextFile, + wireFile: join(sessionDir, "wire.jsonl"), + state, + }); + + if (await session.isEmpty()) continue; + await session.refresh(); + sessions.push(session); + } + + sessions.sort((a, b) => b.updatedAt - a.updatedAt); + return sessions; + } + + /** + * Continue the most recent session for a workDir. + * Returns the last session or null if none exists. + */ + static async continue_(workDir: string): Promise { + workDir = resolve(workDir); + + // Try global metadata first (Python-compatible) + const metadata = await loadMetadata(); + const wdMeta = getWorkDirMeta(metadata, workDir); + if (wdMeta?.lastSessionId) { + const session = await Session.find(workDir, wdMeta.lastSessionId); + if (session) return session; + } + + // Fallback: find the most recently updated session + const sessions = await Session.list(workDir); + if (sessions.length > 0) { + return sessions[0]!; + } + + return null; + } +} diff --git a/src/kimi_cli/session_state.py b/src/kimi_cli/session_state.py deleted file mode 100644 index f7a54bde9..000000000 --- a/src/kimi_cli/session_state.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path - -from pydantic import BaseModel, Field, ValidationError - -from kimi_cli.utils.io import atomic_json_write -from kimi_cli.utils.logging import logger - -STATE_FILE_NAME = "state.json" - - -class ApprovalStateData(BaseModel): - yolo: bool = False - auto_approve_actions: set[str] = Field(default_factory=set) - - -class SessionState(BaseModel): - version: int = 1 - approval: ApprovalStateData = Field(default_factory=ApprovalStateData) - additional_dirs: list[str] = Field(default_factory=list) - custom_title: str | None = None - title_generated: bool = False - title_generate_attempts: int = 0 - plan_mode: bool = False - plan_session_id: str | None = None - plan_slug: str | None = None - # Archive state (previously in metadata.json) - wire_mtime: float | None = None - archived: bool = False - archived_at: float | None = None - auto_archive_exempt: bool = False - - -_LEGACY_METADATA_FILENAME = "metadata.json" - - -def _migrate_legacy_metadata(session_dir: Path, state: SessionState) -> str: - """Migrate fields from legacy metadata.json into SessionState. - - Returns: - "migrated" - fields were merged into state, caller should save and delete legacy file - "no_change" - legacy file parsed but no fields needed, caller can delete legacy file - "skip" - legacy file missing or unreadable, caller should not touch it - """ - metadata_file = session_dir / _LEGACY_METADATA_FILENAME - if not metadata_file.exists(): - return "skip" - try: - data = json.loads(metadata_file.read_text(encoding="utf-8")) - except Exception: - # Leave the file intact for future retry — it may be temporarily unreadable - return "skip" - - changed = False - - # Migrate title fields (only if state has defaults) - if state.custom_title is None and data.get("title") and data["title"] != "Untitled": - state.custom_title = data["title"] - changed = True - if not state.title_generated and data.get("title_generated"): - state.title_generated = True - changed = True - if state.title_generate_attempts == 0 and data.get("title_generate_attempts", 0) > 0: - state.title_generate_attempts = data["title_generate_attempts"] - changed = True - - # Migrate archive fields - if not state.archived and data.get("archived"): - state.archived = True - changed = True - if state.archived_at is None and data.get("archived_at") is not None: - state.archived_at = data["archived_at"] - changed = True - if not state.auto_archive_exempt and data.get("auto_archive_exempt"): - state.auto_archive_exempt = True - changed = True - - # Migrate wire_mtime - if state.wire_mtime is None and data.get("wire_mtime") is not None: - state.wire_mtime = data["wire_mtime"] - changed = True - - return "migrated" if changed else "no_change" - - -def load_session_state(session_dir: Path) -> SessionState: - state_file = session_dir / STATE_FILE_NAME - if not state_file.exists(): - state = SessionState() - else: - try: - with open(state_file, encoding="utf-8") as f: - state = SessionState.model_validate(json.load(f)) - except (json.JSONDecodeError, ValidationError, UnicodeDecodeError): - logger.warning("Corrupted state file, using defaults: {path}", path=state_file) - state = SessionState() - - # One-time migration from legacy metadata.json (best-effort) - migration = _migrate_legacy_metadata(session_dir, state) - if migration in ("migrated", "no_change"): - try: - if migration == "migrated": - save_session_state(state, session_dir) - (session_dir / _LEGACY_METADATA_FILENAME).unlink(missing_ok=True) - except OSError: - logger.warning( - "Failed to persist migration for {path}, will retry next load", - path=session_dir, - ) - - return state - - -def save_session_state(state: SessionState, session_dir: Path) -> None: - state_file = session_dir / STATE_FILE_NAME - atomic_json_write(state.model_dump(mode="json"), state_file) diff --git a/src/kimi_cli/share.py b/src/kimi_cli/share.py deleted file mode 100644 index 485869008..000000000 --- a/src/kimi_cli/share.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -import os -from pathlib import Path - - -def get_share_dir() -> Path: - """Get the share directory path.""" - if share_dir := os.getenv("KIMI_SHARE_DIR"): - share_dir = Path(share_dir) - else: - share_dir = Path.home() / ".kimi" - share_dir.mkdir(parents=True, exist_ok=True) - return share_dir diff --git a/src/kimi_cli/skill/__init__.py b/src/kimi_cli/skill/__init__.py deleted file mode 100644 index 1e42c7c67..000000000 --- a/src/kimi_cli/skill/__init__.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Skill specification discovery and loading utilities.""" - -from __future__ import annotations - -from collections.abc import Callable, Iterable, Iterator, Sequence -from pathlib import Path -from typing import Literal - -from kaos import get_current_kaos -from kaos.local import local_kaos -from kaos.path import KaosPath -from pydantic import BaseModel, ConfigDict - -from kimi_cli import logger -from kimi_cli.skill.flow import Flow, FlowError -from kimi_cli.skill.flow.d2 import parse_d2_flowchart -from kimi_cli.skill.flow.mermaid import parse_mermaid_flowchart -from kimi_cli.utils.frontmatter import parse_frontmatter - -SkillType = Literal["standard", "flow"] - - -def get_builtin_skills_dir() -> Path: - """ - Get the built-in skills directory path. - """ - return Path(__file__).parent.parent / "skills" - - -def get_user_skills_dir_candidates() -> tuple[KaosPath, ...]: - """ - Get user-level skills directory candidates in priority order. - """ - return ( - KaosPath.home() / ".config" / "agents" / "skills", - KaosPath.home() / ".agents" / "skills", - KaosPath.home() / ".kimi" / "skills", - KaosPath.home() / ".claude" / "skills", - KaosPath.home() / ".codex" / "skills", - ) - - -def get_project_skills_dir_candidates(work_dir: KaosPath) -> tuple[KaosPath, ...]: - """ - Get project-level skills directory candidates in priority order. - """ - return ( - work_dir / ".agents" / "skills", - work_dir / ".kimi" / "skills", - work_dir / ".claude" / "skills", - work_dir / ".codex" / "skills", - ) - - -def _supports_builtin_skills() -> bool: - """Return True when the active KAOS backend can read bundled skills.""" - current_name = get_current_kaos().name - return current_name in (local_kaos.name, "acp") - - -async def find_first_existing_dir(candidates: Iterable[KaosPath]) -> KaosPath | None: - """ - Return the first existing directory from candidates. - """ - for candidate in candidates: - if await candidate.is_dir(): - return candidate - return None - - -async def find_user_skills_dir() -> KaosPath | None: - """ - Return the first existing user-level skills directory. - """ - return await find_first_existing_dir(get_user_skills_dir_candidates()) - - -async def find_project_skills_dir(work_dir: KaosPath) -> KaosPath | None: - """ - Return the first existing project-level skills directory. - """ - return await find_first_existing_dir(get_project_skills_dir_candidates(work_dir)) - - -async def resolve_skills_roots( - work_dir: KaosPath, - *, - skills_dirs: Sequence[KaosPath] | None = None, -) -> list[KaosPath]: - """ - Resolve layered skill roots in priority order. - - Built-in skills load first when supported by the active KAOS backend. - When custom directories are provided via ``--skills-dir``, they **override** - user/project discovery. Plugins are always discoverable. - """ - from kimi_cli.plugin.manager import get_plugins_dir - - roots: list[KaosPath] = [] - if _supports_builtin_skills(): - roots.append(KaosPath.unsafe_from_local_path(get_builtin_skills_dir())) - if skills_dirs: - roots.extend(skills_dirs) - else: - if user_dir := await find_user_skills_dir(): - roots.append(user_dir) - if project_dir := await find_project_skills_dir(work_dir): - roots.append(project_dir) - # Plugins are always discoverable - plugins_path = get_plugins_dir() - if plugins_path.is_dir(): - roots.append(KaosPath.unsafe_from_local_path(plugins_path)) - return roots - - -def normalize_skill_name(name: str) -> str: - """Normalize a skill name for lookup.""" - return name.casefold() - - -def index_skills(skills: Iterable[Skill]) -> dict[str, Skill]: - """Build a lookup table for skills by normalized name.""" - return {normalize_skill_name(skill.name): skill for skill in skills} - - -async def discover_skills_from_roots(skills_dirs: Iterable[KaosPath]) -> list[Skill]: - """ - Discover skills from multiple directory roots. - """ - skills_by_name: dict[str, Skill] = {} - for skills_dir in skills_dirs: - for skill in await discover_skills(skills_dir): - skills_by_name.setdefault(normalize_skill_name(skill.name), skill) - return sorted(skills_by_name.values(), key=lambda s: s.name) - - -async def read_skill_text(skill: Skill) -> str | None: - """Read the SKILL.md contents for a skill.""" - try: - return (await skill.skill_md_file.read_text(encoding="utf-8")).strip() - except OSError as exc: - logger.warning( - "Failed to read skill file {path}: {error}", - path=skill.skill_md_file, - error=exc, - ) - return None - - -class Skill(BaseModel): - """Information about a single skill.""" - - model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True) - - name: str - description: str - type: SkillType = "standard" - dir: KaosPath - flow: Flow | None = None - - @property - def skill_md_file(self) -> KaosPath: - """Path to the SKILL.md file.""" - return self.dir / "SKILL.md" - - -async def discover_skills(skills_dir: KaosPath) -> list[Skill]: - """ - Discover all skills in the given directory. - - Args: - skills_dir: Kaos path to the directory containing skills. - - Returns: - List of Skill objects, one for each valid skill found. - """ - if not await skills_dir.is_dir(): - return [] - - skills: list[Skill] = [] - - async for skill_dir in skills_dir.iterdir(): - if not await skill_dir.is_dir(): - continue - - skill_md = skill_dir / "SKILL.md" - if not await skill_md.is_file(): - continue - - try: - content = await skill_md.read_text(encoding="utf-8") - skills.append(parse_skill_text(content, dir_path=skill_dir)) - except Exception as exc: - logger.info("Skipping invalid skill at {}: {}", skill_md, exc) - continue - - return sorted(skills, key=lambda s: s.name) - - -def parse_skill_text(content: str, *, dir_path: KaosPath) -> Skill: - """ - Parse SKILL.md contents to extract name and description. - """ - frontmatter = parse_frontmatter(content) or {} - - name = frontmatter.get("name") or dir_path.name - description = frontmatter.get("description") or "No description provided." - skill_type = frontmatter.get("type") or "standard" - if skill_type not in ("standard", "flow"): - raise ValueError(f'Invalid skill type "{skill_type}"') - flow = None - if skill_type == "flow": - try: - flow = _parse_flow_from_skill(content) - except ValueError as exc: - logger.error("Failed to parse flow skill {name}: {error}", name=name, error=exc) - skill_type = "standard" - flow = None - - return Skill( - name=name, - description=description, - type=skill_type, - dir=dir_path, - flow=flow, - ) - - -def _parse_flow_from_skill(content: str) -> Flow: - for lang, code in _iter_fenced_codeblocks(content): - if lang == "mermaid": - return _parse_flow_block(parse_mermaid_flowchart, code) - if lang == "d2": - return _parse_flow_block(parse_d2_flowchart, code) - raise ValueError("Flow skills require a mermaid or d2 code block in SKILL.md.") - - -def _parse_flow_block(parser: Callable[[str], Flow], code: str) -> Flow: - try: - return parser(code) - except FlowError as exc: - raise ValueError(f"Invalid flow diagram: {exc}") from exc - - -def _iter_fenced_codeblocks(content: str) -> Iterator[tuple[str, str]]: - fence = "" - fence_char = "" - lang = "" - buf: list[str] = [] - in_block = False - - for line in content.splitlines(): - stripped = line.lstrip() - if not in_block: - if match := _parse_fence_open(stripped): - fence, fence_char, info = match - lang = _normalize_code_lang(info) - in_block = True - buf = [] - continue - - if _is_fence_close(stripped, fence_char, len(fence)): - yield lang, "\n".join(buf).strip("\n") - in_block = False - fence = "" - fence_char = "" - lang = "" - buf = [] - continue - - buf.append(line) - - -def _normalize_code_lang(info: str) -> str: - if not info: - return "" - lang = info.split()[0].strip().lower() - if lang.startswith("{") and lang.endswith("}"): - lang = lang[1:-1].strip() - return lang - - -def _parse_fence_open(line: str) -> tuple[str, str, str] | None: - if not line or line[0] not in ("`", "~"): - return None - fence_char = line[0] - count = 0 - for ch in line: - if ch == fence_char: - count += 1 - else: - break - if count < 3: - return None - fence = fence_char * count - info = line[count:].strip() - return fence, fence_char, info - - -def _is_fence_close(line: str, fence_char: str, fence_len: int) -> bool: - if not fence_char or not line or line[0] != fence_char: - return False - count = 0 - for ch in line: - if ch == fence_char: - count += 1 - else: - break - if count < fence_len: - return False - return not line[count:].strip() diff --git a/src/kimi_cli/skill/flow/__init__.py b/src/kimi_cli/skill/flow/__init__.py deleted file mode 100644 index 7f2541a2e..000000000 --- a/src/kimi_cli/skill/flow/__init__.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -import re -from dataclasses import dataclass -from typing import Literal - -from kosong.message import ContentPart - -FlowNodeKind = Literal["begin", "end", "task", "decision"] - - -class FlowError(ValueError): - """Base error for flow parsing/validation.""" - - -class FlowParseError(FlowError): - """Raised when prompt flow parsing fails.""" - - -class FlowValidationError(FlowError): - """Raised when a flowchart fails validation.""" - - -@dataclass(frozen=True, slots=True) -class FlowNode: - id: str - label: str | list[ContentPart] - kind: FlowNodeKind - - -@dataclass(frozen=True, slots=True) -class FlowEdge: - src: str - dst: str - label: str | None - - -@dataclass(slots=True) -class Flow: - nodes: dict[str, FlowNode] - outgoing: dict[str, list[FlowEdge]] - begin_id: str - end_id: str - - -_CHOICE_RE = re.compile(r"([^<]*)") - - -def parse_choice(text: str) -> str | None: - matches = _CHOICE_RE.findall(text or "") - if not matches: - return None - return matches[-1].strip() - - -def validate_flow( - nodes: dict[str, FlowNode], - outgoing: dict[str, list[FlowEdge]], -) -> tuple[str, str]: - begin_ids = [node.id for node in nodes.values() if node.kind == "begin"] - end_ids = [node.id for node in nodes.values() if node.kind == "end"] - - if len(begin_ids) != 1: - raise FlowValidationError(f"Expected exactly one BEGIN node, found {len(begin_ids)}") - if len(end_ids) != 1: - raise FlowValidationError(f"Expected exactly one END node, found {len(end_ids)}") - - begin_id = begin_ids[0] - end_id = end_ids[0] - - reachable: set[str] = set() - queue: list[str] = [begin_id] - while queue: - node_id = queue.pop() - if node_id in reachable: - continue - reachable.add(node_id) - for edge in outgoing.get(node_id, []): - if edge.dst not in reachable: - queue.append(edge.dst) - - for node in nodes.values(): - if node.id not in reachable: - continue - edges = outgoing.get(node.id, []) - if len(edges) <= 1: - continue - labels: list[str] = [] - for edge in edges: - if edge.label is None or not edge.label.strip(): - raise FlowValidationError(f'Node "{node.id}" has an unlabeled edge') - labels.append(edge.label) - if len(set(labels)) != len(labels): - raise FlowValidationError(f'Node "{node.id}" has duplicate edge labels') - - if end_id not in reachable: - raise FlowValidationError("END node is not reachable from BEGIN") - - return begin_id, end_id diff --git a/src/kimi_cli/skill/flow/d2.py b/src/kimi_cli/skill/flow/d2.py deleted file mode 100644 index 9c5a74290..000000000 --- a/src/kimi_cli/skill/flow/d2.py +++ /dev/null @@ -1,482 +0,0 @@ -from __future__ import annotations - -import re -from collections.abc import Iterable -from dataclasses import dataclass - -from . import ( - Flow, - FlowEdge, - FlowNode, - FlowNodeKind, - FlowParseError, - validate_flow, -) - -_NODE_ID_RE = re.compile(r"[A-Za-z0-9_][A-Za-z0-9_./-]*") -_BLOCK_TAG_RE = re.compile(r"^\|md$") -_PROPERTY_SEGMENTS = { - "shape", - "style", - "label", - "link", - "icon", - "near", - "width", - "height", - "direction", - "grid-rows", - "grid-columns", - "grid-gap", - "font-size", - "font-family", - "font-color", - "stroke", - "fill", - "opacity", - "padding", - "border-radius", - "shadow", - "sketch", - "animated", - "multiple", - "constraint", - "tooltip", -} - - -@dataclass(frozen=True, slots=True) -class _NodeDef: - node: FlowNode - explicit: bool - - -def parse_d2_flowchart(text: str) -> Flow: - # Normalize D2 markdown blocks into quoted labels so the parser can stay line-based. - text = _normalize_markdown_blocks(text) - nodes: dict[str, _NodeDef] = {} - outgoing: dict[str, list[FlowEdge]] = {} - - for line_no, statement in _iter_top_level_statements(text): - if _has_unquoted_token(statement, "->"): - _parse_edge_statement(statement, line_no, nodes, outgoing) - else: - _parse_node_statement(statement, line_no, nodes) - - flow_nodes = {node_id: node_def.node for node_id, node_def in nodes.items()} - for node_id in flow_nodes: - outgoing.setdefault(node_id, []) - - flow_nodes = _infer_decision_nodes(flow_nodes, outgoing) - begin_id, end_id = validate_flow(flow_nodes, outgoing) - return Flow(nodes=flow_nodes, outgoing=outgoing, begin_id=begin_id, end_id=end_id) - - -def _normalize_markdown_blocks(text: str) -> str: - normalized = text.replace("\r\n", "\n").replace("\r", "\n") - lines = normalized.split("\n") - out_lines: list[str] = [] - i = 0 - line_no = 1 - - while i < len(lines): - line = lines[i] - prefix, suffix = _split_unquoted_once(line, ":") - if suffix is None: - out_lines.append(line) - i += 1 - line_no += 1 - continue - - suffix_clean = _strip_unquoted_comment(suffix).strip() - # Only treat `: |md` as a markdown block starter. - if not _BLOCK_TAG_RE.fullmatch(suffix_clean): - out_lines.append(line) - i += 1 - line_no += 1 - continue - - start_line = line_no - block_lines: list[str] = [] - i += 1 - line_no += 1 - while i < len(lines): - block_line = lines[i] - if block_line.strip() == "|": - break - block_lines.append(block_line) - i += 1 - line_no += 1 - if i >= len(lines): - raise FlowParseError(_line_error(start_line, "Unclosed markdown block")) - - # Convert the block into a multiline quoted string label. - dedented = _dedent_block(block_lines) - if dedented: - escaped = [_escape_quoted_line(line) for line in dedented] - out_lines.append(f'{prefix}: "{escaped[0]}') - for line in escaped[1:]: - out_lines.append(line) - out_lines[-1] = f'{out_lines[-1]}"' - out_lines.extend(["", ""]) - else: - out_lines.append(f'{prefix}: ""') - out_lines.append("") - - i += 1 - line_no += 1 - - return "\n".join(out_lines) - - -def _strip_unquoted_comment(text: str) -> str: - in_single = False - in_double = False - escape = False - for idx, ch in enumerate(text): - if escape: - escape = False - continue - if ch == "\\" and (in_single or in_double): - escape = True - continue - if ch == "'" and not in_double: - in_single = not in_single - continue - if ch == '"' and not in_single: - in_double = not in_double - continue - if ch == "#" and not in_single and not in_double: - return text[:idx] - return text - - -def _dedent_block(lines: list[str]) -> list[str]: - indent: int | None = None - for line in lines: - if not line.strip(): - continue - stripped = line.lstrip(" \t") - lead = len(line) - len(stripped) - if indent is None or lead < indent: - indent = lead - if indent is None: - return ["" for _ in lines] - return [line[indent:] if len(line) >= indent else "" for line in lines] - - -def _escape_quoted_line(line: str) -> str: - return line.replace("\\", "\\\\").replace('"', '\\"') - - -def _iter_top_level_statements(text: str) -> Iterable[tuple[int, str]]: - text = text.replace("\r\n", "\n").replace("\r", "\n") - brace_depth = 0 - in_single = False - in_double = False - escape = False - drop_line = False - buf: list[str] = [] - line_no = 1 - stmt_line = 1 - i = 0 - - while i < len(text): - ch = text[i] - next_ch = text[i + 1] if i + 1 < len(text) else "" - - if ch == "\\" and next_ch == "\n": - i += 2 - line_no += 1 - continue - - if ch == "\n": - # Preserve newlines inside quoted strings (used for markdown block labels). - if (in_single or in_double) and brace_depth == 0 and not drop_line: - buf.append("\n") - line_no += 1 - i += 1 - continue - if brace_depth == 0 and not in_single and not in_double and not drop_line: - statement = "".join(buf).strip() - if statement: - yield stmt_line, statement - buf = [] - drop_line = False - stmt_line = line_no + 1 - line_no += 1 - i += 1 - continue - - if not in_single and not in_double: - if ch == "#": - while i < len(text) and text[i] != "\n": - i += 1 - continue - if ch == "{": - if brace_depth == 0: - statement = "".join(buf).strip() - if statement: - yield stmt_line, statement - drop_line = True - buf.clear() - brace_depth += 1 - i += 1 - continue - if ch == "}" and brace_depth > 0: - brace_depth -= 1 - i += 1 - continue - if ch == "}" and brace_depth == 0: - raise FlowParseError(_line_error(line_no, "Unmatched '}'")) - - if ch == "'" and not in_double and not escape: - in_single = not in_single - elif ch == '"' and not in_single and not escape: - in_double = not in_double - - if escape: - escape = False - elif ch == "\\" and (in_single or in_double): - escape = True - - if brace_depth == 0 and not drop_line: - buf.append(ch) - - i += 1 - - if brace_depth != 0: - raise FlowParseError(_line_error(line_no, "Unclosed '{' block")) - if in_single or in_double: - raise FlowParseError(_line_error(line_no, "Unclosed string")) - - statement = "".join(buf).strip() - if statement: - yield stmt_line, statement - - -def _has_unquoted_token(text: str, token: str) -> bool: - parts = _split_on_token(text, token) - return len(parts) > 1 - - -def _parse_edge_statement( - statement: str, - line_no: int, - nodes: dict[str, _NodeDef], - outgoing: dict[str, list[FlowEdge]], -) -> None: - parts = _split_on_token(statement, "->") - if len(parts) < 2: - raise FlowParseError(_line_error(line_no, "Expected edge arrow")) - - last_part = parts[-1] - target_text, edge_label = _split_unquoted_once(last_part, ":") - parts[-1] = target_text - - node_ids: list[str] = [] - for idx, part in enumerate(parts): - node_id = _parse_node_id(part, line_no, allow_inline_label=(idx < len(parts) - 1)) - node_ids.append(node_id) - - if any(_is_property_path(node_id) for node_id in node_ids): - return - if len(node_ids) < 2: - raise FlowParseError(_line_error(line_no, "Edge must have at least two nodes")) - - label = _parse_label(edge_label, line_no) if edge_label is not None else None - for idx in range(len(node_ids) - 1): - edge = FlowEdge( - src=node_ids[idx], - dst=node_ids[idx + 1], - label=label if idx == len(node_ids) - 2 else None, - ) - outgoing.setdefault(edge.src, []).append(edge) - outgoing.setdefault(edge.dst, []) - - for node_id in node_ids: - _add_node(nodes, node_id=node_id, label=None, explicit=False, line_no=line_no) - - -def _parse_node_statement(statement: str, line_no: int, nodes: dict[str, _NodeDef]) -> None: - node_text, label_text = _split_unquoted_once(statement, ":") - if label_text is not None and _is_property_path(node_text): - return - node_id = _parse_node_id(node_text, line_no, allow_inline_label=False) - label = None - explicit = False - if label_text is not None and not label_text.strip(): - return - if label_text is not None: - label = _parse_label(label_text, line_no) - explicit = True - _add_node(nodes, node_id=node_id, label=label, explicit=explicit, line_no=line_no) - - -def _parse_node_id(text: str, line_no: int, *, allow_inline_label: bool) -> str: - cleaned = text.strip() - if allow_inline_label and ":" in cleaned: - cleaned = _split_unquoted_once(cleaned, ":")[0].strip() - if not cleaned: - raise FlowParseError(_line_error(line_no, "Expected node id")) - match = _NODE_ID_RE.fullmatch(cleaned) - if not match: - raise FlowParseError(_line_error(line_no, f'Invalid node id "{cleaned}"')) - return match.group(0) - - -def _is_property_path(node_id: str) -> bool: - if "." not in node_id: - return False - parts = [part for part in node_id.split(".") if part] - for part in parts[1:]: - if part in _PROPERTY_SEGMENTS or part.startswith("style"): - return True - return parts[-1] in _PROPERTY_SEGMENTS - - -def _parse_label(text: str, line_no: int) -> str: - label = text.strip() - if not label: - raise FlowParseError(_line_error(line_no, "Label cannot be empty")) - if label[0] in {"'", '"'}: - return _parse_quoted_label(label, line_no) - return label - - -def _parse_quoted_label(text: str, line_no: int) -> str: - quote = text[0] - buf: list[str] = [] - escape = False - i = 1 - while i < len(text): - ch = text[i] - if escape: - buf.append(ch) - escape = False - i += 1 - continue - if ch == "\\": - escape = True - i += 1 - continue - if ch == quote: - trailing = text[i + 1 :].strip() - if trailing: - raise FlowParseError(_line_error(line_no, "Unexpected trailing content")) - return "".join(buf) - buf.append(ch) - i += 1 - raise FlowParseError(_line_error(line_no, "Unclosed quoted label")) - - -def _split_on_token(text: str, token: str) -> list[str]: - parts: list[str] = [] - buf: list[str] = [] - in_single = False - in_double = False - escape = False - i = 0 - - while i < len(text): - if not in_single and not in_double and text.startswith(token, i): - parts.append("".join(buf).strip()) - buf = [] - i += len(token) - continue - ch = text[i] - if escape: - escape = False - elif ch == "\\" and (in_single or in_double): - escape = True - elif ch == "'" and not in_double: - in_single = not in_single - elif ch == '"' and not in_single: - in_double = not in_double - buf.append(ch) - i += 1 - - if in_single or in_double: - raise FlowParseError("Unclosed string in statement") - parts.append("".join(buf).strip()) - return parts - - -def _split_unquoted_once(text: str, token: str) -> tuple[str, str | None]: - in_single = False - in_double = False - escape = False - for idx, ch in enumerate(text): - if escape: - escape = False - continue - if ch == "\\" and (in_single or in_double): - escape = True - continue - if ch == "'" and not in_double: - in_single = not in_single - continue - if ch == '"' and not in_single: - in_double = not in_double - continue - if ch == token and not in_single and not in_double: - return text[:idx].strip(), text[idx + 1 :].strip() - return text.strip(), None - - -def _add_node( - nodes: dict[str, _NodeDef], - *, - node_id: str, - label: str | None, - explicit: bool, - line_no: int, -) -> FlowNode: - label = label if label is not None else node_id - label_norm = label.strip().lower() - if not label: - raise FlowParseError(_line_error(line_no, "Node label cannot be empty")) - - kind: FlowNodeKind = "task" - if label_norm == "begin": - kind = "begin" - elif label_norm == "end": - kind = "end" - - node = FlowNode(id=node_id, label=label, kind=kind) - existing = nodes.get(node_id) - if existing is None: - nodes[node_id] = _NodeDef(node=node, explicit=explicit) - return node - - if existing.node == node: - return existing.node - - if not explicit and existing.explicit: - return existing.node - - if explicit and not existing.explicit: - nodes[node_id] = _NodeDef(node=node, explicit=True) - return node - - raise FlowParseError(_line_error(line_no, f'Conflicting definition for node "{node_id}"')) - - -def _infer_decision_nodes( - nodes: dict[str, FlowNode], - outgoing: dict[str, list[FlowEdge]], -) -> dict[str, FlowNode]: - updated: dict[str, FlowNode] = {} - for node_id, node in nodes.items(): - kind = node.kind - if kind == "task" and len(outgoing.get(node_id, [])) > 1: - kind = "decision" - if kind != node.kind: - updated[node_id] = FlowNode(id=node.id, label=node.label, kind=kind) - else: - updated[node_id] = node - return updated - - -def _line_error(line_no: int, message: str) -> str: - return f"Line {line_no}: {message}" diff --git a/src/kimi_cli/skill/flow/d2.ts b/src/kimi_cli/skill/flow/d2.ts new file mode 100644 index 000000000..f357e62f5 --- /dev/null +++ b/src/kimi_cli/skill/flow/d2.ts @@ -0,0 +1,435 @@ +/** + * D2 flowchart parser — corresponds to Python skill/flow/d2.py + */ + +import { + type Flow, + type FlowEdge, + type FlowNode, + type FlowNodeKind, + FlowParseError, + validateFlow, +} from "./index.ts"; + +const NODE_ID_RE = /^[A-Za-z0-9_][A-Za-z0-9_./-]*/; +const BLOCK_TAG_RE = /^\|md$/; +const PROPERTY_SEGMENTS = new Set([ + "shape", "style", "label", "link", "icon", "near", "width", "height", + "direction", "grid-rows", "grid-columns", "grid-gap", "font-size", + "font-family", "font-color", "stroke", "fill", "opacity", "padding", + "border-radius", "shadow", "sketch", "animated", "multiple", + "constraint", "tooltip", +]); + +interface NodeDef { + node: FlowNode; + explicit: boolean; +} + +export function parseD2Flowchart(text: string): Flow { + const normalized = normalizeMarkdownBlocks(text); + const nodes = new Map(); + const outgoing = new Map(); + + for (const [lineNo, statement] of iterTopLevelStatements(normalized)) { + if (hasUnquotedToken(statement, "->")) { + parseEdgeStatement(statement, lineNo, nodes, outgoing); + } else { + parseNodeStatement(statement, lineNo, nodes); + } + } + + const flowNodes: Record = {}; + for (const [id, def] of nodes) { + flowNodes[id] = def.node; + if (!outgoing.has(id)) outgoing.set(id, []); + } + + const outgoingRecord: Record = {}; + for (const [k, v] of outgoing) outgoingRecord[k] = v; + + const inferred = inferDecisionNodes(flowNodes, outgoingRecord); + const [beginId, endId] = validateFlow(inferred, outgoingRecord); + return { nodes: inferred, outgoing: outgoingRecord, beginId, endId }; +} + +function normalizeMarkdownBlocks(text: string): string { + const normalized = text.replace(/\r\n/g, "\n").replace(/\r/g, "\n"); + const lines = normalized.split("\n"); + const outLines: string[] = []; + let i = 0; + let lineNo = 1; + + while (i < lines.length) { + const line = lines[i]!; + const [prefix, suffix] = splitUnquotedOnce(line, ":"); + + if (suffix == null) { + outLines.push(line); + i++; + lineNo++; + continue; + } + + const suffixClean = stripUnquotedComment(suffix).trim(); + if (!BLOCK_TAG_RE.test(suffixClean)) { + outLines.push(line); + i++; + lineNo++; + continue; + } + + const startLine = lineNo; + const blockLines: string[] = []; + i++; + lineNo++; + while (i < lines.length) { + const blockLine = lines[i]!; + if (blockLine.trim() === "|") break; + blockLines.push(blockLine); + i++; + lineNo++; + } + if (i >= lines.length) { + throw new FlowParseError(lineError(startLine, "Unclosed markdown block")); + } + + const dedented = dedentBlock(blockLines); + if (dedented.length > 0 && dedented.some((l) => l.length > 0)) { + const escaped = dedented.map(escapeQuotedLine); + outLines.push(`${prefix}: "${escaped[0]}`); + for (let j = 1; j < escaped.length; j++) { + outLines.push(escaped[j]!); + } + outLines[outLines.length - 1] = `${outLines[outLines.length - 1]}"`; + outLines.push("", ""); + } else { + outLines.push(`${prefix}: ""`); + outLines.push(""); + } + + i++; + lineNo++; + } + + return outLines.join("\n"); +} + +function stripUnquotedComment(text: string): string { + let inSingle = false; + let inDouble = false; + let escape = false; + for (let idx = 0; idx < text.length; idx++) { + const ch = text[idx]!; + if (escape) { escape = false; continue; } + if (ch === "\\" && (inSingle || inDouble)) { escape = true; continue; } + if (ch === "'" && !inDouble) { inSingle = !inSingle; continue; } + if (ch === '"' && !inSingle) { inDouble = !inDouble; continue; } + if (ch === "#" && !inSingle && !inDouble) return text.slice(0, idx); + } + return text; +} + +function dedentBlock(lines: string[]): string[] { + let indent: number | undefined; + for (const line of lines) { + if (!line.trim()) continue; + const stripped = line.replace(/^[ \t]+/, ""); + const lead = line.length - stripped.length; + if (indent === undefined || lead < indent) indent = lead; + } + if (indent === undefined) return lines.map(() => ""); + return lines.map((line) => (line.length >= indent! ? line.slice(indent!) : "")); +} + +function escapeQuotedLine(line: string): string { + return line.replace(/\\/g, "\\\\").replace(/"/g, '\\"'); +} + +function* iterTopLevelStatements(text: string): Generator<[number, string]> { + const normalized = text.replace(/\r\n/g, "\n").replace(/\r/g, "\n"); + let braceDepth = 0; + let inSingle = false; + let inDouble = false; + let escape = false; + let dropLine = false; + let buf: string[] = []; + let lineNo = 1; + let stmtLine = 1; + let i = 0; + + while (i < normalized.length) { + const ch = normalized[i]!; + const nextCh = i + 1 < normalized.length ? normalized[i + 1]! : ""; + + if (ch === "\\" && nextCh === "\n") { + i += 2; + lineNo++; + continue; + } + + if (ch === "\n") { + if ((inSingle || inDouble) && braceDepth === 0 && !dropLine) { + buf.push("\n"); + lineNo++; + i++; + continue; + } + if (braceDepth === 0 && !inSingle && !inDouble && !dropLine) { + const statement = buf.join("").trim(); + if (statement) yield [stmtLine, statement]; + } + buf = []; + dropLine = false; + stmtLine = lineNo + 1; + lineNo++; + i++; + continue; + } + + if (!inSingle && !inDouble) { + if (ch === "#") { + while (i < normalized.length && normalized[i] !== "\n") i++; + continue; + } + if (ch === "{") { + if (braceDepth === 0) { + const statement = buf.join("").trim(); + if (statement) yield [stmtLine, statement]; + dropLine = true; + buf = []; + } + braceDepth++; + i++; + continue; + } + if (ch === "}" && braceDepth > 0) { + braceDepth--; + i++; + continue; + } + if (ch === "}" && braceDepth === 0) { + throw new FlowParseError(lineError(lineNo, "Unmatched '}'")); + } + } + + if (ch === "'" && !inDouble && !escape) inSingle = !inSingle; + else if (ch === '"' && !inSingle && !escape) inDouble = !inDouble; + + if (escape) escape = false; + else if (ch === "\\" && (inSingle || inDouble)) escape = true; + + if (braceDepth === 0 && !dropLine) buf.push(ch); + i++; + } + + if (braceDepth !== 0) throw new FlowParseError(lineError(lineNo, "Unclosed '{' block")); + if (inSingle || inDouble) throw new FlowParseError(lineError(lineNo, "Unclosed string")); + + const statement = buf.join("").trim(); + if (statement) yield [stmtLine, statement]; +} + +function hasUnquotedToken(text: string, token: string): boolean { + return splitOnToken(text, token).length > 1; +} + +function parseEdgeStatement( + statement: string, + lineNo: number, + nodes: Map, + outgoing: Map, +): void { + const parts = splitOnToken(statement, "->"); + if (parts.length < 2) throw new FlowParseError(lineError(lineNo, "Expected edge arrow")); + + const lastPart = parts[parts.length - 1]!; + const [targetText, edgeLabel] = splitUnquotedOnce(lastPart, ":"); + parts[parts.length - 1] = targetText; + + const nodeIds: string[] = []; + for (let idx = 0; idx < parts.length; idx++) { + const nodeId = parseNodeId(parts[idx]!, lineNo, idx < parts.length - 1); + nodeIds.push(nodeId); + } + + if (nodeIds.some(isPropertyPath)) return; + if (nodeIds.length < 2) throw new FlowParseError(lineError(lineNo, "Edge must have at least two nodes")); + + const label = edgeLabel != null ? parseLabelText(edgeLabel, lineNo) : undefined; + for (let idx = 0; idx < nodeIds.length - 1; idx++) { + const edge: FlowEdge = { + src: nodeIds[idx]!, + dst: nodeIds[idx + 1]!, + label: idx === nodeIds.length - 2 ? label : undefined, + }; + if (!outgoing.has(edge.src)) outgoing.set(edge.src, []); + outgoing.get(edge.src)!.push(edge); + if (!outgoing.has(edge.dst)) outgoing.set(edge.dst, []); + } + + for (const nodeId of nodeIds) { + addNode(nodes, nodeId, undefined, false, lineNo); + } +} + +function parseNodeStatement(statement: string, lineNo: number, nodes: Map): void { + const [nodeText, labelText] = splitUnquotedOnce(statement, ":"); + if (labelText != null && isPropertyPath(nodeText)) return; + const nodeId = parseNodeId(nodeText, lineNo, false); + let label: string | undefined; + let explicit = false; + if (labelText != null && !labelText.trim()) return; + if (labelText != null) { + label = parseLabelText(labelText, lineNo); + explicit = true; + } + addNode(nodes, nodeId, label, explicit, lineNo); +} + +function parseNodeId(text: string, lineNo: number, allowInlineLabel: boolean): string { + let cleaned = text.trim(); + if (allowInlineLabel && cleaned.includes(":")) { + cleaned = splitUnquotedOnce(cleaned, ":")[0].trim(); + } + if (!cleaned) throw new FlowParseError(lineError(lineNo, "Expected node id")); + const match = cleaned.match(NODE_ID_RE); + if (!match || match[0] !== cleaned) { + throw new FlowParseError(lineError(lineNo, `Invalid node id "${cleaned}"`)); + } + return match[0]!; +} + +function isPropertyPath(nodeId: string): boolean { + if (!nodeId.includes(".")) return false; + const parts = nodeId.split(".").filter(Boolean); + for (let i = 1; i < parts.length; i++) { + if (PROPERTY_SEGMENTS.has(parts[i]!) || parts[i]!.startsWith("style")) return true; + } + return PROPERTY_SEGMENTS.has(parts[parts.length - 1]!); +} + +function parseLabelText(text: string, lineNo: number): string { + const label = text.trim(); + if (!label) throw new FlowParseError(lineError(lineNo, "Label cannot be empty")); + if (label[0] === "'" || label[0] === '"') return parseQuotedLabel(label, lineNo); + return label; +} + +function parseQuotedLabel(text: string, lineNo: number): string { + const quote = text[0]!; + const buf: string[] = []; + let escape = false; + let i = 1; + while (i < text.length) { + const ch = text[i]!; + if (escape) { buf.push(ch); escape = false; i++; continue; } + if (ch === "\\") { escape = true; i++; continue; } + if (ch === quote) { + const trailing = text.slice(i + 1).trim(); + if (trailing) throw new FlowParseError(lineError(lineNo, "Unexpected trailing content")); + return buf.join(""); + } + buf.push(ch); + i++; + } + throw new FlowParseError(lineError(lineNo, "Unclosed quoted label")); +} + +function splitOnToken(text: string, token: string): string[] { + const parts: string[] = []; + let buf: string[] = []; + let inSingle = false; + let inDouble = false; + let escape = false; + let i = 0; + + while (i < text.length) { + if (!inSingle && !inDouble && text.startsWith(token, i)) { + parts.push(buf.join("").trim()); + buf = []; + i += token.length; + continue; + } + const ch = text[i]!; + if (escape) escape = false; + else if (ch === "\\" && (inSingle || inDouble)) escape = true; + else if (ch === "'" && !inDouble) inSingle = !inSingle; + else if (ch === '"' && !inSingle) inDouble = !inDouble; + buf.push(ch); + i++; + } + if (inSingle || inDouble) throw new FlowParseError("Unclosed string in statement"); + parts.push(buf.join("").trim()); + return parts; +} + +function splitUnquotedOnce(text: string, token: string): [string, string | undefined] { + let inSingle = false; + let inDouble = false; + let escape = false; + for (let idx = 0; idx < text.length; idx++) { + const ch = text[idx]!; + if (escape) { escape = false; continue; } + if (ch === "\\" && (inSingle || inDouble)) { escape = true; continue; } + if (ch === "'" && !inDouble) { inSingle = !inSingle; continue; } + if (ch === '"' && !inSingle) { inDouble = !inDouble; continue; } + if (ch === token && !inSingle && !inDouble) { + return [text.slice(0, idx).trim(), text.slice(idx + 1).trim()]; + } + } + return [text.trim(), undefined]; +} + +function addNode( + nodes: Map, + nodeId: string, + label: string | undefined, + explicit: boolean, + lineNo: number, +): FlowNode { + const effectiveLabel = label ?? nodeId; + const labelNorm = effectiveLabel.trim().toLowerCase(); + if (!effectiveLabel) throw new FlowParseError(lineError(lineNo, "Node label cannot be empty")); + + let kind: FlowNodeKind = "task"; + if (labelNorm === "begin") kind = "begin"; + else if (labelNorm === "end") kind = "end"; + + const node: FlowNode = { id: nodeId, label: effectiveLabel, kind }; + const existing = nodes.get(nodeId); + + if (!existing) { + nodes.set(nodeId, { node, explicit }); + return node; + } + + if (existing.node.id === node.id && existing.node.label === node.label && existing.node.kind === node.kind) { + return existing.node; + } + + if (!explicit && existing.explicit) return existing.node; + if (explicit && !existing.explicit) { + nodes.set(nodeId, { node, explicit: true }); + return node; + } + + throw new FlowParseError(lineError(lineNo, `Conflicting definition for node "${nodeId}"`)); +} + +function inferDecisionNodes( + nodes: Record, + outgoing: Record, +): Record { + const updated: Record = {}; + for (const [nodeId, node] of Object.entries(nodes)) { + let kind = node.kind; + if (kind === "task" && (outgoing[nodeId]?.length ?? 0) > 1) kind = "decision"; + updated[nodeId] = kind !== node.kind ? { id: node.id, label: node.label, kind } : node; + } + return updated; +} + +function lineError(lineNo: number, message: string): string { + return `Line ${lineNo}: ${message}`; +} diff --git a/src/kimi_cli/skill/flow/index.ts b/src/kimi_cli/skill/flow/index.ts new file mode 100644 index 000000000..87f337ec3 --- /dev/null +++ b/src/kimi_cli/skill/flow/index.ts @@ -0,0 +1,110 @@ +/** + * Flow graph types and validation — corresponds to Python skill/flow/__init__.py + */ + +export type FlowNodeKind = "begin" | "end" | "task" | "decision"; + +export class FlowError extends Error { + constructor(message: string) { + super(message); + this.name = "FlowError"; + } +} + +export class FlowParseError extends FlowError { + constructor(message: string) { + super(message); + this.name = "FlowParseError"; + } +} + +export class FlowValidationError extends FlowError { + constructor(message: string) { + super(message); + this.name = "FlowValidationError"; + } +} + +export interface FlowNode { + readonly id: string; + readonly label: string; + readonly kind: FlowNodeKind; +} + +export interface FlowEdge { + readonly src: string; + readonly dst: string; + readonly label: string | undefined; +} + +export interface Flow { + readonly nodes: Record; + readonly outgoing: Record; + readonly beginId: string; + readonly endId: string; +} + +const CHOICE_RE = /([^<]*)<\/choice>/g; + +export function parseChoice(text: string): string | undefined { + const matches = [...(text || "").matchAll(CHOICE_RE)]; + if (matches.length === 0) return undefined; + return matches[matches.length - 1]![1]!.trim(); +} + +export function validateFlow( + nodes: Record, + outgoing: Record, +): [string, string] { + const beginIds = Object.values(nodes) + .filter((n) => n.kind === "begin") + .map((n) => n.id); + const endIds = Object.values(nodes) + .filter((n) => n.kind === "end") + .map((n) => n.id); + + if (beginIds.length !== 1) { + throw new FlowValidationError(`Expected exactly one BEGIN node, found ${beginIds.length}`); + } + if (endIds.length !== 1) { + throw new FlowValidationError(`Expected exactly one END node, found ${endIds.length}`); + } + + const beginId = beginIds[0]!; + const endId = endIds[0]!; + + // BFS reachability + const reachable = new Set(); + const queue = [beginId]; + while (queue.length > 0) { + const nodeId = queue.pop()!; + if (reachable.has(nodeId)) continue; + reachable.add(nodeId); + for (const edge of outgoing[nodeId] ?? []) { + if (!reachable.has(edge.dst)) queue.push(edge.dst); + } + } + + // Validate decision nodes have labeled, unique edges + for (const node of Object.values(nodes)) { + if (!reachable.has(node.id)) continue; + const edges = outgoing[node.id] ?? []; + if (edges.length <= 1) continue; + const labels: string[] = []; + for (const edge of edges) { + if (!edge.label?.trim()) { + throw new FlowValidationError(`Node "${node.id}" has an unlabeled edge`); + } + labels.push(edge.label); + } + if (new Set(labels).size !== labels.length) { + throw new FlowValidationError(`Node "${node.id}" has duplicate edge labels`); + } + } + + if (!reachable.has(endId)) { + throw new FlowValidationError("END node is not reachable from BEGIN"); + } + + return [beginId, endId]; +} diff --git a/src/kimi_cli/skill/flow/mermaid.py b/src/kimi_cli/skill/flow/mermaid.py deleted file mode 100644 index 77dea345d..000000000 --- a/src/kimi_cli/skill/flow/mermaid.py +++ /dev/null @@ -1,266 +0,0 @@ -from __future__ import annotations - -import re -from dataclasses import dataclass - -from . import ( - Flow, - FlowEdge, - FlowNode, - FlowNodeKind, - FlowParseError, - validate_flow, -) - - -@dataclass(frozen=True, slots=True) -class _NodeSpec: - node_id: str - label: str | None - - -@dataclass(slots=True) -class _NodeDef: - node: FlowNode - explicit: bool - - -_NODE_ID_RE = re.compile(r"[A-Za-z0-9_][A-Za-z0-9_-]*") -_HEADER_RE = re.compile(r"^(flowchart|graph)\b", re.IGNORECASE) - -_SHAPES = { - "[": "]", - "(": ")", - "{": "}", -} -_PIPE_LABEL_RE = re.compile(r"\|([^|]*)\|") -_EDGE_LABEL_RE = re.compile(r"--\s*([^>-][^>]*)\s*-->") -_ARROW_RE = re.compile(r"[-.=]+>") - - -def parse_mermaid_flowchart(text: str) -> Flow: - nodes: dict[str, _NodeDef] = {} - outgoing: dict[str, list[FlowEdge]] = {} - - for line_no, raw_line in enumerate(text.splitlines(), start=1): - line = _strip_comment(raw_line).strip() - if not line or line.startswith("%%"): - continue - if _HEADER_RE.match(line): - continue - if _is_style_line(line): - continue - line = _strip_style_tokens(line) - - edge = _try_parse_edge_line(line, line_no) - if edge is not None: - src_spec, label, dst_spec = edge - src_node = _add_node(nodes, src_spec, line_no) - dst_node = _add_node(nodes, dst_spec, line_no) - flow_edge = FlowEdge(src=src_node.id, dst=dst_node.id, label=label) - outgoing.setdefault(flow_edge.src, []).append(flow_edge) - outgoing.setdefault(flow_edge.dst, []) - continue - - node_spec = _try_parse_node_line(line, line_no) - if node_spec is not None: - _add_node(nodes, node_spec, line_no) - - flow_nodes = {node_id: node_def.node for node_id, node_def in nodes.items()} - for node_id in flow_nodes: - outgoing.setdefault(node_id, []) - - flow_nodes = _infer_decision_nodes(flow_nodes, outgoing) - begin_id, end_id = validate_flow(flow_nodes, outgoing) - return Flow(nodes=flow_nodes, outgoing=outgoing, begin_id=begin_id, end_id=end_id) - - -def _try_parse_edge_line(line: str, line_no: int) -> tuple[_NodeSpec, str | None, _NodeSpec] | None: - try: - src_spec, idx = _parse_node_token(line, 0, line_no) - except FlowParseError: - return None - - normalized, label = _normalize_edge_line(line) - idx = _skip_ws(normalized, idx) - if ">" not in normalized[idx:]: - if "---" not in normalized[idx:]: - return None - normalized = normalized[:idx] + normalized[idx:].replace("---", "-->", 1) - - normalized = _ARROW_RE.sub("-->", normalized) - arrow_idx = normalized.rfind(">") - if arrow_idx == -1: - return None - - dst_start = _skip_ws(normalized, arrow_idx + 1) - try: - dst_spec, _ = _parse_node_token(normalized, dst_start, line_no) - except FlowParseError: - return None - - return src_spec, label, dst_spec - - -def _parse_node_token(line: str, idx: int, line_no: int) -> tuple[_NodeSpec, int]: - match = _NODE_ID_RE.match(line, idx) - if not match: - raise FlowParseError(_line_error(line_no, "Expected node id")) - node_id = match.group(0) - idx = match.end() - - if idx >= len(line) or line[idx] not in _SHAPES: - return _NodeSpec(node_id=node_id, label=None), idx - - close_char = _SHAPES[line[idx]] - idx += 1 - label, idx = _parse_label(line, idx, close_char, line_no) - return _NodeSpec(node_id=node_id, label=label), idx - - -def _parse_label(line: str, idx: int, close_char: str, line_no: int) -> tuple[str, int]: - if idx >= len(line): - raise FlowParseError(_line_error(line_no, "Expected node label")) - if close_char == ")" and line[idx] == "[": - label, idx = _parse_label(line, idx + 1, "]", line_no) - while idx < len(line) and line[idx].isspace(): - idx += 1 - if idx >= len(line) or line[idx] != ")": - raise FlowParseError(_line_error(line_no, "Unclosed node label")) - return label, idx + 1 - if line[idx] == '"': - idx += 1 - buf: list[str] = [] - while idx < len(line): - ch = line[idx] - if ch == '"': - idx += 1 - while idx < len(line) and line[idx].isspace(): - idx += 1 - if idx >= len(line) or line[idx] != close_char: - raise FlowParseError(_line_error(line_no, "Unclosed node label")) - return "".join(buf), idx + 1 - if ch == "\\" and idx + 1 < len(line): - buf.append(line[idx + 1]) - idx += 2 - continue - buf.append(ch) - idx += 1 - raise FlowParseError(_line_error(line_no, "Unclosed quoted label")) - - end = line.find(close_char, idx) - if end == -1: - raise FlowParseError(_line_error(line_no, "Unclosed node label")) - label = line[idx:end].strip() - if not label: - raise FlowParseError(_line_error(line_no, "Node label cannot be empty")) - return label, end + 1 - - -def _skip_ws(line: str, idx: int) -> int: - while idx < len(line) and line[idx].isspace(): - idx += 1 - return idx - - -def _add_node(nodes: dict[str, _NodeDef], spec: _NodeSpec, line_no: int) -> FlowNode: - label = spec.label if spec.label is not None else spec.node_id - label_norm = label.strip().lower() - if not label: - raise FlowParseError(_line_error(line_no, "Node label cannot be empty")) - - kind: FlowNodeKind = "task" - if label_norm == "begin": - kind = "begin" - elif label_norm == "end": - kind = "end" - - node = FlowNode(id=spec.node_id, label=label, kind=kind) - explicit = spec.label is not None - - existing = nodes.get(spec.node_id) - if existing is None: - nodes[spec.node_id] = _NodeDef(node=node, explicit=explicit) - return node - - if existing.node == node: - return existing.node - - if not explicit and existing.explicit: - return existing.node - - if explicit and not existing.explicit: - nodes[spec.node_id] = _NodeDef(node=node, explicit=True) - return node - - raise FlowParseError(_line_error(line_no, f'Conflicting definition for node "{spec.node_id}"')) - - -def _line_error(line_no: int, message: str) -> str: - return f"Line {line_no}: {message}" - - -def _strip_comment(line: str) -> str: - if "%%" not in line: - return line - return line.split("%%", 1)[0] - - -def _is_style_line(line: str) -> bool: - lowered = line.lower() - if lowered in ("end",): - return True - return lowered.startswith( - ( - "classdef ", - "class ", - "style ", - "linkstyle ", - "click ", - "subgraph ", - "direction ", - ) - ) - - -def _strip_style_tokens(line: str) -> str: - return re.sub(r":::[A-Za-z0-9_-]+", "", line) - - -def _try_parse_node_line(line: str, line_no: int) -> _NodeSpec | None: - try: - node_spec, _ = _parse_node_token(line, 0, line_no) - except FlowParseError: - return None - return node_spec - - -def _normalize_edge_line(line: str) -> tuple[str, str | None]: - label = None - normalized = line - pipe_match = _PIPE_LABEL_RE.search(normalized) - if pipe_match: - label = pipe_match.group(1).strip() or None - normalized = normalized[: pipe_match.start()] + normalized[pipe_match.end() :] - if label is None: - edge_match = _EDGE_LABEL_RE.search(normalized) - if edge_match: - label = edge_match.group(1).strip() or None - normalized = normalized[: edge_match.start()] + "-->" + normalized[edge_match.end() :] - return normalized, label - - -def _infer_decision_nodes( - nodes: dict[str, FlowNode], - outgoing: dict[str, list[FlowEdge]], -) -> dict[str, FlowNode]: - updated: dict[str, FlowNode] = {} - for node_id, node in nodes.items(): - kind = node.kind - if kind == "task" and len(outgoing.get(node_id, [])) > 1: - kind = "decision" - if kind != node.kind: - updated[node_id] = FlowNode(id=node.id, label=node.label, kind=kind) - else: - updated[node_id] = node - return updated diff --git a/src/kimi_cli/skill/flow/mermaid.ts b/src/kimi_cli/skill/flow/mermaid.ts new file mode 100644 index 000000000..3a816da26 --- /dev/null +++ b/src/kimi_cli/skill/flow/mermaid.ts @@ -0,0 +1,273 @@ +/** + * Mermaid flowchart parser — corresponds to Python skill/flow/mermaid.py + */ + +import { + type Flow, + type FlowEdge, + type FlowNode, + type FlowNodeKind, + FlowParseError, + validateFlow, +} from "./index.ts"; + +interface NodeSpec { + nodeId: string; + label: string | undefined; +} + +interface NodeDef { + node: FlowNode; + explicit: boolean; +} + +const NODE_ID_RE = /^[A-Za-z0-9_][A-Za-z0-9_-]*/; +const HEADER_RE = /^(flowchart|graph)\b/i; + +const SHAPES: Record = { "[": "]", "(": ")", "{": "}" }; +const PIPE_LABEL_RE = /\|([^|]*)\|/; +const EDGE_LABEL_RE = /--\s*([^>-][^>]*)\s*-->/; +const ARROW_RE = /[-.=]+>/g; + +export function parseMermaidFlowchart(text: string): Flow { + const nodes = new Map(); + const outgoing = new Map(); + + for (const [lineNo, rawLine] of text.split("\n").entries()) { + const line = stripComment(rawLine).trim(); + if (!line || line.startsWith("%%")) continue; + if (HEADER_RE.test(line)) continue; + if (isStyleLine(line)) continue; + const cleaned = stripStyleTokens(line); + + const edge = tryParseEdgeLine(cleaned, lineNo + 1); + if (edge) { + const [srcSpec, label, dstSpec] = edge; + const srcNode = addNode(nodes, srcSpec, lineNo + 1); + const dstNode = addNode(nodes, dstSpec, lineNo + 1); + const flowEdge: FlowEdge = { src: srcNode.id, dst: dstNode.id, label }; + if (!outgoing.has(flowEdge.src)) outgoing.set(flowEdge.src, []); + outgoing.get(flowEdge.src)!.push(flowEdge); + if (!outgoing.has(flowEdge.dst)) outgoing.set(flowEdge.dst, []); + continue; + } + + const nodeSpec = tryParseNodeLine(cleaned, lineNo + 1); + if (nodeSpec) addNode(nodes, nodeSpec, lineNo + 1); + } + + const flowNodes: Record = {}; + for (const [id, def] of nodes) { + flowNodes[id] = def.node; + if (!outgoing.has(id)) outgoing.set(id, []); + } + + const outgoingRecord: Record = {}; + for (const [k, v] of outgoing) outgoingRecord[k] = v; + + const inferred = inferDecisionNodes(flowNodes, outgoingRecord); + const [beginId, endId] = validateFlow(inferred, outgoingRecord); + return { nodes: inferred, outgoing: outgoingRecord, beginId, endId }; +} + +function tryParseEdgeLine(line: string, lineNo: number): [NodeSpec, string | undefined, NodeSpec] | undefined { + let srcSpec: NodeSpec; + let idx: number; + try { + [srcSpec, idx] = parseNodeToken(line, 0, lineNo); + } catch { + return undefined; + } + + const [normalized, label] = normalizeEdgeLine(line); + idx = skipWs(normalized, idx); + let norm = normalized; + if (!norm.slice(idx).includes(">")) { + if (!norm.slice(idx).includes("---")) return undefined; + norm = norm.slice(0, idx) + norm.slice(idx).replace("---", "-->"); + } + + norm = norm.replace(ARROW_RE, "-->"); + const arrowIdx = norm.lastIndexOf(">"); + if (arrowIdx === -1) return undefined; + + const dstStart = skipWs(norm, arrowIdx + 1); + let dstSpec: NodeSpec; + try { + [dstSpec] = parseNodeToken(norm, dstStart, lineNo); + } catch { + return undefined; + } + + return [srcSpec, label, dstSpec]; +} + +function parseNodeToken(line: string, idx: number, lineNo: number): [NodeSpec, number] { + const match = line.slice(idx).match(NODE_ID_RE); + if (!match) throw new FlowParseError(lineError(lineNo, "Expected node id")); + const nodeId = match[0]!; + idx += match[0]!.length; + + if (idx >= line.length || !(line[idx]! in SHAPES)) { + return [{ nodeId, label: undefined }, idx]; + } + + const closeChar = SHAPES[line[idx]!]!; + idx++; + const [label, newIdx] = parseLabel(line, idx, closeChar, lineNo); + return [{ nodeId, label }, newIdx]; +} + +function parseLabel(line: string, idx: number, closeChar: string, lineNo: number): [string, number] { + if (idx >= line.length) throw new FlowParseError(lineError(lineNo, "Expected node label")); + + if (closeChar === ")" && line[idx] === "[") { + const [label, newIdx] = parseLabel(line, idx + 1, "]", lineNo); + let i = newIdx; + while (i < line.length && line[i] === " ") i++; + if (i >= line.length || line[i] !== ")") { + throw new FlowParseError(lineError(lineNo, "Unclosed node label")); + } + return [label, i + 1]; + } + + if (line[idx] === '"') { + idx++; + const buf: string[] = []; + while (idx < line.length) { + const ch = line[idx]!; + if (ch === '"') { + idx++; + while (idx < line.length && line[idx] === " ") idx++; + if (idx >= line.length || line[idx] !== closeChar) { + throw new FlowParseError(lineError(lineNo, "Unclosed node label")); + } + return [buf.join(""), idx + 1]; + } + if (ch === "\\" && idx + 1 < line.length) { + buf.push(line[idx + 1]!); + idx += 2; + continue; + } + buf.push(ch); + idx++; + } + throw new FlowParseError(lineError(lineNo, "Unclosed quoted label")); + } + + const end = line.indexOf(closeChar, idx); + if (end === -1) throw new FlowParseError(lineError(lineNo, "Unclosed node label")); + const label = line.slice(idx, end).trim(); + if (!label) throw new FlowParseError(lineError(lineNo, "Node label cannot be empty")); + return [label, end + 1]; +} + +function skipWs(line: string, idx: number): number { + while (idx < line.length && line[idx] === " ") idx++; + return idx; +} + +function addNode(nodes: Map, spec: NodeSpec, lineNo: number): FlowNode { + const label = spec.label ?? spec.nodeId; + const labelNorm = label.trim().toLowerCase(); + if (!label) throw new FlowParseError(lineError(lineNo, "Node label cannot be empty")); + + let kind: FlowNodeKind = "task"; + if (labelNorm === "begin") kind = "begin"; + else if (labelNorm === "end") kind = "end"; + + const node: FlowNode = { id: spec.nodeId, label, kind }; + const explicit = spec.label != null; + const existing = nodes.get(spec.nodeId); + + if (!existing) { + nodes.set(spec.nodeId, { node, explicit }); + return node; + } + + if (existing.node.id === node.id && existing.node.label === node.label && existing.node.kind === node.kind) { + return existing.node; + } + + if (!explicit && existing.explicit) return existing.node; + if (explicit && !existing.explicit) { + nodes.set(spec.nodeId, { node, explicit: true }); + return node; + } + + throw new FlowParseError(lineError(lineNo, `Conflicting definition for node "${spec.nodeId}"`)); +} + +function lineError(lineNo: number, message: string): string { + return `Line ${lineNo}: ${message}`; +} + +function stripComment(line: string): string { + if (!line.includes("%%")) return line; + return line.split("%%")[0]!; +} + +function isStyleLine(line: string): boolean { + const lowered = line.toLowerCase(); + if (lowered === "end") return true; + return lowered.startsWith("classdef ") || + lowered.startsWith("class ") || + lowered.startsWith("style ") || + lowered.startsWith("linkstyle ") || + lowered.startsWith("click ") || + lowered.startsWith("subgraph ") || + lowered.startsWith("direction "); +} + +function stripStyleTokens(line: string): string { + return line.replace(/:::[A-Za-z0-9_-]+/g, ""); +} + +function tryParseNodeLine(line: string, lineNo: number): NodeSpec | undefined { + try { + const [spec] = parseNodeToken(line, 0, lineNo); + return spec; + } catch { + return undefined; + } +} + +function normalizeEdgeLine(line: string): [string, string | undefined] { + let label: string | undefined; + let normalized = line; + + const pipeMatch = PIPE_LABEL_RE.exec(normalized); + if (pipeMatch) { + label = pipeMatch[1]!.trim() || undefined; + normalized = normalized.slice(0, pipeMatch.index) + normalized.slice(pipeMatch.index! + pipeMatch[0]!.length); + } + + if (label == null) { + const edgeMatch = EDGE_LABEL_RE.exec(normalized); + if (edgeMatch) { + label = edgeMatch[1]!.trim() || undefined; + normalized = normalized.slice(0, edgeMatch.index) + "-->" + normalized.slice(edgeMatch.index! + edgeMatch[0]!.length); + } + } + + return [normalized, label]; +} + +function inferDecisionNodes( + nodes: Record, + outgoing: Record, +): Record { + const updated: Record = {}; + for (const [nodeId, node] of Object.entries(nodes)) { + let kind = node.kind; + if (kind === "task" && (outgoing[nodeId]?.length ?? 0) > 1) { + kind = "decision"; + } + if (kind !== node.kind) { + updated[nodeId] = { id: node.id, label: node.label, kind }; + } else { + updated[nodeId] = node; + } + } + return updated; +} diff --git a/src/kimi_cli/skill/index.ts b/src/kimi_cli/skill/index.ts new file mode 100644 index 000000000..79d3a8b71 --- /dev/null +++ b/src/kimi_cli/skill/index.ts @@ -0,0 +1,283 @@ +/** + * Skill specification discovery and loading — corresponds to Python skill/__init__.py + */ + +import { join, resolve, dirname } from "node:path"; +import { existsSync, readdirSync, readFileSync, statSync } from "node:fs"; +import { homedir } from "node:os"; +import { logger } from "../utils/logging.ts"; +import type { Flow } from "./flow/index.ts"; +import { FlowError } from "./flow/index.ts"; +import { parseMermaidFlowchart } from "./flow/mermaid.ts"; +import { parseD2Flowchart } from "./flow/d2.ts"; + +export type SkillType = "standard" | "flow"; + +export interface Skill { + readonly name: string; + readonly description: string; + readonly type: SkillType; + readonly dir: string; + readonly flow?: Flow; + readonly skillMdFile: string; +} + +// ── Directory discovery ── + +export function getBuiltinSkillsDir(): string { + return join(dirname(new URL(import.meta.url).pathname), "..", "skills"); +} + +export function getUserSkillsDirCandidates(): string[] { + const home = homedir(); + return [ + join(home, ".config", "agents", "skills"), + join(home, ".agents", "skills"), + join(home, ".kimi", "skills"), + join(home, ".claude", "skills"), + join(home, ".codex", "skills"), + ]; +} + +export function getProjectSkillsDirCandidates(workDir: string): string[] { + return [ + join(workDir, ".agents", "skills"), + join(workDir, ".kimi", "skills"), + join(workDir, ".claude", "skills"), + join(workDir, ".codex", "skills"), + ]; +} + +export function findFirstExistingDir(candidates: string[]): string | undefined { + for (const candidate of candidates) { + try { + if (existsSync(candidate) && statSync(candidate).isDirectory()) { + return candidate; + } + } catch { + continue; + } + } + return undefined; +} + +export function findUserSkillsDir(): string | undefined { + return findFirstExistingDir(getUserSkillsDirCandidates()); +} + +export function findProjectSkillsDir(workDir: string): string | undefined { + return findFirstExistingDir(getProjectSkillsDirCandidates(workDir)); +} + +export function resolveSkillsRoots(workDir: string, opts?: { skillsDirs?: string[] }): string[] { + const roots: string[] = []; + const builtinDir = getBuiltinSkillsDir(); + if (existsSync(builtinDir)) roots.push(builtinDir); + + if (opts?.skillsDirs && opts.skillsDirs.length > 0) { + roots.push(...opts.skillsDirs); + } else { + const userDir = findUserSkillsDir(); + if (userDir) roots.push(userDir); + const projectDir = findProjectSkillsDir(workDir); + if (projectDir) roots.push(projectDir); + } + return roots; +} + +// ── Skill parsing ── + +export function normalizeSkillName(name: string): string { + return name.toLowerCase(); +} + +export function indexSkills(skills: Skill[]): Map { + const map = new Map(); + for (const skill of skills) { + map.set(normalizeSkillName(skill.name), skill); + } + return map; +} + +export function discoverSkillsFromRoots(skillsDirs: string[]): Skill[] { + const skillsByName = new Map(); + for (const dir of skillsDirs) { + for (const skill of discoverSkills(dir)) { + const key = normalizeSkillName(skill.name); + if (!skillsByName.has(key)) { + skillsByName.set(key, skill); + } + } + } + return [...skillsByName.values()].sort((a, b) => a.name.localeCompare(b.name)); +} + +export function readSkillText(skill: Skill): string | undefined { + try { + return readFileSync(skill.skillMdFile, "utf-8").trim(); + } catch { + logger.warn(`Failed to read skill file ${skill.skillMdFile}`); + return undefined; + } +} + +export function discoverSkills(skillsDir: string): Skill[] { + if (!existsSync(skillsDir)) return []; + try { + if (!statSync(skillsDir).isDirectory()) return []; + } catch { + return []; + } + + const skills: Skill[] = []; + for (const entry of readdirSync(skillsDir)) { + const skillDir = join(skillsDir, entry); + try { + if (!statSync(skillDir).isDirectory()) continue; + } catch { + continue; + } + const skillMd = join(skillDir, "SKILL.md"); + if (!existsSync(skillMd)) continue; + + try { + const content = readFileSync(skillMd, "utf-8"); + skills.push(parseSkillText(content, skillDir)); + } catch (err) { + logger.info(`Skipping invalid skill at ${skillMd}: ${err}`); + } + } + return skills.sort((a, b) => a.name.localeCompare(b.name)); +} + +export function parseSkillText(content: string, dirPath: string): Skill { + const frontmatter = parseFrontmatter(content) ?? {}; + const name = (frontmatter.name as string) || dirPath.split("/").pop() || "unknown"; + const description = (frontmatter.description as string) || "No description provided."; + let skillType: SkillType = ((frontmatter.type as string) || "standard") as SkillType; + + if (skillType !== "standard" && skillType !== "flow") { + throw new Error(`Invalid skill type "${skillType}"`); + } + + let flow: Flow | undefined; + if (skillType === "flow") { + try { + flow = parseFlowFromSkill(content); + } catch (err) { + logger.error(`Failed to parse flow skill ${name}: ${err}`); + skillType = "standard"; + flow = undefined; + } + } + + return { + name, + description, + type: skillType, + dir: dirPath, + flow, + skillMdFile: join(dirPath, "SKILL.md"), + }; +} + +function parseFlowFromSkill(content: string): Flow { + for (const [lang, code] of iterFencedCodeblocks(content)) { + if (lang === "mermaid") return parseMermaidFlowchart(code); + if (lang === "d2") return parseD2Flowchart(code); + } + throw new Error("Flow skills require a mermaid or d2 code block in SKILL.md."); +} + +function* iterFencedCodeblocks(content: string): Generator<[string, string]> { + let fence = ""; + let fenceChar = ""; + let lang = ""; + let buf: string[] = []; + let inBlock = false; + + for (const line of content.split("\n")) { + const stripped = line.trimStart(); + if (!inBlock) { + const match = parseFenceOpen(stripped); + if (match) { + [fence, fenceChar, lang] = match; + lang = normalizeCodeLang(lang); + inBlock = true; + buf = []; + } + continue; + } + + if (isFenceClose(stripped, fenceChar, fence.length)) { + yield [lang, buf.join("\n").replace(/^\n+|\n+$/g, "")]; + inBlock = false; + fence = ""; + fenceChar = ""; + lang = ""; + buf = []; + continue; + } + + buf.push(line); + } +} + +function normalizeCodeLang(info: string): string { + if (!info) return ""; + let lang = info.split(/\s+/)[0]!.trim().toLowerCase(); + if (lang.startsWith("{") && lang.endsWith("}")) { + lang = lang.slice(1, -1).trim(); + } + return lang; +} + +function parseFenceOpen(line: string): [string, string, string] | undefined { + if (!line || (line[0] !== "`" && line[0] !== "~")) return undefined; + const fenceChar = line[0]!; + let count = 0; + for (const ch of line) { + if (ch === fenceChar) count++; + else break; + } + if (count < 3) return undefined; + const fence = fenceChar.repeat(count); + const info = line.slice(count).trim(); + return [fence, fenceChar, info]; +} + +function isFenceClose(line: string, fenceChar: string, fenceLen: number): boolean { + if (!fenceChar || !line || line[0] !== fenceChar) return false; + let count = 0; + for (const ch of line) { + if (ch === fenceChar) count++; + else break; + } + if (count < fenceLen) return false; + return !line.slice(count).trim(); +} + +// Simple frontmatter parser +function parseFrontmatter(content: string): Record | undefined { + const lines = content.split("\n"); + if (lines[0]?.trim() !== "---") return undefined; + + const result: Record = {}; + for (let i = 1; i < lines.length; i++) { + const line = lines[i]!; + if (line.trim() === "---") return result; + const colonIdx = line.indexOf(":"); + if (colonIdx > 0) { + const key = line.slice(0, colonIdx).trim(); + const value = line.slice(colonIdx + 1).trim(); + // Strip quotes + if ((value.startsWith('"') && value.endsWith('"')) || + (value.startsWith("'") && value.endsWith("'"))) { + result[key] = value.slice(1, -1); + } else { + result[key] = value; + } + } + } + return undefined; // No closing --- +} diff --git a/src/kimi_cli/skills/kimi-cli-help/SKILL.md b/src/kimi_cli/skills/kimi-cli-help/SKILL.md deleted file mode 100644 index 19617bd6c..000000000 --- a/src/kimi_cli/skills/kimi-cli-help/SKILL.md +++ /dev/null @@ -1,55 +0,0 @@ ---- -name: kimi-cli-help -description: Answer Kimi Code CLI usage, configuration, and troubleshooting questions. Use when user asks about Kimi Code CLI installation, setup, configuration, slash commands, keyboard shortcuts, MCP integration, providers, environment variables, how something works internally, or any questions about Kimi Code CLI itself. ---- - -# Kimi Code CLI Help - -Help users with Kimi Code CLI questions by consulting documentation and source code. - -## Strategy - -1. **Prefer official documentation** for most questions -2. **Read local source** when in kimi-cli project itself, or when user is developing with kimi-cli as a library (e.g., importing from `kimi_cli` in their code) -3. **Clone and explore source** for complex internals not covered in docs - **ask user for confirmation first** - -## Documentation - -Base URL: `https://moonshotai.github.io/kimi-cli/` - -Fetch documentation index to find relevant pages: - -``` -https://moonshotai.github.io/kimi-cli/llms.txt -``` - -### Page URL Pattern - -- English: `https://moonshotai.github.io/kimi-cli/en/...` -- Chinese: `https://moonshotai.github.io/kimi-cli/zh/...` - -### Topic Mapping - -| Topic | Page | -|-------|------| -| Installation, first run | `/en/guides/getting-started.md` | -| Config files | `/en/configuration/config-files.md` | -| Providers, models | `/en/configuration/providers.md` | -| Environment variables | `/en/configuration/env-vars.md` | -| Slash commands | `/en/reference/slash-commands.md` | -| CLI flags | `/en/reference/kimi-command.md` | -| Keyboard shortcuts | `/en/reference/keyboard.md` | -| MCP | `/en/customization/mcp.md` | -| Agents | `/en/customization/agents.md` | -| Skills | `/en/customization/skills.md` | -| FAQ | `/en/faq.md` | - -## Source Code - -Repository: `https://github.com/MoonshotAI/kimi-cli` - -When to read source: - -- In kimi-cli project directory (check `pyproject.toml` for `name = "kimi-cli"`) -- User is importing `kimi_cli` as a library in their project -- Question about internals not covered in docs (ask user before cloning) diff --git a/src/kimi_cli/skills/skill-creator/SKILL.md b/src/kimi_cli/skills/skill-creator/SKILL.md deleted file mode 100644 index 143eadafc..000000000 --- a/src/kimi_cli/skills/skill-creator/SKILL.md +++ /dev/null @@ -1,367 +0,0 @@ ---- -name: skill-creator -description: Guide for creating effective skills. This skill should be used when users want to create a new skill (or update an existing skill) that extends Kimi's capabilities with specialized knowledge, workflows, or tool integrations. ---- - -# Skill Creator - -This skill provides guidance for creating effective skills. - -## About Skills - -Skills are modular, self-contained packages that extend Kimi's capabilities by providing -specialized knowledge, workflows, and tools. Think of them as "onboarding guides" for specific -domains or tasks—they transform Kimi from a general-purpose agent into a specialized agent -equipped with procedural knowledge that no model can fully possess. - -### What Skills Provide - -1. Specialized workflows - Multi-step procedures for specific domains -2. Tool integrations - Instructions for working with specific file formats or APIs -3. Domain expertise - Company-specific knowledge, schemas, business logic -4. Bundled resources - Scripts, references, and assets for complex and repetitive tasks - -## Core Principles - -### Concise is Key - -The context window is a public good. Skills share the context window with everything else Kimi needs: system prompt, conversation history, other Skills' metadata, and the actual user request. - -**Default assumption: Kimi is already very smart.** Only add context Kimi doesn't already have. Challenge each piece of information: "Does Kimi really need this explanation?" and "Does this paragraph justify its token cost?" - -Prefer concise examples over verbose explanations. - -### Set Appropriate Degrees of Freedom - -Match the level of specificity to the task's fragility and variability: - -**High freedom (text-based instructions)**: Use when multiple approaches are valid, decisions depend on context, or heuristics guide the approach. - -**Medium freedom (pseudocode or scripts with parameters)**: Use when a preferred pattern exists, some variation is acceptable, or configuration affects behavior. - -**Low freedom (specific scripts, few parameters)**: Use when operations are fragile and error-prone, consistency is critical, or a specific sequence must be followed. - -Think of Kimi as exploring a path: a narrow bridge with cliffs needs specific guardrails (low freedom), while an open field allows many routes (high freedom). - -### Anatomy of a Skill - -Every skill consists of a required SKILL.md file and optional bundled resources: - -``` -skill-name/ -├── SKILL.md (required) -│ ├── YAML frontmatter metadata (required) -│ │ ├── name: (required) -│ │ └── description: (required) -│ └── Markdown instructions (required) -└── Bundled Resources (optional) - ├── scripts/ - Executable code (Python/Bash/etc.) - ├── references/ - Documentation intended to be loaded into context as needed - └── assets/ - Files used in output (templates, icons, fonts, etc.) -``` - -#### SKILL.md (required) - -Every SKILL.md consists of: - -- **Frontmatter** (YAML): Contains `name` and `description` fields. These are the only fields that Kimi reads to determine when the skill gets used, thus it is very important to be clear and comprehensive in describing what the skill is, and when it should be used. -- **Body** (Markdown): Instructions and guidance for using the skill. Only loaded AFTER the skill triggers (if at all). - -#### Bundled Resources (optional) - -##### Scripts (`scripts/`) - -Executable code (Python/Bash/etc.) for tasks that require deterministic reliability or are repeatedly rewritten. - -- **When to include**: When the same code is being rewritten repeatedly or deterministic reliability is needed -- **Example**: `scripts/rotate_pdf.py` for PDF rotation tasks -- **Benefits**: Token efficient, deterministic, may be executed without loading into context -- **Note**: Scripts may still need to be read by Kimi for patching or environment-specific adjustments - -##### References (`references/`) - -Documentation and reference material intended to be loaded as needed into context to inform Kimi's process and thinking. - -- **When to include**: For documentation that Kimi should reference while working -- **Examples**: `references/finance.md` for financial schemas, `references/mnda.md` for company NDA template, `references/policies.md` for company policies, `references/api_docs.md` for API specifications -- **Use cases**: Database schemas, API documentation, domain knowledge, company policies, detailed workflow guides -- **Benefits**: Keeps SKILL.md lean, loaded only when Kimi determines it's needed -- **Best practice**: If files are large (>10k words), include grep search patterns in SKILL.md -- **Avoid duplication**: Information should live in either SKILL.md or references files, not both. Prefer references files for detailed information unless it's truly core to the skill—this keeps SKILL.md lean while making information discoverable without hogging the context window. Keep only essential procedural instructions and workflow guidance in SKILL.md; move detailed reference material, schemas, and examples to references files. - -##### Assets (`assets/`) - -Files not intended to be loaded into context, but rather used within the output Kimi produces. - -- **When to include**: When the skill needs files that will be used in the final output -- **Examples**: `assets/logo.png` for brand assets, `assets/slides.pptx` for PowerPoint templates, `assets/frontend-template/` for HTML/React boilerplate, `assets/font.ttf` for typography -- **Use cases**: Templates, images, icons, boilerplate code, fonts, sample documents that get copied or modified -- **Benefits**: Separates output resources from documentation, enables Kimi to use files without loading them into context - -#### What to Not Include in a Skill - -A skill should only contain essential files that directly support its functionality. Do NOT create extraneous documentation or auxiliary files, including: - -- README.md -- INSTALLATION_GUIDE.md -- QUICK_REFERENCE.md -- CHANGELOG.md -- etc. - -The skill should only contain the information needed for an AI agent to do the job at hand. It should not contain auxiliary context about the process that went into creating it, setup and testing procedures, user-facing documentation, etc. Creating additional documentation files just adds clutter and confusion. - -### Progressive Disclosure Design Principle - -Skills use a three-level loading system to manage context efficiently: - -1. **Metadata (name + description)** - Always in context (~100 words) -2. **SKILL.md body** - When skill triggers (<5k words) -3. **Bundled resources** - As needed by Kimi (Unlimited because scripts can be executed without reading into context window) - -#### Progressive Disclosure Patterns - -Keep SKILL.md body to the essentials and under 500 lines to minimize context bloat. Split content into separate files when approaching this limit. When splitting out content into other files, it is very important to reference them from SKILL.md and describe clearly when to read them, to ensure the reader of the skill knows they exist and when to use them. - -**Key principle:** When a skill supports multiple variations, frameworks, or options, keep only the core workflow and selection guidance in SKILL.md. Move variant-specific details (patterns, examples, configuration) into separate reference files. - -**Pattern 1: High-level guide with references** - -```markdown -# PDF Processing - -## Quick start - -Extract text with pdfplumber: -[code example] - -## Advanced features - -- **Form filling**: See [FORMS.md](FORMS.md) for complete guide -- **API reference**: See [REFERENCE.md](REFERENCE.md) for all methods -- **Examples**: See [EXAMPLES.md](EXAMPLES.md) for common patterns -``` - -Kimi loads FORMS.md, REFERENCE.md, or EXAMPLES.md only when needed. - -**Pattern 2: Domain-specific organization** - -For Skills with multiple domains, organize content by domain to avoid loading irrelevant context: - -``` -bigquery-skill/ -├── SKILL.md (overview and navigation) -└── reference/ - ├── finance.md (revenue, billing metrics) - ├── sales.md (opportunities, pipeline) - ├── product.md (API usage, features) - └── marketing.md (campaigns, attribution) -``` - -When a user asks about sales metrics, Kimi only reads sales.md. - -Similarly, for skills supporting multiple frameworks or variants, organize by variant: - -``` -cloud-deploy/ -├── SKILL.md (workflow + provider selection) -└── references/ - ├── aws.md (AWS deployment patterns) - ├── gcp.md (GCP deployment patterns) - └── azure.md (Azure deployment patterns) -``` - -When the user chooses AWS, Kimi only reads aws.md. - -**Pattern 3: Conditional details** - -Show basic content, link to advanced content: - -```markdown -# DOCX Processing - -## Creating documents - -Use docx-js for new documents. See [DOCX-JS.md](DOCX-JS.md). - -## Editing documents - -For simple edits, modify the XML directly. - -**For tracked changes**: See [REDLINING.md](REDLINING.md) -**For OOXML details**: See [OOXML.md](OOXML.md) -``` - -Kimi reads REDLINING.md or OOXML.md only when the user needs those features. - -**Important guidelines:** - -- **Avoid deeply nested references** - Keep references one level deep from SKILL.md. All reference files should link directly from SKILL.md. -- **Structure longer reference files** - For files longer than 100 lines, include a table of contents at the top so Kimi can see the full scope when previewing. - -## Skill Locations and Discovery - -Kimi Code CLI loads skills in layers (built-in -> user -> project). Within each layer, it uses the -first existing directory in priority order. Built-in skills only load for LocalKaos or ACPKaos. - -**User level** (by priority): -- `~/.config/agents/skills/` (recommended) -- `~/.kimi/skills/` -- `~/.claude/skills/` - -**Project level**: -- `.agents/skills/` - -`--skills-dir` overrides discovery and loads only that directory (built-ins still load when -supported). - -## Skill Creation Process - -Skill creation involves these steps: - -1. Understand the skill with concrete examples -2. Plan reusable skill contents (scripts, references, assets) -3. Initialize the skill (run init_skill.py) -4. Edit the skill (implement resources and write SKILL.md) -5. Package the skill (run package_skill.py) -6. Iterate based on real usage - -Follow these steps in order, skipping only if there is a clear reason why they are not applicable. - -### Skill Naming - -- Use lowercase letters, digits, and hyphens only; normalize user-provided titles to hyphen-case (e.g., "Plan Mode" -> `plan-mode`). -- When generating names, generate a name under 64 characters (letters, digits, hyphens). -- Prefer short, verb-led phrases that describe the action. -- Namespace by tool when it improves clarity or triggering (e.g., `gh-address-comments`, `linear-address-issue`). -- Name the skill folder exactly after the skill name. - -### Step 1: Understanding the Skill with Concrete Examples - -Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill. - -To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback. - -For example, when building an image-editor skill, relevant questions include: - -- "What functionality should the image-editor skill support? Editing, rotating, anything else?" -- "Can you give some examples of how this skill would be used?" -- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?" -- "What would a user say that should trigger this skill?" - -To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness. - -Conclude this step when there is a clear sense of the functionality the skill should support. - -### Step 2: Planning the Reusable Skill Contents - -To turn concrete examples into an effective skill, analyze each example by: - -1. Considering how to execute on the example from scratch -2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly - -Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows: - -1. Rotating a PDF requires re-writing the same code each time -2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill - -Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows: - -1. Writing a frontend webapp requires the same boilerplate HTML/React each time -2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill - -Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows: - -1. Querying BigQuery requires re-discovering the table schemas and relationships each time -2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill - -To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets. - -### Step 3: Initializing the Skill - -At this point, it is time to actually create the skill. - -Skip this step only if the skill being developed already exists, and iteration or packaging is needed. In this case, continue to the next step. - -When creating a new skill from scratch, create a new skill directory with a required `SKILL.md` -file and any optional resource directories that the skill needs (`scripts/`, `references/`, -`assets/`). Create only the directories you intend to populate. - -After initialization, customize the SKILL.md and add resources as needed. - -### Step 4: Edit the Skill - -When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of Kimi to use. Include information that would be beneficial and non-obvious to Kimi. Consider what procedural knowledge, domain-specific details, or reusable assets would help another Kimi instance execute these tasks more effectively. - -#### Learn Proven Design Patterns - -Capture proven design patterns directly in this SKILL.md: - -- **Multi-step processes**: Clearly describe sequential workflows and conditional branches, including triggers, decision points, and expected outputs at each step. -- **Specific output formats or quality standards**: Document required output shapes, templates, and examples directly in this SKILL.md so they are easy to follow. - -#### Start with Reusable Skill Contents - -To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`. - -Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion. - -Delete any placeholder files that are not needed for the skill. Only create resource directories that are actually required. - -#### Update SKILL.md - -**Writing Guidelines:** Always use imperative/infinitive form. - -##### Frontmatter - -Write the YAML frontmatter with `name` and `description`: - -- `name`: The skill name -- `description`: This is the primary triggering mechanism for your skill, and helps Kimi understand when to use the skill. - - Include both what the Skill does and specific triggers/contexts for when to use it. - - Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to Kimi. - - Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when Kimi needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks" - -Do not include any other fields in YAML frontmatter. - -##### Body - -Write instructions for using the skill and its bundled resources. - -### Step 5: Packaging a Skill - -Once development of the skill is complete, package it into a distributable `.skill` file (a zip -archive). Before packaging, validate that the skill meets all requirements: - -1. **Validate** the skill, checking: - - - YAML frontmatter format and required fields - - Skill naming conventions and directory structure - - Description completeness and quality - - File organization and resource references - -2. **Package** the skill if validation passes: - - - Create an archive of the skill's root folder (the folder containing `SKILL.md` and all related files). - - Ensure the archive preserves the internal directory structure. - - Name the archive `.skill` (for example, `my-skill.skill`). The `.skill` file is a zip file with a `.skill` extension. - -Example packaging command: - -```bash -cd -zip -r my-skill.skill my-skill -``` - -If validation fails (for example, due to malformed frontmatter, missing files, or an incomplete -description), fix the issues and repackage the skill. - -### Step 6: Iterate - -After testing the skill, users may request improvements. Often this happens right after using the skill, with fresh context of how the skill performed. - -**Iteration workflow:** - -1. Use the skill on real tasks -2. Notice struggles or inefficiencies -3. Identify how SKILL.md or bundled resources should be updated -4. Implement changes and test again diff --git a/src/kimi_cli/soul/__init__.py b/src/kimi_cli/soul/__init__.py deleted file mode 100644 index 1ca049a8a..000000000 --- a/src/kimi_cli/soul/__init__.py +++ /dev/null @@ -1,288 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -from collections.abc import Callable, Coroutine -from contextvars import ContextVar -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable - -from kimi_cli.hooks.engine import HookEngine -from kimi_cli.utils.aioqueue import QueueShutDown -from kimi_cli.utils.logging import logger -from kimi_cli.wire import Wire -from kimi_cli.wire.file import WireFile -from kimi_cli.wire.types import ContentPart, MCPStatusSnapshot, WireMessage - -if TYPE_CHECKING: - from kimi_cli.llm import LLM, ModelCapability - from kimi_cli.soul.agent import Runtime - from kimi_cli.utils.slashcmd import SlashCommand - - -class LLMNotSet(Exception): - """Raised when the LLM is not set.""" - - def __init__(self) -> None: - super().__init__("LLM not set") - - -class LLMNotSupported(Exception): - """Raised when the LLM does not have required capabilities.""" - - def __init__(self, llm: LLM, capabilities: list[ModelCapability]): - self.llm = llm - self.capabilities = capabilities - capabilities_str = "capability" if len(capabilities) == 1 else "capabilities" - super().__init__( - f"LLM model '{llm.model_name}' does not support required {capabilities_str}: " - f"{', '.join(capabilities)}." - ) - - -class MaxStepsReached(Exception): - """Raised when the maximum number of steps is reached.""" - - n_steps: int - """The number of steps that have been taken.""" - - def __init__(self, n_steps: int): - super().__init__(f"Max number of steps reached: {n_steps}") - self.n_steps = n_steps - - -def format_token_count(n: int) -> str: - """Format token count as compact string, e.g. 28.5k, 128k, 1.2m.""" - suffix = "" - if n >= 1_000_000: - value = n / 1_000_000 - suffix = "m" - elif n >= 1_000: - value = n / 1_000 - suffix = "k" - else: - return str(n) - - # Keep one decimal when needed, but drop trailing ".0". - compact = f"{value:.1f}".rstrip("0").rstrip(".") - return f"{compact}{suffix}" - - -def format_context_status( - context_usage: float, - context_tokens: int = 0, - max_context_tokens: int = 0, -) -> str: - """Format context status string for display in status bar.""" - bounded = max(0.0, min(context_usage, 1.0)) - if max_context_tokens > 0: - used = format_token_count(context_tokens) - total = format_token_count(max_context_tokens) - return f"context: {bounded:.1%} ({used}/{total})" - return f"context: {bounded:.1%}" - - -@dataclass(frozen=True, slots=True) -class StatusSnapshot: - context_usage: float - """The usage of the context, in percentage.""" - yolo_enabled: bool = False - """Whether YOLO (auto-approve) mode is enabled.""" - plan_mode: bool = False - """Whether plan mode (read-only research and planning) is active.""" - context_tokens: int = 0 - """The number of tokens currently in the context.""" - max_context_tokens: int = 0 - """The maximum number of tokens the context can hold.""" - mcp_status: MCPStatusSnapshot | None = None - """The current MCP startup snapshot, if MCP is configured.""" - - -@runtime_checkable -class Soul(Protocol): - @property - def name(self) -> str: - """The name of the soul.""" - ... - - @property - def model_name(self) -> str: - """The name of the LLM model used by the soul. Empty string if LLM is not set.""" - ... - - @property - def model_capabilities(self) -> set[ModelCapability] | None: - """The capabilities of the LLM model used by the soul. None if LLM is not set.""" - ... - - @property - def thinking(self) -> bool | None: - """ - Whether thinking mode is currently enabled. - None if LLM is not set or thinking mode is not set explicitly. - """ - ... - - @property - def status(self) -> StatusSnapshot: - """The current status of the soul. The returned value is immutable.""" - ... - - @property - def hook_engine(self) -> HookEngine: - """The hook engine for this soul.""" - ... - - @property - def available_slash_commands(self) -> list[SlashCommand[Any]]: - """List of available slash commands supported by the soul.""" - ... - - async def run(self, user_input: str | list[ContentPart]): - """ - Run the agent with the given user input until the max steps or no more tool calls. - - Args: - user_input (str | list[ContentPart]): The user input to the agent. - Can be a slash command call or natural language input. - - Raises: - LLMNotSet: When the LLM is not set. - LLMNotSupported: When the LLM does not have required capabilities. - ChatProviderError: When the LLM provider returns an error. - MaxStepsReached: When the maximum number of steps is reached. - asyncio.CancelledError: When the run is cancelled by user. - """ - ... - - -type UILoopFn = Callable[[Wire], Coroutine[Any, Any, None]] -"""A long-running async function to visualize the agent behavior.""" - - -class RunCancelled(Exception): - """The run was cancelled by the cancel event.""" - - -async def run_soul( - soul: Soul, - user_input: str | list[ContentPart], - ui_loop_fn: UILoopFn, - cancel_event: asyncio.Event, - wire_file: WireFile | None = None, - runtime: Runtime | None = None, -) -> None: - """ - Run the soul with the given user input, connecting it to the UI loop with a `Wire`. - - `cancel_event` is a outside handle that can be used to cancel the run. When the - event is set, the run will be gracefully stopped and a `RunCancelled` will be raised. - - Raises: - LLMNotSet: When the LLM is not set. - LLMNotSupported: When the LLM does not have required capabilities. - ChatProviderError: When the LLM provider returns an error. - MaxStepsReached: When the maximum number of steps is reached. - RunCancelled: When the run is cancelled by the cancel event. - """ - wire = Wire(file_backend=wire_file) - wire_token = _current_wire.set(wire) - - logger.debug("Starting UI loop with function: {ui_loop_fn}", ui_loop_fn=ui_loop_fn) - ui_task = asyncio.create_task(ui_loop_fn(wire)) - - logger.debug("Starting soul run") - soul_task = asyncio.create_task(soul.run(user_input)) - notification_task = asyncio.create_task(_pump_notifications_to_wire(runtime, wire)) - - cancel_event_task = asyncio.create_task(cancel_event.wait()) - await asyncio.wait( - [soul_task, cancel_event_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - try: - if cancel_event.is_set(): - logger.debug("Cancelling the run task") - soul_task.cancel() - try: - await soul_task - except asyncio.CancelledError: - raise RunCancelled from None - else: - assert soul_task.done() # either stop event is set or the run task is done - cancel_event_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await cancel_event_task - soul_task.result() # this will raise if any exception was raised in the run task - finally: - notification_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await notification_task - try: - await _deliver_notifications_to_wire_once(runtime, wire) - except Exception: - logger.exception("Failed to flush notifications to wire during shutdown") - logger.debug("Shutting down the UI loop") - # shutting down the wire should break the UI loop - wire.shutdown() - await wire.join() - try: - await asyncio.wait_for(ui_task, timeout=0.5) - except QueueShutDown: - logger.debug("UI loop shut down") - pass - except TimeoutError: - logger.warning("UI loop timed out") - finally: - _current_wire.reset(wire_token) - - -_current_wire = ContextVar[Wire | None]("current_wire", default=None) - - -def get_wire_or_none() -> Wire | None: - """ - Get the current wire or None. - Expect to be not None when called from anywhere in the agent loop. - """ - return _current_wire.get() - - -def wire_send(msg: WireMessage) -> None: - """ - Send a wire message to the current wire. - Take this as `print` and `input` for souls. - Souls should always use this function to send wire messages. - """ - wire = get_wire_or_none() - assert wire is not None, "Wire is expected to be set when soul is running" - wire.soul_side.send(msg) - - -async def _pump_notifications_to_wire(runtime: Runtime | None, wire: Wire) -> None: - while True: - try: - await _deliver_notifications_to_wire_once(runtime, wire) - except asyncio.CancelledError: - raise - except Exception: - logger.exception("Notification wire pump failed") - await asyncio.sleep(1.0) - - -async def _deliver_notifications_to_wire_once(runtime: Runtime | None, wire: Wire) -> None: - if runtime is None or runtime.role != "root": - return - - from kimi_cli.notifications import NotificationView, to_wire_notification - - def _send_notification(view: NotificationView) -> None: - wire.soul_side.send(to_wire_notification(view)) - - await runtime.notifications.deliver_pending( - "wire", - limit=8, - before_claim=runtime.background_tasks.reconcile, - on_notification=_send_notification, - ) diff --git a/src/kimi_cli/soul/agent.py b/src/kimi_cli/soul/agent.py deleted file mode 100644 index e05d18761..000000000 --- a/src/kimi_cli/soul/agent.py +++ /dev/null @@ -1,528 +0,0 @@ -from __future__ import annotations - -import asyncio -from dataclasses import asdict, dataclass -from datetime import datetime -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal - -import pydantic -from jinja2 import Environment as JinjaEnvironment -from jinja2 import FileSystemLoader, StrictUndefined, TemplateError, UndefinedError -from kaos.path import KaosPath -from kosong.tooling import Toolset - -from kimi_cli.agentspec import load_agent_spec -from kimi_cli.approval_runtime import ApprovalRuntime -from kimi_cli.auth.oauth import OAuthManager -from kimi_cli.background import BackgroundTaskManager -from kimi_cli.config import Config -from kimi_cli.exception import MCPConfigError, SystemPromptTemplateError -from kimi_cli.llm import LLM -from kimi_cli.notifications import NotificationManager -from kimi_cli.session import Session -from kimi_cli.skill import ( - Skill, - discover_skills_from_roots, - index_skills, - resolve_skills_roots, -) -from kimi_cli.soul.approval import Approval, ApprovalState -from kimi_cli.soul.denwarenji import DenwaRenji -from kimi_cli.soul.toolset import KimiToolset -from kimi_cli.subagents.models import AgentTypeDefinition, ToolPolicy -from kimi_cli.subagents.registry import LaborMarket -from kimi_cli.subagents.store import SubagentStore -from kimi_cli.utils.environment import Environment -from kimi_cli.utils.logging import logger -from kimi_cli.utils.path import is_within_directory, list_directory -from kimi_cli.wire.root_hub import RootWireHub - -if TYPE_CHECKING: - from fastmcp.mcp_config import MCPConfig - - -@dataclass(frozen=True, slots=True, kw_only=True) -class BuiltinSystemPromptArgs: - """Builtin system prompt arguments.""" - - KIMI_NOW: str - """The current datetime.""" - KIMI_WORK_DIR: KaosPath - """The absolute path of current working directory.""" - KIMI_WORK_DIR_LS: str - """The directory listing of current working directory.""" - KIMI_AGENTS_MD: str # TODO: move to first message from system prompt - """The merged content of AGENTS.md files (from project root to work_dir).""" - KIMI_SKILLS: str - """Formatted information about available skills.""" - KIMI_ADDITIONAL_DIRS_INFO: str - """Formatted information about additional directories in the workspace.""" - KIMI_OS: str - """The operating system kind, e.g. 'Windows', 'macOS', 'Linux'.""" - KIMI_SHELL: str - """The shell executable used by the Shell tool, e.g. 'bash (`/bin/bash`)'.""" - - -_AGENTS_MD_MAX_BYTES = 32 * 1024 # 32 KiB - - -async def _find_project_root(work_dir: KaosPath) -> KaosPath: - """Walk up from *work_dir* to find the nearest directory containing ``.git``. - - Returns *work_dir* itself if no ``.git`` marker is found. - """ - current = work_dir - while True: - if await (current / ".git").exists(): - return current - parent = current.parent - if parent == current: # filesystem root - return work_dir - current = parent - - -async def _dirs_root_to_leaf(work_dir: KaosPath, project_root: KaosPath) -> list[KaosPath]: - """Return the list of directories from *project_root* down to *work_dir* (inclusive).""" - dirs: list[KaosPath] = [] - current = work_dir - while True: - dirs.append(current) - if current == project_root: - break - parent = current.parent - if parent == current: - break - current = parent - dirs.reverse() # root → leaf - return dirs - - -async def load_agents_md(work_dir: KaosPath) -> str | None: - """Discover and merge ``AGENTS.md`` files from the project root down to *work_dir*. - - For each directory on the path, the following candidates are checked in order: - - 1. ``.kimi/AGENTS.md`` — project-local kimi config (highest priority) - 2. ``AGENTS.md`` — standard location - 3. ``agents.md`` — lowercase variant (mutually exclusive with 2) - - Within a single directory, ``.kimi/AGENTS.md`` and ``AGENTS.md``/``agents.md`` - are **both** loaded (with ``.kimi/`` first), but ``AGENTS.md`` and ``agents.md`` - are mutually exclusive (uppercase wins). - - All discovered files are concatenated root→leaf, separated by ``\\n\\n``, with - source annotations. Total size is capped at :data:`_AGENTS_MD_MAX_BYTES`. - Budget is allocated leaf-first so deeper (more specific) files are never - truncated in favour of shallower ones. - """ - project_root = await _find_project_root(work_dir) - dirs = await _dirs_root_to_leaf(work_dir, project_root) - - # Phase 1: collect all candidate files (root → leaf order) - discovered: list[tuple[KaosPath, str]] = [] # (path, content) - for d in dirs: - # .kimi/AGENTS.md is always checked independently (can coexist with root-level file) - kimi_path = d / ".kimi" / "AGENTS.md" - # AGENTS.md and agents.md are mutually exclusive (uppercase wins) - root_candidates = [d / "AGENTS.md", d / "agents.md"] - - candidates: list[KaosPath] = [] - if await kimi_path.is_file(): - candidates.append(kimi_path) - for rc in root_candidates: - if await rc.is_file(): - candidates.append(rc) - break - - for path in candidates: - content = (await path.read_text()).strip() - if content: - discovered.append((path, content)) - logger.info("Loaded agents.md: {path}", path=path) - - if not discovered: - logger.info( - "No AGENTS.md found from {root} to {cwd}", - root=project_root, - cwd=work_dir, - ) - return None - - # Phase 2: allocate budget leaf-first so deeper (more specific) files - # are never truncated in favour of shallower ones. - # The annotation overhead (\n and \n\n separators) - # is included in the budget so the final output never exceeds the limit. - remaining = _AGENTS_MD_MAX_BYTES - budgeted: list[tuple[KaosPath, str]] = [None] * len(discovered) # type: ignore[list-item] - for i in reversed(range(len(discovered))): - path, content = discovered[i] - annotation = f"\n" - # Reserve space for the annotation and the \n\n separator between parts - separator_cost = len(b"\n\n") if i < len(discovered) - 1 else 0 - overhead = len(annotation.encode()) + separator_cost - remaining -= overhead - if remaining <= 0: - budgeted[i] = (path, "") - remaining = 0 - continue - encoded = content.encode() - if len(encoded) > remaining: - content = encoded[:remaining].decode(errors="ignore").strip() - logger.warning("AGENTS.md truncated due to size limit: {path}", path=path) - remaining -= len(content.encode()) - budgeted[i] = (path, content) - - # Phase 3: assemble in root → leaf order, skipping entries emptied by truncation - parts: list[str] = [] - for path, content in budgeted: - if content: - parts.append(f"\n{content}") - - return "\n\n".join(parts) if parts else None - - -@dataclass(slots=True, kw_only=True) -class Runtime: - """Agent runtime.""" - - config: Config - oauth: OAuthManager - llm: LLM | None # we do not freeze the `Runtime` dataclass because LLM can be changed - session: Session - builtin_args: BuiltinSystemPromptArgs - denwa_renji: DenwaRenji - approval: Approval - labor_market: LaborMarket - environment: Environment - notifications: NotificationManager - background_tasks: BackgroundTaskManager - skills: dict[str, Skill] - additional_dirs: list[KaosPath] - skills_dirs: list[KaosPath] - subagent_store: SubagentStore | None = None - approval_runtime: ApprovalRuntime | None = None - root_wire_hub: RootWireHub | None = None - subagent_id: str | None = None - subagent_type: str | None = None - role: Literal["root", "subagent"] = "root" - hook_engine: Any = None - """HookEngine instance, set by KimiCLI after soul creation.""" - - def __post_init__(self) -> None: - if self.subagent_store is None: - self.subagent_store = SubagentStore(self.session) - if self.root_wire_hub is None: - self.root_wire_hub = RootWireHub() - if self.approval_runtime is None: - self.approval_runtime = ApprovalRuntime() - self.approval_runtime.bind_root_wire_hub(self.root_wire_hub) - self.approval.set_runtime(self.approval_runtime) - self.background_tasks.bind_runtime(self) - - @staticmethod - async def create( - config: Config, - oauth: OAuthManager, - llm: LLM | None, - session: Session, - yolo: bool, - skills_dirs: list[KaosPath] | None = None, - ) -> Runtime: - ls_output, agents_md, environment = await asyncio.gather( - list_directory(session.work_dir), - load_agents_md(session.work_dir), - Environment.detect(), - ) - - # Discover and format skills - skills_roots = await resolve_skills_roots( - session.work_dir, - skills_dirs=skills_dirs, - ) - # Canonicalize so symlinked skill directories match resolved paths - skills_roots_canonical = [r.canonical() for r in skills_roots] - skills = await discover_skills_from_roots(skills_roots) - skills_by_name = index_skills(skills) - logger.info("Discovered {count} skill(s)", count=len(skills)) - skills_formatted = "\n".join( - ( - f"- {skill.name}\n" - f" - Path: {skill.skill_md_file}\n" - f" - Description: {skill.description}" - ) - for skill in skills - ) - - # Restore additional directories from session state, pruning stale entries - additional_dirs: list[KaosPath] = [] - pruned = False - valid_dir_strs: list[str] = [] - for dir_str in session.state.additional_dirs: - d = KaosPath(dir_str).canonical() - if await d.is_dir(): - additional_dirs.append(d) - valid_dir_strs.append(dir_str) - else: - logger.warning( - "Additional directory no longer exists, removing from state: {dir}", - dir=dir_str, - ) - pruned = True - if pruned: - session.state.additional_dirs = valid_dir_strs - session.save_state() - - # Format additional dirs info for system prompt - additional_dirs_info = "" - if additional_dirs: - parts: list[str] = [] - for d in additional_dirs: - try: - dir_ls = await list_directory(d) - except OSError: - logger.warning( - "Cannot list additional directory, skipping listing: {dir}", dir=d - ) - dir_ls = "[directory not readable]" - parts.append(f"### `{d}`\n\n```\n{dir_ls}\n```") - additional_dirs_info = "\n\n".join(parts) - - # Merge CLI flag with persisted session state - effective_yolo = yolo or session.state.approval.yolo - saved_actions = set(session.state.approval.auto_approve_actions) - - def _on_approval_change() -> None: - session.state.approval.yolo = approval_state.yolo - session.state.approval.auto_approve_actions = set(approval_state.auto_approve_actions) - session.save_state() - - approval_state = ApprovalState( - yolo=effective_yolo, - auto_approve_actions=saved_actions, - on_change=_on_approval_change, - ) - notifications = NotificationManager( - session.context_file.parent / "notifications", - config.notifications, - ) - - return Runtime( - config=config, - oauth=oauth, - llm=llm, - session=session, - builtin_args=BuiltinSystemPromptArgs( - KIMI_NOW=datetime.now().astimezone().isoformat(), - KIMI_WORK_DIR=session.work_dir, - KIMI_WORK_DIR_LS=ls_output, - KIMI_AGENTS_MD=agents_md or "", - KIMI_SKILLS=skills_formatted or "No skills found.", - KIMI_ADDITIONAL_DIRS_INFO=additional_dirs_info, - KIMI_OS=environment.os_kind, - KIMI_SHELL=f"{environment.shell_name} (`{environment.shell_path}`)", - ), - denwa_renji=DenwaRenji(), - approval=Approval(state=approval_state), - labor_market=LaborMarket(), - environment=environment, - notifications=notifications, - background_tasks=BackgroundTaskManager( - session, - config.background, - notifications=notifications, - ), - skills=skills_by_name, - additional_dirs=additional_dirs, - # Only expose skills roots outside the workspace for Glob access; - # project-level roots are already within work_dir. - skills_dirs=[ - r for r in skills_roots_canonical if not is_within_directory(r, session.work_dir) - ], - subagent_store=SubagentStore(session), - approval_runtime=ApprovalRuntime(), - root_wire_hub=RootWireHub(), - role="root", - ) - - def copy_for_subagent( - self, - *, - agent_id: str, - subagent_type: str, - llm_override: LLM | None = None, - ) -> Runtime: - """Clone runtime for a subagent.""" - return Runtime( - config=self.config, - oauth=self.oauth, - llm=llm_override if llm_override is not None else self.llm, - session=self.session, - builtin_args=self.builtin_args, - denwa_renji=DenwaRenji(), # subagent must have its own DenwaRenji - approval=self.approval.share(), - labor_market=self.labor_market, - environment=self.environment, - notifications=self.notifications, - background_tasks=self.background_tasks.copy_for_role("subagent"), - skills=self.skills, - # Share the same list reference so /add-dir mutations propagate to all agents - additional_dirs=self.additional_dirs, - skills_dirs=self.skills_dirs, - subagent_store=self.subagent_store, - approval_runtime=self.approval_runtime, - root_wire_hub=self.root_wire_hub, - subagent_id=agent_id, - subagent_type=subagent_type, - role="subagent", - ) - - -@dataclass(frozen=True, slots=True, kw_only=True) -class Agent: - """The loaded agent.""" - - name: str - system_prompt: str - toolset: Toolset - runtime: Runtime - """Each agent has its own runtime, which should be derived from its main agent.""" - - -async def load_agent( - agent_file: Path, - runtime: Runtime, - *, - mcp_configs: list[MCPConfig] | list[dict[str, Any]], - start_mcp_loading: bool = True, -) -> Agent: - """ - Load agent from specification file. - - Raises: - FileNotFoundError: When the agent file is not found. - AgentSpecError(KimiCLIException, ValueError): When the agent specification is invalid. - SystemPromptTemplateError(KimiCLIException, ValueError): When the system prompt template - is invalid. - InvalidToolError(KimiCLIException, ValueError): When any tool cannot be loaded. - MCPConfigError(KimiCLIException, ValueError): When any MCP configuration is invalid. - MCPRuntimeError(KimiCLIException, RuntimeError): When any MCP server cannot be connected. - """ - logger.info("Loading agent: {agent_file}", agent_file=agent_file) - agent_spec = load_agent_spec(agent_file) - - system_prompt = _load_system_prompt( - agent_spec.system_prompt_path, - agent_spec.system_prompt_args, - runtime.builtin_args, - ) - - # Register built-in subagent types before loading tools because some tools render - # descriptions from the labor market on initialization. - for subagent_name, subagent_spec in agent_spec.subagents.items(): - logger.debug( - "Registering builtin subagent type: {subagent_name}", subagent_name=subagent_name - ) - builtin_spec = load_agent_spec(subagent_spec.path) - tool_policy = ( - ToolPolicy(mode="allowlist", tools=tuple(builtin_spec.allowed_tools)) - if builtin_spec.allowed_tools is not None - else ToolPolicy(mode="inherit") - ) - runtime.labor_market.add_builtin_type( - AgentTypeDefinition( - name=subagent_name, - description=subagent_spec.description, - agent_file=subagent_spec.path, - when_to_use=builtin_spec.when_to_use, - default_model=builtin_spec.model, - tool_policy=tool_policy, - ) - ) - - toolset = KimiToolset() - tool_deps = { - KimiToolset: toolset, - Runtime: runtime, - # TODO: remove all the following dependencies and use Runtime instead - Config: runtime.config, - BuiltinSystemPromptArgs: runtime.builtin_args, - Session: runtime.session, - DenwaRenji: runtime.denwa_renji, - Approval: runtime.approval, - LaborMarket: runtime.labor_market, - Environment: runtime.environment, - } - tools = agent_spec.allowed_tools if agent_spec.allowed_tools is not None else agent_spec.tools - if agent_spec.exclude_tools: - logger.debug("Excluding tools: {tools}", tools=agent_spec.exclude_tools) - tools = [tool for tool in tools if tool not in agent_spec.exclude_tools] - toolset.load_tools(tools, tool_deps) - - # Load plugin tools - from kimi_cli.plugin.manager import get_plugins_dir - from kimi_cli.plugin.tool import load_plugin_tools - - plugin_tools = load_plugin_tools(get_plugins_dir(), runtime.config, approval=runtime.approval) - for plugin_tool in plugin_tools: - if toolset.find(plugin_tool.name) is not None: - logger.warning( - "Plugin tool '{name}' conflicts with an existing tool, skipping", - name=plugin_tool.name, - ) - continue - toolset.add(plugin_tool) - - if mcp_configs: - validated_mcp_configs: list[MCPConfig] = [] - if mcp_configs: - from fastmcp.mcp_config import MCPConfig - - for mcp_config in mcp_configs: - try: - validated_mcp_configs.append( - mcp_config - if isinstance(mcp_config, MCPConfig) - else MCPConfig.model_validate(mcp_config) - ) - except pydantic.ValidationError as e: - raise MCPConfigError(f"Invalid MCP config: {e}") from e - if start_mcp_loading: - await toolset.load_mcp_tools(validated_mcp_configs, runtime, in_background=True) - else: - toolset.defer_mcp_tool_loading(validated_mcp_configs, runtime) - - return Agent( - name=agent_spec.name, - system_prompt=system_prompt, - toolset=toolset, - runtime=runtime, - ) - - -def _load_system_prompt( - path: Path, args: dict[str, str], builtin_args: BuiltinSystemPromptArgs -) -> str: - logger.info("Loading system prompt: {path}", path=path) - system_prompt = path.read_text(encoding="utf-8").strip() - logger.debug( - "Substituting system prompt with builtin args: {builtin_args}, spec args: {spec_args}", - builtin_args=builtin_args, - spec_args=args, - ) - env = JinjaEnvironment( - loader=FileSystemLoader(path.parent), - keep_trailing_newline=True, - lstrip_blocks=True, - trim_blocks=True, - variable_start_string="${", - variable_end_string="}", - undefined=StrictUndefined, - ) - try: - template = env.from_string(system_prompt) - return template.render(asdict(builtin_args), **args) - except UndefinedError as exc: - raise SystemPromptTemplateError(f"Missing system prompt arg in {path}: {exc}") from exc - except TemplateError as exc: - raise SystemPromptTemplateError(f"Invalid system prompt template: {path}: {exc}") from exc diff --git a/src/kimi_cli/soul/agent.ts b/src/kimi_cli/soul/agent.ts new file mode 100644 index 000000000..4940731e4 --- /dev/null +++ b/src/kimi_cli/soul/agent.ts @@ -0,0 +1,445 @@ +/** + * Agent & Runtime — corresponds to Python soul/agent.py + * Runtime execution environment and Agent loading. + */ + +import type { Config, LoopControl } from "../config.ts"; +import type { LLM } from "../llm.ts"; +import type { Session } from "../session.ts"; +import type { HookEngine } from "../hooks/engine.ts"; +import type { ModelCapability } from "../types.ts"; +import { Approval, ApprovalState } from "./approval.ts"; +import { KimiToolset } from "./toolset.ts"; +import { SlashCommandRegistry, createDefaultRegistry } from "./slash.ts"; +import { Context } from "./context.ts"; +import { logger } from "../utils/logging.ts"; + +// ── Built-in system prompt args ────────────────────── + +export interface BuiltinSystemPromptArgs { + KIMI_NOW: string; + KIMI_WORK_DIR: string; + KIMI_WORK_DIR_LS: string; + KIMI_AGENTS_MD: string; + KIMI_SKILLS: string; + KIMI_ADDITIONAL_DIRS_INFO: string; + KIMI_OS: string; + KIMI_SHELL: string; +} + +// ── Runtime ────────────────────────────────────────── + +export class Runtime { + config: Config; + llm: LLM | null; + session: Session; + approval: Approval; + hookEngine: HookEngine; + builtinArgs: BuiltinSystemPromptArgs; + role: "root" | "subagent"; + additionalDirs: string[]; + + constructor(opts: { + config: Config; + llm: LLM | null; + session: Session; + approval: Approval; + hookEngine: HookEngine; + builtinArgs: BuiltinSystemPromptArgs; + role?: "root" | "subagent"; + additionalDirs?: string[]; + }) { + this.config = opts.config; + this.llm = opts.llm; + this.session = opts.session; + this.approval = opts.approval; + this.hookEngine = opts.hookEngine; + this.builtinArgs = opts.builtinArgs; + this.role = opts.role ?? "root"; + this.additionalDirs = opts.additionalDirs ?? []; + } + + get loopControl(): LoopControl { + return this.config.loop_control; + } + + /** Create runtime with defaults. */ + static async create(opts: { + config: Config; + llm: LLM | null; + session: Session; + hookEngine: HookEngine; + }): Promise { + const workDir = opts.session.workDir; + + // Build system prompt args + let workDirLs = ""; + try { + const result = await Bun.$`ls -la ${workDir}`.quiet().text(); + workDirLs = result.trim(); + } catch { + workDirLs = "(unable to list directory)"; + } + + const osType = + process.platform === "darwin" + ? "macOS" + : process.platform === "win32" + ? "Windows" + : "Linux"; + + const shell = process.env.SHELL ?? "/bin/bash"; + + const builtinArgs: BuiltinSystemPromptArgs = { + KIMI_NOW: new Date().toISOString(), + KIMI_WORK_DIR: workDir, + KIMI_WORK_DIR_LS: workDirLs, + KIMI_AGENTS_MD: await loadAgentsMd(workDir) ?? "", + KIMI_SKILLS: "", // TODO: list skills + KIMI_ADDITIONAL_DIRS_INFO: opts.session.state.additional_dirs.length > 0 + ? `Additional directories: ${opts.session.state.additional_dirs.join(", ")}` + : "", + KIMI_OS: osType, + KIMI_SHELL: shell, + }; + + // Restore additional directories from session state + const additionalDirs = opts.session.state.additional_dirs.filter( + (d: string) => { + try { + const { statSync } = require("node:fs"); + return statSync(d).isDirectory(); + } catch { + return false; + } + }, + ); + + // Restore approval state from session + const approvalState = new ApprovalState({ + yolo: + opts.config.default_yolo || opts.session.state.approval.yolo, + autoApproveActions: new Set( + opts.session.state.approval.auto_approve_actions, + ), + }); + + const approval = new Approval({ state: approvalState }); + + return new Runtime({ + config: opts.config, + llm: opts.llm, + session: opts.session, + approval, + hookEngine: opts.hookEngine, + builtinArgs, + additionalDirs, + }); + } + + /** Create a copy for subagents with shared state. */ + copyForSubagent(): Runtime { + return new Runtime({ + config: this.config, + llm: this.llm, + session: this.session, + approval: this.approval.share(), + hookEngine: this.hookEngine, + builtinArgs: { + ...this.builtinArgs, + KIMI_NOW: new Date().toISOString(), + }, + role: "subagent", + // Share the same list reference so /add-dir mutations propagate to all agents + additionalDirs: this.additionalDirs, + }); + } +} + +// ── Agent ────────────────────────────────────────────── + +export class Agent { + readonly name: string; + readonly systemPrompt: string; + readonly toolset: KimiToolset; + readonly runtime: Runtime; + readonly slashCommands: SlashCommandRegistry; + + constructor(opts: { + name: string; + systemPrompt: string; + toolset: KimiToolset; + runtime: Runtime; + slashCommands?: SlashCommandRegistry; + }) { + this.name = opts.name; + this.systemPrompt = opts.systemPrompt; + this.toolset = opts.toolset; + this.runtime = opts.runtime; + this.slashCommands = opts.slashCommands ?? createDefaultRegistry(); + } + + get modelCapabilities(): Set | null { + return this.runtime.llm?.capabilities ?? null; + } + + get modelName(): string { + return this.runtime.llm?.modelName ?? "unknown"; + } +} + +// ── Agent loader ───────────────────────────────────── + +/** + * Load an agent with its toolset and system prompt. + */ +export async function loadAgent(opts: { + runtime: Runtime; + agentName?: string; + systemPromptOverride?: string; +}): Promise { + const { runtime, agentName = "default" } = opts; + + // Load system prompt + let systemPrompt = opts.systemPromptOverride ?? ""; + if (!systemPrompt) { + systemPrompt = await loadSystemPrompt(agentName, runtime.builtinArgs); + } + + // Create toolset + const toolset = new KimiToolset({ + context: { + workingDir: runtime.session.workDir, + signal: new AbortController().signal, + approval: async (toolName: string, _action: string, description: string) => { + const result = await runtime.approval.request( + toolName, + toolName, + description, + ); + return result.approved ? "approve" : "reject"; + }, + wireEmit: () => {}, // Will be wired by KimiSoul + serviceConfig: { + moonshotSearch: runtime.config.services.moonshot_search + ? { + baseUrl: runtime.config.services.moonshot_search.base_url, + apiKey: runtime.config.services.moonshot_search.api_key, + customHeaders: runtime.config.services.moonshot_search.custom_headers, + } + : undefined, + moonshotFetch: runtime.config.services.moonshot_fetch + ? { + baseUrl: runtime.config.services.moonshot_fetch.base_url, + apiKey: runtime.config.services.moonshot_fetch.api_key, + customHeaders: runtime.config.services.moonshot_fetch.custom_headers, + } + : undefined, + }, + }, + hookEngine: runtime.hookEngine, + }); + + // Register built-in tools + await registerBuiltinTools(toolset); + + return new Agent({ + name: agentName, + systemPrompt, + toolset, + runtime, + }); +} + +async function loadSystemPrompt( + agentName: string, + args: BuiltinSystemPromptArgs, +): Promise { + // Try to load from agents/default/system.md + const paths = [ + `src/kimi_cli/agents/${agentName}/system.md`, + `agents/${agentName}/system.md`, + ]; + + for (const p of paths) { + const file = Bun.file(p); + if (await file.exists()) { + let template = await file.text(); + // Simple template substitution (${VAR} syntax) + for (const [key, value] of Object.entries(args)) { + template = template.replaceAll(`\${${key}}`, String(value)); + } + return template; + } + } + + // Fallback system prompt + return [ + "You are Kimi, an AI assistant running in a terminal.", + `Current working directory: ${args.KIMI_WORK_DIR}`, + `OS: ${args.KIMI_OS}, Shell: ${args.KIMI_SHELL}`, + `Current date: ${args.KIMI_NOW}`, + "", + "You have access to tools for reading/writing files, running shell commands,", + "and searching the web. Use them to help the user with their tasks.", + ].join("\n"); +} + +async function registerBuiltinTools(toolset: KimiToolset): Promise { + // Import and register all built-in tools + const toolModules = [ + () => import("../tools/file/read.ts"), + () => import("../tools/file/write.ts"), + () => import("../tools/file/replace.ts"), + () => import("../tools/file/glob.ts"), + () => import("../tools/file/grep.ts"), + () => import("../tools/shell/shell.ts"), + () => import("../tools/web/fetch.ts"), + () => import("../tools/web/search.ts"), + () => import("../tools/think/think.ts"), + () => import("../tools/ask_user/ask_user.ts"), + () => import("../tools/todo/todo.ts"), + () => import("../tools/plan/plan.ts"), + ]; + + for (const loadModule of toolModules) { + try { + const mod = await loadModule(); + // Find exported classes that look like tools + for (const [_key, value] of Object.entries(mod)) { + if ( + typeof value === "function" && + value.prototype && + typeof value.prototype.execute === "function" && + typeof value.prototype.toDefinition === "function" + ) { + try { + const instance = new (value as new () => any)(); + if (instance.name) { + toolset.add(instance); + } + } catch { + // Some tools need constructor args, skip + } + } + } + } catch (err) { + logger.warn(`Failed to load tool module: ${err}`); + } + } +} + +// ── AGENTS.md loader ──────────────────────────────── + +const AGENTS_MD_MAX_BYTES = 32 * 1024; // 32 KiB + +/** + * Find the nearest git root by walking up from workDir. + */ +async function findProjectRoot(workDir: string): Promise { + const { resolve, dirname } = await import("node:path"); + let current = resolve(workDir); + while (true) { + const gitFile = Bun.file(`${current}/.git`); + if (await gitFile.exists()) return current; + const parent = dirname(current); + if (parent === current) return resolve(workDir); + current = parent; + } +} + +/** + * Return the list of directories from projectRoot down to workDir (inclusive). + */ +function dirsRootToLeaf(workDir: string, projectRoot: string): string[] { + const { resolve, dirname } = require("node:path"); + const dirs: string[] = []; + let current = resolve(workDir); + const root = resolve(projectRoot); + while (true) { + dirs.push(current); + if (current === root) break; + const parent = dirname(current); + if (parent === current) break; + current = parent; + } + dirs.reverse(); // root → leaf + return dirs; +} + +/** + * Discover and merge AGENTS.md files from the project root down to workDir. + * Matches Python's `load_agents_md` behavior. + */ +export async function loadAgentsMd(workDir: string): Promise { + const projectRoot = await findProjectRoot(workDir); + const dirs = dirsRootToLeaf(workDir, projectRoot); + + // Phase 1: collect all candidate files (root → leaf order) + const discovered: { path: string; content: string }[] = []; + for (const d of dirs) { + const candidates: string[] = []; + + // .kimi/AGENTS.md — highest priority + const kimiPath = `${d}/.kimi/AGENTS.md`; + if (await Bun.file(kimiPath).exists()) { + candidates.push(kimiPath); + } + + // AGENTS.md or agents.md — mutually exclusive + const upperPath = `${d}/AGENTS.md`; + const lowerPath = `${d}/agents.md`; + if (await Bun.file(upperPath).exists()) { + candidates.push(upperPath); + } else if (await Bun.file(lowerPath).exists()) { + candidates.push(lowerPath); + } + + for (const path of candidates) { + const content = (await Bun.file(path).text()).trim(); + if (content) { + discovered.push({ path, content }); + logger.info(`Loaded agents.md: ${path}`); + } + } + } + + if (discovered.length === 0) return null; + + // Phase 2: allocate budget leaf-first + let remaining = AGENTS_MD_MAX_BYTES; + const budgeted: { path: string; content: string }[] = new Array(discovered.length); + for (let i = discovered.length - 1; i >= 0; i--) { + const { path, content } = discovered[i]!; + const annotation = `\n`; + const separatorCost = i < discovered.length - 1 ? 2 : 0; // "\n\n" + const overhead = Buffer.byteLength(annotation) + separatorCost; + remaining -= overhead; + if (remaining <= 0) { + budgeted[i] = { path, content: "" }; + remaining = 0; + continue; + } + const encoded = Buffer.from(content); + if (encoded.length > remaining) { + budgeted[i] = { + path, + content: encoded.subarray(0, remaining).toString("utf-8").trim(), + }; + remaining = 0; + } else { + budgeted[i] = { path, content }; + remaining -= encoded.length; + } + } + + // Phase 3: assemble root → leaf + const parts: string[] = []; + for (const { path, content } of budgeted) { + if (content) { + parts.push(`\n${content}`); + } + } + + return parts.length > 0 ? parts.join("\n\n") : null; +} diff --git a/src/kimi_cli/soul/approval.py b/src/kimi_cli/soul/approval.py deleted file mode 100644 index f25ef7786..000000000 --- a/src/kimi_cli/soul/approval.py +++ /dev/null @@ -1,171 +0,0 @@ -from __future__ import annotations - -import uuid -from collections.abc import Callable -from typing import Literal - -from kimi_cli.approval_runtime import ( - ApprovalCancelledError, - ApprovalRuntime, - ApprovalSource, - get_current_approval_source_or_none, -) -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.tools.utils import ToolRejectedError -from kimi_cli.utils.logging import logger -from kimi_cli.wire.types import DisplayBlock - -type Response = Literal["approve", "approve_for_session", "reject"] - - -class ApprovalResult: - """Result of an approval request. Behaves as bool for backward compatibility.""" - - __slots__ = ("approved", "feedback") - - def __init__(self, approved: bool, feedback: str = ""): - self.approved = approved - self.feedback = feedback - - def __bool__(self) -> bool: - return self.approved - - def rejection_error(self) -> ToolRejectedError: - if self.feedback: - return ToolRejectedError( - message=(f"The tool call is rejected by the user. User feedback: {self.feedback}"), - brief=f"Rejected: {self.feedback}", - has_feedback=True, - ) - source = get_current_approval_source_or_none() - is_subagent = source is not None and source.agent_id is not None - if is_subagent: - return ToolRejectedError( - message=( - "The tool call is rejected by the user. " - "Try a different approach to complete your task, or explain the " - "limitation in your summary if no alternative is available. " - "Do not retry the same tool call, and do not attempt to bypass " - "this restriction through indirect means." - ), - ) - return ToolRejectedError() - - -class ApprovalState: - def __init__( - self, - yolo: bool = False, - auto_approve_actions: set[str] | None = None, - on_change: Callable[[], None] | None = None, - ): - self.yolo = yolo - self.auto_approve_actions: set[str] = auto_approve_actions or set() - """Set of action names that should automatically be approved.""" - self._on_change = on_change - - def notify_change(self) -> None: - if self._on_change is not None: - self._on_change() - - -class Approval: - def __init__( - self, - yolo: bool = False, - *, - state: ApprovalState | None = None, - runtime: ApprovalRuntime | None = None, - ): - self._state = state or ApprovalState(yolo=yolo) - self._runtime = runtime or ApprovalRuntime() - - def share(self) -> Approval: - """Create a new approval queue that shares state (yolo + auto-approve).""" - return Approval(state=self._state, runtime=self._runtime) - - def set_runtime(self, runtime: ApprovalRuntime) -> None: - self._runtime = runtime - - @property - def runtime(self) -> ApprovalRuntime: - return self._runtime - - def set_yolo(self, yolo: bool) -> None: - self._state.yolo = yolo - self._state.notify_change() - - def is_yolo(self) -> bool: - return self._state.yolo - - async def request( - self, - sender: str, - action: str, - description: str, - display: list[DisplayBlock] | None = None, - ) -> ApprovalResult: - """ - Request approval for the given action. Intended to be called by tools. - - Args: - sender (str): The name of the sender. - action (str): The action to request approval for. - This is used to identify the action for auto-approval. - description (str): The description of the action. This is used to display to the user. - - Returns: - ApprovalResult: Result with ``approved`` flag and optional ``feedback``. - Behaves as ``bool`` via ``__bool__``, so ``if not result:`` works. - - Raises: - RuntimeError: If the approval is requested from outside a tool call. - """ - tool_call = get_current_tool_call_or_none() - if tool_call is None: - raise RuntimeError("Approval must be requested from a tool call.") - - logger.debug( - "{tool_name} ({tool_call_id}) requesting approval: {action} {description}", - tool_name=tool_call.function.name, - tool_call_id=tool_call.id, - action=action, - description=description, - ) - if self._state.yolo: - return ApprovalResult(approved=True) - - if action in self._state.auto_approve_actions: - return ApprovalResult(approved=True) - - request_id = str(uuid.uuid4()) - display_blocks = display or [] - source = get_current_approval_source_or_none() or ApprovalSource( - kind="foreground_turn", - id=tool_call.id, - ) - self._runtime.create_request( - request_id=request_id, - tool_call_id=tool_call.id, - sender=sender, - action=action, - description=description, - display=display_blocks, - source=source, - ) - try: - response, feedback = await self._runtime.wait_for_response(request_id) - except ApprovalCancelledError: - return ApprovalResult(approved=False) - match response: - case "approve": - return ApprovalResult(approved=True) - case "approve_for_session": - self._state.auto_approve_actions.add(action) - self._state.notify_change() - for pending in self._runtime.list_pending(): - if pending.action == action: - self._runtime.resolve(pending.id, "approve") - return ApprovalResult(approved=True) - case "reject": - return ApprovalResult(approved=False, feedback=feedback) diff --git a/src/kimi_cli/soul/approval.ts b/src/kimi_cli/soul/approval.ts new file mode 100644 index 000000000..4eeb8d25e --- /dev/null +++ b/src/kimi_cli/soul/approval.ts @@ -0,0 +1,134 @@ +/** + * Approval system — corresponds to Python soul/approval.py + * High-level approval request/response logic used by tools. + */ + +import { randomUUID } from "node:crypto"; +import { + ApprovalRuntime, + ApprovalCancelledError, + type ApprovalResponseKind, + type ApprovalSource, +} from "../approval_runtime/index.ts"; + +// ── ApprovalResult ────────────────────────────────────── + +export class ApprovalResult { + readonly approved: boolean; + readonly feedback: string; + + constructor(approved: boolean, feedback = "") { + this.approved = approved; + this.feedback = feedback; + } + + /** Allow `if (result)` / `if (!result)` usage. */ + valueOf(): boolean { + return this.approved; + } +} + +// ── ApprovalState ─────────────────────────────────────── + +export class ApprovalState { + yolo: boolean; + autoApproveActions: Set; + private onChange?: () => void; + + constructor(opts?: { yolo?: boolean; autoApproveActions?: Set; onChange?: () => void }) { + this.yolo = opts?.yolo ?? false; + this.autoApproveActions = opts?.autoApproveActions ?? new Set(); + this.onChange = opts?.onChange; + } + + notifyChange(): void { + this.onChange?.(); + } +} + +// ── Approval ──────────────────────────────────────────── + +export class Approval { + private state: ApprovalState; + private _runtime: ApprovalRuntime; + + constructor(opts?: { yolo?: boolean; state?: ApprovalState; runtime?: ApprovalRuntime }) { + this.state = opts?.state ?? new ApprovalState({ yolo: opts?.yolo }); + this._runtime = opts?.runtime ?? new ApprovalRuntime(); + } + + /** Create a new Approval that shares state (yolo + auto-approve). */ + share(): Approval { + return new Approval({ state: this.state, runtime: this._runtime }); + } + + get runtime(): ApprovalRuntime { + return this._runtime; + } + + setRuntime(runtime: ApprovalRuntime): void { + this._runtime = runtime; + } + + setYolo(yolo: boolean): void { + this.state.yolo = yolo; + this.state.notifyChange(); + } + + isYolo(): boolean { + return this.state.yolo; + } + + async request( + sender: string, + action: string, + description: string, + opts?: { + toolCallId?: string; + display?: unknown[]; + source?: ApprovalSource; + }, + ): Promise { + const toolCallId = opts?.toolCallId ?? randomUUID(); + const source: ApprovalSource = opts?.source ?? { kind: "foreground_turn", id: toolCallId }; + + if (this.state.yolo) return new ApprovalResult(true); + if (this.state.autoApproveActions.has(action)) return new ApprovalResult(true); + + const requestId = randomUUID(); + this._runtime.createRequest({ + requestId, + toolCallId, + sender, + action, + description, + display: opts?.display, + source, + }); + + try { + const [response, feedback] = await this._runtime.waitForResponse(requestId); + switch (response) { + case "approve": + return new ApprovalResult(true); + case "approve_for_session": + this.state.autoApproveActions.add(action); + this.state.notifyChange(); + // Auto-approve other pending requests for the same action + for (const pending of this._runtime.listPending()) { + if (pending.action === action) { + this._runtime.resolve(pending.id, "approve"); + } + } + return new ApprovalResult(true); + case "reject": + return new ApprovalResult(false, feedback); + } + } catch (err) { + if (err instanceof ApprovalCancelledError) { + return new ApprovalResult(false); + } + throw err; + } + } +} diff --git a/src/kimi_cli/soul/compaction.py b/src/kimi_cli/soul/compaction.py deleted file mode 100644 index 7db37d5d8..000000000 --- a/src/kimi_cli/soul/compaction.py +++ /dev/null @@ -1,189 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable - -import kosong -from kosong.chat_provider import TokenUsage -from kosong.message import Message -from kosong.tooling.empty import EmptyToolset - -import kimi_cli.prompts as prompts -from kimi_cli.llm import LLM -from kimi_cli.soul.message import system -from kimi_cli.utils.logging import logger -from kimi_cli.wire.types import ContentPart, TextPart, ThinkPart - - -class CompactionResult(NamedTuple): - messages: Sequence[Message] - usage: TokenUsage | None - - @property - def estimated_token_count(self) -> int: - """Estimate the token count of the compacted messages. - - When LLM usage is available, ``usage.output`` gives the exact token count - of the generated summary (the first message). Preserved messages (all - subsequent messages) are estimated from their text length. - - When usage is not available (no compaction LLM call was made), all - messages are estimated from text length. - - The estimate is intentionally conservative — it will be replaced by the - real value on the next LLM call. - """ - if self.usage is not None and len(self.messages) > 0: - summary_tokens = self.usage.output - preserved_tokens = estimate_text_tokens(self.messages[1:]) - return summary_tokens + preserved_tokens - - return estimate_text_tokens(self.messages) - - -def estimate_text_tokens(messages: Sequence[Message]) -> int: - """Estimate tokens from message text content using a character-based heuristic.""" - total_chars = 0 - for msg in messages: - for part in msg.content: - if isinstance(part, TextPart): - total_chars += len(part.text) - # ~4 chars per token for English; somewhat underestimates for CJK text, - # but this is a temporary estimate that gets corrected on the next LLM call. - return total_chars // 4 - - -def should_auto_compact( - token_count: int, - max_context_size: int, - *, - trigger_ratio: float, - reserved_context_size: int, -) -> bool: - """Determine whether auto-compaction should be triggered. - - Returns True when either condition is met (whichever fires first): - - Ratio-based: token_count >= max_context_size * trigger_ratio - - Reserved-based: token_count + reserved_context_size >= max_context_size - """ - return ( - token_count >= max_context_size * trigger_ratio - or token_count + reserved_context_size >= max_context_size - ) - - -@runtime_checkable -class Compaction(Protocol): - async def compact( - self, messages: Sequence[Message], llm: LLM, *, custom_instruction: str = "" - ) -> CompactionResult: - """ - Compact a sequence of messages into a new sequence of messages. - - Args: - messages (Sequence[Message]): The messages to compact. - llm (LLM): The LLM to use for compaction. - custom_instruction: Optional user instruction to guide compaction focus. - - Returns: - CompactionResult: The compacted messages and token usage from the compaction LLM call. - - Raises: - ChatProviderError: When the chat provider returns an error. - """ - ... - - -if TYPE_CHECKING: - - def type_check(simple: SimpleCompaction): - _: Compaction = simple - - -class SimpleCompaction: - def __init__(self, max_preserved_messages: int = 2) -> None: - self.max_preserved_messages = max_preserved_messages - - async def compact( - self, messages: Sequence[Message], llm: LLM, *, custom_instruction: str = "" - ) -> CompactionResult: - compact_message, to_preserve = self.prepare(messages, custom_instruction=custom_instruction) - if compact_message is None: - return CompactionResult(messages=to_preserve, usage=None) - - # Call kosong.step to get the compacted context - # TODO: set max completion tokens - logger.debug("Compacting context...") - result = await kosong.step( - chat_provider=llm.chat_provider, - system_prompt="You are a helpful assistant that compacts conversation context.", - toolset=EmptyToolset(), - history=[compact_message], - ) - if result.usage: - logger.debug( - "Compaction used {input} input tokens and {output} output tokens", - input=result.usage.input, - output=result.usage.output, - ) - - content: list[ContentPart] = [ - system("Previous context has been compacted. Here is the compaction output:") - ] - compacted_msg = result.message - - # drop thinking parts if any - content.extend(part for part in compacted_msg.content if not isinstance(part, ThinkPart)) - compacted_messages: list[Message] = [Message(role="user", content=content)] - compacted_messages.extend(to_preserve) - return CompactionResult(messages=compacted_messages, usage=result.usage) - - class PrepareResult(NamedTuple): - compact_message: Message | None - to_preserve: Sequence[Message] - - def prepare( - self, messages: Sequence[Message], *, custom_instruction: str = "" - ) -> PrepareResult: - if not messages or self.max_preserved_messages <= 0: - return self.PrepareResult(compact_message=None, to_preserve=messages) - - history = list(messages) - preserve_start_index = len(history) - n_preserved = 0 - for index in range(len(history) - 1, -1, -1): - if history[index].role in {"user", "assistant"}: - n_preserved += 1 - if n_preserved == self.max_preserved_messages: - preserve_start_index = index - break - - if n_preserved < self.max_preserved_messages: - return self.PrepareResult(compact_message=None, to_preserve=messages) - - to_compact = history[:preserve_start_index] - to_preserve = history[preserve_start_index:] - - if not to_compact: - # Let's hope this won't exceed the context size limit - return self.PrepareResult(compact_message=None, to_preserve=to_preserve) - - # Create input message for compaction - compact_message = Message(role="user", content=[]) - for i, msg in enumerate(to_compact): - compact_message.content.append( - TextPart(text=f"## Message {i + 1}\nRole: {msg.role}\nContent:\n") - ) - compact_message.content.extend( - part for part in msg.content if isinstance(part, TextPart) - ) - prompt_text = "\n" + prompts.COMPACT - if custom_instruction: - prompt_text += ( - "\n\n**User's Custom Compaction Instruction:**\n" - "The user has specifically requested the following focus during compaction. " - "You MUST prioritize this instruction above the default compression priorities:\n" - f"{custom_instruction}" - ) - compact_message.content.append(TextPart(text=prompt_text)) - return self.PrepareResult(compact_message=compact_message, to_preserve=to_preserve) diff --git a/src/kimi_cli/soul/compaction.ts b/src/kimi_cli/soul/compaction.ts new file mode 100644 index 000000000..3dcc1be1a --- /dev/null +++ b/src/kimi_cli/soul/compaction.ts @@ -0,0 +1,224 @@ +/** + * Compaction — corresponds to Python soul/compaction.py + * Summarizes conversation history when context window is getting full. + * Supports preserved messages: the last N user/assistant turns are kept verbatim. + */ + +import type { LLM } from "../llm.ts"; +import type { Message } from "../types.ts"; +import type { Context } from "./context.ts"; +import { logger } from "../utils/logging.ts"; + +/** Default number of recent user/assistant turns to preserve during compaction. */ +const DEFAULT_MAX_PRESERVED_MESSAGES = 2; + +/** + * Estimate tokens from message text content using a character-based heuristic. + * ~4 chars per token for English; somewhat underestimates for CJK text. + */ +export function estimateTextTokens(messages: readonly Message[]): number { + let totalChars = 0; + for (const msg of messages) { + if (typeof msg.content === "string") { + totalChars += msg.content.length; + } else { + for (const part of msg.content) { + if (part.type === "text") { + totalChars += part.text.length; + } + } + } + } + return Math.floor(totalChars / 4); +} + +/** + * Prepare messages for compaction by splitting into to-compact and to-preserve. + * Preserves the last `maxPreservedMessages` user/assistant turns verbatim. + */ +export function prepareCompaction( + messages: readonly Message[], + maxPreservedMessages = DEFAULT_MAX_PRESERVED_MESSAGES, +): { toCompact: Message[]; toPreserve: Message[] } { + if (!messages.length || maxPreservedMessages <= 0) { + return { toCompact: [], toPreserve: [...messages] }; + } + + const history = [...messages]; + let preserveStartIndex = history.length; + let nPreserved = 0; + + for (let index = history.length - 1; index >= 0; index--) { + if (history[index]!.role === "user" || history[index]!.role === "assistant") { + nPreserved++; + if (nPreserved === maxPreservedMessages) { + preserveStartIndex = index; + break; + } + } + } + + if (nPreserved < maxPreservedMessages) { + return { toCompact: [], toPreserve: [...messages] }; + } + + const toCompact = history.slice(0, preserveStartIndex); + const toPreserve = history.slice(preserveStartIndex); + + if (toCompact.length === 0) { + return { toCompact: [], toPreserve }; + } + + return { toCompact, toPreserve }; +} + +/** + * Simple compaction strategy: ask the LLM to summarize the conversation. + * Preserves recent messages verbatim. + */ +export async function compactContext( + context: Context, + llm: LLM, + opts?: { + focus?: string; + maxPreservedMessages?: number; + onBegin?: () => void; + onEnd?: () => void; + }, +): Promise { + const history = context.history; + if (history.length === 0) return; + + opts?.onBegin?.(); + + try { + const maxPreserved = opts?.maxPreservedMessages ?? DEFAULT_MAX_PRESERVED_MESSAGES; + const { toCompact, toPreserve } = prepareCompaction(history, maxPreserved); + + // Nothing to compact — preserve all + if (toCompact.length === 0) { + return; + } + + // Build summary request from to-compact messages + const summaryPrompt = buildSummaryPrompt(toCompact, opts?.focus); + + // Ask LLM to summarize + let summary = ""; + try { + const stream = llm.chat( + [{ role: "user", content: summaryPrompt }], + { + system: + "You are a helpful assistant that compacts conversation context.", + maxTokens: 4096, + }, + ); + + for await (const chunk of stream) { + if (chunk.type === "text") { + summary += chunk.text; + } + } + } catch (err) { + logger.warn(`Compaction LLM call failed, using fallback: ${err}`); + summary = buildFallbackSummary(toCompact); + } + + // Clear context and inject summary + preserved messages + await context.compact(); + + if (summary) { + await context.appendMessage({ + role: "user", + content: `Previous context has been compacted. Here is the compaction output:\n${summary}`, + }); + } + + // Re-append preserved messages + for (const msg of toPreserve) { + await context.appendMessage(msg); + } + } finally { + opts?.onEnd?.(); + } +} + +function buildSummaryPrompt( + messages: readonly Message[], + focus?: string, +): string { + const parts: string[] = []; + + // Build structured input matching Python's format + for (let i = 0; i < messages.length; i++) { + const msg = messages[i]!; + parts.push(`## Message ${i + 1}\nRole: ${msg.role}\nContent:`); + if (typeof msg.content === "string") { + parts.push(msg.content); + } else { + for (const part of msg.content) { + if (part.type === "text") { + parts.push(part.text); + } + } + } + } + + let promptText = "\n" + COMPACT_PROMPT; + if (focus) { + promptText += + "\n\n**User's Custom Compaction Instruction:**\n" + + "The user has specifically requested the following focus during compaction. " + + "You MUST prioritize this instruction above the default compression priorities:\n" + + focus; + } + parts.push(promptText); + + return parts.join("\n"); +} + +const COMPACT_PROMPT = `Summarize the conversation above concisely. Preserve: +- Key decisions and outcomes +- Important file paths and code changes +- Tool call results that are still relevant +- Any pending tasks or goals +Be thorough but concise.`; + +function buildFallbackSummary(history: readonly Message[]): string { + // Simple fallback: keep last few messages as summary + const last = history.slice(-6); + const parts = ["[Fallback summary - LLM compaction failed]"]; + + for (const msg of last) { + const content = + typeof msg.content === "string" + ? msg.content.slice(0, 500) + : msg.content + .map((p) => ("text" in p ? p.text : `[${p.type}]`)) + .join("\n") + .slice(0, 500); + parts.push(`[${msg.role}]: ${content}`); + } + + return parts.join("\n"); +} + +/** + * Determine whether auto-compaction should be triggered. + * + * Returns true when either condition is met (whichever fires first): + * - Ratio-based: tokenCount >= maxContextSize * triggerRatio + * - Reserved-based: tokenCount + reservedContextSize >= maxContextSize + */ +export function shouldCompact( + tokenCount: number, + maxContextSize: number, + reservedContextSize: number, + triggerRatio: number, +): boolean { + return ( + tokenCount >= maxContextSize * triggerRatio || + tokenCount + reservedContextSize >= maxContextSize + ); +} diff --git a/src/kimi_cli/soul/context.py b/src/kimi_cli/soul/context.py deleted file mode 100644 index e5af074c9..000000000 --- a/src/kimi_cli/soul/context.py +++ /dev/null @@ -1,239 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -from collections.abc import Sequence -from pathlib import Path - -import aiofiles -import aiofiles.os -from kosong.message import Message - -from kimi_cli.soul.compaction import estimate_text_tokens -from kimi_cli.soul.message import system -from kimi_cli.utils.logging import logger -from kimi_cli.utils.path import next_available_rotation - - -class Context: - def __init__(self, file_backend: Path): - self._file_backend = file_backend - self._history: list[Message] = [] - self._token_count: int = 0 - self._pending_token_estimate: int = 0 - self._next_checkpoint_id: int = 0 - """The ID of the next checkpoint, starting from 0, incremented after each checkpoint.""" - self._system_prompt: str | None = None - - async def restore(self) -> bool: - logger.debug("Restoring context from file: {file_backend}", file_backend=self._file_backend) - if self._history: - logger.error("The context storage is already modified") - raise RuntimeError("The context storage is already modified") - if not self._file_backend.exists(): - logger.debug("No context file found, skipping restoration") - return False - if self._file_backend.stat().st_size == 0: - logger.debug("Empty context file, skipping restoration") - return False - - messages_after_last_usage: list[Message] = [] - async with aiofiles.open(self._file_backend, encoding="utf-8") as f: - async for line in f: - if not line.strip(): - continue - line_json = json.loads(line, strict=False) - if line_json["role"] == "_system_prompt": - self._system_prompt = line_json["content"] - continue - if line_json["role"] == "_usage": - self._token_count = line_json["token_count"] - messages_after_last_usage.clear() - continue - if line_json["role"] == "_checkpoint": - self._next_checkpoint_id = line_json["id"] + 1 - continue - message = Message.model_validate(line_json) - self._history.append(message) - messages_after_last_usage.append(message) - - self._pending_token_estimate = estimate_text_tokens(messages_after_last_usage) - return True - - @property - def history(self) -> Sequence[Message]: - return self._history - - @property - def token_count(self) -> int: - return self._token_count - - @property - def token_count_with_pending(self) -> int: - return self._token_count + self._pending_token_estimate - - @property - def n_checkpoints(self) -> int: - return self._next_checkpoint_id - - @property - def system_prompt(self) -> str | None: - return self._system_prompt - - @property - def file_backend(self) -> Path: - return self._file_backend - - async def write_system_prompt(self, prompt: str) -> None: - """Write the system prompt as the first record of the context file. - - If the file is empty, writes it directly. If the file already has content - (e.g. a legacy session without system prompt), prepends it atomically via a - temporary file to avoid corruption on crash and avoid loading the entire file - into memory. - """ - prompt_line = json.dumps({"role": "_system_prompt", "content": prompt}) + "\n" - - def _write_system_prompt_sync() -> None: - if not self._file_backend.exists() or self._file_backend.stat().st_size == 0: - self._file_backend.write_text(prompt_line, encoding="utf-8") - return - - tmp_path = self._file_backend.with_suffix(".tmp") - with ( - tmp_path.open("w", encoding="utf-8") as tmp_f, - self._file_backend.open(encoding="utf-8") as src_f, - ): - tmp_f.write(prompt_line) - while True: - chunk = src_f.read(64 * 1024) - if not chunk: - break - tmp_f.write(chunk) - tmp_path.replace(self._file_backend) - - await asyncio.to_thread(_write_system_prompt_sync) - - self._system_prompt = prompt - - async def checkpoint(self, add_user_message: bool): - checkpoint_id = self._next_checkpoint_id - self._next_checkpoint_id += 1 - logger.debug("Checkpointing, ID: {id}", id=checkpoint_id) - - async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: - await f.write(json.dumps({"role": "_checkpoint", "id": checkpoint_id}) + "\n") - if add_user_message: - await self.append_message( - Message(role="user", content=[system(f"CHECKPOINT {checkpoint_id}")]) - ) - - async def revert_to(self, checkpoint_id: int): - """ - Revert the context to the specified checkpoint. - After this, the specified checkpoint and all subsequent content will be - removed from the context. File backend will be rotated. - - Args: - checkpoint_id (int): The ID of the checkpoint to revert to. 0 is the first checkpoint. - - Raises: - ValueError: When the checkpoint does not exist. - RuntimeError: When no available rotation path is found. - """ - - logger.debug("Reverting checkpoint, ID: {id}", id=checkpoint_id) - if checkpoint_id >= self._next_checkpoint_id: - logger.error("Checkpoint {checkpoint_id} does not exist", checkpoint_id=checkpoint_id) - raise ValueError(f"Checkpoint {checkpoint_id} does not exist") - - # rotate the context file - rotated_file_path = await next_available_rotation(self._file_backend) - if rotated_file_path is None: - logger.error("No available rotation path found") - raise RuntimeError("No available rotation path found") - await aiofiles.os.replace(self._file_backend, rotated_file_path) - logger.debug( - "Rotated context file: {rotated_file_path}", rotated_file_path=rotated_file_path - ) - - # restore the context until the specified checkpoint - self._history.clear() - self._token_count = 0 - self._next_checkpoint_id = 0 - self._system_prompt = None - messages_after_last_usage: list[Message] = [] - async with ( - aiofiles.open(rotated_file_path, encoding="utf-8") as old_file, - aiofiles.open(self._file_backend, "w", encoding="utf-8") as new_file, - ): - async for line in old_file: - if not line.strip(): - continue - - line_json = json.loads(line, strict=False) - if line_json["role"] == "_checkpoint" and line_json["id"] == checkpoint_id: - break - - await new_file.write(line) - if line_json["role"] == "_system_prompt": - self._system_prompt = line_json["content"] - elif line_json["role"] == "_usage": - self._token_count = line_json["token_count"] - messages_after_last_usage.clear() - elif line_json["role"] == "_checkpoint": - self._next_checkpoint_id = line_json["id"] + 1 - else: - message = Message.model_validate(line_json) - self._history.append(message) - messages_after_last_usage.append(message) - - self._pending_token_estimate = estimate_text_tokens(messages_after_last_usage) - - async def clear(self): - """ - Clear the context history. - This is almost equivalent to revert_to(0), but without relying on the assumption - that the first checkpoint exists. - File backend will be rotated. - - Raises: - RuntimeError: When no available rotation path is found. - """ - - logger.debug("Clearing context") - - # rotate the context file - rotated_file_path = await next_available_rotation(self._file_backend) - if rotated_file_path is None: - logger.error("No available rotation path found") - raise RuntimeError("No available rotation path found") - await aiofiles.os.replace(self._file_backend, rotated_file_path) - self._file_backend.touch() - logger.debug( - "Rotated context file: {rotated_file_path}", rotated_file_path=rotated_file_path - ) - - self._history.clear() - self._token_count = 0 - self._pending_token_estimate = 0 - self._next_checkpoint_id = 0 - self._system_prompt = None - - async def append_message(self, message: Message | Sequence[Message]): - logger.debug("Appending message(s) to context: {message}", message=message) - messages = [message] if isinstance(message, Message) else message - self._history.extend(messages) - self._pending_token_estimate += estimate_text_tokens(messages) - - async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: - for message in messages: - await f.write(message.model_dump_json(exclude_none=True) + "\n") - - async def update_token_count(self, token_count: int): - logger.debug("Updating token count in context: {token_count}", token_count=token_count) - self._token_count = token_count - self._pending_token_estimate = 0 - - async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: - await f.write(json.dumps({"role": "_usage", "token_count": token_count}) + "\n") diff --git a/src/kimi_cli/soul/context.ts b/src/kimi_cli/soul/context.ts new file mode 100644 index 000000000..688f37991 --- /dev/null +++ b/src/kimi_cli/soul/context.ts @@ -0,0 +1,288 @@ +/** + * Context window management — corresponds to Python soul/context.py + * Manages conversation message history with token tracking and persistence. + */ + +import type { Message, TokenUsage } from "../types.ts"; +import { estimateMessagesTokenCount } from "../llm.ts"; +import { logger } from "../utils/logging.ts"; + +// ── Special JSONL record markers ──────────────────────── + +interface SystemPromptRecord { + _system_prompt: string; +} + +interface UsageRecord { + _usage: { input_tokens: number; output_tokens: number }; +} + +interface CheckpointRecord { + _checkpoint: { id: number; reminder?: string }; +} + +type ContextRecord = Message | SystemPromptRecord | UsageRecord | CheckpointRecord; + +// ── Context class ─────────────────────────────────────── + +export class Context { + private _history: Message[] = []; + private _tokenCount = 0; + private _pendingTokenEstimate = 0; + private _nextCheckpointId = 0; + private _systemPrompt: string | null = null; + private _filePath: string; + + constructor(filePath: string) { + this._filePath = filePath; + } + + // ── Properties ─────────────────────────────────── + + get history(): readonly Message[] { + return this._history; + } + + get tokenCount(): number { + return this._tokenCount; + } + + get tokenCountWithPending(): number { + return this._tokenCount + this._pendingTokenEstimate; + } + + get systemPrompt(): string | null { + return this._systemPrompt; + } + + get nCheckpoints(): number { + return this._nextCheckpointId; + } + + get filePath(): string { + return this._filePath; + } + + // ── Restore from file ──────────────────────────── + + async restore(): Promise { + const file = Bun.file(this._filePath); + if (!(await file.exists())) return; + + const text = await file.text(); + const lines = text.split("\n"); + let lastUsageLineIdx = -1; + + this._history = []; + this._systemPrompt = null; + this._tokenCount = 0; + this._nextCheckpointId = 0; + + for (let i = 0; i < lines.length; i++) { + const line = lines[i]!.trim(); + if (!line) continue; + + try { + const record: ContextRecord = JSON.parse(line); + + if ("_system_prompt" in record) { + this._systemPrompt = record._system_prompt; + } else if ("_usage" in record) { + // Only input tokens count toward context window (matches Python behavior) + this._tokenCount = record._usage.input_tokens; + lastUsageLineIdx = i; + } else if ("_checkpoint" in record) { + this._nextCheckpointId = record._checkpoint.id + 1; + if (record._checkpoint.reminder) { + // Checkpoint with system reminder → inject as user message + this._history.push({ + role: "user", + content: `\n${record._checkpoint.reminder}\n`, + }); + } + } else if ("role" in record) { + this._history.push(record as Message); + } + } catch { + logger.warn(`Skipping corrupt context line ${i}: ${line.slice(0, 80)}`); + } + } + + // Estimate tokens for messages after last usage record + if (lastUsageLineIdx >= 0) { + const postUsageMessages: Message[] = []; + let postUsageCount = 0; + for (let i = lastUsageLineIdx + 1; i < lines.length; i++) { + const line = lines[i]!.trim(); + if (!line) continue; + try { + const record = JSON.parse(line); + if ("role" in record) { + postUsageMessages.push(record); + postUsageCount++; + } + } catch { + // skip + } + } + if (postUsageCount > 0) { + this._pendingTokenEstimate = + estimateMessagesTokenCount(postUsageMessages); + } + } else { + // No usage record at all → estimate everything + this._pendingTokenEstimate = estimateMessagesTokenCount(this._history); + } + } + + // ── Append message ────────────────────────────── + + async appendMessage(message: Message): Promise { + this._history.push(message); + const estimate = estimateMessagesTokenCount([message]); + this._pendingTokenEstimate += estimate; + await this._appendToFile(message); + } + + // ── Write system prompt ───────────────────────── + + async writeSystemPrompt(systemPrompt: string): Promise { + this._systemPrompt = systemPrompt; + const record: SystemPromptRecord = { _system_prompt: systemPrompt }; + // Prepend to file (rewrite) + const file = Bun.file(this._filePath); + let existing = ""; + if (await file.exists()) { + existing = await file.text(); + } + const line = JSON.stringify(record) + "\n"; + await Bun.write(this._filePath, line + existing); + } + + // ── Update token count ────────────────────────── + + async updateTokenCount(usage: TokenUsage): Promise { + // Only input tokens count toward context window size (output doesn't consume context) + this._tokenCount = usage.inputTokens + (usage.cacheReadTokens ?? 0); + this._pendingTokenEstimate = 0; + const record: UsageRecord = { + _usage: { + input_tokens: usage.inputTokens, + output_tokens: usage.outputTokens, + }, + }; + await this._appendToFile(record); + } + + // ── Checkpoint ────────────────────────────────── + + async checkpoint(reminder?: string): Promise { + const id = this._nextCheckpointId++; + const record: CheckpointRecord = { + _checkpoint: { id, ...(reminder ? { reminder } : {}) }, + }; + if (reminder) { + this._history.push({ + role: "user", + content: `\n${reminder}\n`, + }); + } + await this._appendToFile(record); + return id; + } + + // ── Clear context ────────────────────────────── + + async clear(): Promise { + // Clear all state, keep system prompt + this._history = []; + this._tokenCount = 0; + this._pendingTokenEstimate = 0; + this._nextCheckpointId = 0; + + // Write empty file (with system prompt if present) + if (this._systemPrompt) { + const record: SystemPromptRecord = { + _system_prompt: this._systemPrompt, + }; + await Bun.write(this._filePath, JSON.stringify(record) + "\n"); + } else { + await Bun.write(this._filePath, ""); + } + } + + // ── Compact (clear and rotate) ───────────────── + + async compact(): Promise { + // Rotate old file + const backupPath = this._filePath + ".bak"; + const file = Bun.file(this._filePath); + if (await file.exists()) { + const content = await file.text(); + await Bun.write(backupPath, content); + } + + // Clear state + this._history = []; + this._tokenCount = 0; + this._pendingTokenEstimate = 0; + this._nextCheckpointId = 0; + + // Write empty file (with system prompt if present) + if (this._systemPrompt) { + const record: SystemPromptRecord = { + _system_prompt: this._systemPrompt, + }; + await Bun.write(this._filePath, JSON.stringify(record) + "\n"); + } else { + await Bun.write(this._filePath, ""); + } + } + + // ── Revert to checkpoint ─────────────────────── + + async revertTo(checkpointId: number): Promise { + const file = Bun.file(this._filePath); + if (!(await file.exists())) return; + + const text = await file.text(); + const lines = text.split("\n"); + const kept: string[] = []; + let found = false; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed) continue; + kept.push(trimmed); + try { + const record = JSON.parse(trimmed); + if ("_checkpoint" in record && record._checkpoint.id === checkpointId) { + found = true; + break; + } + } catch { + // keep the line + } + } + + if (!found) { + logger.warn(`Checkpoint ${checkpointId} not found, no revert`); + return; + } + + // Backup and rewrite + await Bun.write(this._filePath + ".bak", text); + await Bun.write(this._filePath, kept.join("\n") + "\n"); + + // Reload + await this.restore(); + } + + // ── Private helpers ──────────────────────────── + + private async _appendToFile(record: ContextRecord): Promise { + const line = JSON.stringify(record) + "\n"; + const { appendFile } = await import("node:fs/promises"); + await appendFile(this._filePath, line, "utf-8"); + } +} diff --git a/src/kimi_cli/soul/denwarenji.py b/src/kimi_cli/soul/denwarenji.py deleted file mode 100644 index aa1ba8ef3..000000000 --- a/src/kimi_cli/soul/denwarenji.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from pydantic import BaseModel, Field - - -class DMail(BaseModel): - message: str = Field(description="The message to send.") - checkpoint_id: int = Field(description="The checkpoint to send the message back to.", ge=0) - # TODO: allow restoring filesystem state to the checkpoint - - -class DenwaRenjiError(Exception): - pass - - -class DenwaRenji: - def __init__(self): - self._pending_dmail: DMail | None = None - self._n_checkpoints: int = 0 - - def send_dmail(self, dmail: DMail): - """Send a D-Mail. Intended to be called by the SendDMail tool.""" - if self._pending_dmail is not None: - raise DenwaRenjiError("Only one D-Mail can be sent at a time") - if dmail.checkpoint_id < 0: - raise DenwaRenjiError("The checkpoint ID can not be negative") - if dmail.checkpoint_id >= self._n_checkpoints: - raise DenwaRenjiError("There is no checkpoint with the given ID") - self._pending_dmail = dmail - - def set_n_checkpoints(self, n_checkpoints: int): - """Set the number of checkpoints. Intended to be called by the soul.""" - self._n_checkpoints = n_checkpoints - - def fetch_pending_dmail(self) -> DMail | None: - """Fetch a pending D-Mail. Intended to be called by the soul.""" - pending_dmail = self._pending_dmail - self._pending_dmail = None - return pending_dmail diff --git a/src/kimi_cli/soul/dynamic_injection.py b/src/kimi_cli/soul/dynamic_injection.py deleted file mode 100644 index 9eabf320e..000000000 --- a/src/kimi_cli/soul/dynamic_injection.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from kosong.message import Message - -from kimi_cli.notifications import is_notification_message - -if TYPE_CHECKING: - from kimi_cli.soul.kimisoul import KimiSoul - - -@dataclass(frozen=True, slots=True) -class DynamicInjection: - """A dynamic prompt content to be injected before an LLM step.""" - - type: str # identifier, e.g. "plan_mode" - content: str # text content (will be wrapped in tags) - - -class DynamicInjectionProvider(ABC): - """Base class for dynamic injection providers. - - Called before each LLM step. Implementations handle their own throttling. - Providers can access all runtime state via the ``soul`` parameter - (context_usage, runtime, config, etc.). - """ - - @abstractmethod - async def get_injections( - self, - history: Sequence[Message], - soul: KimiSoul, - ) -> list[DynamicInjection]: ... - - -def normalize_history(history: Sequence[Message]) -> list[Message]: - """Merge adjacent user messages to produce a clean API input sequence. - - Dynamic injections are stored as standalone user messages in history; - normalization merges them into the adjacent user message. - - Only ``user`` role messages are merged. Assistant and tool messages - are never merged because their ``tool_calls`` / ``tool_call_id`` - fields form linked pairs that must stay intact. - """ - if not history: - return [] - - result: list[Message] = [] - for msg in history: - if ( - result - and result[-1].role == msg.role - and msg.role == "user" - and not is_notification_message(result[-1]) - and not is_notification_message(msg) - ): - merged_content = list(result[-1].content) + list(msg.content) - result[-1] = Message(role="user", content=merged_content) - else: - result.append(msg) - return result diff --git a/src/kimi_cli/soul/dynamic_injection.ts b/src/kimi_cli/soul/dynamic_injection.ts new file mode 100644 index 000000000..8bf9b465f --- /dev/null +++ b/src/kimi_cli/soul/dynamic_injection.ts @@ -0,0 +1,102 @@ +/** + * Dynamic injection system — corresponds to Python soul/dynamic_injection.py + * Provides an extensible provider pattern for injecting dynamic prompts before LLM steps. + */ + +import type { Message, ContentPart } from "../types.ts"; +import type { KimiSoul } from "./kimisoul.ts"; + +// ── DynamicInjection ───────────────────────────────── + +export interface DynamicInjection { + /** Identifier, e.g. "plan_mode", "yolo_mode" */ + readonly type: string; + /** Text content (will be wrapped in tags) */ + readonly content: string; +} + +// ── DynamicInjectionProvider ───────────────────────── + +/** + * Base interface for dynamic injection providers. + * + * Called before each LLM step. Implementations handle their own throttling. + * Providers can access all runtime state via the `soul` parameter. + */ +export interface DynamicInjectionProvider { + getInjections( + history: readonly Message[], + soul: KimiSoul, + ): Promise; +} + +// ── normalizeHistory ───────────────────────────────── + +/** + * Merge adjacent user messages to produce a clean API input sequence. + * + * Dynamic injections are stored as standalone user messages in history; + * normalization merges them into the adjacent user message. + * + * Only `user` role messages are merged. Assistant and tool messages + * are never merged because their tool_calls / tool_call_id fields + * form linked pairs that must stay intact. + */ +export function normalizeHistory(messages: readonly Message[]): Message[] { + if (messages.length === 0) return []; + + const result: Message[] = []; + for (const msg of messages) { + const prev = result[result.length - 1]; + if ( + prev && + prev.role === "user" && + msg.role === "user" && + !isNotificationMessage(prev) && + !isNotificationMessage(msg) + ) { + // Merge content + const prevParts = toContentArray(prev.content); + const curParts = toContentArray(msg.content); + result[result.length - 1] = { + role: "user", + content: [...prevParts, ...curParts], + }; + } else { + result.push(msg); + } + } + return result; +} + +// ── Helpers ────────────────────────────────────────── + +function toContentArray( + content: string | readonly ContentPart[], +): ContentPart[] { + if (typeof content === "string") { + return [{ type: "text" as const, text: content }]; + } + return [...content]; +} + +/** + * Minimal check: notification messages are user messages whose text + * starts with a notification tag. This keeps us decoupled from the + * full notifications module. + */ +function isNotificationMessage(msg: Message): boolean { + if (msg.role !== "user") return false; + const text = extractText(msg.content); + return text.includes("") || text.includes(" p.type === "text") + .map((p) => p.text) + .join(""); +} diff --git a/src/kimi_cli/soul/dynamic_injections/__init__.py b/src/kimi_cli/soul/dynamic_injections/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/kimi_cli/soul/dynamic_injections/plan_mode.py b/src/kimi_cli/soul/dynamic_injections/plan_mode.py deleted file mode 100644 index 1861473d4..000000000 --- a/src/kimi_cli/soul/dynamic_injections/plan_mode.py +++ /dev/null @@ -1,238 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import TYPE_CHECKING - -from kosong.message import Message, TextPart - -from kimi_cli.soul.dynamic_injection import DynamicInjection, DynamicInjectionProvider - -if TYPE_CHECKING: - from kimi_cli.soul.kimisoul import KimiSoul - -# Inject a reminder every N assistant turns. -_TURN_INTERVAL = 5 -# Every N-th reminder is the full version; others are sparse. -_FULL_EVERY_N = 5 - - -class PlanModeInjectionProvider(DynamicInjectionProvider): - """Periodically injects read-only reminders while plan mode is active. - - Throttling is inferred from history: scan backwards to the last - plan mode reminder and count assistant messages in between. - Only inject when the count exceeds ``_TURN_INTERVAL``. - """ - - def __init__(self) -> None: - self._inject_count: int = 0 - - async def get_injections( - self, - history: Sequence[Message], - soul: KimiSoul, - ) -> list[DynamicInjection]: - if not soul.plan_mode: - self._inject_count = 0 - return [] - - plan_path = soul.get_plan_file_path() - plan_path_str = str(plan_path) if plan_path else None - plan_exists = plan_path is not None and plan_path.exists() - - # Manual toggles schedule a one-shot activation reminder for the next LLM step. - if soul.consume_pending_plan_activation_injection(): - self._inject_count = 1 - # When re-entering with an existing plan, use the reentry reminder. - if plan_exists: - return [ - DynamicInjection( - type="plan_mode_reentry", - content=_reentry_reminder(plan_path_str), - ) - ] - return [ - DynamicInjection( - type="plan_mode", - content=_full_reminder(plan_path_str, plan_exists), - ) - ] - - # Scan history backwards to find the last plan mode reminder. - turns_since_last = 0 - found_previous = False - for msg in reversed(history): - if msg.role == "user" and _has_plan_reminder(msg): - found_previous = True - break - if msg.role == "assistant": - turns_since_last += 1 - - # First time (no reminder in history yet) -> inject full version. - if not found_previous: - self._inject_count = 1 - return [ - DynamicInjection( - type="plan_mode", - content=_full_reminder(plan_path_str, plan_exists), - ) - ] - - # Not enough turns since last reminder -> skip. - if turns_since_last < _TURN_INTERVAL: - return [] - - # Inject. - self._inject_count += 1 - is_full = self._inject_count % _FULL_EVERY_N == 1 - if is_full: - content = _full_reminder(plan_path_str, plan_exists) - else: - content = _sparse_reminder(plan_path_str) - return [DynamicInjection(type="plan_mode", content=content)] - - -def _has_plan_reminder(msg: Message) -> bool: - """Check whether a message contains a plan mode reminder. - - Detects by matching against stable prefixes of the actual reminder texts - so changes to the reminder wording stay automatically in sync. - """ - keys = ( - _sparse_reminder().split(".")[0], # "Plan mode still active ..." - _full_reminder().split("\n")[0], # "Plan mode is active. ..." - ) - for part in msg.content: - if isinstance(part, TextPart) and any(key in part.text for key in keys): - return True - return False - - -def _full_reminder( - plan_file_path: str | None = None, - plan_exists: bool = False, -) -> str: - lines = [ - "Plan mode is active. You MUST NOT make any edits " - "(with the exception of the plan file below), run non-readonly tools, " - "or otherwise make changes to the system. " - "This supersedes any other instructions you have received.", - ] - # Plan file info block - if plan_file_path: - lines.append("") - if plan_exists: - lines.append( - f"Plan file: {plan_file_path} " - "(exists — read first, then update it with WriteFile or StrReplaceFile)" - ) - else: - lines.append( - f"Plan file: {plan_file_path} " - "(create it with WriteFile; once it exists, you can modify it with " - "WriteFile or StrReplaceFile)" - ) - lines.append("This is the only file you are allowed to edit.") - # Workflow - lines.extend( - [ - "", - "Workflow:", - "1. Understand — explore the codebase with Glob, Grep, ReadFile", - "2. Design — converge on the best approach; " - "consider trade-offs but aim for a single recommendation", - "3. Review — re-read key files to verify understanding", - "4. Write Plan — modify the plan file with WriteFile or StrReplaceFile. " - "Use WriteFile if the plan file does not exist yet", - "5. Exit — call ExitPlanMode for user approval", - ] - ) - # Multi-approach handling - lines.extend( - [ - "", - "## Handling multiple approaches", - "Keep it focused: at most 2-3 meaningfully different approaches. " - "Do NOT pad with minor variations — if one approach is clearly " - "superior, just propose that one.", - "When the best approach depends on user preferences, constraints, " - "or context you don't have, use AskUserQuestion to clarify first. " - "This helps you write a better, more targeted plan rather than " - "dumping multiple options for the user to sort through.", - "When you do include multiple approaches in the plan, you MUST pass them " - "as the `options` parameter when calling ExitPlanMode, so the user can select which " - "approach to execute at approval time.", - "NEVER write multiple approaches in the plan and call ExitPlanMode without the " - "`options` parameter — the user will only see Approve/Reject with no way to choose.", - ] - ) - # Turn ending constraint + anti-pattern - lines.extend( - [ - "", - "AskUserQuestion is for clarifying missing requirements or user preferences " - "that affect the plan.", - "Never ask about plan approval via text or AskUserQuestion.", - "Your turn must end with either AskUserQuestion " - "(to clarify requirements or preferences) " - "or ExitPlanMode (to request plan approval). " - "Do NOT end your turn any other way.", - "Do NOT use AskUserQuestion to ask about plan approval or reference " - '"the plan" — the user cannot see the plan until you call ExitPlanMode.', - ] - ) - return "\n".join(lines) - - -def _sparse_reminder(plan_file_path: str | None = None) -> str: - parts = [ - "Plan mode still active (see full instructions earlier).", - ] - if plan_file_path: - parts.append(f"Read-only except plan file ({plan_file_path}).") - else: - parts.append("Read-only.") - parts.extend( - [ - "Use WriteFile or StrReplaceFile to modify the plan file. " - "If it does not exist yet, create it with WriteFile first.", - "Use AskUserQuestion to clarify user preferences " - "when it helps you write a better plan.", - "If the plan has multiple approaches, " - "pass options to ExitPlanMode so the user can choose.", - "End turns with AskUserQuestion (for clarifications) or ExitPlanMode (for approval).", - "Never ask about plan approval via text or AskUserQuestion.", - ] - ) - return " ".join(parts) - - -def _reentry_reminder(plan_file_path: str | None = None) -> str: - """One-shot reminder when re-entering plan mode with an existing plan.""" - lines = [ - "Plan mode is active. You MUST NOT make any edits " - "(with the exception of the plan file below), run non-readonly tools, " - "or otherwise make changes to the system. " - "This supersedes any other instructions you have received.", - "", - "## Re-entering Plan Mode", - ( - f"A plan file exists at {plan_file_path} from a previous planning session." - if plan_file_path - else "A plan file from a previous planning session already exists." - ), - "Before proceeding:", - "1. Read the existing plan file to understand what was previously planned", - "2. Evaluate the user's current request against that plan", - "3. If different task: replace the old plan with a fresh one. " - "If same task: update the existing plan.", - "4. You may use WriteFile or StrReplaceFile to modify the plan file. " - "If the file does not exist yet, create it with WriteFile first.", - "5. Use AskUserQuestion to clarify missing requirements " - "or user preferences that affect the plan.", - "6. Always edit the plan file before calling ExitPlanMode.", - "", - "Your turn must end with either AskUserQuestion (to clarify requirements) " - "or ExitPlanMode (to request plan approval).", - ] - return "\n".join(lines) diff --git a/src/kimi_cli/soul/dynamic_injections/plan_mode.ts b/src/kimi_cli/soul/dynamic_injections/plan_mode.ts new file mode 100644 index 000000000..9059e6149 --- /dev/null +++ b/src/kimi_cli/soul/dynamic_injections/plan_mode.ts @@ -0,0 +1,236 @@ +/** + * Plan mode dynamic injection — corresponds to Python soul/dynamic_injections/plan_mode.py + * Periodically injects read-only reminders while plan mode is active. + */ + +import type { Message } from "../../types.ts"; +import type { KimiSoul } from "../kimisoul.ts"; +import type { DynamicInjection, DynamicInjectionProvider } from "../dynamic_injection.ts"; + +/** Inject a reminder every N assistant turns. */ +const TURN_INTERVAL = 5; +/** Every N-th reminder is the full version; others are sparse. */ +const FULL_EVERY_N = 5; + +export class PlanModeInjectionProvider implements DynamicInjectionProvider { + private _injectCount = 0; + + async getInjections( + history: readonly Message[], + soul: KimiSoul, + ): Promise { + if (!soul.planMode) { + this._injectCount = 0; + return []; + } + + const planPath = soul.getPlanFilePath(); + const planPathStr = planPath ?? null; + const planExists = planPath != null && (await fileExists(planPath)); + + // Manual toggles schedule a one-shot activation reminder for the next LLM step. + if (soul.consumePendingPlanActivationInjection()) { + this._injectCount = 1; + if (planExists) { + return [ + { type: "plan_mode_reentry", content: reentryReminder(planPathStr) }, + ]; + } + return [ + { type: "plan_mode", content: fullReminder(planPathStr, planExists) }, + ]; + } + + // Scan history backwards to find the last plan mode reminder. + let turnsSinceLast = 0; + let foundPrevious = false; + for (let i = history.length - 1; i >= 0; i--) { + const msg = history[i]!; + if (msg.role === "user" && hasPlanReminder(msg)) { + foundPrevious = true; + break; + } + if (msg.role === "assistant") { + turnsSinceLast++; + } + } + + // First time (no reminder in history yet) -> inject full version. + if (!foundPrevious) { + this._injectCount = 1; + return [ + { type: "plan_mode", content: fullReminder(planPathStr, planExists) }, + ]; + } + + // Not enough turns since last reminder -> skip. + if (turnsSinceLast < TURN_INTERVAL) { + return []; + } + + // Inject. + this._injectCount++; + const isFull = this._injectCount % FULL_EVERY_N === 1; + const content = isFull + ? fullReminder(planPathStr, planExists) + : sparseReminder(planPathStr); + return [{ type: "plan_mode", content }]; + } +} + +// ── Reminder text builders ─────────────────────────── + +function hasPlanReminder(msg: Message): boolean { + const sparseKey = sparseReminder().split(".")[0]!; + const fullKey = fullReminder().split("\n")[0]!; + + const text = extractText(msg.content); + return text.includes(sparseKey) || text.includes(fullKey); +} + +export function fullReminder( + planFilePath: string | null = null, + planExists = false, +): string { + const lines: string[] = [ + "Plan mode is active. You MUST NOT make any edits " + + "(with the exception of the plan file below), run non-readonly tools, " + + "or otherwise make changes to the system. " + + "This supersedes any other instructions you have received.", + ]; + + if (planFilePath) { + lines.push(""); + if (planExists) { + lines.push( + `Plan file: ${planFilePath} ` + + "(exists — read first, then update it with WriteFile or StrReplaceFile)", + ); + } else { + lines.push( + `Plan file: ${planFilePath} ` + + "(create it with WriteFile; once it exists, you can modify it with " + + "WriteFile or StrReplaceFile)", + ); + } + lines.push("This is the only file you are allowed to edit."); + } + + lines.push( + "", + "Workflow:", + "1. Understand — explore the codebase with Glob, Grep, ReadFile", + "2. Design — converge on the best approach; " + + "consider trade-offs but aim for a single recommendation", + "3. Review — re-read key files to verify understanding", + "4. Write Plan — modify the plan file with WriteFile or StrReplaceFile. " + + "Use WriteFile if the plan file does not exist yet", + "5. Exit — call ExitPlanMode for user approval", + ); + + lines.push( + "", + "## Handling multiple approaches", + "Keep it focused: at most 2-3 meaningfully different approaches. " + + "Do NOT pad with minor variations — if one approach is clearly " + + "superior, just propose that one.", + "When the best approach depends on user preferences, constraints, " + + "or context you don't have, use AskUserQuestion to clarify first. " + + "This helps you write a better, more targeted plan rather than " + + "dumping multiple options for the user to sort through.", + "When you do include multiple approaches in the plan, you MUST pass them " + + "as the `options` parameter when calling ExitPlanMode, so the user can select which " + + "approach to execute at approval time.", + "NEVER write multiple approaches in the plan and call ExitPlanMode without the " + + "`options` parameter — the user will only see Approve/Reject with no way to choose.", + ); + + lines.push( + "", + "AskUserQuestion is for clarifying missing requirements or user preferences " + + "that affect the plan.", + "Never ask about plan approval via text or AskUserQuestion.", + "Your turn must end with either AskUserQuestion " + + "(to clarify requirements or preferences) " + + "or ExitPlanMode (to request plan approval). " + + "Do NOT end your turn any other way.", + 'Do NOT use AskUserQuestion to ask about plan approval or reference ' + + '"the plan" — the user cannot see the plan until you call ExitPlanMode.', + ); + + return lines.join("\n"); +} + +export function sparseReminder(planFilePath: string | null = null): string { + const parts: string[] = [ + "Plan mode still active (see full instructions earlier).", + ]; + + if (planFilePath) { + parts.push(`Read-only except plan file (${planFilePath}).`); + } else { + parts.push("Read-only."); + } + + parts.push( + "Use WriteFile or StrReplaceFile to modify the plan file. " + + "If it does not exist yet, create it with WriteFile first.", + "Use AskUserQuestion to clarify user preferences " + + "when it helps you write a better plan.", + "If the plan has multiple approaches, " + + "pass options to ExitPlanMode so the user can choose.", + "End turns with AskUserQuestion (for clarifications) or ExitPlanMode (for approval).", + "Never ask about plan approval via text or AskUserQuestion.", + ); + + return parts.join(" "); +} + +function reentryReminder(planFilePath: string | null = null): string { + const lines: string[] = [ + "Plan mode is active. You MUST NOT make any edits " + + "(with the exception of the plan file below), run non-readonly tools, " + + "or otherwise make changes to the system. " + + "This supersedes any other instructions you have received.", + "", + "## Re-entering Plan Mode", + planFilePath + ? `A plan file exists at ${planFilePath} from a previous planning session.` + : "A plan file from a previous planning session already exists.", + "Before proceeding:", + "1. Read the existing plan file to understand what was previously planned", + "2. Evaluate the user's current request against that plan", + "3. If different task: replace the old plan with a fresh one. " + + "If same task: update the existing plan.", + "4. You may use WriteFile or StrReplaceFile to modify the plan file. " + + "If the file does not exist yet, create it with WriteFile first.", + "5. Use AskUserQuestion to clarify missing requirements " + + "or user preferences that affect the plan.", + "6. Always edit the plan file before calling ExitPlanMode.", + "", + "Your turn must end with either AskUserQuestion (to clarify requirements) " + + "or ExitPlanMode (to request plan approval).", + ]; + return lines.join("\n"); +} + +// ── Helpers ────────────────────────────────────────── + +function extractText( + content: string | readonly { type: string; [key: string]: unknown }[], +): string { + if (typeof content === "string") return content; + return content + .filter((p): p is { type: "text"; text: string } => p.type === "text") + .map((p) => p.text) + .join(""); +} + +async function fileExists(path: string): Promise { + try { + const file = Bun.file(path); + return await file.exists(); + } catch { + return false; + } +} diff --git a/src/kimi_cli/soul/dynamic_injections/yolo_mode.py b/src/kimi_cli/soul/dynamic_injections/yolo_mode.py deleted file mode 100644 index a6c8430df..000000000 --- a/src/kimi_cli/soul/dynamic_injections/yolo_mode.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import TYPE_CHECKING - -from kosong.message import Message - -from kimi_cli.soul.dynamic_injection import DynamicInjection, DynamicInjectionProvider - -if TYPE_CHECKING: - from kimi_cli.soul.kimisoul import KimiSoul - -_YOLO_INJECTION_TYPE = "yolo_mode" - -_YOLO_PROMPT = ( - "You are running in non-interactive mode. The user cannot answer questions " - "or provide feedback during execution.\n" - "- Do NOT call AskUserQuestion. If you need to make a decision, make your " - "best judgment and proceed.\n" - "- For EnterPlanMode / ExitPlanMode, they will be auto-approved. You can use " - "them normally but expect no user feedback." -) - - -class YoloModeInjectionProvider(DynamicInjectionProvider): - """Injects a one-time reminder when yolo mode is active.""" - - def __init__(self) -> None: - self._injected: bool = False - - async def get_injections( - self, - history: Sequence[Message], - soul: KimiSoul, - ) -> list[DynamicInjection]: - if not soul.is_yolo: - return [] - if self._injected: - return [] - self._injected = True - return [DynamicInjection(type=_YOLO_INJECTION_TYPE, content=_YOLO_PROMPT)] diff --git a/src/kimi_cli/soul/dynamic_injections/yolo_mode.ts b/src/kimi_cli/soul/dynamic_injections/yolo_mode.ts new file mode 100644 index 000000000..b23156a7c --- /dev/null +++ b/src/kimi_cli/soul/dynamic_injections/yolo_mode.ts @@ -0,0 +1,36 @@ +/** + * YOLO mode dynamic injection — corresponds to Python soul/dynamic_injections/yolo_mode.py + * Injects a one-time reminder when yolo mode is active. + */ + +import type { Message } from "../../types.ts"; +import type { KimiSoul } from "../kimisoul.ts"; +import type { DynamicInjection, DynamicInjectionProvider } from "../dynamic_injection.ts"; + +const YOLO_INJECTION_TYPE = "yolo_mode"; + +const YOLO_PROMPT = + "You are running in non-interactive mode. The user cannot answer questions " + + "or provide feedback during execution.\n" + + "- Do NOT call AskUserQuestion. If you need to make a decision, make your " + + "best judgment and proceed.\n" + + "- For EnterPlanMode / ExitPlanMode, they will be auto-approved. You can use " + + "them normally but expect no user feedback."; + +export class YoloModeInjectionProvider implements DynamicInjectionProvider { + private _injected = false; + + async getInjections( + _history: readonly Message[], + soul: KimiSoul, + ): Promise { + if (!soul.isYolo) { + return []; + } + if (this._injected) { + return []; + } + this._injected = true; + return [{ type: YOLO_INJECTION_TYPE, content: YOLO_PROMPT }]; + } +} diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py deleted file mode 100644 index 7e79b5d9e..000000000 --- a/src/kimi_cli/soul/kimisoul.py +++ /dev/null @@ -1,1244 +0,0 @@ -from __future__ import annotations - -import asyncio -import uuid -from collections.abc import Awaitable, Callable, Sequence -from dataclasses import dataclass -from functools import partial -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, cast - -import kosong -import tenacity -from kosong import StepResult -from kosong.chat_provider import ( - APIConnectionError, - APIEmptyResponseError, - APIStatusError, - APITimeoutError, - RetryableChatProvider, -) -from kosong.message import Message -from tenacity import RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential_jitter - -from kimi_cli.approval_runtime import ( - ApprovalSource, - get_current_approval_source_or_none, - reset_current_approval_source, - set_current_approval_source, -) -from kimi_cli.background import build_active_task_snapshot -from kimi_cli.hooks.engine import HookEngine -from kimi_cli.llm import ModelCapability -from kimi_cli.notifications import ( - NotificationView, - build_notification_message, - extract_notification_ids, -) -from kimi_cli.skill import Skill, read_skill_text -from kimi_cli.skill.flow import Flow, FlowEdge, FlowNode, parse_choice -from kimi_cli.soul import ( - LLMNotSet, - LLMNotSupported, - MaxStepsReached, - Soul, - StatusSnapshot, - wire_send, -) -from kimi_cli.soul.agent import Agent, Runtime -from kimi_cli.soul.compaction import ( - CompactionResult, - SimpleCompaction, - estimate_text_tokens, - should_auto_compact, -) -from kimi_cli.soul.context import Context -from kimi_cli.soul.dynamic_injection import ( - DynamicInjection, - DynamicInjectionProvider, - normalize_history, -) -from kimi_cli.soul.dynamic_injections.plan_mode import PlanModeInjectionProvider -from kimi_cli.soul.dynamic_injections.yolo_mode import YoloModeInjectionProvider -from kimi_cli.soul.message import check_message, system, system_reminder, tool_result_to_message -from kimi_cli.soul.slash import registry as soul_slash_registry -from kimi_cli.soul.toolset import KimiToolset -from kimi_cli.tools.dmail import NAME as SendDMail_NAME -from kimi_cli.tools.utils import ToolRejectedError -from kimi_cli.utils.logging import logger -from kimi_cli.utils.slashcmd import SlashCommand, parse_slash_command_call -from kimi_cli.wire.file import WireFile -from kimi_cli.wire.types import ( - CompactionBegin, - CompactionEnd, - ContentPart, - MCPLoadingBegin, - MCPLoadingEnd, - StatusUpdate, - SteerInput, - StepBegin, - StepInterrupted, - TextPart, - ToolResult, - TurnBegin, - TurnEnd, -) - -if TYPE_CHECKING: - - def type_check(soul: KimiSoul): - _: Soul = soul - - -SKILL_COMMAND_PREFIX = "skill:" -FLOW_COMMAND_PREFIX = "flow:" -DEFAULT_MAX_FLOW_MOVES = 1000 - - -type StepStopReason = Literal["no_tool_calls", "tool_rejected"] - - -@dataclass(frozen=True, slots=True) -class StepOutcome: - stop_reason: StepStopReason - assistant_message: Message - - -type TurnStopReason = StepStopReason - - -@dataclass(frozen=True, slots=True) -class TurnOutcome: - stop_reason: TurnStopReason - final_message: Message | None - step_count: int - - -class KimiSoul: - """The soul of Kimi Code CLI.""" - - def __init__( - self, - agent: Agent, - *, - context: Context, - ): - """ - Initialize the soul. - - Args: - agent (Agent): The agent to run. - context (Context): The context of the agent. - """ - self._agent = agent - self._runtime = agent.runtime - self._denwa_renji = agent.runtime.denwa_renji - self._approval = agent.runtime.approval - self._context = context - self._loop_control = agent.runtime.config.loop_control - self._compaction = SimpleCompaction() # TODO: maybe configurable and composable - - for tool in agent.toolset.tools: - if tool.name == SendDMail_NAME: - self._checkpoint_with_user_message = True - break - else: - self._checkpoint_with_user_message = False - - self._steer_queue: asyncio.Queue[str | list[ContentPart]] = asyncio.Queue() - self._plan_mode: bool = self._runtime.session.state.plan_mode - self._plan_session_id: str | None = self._runtime.session.state.plan_session_id - # Pre-warm slug cache so the persisted slug survives process restarts - if self._plan_session_id is not None and self._runtime.session.state.plan_slug is not None: - from kimi_cli.tools.plan.heroes import seed_slug_cache - - seed_slug_cache(self._plan_session_id, self._runtime.session.state.plan_slug) - self._pending_plan_activation_injection: bool = False - if self._plan_mode: - self._ensure_plan_session_id() - self._injection_providers: list[DynamicInjectionProvider] = [ - PlanModeInjectionProvider(), - YoloModeInjectionProvider(), - ] - self._hook_engine: HookEngine = HookEngine() - self._stop_hook_active: bool = False - if self._runtime.role == "root": - self._runtime.notifications.ack_ids("llm", extract_notification_ids(context.history)) - - # Bind plan mode state to tools that support it - self._bind_plan_mode_tools() - - self._slash_commands = self._build_slash_commands() - self._slash_command_map = self._index_slash_commands(self._slash_commands) - - @property - def name(self) -> str: - return self._agent.name - - @property - def model_name(self) -> str: - return self._runtime.llm.chat_provider.model_name if self._runtime.llm else "" - - @property - def model_capabilities(self) -> set[ModelCapability] | None: - if self._runtime.llm is None: - return None - return self._runtime.llm.capabilities - - @property - def is_yolo(self) -> bool: - """Whether yolo (auto-approve / non-interactive) mode is enabled.""" - return self._approval.is_yolo() - - @property - def plan_mode(self) -> bool: - """Whether plan mode (read-only research and planning) is active.""" - return self._plan_mode - - @property - def hook_engine(self) -> HookEngine: - return self._hook_engine - - def set_hook_engine(self, engine: HookEngine) -> None: - self._hook_engine = engine - if isinstance(self._agent.toolset, KimiToolset): - self._agent.toolset.set_hook_engine(engine) - - def add_injection_provider(self, provider: DynamicInjectionProvider) -> None: - """Register an additional dynamic injection provider.""" - self._injection_providers.append(provider) - - async def _collect_injections(self) -> list[DynamicInjection]: - """Collect dynamic injections from all registered providers.""" - injections: list[DynamicInjection] = [] - for provider in self._injection_providers: - try: - result = await provider.get_injections(self._context.history, self) - injections.extend(result) - except Exception: - logger.warning( - "injection provider %s failed", - type(provider).__name__, - exc_info=True, - ) - return injections - - def _bind_plan_mode_tools(self) -> None: - """Bind plan mode state to tools that support it.""" - if not isinstance(self._agent.toolset, KimiToolset): - return - - def checker() -> bool: - return self._plan_mode - - def path_getter() -> Path | None: - return self.get_plan_file_path() - - # WriteFile gets both checker and path_getter (for plan file auto-approve) - from kimi_cli.tools.file.write import WriteFile - - write_tool = self._agent.toolset.find("WriteFile") - if isinstance(write_tool, WriteFile): - write_tool.bind_plan_mode(checker, path_getter) - - from kimi_cli.tools.file.replace import StrReplaceFile - - replace_tool = self._agent.toolset.find("StrReplaceFile") - if isinstance(replace_tool, StrReplaceFile): - replace_tool.bind_plan_mode(checker, path_getter) - - # ExitPlanMode has a special bind() method - from kimi_cli.tools.plan import ExitPlanMode - - exit_tool = self._agent.toolset.find("ExitPlanMode") - if isinstance(exit_tool, ExitPlanMode): - exit_tool.bind(self.toggle_plan_mode, path_getter, checker, self._approval.is_yolo) - - # EnterPlanMode has a special bind() method - from kimi_cli.tools.plan.enter import EnterPlanMode - - enter_tool = self._agent.toolset.find("EnterPlanMode") - if isinstance(enter_tool, EnterPlanMode): - enter_tool.bind(self.toggle_plan_mode, path_getter, checker, self._approval.is_yolo) - - # AskUserQuestion — bind yolo checker for auto-dismiss - from kimi_cli.tools.ask_user import AskUserQuestion - - ask_tool = self._agent.toolset.find("AskUserQuestion") - if isinstance(ask_tool, AskUserQuestion): - ask_tool.bind_approval(self._approval.is_yolo) - - def _ensure_plan_session_id(self) -> None: - """Allocate a stable plan session ID on first activation.""" - if self._plan_session_id is None: - import uuid - - self._plan_session_id = uuid.uuid4().hex - self._runtime.session.state.plan_session_id = self._plan_session_id - # Compute and persist slug immediately so the path survives process restarts - from kimi_cli.tools.plan.heroes import get_or_create_slug - - slug = get_or_create_slug(self._plan_session_id) - self._runtime.session.state.plan_slug = slug - self._runtime.session.save_state() - - def _set_plan_mode(self, enabled: bool, *, source: Literal["manual", "tool"]) -> bool: - """Update plan mode state for either manual or tool-driven toggles.""" - if enabled == self._plan_mode: - return self._plan_mode - self._plan_mode = enabled - if enabled: - self._ensure_plan_session_id() - self._pending_plan_activation_injection = source == "manual" - else: - self._pending_plan_activation_injection = False - self._plan_session_id = None - self._runtime.session.state.plan_session_id = None - self._runtime.session.state.plan_slug = None - # Persist plan mode to session state so it survives process restarts - self._runtime.session.state.plan_mode = self._plan_mode - self._runtime.session.save_state() - return self._plan_mode - - def get_plan_file_path(self) -> Path | None: - """Get the plan file path for the current session.""" - if self._plan_session_id is None: - return None - from kimi_cli.tools.plan.heroes import get_plan_file_path - - return get_plan_file_path(self._plan_session_id) - - def read_current_plan(self) -> str | None: - """Read the current plan file content.""" - if self._plan_session_id is None: - return None - from kimi_cli.tools.plan.heroes import read_plan_file - - return read_plan_file(self._plan_session_id) - - def clear_current_plan(self) -> None: - """Delete the current plan file.""" - path = self.get_plan_file_path() - if path and path.exists(): - path.unlink() - - async def toggle_plan_mode(self) -> bool: - """Toggle plan mode on/off. Returns the new state. - - Tools are not hidden/unhidden — instead, each tool checks plan mode - state at call time and rejects if blocked. - Periodic reminders are handled by the dynamic injection system. - """ - return self._set_plan_mode(not self._plan_mode, source="tool") - - async def toggle_plan_mode_from_manual(self) -> bool: - """Toggle plan mode from UI/manual entry points (slash command, keybinding).""" - return self._set_plan_mode(not self._plan_mode, source="manual") - - async def set_plan_mode_from_manual(self, enabled: bool) -> bool: - """Set plan mode to a specific state from UI/manual entry points. - - Unlike toggle, this accepts the desired state directly, avoiding - race conditions when the caller already knows the target value. - """ - return self._set_plan_mode(enabled, source="manual") - - def consume_pending_plan_activation_injection(self) -> bool: - """Consume the next-step activation reminder scheduled by a manual toggle.""" - if not self._plan_mode or not self._pending_plan_activation_injection: - return False - self._pending_plan_activation_injection = False - return True - - @property - def thinking(self) -> bool | None: - """Whether thinking mode is enabled.""" - if self._runtime.llm is None: - return None - if thinking_effort := self._runtime.llm.chat_provider.thinking_effort: - return thinking_effort != "off" - return None - - @property - def status(self) -> StatusSnapshot: - token_count = self._context.token_count - max_size = self._runtime.llm.max_context_size if self._runtime.llm is not None else 0 - return StatusSnapshot( - context_usage=self._context_usage, - yolo_enabled=self._approval.is_yolo(), - plan_mode=self._plan_mode, - context_tokens=token_count, - max_context_tokens=max_size, - mcp_status=self._mcp_status_snapshot(), - ) - - @property - def agent(self) -> Agent: - return self._agent - - @property - def runtime(self) -> Runtime: - return self._runtime - - @property - def context(self) -> Context: - return self._context - - @property - def _context_usage(self) -> float: - if self._runtime.llm is not None: - return self._context.token_count / self._runtime.llm.max_context_size - return 0.0 - - @property - def wire_file(self) -> WireFile: - return self._runtime.session.wire_file - - def _mcp_status_snapshot(self): - if not isinstance(self._agent.toolset, KimiToolset): - return None - return self._agent.toolset.mcp_status_snapshot() - - async def start_background_mcp_loading(self) -> bool: - """Start deferred MCP loading, if any, without exposing toolset internals.""" - if not isinstance(self._agent.toolset, KimiToolset): - return False - return await self._agent.toolset.start_deferred_mcp_tool_loading() - - async def wait_for_background_mcp_loading(self) -> None: - """Wait for any in-flight MCP startup to finish.""" - if not isinstance(self._agent.toolset, KimiToolset): - return - await self._agent.toolset.wait_for_mcp_tools() - - async def _checkpoint(self): - await self._context.checkpoint(self._checkpoint_with_user_message) - - def steer(self, content: str | list[ContentPart]) -> None: - """Queue a steer message for injection into the current turn.""" - self._steer_queue.put_nowait(content) - - async def _consume_pending_steers(self) -> bool: - """Drain the steer queue and inject as follow-up user messages. - - Returns True if any steers were consumed. - """ - consumed = False - while not self._steer_queue.empty(): - content = self._steer_queue.get_nowait() - await self._inject_steer(content) - wire_send(SteerInput(user_input=content)) - consumed = True - return consumed - - async def _inject_steer(self, content: str | list[ContentPart]) -> None: - """Inject a single steer as a regular follow-up user message.""" - parts = cast( - list[ContentPart], - [TextPart(text=content)] if isinstance(content, str) else list(content), - ) - message = Message(role="user", content=parts) - if self._runtime.llm is None: - raise LLMNotSet() - if missing_caps := check_message(message, self._runtime.llm.capabilities): - raise LLMNotSupported(self._runtime.llm, list(missing_caps)) - await self._context.append_message(message) - - @property - def available_slash_commands(self) -> list[SlashCommand[Any]]: - return self._slash_commands - - async def run(self, user_input: str | list[ContentPart]): - approval_source_token = None - if get_current_approval_source_or_none() is None: - approval_source_token = set_current_approval_source( - ApprovalSource(kind="foreground_turn", id=uuid.uuid4().hex) - ) - try: - # Refresh OAuth tokens on each turn to avoid idle-time expirations. - await self._runtime.oauth.ensure_fresh(self._runtime) - - # Set session_id ContextVar for toolset hooks - from kimi_cli.soul.toolset import set_session_id - - set_session_id(self._runtime.session.id) - - # --- UserPromptSubmit hook --- - text_input_for_hook = user_input if isinstance(user_input, str) else "" - from kimi_cli.hooks import events - - hook_results = await self._hook_engine.trigger( - "UserPromptSubmit", - matcher_value=text_input_for_hook, - input_data=events.user_prompt_submit( - session_id=self._runtime.session.id, - cwd=str(Path.cwd()), - prompt=text_input_for_hook, - ), - ) - for result in hook_results: - if result.action == "block": - wire_send(TurnBegin(user_input=user_input)) - wire_send(TextPart(text=result.reason or "Prompt blocked by hook.")) - wire_send(TurnEnd()) - return - - wire_send(TurnBegin(user_input=user_input)) - user_message = Message(role="user", content=user_input) - text_input = user_message.extract_text(" ").strip() - - if command_call := parse_slash_command_call(text_input): - command = self._find_slash_command(command_call.name) - if command is None: - # this should not happen actually, the shell should have filtered it out - wire_send(TextPart(text=f'Unknown slash command "/{command_call.name}".')) - else: - ret = command.func(self, command_call.args) - if isinstance(ret, Awaitable): - await ret - elif self._loop_control.max_ralph_iterations != 0: - runner = FlowRunner.ralph_loop( - user_message, - self._loop_control.max_ralph_iterations, - ) - await runner.run(self, "") - else: - await self._turn(user_message) - - # --- Stop hook (max 1 re-trigger to prevent infinite loop) --- - if not self._stop_hook_active: - stop_results = await self._hook_engine.trigger( - "Stop", - input_data=events.stop( - session_id=self._runtime.session.id, - cwd=str(Path.cwd()), - stop_hook_active=False, - ), - ) - for result in stop_results: - if result.action == "block" and result.reason: - self._stop_hook_active = True - try: - await self._turn(Message(role="user", content=result.reason)) - finally: - self._stop_hook_active = False - break - - wire_send(TurnEnd()) - - # Auto-set title after first real turn (skip slash commands) - if not command_call: - session = self._runtime.session - if session.state.custom_title is None: - from textwrap import shorten - - title = shorten( - Message(role="user", content=user_input).extract_text(" "), - width=50, - ) - if title: - from kimi_cli.session_state import ( - load_session_state, - save_session_state, - ) - - # Read-modify-write: load fresh state to avoid - # overwriting concurrent web changes - fresh = load_session_state(session.dir) - if fresh.custom_title is None: - fresh.custom_title = title - save_session_state(fresh, session.dir) - session.state.custom_title = fresh.custom_title - finally: - if approval_source_token is not None: - reset_current_approval_source(approval_source_token) - - async def _turn(self, user_message: Message) -> TurnOutcome: - if self._runtime.llm is None: - raise LLMNotSet() - - if missing_caps := check_message(user_message, self._runtime.llm.capabilities): - raise LLMNotSupported(self._runtime.llm, list(missing_caps)) - - await self._checkpoint() # this creates the checkpoint 0 on first run - await self._context.append_message(user_message) - logger.debug("Appended user message to context") - return await self._agent_loop() - - def _build_slash_commands(self) -> list[SlashCommand[Any]]: - commands: list[SlashCommand[Any]] = list(soul_slash_registry.list_commands()) - seen_names = {cmd.name for cmd in commands} - - for skill in self._runtime.skills.values(): - if skill.type not in ("standard", "flow"): - continue - name = f"{SKILL_COMMAND_PREFIX}{skill.name}" - if name in seen_names: - logger.warning( - "Skipping skill slash command /{name}: name already registered", - name=name, - ) - continue - commands.append( - SlashCommand( - name=name, - func=self._make_skill_runner(skill), - description=skill.description or "", - aliases=[], - ) - ) - seen_names.add(name) - - for skill in self._runtime.skills.values(): - if skill.type != "flow": - continue - if skill.flow is None: - logger.warning("Flow skill {name} has no flow; skipping", name=skill.name) - continue - command_name = f"{FLOW_COMMAND_PREFIX}{skill.name}" - if command_name in seen_names: - logger.warning( - "Skipping prompt flow slash command /{name}: name already registered", - name=command_name, - ) - continue - runner = FlowRunner(skill.flow, name=skill.name) - commands.append( - SlashCommand( - name=command_name, - func=runner.run, - description=skill.description or "", - aliases=[], - ) - ) - seen_names.add(command_name) - - return commands - - @staticmethod - def _index_slash_commands( - commands: list[SlashCommand[Any]], - ) -> dict[str, SlashCommand[Any]]: - indexed: dict[str, SlashCommand[Any]] = {} - for command in commands: - indexed[command.name] = command - for alias in command.aliases: - indexed[alias] = command - return indexed - - def _find_slash_command(self, name: str) -> SlashCommand[Any] | None: - return self._slash_command_map.get(name) - - def _make_skill_runner(self, skill: Skill) -> Callable[[KimiSoul, str], None | Awaitable[None]]: - async def _run_skill(soul: KimiSoul, args: str, *, _skill: Skill = skill) -> None: - skill_text = await read_skill_text(_skill) - if skill_text is None: - wire_send( - TextPart(text=f'Failed to load skill "/{SKILL_COMMAND_PREFIX}{_skill.name}".') - ) - return - extra = args.strip() - if extra: - skill_text = f"{skill_text}\n\nUser request:\n{extra}" - await soul._turn(Message(role="user", content=skill_text)) - - _run_skill.__doc__ = skill.description - return _run_skill - - async def _agent_loop(self) -> TurnOutcome: - """The main agent loop for one run.""" - assert self._runtime.llm is not None - - # Discard any stale steers from a previous turn. - while not self._steer_queue.empty(): - self._steer_queue.get_nowait() - - if isinstance(self._agent.toolset, KimiToolset): - await self.start_background_mcp_loading() - loading = bool((snapshot := self._mcp_status_snapshot()) and snapshot.loading) - if loading: - wire_send(StatusUpdate(mcp_status=snapshot)) - wire_send(MCPLoadingBegin()) - try: - await self.wait_for_background_mcp_loading() - finally: - if loading: - wire_send(StatusUpdate(mcp_status=self._mcp_status_snapshot())) - wire_send(MCPLoadingEnd()) - - step_no = 0 - while True: - step_no += 1 - if step_no > self._loop_control.max_steps_per_turn: - raise MaxStepsReached(self._loop_control.max_steps_per_turn) - - wire_send(StepBegin(n=step_no)) - back_to_the_future: BackToTheFuture | None = None - step_outcome: StepOutcome | None = None - try: - # compact the context if needed - if should_auto_compact( - self._context.token_count_with_pending, - self._runtime.llm.max_context_size, - trigger_ratio=self._loop_control.compaction_trigger_ratio, - reserved_context_size=self._loop_control.reserved_context_size, - ): - logger.info("Context too long, compacting...") - await self.compact_context() - - logger.debug("Beginning step {step_no}", step_no=step_no) - await self._checkpoint() - self._denwa_renji.set_n_checkpoints(self._context.n_checkpoints) - step_outcome = await self._step() - except BackToTheFuture as e: - back_to_the_future = e - except Exception as e: - # any other exception should interrupt the step - wire_send(StepInterrupted()) - # --- StopFailure hook --- - from kimi_cli.hooks import events as _hook_events - - _hook_task = asyncio.create_task( - self._hook_engine.trigger( - "StopFailure", - matcher_value=type(e).__name__, - input_data=_hook_events.stop_failure( - session_id=self._runtime.session.id, - cwd=str(Path.cwd()), - error_type=type(e).__name__, - error_message=str(e), - ), - ) - ) - _hook_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) - # break the agent loop - raise - - if step_outcome is not None: - has_steers = await self._consume_pending_steers() - if has_steers: - continue # steers injected, force another LLM step - final_message = ( - step_outcome.assistant_message - if step_outcome.stop_reason == "no_tool_calls" - else None - ) - return TurnOutcome( - stop_reason=step_outcome.stop_reason, - final_message=final_message, - step_count=step_no, - ) - - if back_to_the_future is not None: - await self._context.revert_to(back_to_the_future.checkpoint_id) - await self._checkpoint() - await self._context.append_message(back_to_the_future.messages) - - # Consume any pending steers between steps - await self._consume_pending_steers() - - async def _step(self) -> StepOutcome | None: - """Run a single step and return a stop outcome, or None to continue.""" - # already checked in `run` - assert self._runtime.llm is not None - chat_provider = self._runtime.llm.chat_provider - - if self._runtime.role == "root": - - async def _append_notification(view: NotificationView) -> None: - await self._context.append_message(build_notification_message(view, self._runtime)) - # --- Notification hook --- - from kimi_cli.hooks import events - - _hook_task = asyncio.create_task( - self._hook_engine.trigger( - "Notification", - matcher_value=view.event.type, - input_data=events.notification( - session_id=self._runtime.session.id, - cwd=str(Path.cwd()), - sink="llm", - notification_type=view.event.type, - title=view.event.title, - body=view.event.body, - severity=view.event.severity, - ), - ) - ) - _hook_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) - - await self._runtime.notifications.deliver_pending( - "llm", - limit=4, - before_claim=self._runtime.background_tasks.reconcile, - on_notification=_append_notification, - ) - - # Dynamic injection - injections = await self._collect_injections() - if injections: - combined_reminders = "\n".join(system_reminder(inj.content).text for inj in injections) - await self._context.append_message( - Message( - role="user", - content=[TextPart(text=combined_reminders)], - ) - ) - - # Normalize: merge adjacent user messages for clean API input - effective_history = normalize_history(self._context.history) - - async def _run_step_once() -> StepResult: - # run an LLM step (may be interrupted) - return await kosong.step( - chat_provider, - self._agent.system_prompt, - self._agent.toolset, - effective_history, - on_message_part=wire_send, - on_tool_result=wire_send, - ) - - @tenacity.retry( - retry=retry_if_exception(self._is_retryable_error), - before_sleep=partial(self._retry_log, "step"), - wait=wait_exponential_jitter(initial=0.3, max=5, jitter=0.5), - stop=stop_after_attempt(self._loop_control.max_retries_per_step), - reraise=True, - ) - async def _kosong_step_with_retry() -> StepResult: - return await self._run_with_connection_recovery( - "step", - _run_step_once, - chat_provider=chat_provider, - ) - - result = await _kosong_step_with_retry() - logger.debug("Got step result: {result}", result=result) - status_update = StatusUpdate( - token_usage=result.usage, message_id=result.id, plan_mode=self._plan_mode - ) - if result.usage is not None: - # mark the token count for the context before the step - await self._context.update_token_count(result.usage.input) - snap = self.status - status_update.context_usage = snap.context_usage - status_update.context_tokens = snap.context_tokens - status_update.max_context_tokens = snap.max_context_tokens - wire_send(status_update) - - # wait for all tool results (may be interrupted) - plan_mode_before_tools = self._plan_mode - results = await result.tool_results() - logger.debug("Got tool results: {results}", results=results) - - # If a tool (EnterPlanMode/ExitPlanMode) changed plan mode during execution, - # send a corrected StatusUpdate so the client sees the up-to-date state. - if self._plan_mode != plan_mode_before_tools: - wire_send(StatusUpdate(plan_mode=self._plan_mode)) - - # shield the context manipulation from interruption - await asyncio.shield(self._grow_context(result, results)) - - rejected_errors = [ - result.return_value - for result in results - if isinstance(result.return_value, ToolRejectedError) - ] - if ( - rejected_errors - and not any(e.has_feedback for e in rejected_errors) - and self._runtime.role != "subagent" - ): - # Pure rejection (no user feedback) — stop the turn. - # Subagents skip this so the LLM can see the rejection and try - # an alternative approach instead of terminating immediately. - _ = self._denwa_renji.fetch_pending_dmail() - return StepOutcome(stop_reason="tool_rejected", assistant_message=result.message) - - # handle pending D-Mail - if dmail := self._denwa_renji.fetch_pending_dmail(): - assert dmail.checkpoint_id >= 0, "DenwaRenji guarantees checkpoint_id >= 0" - assert dmail.checkpoint_id < self._context.n_checkpoints, ( - "DenwaRenji guarantees checkpoint_id < n_checkpoints" - ) - # raise to let the main loop take us back to the future - raise BackToTheFuture( - dmail.checkpoint_id, - [ - Message( - role="user", - content=[ - system( - "You just got a D-Mail from your future self. " - "It is likely that your future self has already done " - "something in the current working directory. Please read " - "the D-Mail and decide what to do next. You MUST NEVER " - "mention to the user about this information. " - f"D-Mail content:\n\n{dmail.message.strip()}" - ) - ], - ) - ], - ) - - if result.tool_calls: - return None - return StepOutcome(stop_reason="no_tool_calls", assistant_message=result.message) - - async def _grow_context(self, result: StepResult, tool_results: list[ToolResult]): - logger.debug("Growing context with result: {result}", result=result) - - assert self._runtime.llm is not None - tool_messages = [tool_result_to_message(tr) for tr in tool_results] - for tm in tool_messages: - if missing_caps := check_message(tm, self._runtime.llm.capabilities): - logger.warning( - "Tool result message requires unsupported capabilities: {caps}", - caps=missing_caps, - ) - raise LLMNotSupported(self._runtime.llm, list(missing_caps)) - - await self._context.append_message(result.message) - if result.usage is not None: - await self._context.update_token_count(result.usage.total) - - logger.debug( - "Appending tool messages to context: {tool_messages}", tool_messages=tool_messages - ) - await self._context.append_message(tool_messages) - # token count of tool results are not available yet - - async def compact_context(self, custom_instruction: str = "") -> None: - """ - Compact the context. - - Raises: - LLMNotSet: When the LLM is not set. - ChatProviderError: When the chat provider returns an error. - """ - - chat_provider = self._runtime.llm.chat_provider if self._runtime.llm is not None else None - - async def _run_compaction_once() -> CompactionResult: - if self._runtime.llm is None: - raise LLMNotSet() - return await self._compaction.compact( - self._context.history, self._runtime.llm, custom_instruction=custom_instruction - ) - - @tenacity.retry( - retry=retry_if_exception(self._is_retryable_error), - before_sleep=partial(self._retry_log, "compaction"), - wait=wait_exponential_jitter(initial=0.3, max=5, jitter=0.5), - stop=stop_after_attempt(self._loop_control.max_retries_per_step), - reraise=True, - ) - async def _compact_with_retry() -> CompactionResult: - return await self._run_with_connection_recovery( - "compaction", - _run_compaction_once, - chat_provider=chat_provider, - ) - - trigger_reason = "manual" if custom_instruction else "auto" - from kimi_cli.hooks import events - - await self._hook_engine.trigger( - "PreCompact", - matcher_value=trigger_reason, - input_data=events.pre_compact( - session_id=self._runtime.session.id, - cwd=str(Path.cwd()), - trigger=trigger_reason, - token_count=self._context.token_count, - ), - ) - - wire_send(CompactionBegin()) - compaction_result = await _compact_with_retry() - await self._context.clear() - await self._context.write_system_prompt(self._agent.system_prompt) - await self._checkpoint() - await self._context.append_message(compaction_result.messages) - estimated_token_count = compaction_result.estimated_token_count - - if self._runtime.role == "root": - active_task_snapshot = build_active_task_snapshot(self._runtime.background_tasks) - if active_task_snapshot is not None: - active_task_message = Message( - role="user", - content=[ - system( - "The following background tasks are still active after compaction. " - "Use TaskList if you need to re-enumerate them later." - ), - TextPart(text=active_task_snapshot), - ], - ) - await self._context.append_message(active_task_message) - estimated_token_count += estimate_text_tokens([active_task_message]) - - # Estimate token count so context_usage is not reported as 0% - await self._context.update_token_count(estimated_token_count) - - wire_send(CompactionEnd()) - - _hook_task = asyncio.create_task( - self._hook_engine.trigger( - "PostCompact", - matcher_value=trigger_reason, - input_data=events.post_compact( - session_id=self._runtime.session.id, - cwd=str(Path.cwd()), - trigger=trigger_reason, - estimated_token_count=estimated_token_count, - ), - ) - ) - _hook_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) - - @staticmethod - def _is_retryable_error(exception: BaseException) -> bool: - if isinstance(exception, (APIConnectionError, APITimeoutError)): - return not bool(getattr(exception, "_kimi_recovery_exhausted", False)) - if isinstance(exception, APIEmptyResponseError): - return True - return isinstance(exception, APIStatusError) and exception.status_code in ( - 429, # Too Many Requests - 500, # Internal Server Error - 502, # Bad Gateway - 503, # Service Unavailable - 504, # Gateway Timeout - ) - - async def _run_with_connection_recovery( - self, - name: str, - operation: Callable[[], Awaitable[Any]], - *, - chat_provider: object | None = None, - ) -> Any: - try: - return await operation() - except (APIConnectionError, APITimeoutError) as error: - if not isinstance(chat_provider, RetryableChatProvider): - raise - try: - recovered = chat_provider.on_retryable_error(error) - except Exception: - logger.exception( - "Failed to recover chat provider during {name} after {error_type}.", - name=name, - error_type=type(error).__name__, - ) - raise - if not recovered: - raise - logger.info( - "Recovered chat provider during {name} after {error_type}; retrying once.", - name=name, - error_type=type(error).__name__, - ) - try: - return await operation() - except (APIConnectionError, APITimeoutError) as second_error: - second_error._kimi_recovery_exhausted = True # type: ignore[attr-defined] - raise - - @staticmethod - def _retry_log(name: str, retry_state: RetryCallState): - logger.info( - "Retrying {name} for the {n} time. Waiting {sleep} seconds.", - name=name, - n=retry_state.attempt_number, - sleep=retry_state.next_action.sleep - if retry_state.next_action is not None - else "unknown", - ) - - -class BackToTheFuture(Exception): - """ - Raise when we need to revert the context to a previous checkpoint. - The main agent loop should catch this exception and handle it. - """ - - def __init__(self, checkpoint_id: int, messages: Sequence[Message]): - self.checkpoint_id = checkpoint_id - self.messages = messages - - -class FlowRunner: - def __init__( - self, - flow: Flow, - *, - name: str | None = None, - max_moves: int = DEFAULT_MAX_FLOW_MOVES, - ) -> None: - self._flow = flow - self._name = name - self._max_moves = max_moves - - @staticmethod - def ralph_loop( - user_message: Message, - max_ralph_iterations: int, - ) -> FlowRunner: - prompt_content = list(user_message.content) - prompt_text = Message(role="user", content=prompt_content).extract_text(" ").strip() - total_runs = max_ralph_iterations + 1 - if max_ralph_iterations < 0: - total_runs = 1000000000000000 # effectively infinite - - nodes: dict[str, FlowNode] = { - "BEGIN": FlowNode(id="BEGIN", label="BEGIN", kind="begin"), - "END": FlowNode(id="END", label="END", kind="end"), - } - outgoing: dict[str, list[FlowEdge]] = {"BEGIN": [], "END": []} - - nodes["R1"] = FlowNode(id="R1", label=prompt_content, kind="task") - nodes["R2"] = FlowNode( - id="R2", - label=( - f"{prompt_text}. (You are running in an automated loop where the same " - "prompt is fed repeatedly. Only choose STOP when the task is fully complete. " - "Including it will stop further iterations. If you are not 100% sure, " - "choose CONTINUE.)" - ).strip(), - kind="decision", - ) - outgoing["R1"] = [] - outgoing["R2"] = [] - - outgoing["BEGIN"].append(FlowEdge(src="BEGIN", dst="R1", label=None)) - outgoing["R1"].append(FlowEdge(src="R1", dst="R2", label=None)) - outgoing["R2"].append(FlowEdge(src="R2", dst="R2", label="CONTINUE")) - outgoing["R2"].append(FlowEdge(src="R2", dst="END", label="STOP")) - - flow = Flow(nodes=nodes, outgoing=outgoing, begin_id="BEGIN", end_id="END") - max_moves = total_runs - return FlowRunner(flow, max_moves=max_moves) - - async def run(self, soul: KimiSoul, args: str) -> None: - if args.strip(): - command = f"/{FLOW_COMMAND_PREFIX}{self._name}" if self._name else "/flow" - logger.warning("Agent flow {command} ignores args: {args}", command=command, args=args) - return - - current_id = self._flow.begin_id - moves = 0 - total_steps = 0 - while True: - node = self._flow.nodes[current_id] - edges = self._flow.outgoing.get(current_id, []) - - if node.kind == "end": - logger.info("Agent flow reached END node {node_id}", node_id=current_id) - return - - if node.kind == "begin": - if not edges: - logger.error( - 'Agent flow BEGIN node "{node_id}" has no outgoing edges; stopping.', - node_id=node.id, - ) - return - current_id = edges[0].dst - continue - - if moves >= self._max_moves: - raise MaxStepsReached(total_steps) - next_id, steps_used = await self._execute_flow_node(soul, node, edges) - total_steps += steps_used - if next_id is None: - return - moves += 1 - current_id = next_id - - async def _execute_flow_node( - self, - soul: KimiSoul, - node: FlowNode, - edges: list[FlowEdge], - ) -> tuple[str | None, int]: - if not edges: - logger.error( - 'Agent flow node "{node_id}" has no outgoing edges; stopping.', - node_id=node.id, - ) - return None, 0 - - base_prompt = self._build_flow_prompt(node, edges) - prompt = base_prompt - steps_used = 0 - while True: - result = await self._flow_turn(soul, prompt) - steps_used += result.step_count - if result.stop_reason == "tool_rejected": - logger.error("Agent flow stopped after tool rejection.") - return None, steps_used - - if node.kind != "decision": - return edges[0].dst, steps_used - - choice = ( - parse_choice(result.final_message.extract_text(" ")) - if result.final_message - else None - ) - next_id = self._match_flow_edge(edges, choice) - if next_id is not None: - return next_id, steps_used - - options = ", ".join(edge.label or "" for edge in edges) - logger.warning( - "Agent flow invalid choice. Got: {choice}. Available: {options}.", - choice=choice or "", - options=options, - ) - prompt = ( - f"{base_prompt}\n\n" - "Your last response did not include a valid choice. " - "Reply with one of the choices using ...." - ) - - @staticmethod - def _build_flow_prompt(node: FlowNode, edges: list[FlowEdge]) -> str | list[ContentPart]: - if node.kind != "decision": - return node.label - - if not isinstance(node.label, str): - label_text = Message(role="user", content=node.label).extract_text(" ") - else: - label_text = node.label - choices = [edge.label for edge in edges if edge.label] - lines = [ - label_text, - "", - "Available branches:", - *(f"- {choice}" for choice in choices), - "", - "Reply with a choice using ....", - ] - return "\n".join(lines) - - @staticmethod - def _match_flow_edge(edges: list[FlowEdge], choice: str | None) -> str | None: - if not choice: - return None - for edge in edges: - if edge.label == choice: - return edge.dst - return None - - @staticmethod - async def _flow_turn( - soul: KimiSoul, - prompt: str | list[ContentPart], - ) -> TurnOutcome: - wire_send(TurnBegin(user_input=prompt)) - res = await soul._turn(Message(role="user", content=prompt)) # type: ignore[reportPrivateUsage] - wire_send(TurnEnd()) - return res diff --git a/src/kimi_cli/soul/kimisoul.ts b/src/kimi_cli/soul/kimisoul.ts new file mode 100644 index 000000000..58748699a --- /dev/null +++ b/src/kimi_cli/soul/kimisoul.ts @@ -0,0 +1,970 @@ +/** + * KimiSoul — corresponds to Python soul/kimisoul.py + * The core agent loop: receive input → call LLM → execute tools → repeat. + */ + +import type { Message, ContentPart, ToolCall, TokenUsage, StatusSnapshot, SlashCommand, ModelCapability } from "../types.ts"; +import type { ToolResult } from "../tools/types.ts"; +import type { LLM, StreamChunk, ChatOptions } from "../llm.ts"; +import type { HookEngine } from "../hooks/engine.ts"; +import type { Config } from "../config.ts"; +import type { Session } from "../session.ts"; +import { Context } from "./context.ts"; +import { Agent, type Runtime } from "./agent.ts"; +import { KimiToolset } from "./toolset.ts"; +import { SlashCommandRegistry } from "./slash.ts"; +import { compactContext, shouldCompact } from "./compaction.ts"; +import { toolResultMessage, systemReminder } from "./message.ts"; +import type { DynamicInjection, DynamicInjectionProvider } from "./dynamic_injection.ts"; +import { normalizeHistory } from "./dynamic_injection.ts"; +import { PlanModeInjectionProvider } from "./dynamic_injections/plan_mode.ts"; +import { YoloModeInjectionProvider } from "./dynamic_injections/yolo_mode.ts"; +import { handleNew, handleSessions, handleTitle } from "../ui/shell/commands/session.ts"; +import { handleModel } from "../ui/shell/commands/model.ts"; +import { handleLogin, handleLogout, createLoginPanel } from "../ui/shell/commands/login.ts"; +import { handleHooks, handleMcp, handleDebug, handleChangelog } from "../ui/shell/commands/info.ts"; +import { handleExport, handleImport } from "../ui/shell/commands/export_import.ts"; +import { handleWeb, handleVis, handleReload, handleTask } from "../ui/shell/commands/misc.ts"; +import { handleUsage } from "../ui/shell/commands/usage.ts"; +import { handleFeedback } from "../ui/shell/commands/feedback.ts"; +import { handleEditor } from "../ui/shell/commands/editor.ts"; +import { handleInit } from "../ui/shell/commands/init.ts"; +import { handleAddDir } from "../ui/shell/commands/add_dir.ts"; +import { logger } from "../utils/logging.ts"; + +// ── Errors ───────────────────────────────────────── + +export class MaxStepsReached extends Error { + readonly maxSteps: number; + constructor(maxSteps: number) { + super(`Reached max steps per turn: ${maxSteps}`); + this.name = "MaxStepsReached"; + this.maxSteps = maxSteps; + } +} + +// ── Wire event callbacks ──────────────────────────── + +export interface SoulCallbacks { + onTurnBegin?: (userInput: string | ContentPart[]) => void; + onTurnEnd?: () => void; + onStepBegin?: (stepNum: number) => void; + onStepInterrupted?: () => void; + onTextDelta?: (text: string) => void; + onThinkDelta?: (text: string) => void; + onToolCall?: (toolCall: ToolCall) => void; + onToolResult?: (toolCallId: string, result: ToolResult) => void; + onStatusUpdate?: (status: Partial) => void; + onCompactionBegin?: () => void; + onCompactionEnd?: () => void; + onError?: (error: Error) => void; + onNotification?: (title: string, body: string) => void; +} + +// ── Retry helpers ─────────────────────────────────── + +function isRetryableError(err: unknown): boolean { + if (!(err instanceof Error)) return false; + const msg = err.message.toLowerCase(); + // HTTP status codes that are retryable + if (/\b(429|500|502|503|504)\b/.test(msg)) return true; + // Network errors + if (msg.includes("timeout") || msg.includes("econnreset") || + msg.includes("econnrefused") || msg.includes("connection") || + msg.includes("network") || msg.includes("fetch failed") || + msg.includes("socket hang up")) return true; + // Empty response + if (msg.includes("empty response") || msg.includes("no body")) return true; + return false; +} + +async function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +/** Retry with exponential backoff and jitter. */ +async function withRetry( + fn: () => Promise, + maxRetries: number, + label: string, +): Promise { + let lastError: unknown; + for (let attempt = 0; attempt <= maxRetries; attempt++) { + try { + return await fn(); + } catch (err) { + lastError = err; + if (attempt >= maxRetries || !isRetryableError(err)) { + throw err; + } + const baseDelay = Math.min(300 * Math.pow(2, attempt), 5000); + const jitter = Math.random() * 500; + const delay = baseDelay + jitter; + logger.warn(`${label}: retryable error (attempt ${attempt + 1}/${maxRetries}), retrying in ${Math.round(delay)}ms: ${err instanceof Error ? err.message : err}`); + await sleep(delay); + } + } + throw lastError; +} + +// ── KimiSoul ──────────────────────────────────────── + +export class KimiSoul { + private agent: Agent; + private context: Context; + private callbacks: SoulCallbacks; + private abortController: AbortController | null = null; + private _isRunning = false; + private _planMode = false; + private _planSessionId: string | null = null; + private _pendingPlanActivationInjection = false; + private _injectionProviders: DynamicInjectionProvider[]; + private _stepCount = 0; + private _totalUsage: TokenUsage = { + inputTokens: 0, + outputTokens: 0, + }; + // Steer queue: messages injected during a running turn + private _pendingSteers: Message[] = []; + // Track whether any tool was rejected without feedback this turn + private _toolRejectedNoFeedback = false; + + constructor(opts: { + agent: Agent; + context: Context; + callbacks?: SoulCallbacks; + }) { + this.agent = opts.agent; + this.context = opts.context; + this.callbacks = opts.callbacks ?? {}; + + // Restore plan mode from session state + this._planMode = opts.agent.runtime.session.state.plan_mode ?? false; + this._planSessionId = opts.agent.runtime.session.state.plan_session_id ?? null; + if (this._planMode) { + this._ensurePlanSessionId(); + } + + // Initialize dynamic injection providers + this._injectionProviders = [ + new PlanModeInjectionProvider(), + new YoloModeInjectionProvider(), + ]; + } + + // ── Properties ─────────────────────────────────── + + get runtime(): Runtime { + return this.agent.runtime; + } + + get config(): Config { + return this.agent.runtime.config; + } + + get session(): Session { + return this.agent.runtime.session; + } + + get ctx(): Context { + return this.context; + } + + get name(): string { + return this.agent.name; + } + + get modelName(): string { + return this.agent.modelName; + } + + get modelCapabilities(): Set | null { + return this.agent.modelCapabilities; + } + + get thinking(): boolean { + return this.agent.runtime.llm?.hasCapability("thinking") ?? false; + } + + get isRunning(): boolean { + return this._isRunning; + } + + get planMode(): boolean { + return this._planMode; + } + + get isYolo(): boolean { + return this.agent.runtime.approval.isYolo(); + } + + get status(): StatusSnapshot { + const llm = this.agent.runtime.llm; + const maxCtx = llm?.maxContextSize ?? 0; + const tokenCount = this.context.tokenCountWithPending; + return { + contextUsage: maxCtx > 0 ? tokenCount / maxCtx : null, + contextTokens: tokenCount, + maxContextTokens: maxCtx, + tokenUsage: this._totalUsage, + planMode: this._planMode, + mcpStatus: null, + }; + } + + get hookEngine(): HookEngine { + return this.agent.runtime.hookEngine; + } + + /** Push a notification to the UI (appears in message list). */ + notify(title: string, body: string): void { + this.callbacks.onNotification?.(title, body); + } + + get availableSlashCommands(): SlashCommand[] { + return this.agent.slashCommands.list(); + } + + // ── Plan mode ──────────────────────────────────── + + /** Toggle plan mode from a tool call. Returns the new state. */ + async togglePlanMode(): Promise { + return this._setPlanMode(!this._planMode, "tool"); + } + + /** Toggle plan mode from a manual entry point (slash command, keybinding). */ + async togglePlanModeFromManual(): Promise { + return this._setPlanMode(!this._planMode, "manual"); + } + + /** Set plan mode to a specific state from manual entry points. */ + async setPlanModeFromManual(enabled: boolean): Promise { + return this._setPlanMode(enabled, "manual"); + } + + setPlanMode(on: boolean): void { + this._setPlanMode(on, "tool"); + } + + private _setPlanMode(enabled: boolean, source: "manual" | "tool"): boolean { + if (enabled === this._planMode) return this._planMode; + this._planMode = enabled; + if (enabled) { + this._ensurePlanSessionId(); + this._pendingPlanActivationInjection = source === "manual"; + } else { + this._pendingPlanActivationInjection = false; + this._planSessionId = null; + this.agent.runtime.session.state.plan_session_id = null; + } + // Persist to session state + this.agent.runtime.session.state.plan_mode = this._planMode; + this.callbacks.onStatusUpdate?.({ planMode: this._planMode }); + return this._planMode; + } + + private _ensurePlanSessionId(): void { + if (this._planSessionId == null) { + this._planSessionId = crypto.randomUUID().replace(/-/g, ""); + this.agent.runtime.session.state.plan_session_id = this._planSessionId; + } + } + + /** Get the plan file path for the current session. */ + getPlanFilePath(): string | null { + if (this._planSessionId == null) return null; + const workDir = this.agent.runtime.session.workDir; + return `${workDir}/.kimi/plans/${this._planSessionId}.md`; + } + + /** Read the current plan file content. */ + readCurrentPlan(): string | null { + const path = this.getPlanFilePath(); + if (!path) return null; + try { + const file = Bun.file(path); + // Synchronous existence check is not available — use a simple approach + return file.size > 0 ? null : null; // Will be refined when plan tools exist + } catch { + return null; + } + } + + /** Delete the current plan file. */ + clearCurrentPlan(): void { + const path = this.getPlanFilePath(); + if (!path) return; + try { + const { unlinkSync } = require("node:fs"); + unlinkSync(path); + } catch { + // File may not exist + } + } + + /** Consume the next-step activation reminder scheduled by a manual toggle. */ + consumePendingPlanActivationInjection(): boolean { + if (!this._planMode || !this._pendingPlanActivationInjection) return false; + this._pendingPlanActivationInjection = false; + return true; + } + + /** Register an additional dynamic injection provider. */ + addInjectionProvider(provider: DynamicInjectionProvider): void { + this._injectionProviders.push(provider); + } + + // ── Yolo mode ──────────────────────────────────── + + setYolo(yolo: boolean): void { + this.agent.runtime.approval.setYolo(yolo); + } + + // ── Main entry point ───────────────────────────── + + async run(userInput: string | ContentPart[]): Promise { + if (this._isRunning) { + logger.warn("Soul is already running, ignoring input"); + return; + } + + // Check for slash commands + if (typeof userInput === "string" && userInput.trim().startsWith("/")) { + const handled = await this.agent.slashCommands.execute(userInput); + if (handled) return; + } + + this._isRunning = true; + this.abortController = new AbortController(); + this._toolRejectedNoFeedback = false; + + try { + this.callbacks.onTurnBegin?.(userInput); + this._wireLog({ type: "turn_begin", user_input: typeof userInput === "string" ? userInput : "[complex]" }); + await this._turn(userInput); + this._wireLog({ type: "turn_end" }); + this.callbacks.onTurnEnd?.(); + } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + logger.info("Turn aborted"); + this.callbacks.onStepInterrupted?.(); + } else if (err instanceof MaxStepsReached) { + logger.warn(err.message); + this.callbacks.onError?.(err); + this.callbacks.onTurnEnd?.(); + } else { + logger.error(`Turn error: ${err}`); + this.callbacks.onError?.( + err instanceof Error ? err : new Error(String(err)), + ); + } + } finally { + this._isRunning = false; + this.abortController = null; + this._pendingSteers = []; + } + } + + /** Abort the current turn. */ + abort(): void { + this.abortController?.abort(); + } + + /** Steer: inject follow-up input during a running turn. */ + async steer(content: string | ContentPart[]): Promise { + if (!this._isRunning) return; + const msg: Message = { + role: "user", + content: typeof content === "string" ? content : content, + }; + this._pendingSteers.push(msg); + } + + // ── Turn execution ────────────────────────────── + + private async _turn(userInput: string | ContentPart[]): Promise { + // Append user message + const userMsg: Message = { + role: "user", + content: typeof userInput === "string" ? userInput : userInput, + }; + await this.context.appendMessage(userMsg); + + // Agent loop + await this._agentLoop(); + } + + // ── Agent loop ────────────────────────────────── + + private async _agentLoop(): Promise { + const maxSteps = this.agent.runtime.config.loop_control.max_steps_per_turn; + this._stepCount = 0; + + while (true) { + // Check max steps — raise exception like Python + if (this._stepCount >= maxSteps) { + throw new MaxStepsReached(maxSteps); + } + + // Check abort + if (this.abortController?.signal.aborted) { + this.callbacks.onStepInterrupted?.(); + break; + } + + // Consume pending steers + const hadSteers = await this._consumePendingSteers(); + + // Check if compaction needed + await this._maybeCompact(); + + // Execute one step + this._stepCount++; + this.callbacks.onStepBegin?.(this._stepCount); + + const maxRetries = this.agent.runtime.config.loop_control.max_retries_per_step; + const toolCalls = await withRetry( + () => this._step(), + maxRetries, + `step ${this._stepCount}`, + ); + + // No tool calls = turn is done (unless steers are pending) + if (toolCalls.length === 0) { + // Check for pending steers — if any, force another iteration + if (this._pendingSteers.length > 0) { + continue; + } + break; + } + + // Execute tools and collect results — shielded from abort + await this._executeToolsShielded(toolCalls); + + // If a tool was rejected without feedback, stop the turn + if (this._toolRejectedNoFeedback && this.agent.runtime.role !== "subagent") { + logger.info("Turn stopped: tool was rejected without feedback"); + break; + } + } + } + + /** + * Execute tools and append results to context. + * This is "shielded" from abort to keep context consistent — + * once we start appending, we finish even if abort fires. + */ + private async _executeToolsShielded(toolCalls: ToolCall[]): Promise { + for (const tc of toolCalls) { + // Check abort before each tool, but don't interrupt mid-append + if (this.abortController?.signal.aborted) break; + + const result = await this.agent.toolset.handle(tc); + + // Detect tool rejection without feedback + if (result.isError && result.message?.includes("rejected by the user")) { + // If the rejection message is just the standard template, no user feedback + if (!result.extras?.userFeedback) { + this._toolRejectedNoFeedback = true; + } + } + + // Build tool result message and append to context + const resultMsg = toolResultMessage({ + toolCallId: tc.id, + output: result.output, + isError: result.isError, + message: result.message, + }); + // Append atomically — even if abort was signaled during tool execution, + // we still append the result to keep context consistent + await this.context.appendMessage(resultMsg); + } + } + + /** Drain the steer queue into context. Returns true if any steers were consumed. */ + private async _consumePendingSteers(): Promise { + if (this._pendingSteers.length === 0) return false; + const steers = this._pendingSteers.splice(0); + for (const msg of steers) { + await this.context.appendMessage(msg); + } + return true; + } + + // ── Single step ───────────────────────────────── + + private async _step(): Promise { + const llm = this.agent.runtime.llm; + if (!llm) { + throw new Error("No LLM configured"); + } + + // Build messages for LLM — normalize to merge adjacent user messages + const rawMessages = [...this.context.history] as Message[]; + + // Collect dynamic injections from providers (plan mode, yolo mode, etc.) + const injections = await this._collectInjections(); + if (injections.length > 0) { + // Add as the last user message wrapped in system-reminder tags + const injectionContent = injections + .map((inj) => `\n${inj.content}\n`) + .join("\n\n"); + rawMessages.push({ + role: "user", + content: injectionContent, + }); + } + + // Normalize: merge adjacent user messages to avoid API errors + const messages = normalizeHistory(rawMessages); + + // Call LLM + const chatOptions: ChatOptions = { + system: this.agent.systemPrompt, + tools: this.agent.toolset.definitions(), + signal: this.abortController?.signal, + }; + + let assistantText = ""; + let thinkText = ""; + const toolCalls: ToolCall[] = []; + let usage: TokenUsage | null = null; + + const stream = llm.chat(messages, chatOptions); + + for await (const chunk of stream) { + switch (chunk.type) { + case "text": + assistantText += chunk.text; + this.callbacks.onTextDelta?.(chunk.text); + break; + + case "think": + thinkText += chunk.text; + this.callbacks.onThinkDelta?.(chunk.text); + break; + + case "tool_call": + toolCalls.push({ + id: chunk.id, + name: chunk.name, + arguments: chunk.arguments, + }); + this.callbacks.onToolCall?.({ + id: chunk.id, + name: chunk.name, + arguments: chunk.arguments, + }); + break; + + case "usage": + usage = chunk.usage; + this._totalUsage = { + inputTokens: + this._totalUsage.inputTokens + chunk.usage.inputTokens, + outputTokens: + this._totalUsage.outputTokens + chunk.usage.outputTokens, + }; + break; + + case "done": + break; + } + } + + // Build assistant message content + const contentParts: ContentPart[] = []; + if (assistantText) { + contentParts.push({ type: "text", text: assistantText }); + } + for (const tc of toolCalls) { + contentParts.push({ + type: "tool_use", + id: tc.id, + name: tc.name, + input: JSON.parse(tc.arguments || "{}"), + }); + } + + // Append assistant message to context + // Note: reasoning_content (thinkText) is stored separately in the message + // and will be serialized as reasoning_content field for the API + if (contentParts.length > 0) { + const assistantMsg: Message & { reasoning_content?: string } = { + role: "assistant", + content: contentParts, + }; + // Preserve thinking content so it can be sent back to the model + if (thinkText) { + assistantMsg.reasoning_content = thinkText; + } + await this.context.appendMessage(assistantMsg); + } + + // Update token count + if (usage) { + await this.context.updateTokenCount(usage); + } + + // Wire log step results + if (assistantText) { + this._wireLog({ type: "text_part", text: assistantText }); + } + for (const tc of toolCalls) { + this._wireLog({ type: "tool_call", name: tc.name, id: tc.id }); + } + + // Send status update + this.callbacks.onStatusUpdate?.(this.status); + + return toolCalls; + } + + // ── Dynamic injections ────────────────────────── + + /** Collect dynamic injections from all registered providers. */ + private async _collectInjections(): Promise { + const injections: DynamicInjection[] = []; + for (const provider of this._injectionProviders) { + try { + const result = await provider.getInjections(this.context.history, this); + injections.push(...result); + } catch (err) { + logger.warn(`Injection provider failed: ${err}`); + } + } + return injections; + } + + // ── Compaction ────────────────────────────────── + + private async _maybeCompact(): Promise { + const llm = this.agent.runtime.llm; + if (!llm) return; + + const lc = this.agent.runtime.config.loop_control; + + if ( + shouldCompact( + this.context.tokenCountWithPending, + llm.maxContextSize, + lc.reserved_context_size, + lc.compaction_trigger_ratio, + ) + ) { + this.callbacks.onCompactionBegin?.(); + try { + await compactContext(this.context, llm); + } catch (err) { + logger.error(`Compaction failed: ${err}`); + } + this.callbacks.onCompactionEnd?.(); + } + } + + // ── Slash command wiring ──────────────────────── + + wireSlashCommands(): void { + const registry = this.agent.slashCommands; + + // Wire /clear + const clearCmd = registry.get("clear"); + if (clearCmd) { + clearCmd.handler = async () => { + await this.context.clear(); + await this.context.writeSystemPrompt(this.agent.systemPrompt); + logger.info("Context cleared"); + this.callbacks.onStatusUpdate?.(this.status); + }; + } + + // Wire /compact + const compactCmd = registry.get("compact"); + if (compactCmd) { + compactCmd.handler = async (args: string) => { + if (this.context.nCheckpoints === 0) { + logger.info("The context is empty."); + return; + } + const llm = this.agent.runtime.llm; + if (!llm) return; + logger.info("Running `/compact`"); + await compactContext(this.context, llm, { focus: args || undefined }); + this.callbacks.onStatusUpdate?.(this.status); + }; + } + + // Wire /yolo + const yoloCmd = registry.get("yolo"); + if (yoloCmd) { + yoloCmd.handler = async () => { + if (this.agent.runtime.approval.isYolo()) { + this.agent.runtime.approval.setYolo(false); + logger.info("YOLO mode: OFF"); + } else { + this.agent.runtime.approval.setYolo(true); + logger.info("YOLO mode: ON"); + } + }; + } + + // Wire /plan with subcmd support (on/off/view/clear/toggle) + const planCmd = registry.get("plan"); + if (planCmd) { + planCmd.handler = async (args: string) => { + const subcmd = args.trim().toLowerCase(); + if (subcmd === "on") { + if (!this._planMode) await this.togglePlanModeFromManual(); + const planPath = this.getPlanFilePath(); + logger.info(`Plan mode ON. Plan file: ${planPath}`); + this.callbacks.onStatusUpdate?.({ planMode: this._planMode }); + } else if (subcmd === "off") { + if (this._planMode) await this.togglePlanModeFromManual(); + logger.info("Plan mode OFF. All tools are now available."); + this.callbacks.onStatusUpdate?.({ planMode: this._planMode }); + } else if (subcmd === "view") { + const content = this.readCurrentPlan(); + if (content) { + logger.info(content); + } else { + logger.info("No plan file found for this session."); + } + } else if (subcmd === "clear") { + this.clearCurrentPlan(); + logger.info("Plan cleared."); + } else { + // Default: toggle + const newState = await this.togglePlanModeFromManual(); + if (newState) { + const planPath = this.getPlanFilePath(); + logger.info(`Plan mode ON. Write your plan to: ${planPath}`); + } else { + logger.info("Plan mode OFF. All tools are now available."); + } + this.callbacks.onStatusUpdate?.({ planMode: this._planMode }); + } + }; + } + + // Wire /model + const modelCmd = registry.get("model"); + if (modelCmd) { + modelCmd.handler = async () => { + await handleModel(this.agent.runtime.config, { isFromDefaultLocation: true, sourceFile: null }); + }; + } + + // Wire /export + const exportCmd = registry.get("export"); + if (exportCmd) { + exportCmd.handler = async (args: string) => { + await handleExport(this.context, this.agent.runtime.session, args); + }; + } + + // Wire /import + const importCmd = registry.get("import"); + if (importCmd) { + importCmd.handler = async (args: string) => { + await handleImport(this.context, this.agent.runtime.session, args); + }; + } + + // Wire /web + const webCmd = registry.get("web"); + if (webCmd) { + webCmd.handler = async () => { + handleWeb(this.agent.runtime.session.id); + }; + } + + // Wire /vis + const visCmd = registry.get("vis"); + if (visCmd) { + visCmd.handler = async () => { + handleVis(this.agent.runtime.session.id); + }; + } + + // Wire /reload + const reloadCmd = registry.get("reload"); + if (reloadCmd) { + reloadCmd.handler = async () => { + handleReload(); + }; + } + + // Wire /task + const taskCmd = registry.get("task"); + if (taskCmd) { + taskCmd.handler = async () => { + handleTask(); + }; + } + + // Wire /login + const loginCmd = registry.get("login"); + if (loginCmd) { + const notify = (t: string, b: string) => this.notify(t, b); + loginCmd.handler = async () => { + await handleLogin(this.agent.runtime.config, notify); + }; + loginCmd.panel = () => createLoginPanel(this.agent.runtime.config, notify); + } + + // Wire /logout + const logoutCmd = registry.get("logout"); + if (logoutCmd) { + logoutCmd.handler = async () => { + await handleLogout(this.agent.runtime.config, (t, b) => this.notify(t, b)); + }; + } + + // Wire /usage + const usageCmd = registry.get("usage"); + if (usageCmd) { + usageCmd.handler = async () => { + await handleUsage(this.agent.runtime.config, this.agent.runtime.config.default_model || undefined); + }; + } + + // Wire /feedback + const feedbackCmd = registry.get("feedback"); + if (feedbackCmd) { + feedbackCmd.handler = async (args: string) => { + await handleFeedback( + this.agent.runtime.config, + args, + this.agent.runtime.session.id, + this.agent.runtime.config.default_model || undefined, + ); + }; + } + + // Wire /editor + const editorCmd = registry.get("editor"); + if (editorCmd) { + editorCmd.handler = async (args: string) => { + await handleEditor(this.agent.runtime.config, { isFromDefaultLocation: true, sourceFile: null }, args); + }; + } + + // Wire /hooks + const hooksCmd = registry.get("hooks"); + if (hooksCmd) { + hooksCmd.handler = async () => { + handleHooks(this.agent.runtime.hookEngine); + }; + } + + // Wire /mcp + const mcpCmd = registry.get("mcp"); + if (mcpCmd) { + mcpCmd.handler = async () => { + handleMcp(this.agent.runtime.config); + }; + } + + // Wire /debug + const debugCmd = registry.get("debug"); + if (debugCmd) { + debugCmd.handler = async () => { + handleDebug(this.context); + }; + } + + // Wire /changelog + const changelogCmd = registry.get("changelog"); + if (changelogCmd) { + changelogCmd.handler = async () => { + handleChangelog(); + }; + } + + // Wire /new + const newCmd = registry.get("new"); + if (newCmd) { + newCmd.handler = async () => { + await handleNew(this.agent.runtime.session); + }; + } + + // Wire /sessions + const sessionsCmd = registry.get("sessions"); + if (sessionsCmd) { + sessionsCmd.handler = async () => { + await handleSessions(this.agent.runtime.session); + }; + } + + // Wire /title + const titleCmd = registry.get("title"); + if (titleCmd) { + titleCmd.handler = async (args: string) => { + await handleTitle(this.agent.runtime.session, args); + }; + } + + // Wire /init + const initCmd = registry.get("init"); + if (initCmd) { + initCmd.handler = async () => { + const result = await handleInit(this.agent.runtime.session.workDir); + if (result) { + // Inject the generated AGENTS.md into context so the LLM knows about it + await this.context.appendMessage({ + role: "user", + content: `The user ran /init. Generated AGENTS.md:\n${result}`, + }); + } + }; + } + + // Wire /add-dir + const addDirCmd = registry.get("add-dir"); + if (addDirCmd) { + addDirCmd.handler = async (args: string) => { + const result = await handleAddDir( + this.agent.runtime.session, + this.agent.runtime.session.workDir, + args, + ); + if (result) { + // Inject directory info into context so the LLM knows about it + await this.context.appendMessage({ + role: "user", + content: result, + }); + } + }; + } + } + + /** Wire tool context callbacks (plan mode, ask user, etc.) to the soul. */ + wireToolContext(): void { + const ctx = this.agent.toolset.context; + ctx.setPlanMode = (on: boolean) => this.setPlanMode(on); + ctx.getPlanMode = () => this._planMode; + ctx.getPlanFilePath = () => this.getPlanFilePath() ?? undefined; + ctx.togglePlanMode = () => this.togglePlanMode(); + } + + // ── Wire file logging ──────────────────────────── + + /** + * Append a wire event to the session's wire.jsonl file. + * Used for session title generation and debugging. + */ + private async _wireLog(event: Record): Promise { + const wireFile = this.agent.runtime.session.wireFile; + if (!wireFile) return; + try { + const { appendFile } = await import("node:fs/promises"); + const line = JSON.stringify({ ...event, ts: Date.now() }) + "\n"; + await appendFile(wireFile, line, "utf-8"); + } catch { + // Wire logging is best-effort — don't crash on failure + } + } +} diff --git a/src/kimi_cli/soul/message.py b/src/kimi_cli/soul/message.py deleted file mode 100644 index e8e823d45..000000000 --- a/src/kimi_cli/soul/message.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence - -from kosong.message import Message -from kosong.tooling.error import ToolRuntimeError - -from kimi_cli.llm import ModelCapability -from kimi_cli.wire.types import ( - ContentPart, - ImageURLPart, - TextPart, - ThinkPart, - ToolResult, - VideoURLPart, -) - - -def system(message: str) -> ContentPart: - return TextPart(text=f"{message}") - - -def system_reminder(message: str) -> TextPart: - return TextPart(text=f"\n{message}\n") - - -def is_system_reminder_message(message: Message) -> bool: - """Check whether a message is an internal system-reminder user message.""" - if message.role != "user" or len(message.content) != 1: - return False - part = message.content[0] - return isinstance(part, TextPart) and part.text.strip().startswith("") - - -def tool_result_to_message(tool_result: ToolResult) -> Message: - """Convert a tool result to a message.""" - if tool_result.return_value.is_error: - assert tool_result.return_value.message, "Error return value should have a message" - message = tool_result.return_value.message - if isinstance(tool_result.return_value, ToolRuntimeError): - message += "\nThis is an unexpected error and the tool is probably not working." - content: list[ContentPart] = [system(f"ERROR: {message}")] - if tool_result.return_value.output: - content.extend(_output_to_content_parts(tool_result.return_value.output)) - else: - content: list[ContentPart] = [] - if tool_result.return_value.message: - content.append(system(tool_result.return_value.message)) - if tool_result.return_value.output: - content.extend(_output_to_content_parts(tool_result.return_value.output)) - if not content: - content.append(system("Tool output is empty.")) - elif not any(isinstance(part, TextPart) for part in content): - # Ensure at least one TextPart exists so the LLM API won't reject - # the message with "text content is empty" (see #1663). - content.insert(0, system("Tool returned non-text content.")) - - return Message( - role="tool", - content=content, - tool_call_id=tool_result.tool_call_id, - ) - - -def _output_to_content_parts( - output: str | ContentPart | Sequence[ContentPart], -) -> list[ContentPart]: - content: list[ContentPart] = [] - match output: - case str(text): - if text: - content.append(TextPart(text=text)) - case ContentPart(): - content.append(output) - case _: - content.extend(output) - return content - - -def check_message( - message: Message, model_capabilities: set[ModelCapability] -) -> set[ModelCapability]: - """Check the message content, return the missing model capabilities.""" - capabilities_needed = set[ModelCapability]() - for part in message.content: - if isinstance(part, ImageURLPart): - capabilities_needed.add("image_in") - elif isinstance(part, VideoURLPart): - capabilities_needed.add("video_in") - elif isinstance(part, ThinkPart): - capabilities_needed.add("thinking") - return capabilities_needed - model_capabilities diff --git a/src/kimi_cli/soul/message.ts b/src/kimi_cli/soul/message.ts new file mode 100644 index 000000000..216a052d8 --- /dev/null +++ b/src/kimi_cli/soul/message.ts @@ -0,0 +1,106 @@ +/** + * Message utility functions — corresponds to Python soul/message.py + * Helpers for constructing system/tool messages. + */ + +import type { ContentPart, Message, ModelCapability } from "../types.ts"; + +/** Wrap text in tags. */ +export function system(message: string): ContentPart { + return { type: "text", text: `${message}` }; +} + +/** Wrap text in tags. */ +export function systemReminder(message: string): ContentPart { + return { type: "text", text: `\n${message}\n` }; +} + +/** Check whether a message is an internal system-reminder user message. */ +export function isSystemReminderMessage(message: Message): boolean { + if (message.role !== "user") return false; + if (typeof message.content === "string") { + return message.content.trim().startsWith(""); + } + if (Array.isArray(message.content) && message.content.length === 1) { + const part = message.content[0]!; + if (part.type === "text") { + return part.text.trim().startsWith(""); + } + } + return false; +} + +/** Build a tool result message from output. */ +export function toolResultMessage(opts: { + toolCallId: string; + output: string | ContentPart | ContentPart[]; + isError?: boolean; + message?: string; +}): Message { + const parts: ContentPart[] = []; + + if (opts.isError) { + const errMsg = opts.message ?? "Unknown error"; + parts.push(system(`ERROR: ${errMsg}`)); + const outputParts = outputToContentParts(opts.output); + parts.push(...outputParts); + } else { + if (opts.message) { + parts.push(system(opts.message)); + } + const outputParts = outputToContentParts(opts.output); + parts.push(...outputParts); + if (parts.length === 0) { + parts.push(system("Tool output is empty.")); + } else if (!parts.some((p) => p.type === "text")) { + // Ensure at least one TextPart exists so the LLM API won't reject + parts.unshift(system("Tool returned non-text content.")); + } + } + + return { + role: "tool", + content: [ + { + type: "tool_result", + toolUseId: opts.toolCallId, + content: parts.map((p) => (p.type === "text" ? p.text : JSON.stringify(p))).join("\n"), + isError: opts.isError, + }, + ], + }; +} + +/** Convert various output formats to ContentPart array. */ +function outputToContentParts( + output: string | ContentPart | ContentPart[], +): ContentPart[] { + if (typeof output === "string") { + return output ? [{ type: "text", text: output }] : []; + } + if (Array.isArray(output)) { + return output; + } + // Single ContentPart + return [output]; +} + +/** Check message content for required model capabilities, return missing ones. */ +export function checkMessage( + message: Message, + modelCapabilities: Set, +): Set { + const needed = new Set(); + const content = typeof message.content === "string" ? [] : message.content; + for (const part of content) { + if (part.type === "image") needed.add("image_in"); + if ((part as any).type === "video") needed.add("video_in"); + if ((part as any).type === "thinking") needed.add("thinking"); + } + // Return only the capabilities that are missing + const missing = new Set(); + for (const cap of needed) { + if (!modelCapabilities.has(cap)) missing.add(cap); + } + return missing; +} diff --git a/src/kimi_cli/soul/slash.py b/src/kimi_cli/soul/slash.py deleted file mode 100644 index fbe5a4541..000000000 --- a/src/kimi_cli/soul/slash.py +++ /dev/null @@ -1,285 +0,0 @@ -from __future__ import annotations - -import tempfile -from collections.abc import Awaitable, Callable -from pathlib import Path -from typing import TYPE_CHECKING - -from kaos.path import KaosPath -from kosong.message import Message - -import kimi_cli.prompts as prompts -from kimi_cli import logger -from kimi_cli.soul import wire_send -from kimi_cli.soul.agent import load_agents_md -from kimi_cli.soul.context import Context -from kimi_cli.soul.message import system -from kimi_cli.utils.export import is_sensitive_file -from kimi_cli.utils.path import sanitize_cli_path, shorten_home -from kimi_cli.utils.slashcmd import SlashCommandRegistry -from kimi_cli.wire.types import StatusUpdate, TextPart - -if TYPE_CHECKING: - from kimi_cli.soul.kimisoul import KimiSoul - -type SoulSlashCmdFunc = Callable[[KimiSoul, str], None | Awaitable[None]] -""" -A function that runs as a KimiSoul-level slash command. - -Raises: - Any exception that can be raised by `Soul.run`. -""" - -registry = SlashCommandRegistry[SoulSlashCmdFunc]() - - -@registry.command -async def init(soul: KimiSoul, args: str): - """Analyze the codebase and generate an `AGENTS.md` file""" - from kimi_cli.soul.kimisoul import KimiSoul - - with tempfile.TemporaryDirectory() as temp_dir: - tmp_context = Context(file_backend=Path(temp_dir) / "context.jsonl") - tmp_soul = KimiSoul(soul.agent, context=tmp_context) - await tmp_soul.run(prompts.INIT) - - agents_md = await load_agents_md(soul.runtime.builtin_args.KIMI_WORK_DIR) - system_message = system( - "The user just ran `/init` slash command. " - "The system has analyzed the codebase and generated an `AGENTS.md` file. " - f"Latest AGENTS.md file content:\n{agents_md}" - ) - await soul.context.append_message(Message(role="user", content=[system_message])) - - -@registry.command -async def compact(soul: KimiSoul, args: str): - """Compact the context (optionally with a custom focus, e.g. /compact keep db discussions)""" - if soul.context.n_checkpoints == 0: - wire_send(TextPart(text="The context is empty.")) - return - - logger.info("Running `/compact`") - await soul.compact_context(custom_instruction=args.strip()) - wire_send(TextPart(text="The context has been compacted.")) - snap = soul.status - wire_send( - StatusUpdate( - context_usage=snap.context_usage, - context_tokens=snap.context_tokens, - max_context_tokens=snap.max_context_tokens, - ) - ) - - -@registry.command(aliases=["reset"]) -async def clear(soul: KimiSoul, args: str): - """Clear the context""" - logger.info("Running `/clear`") - await soul.context.clear() - await soul.context.write_system_prompt(soul.agent.system_prompt) - wire_send(TextPart(text="The context has been cleared.")) - snap = soul.status - wire_send( - StatusUpdate( - context_usage=snap.context_usage, - context_tokens=snap.context_tokens, - max_context_tokens=snap.max_context_tokens, - ) - ) - - -@registry.command -async def yolo(soul: KimiSoul, args: str): - """Toggle YOLO mode (auto-approve all actions)""" - if soul.runtime.approval.is_yolo(): - soul.runtime.approval.set_yolo(False) - wire_send(TextPart(text="You only die once! Actions will require approval.")) - else: - soul.runtime.approval.set_yolo(True) - wire_send(TextPart(text="You only live once! All actions will be auto-approved.")) - - -@registry.command -async def plan(soul: KimiSoul, args: str): - """Toggle plan mode. Usage: /plan [on|off|view|clear]""" - subcmd = args.strip().lower() - - if subcmd == "on": - if not soul.plan_mode: - await soul.toggle_plan_mode_from_manual() - plan_path = soul.get_plan_file_path() - wire_send(TextPart(text=f"Plan mode ON. Plan file: {plan_path}")) - wire_send(StatusUpdate(plan_mode=soul.plan_mode)) - elif subcmd == "off": - if soul.plan_mode: - await soul.toggle_plan_mode_from_manual() - wire_send(TextPart(text="Plan mode OFF. All tools are now available.")) - wire_send(StatusUpdate(plan_mode=soul.plan_mode)) - elif subcmd == "view": - content = soul.read_current_plan() - if content: - wire_send(TextPart(text=content)) - else: - wire_send(TextPart(text="No plan file found for this session.")) - elif subcmd == "clear": - soul.clear_current_plan() - wire_send(TextPart(text="Plan cleared.")) - else: - # Default: toggle - new_state = await soul.toggle_plan_mode_from_manual() - if new_state: - plan_path = soul.get_plan_file_path() - wire_send( - TextPart( - text=f"Plan mode ON. Write your plan to: {plan_path}\n" - "Use ExitPlanMode when done, or /plan off to exit manually." - ) - ) - else: - wire_send(TextPart(text="Plan mode OFF. All tools are now available.")) - wire_send(StatusUpdate(plan_mode=soul.plan_mode)) - - -@registry.command(name="add-dir") -async def add_dir(soul: KimiSoul, args: str): - """Add a directory to the workspace. Usage: /add-dir . Run without args to list added dirs""" # noqa: E501 - from kaos.path import KaosPath - - from kimi_cli.utils.path import is_within_directory, list_directory - - args = sanitize_cli_path(args) - if not args: - if not soul.runtime.additional_dirs: - wire_send(TextPart(text="No additional directories. Usage: /add-dir ")) - else: - lines = ["Additional directories:"] - for d in soul.runtime.additional_dirs: - lines.append(f" - {d}") - wire_send(TextPart(text="\n".join(lines))) - return - - path = KaosPath(args).expanduser().canonical() - - if not await path.exists(): - wire_send(TextPart(text=f"Directory does not exist: {path}")) - return - if not await path.is_dir(): - wire_send(TextPart(text=f"Not a directory: {path}")) - return - - # Check if already added (exact match) - if path in soul.runtime.additional_dirs: - wire_send(TextPart(text=f"Directory already in workspace: {path}")) - return - - # Check if it's within the work_dir (already accessible) - work_dir = soul.runtime.builtin_args.KIMI_WORK_DIR - if is_within_directory(path, work_dir): - wire_send(TextPart(text=f"Directory is already within the working directory: {path}")) - return - - # Check if it's within an already-added additional directory (redundant) - for existing in soul.runtime.additional_dirs: - if is_within_directory(path, existing): - wire_send( - TextPart( - text=f"Directory is already within an added directory `{existing}`: {path}" - ) - ) - return - - # Validate readability before committing any state changes - try: - ls_output = await list_directory(path) - except OSError as e: - wire_send(TextPart(text=f"Cannot read directory: {path} ({e})")) - return - - # Add the directory (only after readability is confirmed) - soul.runtime.additional_dirs.append(path) - - # Persist to session state - soul.runtime.session.state.additional_dirs.append(str(path)) - soul.runtime.session.save_state() - - # Inject a system message to inform the LLM about the new directory - system_message = system( - f"The user has added an additional directory to the workspace: `{path}`\n\n" - f"Directory listing:\n```\n{ls_output}\n```\n\n" - "You can now read, write, search, and glob files in this directory " - "as if it were part of the working directory." - ) - await soul.context.append_message(Message(role="user", content=[system_message])) - - wire_send(TextPart(text=f"Added directory to workspace: {path}")) - logger.info("Added additional directory: {path}", path=path) - - -@registry.command -async def export(soul: KimiSoul, args: str): - """Export current session context to a markdown file""" - from kimi_cli.utils.export import perform_export - - session = soul.runtime.session - result = await perform_export( - history=list(soul.context.history), - session_id=session.id, - work_dir=str(session.work_dir), - token_count=soul.context.token_count, - args=args, - default_dir=Path(str(session.work_dir)), - ) - if isinstance(result, str): - wire_send(TextPart(text=result)) - return - output, count = result - display = shorten_home(KaosPath(str(output))) - wire_send(TextPart(text=f"Exported {count} messages to {display}")) - wire_send( - TextPart( - text=" Note: The exported file may contain sensitive information. " - "Please be cautious when sharing it externally." - ) - ) - - -@registry.command(name="import") -async def import_context(soul: KimiSoul, args: str): - """Import context from a file or session ID""" - from kimi_cli.utils.export import perform_import - - target = sanitize_cli_path(args) - if not target: - wire_send(TextPart(text="Usage: /import ")) - return - - session = soul.runtime.session - raw_max_context_size = ( - soul.runtime.llm.max_context_size if soul.runtime.llm is not None else None - ) - max_context_size = ( - raw_max_context_size - if isinstance(raw_max_context_size, int) and raw_max_context_size > 0 - else None - ) - result = await perform_import( - target=target, - current_session_id=session.id, - work_dir=session.work_dir, - context=soul.context, - max_context_size=max_context_size, - ) - if isinstance(result, str): - wire_send(TextPart(text=result)) - return - - source_desc, content_len = result - wire_send(TextPart(text=f"Imported context from {source_desc} ({content_len} chars).")) - if source_desc.startswith("file") and is_sensitive_file(Path(target).name): - wire_send( - TextPart( - text="Warning: This file may contain secrets (API keys, tokens, credentials). " - "The content is now part of your session context." - ) - ) diff --git a/src/kimi_cli/soul/slash.ts b/src/kimi_cli/soul/slash.ts new file mode 100644 index 000000000..3b17ec602 --- /dev/null +++ b/src/kimi_cli/soul/slash.ts @@ -0,0 +1,206 @@ +/** + * Slash command registry — corresponds to Python soul/slash concepts + * Provides registration + dispatch for /commands in the CLI. + */ + +import type { SlashCommand } from "../types.ts"; + +export class SlashCommandRegistry { + private commands = new Map(); + private aliases = new Map(); + + register(command: SlashCommand): void { + this.commands.set(command.name, command); + if (command.aliases) { + for (const alias of command.aliases) { + this.aliases.set(alias, command.name); + } + } + } + + get(name: string): SlashCommand | undefined { + const resolved = this.aliases.get(name) ?? name; + return this.commands.get(resolved); + } + + has(name: string): boolean { + return this.commands.has(name) || this.aliases.has(name); + } + + list(): SlashCommand[] { + return [...this.commands.values()]; + } + + async execute(input: string): Promise { + const trimmed = input.trim(); + if (!trimmed.startsWith("/")) return false; + + const spaceIdx = trimmed.indexOf(" "); + const name = spaceIdx === -1 ? trimmed.slice(1) : trimmed.slice(1, spaceIdx); + const args = spaceIdx === -1 ? "" : trimmed.slice(spaceIdx + 1).trim(); + + const cmd = this.get(name); + if (!cmd) return false; + + await cmd.handler(args); + return true; + } +} + +/** + * Create a default registry with built-in commands. + * Handlers are stubs — the real app wires them up. + */ +export function createDefaultRegistry(): SlashCommandRegistry { + const registry = new SlashCommandRegistry(); + + const builtins: SlashCommand[] = [ + { + name: "clear", + description: "Clear conversation history", + aliases: ["reset"], + handler: async () => { + /* wired by app */ + }, + }, + { + name: "compact", + description: "Compact conversation context", + handler: async () => {}, + }, + { + name: "yolo", + description: "Toggle auto-approve mode", + aliases: ["auto-approve"], + handler: async () => {}, + }, + { + name: "plan", + description: "Toggle plan mode", + handler: async () => {}, + }, + { + name: "model", + description: "Switch model", + handler: async () => {}, + }, + { + name: "help", + description: "Show help", + aliases: ["?"], + handler: async () => {}, + }, + { + name: "init", + description: "Initialize project configuration", + handler: async () => {}, + }, + { + name: "add-dir", + description: "Add directory to workspace scope", + handler: async () => {}, + }, + // ── Commands below are newly registered to match Python version ── + { + name: "login", + description: "Login or setup a platform", + aliases: ["setup"], + handler: async () => {}, + }, + { + name: "logout", + description: "Logout from the current platform", + handler: async () => {}, + }, + { + name: "new", + description: "Start a new session", + handler: async () => {}, + }, + { + name: "sessions", + description: "List sessions and resume", + aliases: ["resume"], + handler: async () => {}, + }, + { + name: "title", + description: "Set or show the session title", + aliases: ["rename"], + handler: async () => {}, + }, + { + name: "task", + description: "Browse and manage background tasks", + handler: async () => {}, + }, + { + name: "editor", + description: "Set default external editor", + handler: async () => {}, + }, + { + name: "reload", + description: "Reload configuration", + handler: async () => {}, + }, + { + name: "usage", + description: "Display API usage and quota information", + aliases: ["status"], + handler: async () => {}, + }, + { + name: "changelog", + description: "Show release notes", + aliases: ["release-notes"], + handler: async () => {}, + }, + { + name: "feedback", + description: "Submit feedback", + handler: async () => {}, + }, + { + name: "hooks", + description: "List configured hooks", + handler: async () => {}, + }, + { + name: "mcp", + description: "Show MCP servers and tools", + handler: async () => {}, + }, + { + name: "web", + description: "Open Kimi Code Web UI in browser", + handler: async () => {}, + }, + { + name: "vis", + description: "Open Kimi Agent Tracing Visualizer", + handler: async () => {}, + }, + { + name: "export", + description: "Export session context to markdown", + handler: async () => {}, + }, + { + name: "import", + description: "Import context from file or session", + handler: async () => {}, + }, + { + name: "debug", + description: "Debug the context", + handler: async () => {}, + }, + ]; + + for (const cmd of builtins) { + registry.register(cmd); + } + + return registry; +} diff --git a/src/kimi_cli/soul/toolset.py b/src/kimi_cli/soul/toolset.py deleted file mode 100644 index 3722147ef..000000000 --- a/src/kimi_cli/soul/toolset.py +++ /dev/null @@ -1,610 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import importlib -import inspect -import json -from contextvars import ContextVar -from dataclasses import dataclass -from datetime import timedelta -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, overload - -from kosong.tooling import ( - CallableTool, - CallableTool2, - HandleResult, - Tool, - ToolError, - ToolOk, - Toolset, -) -from kosong.tooling.error import ( - ToolNotFoundError, - ToolParseError, - ToolRuntimeError, -) -from kosong.tooling.mcp import convert_mcp_content -from kosong.utils.typing import JsonType - -from kimi_cli import logger -from kimi_cli.exception import InvalidToolError, MCPRuntimeError -from kimi_cli.hooks.engine import HookEngine -from kimi_cli.tools import SkipThisTool -from kimi_cli.wire.types import ( - ContentPart, - MCPServerSnapshot, - MCPStatusSnapshot, - ToolCall, - ToolCallRequest, - ToolResult, - ToolReturnValue, -) - -if TYPE_CHECKING: - import fastmcp - import mcp - from fastmcp.client.client import CallToolResult - from fastmcp.client.transports import ClientTransport - from fastmcp.mcp_config import MCPConfig - - from kimi_cli.soul.agent import Runtime - -current_tool_call = ContextVar[ToolCall | None]("current_tool_call", default=None) - -_current_session_id: ContextVar[str] = ContextVar("_current_session_id", default="") - - -def set_session_id(sid: str) -> None: - _current_session_id.set(sid) - - -def _get_session_id() -> str: - return _current_session_id.get() - - -def get_current_tool_call_or_none() -> ToolCall | None: - """ - Get the current tool call or None. - Expect to be not None when called from a `__call__` method of a tool. - """ - return current_tool_call.get() - - -type ToolType = CallableTool | CallableTool2[Any] - - -if TYPE_CHECKING: - - def type_check(kimi_toolset: KimiToolset): - _: Toolset = kimi_toolset - - -class KimiToolset: - def __init__(self) -> None: - self._tool_dict: dict[str, ToolType] = {} - self._hidden_tools: set[str] = set() - self._mcp_servers: dict[str, MCPServerInfo] = {} - self._mcp_loading_task: asyncio.Task[None] | None = None - self._deferred_mcp_load: tuple[list[MCPConfig], Runtime] | None = None - self._hook_engine: HookEngine = HookEngine() - - def set_hook_engine(self, engine: HookEngine) -> None: - self._hook_engine = engine - - def add(self, tool: ToolType) -> None: - self._tool_dict[tool.name] = tool - - def hide(self, tool_name: str) -> bool: - """Hide a tool from the LLM tool list. Returns True if the tool exists.""" - if tool_name in self._tool_dict: - self._hidden_tools.add(tool_name) - return True - return False - - def unhide(self, tool_name: str) -> None: - """Restore a hidden tool to the LLM tool list.""" - self._hidden_tools.discard(tool_name) - - @overload - def find(self, tool_name_or_type: str) -> ToolType | None: ... - @overload - def find[T: ToolType](self, tool_name_or_type: type[T]) -> T | None: ... - def find(self, tool_name_or_type: str | type[ToolType]) -> ToolType | None: - if isinstance(tool_name_or_type, str): - return self._tool_dict.get(tool_name_or_type) - else: - for tool in self._tool_dict.values(): - if isinstance(tool, tool_name_or_type): - return tool - return None - - @property - def tools(self) -> list[Tool]: - return [ - tool.base for tool in self._tool_dict.values() if tool.name not in self._hidden_tools - ] - - def handle(self, tool_call: ToolCall) -> HandleResult: - token = current_tool_call.set(tool_call) - try: - if tool_call.function.name not in self._tool_dict: - return ToolResult( - tool_call_id=tool_call.id, - return_value=ToolNotFoundError(tool_call.function.name), - ) - - tool = self._tool_dict[tool_call.function.name] - - try: - arguments: JsonType = json.loads(tool_call.function.arguments or "{}", strict=False) - except json.JSONDecodeError as e: - return ToolResult(tool_call_id=tool_call.id, return_value=ToolParseError(str(e))) - - async def _call(): - tool_input_dict = arguments if isinstance(arguments, dict) else {} - - # --- PreToolUse --- - from kimi_cli.hooks import events - - results = await self._hook_engine.trigger( - "PreToolUse", - matcher_value=tool_call.function.name, - input_data=events.pre_tool_use( - session_id=_get_session_id(), - cwd=str(Path.cwd()), - tool_name=tool_call.function.name, - tool_input=tool_input_dict, - tool_call_id=tool_call.id, - ), - ) - for result in results: - if result.action == "block": - return ToolResult( - tool_call_id=tool_call.id, - return_value=ToolError( - message=result.reason or "Blocked by PreToolUse hook", - brief="Hook blocked", - ), - ) - - # --- Execute tool --- - try: - ret = await tool.call(arguments) - except Exception as e: - # --- PostToolUseFailure (fire-and-forget) --- - _hook_task = asyncio.create_task( - self._hook_engine.trigger( - "PostToolUseFailure", - matcher_value=tool_call.function.name, - input_data=events.post_tool_use_failure( - session_id=_get_session_id(), - cwd=str(Path.cwd()), - tool_name=tool_call.function.name, - tool_input=tool_input_dict, - error=str(e), - tool_call_id=tool_call.id, - ), - ) - ) - _hook_task.add_done_callback( - lambda t: t.exception() if not t.cancelled() else None - ) - return ToolResult( - tool_call_id=tool_call.id, - return_value=ToolRuntimeError(str(e)), - ) - - # --- PostToolUse (fire-and-forget) --- - _hook_task = asyncio.create_task( - self._hook_engine.trigger( - "PostToolUse", - matcher_value=tool_call.function.name, - input_data=events.post_tool_use( - session_id=_get_session_id(), - cwd=str(Path.cwd()), - tool_name=tool_call.function.name, - tool_input=tool_input_dict, - tool_output=str(ret)[:2000], - tool_call_id=tool_call.id, - ), - ) - ) - _hook_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) - - return ToolResult(tool_call_id=tool_call.id, return_value=ret) - - return asyncio.create_task(_call()) - finally: - current_tool_call.reset(token) - - def register_external_tool( - self, - name: str, - description: str, - parameters: dict[str, Any], - ) -> tuple[bool, str | None]: - if name in self._tool_dict: - existing = self._tool_dict[name] - if not isinstance(existing, WireExternalTool): - return False, "tool name conflicts with existing tool" - try: - tool = WireExternalTool( - name=name, - description=description, - parameters=parameters, - ) - except Exception as e: - return False, str(e) - self.add(tool) - return True, None - - @property - def mcp_servers(self) -> dict[str, MCPServerInfo]: - """Get MCP servers info.""" - return self._mcp_servers - - def mcp_status_snapshot(self) -> MCPStatusSnapshot | None: - """Return a read-only snapshot of current MCP startup state.""" - if not self._mcp_servers: - return None - - servers = tuple( - MCPServerSnapshot( - name=name, - status=info.status, - tools=tuple(tool.name for tool in info.tools), - ) - for name, info in self._mcp_servers.items() - ) - return MCPStatusSnapshot( - loading=self.has_pending_mcp_tools(), - connected=sum(1 for server in servers if server.status == "connected"), - total=len(servers), - tools=sum(len(server.tools) for server in servers), - servers=servers, - ) - - def defer_mcp_tool_loading(self, mcp_configs: list[MCPConfig], runtime: Runtime) -> None: - """Store MCP configs for a later background startup.""" - self._deferred_mcp_load = (list(mcp_configs), runtime) - - def has_deferred_mcp_tools(self) -> bool: - """Return True when MCP loading is configured but has not started yet.""" - return self._deferred_mcp_load is not None - - async def start_deferred_mcp_tool_loading(self) -> bool: - """Start any deferred MCP loading in the background.""" - if self._deferred_mcp_load is None: - return False - if self._mcp_loading_task is not None or self._mcp_servers: - self._deferred_mcp_load = None - return False - - mcp_configs, runtime = self._deferred_mcp_load - self._deferred_mcp_load = None - await self.load_mcp_tools(mcp_configs, runtime, in_background=True) - return True - - def load_tools(self, tool_paths: list[str], dependencies: dict[type[Any], Any]) -> None: - """ - Load tools from paths like `kimi_cli.tools.shell:Shell`. - - Raises: - InvalidToolError(KimiCLIException, ValueError): When any tool cannot be loaded. - """ - - good_tools: list[str] = [] - bad_tools: list[str] = [] - - for tool_path in tool_paths: - try: - tool = self._load_tool(tool_path, dependencies) - except SkipThisTool: - logger.info("Skipping tool: {tool_path}", tool_path=tool_path) - continue - if tool: - self.add(tool) - good_tools.append(tool_path) - else: - bad_tools.append(tool_path) - logger.info("Loaded tools: {good_tools}", good_tools=good_tools) - if bad_tools: - raise InvalidToolError(f"Invalid tools: {bad_tools}") - - @staticmethod - def _load_tool(tool_path: str, dependencies: dict[type[Any], Any]) -> ToolType | None: - logger.debug("Loading tool: {tool_path}", tool_path=tool_path) - module_name, class_name = tool_path.rsplit(":", 1) - try: - module = importlib.import_module(module_name) - except ImportError: - return None - tool_cls = getattr(module, class_name, None) - if tool_cls is None: - return None - args: list[Any] = [] - if "__init__" in tool_cls.__dict__: - # the tool class overrides the `__init__` of base class - for param in inspect.signature(tool_cls).parameters.values(): - if param.kind == inspect.Parameter.KEYWORD_ONLY: - # once we encounter a keyword-only parameter, we stop injecting dependencies - break - # all positional parameters should be dependencies to be injected - if param.annotation not in dependencies: - raise ValueError(f"Tool dependency not found: {param.annotation}") - args.append(dependencies[param.annotation]) - return tool_cls(*args) - - # TODO(rc): remove `in_background` parameter and always load in background - async def load_mcp_tools( - self, mcp_configs: list[MCPConfig], runtime: Runtime, in_background: bool = True - ) -> None: - """ - Load MCP tools from specified MCP configs. - - Raises: - MCPRuntimeError(KimiCLIException, RuntimeError): When any MCP server cannot be - connected. - """ - import fastmcp - from fastmcp.mcp_config import MCPConfig, RemoteMCPServer - - from kimi_cli.ui.shell.prompt import toast - - async def _check_oauth_tokens(server_url: str) -> bool: - """Check if OAuth tokens exist for the server.""" - try: - from fastmcp.client.auth.oauth import FileTokenStorage - - storage = FileTokenStorage(server_url=server_url) - tokens = await storage.get_tokens() - return tokens is not None - except Exception: - return False - - def _toast_mcp(message: str) -> None: - if in_background: - toast( - message, - duration=10.0, - topic="mcp", - immediate=True, - position="right", - ) - - oauth_servers: dict[str, str] = {} - - async def _connect_server( - server_name: str, server_info: MCPServerInfo - ) -> tuple[str, Exception | None]: - if server_info.status != "pending": - return server_name, None - - server_info.status = "connecting" - try: - async with server_info.client as client: - for tool in await client.list_tools(): - server_info.tools.append( - MCPTool(server_name, tool, client, runtime=runtime) - ) - - for tool in server_info.tools: - self.add(tool) - - server_info.status = "connected" - logger.info("Connected MCP server: {server_name}", server_name=server_name) - return server_name, None - except Exception as e: - logger.error( - "Failed to connect MCP server: {server_name}, error: {error}", - server_name=server_name, - error=e, - ) - server_info.status = "failed" - return server_name, e - - async def _connect(): - _toast_mcp("connecting to mcp servers...") - unauthorized_servers: dict[str, str] = {} - for server_name, server_info in self._mcp_servers.items(): - server_url = oauth_servers.get(server_name) - if not server_url: - continue - if not await _check_oauth_tokens(server_url): - logger.warning( - "Skipping OAuth MCP server '{server_name}': not authorized. " - "Run 'kimi mcp auth {server_name}' first.", - server_name=server_name, - ) - server_info.status = "unauthorized" - unauthorized_servers[server_name] = server_url - - tasks = [ - asyncio.create_task(_connect_server(server_name, server_info)) - for server_name, server_info in self._mcp_servers.items() - if server_info.status == "pending" - ] - results = await asyncio.gather(*tasks) if tasks else [] - failed_servers = {name: error for name, error in results if error is not None} - - for mcp_config in mcp_configs: - # Skip empty MCP configs (no servers defined) - if not mcp_config.mcpServers: - logger.debug("Skipping empty MCP config: {mcp_config}", mcp_config=mcp_config) - continue - - if failed_servers: - _toast_mcp("mcp connection failed") - raise MCPRuntimeError(f"Failed to connect MCP servers: {failed_servers}") - if unauthorized_servers: - _toast_mcp("mcp authorization needed") - else: - _toast_mcp("mcp servers connected") - - for mcp_config in mcp_configs: - if not mcp_config.mcpServers: - logger.debug("Skipping empty MCP config: {mcp_config}", mcp_config=mcp_config) - continue - - for server_name, server_config in mcp_config.mcpServers.items(): - if isinstance(server_config, RemoteMCPServer) and server_config.auth == "oauth": - oauth_servers[server_name] = server_config.url - - client = fastmcp.Client(MCPConfig(mcpServers={server_name: server_config})) - self._mcp_servers[server_name] = MCPServerInfo( - status="pending", client=client, tools=[] - ) - - if in_background: - self._mcp_loading_task = asyncio.create_task(_connect()) - else: - await _connect() - - def has_pending_mcp_tools(self) -> bool: - """Return True if the background MCP tool-loading task is still running.""" - return self._mcp_loading_task is not None and not self._mcp_loading_task.done() - - async def wait_for_mcp_tools(self) -> None: - """Wait for background MCP tool loading to finish.""" - task = self._mcp_loading_task - if not task: - return - try: - await task - finally: - if self._mcp_loading_task is task and task.done(): - self._mcp_loading_task = None - - async def cleanup(self) -> None: - """Cleanup any resources held by the toolset.""" - self._deferred_mcp_load = None - if self._mcp_loading_task: - self._mcp_loading_task.cancel() - with contextlib.suppress(Exception): - await self._mcp_loading_task - for server_info in self._mcp_servers.values(): - await server_info.client.close() - - -@dataclass(slots=True) -class MCPServerInfo: - status: Literal["pending", "connecting", "connected", "failed", "unauthorized"] - client: fastmcp.Client[Any] - tools: list[MCPTool[Any]] - - -class MCPTool[T: ClientTransport](CallableTool): - def __init__( - self, - server_name: str, - mcp_tool: mcp.Tool, - client: fastmcp.Client[T], - *, - runtime: Runtime, - **kwargs: Any, - ): - super().__init__( - name=mcp_tool.name, - description=( - f"This is an MCP (Model Context Protocol) tool from MCP server `{server_name}`.\n\n" - f"{mcp_tool.description or 'No description provided.'}" - ), - parameters=mcp_tool.inputSchema, - **kwargs, - ) - self._mcp_tool = mcp_tool - self._client = client - self._runtime = runtime - self._timeout = timedelta(milliseconds=runtime.config.mcp.client.tool_call_timeout_ms) - self._action_name = f"mcp:{mcp_tool.name}" - - async def __call__(self, *args: Any, **kwargs: Any) -> ToolReturnValue: - description = f"Call MCP tool `{self._mcp_tool.name}`." - result = await self._runtime.approval.request(self.name, self._action_name, description) - if not result: - return result.rejection_error() - - try: - async with self._client as client: - result = await client.call_tool( - self._mcp_tool.name, - kwargs, - timeout=self._timeout, - raise_on_error=False, - ) - return convert_mcp_tool_result(result) - except Exception as e: - # fastmcp raises `RuntimeError` on timeout and we cannot tell it from other errors - exc_msg = str(e).lower() - if "timeout" in exc_msg or "timed out" in exc_msg: - return ToolError( - message=( - f"Timeout while calling MCP tool `{self._mcp_tool.name}`. " - "You may explain to the user that the timeout config is set too low." - ), - brief="Timeout", - ) - raise - - -class WireExternalTool(CallableTool): - def __init__(self, *, name: str, description: str, parameters: dict[str, Any]) -> None: - super().__init__( - name=name, - description=description or "No description provided.", - parameters=parameters, - ) - - async def __call__(self, *args: Any, **kwargs: Any) -> ToolReturnValue: - tool_call = get_current_tool_call_or_none() - if tool_call is None: - return ToolError( - message="External tool calls must be invoked from a tool call context.", - brief="Invalid tool call", - ) - - from kimi_cli.soul import get_wire_or_none - - wire = get_wire_or_none() - if wire is None: - logger.error( - "Wire is not available for external tool call: {tool_name}", tool_name=self.name - ) - return ToolError( - message="Wire is not available for external tool calls.", - brief="Wire unavailable", - ) - - external_tool_call = ToolCallRequest.from_tool_call(tool_call) - wire.soul_side.send(external_tool_call) - try: - return await external_tool_call.wait() - except asyncio.CancelledError: - raise - except Exception as e: - logger.exception("External tool call failed: {tool_name}:", tool_name=self.name) - return ToolError( - message=f"External tool call failed: {e}", - brief="External tool error", - ) - - -def convert_mcp_tool_result(result: CallToolResult) -> ToolReturnValue: - """Convert MCP tool result to kosong tool return value. - - Raises: - ValueError: If any content part has unsupported type or mime type. - """ - content: list[ContentPart] = [] - for part in result.content: - content.append(convert_mcp_content(part)) - if result.is_error: - return ToolError( - output=content, - message="Tool returned an error. The output may be error message or incomplete output", - brief="", - ) - else: - return ToolOk(output=content) diff --git a/src/kimi_cli/soul/toolset.ts b/src/kimi_cli/soul/toolset.ts new file mode 100644 index 000000000..1dc3be49b --- /dev/null +++ b/src/kimi_cli/soul/toolset.ts @@ -0,0 +1,193 @@ +/** + * Toolset — corresponds to Python soul/toolset.py + * Extended tool registry with hook integration, wire event emission, + * currentToolCall tracking, and sessionId context. + */ + +import type { CallableTool } from "../tools/base.ts"; +import { ToolRegistry } from "../tools/registry.ts"; +import type { ToolContext, ToolResult } from "../tools/types.ts"; +import type { HookEngine } from "../hooks/engine.ts"; +import type { ToolCall } from "../types.ts"; +import { logger } from "../utils/logging.ts"; + +// ── Context variables (module-level singletons) ────── + +let _currentToolCall: ToolCall | null = null; +let _currentSessionId = ""; + +/** Set the current session ID for tool call context. */ +export function setSessionId(sid: string): void { + _currentSessionId = sid; +} + +/** Get the current session ID. */ +export function getSessionId(): string { + return _currentSessionId; +} + +/** Get the current tool call, or null if not in a tool execution. */ +export function getCurrentToolCallOrNull(): ToolCall | null { + return _currentToolCall; +} + +export interface ToolsetOptions { + context: ToolContext; + hookEngine?: HookEngine; + onToolCall?: (toolCall: ToolCall) => void; + onToolResult?: (toolCallId: string, result: ToolResult) => void; +} + +export class KimiToolset { + private registry: ToolRegistry; + private hookEngine?: HookEngine; + private hiddenTools = new Set(); + private onToolCall?: (toolCall: ToolCall) => void; + private onToolResult?: (toolCallId: string, result: ToolResult) => void; + + constructor(opts: ToolsetOptions) { + this.registry = new ToolRegistry(opts.context); + this.hookEngine = opts.hookEngine; + this.onToolCall = opts.onToolCall; + this.onToolResult = opts.onToolResult; + } + + get context(): ToolContext { + return this.registry.context; + } + + // ── Tool management ───────────────────────────── + + add(tool: CallableTool): void { + this.registry.register(tool); + } + + find(name: string): CallableTool | undefined { + return this.registry.find(name); + } + + list(): CallableTool[] { + return this.registry.list(); + } + + hide(toolName: string): void { + this.hiddenTools.add(toolName); + } + + unhide(toolName: string): void { + this.hiddenTools.delete(toolName); + } + + /** Get tool definitions for LLM, excluding hidden tools. */ + definitions(): Array<{ + name: string; + description: string; + parameters: Record; + }> { + return this.registry + .list() + .filter((t) => !this.hiddenTools.has(t.name)) + .map((t) => t.toDefinition()); + } + + // ── Tool execution with hooks ──────────────────── + + async handle(toolCall: ToolCall): Promise { + const { id, name, arguments: argsStr } = toolCall; + + // Set current tool call context + const prevToolCall = _currentToolCall; + _currentToolCall = toolCall; + + try { + // Notify about tool call + this.onToolCall?.(toolCall); + + // Parse arguments + let args: Record; + try { + args = argsStr ? JSON.parse(argsStr) : {}; + } catch { + const result: ToolResult = { + isError: true, + output: "", + message: `Failed to parse arguments for tool "${name}": ${argsStr}`, + }; + this.onToolResult?.(id, result); + return result; + } + + // Run PreToolUse hook + if (this.hookEngine?.hasHooksFor("PreToolUse")) { + const hookResults = await this.hookEngine.trigger("PreToolUse", { + matcherValue: name, + inputData: { + session_id: _currentSessionId, + tool_name: name, + tool_input: args, + tool_call_id: id, + }, + }); + + for (const hr of hookResults) { + if (hr.action === "block") { + const result: ToolResult = { + isError: true, + output: "", + message: `Tool "${name}" blocked by hook: ${hr.reason}`, + }; + this.onToolResult?.(id, result); + return result; + } + } + } + + // Execute tool + let result: ToolResult; + try { + result = await this.registry.execute(name, args); + } catch (err) { + logger.error(`Tool "${name}" threw an error: ${err}`); + result = { + isError: true, + output: "", + message: `Tool "${name}" failed: ${err instanceof Error ? err.message : String(err)}`, + }; + } + + // Run PostToolUse / PostToolUseFailure hook (fire-and-forget) + if (this.hookEngine) { + const hookEvent = result.isError ? "PostToolUseFailure" : "PostToolUse"; + if (this.hookEngine.hasHooksFor(hookEvent as any)) { + this.hookEngine + .trigger(hookEvent as any, { + matcherValue: name, + inputData: { + session_id: _currentSessionId, + tool_name: name, + tool_input: args, + tool_output: (result.output ?? "").slice(0, 2000), + tool_error: result.isError ? result.message : undefined, + tool_call_id: id, + }, + }) + .catch(() => {}); // fire-and-forget + } + } + + // Notify about result + this.onToolResult?.(id, result); + + return result; + } finally { + // Restore previous tool call context + _currentToolCall = prevToolCall; + } + } + + // ── Cleanup ─────────────────────────────────────── + + async cleanup(): Promise { + // Cleanup MCP connections, etc. (future) + } +} diff --git a/src/kimi_cli/subagents/__init__.py b/src/kimi_cli/subagents/__init__.py deleted file mode 100644 index 0aae431a7..000000000 --- a/src/kimi_cli/subagents/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from kimi_cli.subagents.models import ( - AgentInstanceRecord, - AgentLaunchSpec, - AgentTypeDefinition, - SubagentStatus, - ToolPolicy, - ToolPolicyMode, -) -from kimi_cli.subagents.registry import LaborMarket -from kimi_cli.subagents.store import SubagentStore - -__all__ = [ - "AgentInstanceRecord", - "AgentLaunchSpec", - "AgentTypeDefinition", - "LaborMarket", - "SubagentStatus", - "SubagentStore", - "ToolPolicy", - "ToolPolicyMode", -] diff --git a/src/kimi_cli/subagents/builder.py b/src/kimi_cli/subagents/builder.py deleted file mode 100644 index ad2a28e99..000000000 --- a/src/kimi_cli/subagents/builder.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -from kimi_cli.llm import clone_llm_with_model_alias -from kimi_cli.soul.agent import Agent, Runtime, load_agent -from kimi_cli.subagents.models import AgentLaunchSpec, AgentTypeDefinition - - -class SubagentBuilder: - def __init__(self, root_runtime: Runtime): - self._root_runtime = root_runtime - - async def build_builtin_instance( - self, - *, - agent_id: str, - type_def: AgentTypeDefinition, - launch_spec: AgentLaunchSpec, - ) -> Agent: - effective_model = self.resolve_effective_model(type_def=type_def, launch_spec=launch_spec) - llm_override = clone_llm_with_model_alias( - self._root_runtime.llm, - self._root_runtime.config, - effective_model, - session_id=self._root_runtime.session.id, - oauth=self._root_runtime.oauth, - ) - runtime = self._root_runtime.copy_for_subagent( - agent_id=agent_id, - subagent_type=type_def.name, - llm_override=llm_override, - ) - return await load_agent( - type_def.agent_file, - runtime, - mcp_configs=[], - ) - - @staticmethod - def resolve_effective_model( - *, type_def: AgentTypeDefinition, launch_spec: AgentLaunchSpec - ) -> str | None: - return launch_spec.model_override or launch_spec.effective_model or type_def.default_model diff --git a/src/kimi_cli/subagents/builder.ts b/src/kimi_cli/subagents/builder.ts new file mode 100644 index 000000000..dabe34681 --- /dev/null +++ b/src/kimi_cli/subagents/builder.ts @@ -0,0 +1,23 @@ +/** + * Subagent builder — corresponds to Python subagents/builder.py + * Constructs subagent instances from type definitions. + */ + +import type { AgentLaunchSpec, AgentTypeDefinition } from "./models.ts"; + +export class SubagentBuilder { + /** + * Determine the effective model for a subagent launch. + * Priority: launch spec override > launch spec effective > type definition default. + */ + static resolveEffectiveModel(opts: { + typeDef: AgentTypeDefinition; + launchSpec: AgentLaunchSpec; + }): string | undefined { + return ( + opts.launchSpec.modelOverride ?? + opts.launchSpec.effectiveModel ?? + opts.typeDef.defaultModel + ); + } +} diff --git a/src/kimi_cli/subagents/core.py b/src/kimi_cli/subagents/core.py deleted file mode 100644 index 183e5c544..000000000 --- a/src/kimi_cli/subagents/core.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Shared core logic for preparing a subagent soul. - -Both ``ForegroundSubagentRunner`` and ``BackgroundAgentRunner`` delegate -the repetitive build-restore-prompt pipeline to :func:`prepare_soul` so -that prompt enhancements (e.g. git context injection) only need to be -implemented once. -""" - -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass, replace -from typing import TYPE_CHECKING - -from kimi_cli.soul.context import Context -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.subagents.builder import SubagentBuilder -from kimi_cli.subagents.models import AgentLaunchSpec, AgentTypeDefinition -from kimi_cli.subagents.store import SubagentStore - -if TYPE_CHECKING: - from kimi_cli.soul.agent import Runtime - - -@dataclass(frozen=True, slots=True, kw_only=True) -class SubagentRunSpec: - """Everything needed to prepare a soul, without lifecycle concerns.""" - - agent_id: str - type_def: AgentTypeDefinition - launch_spec: AgentLaunchSpec - prompt: str - resumed: bool - - -async def prepare_soul( - spec: SubagentRunSpec, - runtime: Runtime, - builder: SubagentBuilder, - store: SubagentStore, - on_stage: Callable[[str], None] | None = None, -) -> tuple[KimiSoul, str]: - """Build agent, restore context, handle system prompt, write prompt file. - - Returns ``(soul, final_prompt)`` ready for execution via - :func:`run_with_summary_continuation`. - """ - - # 1. Build agent from type definition - agent = await builder.build_builtin_instance( - agent_id=spec.agent_id, - type_def=spec.type_def, - launch_spec=spec.launch_spec, - ) - if on_stage: - on_stage("agent_built") - - # 2. Restore conversation context - context = Context(store.context_path(spec.agent_id)) - await context.restore() - if on_stage: - on_stage("context_restored") - - # 3. System prompt: reuse persisted prompt on resume, persist on first run - if context.system_prompt is not None: - agent = replace(agent, system_prompt=context.system_prompt) - else: - await context.write_system_prompt(agent.system_prompt) - if on_stage: - on_stage("context_ready") - - # 4. For new (non-resumed) explore agents, prepend git context to the prompt - prompt = spec.prompt - if spec.type_def.name == "explore" and not spec.resumed: - from kimi_cli.subagents.git_context import collect_git_context - - git_ctx = await collect_git_context(runtime.builtin_args.KIMI_WORK_DIR) - if git_ctx: - prompt = f"{git_ctx}\n\n{prompt}" - - # 5. Write prompt snapshot (debugging aid) - store.prompt_path(spec.agent_id).write_text(prompt, encoding="utf-8") - - # 6. Create soul - soul = KimiSoul(agent, context=context) - return soul, prompt diff --git a/src/kimi_cli/subagents/core.ts b/src/kimi_cli/subagents/core.ts new file mode 100644 index 000000000..b26eba2a2 --- /dev/null +++ b/src/kimi_cli/subagents/core.ts @@ -0,0 +1,13 @@ +/** + * Subagent run spec and prepare_soul pipeline — corresponds to Python subagents/core.py + */ + +import type { AgentLaunchSpec, AgentTypeDefinition } from "./models.ts"; + +export interface SubagentRunSpec { + readonly agentId: string; + readonly typeDef: AgentTypeDefinition; + readonly launchSpec: AgentLaunchSpec; + readonly prompt: string; + readonly resumed: boolean; +} diff --git a/src/kimi_cli/subagents/git_context.py b/src/kimi_cli/subagents/git_context.py deleted file mode 100644 index 23afd0464..000000000 --- a/src/kimi_cli/subagents/git_context.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Collect git repository context for explore subagents.""" - -from __future__ import annotations - -import asyncio -import re -from urllib.parse import urlparse - -import kaos -from kaos.path import KaosPath - -from kimi_cli.utils.logging import logger - -_TIMEOUT = 5.0 -_MAX_DIRTY_FILES = 20 - - -async def collect_git_context(work_dir: KaosPath) -> str: - """Collect git context information for the explore agent. - - Returns a formatted ```` block, or an empty string if the - directory is not a git repository or all git commands fail. Every git - command is individually guarded so a single failure never breaks the whole - collection. - """ - cwd = str(work_dir) - - # Quick check: is this a git repo? - if await _run_git(["rev-parse", "--is-inside-work-tree"], cwd) is None: - return "" - - # Run all git commands in parallel for speed - remote_url, branch, dirty_raw, log_raw = await asyncio.gather( - _run_git(["remote", "get-url", "origin"], cwd), - _run_git(["branch", "--show-current"], cwd), - _run_git(["status", "--porcelain"], cwd), - _run_git(["log", "-3", "--format=%h %s"], cwd), - ) - - sections: list[str] = [] - sections.append(f"Working directory: {cwd}") - - # Remote origin & project name - if remote_url: - safe_url = _sanitize_remote_url(remote_url) - if safe_url: - sections.append(f"Remote: {safe_url}") - project = _parse_project_name(remote_url) - if project: - sections.append(f"Project: {project}") - - # Current branch - if branch: - sections.append(f"Branch: {branch}") - - # Dirty files - if dirty_raw is not None: - dirty_lines = [line for line in dirty_raw.splitlines() if line.strip()] - if dirty_lines: - total = len(dirty_lines) - shown = dirty_lines[:_MAX_DIRTY_FILES] - header = f"Dirty files ({total}):" - body = "\n".join(f" {line}" for line in shown) - if total > _MAX_DIRTY_FILES: - body += f"\n ... and {total - _MAX_DIRTY_FILES} more" - sections.append(f"{header}\n{body}") - - # Recent commits - if log_raw: - log_lines = [line for line in log_raw.splitlines() if line.strip()] - if log_lines: - body = "\n".join(f" {line[:200]}" for line in log_lines) - sections.append(f"Recent commits:\n{body}") - - if len(sections) <= 1: - # Only the working directory line — nothing useful collected - return "" - - content = "\n".join(sections) - return f"\n{content}\n" - - -async def _run_git(args: list[str], cwd: str, timeout: float = _TIMEOUT) -> str | None: - """Run a single git command via kaos.exec and return stripped stdout, or None on failure. - - Uses ``git -C `` so the command runs in the specified directory - regardless of the kaos backend's current working directory. Works - transparently on both local and remote (SSH) backends. - """ - proc = None - try: - proc = await kaos.exec("git", "-C", cwd, *args) - proc.stdin.close() - stdout_bytes = await asyncio.wait_for(proc.stdout.read(-1), timeout=timeout) - exit_code = await asyncio.wait_for(proc.wait(), timeout=timeout) - if exit_code != 0: - return None - return stdout_bytes.decode("utf-8", errors="replace").strip() - except TimeoutError: - logger.debug("git {args} timed out after {t}s", args=args, t=timeout) - if proc is not None: - await proc.kill() - await proc.wait() - return None - except Exception: - logger.debug("git {args} failed", args=args) - if proc is not None and proc.returncode is None: - await proc.kill() - await proc.wait() - return None - - -# Well-known public hosts whose remote URLs are safe to surface and -# recognizable enough for the model to infer project ecosystem context. -_ALLOWED_HOSTS = ( - "github.com", - "gitlab.com", - "gitee.com", - "bitbucket.org", - "codeberg.org", - "sr.ht", -) - - -def _sanitize_remote_url(remote_url: str) -> str | None: - """Return the remote URL if it points to a well-known public host. - - Credentials are stripped from HTTPS URLs. - - Recognizable remote URLs help orient the agent within the broader project - ecosystem (e.g. issue tracker conventions, CI patterns). Self-hosted or - unrecognized hosts are excluded to avoid leaking internal infrastructure. - """ - # SSH format: git@host:owner/repo.git — no credentials possible - for host in _ALLOWED_HOSTS: - if re.match(rf"^git@{re.escape(host)}:", remote_url): - return remote_url - - # HTTPS format: parse hostname exactly, strip userinfo - try: - parsed = urlparse(remote_url) - _ = parsed.port # raises ValueError on malformed port like :443.evil - except ValueError: - return None - if parsed.hostname in _ALLOWED_HOSTS: - # Rebuild without userinfo: https://host[:port]/path - port_part = f":{parsed.port}" if parsed.port else "" - return f"https://{parsed.hostname}{port_part}{parsed.path}" - - return None - - -def _parse_project_name(remote_url: str) -> str | None: - """Extract ``owner/repo`` from a git remote URL. - - Supports typical SSH (e.g. ``git@github.com:owner/repo.git``, - ``git@gitlab.com:owner/repo.git``) and HTTPS (e.g. - ``https://github.com/owner/repo.git``, - ``https://gitee.com/owner/repo.git``) formats by taking the - trailing ``owner/repo`` component regardless of host. - """ - # SSH format: git@host:owner/repo.git - m = re.search(r":([^/]+/[^/]+?)(?:\.git)?$", remote_url) - if m: - return m.group(1) - # HTTPS format: https://host/owner/repo.git - m = re.search(r"/([^/]+/[^/]+?)(?:\.git)?$", remote_url) - if m: - return m.group(1) - return None diff --git a/src/kimi_cli/subagents/git_context.ts b/src/kimi_cli/subagents/git_context.ts new file mode 100644 index 000000000..c0eb09aa4 --- /dev/null +++ b/src/kimi_cli/subagents/git_context.ts @@ -0,0 +1,127 @@ +/** + * Git context collection for explore subagents — corresponds to Python subagents/git_context.py + * Collects git repository metadata (remote URL, branch, dirty files, recent commits). + */ + +import { logger } from "../utils/logging.ts"; + +const TIMEOUT = 5000; // ms +const MAX_DIRTY_FILES = 20; + +const ALLOWED_HOSTS = [ + "github.com", + "gitlab.com", + "gitee.com", + "bitbucket.org", + "codeberg.org", + "sr.ht", +]; + +export async function collectGitContext(workDir: string): Promise { + // Quick check: is this a git repo? + if ((await runGit(["rev-parse", "--is-inside-work-tree"], workDir)) == null) { + return ""; + } + + // Run all git commands in parallel + const [remoteUrl, branch, dirtyRaw, logRaw] = await Promise.all([ + runGit(["remote", "get-url", "origin"], workDir), + runGit(["branch", "--show-current"], workDir), + runGit(["status", "--porcelain"], workDir), + runGit(["log", "-3", "--format=%h %s"], workDir), + ]); + + const sections: string[] = []; + sections.push(`Working directory: ${workDir}`); + + // Remote origin & project name + if (remoteUrl) { + const safeUrl = sanitizeRemoteUrl(remoteUrl); + if (safeUrl) sections.push(`Remote: ${safeUrl}`); + const project = parseProjectName(remoteUrl); + if (project) sections.push(`Project: ${project}`); + } + + // Current branch + if (branch) sections.push(`Branch: ${branch}`); + + // Dirty files + if (dirtyRaw != null) { + const dirtyLines = dirtyRaw.split("\n").filter((l) => l.trim()); + if (dirtyLines.length > 0) { + const total = dirtyLines.length; + const shown = dirtyLines.slice(0, MAX_DIRTY_FILES); + const header = `Dirty files (${total}):`; + let body = shown.map((l) => ` ${l}`).join("\n"); + if (total > MAX_DIRTY_FILES) { + body += `\n ... and ${total - MAX_DIRTY_FILES} more`; + } + sections.push(`${header}\n${body}`); + } + } + + // Recent commits + if (logRaw) { + const logLines = logRaw.split("\n").filter((l) => l.trim()); + if (logLines.length > 0) { + const body = logLines.map((l) => ` ${l.slice(0, 200)}`).join("\n"); + sections.push(`Recent commits:\n${body}`); + } + } + + if (sections.length <= 1) return ""; + const content = sections.join("\n"); + return `\n${content}\n`; +} + +async function runGit(args: string[], cwd: string): Promise { + try { + const proc = Bun.spawn(["git", "-C", cwd, ...args], { + stdout: "pipe", + stderr: "pipe", + stdin: "ignore", + }); + + const timer = setTimeout(() => proc.kill(), TIMEOUT); + const exitCode = await proc.exited; + clearTimeout(timer); + + if (exitCode !== 0) return undefined; + const stdout = await new Response(proc.stdout).text(); + return stdout.trim(); + } catch { + logger.debug(`git ${args.join(" ")} failed`); + return undefined; + } +} + +function sanitizeRemoteUrl(remoteUrl: string): string | undefined { + // SSH format: git@host:owner/repo.git + for (const host of ALLOWED_HOSTS) { + const pattern = new RegExp(`^git@${host.replace(".", "\\.")}:`); + if (pattern.test(remoteUrl)) return remoteUrl; + } + + // HTTPS format + try { + const url = new URL(remoteUrl); + if (ALLOWED_HOSTS.includes(url.hostname)) { + const port = url.port ? `:${url.port}` : ""; + return `https://${url.hostname}${port}${url.pathname}`; + } + } catch { + // Not a valid URL + } + + return undefined; +} + +function parseProjectName(remoteUrl: string): string | undefined { + // SSH format: git@host:owner/repo.git + const sshMatch = remoteUrl.match(/:([^/]+\/[^/]+?)(?:\.git)?$/); + if (sshMatch) return sshMatch[1]; + // HTTPS format: https://host/owner/repo.git + const httpsMatch = remoteUrl.match(/\/([^/]+\/[^/]+?)(?:\.git)?$/); + if (httpsMatch) return httpsMatch[1]; + return undefined; +} diff --git a/src/kimi_cli/subagents/models.py b/src/kimi_cli/subagents/models.py deleted file mode 100644 index f31aa530f..000000000 --- a/src/kimi_cli/subagents/models.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import Literal - -type ToolPolicyMode = Literal["inherit", "allowlist"] -type SubagentStatus = Literal[ - "idle", - "running_foreground", - "running_background", - "completed", - "failed", - "killed", -] - - -@dataclass(frozen=True, slots=True, kw_only=True) -class ToolPolicy: - mode: ToolPolicyMode - tools: tuple[str, ...] = () - - -@dataclass(frozen=True, slots=True, kw_only=True) -class AgentTypeDefinition: - name: str - description: str - agent_file: Path - when_to_use: str = "" - default_model: str | None = None - tool_policy: ToolPolicy = field(default_factory=lambda: ToolPolicy(mode="inherit")) - supports_background: bool = True - - -@dataclass(frozen=True, slots=True, kw_only=True) -class AgentLaunchSpec: - agent_id: str - subagent_type: str - model_override: str | None - effective_model: str | None - created_at: float = field(default_factory=time.time) - - -@dataclass(frozen=True, slots=True, kw_only=True) -class AgentInstanceRecord: - agent_id: str - subagent_type: str - status: SubagentStatus - description: str - created_at: float - updated_at: float - last_task_id: str | None - launch_spec: AgentLaunchSpec diff --git a/src/kimi_cli/subagents/models.ts b/src/kimi_cli/subagents/models.ts new file mode 100644 index 000000000..cdc156525 --- /dev/null +++ b/src/kimi_cli/subagents/models.ts @@ -0,0 +1,52 @@ +/** + * Subagent models — corresponds to Python subagents/models.py + */ + +export type ToolPolicyMode = "inherit" | "allowlist"; +export type SubagentStatus = + | "idle" + | "running_foreground" + | "running_background" + | "completed" + | "failed" + | "killed"; + +export interface ToolPolicy { + readonly mode: ToolPolicyMode; + readonly tools: readonly string[]; +} + +export interface AgentTypeDefinition { + readonly name: string; + readonly description: string; + readonly agentFile: string; + readonly whenToUse: string; + readonly defaultModel?: string; + readonly toolPolicy: ToolPolicy; + readonly supportsBackground: boolean; +} + +export interface AgentLaunchSpec { + readonly agentId: string; + readonly subagentType: string; + readonly modelOverride?: string; + readonly effectiveModel?: string; + readonly createdAt: number; +} + +export interface AgentInstanceRecord { + readonly agentId: string; + readonly subagentType: string; + readonly status: SubagentStatus; + readonly description: string; + readonly createdAt: number; + readonly updatedAt: number; + readonly lastTaskId?: string; + readonly launchSpec: AgentLaunchSpec; +} + +// ── Defaults ── + +export function defaultToolPolicy(): ToolPolicy { + return { mode: "inherit", tools: [] }; +} diff --git a/src/kimi_cli/subagents/output.py b/src/kimi_cli/subagents/output.py deleted file mode 100644 index 849ca585a..000000000 --- a/src/kimi_cli/subagents/output.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Unified output writer for subagent executions (foreground and background).""" - -from __future__ import annotations - -from collections.abc import Sequence -from pathlib import Path - -from kosong.message import TextPart, ToolCall, ToolCallPart -from kosong.tooling import ToolResult - - -class SubagentOutputWriter: - """Appends human-readable transcript lines to one or more output files. - - Both foreground and background runners use this so the output format - is identical regardless of execution mode. When *extra_paths* are - provided every write is tee'd to those files as well (used by the - background agent runner to keep the task ``output.log`` in sync with - the canonical subagent output). - """ - - def __init__(self, path: Path, *, extra_paths: Sequence[Path] = ()) -> None: - self._path = path - self._extra_paths = extra_paths - - def stage(self, name: str) -> None: - self._append(f"[stage] {name}\n") - - def tool_call(self, tc: ToolCall) -> None: - name = tc.function.name if tc.function else "?" - self._append(f"[tool] {name}\n") - - def tool_result(self, tr: ToolResult) -> None: - status = "error" if tr.return_value.is_error else "success" - brief = getattr(tr.return_value, "brief", None) - if brief: - self._append(f"[tool_result] {status}: {brief}\n") - else: - self._append(f"[tool_result] {status}\n") - - def text(self, text: str) -> None: - if text: - self._append(text) - - def summary(self, text: str) -> None: - if text: - self._append(f"\n[summary]\n{text}\n") - - def error(self, message: str) -> None: - self._append(f"[error] {message}\n") - - def write_wire_message(self, msg: object) -> None: - """Dispatch a wire message to the appropriate writer method.""" - if isinstance(msg, TextPart): - self.text(msg.text) - elif isinstance(msg, ToolCall): - self.tool_call(msg) - elif isinstance(msg, ToolResult): - self.tool_result(msg) - elif isinstance(msg, ToolCallPart): - pass # incremental argument chunks — not useful in transcript - - def _append(self, text: str) -> None: - with self._path.open("a", encoding="utf-8") as f: - f.write(text) - for p in self._extra_paths: - try: - with p.open("a", encoding="utf-8") as f: - f.write(text) - except OSError: - pass # best-effort — never interrupt the agent for a tee failure diff --git a/src/kimi_cli/subagents/output.ts b/src/kimi_cli/subagents/output.ts new file mode 100644 index 000000000..59d7d2c7e --- /dev/null +++ b/src/kimi_cli/subagents/output.ts @@ -0,0 +1,59 @@ +/** + * Subagent output writer — corresponds to Python subagents/output.py + * Appends human-readable transcript lines to output files. + */ + +import { appendFileSync } from "node:fs"; + +export class SubagentOutputWriter { + private _path: string; + private _extraPaths: string[]; + + constructor(path: string, extraPaths: string[] = []) { + this._path = path; + this._extraPaths = extraPaths; + } + + stage(name: string): void { + this.append(`[stage] ${name}\n`); + } + + toolCall(name: string): void { + this.append(`[tool] ${name}\n`); + } + + toolResult(status: "success" | "error", brief?: string): void { + if (brief) { + this.append(`[tool_result] ${status}: ${brief}\n`); + } else { + this.append(`[tool_result] ${status}\n`); + } + } + + text(text: string): void { + if (text) this.append(text); + } + + summary(text: string): void { + if (text) this.append(`\n[summary]\n${text}\n`); + } + + error(message: string): void { + this.append(`[error] ${message}\n`); + } + + private append(text: string): void { + try { + appendFileSync(this._path, text, "utf-8"); + } catch { + // Ignore write errors + } + for (const p of this._extraPaths) { + try { + appendFileSync(p, text, "utf-8"); + } catch { + // Best-effort tee + } + } + } +} diff --git a/src/kimi_cli/subagents/registry.py b/src/kimi_cli/subagents/registry.py deleted file mode 100644 index 82a6791b7..000000000 --- a/src/kimi_cli/subagents/registry.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping - -from kimi_cli.subagents.models import AgentTypeDefinition - - -class LaborMarket: - """Registry of built-in subagent types.""" - - def __init__(self) -> None: - self._builtin_types: dict[str, AgentTypeDefinition] = {} - - @property - def builtin_types(self) -> Mapping[str, AgentTypeDefinition]: - return self._builtin_types - - def add_builtin_type(self, type_def: AgentTypeDefinition) -> None: - self._builtin_types[type_def.name] = type_def - - def get_builtin_type(self, name: str) -> AgentTypeDefinition | None: - return self._builtin_types.get(name) - - def require_builtin_type(self, name: str) -> AgentTypeDefinition: - type_def = self.get_builtin_type(name) - if type_def is None: - raise KeyError(f"Builtin subagent type not found: {name}") - return type_def diff --git a/src/kimi_cli/subagents/registry.ts b/src/kimi_cli/subagents/registry.ts new file mode 100644 index 000000000..3463db11c --- /dev/null +++ b/src/kimi_cli/subagents/registry.ts @@ -0,0 +1,30 @@ +/** + * Subagent type registry — corresponds to Python subagents/registry.py + * LaborMarket holds the available agent type definitions. + */ + +import type { AgentTypeDefinition } from "./models.ts"; + +export class LaborMarket { + private _builtinTypes = new Map(); + + get builtinTypes(): ReadonlyMap { + return this._builtinTypes; + } + + addBuiltinType(typeDef: AgentTypeDefinition): void { + this._builtinTypes.set(typeDef.name, typeDef); + } + + getBuiltinType(name: string): AgentTypeDefinition | undefined { + return this._builtinTypes.get(name); + } + + requireBuiltinType(name: string): AgentTypeDefinition { + const typeDef = this._builtinTypes.get(name); + if (!typeDef) { + throw new Error(`Builtin subagent type not found: ${name}`); + } + return typeDef; + } +} diff --git a/src/kimi_cli/subagents/runner.py b/src/kimi_cli/subagents/runner.py deleted file mode 100644 index 36a0426b7..000000000 --- a/src/kimi_cli/subagents/runner.py +++ /dev/null @@ -1,370 +0,0 @@ -from __future__ import annotations - -import asyncio -import uuid -from dataclasses import dataclass, replace -from pathlib import Path -from typing import TYPE_CHECKING - -from kosong.tooling import ToolError, ToolOk, ToolReturnValue - -from kimi_cli.approval_runtime import ( - ApprovalSource, - reset_current_approval_source, - set_current_approval_source, -) -from kimi_cli.soul import MaxStepsReached, UILoopFn, get_wire_or_none, run_soul -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.subagents.builder import SubagentBuilder -from kimi_cli.subagents.core import SubagentRunSpec, prepare_soul -from kimi_cli.subagents.models import AgentInstanceRecord, AgentLaunchSpec -from kimi_cli.subagents.output import SubagentOutputWriter -from kimi_cli.subagents.store import SubagentStore -from kimi_cli.wire import Wire -from kimi_cli.wire.file import WireFile -from kimi_cli.wire.types import ( - ApprovalRequest, - ApprovalResponse, - HookRequest, - QuestionRequest, - SubagentEvent, - ToolCallRequest, -) - -if TYPE_CHECKING: - from kimi_cli.soul.agent import Runtime - -SUMMARY_MIN_LENGTH = 200 -SUMMARY_CONTINUATION_ATTEMPTS = 1 -SUMMARY_CONTINUATION_PROMPT = """ -Your previous response was too brief. Please provide a more comprehensive summary that includes: - -1. Specific technical details and implementations -2. Detailed findings and analysis -3. All important information that the parent agent should know -""".strip() - - -# --------------------------------------------------------------------------- -# Shared result types and execution helpers (used by both foreground and -# background runners). -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True, slots=True, kw_only=True) -class SoulRunFailure: - """Describes why a soul run did not produce a usable result.""" - - message: str - brief: str - - -async def run_soul_checked( - soul: KimiSoul, - prompt: str, - ui_loop_fn: UILoopFn, - wire_path: Path, - phase: str, -) -> SoulRunFailure | None: - """Run a single soul turn and validate the result. - - Returns a ``SoulRunFailure`` if the run failed or produced an invalid - result, or ``None`` on success. ``MaxStepsReached`` is converted to a - failure; ``CancelledError`` and other exceptions are re-raised. - """ - try: - await run_soul( - soul, - prompt, - ui_loop_fn, - asyncio.Event(), - wire_file=WireFile(wire_path), - runtime=soul.runtime, - ) - except MaxStepsReached as exc: - return SoulRunFailure( - message=( - f"Max steps {exc.n_steps} reached when {phase}. " - "Please try splitting the task into smaller subtasks." - ), - brief="Max steps reached", - ) - - context = soul.context - if not context.history or context.history[-1].role != "assistant": - return SoulRunFailure( - message="The agent did not produce a valid assistant response.", - brief="Invalid agent result", - ) - return None - - -async def run_with_summary_continuation( - soul: KimiSoul, - prompt: str, - ui_loop_fn: UILoopFn, - wire_path: Path, -) -> tuple[str | None, SoulRunFailure | None]: - """Run soul, then optionally extend the summary if it is too short. - - Returns ``(final_response, failure)``. On success ``failure`` is - ``None`` and ``final_response`` contains the agent's output text. - On failure ``final_response`` is ``None``. - """ - failure = await run_soul_checked(soul, prompt, ui_loop_fn, wire_path, "running agent") - if failure is not None: - return None, failure - - final_response = soul.context.history[-1].extract_text(sep="\n") - remaining = SUMMARY_CONTINUATION_ATTEMPTS - while remaining > 0 and len(final_response) < SUMMARY_MIN_LENGTH: - remaining -= 1 - failure = await run_soul_checked( - soul, - SUMMARY_CONTINUATION_PROMPT, - ui_loop_fn, - wire_path, - "continuing the agent summary", - ) - if failure is not None: - return None, failure - final_response = soul.context.history[-1].extract_text(sep="\n") - - return final_response, None - - -# --------------------------------------------------------------------------- -# Foreground runner -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True, slots=True, kw_only=True) -class ForegroundRunRequest: - description: str - prompt: str - requested_type: str - model: str | None - resume: str | None - - -@dataclass(frozen=True, slots=True, kw_only=True) -class PreparedInstance: - record: AgentInstanceRecord - actual_type: str - resumed: bool - - -class ForegroundSubagentRunner: - def __init__(self, runtime: Runtime): - self._runtime = runtime - assert runtime.subagent_store is not None - self._store: SubagentStore = runtime.subagent_store - self._builder = SubagentBuilder(runtime) - - async def run(self, req: ForegroundRunRequest) -> ToolReturnValue: - prepared = await self._prepare_instance(req) - agent_id = prepared.record.agent_id - actual_type = prepared.actual_type - resumed = prepared.resumed - - type_def = self._runtime.labor_market.require_builtin_type(actual_type) - launch_spec = prepared.record.launch_spec - if req.model is not None: - launch_spec = replace( - launch_spec, - model_override=req.model, - effective_model=req.model, - ) - - output_writer = SubagentOutputWriter(self._store.output_path(agent_id)) - output_writer.stage("runner_started") - - spec = SubagentRunSpec( - agent_id=agent_id, - type_def=type_def, - launch_spec=launch_spec, - prompt=req.prompt, - resumed=resumed, - ) - soul, prompt = await prepare_soul( - spec, - self._runtime, - self._builder, - self._store, - on_stage=output_writer.stage, - ) - - self._store.update_instance( - agent_id, - status="running_foreground", - description=req.description.strip(), - ) - # Propagate hook engine from parent runtime to subagent soul - if self._runtime.hook_engine is not None: - soul.set_hook_engine(self._runtime.hook_engine) - tool_call = get_current_tool_call_or_none() - ui_loop_fn = self._make_ui_loop_fn( - parent_tool_call_id=tool_call.id if tool_call is not None else None, - agent_id=agent_id, - subagent_type=actual_type, - output_writer=output_writer, - ) - - # Use a single stable ApprovalSource for the entire run (including summary - # continuation). This ensures cancel_by_source can reliably cancel all - # pending approval requests belonging to this foreground subagent execution. - approval_source = ApprovalSource( - kind="foreground_turn", - id=uuid.uuid4().hex, - agent_id=agent_id, - subagent_type=actual_type, - ) - approval_source_token = set_current_approval_source(approval_source) - try: - # --- SubagentStart hook --- - hook_engine = soul.hook_engine - from kimi_cli.hooks import events as hook_events - - await hook_engine.trigger( - "SubagentStart", - matcher_value=actual_type, - input_data=hook_events.subagent_start( - session_id=self._runtime.session.id, - cwd=str(Path.cwd()), - agent_name=actual_type, - prompt=req.prompt[:500], - ), - ) - - output_writer.stage("run_soul_start") - final_response, failure = await run_with_summary_continuation( - soul, - prompt, - ui_loop_fn, - self._store.wire_path(agent_id), - ) - if failure is not None: - self._store.update_instance(agent_id, status="failed") - output_writer.stage(f"failed: {failure.brief}") - return ToolError(message=failure.message, brief=failure.brief) - output_writer.stage("run_soul_finished") - - # --- SubagentStop hook --- - _hook_task = asyncio.create_task( - hook_engine.trigger( - "SubagentStop", - matcher_value=actual_type, - input_data=hook_events.subagent_stop( - session_id=self._runtime.session.id, - cwd=str(Path.cwd()), - agent_name=actual_type, - response=(final_response or "")[:500], - ), - ) - ) - _hook_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) - except asyncio.CancelledError: - self._store.update_instance(agent_id, status="killed") - output_writer.stage("cancelled") - raise - except Exception: - self._store.update_instance(agent_id, status="failed") - output_writer.stage("failed_exception") - raise - finally: - reset_current_approval_source(approval_source_token) - if self._runtime.approval_runtime is not None: - self._runtime.approval_runtime.cancel_by_source( - approval_source.kind, approval_source.id - ) - - assert final_response is not None - self._store.update_instance(agent_id, status="idle") - output_writer.summary(final_response) - lines = [ - f"agent_id: {agent_id}", - "resumed: true" if resumed else "resumed: false", - ] - if resumed and req.requested_type and req.requested_type != actual_type: - lines.append(f"requested_subagent_type: {req.requested_type}") - lines.extend( - [ - f"actual_subagent_type: {actual_type}", - "status: completed", - "", - "[summary]", - final_response, - ] - ) - return ToolOk(output="\n".join(lines)) - - async def _prepare_instance(self, req: ForegroundRunRequest) -> PreparedInstance: - if req.resume: - record = self._store.require_instance(req.resume) - if record.status in {"running_foreground", "running_background"}: - raise RuntimeError( - f"Agent instance {record.agent_id} is still {record.status} and cannot be " - "resumed concurrently." - ) - return PreparedInstance( - record=record, - actual_type=record.subagent_type, - resumed=True, - ) - - actual_type = req.requested_type or "coder" - type_def = self._runtime.labor_market.require_builtin_type(actual_type) - agent_id = f"a{uuid.uuid4().hex[:8]}" - record = self._store.create_instance( - agent_id=agent_id, - description=req.description.strip(), - launch_spec=AgentLaunchSpec( - agent_id=agent_id, - subagent_type=actual_type, - model_override=req.model, - effective_model=req.model or type_def.default_model, - ), - ) - return PreparedInstance( - record=record, - actual_type=actual_type, - resumed=False, - ) - - @staticmethod - def _make_ui_loop_fn( - *, - parent_tool_call_id: str | None, - agent_id: str, - subagent_type: str, - output_writer: SubagentOutputWriter, - ): - super_wire = get_wire_or_none() - - async def _ui_loop_fn(wire: Wire) -> None: - wire_ui = wire.ui_side(merge=True) - while True: - msg = await wire_ui.receive() - # Always write to output file regardless of wire availability. - output_writer.write_wire_message(msg) - if super_wire is None or parent_tool_call_id is None: - continue - if isinstance( - msg, - ApprovalRequest | ApprovalResponse | ToolCallRequest | QuestionRequest, - ): - super_wire.soul_side.send(msg) - continue - if isinstance(msg, HookRequest): - continue - super_wire.soul_side.send( - SubagentEvent( - parent_tool_call_id=parent_tool_call_id, - agent_id=agent_id, - subagent_type=subagent_type, - event=msg, - ) - ) - - return _ui_loop_fn diff --git a/src/kimi_cli/subagents/runner.ts b/src/kimi_cli/subagents/runner.ts new file mode 100644 index 000000000..d6d8970f5 --- /dev/null +++ b/src/kimi_cli/subagents/runner.ts @@ -0,0 +1,31 @@ +/** + * Foreground subagent runner — corresponds to Python subagents/runner.py + * Manages the lifecycle of foreground subagent executions. + */ + +export interface ForegroundRunRequest { + readonly description: string; + readonly prompt: string; + readonly requestedType: string; + readonly model?: string; + readonly resume?: string; +} + +export interface PreparedInstance { + readonly agentId: string; + readonly actualType: string; + readonly resumed: boolean; +} + +export interface SoulRunFailure { + readonly message: string; + readonly brief: string; +} + +export const SUMMARY_MIN_LENGTH = 200; +export const SUMMARY_CONTINUATION_ATTEMPTS = 1; +export const SUMMARY_CONTINUATION_PROMPT = `Your previous response was too brief. Please provide a more comprehensive summary that includes: + +1. Specific technical details and implementations +2. Detailed findings and analysis +3. All important information that the parent agent should know`; diff --git a/src/kimi_cli/subagents/store.py b/src/kimi_cli/subagents/store.py deleted file mode 100644 index ac5285568..000000000 --- a/src/kimi_cli/subagents/store.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import shutil -from dataclasses import asdict -from pathlib import Path -from typing import Any, cast - -from kimi_cli.session import Session -from kimi_cli.subagents.models import AgentInstanceRecord, AgentLaunchSpec, SubagentStatus -from kimi_cli.utils.io import atomic_json_write - - -def _record_from_dict(data: dict[str, Any]) -> AgentInstanceRecord: - launch_spec = data["launch_spec"] - return AgentInstanceRecord( - agent_id=data["agent_id"], - subagent_type=data["subagent_type"], - status=data["status"], - description=data["description"], - created_at=data["created_at"], - updated_at=data["updated_at"], - last_task_id=data.get("last_task_id"), - launch_spec=AgentLaunchSpec(**launch_spec), - ) - - -class SubagentStore: - def __init__(self, session: Session) -> None: - self._session = session - - @property - def root(self) -> Path: - return self._session.dir / "subagents" - - def instance_dir(self, agent_id: str, *, create: bool = False) -> Path: - path = self.root / agent_id - if create: - path.mkdir(parents=True, exist_ok=True) - return path - - def context_path(self, agent_id: str) -> Path: - return self.instance_dir(agent_id) / "context.jsonl" - - def wire_path(self, agent_id: str) -> Path: - return self.instance_dir(agent_id) / "wire.jsonl" - - def meta_path(self, agent_id: str) -> Path: - return self.instance_dir(agent_id) / "meta.json" - - def prompt_path(self, agent_id: str) -> Path: - return self.instance_dir(agent_id) / "prompt.txt" - - def output_path(self, agent_id: str) -> Path: - return self.instance_dir(agent_id) / "output" - - def create_instance( - self, - *, - agent_id: str, - description: str, - launch_spec: AgentLaunchSpec, - ) -> AgentInstanceRecord: - self._initialize_instance_files(agent_id) - record = AgentInstanceRecord( - agent_id=agent_id, - subagent_type=launch_spec.subagent_type, - status="idle", - description=description, - created_at=launch_spec.created_at, - updated_at=launch_spec.created_at, - last_task_id=None, - launch_spec=launch_spec, - ) - self.write_instance(record) - return record - - def write_instance(self, record: AgentInstanceRecord) -> None: - instance_dir = self.instance_dir(record.agent_id) - atomic_json_write(asdict(record), instance_dir / "meta.json") - - def _initialize_instance_files(self, agent_id: str) -> None: - instance_dir = self.instance_dir(agent_id, create=True) - (instance_dir / "context.jsonl").touch(exist_ok=True) - (instance_dir / "wire.jsonl").touch(exist_ok=True) - (instance_dir / "prompt.txt").touch(exist_ok=True) - (instance_dir / "output").touch(exist_ok=True) - - def get_instance(self, agent_id: str) -> AgentInstanceRecord | None: - meta = self.meta_path(agent_id) - if not meta.exists(): - return None - import json - - return _record_from_dict(json.loads(meta.read_text(encoding="utf-8"))) - - def require_instance(self, agent_id: str) -> AgentInstanceRecord: - record = self.get_instance(agent_id) - if record is None: - raise FileNotFoundError(f"Subagent instance not found: {agent_id}") - return record - - def update_instance( - self, - agent_id: str, - *, - status: SubagentStatus | None = None, - description: str | None = None, - last_task_id: str | None | object = ..., - ) -> AgentInstanceRecord: - import time - - current = self.require_instance(agent_id) - record = AgentInstanceRecord( - agent_id=current.agent_id, - subagent_type=current.subagent_type, - status=current.status if status is None else status, - description=current.description if description is None else description, - created_at=current.created_at, - updated_at=time.time(), - last_task_id=( - current.last_task_id if last_task_id is ... else cast(str | None, last_task_id) - ), - launch_spec=current.launch_spec, - ) - self.write_instance(record) - return record - - def list_instances(self) -> list[AgentInstanceRecord]: - records: list[AgentInstanceRecord] = [] - if not self.root.exists(): - return records - for path in self.root.iterdir(): - if not path.is_dir(): - continue - meta = path / "meta.json" - if not meta.exists(): - continue - import json - - records.append(_record_from_dict(json.loads(meta.read_text(encoding="utf-8")))) - records.sort(key=lambda record: record.updated_at, reverse=True) - return records - - def delete_instance(self, agent_id: str) -> None: - instance_dir = self.instance_dir(agent_id) - if not instance_dir.exists(): - return - shutil.rmtree(instance_dir) diff --git a/src/kimi_cli/subagents/store.ts b/src/kimi_cli/subagents/store.ts new file mode 100644 index 000000000..a59b34ff0 --- /dev/null +++ b/src/kimi_cli/subagents/store.ts @@ -0,0 +1,199 @@ +/** + * Subagent store — corresponds to Python subagents/store.py + * File-based persistence for subagent instance metadata. + */ + +import { join } from "node:path"; +import { + existsSync, + mkdirSync, + readFileSync, + writeFileSync, + readdirSync, + statSync, + rmSync, +} from "node:fs"; +import type { + AgentInstanceRecord, + AgentLaunchSpec, + SubagentStatus, +} from "./models.ts"; + +function recordFromJson(data: Record): AgentInstanceRecord { + const launchSpec = data.launch_spec as Record ?? data.launchSpec as Record ?? {}; + return { + agentId: String(data.agent_id ?? data.agentId ?? ""), + subagentType: String(data.subagent_type ?? data.subagentType ?? ""), + status: String(data.status ?? "idle") as SubagentStatus, + description: String(data.description ?? ""), + createdAt: Number(data.created_at ?? data.createdAt ?? 0), + updatedAt: Number(data.updated_at ?? data.updatedAt ?? 0), + lastTaskId: (data.last_task_id ?? data.lastTaskId) as string | undefined, + launchSpec: { + agentId: String(launchSpec.agent_id ?? launchSpec.agentId ?? ""), + subagentType: String(launchSpec.subagent_type ?? launchSpec.subagentType ?? ""), + modelOverride: (launchSpec.model_override ?? launchSpec.modelOverride) as string | undefined, + effectiveModel: (launchSpec.effective_model ?? launchSpec.effectiveModel) as string | undefined, + createdAt: Number(launchSpec.created_at ?? launchSpec.createdAt ?? 0), + }, + }; +} + +function recordToJson(record: AgentInstanceRecord): Record { + return { + agent_id: record.agentId, + subagent_type: record.subagentType, + status: record.status, + description: record.description, + created_at: record.createdAt, + updated_at: record.updatedAt, + last_task_id: record.lastTaskId, + launch_spec: { + agent_id: record.launchSpec.agentId, + subagent_type: record.launchSpec.subagentType, + model_override: record.launchSpec.modelOverride, + effective_model: record.launchSpec.effectiveModel, + created_at: record.launchSpec.createdAt, + }, + }; +} + +export class SubagentStore { + private _root: string; + + constructor(root: string) { + this._root = root; + } + + get root(): string { + return this._root; + } + + instanceDir(agentId: string, create = false): string { + const path = join(this._root, agentId); + if (create && !existsSync(path)) { + mkdirSync(path, { recursive: true }); + } + return path; + } + + contextPath(agentId: string): string { + return join(this.instanceDir(agentId), "context.jsonl"); + } + + wirePath(agentId: string): string { + return join(this.instanceDir(agentId), "wire.jsonl"); + } + + metaPath(agentId: string): string { + return join(this.instanceDir(agentId), "meta.json"); + } + + promptPath(agentId: string): string { + return join(this.instanceDir(agentId), "prompt.txt"); + } + + outputPath(agentId: string): string { + return join(this.instanceDir(agentId), "output"); + } + + createInstance(opts: { + agentId: string; + description: string; + launchSpec: AgentLaunchSpec; + }): AgentInstanceRecord { + this.initializeInstanceFiles(opts.agentId); + const record: AgentInstanceRecord = { + agentId: opts.agentId, + subagentType: opts.launchSpec.subagentType, + status: "idle", + description: opts.description, + createdAt: opts.launchSpec.createdAt, + updatedAt: opts.launchSpec.createdAt, + launchSpec: opts.launchSpec, + }; + this.writeInstance(record); + return record; + } + + writeInstance(record: AgentInstanceRecord): void { + const dir = this.instanceDir(record.agentId, true); + const tmpPath = join(dir, "meta.json.tmp"); + const targetPath = join(dir, "meta.json"); + writeFileSync(tmpPath, JSON.stringify(recordToJson(record), null, 2), "utf-8"); + const { renameSync } = require("node:fs"); + renameSync(tmpPath, targetPath); + } + + private initializeInstanceFiles(agentId: string): void { + const dir = this.instanceDir(agentId, true); + for (const name of ["context.jsonl", "wire.jsonl", "prompt.txt", "output"]) { + const path = join(dir, name); + if (!existsSync(path)) { + writeFileSync(path, "", "utf-8"); + } + } + } + + getInstance(agentId: string): AgentInstanceRecord | undefined { + const meta = this.metaPath(agentId); + if (!existsSync(meta)) return undefined; + const data = JSON.parse(readFileSync(meta, "utf-8")); + return recordFromJson(data); + } + + requireInstance(agentId: string): AgentInstanceRecord { + const record = this.getInstance(agentId); + if (!record) { + throw new Error(`Subagent instance not found: ${agentId}`); + } + return record; + } + + updateInstance( + agentId: string, + opts?: { + status?: SubagentStatus; + description?: string; + lastTaskId?: string | null; + }, + ): AgentInstanceRecord { + const current = this.requireInstance(agentId); + const record: AgentInstanceRecord = { + agentId: current.agentId, + subagentType: current.subagentType, + status: opts?.status ?? current.status, + description: opts?.description ?? current.description, + createdAt: current.createdAt, + updatedAt: Date.now() / 1000, + lastTaskId: opts?.lastTaskId !== undefined ? (opts.lastTaskId ?? undefined) : current.lastTaskId, + launchSpec: current.launchSpec, + }; + this.writeInstance(record); + return record; + } + + listInstances(): AgentInstanceRecord[] { + if (!existsSync(this._root)) return []; + const records: AgentInstanceRecord[] = []; + for (const entry of readdirSync(this._root)) { + const dirPath = join(this._root, entry); + try { + if (!statSync(dirPath).isDirectory()) continue; + } catch { + continue; + } + const meta = join(dirPath, "meta.json"); + if (!existsSync(meta)) continue; + records.push(recordFromJson(JSON.parse(readFileSync(meta, "utf-8")))); + } + records.sort((a, b) => b.updatedAt - a.updatedAt); + return records; + } + + deleteInstance(agentId: string): void { + const dir = this.instanceDir(agentId); + if (!existsSync(dir)) return; + rmSync(dir, { recursive: true, force: true }); + } +} diff --git a/src/kimi_cli/tools/AGENTS.md b/src/kimi_cli/tools/AGENTS.md deleted file mode 100644 index 5e3f4d6cf..000000000 --- a/src/kimi_cli/tools/AGENTS.md +++ /dev/null @@ -1,5 +0,0 @@ -# Kimi Code CLI Tools - -## Guidelines - -- Tools should not refer to types in `kimi_cli/wire/` unless they are explicitly implementing a UI / runtime bridge. When importing things like `ToolReturnValue` or `DisplayBlock`, prefer `kosong.tooling`. diff --git a/src/kimi_cli/tools/__init__.py b/src/kimi_cli/tools/__init__.py deleted file mode 100644 index 371a9d498..000000000 --- a/src/kimi_cli/tools/__init__.py +++ /dev/null @@ -1,105 +0,0 @@ -import json -from typing import cast - -import streamingjson # type: ignore[reportMissingTypeStubs] -from kaos.path import KaosPath -from kosong.utils.typing import JsonType - -from kimi_cli.utils.string import shorten_middle - - -class SkipThisTool(Exception): - """Raised when a tool decides to skip itself from the loading process.""" - - pass - - -def extract_key_argument(json_content: str | streamingjson.Lexer, tool_name: str) -> str | None: - if isinstance(json_content, streamingjson.Lexer): - json_str = json_content.complete_json() - else: - json_str = json_content - try: - curr_args: JsonType = json.loads(json_str, strict=False) - except json.JSONDecodeError: - return None - if not curr_args: - return None - key_argument: str = "" - match tool_name: - case "Agent": - if not isinstance(curr_args, dict) or not curr_args.get("description"): - return None - key_argument = str(curr_args["description"]) - case "SendDMail": - return None - case "Think": - if not isinstance(curr_args, dict) or not curr_args.get("thought"): - return None - key_argument = str(curr_args["thought"]) - case "SetTodoList": - return None - case "Shell": - if not isinstance(curr_args, dict) or not curr_args.get("command"): - return None - key_argument = str(curr_args["command"]) - case "TaskOutput": - if not isinstance(curr_args, dict) or not curr_args.get("task_id"): - return None - key_argument = str(curr_args["task_id"]) - case "TaskList": - if not isinstance(curr_args, dict): - return None - key_argument = "active" if curr_args.get("active_only", True) else "all" - case "TaskStop": - if not isinstance(curr_args, dict) or not curr_args.get("task_id"): - return None - key_argument = str(curr_args["task_id"]) - case "ReadFile": - if not isinstance(curr_args, dict) or not curr_args.get("path"): - return None - key_argument = _normalize_path(str(curr_args["path"])) - case "ReadMediaFile": - if not isinstance(curr_args, dict) or not curr_args.get("path"): - return None - key_argument = _normalize_path(str(curr_args["path"])) - case "Glob": - if not isinstance(curr_args, dict) or not curr_args.get("pattern"): - return None - key_argument = str(curr_args["pattern"]) - case "Grep": - if not isinstance(curr_args, dict) or not curr_args.get("pattern"): - return None - key_argument = str(curr_args["pattern"]) - case "WriteFile": - if not isinstance(curr_args, dict) or not curr_args.get("path"): - return None - key_argument = _normalize_path(str(curr_args["path"])) - case "StrReplaceFile": - if not isinstance(curr_args, dict) or not curr_args.get("path"): - return None - key_argument = _normalize_path(str(curr_args["path"])) - case "SearchWeb": - if not isinstance(curr_args, dict) or not curr_args.get("query"): - return None - key_argument = str(curr_args["query"]) - case "FetchURL": - if not isinstance(curr_args, dict) or not curr_args.get("url"): - return None - key_argument = str(curr_args["url"]) - case _: - if isinstance(json_content, streamingjson.Lexer): - # lexer.json_content is list[str] based on streamingjson source code - content: list[str] = cast(list[str], json_content.json_content) # type: ignore[reportUnknownMemberType] - key_argument = "".join(content) - else: - key_argument = json_content - key_argument = shorten_middle(key_argument, width=50) - return key_argument - - -def _normalize_path(path: str) -> str: - cwd = str(KaosPath.cwd().canonical()) - if path.startswith(cwd): - path = path[len(cwd) :].lstrip("/\\") - return path diff --git a/src/kimi_cli/tools/agent/__init__.py b/src/kimi_cli/tools/agent/__init__.py deleted file mode 100644 index cb27e41eb..000000000 --- a/src/kimi_cli/tools/agent/__init__.py +++ /dev/null @@ -1,276 +0,0 @@ -import asyncio -from pathlib import Path -from typing import override - -from kosong.tooling import CallableTool2, ToolError, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.soul.agent import Runtime -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.subagents.models import AgentLaunchSpec, AgentTypeDefinition -from kimi_cli.subagents.runner import ForegroundRunRequest, ForegroundSubagentRunner -from kimi_cli.tools.utils import load_desc -from kimi_cli.utils.logging import logger - -NAME = "Agent" - -MAX_FOREGROUND_TIMEOUT = 60 * 60 # 1 hour -MAX_BACKGROUND_TIMEOUT = 60 * 60 # 1 hour - - -class Params(BaseModel): - description: str = Field(description="A short (3-5 word) description of the task") - prompt: str = Field(description="The task for the agent to perform") - subagent_type: str = Field( - default="coder", - description="The built-in agent type to use. Defaults to `coder`.", - ) - model: str | None = Field( - default=None, - description=( - "Optional model override. Selection priority is: this parameter, then the built-in " - "type default model, then the parent agent's current model." - ), - ) - resume: str | None = Field( - default=None, - description="Optional agent ID to resume instead of creating a new instance.", - ) - run_in_background: bool = Field( - default=False, - description=( - "Whether to run the agent in the background. Prefer false unless the task can " - "continue independently and there is a clear benefit to returning control before " - "the result is needed." - ), - ) - timeout: int | None = Field( - default=None, - description=( - "Timeout in seconds for the agent task. " - "Foreground: no default timeout (runs until completion), max 3600s (1hr). " - "Background: default from config (15min), max 3600s (1hr). " - "The agent is stopped if it exceeds this limit." - ), - ge=30, - le=MAX_BACKGROUND_TIMEOUT, - ) - - @property - def effective_timeout(self) -> int | None: - """Return the user-specified timeout, or None to use the system default.""" - return self.timeout - - -class AgentTool(CallableTool2[Params]): - name: str = NAME - params: type[Params] = Params - - def __init__(self, runtime: Runtime): - super().__init__( - description=load_desc( - Path(__file__).parent / "description.md", - { - "BUILTIN_AGENT_TYPES_MD": self._builtin_type_lines(runtime), - }, - ) - ) - self._runtime = runtime - - @staticmethod - def _builtin_type_lines(runtime: Runtime) -> str: - lines: list[str] = [] - for name, type_def in runtime.labor_market.builtin_types.items(): - tool_names = AgentTool._tool_summary(type_def) - model = type_def.default_model or "inherit" - suffix = ( - f" When to use: {AgentTool._normalize_summary(type_def.when_to_use)}" - if type_def.when_to_use - else "" - ) - background = "yes" if type_def.supports_background else "no" - lines.append( - f"- `{name}`: {type_def.description} " - f"(Tools: {tool_names}, Model: {model}, Background: {background}).{suffix}" - ) - return "\n".join(lines) - - @staticmethod - def _normalize_summary(text: str) -> str: - return " ".join(text.split()) - - @staticmethod - def _tool_summary(type_def: AgentTypeDefinition) -> str: - if type_def.tool_policy.mode != "allowlist": - return "*" - if not type_def.tool_policy.tools: - return "(none)" - return ", ".join(AgentTool._unique_tool_names(type_def.tool_policy.tools)) - - @staticmethod - def _unique_tool_names(tool_paths: tuple[str, ...]) -> list[str]: - names: list[str] = [] - for path in tool_paths: - name = path.split(":")[-1] - if name not in names: - names.append(name) - return names - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - if self._runtime.role != "root": - return ToolError( - message="Subagents cannot launch other subagents.", - brief="Agent unavailable", - ) - if params.model is not None and params.model not in self._runtime.config.models: - return ToolError( - message=f"Unknown model alias: {params.model}", - brief="Invalid model alias", - ) - if params.run_in_background: - return await self._run_in_background(params) - try: - runner = ForegroundSubagentRunner(self._runtime) - req = ForegroundRunRequest( - description=params.description, - prompt=params.prompt, - requested_type=params.subagent_type or "coder", - model=params.model, - resume=params.resume, - ) - timeout = params.effective_timeout - if timeout is not None: - return await asyncio.wait_for(runner.run(req), timeout=timeout) - return await runner.run(req) - except TimeoutError as exc: - if isinstance(exc.__cause__, asyncio.CancelledError): - # Task-level timeout from wait_for (it raises TimeoutError from CancelledError) - t = params.effective_timeout - logger.warning("Foreground agent timed out after {t}s", t=t) - return ToolError( - message=f"Agent timed out after {t}s.", - brief=f"Agent timed out ({t}s)", - ) - # Internal timeout (e.g. aiohttp request) — treat as generic failure - logger.exception("Foreground agent run failed") - return ToolError(message=f"Failed to run agent: {exc}", brief="Agent failed") - except Exception as exc: - logger.exception("Foreground agent run failed") - return ToolError(message=f"Failed to run agent: {exc}", brief="Agent failed") - - async def _run_in_background(self, params: Params) -> ToolReturnValue: - assert self._runtime.subagent_store is not None - try: - tool_call = get_current_tool_call_or_none() - if tool_call is None: - return ToolError( - message="Background agent requires a tool call context.", - brief="No tool call context", - ) - - requested_type = params.subagent_type or "coder" - if params.resume: - record = self._runtime.subagent_store.require_instance(params.resume) - if record.status in {"running_foreground", "running_background"}: - return ToolError( - message=( - f"Agent instance {record.agent_id} is still {record.status} and cannot " - "be resumed concurrently." - ), - brief="Agent already running", - ) - actual_type = record.subagent_type - agent_id = record.agent_id - # Validate the effective model for resumed instances — the model - # stored in the launch spec may have been removed from config since - # the instance was created. params.model is already validated in - # __call__, so only check the stored effective_model fallback here. - if params.model is None: - type_def = self._runtime.labor_market.require_builtin_type(actual_type) - effective = record.launch_spec.effective_model or type_def.default_model - if effective is not None and effective not in self._runtime.config.models: - return ToolError( - message=f"Unknown model alias: {effective}", - brief="Invalid model alias", - ) - else: - actual_type = requested_type - import uuid - - agent_id = f"a{uuid.uuid4().hex[:8]}" - record = None - - created_instance = False - if not params.resume: - type_def = self._runtime.labor_market.require_builtin_type(actual_type) - self._runtime.subagent_store.create_instance( - agent_id=agent_id, - description=params.description.strip(), - launch_spec=AgentLaunchSpec( - agent_id=agent_id, - subagent_type=actual_type, - model_override=params.model, - effective_model=params.model or type_def.default_model, - ), - ) - created_instance = True - - # Mark running_background synchronously before dispatching the - # async task so that concurrent resume attempts see the guard - # immediately (asyncio.create_task only queues the coroutine). - self._runtime.subagent_store.update_instance( - agent_id, - status="running_background", - ) - try: - view = self._runtime.background_tasks.create_agent_task( - agent_id=agent_id, - subagent_type=actual_type, - prompt=params.prompt, - description=params.description.strip(), - tool_call_id=tool_call.id, - model_override=params.model, - timeout_s=params.effective_timeout, - resumed=params.resume is not None, - ) - except Exception: - self._runtime.subagent_store.update_instance( - agent_id, - status="idle", - ) - if created_instance: - self._runtime.subagent_store.delete_instance(agent_id) - raise - lines = [ - f"task_id: {view.spec.id}", - f"kind: {view.spec.kind}", - f"status: {view.runtime.status}", - f"description: {view.spec.description}", - f"agent_id: {agent_id}", - f"actual_subagent_type: {actual_type}", - "automatic_notification: true", - "next_step: You will be automatically notified when it completes.", - ( - "next_step: Use TaskOutput with this task_id for a non-blocking status/output " - "snapshot. Only set block=true when you intentionally want to wait." - ), - f'resume_hint: Use Agent(resume="{agent_id}", prompt="...") to continue this ' - "instance later.", - ] - return ToolReturnValue( - is_error=False, - output="\n".join(lines), - message="Background task started.", - display=[], - ) - except FileNotFoundError as exc: - return ToolError(message=str(exc), brief="Agent not found") - except KeyError as exc: - return ToolError(message=str(exc), brief="Invalid subagent type") - except RuntimeError as exc: - logger.exception("Background agent launch failed") - return ToolError(message=str(exc), brief="Background start failed") - - -Agent = AgentTool diff --git a/src/kimi_cli/tools/agent/agent.ts b/src/kimi_cli/tools/agent/agent.ts new file mode 100644 index 000000000..9cfc113c6 --- /dev/null +++ b/src/kimi_cli/tools/agent/agent.ts @@ -0,0 +1,67 @@ +/** + * Agent tool — spawn subagent instances. + * Corresponds to Python tools/agent/__init__.py + * Stub: full implementation requires subagent runner integration. + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolOk } from "../types.ts"; + +const DESCRIPTION = `Start a subagent instance to work on a focused task. + +**Usage:** +- Always provide a short \`description\` (3-5 words). +- Use \`subagent_type\` to select a built-in agent type. If omitted, \`coder\` is used. +- Use \`model\` when you need to override the default model. +- Default to foreground execution. Use \`run_in_background=true\` only when needed. +- Be explicit about whether the subagent should write code or only do research. +- The subagent result is only visible to you. If the user should see it, summarize it yourself.`; + +const ParamsSchema = z.object({ + description: z + .string() + .describe("A short (3-5 word) description of the task"), + prompt: z.string().describe("The task for the agent to perform"), + subagent_type: z + .string() + .default("coder") + .describe("The built-in agent type to use. Defaults to `coder`."), + model: z + .string() + .nullish() + .describe("Optional model override."), + resume: z + .string() + .nullish() + .describe( + "Optional agent ID to resume instead of creating a new instance.", + ), + run_in_background: z + .boolean() + .default(false) + .describe("Whether to run the agent in the background."), + timeout: z + .number() + .int() + .min(30) + .max(3600) + .nullish() + .describe("Timeout in seconds for the agent task."), +}); + +type Params = z.infer; + +export class AgentTool extends CallableTool { + readonly name = "Agent"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, _ctx: ToolContext): Promise { + // Stub: full implementation requires subagent runner + return ToolError( + "Subagent system is not yet implemented in this version.", + ); + } +} diff --git a/src/kimi_cli/tools/agent/description.md b/src/kimi_cli/tools/agent/description.md deleted file mode 100644 index 8e6a3d461..000000000 --- a/src/kimi_cli/tools/agent/description.md +++ /dev/null @@ -1,41 +0,0 @@ -Start a subagent instance to work on a focused task. - -The Agent tool can either create a new subagent instance or resume an existing one by `agent_id`. -Each instance keeps its own context history under the current session, so repeated use of the same -instance can preserve previous findings and work. - -**Available Built-in Agent Types** - -${BUILTIN_AGENT_TYPES_MD} - -**Usage** - -- Always provide a short `description` (3-5 words). -- Use `subagent_type` to select a built-in agent type. If omitted, `coder` is used. -- Use `model` when you need to override the built-in type's default model or the parent agent's current model. -- Use `resume` when you want to continue an existing instance instead of starting a new one. -- If an existing subagent already has relevant context or the task is a continuation of its prior work, prefer `resume` over creating a new instance. -- Default to foreground execution. Use `run_in_background=true` only when the task can continue independently, you do not need the result immediately, and there is a clear benefit to returning control before it finishes. -- Be explicit about whether the subagent should write code or only do research. -- The subagent result is only visible to you. If the user should see it, summarize it yourself. - -**Explore Agent — Preferred for Codebase Research** - -When you need to understand the codebase before making changes, fixing bugs, or planning features, -prefer `subagent_type="explore"` over doing the search yourself. The explore agent is optimized for -fast, read-only codebase investigation. Use it when: -- Your task will clearly require more than 3 search queries -- You need to understand how a module, feature, or code path works -- You are about to enter plan mode and want to gather context first -- You want to investigate multiple independent questions — launch multiple explore agents concurrently - -When calling explore, specify the desired thoroughness in the prompt: -- "quick": targeted lookups — find a specific file, function, or config value -- "medium": understand a module — how does auth work, what calls this API -- "thorough": cross-cutting analysis — architecture overview, dependency mapping, multi-module investigation - -**When Not To Use Agent** - -- Reading a known file path -- Searching a small number of known files -- Tasks that can be completed in one or two direct tool calls diff --git a/src/kimi_cli/tools/ask_user/__init__.py b/src/kimi_cli/tools/ask_user/__init__.py deleted file mode 100644 index b68c25b6d..000000000 --- a/src/kimi_cli/tools/ask_user/__init__.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import annotations - -import json -import logging -from collections.abc import Callable -from pathlib import Path -from typing import override -from uuid import uuid4 - -from kosong.tooling import BriefDisplayBlock, CallableTool2, ToolError, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.soul import get_wire_or_none, wire_send -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.tools.utils import load_desc -from kimi_cli.wire.types import QuestionItem, QuestionNotSupported, QuestionOption, QuestionRequest - -logger = logging.getLogger(__name__) - -NAME = "AskUserQuestion" - -_BASE_DESCRIPTION = load_desc(Path(__file__).parent / "description.md") - - -class QuestionOptionParam(BaseModel): - label: str = Field( - description="Concise display text (1-5 words). If recommended, append '(Recommended)'." - ) - description: str = Field( - default="", - description="Brief explanation of trade-offs or implications of choosing this option.", - ) - - -class QuestionParam(BaseModel): - question: str = Field(description="A specific, actionable question. End with '?'.") - header: str = Field( - default="", description="Short category tag (max 12 chars, e.g. 'Auth', 'Style')." - ) - options: list[QuestionOptionParam] = Field( - description=( - "2-4 meaningful, distinct options. Do NOT include an 'Other' option — " - "the system adds one automatically." - ), - min_length=2, - max_length=4, - ) - multi_select: bool = Field( - default=False, - description="Whether the user can select multiple options.", - ) - - -class Params(BaseModel): - questions: list[QuestionParam] = Field( - description="The questions to ask the user (1-4 questions).", - min_length=1, - max_length=4, - ) - - -class AskUserQuestion(CallableTool2[Params]): - name: str = NAME - description: str = _BASE_DESCRIPTION - params: type[Params] = Params - - def __init__(self) -> None: - super().__init__() - self._is_yolo: Callable[[], bool] | None = None - - def bind_approval(self, is_yolo: Callable[[], bool]) -> None: - """Late-bind yolo checker so we can auto-dismiss in non-interactive mode.""" - self._is_yolo = is_yolo - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - if self._is_yolo and self._is_yolo(): - return ToolReturnValue( - is_error=False, - output=( - '{"answers": {}, "note": "Running in non-interactive' - ' (yolo) mode. Make your own decision."}' - ), - message="Non-interactive mode, auto-dismissed.", - display=[BriefDisplayBlock(text="Auto-dismissed (yolo)")], - ) - - wire = get_wire_or_none() - if wire is None: - return ToolError( - message="Cannot ask user questions: Wire is not available.", - brief="Wire unavailable", - ) - - tool_call = get_current_tool_call_or_none() - if tool_call is None: - return ToolError( - message="AskUserQuestion must be called from a tool call context.", - brief="Invalid context", - ) - - questions = [ - QuestionItem( - question=q.question, - header=q.header, - options=[ - QuestionOption(label=o.label, description=o.description) for o in q.options - ], - multi_select=q.multi_select, - ) - for q in params.questions - ] - - request = QuestionRequest( - id=str(uuid4()), - tool_call_id=tool_call.id, - questions=questions, - ) - - wire_send(request) - - try: - answers = await request.wait() - except QuestionNotSupported: - return ToolError( - message=( - "The connected client does not support interactive questions. " - "Do NOT call this tool again. " - "Ask the user directly in your text response instead." - ), - brief="Client unsupported", - ) - except Exception: - logger.exception("Failed to get user response for question %s", request.id) - return ToolError( - message="Failed to get user response.", - brief="Question failed", - ) - - if not answers: - return ToolReturnValue( - is_error=False, - output='{"answers": {}, "note": "User dismissed the question without answering."}', - message="User dismissed the question without answering.", - display=[BriefDisplayBlock(text="User dismissed")], - ) - - formatted = json.dumps({"answers": answers}, ensure_ascii=False) - return ToolReturnValue( - is_error=False, - output=formatted, - message="User has answered.", - display=[BriefDisplayBlock(text="User answered")], - ) diff --git a/src/kimi_cli/tools/ask_user/ask_user.ts b/src/kimi_cli/tools/ask_user/ask_user.ts new file mode 100644 index 000000000..eff2ec079 --- /dev/null +++ b/src/kimi_cli/tools/ask_user/ask_user.ts @@ -0,0 +1,99 @@ +/** + * AskUserQuestion tool — ask the user structured questions. + * Corresponds to Python tools/ask_user/__init__.py + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolOk } from "../types.ts"; + +const DESCRIPTION = `Use this tool when you need to ask the user questions with structured options during execution. This allows you to: +1. Collect user preferences or requirements before proceeding +2. Resolve ambiguous or underspecified instructions +3. Let the user decide between implementation approaches as you work +4. Present concrete options when multiple valid directions exist + +**When NOT to use:** +- When you can infer the answer from context — be decisive and proceed +- Trivial decisions that don't materially affect the outcome + +**Usage notes:** +- Users always have an "Other" option for custom input +- Use multi_select to allow multiple answers +- Keep option labels concise (1-5 words) +- Each question should have 2-4 meaningful, distinct options`; + +const QuestionOptionSchema = z.object({ + label: z + .string() + .describe( + "Concise display text (1-5 words). If recommended, append '(Recommended)'.", + ), + description: z + .string() + .default("") + .describe("Brief explanation of trade-offs or implications."), +}); + +const QuestionSchema = z.object({ + question: z + .string() + .describe("A specific, actionable question. End with '?'."), + header: z + .string() + .default("") + .describe("Short category tag (max 12 chars, e.g. 'Auth', 'Style')."), + options: z + .array(QuestionOptionSchema) + .min(2) + .max(4) + .describe("2-4 meaningful, distinct options."), + multi_select: z + .boolean() + .default(false) + .describe("Whether the user can select multiple options."), +}); + +const ParamsSchema = z.object({ + questions: z + .array(QuestionSchema) + .min(1) + .max(4) + .describe("The questions to ask the user (1-4 questions)."), +}); + +type Params = z.infer; + +export class AskUserQuestion extends CallableTool { + readonly name = "AskUserQuestion"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, ctx: ToolContext): Promise { + const answers: Record = {}; + + for (const q of params.questions) { + const optionLabels = q.options.map((o) => o.label); + + if (ctx.askUser) { + // Wire-connected: actually ask the user + try { + const answer = await ctx.askUser(q.question, optionLabels); + answers[q.question] = answer; + } catch { + // User didn't respond or error — use first option as default + answers[q.question] = optionLabels[0] ?? "No answer"; + } + } else { + // Not connected (print mode, yolo mode, etc.) — auto-select first option + answers[q.question] = optionLabels[0] ?? "No answer"; + } + } + + return ToolOk( + JSON.stringify({ answers }, null, 2), + "User responses collected.", + ); + } +} diff --git a/src/kimi_cli/tools/ask_user/description.md b/src/kimi_cli/tools/ask_user/description.md deleted file mode 100644 index ec0c553a5..000000000 --- a/src/kimi_cli/tools/ask_user/description.md +++ /dev/null @@ -1,19 +0,0 @@ -Use this tool when you need to ask the user questions with structured options during execution. This allows you to: -1. Collect user preferences or requirements before proceeding -2. Resolve ambiguous or underspecified instructions -3. Let the user decide between implementation approaches as you work -4. Present concrete options when multiple valid directions exist - -**When NOT to use:** -- When you can infer the answer from context — be decisive and proceed -- Trivial decisions that don't materially affect the outcome - -Overusing this tool interrupts the user's flow. Only use it when the user's input genuinely changes your next action. - -**Usage notes:** -- Users always have an "Other" option for custom input — don't create one yourself -- Use multi_select to allow multiple answers to be selected for a question -- Keep option labels concise (1-5 words), use descriptions for trade-offs and details -- Each question should have 2-4 meaningful, distinct options -- You can ask 1-4 questions at a time; group related questions to minimize interruptions -- If you recommend a specific option, list it first and append "(Recommended)" to its label diff --git a/src/kimi_cli/tools/background/__init__.py b/src/kimi_cli/tools/background/__init__.py deleted file mode 100644 index 1dbab1393..000000000 --- a/src/kimi_cli/tools/background/__init__.py +++ /dev/null @@ -1,318 +0,0 @@ -import time -from pathlib import Path -from typing import override - -from kosong.tooling import CallableTool2, ToolError, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.background import TaskView, format_task, format_task_list, list_task_views -from kimi_cli.soul.agent import Runtime -from kimi_cli.soul.approval import Approval -from kimi_cli.tools.display import BackgroundTaskDisplayBlock -from kimi_cli.tools.utils import load_desc - -TASK_OUTPUT_PREVIEW_BYTES = 32 << 10 -TASK_OUTPUT_READ_HINT_LINES = 300 - - -def _ensure_root(runtime: Runtime) -> ToolError | None: - if runtime.role != "root": - return ToolError( - message="Background tasks can only be managed by the root agent.", - brief="Background task unavailable", - ) - return None - - -def _task_display(runtime: Runtime, task_id: str) -> BackgroundTaskDisplayBlock: - view = runtime.background_tasks.store.merged_view(task_id) - return BackgroundTaskDisplayBlock( - task_id=view.spec.id, - kind=view.spec.kind, - status=view.runtime.status, - description=view.spec.description, - ) - - -def _format_task_output( - view: TaskView, - *, - retrieval_status: str, - output: str, - output_path: Path, - full_output_available: bool, - output_size_bytes: int, - output_preview_bytes: int, - output_truncated: bool, -) -> str: - terminal_reason = "timed_out" if view.runtime.timed_out else view.runtime.status - output_path_str = str(output_path.resolve()) - lines = [ - f"retrieval_status: {retrieval_status}", - f"task_id: {view.spec.id}", - f"kind: {view.spec.kind}", - f"status: {view.runtime.status}", - f"description: {view.spec.description}", - ] - if view.spec.kind == "agent" and view.spec.kind_payload: - if agent_id := view.spec.kind_payload.get("agent_id"): - lines.append(f"agent_id: {agent_id}") - if subagent_type := view.spec.kind_payload.get("subagent_type"): - lines.append(f"subagent_type: {subagent_type}") - if view.spec.command: - lines.append(f"command: {view.spec.command}") - lines.extend( - [ - f"interrupted: {str(view.runtime.interrupted).lower()}", - f"timed_out: {str(view.runtime.timed_out).lower()}", - f"terminal_reason: {terminal_reason}", - ] - ) - if view.runtime.exit_code is not None: - lines.append(f"exit_code: {view.runtime.exit_code}") - if view.runtime.failure_reason: - lines.append(f"reason: {view.runtime.failure_reason}") - full_output_hint = ( - ( - "full_output_hint: " - f'Use ReadFile(path="{output_path_str}", line_offset=1, ' - f"n_lines={TASK_OUTPUT_READ_HINT_LINES}) to inspect the full log. " - "Increase line_offset to continue paging through the file." - ) - if full_output_available - else "full_output_hint: No output file is currently available for this task." - ) - lines.extend( - [ - "", - f"output_path: {output_path_str}", - f"output_size_bytes: {output_size_bytes}", - f"output_preview_bytes: {output_preview_bytes}", - f"output_truncated: {str(output_truncated).lower()}", - "", - f"full_output_available: {str(full_output_available).lower()}", - "full_output_tool: ReadFile", - full_output_hint, - ] - ) - rendered_output = output or "[no output available]" - if output_truncated: - rendered_output = f"[Truncated. Full output: {output_path_str}]\n\n{rendered_output}" - return "\n".join( - lines - + [ - "", - "[output]", - rendered_output, - ] - ) - - -class TaskOutputParams(BaseModel): - task_id: str = Field(description="The background task ID to inspect.") - block: bool = Field( - default=False, - description="Whether to wait for the task to finish before returning.", - ) - timeout: int = Field( - default=30, - ge=0, - le=3600, - description="Maximum number of seconds to wait when block=true.", - ) - - -class TaskStopParams(BaseModel): - task_id: str = Field(description="The background task ID to stop.") - reason: str = Field( - default="Stopped by TaskStop", - description="Short reason recorded when the task is stopped.", - ) - - -class TaskListParams(BaseModel): - active_only: bool = Field( - default=True, - description="Whether to list only non-terminal background tasks.", - ) - limit: int = Field( - default=20, - ge=1, - le=100, - description="Maximum number of tasks to return.", - ) - - -class TaskList(CallableTool2[TaskListParams]): - name: str = "TaskList" - description: str = load_desc(Path(__file__).parent / "list.md") - params: type[TaskListParams] = TaskListParams - - def __init__(self, runtime: Runtime): - super().__init__() - self._runtime = runtime - - @override - async def __call__(self, params: TaskListParams) -> ToolReturnValue: - if err := _ensure_root(self._runtime): - return err - - views = list_task_views( - self._runtime.background_tasks, - active_only=params.active_only, - limit=params.limit, - ) - display = [ - BackgroundTaskDisplayBlock( - task_id=view.spec.id, - kind=view.spec.kind, - status=view.runtime.status, - description=view.spec.description, - ) - for view in views - ] - return ToolReturnValue( - is_error=False, - output=format_task_list(views, active_only=params.active_only), - message="Task list retrieved.", - display=list(display), - ) - - -class TaskOutput(CallableTool2[TaskOutputParams]): - name: str = "TaskOutput" - description: str = load_desc(Path(__file__).parent / "output.md") - params: type[TaskOutputParams] = TaskOutputParams - - def __init__(self, runtime: Runtime): - super().__init__() - self._runtime = runtime - - def _render_output_preview(self, task_id: str) -> tuple[str, bool, int, int, bool, Path]: - manager = self._runtime.background_tasks - output_path = manager.resolve_output_path(task_id) - try: - output_size = output_path.stat().st_size if output_path.exists() else 0 - except OSError: - output_size = 0 - preview_offset = max(0, output_size - TASK_OUTPUT_PREVIEW_BYTES) - chunk = manager.read_output( - task_id, - offset=preview_offset, - max_bytes=TASK_OUTPUT_PREVIEW_BYTES, - ) - return ( - chunk.text.rstrip("\n"), - output_size > 0, - output_size, - chunk.next_offset - chunk.offset, - preview_offset > 0, - output_path, - ) - - @override - async def __call__(self, params: TaskOutputParams) -> ToolReturnValue: - if err := _ensure_root(self._runtime): - return err - - view = self._runtime.background_tasks.get_task(params.task_id) - if view is None: - return ToolError(message=f"Task not found: {params.task_id}", brief="Task not found") - - if params.block: - view = await self._runtime.background_tasks.wait( - params.task_id, - timeout_s=params.timeout, - ) - retrieval_status = ( - "success" - if view.runtime.status in {"completed", "failed", "killed", "lost"} - else "timeout" - ) - else: - retrieval_status = ( - "success" - if view.runtime.status in {"completed", "failed", "killed", "lost"} - else "not_ready" - ) - - ( - output, - full_output_available, - output_size, - output_preview_bytes, - output_truncated, - output_path, - ) = self._render_output_preview(params.task_id) - consumer = view.consumer.model_copy( - update={ - "last_seen_output_size": output_size, - "last_viewed_at": time.time(), - } - ) - self._runtime.background_tasks.store.write_consumer(params.task_id, consumer) - - return ToolReturnValue( - is_error=False, - output=_format_task_output( - view, - retrieval_status=retrieval_status, - output=output, - output_path=output_path, - full_output_available=full_output_available, - output_size_bytes=output_size, - output_preview_bytes=output_preview_bytes, - output_truncated=output_truncated, - ), - message=( - "Task snapshot retrieved." - if not params.block and retrieval_status == "not_ready" - else "Task output retrieved." - ), - display=[_task_display(self._runtime, params.task_id)], - ) - - -class TaskStop(CallableTool2[TaskStopParams]): - name: str = "TaskStop" - description: str = load_desc(Path(__file__).parent / "stop.md") - params: type[TaskStopParams] = TaskStopParams - - def __init__(self, runtime: Runtime, approval: Approval): - super().__init__() - self._runtime = runtime - self._approval = approval - - @override - async def __call__(self, params: TaskStopParams) -> ToolReturnValue: - if err := _ensure_root(self._runtime): - return err - if self._runtime.session.state.plan_mode: - return ToolError( - message="TaskStop is not available in plan mode.", - brief="Blocked in plan mode", - ) - - view = self._runtime.background_tasks.get_task(params.task_id) - if view is None: - return ToolError(message=f"Task not found: {params.task_id}", brief="Task not found") - - result = await self._approval.request( - self.name, - "stop background task", - f"Stop background task `{params.task_id}`", - display=[_task_display(self._runtime, params.task_id)], - ) - if not result: - return result.rejection_error() - - view = self._runtime.background_tasks.kill( - params.task_id, - reason=params.reason.strip() or "Stopped by TaskStop", - ) - return ToolReturnValue( - is_error=False, - output=format_task(view, include_command=True), - message="Task stop requested.", - display=[_task_display(self._runtime, params.task_id)], - ) diff --git a/src/kimi_cli/tools/background/background.ts b/src/kimi_cli/tools/background/background.ts new file mode 100644 index 000000000..04a1832fb --- /dev/null +++ b/src/kimi_cli/tools/background/background.ts @@ -0,0 +1,170 @@ +/** + * Background task tools — TaskList, TaskOutput, TaskStop. + * Corresponds to Python tools/background/__init__.py + * Uses BackgroundTaskManager for real task management. + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolOk } from "../types.ts"; +import { listTaskViews, formatTaskList } from "../../background/summary.ts"; +import type { BackgroundTaskManager } from "../../background/manager.ts"; +import { isTerminalStatus } from "../../background/models.ts"; + +// Shared manager reference — bound at session startup +let _manager: BackgroundTaskManager | undefined; + +export function bindBackgroundManager(manager: BackgroundTaskManager): void { + _manager = manager; +} + +export function getBackgroundManager(): BackgroundTaskManager | undefined { + return _manager; +} + +// ── TaskList ──────────────────────────────────────────── + +const TaskListParamsSchema = z.object({ + active_only: z + .boolean() + .default(true) + .describe("Whether to list only non-terminal background tasks."), + limit: z + .number() + .int() + .min(1) + .max(100) + .default(20) + .describe("Maximum number of tasks to return."), +}); + +export class TaskList extends CallableTool { + readonly name = "TaskList"; + readonly description = + "List background tasks. Returns task IDs, statuses, and descriptions."; + readonly schema = TaskListParamsSchema; + + async execute( + params: z.infer, + _ctx: ToolContext, + ): Promise { + if (!_manager) { + return ToolOk("No background tasks.", "Task list retrieved."); + } + _manager.reconcile(); + const views = listTaskViews(_manager, { + activeOnly: params.active_only, + limit: params.limit, + }); + const text = formatTaskList(views, { + activeOnly: params.active_only, + includeCommand: true, + }); + return ToolOk(text, "Task list retrieved."); + } +} + +// ── TaskOutput ────────────────────────────────────────── + +const TaskOutputParamsSchema = z.object({ + task_id: z.string().describe("The background task ID to inspect."), + block: z + .boolean() + .default(false) + .describe("Whether to wait for the task to finish before returning."), + timeout: z + .number() + .int() + .min(0) + .max(3600) + .default(30) + .describe("Maximum number of seconds to wait when block=true."), +}); + +export class TaskOutput extends CallableTool { + readonly name = "TaskOutput"; + readonly description = + "Retrieve output from a background task by its ID."; + readonly schema = TaskOutputParamsSchema; + + async execute( + params: z.infer, + _ctx: ToolContext, + ): Promise { + if (!_manager) { + return ToolError(`Task not found: ${params.task_id}`); + } + + _manager.reconcile(); + let view = _manager.getTask(params.task_id); + if (!view) { + return ToolError(`Task not found: ${params.task_id}`); + } + + // Block if requested and task is still running + if (params.block && !isTerminalStatus(view.runtime.status)) { + view = await _manager.wait(params.task_id, params.timeout); + } + + const chunk = _manager.readOutput(params.task_id); + const lines = [ + `task_id: ${view.spec.id}`, + `status: ${view.runtime.status}`, + `kind: ${view.spec.kind}`, + ]; + if (view.runtime.exitCode != null) { + lines.push(`exit_code: ${view.runtime.exitCode}`); + } + if (view.runtime.failureReason) { + lines.push(`reason: ${view.runtime.failureReason}`); + } + if (chunk.text) { + lines.push("", "[output]", chunk.text); + } + if (!chunk.eof) { + lines.push(`[truncated at offset ${chunk.nextOffset}]`); + } + return ToolOk(lines.join("\n"), `Output for task ${params.task_id}.`); + } +} + +// ── TaskStop ──────────────────────────────────────────── + +const TaskStopParamsSchema = z.object({ + task_id: z.string().describe("The background task ID to stop."), + reason: z + .string() + .default("Stopped by TaskStop") + .describe("Short reason recorded when the task is stopped."), +}); + +export class TaskStop extends CallableTool { + readonly name = "TaskStop"; + readonly description = "Stop a running background task by its ID."; + readonly schema = TaskStopParamsSchema; + + async execute( + params: z.infer, + _ctx: ToolContext, + ): Promise { + if (!_manager) { + return ToolError(`Task not found: ${params.task_id}`); + } + + const existing = _manager.getTask(params.task_id); + if (!existing) { + return ToolError(`Task not found: ${params.task_id}`); + } + + if (isTerminalStatus(existing.runtime.status)) { + return ToolOk(`Task ${params.task_id} already in terminal state: ${existing.runtime.status}`); + } + + const view = _manager.kill(params.task_id, params.reason); + return ToolOk( + `Task ${params.task_id} stop requested. Current status: ${view.runtime.status}`, + `Task ${params.task_id} stopped.`, + ); + } +} diff --git a/src/kimi_cli/tools/background/list.md b/src/kimi_cli/tools/background/list.md deleted file mode 100644 index 4f409fe6c..000000000 --- a/src/kimi_cli/tools/background/list.md +++ /dev/null @@ -1,10 +0,0 @@ -List background tasks from the current session. - -Use this when you need to re-enumerate which background tasks still exist, especially after context compaction or when you are no longer confident which task IDs are still active. - -Guidelines: - -- Prefer the default `active_only=true` unless you specifically need completed or failed tasks. -- Use `TaskOutput` to inspect one task in detail after you have identified the correct task ID. -- Do not guess which tasks are still running when you can call this tool directly. -- This tool is read-only and safe to use in plan mode. diff --git a/src/kimi_cli/tools/background/output.md b/src/kimi_cli/tools/background/output.md deleted file mode 100644 index 8e7724307..000000000 --- a/src/kimi_cli/tools/background/output.md +++ /dev/null @@ -1,11 +0,0 @@ -Retrieve output from a running or completed background task. - -Use this after `Shell(run_in_background=true)` when you need to inspect progress or explicitly wait for completion. - -Guidelines: -- Prefer relying on automatic completion notifications. Use this tool only when you need task output before the automatic notification arrives. -- By default this tool is non-blocking and returns a current status/output snapshot. -- Use `block=true` only when you intentionally want to wait for completion or timeout. -- This tool returns structured task metadata, a fixed-size output preview, and an `output_path` for the full log. -- When the preview is truncated, use `ReadFile` with the returned `output_path` to inspect the full log in pages. -- This tool works with the generic background task system and should remain the primary read path for future task types, not just bash. diff --git a/src/kimi_cli/tools/background/stop.md b/src/kimi_cli/tools/background/stop.md deleted file mode 100644 index cb8ab580b..000000000 --- a/src/kimi_cli/tools/background/stop.md +++ /dev/null @@ -1,8 +0,0 @@ -Stop a running background task. - -Use this only when a background task must be cancelled. For normal task completion, prefer waiting for the automatic notification or using `TaskOutput`. - -Guidelines: -- This is a generic task stop capability, not a bash-specific kill tool. -- Use it sparingly because stopping a task is destructive and may leave partial side effects. -- If the task is already complete, this tool will simply return its current state. diff --git a/src/kimi_cli/tools/base.ts b/src/kimi_cli/tools/base.ts new file mode 100644 index 000000000..3ad72c4f1 --- /dev/null +++ b/src/kimi_cli/tools/base.ts @@ -0,0 +1,29 @@ +/** + * Abstract base class for all tools. + * Corresponds to Python's CallableTool2. + */ + +import { z } from "zod/v4"; +import type { ToolContext, ToolDefinition, ToolResult } from "./types.ts"; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export abstract class CallableTool = z.ZodType> { + abstract readonly name: string; + abstract readonly description: string; + abstract readonly schema: TParams; + + /** Execute the tool with validated parameters. */ + abstract execute( + params: z.infer, + ctx: ToolContext, + ): Promise; + + /** Convert this tool into a ToolDefinition for LLM function calling. */ + toDefinition(): ToolDefinition { + return { + name: this.name, + description: this.description, + parameters: z.toJSONSchema(this.schema) as Record, + }; + } +} diff --git a/src/kimi_cli/tools/display.py b/src/kimi_cli/tools/display.py deleted file mode 100644 index 95064f948..000000000 --- a/src/kimi_cli/tools/display.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Literal - -from kosong.tooling import DisplayBlock -from pydantic import BaseModel - - -class DiffDisplayBlock(DisplayBlock): - """Display block describing a file diff.""" - - type: str = "diff" - path: str - old_text: str - new_text: str - old_start: int = 1 - new_start: int = 1 - is_summary: bool = False - - -class TodoDisplayItem(BaseModel): - title: str - status: Literal["pending", "in_progress", "done"] - - -class TodoDisplayBlock(DisplayBlock): - """Display block describing a todo list update.""" - - type: str = "todo" - items: list[TodoDisplayItem] - - -class ShellDisplayBlock(DisplayBlock): - """Display block describing a shell command.""" - - type: str = "shell" - language: str - command: str - - -class BackgroundTaskDisplayBlock(DisplayBlock): - """Display block describing a background task.""" - - type: str = "background_task" - task_id: str - kind: str - status: str - description: str diff --git a/src/kimi_cli/tools/display.ts b/src/kimi_cli/tools/display.ts new file mode 100644 index 000000000..335c794b8 --- /dev/null +++ b/src/kimi_cli/tools/display.ts @@ -0,0 +1,44 @@ +/** + * Display block types for UI rendering. + * Corresponds to Python tools/display.py + */ + +export interface DiffDisplayBlock { + type: "diff"; + path: string; + oldText: string; + newText: string; + oldStart?: number; + newStart?: number; + isSummary?: boolean; +} + +export interface TodoDisplayItem { + title: string; + status: "pending" | "in_progress" | "done"; +} + +export interface TodoDisplayBlock { + type: "todo"; + items: TodoDisplayItem[]; +} + +export interface ShellDisplayBlock { + type: "shell"; + language: string; + command: string; +} + +export interface BackgroundTaskDisplayBlock { + type: "background_task"; + taskId: string; + kind: string; + status: string; + description: string; +} + +export type DisplayBlock = + | DiffDisplayBlock + | TodoDisplayBlock + | ShellDisplayBlock + | BackgroundTaskDisplayBlock; diff --git a/src/kimi_cli/tools/dmail/__init__.py b/src/kimi_cli/tools/dmail/__init__.py deleted file mode 100644 index 5d3c22be9..000000000 --- a/src/kimi_cli/tools/dmail/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -from pathlib import Path -from typing import override - -from kosong.tooling import CallableTool2, ToolError, ToolOk, ToolReturnValue - -from kimi_cli.soul.denwarenji import DenwaRenji, DenwaRenjiError, DMail -from kimi_cli.tools.utils import load_desc - -NAME = "SendDMail" - - -class SendDMail(CallableTool2[DMail]): - name: str = NAME - description: str = load_desc(Path(__file__).parent / "dmail.md") - params: type[DMail] = DMail - - def __init__(self, denwa_renji: DenwaRenji) -> None: - super().__init__() - self._denwa_renji = denwa_renji - - @override - async def __call__(self, params: DMail) -> ToolReturnValue: - try: - self._denwa_renji.send_dmail(params) - except DenwaRenjiError as e: - return ToolError( - output="", - message=f"Failed to send D-Mail. Error: {str(e)}", - brief="Failed to send D-Mail", - ) - return ToolOk( - output="", - message=( - "If you see this message, the D-Mail was NOT sent successfully. " - "This may be because some other tool that needs approval was rejected." - ), - brief="El Psy Kongroo", - ) diff --git a/src/kimi_cli/tools/dmail/dmail.md b/src/kimi_cli/tools/dmail/dmail.md deleted file mode 100644 index 15bf57569..000000000 --- a/src/kimi_cli/tools/dmail/dmail.md +++ /dev/null @@ -1,17 +0,0 @@ -Send a message to the past, just like sending a D-Mail in Steins;Gate. - -This tool is provided to enable you to proactively manage the context. You can see some `user` messages with text `CHECKPOINT {checkpoint_id}` wrapped in `` tags in the context. When you feel there is too much irrelevant information in the current context, you can send a D-Mail to revert the context to a previous checkpoint with a message containing only the useful information. When you send a D-Mail, you must specify an existing checkpoint ID from the before-mentioned messages. - -Typical scenarios you may want to send a D-Mail: - -- You read a file, found it very large and most of the content is not relevant to the current task. In this case you can send a D-Mail immediately to the checkpoint before you read the file and give your past self only the useful part. -- You searched the web, the result is large. - - If you got what you need, you may send a D-Mail to the checkpoint before you searched the web and put only the useful result in the mail message. - - If you did not get what you need, you may send a D-Mail to tell your past self to try another query. -- You wrote some code and it did not work as expected. You spent many struggling steps to fix it but the process is not relevant to the ultimate goal. In this case you can send a D-Mail to the checkpoint before you wrote the code and give your past self the fixed version of the code and tell yourself no need to write it again because you already wrote to the filesystem. - -After a D-Mail is sent, the system will revert the current context to the specified checkpoint, after which, you will no longer see any messages which you can now see after that checkpoint. The message in the D-Mail will be appended to the end of the context. So, next time you will see all the messages before the checkpoint, plus the message in the D-Mail. You must make it very clear in the message, tell your past self what you have done/changed, what you have learned and any other information that may be useful, so that your past self can continue the task without confusion and will not repeat the steps you have already done. - -You must understand that, unlike D-Mail in Steins;Gate, the D-Mail you send here will not revert the filesystem or any external state. That means, you are basically folding the recent messages in your context into a single message, which can significantly reduce the waste of context window. - -When sending a D-Mail, DO NOT explain to the user. The user do not care about this. Just explain to your past self. diff --git a/src/kimi_cli/tools/dmail/dmail.ts b/src/kimi_cli/tools/dmail/dmail.ts new file mode 100644 index 000000000..1c42ce1a7 --- /dev/null +++ b/src/kimi_cli/tools/dmail/dmail.ts @@ -0,0 +1,43 @@ +/** + * SendDMail tool — send a D-Mail to revert context to a checkpoint. + * Corresponds to Python tools/dmail/__init__.py + * Stub: full implementation requires denwa_renji integration. + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolOk } from "../types.ts"; + +const DESCRIPTION = `Send a message to the past, just like sending a D-Mail in Steins;Gate. + +This tool is provided to enable you to proactively manage the context. You can see some \`user\` messages with text \`CHECKPOINT {checkpoint_id}\` wrapped in \`\` tags in the context. When you feel there is too much irrelevant information in the current context, you can send a D-Mail to revert the context to a previous checkpoint with a message containing only the useful information. + +After a D-Mail is sent, the system will revert the current context to the specified checkpoint. You must make it very clear in the message what you have done/changed, what you have learned, so that your past self can continue the task without confusion. + +When sending a D-Mail, DO NOT explain to the user. Just explain to your past self.`; + +const ParamsSchema = z.object({ + checkpoint_id: z.string().describe("The checkpoint ID to revert to."), + message: z + .string() + .describe( + "The message to send to your past self with useful information.", + ), +}); + +type Params = z.infer; + +export class SendDMail extends CallableTool { + readonly name = "SendDMail"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, _ctx: ToolContext): Promise { + // Stub: full implementation requires denwa_renji + return ToolOk( + "", + "If you see this message, the D-Mail was NOT sent successfully.", + ); + } +} diff --git a/src/kimi_cli/tools/file/__init__.py b/src/kimi_cli/tools/file/__init__.py deleted file mode 100644 index 5b9a43699..000000000 --- a/src/kimi_cli/tools/file/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -from enum import StrEnum - - -class FileOpsWindow: - """Maintains a window of file operations.""" - - pass - - -class FileActions(StrEnum): - READ = "read file" - EDIT = "edit file" - EDIT_OUTSIDE = "edit file outside of working directory" - - -from .glob import Glob # noqa: E402 -from .grep_local import Grep # noqa: E402 -from .read import ReadFile # noqa: E402 -from .read_media import ReadMediaFile # noqa: E402 -from .replace import StrReplaceFile # noqa: E402 -from .write import WriteFile # noqa: E402 - -__all__ = ( - "ReadFile", - "ReadMediaFile", - "Glob", - "Grep", - "WriteFile", - "StrReplaceFile", -) diff --git a/src/kimi_cli/tools/file/glob.md b/src/kimi_cli/tools/file/glob.md deleted file mode 100644 index bf6968637..000000000 --- a/src/kimi_cli/tools/file/glob.md +++ /dev/null @@ -1,17 +0,0 @@ -Find files and directories using glob patterns. This tool supports standard glob syntax like `*`, `?`, and `**` for recursive searches. - -**When to use:** -- Find files matching specific patterns (e.g., all Python files: `*.py`) -- Search for files recursively in subdirectories (e.g., `src/**/*.js`) -- Locate configuration files (e.g., `*.config.*`, `*.json`) -- Find test files (e.g., `test_*.py`, `*_test.go`) - -**Example patterns:** -- `*.py` - All Python files in current directory -- `src/**/*.js` - All JavaScript files in src directory recursively -- `test_*.py` - Python test files starting with "test_" -- `*.config.{js,ts}` - Config files with .js or .ts extension - -**Bad example patterns:** -- `**`, `**/*.py` - Any pattern starting with '**' will be rejected. Because it would recursively search all directories and subdirectories, which is very likely to yield large result that exceeds your context size. Always use more specific patterns like `src/**/*.py` instead. -- `node_modules/**/*.js` - Although this does not start with '**', it would still highly possible to yield large result because `node_modules` is well-known to contain too many directories and files. Avoid recursively searching in such directories, other examples include `venv`, `.venv`, `__pycache__`, `target`. If you really need to search in a dependency, use more specific patterns like `node_modules/react/src/*` instead. diff --git a/src/kimi_cli/tools/file/glob.py b/src/kimi_cli/tools/file/glob.py deleted file mode 100644 index 65359d5ca..000000000 --- a/src/kimi_cli/tools/file/glob.py +++ /dev/null @@ -1,156 +0,0 @@ -"""Glob tool implementation.""" - -from pathlib import Path -from typing import override - -from kaos.path import KaosPath -from kosong.tooling import CallableTool2, ToolError, ToolOk, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.soul.agent import Runtime -from kimi_cli.tools.utils import load_desc -from kimi_cli.utils.path import is_within_directory, is_within_workspace, list_directory - -MAX_MATCHES = 1000 - - -class Params(BaseModel): - pattern: str = Field(description=("Glob pattern to match files/directories.")) - directory: str | None = Field( - description=( - "Absolute path to the directory to search in (defaults to working directory)." - ), - default=None, - ) - include_dirs: bool = Field( - description="Whether to include directories in results.", - default=True, - ) - - -class Glob(CallableTool2[Params]): - name: str = "Glob" - description: str = load_desc( - Path(__file__).parent / "glob.md", - { - "MAX_MATCHES": str(MAX_MATCHES), - }, - ) - params: type[Params] = Params - - def __init__(self, runtime: Runtime) -> None: - super().__init__() - self._work_dir = runtime.builtin_args.KIMI_WORK_DIR - self._additional_dirs = runtime.additional_dirs - self._skills_dirs = runtime.skills_dirs - - async def _validate_pattern(self, pattern: str) -> ToolError | None: - """Validate that the pattern is safe to use.""" - if pattern.startswith("**"): - ls_result = await list_directory(self._work_dir) - return ToolError( - output=ls_result, - message=( - f"Pattern `{pattern}` starts with '**' which is not allowed. " - "This would recursively search all directories and may include large " - "directories like `node_modules`. Use more specific patterns instead. " - "For your convenience, a list of all files and directories in the " - "top level of the working directory is provided below." - ), - brief="Unsafe pattern", - ) - return None - - async def _validate_directory(self, directory: KaosPath) -> ToolError | None: - """Validate that the directory is safe to search.""" - resolved_dir = directory.canonical() - - # Allow directories within the workspace (work_dir or additional dirs) - if is_within_workspace(resolved_dir, self._work_dir, self._additional_dirs): - return None - - # Allow directories within any discovered skills root - if any(is_within_directory(resolved_dir, d) for d in self._skills_dirs): - return None - - return ToolError( - message=( - f"`{directory}` is outside the workspace. " - "You can only search within the working directory, " - "additional directories, and skills directories." - ), - brief="Directory outside workspace", - ) - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - try: - # Validate pattern safety - pattern_error = await self._validate_pattern(params.pattern) - if pattern_error: - return pattern_error - - dir_path = ( - KaosPath(params.directory).expanduser() if params.directory else self._work_dir - ) - - if not dir_path.is_absolute(): - return ToolError( - message=( - f"`{params.directory}` is not an absolute path. " - "You must provide an absolute path to search." - ), - brief="Invalid directory", - ) - - # Validate directory safety - dir_error = await self._validate_directory(dir_path) - if dir_error: - return dir_error - - if not await dir_path.exists(): - return ToolError( - message=f"`{params.directory}` does not exist.", - brief="Directory not found", - ) - if not await dir_path.is_dir(): - return ToolError( - message=f"`{params.directory}` is not a directory.", - brief="Invalid directory", - ) - - # Perform the glob search - users can use ** directly in pattern - matches: list[KaosPath] = [] - async for match in dir_path.glob(params.pattern): - matches.append(match) - - # Filter out directories if not requested - if not params.include_dirs: - matches = [p for p in matches if await p.is_file()] - - # Sort for consistent output - matches.sort() - - # Limit matches - message = ( - f"Found {len(matches)} matches for pattern `{params.pattern}`." - if len(matches) > 0 - else f"No matches found for pattern `{params.pattern}`." - ) - if len(matches) > MAX_MATCHES: - matches = matches[:MAX_MATCHES] - message += ( - f" Only the first {MAX_MATCHES} matches are returned. " - "You may want to use a more specific pattern." - ) - - return ToolOk( - output="\n".join(str(p.relative_to(dir_path)) for p in matches), - message=message, - ) - - except Exception as e: - return ToolError( - message=f"Failed to search for pattern {params.pattern}. Error: {e}", - brief="Glob failed", - ) diff --git a/src/kimi_cli/tools/file/glob.ts b/src/kimi_cli/tools/file/glob.ts new file mode 100644 index 000000000..c01905747 --- /dev/null +++ b/src/kimi_cli/tools/file/glob.ts @@ -0,0 +1,97 @@ +/** + * Glob tool — find files and directories using glob patterns. + * Corresponds to Python tools/file/glob.py + */ + +import { z } from "zod/v4"; +import { globby } from "globby"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolOk } from "../types.ts"; + +const MAX_MATCHES = 1000; + +const DESCRIPTION = `Find files and directories using glob patterns. This tool supports standard glob syntax like \`*\`, \`?\`, and \`**\` for recursive searches. + +**When to use:** +- Find files matching specific patterns (e.g., all Python files: \`*.py\`) +- Search for files recursively in subdirectories (e.g., \`src/**/*.js\`) +- Locate configuration files (e.g., \`*.config.*\`, \`*.json\`) + +**Bad example patterns:** +- \`**\`, \`**/*.py\` - Any pattern starting with '**' will be rejected. +- \`node_modules/**/*.js\` - Avoid recursively searching in large directories.`; + +const ParamsSchema = z.object({ + pattern: z.string().describe("Glob pattern to match files/directories."), + directory: z + .string() + .nullish() + .describe( + "Absolute path to the directory to search in (defaults to working directory).", + ), + include_dirs: z + .boolean() + .default(true) + .describe("Whether to include directories in results."), +}); + +type Params = z.infer; + +export class Glob extends CallableTool { + readonly name = "Glob"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, ctx: ToolContext): Promise { + try { + // Validate pattern safety + if (params.pattern.startsWith("**")) { + return ToolError( + `Pattern \`${params.pattern}\` starts with '**' which is not allowed. ` + + "This would recursively search all directories. Use more specific patterns instead.", + ); + } + + const dirPath = params.directory || ctx.workingDir; + + if (!dirPath.startsWith("/")) { + return ToolError( + `\`${params.directory}\` is not an absolute path. You must provide an absolute path to search.`, + ); + } + + // Perform the glob search + let matches = await globby(params.pattern, { + cwd: dirPath, + dot: true, + onlyFiles: !params.include_dirs, + ignore: [ + ".git", + ".svn", + ".hg", + "node_modules/**", + ], + }); + + // Sort for consistent output + matches.sort(); + + let message = + matches.length > 0 + ? `Found ${matches.length} matches for pattern \`${params.pattern}\`.` + : `No matches found for pattern \`${params.pattern}\`.`; + + if (matches.length > MAX_MATCHES) { + matches = matches.slice(0, MAX_MATCHES); + message += ` Only the first ${MAX_MATCHES} matches are returned. You may want to use a more specific pattern.`; + } + + return ToolOk(matches.join("\n"), message); + } catch (e) { + return ToolError( + `Failed to search for pattern ${params.pattern}. Error: ${e}`, + ); + } + } +} diff --git a/src/kimi_cli/tools/file/grep.md b/src/kimi_cli/tools/file/grep.md deleted file mode 100644 index bef02fb7a..000000000 --- a/src/kimi_cli/tools/file/grep.md +++ /dev/null @@ -1,5 +0,0 @@ -A powerful search tool based-on ripgrep. - -**Tips:** -- ALWAYS use Grep tool instead of running `grep` or `rg` command with Shell tool. -- Use the ripgrep pattern syntax, not grep syntax. E.g. you need to escape braces like `\\{` to search for `{`. diff --git a/src/kimi_cli/tools/file/grep.ts b/src/kimi_cli/tools/file/grep.ts new file mode 100644 index 000000000..be6f2abc1 --- /dev/null +++ b/src/kimi_cli/tools/file/grep.ts @@ -0,0 +1,339 @@ +/** + * Grep tool — regex search using ripgrep. + * Corresponds to Python tools/file/grep_local.py + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolResultBuilder } from "../types.ts"; + +const RG_TIMEOUT = 20_000; // 20 seconds in ms +const RG_MAX_BUFFER = 20_000_000; // 20MB +const RG_KILL_GRACE = 5_000; // 5 seconds: SIGTERM → SIGKILL + +const DESCRIPTION = `A powerful search tool based on ripgrep. + +**Tips:** +- ALWAYS use Grep tool instead of running \`grep\` or \`rg\` command with Shell tool. +- Use the ripgrep pattern syntax, not grep syntax. E.g. you need to escape braces like \`\\{\` to search for \`{\`.`; + +const ParamsSchema = z.object({ + pattern: z + .string() + .describe( + "The regular expression pattern to search for in file contents", + ), + path: z + .string() + .default(".") + .describe( + "File or directory to search in. Defaults to current working directory.", + ), + glob: z + .string() + .nullish() + .describe("Glob pattern to filter files (e.g. `*.js`, `*.{ts,tsx}`)."), + output_mode: z + .string() + .default("files_with_matches") + .describe( + "`content`: Show matching lines; `files_with_matches`: Show file paths; `count_matches`: Show total number of matches.", + ), + "-B": z + .number() + .int() + .nullish() + .describe("Number of lines to show before each match."), + "-A": z + .number() + .int() + .nullish() + .describe("Number of lines to show after each match."), + "-C": z + .number() + .int() + .nullish() + .describe("Number of lines to show before and after each match."), + "-n": z.boolean().default(true).describe("Show line numbers in output."), + "-i": z.boolean().default(false).describe("Case insensitive search."), + type: z + .string() + .nullish() + .describe("File type to search (e.g. py, js, ts, go, java)."), + head_limit: z + .number() + .int() + .min(0) + .default(250) + .describe("Limit output to first N lines/entries. 0 for unlimited."), + offset: z + .number() + .int() + .min(0) + .default(0) + .describe("Skip first N lines/entries before applying head_limit."), + multiline: z + .boolean() + .default(false) + .describe("Enable multiline mode where `.` matches newlines."), +}); + +type Params = z.infer; + +function buildRgArgs( + params: Params, + searchPath: string, + opts?: { singleThreaded?: boolean }, +): string[] { + const args: string[] = ["rg"]; + + // Fixed args + if (params.output_mode !== "content") { + args.push("--max-columns", "500"); + } + args.push("--hidden"); + for (const vcsDir of [".git", ".svn", ".hg", ".bzr", ".jj", ".sl"]) { + args.push("--glob", `!${vcsDir}`); + } + + if (opts?.singleThreaded) { + args.push("-j", "1"); + } + + // Search options + if (params["-i"]) args.push("--ignore-case"); + if (params.multiline) args.push("--multiline", "--multiline-dotall"); + + // Content display options + if (params.output_mode === "content") { + if (params["-B"] != null) args.push("--before-context", String(params["-B"])); + if (params["-A"] != null) args.push("--after-context", String(params["-A"])); + if (params["-C"] != null) args.push("--context", String(params["-C"])); + if (params["-n"]) args.push("--line-number"); + } + + // File filtering + if (params.glob) args.push("--glob", params.glob); + if (params.type) args.push("--type", params.type); + + // Output mode + if (params.output_mode === "files_with_matches") { + args.push("--files-with-matches"); + } else if (params.output_mode === "count_matches") { + args.push("--count-matches"); + } + + // Pattern and path + args.push("--", params.pattern, searchPath); + + return args; +} + +function stripPathPrefix(output: string, searchBase: string): string { + const prefix = searchBase.replace(/[/\\]$/, "") + "/"; + return output + .split("\n") + .map((line) => (line.startsWith(prefix) ? line.slice(prefix.length) : line)) + .join("\n"); +} + +function isEagain(stderr: string): boolean { + return stderr.includes("os error 11") || stderr.includes("Resource temporarily unavailable"); +} + +/** Two-phase kill: SIGTERM → grace period → SIGKILL. */ +async function killProcess(proc: ReturnType): Promise { + proc.kill(); // SIGTERM + try { + await Promise.race([ + proc.exited, + new Promise((_, reject) => + setTimeout(() => reject(new Error("kill_grace_timeout")), RG_KILL_GRACE), + ), + ]); + } catch { + // Grace period expired, send SIGKILL + proc.kill(9); + await proc.exited; + } +} + +export class Grep extends CallableTool { + readonly name = "Grep"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute( + params: Params, + ctx: ToolContext, + opts?: { _retry?: boolean }, + ): Promise { + try { + const builder = new ToolResultBuilder(); + let message = ""; + + // Resolve the search path + let searchPath = params.path; + if (!searchPath.startsWith("/")) { + searchPath = `${ctx.workingDir}/${searchPath}`; + } + searchPath = searchPath.replace(/^~/, process.env.HOME || ""); + + const args = buildRgArgs(params, searchPath, { + singleThreaded: opts?._retry, + }); + + // Execute ripgrep using Bun.spawn + const proc = Bun.spawn(args, { + stdout: "pipe", + stderr: "pipe", + }); + + let timedOut = false; + let output: string; + let stderrStr: string; + + try { + const timeoutPromise = new Promise((_, reject) => { + setTimeout(() => reject(new Error("timeout")), RG_TIMEOUT); + }); + + const resultPromise = (async () => { + const stdoutBytes = await new Response(proc.stdout).arrayBuffer(); + const stderrBytes = await new Response(proc.stderr).arrayBuffer(); + return { + stdout: new TextDecoder().decode(stdoutBytes), + stderr: new TextDecoder().decode(stderrBytes), + }; + })(); + + const result = await Promise.race([resultPromise, timeoutPromise]); + output = result.stdout; + stderrStr = result.stderr; + await proc.exited; + } catch (e) { + if (e instanceof Error && e.message === "timeout") { + await killProcess(proc); + timedOut = true; + output = ""; + stderrStr = ""; + } else { + throw e; + } + } + + // Buffer truncation + let bufferTruncated = false; + if (output.length > RG_MAX_BUFFER) { + output = output.slice(0, RG_MAX_BUFFER); + const lastNl = output.lastIndexOf("\n"); + output = lastNl >= 0 ? output.slice(0, lastNl) : ""; + bufferTruncated = true; + message = "Output exceeded buffer limit. Some results omitted."; + } + + // Timeout handling + if (timedOut) { + if (!output.trim()) { + return ToolError( + `Grep timed out after ${RG_TIMEOUT / 1000}s. Try a more specific path or pattern.`, + ); + } + const timeoutMsg = `Grep timed out after ${RG_TIMEOUT / 1000}s. Partial results returned.`; + message = message ? `${message} ${timeoutMsg}` : timeoutMsg; + } + + // rg exit codes: 0=matches found, 1=no matches, 2+=error + if (!timedOut && proc.exitCode !== 0 && proc.exitCode !== 1) { + // EAGAIN: retry once with single-threaded mode + if (!opts?._retry && isEagain(stderrStr)) { + return this.execute(params, ctx, { _retry: true }); + } + return ToolError(`Failed to grep. Error: ${stderrStr}`); + } + + // Post-processing: strip path prefix + let searchBase = searchPath; + try { + const { stat } = await import("node:fs/promises"); + const info = await stat(searchBase); + if (info.isFile()) { + searchBase = searchBase.replace(/\/[^/]+$/, ""); + } + } catch { + // path doesn't exist or inaccessible, use as-is + } + output = stripPathPrefix(output, searchBase); + + // Split into lines + let lines = output.split("\n"); + if (lines.length > 0 && lines[lines.length - 1] === "") { + lines = lines.slice(0, -1); + } + + // Sort files_with_matches by mtime (most recently modified first) + if (!timedOut && params.output_mode === "files_with_matches" && lines.length > 0) { + const { stat: fsStat } = await import("node:fs/promises"); + const withMtime = await Promise.all( + lines.map(async (filePath) => { + try { + const fullPath = filePath.startsWith("/") ? filePath : `${searchBase}/${filePath}`; + const info = await fsStat(fullPath); + return { filePath, mtime: info.mtimeMs }; + } catch { + return { filePath, mtime: 0 }; + } + }), + ); + withMtime.sort((a, b) => b.mtime - a.mtime); + lines = withMtime.map((x) => x.filePath); + } + + // count_matches summary + if (params.output_mode === "count_matches") { + let totalMatches = 0; + let totalFiles = 0; + for (const line of lines) { + const idx = line.lastIndexOf(":"); + if (idx > 0) { + const count = parseInt(line.slice(idx + 1), 10); + if (!isNaN(count)) { + totalMatches += count; + totalFiles += 1; + } + } + } + const countSummary = `Found ${totalMatches} total occurrences across ${totalFiles} files.`; + message = message ? `${message} ${countSummary}` : countSummary; + } + + // Offset + head_limit pagination + if (params.offset > 0) { + lines = lines.slice(params.offset); + } + + const effectiveLimit = params.head_limit; + if (effectiveLimit && lines.length > effectiveLimit) { + const total = lines.length + params.offset; + lines = lines.slice(0, effectiveLimit); + output = lines.join("\n"); + const truncationMsg = + `Results truncated to ${effectiveLimit} lines (total: ${total}). ` + + `Use offset=${params.offset + effectiveLimit} to see more.`; + message = message ? `${message} ${truncationMsg}` : truncationMsg; + } else { + output = lines.join("\n"); + } + + if (!output && !bufferTruncated) { + return builder.ok("No matches found"); + } + + builder.write(output); + return builder.ok(message); + } catch (e) { + return ToolError(`Failed to grep. Error: ${String(e)}`); + } + } +} diff --git a/src/kimi_cli/tools/file/grep_local.py b/src/kimi_cli/tools/file/grep_local.py deleted file mode 100644 index 62afdd998..000000000 --- a/src/kimi_cli/tools/file/grep_local.py +++ /dev/null @@ -1,524 +0,0 @@ -""" -The local version of the Grep tool using ripgrep. -Be cautious that `KaosPath` is not used in this implementation. -""" - -import asyncio -import os -import platform -import shutil -import stat -import tarfile -import tempfile -import zipfile -from pathlib import Path -from typing import override - -import aiohttp -from kosong.tooling import CallableTool2, ToolError, ToolReturnValue -from pydantic import BaseModel, Field - -import kimi_cli -from kimi_cli.share import get_share_dir -from kimi_cli.tools.utils import ToolResultBuilder, load_desc -from kimi_cli.utils.aiohttp import new_client_session -from kimi_cli.utils.logging import logger - - -class Params(BaseModel): - pattern: str = Field( - description="The regular expression pattern to search for in file contents" - ) - path: str = Field( - description=( - "File or directory to search in. Defaults to current working directory. " - "If specified, it must be an absolute path." - ), - default=".", - ) - glob: str | None = Field( - description=( - "Glob pattern to filter files (e.g. `*.js`, `*.{ts,tsx}`). No filter by default." - ), - default=None, - ) - output_mode: str = Field( - description=( - "`content`: Show matching lines (supports `-B`, `-A`, `-C`, `-n`, `head_limit`); " - "`files_with_matches`: Show file paths (supports `head_limit`); " - "`count_matches`: Show total number of matches. " - "Defaults to `files_with_matches`." - ), - default="files_with_matches", - ) - before_context: int | None = Field( - alias="-B", - description=( - "Number of lines to show before each match (the `-B` option). " - "Requires `output_mode` to be `content`." - ), - default=None, - ) - after_context: int | None = Field( - alias="-A", - description=( - "Number of lines to show after each match (the `-A` option). " - "Requires `output_mode` to be `content`." - ), - default=None, - ) - context: int | None = Field( - alias="-C", - description=( - "Number of lines to show before and after each match (the `-C` option). " - "Requires `output_mode` to be `content`." - ), - default=None, - ) - line_number: bool = Field( - alias="-n", - description=( - "Show line numbers in output (the `-n` option). " - "Requires `output_mode` to be `content`. Defaults to true." - ), - default=True, - ) - ignore_case: bool = Field( - alias="-i", - description="Case insensitive search (the `-i` option).", - default=False, - ) - type: str | None = Field( - description=( - "File type to search. Examples: py, rust, js, ts, go, java, etc. " - "More efficient than `glob` for standard file types." - ), - default=None, - ) - head_limit: int | None = Field( - description=( - "Limit output to first N lines/entries, equivalent to `| head -N`. " - "Works across all output modes: content (limits output lines), " - "files_with_matches (limits file paths), count_matches (limits count entries). " - "Defaults to 250. " - "Pass 0 for unlimited (use sparingly — large result sets waste context)." - ), - default=250, - ge=0, - ) - offset: int = Field( - description=( - "Skip first N lines/entries before applying head_limit, " - "equivalent to `| tail -n +N | head -N`. " - "Works across all output modes. Defaults to 0." - ), - default=0, - ge=0, - ) - multiline: bool = Field( - description=( - "Enable multiline mode where `.` matches newlines and patterns can span " - "lines (the `-U` and `--multiline-dotall` options). " - "By default, multiline mode is disabled." - ), - default=False, - ) - - -RG_VERSION = "15.0.0" -RG_BASE_URL = "http://cdn.kimi.com/binaries/kimi-cli/rg" -RG_TIMEOUT = 20 # seconds -RG_MAX_BUFFER = 20_000_000 # 20MB stdout/stderr buffer limit -RG_KILL_GRACE = 5 # seconds: SIGTERM → SIGKILL -_RG_DOWNLOAD_LOCK = asyncio.Lock() - - -def _rg_binary_name() -> str: - return "rg.exe" if platform.system() == "Windows" else "rg" - - -def _find_existing_rg(bin_name: str) -> Path | None: - share_bin = get_share_dir() / "bin" / bin_name - if share_bin.is_file(): - return share_bin - - assert kimi_cli.__file__ is not None - local_dep = Path(kimi_cli.__file__).parent / "deps" / "bin" / bin_name - if local_dep.is_file(): - return local_dep - - system_rg = shutil.which("rg") - if system_rg: - return Path(system_rg) - - return None - - -def _detect_target() -> str | None: - sys_name = platform.system() - mach = platform.machine().lower() - - if mach in ("x86_64", "amd64"): - arch = "x86_64" - elif mach in ("arm64", "aarch64"): - arch = "aarch64" - else: - logger.error("Unsupported architecture for ripgrep: {mach}", mach=mach) - return None - - if sys_name == "Darwin": - os_name = "apple-darwin" - elif sys_name == "Linux": - os_name = "unknown-linux-musl" if arch == "x86_64" else "unknown-linux-gnu" - elif sys_name == "Windows": - os_name = "pc-windows-msvc" - else: - logger.error("Unsupported operating system for ripgrep: {sys_name}", sys_name=sys_name) - return None - - return f"{arch}-{os_name}" - - -async def _download_and_install_rg(bin_name: str) -> Path: - target = _detect_target() - if not target: - raise RuntimeError("Unsupported platform for ripgrep download") - - is_windows = "windows" in target - archive_ext = "zip" if is_windows else "tar.gz" - filename = f"ripgrep-{RG_VERSION}-{target}.{archive_ext}" - url = f"{RG_BASE_URL}/{filename}" - logger.info("Downloading ripgrep from {url}", url=url) - - share_bin_dir = get_share_dir() / "bin" - share_bin_dir.mkdir(parents=True, exist_ok=True) - destination = share_bin_dir / bin_name - - # Downloading the ripgrep binary can be slow on constrained networks. - download_timeout = aiohttp.ClientTimeout(total=600, sock_read=60, sock_connect=15) - async with new_client_session(timeout=download_timeout) as session: - with tempfile.TemporaryDirectory(prefix="kimi-rg-") as tmpdir: - tar_path = Path(tmpdir) / filename - - try: - async with session.get(url) as resp: - resp.raise_for_status() - with open(tar_path, "wb") as fh: - async for chunk in resp.content.iter_chunked(1024 * 64): - if chunk: - fh.write(chunk) - except (aiohttp.ClientError, TimeoutError) as exc: - raise RuntimeError("Failed to download ripgrep binary") from exc - - try: - if is_windows: - with zipfile.ZipFile(tar_path, "r") as zf: - member_name = next( - (name for name in zf.namelist() if Path(name).name == bin_name), - None, - ) - if not member_name: - raise RuntimeError("Ripgrep binary not found in archive") - with zf.open(member_name) as source, open(destination, "wb") as dest_fh: - shutil.copyfileobj(source, dest_fh) - else: - with tarfile.open(tar_path, "r:gz") as tar: - member = next( - (m for m in tar.getmembers() if Path(m.name).name == bin_name), - None, - ) - if not member: - raise RuntimeError("Ripgrep binary not found in archive") - extracted = tar.extractfile(member) - if not extracted: - raise RuntimeError("Failed to extract ripgrep binary") - with open(destination, "wb") as dest_fh: - shutil.copyfileobj(extracted, dest_fh) - except (zipfile.BadZipFile, tarfile.TarError, OSError) as exc: - raise RuntimeError("Failed to extract ripgrep archive") from exc - - destination.chmod(destination.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) - logger.info("Installed ripgrep to {destination}", destination=destination) - return destination - - -async def _ensure_rg_path() -> str: - bin_name = _rg_binary_name() - existing = _find_existing_rg(bin_name) - if existing: - return str(existing) - - async with _RG_DOWNLOAD_LOCK: - existing = _find_existing_rg(bin_name) - if existing: - return str(existing) - - downloaded = await _download_and_install_rg(bin_name) - return str(downloaded) - - -def _build_rg_args(rg_path: str, params: Params, *, single_threaded: bool = False) -> list[str]: - """Build ripgrep command-line arguments from Params.""" - args: list[str] = [rg_path] - - # Fixed args - if params.output_mode != "content": - args.extend(["--max-columns", "500"]) - args.append("--hidden") - for vcs_dir in (".git", ".svn", ".hg", ".bzr", ".jj", ".sl"): - args.extend(["--glob", f"!{vcs_dir}"]) - - if single_threaded: - args.extend(["-j", "1"]) - - # Search options - if params.ignore_case: - args.append("--ignore-case") - if params.multiline: - args.extend(["--multiline", "--multiline-dotall"]) - - # Content display options (only for content mode) - if params.output_mode == "content": - if params.before_context is not None: - args.extend(["--before-context", str(params.before_context)]) - if params.after_context is not None: - args.extend(["--after-context", str(params.after_context)]) - if params.context is not None: - args.extend(["--context", str(params.context)]) - if params.line_number: - args.append("--line-number") - - # File filtering options - if params.glob: - args.extend(["--glob", params.glob]) - if params.type: - args.extend(["--type", params.type]) - - # Output mode - if params.output_mode == "files_with_matches": - args.append("--files-with-matches") - elif params.output_mode == "count_matches": - args.append("--count-matches") - - # Separate pattern from flags to avoid ambiguity (e.g. pattern starting with -) - args.append("--") - args.append(params.pattern) - args.append(os.path.expanduser(params.path)) - - return args - - -async def _read_stream( - stream: asyncio.StreamReader, - buffer: bytearray, - limit: int, - truncated_flag: list[bool] | None = None, -) -> bool: - """Incrementally read from stream into buffer, up to limit bytes. - - After hitting the limit, continues draining the pipe (discarding data) - so the child process doesn't block on a full pipe buffer. - - Args: - truncated_flag: If provided, truncated_flag[0] is set to True at the - moment truncation occurs (synchronously, before the next await). - This ensures the flag is available even if the coroutine is - cancelled by asyncio.wait_for timeout. - - Returns True if output was truncated (exceeded limit). - """ - truncated = False - while True: - chunk = await stream.read(65536) - if not chunk: - break - if len(buffer) < limit: - needed = limit - len(buffer) - buffer.extend(chunk[:needed]) - if len(chunk) > needed: - truncated = True - if truncated_flag is not None: - truncated_flag[0] = True - else: - truncated = True - if truncated_flag is not None: - truncated_flag[0] = True - return truncated - - -async def _kill_process(process: asyncio.subprocess.Process) -> None: - """Two-phase kill: SIGTERM → grace period → SIGKILL.""" - process.terminate() - try: - await asyncio.wait_for(process.wait(), timeout=RG_KILL_GRACE) - except TimeoutError: - process.kill() - await process.wait() - - -def _is_eagain(stderr: str) -> bool: - return "os error 11" in stderr or "Resource temporarily unavailable" in stderr - - -def _strip_path_prefix(output: str, search_base: str) -> str: - """Strip search_base prefix from each line to produce relative paths.""" - prefix = search_base.rstrip("/\\") + os.sep - return "\n".join( - line[len(prefix) :] if line.startswith(prefix) else line for line in output.split("\n") - ) - - -class Grep(CallableTool2[Params]): - name: str = "Grep" - description: str = load_desc(Path(__file__).parent / "grep.md") - params: type[Params] = Params - - @override - async def __call__(self, params: Params, *, _retry: bool = False) -> ToolReturnValue: - try: - builder = ToolResultBuilder() - message = "" - - # Build rg command - rg_path = await _ensure_rg_path() - logger.debug("Using ripgrep binary: {rg_bin}", rg_bin=rg_path) - args = _build_rg_args(rg_path, params, single_threaded=_retry) - - # Execute search as async subprocess (non-blocking, cancellable) - process = await asyncio.create_subprocess_exec( - *args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - # Stream stdout/stderr incrementally with buffer limit - stdout_buf = bytearray() - stderr_buf = bytearray() - timed_out = False - stdout_truncated_flag: list[bool] = [False] - - try: - assert process.stdout is not None - assert process.stderr is not None - await asyncio.wait_for( - asyncio.gather( - _read_stream( - process.stdout, stdout_buf, RG_MAX_BUFFER, stdout_truncated_flag - ), - _read_stream(process.stderr, stderr_buf, RG_MAX_BUFFER), - ), - timeout=RG_TIMEOUT, - ) - await process.wait() - except asyncio.CancelledError: - await _kill_process(process) - raise - except TimeoutError: - await _kill_process(process) - timed_out = True - - output = stdout_buf.decode("utf-8", errors="replace") - stderr_str = stderr_buf.decode("utf-8", errors="replace") - - # truncated_flag is set synchronously inside _read_stream at - # the moment of truncation, so it's available even after timeout. - buffer_truncated = stdout_truncated_flag[0] - - # Drop last incomplete line if buffer was truncated - if buffer_truncated: - last_nl = output.rfind("\n") - output = output[:last_nl] if last_nl >= 0 else "" - message = "Output exceeded buffer limit. Some results omitted." - - # Timeout: return partial results if available, otherwise error - if timed_out: - if not output.strip(): - return ToolError( - message=( - f"Grep timed out after {RG_TIMEOUT}s. " - "Try a more specific path or pattern." - ), - brief="Grep timed out", - ) - timeout_msg = f"Grep timed out after {RG_TIMEOUT}s. Partial results returned." - message = f"{message} {timeout_msg}" if message else timeout_msg - - # rg exit codes: 0=matches found, 1=no matches, 2+=error - if not timed_out and process.returncode not in (0, 1): - # EAGAIN: retry once with single-threaded mode - if not _retry and _is_eagain(stderr_str): - logger.warning("rg EAGAIN error, retrying with -j 1") - return await self.__call__(params, _retry=True) - return ToolError( - message=f"Failed to grep. Error: {stderr_str}", - brief="Failed to grep", - ) - - # --- Post-processing pipeline --- - - # Step 1: mtime sorting (files_with_matches only, skip on timeout) - if not timed_out and params.output_mode == "files_with_matches": - lines = [x for x in output.split("\n") if x.strip()] - lines.sort( - key=lambda p: os.path.getmtime(p) if os.path.exists(p) else 0, - reverse=True, - ) - output = "\n".join(lines) - - # Step 2: shorten paths to relative (prefix stripping) - search_base = os.path.abspath(os.path.expanduser(params.path)) - if os.path.isfile(search_base): - search_base = os.path.dirname(search_base) - output = _strip_path_prefix(output, search_base) - - # Step 3: count_matches summary (before pagination, on full results) - lines = output.split("\n") - if lines and lines[-1] == "": - lines = lines[:-1] - - if params.output_mode == "count_matches": - total_matches = 0 - total_files = 0 - for line in lines: - idx = line.rfind(":") - if idx > 0: - try: - total_matches += int(line[idx + 1 :]) - total_files += 1 - except ValueError: - pass - count_summary = ( - f"Found {total_matches} total occurrences across {total_files} files." - ) - message = f"{message} {count_summary}" if message else count_summary - - # Step 4: offset + head_limit pagination - if params.offset > 0: - lines = lines[params.offset :] - - effective_limit = params.head_limit - if effective_limit and len(lines) > effective_limit: - total = len(lines) + params.offset - lines = lines[:effective_limit] - output = "\n".join(lines) - truncation_msg = ( - f"Results truncated to {effective_limit} lines (total: {total}). " - f"Use offset={params.offset + effective_limit} to see more." - ) - message = f"{message} {truncation_msg}" if message else truncation_msg - else: - output = "\n".join(lines) - - if not output and not buffer_truncated: - return builder.ok(message="No matches found") - - builder.write(output) - return builder.ok(message=message) - - except asyncio.CancelledError: - raise - except Exception as e: - return ToolError( - message=f"Failed to grep. Error: {str(e)}", - brief="Failed to grep", - ) diff --git a/src/kimi_cli/tools/file/plan_mode.py b/src/kimi_cli/tools/file/plan_mode.py deleted file mode 100644 index fc25bf56a..000000000 --- a/src/kimi_cli/tools/file/plan_mode.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass -from pathlib import Path - -from kaos.path import KaosPath -from kosong.tooling import ToolError - - -@dataclass(frozen=True) -class PlanEditTarget: - active: bool - plan_path: Path | None - is_plan_target: bool - - -def inspect_plan_edit_target( - path: KaosPath, - *, - plan_mode_checker: Callable[[], bool] | None, - plan_file_path_getter: Callable[[], Path | None] | None, -) -> PlanEditTarget | ToolError: - """Resolve whether a file edit is targeting the current plan artifact.""" - if plan_mode_checker is None or not plan_mode_checker(): - return PlanEditTarget(active=False, plan_path=None, is_plan_target=False) - - plan_path = plan_file_path_getter() if plan_file_path_getter is not None else None - if plan_path is None: - return ToolError( - message="Plan mode is active, but the current plan file is unavailable.", - brief="Plan file unavailable", - ) - - canonical_plan_path = KaosPath(str(plan_path)).canonical() - if str(path) != str(canonical_plan_path): - return ToolError( - message=( - "Plan mode is active. You may only edit the current plan file: " - f"`{canonical_plan_path}`." - ), - brief="Plan mode restriction", - ) - - return PlanEditTarget(active=True, plan_path=plan_path, is_plan_target=True) diff --git a/src/kimi_cli/tools/file/plan_mode.ts b/src/kimi_cli/tools/file/plan_mode.ts new file mode 100644 index 000000000..9d1e4d595 --- /dev/null +++ b/src/kimi_cli/tools/file/plan_mode.ts @@ -0,0 +1,50 @@ +/** + * Plan mode edit target validation. + * Corresponds to Python tools/file/plan_mode.py + */ + +import { resolve } from "node:path"; +import type { ToolResult } from "../types.ts"; +import { ToolError } from "../types.ts"; + +export interface PlanEditTarget { + active: boolean; + planPath: string | null; + isPlanTarget: boolean; +} + +/** + * Resolve whether a file edit is targeting the current plan artifact. + * Returns a PlanEditTarget on success, or a ToolResult error on failure. + */ +export function inspectPlanEditTarget( + filePath: string, + opts: { + planModeChecker?: () => boolean; + planFilePathGetter?: () => string | null; + }, +): PlanEditTarget | ToolResult { + const { planModeChecker, planFilePathGetter } = opts; + + if (!planModeChecker || !planModeChecker()) { + return { active: false, planPath: null, isPlanTarget: false }; + } + + const planPath = planFilePathGetter?.() ?? null; + if (planPath === null) { + return ToolError( + "Plan mode is active, but the current plan file is unavailable.", + ); + } + + const canonicalPlanPath = resolve(planPath); + const canonicalFilePath = resolve(filePath); + + if (canonicalFilePath !== canonicalPlanPath) { + return ToolError( + `Plan mode is active. You may only edit the current plan file: \`${canonicalPlanPath}\`.`, + ); + } + + return { active: true, planPath, isPlanTarget: true }; +} diff --git a/src/kimi_cli/tools/file/read.md b/src/kimi_cli/tools/file/read.md deleted file mode 100644 index 57e08a24c..000000000 --- a/src/kimi_cli/tools/file/read.md +++ /dev/null @@ -1,14 +0,0 @@ -Read text content from a file. - -**Tips:** -- Make sure you follow the description of each tool parameter. -- A `` tag will be given before the read file content. -- The system will notify you when there is anything wrong when reading the file. -- This tool is a tool that you typically want to use in parallel. Always read multiple files in one response when possible. -- This tool can only read text files. To read images or videos, use other appropriate tools. To list directories, use the Glob tool or `ls` command via the Shell tool. To read other file types, use appropriate commands via the Shell tool. -- If the file doesn't exist or path is invalid, an error will be returned. -- If you want to search for a certain content/pattern, prefer Grep tool over ReadFile. -- Content will be returned with a line number before each line like `cat -n` format. -- Use `line_offset` and `n_lines` parameters when you only need to read a part of the file. -- The maximum number of lines that can be read at once is ${MAX_LINES}. -- Any lines longer than ${MAX_LINE_LENGTH} characters will be truncated, ending with "...". diff --git a/src/kimi_cli/tools/file/read.py b/src/kimi_cli/tools/file/read.py deleted file mode 100644 index 1f0c0da81..000000000 --- a/src/kimi_cli/tools/file/read.py +++ /dev/null @@ -1,189 +0,0 @@ -from pathlib import Path -from typing import override - -from kaos.path import KaosPath -from kosong.tooling import CallableTool2, ToolError, ToolOk, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.soul.agent import Runtime -from kimi_cli.tools.file.utils import MEDIA_SNIFF_BYTES, detect_file_type -from kimi_cli.tools.utils import load_desc, truncate_line -from kimi_cli.utils.path import is_within_workspace - -MAX_LINES = 1000 -MAX_LINE_LENGTH = 2000 -MAX_BYTES = 100 << 10 # 100KB - - -class Params(BaseModel): - path: str = Field( - description=( - "The path to the file to read. Absolute paths are required when reading files " - "outside the working directory." - ) - ) - line_offset: int = Field( - description=( - "The line number to start reading from. " - "By default read from the beginning of the file. " - "Set this when the file is too large to read at once." - ), - default=1, - ge=1, - ) - n_lines: int = Field( - description=( - "The number of lines to read. " - f"By default read up to {MAX_LINES} lines, which is the max allowed value. " - "Set this value when the file is too large to read at once." - ), - default=MAX_LINES, - ge=1, - ) - - -class ReadFile(CallableTool2[Params]): - name: str = "ReadFile" - params: type[Params] = Params - - def __init__(self, runtime: Runtime) -> None: - description = load_desc( - Path(__file__).parent / "read.md", - { - "MAX_LINES": MAX_LINES, - "MAX_LINE_LENGTH": MAX_LINE_LENGTH, - "MAX_BYTES": MAX_BYTES, - }, - ) - super().__init__(description=description) - self._runtime = runtime - self._work_dir = runtime.builtin_args.KIMI_WORK_DIR - self._additional_dirs = runtime.additional_dirs - - async def _validate_path(self, path: KaosPath) -> ToolError | None: - """Validate that the path is safe to read.""" - resolved_path = path.canonical() - - if ( - not is_within_workspace(resolved_path, self._work_dir, self._additional_dirs) - and not path.is_absolute() - ): - # Outside files can only be read with absolute paths - return ToolError( - message=( - f"`{path}` is not an absolute path. " - "You must provide an absolute path to read a file " - "outside the working directory." - ), - brief="Invalid path", - ) - return None - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - # TODO: checks: - # - check if the path may contain secrets - - if not params.path: - return ToolError( - message="File path cannot be empty.", - brief="Empty file path", - ) - - try: - p = KaosPath(params.path).expanduser() - if err := await self._validate_path(p): - return err - p = p.canonical() - - if not await p.exists(): - return ToolError( - message=f"`{params.path}` does not exist.", - brief="File not found", - ) - if not await p.is_file(): - return ToolError( - message=f"`{params.path}` is not a file.", - brief="Invalid path", - ) - - header = await p.read_bytes(MEDIA_SNIFF_BYTES) - file_type = detect_file_type(str(p), header=header) - if file_type.kind in ("image", "video"): - return ToolError( - message=( - f"`{params.path}` is a {file_type.kind} file. " - "Use other appropriate tools to read image or video files." - ), - brief="Unsupported file type", - ) - - if file_type.kind == "unknown": - return ToolError( - message=( - f"`{params.path}` seems not readable. " - "You may need to read it with proper shell commands, Python tools " - "or MCP tools if available. " - "If you read/operate it with Python, you MUST ensure that any " - "third-party packages are installed in a virtual environment (venv)." - ), - brief="File not readable", - ) - - assert params.line_offset >= 1 - assert params.n_lines >= 1 - - lines: list[str] = [] - n_bytes = 0 - truncated_line_numbers: list[int] = [] - max_lines_reached = False - max_bytes_reached = False - current_line_no = 0 - async for line in p.read_lines(errors="replace"): - current_line_no += 1 - if current_line_no < params.line_offset: - continue - truncated = truncate_line(line, MAX_LINE_LENGTH) - if truncated != line: - truncated_line_numbers.append(current_line_no) - lines.append(truncated) - n_bytes += len(truncated.encode("utf-8")) - if len(lines) >= params.n_lines: - break - if len(lines) >= MAX_LINES: - max_lines_reached = True - break - if n_bytes >= MAX_BYTES: - max_bytes_reached = True - break - - # Format output with line numbers like `cat -n` - lines_with_no: list[str] = [] - for line_num, line in zip( - range(params.line_offset, params.line_offset + len(lines)), lines, strict=True - ): - # Use 6-digit line number width, right-aligned, with tab separator - lines_with_no.append(f"{line_num:6d}\t{line}") - - message = ( - f"{len(lines)} lines read from file starting from line {params.line_offset}." - if len(lines) > 0 - else "No lines read from file." - ) - if max_lines_reached: - message += f" Max {MAX_LINES} lines reached." - elif max_bytes_reached: - message += f" Max {MAX_BYTES} bytes reached." - elif len(lines) < params.n_lines: - message += " End of file reached." - if truncated_line_numbers: - message += f" Lines {truncated_line_numbers} were truncated." - return ToolOk( - output="".join(lines_with_no), # lines already contain \n, just join them - message=message, - ) - except Exception as e: - return ToolError( - message=f"Failed to read {params.path}. Error: {e}", - brief="Failed to read file", - ) diff --git a/src/kimi_cli/tools/file/read.ts b/src/kimi_cli/tools/file/read.ts new file mode 100644 index 000000000..5fff28202 --- /dev/null +++ b/src/kimi_cli/tools/file/read.ts @@ -0,0 +1,240 @@ +/** + * ReadFile tool — read text content from a file. + * Corresponds to Python tools/file/read.py + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolOk } from "../types.ts"; + +const MAX_LINES = 1000; +const MAX_LINE_LENGTH = 2000; +const MAX_BYTES = 100 * 1024; // 100KB + +const DESCRIPTION = `Read text content from a file. + +**Tips:** +- A \`\` tag will be given before the read file content. +- This tool can only read text files. +- Content will be returned with a line number before each line like \`cat -n\` format. +- Use \`line_offset\` and \`n_lines\` parameters when you only need to read a part of the file. +- The maximum number of lines that can be read at once is ${MAX_LINES}. +- Any lines longer than ${MAX_LINE_LENGTH} characters will be truncated, ending with "...".`; + +const ParamsSchema = z.object({ + path: z.string().describe( + "The path to the file to read. Absolute paths are required when reading files outside the working directory.", + ), + line_offset: z + .number() + .int() + .min(1) + .default(1) + .describe("The line number to start reading from. Defaults to 1."), + n_lines: z + .number() + .int() + .min(1) + .default(MAX_LINES) + .describe( + `The number of lines to read. Defaults to ${MAX_LINES} (max allowed).`, + ), +}); + +type Params = z.infer; + +function truncateLine(line: string, maxLength: number): string { + if (line.length <= maxLength) return line; + return line.slice(0, maxLength - 3) + "..."; +} + +function resolvePath(filePath: string, workingDir: string): string { + if (filePath.startsWith("/") || filePath.startsWith("~")) { + if (filePath.startsWith("~")) { + const home = process.env.HOME || process.env.USERPROFILE || ""; + return filePath.replace(/^~/, home); + } + return filePath; + } + return `${workingDir}/${filePath}`; +} + +// ── Binary file detection ───────────────────────────── +// Magic byte signatures for common binary formats + +const BINARY_SIGNATURES: Array<{ bytes: number[]; type: string }> = [ + { bytes: [0x89, 0x50, 0x4e, 0x47], type: "PNG image" }, + { bytes: [0xff, 0xd8, 0xff], type: "JPEG image" }, + { bytes: [0x47, 0x49, 0x46, 0x38], type: "GIF image" }, + { bytes: [0x52, 0x49, 0x46, 0x46], type: "RIFF (WebP/AVI)" }, + { bytes: [0x50, 0x4b, 0x03, 0x04], type: "ZIP archive" }, + { bytes: [0x1f, 0x8b], type: "gzip archive" }, + { bytes: [0x25, 0x50, 0x44, 0x46], type: "PDF document" }, + { bytes: [0x7f, 0x45, 0x4c, 0x46], type: "ELF binary" }, + { bytes: [0xfe, 0xed, 0xfa, 0xce], type: "Mach-O binary" }, + { bytes: [0xfe, 0xed, 0xfa, 0xcf], type: "Mach-O binary (64-bit)" }, + { bytes: [0xce, 0xfa, 0xed, 0xfe], type: "Mach-O binary (reverse)" }, + { bytes: [0xcf, 0xfa, 0xed, 0xfe], type: "Mach-O binary (64-bit reverse)" }, + { bytes: [0xca, 0xfe, 0xba, 0xbe], type: "Mach-O universal binary" }, + { bytes: [0x4d, 0x5a], type: "Windows executable" }, +]; + +const NON_TEXT_EXTENSIONS = new Set([ + // Images + "png", "jpg", "jpeg", "gif", "bmp", "ico", "webp", "svg", "tiff", "tif", "avif", "heic", "heif", + // Video + "mp4", "mkv", "avi", "mov", "wmv", "flv", "webm", "m4v", "3gp", + // Audio + "mp3", "wav", "ogg", "flac", "aac", "wma", "m4a", "opus", + // Archives + "zip", "tar", "gz", "bz2", "xz", "7z", "rar", "zst", + // Binaries + "exe", "dll", "so", "dylib", "o", "a", "lib", "bin", "dat", + // Documents (binary) + "pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", + // Database + "db", "sqlite", "sqlite3", + // Fonts + "ttf", "otf", "woff", "woff2", "eot", + // Other + "pyc", "pyo", "class", "jar", "war", "deb", "rpm", "dmg", "iso", "img", +]); + +function detectBinaryType(resolvedPath: string, headerBytes: Uint8Array): string | null { + // Check magic bytes + for (const sig of BINARY_SIGNATURES) { + if (sig.bytes.every((b, i) => headerBytes[i] === b)) { + return sig.type; + } + } + + // Check for NUL bytes in first 8KB (strong binary indicator) + for (let i = 0; i < Math.min(headerBytes.length, 8192); i++) { + if (headerBytes[i] === 0x00) { + return "binary file (contains NUL bytes)"; + } + } + + // Check extension + const ext = resolvedPath.split(".").pop()?.toLowerCase() ?? ""; + if (NON_TEXT_EXTENSIONS.has(ext)) { + return `binary file (.${ext})`; + } + + return null; +} + +export class ReadFile extends CallableTool { + readonly name = "ReadFile"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, ctx: ToolContext): Promise { + if (!params.path) { + return ToolError("File path cannot be empty."); + } + + try { + const resolvedPath = resolvePath(params.path, ctx.workingDir); + const file = Bun.file(resolvedPath); + + if (!(await file.exists())) { + return ToolError(`\`${params.path}\` does not exist.`); + } + + // Check if it's a directory + const { stat } = await import("node:fs/promises"); + try { + const info = await stat(resolvedPath); + if (info.isDirectory()) { + return ToolError(`\`${params.path}\` is a directory, not a file. Use the Glob tool to list directory contents.`); + } + } catch { + // stat failed — continue, file.text() will catch it + } + + // Binary detection: read header bytes first + const headerSize = Math.min(file.size, 8192); + if (headerSize > 0) { + const headerBuf = await file.slice(0, headerSize).arrayBuffer(); + const headerBytes = new Uint8Array(headerBuf); + const binaryType = detectBinaryType(resolvedPath, headerBytes); + if (binaryType) { + return ToolError( + `\`${params.path}\` is a ${binaryType}. This tool can only read text files. ` + + `Use the Shell tool if you need to inspect binary files (e.g. \`file\`, \`hexdump\`).` + ); + } + } + + // Read file content — stream to avoid OOM on large files + const text = await file.text(); + const allLines = text.split("\n"); + + const lineOffset = params.line_offset; + const nLines = params.n_lines; + + const lines: string[] = []; + const truncatedLineNumbers: number[] = []; + let nBytes = 0; + let maxLinesReached = false; + let maxBytesReached = false; + + for ( + let i = lineOffset - 1; + i < allLines.length && lines.length < nLines; + i++ + ) { + const lineNo = i + 1; + let line = allLines[i] ?? ""; + // Add newline back except for last line if original doesn't end with \n + if (i < allLines.length - 1 || text.endsWith("\n")) { + line += "\n"; + } + + const truncated = truncateLine(line, MAX_LINE_LENGTH); + if (truncated !== line) { + truncatedLineNumbers.push(lineNo); + } + lines.push(truncated); + nBytes += new TextEncoder().encode(truncated).length; + + if (lines.length >= MAX_LINES) { + maxLinesReached = true; + break; + } + if (nBytes >= MAX_BYTES) { + maxBytesReached = true; + break; + } + } + + // Format output with line numbers (cat -n format) + const linesWithNo = lines.map((line: string, idx: number) => { + const lineNum = lineOffset + idx; + return `${String(lineNum).padStart(6)}\t${line}`; + }); + + let message = + lines.length > 0 + ? `${lines.length} lines read from file starting from line ${lineOffset}.` + : "No lines read from file."; + + if (maxLinesReached) { + message += ` Max ${MAX_LINES} lines reached.`; + } else if (maxBytesReached) { + message += ` Max ${MAX_BYTES} bytes reached.`; + } else if (lines.length < nLines) { + message += " End of file reached."; + } + if (truncatedLineNumbers.length > 0) { + message += ` Lines [${truncatedLineNumbers.join(", ")}] were truncated.`; + } + + return ToolOk(linesWithNo.join(""), message); + } catch (e) { + return ToolError(`Failed to read ${params.path}. Error: ${e}`); + } + } +} diff --git a/src/kimi_cli/tools/file/read_media.md b/src/kimi_cli/tools/file/read_media.md deleted file mode 100644 index 7fdc51f17..000000000 --- a/src/kimi_cli/tools/file/read_media.md +++ /dev/null @@ -1,24 +0,0 @@ -Read media content from a file. - -**Tips:** -- Make sure you follow the description of each tool parameter. -- A `` tag will be given before the read file content. -- The system will notify you when there is anything wrong when reading the file. -- This tool is a tool that you typically want to use in parallel. Always read multiple files in one response when possible. -- This tool can only read image or video files. To read other types of files, use the ReadFile tool. To list directories, use the Glob tool or `ls` command via the Shell tool. -- If the file doesn't exist or path is invalid, an error will be returned. -- The maximum size that can be read is ${MAX_MEDIA_MEGABYTES}MB. An error will be returned if the file is larger than this limit. -- The media content will be returned in a form that you can directly view and understand. - -**Capabilities** -{% if "image_in" in capabilities and "video_in" in capabilities %} -- This tool supports image and video files for the current model. -{% elif "image_in" in capabilities %} -- This tool supports image files for the current model. -- Video files are not supported by the current model. -{% elif "video_in" in capabilities %} -- This tool supports video files for the current model. -- Image files are not supported by the current model. -{% else %} -- The current model does not support image or video input. -{% endif %} diff --git a/src/kimi_cli/tools/file/read_media.py b/src/kimi_cli/tools/file/read_media.py deleted file mode 100644 index 209854985..000000000 --- a/src/kimi_cli/tools/file/read_media.py +++ /dev/null @@ -1,215 +0,0 @@ -import base64 -from io import BytesIO -from pathlib import Path -from typing import override - -from kaos.path import KaosPath -from kosong.chat_provider.kimi import Kimi -from kosong.tooling import CallableTool2, ToolError, ToolOk, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.soul.agent import Runtime -from kimi_cli.tools import SkipThisTool -from kimi_cli.tools.file.utils import MEDIA_SNIFF_BYTES, FileType, detect_file_type -from kimi_cli.tools.utils import load_desc -from kimi_cli.utils.media_tags import wrap_media_part -from kimi_cli.utils.path import is_within_workspace -from kimi_cli.wire.types import ImageURLPart, VideoURLPart - -MAX_MEDIA_MEGABYTES = 100 - - -def _to_data_url(mime_type: str, data: bytes) -> str: - encoded = base64.b64encode(data).decode("ascii") - return f"data:{mime_type};base64,{encoded}" - - -def _extract_image_size(data: bytes) -> tuple[int, int] | None: - try: - from PIL import Image - except Exception: - return None - try: - with Image.open(BytesIO(data)) as image: - image.load() - return image.size - except Exception: - return None - - -class Params(BaseModel): - path: str = Field( - description=( - "The path to the file to read. Absolute paths are required when reading files " - "outside the working directory." - ) - ) - - -class ReadMediaFile(CallableTool2[Params]): - name: str = "ReadMediaFile" - params: type[Params] = Params - - def __init__(self, runtime: Runtime) -> None: - capabilities = runtime.llm.capabilities if runtime.llm else set[str]() - if "image_in" not in capabilities and "video_in" not in capabilities: - raise SkipThisTool() - - description = load_desc( - Path(__file__).parent / "read_media.md", - { - "MAX_MEDIA_MEGABYTES": MAX_MEDIA_MEGABYTES, - "capabilities": capabilities, - }, - ) - super().__init__(description=description) - - self._runtime = runtime - self._work_dir = runtime.builtin_args.KIMI_WORK_DIR - self._additional_dirs = runtime.additional_dirs - self._capabilities = capabilities - - async def _validate_path(self, path: KaosPath) -> ToolError | None: - """Validate that the path is safe to read.""" - resolved_path = path.canonical() - - if ( - not is_within_workspace(resolved_path, self._work_dir, self._additional_dirs) - and not path.is_absolute() - ): - # Outside files can only be read with absolute paths - return ToolError( - message=( - f"`{path}` is not an absolute path. " - "You must provide an absolute path to read a file " - "outside the working directory." - ), - brief="Invalid path", - ) - return None - - async def _read_media(self, path: KaosPath, file_type: FileType) -> ToolReturnValue: - assert file_type.kind in ("image", "video") - - media_path = str(path) - stat = await path.stat() - size = stat.st_size - if size == 0: - return ToolError( - message=f"`{path}` is empty.", - brief="Empty file", - ) - if size > (MAX_MEDIA_MEGABYTES << 20): - return ToolError( - message=( - f"`{path}` is {size} bytes, which exceeds the max " - f"{MAX_MEDIA_MEGABYTES}MB bytes for media files." - ), - brief="File too large", - ) - - match file_type.kind: - case "image": - data = await path.read_bytes() - data_url = _to_data_url(file_type.mime_type, data) - part = ImageURLPart(image_url=ImageURLPart.ImageURL(url=data_url)) - wrapped = wrap_media_part(part, tag="image", attrs={"path": media_path}) - image_size = _extract_image_size(data) - case "video": - data = await path.read_bytes() - if (llm := self._runtime.llm) and isinstance(llm.chat_provider, Kimi): - part = await llm.chat_provider.files.upload_video( - data=data, - mime_type=file_type.mime_type, - ) - wrapped = wrap_media_part(part, tag="video", attrs={"path": media_path}) - else: - data_url = _to_data_url(file_type.mime_type, data) - part = VideoURLPart(video_url=VideoURLPart.VideoURL(url=data_url)) - wrapped = wrap_media_part(part, tag="video", attrs={"path": media_path}) - image_size = None - - size_hint = "" - if image_size: - size_hint = f", original size {image_size[0]}x{image_size[1]}px" - note = ( - " If you need to output coordinates, output relative coordinates first and " - "compute absolute coordinates using the original image size; if you generate or " - "edit images/videos via commands or scripts, read the result back immediately " - "before continuing." - ) - return ToolOk( - output=wrapped, - message=( - f"Loaded {file_type.kind} file `{path}` " - f"({file_type.mime_type}, {size} bytes{size_hint}).{note}" - ), - ) - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - if not params.path: - return ToolError( - message="File path cannot be empty.", - brief="Empty file path", - ) - - try: - p = KaosPath(params.path).expanduser() - if err := await self._validate_path(p): - return err - p = p.canonical() - - if not await p.exists(): - return ToolError( - message=f"`{params.path}` does not exist.", - brief="File not found", - ) - if not await p.is_file(): - return ToolError( - message=f"`{params.path}` is not a file.", - brief="Invalid path", - ) - - header = await p.read_bytes(MEDIA_SNIFF_BYTES) - file_type = detect_file_type(str(p), header=header) - if file_type.kind == "text": - return ToolError( - message=f"`{params.path}` is a text file. Use ReadFile to read text files.", - brief="Unsupported file type", - ) - if file_type.kind == "unknown": - return ToolError( - message=( - f"`{params.path}` seems not readable as an image or video file. " - "You may need to read it with proper shell commands, Python tools " - "or MCP tools if available. " - "If you read/operate it with Python, you MUST ensure that any " - "third-party packages are installed in a virtual environment (venv)." - ), - brief="File not readable", - ) - - if file_type.kind == "image" and "image_in" not in self._capabilities: - return ToolError( - message=( - "The current model does not support image input. " - "Tell the user to use a model with image input capability." - ), - brief="Unsupported media type", - ) - if file_type.kind == "video" and "video_in" not in self._capabilities: - return ToolError( - message=( - "The current model does not support video input. " - "Tell the user to use a model with video input capability." - ), - brief="Unsupported media type", - ) - - return await self._read_media(p, file_type) - except Exception as e: - return ToolError( - message=f"Failed to read {params.path}. Error: {e}", - brief="Failed to read file", - ) diff --git a/src/kimi_cli/tools/file/read_media.ts b/src/kimi_cli/tools/file/read_media.ts new file mode 100644 index 000000000..5f512185f --- /dev/null +++ b/src/kimi_cli/tools/file/read_media.ts @@ -0,0 +1,115 @@ +/** + * ReadMediaFile tool — read images and videos. + * Corresponds to Python tools/file/read_media.py + */ + +import { resolve } from "node:path"; +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolOk } from "../types.ts"; +import { MEDIA_SNIFF_BYTES, detectFileType, type FileType } from "./utils.ts"; + +const MAX_MEDIA_MEGABYTES = 100; + +const DESCRIPTION = `Read an image or video file from disk. + +**Tips:** +- Use this tool to view images and videos directly. +- Maximum file size: ${MAX_MEDIA_MEGABYTES}MB. +- For text files, use ReadFile instead.`; + +const ParamsSchema = z.object({ + path: z.string().describe( + "The path to the file to read. Absolute paths are required when reading files outside the working directory.", + ), +}); + +type Params = z.infer; + +function toDataUrl(mimeType: string, data: Uint8Array): string { + const base64 = Buffer.from(data).toString("base64"); + return `data:${mimeType};base64,${base64}`; +} + +function resolvePath(filePath: string, workingDir: string): string { + if (filePath.startsWith("/") || filePath.startsWith("~")) { + if (filePath.startsWith("~")) { + const home = process.env.HOME || process.env.USERPROFILE || ""; + return filePath.replace(/^~/, home); + } + return filePath; + } + return resolve(workingDir, filePath); +} + +export class ReadMediaFile extends CallableTool { + readonly name = "ReadMediaFile"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, ctx: ToolContext): Promise { + if (!params.path) { + return ToolError("File path cannot be empty."); + } + + try { + const resolvedPath = resolvePath(params.path, ctx.workingDir); + const file = Bun.file(resolvedPath); + + if (!(await file.exists())) { + return ToolError(`\`${params.path}\` does not exist.`); + } + + const { stat: fsStat } = await import("node:fs/promises"); + const info = await fsStat(resolvedPath); + if (!info.isFile()) { + return ToolError(`\`${params.path}\` is not a file.`); + } + + const size = info.size; + if (size === 0) { + return ToolError(`\`${params.path}\` is empty.`); + } + if (size > MAX_MEDIA_MEGABYTES * 1024 * 1024) { + return ToolError( + `\`${params.path}\` is ${size} bytes, which exceeds the max ${MAX_MEDIA_MEGABYTES}MB for media files.`, + ); + } + + // Read header for file type detection + const headerBuf = await file.slice(0, MEDIA_SNIFF_BYTES).arrayBuffer(); + const header = new Uint8Array(headerBuf); + const fileType = detectFileType(resolvedPath, header); + + if (fileType.kind === "text") { + return ToolError( + `\`${params.path}\` is a text file. Use ReadFile to read text files.`, + ); + } + if (fileType.kind === "unknown") { + return ToolError( + `\`${params.path}\` seems not readable as an image or video file. ` + + "You may need to read it with proper shell commands or other tools.", + ); + } + + // Read the full file + const data = new Uint8Array(await file.arrayBuffer()); + const dataUrl = toDataUrl(fileType.mimeType, data); + + const note = + " If you need to output coordinates, output relative coordinates first and " + + "compute absolute coordinates using the original image size; if you generate or " + + "edit images/videos via commands or scripts, read the result back immediately " + + "before continuing."; + + return ToolOk( + dataUrl, + `Loaded ${fileType.kind} file \`${params.path}\` (${fileType.mimeType}, ${size} bytes).${note}`, + ); + } catch (e) { + return ToolError(`Failed to read ${params.path}. Error: ${e}`); + } + } +} diff --git a/src/kimi_cli/tools/file/replace.md b/src/kimi_cli/tools/file/replace.md deleted file mode 100644 index fab288073..000000000 --- a/src/kimi_cli/tools/file/replace.md +++ /dev/null @@ -1,7 +0,0 @@ -Replace specific strings within a specified file. - -**Tips:** -- Only use this tool on text files. -- Multi-line strings are supported. -- Can specify a single edit or a list of edits in one call. -- You should prefer this tool over WriteFile tool and Shell `sed` command. diff --git a/src/kimi_cli/tools/file/replace.py b/src/kimi_cli/tools/file/replace.py deleted file mode 100644 index 5c509777d..000000000 --- a/src/kimi_cli/tools/file/replace.py +++ /dev/null @@ -1,193 +0,0 @@ -from collections.abc import Callable -from pathlib import Path -from typing import override - -from kaos.path import KaosPath -from kosong.tooling import CallableTool2, ToolError, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.soul.agent import Runtime -from kimi_cli.soul.approval import Approval -from kimi_cli.tools.display import DisplayBlock -from kimi_cli.tools.file import FileActions -from kimi_cli.tools.file.plan_mode import inspect_plan_edit_target -from kimi_cli.tools.utils import load_desc -from kimi_cli.utils.diff import build_diff_blocks -from kimi_cli.utils.path import is_within_workspace - -_BASE_DESCRIPTION = load_desc(Path(__file__).parent / "replace.md") - - -class Edit(BaseModel): - old: str = Field(description="The old string to replace. Can be multi-line.") - new: str = Field(description="The new string to replace with. Can be multi-line.") - replace_all: bool = Field(description="Whether to replace all occurrences.", default=False) - - -class Params(BaseModel): - path: str = Field( - description=( - "The path to the file to edit. Absolute paths are required when editing files " - "outside the working directory." - ) - ) - edit: Edit | list[Edit] = Field( - description=( - "The edit(s) to apply to the file. " - "You can provide a single edit or a list of edits here." - ) - ) - - -class StrReplaceFile(CallableTool2[Params]): - name: str = "StrReplaceFile" - description: str = _BASE_DESCRIPTION - params: type[Params] = Params - - def __init__(self, runtime: Runtime, approval: Approval): - super().__init__() - self._work_dir = runtime.builtin_args.KIMI_WORK_DIR - self._additional_dirs = runtime.additional_dirs - self._approval = approval - self._plan_mode_checker: Callable[[], bool] | None = None - self._plan_file_path_getter: Callable[[], Path | None] | None = None - - def bind_plan_mode( - self, checker: Callable[[], bool], path_getter: Callable[[], Path | None] - ) -> None: - """Bind plan mode state checker and plan file path getter.""" - self._plan_mode_checker = checker - self._plan_file_path_getter = path_getter - - async def _validate_path(self, path: KaosPath) -> ToolError | None: - """Validate that the path is safe to edit.""" - resolved_path = path.canonical() - - if ( - not is_within_workspace(resolved_path, self._work_dir, self._additional_dirs) - and not path.is_absolute() - ): - return ToolError( - message=( - f"`{path}` is not an absolute path. " - "You must provide an absolute path to edit a file " - "outside the working directory." - ), - brief="Invalid path", - ) - return None - - def _apply_edit(self, content: str, edit: Edit) -> str: - """Apply a single edit to the content.""" - if edit.replace_all: - return content.replace(edit.old, edit.new) - else: - return content.replace(edit.old, edit.new, 1) - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - if not params.path: - return ToolError( - message="File path cannot be empty.", - brief="Empty file path", - ) - - try: - p = KaosPath(params.path).expanduser() - if err := await self._validate_path(p): - return err - p = p.canonical() - - plan_target = inspect_plan_edit_target( - p, - plan_mode_checker=self._plan_mode_checker, - plan_file_path_getter=self._plan_file_path_getter, - ) - if isinstance(plan_target, ToolError): - return plan_target - - is_plan_file_edit = plan_target.is_plan_target - - if not await p.exists(): - if is_plan_file_edit: - return ToolError( - message=( - "The current plan file does not exist yet. " - "Use WriteFile to create it before calling StrReplaceFile." - ), - brief="Plan file not created", - ) - return ToolError( - message=f"`{params.path}` does not exist.", - brief="File not found", - ) - if not await p.is_file(): - return ToolError( - message=f"`{params.path}` is not a file.", - brief="Invalid path", - ) - - # Read the file content - content = await p.read_text(errors="replace") - - original_content = content - edits = [params.edit] if isinstance(params.edit, Edit) else params.edit - - # Apply all edits - for edit in edits: - content = self._apply_edit(content, edit) - - # Check if any changes were made - if content == original_content: - return ToolError( - message="No replacements were made. The old string was not found in the file.", - brief="No replacements made", - ) - - diff_blocks: list[DisplayBlock] = await build_diff_blocks( - str(p), original_content, content - ) - - action = ( - FileActions.EDIT - if is_within_workspace(p, self._work_dir, self._additional_dirs) - else FileActions.EDIT_OUTSIDE - ) - - # Plan file edits are auto-approved; all other edits need approval. - if not is_plan_file_edit: - result = await self._approval.request( - self.name, - action, - f"Edit file `{p}`", - display=diff_blocks, - ) - if not result: - return result.rejection_error() - - # Write the modified content back to the file - await p.write_text(content, errors="replace") - - # Count changes for success message - total_replacements = 0 - for edit in edits: - if edit.replace_all: - total_replacements += original_content.count(edit.old) - else: - total_replacements += 1 if edit.old in original_content else 0 - - return ToolReturnValue( - is_error=False, - output="", - message=( - f"File successfully edited. " - f"Applied {len(edits)} edit(s) with {total_replacements} total replacement(s)." - ), - display=diff_blocks, - ) - - except Exception as e: - return ToolError( - message=f"Failed to edit. Error: {e}", - brief="Failed to edit file", - ) diff --git a/src/kimi_cli/tools/file/replace.ts b/src/kimi_cli/tools/file/replace.ts new file mode 100644 index 000000000..584c3b999 --- /dev/null +++ b/src/kimi_cli/tools/file/replace.ts @@ -0,0 +1,184 @@ +/** + * StrReplaceFile tool — edit/replace strings in a file. + * Corresponds to Python tools/file/replace.py + */ + +import { resolve } from "node:path"; +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError } from "../types.ts"; +import { inspectPlanEditTarget } from "./plan_mode.ts"; + +const DESCRIPTION = `Replace specific strings within a specified file. + +**Tips:** +- Only use this tool on text files. +- Multi-line strings are supported. +- Can specify a single edit or a list of edits in one call. +- You should prefer this tool over WriteFile tool and Shell \`sed\` command.`; + +const EditSchema = z.object({ + old: z.string().describe("The old string to replace. Can be multi-line."), + new: z.string().describe("The new string to replace with. Can be multi-line."), + replace_all: z + .boolean() + .default(false) + .describe("Whether to replace all occurrences."), +}); + +const ParamsSchema = z.object({ + path: z.string().describe( + "The path to the file to edit. Absolute paths are required when editing files outside the working directory.", + ), + edit: z + .union([EditSchema, z.array(EditSchema)]) + .describe("The edit(s) to apply to the file."), +}); + +type Params = z.infer; +type Edit = z.infer; + +function resolvePath(filePath: string, workingDir: string): string { + if (filePath.startsWith("/") || filePath.startsWith("~")) { + if (filePath.startsWith("~")) { + const home = process.env.HOME || process.env.USERPROFILE || ""; + return filePath.replace(/^~/, home); + } + return filePath; + } + return resolve(workingDir, filePath); +} + +function applyEdit(content: string, edit: Edit): string { + if (edit.replace_all) { + return content.split(edit.old).join(edit.new); + } + const idx = content.indexOf(edit.old); + if (idx === -1) return content; + return content.slice(0, idx) + edit.new + content.slice(idx + edit.old.length); +} + +export class StrReplaceFile extends CallableTool { + readonly name = "StrReplaceFile"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + /** Optional plan mode bindings. */ + private _planModeChecker?: () => boolean; + private _planFilePathGetter?: () => string | null; + + /** Bind plan mode state checker and plan file path getter. */ + bindPlanMode(checker: () => boolean, pathGetter: () => string | null): void { + this._planModeChecker = checker; + this._planFilePathGetter = pathGetter; + } + + async execute(params: Params, ctx: ToolContext): Promise { + if (!params.path) { + return ToolError("File path cannot be empty."); + } + + try { + const resolvedPath = resolvePath(params.path, ctx.workingDir); + + // Check plan mode restrictions + const planTarget = inspectPlanEditTarget(resolvedPath, { + planModeChecker: this._planModeChecker ?? ctx.getPlanMode, + planFilePathGetter: this._planFilePathGetter, + }); + if ("isError" in planTarget && planTarget.isError) { + return planTarget; + } + const isPlanFileEdit = !("isError" in planTarget) && planTarget.isPlanTarget; + + const file = Bun.file(resolvedPath); + + if (!(await file.exists())) { + if (isPlanFileEdit) { + return ToolError( + "The current plan file does not exist yet. " + + "Use WriteFile to create it before calling StrReplaceFile.", + ); + } + return ToolError(`\`${params.path}\` does not exist.`); + } + + // Check if it's actually a file + const { stat: fsStat } = await import("node:fs/promises"); + try { + const info = await fsStat(resolvedPath); + if (!info.isFile()) { + return ToolError(`\`${params.path}\` is not a file.`); + } + } catch { + // stat failed — continue + } + + // Read the file content + const originalContent = await file.text(); + let content = originalContent; + + const edits: Edit[] = Array.isArray(params.edit) + ? params.edit + : [params.edit]; + + // Apply all edits + for (const edit of edits) { + content = applyEdit(content, edit); + } + + // Check if any changes were made + if (content === originalContent) { + return ToolError( + "No replacements were made. The old string was not found in the file.", + ); + } + + // Plan file edits are auto-approved; all other edits need approval + if (!isPlanFileEdit) { + // Build diff preview + const diffLines: string[] = []; + for (const edit of edits) { + if (edit.old.length < 200 && edit.new.length < 200) { + diffLines.push(`-${edit.old.split("\n").join("\n-")}`); + diffLines.push(`+${edit.new.split("\n").join("\n+")}`); + } + } + const diffPreview = diffLines.length > 0 ? `\n${diffLines.join("\n")}` : ""; + + const decision = await ctx.approval( + "StrReplaceFile", + "edit", + `Edit file \`${resolvedPath}\` (${edits.length} edit(s))${diffPreview}`, + ); + if (decision === "reject") { + return ToolError( + "The tool call is rejected by the user. Stop what you are doing and wait for the user to tell you how to proceed.", + ); + } + } + + // Write the modified content back + await Bun.write(resolvedPath, content); + + // Count changes for success message + let totalReplacements = 0; + for (const edit of edits) { + if (edit.replace_all) { + totalReplacements += originalContent.split(edit.old).length - 1; + } else { + totalReplacements += originalContent.includes(edit.old) ? 1 : 0; + } + } + + return { + isError: false, + output: "", + message: `File successfully edited. Applied ${edits.length} edit(s) with ${totalReplacements} total replacement(s).`, + }; + } catch (e) { + return ToolError(`Failed to edit. Error: ${e}`); + } + } +} diff --git a/src/kimi_cli/tools/file/utils.py b/src/kimi_cli/tools/file/utils.py deleted file mode 100644 index d674f8989..000000000 --- a/src/kimi_cli/tools/file/utils.py +++ /dev/null @@ -1,257 +0,0 @@ -from __future__ import annotations - -import mimetypes -from dataclasses import dataclass -from pathlib import PurePath -from typing import Literal - -MEDIA_SNIFF_BYTES = 512 - -_EXTRA_MIME_TYPES = { - ".avif": "image/avif", - ".heic": "image/heic", - ".heif": "image/heif", - ".mkv": "video/x-matroska", - ".m4v": "video/x-m4v", - ".3gp": "video/3gpp", - ".3g2": "video/3gpp2", - # TypeScript files: override mimetypes default (video/mp2t for MPEG Transport Stream) - ".ts": "text/typescript", - ".tsx": "text/typescript", - ".mts": "text/typescript", - ".cts": "text/typescript", -} - -for suffix, mime_type in _EXTRA_MIME_TYPES.items(): - mimetypes.add_type(mime_type, suffix) - -_IMAGE_MIME_BY_SUFFIX = { - ".png": "image/png", - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".gif": "image/gif", - ".bmp": "image/bmp", - ".tif": "image/tiff", - ".tiff": "image/tiff", - ".webp": "image/webp", - ".ico": "image/x-icon", - ".heic": "image/heic", - ".heif": "image/heif", - ".avif": "image/avif", - ".svgz": "image/svg+xml", -} -_VIDEO_MIME_BY_SUFFIX = { - ".mp4": "video/mp4", - ".mkv": "video/x-matroska", - ".avi": "video/x-msvideo", - ".mov": "video/quicktime", - ".wmv": "video/x-ms-wmv", - ".webm": "video/webm", - ".m4v": "video/x-m4v", - ".flv": "video/x-flv", - ".3gp": "video/3gpp", - ".3g2": "video/3gpp2", -} -_TEXT_MIME_BY_SUFFIX = { - ".svg": "image/svg+xml", -} - -_ASF_HEADER = b"\x30\x26\xb2\x75\x8e\x66\xcf\x11\xa6\xd9\x00\xaa\x00\x62\xce\x6c" -_FTYP_IMAGE_BRANDS = { - "avif": "image/avif", - "avis": "image/avif", - "heic": "image/heic", - "heif": "image/heif", - "heix": "image/heif", - "hevc": "image/heic", - "mif1": "image/heif", - "msf1": "image/heif", -} -_FTYP_VIDEO_BRANDS = { - "isom": "video/mp4", - "iso2": "video/mp4", - "iso5": "video/mp4", - "mp41": "video/mp4", - "mp42": "video/mp4", - "avc1": "video/mp4", - "mp4v": "video/mp4", - "m4v": "video/x-m4v", - "qt": "video/quicktime", - "3gp4": "video/3gpp", - "3gp5": "video/3gpp", - "3gp6": "video/3gpp", - "3gp7": "video/3gpp", - "3g2": "video/3gpp2", -} - -_NON_TEXT_SUFFIXES = { - ".icns", - ".psd", - ".ai", - ".eps", - # Documents / office formats - ".pdf", - ".doc", - ".docx", - ".dot", - ".dotx", - ".rtf", - ".odt", - ".xls", - ".xlsx", - ".xlsm", - ".xlt", - ".xltx", - ".xltm", - ".ods", - ".ppt", - ".pptx", - ".pptm", - ".pps", - ".ppsx", - ".odp", - ".pages", - ".numbers", - ".key", - # Archives / compressed - ".zip", - ".rar", - ".7z", - ".tar", - ".gz", - ".tgz", - ".bz2", - ".xz", - ".zst", - ".lz", - ".lz4", - ".br", - ".cab", - ".ar", - ".deb", - ".rpm", - # Audio - ".mp3", - ".wav", - ".flac", - ".ogg", - ".oga", - ".opus", - ".aac", - ".m4a", - ".wma", - # Fonts - ".ttf", - ".otf", - ".woff", - ".woff2", - # Binaries / bundles - ".exe", - ".dll", - ".so", - ".dylib", - ".bin", - ".apk", - ".ipa", - ".jar", - ".class", - ".pyc", - ".pyo", - ".wasm", - # Disk images / databases - ".dmg", - ".iso", - ".img", - ".sqlite", - ".sqlite3", - ".db", - ".db3", -} - - -@dataclass(frozen=True) -class FileType: - kind: Literal["text", "image", "video", "unknown"] - mime_type: str - - -def _sniff_ftyp_brand(header: bytes) -> str | None: - if len(header) < 12 or header[4:8] != b"ftyp": - return None - brand = header[8:12].decode("ascii", errors="ignore").lower() - return brand.strip() - - -def sniff_media_from_magic(data: bytes) -> FileType | None: - header = data[:MEDIA_SNIFF_BYTES] - if header.startswith(b"\x89PNG\r\n\x1a\n"): - return FileType(kind="image", mime_type="image/png") - if header.startswith(b"\xff\xd8\xff"): - return FileType(kind="image", mime_type="image/jpeg") - if header.startswith((b"GIF87a", b"GIF89a")): - return FileType(kind="image", mime_type="image/gif") - if header.startswith(b"BM"): - return FileType(kind="image", mime_type="image/bmp") - if header.startswith((b"II*\x00", b"MM\x00*")): - return FileType(kind="image", mime_type="image/tiff") - if header.startswith(b"\x00\x00\x01\x00"): - return FileType(kind="image", mime_type="image/x-icon") - if header.startswith(b"RIFF") and len(header) >= 12: - chunk = header[8:12] - if chunk == b"WEBP": - return FileType(kind="image", mime_type="image/webp") - if chunk == b"AVI ": - return FileType(kind="video", mime_type="video/x-msvideo") - if header.startswith(b"FLV"): - return FileType(kind="video", mime_type="video/x-flv") - if header.startswith(_ASF_HEADER): - return FileType(kind="video", mime_type="video/x-ms-wmv") - if header.startswith(b"\x1a\x45\xdf\xa3"): - lowered = header.lower() - if b"webm" in lowered: - return FileType(kind="video", mime_type="video/webm") - if b"matroska" in lowered: - return FileType(kind="video", mime_type="video/x-matroska") - if brand := _sniff_ftyp_brand(header): - if brand in _FTYP_IMAGE_BRANDS: - return FileType(kind="image", mime_type=_FTYP_IMAGE_BRANDS[brand]) - if brand in _FTYP_VIDEO_BRANDS: - return FileType(kind="video", mime_type=_FTYP_VIDEO_BRANDS[brand]) - return None - - -def detect_file_type(path: str | PurePath, header: bytes | None = None) -> FileType: - suffix = PurePath(str(path)).suffix.lower() - media_hint: FileType | None = None - if suffix in _TEXT_MIME_BY_SUFFIX: - media_hint = FileType(kind="text", mime_type=_TEXT_MIME_BY_SUFFIX[suffix]) - elif suffix in _IMAGE_MIME_BY_SUFFIX: - media_hint = FileType(kind="image", mime_type=_IMAGE_MIME_BY_SUFFIX[suffix]) - elif suffix in _VIDEO_MIME_BY_SUFFIX: - media_hint = FileType(kind="video", mime_type=_VIDEO_MIME_BY_SUFFIX[suffix]) - else: - mime_type, _ = mimetypes.guess_type(str(path)) - if mime_type: - if mime_type.startswith("image/"): - media_hint = FileType(kind="image", mime_type=mime_type) - elif mime_type.startswith("video/"): - media_hint = FileType(kind="video", mime_type=mime_type) - - if media_hint and media_hint.kind in ("image", "video"): - return media_hint - - if header is not None: - sniffed = sniff_media_from_magic(header) - if sniffed: - if media_hint and sniffed.kind != media_hint.kind: - return FileType(kind="unknown", mime_type="") - return sniffed - # NUL bytes are a strong signal of binary content. - if b"\x00" in header: - return FileType(kind="unknown", mime_type="") - - if media_hint: - return media_hint - if suffix in _NON_TEXT_SUFFIXES: - return FileType(kind="unknown", mime_type="") - return FileType(kind="text", mime_type="text/plain") diff --git a/src/kimi_cli/tools/file/utils.ts b/src/kimi_cli/tools/file/utils.ts new file mode 100644 index 000000000..41cd83e1f --- /dev/null +++ b/src/kimi_cli/tools/file/utils.ts @@ -0,0 +1,223 @@ +/** + * File type detection utilities. + * Corresponds to Python tools/file/utils.py + */ + +import { extname } from "node:path"; + +export const MEDIA_SNIFF_BYTES = 512; + +export interface FileType { + kind: "text" | "image" | "video" | "unknown"; + mimeType: string; +} + +const IMAGE_MIME_BY_SUFFIX: Record = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".bmp": "image/bmp", + ".tif": "image/tiff", + ".tiff": "image/tiff", + ".webp": "image/webp", + ".ico": "image/x-icon", + ".heic": "image/heic", + ".heif": "image/heif", + ".avif": "image/avif", + ".svgz": "image/svg+xml", +}; + +const VIDEO_MIME_BY_SUFFIX: Record = { + ".mp4": "video/mp4", + ".mkv": "video/x-matroska", + ".avi": "video/x-msvideo", + ".mov": "video/quicktime", + ".wmv": "video/x-ms-wmv", + ".webm": "video/webm", + ".m4v": "video/x-m4v", + ".flv": "video/x-flv", + ".3gp": "video/3gpp", + ".3g2": "video/3gpp2", +}; + +const TEXT_MIME_BY_SUFFIX: Record = { + ".svg": "image/svg+xml", +}; + +const FTYP_IMAGE_BRANDS: Record = { + avif: "image/avif", + avis: "image/avif", + heic: "image/heic", + heif: "image/heif", + heix: "image/heif", + hevc: "image/heic", + mif1: "image/heif", + msf1: "image/heif", +}; + +const FTYP_VIDEO_BRANDS: Record = { + isom: "video/mp4", + iso2: "video/mp4", + iso5: "video/mp4", + mp41: "video/mp4", + mp42: "video/mp4", + avc1: "video/mp4", + mp4v: "video/mp4", + m4v: "video/x-m4v", + qt: "video/quicktime", + "3gp4": "video/3gpp", + "3gp5": "video/3gpp", + "3gp6": "video/3gpp", + "3gp7": "video/3gpp", + "3g2": "video/3gpp2", +}; + +const NON_TEXT_SUFFIXES = new Set([ + ".icns", ".psd", ".ai", ".eps", + // Documents / office formats + ".pdf", ".doc", ".docx", ".dot", ".dotx", ".rtf", ".odt", + ".xls", ".xlsx", ".xlsm", ".xlt", ".xltx", ".xltm", ".ods", + ".ppt", ".pptx", ".pptm", ".pps", ".ppsx", ".odp", + ".pages", ".numbers", ".key", + // Archives / compressed + ".zip", ".rar", ".7z", ".tar", ".gz", ".tgz", ".bz2", ".xz", + ".zst", ".lz", ".lz4", ".br", ".cab", ".ar", ".deb", ".rpm", + // Audio + ".mp3", ".wav", ".flac", ".ogg", ".oga", ".opus", ".aac", ".m4a", ".wma", + // Fonts + ".ttf", ".otf", ".woff", ".woff2", + // Binaries / bundles + ".exe", ".dll", ".so", ".dylib", ".bin", ".apk", ".ipa", + ".jar", ".class", ".pyc", ".pyo", ".wasm", + // Disk images / databases + ".dmg", ".iso", ".img", ".sqlite", ".sqlite3", ".db", ".db3", +]); + +const ASF_HEADER = new Uint8Array([ + 0x30, 0x26, 0xb2, 0x75, 0x8e, 0x66, 0xcf, 0x11, + 0xa6, 0xd9, 0x00, 0xaa, 0x00, 0x62, 0xce, 0x6c, +]); + +function bufStartsWith(buf: Uint8Array, prefix: Uint8Array | number[]): boolean { + if (buf.length < prefix.length) return false; + for (let i = 0; i < prefix.length; i++) { + if (buf[i] !== prefix[i]) return false; + } + return true; +} + +function sniffFtypBrand(header: Uint8Array): string | null { + if (header.length < 12) return null; + // Check for "ftyp" at bytes 4-8 + if ( + header[4] !== 0x66 || // f + header[5] !== 0x74 || // t + header[6] !== 0x79 || // y + header[7] !== 0x70 // p + ) return null; + const brand = String.fromCharCode(...header.slice(8, 12)).toLowerCase().trim(); + return brand; +} + +/** Detect media type from raw magic bytes. */ +export function sniffMediaFromMagic(data: Uint8Array): FileType | null { + const header = data.slice(0, MEDIA_SNIFF_BYTES); + + // PNG + if (bufStartsWith(header, [0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a])) { + return { kind: "image", mimeType: "image/png" }; + } + // JPEG + if (bufStartsWith(header, [0xff, 0xd8, 0xff])) { + return { kind: "image", mimeType: "image/jpeg" }; + } + // GIF + if (bufStartsWith(header, [0x47, 0x49, 0x46, 0x38, 0x37, 0x61]) || // GIF87a + bufStartsWith(header, [0x47, 0x49, 0x46, 0x38, 0x39, 0x61])) { // GIF89a + return { kind: "image", mimeType: "image/gif" }; + } + // BMP + if (bufStartsWith(header, [0x42, 0x4d])) { + return { kind: "image", mimeType: "image/bmp" }; + } + // TIFF (II or MM) + if (bufStartsWith(header, [0x49, 0x49, 0x2a, 0x00]) || + bufStartsWith(header, [0x4d, 0x4d, 0x00, 0x2a])) { + return { kind: "image", mimeType: "image/tiff" }; + } + // ICO + if (bufStartsWith(header, [0x00, 0x00, 0x01, 0x00])) { + return { kind: "image", mimeType: "image/x-icon" }; + } + // RIFF (WEBP or AVI) + if (bufStartsWith(header, [0x52, 0x49, 0x46, 0x46]) && header.length >= 12) { + const chunk = String.fromCharCode(header[8]!, header[9]!, header[10]!, header[11]!); + if (chunk === "WEBP") return { kind: "image", mimeType: "image/webp" }; + if (chunk === "AVI ") return { kind: "video", mimeType: "video/x-msvideo" }; + } + // FLV + if (bufStartsWith(header, [0x46, 0x4c, 0x56])) { + return { kind: "video", mimeType: "video/x-flv" }; + } + // ASF (WMV) + if (bufStartsWith(header, ASF_HEADER)) { + return { kind: "video", mimeType: "video/x-ms-wmv" }; + } + // WebM / Matroska + if (bufStartsWith(header, [0x1a, 0x45, 0xdf, 0xa3])) { + const lowered = new TextDecoder().decode(header).toLowerCase(); + if (lowered.includes("webm")) return { kind: "video", mimeType: "video/webm" }; + if (lowered.includes("matroska")) return { kind: "video", mimeType: "video/x-matroska" }; + } + // ftyp container (MP4, HEIC, AVIF, etc.) + const brand = sniffFtypBrand(header); + if (brand) { + if (brand in FTYP_IMAGE_BRANDS) { + return { kind: "image", mimeType: FTYP_IMAGE_BRANDS[brand]! }; + } + if (brand in FTYP_VIDEO_BRANDS) { + return { kind: "video", mimeType: FTYP_VIDEO_BRANDS[brand]! }; + } + } + + return null; +} + +/** Detect file type from path extension and optional header bytes. */ +export function detectFileType(path: string, header?: Uint8Array): FileType { + const suffix = extname(path).toLowerCase(); + + let mediaHint: FileType | null = null; + if (suffix in TEXT_MIME_BY_SUFFIX) { + mediaHint = { kind: "text", mimeType: TEXT_MIME_BY_SUFFIX[suffix]! }; + } else if (suffix in IMAGE_MIME_BY_SUFFIX) { + mediaHint = { kind: "image", mimeType: IMAGE_MIME_BY_SUFFIX[suffix]! }; + } else if (suffix in VIDEO_MIME_BY_SUFFIX) { + mediaHint = { kind: "video", mimeType: VIDEO_MIME_BY_SUFFIX[suffix]! }; + } + + if (mediaHint && (mediaHint.kind === "image" || mediaHint.kind === "video")) { + return mediaHint; + } + + if (header !== undefined) { + const sniffed = sniffMediaFromMagic(header); + if (sniffed) { + if (mediaHint && sniffed.kind !== mediaHint.kind) { + return { kind: "unknown", mimeType: "" }; + } + return sniffed; + } + // NUL bytes are a strong signal of binary content + if (header.includes(0x00)) { + return { kind: "unknown", mimeType: "" }; + } + } + + if (mediaHint) return mediaHint; + if (NON_TEXT_SUFFIXES.has(suffix)) { + return { kind: "unknown", mimeType: "" }; + } + return { kind: "text", mimeType: "text/plain" }; +} diff --git a/src/kimi_cli/tools/file/write.md b/src/kimi_cli/tools/file/write.md deleted file mode 100644 index bf04d0fe9..000000000 --- a/src/kimi_cli/tools/file/write.md +++ /dev/null @@ -1,5 +0,0 @@ -Write content to a file. - -**Tips:** -- When `mode` is not specified, it defaults to `overwrite`. Always write with caution. -- When the content to write is too long (e.g. > 100 lines), use this tool multiple times instead of a single call. Use `overwrite` mode at the first time, then use `append` mode after the first write. diff --git a/src/kimi_cli/tools/file/write.py b/src/kimi_cli/tools/file/write.py deleted file mode 100644 index f5758b350..000000000 --- a/src/kimi_cli/tools/file/write.py +++ /dev/null @@ -1,175 +0,0 @@ -from collections.abc import Callable -from pathlib import Path -from typing import Literal, override - -from kaos.path import KaosPath -from kosong.tooling import CallableTool2, ToolError, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.soul.agent import Runtime -from kimi_cli.soul.approval import Approval -from kimi_cli.tools.display import DisplayBlock -from kimi_cli.tools.file import FileActions -from kimi_cli.tools.file.plan_mode import inspect_plan_edit_target -from kimi_cli.tools.utils import load_desc -from kimi_cli.utils.diff import build_diff_blocks -from kimi_cli.utils.path import is_within_workspace - -_BASE_DESCRIPTION = load_desc(Path(__file__).parent / "write.md") - - -class Params(BaseModel): - path: str = Field( - description=( - "The path to the file to write. Absolute paths are required when writing files " - "outside the working directory." - ) - ) - content: str = Field(description="The content to write to the file") - mode: Literal["overwrite", "append"] = Field( - description=( - "The mode to use to write to the file. " - "Two modes are supported: `overwrite` for overwriting the whole file and " - "`append` for appending to the end of an existing file." - ), - default="overwrite", - ) - - -class WriteFile(CallableTool2[Params]): - name: str = "WriteFile" - description: str = _BASE_DESCRIPTION - params: type[Params] = Params - - def __init__(self, runtime: Runtime, approval: Approval): - super().__init__() - self._work_dir = runtime.builtin_args.KIMI_WORK_DIR - self._additional_dirs = runtime.additional_dirs - self._approval = approval - self._plan_mode_checker: Callable[[], bool] | None = None - self._plan_file_path_getter: Callable[[], Path | None] | None = None - - def bind_plan_mode( - self, checker: Callable[[], bool], path_getter: Callable[[], Path | None] - ) -> None: - """Bind plan mode state checker and plan file path getter.""" - self._plan_mode_checker = checker - self._plan_file_path_getter = path_getter - - async def _validate_path(self, path: KaosPath) -> ToolError | None: - """Validate that the path is safe to write.""" - resolved_path = path.canonical() - - if ( - not is_within_workspace(resolved_path, self._work_dir, self._additional_dirs) - and not path.is_absolute() - ): - return ToolError( - message=( - f"`{path}` is not an absolute path. " - "You must provide an absolute path to write a file " - "outside the working directory." - ), - brief="Invalid path", - ) - return None - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - # TODO: checks: - # - check if the path may contain secrets - if not params.path: - return ToolError( - message="File path cannot be empty.", - brief="Empty file path", - ) - - try: - p = KaosPath(params.path).expanduser() - - if err := await self._validate_path(p): - return err - p = p.canonical() - - plan_target = inspect_plan_edit_target( - p, - plan_mode_checker=self._plan_mode_checker, - plan_file_path_getter=self._plan_file_path_getter, - ) - if isinstance(plan_target, ToolError): - return plan_target - - is_plan_file_write = plan_target.is_plan_target - if is_plan_file_write and plan_target.plan_path is not None: - plan_target.plan_path.parent.mkdir(parents=True, exist_ok=True) - - if not await p.parent.exists(): - return ToolError( - message=f"`{params.path}` parent directory does not exist.", - brief="Parent directory not found", - ) - - # Validate mode parameter - if params.mode not in ["overwrite", "append"]: - return ToolError( - message=( - f"Invalid write mode: `{params.mode}`. " - "Mode must be either `overwrite` or `append`." - ), - brief="Invalid write mode", - ) - - file_existed = await p.exists() - old_text = None - if file_existed: - old_text = await p.read_text(errors="replace") - - new_text = ( - params.content if params.mode == "overwrite" else (old_text or "") + params.content - ) - diff_blocks: list[DisplayBlock] = await build_diff_blocks( - str(p), - old_text or "", - new_text, - ) - - # Plan file writes are auto-approved; other writes need approval - if not is_plan_file_write: - action = ( - FileActions.EDIT - if is_within_workspace(p, self._work_dir, self._additional_dirs) - else FileActions.EDIT_OUTSIDE - ) - - # Request approval - result = await self._approval.request( - self.name, - action, - f"Write file `{p}`", - display=diff_blocks, - ) - if not result: - return result.rejection_error() - - # Write content to file - match params.mode: - case "overwrite": - await p.write_text(params.content) - case "append": - await p.append_text(params.content) - - # Get file info for success message - file_size = (await p.stat()).st_size - action = "overwritten" if params.mode == "overwrite" else "appended to" - return ToolReturnValue( - is_error=False, - output="", - message=(f"File successfully {action}. Current size: {file_size} bytes."), - display=diff_blocks, - ) - - except Exception as e: - return ToolError( - message=f"Failed to write to {params.path}. Error: {e}", - brief="Failed to write file", - ) diff --git a/src/kimi_cli/tools/file/write.ts b/src/kimi_cli/tools/file/write.ts new file mode 100644 index 000000000..cbd79fff5 --- /dev/null +++ b/src/kimi_cli/tools/file/write.ts @@ -0,0 +1,185 @@ +/** + * WriteFile tool — write content to a file. + * Corresponds to Python tools/file/write.py + */ + +import { resolve, dirname } from "node:path"; +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError } from "../types.ts"; +import { inspectPlanEditTarget } from "./plan_mode.ts"; +import type { DiffDisplayBlock } from "../display.ts"; + +const DESCRIPTION = `Write content to a file. + +**Tips:** +- When \`mode\` is not specified, it defaults to \`overwrite\`. Always write with caution. +- When the content to write is too long (e.g. > 100 lines), use this tool multiple times instead of a single call. Use \`overwrite\` mode at the first time, then use \`append\` mode after the first write.`; + +const ParamsSchema = z.object({ + path: z.string().describe( + "The path to the file to write. Absolute paths are required when writing files outside the working directory.", + ), + content: z.string().describe("The content to write to the file"), + mode: z + .enum(["overwrite", "append"]) + .default("overwrite") + .describe("The mode to use: `overwrite` or `append`."), +}); + +type Params = z.infer; + +function resolvePath(filePath: string, workingDir: string): string { + if (filePath.startsWith("/") || filePath.startsWith("~")) { + if (filePath.startsWith("~")) { + const home = process.env.HOME || process.env.USERPROFILE || ""; + return filePath.replace(/^~/, home); + } + return filePath; + } + return resolve(workingDir, filePath); +} + +/** Build a simple unified diff for display. */ +function buildSimpleDiff(oldContent: string, newContent: string, path: string): string { + const oldLines = oldContent.split("\n"); + const newLines = newContent.split("\n"); + const maxPreview = 50; + const diffLines: string[] = [`--- a/${path}`, `+++ b/${path}`]; + + let shown = 0; + const maxLen = Math.max(oldLines.length, newLines.length); + for (let i = 0; i < maxLen && shown < maxPreview; i++) { + const oldLine = oldLines[i]; + const newLine = newLines[i]; + if (oldLine !== newLine) { + if (oldLine !== undefined) { + diffLines.push(`-${oldLine}`); + shown++; + } + if (newLine !== undefined) { + diffLines.push(`+${newLine}`); + shown++; + } + } + } + + if (shown >= maxPreview) { + diffLines.push(`... (diff truncated, ${maxLen} total lines)`); + } + + return diffLines.join("\n"); +} + +export class WriteFile extends CallableTool { + readonly name = "WriteFile"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + /** Optional plan mode bindings. */ + private _planModeChecker?: () => boolean; + private _planFilePathGetter?: () => string | null; + + /** Bind plan mode state checker and plan file path getter. */ + bindPlanMode(checker: () => boolean, pathGetter: () => string | null): void { + this._planModeChecker = checker; + this._planFilePathGetter = pathGetter; + } + + async execute(params: Params, ctx: ToolContext): Promise { + if (!params.path) { + return ToolError("File path cannot be empty."); + } + + try { + const resolvedPath = resolvePath(params.path, ctx.workingDir); + + // Check plan mode restrictions + const planTarget = inspectPlanEditTarget(resolvedPath, { + planModeChecker: this._planModeChecker ?? ctx.getPlanMode, + planFilePathGetter: this._planFilePathGetter, + }); + if ("isError" in planTarget && planTarget.isError) { + return planTarget; + } + const isPlanFileWrite = !("isError" in planTarget) && planTarget.isPlanTarget; + + // Ensure parent directory for plan file writes + if (isPlanFileWrite && !("isError" in planTarget) && planTarget.planPath) { + const { mkdirSync } = await import("node:fs"); + mkdirSync(dirname(planTarget.planPath), { recursive: true }); + } + + // Check if parent directory exists + const parentDir = dirname(resolvedPath); + const { stat: fsStat, mkdir } = await import("node:fs/promises"); + try { + const parentInfo = await fsStat(parentDir); + if (!parentInfo.isDirectory()) { + return ToolError(`Parent path \`${parentDir}\` exists but is not a directory.`); + } + } catch (err: any) { + if (err?.code === "ENOENT") { + await mkdir(parentDir, { recursive: true }); + } else { + return ToolError(`Cannot access parent directory \`${parentDir}\`: ${err?.message}`); + } + } + + const file = Bun.file(resolvedPath); + const fileExisted = await file.exists(); + + // Build diff for approval display + let diffPreview = ""; + if (fileExisted) { + try { + const oldContent = await file.text(); + const newContent = params.mode === "append" ? oldContent + params.content : params.content; + diffPreview = buildSimpleDiff(oldContent, newContent, params.path); + } catch { + // Can't read old file — skip diff + } + } + + // Plan file writes are auto-approved; other writes need approval + if (!isPlanFileWrite) { + const approvalSummary = fileExisted + ? `${params.mode === "append" ? "Append to" : "Overwrite"} file \`${params.path}\`${diffPreview ? `\n${diffPreview}` : ""}` + : `Create file \`${params.path}\` (${params.content.length} chars)`; + + const decision = await ctx.approval( + "WriteFile", + fileExisted ? "edit" : "create", + approvalSummary, + ); + if (decision === "reject") { + return ToolError( + "The tool call is rejected by the user. Stop what you are doing and wait for the user to tell you how to proceed.", + ); + } + } + + if (params.mode === "append" && fileExisted) { + const { appendFile } = await import("node:fs/promises"); + await appendFile(resolvedPath, params.content, "utf-8"); + } else { + await Bun.write(resolvedPath, params.content); + } + + const newFile = Bun.file(resolvedPath); + const fileSize = newFile.size; + const action = + params.mode === "overwrite" + ? (fileExisted ? "overwritten" : "created") + : "appended to"; + return { + isError: false, + output: "", + message: `File successfully ${action}. Current size: ${fileSize} bytes.`, + }; + } catch (e) { + return ToolError(`Failed to write to ${params.path}. Error: ${e}`); + } + } +} diff --git a/src/kimi_cli/tools/plan/__init__.py b/src/kimi_cli/tools/plan/__init__.py deleted file mode 100644 index 99827efe6..000000000 --- a/src/kimi_cli/tools/plan/__init__.py +++ /dev/null @@ -1,325 +0,0 @@ -"""ExitPlanMode tool — lets the LLM submit a plan for user approval.""" - -from __future__ import annotations - -import asyncio -import logging -from collections.abc import Awaitable, Callable -from pathlib import Path -from typing import override -from uuid import uuid4 - -from kosong.tooling import BriefDisplayBlock, CallableTool2, ToolError, ToolReturnValue -from pydantic import BaseModel, Field, field_validator - -from kimi_cli.soul import get_wire_or_none, wire_send -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.tools.utils import ToolRejectedError, load_desc -from kimi_cli.wire.types import ( - PlanDisplay, - QuestionItem, - QuestionNotSupported, - QuestionOption, - QuestionRequest, -) - -logger = logging.getLogger(__name__) - -NAME = "ExitPlanMode" - -_RESERVED_LABELS = {"reject", "revise", "approve", "reject and exit"} - - -class PlanOption(BaseModel): - """A selectable approach/option within the plan.""" - - label: str = Field( - description=( - "Short name for this option (1-8 words). " - "Append '(Recommended)' if you recommend this option." - ), - ) - description: str = Field( - default="", - description="Brief summary of this approach and its trade-offs.", - ) - - @field_validator("label") - @classmethod - def label_not_reserved(cls, v: str) -> str: - if v.strip().lower() in _RESERVED_LABELS: - reserved = ", ".join(f"'{w.title()}'" for w in sorted(_RESERVED_LABELS)) - raise ValueError( - f"Option label {v!r} is reserved. Do not use {reserved} as option labels." - ) - return v - - -class Params(BaseModel): - options: list[PlanOption] | None = Field( - default=None, - max_length=3, - description=( - "When the plan contains multiple alternative approaches, list them here " - "so the user can choose which one to execute. 2-3 options. " - "Each option represents a distinct approach from the plan. " - "Do not use 'Reject', 'Revise', 'Approve', or 'Reject and Exit' as labels." - ), - ) - - @field_validator("options") - @classmethod - def options_labels_unique(cls, v: list[PlanOption] | None) -> list[PlanOption] | None: - if v is None: - return v - labels = [opt.label for opt in v] - if len(labels) != len(set(labels)): - raise ValueError("Option labels must be unique. Found duplicate label(s).") - return v - - -class ExitPlanMode(CallableTool2[Params]): - name: str = NAME - description: str = load_desc(Path(__file__).parent / "description.md") - params: type[Params] = Params - - def __init__(self) -> None: - super().__init__() - self._toggle_callback: Callable[[], Awaitable[bool]] | None = None - self._plan_file_path_getter: Callable[[], Path | None] | None = None - self._plan_mode_checker: Callable[[], bool] | None = None - self._is_yolo: Callable[[], bool] | None = None - - def bind( - self, - toggle_callback: Callable[[], Awaitable[bool]], - plan_file_path_getter: Callable[[], Path | None], - plan_mode_checker: Callable[[], bool], - is_yolo: Callable[[], bool] | None = None, - ) -> None: - """Late-bind soul callbacks after KimiSoul is constructed.""" - self._toggle_callback = toggle_callback - self._plan_file_path_getter = plan_file_path_getter - self._plan_mode_checker = plan_mode_checker - self._is_yolo = is_yolo - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - # Guard: only works in plan mode - if not self._plan_mode_checker or not self._plan_mode_checker(): - return ToolError( - message="Not in plan mode. ExitPlanMode is only available during plan mode.", - brief="Not in plan mode", - ) - - if not self._toggle_callback or not self._plan_file_path_getter: - return ToolError( - message="ExitPlanMode is not properly initialized.", - brief="Not initialized", - ) - - # Read the plan file - plan_path = self._plan_file_path_getter() - plan_content: str | None = None - if plan_path and await asyncio.to_thread(plan_path.exists): - plan_content = await asyncio.to_thread(plan_path.read_text, encoding="utf-8") - - if not plan_content: - return ToolError( - message=f"No plan file found. Write your plan to {plan_path} first, " - "then call ExitPlanMode.", - brief="No plan file", - ) - - # In yolo mode, auto-approve the plan - if self._is_yolo and self._is_yolo(): - await self._toggle_callback() - return ToolReturnValue( - is_error=False, - output=( - f"Plan approved (auto-approved in non-interactive mode). " - f"Plan mode deactivated. All tools are now available.\n" - f"Plan saved to: {plan_path}\n\n" - f"## Approved Plan:\n{plan_content}" - ), - message="Plan approved (auto)", - display=[BriefDisplayBlock(text="Plan approved (auto)")], - ) - - # Present plan to user via QuestionRequest - wire = get_wire_or_none() - if wire is None: - return ToolError( - message="Cannot present plan: Wire is not available.", - brief="Wire unavailable", - ) - - tool_call = get_current_tool_call_or_none() - if tool_call is None: - return ToolError( - message="ExitPlanMode must be called from a tool call context.", - brief="Invalid context", - ) - - has_options = params.options is not None and len(params.options) >= 2 - - _reject_options = [ - QuestionOption( - label="Reject", - description="Reject and stay in plan mode", - ), - QuestionOption( - label="Reject and Exit", - description="Reject and exit plan mode", - ), - ] - - if has_options: - assert params.options is not None - question_options = [ - QuestionOption(label=opt.label, description=opt.description) - for opt in params.options - ] - question_options.extend(_reject_options) - else: - question_options = [ - QuestionOption( - label="Approve", - description="Exit plan mode and start execution", - ), - *_reject_options, - ] - - # Display plan content inline in the chat - wire_send(PlanDisplay(content=plan_content, file_path=str(plan_path))) - - request = QuestionRequest( - id=str(uuid4()), - tool_call_id=tool_call.id, - questions=[ - QuestionItem( - question="Approve this plan", - header="Plan", - options=question_options, - other_label="Revise", - other_description="Stay in plan mode and provide feedback", - ) - ], - ) - - wire_send(request) - - try: - answers = await request.wait() - except QuestionNotSupported: - return ToolError( - message="The connected client does not support plan mode. " - "Do NOT call this tool again.", - brief="Client unsupported", - ) - except Exception: - logger.exception("Failed to get user response for ExitPlanMode") - return ToolError( - message="Failed to get user response.", - brief="Question failed", - ) - - if not answers: - return ToolReturnValue( - is_error=False, - output="User dismissed without choosing. Plan mode remains active. " - "Continue working on your plan or call ExitPlanMode again when ready.", - message="Dismissed", - display=[BriefDisplayBlock(text="Dismissed")], - ) - - # Parse user choice — exact match on option label - chose_reject_and_exit = any(v == "Reject and Exit" for v in answers.values()) - - if chose_reject_and_exit: - await self._toggle_callback() - return ToolRejectedError( - message=( - "Plan rejected by user. Plan mode deactivated. " - "All tools are now available. " - "Wait for the user's next message." - ), - brief="Plan rejected, exited plan mode", - ) - - chose_reject = any(v == "Reject" for v in answers.values()) - - if chose_reject: - return ToolRejectedError( - message=( - "Plan rejected by user. Stay in plan mode. " - "The user will provide feedback via conversation. " - "Wait for the user's next message before revising." - ), - brief="Plan rejected", - ) - - # Approve — multi-approach (user selected a specific option) - if has_options: - assert params.options is not None - option_labels = {opt.label for opt in params.options} - chosen_option = None - for v in answers.values(): - if v in option_labels: - chosen_option = v - break - - if chosen_option: - await self._toggle_callback() - return ToolReturnValue( - is_error=False, - output=( - f'Plan approved by user. Selected approach: "{chosen_option}"\n' - f"Plan mode deactivated. All tools are now available.\n" - f"Plan saved to: {plan_path}\n\n" - f'IMPORTANT: Execute ONLY the selected approach "{chosen_option}". ' - f"Ignore other approaches in the plan.\n\n" - f"## Approved Plan:\n{plan_content}" - ), - message=f"Plan approved: {chosen_option}", - display=[BriefDisplayBlock(text=f"Plan approved: {chosen_option}")], - ) - - # Approve — single-approach only (has_options uses option labels, not "Approve") - chose_approve = not has_options and any(v == "Approve" for v in answers.values()) - if chose_approve: - await self._toggle_callback() - return ToolReturnValue( - is_error=False, - output=( - f"Plan approved by user. Plan mode deactivated. " - f"All tools are now available.\n" - f"Plan saved to: {plan_path}\n\n" - f"## Approved Plan:\n{plan_content}" - ), - message="Plan approved", - display=[BriefDisplayBlock(text="Plan approved")], - ) - - # Revise — user selected the free-text "Revise" option (fallback) - feedback = "" - for v in answers.values(): - if v not in ("Approve", "Reject", "Reject and Exit"): - feedback = v - if feedback: - msg = ( - "User wants to revise the plan. Stay in plan mode. " - "Revise based on the feedback below.\n\n" - f"User feedback: {feedback}" - ) - else: - msg = ( - "User wants to revise the plan. Stay in plan mode. " - "Wait for the user's next message with feedback before revising." - ) - return ToolReturnValue( - is_error=False, - output=msg, - message="Plan revision requested", - display=[BriefDisplayBlock(text="Plan revision requested")], - ) diff --git a/src/kimi_cli/tools/plan/description.md b/src/kimi_cli/tools/plan/description.md deleted file mode 100644 index e827339f3..000000000 --- a/src/kimi_cli/tools/plan/description.md +++ /dev/null @@ -1,25 +0,0 @@ -Use this tool when you are in plan mode and have finished writing your plan to the plan file and are ready for user approval. - -## How This Tool Works -- You should have already written your plan to the plan file specified in the plan mode reminder. -- This tool does NOT take the plan content as a parameter — it reads the plan from the file you wrote. -- The user will see the contents of your plan file when they review it. - -## When to Use -Only use this tool for tasks that require planning implementation steps. For research tasks (searching files, reading code, understanding the codebase), do NOT use this tool. - -## Multiple Approaches -If your plan contains multiple alternative approaches: -- Pass them via the `options` parameter so the user can choose which approach to execute. -- Each option should have a concise label and a brief description of trade-offs. -- If you recommend one option, append "(Recommended)" to its label. -- The user will see all options alongside Reject and Revise choices. -- Provide 2-3 options at most (the system appends a "Reject" option automatically, so the total shown to the user is 3-4). -- Do NOT use "Reject", "Revise", or "Approve" as option labels — these are reserved by the system. - -## Before Using -- If you have unresolved questions, use AskUserQuestion first. -- If you have multiple approaches and haven't narrowed down yet, consider using AskUserQuestion first to let the user choose, then write a plan for the chosen approach only. -- Once your plan is finalized, use THIS tool to request approval. -- Do NOT use AskUserQuestion to ask "Is this plan OK?" or "Should I proceed?" — that is exactly what ExitPlanMode does. -- If rejected, revise based on feedback and call ExitPlanMode again. diff --git a/src/kimi_cli/tools/plan/enter.py b/src/kimi_cli/tools/plan/enter.py deleted file mode 100644 index cb8af0d6d..000000000 --- a/src/kimi_cli/tools/plan/enter.py +++ /dev/null @@ -1,183 +0,0 @@ -"""EnterPlanMode tool — lets the LLM request to enter plan mode.""" - -from __future__ import annotations - -import logging -from collections.abc import Awaitable, Callable -from pathlib import Path -from typing import override -from uuid import uuid4 - -from kosong.tooling import BriefDisplayBlock, CallableTool2, ToolError, ToolReturnValue -from pydantic import BaseModel - -from kimi_cli.soul import get_wire_or_none, wire_send -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.tools.utils import load_desc -from kimi_cli.wire.types import QuestionItem, QuestionNotSupported, QuestionOption, QuestionRequest - -logger = logging.getLogger(__name__) - -NAME = "EnterPlanMode" - -_DESCRIPTION = load_desc(Path(__file__).parent / "enter_description.md") - - -class Params(BaseModel): - pass - - -class EnterPlanMode(CallableTool2[Params]): - name: str = NAME - description: str = _DESCRIPTION - params: type[Params] = Params - - def __init__(self) -> None: - super().__init__() - self._toggle_callback: Callable[[], Awaitable[bool]] | None = None - self._plan_file_path_getter: Callable[[], Path | None] | None = None - self._plan_mode_checker: Callable[[], bool] | None = None - self._is_yolo: Callable[[], bool] | None = None - - def bind( - self, - toggle_callback: Callable[[], Awaitable[bool]], - plan_file_path_getter: Callable[[], Path | None], - plan_mode_checker: Callable[[], bool], - is_yolo: Callable[[], bool] | None = None, - ) -> None: - """Late-bind soul callbacks after KimiSoul is constructed.""" - self._toggle_callback = toggle_callback - self._plan_file_path_getter = plan_file_path_getter - self._plan_mode_checker = plan_mode_checker - self._is_yolo = is_yolo - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - # Guard: already in plan mode - if self._plan_mode_checker and self._plan_mode_checker(): - return ToolError( - message="Already in plan mode. Use ExitPlanMode when your plan is ready.", - brief="Already in plan mode", - ) - - if not self._toggle_callback or not self._plan_file_path_getter: - return ToolError( - message="EnterPlanMode is not properly initialized.", - brief="Not initialized", - ) - - # In yolo mode, auto-approve entering plan mode - if self._is_yolo and self._is_yolo(): - await self._toggle_callback() - plan_path = self._plan_file_path_getter() - return ToolReturnValue( - is_error=False, - output=( - f"Plan mode activated (auto-approved in non-interactive mode).\n" - f"Plan file: {plan_path}\n" - f"Workflow: identify key questions about the codebase → " - f"use Agent(subagent_type='explore') to investigate if needed → " - f"design approach → " - f"modify the plan file with WriteFile or StrReplaceFile " - f"(create it with WriteFile first if it does not exist) → " - f"call ExitPlanMode.\n" - ), - message="Plan mode on (auto)", - display=[BriefDisplayBlock(text="Plan mode on (auto)")], - ) - - # Present confirmation dialog to user via QuestionRequest - wire = get_wire_or_none() - if wire is None: - return ToolError( - message="Cannot request user confirmation: Wire is not available.", - brief="Wire unavailable", - ) - - tool_call = get_current_tool_call_or_none() - if tool_call is None: - return ToolError( - message="EnterPlanMode must be called from a tool call context.", - brief="Invalid context", - ) - - request = QuestionRequest( - id=str(uuid4()), - tool_call_id=tool_call.id, - questions=[ - QuestionItem( - question="Enter plan mode?", - header="Plan Mode", - options=[ - QuestionOption( - label="Yes", - description="Enter plan mode to explore and design an approach", - ), - QuestionOption( - label="No", - description="Skip planning, start implementing now", - ), - ], - ) - ], - ) - - wire_send(request) - - try: - answers = await request.wait() - except QuestionNotSupported: - return ToolError( - message="The connected client does not support plan mode. " - "Do NOT call this tool again.", - brief="Client unsupported", - ) - except Exception: - logger.exception("Failed to get user response for EnterPlanMode") - return ToolError( - message="Failed to get user response.", - brief="Question failed", - ) - - if not answers: - return ToolReturnValue( - is_error=False, - output="User dismissed without choosing. Proceed with implementation directly.", - message="Dismissed", - display=[BriefDisplayBlock(text="Dismissed")], - ) - - # Parse user choice — exact match on option label - chose_yes = any(v == "Yes" for v in answers.values()) - if chose_yes: - await self._toggle_callback() - plan_path = self._plan_file_path_getter() - return ToolReturnValue( - is_error=False, - output=( - f"Plan mode activated. You MUST NOT edit code files — only read and plan.\n" - f"Plan file: {plan_path}\n" - f"Workflow: identify key questions about the codebase → " - f"use Agent(subagent_type='explore') to investigate if needed → " - f"design approach → " - f"modify the plan file with WriteFile or StrReplaceFile " - f"(create it with WriteFile first if it does not exist) → " - f"call ExitPlanMode.\n" - f"Use AskUserQuestion only to clarify missing requirements or choose " - f"between approaches.\n" - f"Do NOT use AskUserQuestion to ask about plan approval." - ), - message="Plan mode on", - display=[BriefDisplayBlock(text="Plan mode on")], - ) - else: - return ToolReturnValue( - is_error=False, - output=( - "User declined to enter plan mode. Please check with user whether " - "to proceed with implementation directly." - ), - message="Declined", - display=[BriefDisplayBlock(text="Declined")], - ) diff --git a/src/kimi_cli/tools/plan/enter_description.md b/src/kimi_cli/tools/plan/enter_description.md deleted file mode 100644 index 6f9860c86..000000000 --- a/src/kimi_cli/tools/plan/enter_description.md +++ /dev/null @@ -1,30 +0,0 @@ -Use this tool proactively when you're about to start a non-trivial implementation task. -Getting user sign-off on your approach before writing code prevents wasted effort. - -Use it when ANY of these conditions apply: - -1. New Feature Implementation — e.g. "Add a caching layer to the API" -2. Multiple Valid Approaches — e.g. "Optimize database queries" (indexing vs rewrite vs caching) -3. Code Modifications — e.g. "Refactor auth module to support OAuth" -4. Architectural Decisions — e.g. "Add WebSocket support" -5. Multi-File Changes — involves more than 2-3 files -6. Unclear Requirements — need exploration to understand scope -7. User Preferences Matter — if you'd use AskUserQuestion to clarify approach, use EnterPlanMode instead - -Yolo mode note: -- Yolo mode users chose continuous execution. -- In yolo mode, use EnterPlanMode only when the user explicitly asks for planning or when - there is exceptional architectural ambiguity that requires user input before proceeding. - -When NOT to use: -- Single-line or few-line fixes (typos, obvious bugs, small tweaks) -- User gave very specific, detailed instructions -- Pure research/exploration tasks - -## What Happens in Plan Mode -In plan mode, you will: -1. Identify 2-3 key questions about the codebase that are critical to your plan. If you are not confident about the codebase structure or relevant code paths, use `Agent(subagent_type="explore")` to investigate these questions first — this is strongly recommended for non-trivial tasks. -2. Explore the codebase using Glob, Grep, ReadFile (read-only) for any remaining quick lookups -3. Design an implementation approach based on your findings -4. Write your plan to a plan file -5. Present your plan to the user via ExitPlanMode for approval diff --git a/src/kimi_cli/tools/plan/heroes.py b/src/kimi_cli/tools/plan/heroes.py deleted file mode 100644 index c889d198f..000000000 --- a/src/kimi_cli/tools/plan/heroes.py +++ /dev/null @@ -1,277 +0,0 @@ -"""Plan file slug generation using Marvel and DC hero names.""" - -from __future__ import annotations - -import secrets -from pathlib import Path - -PLANS_DIR = Path.home() / ".kimi" / "plans" - -HERO_NAMES: list[str] = [ - # --- Marvel --- - "iron-man", - "spider-man", - "captain-america", - "thor", - "hulk", - "black-widow", - "hawkeye", - "black-panther", - "doctor-strange", - "scarlet-witch", - "vision", - "falcon", - "war-machine", - "ant-man", - "wasp", - "captain-marvel", - "gamora", - "star-lord", - "groot", - "rocket", - "drax", - "mantis", - "nebula", - "shang-chi", - "moon-knight", - "ms-marvel", - "she-hulk", - "echo", - "wolverine", - "cyclops", - "storm", - "jean-grey", - "rogue", - "beast", - "nightcrawler", - "colossus", - "shadowcat", - "jubilee", - "cable", - "deadpool", - "bishop", - "magik", - "iceman", - "archangel", - "psylocke", - "dazzler", - "forge", - "havok", - "polaris", - "emma-frost", - "namor", - "silver-surfer", - "adam-warlock", - "nova", - "quasar", - "sentry", - "blue-marvel", - "spectrum", - "squirrel-girl", - "cloak", - "dagger", - "punisher", - "elektra", - "luke-cage", - "iron-fist", - "jessica-jones", - "daredevil", - "blade", - "ghost-rider", - "morbius", - "venom", - "carnage", - "silk", - "spider-gwen", - "miles-morales", - "america-chavez", - "kate-bishop", - "yelena-belova", - "white-tiger", - "moon-girl", - "devil-dinosaur", - "amadeus-cho", - "riri-williams", - "kamala-khan", - "sam-alexander", - "nova-prime", - "medusa", - "black-bolt", - "crystal", - "karnak", - "gorgon", - "lockjaw", - "quake", - "mockingbird", - "bobbi-morse", - "maria-hill", - "nick-fury", - "phil-coulson", - "winter-soldier", - "us-agent", - "patriot", - "speed", - "wiccan", - "hulkling", - "stature", - "yellowjacket", - "tigra", - "hellcat", - "valkyrie", - "sif", - "beta-ray-bill", - "hercules", - "wonder-man", - "taskmaster", - "domino", - "cannonball", - "sunspot", - "wolfsbane", - "warpath", - "multiple-man", - "banshee", - "siryn", - "monet", - "rictor", - "shatterstar", - "longshot", - "daken", - "x-23", - "fantomex", - "batman", - "superman", - "wonder-woman", - "flash", - "aquaman", - "green-lantern", - "martian-manhunter", - "cyborg", - "hawkgirl", - "green-arrow", - "black-canary", - "zatanna", - "constantine", - "shazam", - "blue-beetle", - "booster-gold", - "firestorm", - "atom", - "hawkman", - "plastic-man", - "red-tornado", - "starfire", - "raven", - "beast-boy", - "robin", - "nightwing", - "batgirl", - "batwoman", - "red-hood", - "signal", - "orphan", - "spoiler", - "catwoman", - "huntress", - "supergirl", - "superboy", - "power-girl", - "steel", - "stargirl", - "wildcat", - "doctor-fate", - "mister-terrific", - "hourman", - "sandman", - "spectre", - "phantom-stranger", - "swamp-thing", - "animal-man", - "deadman", - "vixen", - "black-lightning", - "static", - "icon", - "rocket-dc", - "captain-atom", - "fire", - "ice", - "elongated-man", - "metamorpho", - "black-hawk", - "crimson-avenger", - "doctor-mid-nite", - "jakeem-thunder", - "mister-miracle", - "big-barda", - "orion", - "lightray", - "forager", - "killer-frost", - "jessica-cruz", - "simon-baz", - "john-stewart", - "guy-gardner", - "kyle-rayner", - "hal-jordan", - "wally-west", - "barry-allen", - "jay-garrick", - "impulse", - "kid-flash", - "donna-troy", - "tempest", - "aqualad", - "miss-martian", - "terra", - "jericho", - "ravager", - "red-star", - "pantha", - "argent", - "damage", - "jade", - "obsidian", - "cyclone", - "atom-smasher", - "maxima", - "starman", - "liberty-belle", -] - -_slug_cache: dict[str, str] = {} - - -def seed_slug_cache(session_id: str, slug: str) -> None: - """Pre-warm the in-process slug cache with a previously persisted slug.""" - _slug_cache[session_id] = slug - - -def get_or_create_slug(session_id: str) -> str: - """Get or create a plan file slug for the given session.""" - if session_id in _slug_cache: - return _slug_cache[session_id] - PLANS_DIR.mkdir(parents=True, exist_ok=True) - slug = "" - for _ in range(20): - words = [secrets.choice(HERO_NAMES) for _ in range(3)] - slug = "-".join(words) - if not (PLANS_DIR / f"{slug}.md").exists(): - break - else: - # All 20 attempts collided; append session prefix for uniqueness - slug = f"{slug}-{session_id[:8]}" - _slug_cache[session_id] = slug - return slug - - -def get_plan_file_path(session_id: str) -> Path: - """Get the plan file path for the given session.""" - return PLANS_DIR / f"{get_or_create_slug(session_id)}.md" - - -def read_plan_file(session_id: str) -> str | None: - """Read the plan file content for the given session, or None if not found.""" - path = get_plan_file_path(session_id) - if path.exists(): - return path.read_text(encoding="utf-8") - return None diff --git a/src/kimi_cli/tools/plan/heroes.ts b/src/kimi_cli/tools/plan/heroes.ts new file mode 100644 index 000000000..513eea744 --- /dev/null +++ b/src/kimi_cli/tools/plan/heroes.ts @@ -0,0 +1,292 @@ +/** + * Plan file slug generation using Marvel and DC hero names. + * Corresponds to Python tools/plan/heroes.py + */ + +import { join } from "node:path"; +import { homedir } from "node:os"; +import { existsSync, mkdirSync, readFileSync } from "node:fs"; + +export const PLANS_DIR = join(homedir(), ".kimi", "plans"); + +export const HERO_NAMES: string[] = [ + // --- Marvel --- + "iron-man", + "spider-man", + "captain-america", + "thor", + "hulk", + "black-widow", + "hawkeye", + "black-panther", + "doctor-strange", + "scarlet-witch", + "vision", + "falcon", + "war-machine", + "ant-man", + "wasp", + "captain-marvel", + "gamora", + "star-lord", + "groot", + "rocket", + "drax", + "mantis", + "nebula", + "shang-chi", + "moon-knight", + "ms-marvel", + "she-hulk", + "echo", + "wolverine", + "cyclops", + "storm", + "jean-grey", + "rogue", + "beast", + "nightcrawler", + "colossus", + "shadowcat", + "jubilee", + "cable", + "deadpool", + "bishop", + "magik", + "iceman", + "archangel", + "psylocke", + "dazzler", + "forge", + "havok", + "polaris", + "emma-frost", + "namor", + "silver-surfer", + "adam-warlock", + "nova", + "quasar", + "sentry", + "blue-marvel", + "spectrum", + "squirrel-girl", + "cloak", + "dagger", + "punisher", + "elektra", + "luke-cage", + "iron-fist", + "jessica-jones", + "daredevil", + "blade", + "ghost-rider", + "morbius", + "venom", + "carnage", + "silk", + "spider-gwen", + "miles-morales", + "america-chavez", + "kate-bishop", + "yelena-belova", + "white-tiger", + "moon-girl", + "devil-dinosaur", + "amadeus-cho", + "riri-williams", + "kamala-khan", + "sam-alexander", + "nova-prime", + "medusa", + "black-bolt", + "crystal", + "karnak", + "gorgon", + "lockjaw", + "quake", + "mockingbird", + "bobbi-morse", + "maria-hill", + "nick-fury", + "phil-coulson", + "winter-soldier", + "us-agent", + "patriot", + "speed", + "wiccan", + "hulkling", + "stature", + "yellowjacket", + "tigra", + "hellcat", + "valkyrie", + "sif", + "beta-ray-bill", + "hercules", + "wonder-man", + "taskmaster", + "domino", + "cannonball", + "sunspot", + "wolfsbane", + "warpath", + "multiple-man", + "banshee", + "siryn", + "monet", + "rictor", + "shatterstar", + "longshot", + "daken", + "x-23", + "fantomex", + // --- DC --- + "batman", + "superman", + "wonder-woman", + "flash", + "aquaman", + "green-lantern", + "martian-manhunter", + "cyborg", + "hawkgirl", + "green-arrow", + "black-canary", + "zatanna", + "constantine", + "shazam", + "blue-beetle", + "booster-gold", + "firestorm", + "atom", + "hawkman", + "plastic-man", + "red-tornado", + "starfire", + "raven", + "beast-boy", + "robin", + "nightwing", + "batgirl", + "batwoman", + "red-hood", + "signal", + "orphan", + "spoiler", + "catwoman", + "huntress", + "supergirl", + "superboy", + "power-girl", + "steel", + "stargirl", + "wildcat", + "doctor-fate", + "mister-terrific", + "hourman", + "sandman", + "spectre", + "phantom-stranger", + "swamp-thing", + "animal-man", + "deadman", + "vixen", + "black-lightning", + "static", + "icon", + "rocket-dc", + "captain-atom", + "fire", + "ice", + "elongated-man", + "metamorpho", + "black-hawk", + "crimson-avenger", + "doctor-mid-nite", + "jakeem-thunder", + "mister-miracle", + "big-barda", + "orion", + "lightray", + "forager", + "killer-frost", + "jessica-cruz", + "simon-baz", + "john-stewart", + "guy-gardner", + "kyle-rayner", + "hal-jordan", + "wally-west", + "barry-allen", + "jay-garrick", + "impulse", + "kid-flash", + "donna-troy", + "tempest", + "aqualad", + "miss-martian", + "terra", + "jericho", + "ravager", + "red-star", + "pantha", + "argent", + "damage", + "jade", + "obsidian", + "cyclone", + "atom-smasher", + "maxima", + "starman", + "liberty-belle", +]; + +const _slugCache = new Map(); + +/** Pre-warm the in-process slug cache with a previously persisted slug. */ +export function seedSlugCache(sessionId: string, slug: string): void { + _slugCache.set(sessionId, slug); +} + +/** Get or create a plan file slug for the given session. */ +export function getOrCreateSlug(sessionId: string): string { + const cached = _slugCache.get(sessionId); + if (cached) return cached; + + mkdirSync(PLANS_DIR, { recursive: true }); + + let slug = ""; + for (let i = 0; i < 20; i++) { + const words: string[] = []; + for (let j = 0; j < 3; j++) { + words.push(HERO_NAMES[Math.floor(Math.random() * HERO_NAMES.length)]!); + } + slug = words.join("-"); + if (!existsSync(join(PLANS_DIR, `${slug}.md`))) { + break; + } + // If last attempt and still colliding, append session prefix + if (i === 19) { + slug = `${slug}-${sessionId.slice(0, 8)}`; + } + } + + _slugCache.set(sessionId, slug); + return slug; +} + +/** Get the plan file path for the given session. */ +export function getPlanFilePath(sessionId: string): string { + return join(PLANS_DIR, `${getOrCreateSlug(sessionId)}.md`); +} + +/** Read the plan file content for the given session, or null if not found. */ +export function readPlanFile(sessionId: string): string | null { + const path = getPlanFilePath(sessionId); + try { + if (!existsSync(path)) return null; + return readFileSync(path, "utf-8"); + } catch { + return null; + } +} diff --git a/src/kimi_cli/tools/plan/plan.ts b/src/kimi_cli/tools/plan/plan.ts new file mode 100644 index 000000000..769f60ff6 --- /dev/null +++ b/src/kimi_cli/tools/plan/plan.ts @@ -0,0 +1,337 @@ +/** + * Plan mode tools — lets the LLM enter/exit plan mode. + * Corresponds to Python tools/plan/enter.py and tools/plan/__init__.py + */ + +import { existsSync, mkdirSync } from "node:fs"; +import { dirname } from "node:path"; +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolOk } from "../types.ts"; +import { getPlanFilePath, readPlanFile } from "./heroes.ts"; + +// ── EnterPlanMode ────────────────────────────────── + +const ENTER_DESCRIPTION = `Use this tool proactively when you're about to start a non-trivial implementation task. +Getting user sign-off on your approach before writing code prevents wasted effort. + +Use it when ANY of these conditions apply: +1. New Feature Implementation +2. Multiple Valid Approaches +3. Code Modifications +4. Architectural Decisions +5. Multi-File Changes +6. Unclear Requirements +7. User Preferences Matter + +When NOT to use: +- Single-line or few-line fixes +- User gave very specific, detailed instructions +- Pure research/exploration tasks`; + +const EnterParamsSchema = z.object({}); + +export class EnterPlanMode extends CallableTool { + readonly name = "EnterPlanMode"; + readonly description = ENTER_DESCRIPTION; + readonly schema = EnterParamsSchema; + + /** Session ID for plan file management. */ + private _sessionId: string; + private _isYolo: (() => boolean) | null = null; + + constructor(sessionId = "default") { + super(); + this._sessionId = sessionId; + } + + /** Bind optional YOLO mode checker. */ + bindYolo(isYolo: () => boolean): void { + this._isYolo = isYolo; + } + + async execute(_params: unknown, ctx: ToolContext): Promise { + // Guard: already in plan mode + if (ctx.getPlanMode?.()) { + return ToolError( + "Already in plan mode. Use ExitPlanMode when your plan is ready.", + ); + } + + const planPath = getPlanFilePath(this._sessionId); + + // In YOLO mode, auto-approve entering plan mode + if (this._isYolo?.()) { + ctx.setPlanMode?.(true); + return ToolOk( + `Plan mode activated (auto-approved in non-interactive mode).\n` + + `Plan file: ${planPath}\n` + + `Workflow: identify key questions about the codebase → ` + + `use Agent(subagent_type='explore') to investigate if needed → ` + + `design approach → ` + + `modify the plan file with WriteFile or StrReplaceFile ` + + `(create it with WriteFile first if it does not exist) → ` + + `call ExitPlanMode.\n`, + "Plan mode on (auto)", + ); + } + + // In interactive mode, ask user for confirmation + if (ctx.askUser) { + try { + const answer = await ctx.askUser("Enter plan mode?", ["Yes", "No"]); + + if (answer === "Yes") { + ctx.setPlanMode?.(true); + return ToolOk( + `Plan mode activated. You MUST NOT edit code files — only read and plan.\n` + + `Plan file: ${planPath}\n` + + `Workflow: identify key questions about the codebase → ` + + `use Agent(subagent_type='explore') to investigate if needed → ` + + `design approach → ` + + `modify the plan file with WriteFile or StrReplaceFile ` + + `(create it with WriteFile first if it does not exist) → ` + + `call ExitPlanMode.\n` + + `Use AskUserQuestion only to clarify missing requirements or choose ` + + `between approaches.\n` + + `Do NOT use AskUserQuestion to ask about plan approval.`, + "Plan mode on", + ); + } else { + return ToolOk( + "User declined to enter plan mode. Please check with user whether " + + "to proceed with implementation directly.", + "Declined", + ); + } + } catch { + // askUser not supported, fall through to auto-enter + } + } + + // Fallback: auto-enter plan mode + ctx.setPlanMode?.(true); + return ToolOk( + "Entered plan mode. You should now focus on exploring the codebase and designing an implementation approach.\n" + + "In plan mode, you should:\n" + + "1. Thoroughly explore the codebase to understand existing patterns\n" + + "2. Consider multiple approaches and their trade-offs\n" + + "3. Design a concrete implementation strategy\n" + + `4. Write your plan to: ${planPath}\n` + + "5. When ready, use ExitPlanMode to present your plan for approval\n" + + "\n" + + "Remember: DO NOT write or edit any files yet. This is a read-only exploration and planning phase.", + "Plan mode activated.", + ); + } +} + +// ── ExitPlanMode ────────────────────────────────── + +const EXIT_DESCRIPTION = `Use this tool when you are in plan mode and have finished writing your plan. +This signals that you're done planning and ready for the user to review and approve. + +IMPORTANT: Only use this tool when the task requires planning the implementation of a task that requires writing code.`; + +const RESERVED_LABELS = new Set(["reject", "revise", "approve", "reject and exit"]); + +const PlanOptionSchema = z.object({ + label: z.string().describe( + "Short name for this option (1-8 words). Append '(Recommended)' if you recommend this option.", + ), + description: z.string().default("").describe( + "Brief summary of this approach and its trade-offs.", + ), +}); + +const ExitParamsSchema = z.object({ + options: z + .array(PlanOptionSchema) + .max(3) + .nullish() + .describe( + "When the plan contains multiple alternative approaches, list them here " + + "so the user can choose which one to execute. 2-3 options. " + + "Do not use 'Reject', 'Revise', 'Approve', or 'Reject and Exit' as labels.", + ), +}); + +type ExitParams = z.infer; + +export class ExitPlanMode extends CallableTool { + readonly name = "ExitPlanMode"; + readonly description = EXIT_DESCRIPTION; + readonly schema = ExitParamsSchema; + + /** Session ID for plan file management. */ + private _sessionId: string; + private _isYolo: (() => boolean) | null = null; + + constructor(sessionId = "default") { + super(); + this._sessionId = sessionId; + } + + /** Bind optional YOLO mode checker. */ + bindYolo(isYolo: () => boolean): void { + this._isYolo = isYolo; + } + + async execute(params: ExitParams, ctx: ToolContext): Promise { + // Guard: only works in plan mode + if (!ctx.getPlanMode?.()) { + return ToolError( + "Not in plan mode. ExitPlanMode is only available during plan mode.", + ); + } + + // Read the plan file + const planPath = getPlanFilePath(this._sessionId); + const planContent = readPlanFile(this._sessionId); + + if (!planContent) { + return ToolError( + `No plan file found. Write your plan to ${planPath} first, then call ExitPlanMode.`, + ); + } + + // Validate option labels + if (params.options) { + for (const opt of params.options) { + if (RESERVED_LABELS.has(opt.label.trim().toLowerCase())) { + return ToolError( + `Option label '${opt.label}' is reserved. Do not use Reject, Revise, Approve, or Reject and Exit as option labels.`, + ); + } + } + // Check uniqueness + const labels = params.options.map((o) => o.label); + if (new Set(labels).size !== labels.length) { + return ToolError("Option labels must be unique."); + } + } + + // In YOLO mode, auto-approve + if (this._isYolo?.()) { + ctx.setPlanMode?.(false); + return ToolOk( + `Plan approved (auto-approved in non-interactive mode). ` + + `Plan mode deactivated. All tools are now available.\n` + + `Plan saved to: ${planPath}\n\n` + + `## Approved Plan:\n${planContent}`, + "Plan approved (auto)", + ); + } + + const hasOptions = params.options != null && params.options.length >= 2; + + // Interactive mode: ask user for approval + if (ctx.askUser) { + try { + // Build option list + let choices: string[]; + if (hasOptions) { + choices = [ + ...params.options!.map((o) => o.label), + "Reject", + "Reject and Exit", + ]; + } else { + choices = ["Approve", "Reject", "Reject and Exit"]; + } + + // Display plan content via wireEmit if available + ctx.wireEmit?.({ + type: "plan_display", + content: planContent, + filePath: planPath, + }); + + const answer = await ctx.askUser("Approve this plan", choices); + + // Handle the answer + if (answer === "Reject and Exit") { + ctx.setPlanMode?.(false); + return ToolError( + "Plan rejected by user. Plan mode deactivated. " + + "All tools are now available. " + + "Wait for the user's next message.", + ); + } + + if (answer === "Reject") { + return ToolError( + "Plan rejected by user. Stay in plan mode. " + + "The user will provide feedback via conversation. " + + "Wait for the user's next message before revising.", + ); + } + + // Approve — multi-approach (user selected a specific option) + if (hasOptions) { + const optionLabels = new Set(params.options!.map((o) => o.label)); + if (optionLabels.has(answer)) { + ctx.setPlanMode?.(false); + return ToolOk( + `Plan approved by user. Selected approach: "${answer}"\n` + + `Plan mode deactivated. All tools are now available.\n` + + `Plan saved to: ${planPath}\n\n` + + `IMPORTANT: Execute ONLY the selected approach "${answer}". ` + + `Ignore other approaches in the plan.\n\n` + + `## Approved Plan:\n${planContent}`, + `Plan approved: ${answer}`, + ); + } + } + + // Approve — single-approach + if (answer === "Approve") { + ctx.setPlanMode?.(false); + return ToolOk( + `Plan approved by user. Plan mode deactivated. ` + + `All tools are now available.\n` + + `Plan saved to: ${planPath}\n\n` + + `## Approved Plan:\n${planContent}`, + "Plan approved", + ); + } + + // Revise — user provided free-text feedback + if (answer) { + return ToolOk( + `User wants to revise the plan. Stay in plan mode. ` + + `Revise based on the feedback below.\n\n` + + `User feedback: ${answer}`, + "Plan revision requested", + ); + } + + return ToolOk( + "User dismissed without choosing. Plan mode remains active. " + + "Continue working on your plan or call ExitPlanMode again when ready.", + "Dismissed", + ); + } catch { + // askUser not supported, auto-approve + ctx.setPlanMode?.(false); + return ToolOk( + `Plan approved (client does not support interactive review). ` + + `Plan mode deactivated.\n` + + `Plan saved to: ${planPath}\n\n` + + `## Approved Plan:\n${planContent}`, + "Plan approved", + ); + } + } + + // Fallback: auto-exit with plan content + ctx.setPlanMode?.(false); + return ToolOk( + `Exited plan mode. All tools are now available.\n` + + `Plan saved to: ${planPath}\n\n` + + `## Plan:\n${planContent}`, + "Plan mode deactivated.", + ); + } +} diff --git a/src/kimi_cli/tools/registry.ts b/src/kimi_cli/tools/registry.ts new file mode 100644 index 000000000..c69d695a2 --- /dev/null +++ b/src/kimi_cli/tools/registry.ts @@ -0,0 +1,97 @@ +/** + * Tool registry — register, find, and list all tools. + * Also acts as a DI container for ToolContext. + * Corresponds to Python tools/__init__.py and tools/registry. + */ + +import type { CallableTool } from "./base.ts"; +import type { ToolContext, ToolDefinition, ToolResult } from "./types.ts"; +import { SkipThisTool, extractKeyArgument } from "./types.ts"; + +// Re-export for convenience +export { SkipThisTool, extractKeyArgument }; + +export class ToolRegistry { + private tools = new Map(); + private _ctx: ToolContext; + + constructor(ctx: ToolContext) { + this._ctx = ctx; + } + + get context(): ToolContext { + return this._ctx; + } + + /** Register a tool instance. Silently skips if SkipThisTool is thrown during construction. */ + register(tool: CallableTool): void { + if (this.tools.has(tool.name)) { + throw new Error(`Tool "${tool.name}" is already registered.`); + } + this.tools.set(tool.name, tool); + } + + /** + * Safely register a tool, catching SkipThisTool during construction. + * Returns true if registered, false if skipped. + */ + tryRegister(factory: () => CallableTool): boolean { + try { + const tool = factory(); + this.register(tool); + return true; + } catch (e) { + if (e instanceof SkipThisTool) { + return false; + } + throw e; + } + } + + /** Find a tool by name. */ + find(name: string): CallableTool | undefined { + return this.tools.get(name); + } + + /** List all registered tools. */ + list(): CallableTool[] { + return [...this.tools.values()]; + } + + /** Get all tool definitions for LLM function calling. */ + definitions(): ToolDefinition[] { + return this.list().map((t) => t.toDefinition()); + } + + /** Execute a tool by name with raw JSON arguments. */ + async execute( + name: string, + rawArgs: Record, + ): Promise { + const tool = this.tools.get(name); + if (!tool) { + return { + isError: true, + output: "", + message: `Tool "${name}" not found.`, + }; + } + + // Validate params through tool schema + const parsed = tool.schema.safeParse(rawArgs); + if (!parsed.success) { + return { + isError: true, + output: "", + message: `Invalid parameters for tool "${name}": ${parsed.error.message}`, + }; + } + + return tool.execute(parsed.data, this._ctx); + } + + /** Extract a key argument for display/logging from raw JSON arguments. */ + extractKeyArgument(jsonContent: string, toolName: string): string | null { + return extractKeyArgument(jsonContent, toolName); + } +} diff --git a/src/kimi_cli/tools/shell/__init__.py b/src/kimi_cli/tools/shell/__init__.py deleted file mode 100644 index 1796ddded..000000000 --- a/src/kimi_cli/tools/shell/__init__.py +++ /dev/null @@ -1,235 +0,0 @@ -import asyncio -from collections.abc import Callable -from pathlib import Path -from typing import Self, override - -import kaos -from kaos import AsyncReadable -from kosong.tooling import CallableTool2, ToolReturnValue -from pydantic import BaseModel, Field, model_validator - -from kimi_cli.background import TaskView, format_task -from kimi_cli.soul.agent import Runtime -from kimi_cli.soul.approval import Approval -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.tools.display import BackgroundTaskDisplayBlock, ShellDisplayBlock -from kimi_cli.tools.utils import ToolResultBuilder, load_desc -from kimi_cli.utils.environment import Environment -from kimi_cli.utils.subprocess_env import get_noninteractive_env - -MAX_FOREGROUND_TIMEOUT = 5 * 60 -MAX_BACKGROUND_TIMEOUT = 24 * 60 * 60 - - -class Params(BaseModel): - command: str = Field(description="The command to execute.") - timeout: int = Field( - description=( - "The timeout in seconds for the command to execute. " - "If the command takes longer than this, it will be killed." - ), - default=60, - ge=1, - le=MAX_BACKGROUND_TIMEOUT, - ) - run_in_background: bool = Field( - default=False, - description="Whether to run the command as a background task.", - ) - description: str = Field( - default="", - description=( - "A short description for the background task. Required when run_in_background=true." - ), - ) - - @model_validator(mode="after") - def _validate_background_fields(self) -> Self: - if self.run_in_background and not self.description.strip(): - raise ValueError("description is required when run_in_background is true") - if not self.run_in_background and self.timeout > MAX_FOREGROUND_TIMEOUT: - raise ValueError( - f"timeout must be <= {MAX_FOREGROUND_TIMEOUT}s for foreground commands; " - f"use run_in_background=true for longer timeouts (up to {MAX_BACKGROUND_TIMEOUT}s)" - ) - return self - - -class Shell(CallableTool2[Params]): - name: str = "Shell" - params: type[Params] = Params - - def __init__(self, approval: Approval, environment: Environment, runtime: Runtime): - is_powershell = environment.shell_name == "Windows PowerShell" - super().__init__( - description=load_desc( - Path(__file__).parent / ("powershell.md" if is_powershell else "bash.md"), - {"SHELL": f"{environment.shell_name} (`{environment.shell_path}`)"}, - ) - ) - self._approval = approval - self._is_powershell = is_powershell - self._shell_path = environment.shell_path - self._runtime = runtime - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - builder = ToolResultBuilder() - - if not params.command: - return builder.error("Command cannot be empty.", brief="Empty command") - - if params.run_in_background: - return await self._run_in_background(params) - - result = await self._approval.request( - self.name, - "run command", - f"Run command `{params.command}`", - display=[ - ShellDisplayBlock( - language="powershell" if self._is_powershell else "bash", - command=params.command, - ) - ], - ) - if not result: - return result.rejection_error() - - def stdout_cb(line: bytes): - line_str = line.decode(encoding="utf-8", errors="replace") - builder.write(line_str) - - def stderr_cb(line: bytes): - line_str = line.decode(encoding="utf-8", errors="replace") - builder.write(line_str) - - try: - exitcode = await self._run_shell_command( - params.command, stdout_cb, stderr_cb, params.timeout - ) - - if exitcode == 0: - return builder.ok("Command executed successfully.") - else: - return builder.error( - f"Command failed with exit code: {exitcode}.", - brief=f"Failed with exit code: {exitcode}", - ) - except TimeoutError: - return builder.error( - f"Command killed by timeout ({params.timeout}s)", - brief=f"Killed by timeout ({params.timeout}s)", - ) - - async def _run_in_background(self, params: Params) -> ToolReturnValue: - tool_call = get_current_tool_call_or_none() - if tool_call is None: - return ToolResultBuilder().error( - "Background shell requires a tool call context.", - brief="No tool call context", - ) - - result = await self._approval.request( - self.name, - "run background command", - f"Run background command `{params.command}`", - display=[ - ShellDisplayBlock( - language="powershell" if self._is_powershell else "bash", - command=params.command, - ) - ], - ) - if not result: - return result.rejection_error() - - try: - view = self._runtime.background_tasks.create_bash_task( - command=params.command, - description=params.description.strip(), - timeout_s=params.timeout, - tool_call_id=tool_call.id, - shell_name="Windows PowerShell" if self._is_powershell else "bash", - shell_path=str(self._shell_path), - cwd=str(self._runtime.session.work_dir), - ) - except Exception as exc: - builder = ToolResultBuilder() - return builder.error(f"Failed to start background task: {exc}", brief="Start failed") - - return self._background_ok(view) - - def _background_ok(self, view: TaskView) -> ToolReturnValue: - builder = ToolResultBuilder() - builder.write( - "\n".join( - [ - format_task(view, include_command=True), - "automatic_notification: true", - "next_step: You will be automatically notified when it completes.", - ( - "next_step: Use TaskOutput with this task_id for a non-blocking " - "status/output snapshot. Only set block=true when you intentionally " - "want to wait." - ), - "next_step: Use TaskStop only if the task must be cancelled.", - ( - "human_shell_hint: For users in the interactive shell, " - "the only task-management slash command is /task. " - "Do not suggest /task list, /task output, /task stop, or /tasks." - ), - ] - ) - ) - builder.display( - BackgroundTaskDisplayBlock( - task_id=view.spec.id, - kind=view.spec.kind, - status=view.runtime.status, - description=view.spec.description, - ) - ) - return builder.ok("Background task started", brief=f"Started {view.spec.id}") - - async def _run_shell_command( - self, - command: str, - stdout_cb: Callable[[bytes], None], - stderr_cb: Callable[[bytes], None], - timeout: int, - ) -> int: - async def _read_stream(stream: AsyncReadable, cb: Callable[[bytes], None]): - while True: - line = await stream.readline() - if line: - cb(line) - else: - break - - process = await kaos.exec(*self._shell_args(command), env=get_noninteractive_env()) - - # Close stdin immediately so interactive prompts (e.g. git password) get - # EOF instead of hanging forever waiting for input that will never come. - process.stdin.close() - - try: - await asyncio.wait_for( - asyncio.gather( - _read_stream(process.stdout, stdout_cb), - _read_stream(process.stderr, stderr_cb), - ), - timeout, - ) - return await process.wait() - except asyncio.CancelledError: - await process.kill() - raise - except TimeoutError: - await process.kill() - raise - - def _shell_args(self, command: str) -> tuple[str, ...]: - if self._is_powershell: - return (str(self._shell_path), "-command", command) - return (str(self._shell_path), "-c", command) diff --git a/src/kimi_cli/tools/shell/bash.md b/src/kimi_cli/tools/shell/bash.md deleted file mode 100644 index 010d27af2..000000000 --- a/src/kimi_cli/tools/shell/bash.md +++ /dev/null @@ -1,35 +0,0 @@ -Execute a ${SHELL} command. Use this tool to explore the filesystem, edit files, run scripts, get system information, etc. - -**Output:** -The stdout and stderr will be combined and returned as a string. The output may be truncated if it is too long. If the command failed, the exit code will be provided in a system tag. - -If `run_in_background=true`, the command will be started as a background task and this tool will return a task ID instead of waiting for command completion. When doing that, you must provide a short `description`. You will be automatically notified when the task completes. Use `TaskOutput` for a non-blocking status/output snapshot, and only set `block=true` when you explicitly want to wait for completion. Use `TaskStop` only if the task must be cancelled. For human users in the interactive shell, background tasks are managed through `/task` only; do not suggest `/task list`, `/task output`, `/task stop`, `/tasks`, or any other invented shell subcommands. - -**Guidelines for safety and security:** -- Each shell tool call will be executed in a fresh shell environment. The shell variables, current working directory changes, and the shell history is not preserved between calls. -- The tool call will return after the command is finished. You shall not use this tool to execute an interactive command or a command that may run forever. For possibly long-running commands, you shall set `timeout` argument to a reasonable value. -- Avoid using `..` to access files or directories outside of the working directory. -- Avoid modifying files outside of the working directory unless explicitly instructed to do so. -- Never run commands that require superuser privileges unless explicitly instructed to do so. - -**Guidelines for efficiency:** -- For multiple related commands, use `&&` to chain them in a single call, e.g. `cd /path && ls -la` -- Use `;` to run commands sequentially regardless of success/failure -- Use `||` for conditional execution (run second command only if first fails) -- Use pipe operations (`|`) and redirections (`>`, `>>`) to chain input and output between commands -- Always quote file paths containing spaces with double quotes (e.g., cd "/path with spaces/") -- Use `if`, `case`, `for`, `while` control flows to execute complex logic in a single call. -- Verify directory structure before create/edit/delete files or directories to reduce the risk of failure. -- Prefer `run_in_background=true` for long-running builds, tests, watchers, or servers when you need the conversation to continue before the command finishes. -- After starting a background task, do not guess its outcome. Rely on the automatic completion notification whenever possible. Use `TaskOutput` for non-blocking progress snapshots by default, and set `block=true` only when you intentionally want to wait. -- If you need to tell a human shell user how to manage background tasks, only mention `/task`. Do not invent `/task list`, `/task output`, `/task stop`, or `/tasks`. - -**Commands available:** -- Shell environment: cd, pwd, export, unset, env -- File system operations: ls, find, mkdir, rm, cp, mv, touch, chmod, chown -- File viewing/editing: cat, grep, head, tail, diff, patch -- Text processing: awk, sed, sort, uniq, wc -- System information/operations: ps, kill, top, df, free, uname, whoami, id, date -- Network operations: curl, wget, ping, telnet, ssh -- Archive operations: tar, zip, unzip -- Other: Other commands available in the shell environment. Check the existence of a command by running `which ` before using it. diff --git a/src/kimi_cli/tools/shell/powershell.md b/src/kimi_cli/tools/shell/powershell.md deleted file mode 100644 index 9696a3a1d..000000000 --- a/src/kimi_cli/tools/shell/powershell.md +++ /dev/null @@ -1,30 +0,0 @@ -Execute a ${SHELL} command. Use this tool to explore the filesystem, inspect or edit files, run Windows scripts, collect system information, etc., whenever the agent is running on Windows. - -Note that you are running on Windows, so make sure to use Windows commands, paths, and conventions. - -**Output:** -The stdout and stderr streams are combined and returned as a single string. Extremely long output may be truncated. When a command fails, the exit code is provided in a system tag. - -If `run_in_background=true`, the command will be started as a background task and this tool will return a task ID instead of waiting for completion. When doing that, you must provide a short `description`. You will be automatically notified when the task completes. Use `TaskOutput` for a non-blocking status/output snapshot, and only set `block=true` when you explicitly want to wait for completion. Use `TaskStop` only if the task must be cancelled. For human users in the interactive shell, background tasks are managed through `/task` only; do not suggest `/task list`, `/task output`, `/task stop`, `/tasks`, or any other invented shell subcommands. - -**Guidelines for safety and security:** -- Every tool call starts a fresh ${SHELL} session. Environment variables, `cd` changes, and command history do not persist between calls. -- Do not launch interactive programs or anything that is expected to block indefinitely; ensure each command finishes promptly. Provide a `timeout` argument for potentially long runs. -- Avoid using `..` to leave the working directory, and never touch files outside that directory unless explicitly instructed. -- Never attempt commands that require elevated (Administrator) privileges unless explicitly authorized. - -**Guidelines for efficiency:** -- Chain related commands with `;` and use `if ($?)` or `if (-not $?)` to conditionally execute commands based on the success or failure of previous ones. -- Redirect or pipe output with `>`, `>>`, `|`, and leverage `for /f`, `if`, and `set` to build richer one-liners instead of multiple tool calls. -- Reuse built-in utilities (e.g., `findstr`, `where`) to filter, transform, or locate data in a single invocation. -- Prefer `run_in_background=true` for long-running builds, tests, watchers, or servers when you need the conversation to continue before the command finishes. -- After starting a background task, do not guess its outcome. Rely on the automatic completion notification whenever possible. Use `TaskOutput` for non-blocking progress snapshots by default, and set `block=true` only when you intentionally want to wait. -- If you need to tell a human shell user how to manage background tasks, only mention `/task`. Do not invent `/task list`, `/task output`, `/task stop`, or `/tasks`. - -**Commands available:** -- Shell environment: `cd`, `dir`, `set`, `setlocal`, `echo`, `call`, `where` -- File operations: `type`, `copy`, `move`, `del`, `erase`, `mkdir`, `rmdir`, `attrib`, `mklink` -- Text/search: `find`, `findstr`, `more`, `sort`, `Get-Content` -- System info: `ver`, `systeminfo`, `tasklist`, `wmic`, `hostname` -- Archives/scripts: `tar`, `Compress-Archive`, `powershell`, `python`, `node` -- Other: Any other binaries available on the system PATH; run `where ` first if unsure. diff --git a/src/kimi_cli/tools/shell/shell.ts b/src/kimi_cli/tools/shell/shell.ts new file mode 100644 index 000000000..b01a31a7c --- /dev/null +++ b/src/kimi_cli/tools/shell/shell.ts @@ -0,0 +1,213 @@ +/** + * Shell tool — execute shell commands. + * Corresponds to Python tools/shell/__init__.py + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolResultBuilder } from "../types.ts"; + +const MAX_FOREGROUND_TIMEOUT = 5 * 60; // 5 minutes +const MAX_BACKGROUND_TIMEOUT = 24 * 60 * 60; // 24 hours + +const DESCRIPTION = `Execute a shell command. Use this tool to explore the filesystem, edit files, run scripts, get system information, etc. + +**Output:** +The stdout and stderr will be combined and returned as a string. The output may be truncated if it is too long. + +**Guidelines for safety and security:** +- Each shell tool call will be executed in a fresh shell environment. +- Avoid using \`..\ to access files outside of the working directory. +- Never run commands that require superuser privileges unless explicitly instructed. + +**Guidelines for efficiency:** +- For multiple related commands, use \`&&\` to chain them in a single call. +- Prefer \`run_in_background=true\` for long-running builds, tests, or servers.`; + +const ParamsSchema = z + .object({ + command: z.string().describe("The command to execute."), + timeout: z + .number() + .int() + .min(1) + .max(MAX_BACKGROUND_TIMEOUT) + .default(60) + .describe("The timeout in seconds for the command to execute."), + run_in_background: z + .boolean() + .default(false) + .describe("Whether to run the command as a background task."), + description: z + .string() + .default("") + .describe( + "A short description for the background task. Required when run_in_background=true.", + ), + }) + .refine( + (data) => !data.run_in_background || data.description.trim().length > 0, + { + message: "description is required when run_in_background is true", + path: ["description"], + }, + ) + .refine( + (data) => + data.run_in_background || data.timeout <= MAX_FOREGROUND_TIMEOUT, + { + message: `timeout must be <= ${MAX_FOREGROUND_TIMEOUT}s for foreground commands; use run_in_background=true for longer timeouts`, + path: ["timeout"], + }, + ); + +type Params = z.infer; + +/** Build a non-interactive environment to prevent prompts from hanging. */ +function getNoninteractiveEnv(): Record { + return { + ...process.env, + GIT_TERMINAL_PROMPT: "0", + TERM: "dumb", + // Prevent SSH from trying to open a tty for passphrase/password + SSH_ASKPASS: "", + SSH_ASKPASS_REQUIRE: "never", + // Prevent GPG pinentry + GPG_TTY: "", + // Disable pager for git, man, etc. + GIT_PAGER: "cat", + PAGER: "cat", + // Disable color in common tools (helps with output parsing) + NO_COLOR: "1", + }; +} + +/** Read a stream and write chunks to builder, interleaving stdout and stderr. */ +async function readStreamToBuilder( + stream: ReadableStream | null, + builder: ToolResultBuilder, +): Promise { + if (!stream) return ""; + const reader = stream.getReader(); + const decoder = new TextDecoder("utf-8", { fatal: false }); + const chunks: string[] = []; + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + const text = decoder.decode(value, { stream: true }); + chunks.push(text); + builder.write(text); + } + } finally { + reader.releaseLock(); + } + return chunks.join(""); +} + +export class Shell extends CallableTool { + readonly name = "Shell"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, ctx: ToolContext): Promise { + const builder = new ToolResultBuilder(); + + if (!params.command) { + return builder.error("Command cannot be empty."); + } + + if (params.run_in_background) { + // Background mode - stub for now + return builder.error( + "Background tasks are not yet implemented in this version.", + ); + } + + // Request approval + const decision = await ctx.approval( + "Shell", + "run command", + `Run command \`${params.command}\``, + ); + if (decision === "reject") { + return ToolError( + "The tool call is rejected by the user. Stop what you are doing and wait for the user to tell you how to proceed.", + ); + } + + try { + const shellPath = process.env.SHELL || "/bin/bash"; + + // Redirect stderr to stdout so they're interleaved in order + // Use shell syntax: command 2>&1 + const wrappedCommand = `${params.command} 2>&1`; + + const proc = Bun.spawn([shellPath, "-c", wrappedCommand], { + stdout: "pipe", + stderr: "pipe", // stderr still piped for safety (but most goes to stdout via 2>&1) + stdin: "pipe", + cwd: ctx.workingDir, + env: getNoninteractiveEnv(), + }); + + // Close stdin immediately so interactive prompts get EOF + try { + proc.stdin.end(); + } catch { + // Bun may not support .end() on all platforms + } + + let timedOut = false; + let timeoutId: ReturnType | null = null; + + try { + const timeoutPromise = new Promise((_, reject) => { + timeoutId = setTimeout( + () => reject(new Error("timeout")), + params.timeout * 1000, + ); + }); + + // Stream stdout to builder for real-time interleaved output + const readPromise = (async () => { + await readStreamToBuilder(proc.stdout as ReadableStream, builder); + // Also drain stderr (in case anything bypassed 2>&1) + const stderrBytes = await new Response(proc.stderr).arrayBuffer(); + const stderrStr = new TextDecoder("utf-8", { fatal: false }).decode(stderrBytes); + if (stderrStr) builder.write(stderrStr); + })(); + + await Promise.race([readPromise, timeoutPromise]); + if (timeoutId !== null) clearTimeout(timeoutId); + + await proc.exited; + } catch (e) { + if (timeoutId !== null) clearTimeout(timeoutId); + if (e instanceof Error && e.message === "timeout") { + proc.kill(); + timedOut = true; + } else { + throw e; + } + } + + if (timedOut) { + return builder.error( + `Command killed by timeout (${params.timeout}s)`, + ); + } + + const exitCode = proc.exitCode; + if (exitCode === 0) { + return builder.ok("Command executed successfully."); + } + return builder.error( + `Command failed with exit code: ${exitCode}.`, + ); + } catch (e) { + return builder.error(`Failed to execute command. Error: ${e}`); + } + } +} diff --git a/src/kimi_cli/tools/test.py b/src/kimi_cli/tools/test.py deleted file mode 100644 index b488f5083..000000000 --- a/src/kimi_cli/tools/test.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio -from typing import override - -from kosong.tooling import CallableTool2, ToolOk, ToolReturnValue -from pydantic import BaseModel - - -class PlusParams(BaseModel): - a: float - b: float - - -class Plus(CallableTool2[PlusParams]): - name: str = "plus" - description: str = "Add two numbers" - params: type[PlusParams] = PlusParams - - @override - async def __call__(self, params: PlusParams) -> ToolReturnValue: - return ToolOk(output=str(params.a + params.b)) - - -class CompareParams(BaseModel): - a: float - b: float - - -class Compare(CallableTool2[CompareParams]): - name: str = "compare" - description: str = "Compare two numbers" - params: type[CompareParams] = CompareParams - - @override - async def __call__(self, params: CompareParams) -> ToolReturnValue: - if params.a > params.b: - return ToolOk(output="greater") - elif params.a < params.b: - return ToolOk(output="less") - else: - return ToolOk(output="equal") - - -class PanicParams(BaseModel): - message: str - - -class Panic(CallableTool2[PanicParams]): - name: str = "panic" - description: str = "Raise an exception to cause the tool call to fail." - params: type[PanicParams] = PanicParams - - @override - async def __call__(self, params: PanicParams) -> ToolReturnValue: - await asyncio.sleep(2) - raise Exception(f"panicked with a message with {len(params.message)} characters") diff --git a/src/kimi_cli/tools/think/__init__.py b/src/kimi_cli/tools/think/__init__.py deleted file mode 100644 index 9c8b16d37..000000000 --- a/src/kimi_cli/tools/think/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from pathlib import Path -from typing import override - -from kosong.tooling import CallableTool2, ToolOk, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.tools.utils import load_desc - - -class Params(BaseModel): - thought: str = Field(description=("A thought to think about.")) - - -class Think(CallableTool2[Params]): - name: str = "Think" - description: str = load_desc(Path(__file__).parent / "think.md", {}) - params: type[Params] = Params - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - return ToolOk(output="", message="Thought logged") diff --git a/src/kimi_cli/tools/think/think.md b/src/kimi_cli/tools/think/think.md deleted file mode 100644 index f3378c36b..000000000 --- a/src/kimi_cli/tools/think/think.md +++ /dev/null @@ -1 +0,0 @@ -Use the tool to think about something. It will not obtain new information or change the database, but just append the thought to the log. Use it when complex reasoning or some cache memory is needed. diff --git a/src/kimi_cli/tools/think/think.ts b/src/kimi_cli/tools/think/think.ts new file mode 100644 index 000000000..7b741db9f --- /dev/null +++ b/src/kimi_cli/tools/think/think.ts @@ -0,0 +1,28 @@ +/** + * Think tool — give the LLM thinking space. + * Corresponds to Python tools/think/__init__.py + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolOk } from "../types.ts"; + +const DESCRIPTION = + "Use the tool to think about something. It will not obtain new information or change the database, but just append the thought to the log. Use it when complex reasoning or some cache memory is needed."; + +const ParamsSchema = z.object({ + thought: z.string().describe("A thought to think about."), +}); + +type Params = z.infer; + +export class Think extends CallableTool { + readonly name = "Think"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(_params: Params, _ctx: ToolContext): Promise { + return ToolOk("", "Thought logged"); + } +} diff --git a/src/kimi_cli/tools/todo/__init__.py b/src/kimi_cli/tools/todo/__init__.py deleted file mode 100644 index b5b8b61f0..000000000 --- a/src/kimi_cli/tools/todo/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -from pathlib import Path -from typing import Literal, override - -from kosong.tooling import CallableTool2, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.tools.display import TodoDisplayBlock, TodoDisplayItem -from kimi_cli.tools.utils import load_desc - - -class Todo(BaseModel): - title: str = Field(description="The title of the todo", min_length=1) - status: Literal["pending", "in_progress", "done"] = Field(description="The status of the todo") - - -class Params(BaseModel): - todos: list[Todo] = Field(description="The updated todo list") - - -class SetTodoList(CallableTool2[Params]): - name: str = "SetTodoList" - description: str = load_desc(Path(__file__).parent / "set_todo_list.md") - params: type[Params] = Params - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - items = [TodoDisplayItem(title=todo.title, status=todo.status) for todo in params.todos] - return ToolReturnValue( - is_error=False, - output="", - message="Todo list updated", - display=[TodoDisplayBlock(items=items)], - ) diff --git a/src/kimi_cli/tools/todo/set_todo_list.md b/src/kimi_cli/tools/todo/set_todo_list.md deleted file mode 100644 index d28889940..000000000 --- a/src/kimi_cli/tools/todo/set_todo_list.md +++ /dev/null @@ -1,15 +0,0 @@ -Update the whole todo list. - -Todo list is a simple yet powerful tool to help you get things done. You typically want to use this tool when the given task involves multiple subtasks/milestones, or, multiple tasks are given in a single request. This tool can help you to break down the task and track the progress. - -This is the only todo list tool available to you. That said, each time you want to operate on the todo list, you need to update the whole. Make sure to maintain the todo items and their statuses properly. - -Once you finished a subtask/milestone, remember to update the todo list to reflect the progress. Also, you can give yourself a self-encouragement to keep you motivated. - -Abusing this tool to track too small steps will just waste your time and make your context messy. For example, here are some cases you should not use this tool: - -- When the user just simply ask you a question. E.g. "What language and framework is used in the project?", "What is the best practice for x?" -- When it only takes a few steps/tool calls to complete the task. E.g. "Fix the unit test function 'test_xxx'", "Refactor the function 'xxx' to make it more solid." -- When the user prompt is very specific and the only thing you need to do is brainlessly following the instructions. E.g. "Replace xxx to yyy in the file zzz", "Create a file xxx with content yyy." - -However, do not get stuck in a rut. Be flexible. Sometimes, you may try to use todo list at first, then realize the task is too simple and you can simply stop using it; or, sometimes, you may realize the task is complex after a few steps and then you can start using todo list to break it down. diff --git a/src/kimi_cli/tools/todo/todo.ts b/src/kimi_cli/tools/todo/todo.ts new file mode 100644 index 000000000..89e37c7a1 --- /dev/null +++ b/src/kimi_cli/tools/todo/todo.ts @@ -0,0 +1,50 @@ +/** + * SetTodoList tool — manage a todo list. + * Corresponds to Python tools/todo/__init__.py + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; + +const DESCRIPTION = `Update the whole todo list. + +Todo list is a simple yet powerful tool to help you get things done. Use this tool when the given task involves multiple subtasks/milestones. + +Each time you want to operate on the todo list, you need to update the whole. Make sure to maintain the todo items and their statuses properly.`; + +const TodoSchema = z.object({ + title: z.string().min(1).describe("The title of the todo"), + status: z + .enum(["pending", "in_progress", "done"]) + .describe("The status of the todo"), +}); + +const ParamsSchema = z.object({ + todos: z.array(TodoSchema).describe("The updated todo list"), +}); + +type Params = z.infer; + +export class SetTodoList extends CallableTool { + readonly name = "SetTodoList"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, _ctx: ToolContext): Promise { + return { + isError: false, + output: "", + message: "Todo list updated", + display: [ + { + type: "todo", + items: params.todos.map((t) => ({ + title: t.title, + status: t.status, + })), + }, + ], + }; + } +} diff --git a/src/kimi_cli/tools/types.ts b/src/kimi_cli/tools/types.ts new file mode 100644 index 000000000..a27652df9 --- /dev/null +++ b/src/kimi_cli/tools/types.ts @@ -0,0 +1,346 @@ +/** + * Tool-related types — corresponds to Python tools/utils.py and kosong.tooling types. + */ + +import type { ApprovalDecision, JsonValue } from "../types.ts"; + +// ── ToolContext ────────────────────────────────────────── + +/** Context injected into every tool execution. */ +export interface ToolContext { + /** Current working directory. */ + workingDir: string; + /** AbortSignal for cooperative cancellation. */ + signal?: AbortSignal; + /** Request user approval; returns the decision. */ + approval: ( + toolName: string, + action: string, + summary: string, + ) => Promise; + /** Emit a wire event (for UI communication). */ + wireEmit?: (event: unknown) => void; + /** Toggle plan mode on/off. */ + setPlanMode?: (on: boolean) => void; + /** Get current plan mode status. */ + getPlanMode?: () => boolean; + /** Get plan file path. */ + getPlanFilePath?: () => string | undefined; + /** Toggle plan mode (manual toggle from slash command). */ + togglePlanMode?: () => void; + /** Ask the user a question and get the answer (for AskUserQuestion tool). */ + askUser?: (question: string, options?: string[]) => Promise; + /** Access to service config (for SearchWeb, FetchURL). */ + serviceConfig?: { + moonshotSearch?: { baseUrl: string; apiKey: string; customHeaders?: Record }; + moonshotFetch?: { baseUrl: string; apiKey: string; customHeaders?: Record }; + }; +} + +// ── ToolResult ────────────────────────────────────────── + +export interface ToolResult { + isError: boolean; + output: string; + message?: string; + display?: unknown[]; + extras?: Record; +} + +/** Create a successful ToolResult. */ +export function ToolOk( + output: string, + message?: string, + display?: unknown[], + extras?: Record, +): ToolResult { + return { isError: false, output, message, display, extras }; +} + +/** Create an error ToolResult. */ +export function ToolError( + message: string, + output = "", + display?: unknown[], +): ToolResult { + return { isError: true, output, message, display }; +} + +// ── ToolDefinition ────────────────────────────────────── + +export interface ToolDefinition { + name: string; + description: string; + parameters: Record; +} + +// ── ToolResultBuilder ─────────────────────────────────── + +const DEFAULT_MAX_CHARS = 50_000; +const DEFAULT_MAX_LINE_LENGTH = 2000; + +function truncateLine(line: string, maxLength: number, marker = "..."): string { + if (line.length <= maxLength) return line; + + // Find trailing line breaks + const m = line.match(/[\r\n]+$/); + const linebreak = m ? m[0] : ""; + const end = marker + linebreak; + const effectiveMax = Math.max(maxLength, end.length); + return line.slice(0, effectiveMax - end.length) + end; +} + +export class ToolResultBuilder { + private maxChars: number; + private maxLineLength: number | null; + private marker = "[...truncated]"; + private buffer: string[] = []; + private _nChars = 0; + private _nLines = 0; + private _truncationHappened = false; + private _display: unknown[] = []; + private _extras: Record | null = null; + + constructor( + maxChars = DEFAULT_MAX_CHARS, + maxLineLength: number | null = DEFAULT_MAX_LINE_LENGTH, + ) { + this.maxChars = maxChars; + this.maxLineLength = maxLineLength; + } + + get isFull(): boolean { + return this._nChars >= this.maxChars; + } + + get nChars(): number { + return this._nChars; + } + + get nLines(): number { + return this._nLines; + } + + /** Write text to the output buffer. Returns number of characters written. */ + write(text: string): number { + if (this.isFull) return 0; + + // Split keeping line endings + const lines = text.split(/(?<=\n)/); + if (lines.length === 0) return 0; + + let charsWritten = 0; + + for (const originalLine of lines) { + if (this.isFull) break; + if (!originalLine) continue; + + const remainingChars = this.maxChars - this._nChars; + const limit = + this.maxLineLength !== null + ? Math.min(remainingChars, this.maxLineLength) + : remainingChars; + const line = truncateLine(originalLine, limit, this.marker); + if (line !== originalLine) { + this._truncationHappened = true; + } + + this.buffer.push(line); + charsWritten += line.length; + this._nChars += line.length; + if (line.endsWith("\n")) { + this._nLines += 1; + } + } + + return charsWritten; + } + + display(...blocks: unknown[]): void { + this._display.push(...blocks); + } + + extras(extra: Record): void { + if (this._extras === null) { + this._extras = {}; + } + Object.assign(this._extras, extra); + } + + ok(message = ""): ToolResult { + const output = this.buffer.join(""); + + let finalMessage = message; + if (finalMessage && !finalMessage.endsWith(".")) { + finalMessage += "."; + } + const truncationMsg = "Output is truncated to fit in the message."; + if (this._truncationHappened) { + finalMessage = finalMessage + ? `${finalMessage} ${truncationMsg}` + : truncationMsg; + } + return { + isError: false, + output, + message: finalMessage || undefined, + display: this._display.length > 0 ? this._display : undefined, + extras: this._extras ?? undefined, + }; + } + + error(message: string): ToolResult { + const output = this.buffer.join(""); + + let finalMessage = message; + if (this._truncationHappened) { + const truncationMsg = "Output is truncated to fit in the message."; + finalMessage = finalMessage + ? `${finalMessage} ${truncationMsg}` + : truncationMsg; + } + + return { + isError: true, + output, + message: finalMessage, + display: this._display.length > 0 ? this._display : undefined, + extras: this._extras ?? undefined, + }; + } +} + +// ── ToolRejectedError ─────────────────────────────── + +/** + * Thrown / returned when a tool call is rejected by the user. + * Corresponds to Python utils.ToolRejectedError. + */ +export class ToolRejectedError extends Error { + readonly isError = true as const; + readonly hasFeedback: boolean; + readonly brief: string; + + constructor(opts?: { + message?: string; + brief?: string; + hasFeedback?: boolean; + }) { + super( + opts?.message ?? + "The tool call is rejected by the user. " + + "Stop what you are doing and wait for the user to tell you how to proceed.", + ); + this.name = "ToolRejectedError"; + this.brief = opts?.brief ?? "Rejected by user"; + this.hasFeedback = opts?.hasFeedback ?? false; + } + + /** Convert to a ToolResult for returning from a tool handler. */ + toToolResult(): ToolResult { + return { + isError: true, + output: "", + message: this.message, + }; + } +} + +// ── SkipThisTool ──────────────────────────────────── + +/** + * Thrown when a tool decides to skip itself from the loading process. + * Corresponds to Python __init__.SkipThisTool. + */ +export class SkipThisTool extends Error { + constructor(reason?: string) { + super(reason ?? "Tool skipped"); + this.name = "SkipThisTool"; + } +} + +// ── extractKeyArgument ────────────────────────────── + +/** + * Extract a key argument string from tool call JSON arguments. + * Used for logging / display summaries. + * Corresponds to Python __init__.extract_key_argument. + */ +export function extractKeyArgument( + jsonContent: string, + toolName: string, +): string | null { + let args: Record; + try { + args = JSON.parse(jsonContent); + } catch { + return null; + } + if (!args || typeof args !== "object") return null; + + let keyArg = ""; + + switch (toolName) { + case "Agent": + if (!args.description) return null; + keyArg = String(args.description); + break; + case "SendDMail": + case "SetTodoList": + return null; + case "Think": + if (!args.thought) return null; + keyArg = String(args.thought); + break; + case "Shell": + if (!args.command) return null; + keyArg = String(args.command); + break; + case "TaskOutput": + case "TaskStop": + if (!args.task_id) return null; + keyArg = String(args.task_id); + break; + case "TaskList": + keyArg = args.active_only !== false ? "active" : "all"; + break; + case "ReadFile": + case "ReadMediaFile": + case "WriteFile": + case "StrReplaceFile": + if (!args.path) return null; + keyArg = _normalizePath(String(args.path)); + break; + case "Glob": + case "Grep": + if (!args.pattern) return null; + keyArg = String(args.pattern); + break; + case "SearchWeb": + if (!args.query) return null; + keyArg = String(args.query); + break; + case "FetchURL": + if (!args.url) return null; + keyArg = String(args.url); + break; + default: + keyArg = jsonContent; + } + + return _shortenMiddle(keyArg, 50); +} + +function _normalizePath(path: string): string { + const cwd = process.cwd(); + if (path.startsWith(cwd)) { + path = path.slice(cwd.length).replace(/^[/\\]/, ""); + } + return path; +} + +function _shortenMiddle(s: string, width: number): string { + if (s.length <= width) return s; + const half = Math.floor((width - 3) / 2); + return s.slice(0, half) + "..." + s.slice(s.length - half); +} diff --git a/src/kimi_cli/tools/utils.py b/src/kimi_cli/tools/utils.py deleted file mode 100644 index 8427703a2..000000000 --- a/src/kimi_cli/tools/utils.py +++ /dev/null @@ -1,199 +0,0 @@ -import re -from pathlib import Path - -from jinja2 import Environment, Undefined -from kosong.tooling import BriefDisplayBlock, DisplayBlock, ToolError, ToolReturnValue -from kosong.utils.typing import JsonType - - -class _KeepPlaceholderUndefined(Undefined): - def __str__(self) -> str: - if self._undefined_name is None: - return "" - return f"${{{self._undefined_name}}}" - - __repr__ = __str__ - - -def load_desc(path: Path, context: dict[str, object] | None = None) -> str: - """Load a tool description from a file, rendered via Jinja2.""" - description = path.read_text(encoding="utf-8") - env = Environment( - keep_trailing_newline=True, - lstrip_blocks=True, - trim_blocks=True, - variable_start_string="${", - variable_end_string="}", - undefined=_KeepPlaceholderUndefined, - ) - template = env.from_string(description) - return template.render(context or {}) - - -def truncate_line(line: str, max_length: int, marker: str = "...") -> str: - """ - Truncate a line if it exceeds `max_length`, preserving the beginning and the line break. - The output may be longer than `max_length` if it is too short to fit the marker. - """ - if len(line) <= max_length: - return line - - # Find line breaks at the end of the line - m = re.search(r"[\r\n]+$", line) - linebreak = m.group(0) if m else "" - end = marker + linebreak - max_length = max(max_length, len(end)) - return line[: max_length - len(end)] + end - - -# Default output limits -DEFAULT_MAX_CHARS = 50_000 -DEFAULT_MAX_LINE_LENGTH = 2000 - - -class ToolResultBuilder: - """ - Builder for tool results with character and line limits. - """ - - def __init__( - self, - max_chars: int = DEFAULT_MAX_CHARS, - max_line_length: int | None = DEFAULT_MAX_LINE_LENGTH, - ): - self.max_chars = max_chars - self.max_line_length = max_line_length - self._marker = "[...truncated]" - if max_line_length is not None: - assert max_line_length > len(self._marker) - self._buffer: list[str] = [] - self._n_chars = 0 - self._n_lines = 0 - self._truncation_happened = False - self._display: list[DisplayBlock] = [] - self._extras: dict[str, JsonType] | None = None - - @property - def is_full(self) -> bool: - """Check if output buffer is full due to character limit.""" - return self._n_chars >= self.max_chars - - @property - def n_chars(self) -> int: - """Get current character count.""" - return self._n_chars - - @property - def n_lines(self) -> int: - """Get current line count.""" - return self._n_lines - - def write(self, text: str) -> int: - """ - Write text to the output buffer. - - Returns: - int: Number of characters actually written - """ - if self.is_full: - return 0 - - lines = text.splitlines(keepends=True) - if not lines: - return 0 - - chars_written = 0 - - for line in lines: - if self.is_full: - break - - original_line = line - remaining_chars = self.max_chars - self._n_chars - limit = ( - min(remaining_chars, self.max_line_length) - if self.max_line_length is not None - else remaining_chars - ) - line = truncate_line(line, limit, self._marker) - if line != original_line: - self._truncation_happened = True - - self._buffer.append(line) - chars_written += len(line) - self._n_chars += len(line) - if line.endswith("\n"): - self._n_lines += 1 - - return chars_written - - def display(self, *blocks: DisplayBlock) -> None: - """Add display blocks to the tool result.""" - self._display.extend(blocks) - - def extras(self, **extras: JsonType) -> None: - """Add extra data to the tool result.""" - if self._extras is None: - self._extras = {} - self._extras.update(extras) - - def ok(self, message: str = "", *, brief: str = "") -> ToolReturnValue: - """Create a ToolReturnValue with is_error=False and the current output.""" - output = "".join(self._buffer) - - final_message = message - if final_message and not final_message.endswith("."): - final_message += "." - truncation_msg = "Output is truncated to fit in the message." - if self._truncation_happened: - if final_message: - final_message += f" {truncation_msg}" - else: - final_message = truncation_msg - return ToolReturnValue( - is_error=False, - output=output, - message=final_message, - display=([BriefDisplayBlock(text=brief)] if brief else []) + self._display, - extras=self._extras, - ) - - def error(self, message: str, *, brief: str) -> ToolReturnValue: - """Create a ToolReturnValue with is_error=True and the current output.""" - output = "".join(self._buffer) - - final_message = message - if self._truncation_happened: - truncation_msg = "Output is truncated to fit in the message." - if final_message: - final_message += f" {truncation_msg}" - else: - final_message = truncation_msg - - return ToolReturnValue( - is_error=True, - output=output, - message=final_message, - display=([BriefDisplayBlock(text=brief)] if brief else []) + self._display, - extras=self._extras, - ) - - -class ToolRejectedError(ToolError): - has_feedback: bool = False - - def __init__( - self, - message: str | None = None, - brief: str = "Rejected by user", - has_feedback: bool = False, - ): - super().__init__( - message=message - or ( - "The tool call is rejected by the user. " - "Stop what you are doing and wait for the user to tell you how to proceed." - ), - brief=brief, - ) - self.has_feedback = has_feedback diff --git a/src/kimi_cli/tools/web/__init__.py b/src/kimi_cli/tools/web/__init__.py deleted file mode 100644 index 012f0ba83..000000000 --- a/src/kimi_cli/tools/web/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .fetch import FetchURL -from .search import SearchWeb - -__all__ = ("SearchWeb", "FetchURL") diff --git a/src/kimi_cli/tools/web/fetch.md b/src/kimi_cli/tools/web/fetch.md deleted file mode 100644 index 73ebcc803..000000000 --- a/src/kimi_cli/tools/web/fetch.md +++ /dev/null @@ -1 +0,0 @@ -Fetch a web page from a URL and extract main text content from it. diff --git a/src/kimi_cli/tools/web/fetch.py b/src/kimi_cli/tools/web/fetch.py deleted file mode 100644 index dbe3fe5f5..000000000 --- a/src/kimi_cli/tools/web/fetch.py +++ /dev/null @@ -1,173 +0,0 @@ -from pathlib import Path -from typing import override - -import aiohttp -import trafilatura -from kosong.tooling import CallableTool2, ToolReturnValue -from pydantic import BaseModel, Field - -from kimi_cli.config import Config -from kimi_cli.constant import USER_AGENT -from kimi_cli.soul.agent import Runtime -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.tools.utils import ToolResultBuilder, load_desc -from kimi_cli.utils.aiohttp import new_client_session -from kimi_cli.utils.logging import logger - - -class Params(BaseModel): - url: str = Field(description="The URL to fetch content from.") - - -class FetchURL(CallableTool2[Params]): - name: str = "FetchURL" - description: str = load_desc(Path(__file__).parent / "fetch.md", {}) - params: type[Params] = Params - - def __init__(self, config: Config, runtime: Runtime): - super().__init__() - self._runtime = runtime - self._service_config = config.services.moonshot_fetch - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - if self._service_config: - ret = await self._fetch_with_service(params) - if not ret.is_error: - return ret - logger.warning("Failed to fetch URL via service: {error}", error=ret.message) - # fallback to local fetch if service fetch fails - return await self.fetch_with_http_get(params) - - @staticmethod - async def fetch_with_http_get(params: Params) -> ToolReturnValue: - builder = ToolResultBuilder(max_line_length=None) - try: - # Fetching arbitrary web pages can take a while on large/slow sites. - fetch_timeout = aiohttp.ClientTimeout(total=180, sock_read=60, sock_connect=15) - async with ( - new_client_session(timeout=fetch_timeout) as session, - session.get( - params.url, - headers={ - "User-Agent": ( - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " - "(KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - ), - }, - ) as response, - ): - if response.status >= 400: - return builder.error( - ( - f"Failed to fetch URL. Status: {response.status}. " - f"This may indicate the page is not accessible or the server is down." - ), - brief=f"HTTP {response.status} error", - ) - - resp_text = await response.text() - - content_type = response.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower() - if content_type.startswith(("text/plain", "text/markdown")): - builder.write(resp_text) - return builder.ok("The returned content is the full content of the page.") - except TimeoutError: - return builder.error( - "Failed to fetch URL: request timed out. The server may be slow or unreachable.", - brief="Request timed out", - ) - except aiohttp.ClientError as e: - return builder.error( - ( - f"Failed to fetch URL due to network error: {e}. " - "This may indicate the URL is invalid or the server is unreachable." - ), - brief="Network error", - ) - - if not resp_text: - return builder.ok( - "The response body is empty.", - brief="Empty response body", - ) - - extracted_text = trafilatura.extract( - resp_text, - include_comments=True, - include_tables=True, - include_formatting=False, - output_format="txt", - with_metadata=True, - ) - - if not extracted_text: - return builder.error( - ( - "Failed to extract meaningful content from the page. " - "This may indicate the page content is not suitable for text extraction, " - "or the page requires JavaScript to render its content." - ), - brief="No content extracted", - ) - - builder.write(extracted_text) - return builder.ok("The returned content is the main text content extracted from the page.") - - async def _fetch_with_service(self, params: Params) -> ToolReturnValue: - assert self._service_config is not None - - tool_call = get_current_tool_call_or_none() - assert tool_call is not None, "Tool call is expected to be set" - - builder = ToolResultBuilder(max_line_length=None) - api_key = self._runtime.oauth.resolve_api_key( - self._service_config.api_key, self._service_config.oauth - ) - if not api_key: - return builder.error( - "Fetch service is not configured. You may want to try other methods to fetch.", - brief="Fetch service not configured", - ) - headers = { - "User-Agent": USER_AGENT, - "Authorization": f"Bearer {api_key}", - "Accept": "text/markdown", - "X-Msh-Tool-Call-Id": tool_call.id, - **self._runtime.oauth.common_headers(), - **(self._service_config.custom_headers or {}), - } - - try: - async with ( - new_client_session() as session, - session.post( - self._service_config.base_url, - headers=headers, - json={"url": params.url}, - ) as response, - ): - if response.status != 200: - return builder.error( - f"Failed to fetch URL via service. Status: {response.status}.", - brief="Failed to fetch URL via fetch service", - ) - - content = await response.text() - builder.write(content) - return builder.ok( - "The returned content is the main content extracted from the page." - ) - except TimeoutError: - return builder.error( - "Failed to fetch URL via service: request timed out.", - brief="Service request timed out", - ) - except aiohttp.ClientError as e: - return builder.error( - ( - f"Failed to fetch URL via service due to network error: {e}. " - "This may indicate the service is unreachable." - ), - brief="Network error when calling fetch service", - ) diff --git a/src/kimi_cli/tools/web/fetch.ts b/src/kimi_cli/tools/web/fetch.ts new file mode 100644 index 000000000..8136e7fa5 --- /dev/null +++ b/src/kimi_cli/tools/web/fetch.ts @@ -0,0 +1,252 @@ +/** + * FetchURL tool — fetch a web page and extract main text content. + * Corresponds to Python tools/web/fetch.py + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolResultBuilder } from "../types.ts"; + +const DESCRIPTION = + "Fetch a web page from a URL and extract main text content from it."; + +const ParamsSchema = z.object({ + url: z.string().describe("The URL to fetch content from."), +}); + +type Params = z.infer; + +export class FetchURL extends CallableTool { + readonly name = "FetchURL"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, ctx: ToolContext): Promise { + const builder = new ToolResultBuilder(50_000, null); + + try { + // Try service-based fetch first (if configured) + const fetchConfig = ctx.serviceConfig?.moonshotFetch; + if (fetchConfig?.baseUrl && fetchConfig?.apiKey) { + try { + const serviceResult = await fetchViaService(params.url, fetchConfig); + if (serviceResult) { + builder.write(serviceResult); + return builder.ok("Content fetched via service."); + } + } catch { + // Fall through to direct fetch + } + } + + // Direct HTTP fetch + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 60_000); + + const response = await fetch(params.url, { + headers: { + "User-Agent": + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + Accept: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + }, + signal: controller.signal, + redirect: "follow", + }); + + clearTimeout(timeout); + + if (response.status >= 400) { + return builder.error( + `Failed to fetch URL. Status: ${response.status}. This may indicate the page is not accessible or the server is down.`, + ); + } + + const respText = await response.text(); + const contentType = response.headers.get("content-type") || ""; + + if ( + contentType.startsWith("text/plain") || + contentType.startsWith("text/markdown") || + contentType.startsWith("application/json") + ) { + builder.write(respText); + return builder.ok( + "The returned content is the full content of the page.", + ); + } + + if (!respText) { + return builder.ok("The response body is empty."); + } + + // Extract main content from HTML + const extracted = extractContent(respText); + + if (!extracted || extracted.length < 10) { + return builder.error( + "Failed to extract meaningful content from the page. " + + "The page may require JavaScript to render its content.", + ); + } + + builder.write(extracted); + return builder.ok( + "The returned content is the main text content extracted from the page.", + ); + } catch (e) { + if (e instanceof DOMException && e.name === "AbortError") { + return builder.error( + "Failed to fetch URL: request timed out. The server may be slow or unreachable.", + ); + } + return builder.error( + `Failed to fetch URL due to network error: ${e}. The URL may be invalid or the server is unreachable.`, + ); + } + } +} + +/** Try fetching via moonshot fetch service. */ +async function fetchViaService( + url: string, + config: { baseUrl: string; apiKey: string; customHeaders?: Record }, +): Promise { + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 90_000); + + const response = await fetch(config.baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${config.apiKey}`, + ...(config.customHeaders ?? {}), + }, + body: JSON.stringify({ url }), + signal: controller.signal, + }); + + clearTimeout(timeout); + + if (!response.ok) return null; + + const data = (await response.json()) as { content?: string; text?: string }; + return data.content || data.text || null; +} + +// ── HTML Content Extraction ────────────────────────── +// A proper content extraction that handles structure, entities, and main content detection. + +/** Decode all HTML entities (named + numeric). */ +function decodeEntities(text: string): string { + const namedEntities: Record = { + nbsp: " ", amp: "&", lt: "<", gt: ">", quot: '"', apos: "'", + ldquo: "\u201C", rdquo: "\u201D", lsquo: "\u2018", rsquo: "\u2019", + mdash: "\u2014", ndash: "\u2013", hellip: "\u2026", + copy: "\u00A9", reg: "\u00AE", trade: "\u2122", + bull: "\u2022", middot: "\u00B7", laquo: "\u00AB", raquo: "\u00BB", + }; + + return text + .replace(/&#(\d+);/g, (_, n) => String.fromCodePoint(parseInt(n, 10))) + .replace(/&#x([0-9a-fA-F]+);/g, (_, n) => String.fromCodePoint(parseInt(n, 16))) + .replace(/&(\w+);/g, (match, name) => namedEntities[name.toLowerCase()] ?? match); +} + +/** Extract readable text content from HTML. */ +function extractContent(html: string): string { + // Step 1: Remove non-content elements entirely + let text = html + .replace(/]*>[\s\S]*?<\/script>/gi, "") + .replace(/]*>[\s\S]*?<\/style>/gi, "") + .replace(/]*>[\s\S]*?<\/noscript>/gi, "") + .replace(/]*>[\s\S]*?<\/svg>/gi, "") + .replace(/]*>[\s\S]*?<\/nav>/gi, "") // Navigation + .replace(/]*>[\s\S]*?<\/footer>/gi, "") // Footer + .replace(//g, ""); // HTML comments + + // Step 2: Extract title + const titleMatch = text.match(/]*>([\s\S]*?)<\/title>/i); + const title = titleMatch ? decodeEntities(titleMatch[1]!.trim()) : ""; + + // Step 3: Try to find main content area + let mainContent = ""; + const mainPatterns = [ + /]*>([\s\S]*?)<\/main>/i, + /]*>([\s\S]*?)<\/article>/i, + /]*(?:class|id)="[^"]*(?:content|article|post|entry|main)[^"]*"[^>]*>([\s\S]*?)<\/div>/i, + ]; + + for (const pattern of mainPatterns) { + const match = text.match(pattern); + if (match && match[1]!.length > 200) { + mainContent = match[1]!; + break; + } + } + + // Fall back to body or full text + if (!mainContent) { + const bodyMatch = text.match(/]*>([\s\S]*?)<\/body>/i); + mainContent = bodyMatch ? bodyMatch[1]! : text; + } + + // Step 4: Convert headings to markdown-style + mainContent = mainContent + .replace(/]*>([\s\S]*?)<\/h1>/gi, "\n\n# $1\n\n") + .replace(/]*>([\s\S]*?)<\/h2>/gi, "\n\n## $1\n\n") + .replace(/]*>([\s\S]*?)<\/h3>/gi, "\n\n### $1\n\n") + .replace(/]*>([\s\S]*?)<\/h[4-6]>/gi, "\n\n#### $1\n\n"); + + // Step 5: Handle links — extract text with URL + mainContent = mainContent.replace( + /]*href="([^"]*)"[^>]*>([\s\S]*?)<\/a>/gi, + (_, href, text) => { + const linkText = text.replace(/<[^>]+>/g, "").trim(); + if (!linkText) return ""; + // Only show URL for external links + if (href.startsWith("http")) return `[${linkText}](${href})`; + return linkText; + }, + ); + + // Step 6: Handle lists + mainContent = mainContent + .replace(/]*>/gi, "\n- ") + .replace(/<\/li>/gi, ""); + + // Step 7: Handle tables — convert to simple text + mainContent = mainContent + .replace(/]*>/gi, "\n") + .replace(/<\/tr>/gi, "") + .replace(/]*>/gi, "\t") + .replace(/<\/t[hd]>/gi, ""); + + // Step 8: Handle block elements + mainContent = mainContent + .replace(/<\/?(p|div|section|header|blockquote)[^>]*>/gi, "\n\n") + .replace(//gi, "\n") + .replace(//gi, "\n---\n") + .replace(/<\/?(pre|code)[^>]*>/gi, "\n```\n"); + + // Step 9: Remove all remaining HTML tags + mainContent = mainContent.replace(/<[^>]+>/g, ""); + + // Step 10: Decode entities + mainContent = decodeEntities(mainContent); + + // Step 11: Clean up whitespace + mainContent = mainContent + .replace(/[ \t]+/g, " ") // Collapse horizontal whitespace + .replace(/\n[ \t]+/g, "\n") // Remove leading whitespace on lines + .replace(/[ \t]+\n/g, "\n") // Remove trailing whitespace on lines + .replace(/\n{3,}/g, "\n\n") // Max 2 consecutive newlines + .trim(); + + // Prepend title if found + if (title && !mainContent.startsWith(title)) { + mainContent = `# ${title}\n\n${mainContent}`; + } + + return mainContent; +} diff --git a/src/kimi_cli/tools/web/search.md b/src/kimi_cli/tools/web/search.md deleted file mode 100644 index 19e4cec77..000000000 --- a/src/kimi_cli/tools/web/search.md +++ /dev/null @@ -1 +0,0 @@ -WebSearch tool allows you to search on the internet to get latest information, including news, documents, release notes, blog posts, papers, etc. diff --git a/src/kimi_cli/tools/web/search.py b/src/kimi_cli/tools/web/search.py deleted file mode 100644 index 4c00ddf7a..000000000 --- a/src/kimi_cli/tools/web/search.py +++ /dev/null @@ -1,146 +0,0 @@ -from pathlib import Path -from typing import override - -import aiohttp -from kosong.tooling import CallableTool2, ToolReturnValue -from pydantic import BaseModel, Field, ValidationError - -from kimi_cli.config import Config -from kimi_cli.constant import USER_AGENT -from kimi_cli.soul.agent import Runtime -from kimi_cli.soul.toolset import get_current_tool_call_or_none -from kimi_cli.tools import SkipThisTool -from kimi_cli.tools.utils import ToolResultBuilder, load_desc -from kimi_cli.utils.aiohttp import new_client_session - - -class Params(BaseModel): - query: str = Field(description="The query text to search for.") - limit: int = Field( - description=( - "The number of results to return. " - "Typically you do not need to set this value. " - "When the results do not contain what you need, " - "you probably want to give a more concrete query." - ), - default=5, - ge=1, - le=20, - ) - include_content: bool = Field( - description=( - "Whether to include the content of the web pages in the results. " - "It can consume a large amount of tokens when this is set to True. " - "You should avoid enabling this when `limit` is set to a large value." - ), - default=False, - ) - - -class SearchWeb(CallableTool2[Params]): - name: str = "SearchWeb" - description: str = load_desc(Path(__file__).parent / "search.md", {}) - params: type[Params] = Params - - def __init__(self, config: Config, runtime: Runtime): - super().__init__() - if config.services.moonshot_search is None: - raise SkipThisTool() - self._runtime = runtime - self._base_url = config.services.moonshot_search.base_url - self._api_key = config.services.moonshot_search.api_key - self._oauth_ref = config.services.moonshot_search.oauth - self._custom_headers = config.services.moonshot_search.custom_headers or {} - - @override - async def __call__(self, params: Params) -> ToolReturnValue: - builder = ToolResultBuilder(max_line_length=None) - - api_key = self._runtime.oauth.resolve_api_key(self._api_key, self._oauth_ref) - if not self._base_url or not api_key: - return builder.error( - "Search service is not configured. You may want to try other methods to search.", - brief="Search service not configured", - ) - - tool_call = get_current_tool_call_or_none() - assert tool_call is not None, "Tool call is expected to be set" - - try: - # Server-side timeout is 30s, but page crawling can take longer. - search_timeout = aiohttp.ClientTimeout(total=180, sock_read=90, sock_connect=15) - async with ( - new_client_session(timeout=search_timeout) as session, - session.post( - self._base_url, - headers={ - "User-Agent": USER_AGENT, - "Authorization": f"Bearer {api_key}", - "X-Msh-Tool-Call-Id": tool_call.id, - **self._runtime.oauth.common_headers(), - **self._custom_headers, - }, - json={ - "text_query": params.query, - "limit": params.limit, - "enable_page_crawling": params.include_content, - "timeout_seconds": 30, - }, - ) as response, - ): - if response.status != 200: - return builder.error( - ( - f"Failed to search. Status: {response.status}. " - "This may indicate that the search service is currently unavailable." - ), - brief="Failed to search", - ) - - try: - results = Response(**await response.json()).search_results - except ValidationError as e: - return builder.error( - ( - f"Failed to parse search results. Error: {e}. " - "This may indicate that the search service is currently unavailable." - ), - brief="Failed to parse search results", - ) - except TimeoutError: - return builder.error( - "Search request timed out. The search service may be slow or unavailable.", - brief="Search request timed out", - ) - except aiohttp.ClientError as e: - return builder.error( - f"Search request failed: {e}. The search service may be unavailable.", - brief="Search request failed", - ) - - for i, result in enumerate(results): - if i > 0: - builder.write("---\n\n") - builder.write( - f"Title: {result.title}\nDate: {result.date}\n" - f"URL: {result.url}\nSummary: {result.snippet}\n\n" - ) - if result.content: - builder.write(f"{result.content}\n\n") - - return builder.ok() - - -class SearchResult(BaseModel): - site_name: str - title: str - url: str - snippet: str - content: str = "" - date: str = "" - icon: str = "" - mime: str = "" - - -class Response(BaseModel): - search_results: list[SearchResult] diff --git a/src/kimi_cli/tools/web/search.ts b/src/kimi_cli/tools/web/search.ts new file mode 100644 index 000000000..e1da67dae --- /dev/null +++ b/src/kimi_cli/tools/web/search.ts @@ -0,0 +1,119 @@ +/** + * SearchWeb tool — web search via moonshot search service. + * Corresponds to Python tools/web/search.py + */ + +import { z } from "zod/v4"; +import { CallableTool } from "../base.ts"; +import type { ToolContext, ToolResult } from "../types.ts"; +import { ToolError, ToolResultBuilder } from "../types.ts"; + +const DESCRIPTION = + "WebSearch tool allows you to search on the internet to get latest information, including news, documents, release notes, blog posts, papers, etc."; + +const ParamsSchema = z.object({ + query: z.string().describe("The query text to search for."), + limit: z + .number() + .int() + .min(1) + .max(20) + .default(5) + .describe("The number of results to return."), + include_content: z + .boolean() + .default(false) + .describe( + "Whether to include the content of the web pages in the results. Can consume many tokens.", + ), +}); + +type Params = z.infer; + +interface SearchResult { + site_name: string; + title: string; + url: string; + snippet: string; + content?: string; + date?: string; +} + +export class SearchWeb extends CallableTool { + readonly name = "SearchWeb"; + readonly description = DESCRIPTION; + readonly schema = ParamsSchema; + + async execute(params: Params, ctx: ToolContext): Promise { + const builder = new ToolResultBuilder(50_000, null); + + const searchConfig = ctx.serviceConfig?.moonshotSearch; + if (!searchConfig?.baseUrl || !searchConfig?.apiKey) { + return builder.error( + "Search service is not configured. You may want to try other methods to search.", + ); + } + + try { + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 180_000); // 3 min total timeout + + const response = await fetch(searchConfig.baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${searchConfig.apiKey}`, + ...(searchConfig.customHeaders ?? {}), + }, + body: JSON.stringify({ + text_query: params.query, + limit: params.limit, + enable_page_crawling: params.include_content, + timeout_seconds: 30, + }), + signal: controller.signal, + }); + + clearTimeout(timeoutId); + + if (!response.ok) { + return builder.error( + `Failed to search. Status: ${response.status}. ` + + "This may indicate that the search service is currently unavailable.", + ); + } + + const data = (await response.json()) as { search_results?: SearchResult[] }; + const results = data.search_results ?? []; + + if (results.length === 0) { + return builder.ok("No search results found."); + } + + for (let i = 0; i < results.length; i++) { + const result = results[i]!; + if (i > 0) builder.write("---\n\n"); + builder.write( + `Title: ${result.title}\n` + + `Date: ${result.date ?? ""}\n` + + `URL: ${result.url}\n` + + `Summary: ${result.snippet}\n\n`, + ); + if (result.content) { + builder.write(`${result.content}\n\n`); + } + } + + return builder.ok(`Found ${results.length} search results.`); + } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + return builder.error( + "Search request timed out. The search service may be slow or unavailable.", + ); + } + return builder.error( + `Search request failed: ${err instanceof Error ? err.message : err}. The search service may be unavailable.`, + ); + } + } +} diff --git a/src/kimi_cli/types.ts b/src/kimi_cli/types.ts new file mode 100644 index 000000000..8fa01ab70 --- /dev/null +++ b/src/kimi_cli/types.ts @@ -0,0 +1,133 @@ +/** + * Shared types used across the codebase + * Corresponds to common types from Python's Pydantic models + */ + +import { z } from "zod/v4"; + +// ── Content Types (LLM message content) ────────────────── + +export const TextPart = z.object({ + type: z.literal("text"), + text: z.string(), +}); + +export const ImagePart = z.object({ + type: z.literal("image"), + source: z.object({ + type: z.enum(["base64", "url"]), + mediaType: z.string().optional(), + data: z.string(), + }), +}); + +export const ToolUsePart = z.object({ + type: z.literal("tool_use"), + id: z.string(), + name: z.string(), + input: z.record(z.string(), z.unknown()), +}); + +export const ToolResultPart = z.object({ + type: z.literal("tool_result"), + toolUseId: z.string(), + content: z.string(), + isError: z.boolean().optional(), +}); + +export const ContentPart = z.union([TextPart, ImagePart, ToolUsePart, ToolResultPart]); +export type ContentPart = z.infer; + +// ── Message Types ──────────────────────────────────────── + +export const Message = z.object({ + role: z.enum(["user", "assistant", "system", "tool"]), + content: z.union([z.string(), z.array(ContentPart)]), +}); +export type Message = z.infer; + +// ── Usage / Token Tracking ────────────────────────────── + +export const TokenUsage = z.object({ + inputTokens: z.number(), + outputTokens: z.number(), + cacheReadTokens: z.number().optional(), + cacheWriteTokens: z.number().optional(), +}); +export type TokenUsage = z.infer; + +// ── Model Capabilities ────────────────────────────────── + +export const ModelCapability = z.enum([ + "image_in", + "video_in", + "thinking", + "always_thinking", +]); +export type ModelCapability = z.infer; + +// ── Tool Types ────────────────────────────────────────── + +export const ToolCall = z.object({ + id: z.string(), + name: z.string(), + arguments: z.string(), // JSON string +}); +export type ToolCall = z.infer; + +export const ToolReturnValue = z.object({ + isError: z.boolean().default(false), + output: z.string(), + message: z.string().optional(), + display: z.array(z.unknown()).optional(), + extras: z.record(z.string(), z.unknown()).optional(), +}); +export type ToolReturnValue = z.infer; + +// ── Approval ──────────────────────────────────────────── + +export type ApprovalDecision = "approve" | "approve_for_session" | "reject"; + +// ── Status ────────────────────────────────────────────── + +export interface StatusSnapshot { + contextUsage: number | null; + contextTokens: number | null; + maxContextTokens: number | null; + tokenUsage: TokenUsage | null; + planMode: boolean; + mcpStatus: Record | null; +} + +// ── Slash Commands ────────────────────────────────────── + +export interface PanelChoiceItem { + label: string; + value: string; + description?: string; + current?: boolean; +} + +export type CommandPanelConfig = + | { type: "choice"; title: string; items: PanelChoiceItem[]; onSelect: (value: string) => CommandPanelConfig | Promise | void } + | { type: "content"; title: string; content: string } + | { type: "input"; title: string; placeholder?: string; password?: boolean; onSubmit: (value: string) => CommandPanelConfig | Promise | void }; + +export interface SlashCommand { + name: string; + description: string; + aliases?: string[]; + handler: (args: string) => Promise; + /** If defined, selecting from menu renders a secondary panel instead of executing handler */ + panel?: () => CommandPanelConfig | null; +} + +// ── JSON utility type ─────────────────────────────────── + +export type JsonValue = + | string + | number + | boolean + | null + | JsonValue[] + | { [key: string]: JsonValue }; diff --git a/src/kimi_cli/ui/__init__.py b/src/kimi_cli/ui/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/kimi_cli/ui/acp/__init__.py b/src/kimi_cli/ui/acp/__init__.py deleted file mode 100644 index b71668bdc..000000000 --- a/src/kimi_cli/ui/acp/__init__.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -from typing import Any, NoReturn - -import acp - -from kimi_cli.acp.types import ACPContentBlock, MCPServer -from kimi_cli.soul import Soul -from kimi_cli.utils.logging import logger - -_DEPRECATED_MESSAGE = ( - "`kimi --acp` is deprecated. " - "Update your ACP client settings to use `kimi acp` without any flags or options." -) - - -class ACPServerSingleSession: - def __init__(self, soul: Soul): - self.soul = soul - - def on_connect(self, conn: acp.Client) -> None: - logger.info("ACP client connected") - - def _raise(self) -> NoReturn: - logger.error(_DEPRECATED_MESSAGE) - raise acp.RequestError.invalid_params({"error": _DEPRECATED_MESSAGE}) - - async def initialize( - self, - protocol_version: int, - client_capabilities: acp.schema.ClientCapabilities | None = None, - client_info: acp.schema.Implementation | None = None, - **kwargs: Any, - ) -> acp.InitializeResponse: - self._raise() - - async def new_session( - self, cwd: str, mcp_servers: list[MCPServer] | None = None, **kwargs: Any - ) -> acp.NewSessionResponse: - self._raise() - - async def load_session( - self, cwd: str, session_id: str, mcp_servers: list[MCPServer] | None = None, **kwargs: Any - ) -> None: - self._raise() - - async def resume_session( - self, cwd: str, session_id: str, mcp_servers: list[MCPServer] | None = None, **kwargs: Any - ) -> acp.schema.ResumeSessionResponse: - self._raise() - - async def fork_session( - self, cwd: str, session_id: str, mcp_servers: list[MCPServer] | None = None, **kwargs: Any - ) -> acp.schema.ForkSessionResponse: - self._raise() - - async def list_sessions( - self, cursor: str | None = None, cwd: str | None = None, **kwargs: Any - ) -> acp.schema.ListSessionsResponse: - self._raise() - - async def set_session_mode( - self, mode_id: str, session_id: str, **kwargs: Any - ) -> acp.SetSessionModeResponse | None: - self._raise() - - async def set_session_model( - self, model_id: str, session_id: str, **kwargs: Any - ) -> acp.SetSessionModelResponse | None: - self._raise() - - async def authenticate(self, method_id: str, **kwargs: Any) -> acp.AuthenticateResponse | None: - self._raise() - - async def prompt( - self, prompt: list[ACPContentBlock], session_id: str, **kwargs: Any - ) -> acp.PromptResponse: - self._raise() - - async def cancel(self, session_id: str, **kwargs: Any) -> None: - self._raise() - - async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]: - self._raise() - - async def ext_notification(self, method: str, params: dict[str, Any]) -> None: - self._raise() - - -class ACP: - """ACP server using the official acp library.""" - - def __init__(self, soul: Soul): - self.soul = soul - - async def run(self): - """Run the ACP server.""" - logger.info("Starting ACP server (single session) on stdio") - await acp.run_agent(ACPServerSingleSession(self.soul)) diff --git a/src/kimi_cli/ui/components/ApprovalPrompt.tsx b/src/kimi_cli/ui/components/ApprovalPrompt.tsx new file mode 100644 index 000000000..eac3d26fb --- /dev/null +++ b/src/kimi_cli/ui/components/ApprovalPrompt.tsx @@ -0,0 +1,105 @@ +/** + * ApprovalPrompt component — approval request UI. + * Corresponds to Python's approval_panel.py. + * + * Shows: action description, display blocks, and approval choices. + * [y] Allow / [n] Deny / [a] Always allow + */ + +import React, { useCallback } from "react"; +import { Box, Text, useInput } from "ink"; +import { getMessageColors } from "../theme"; +import type { ApprovalRequest, ApprovalResponseKind } from "../../wire/types"; + +interface ApprovalPromptProps { + request: ApprovalRequest; + onRespond: (decision: ApprovalResponseKind, feedback?: string) => void; +} + +export function ApprovalPrompt({ request, onRespond }: ApprovalPromptProps) { + const colors = getMessageColors(); + + useInput((input, key) => { + switch (input.toLowerCase()) { + case "y": + onRespond("approve"); + break; + case "n": + onRespond("reject"); + break; + case "a": + onRespond("approve_for_session"); + break; + } + }); + + return ( + + + ⚠ Approval Required + + + {/* Source info */} + {request.source_description && ( + + From: {request.source_description} + + )} + + {/* Action */} + + + {request.action} + + + + {/* Description */} + {request.description} + + {/* Display blocks preview */} + {request.display.length > 0 && ( + + {request.display.slice(0, 3).map((block, idx) => { + if (block.type === "brief") { + return ( + + {(block as { brief: string }).brief} + + ); + } + if (block.type === "shell") { + return ( + + $ {(block as { command: string }).command} + + ); + } + return null; + })} + + )} + + {/* Choices */} + + + [y] + Allow + + + [n] + Deny + + + [a] + Always + + + + ); +} diff --git a/src/kimi_cli/ui/components/CommandPanel.tsx b/src/kimi_cli/ui/components/CommandPanel.tsx new file mode 100644 index 000000000..3de4df9a9 --- /dev/null +++ b/src/kimi_cli/ui/components/CommandPanel.tsx @@ -0,0 +1,288 @@ +/** + * CommandPanel.tsx — Secondary interactive panel for slash commands. + * Renders below the input area when a command needs a secondary menu. + * + * Supports three modes: + * - "choice": selectable list (↑↓ navigate, Enter select, Esc close) + * - "content": scrollable text (↑↓ scroll, Esc close) + * - "input": text input field (type, Enter submit, Esc close) + * + * Panels can chain: onSelect/onSubmit may return a new CommandPanelConfig + * to transition to the next step (wizard pattern). + */ + +import React, { useState, useCallback } from "react"; +import { Box, Text, useInput, useStdout } from "ink"; +import TextInput from "ink-text-input"; +import type { CommandPanelConfig } from "../../types.ts"; + +const DIM = "#888888"; +const HIGHLIGHT = "#1e90ff"; +const BORDER_COLOR = "#555555"; + +interface CommandPanelProps { + config: CommandPanelConfig; + onClose: () => void; +} + +export function CommandPanel({ config, onClose }: CommandPanelProps) { + // Support panel transitions: when a child panel returns a new config, + // we replace the current config with the new one. + const [currentConfig, setCurrentConfig] = useState(config); + + // Reset when external config changes (e.g. opening a different command panel) + React.useEffect(() => { + setCurrentConfig(config); + }, [config]); + + const handleTransition = useCallback( + (result: CommandPanelConfig | Promise | void) => { + if (!result) return; + if (result instanceof Promise) { + result.then((next) => { + if (next) setCurrentConfig(next); + }); + } else { + setCurrentConfig(result); + } + }, + [], + ); + + if (currentConfig.type === "choice") { + return ; + } + if (currentConfig.type === "input") { + return ; + } + return ; +} + +// ── Choice Panel ──────────────────────────────────────── + +function ChoicePanel({ + config, + onClose, + onTransition, +}: { + config: Extract; + onClose: () => void; + onTransition: (result: CommandPanelConfig | Promise | void) => void; +}) { + const { items, onSelect, title } = config; + const defaultIndex = items.findIndex((item) => item.current); + const [selectedIndex, setSelectedIndex] = useState( + defaultIndex >= 0 ? defaultIndex : 0, + ); + const { stdout } = useStdout(); + const columns = stdout?.columns ?? 80; + + useInput( + useCallback( + (_input: string, key: { upArrow?: boolean; downArrow?: boolean; return?: boolean; escape?: boolean }) => { + if (key.escape) { + onClose(); + return; + } + if (key.upArrow) { + setSelectedIndex((i) => Math.max(0, i - 1)); + return; + } + if (key.downArrow) { + setSelectedIndex((i) => Math.min(items.length - 1, i + 1)); + return; + } + if (key.return) { + const item = items[selectedIndex]; + if (item) { + const result = onSelect(item.value); + if (result) { + onTransition(result); + } else { + onClose(); + } + } + return; + } + }, + [items, selectedIndex, onSelect, onClose, onTransition], + ), + ); + + return ( + + {"─".repeat(columns)} + + + {title} + + (↑↓ select, Enter confirm, Esc cancel) + + {"─".repeat(columns)} + {items.map((item, i) => { + const isSelected = i === selectedIndex; + return ( + + + {isSelected ? "▸ " : " "} + + + {item.label} + + {item.description && ( + {" " + item.description} + )} + {item.current && ( + (current) + )} + + ); + })} + + ); +} + +// ── Input Panel ───────────────────────────────────────── + +function InputPanel({ + config, + onClose, + onTransition, +}: { + config: Extract; + onClose: () => void; + onTransition: (result: CommandPanelConfig | Promise | void) => void; +}) { + const { title, placeholder, password, onSubmit } = config; + const [value, setValue] = useState(""); + const { stdout } = useStdout(); + const columns = stdout?.columns ?? 80; + + useInput( + useCallback( + (_input: string, key: { escape?: boolean }) => { + if (key.escape) { + onClose(); + } + }, + [onClose], + ), + ); + + const handleSubmit = useCallback( + (input: string) => { + const trimmed = input.trim(); + if (!trimmed) return; + const result = onSubmit(trimmed); + if (result) { + onTransition(result); + } else { + onClose(); + } + }, + [onSubmit, onClose, onTransition], + ); + + // For password fields, mask the input + const displayValue = password ? "•".repeat(value.length) : value; + + return ( + + {"─".repeat(columns)} + + + {title} + + (Enter submit, Esc cancel) + + {"─".repeat(columns)} + + {"▸ "} + {password ? ( + + ) : ( + + )} + + + ); +} + +// ── Content Panel ─────────────────────────────────────── + +function ContentPanel({ + config, + onClose, +}: { + config: Extract; + onClose: () => void; +}) { + const { content, title } = config; + const { stdout } = useStdout(); + const columns = stdout?.columns ?? 80; + const maxVisibleLines = Math.max((stdout?.rows ?? 24) - 8, 10); + + const lines = content.split("\n"); + const [scrollOffset, setScrollOffset] = useState(0); + const maxScroll = Math.max(0, lines.length - maxVisibleLines); + + useInput( + useCallback( + (_input: string, key: { upArrow?: boolean; downArrow?: boolean; escape?: boolean }) => { + if (key.escape) { + onClose(); + return; + } + if (key.upArrow) { + setScrollOffset((o) => Math.max(0, o - 1)); + return; + } + if (key.downArrow) { + setScrollOffset((o) => Math.min(maxScroll, o + 1)); + return; + } + }, + [maxScroll, onClose], + ), + ); + + const visibleLines = lines.slice(scrollOffset, scrollOffset + maxVisibleLines); + const hasMore = scrollOffset < maxScroll; + + return ( + + {"─".repeat(columns)} + + + {title} + + (↑↓ scroll, Esc close) + {maxScroll > 0 && ( + + {` [${scrollOffset + 1}-${Math.min(scrollOffset + maxVisibleLines, lines.length)}/${lines.length}]`} + + )} + + {"─".repeat(columns)} + + {visibleLines.map((line, i) => ( + {line || " "} + ))} + + {hasMore && ( + ↓ more... + )} + + ); +} diff --git a/src/kimi_cli/ui/components/SlashMenu.tsx b/src/kimi_cli/ui/components/SlashMenu.tsx new file mode 100644 index 000000000..9524b83b1 --- /dev/null +++ b/src/kimi_cli/ui/components/SlashMenu.tsx @@ -0,0 +1,90 @@ +/** + * SlashMenu.tsx — Slash command completion menu. + * Renders a list of matching commands when user types '/'. + * Corresponds to Python's SlashCommandCompletionMenu. + */ + +import React from "react"; +import { Box, Text, useStdout } from "ink"; +import type { SlashCommand } from "../../types.ts"; + +const DIM = "#888888"; +const HIGHLIGHT_BG = "#1e90ff"; + +interface SlashMenuProps { + /** All available commands */ + commands: SlashCommand[]; + /** Current filter text (what user typed after '/') */ + filter: string; + /** Currently selected index */ + selectedIndex: number; +} + +export function SlashMenu({ commands, filter, selectedIndex }: SlashMenuProps) { + const { stdout } = useStdout(); + const columns = stdout?.columns ?? 80; + + // Fuzzy filter commands + const filtered = filterCommands(commands, filter); + + if (filtered.length === 0) return null; + + const separator = "─".repeat(columns); + + return ( + + {separator} + {filtered.map((cmd, i) => { + const isSelected = i === selectedIndex; + return ( + + + {isSelected ? "▸ " : " "} + + + /{cmd.name} + + + {" " + cmd.description} + + + ); + })} + + ); +} + +/** Fuzzy-filter commands by name or alias */ +function filterCommands( + commands: SlashCommand[], + filter: string, +): SlashCommand[] { + if (!filter) return commands; + + const lower = filter.toLowerCase(); + return commands.filter((cmd) => { + if (cmd.name.toLowerCase().includes(lower)) return true; + if (cmd.aliases) { + return cmd.aliases.some((a) => a.toLowerCase().includes(lower)); + } + return false; + }); +} + +/** Get filtered command count (used by parent to know menu size) */ +export function getFilteredCommandCount( + commands: SlashCommand[], + filter: string, +): number { + return filterCommands(commands, filter).length; +} + +/** Get the command at a given index after filtering */ +export function getFilteredCommand( + commands: SlashCommand[], + filter: string, + index: number, +): SlashCommand | undefined { + const filtered = filterCommands(commands, filter); + return filtered[index]; +} diff --git a/src/kimi_cli/ui/components/Spinner.tsx b/src/kimi_cli/ui/components/Spinner.tsx new file mode 100644 index 000000000..edc77c274 --- /dev/null +++ b/src/kimi_cli/ui/components/Spinner.tsx @@ -0,0 +1,54 @@ +/** + * Spinner component — loading indicators. + * Uses ink-spinner for animated spinners. + */ + +import React from "react"; +import { Box, Text } from "ink"; +import InkSpinner from "ink-spinner"; +import { getMessageColors } from "../theme"; + +interface SpinnerProps { + /** Text to display next to the spinner */ + label?: string; + /** Spinner color */ + color?: string; +} + +export function Spinner({ label = "Thinking...", color }: SpinnerProps) { + const colors = getMessageColors(); + const spinnerColor = color || colors.highlight; + + return ( + + + + + {label && ( + {label} + )} + + ); +} + +interface CompactionSpinnerProps { + /** Whether compaction is in progress */ + active: boolean; +} + +export function CompactionSpinner({ active }: CompactionSpinnerProps) { + if (!active) return null; + return ; +} + +interface StreamingSpinnerProps { + stepCount: number; +} + +export function StreamingSpinner({ stepCount }: StreamingSpinnerProps) { + return ( + 0 ? `Thinking... (step ${stepCount})` : "Thinking..."} + /> + ); +} diff --git a/src/kimi_cli/ui/components/StatusBar.tsx b/src/kimi_cli/ui/components/StatusBar.tsx new file mode 100644 index 000000000..cb742dd25 --- /dev/null +++ b/src/kimi_cli/ui/components/StatusBar.tsx @@ -0,0 +1,110 @@ +/** + * StatusBar component — bottom status bar. + * Matches Python's toolbar: separator line + single status line. + * + * Layout: + * ────────────────────────────────────────────────────── + * agent (kimi-k2.5 ●) ~/workdir main context: 0.0% + */ + +import React from "react"; +import { Box, Text, useStdout } from "ink"; +import type { StatusUpdate } from "../../wire/types"; + +const DIM = "#888888"; + +interface StatusBarProps { + modelName?: string; + workDir?: string; + status: StatusUpdate | null; + isStreaming: boolean; + stepCount: number; + isCompacting?: boolean; + planMode?: boolean; + yolo?: boolean; + thinking?: boolean; +} + +export function StatusBar({ + modelName = "", + workDir, + status, + isStreaming, + stepCount, + isCompacting = false, + planMode = false, + yolo = false, + thinking = false, +}: StatusBarProps) { + const { stdout } = useStdout(); + const columns = stdout?.columns ?? 80; + + // Context usage + const contextUsage = status?.context_usage; + const contextPercent = + contextUsage != null ? (contextUsage * 100).toFixed(1) : "0.0"; + + // Shorten workDir + const home = process.env.HOME || process.env.USERPROFILE || ""; + const displayDir = workDir + ? workDir.startsWith(home) + ? "~" + workDir.slice(home.length) + : workDir + : ""; + + // Build left section: [yolo] [plan] agent (model ●) + const leftParts: string[] = []; + if (yolo) leftParts.push("yolo"); + if (planMode) leftParts.push("plan"); + + const thinkingDot = thinking ? "●" : "○"; + const modeStr = modelName + ? `agent (${modelName} ${thinkingDot})` + : "agent"; + leftParts.push(modeStr); + const leftText = leftParts.join(" "); + + // Build right section + const rightText = `context: ${contextPercent}%`; + + // Separator + const separator = "─".repeat(columns); + + return ( + + {separator} + + + {yolo && ( + + yolo + + )} + {planMode && ( + + plan + + )} + {modeStr} + {displayDir && {displayDir}} + {isStreaming && ( + step {stepCount} + )} + {isCompacting && compacting...} + + + + shift-tab: plan mode | ctrl-o: editor + + {rightText} + + + + ); +} + +function formatTokenCount(count: number): string { + if (count < 1000) return String(count); + if (count < 1_000_000) return `${(count / 1000).toFixed(1)}k`; + return `${(count / 1_000_000).toFixed(1)}M`; +} diff --git a/src/kimi_cli/ui/components/WelcomeBox.tsx b/src/kimi_cli/ui/components/WelcomeBox.tsx new file mode 100644 index 000000000..fa2797c8d --- /dev/null +++ b/src/kimi_cli/ui/components/WelcomeBox.tsx @@ -0,0 +1,91 @@ +/** + * WelcomeBox.tsx — Welcome panel displayed on first render. + * Matches Python's welcome box layout with logo, directory, session, model info. + */ + +import React from "react"; +import { Box, Text } from "ink"; + +const KIMI_BLUE = "#1e90ff"; + +interface WelcomeBoxProps { + workDir?: string; + sessionId?: string; + modelName?: string; + tip?: string; +} + +export function WelcomeBox({ + workDir, + sessionId, + modelName, + tip, +}: WelcomeBoxProps) { + // Shorten home directory + const home = process.env.HOME || process.env.USERPROFILE || ""; + const displayDir = workDir + ? workDir.startsWith(home) + ? "~" + workDir.slice(home.length) + : workDir + : "~"; + + return ( + + {/* Logo + Welcome */} + + + ▐█▛█▛█▌ + ▐█████▌ + + + Welcome to Kimi Code CLI! + Send /help for help information. + + + + {/* Blank line */} + + + {/* Directory */} + + Directory: + {displayDir} + + + {/* Session */} + {sessionId && ( + + Session: + {sessionId} + + )} + + {/* Model */} + + Model: + {modelName ? ( + {modelName} + ) : ( + not set, send /login to login + )} + + + {/* Tip */} + {tip && ( + <> + + + Tip: + {tip} + + + )} + + ); +} diff --git a/src/kimi_cli/ui/hooks/index.ts b/src/kimi_cli/ui/hooks/index.ts new file mode 100644 index 000000000..e6ee5b445 --- /dev/null +++ b/src/kimi_cli/ui/hooks/index.ts @@ -0,0 +1,6 @@ +export { useWire } from "./useWire"; +export type { WireState, UseWireOptions } from "./useWire"; +export { useApproval, createApprovalManager } from "./useApproval"; +export type { ApprovalState, UseApprovalOptions } from "./useApproval"; +export { useInputHistory } from "./useInput"; +export type { InputHistoryState } from "./useInput"; diff --git a/src/kimi_cli/ui/hooks/useApproval.ts b/src/kimi_cli/ui/hooks/useApproval.ts new file mode 100644 index 000000000..1c6aae74f --- /dev/null +++ b/src/kimi_cli/ui/hooks/useApproval.ts @@ -0,0 +1,66 @@ +/** + * useApproval hook — manages approval request state machine. + * Corresponds to approval handling in Python's visualize.py. + */ + +import { useState, useCallback } from "react"; +import type { ApprovalRequest, ApprovalResponseKind } from "../../wire/types"; + +export interface ApprovalState { + pending: ApprovalRequest | null; + respond: (decision: ApprovalResponseKind, feedback?: string) => void; + dismiss: () => void; +} + +export interface UseApprovalOptions { + onRespond?: ( + requestId: string, + decision: ApprovalResponseKind, + feedback?: string, + ) => void; +} + +/** + * Hook for managing approval request lifecycle. + */ +export function useApproval(options?: UseApprovalOptions): ApprovalState { + const [pending, setPending] = useState(null); + + const respond = useCallback( + (decision: ApprovalResponseKind, feedback?: string) => { + if (!pending) return; + options?.onRespond?.(pending.id, decision, feedback); + setPending(null); + }, + [pending, options], + ); + + const dismiss = useCallback(() => { + setPending(null); + }, []); + + return { pending, respond, dismiss }; +} + +/** + * Set the pending approval from external source (e.g., wire events). + * This is used by the Shell to inject approval requests into the hook. + */ +export function createApprovalManager(options?: UseApprovalOptions) { + let _pending: ApprovalRequest | null = null; + let _setPending: ((req: ApprovalRequest | null) => void) | null = null; + + return { + setPendingRef: (setter: (req: ApprovalRequest | null) => void) => { + _setPending = setter; + }, + enqueue: (request: ApprovalRequest) => { + _pending = request; + _setPending?.(request); + }, + clear: () => { + _pending = null; + _setPending?.(null); + }, + }; +} diff --git a/src/kimi_cli/ui/hooks/useInput.ts b/src/kimi_cli/ui/hooks/useInput.ts new file mode 100644 index 000000000..fc4c0623a --- /dev/null +++ b/src/kimi_cli/ui/hooks/useInput.ts @@ -0,0 +1,103 @@ +/** + * useInputHistory hook — manages input history and slash command parsing. + * Corresponds to history logic in Python's prompt.py. + */ + +import { useState, useCallback, useRef } from "react"; +import type { SlashCommand } from "../../types"; + +export interface InputHistoryState { + /** Current input value */ + value: string; + /** Set input value */ + setValue: (v: string) => void; + /** Navigate to previous history entry */ + historyPrev: () => void; + /** Navigate to next history entry */ + historyNext: () => void; + /** Add current value to history */ + addToHistory: (entry: string) => void; + /** Check if current input is a slash command */ + isSlashCommand: boolean; + /** Parse slash command name and args */ + parseSlashCommand: () => { name: string; args: string } | null; +} + +/** + * Hook for input history management and slash command parsing. + */ +export function useInputHistory(maxHistory = 100): InputHistoryState { + const [value, setValue] = useState(""); + const history = useRef([]); + const historyIndex = useRef(-1); + const savedInput = useRef(""); + + const addToHistory = useCallback( + (entry: string) => { + const trimmed = entry.trim(); + if (!trimmed) return; + // Deduplicate: remove if already exists at end + if ( + history.current.length > 0 && + history.current[history.current.length - 1] === trimmed + ) { + // Already the last entry + } else { + history.current.push(trimmed); + if (history.current.length > maxHistory) { + history.current.shift(); + } + } + historyIndex.current = -1; + savedInput.current = ""; + }, + [maxHistory], + ); + + const historyPrev = useCallback(() => { + if (history.current.length === 0) return; + if (historyIndex.current === -1) { + savedInput.current = value; + historyIndex.current = history.current.length - 1; + } else if (historyIndex.current > 0) { + historyIndex.current -= 1; + } + setValue(history.current[historyIndex.current] ?? ""); + }, [value]); + + const historyNext = useCallback(() => { + if (historyIndex.current === -1) return; + if (historyIndex.current < history.current.length - 1) { + historyIndex.current += 1; + setValue(history.current[historyIndex.current] ?? ""); + } else { + historyIndex.current = -1; + setValue(savedInput.current); + } + }, []); + + const isSlashCommand = value.startsWith("/"); + + const parseSlashCommand = useCallback(() => { + if (!value.startsWith("/")) return null; + const trimmed = value.slice(1).trim(); + const spaceIdx = trimmed.indexOf(" "); + if (spaceIdx === -1) { + return { name: trimmed, args: "" }; + } + return { + name: trimmed.slice(0, spaceIdx), + args: trimmed.slice(spaceIdx + 1).trim(), + }; + }, [value]); + + return { + value, + setValue, + historyPrev, + historyNext, + addToHistory, + isSlashCommand, + parseSlashCommand, + }; +} diff --git a/src/kimi_cli/ui/hooks/useWire.ts b/src/kimi_cli/ui/hooks/useWire.ts new file mode 100644 index 000000000..182dd6cd9 --- /dev/null +++ b/src/kimi_cli/ui/hooks/useWire.ts @@ -0,0 +1,221 @@ +/** + * useWire hook — subscribes to Wire EventBus and accumulates renderable messages. + * Corresponds to the event-processing logic in Python's visualize.py. + */ + +import { useState, useEffect, useCallback, useRef } from "react"; +import type { + UIMessage, + WireUIEvent, + TextSegment, + ThinkSegment, + ToolCallSegment, +} from "../shell/events"; +import type { StatusUpdate, ApprovalRequest } from "../../wire/types"; +import { nanoid } from "nanoid"; + +export interface WireState { + messages: UIMessage[]; + isStreaming: boolean; + pendingApproval: ApprovalRequest | null; + status: StatusUpdate | null; + stepCount: number; + isCompacting: boolean; +} + +export interface UseWireOptions { + /** External event source — call pushEvent to feed events */ + onReady?: (pushEvent: (event: WireUIEvent) => void) => void; +} + +/** + * Hook that accumulates wire events into a renderable message list. + */ +export function useWire(options?: UseWireOptions): WireState & { + pushEvent: (event: WireUIEvent) => void; + clearMessages: () => void; +} { + const [messages, setMessages] = useState([]); + const [isStreaming, setIsStreaming] = useState(false); + const [pendingApproval, setPendingApproval] = + useState(null); + const [status, setStatus] = useState(null); + const [stepCount, setStepCount] = useState(0); + const [isCompacting, setIsCompacting] = useState(false); + + // Use ref for current assistant message being built + const currentAssistantRef = useRef(null); + + const pushEvent = useCallback((event: WireUIEvent) => { + switch (event.type) { + case "turn_begin": { + // Add user message + const userMsg: UIMessage = { + id: nanoid(), + role: "user", + segments: [{ type: "text", text: event.userInput }], + timestamp: Date.now(), + }; + setMessages((prev) => [...prev, userMsg]); + setIsStreaming(true); + setStepCount(0); + // Start new assistant message + const assistantMsg: UIMessage = { + id: nanoid(), + role: "assistant", + segments: [], + timestamp: Date.now(), + }; + currentAssistantRef.current = assistantMsg; + setMessages((prev) => [...prev, assistantMsg]); + break; + } + + case "turn_end": { + currentAssistantRef.current = null; + setIsStreaming(false); + break; + } + + case "step_begin": { + setStepCount(event.n); + break; + } + + case "step_interrupted": { + setIsStreaming(false); + break; + } + + case "text_delta": { + if (!currentAssistantRef.current) break; + const msg = currentAssistantRef.current; + const lastSeg = msg.segments[msg.segments.length - 1]; + if (lastSeg && lastSeg.type === "text") { + (lastSeg as TextSegment).text += event.text; + } else { + msg.segments.push({ type: "text", text: event.text }); + } + setMessages((prev) => [...prev.slice(0, -1), { ...msg }]); + break; + } + + case "think_delta": { + if (!currentAssistantRef.current) break; + const msg = currentAssistantRef.current; + const lastSeg = msg.segments[msg.segments.length - 1]; + if (lastSeg && lastSeg.type === "think") { + (lastSeg as ThinkSegment).text += event.text; + } else { + msg.segments.push({ type: "think", text: event.text }); + } + setMessages((prev) => [...prev.slice(0, -1), { ...msg }]); + break; + } + + case "tool_call": { + if (!currentAssistantRef.current) break; + const msg = currentAssistantRef.current; + msg.segments.push({ + type: "tool_call", + id: event.id, + name: event.name, + arguments: event.arguments, + collapsed: false, + }); + setMessages((prev) => [...prev.slice(0, -1), { ...msg }]); + break; + } + + case "tool_result": { + if (!currentAssistantRef.current) break; + const msg = currentAssistantRef.current; + const toolSeg = msg.segments.find( + (s) => + s.type === "tool_call" && + (s as ToolCallSegment).id === event.toolCallId, + ) as ToolCallSegment | undefined; + if (toolSeg) { + toolSeg.result = event.result; + toolSeg.collapsed = true; + } + setMessages((prev) => [...prev.slice(0, -1), { ...msg }]); + break; + } + + case "approval_request": { + setPendingApproval(event.request); + break; + } + + case "approval_response": { + setPendingApproval(null); + break; + } + + case "status_update": { + setStatus(event.status); + break; + } + + case "compaction_begin": { + setIsCompacting(true); + break; + } + + case "compaction_end": { + setIsCompacting(false); + break; + } + + case "notification": { + const sysMsg: UIMessage = { + id: nanoid(), + role: "system", + segments: [ + { type: "text", text: `${event.title}: ${event.body}` }, + ], + timestamp: Date.now(), + }; + setMessages((prev) => [...prev, sysMsg]); + break; + } + + case "error": { + const errMsg: UIMessage = { + id: nanoid(), + role: "system", + segments: [{ type: "text", text: `Error: ${event.message}` }], + timestamp: Date.now(), + }; + setMessages((prev) => [...prev, errMsg]); + setIsStreaming(false); + break; + } + } + }, []); + + const clearMessages = useCallback(() => { + setMessages([]); + currentAssistantRef.current = null; + setIsStreaming(false); + setPendingApproval(null); + setStepCount(0); + }, []); + + // Notify caller that pushEvent is ready + useEffect(() => { + options?.onReady?.(pushEvent); + }, [pushEvent, options]); + + return { + messages, + isStreaming, + pendingApproval, + status, + stepCount, + isCompacting, + pushEvent, + clearMessages, + }; +} diff --git a/src/kimi_cli/ui/print/__init__.py b/src/kimi_cli/ui/print/__init__.py deleted file mode 100644 index 6242c33af..000000000 --- a/src/kimi_cli/ui/print/__init__.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import sys -from functools import partial -from pathlib import Path - -from kosong.chat_provider import ( - APIConnectionError, - APIEmptyResponseError, - APIStatusError, - APITimeoutError, - ChatProviderError, -) -from kosong.message import Message -from rich import print - -from kimi_cli.cli import ExitCode, InputFormat, OutputFormat -from kimi_cli.soul import ( - LLMNotSet, - LLMNotSupported, - MaxStepsReached, - RunCancelled, - Soul, - run_soul, -) -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.ui.print.visualize import visualize -from kimi_cli.utils.logging import logger -from kimi_cli.utils.signals import install_sigint_handler - - -class Print: - """ - An app implementation that prints the agent behavior to the console. - - Args: - soul (Soul): The soul to run. - input_format (InputFormat): The input format to use. - output_format (OutputFormat): The output format to use. - context_file (Path): The file to store the context. - final_only (bool): Whether to only print the final assistant message. - """ - - def __init__( - self, - soul: Soul, - input_format: InputFormat, - output_format: OutputFormat, - context_file: Path, - *, - final_only: bool = False, - ): - self.soul = soul - self.input_format: InputFormat = input_format - self.output_format: OutputFormat = output_format - self.context_file = context_file - self.final_only = final_only - - async def run(self, command: str | None = None) -> int: - cancel_event = asyncio.Event() - - def _handler(): - logger.debug("SIGINT received.") - cancel_event.set() - - loop = asyncio.get_running_loop() - remove_sigint = install_sigint_handler(loop, _handler) - - if command is None and not sys.stdin.isatty() and self.input_format == "text": - command = sys.stdin.read().strip() - logger.info("Read command from stdin: {command}", command=command) - - try: - while True: - if command is None: - if self.input_format == "text": - return ExitCode.SUCCESS - else: - assert self.input_format == "stream-json" - command = self._read_next_command() - if command is None: - return ExitCode.SUCCESS - - if command: - logger.info("Running agent with command: {command}", command=command) - if self.output_format == "text" and not self.final_only: - print(command) - runtime = self.soul.runtime if isinstance(self.soul, KimiSoul) else None - await run_soul( - self.soul, - command, - partial(visualize, self.output_format, self.final_only), - cancel_event, - runtime.session.wire_file if runtime else None, - runtime, - ) - else: - logger.info("Empty command, skipping") - - command = None - except LLMNotSet as e: - logger.exception("LLM not set:") - print(str(e)) - return ExitCode.FAILURE - except LLMNotSupported as e: - logger.exception("LLM not supported:") - print(str(e)) - return ExitCode.FAILURE - except ChatProviderError as e: - logger.exception("LLM provider error:") - print(str(e)) - return self._classify_provider_error(e) - except MaxStepsReached as e: - logger.warning("Max steps reached: {n_steps}", n_steps=e.n_steps) - print(str(e)) - return ExitCode.FAILURE - except RunCancelled: - logger.error("Interrupted by user") - print("Interrupted by user") - return ExitCode.FAILURE - except BaseException as e: - logger.exception("Unknown error:") - print(f"Unknown error: {e}") - raise - finally: - remove_sigint() - return ExitCode.FAILURE - - _RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504} - - @staticmethod - def _classify_provider_error(e: ChatProviderError) -> int: - """Classify a ChatProviderError into an exit code.""" - if isinstance(e, (APIConnectionError, APITimeoutError, APIEmptyResponseError)): - return ExitCode.RETRYABLE - if isinstance(e, APIStatusError): - if e.status_code in Print._RETRYABLE_STATUS_CODES: - return ExitCode.RETRYABLE - return ExitCode.FAILURE - return ExitCode.FAILURE - - def _read_next_command(self) -> str | None: - while True: - json_line = sys.stdin.readline() - if not json_line: - # EOF - return None - - json_line = json_line.strip() - if not json_line: - # for empty line, read next line - continue - - try: - data = json.loads(json_line) - message = Message.model_validate(data) - if message.role == "user": - return message.extract_text(sep="\n") - logger.warning( - "Ignoring message with role `{role}`: {json_line}", - role=message.role, - json_line=json_line, - ) - except Exception: - logger.warning("Ignoring invalid user message: {json_line}", json_line=json_line) diff --git a/src/kimi_cli/ui/print/index.ts b/src/kimi_cli/ui/print/index.ts new file mode 100644 index 000000000..26018ef15 --- /dev/null +++ b/src/kimi_cli/ui/print/index.ts @@ -0,0 +1,464 @@ +/** + * Print mode — non-interactive output. + * Corresponds to Python's ui/print/__init__.py + ui/print/visualize.py. + * + * Provides multiple printer strategies: + * - TextPrinter: prints wire events as rich text + * - JsonPrinter: outputs JSON messages with content merging + * - FinalOnlyTextPrinter: only prints the final assistant text + * - FinalOnlyJsonPrinter: only prints the final assistant message as JSON + * - PrintMode: legacy event-based printer (wraps the above) + */ + +import chalk from "chalk"; +import type { WireUIEvent } from "../shell/events"; + +export type OutputFormat = "text" | "stream-json"; + +export interface PrintOptions { + outputFormat: OutputFormat; + finalOnly: boolean; +} + +// ── Printer Protocol ──────────────────────────────────── + +export interface Printer { + feed(event: WireUIEvent): void; + flush(): void; +} + +// ── Content Part Merging ──────────────────────────────── + +interface ContentBuffer { + type: "text" | "think"; + text: string; +} + +function mergeContent(buffer: ContentBuffer[], part: ContentBuffer): void { + const last = buffer[buffer.length - 1]; + if (last && last.type === part.type) { + last.text += part.text; + } else { + buffer.push({ ...part }); + } +} + +// ── TextPrinter ───────────────────────────────────────── + +export class TextPrinter implements Printer { + feed(event: WireUIEvent): void { + switch (event.type) { + case "text_delta": + process.stdout.write(event.text); + break; + case "think_delta": + process.stdout.write(chalk.italic.grey(event.text)); + break; + case "tool_call": + process.stderr.write( + chalk.dim(`[tool] ${event.name}(${truncateStr(event.arguments, 60)})\n`), + ); + break; + case "tool_call_delta": + // Streaming tool call args — ignore in text mode + break; + case "plan_display": + process.stdout.write(chalk.blue.bold("📋 Plan") + chalk.grey(` (${(event as any).filePath})`) + "\n"); + process.stdout.write((event as any).content + "\n"); + break; + case "tool_result": + if (event.result.return_value.isError) { + process.stderr.write( + chalk.red(`[error] ${truncateStr(event.result.return_value.output, 100)}\n`), + ); + } + break; + case "step_begin": + break; + case "step_interrupted": + process.stderr.write(chalk.yellow("[interrupted]\n")); + break; + case "error": + process.stderr.write(chalk.red(`Error: ${event.message}\n`)); + break; + case "notification": { + const sev = (event as any).severity; + const prefix = sev === "error" ? chalk.red("[error]") : sev === "warning" ? chalk.yellow("[warn]") : chalk.dim(`[${event.title}]`); + process.stderr.write(`${prefix} ${event.body}\n`); + break; + } + case "turn_end": + process.stdout.write("\n"); + break; + } + } + + flush(): void {} +} + +// ── JsonPrinter ───────────────────────────────────────── + +export class JsonPrinter implements Printer { + private contentBuffer: ContentBuffer[] = []; + private toolCalls: Array<{ id: string; name: string; arguments: string }> = []; + private pendingNotifications: Array<{ title: string; body: string }> = []; + + feed(event: WireUIEvent): void { + switch (event.type) { + case "step_begin": + case "step_interrupted": + this.flush(); + break; + case "notification": + if (this.contentBuffer.length > 0 || this.toolCalls.length > 0) { + this.pendingNotifications.push({ title: event.title, body: event.body }); + } else { + this.flushAssistantMessage(); + this.flushNotifications(); + this.emitJson({ type: "notification", title: event.title, body: event.body }); + } + break; + case "text_delta": + mergeContent(this.contentBuffer, { type: "text", text: event.text }); + break; + case "think_delta": + mergeContent(this.contentBuffer, { type: "think", text: event.text }); + break; + case "tool_call": + this.toolCalls.push({ id: event.id, name: event.name, arguments: event.arguments }); + break; + case "tool_result": + this.flushAssistantMessage(); + this.flushNotifications(); + this.emitJson({ + role: "tool", + tool_call_id: event.toolCallId, + content: event.result.return_value.output, + is_error: event.result.return_value.isError, + }); + break; + case "plan_display": + this.flushAssistantMessage(); + this.flushNotifications(); + this.emitJson({ + type: "plan_display", + content: (event as any).content, + file_path: (event as any).filePath, + }); + break; + case "error": + process.stderr.write(chalk.red(`Error: ${event.message}\n`)); + break; + } + } + + private flushAssistantMessage(): void { + if (this.contentBuffer.length === 0 && this.toolCalls.length === 0) return; + const content = this.contentBuffer.map((part) => ({ + type: part.type, + [part.type === "think" ? "think" : "text"]: part.text, + })); + const msg: Record = { role: "assistant", content }; + if (this.toolCalls.length > 0) { + msg.tool_calls = this.toolCalls.map((tc) => ({ + id: tc.id, + type: "function", + function: { name: tc.name, arguments: tc.arguments }, + })); + } + this.emitJson(msg); + this.contentBuffer = []; + this.toolCalls = []; + } + + private flushNotifications(): void { + for (const n of this.pendingNotifications) { + this.emitJson({ type: "notification", ...n }); + } + this.pendingNotifications = []; + } + + private emitJson(data: Record): void { + process.stdout.write(JSON.stringify(data) + "\n"); + } + + flush(): void { + this.flushAssistantMessage(); + this.flushNotifications(); + } +} + +// ── FinalOnlyTextPrinter ──────────────────────────────── + +export class FinalOnlyTextPrinter implements Printer { + private contentBuffer: ContentBuffer[] = []; + + feed(event: WireUIEvent): void { + switch (event.type) { + case "step_begin": + case "step_interrupted": + this.contentBuffer = []; + break; + case "text_delta": + mergeContent(this.contentBuffer, { type: "text", text: event.text }); + break; + case "error": + process.stderr.write(chalk.red(`Error: ${event.message}\n`)); + break; + } + } + + flush(): void { + const text = this.contentBuffer + .filter((p) => p.type === "text") + .map((p) => p.text) + .join(""); + if (text) { + process.stdout.write(text + "\n"); + } + this.contentBuffer = []; + } +} + +// ── FinalOnlyJsonPrinter ──────────────────────────────── + +export class FinalOnlyJsonPrinter implements Printer { + private contentBuffer: ContentBuffer[] = []; + + feed(event: WireUIEvent): void { + switch (event.type) { + case "step_begin": + case "step_interrupted": + this.contentBuffer = []; + break; + case "text_delta": + mergeContent(this.contentBuffer, { type: "text", text: event.text }); + break; + case "error": + process.stderr.write(chalk.red(`Error: ${event.message}\n`)); + break; + } + } + + flush(): void { + const text = this.contentBuffer + .filter((p) => p.type === "text") + .map((p) => p.text) + .join(""); + if (text) { + process.stdout.write( + JSON.stringify({ + role: "assistant", + content: [{ type: "text", text }], + }) + "\n", + ); + } + this.contentBuffer = []; + } +} + +// ── StreamJsonPrinter ────────────────────────────────── + +export class StreamJsonPrinter implements Printer { + feed(event: WireUIEvent): void { + switch (event.type) { + case "text_delta": + case "think_delta": + case "tool_call": + case "tool_result": + case "notification": + case "step_begin": + case "step_interrupted": + case "turn_end": + this.emitJson(event as unknown as Record); + break; + case "error": + process.stderr.write(chalk.red(`Error: ${(event as any).message}\n`)); + break; + } + } + + private emitJson(data: Record): void { + process.stdout.write(JSON.stringify(data) + "\n"); + } + + flush(): void {} +} + +// ── FinalOnlyStreamJsonPrinter ──────────────────────── + +export class FinalOnlyStreamJsonPrinter implements Printer { + private textBuffer = ""; + + feed(event: WireUIEvent): void { + switch (event.type) { + case "text_delta": + this.textBuffer += event.text; + break; + case "step_begin": + case "step_interrupted": + this.textBuffer = ""; + break; + case "error": + process.stderr.write(chalk.red(`Error: ${(event as any).message}\n`)); + break; + } + } + + flush(): void { + if (this.textBuffer) { + process.stdout.write(JSON.stringify({ type: "final_text", text: this.textBuffer }) + "\n"); + } + this.textBuffer = ""; + } +} + +// ── Factory ───────────────────────────────────────────── + +export function createPrinter(options: PrintOptions): Printer { + if (options.finalOnly) { + return options.outputFormat === "text" + ? new FinalOnlyTextPrinter() + : new FinalOnlyStreamJsonPrinter(); + } + return options.outputFormat === "text" + ? new TextPrinter() + : new StreamJsonPrinter(); +} + +// ── Legacy PrintMode (wraps Printer) ──────────────────── + +export class PrintMode { + private printer: Printer; + + constructor(options: PrintOptions) { + this.printer = createPrinter(options); + } + + handleEvent(event: WireUIEvent): void { + this.printer.feed(event); + if (event.type === "turn_end") { + this.printer.flush(); + } + } + + flush(): void { + this.printer.flush(); + } +} + +/** + * Classify error for exit codes. + */ +export function classifyError( + error: unknown, +): "retryable" | "permanent" | "unknown" { + if (error instanceof Error) { + const msg = error.message.toLowerCase(); + if ( + msg.includes("429") || + msg.includes("500") || + msg.includes("502") || + msg.includes("503") || + msg.includes("504") || + msg.includes("timeout") || + msg.includes("connection") + ) { + return "retryable"; + } + return "permanent"; + } + return "unknown"; +} + +// ── Stream-JSON Input Parser ───────────────────────────── + +/** + * Parse a stream-json input line into a user command. + * Returns null if the line is invalid or non-user role. + * Corresponds to Python Print._read_next_command(). + */ +export function parseStreamJsonInput(jsonLine: string): string | null { + const trimmed = jsonLine.trim(); + if (!trimmed) return null; + + try { + const data = JSON.parse(trimmed); + if (!data || typeof data !== "object") return null; + + // Expect { role: "user", content: "..." } or { role: "user", content: [...] } + if (data.role !== "user") return null; + + if (typeof data.content === "string") { + return data.content; + } + + if (Array.isArray(data.content)) { + // Extract text parts and join + const texts: string[] = []; + for (const part of data.content) { + if (part && typeof part === "object" && part.type === "text" && typeof part.text === "string") { + texts.push(part.text); + } + } + return texts.length > 0 ? texts.join("\n") : null; + } + + return null; + } catch { + return null; + } +} + +/** + * Read stream-json lines from a ReadableStream, yielding user commands. + * Corresponds to the Python Print._read_next_command loop. + */ +export async function* readStreamJsonInput( + input: ReadableStream | AsyncIterable, +): AsyncGenerator { + let buffer = ""; + + if ("getReader" in input) { + const reader = (input as ReadableStream).getReader(); + const decoder = new TextDecoder(); + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + for (const line of lines) { + const command = parseStreamJsonInput(line); + if (command) yield command; + } + } + // Process remaining buffer + if (buffer.trim()) { + const command = parseStreamJsonInput(buffer); + if (command) yield command; + } + } finally { + reader.releaseLock(); + } + } else { + for await (const line of input as AsyncIterable) { + const command = parseStreamJsonInput(line); + if (command) yield command; + } + } +} + +// ── Exit Codes ────────────────────────────────────────────── + +export const ExitCode = { + SUCCESS: 0, + FAILURE: 1, + RETRYABLE: 2, +} as const; + +function truncateStr(text: string, maxLen: number): string { + if (text.length <= maxLen) return text; + return text.slice(0, maxLen) + "…"; +} diff --git a/src/kimi_cli/ui/print/visualize.py b/src/kimi_cli/ui/print/visualize.py deleted file mode 100644 index 875470ed8..000000000 --- a/src/kimi_cli/ui/print/visualize.py +++ /dev/null @@ -1,185 +0,0 @@ -from typing import Protocol - -import rich -from kosong.message import Message - -from kimi_cli.cli import OutputFormat -from kimi_cli.soul.message import tool_result_to_message -from kimi_cli.utils.aioqueue import QueueShutDown -from kimi_cli.wire import Wire -from kimi_cli.wire.types import ( - ContentPart, - Notification, - PlanDisplay, - StepBegin, - StepInterrupted, - ToolCall, - ToolCallPart, - ToolResult, - WireMessage, -) - - -class Printer(Protocol): - def feed(self, msg: WireMessage) -> None: ... - def flush(self) -> None: ... - - -def _merge_content(buffer: list[ContentPart], part: ContentPart) -> None: - if not buffer or not buffer[-1].merge_in_place(part): - buffer.append(part) - - -class TextPrinter(Printer): - def feed(self, msg: WireMessage) -> None: - rich.print(msg) - - def flush(self) -> None: - pass - - -class JsonPrinter(Printer): - def __init__(self) -> None: - self._content_buffer: list[ContentPart] = [] - """The buffer to merge content parts.""" - self._tool_call_buffer: list[ToolCall] = [] - """The buffer to store the current assistant message's tool calls.""" - self._pending_notifications: list[Notification] = [] - """Notifications buffered until the current assistant message reaches a safe boundary.""" - self._last_tool_call: ToolCall | None = None - - def feed(self, msg: WireMessage) -> None: - match msg: - case StepBegin() | StepInterrupted(): - self.flush() - case Notification() as notification: - if self._content_buffer or self._tool_call_buffer: - self._pending_notifications.append(notification) - else: - self._flush_assistant_message() - self._flush_notifications() - self._emit_notification(notification) - case ContentPart() as part: - # merge with previous parts as much as possible - _merge_content(self._content_buffer, part) - case ToolCall() as call: - self._tool_call_buffer.append(call) - self._last_tool_call = call - case ToolCallPart() as part: - if self._last_tool_call is None: - return - assert self._last_tool_call.merge_in_place(part) - case ToolResult() as result: - self._flush_assistant_message() - self._flush_notifications() - message = tool_result_to_message(result) - print(message.model_dump_json(exclude_none=True), flush=True) - case PlanDisplay() as plan: - self._flush_assistant_message() - self._flush_notifications() - print(plan.model_dump_json(exclude_none=True), flush=True) - case _: - # ignore other messages - pass - - def _flush_assistant_message(self) -> None: - if not self._content_buffer and not self._tool_call_buffer: - return - - message = Message( - role="assistant", - content=self._content_buffer, - tool_calls=self._tool_call_buffer or None, - ) - print(message.model_dump_json(exclude_none=True), flush=True) - - self._content_buffer.clear() - self._tool_call_buffer.clear() - self._last_tool_call = None - - def _emit_notification(self, notification: Notification) -> None: - print(notification.model_dump_json(exclude_none=True), flush=True) - - def _flush_notifications(self) -> None: - for notification in self._pending_notifications: - self._emit_notification(notification) - self._pending_notifications.clear() - - def flush(self) -> None: - self._flush_assistant_message() - self._flush_notifications() - - -class FinalOnlyTextPrinter(Printer): - def __init__(self) -> None: - self._content_buffer: list[ContentPart] = [] - - def feed(self, msg: WireMessage) -> None: - match msg: - case StepBegin() | StepInterrupted(): - self._content_buffer.clear() - case ContentPart() as part: - _merge_content(self._content_buffer, part) - case _: - pass - - def flush(self) -> None: - if not self._content_buffer: - return - message = Message(role="assistant", content=self._content_buffer) - text = message.extract_text() - if text: - print(text, flush=True) - self._content_buffer.clear() - - -class FinalOnlyJsonPrinter(Printer): - def __init__(self) -> None: - self._content_buffer: list[ContentPart] = [] - - def feed(self, msg: WireMessage) -> None: - match msg: - case StepBegin() | StepInterrupted(): - self._content_buffer.clear() - case ContentPart() as part: - _merge_content(self._content_buffer, part) - case _: - pass - - def flush(self) -> None: - if not self._content_buffer: - return - message = Message(role="assistant", content=self._content_buffer) - text = message.extract_text() - if text: - final_message = Message(role="assistant", content=text) - print(final_message.model_dump_json(exclude_none=True), flush=True) - self._content_buffer.clear() - - -async def visualize(output_format: OutputFormat, final_only: bool, wire: Wire) -> None: - if final_only: - match output_format: - case "text": - handler = FinalOnlyTextPrinter() - case "stream-json": - handler = FinalOnlyJsonPrinter() - else: - match output_format: - case "text": - handler = TextPrinter() - case "stream-json": - handler = JsonPrinter() - - wire_ui = wire.ui_side(merge=True) - while True: - try: - msg = await wire_ui.receive() - except QueueShutDown: - handler.flush() - break - - handler.feed(msg) - - if isinstance(msg, StepInterrupted): - break diff --git a/src/kimi_cli/ui/shell/ApprovalPanel.tsx b/src/kimi_cli/ui/shell/ApprovalPanel.tsx new file mode 100644 index 000000000..39badd006 --- /dev/null +++ b/src/kimi_cli/ui/shell/ApprovalPanel.tsx @@ -0,0 +1,328 @@ +/** + * ApprovalPanel.tsx — Full approval request panel with React Ink. + * Corresponds to Python's ui/shell/approval_panel.py. + * + * Features: + * - 4 options: approve once (y), approve for session (a), reject (n), reject with feedback (f) + * - Diff preview panel + * - Inline feedback input + * - Truncation with expand hint + * - Keyboard navigation (↑↓ or 1-4 number keys) + */ + +import React, { useState, useCallback } from "react"; +import { Box, Text, useInput } from "ink"; +import type { + ApprovalRequest, + ApprovalResponseKind, + DisplayBlock, + DiffDisplayBlock, + ShellDisplayBlock, + BriefDisplayBlock, +} from "../../wire/types"; + +const MAX_PREVIEW_LINES = 4; + +interface ApprovalOption { + label: string; + response: ApprovalResponseKind; +} + +const OPTIONS: ApprovalOption[] = [ + { label: "Approve once", response: "approve" }, + { label: "Approve for this session", response: "approve_for_session" }, + { label: "Reject", response: "reject" }, + { label: "Reject, tell the model what to do instead", response: "reject" }, +]; + +const FEEDBACK_OPTION_INDEX = 3; + +// ── DiffPreview ────────────────────────────────────────── + +function DiffPreview({ blocks }: { blocks: DisplayBlock[] }) { + const diffBlocks = blocks.filter( + (b): b is DiffDisplayBlock => b.type === "diff", + ); + if (diffBlocks.length === 0) return null; + + // Group by path + const byPath = new Map(); + for (const block of diffBlocks) { + const existing = byPath.get(block.path) || []; + existing.push(block); + byPath.set(block.path, existing); + } + + return ( + + {[...byPath.entries()].map(([path, diffs]) => ( + + + {path} + + {diffs.map((diff, idx) => ( + + {diff.old_text + .split("\n") + .slice(0, MAX_PREVIEW_LINES) + .map((line, lineIdx) => ( + + - {line} + + ))} + {diff.new_text + .split("\n") + .slice(0, MAX_PREVIEW_LINES) + .map((line, lineIdx) => ( + + + {line} + + ))} + + ))} + + ))} + + ); +} + +// ── ContentPreview ─────────────────────────────────────── + +function ContentPreview({ blocks }: { blocks: DisplayBlock[] }) { + let budget = MAX_PREVIEW_LINES; + let truncated = false; + const elements: React.ReactNode[] = []; + + for (let i = 0; i < blocks.length; i++) { + const block = blocks[i]; + if (budget <= 0) { + truncated = true; + break; + } + + if (!block) continue; + if (block.type === "shell") { + const shellBlock = block as ShellDisplayBlock; + const lines = shellBlock.command.trim().split("\n"); + const showLines = lines.slice(0, budget); + if (lines.length > budget) truncated = true; + budget -= showLines.length; + elements.push( + + {showLines.join("\n")} + , + ); + } else if (block.type === "brief") { + const briefBlock = block as BriefDisplayBlock; + const lines = briefBlock.brief.trim().split("\n"); + const showLines = lines.slice(0, budget); + if (lines.length > budget) truncated = true; + budget -= showLines.length; + elements.push( + + {showLines.join("\n")} + , + ); + } + } + + return ( + + {elements} + {truncated && ( + + ... (truncated, ctrl-e to expand) + + )} + + ); +} + +// ── ApprovalPanel ──────────────────────────────────────── + +export interface ApprovalPanelProps { + request: ApprovalRequest; + onRespond: ( + decision: ApprovalResponseKind, + feedback?: string, + ) => void; +} + +export function ApprovalPanel({ request, onRespond }: ApprovalPanelProps) { + const [selectedIndex, setSelectedIndex] = useState(0); + const [feedbackMode, setFeedbackMode] = useState(false); + const [feedbackText, setFeedbackText] = useState(""); + + const isFeedbackSelected = selectedIndex === FEEDBACK_OPTION_INDEX; + + const submit = useCallback( + (index: number) => { + if (index === FEEDBACK_OPTION_INDEX) { + setFeedbackMode(true); + return; + } + onRespond(OPTIONS[index]!.response); + }, + [onRespond], + ); + + useInput((input, key) => { + if (feedbackMode) { + if (key.return) { + if (feedbackText.trim()) { + onRespond("reject", feedbackText.trim()); + } + return; + } + if (key.escape) { + onRespond("reject", ""); + return; + } + if (key.backspace || key.delete) { + setFeedbackText((t) => t.slice(0, -1)); + return; + } + if (input && !key.ctrl && !key.meta) { + setFeedbackText((t) => t + input); + } + return; + } + + // Normal navigation + if (key.upArrow) { + setSelectedIndex((i) => (i - 1 + OPTIONS.length) % OPTIONS.length); + } else if (key.downArrow) { + setSelectedIndex((i) => (i + 1) % OPTIONS.length); + } else if (key.return) { + submit(selectedIndex); + } else if (key.escape) { + onRespond("reject"); + } else if (input >= "1" && input <= "4") { + const idx = parseInt(input) - 1; + if (idx < OPTIONS.length) { + setSelectedIndex(idx); + if (idx !== FEEDBACK_OPTION_INDEX) { + submit(idx); + } else { + setFeedbackMode(true); + } + } + } + }); + + const hasDiff = request.display.some((b) => b.type === "diff"); + const hasContent = + hasDiff || + !!request.description || + request.display.some((b) => b.type === "shell" || b.type === "brief"); + + return ( + + {/* Title */} + + ⚠ ACTION REQUIRED + + + + {/* Request header */} + + + {request.sender} is requesting approval to {request.action}: + + + {/* Source metadata */} + {(request.subagent_type || request.agent_id) && ( + + Subagent:{" "} + {request.subagent_type && request.agent_id + ? `${request.subagent_type} (${request.agent_id})` + : request.subagent_type || request.agent_id} + + )} + {request.source_description && ( + Task: {request.source_description} + )} + + + + + {/* Description */} + {request.description && !request.display.length && ( + + {truncateLines(request.description, MAX_PREVIEW_LINES)} + + )} + + {/* Diff preview */} + {hasDiff && ( + + + + )} + + {/* Non-diff content preview */} + {request.display.some( + (b) => b.type === "shell" || b.type === "brief", + ) && ( + + + + )} + + + + {/* Options */} + {OPTIONS.map((option, i) => { + const num = i + 1; + const isSelected = i === selectedIndex; + const isFeedback = i === FEEDBACK_OPTION_INDEX; + + if (isFeedback && feedbackMode && isSelected) { + return ( + + → [{num}] Reject: {feedbackText}█ + + ); + } + + return ( + + {isSelected ? "→" : " "} [{num}] {option.label} + + ); + })} + + + + {/* Keyboard hints */} + {feedbackMode ? ( + + {" "}Type your feedback, then press Enter to submit. + + ) : ( + + {" "}▲/▼ select {" "}1/2/3/4 choose {" "}↵ confirm + {hasContent ? " ctrl-e expand" : ""} + + )} + + ); +} + +// ── Helpers ────────────────────────────────────────────── + +function truncateLines(text: string, maxLines: number): string { + const lines = text.split("\n"); + if (lines.length <= maxLines) return text; + return lines.slice(0, maxLines).join("\n") + "\n..."; +} + +export default ApprovalPanel; diff --git a/src/kimi_cli/ui/shell/DebugPanel.tsx b/src/kimi_cli/ui/shell/DebugPanel.tsx new file mode 100644 index 000000000..459f35c09 --- /dev/null +++ b/src/kimi_cli/ui/shell/DebugPanel.tsx @@ -0,0 +1,183 @@ +/** + * DebugPanel.tsx — Context debug viewer. + * Corresponds to Python's ui/shell/debug.py. + * + * Features: + * - Display full context (all messages with role colors) + * - Token count + * - Checkpoint information + */ + +import React from "react"; +import { Box, Text } from "ink"; + +// ── Types ─────────────────────────────────────────────── + +export interface ContextInfo { + totalMessages: number; + tokenCount: number; + checkpoints: number; + trajectory?: string; +} + +export interface DebugMessage { + role: string; + content: string; + name?: string; + toolCallId?: string; + toolCalls?: Array<{ + id: string; + name: string; + arguments: string; + }>; + partial?: boolean; +} + +export interface DebugPanelProps { + context: ContextInfo; + messages: DebugMessage[]; +} + +// ── Role colors ───────────────────────────────────────── + +const ROLE_COLORS: Record = { + system: "magenta", + developer: "magenta", + user: "green", + assistant: "blue", + tool: "yellow", +}; + +function getRoleColor(role: string): string { + return ROLE_COLORS[role] || "white"; +} + +// ── ContentPart formatting ────────────────────────────── + +function formatContent(content: string): React.ReactNode { + const trimmed = content.trim(); + if (trimmed.startsWith("") && trimmed.endsWith("")) { + const inner = trimmed.slice(8, -9).trim(); + return ( + + system + {inner} + + ); + } + return {content}; +} + +// ── ToolCall formatting ───────────────────────────────── + +function ToolCallDebugView({ + toolCall, +}: { + toolCall: { id: string; name: string; arguments: string }; +}) { + let argsFormatted: string; + try { + argsFormatted = JSON.stringify(JSON.parse(toolCall.arguments), null, 2); + } catch { + argsFormatted = toolCall.arguments; + } + + return ( + + Tool Call + Function: {toolCall.name} + Call ID: {toolCall.id} + Arguments: + {argsFormatted} + + ); +} + +// ── Message formatting ────────────────────────────────── + +function MessageDebugView({ + msg, + index, +}: { + msg: DebugMessage; + index: number; +}) { + const roleColor = getRoleColor(msg.role); + let title = `#${index + 1} ${msg.role.toUpperCase()}`; + if (msg.name) title += ` (${msg.name})`; + if (msg.toolCallId) title += ` → ${msg.toolCallId}`; + if (msg.partial) title += " (partial)"; + + return ( + + {title} + {msg.content ? ( + formatContent(msg.content) + ) : ( + [empty message] + )} + {msg.toolCalls?.map((tc) => ( + + ))} + + ); +} + +// ── DebugPanel ────────────────────────────────────────── + +export function DebugPanel({ context, messages }: DebugPanelProps) { + if (messages.length === 0) { + return ( + + Context is empty - no messages yet + + ); + } + + return ( + + {/* Context info */} + + Context Info + Total messages: {context.totalMessages} + Token count: {context.tokenCount.toLocaleString()} + Checkpoints: {context.checkpoints} + {context.trajectory && ( + Trajectory: {context.trajectory} + )} + + + {/* Separator */} + {"─".repeat(60)} + + {/* All messages */} + {messages.map((msg, idx) => ( + + ))} + + ); +} + +export default DebugPanel; diff --git a/src/kimi_cli/ui/shell/Prompt.tsx b/src/kimi_cli/ui/shell/Prompt.tsx new file mode 100644 index 000000000..3530892eb --- /dev/null +++ b/src/kimi_cli/ui/shell/Prompt.tsx @@ -0,0 +1,173 @@ +/** + * Prompt.tsx — Input prompt component with slash command completion. + * Uses ✨ sparkles emoji matching Python version. + * Slash menu renders BELOW the input (pushes up from bottom). + */ + +import React, { useState, useCallback } from "react"; +import { Box, Text, useInput, useStdout } from "ink"; +import TextInput from "ink-text-input"; +import { useInputHistory } from "../hooks/useInput.ts"; +import { + SlashMenu, + getFilteredCommandCount, + getFilteredCommand, +} from "../components/SlashMenu.tsx"; +import type { SlashCommand } from "../../types.ts"; + +interface PromptProps { + onSubmit: (input: string) => void; + onOpenPanel?: (cmd: SlashCommand) => void; + disabled?: boolean; + placeholder?: string; + isStreaming?: boolean; + commands?: SlashCommand[]; + onSlashMenuChange?: (visible: boolean) => void; + /** Incremented by parent to signal "clear the input box" */ + clearSignal?: number; +} + +export function Prompt({ + onSubmit, + onOpenPanel, + disabled = false, + placeholder = "Send a message... (/ for commands)", + isStreaming = false, + commands = [], + onSlashMenuChange, + clearSignal = 0, +}: PromptProps) { + const { value, setValue, historyPrev, historyNext, addToHistory } = + useInputHistory(); + + const [slashMenuIndex, setSlashMenuIndex] = useState(0); + + // React to clearSignal from parent (double-Esc) + React.useEffect(() => { + if (clearSignal > 0) { + setValue(""); + } + }, [clearSignal, setValue]); + + // Detect slash completion mode + const isSlashMode = + value.startsWith("/") && !value.includes(" ") && commands.length > 0; + const slashFilter = isSlashMode ? value.slice(1) : ""; + const menuCount = isSlashMode + ? getFilteredCommandCount(commands, slashFilter) + : 0; + const showSlashMenu = isSlashMode && menuCount > 0; + + // Notify parent about slash menu visibility + React.useEffect(() => { + onSlashMenuChange?.(showSlashMenu); + }, [showSlashMenu, onSlashMenuChange]); + + // Reset menu index when filter changes + React.useEffect(() => { + setSlashMenuIndex(0); + }, [slashFilter]); + + const handleChange = useCallback( + (newValue: string) => { + setValue(newValue); + }, + [setValue], + ); + + const handleSubmit = useCallback( + (input: string) => { + if (showSlashMenu) { + const selected = getFilteredCommand( + commands, + slashFilter, + slashMenuIndex, + ); + if (selected) { + const cmd = `/${selected.name}`; + addToHistory(cmd); + setValue(""); + // If the command has a panel, open it instead of submitting + if (selected.panel && onOpenPanel) { + onOpenPanel(selected); + return; + } + onSubmit(cmd); + return; + } + } + + const trimmed = input.trim(); + if (!trimmed) return; + addToHistory(trimmed); + setValue(""); + onSubmit(trimmed); + }, + [ + onSubmit, + onOpenPanel, + addToHistory, + setValue, + showSlashMenu, + commands, + slashFilter, + slashMenuIndex, + ], + ); + + // Handle up/down/tab for navigation + useInput( + (_input, key) => { + if (showSlashMenu) { + if (key.upArrow) { + setSlashMenuIndex((i) => Math.max(0, i - 1)); + } else if (key.downArrow) { + setSlashMenuIndex((i) => Math.min(menuCount - 1, i + 1)); + } else if (key.tab) { + const selected = getFilteredCommand( + commands, + slashFilter, + slashMenuIndex, + ); + if (selected) { + setValue(`/${selected.name} `); + } + } + } else { + if (key.upArrow) historyPrev(); + else if (key.downArrow) historyNext(); + } + }, + { isActive: !disabled }, + ); + + const { stdout } = useStdout(); + const columns = stdout?.columns ?? 80; + + return ( + + {/* Separator line above input */} + {"─".repeat(columns)} + + {/* Input line — always rendered, always on top */} + + {isStreaming ? "🔄 " : "✨ "} + + + + {/* Slash command menu — renders below input, pushes up from bottom */} + {showSlashMenu && ( + + )} + + ); +} diff --git a/src/kimi_cli/ui/shell/QuestionPanel.tsx b/src/kimi_cli/ui/shell/QuestionPanel.tsx new file mode 100644 index 000000000..e517fdab9 --- /dev/null +++ b/src/kimi_cli/ui/shell/QuestionPanel.tsx @@ -0,0 +1,406 @@ +/** + * QuestionPanel.tsx — Interactive question panel with React Ink. + * Corresponds to Python's ui/shell/question_panel.py. + * + * Features: + * - Multi-question tabs (◀/▶ to switch) + * - Number key selection (1-6) + * - Multi-select with space toggle + * - "Other" free text input + * - Body content area + */ + +import React, { useState, useCallback } from "react"; +import { Box, Text, useInput } from "ink"; +import type { QuestionRequest, QuestionItem, QuestionOption } from "../../wire/types"; + +const OTHER_OPTION_LABEL = "Other"; + +interface OptionEntry { + label: string; + description: string; +} + +export interface QuestionPanelProps { + request: QuestionRequest; + onAnswer: (answers: Record) => void; + onCancel: () => void; +} + +export function QuestionPanel({ + request, + onAnswer, + onCancel, +}: QuestionPanelProps) { + const [questionIndex, setQuestionIndex] = useState(0); + const [selectedIndex, setSelectedIndex] = useState(0); + const [multiSelected, setMultiSelected] = useState>(new Set()); + const [answers, setAnswers] = useState>({}); + const [otherMode, setOtherMode] = useState(false); + const [otherText, setOtherText] = useState(""); + const [otherDrafts, setOtherDrafts] = useState>({}); + + const question: QuestionItem = request.questions[questionIndex]!; + const options: OptionEntry[] = [ + ...question.options.map((o) => ({ + label: o.label, + description: o.description, + })), + { + label: question.other_label || OTHER_OPTION_LABEL, + description: question.other_description || "", + }, + ]; + const isMultiSelect = question.multi_select; + const isOtherSelected = selectedIndex === options.length - 1; + const otherIdx = options.length - 1; + + const advance = useCallback(() => { + const newAnswers = { ...answers }; + // Find next unanswered + const total = request.questions.length; + if (Object.keys(newAnswers).length >= total) { + onAnswer(newAnswers); + return true; + } + for (let offset = 1; offset <= total; offset++) { + const idx = (questionIndex + offset) % total; + if (!(request.questions[idx]!.question in newAnswers)) { + setQuestionIndex(idx); + setSelectedIndex(0); + setMultiSelected(new Set()); + setOtherMode(false); + setOtherText(""); + return false; + } + } + onAnswer(newAnswers); + return true; + }, [answers, questionIndex, request.questions, onAnswer]); + + const submitCurrent = useCallback(() => { + if (isMultiSelect) { + if (otherIdx in [...multiSelected] && multiSelected.has(otherIdx)) { + // Need other input first + setOtherMode(true); + return; + } + const selected = [...multiSelected] + .filter((i) => i < question.options.length) + .sort() + .map((i) => options[i]!.label); + if (selected.length === 0) return; + const newAnswers = { ...answers, [question.question]: selected.join(", ") }; + setAnswers(newAnswers); + // Check if all answered + if (Object.keys(newAnswers).length >= request.questions.length) { + onAnswer(newAnswers); + } else { + advance(); + } + } else { + if (isOtherSelected) { + setOtherMode(true); + return; + } + const newAnswers = { + ...answers, + [question.question]: options[selectedIndex]!.label, + }; + setAnswers(newAnswers); + if (Object.keys(newAnswers).length >= request.questions.length) { + onAnswer(newAnswers); + } else { + // Find next unanswered + const total = request.questions.length; + for (let offset = 1; offset <= total; offset++) { + const idx = (questionIndex + offset) % total; + if (!(request.questions[idx]!.question in newAnswers)) { + setQuestionIndex(idx); + setSelectedIndex(0); + setMultiSelected(new Set()); + return; + } + } + onAnswer(newAnswers); + } + } + }, [ + isMultiSelect, + multiSelected, + otherIdx, + question, + options, + selectedIndex, + isOtherSelected, + answers, + request.questions, + questionIndex, + onAnswer, + advance, + ]); + + const submitOther = useCallback( + (text: string) => { + let newAnswers: Record; + if (isMultiSelect) { + const selected = [...multiSelected] + .filter((i) => i < question.options.length && i !== otherIdx) + .sort() + .map((i) => options[i]!.label); + if (text) selected.push(text); + newAnswers = { + ...answers, + [question.question]: selected.join(", ") || text, + }; + } else { + newAnswers = { ...answers, [question.question]: text }; + } + setAnswers(newAnswers); + setOtherMode(false); + setOtherText(""); + if (Object.keys(newAnswers).length >= request.questions.length) { + onAnswer(newAnswers); + } else { + const total = request.questions.length; + for (let offset = 1; offset <= total; offset++) { + const idx = (questionIndex + offset) % total; + if (!(request.questions[idx]!.question in newAnswers)) { + setQuestionIndex(idx); + setSelectedIndex(0); + setMultiSelected(new Set()); + return; + } + } + onAnswer(newAnswers); + } + }, + [ + isMultiSelect, + multiSelected, + question, + otherIdx, + options, + answers, + request.questions, + questionIndex, + onAnswer, + ], + ); + + useInput((input, key) => { + // Other text input mode + if (otherMode) { + if (key.return) { + submitOther(otherText.trim()); + return; + } + if (key.escape) { + setOtherMode(false); + setOtherText(""); + onCancel(); + return; + } + if (key.backspace || key.delete) { + setOtherText((t) => t.slice(0, -1)); + return; + } + if (input && !key.ctrl && !key.meta) { + setOtherText((t) => t + input); + } + return; + } + + // Navigation + if (key.upArrow) { + setSelectedIndex((i) => (i - 1 + options.length) % options.length); + } else if (key.downArrow) { + setSelectedIndex((i) => (i + 1) % options.length); + } else if (key.leftArrow) { + // Previous tab + if (questionIndex > 0) { + setQuestionIndex(questionIndex - 1); + setSelectedIndex(0); + setMultiSelected(new Set()); + } + } else if (key.rightArrow || key.tab) { + // Next tab + if (questionIndex < request.questions.length - 1) { + setQuestionIndex(questionIndex + 1); + setSelectedIndex(0); + setMultiSelected(new Set()); + } + } else if (input === " " && isMultiSelect) { + // Toggle multi-select + setMultiSelected((prev) => { + const next = new Set(prev); + if (next.has(selectedIndex)) { + next.delete(selectedIndex); + } else { + next.add(selectedIndex); + } + return next; + }); + } else if (key.return) { + submitCurrent(); + } else if (key.escape) { + onCancel(); + } else if (input >= "1" && input <= "6") { + const idx = parseInt(input) - 1; + if (idx < options.length) { + setSelectedIndex(idx); + if (isMultiSelect) { + setMultiSelected((prev) => { + const next = new Set(prev); + if (next.has(idx)) { + next.delete(idx); + } else { + next.add(idx); + } + return next; + }); + } else if (idx !== otherIdx) { + // Direct submit for non-other + const newAnswers = { + ...answers, + [question.question]: options[idx]!.label, + }; + setAnswers(newAnswers); + if (Object.keys(newAnswers).length >= request.questions.length) { + onAnswer(newAnswers); + } else { + const total = request.questions.length; + for (let offset = 1; offset <= total; offset++) { + const nextIdx = (questionIndex + offset) % total; + if ( + !(request.questions[nextIdx]!.question in newAnswers) + ) { + setQuestionIndex(nextIdx); + setSelectedIndex(0); + setMultiSelected(new Set()); + return; + } + } + onAnswer(newAnswers); + } + } else { + setOtherMode(true); + } + } + } + }); + + return ( + + {/* Title */} + + ? QUESTION + + + + {/* Tabs for multi-question */} + {request.questions.length > 1 && ( + <> + + {request.questions.map((q, i) => { + const label = q.header || `Q${i + 1}`; + const isActive = i === questionIndex; + const isAnswered = q.question in answers; + const icon = isActive ? "●" : isAnswered ? "✓" : "○"; + const color = isActive ? "cyan" : isAnswered ? "green" : "grey"; + return ( + + ({icon}) {label} + + ); + })} + + + + )} + + {/* Question text */} + ? {question.question} + {isMultiSelect && ( + + {" "}(SPACE to toggle, ENTER to submit) + + )} + + + {/* Body hint */} + {question.body && ( + <> + + {" "}▶ Press ctrl-e to view full content + + + + )} + + {/* Options */} + {options.map((option, i) => { + const num = i + 1; + const isSelected = i === selectedIndex; + const isOther = i === otherIdx; + + if (isMultiSelect) { + const checked = multiSelected.has(i) ? "✓" : " "; + return ( + + + [{checked}] {option.label} + + {option.description && !isSelected && ( + {" "}{option.description} + )} + + ); + } + + if (isOther && otherMode && isSelected) { + return ( + + → [{num}] {option.label}: {otherText}█ + + ); + } + + return ( + + + {isSelected ? "→" : " "} [{num}] {option.label} + + {option.description && !(isOther && otherMode) && ( + {" "}{option.description} + )} + + ); + })} + + {/* Hints */} + + {otherMode ? ( + + {" "}Type your answer, then press Enter to submit. + + ) : request.questions.length > 1 ? ( + + {" "}◄/► switch question {" "}▲/▼ select {" "}↵ submit {" "}esc + exit + + ) : ( + + {" "}▲/▼ select {" "}↵ submit {" "}esc exit + + )} + + ); +} + +export default QuestionPanel; diff --git a/src/kimi_cli/ui/shell/ReplayPanel.tsx b/src/kimi_cli/ui/shell/ReplayPanel.tsx new file mode 100644 index 000000000..732226551 --- /dev/null +++ b/src/kimi_cli/ui/shell/ReplayPanel.tsx @@ -0,0 +1,219 @@ +/** + * ReplayPanel.tsx — Session reconnection replay. + * Corresponds to Python's ui/shell/replay.py. + * + * Features: + * - Replays the most recent turns when reconnecting to a session + * - Shows user messages and assistant responses + * - Renders tool calls and results + */ + +import React from "react"; +import { Box, Text } from "ink"; +import type { WireUIEvent } from "./events"; + +const MAX_REPLAY_TURNS = 5; + +// ── Types ─────────────────────────────────────────────── + +export interface ReplayTurn { + userInput: string; + events: ReplayEvent[]; + stepCount: number; +} + +export interface ReplayEvent { + type: "text" | "think" | "tool_call" | "tool_result" | "step_begin" | "notification" | "plan_display"; + text?: string; + toolName?: string; + toolArgs?: string; + toolCallId?: string; + isError?: boolean; + title?: string; + body?: string; + content?: string; + filePath?: string; +} + +export interface ReplayPanelProps { + turns: ReplayTurn[]; +} + +// ── ReplayTurnView ────────────────────────────────────── + +function ReplayTurnView({ turn }: { turn: ReplayTurn }) { + return ( + + + You: + {turn.userInput} + + {turn.events.map((event, idx) => ( + + ))} + + ); +} + +// ── ReplayEventView ───────────────────────────────────── + +function ReplayEventView({ event }: { event: ReplayEvent }) { + switch (event.type) { + case "text": + return {event.text}; + case "think": + return ( + + 💭 {event.text} + + ); + case "tool_call": + return ( + + + {event.toolName} + {event.toolArgs && {truncate(event.toolArgs, 60)}} + + ); + case "tool_result": + return ( + + + {event.isError ? "✗" : "✓"}{" "} + + {event.text && {truncate(event.text, 100)}} + + ); + case "notification": + return ( + + + [{event.title}] {event.body} + + ); + case "plan_display": + return ( + + 📋 Plan + {event.content && {truncate(event.content, 200)}} + + ); + case "step_begin": + return null; + default: + return null; + } +} + +// ── ReplayPanel ───────────────────────────────────────── + +export function ReplayPanel({ turns }: ReplayPanelProps) { + if (turns.length === 0) return null; + + const recentTurns = turns.slice(-MAX_REPLAY_TURNS); + + return ( + + ─── Replaying recent history ─── + {recentTurns.map((turn, idx) => ( + + ))} + ─── End of replay ─── + + ); +} + +// ── Helpers ───────────────────────────────────────────── + +function truncate(text: string, maxLen: number): string { + if (text.length <= maxLen) return text; + return text.slice(0, maxLen) + "…"; +} + +/** + * Build replay turns from wire events. + */ +export function buildReplayTurnsFromEvents(events: WireUIEvent[]): ReplayTurn[] { + const turns: ReplayTurn[] = []; + let currentTurn: ReplayTurn | null = null; + + for (const event of events) { + switch (event.type) { + case "turn_begin": + currentTurn = { userInput: event.userInput, events: [], stepCount: 0 }; + turns.push(currentTurn); + break; + case "step_begin": + if (currentTurn) { + currentTurn.stepCount = event.n; + currentTurn.events.push({ type: "step_begin" }); + } + break; + case "text_delta": + if (currentTurn) { + const last = currentTurn.events[currentTurn.events.length - 1]; + if (last && last.type === "text") { + last.text = (last.text || "") + event.text; + } else { + currentTurn.events.push({ type: "text", text: event.text }); + } + } + break; + case "think_delta": + if (currentTurn) { + const last = currentTurn.events[currentTurn.events.length - 1]; + if (last && last.type === "think") { + last.text = (last.text || "") + event.text; + } else { + currentTurn.events.push({ type: "think", text: event.text }); + } + } + break; + case "tool_call": + if (currentTurn) { + currentTurn.events.push({ + type: "tool_call", + toolName: event.name, + toolArgs: event.arguments, + toolCallId: event.id, + }); + } + break; + case "tool_result": + if (currentTurn) { + currentTurn.events.push({ + type: "tool_result", + toolCallId: event.toolCallId, + text: event.result.return_value.output, + isError: event.result.return_value.isError, + }); + } + break; + case "notification": + if (currentTurn) { + currentTurn.events.push({ + type: "notification", + title: event.title, + body: event.body, + }); + } + break; + case "plan_display": + if (currentTurn) { + currentTurn.events.push({ + type: "plan_display", + content: (event as any).content, + filePath: (event as any).filePath, + }); + } + break; + case "turn_end": + currentTurn = null; + break; + } + } + + return turns.slice(-MAX_REPLAY_TURNS); +} + +export default ReplayPanel; diff --git a/src/kimi_cli/ui/shell/SetupWizard.tsx b/src/kimi_cli/ui/shell/SetupWizard.tsx new file mode 100644 index 000000000..d17aa2618 --- /dev/null +++ b/src/kimi_cli/ui/shell/SetupWizard.tsx @@ -0,0 +1,258 @@ +/** + * SetupWizard.tsx — First-time setup wizard. + * Corresponds to Python's ui/shell/setup.py. + * + * Features: + * - Platform selection + * - API key input + * - Model selection + * - Thinking mode toggle + */ + +import React, { useState, useCallback } from "react"; +import { Box, Text, useInput } from "ink"; + +// ── Types ─────────────────────────────────────────────── + +export interface PlatformInfo { + id: string; + name: string; +} + +export interface ModelInfo { + id: string; + contextLength?: number; + capabilities?: string[]; +} + +export type SetupStep = + | "platform" + | "api_key" + | "verifying" + | "model" + | "thinking" + | "done" + | "error"; + +export interface SetupResult { + platformId: string; + platformName: string; + apiKey: string; + modelId: string; + thinking: boolean; +} + +export interface SetupWizardProps { + platforms: PlatformInfo[]; + onVerifyKey?: (platformId: string, apiKey: string) => Promise; + onComplete?: (result: SetupResult) => void; + onCancel?: () => void; +} + +// ── SetupWizard ───────────────────────────────────────── + +export function SetupWizard({ + platforms, + onVerifyKey, + onComplete, + onCancel, +}: SetupWizardProps) { + const [step, setStep] = useState("platform"); + const [selectedIndex, setSelectedIndex] = useState(0); + const [selectedPlatform, setSelectedPlatform] = useState(null); + const [apiKey, setApiKey] = useState(""); + const [models, setModels] = useState([]); + const [selectedModel, setSelectedModel] = useState(null); + const [thinking, setThinking] = useState(false); + const [error, setError] = useState(""); + + const finishSetup = useCallback( + (model: ModelInfo, thinkingMode: boolean) => { + setStep("done"); + setSelectedModel(model); + setThinking(thinkingMode); + onComplete?.({ + platformId: selectedPlatform!.id, + platformName: selectedPlatform!.name, + apiKey: apiKey.trim(), + modelId: model.id, + thinking: thinkingMode, + }); + }, + [selectedPlatform, apiKey, onComplete], + ); + + const handleVerifyKey = useCallback(async () => { + if (!selectedPlatform || !apiKey.trim()) return; + setStep("verifying"); + try { + const result = await onVerifyKey?.(selectedPlatform.id, apiKey.trim()); + if (result && result.length > 0) { + setModels(result); + setStep("model"); + setSelectedIndex(0); + } else { + setError("No models available for the selected platform."); + setStep("error"); + } + } catch (e) { + setError(e instanceof Error ? e.message : "Failed to verify API key."); + setStep("error"); + } + }, [selectedPlatform, apiKey, onVerifyKey]); + + useInput((input, key) => { + if (key.escape) { + onCancel?.(); + return; + } + + switch (step) { + case "platform": { + if (key.upArrow) { + setSelectedIndex((i) => (i - 1 + platforms.length) % platforms.length); + } else if (key.downArrow) { + setSelectedIndex((i) => (i + 1) % platforms.length); + } else if (key.return) { + const platform = platforms[selectedIndex]!; + setSelectedPlatform(platform); + setStep("api_key"); + setSelectedIndex(0); + } + break; + } + + case "api_key": { + if (key.return && apiKey.trim()) { + handleVerifyKey(); + } else if (key.backspace || key.delete) { + setApiKey((k) => k.slice(0, -1)); + } else if (input && !key.ctrl && !key.meta) { + setApiKey((k) => k + input); + } + break; + } + + case "model": { + if (key.upArrow) { + setSelectedIndex((i) => (i - 1 + models.length) % models.length); + } else if (key.downArrow) { + setSelectedIndex((i) => (i + 1) % models.length); + } else if (key.return) { + const model = models[selectedIndex]!; + const caps = model.capabilities || []; + if (caps.includes("always_thinking")) { + finishSetup(model, true); + } else if (caps.includes("thinking")) { + setSelectedModel(model); + setStep("thinking"); + setSelectedIndex(0); + } else { + finishSetup(model, false); + } + } + break; + } + + case "thinking": { + const choices = ["on", "off"]; + if (key.upArrow) { + setSelectedIndex((i) => (i - 1 + choices.length) % choices.length); + } else if (key.downArrow) { + setSelectedIndex((i) => (i + 1) % choices.length); + } else if (key.return) { + finishSetup(selectedModel!, selectedIndex === 0); + } + break; + } + + case "error": { + if (key.return) { + setStep("api_key"); + setApiKey(""); + setError(""); + } + break; + } + } + }); + + return ( + + 🔧 Setup Wizard + + + {step === "platform" && ( + <> + Select a platform (↑↓ navigate, Enter select, Esc cancel): + + {platforms.map((platform, i) => ( + + {i === selectedIndex ? "→" : " "} {platform.name} + + ))} + + )} + + {step === "api_key" && ( + <> + Enter your API key for {selectedPlatform?.name}: + + + {">"} + {apiKey ? "•".repeat(apiKey.length) : ""} + + + + Press Enter to verify, Esc to cancel. + + )} + + {step === "verifying" && Verifying API key...} + + {step === "model" && ( + <> + Select a model (↑↓ navigate, Enter select): + + {models.map((model, i) => ( + + {i === selectedIndex ? "→" : " "} {model.id} + + ))} + + )} + + {step === "thinking" && ( + <> + Enable thinking mode? (↑↓ navigate, Enter select): + + {["on", "off"].map((choice, i) => ( + + {i === selectedIndex ? "→" : " "} {choice} + + ))} + + )} + + {step === "done" && ( + <> + ✓ Setup complete! + Platform: {selectedPlatform?.name} + Model: {selectedModel?.id} + Thinking: {thinking ? "on" : "off"} + Reloading... + + )} + + {step === "error" && ( + <> + {error} + + Press Enter to try again, Esc to cancel. + + )} + + ); +} + +export default SetupWizard; diff --git a/src/kimi_cli/ui/shell/Shell.tsx b/src/kimi_cli/ui/shell/Shell.tsx new file mode 100644 index 000000000..ad138cd65 --- /dev/null +++ b/src/kimi_cli/ui/shell/Shell.tsx @@ -0,0 +1,278 @@ +/** + * Shell.tsx — Main REPL component. + * Corresponds to Python's ui/shell/__init__.py. + * + * Layout logic: + * - WelcomeBox: fixed at top (will scroll off when content grows) + * - ChatList: height = content lines (grows as messages added) + * - InputBox: flexGrow=1 + minHeight=6, fills remaining space + * - text starts from top (row 0) + * - when ChatList grows, InputBox shrinks down to minHeight + * - when InputBox is at minHeight, total layout exceeds screen → scrollable + * - StatusBar: always at bottom + */ + +import React, { useCallback, useEffect, useState } from "react"; +import { Box, useApp, useStdout } from "ink"; +import { MessageList } from "./Visualize.tsx"; +import { Prompt } from "./Prompt.tsx"; +import { WelcomeBox } from "../components/WelcomeBox.tsx"; +import { StatusBar } from "../components/StatusBar.tsx"; +import { ApprovalPrompt } from "../components/ApprovalPrompt.tsx"; +import { CommandPanel } from "../components/CommandPanel.tsx"; +import { StreamingSpinner, CompactionSpinner } from "../components/Spinner.tsx"; +import { useWire } from "../hooks/useWire.ts"; +import { useKeyboard } from "./keyboard.ts"; +import { + createShellSlashCommands, + parseSlashCommand, + findSlashCommand, +} from "./slash.ts"; +import { setActiveTheme } from "../theme.ts"; +import type { WireUIEvent } from "./events.ts"; +import type { ApprovalResponseKind } from "../../wire/types.ts"; +import type { SlashCommand, CommandPanelConfig } from "../../types.ts"; + +const INPUT_MIN_HEIGHT = 6; + +/** Deduplicate commands by name, shell commands take priority */ +function deduplicateCommands(commands: SlashCommand[]): SlashCommand[] { + const seen = new Map(); + for (const cmd of commands) { + if (!seen.has(cmd.name)) { + seen.set(cmd.name, cmd); + } + } + return [...seen.values()]; +} + +export interface ShellProps { + modelName?: string; + workDir?: string; + sessionId?: string; + thinking?: boolean; + onSubmit?: (input: string) => void; + onInterrupt?: () => void; + onApprovalResponse?: ( + requestId: string, + decision: ApprovalResponseKind, + feedback?: string, + ) => void; + onWireReady?: (pushEvent: (event: WireUIEvent) => void) => void; + extraSlashCommands?: SlashCommand[]; +} + +export function Shell({ + modelName = "", + workDir, + sessionId, + thinking = false, + onSubmit, + onInterrupt, + onApprovalResponse, + onWireReady, + extraSlashCommands = [], +}: ShellProps) { + const { exit } = useApp(); + const { stdout } = useStdout(); + const [termHeight, setTermHeight] = useState(stdout?.rows || 24); + const [slashMenuVisible, setSlashMenuVisible] = useState(false); + const [activePanel, setActivePanel] = useState(null); + const [clearInputSignal, setClearInputSignal] = useState(0); + + // Wire state + const wire = useWire({ onReady: onWireReady }); + + // Helper to push notifications to chat area + const pushNotification = useCallback( + (title: string, body: string) => { + wire.pushEvent({ type: "notification", title, body }); + }, + [wire], + ); + + // Shell slash commands + const shellCommands = createShellSlashCommands({ + clearMessages: wire.clearMessages, + exit: () => exit(), + setTheme: (theme) => setActiveTheme(theme), + getAllCommands: () => allCommands, + pushNotification, + }); + + const allCommands = deduplicateCommands([ + ...shellCommands, + ...extraSlashCommands, + ]); + + // Handle terminal resize + useEffect(() => { + const onResize = () => setTermHeight(stdout?.rows || 24); + stdout?.on("resize", onResize); + return () => { + stdout?.off("resize", onResize); + }; + }, [stdout]); + + // Global keyboard handling: Ctrl+C / Esc + useKeyboard({ + onAction: (action) => { + switch (action) { + case "interrupt": + if (activePanel) { + // Close command panel on interrupt + setActivePanel(null); + } else if (wire.isStreaming) { + // Interrupt the running turn: abort the soul + push UI event + onInterrupt?.(); + wire.pushEvent({ type: "error", message: "Interrupted by user" }); + } + break; + case "clear-input": + // Double-Esc: clear the input box + setClearInputSignal((n) => n + 1); + break; + // "exit" is handled internally by useKeyboard (calls exit()) + } + }, + active: true, + }); + + // Handle user input + const handleSubmit = useCallback( + (input: string) => { + const parsed = parseSlashCommand(input); + if (parsed) { + const cmd = findSlashCommand(allCommands, parsed.name); + if (cmd) { + // If command has panel and no args provided, try opening panel + if (cmd.panel && !parsed.args) { + const panelConfig = cmd.panel(); + if (panelConfig) { + setActivePanel(panelConfig); + return; + } + } + cmd.handler(parsed.args); + return; + } + wire.pushEvent({ + type: "notification", + title: "Unknown command", + body: `/${parsed.name} is not a recognized command. Type /help for available commands.`, + }); + return; + } + onSubmit?.(input); + }, + [allCommands, onSubmit, wire], + ); + + // Handle opening a command panel from slash menu + const handleOpenPanel = useCallback( + (cmd: SlashCommand) => { + if (cmd.panel) { + const panelConfig = cmd.panel(); + if (panelConfig) { + setActivePanel(panelConfig); + return; + } + } + // Fallback: execute handler directly + cmd.handler(""); + }, + [], + ); + + // Close command panel + const handleClosePanel = useCallback(() => { + setActivePanel(null); + }, []); + + // Handle approval response + const handleApprovalResponse = useCallback( + (decision: ApprovalResponseKind, feedback?: string) => { + if (wire.pendingApproval) { + onApprovalResponse?.(wire.pendingApproval.id, decision, feedback); + wire.pushEvent({ + type: "approval_response", + requestId: wire.pendingApproval.id, + response: decision, + }); + } + }, + [wire.pendingApproval, onApprovalResponse, wire], + ); + + // Calculate status bar height (separator + 2 lines of status) + const statusBarHeight = slashMenuVisible ? 0 : 3; + + return ( + + {/* ═══ Top: Welcome box ═══ */} + + + {/* ═══ ChatList: height follows content ═══ */} + + + + {wire.isStreaming && !wire.isCompacting && ( + + )} + + + + {wire.pendingApproval && ( + + )} + + + {/* ═══ InputBox: fills remaining, min 6 lines, text at top ═══ */} + + {activePanel ? ( + + ) : ( + + )} + + + {/* ═══ Bottom: Status bar (always visible, hidden when slash menu or panel) ═══ */} + {!slashMenuVisible && !activePanel && ( + + )} + + ); +} diff --git a/src/kimi_cli/ui/shell/TaskBrowser.tsx b/src/kimi_cli/ui/shell/TaskBrowser.tsx new file mode 100644 index 000000000..3beb2bd78 --- /dev/null +++ b/src/kimi_cli/ui/shell/TaskBrowser.tsx @@ -0,0 +1,284 @@ +/** + * TaskBrowser.tsx — Background task browser with React Ink. + * Corresponds to Python's ui/shell/task_browser.py. + * + * Features: + * - Background task list + * - Detail/preview panel + * - Stop/confirm operations + * - Filter (all/active) + */ + +import React, { useState, useCallback } from "react"; +import { Box, Text, useInput } from "ink"; + +// ── Types ─────────────────────────────────────────────── + +export type TaskStatus = + | "running" + | "starting" + | "completed" + | "failed" + | "killed" + | "lost"; + +export type TaskBrowserFilter = "all" | "active"; + +export interface TaskViewSpec { + id: string; + description: string; + kind: string; + command?: string; + cwd?: string; + createdAt: number; +} + +export interface TaskViewRuntime { + status: TaskStatus; + exitCode?: number | null; + failureReason?: string; + startedAt?: number | null; + finishedAt?: number | null; + updatedAt: number; + timedOut?: boolean; +} + +export interface TaskView { + spec: TaskViewSpec; + runtime: TaskViewRuntime; +} + +export interface TaskBrowserProps { + tasks: TaskView[]; + onStop?: (taskId: string) => void; + onViewOutput?: (taskId: string) => void; + onRefresh?: () => void; + onClose?: () => void; +} + +const TERMINAL_STATUSES = new Set([ + "completed", + "failed", + "killed", + "lost", +]); + +function isTerminal(status: TaskStatus): boolean { + return TERMINAL_STATUSES.has(status); +} + +// ── Helpers ───────────────────────────────────────────── + +function formatDuration(seconds: number): string { + if (seconds < 60) return `${seconds}s`; + if (seconds < 3600) return `${Math.floor(seconds / 60)}m`; + const h = Math.floor(seconds / 3600); + const m = Math.floor((seconds % 3600) / 60); + return m > 0 ? `${h}h${m}m` : `${h}h`; +} + +function formatRelativeTime(ts: number): string { + const delta = Math.max(0, Math.floor(Date.now() / 1000 - ts)); + if (delta < 5) return "just now"; + if (delta < 60) return `${delta}s ago`; + if (delta < 3600) return `${Math.floor(delta / 60)}m ago`; + return `${Math.floor(delta / 3600)}h ago`; +} + +function taskTimingLabel(view: TaskView): string { + const now = Date.now() / 1000; + if (view.runtime.finishedAt != null) { + return `finished ${formatRelativeTime(view.runtime.finishedAt)}`; + } + if (view.runtime.startedAt != null) { + const seconds = Math.max(0, Math.floor(now - view.runtime.startedAt)); + return `running ${formatDuration(seconds)}`; + } + return `updated ${formatRelativeTime(view.runtime.updatedAt)}`; +} + +// ── TaskBrowser ───────────────────────────────────────── + +export function TaskBrowser({ + tasks, + onStop, + onViewOutput, + onRefresh, + onClose, +}: TaskBrowserProps) { + const [filterMode, setFilterMode] = useState("all"); + const [selectedIndex, setSelectedIndex] = useState(0); + const [pendingStopId, setPendingStopId] = useState(null); + const [flashMessage, setFlashMessage] = useState(""); + + // Filter tasks + const visibleTasks = + filterMode === "active" + ? tasks.filter((t) => !isTerminal(t.runtime.status)) + : [...tasks]; + + // Sort: active first, then by created time + visibleTasks.sort((a, b) => { + const aTerminal = isTerminal(a.runtime.status) ? 1 : 0; + const bTerminal = isTerminal(b.runtime.status) ? 1 : 0; + if (aTerminal !== bTerminal) return aTerminal - bTerminal; + return a.spec.createdAt - b.spec.createdAt; + }); + + const clampedIndex = Math.min(selectedIndex, Math.max(0, visibleTasks.length - 1)); + const selectedTask = visibleTasks[clampedIndex] || null; + + // Status counts + const counts: Record = { + running: 0, starting: 0, completed: 0, failed: 0, killed: 0, lost: 0, + }; + for (const t of tasks) { + counts[t.runtime.status] = (counts[t.runtime.status] || 0) + 1; + } + + const flash = useCallback((msg: string) => { + setFlashMessage(msg); + setTimeout(() => setFlashMessage(""), 3000); + }, []); + + useInput((input, key) => { + // Confirm stop mode + if (pendingStopId !== null) { + if (input === "y" || input === "Y") { + onStop?.(pendingStopId); + flash(`Stop requested for task ${pendingStopId}.`); + setPendingStopId(null); + return; + } + if (input === "n" || input === "N" || key.escape) { + flash("Stop cancelled."); + setPendingStopId(null); + return; + } + return; + } + + // Normal mode + if (key.upArrow) { + setSelectedIndex((i) => Math.max(0, i - 1)); + } else if (key.downArrow) { + setSelectedIndex((i) => Math.min(visibleTasks.length - 1, i + 1)); + } else if (input === "q" || key.escape) { + onClose?.(); + } else if (key.tab) { + const newFilter = filterMode === "all" ? "active" : "all"; + setFilterMode(newFilter); + flash(newFilter === "active" ? "Showing active tasks only." : "Showing all tasks."); + } else if (input === "r" || input === "R") { + onRefresh?.(); + flash("Refreshed."); + } else if (input === "s" || input === "S") { + if (selectedTask) { + if (isTerminal(selectedTask.runtime.status)) { + flash(`Task ${selectedTask.spec.id} is already ${selectedTask.runtime.status}.`); + } else { + setPendingStopId(selectedTask.spec.id); + } + } + } else if (key.return || input === "o") { + if (selectedTask) { + onViewOutput?.(selectedTask.spec.id); + } + } + }); + + return ( + + {/* Header */} + + TASK BROWSER + filter={filterMode.toUpperCase()} + {counts.running} running + {counts.starting} starting + {counts.failed} failed + {counts.completed} completed + {(counts.killed || 0) + (counts.lost || 0)} interrupted + {tasks.length} total + + + + {/* Task list */} + + Tasks [{filterMode}] + {visibleTasks.length === 0 ? ( + + {filterMode === "active" ? "No active background tasks." : "No background tasks in this session."} + + ) : ( + visibleTasks.map((task, idx) => { + const isSelected = idx === clampedIndex; + const description = task.spec.description.trim() || "(no description)"; + const timing = taskTimingLabel(task); + const line = `[${task.runtime.status}] ${description} · ${task.spec.id} · ${task.spec.kind} · ${timing}`; + return ( + + {isSelected ? ">" : " "} {line} + + ); + }) + )} + + + {/* Detail + Preview */} + + + Detail + {selectedTask ? ( + + Task ID: {selectedTask.spec.id} + Status: {selectedTask.runtime.status} + Description: {selectedTask.spec.description} + Kind: {selectedTask.spec.kind} + Time: {taskTimingLabel(selectedTask)} + {selectedTask.spec.cwd && Cwd: {selectedTask.spec.cwd}} + {selectedTask.spec.command && Command: {selectedTask.spec.command}} + {selectedTask.runtime.exitCode != null && Exit code: {selectedTask.runtime.exitCode}} + {selectedTask.runtime.failureReason && Reason: {selectedTask.runtime.failureReason}} + + ) : ( + Select a task from the list. + )} + + + Preview Output + + {selectedTask ? "Press Enter or O to view full output." : "No output to preview."} + + + + + + {/* Footer */} + + {pendingStopId !== null ? ( + <> + Confirm stop {pendingStopId}? + Y confirm + N cancel + + ) : ( + <> + Enter output + S stop + R refresh + Tab filter + Q exit + {flashMessage && | {flashMessage}} + + )} + + + ); +} + +export default TaskBrowser; diff --git a/src/kimi_cli/ui/shell/UsagePanel.tsx b/src/kimi_cli/ui/shell/UsagePanel.tsx new file mode 100644 index 000000000..f61b9c735 --- /dev/null +++ b/src/kimi_cli/ui/shell/UsagePanel.tsx @@ -0,0 +1,257 @@ +/** + * UsagePanel.tsx — API usage and quota display panel. + * Corresponds to Python's ui/shell/usage.py. + * + * Features: + * - API quota usage display + * - Progress bars + * - Reset timer + */ + +import React from "react"; +import { Box, Text } from "ink"; + +export interface UsageRow { + label: string; + used: number; + limit: number; + resetHint?: string | null; +} + +export interface UsagePanelProps { + summary?: UsageRow | null; + limits: UsageRow[]; + loading?: boolean; + error?: string | null; +} + +function ProgressBar({ + completed, + total, + width = 20, +}: { + completed: number; + total: number; + width?: number; +}) { + const ratio = total > 0 ? Math.min(completed / total, 1) : 0; + const filledWidth = Math.round(ratio * width); + const emptyWidth = width - filledWidth; + const color = ratioColor(total > 0 ? (total - completed) / total : 0); + + return ( + + {"█".repeat(filledWidth)} + {"░".repeat(emptyWidth)} + + ); +} + +function ratioColor(ratio: number): string { + if (ratio >= 0.9) return "red"; + if (ratio >= 0.7) return "yellow"; + return "green"; +} + +function UsageRowView({ row, labelWidth }: { row: UsageRow; labelWidth: number }) { + const remaining = row.limit > 0 ? (row.limit - row.used) / row.limit : 0; + const percent = remaining * 100; + + return ( + + + {row.label.padEnd(labelWidth)} + + + + + + {` ${percent.toFixed(0)}% left`} + {row.resetHint && ( + {` (${row.resetHint})`} + )} + + + ); +} + +export function UsagePanel({ summary, limits, loading, error }: UsagePanelProps) { + if (loading) { + return ( + + Fetching usage... + + ); + } + + if (error) { + return ( + + {error} + + ); + } + + const rows = [...(summary ? [summary] : []), ...limits]; + if (rows.length === 0) { + return ( + + No usage data + + ); + } + + const labelWidth = Math.max(6, ...rows.map((r) => r.label.length)); + + return ( + + API Usage + + {rows.map((row, idx) => ( + + ))} + + ); +} + +// ── Usage data parsing helpers ────────────────────────── + +export function parseUsagePayload(payload: Record): { + summary: UsageRow | null; + limits: UsageRow[]; +} { + let summary: UsageRow | null = null; + const limits: UsageRow[] = []; + + const usage = payload.usage; + if (usage && typeof usage === "object") { + summary = toUsageRow(usage as Record, "Weekly limit"); + } + + const rawLimits = payload.limits; + if (Array.isArray(rawLimits)) { + for (let idx = 0; idx < rawLimits.length; idx++) { + const item = rawLimits[idx]; + if (!item || typeof item !== "object") continue; + const itemMap = item as Record; + const detail = + itemMap.detail && typeof itemMap.detail === "object" + ? (itemMap.detail as Record) + : itemMap; + const window = + itemMap.window && typeof itemMap.window === "object" + ? (itemMap.window as Record) + : {}; + const label = limitLabel(itemMap, detail, window, idx); + const row = toUsageRow(detail, label); + if (row) limits.push(row); + } + } + + return { summary, limits }; +} + +function toUsageRow( + data: Record, + defaultLabel: string, +): UsageRow | null { + const limit = toInt(data.limit); + let used = toInt(data.used); + if (used == null) { + const remaining = toInt(data.remaining); + if (remaining != null && limit != null) { + used = limit - remaining; + } + } + if (used == null && limit == null) return null; + return { + label: String(data.name || data.title || defaultLabel), + used: used || 0, + limit: limit || 0, + resetHint: resetHint(data), + }; +} + +function limitLabel( + item: Record, + detail: Record, + window: Record, + idx: number, +): string { + for (const key of ["name", "title", "scope"]) { + const val = item[key] || detail[key]; + if (val) return String(val); + } + const duration = toInt( + window.duration || item.duration || detail.duration, + ); + const timeUnit = String( + window.timeUnit || item.timeUnit || detail.timeUnit || "", + ); + if (duration) { + if (timeUnit.includes("MINUTE")) { + if (duration >= 60 && duration % 60 === 0) return `${duration / 60}h limit`; + return `${duration}m limit`; + } + if (timeUnit.includes("HOUR")) return `${duration}h limit`; + if (timeUnit.includes("DAY")) return `${duration}d limit`; + return `${duration}s limit`; + } + return `Limit #${idx + 1}`; +} + +function resetHint(data: Record): string | null { + for (const key of ["reset_at", "resetAt", "reset_time", "resetTime"]) { + if (data[key]) return formatResetTime(String(data[key])); + } + for (const key of ["reset_in", "resetIn", "ttl", "window"]) { + const seconds = toInt(data[key]); + if (seconds) return `resets in ${formatDuration(seconds)}`; + } + return null; +} + +function formatResetTime(val: string): string { + try { + const dt = new Date(val); + const now = Date.now(); + const delta = dt.getTime() - now; + if (delta <= 0) return "reset"; + return `resets in ${formatDuration(Math.floor(delta / 1000))}`; + } catch { + return `resets at ${val}`; + } +} + +function formatDuration(seconds: number): string { + if (seconds < 60) return `${seconds}s`; + if (seconds < 3600) return `${Math.floor(seconds / 60)}m`; + const hours = Math.floor(seconds / 3600); + const mins = Math.floor((seconds % 3600) / 60); + return mins > 0 ? `${hours}h${mins}m` : `${hours}h`; +} + +function toInt(value: unknown): number | null { + if (value == null) return null; + const n = Number(value); + return Number.isFinite(n) ? Math.floor(n) : null; +} + +export default UsagePanel; diff --git a/src/kimi_cli/ui/shell/Visualize.tsx b/src/kimi_cli/ui/shell/Visualize.tsx new file mode 100644 index 000000000..4eaf90fad --- /dev/null +++ b/src/kimi_cli/ui/shell/Visualize.tsx @@ -0,0 +1,917 @@ +/** + * Visualize.tsx — Message visualization components. + * Corresponds to Python's ui/shell/visualize.py. + * + * Components: + * - MessageList: renders all messages with step headers + * - MessageView: single message with role-based styling + * - ToolCallView: tool call display (collapsible) with key argument extraction + * - StreamingText: streaming text with cursor + full markdown rendering + * - ThinkingView: thinking/reasoning display + * - CodeBlockView: syntax-highlighted code blocks + * - TableView: markdown table rendering + * - ErrorRecoveryView: API error classification display + */ + +import React, { useState, useMemo } from "react"; +import { Box, Text, Newline } from "ink"; +import chalk from "chalk"; +import { getStyles, getMessageColors, getDiffColors } from "../theme.ts"; +import type { + UIMessage, + MessageSegment, + TextSegment, + ThinkSegment, + ToolCallSegment, +} from "./events.ts"; +import type { ToolResult, DisplayBlock } from "../../wire/types.ts"; + +// ── MessageList ──────────────────────────────────────────── + +interface MessageListProps { + messages: UIMessage[]; + isStreaming: boolean; + stepCount?: number; +} + +export function MessageList({ messages, isStreaming, stepCount }: MessageListProps) { + return ( + + {messages.map((msg, idx) => ( + + ))} + + ); +} + +// ── MessageView ──────────────────────────────────────────── + +interface MessageViewProps { + message: UIMessage; + isLast: boolean; + isStreaming: boolean; + stepCount?: number; +} + +function MessageView({ message, isLast, isStreaming, stepCount }: MessageViewProps) { + const colors = getMessageColors(); + + const roleLabel = getRoleLabel(message.role); + const roleColor = getRoleColor(message.role, colors); + + return ( + + {/* Step count header for assistant messages */} + {message.role === "assistant" && stepCount != null && stepCount > 0 && ( + + ─── Step {stepCount} ─── + + )} + + {roleLabel} + + {message.segments.map((segment, idx) => ( + + ))} + + ); +} + +function getRoleLabel(role: string): string { + switch (role) { + case "user": + return "You"; + case "assistant": + return "Assistant"; + case "system": + return "System"; + case "tool": + return "Tool"; + default: + return role; + } +} + +function getRoleColor( + role: string, + colors: ReturnType, +): string { + switch (role) { + case "user": + return colors.user; + case "assistant": + return colors.assistant; + case "system": + return colors.system; + case "tool": + return colors.tool; + default: + return colors.dim; + } +} + +// ── SegmentView ──────────────────────────────────────────── + +interface SegmentViewProps { + segment: MessageSegment; + isStreaming: boolean; +} + +function SegmentView({ segment, isStreaming }: SegmentViewProps) { + switch (segment.type) { + case "text": + return ; + case "think": + return ; + case "tool_call": + return ; + default: + return null; + } +} + +// ── StreamingText ────────────────────────────────────────── + +interface StreamingTextProps { + text: string; + isStreaming: boolean; +} + +export function StreamingText({ text, isStreaming }: StreamingTextProps) { + const rendered = useMemo(() => renderMarkdown(text), [text]); + + return ( + + {rendered} + {isStreaming && } + + ); +} + +// ── ThinkingView ─────────────────────────────────────────── + +interface ThinkingViewProps { + text: string; +} + +export function ThinkingView({ text }: ThinkingViewProps) { + const colors = getMessageColors(); + // Truncate thinking to max 6 lines for preview + const lines = text.split("\n"); + const preview = lines.length > 6 + ? lines.slice(0, 6).join("\n") + `\n… ${lines.length - 6} more lines` + : text; + + return ( + + + 💭 {preview} + + + ); +} + +// ── ToolCallView ─────────────────────────────────────────── + +interface ToolCallViewProps { + toolCall: ToolCallSegment; +} + +export function ToolCallView({ toolCall }: ToolCallViewProps) { + const [collapsed, setCollapsed] = useState(toolCall.collapsed); + const colors = getMessageColors(); + const statusIcon = toolCall.result + ? toolCall.result.return_value.isError + ? "✗" + : "✓" + : "⟳"; + const statusColor = toolCall.result + ? toolCall.result.return_value.isError + ? colors.error + : colors.highlight + : colors.dim; + + // Format arguments for display — extract key argument + let argsPreview = ""; + try { + const parsed = JSON.parse(toolCall.arguments); + const key = extractKeyArgument(toolCall.name, parsed); + argsPreview = key || truncate(toolCall.arguments, 60); + } catch { + // Streaming JSON: show partial arguments + argsPreview = renderStreamingJson(toolCall.arguments); + } + + return ( + + + {statusIcon} + + {toolCall.name} + + {argsPreview} + + {!collapsed && toolCall.result && ( + + + + )} + + ); +} + +// ── ToolResultView ───────────────────────────────────────── + +interface ToolResultViewProps { + result: ToolResult; +} + +function ToolResultView({ result }: ToolResultViewProps) { + const colors = getMessageColors(); + const output = result.return_value.output; + const isError = result.return_value.isError; + const truncated = truncate(output, 500); + + return ( + + {result.display.map((block, idx) => ( + + ))} + {!result.display.length && ( + {truncated} + )} + + ); +} + +// ── DisplayBlockView ─────────────────────────────────────── + +interface DisplayBlockViewProps { + block: DisplayBlock; +} + +function DisplayBlockView({ block }: DisplayBlockViewProps) { + const colors = getMessageColors(); + const diffColors = getDiffColors(); + const b = block as Record; + + switch (block.type) { + case "brief": + return {b.brief as string}; + case "diff": + return ( + + ); + case "shell": + return ( + + $ + {b.command as string} + + ); + case "todo": { + const items = b.items as Array<{ + title: string; + status: string; + }>; + return ( + + {items.map((item, idx) => ( + + + {item.status === "done" + ? "✓" + : item.status === "in_progress" + ? "⟳" + : "○"}{" "} + {item.title} + + + ))} + + ); + } + case "background_task": { + return ( + + + [{b.kind as string}] + {b.description as string} + ({b.status as string}) + + ); + } + default: + return null; + } +} + +// ── DiffView ─────────────────────────────────────────────── + +function DiffView({ + block, +}: { + block: { path: string; old_text: string; new_text: string }; +}) { + const diffColors = getDiffColors(); + return ( + + + {block.path} + + {block.old_text + .split("\n") + .filter(Boolean) + .map((line, idx) => ( + + - {line} + + ))} + {block.new_text + .split("\n") + .filter(Boolean) + .map((line, idx) => ( + + + {line} + + ))} + + ); +} + +// ── ErrorRecoveryView ────────────────────────────────────── + +export interface ErrorInfo { + type: "rate_limit" | "server_error" | "network" | "auth" | "unknown"; + message: string; + retryable: boolean; + retryAfter?: number; +} + +export function ErrorRecoveryView({ error }: { error: ErrorInfo }) { + const icon = error.retryable ? "⟳" : "✗"; + const color = error.retryable ? "#f2cc60" : "#ff7b72"; + const typeLabel = { + rate_limit: "Rate Limited", + server_error: "Server Error", + network: "Network Error", + auth: "Authentication Error", + unknown: "Error", + }[error.type]; + + return ( + + + + {icon} {typeLabel} + + + + {error.message} + + {error.retryable && error.retryAfter && ( + + + Retrying in {error.retryAfter}s… + + + )} + + ); +} + +/** + * Classify API error for display. + */ +export function classifyApiError(err: unknown): ErrorInfo { + const msg = err instanceof Error ? err.message : String(err); + const lower = msg.toLowerCase(); + + if (lower.includes("429") || lower.includes("rate limit")) { + const retryMatch = lower.match(/retry.after.*?(\d+)/); + return { + type: "rate_limit", + message: msg, + retryable: true, + retryAfter: retryMatch ? parseInt(retryMatch[1]!, 10) : 60, + }; + } + if (lower.includes("500") || lower.includes("502") || lower.includes("503") || lower.includes("504")) { + return { type: "server_error", message: msg, retryable: true, retryAfter: 5 }; + } + if (lower.includes("timeout") || lower.includes("econnrefused") || lower.includes("network")) { + return { type: "network", message: msg, retryable: true, retryAfter: 3 }; + } + if (lower.includes("401") || lower.includes("403") || lower.includes("auth")) { + return { type: "auth", message: msg, retryable: false }; + } + return { type: "unknown", message: msg, retryable: false }; +} + +// ── Markdown Rendering ───────────────────────────────────── + +/** + * Full markdown rendering to React Ink components. + * Supports: headings, code blocks (with language hint), tables, lists, + * blockquotes, horizontal rules, and inline formatting. + */ +function renderMarkdown(text: string): React.ReactNode { + const lines = text.split("\n"); + const elements: React.ReactNode[] = []; + let i = 0; + + while (i < lines.length) { + const line = lines[i]!; + + // Fenced code block + const codeMatch = line.match(/^```(\w*)/); + if (codeMatch) { + const lang = codeMatch[1] || ""; + const codeLines: string[] = []; + i++; + while (i < lines.length && !lines[i]!.startsWith("```")) { + codeLines.push(lines[i]!); + i++; + } + if (i < lines.length) i++; // skip closing ``` + elements.push( + , + ); + continue; + } + + // Table detection (| header | header |) + if (line.includes("|") && line.trim().startsWith("|")) { + const tableLines: string[] = [line]; + i++; + while (i < lines.length && lines[i]!.includes("|") && lines[i]!.trim().startsWith("|")) { + tableLines.push(lines[i]!); + i++; + } + if (tableLines.length >= 2) { + elements.push( + , + ); + continue; + } + // Not a real table, render as text + for (const tl of tableLines) { + elements.push( + {renderInlineFormatting(tl)}, + ); + } + continue; + } + + // Heading + const headingMatch = line.match(/^(#{1,6})\s+(.+)/); + if (headingMatch) { + const level = headingMatch[1]!.length; + const headingText = headingMatch[2]!; + const color = level <= 2 ? "#56a4ff" : level <= 4 ? "#e6e6e6" : "#9ca3af"; + elements.push( + + {level <= 2 ? "█ " : level <= 4 ? "▌ " : "▎ "}{renderInlineFormatting(headingText)} + , + ); + i++; + continue; + } + + // Blockquote + if (line.startsWith("> ") || line === ">") { + const quoteLines: string[] = []; + while (i < lines.length && (lines[i]!.startsWith("> ") || lines[i] === ">")) { + quoteLines.push(lines[i]!.replace(/^>\s?/, "")); + i++; + } + elements.push( + + + {quoteLines.join("\n")} + + , + ); + continue; + } + + // Horizontal rule + if (/^(-{3,}|\*{3,}|_{3,})$/.test(line.trim())) { + elements.push( + + {"─".repeat(60)} + , + ); + i++; + continue; + } + + // Unordered list + if (/^\s*[-*+]\s+/.test(line)) { + const indent = line.match(/^(\s*)/)?.[1]?.length ?? 0; + const bullet = indent >= 4 ? " ◦ " : indent >= 2 ? " ◦ " : "• "; + const content = line.replace(/^\s*[-*+]\s+/, ""); + elements.push( + + {" ".repeat(Math.floor(indent / 2))}{bullet}{renderInlineFormatting(content)} + , + ); + i++; + continue; + } + + // Ordered list + const olMatch = line.match(/^\s*(\d+)[.)]\s+(.*)/); + if (olMatch) { + const num = olMatch[1]!; + const content = olMatch[2]!; + elements.push( + + {" "}{chalk.bold(num + ".")} {renderInlineFormatting(content)} + , + ); + i++; + continue; + } + + // Regular text + if (line.trim() === "") { + elements.push({" "}); + } else { + elements.push( + {renderInlineFormatting(line)}, + ); + } + i++; + } + + return <>{elements}; +} + +// ── Code Block View ──────────────────────────────────────── + +function CodeBlockView({ code, language }: { code: string; language: string }) { + const colors = getMessageColors(); + + // Simple keyword-based syntax coloring + const highlighted = language + ? highlightCode(code, language) + : code; + + return ( + + {language && ( + + {language} + + )} + {highlighted} + + ); +} + +/** + * Simple syntax highlighting using chalk. + * Covers common patterns: keywords, strings, comments, numbers. + */ +function highlightCode(code: string, language: string): string { + const lang = language.toLowerCase(); + + // Language-specific keywords + const KEYWORDS: Record = { + js: ["const", "let", "var", "function", "return", "if", "else", "for", "while", "class", "import", "export", "from", "async", "await", "new", "this", "try", "catch", "throw", "typeof", "instanceof", "switch", "case", "default", "break", "continue"], + ts: ["const", "let", "var", "function", "return", "if", "else", "for", "while", "class", "import", "export", "from", "async", "await", "new", "this", "try", "catch", "throw", "typeof", "instanceof", "interface", "type", "enum", "implements", "extends", "switch", "case", "default", "break", "continue"], + typescript: ["const", "let", "var", "function", "return", "if", "else", "for", "while", "class", "import", "export", "from", "async", "await", "new", "this", "try", "catch", "throw", "typeof", "instanceof", "interface", "type", "enum", "implements", "extends", "switch", "case", "default", "break", "continue"], + javascript: ["const", "let", "var", "function", "return", "if", "else", "for", "while", "class", "import", "export", "from", "async", "await", "new", "this", "try", "catch", "throw", "typeof", "instanceof", "switch", "case", "default", "break", "continue"], + python: ["def", "class", "return", "if", "elif", "else", "for", "while", "import", "from", "as", "try", "except", "raise", "with", "yield", "lambda", "pass", "break", "continue", "and", "or", "not", "in", "is", "None", "True", "False", "self", "async", "await"], + py: ["def", "class", "return", "if", "elif", "else", "for", "while", "import", "from", "as", "try", "except", "raise", "with", "yield", "lambda", "pass", "break", "continue", "and", "or", "not", "in", "is", "None", "True", "False", "self", "async", "await"], + rust: ["fn", "let", "mut", "const", "if", "else", "for", "while", "loop", "match", "struct", "enum", "impl", "trait", "pub", "use", "mod", "crate", "self", "super", "return", "async", "await", "move", "type", "where"], + go: ["func", "var", "const", "if", "else", "for", "range", "switch", "case", "default", "return", "type", "struct", "interface", "package", "import", "go", "chan", "select", "defer", "map", "make", "new", "nil", "true", "false"], + bash: ["if", "then", "else", "elif", "fi", "for", "while", "do", "done", "case", "esac", "function", "return", "local", "export", "echo", "exit"], + sh: ["if", "then", "else", "elif", "fi", "for", "while", "do", "done", "case", "esac", "function", "return", "local", "export", "echo", "exit"], + }; + + const keywords = KEYWORDS[lang] || []; + if (keywords.length === 0) return code; + + // Apply highlighting line-by-line + return code.split("\n").map((line) => { + // Comments + if (line.trimStart().startsWith("//") || line.trimStart().startsWith("#")) { + return chalk.hex("#6b7280")(line); + } + + // String literals (basic) + let result = line; + result = result.replace(/(["'`])(?:(?!\1|\\).|\\.)*\1/g, (m) => chalk.hex("#a5d6a7")(m)); + + // Numbers + result = result.replace(/\b(\d+\.?\d*)\b/g, (m) => chalk.hex("#f2cc60")(m)); + + // Keywords + const kwPattern = new RegExp(`\\b(${keywords.join("|")})\\b`, "g"); + result = result.replace(kwPattern, (m) => chalk.hex("#c792ea")(m)); + + return result; + }).join("\n"); +} + +// ── Table View ───────────────────────────────────────────── + +function TableView({ lines }: { lines: string[] }) { + const colors = getMessageColors(); + + // Parse table + const rows = lines + .filter((line) => !line.match(/^\|[\s-:|]+\|$/)) // Skip separator rows + .map((line) => + line + .split("|") + .slice(1, -1) + .map((cell) => cell.trim()), + ); + + if (rows.length === 0) return null; + + const header = rows[0]!; + const body = rows.slice(1); + + // Calculate column widths + const colWidths = header.map((h, colIdx) => { + const maxContent = Math.max( + h.length, + ...body.map((row) => (row[colIdx] || "").length), + ); + return Math.min(maxContent + 2, 40); + }); + + const separator = "┼" + colWidths.map((w) => "─".repeat(w)).join("┼") + "┼"; + + return ( + + {/* Header */} + + {"┌" + colWidths.map((w) => "─".repeat(w)).join("┬") + "┐"} + + + {"│"} + {header.map((cell, idx) => ( + chalk.bold(cell.padEnd(colWidths[idx]!)) + )).join("│")} + {"│"} + + + {"├" + colWidths.map((w) => "─".repeat(w)).join("┼") + "┤"} + + {/* Body */} + {body.map((row, rowIdx) => ( + + {"│"} + {row.map((cell, colIdx) => ( + (cell || "").padEnd(colWidths[colIdx]!) + )).join("│")} + {"│"} + + ))} + + {"└" + colWidths.map((w) => "─".repeat(w)).join("┴") + "┘"} + + + ); +} + +// ── Inline Formatting ────────────────────────────────────── + +/** + * Render inline markdown formatting: bold, italic, code, strikethrough, links. + */ +function renderInlineFormatting(text: string): string { + return text + // Bold + italic + .replace(/\*\*\*(.+?)\*\*\*/g, (_, p1) => chalk.bold.italic(p1)) + // Bold + .replace(/\*\*(.+?)\*\*/g, (_, p1) => chalk.bold(p1)) + // Italic + .replace(/\*(.+?)\*/g, (_, p1) => chalk.italic(p1)) + .replace(/_(.+?)_/g, (_, p1) => chalk.italic(p1)) + // Strikethrough + .replace(/~~(.+?)~~/g, (_, p1) => chalk.strikethrough(p1)) + // Inline code + .replace(/`(.+?)`/g, (_, p1) => chalk.cyan(p1)) + // Links [text](url) + .replace(/\[(.+?)\]\((.+?)\)/g, (_, text, url) => + chalk.underline.hex("#56a4ff")(text) + chalk.hex("#6b7280")(` (${url})`), + ); +} + +// ── Streaming JSON Rendering ─────────────────────────────── + +/** + * Render partial/streaming JSON arguments for tool calls. + * Shows key-value pairs as they arrive. + */ +function renderStreamingJson(partial: string): string { + // Try to extract readable key-value pairs from partial JSON + const pairs: string[] = []; + const kvPattern = /"(\w+)":\s*"([^"]*)"?/g; + let match; + while ((match = kvPattern.exec(partial)) !== null) { + const key = match[1]!; + const value = match[2]!; + if (key.length < 20 && value.length < 80) { + pairs.push(`${key}=${truncate(value, 40)}`); + } + } + if (pairs.length > 0) { + return pairs.slice(0, 3).join(", "); + } + return truncate(partial, 60); +} + +// ── NotificationView ──────────────────────────────────────── + +export interface NotificationViewProps { + title: string; + body: string; + severity?: string; +} + +export function NotificationView({ title, body, severity }: NotificationViewProps) { + const icon = severity === "error" ? "✗" : severity === "warning" ? "⚠" : "ℹ"; + const color = severity === "error" ? "#ff7b72" : severity === "warning" ? "#f2cc60" : "#56a4ff"; + + return ( + + + {icon} {title} + + {body && ( + + {body} + + )} + + ); +} + +// ── StatusView (context token usage) ──────────────────────── + +export interface StatusViewProps { + contextTokens: number; + maxContextTokens: number; + contextUsage?: number | null; +} + +export function StatusView({ contextTokens, maxContextTokens, contextUsage }: StatusViewProps) { + const ratio = maxContextTokens > 0 ? contextTokens / maxContextTokens : 0; + const percent = (ratio * 100).toFixed(0); + const barWidth = 20; + const filled = Math.round(ratio * barWidth); + const empty = barWidth - filled; + const color = ratio >= 0.9 ? "#ff7b72" : ratio >= 0.7 ? "#f2cc60" : "#56d364"; + + return ( + + context + {"█".repeat(filled)} + {"░".repeat(empty)} + {percent}% ({(contextTokens / 1000).toFixed(1)}k/{(maxContextTokens / 1000).toFixed(1)}k) + + ); +} + +// ── PlanDisplayView ───────────────────────────────────────── + +export function PlanDisplayView({ content, filePath }: { content: string; filePath: string }) { + const rendered = renderMarkdown(content); + return ( + + + 📋 Plan + ({filePath}) + + + {rendered} + + + ); +} + +// ── HookView ──────────────────────────────────────────────── + +export function HookTriggeredView({ event, target, hookCount }: { event: string; target: string; hookCount: number }) { + return ( + + ⟳ hook + {event} + {target && → {target}} + {hookCount > 1 && ({hookCount} hooks)} + + ); +} + +export function HookResolvedView({ event, target, action, reason, durationMs }: { event: string; target: string; action: string; reason: string; durationMs: number }) { + const icon = action === "allow" ? "✓" : "✗"; + const color = action === "allow" ? "#56d364" : "#ff7b72"; + return ( + + {icon} hook + {event} + {target && → {target}} + ({action}{reason ? `: ${reason}` : ""}) {durationMs}ms + + ); +} + +// ── Enhanced DiffView with line numbers and context ───────── + +function EnhancedDiffView({ + block, +}: { + block: { path: string; old_text: string; new_text: string; old_start?: number; new_start?: number }; +}) { + const diffColors = getDiffColors(); + const oldStart = block.old_start ?? 1; + const newStart = block.new_start ?? 1; + const oldLines = block.old_text.split("\n").filter(Boolean); + const newLines = block.new_text.split("\n").filter(Boolean); + + // Determine max line number width for alignment + const maxLineNum = Math.max(oldStart + oldLines.length, newStart + newLines.length); + const lineNumWidth = String(maxLineNum).length; + + return ( + + + {block.path} + + + @@ -{oldStart},{oldLines.length} +{newStart},{newLines.length} @@ + + {oldLines.map((line, idx) => ( + + {String(oldStart + idx).padStart(lineNumWidth)} - {line} + + ))} + {newLines.map((line, idx) => ( + + {String(newStart + idx).padStart(lineNumWidth)} + {line} + + ))} + + ); +} + +// ── Helpers ──────────────────────────────────────────────── + +function truncate(text: string, maxLen: number): string { + if (text.length <= maxLen) return text; + return `${text.slice(0, maxLen)}…`; +} + +/** + * Extract the most relevant argument from a tool call for preview. + */ +function extractKeyArgument( + toolName: string, + args: Record, +): string { + // Try common key argument names + const keyNames = ["path", "file_path", "command", "query", "url", "name", "pattern", "description"]; + for (const key of keyNames) { + if (key in args && typeof args[key] === "string") { + return truncate(args[key] as string, 80); + } + } + // Fall back to first string argument + for (const [_, val] of Object.entries(args)) { + if (typeof val === "string" && val.length < 100) { + return val; + } + } + return ""; +} diff --git a/src/kimi_cli/ui/shell/__init__.py b/src/kimi_cli/ui/shell/__init__.py deleted file mode 100644 index 4628423d2..000000000 --- a/src/kimi_cli/ui/shell/__init__.py +++ /dev/null @@ -1,991 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import shlex -import time -from collections import deque -from collections.abc import Awaitable, Callable, Coroutine -from dataclasses import dataclass -from enum import Enum -from typing import Any - -from kosong.chat_provider import APIStatusError, ChatProviderError -from rich.console import Group, RenderableType -from rich.panel import Panel -from rich.table import Table -from rich.text import Text - -from kimi_cli import logger -from kimi_cli.background import list_task_views -from kimi_cli.notifications import NotificationManager, NotificationWatcher -from kimi_cli.soul import LLMNotSet, LLMNotSupported, MaxStepsReached, RunCancelled, Soul, run_soul -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.ui.shell import update as _update_mod -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.echo import render_user_echo_text -from kimi_cli.ui.shell.mcp_status import render_mcp_prompt -from kimi_cli.ui.shell.prompt import ( - CustomPromptSession, - PromptMode, - UserInput, - toast, -) -from kimi_cli.ui.shell.replay import replay_recent_history -from kimi_cli.ui.shell.slash import registry as shell_slash_registry -from kimi_cli.ui.shell.slash import shell_mode_registry -from kimi_cli.ui.shell.update import LATEST_VERSION_FILE, UpdateResult, do_update, semver_tuple -from kimi_cli.ui.shell.visualize import ( - ApprovalPromptDelegate, - visualize, -) -from kimi_cli.utils.aioqueue import QueueShutDown -from kimi_cli.utils.envvar import get_env_bool -from kimi_cli.utils.logging import open_original_stderr -from kimi_cli.utils.signals import install_sigint_handler -from kimi_cli.utils.slashcmd import SlashCommand, SlashCommandCall, parse_slash_command_call -from kimi_cli.utils.subprocess_env import get_clean_env -from kimi_cli.utils.term import ensure_new_line, ensure_tty_sane -from kimi_cli.wire.types import ( - ApprovalRequest, - ApprovalResponse, - ContentPart, - StatusUpdate, - WireMessage, -) - - -@dataclass(slots=True) -class _PromptEvent: - kind: str - user_input: UserInput | None = None - - -_MAX_BG_AUTO_TRIGGER_FAILURES = 3 -"""Stop auto-triggering after this many consecutive failures.""" - - -class _BackgroundCompletionWatcher: - """Watches for background task completions and auto-triggers the agent. - - Sits between the idle event loop and the soul: when a background task - finishes while the agent is idle *and* the LLM hasn't consumed the - notification yet, it triggers a soul run. - """ - - def __init__(self, soul: Soul) -> None: - self._event: asyncio.Event | None = None - self._notifications: NotificationManager | None = None - if isinstance(soul, KimiSoul): - self._event = soul.runtime.background_tasks.completion_event - self._notifications = soul.runtime.notifications - - @property - def enabled(self) -> bool: - return self._event is not None - - def clear(self) -> None: - """Clear stale signals from the previous soul run.""" - if self._event is not None: - self._event.clear() - - async def wait_for_next(self, idle_events: asyncio.Queue[_PromptEvent]) -> _PromptEvent | None: - """Wait for either a user prompt event or a background completion. - - Returns the prompt event if user input arrived first, or ``None`` - if a background task completed with unclaimed LLM notifications. - User input always takes priority over background completions. - """ - if self.enabled and self._has_pending_llm_notifications(): - # Pending notifications exist, but user input still wins. - try: - return idle_events.get_nowait() - except asyncio.QueueEmpty: - return None - - idle_task = asyncio.create_task(idle_events.get()) - if not self.enabled: - return await idle_task - - assert self._event is not None - bg_wait_task = asyncio.create_task(self._event.wait()) - - done, _ = await asyncio.wait( - [idle_task, bg_wait_task], - return_when=asyncio.FIRST_COMPLETED, - ) - for t in (idle_task, bg_wait_task): - if t not in done: - t.cancel() - with contextlib.suppress(asyncio.CancelledError): - await t - - if idle_task in done: - if bg_wait_task in done: - self._event.clear() - return idle_task.result() - - # Only bg fired - self._event.clear() - if self._has_pending_llm_notifications(): - return None - return _PromptEvent(kind="bg_noop") - - def _has_pending_llm_notifications(self) -> bool: - if self._notifications is None: - return False - return self._notifications.has_pending_for_sink("llm") - - -class Shell: - def __init__(self, soul: Soul, welcome_info: list[WelcomeInfoItem] | None = None): - self.soul = soul - self._welcome_info = list(welcome_info or []) - self._background_tasks: set[asyncio.Task[Any]] = set() - self._prompt_session: CustomPromptSession | None = None - self._running_input_handler: Callable[[UserInput], None] | None = None - self._running_interrupt_handler: Callable[[], None] | None = None - self._active_approval_sink: Any | None = None - self._pending_approval_requests = deque[ApprovalRequest]() - self._current_prompt_approval_request: ApprovalRequest | None = None - self._approval_modal: ApprovalPromptDelegate | None = None - self._exit_after_run = False - self._available_slash_commands: dict[str, SlashCommand[Any]] = { - **{cmd.name: cmd for cmd in soul.available_slash_commands}, - **{cmd.name: cmd for cmd in shell_slash_registry.list_commands()}, - } - """Shell-level slash commands + soul-level slash commands. Name to command mapping.""" - - @property - def available_slash_commands(self) -> dict[str, SlashCommand[Any]]: - """Get all available slash commands, including shell-level and soul-level commands.""" - return self._available_slash_commands - - @staticmethod - def _should_exit_input(user_input: UserInput) -> bool: - return user_input.command.strip() in {"exit", "quit", "/exit", "/quit"} - - @staticmethod - def _agent_slash_command_call(user_input: UserInput) -> SlashCommandCall | None: - if user_input.mode != PromptMode.AGENT: - return None - display_call = parse_slash_command_call(user_input.command) - if display_call is None: - return None - resolved_call = parse_slash_command_call(user_input.resolved_command) - if resolved_call is None or resolved_call.name != display_call.name: - return display_call - return resolved_call - - @staticmethod - def _should_echo_agent_input(user_input: UserInput) -> bool: - if user_input.mode != PromptMode.AGENT: - return False - if Shell._should_exit_input(user_input): - return False - return Shell._agent_slash_command_call(user_input) is None - - @staticmethod - def _echo_agent_input(user_input: UserInput) -> None: - console.print(render_user_echo_text(user_input.command)) - - def _bind_running_input( - self, - on_input: Callable[[UserInput], None], - on_interrupt: Callable[[], None], - ) -> None: - self._running_input_handler = on_input - self._running_interrupt_handler = on_interrupt - - def _unbind_running_input(self) -> None: - self._running_input_handler = None - self._running_interrupt_handler = None - - async def _route_prompt_events( - self, - prompt_session: CustomPromptSession, - idle_events: asyncio.Queue[_PromptEvent], - resume_prompt: asyncio.Event, - ) -> None: - while True: - # Keep exactly one active prompt read. Idle submissions pause the - # router until the shell decides whether the next prompt should - # wait for a blocking action or stay live during an agent run. - await resume_prompt.wait() - ensure_tty_sane() - try: - ensure_new_line() - user_input = await prompt_session.prompt_next() - except KeyboardInterrupt: - logger.debug("Prompt router got KeyboardInterrupt") - if ( - self._running_input_handler is not None - and prompt_session.running_prompt_accepts_submission() - ): - if self._running_interrupt_handler is not None: - self._running_interrupt_handler() - continue - resume_prompt.clear() - await idle_events.put(_PromptEvent(kind="interrupt")) - continue - except EOFError: - logger.debug("Prompt router got EOF") - if ( - self._running_input_handler is not None - and prompt_session.running_prompt_accepts_submission() - ): - self._exit_after_run = True - if self._running_interrupt_handler is not None: - self._running_interrupt_handler() - return - resume_prompt.clear() - await idle_events.put(_PromptEvent(kind="eof")) - return - except Exception: - logger.exception("Prompt router crashed") - resume_prompt.clear() - await idle_events.put(_PromptEvent(kind="error")) - return - - if prompt_session.last_submission_was_running: # noqa: SIM102 - if self._running_input_handler is not None: - if user_input: - self._running_input_handler(user_input) - continue - # Handler already unbound — fall through to idle path. - - resume_prompt.clear() - await idle_events.put(_PromptEvent(kind="input", user_input=user_input)) - - async def run(self, command: str | None = None) -> bool: - # Initialize theme from config - if isinstance(self.soul, KimiSoul): - from kimi_cli.ui.theme import set_active_theme - - set_active_theme(self.soul.runtime.config.theme) - - if command is not None: - # run single command and exit - logger.info("Running agent with command: {command}", command=command) - if isinstance(self.soul, KimiSoul): - self._start_background_task(self._watch_root_wire_hub()) - try: - return await self.run_soul_command(command) - finally: - self._cancel_background_tasks() - - # Start auto-update background task if not disabled - if get_env_bool("KIMI_CLI_NO_AUTO_UPDATE"): - logger.info("Auto-update disabled by KIMI_CLI_NO_AUTO_UPDATE environment variable") - else: - self._start_background_task(self._auto_update()) - - _print_welcome_info(self.soul.name or "Kimi Code CLI", self._welcome_info) - - if isinstance(self.soul, KimiSoul): - watcher = NotificationWatcher( - self.soul.runtime.notifications, - sink="shell", - before_poll=self.soul.runtime.background_tasks.reconcile, - on_notification=lambda notification: toast( - f"[{notification.event.type}] {notification.event.title}", - topic="notification", - duration=10.0, - ), - ) - self._start_background_task(watcher.run_forever()) - self._start_background_task(self._watch_root_wire_hub()) - await replay_recent_history( - self.soul.context.history, - wire_file=self.soul.wire_file, - ) - await self.soul.start_background_mcp_loading() - - async def _plan_mode_toggle() -> bool: - if isinstance(self.soul, KimiSoul): - return await self.soul.toggle_plan_mode_from_manual() - return False - - def _mcp_status_block(columns: int): - if not isinstance(self.soul, KimiSoul): - return None - snapshot = self.soul.status.mcp_status - if snapshot is None: - return None - return render_mcp_prompt(snapshot) - - def _mcp_status_loading() -> bool: - if not isinstance(self.soul, KimiSoul): - return False - snapshot = self.soul.status.mcp_status - return bool(snapshot and snapshot.loading) - - @dataclass - class _BgCountCache: - time: float = 0.0 - count: int = 0 - - _bg_cache = _BgCountCache() - - def _bg_task_count() -> int: - if not isinstance(self.soul, KimiSoul): - return 0 - now = time.monotonic() - if now - _bg_cache.time < 1.0: - return _bg_cache.count - views = list_task_views(self.soul.runtime.background_tasks, active_only=True) - _bg_cache.count = sum(1 for v in views if v.spec.kind == "bash") - _bg_cache.time = now - return _bg_cache.count - - with CustomPromptSession( - status_provider=lambda: self.soul.status, - status_block_provider=_mcp_status_block, - fast_refresh_provider=_mcp_status_loading, - background_task_count_provider=_bg_task_count, - model_capabilities=self.soul.model_capabilities or set(), - model_name=self.soul.model_name, - thinking=self.soul.thinking or False, - agent_mode_slash_commands=list(self._available_slash_commands.values()), - shell_mode_slash_commands=shell_mode_registry.list_commands(), - editor_command_provider=lambda: ( - self.soul.runtime.config.default_editor if isinstance(self.soul, KimiSoul) else "" - ), - plan_mode_toggle_callback=_plan_mode_toggle, - ) as prompt_session: - self._prompt_session = prompt_session - if isinstance(self.soul, KimiSoul): - kimi_soul = self.soul - snapshot = kimi_soul.status.mcp_status - if snapshot and snapshot.loading: - - async def _invalidate_after_mcp_loading() -> None: - try: - await kimi_soul.wait_for_background_mcp_loading() - except Exception: - logger.debug("MCP loading finished with error while refreshing prompt") - if self._prompt_session is prompt_session: - prompt_session.invalidate() - - self._start_background_task(_invalidate_after_mcp_loading()) - self._exit_after_run = False - idle_events: asyncio.Queue[_PromptEvent] = asyncio.Queue() - # resume_prompt controls whether the prompt router reads input. - # Set BEFORE an await = prompt stays live during the operation - # (agent runs that accept steer input); set AFTER = prompt is - # paused until the operation finishes. - resume_prompt = asyncio.Event() - resume_prompt.set() - prompt_task = asyncio.create_task( - self._route_prompt_events(prompt_session, idle_events, resume_prompt) - ) - bg_watcher = _BackgroundCompletionWatcher(self.soul) - - shell_ok = True - bg_auto_failures = 0 - try: - while True: - bg_watcher.clear() - if bg_auto_failures >= _MAX_BG_AUTO_TRIGGER_FAILURES: - result = await idle_events.get() - else: - result = await bg_watcher.wait_for_next(idle_events) - - if result is None: - logger.info("Background task completed while idle, triggering agent") - resume_prompt.set() - ok = await self.run_soul_command( - "" - "Background tasks completed while you" - " were idle." - "" - ) - console.print() - if not ok: - bg_auto_failures += 1 - logger.warning( - "Background auto-trigger failed ({n}/{max})", - n=bg_auto_failures, - max=_MAX_BG_AUTO_TRIGGER_FAILURES, - ) - else: - bg_auto_failures = 0 - if self._exit_after_run: - console.print("Bye!") - break - continue - - event = result - - if event.kind == "bg_noop": - continue - - if event.kind == "interrupt": - console.print("[grey50]Tip: press Ctrl-D or send 'exit' to quit[/grey50]") - resume_prompt.set() - continue - - if event.kind == "eof": - console.print("Bye!") - break - - if event.kind == "error": - shell_ok = False - break - - user_input = event.user_input - assert user_input is not None - bg_auto_failures = 0 - if not user_input: - logger.debug("Got empty input, skipping") - resume_prompt.set() - continue - logger.debug("Got user input: {user_input}", user_input=user_input) - - if self._should_echo_agent_input(user_input): - self._echo_agent_input(user_input) - - if self._should_exit_input(user_input): - logger.debug("Exiting by slash command") - console.print("Bye!") - break - - if user_input.mode == PromptMode.SHELL: - await self._run_shell_command(user_input.command) - resume_prompt.set() - continue - - if slash_cmd_call := self._agent_slash_command_call(user_input): - is_soul_slash = ( - slash_cmd_call.name in self._available_slash_commands - and shell_slash_registry.find_command(slash_cmd_call.name) is None - ) - if is_soul_slash: - resume_prompt.set() - await self.run_soul_command(slash_cmd_call.raw_input) - console.print() - if self._exit_after_run: - console.print("Bye!") - break - else: - await self._run_slash_command(slash_cmd_call) - resume_prompt.set() - continue - - resume_prompt.set() - await self.run_soul_command(user_input.content) - console.print() - if self._exit_after_run: - console.print("Bye!") - break - finally: - prompt_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await prompt_task - self._running_input_handler = None - self._running_interrupt_handler = None - if self._prompt_session is prompt_session and self._approval_modal is not None: - prompt_session.detach_modal(self._approval_modal) - self._approval_modal = None - self._prompt_session = None - self._cancel_background_tasks() - ensure_tty_sane() - - return shell_ok - - async def _run_shell_command(self, command: str) -> None: - """Run a shell command in foreground.""" - if not command.strip(): - return - - # Check if it's an allowed slash command in shell mode - if slash_cmd_call := parse_slash_command_call(command): - if shell_mode_registry.find_command(slash_cmd_call.name): - await self._run_slash_command(slash_cmd_call) - return - else: - console.print( - f'[yellow]"/{slash_cmd_call.name}" is not available in shell mode. ' - "Press Ctrl-X to switch to agent mode.[/yellow]" - ) - return - - # Check if user is trying to use 'cd' command - stripped_cmd = command.strip() - split_cmd: list[str] | None = None - try: - split_cmd = shlex.split(stripped_cmd) - except ValueError as exc: - logger.debug("Failed to parse shell command for cd check: {error}", error=exc) - if split_cmd and len(split_cmd) == 2 and split_cmd[0] == "cd": - console.print( - "[yellow]Warning: Directory changes are not preserved across command executions." - "[/yellow]" - ) - return - - logger.info("Running shell command: {cmd}", cmd=command) - - proc: asyncio.subprocess.Process | None = None - - def _handler(): - logger.debug("SIGINT received.") - if proc: - proc.terminate() - - loop = asyncio.get_running_loop() - remove_sigint = install_sigint_handler(loop, _handler) - try: - # TODO: For the sake of simplicity, we now use `create_subprocess_shell`. - # Later we should consider making this behave like a real shell. - with open_original_stderr() as stderr: - kwargs: dict[str, Any] = {} - if stderr is not None: - kwargs["stderr"] = stderr - proc = await asyncio.create_subprocess_shell(command, env=get_clean_env(), **kwargs) - await proc.wait() - except Exception as e: - logger.exception("Failed to run shell command:") - console.print(f"[red]Failed to run shell command: {e}[/red]") - finally: - remove_sigint() - - async def _run_slash_command(self, command_call: SlashCommandCall) -> None: - from kimi_cli.cli import Reload, SwitchToVis, SwitchToWeb - - if command_call.name not in self._available_slash_commands: - logger.info("Unknown slash command /{command}", command=command_call.name) - console.print( - f'[red]Unknown slash command "/{command_call.name}", ' - 'type "/" for all available commands[/red]' - ) - return - - command = shell_slash_registry.find_command(command_call.name) - if command is None: - # the input is a soul-level slash command call - await self.run_soul_command(command_call.raw_input) - return - - logger.debug( - "Running shell-level slash command: /{command} with args: {args}", - command=command_call.name, - args=command_call.args, - ) - - try: - ret = command.func(self, command_call.args) - if isinstance(ret, Awaitable): - await ret - except (Reload, SwitchToWeb, SwitchToVis): - # just propagate - raise - except (asyncio.CancelledError, KeyboardInterrupt): - # Handle Ctrl-C during slash command execution, return to shell prompt - logger.debug("Slash command interrupted by KeyboardInterrupt") - console.print("[red]Interrupted by user[/red]") - except Exception as e: - logger.exception("Unknown error:") - console.print(f"[red]Unknown error: {e}[/red]") - raise # re-raise unknown error - - async def run_soul_command(self, user_input: str | list[ContentPart]) -> bool: - """ - Run the soul and handle any known exceptions. - - Returns: - bool: Whether the run is successful. - """ - logger.info("Running soul with user input: {user_input}", user_input=user_input) - - cancel_event = asyncio.Event() - - def _handler(): - logger.debug("SIGINT received.") - cancel_event.set() - - loop = asyncio.get_running_loop() - remove_sigint = install_sigint_handler(loop, _handler) - - try: - snap = self.soul.status - runtime = self.soul.runtime if isinstance(self.soul, KimiSoul) else None - await run_soul( - self.soul, - user_input, - lambda wire: visualize( - wire.ui_side(merge=False), # shell UI maintain its own merge buffer - initial_status=StatusUpdate( - context_usage=snap.context_usage, - context_tokens=snap.context_tokens, - max_context_tokens=snap.max_context_tokens, - mcp_status=snap.mcp_status, - ), - cancel_event=cancel_event, - prompt_session=self._prompt_session, - steer=self.soul.steer if isinstance(self.soul, KimiSoul) else None, - bind_running_input=self._bind_running_input, - unbind_running_input=self._unbind_running_input, - on_view_ready=self._set_active_approval_sink, - on_view_closed=self._clear_active_approval_sink, - ), - cancel_event, - runtime.session.wire_file if runtime else None, - runtime, - ) - return True - except LLMNotSet: - logger.exception("LLM not set:") - console.print('[red]LLM not set, send "/login" to login[/red]') - except LLMNotSupported as e: - # actually unsupported input/mode should already be blocked by prompt session - logger.exception("LLM not supported:") - console.print(f"[red]{e}[/red]") - except ChatProviderError as e: - logger.exception("LLM provider error:") - if isinstance(e, APIStatusError) and e.status_code == 401: - console.print("[red]Authorization failed, please check your login status[/red]") - elif isinstance(e, APIStatusError) and e.status_code == 402: - console.print("[red]Membership expired, please renew your plan[/red]") - elif isinstance(e, APIStatusError) and e.status_code == 403: - console.print("[red]Quota exceeded, please upgrade your plan or retry later[/red]") - else: - console.print(f"[red]LLM provider error: {e}[/red]") - except MaxStepsReached as e: - logger.warning("Max steps reached: {n_steps}", n_steps=e.n_steps) - console.print(f"[yellow]{e}[/yellow]") - except RunCancelled: - logger.info("Cancelled by user") - console.print("[red]Interrupted by user[/red]") - except Exception as e: - logger.exception("Unexpected error:") - console.print(f"[red]Unexpected error: {e}[/red]") - raise # re-raise unknown error - finally: - self._maybe_present_pending_approvals() - remove_sigint() - return False - - async def _watch_root_wire_hub(self) -> None: - if not isinstance(self.soul, KimiSoul): - return - if self.soul.runtime.root_wire_hub is None: - return - queue = self.soul.runtime.root_wire_hub.subscribe() - try: - while True: - try: - msg = await queue.get() - except QueueShutDown: - return - try: - await self._handle_root_hub_message(msg) - except Exception: - logger.exception("Failed to handle root hub message:") - finally: - self.soul.runtime.root_wire_hub.unsubscribe(queue) - - async def _handle_root_hub_message(self, msg: WireMessage) -> None: - if not isinstance(self.soul, KimiSoul): - return - match msg: - case ApprovalRequest() as request: - request = self._enrich_approval_request_for_ui(request) - if self.soul.runtime.approval_runtime is None: - return - record = self.soul.runtime.approval_runtime.get_request(request.id) - if record is None or record.status != "pending": - return - if self._prompt_session is not None: - # Interactive mode: queue and present via modal - self._queue_approval_request(request) - self._maybe_present_pending_approvals() - self._prompt_session.invalidate() - elif self._active_approval_sink is not None: - # Non-interactive with live view: forward to sink - self._forward_approval_to_sink(request) - else: - # Queue for later - self._queue_approval_request(request) - case ApprovalResponse() as response: - # External resolution (e.g. from web UI) - if ( - self._approval_modal is not None - and self._approval_modal.request.id == response.request_id - ): - if not self._approval_modal.request.resolved: - self._approval_modal.request.resolve(response.response) - self._clear_current_prompt_approval_request(response.request_id) - self._activate_prompt_approval_modal() - self._remove_pending_approval_request(response.request_id) - self._maybe_present_pending_approvals() - if self._prompt_session is not None: - self._prompt_session.invalidate() - case _: - return - - def _enrich_approval_request_for_ui(self, request: ApprovalRequest) -> ApprovalRequest: - if not isinstance(self.soul, KimiSoul): - return request - if request.agent_id is None: - return request - if self.soul.runtime.subagent_store is None: - return request - record = self.soul.runtime.subagent_store.get_instance(request.agent_id) - if record is None: - return request - return request.model_copy(update={"source_description": record.description}) - - def _set_active_approval_sink(self, sink: Any) -> None: - self._active_approval_sink = sink - # Flush pending approvals to the newly active sink - while self._pending_approval_requests: - request = self._pending_approval_requests.popleft() - - if not isinstance(self.soul, KimiSoul) or self.soul.runtime.approval_runtime is None: - break - record = self.soul.runtime.approval_runtime.get_request(request.id) - if record is None or record.status != "pending": - continue - self._forward_approval_to_sink(request) - - def _clear_active_approval_sink(self) -> None: - self._active_approval_sink = None - # Re-queue any approval requests that were forwarded to the sink - # but not yet resolved. Without this, those requests would be - # silently lost when the live view closes between turns. - if not isinstance(self.soul, KimiSoul) or self.soul.runtime.approval_runtime is None: - return - for record in self.soul.runtime.approval_runtime.list_pending(): - self._queue_approval_request( - self._enrich_approval_request_for_ui( - ApprovalRequest( - id=record.id, - tool_call_id=record.tool_call_id, - sender=record.sender, - action=record.action, - description=record.description, - display=record.display, - source_kind=record.source.kind, - source_id=record.source.id, - agent_id=record.source.agent_id, - subagent_type=record.source.subagent_type, - ) - ) - ) - - def _forward_approval_to_sink(self, request: ApprovalRequest) -> None: - """Forward an approval request to the active live view sink and bridge the response.""" - if self._active_approval_sink is None: - self._queue_approval_request(request) - return - self._active_approval_sink.enqueue_external_message(request) - - async def _bridge() -> None: - try: - response = await request.wait() - if ( - isinstance(self.soul, KimiSoul) - and self.soul.runtime.approval_runtime is not None - ): - self.soul.runtime.approval_runtime.resolve( - request.id, response, feedback=request.feedback - ) - finally: - if self._prompt_session is not None: - self._prompt_session.invalidate() - - self._start_background_task(_bridge()) - - def _queue_approval_request(self, request: ApprovalRequest) -> None: - if self._approval_modal is not None and self._approval_modal.request.id == request.id: - return - if ( - self._current_prompt_approval_request is not None - and self._current_prompt_approval_request.id == request.id - ): - return - if any(r.id == request.id for r in self._pending_approval_requests): - return - self._pending_approval_requests.append(request) - - def _remove_pending_approval_request(self, request_id: str) -> None: - self._clear_current_prompt_approval_request(request_id) - self._pending_approval_requests = deque( - r for r in self._pending_approval_requests if r.id != request_id - ) - - def _clear_current_prompt_approval_request(self, request_id: str) -> None: - if ( - self._current_prompt_approval_request is not None - and self._current_prompt_approval_request.id == request_id - ): - self._current_prompt_approval_request = None - - def _maybe_present_pending_approvals(self) -> None: - if self._prompt_session is not None: - self._activate_prompt_approval_modal() - return - if self._active_approval_sink is not None: - while self._pending_approval_requests: - request = self._pending_approval_requests.popleft() - - if not isinstance(self.soul, KimiSoul): - break - if self.soul.runtime.approval_runtime is None: - break - record = self.soul.runtime.approval_runtime.get_request(request.id) - if record is None or record.status != "pending": - continue - self._forward_approval_to_sink(request) - - def _activate_prompt_approval_modal(self) -> None: - if self._prompt_session is None: - return - current_request = self._current_prompt_approval_request - if current_request is None: - current_request = self._pop_next_pending_approval_request() - self._current_prompt_approval_request = current_request - if current_request is None: - if self._approval_modal is not None: - self._prompt_session.detach_modal(self._approval_modal) - self._approval_modal = None - return - if self._approval_modal is None: - self._approval_modal = ApprovalPromptDelegate( - current_request, - on_response=self._handle_prompt_approval_response, - buffer_text_provider=( - lambda: self._prompt_session._session.default_buffer.text # pyright: ignore[reportPrivateUsage] - if self._prompt_session is not None - else "" - ), - text_expander=self._prompt_session._get_placeholder_manager().serialize_for_history, # pyright: ignore[reportPrivateUsage] - ) - self._prompt_session.attach_modal(self._approval_modal) - else: - if self._approval_modal.request.id != current_request.id: - self._approval_modal.set_request(current_request) - self._prompt_session.invalidate() - - def _handle_prompt_approval_response( - self, - request: ApprovalRequest, - response: ApprovalResponse.Kind, - feedback: str = "", - ) -> None: - if not isinstance(self.soul, KimiSoul): - return - if self.soul.runtime.approval_runtime is None: - return - self.soul.runtime.approval_runtime.resolve(request.id, response, feedback=feedback) - self._clear_current_prompt_approval_request(request.id) - self._activate_prompt_approval_modal() - - def _pop_next_pending_approval_request(self) -> ApprovalRequest | None: - if not isinstance(self.soul, KimiSoul) or self.soul.runtime.approval_runtime is None: - return None - while self._pending_approval_requests: - request = self._pending_approval_requests.popleft() - - record = self.soul.runtime.approval_runtime.get_request(request.id) - if record is None or record.status != "pending": - continue - return request - return None - - async def _auto_update(self) -> None: - result = await do_update(print=False, check_only=True) - if result == UpdateResult.UPDATE_AVAILABLE: - while True: - toast( - f"new version found, run `{_update_mod.UPGRADE_COMMAND}` to upgrade", - topic="update", - duration=30.0, - ) - await asyncio.sleep(60.0) - elif result == UpdateResult.UPDATED: - toast("auto updated, restart to use the new version", topic="update", duration=5.0) - - def _start_background_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]: - task = asyncio.create_task(coro) - self._background_tasks.add(task) - - def _cleanup(t: asyncio.Task[Any]) -> None: - self._background_tasks.discard(t) - try: - t.result() - except asyncio.CancelledError: - pass - except Exception: - logger.exception("Background task failed:") - - task.add_done_callback(_cleanup) - return task - - def _cancel_background_tasks(self) -> None: - """Cancel all background tasks (notification watcher, auto-update, etc.).""" - for task in self._background_tasks: - task.cancel() - self._background_tasks.clear() - - -_KIMI_BLUE = "dodger_blue1" -_LOGO = f"""\ -[{_KIMI_BLUE}]\ -▐█▛█▛█▌ -▐█████▌\ -[{_KIMI_BLUE}]\ -""" - - -@dataclass(slots=True) -class WelcomeInfoItem: - class Level(Enum): - INFO = "grey50" - WARN = "yellow" - ERROR = "red" - - name: str - value: str - level: Level = Level.INFO - - -def _print_welcome_info(name: str, info_items: list[WelcomeInfoItem]) -> None: - head = Text.from_markup("Welcome to Kimi Code CLI!") - help_text = Text.from_markup("[grey50]Send /help for help information.[/grey50]") - - # Use Table for precise width control - logo = Text.from_markup(_LOGO) - table = Table(show_header=False, show_edge=False, box=None, padding=(0, 1), expand=False) - table.add_column(justify="left") - table.add_column(justify="left") - table.add_row(logo, Group(head, help_text)) - - rows: list[RenderableType] = [table] - - if info_items: - rows.append(Text("")) # empty line - for item in info_items: - rows.append(Text(f"{item.name}: {item.value}", style=item.level.value)) - - if LATEST_VERSION_FILE.exists(): - from kimi_cli.constant import VERSION as current_version - - latest_version = LATEST_VERSION_FILE.read_text(encoding="utf-8").strip() - if semver_tuple(latest_version) > semver_tuple(current_version): - rows.append( - Text.from_markup( - f"\n[yellow]New version available: {latest_version}. " - f"Please run `{_update_mod.UPGRADE_COMMAND}` to upgrade.[/yellow]" - ) - ) - - console.print( - Panel( - Group(*rows), - border_style=_KIMI_BLUE, - expand=False, - padding=(1, 2), - ) - ) diff --git a/src/kimi_cli/ui/shell/approval_panel.py b/src/kimi_cli/ui/shell/approval_panel.py deleted file mode 100644 index 2bfcfb828..000000000 --- a/src/kimi_cli/ui/shell/approval_panel.py +++ /dev/null @@ -1,481 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from typing import NamedTuple - -from prompt_toolkit.application.run_in_terminal import run_in_terminal -from prompt_toolkit.buffer import Buffer -from prompt_toolkit.document import Document -from prompt_toolkit.formatted_text import ANSI -from prompt_toolkit.key_binding import KeyPressEvent -from rich.console import Group, RenderableType -from rich.markup import escape -from rich.padding import Padding -from rich.panel import Panel -from rich.text import Text - -from kimi_cli.ui.shell.console import console, render_to_ansi -from kimi_cli.ui.shell.keyboard import KeyEvent -from kimi_cli.utils.rich.diff_render import ( - collect_diff_hunks, - render_diff_panel, - render_diff_preview, - render_diff_summary_panel, - render_diff_summary_preview, -) -from kimi_cli.utils.rich.syntax import KimiSyntax -from kimi_cli.wire.types import ( - ApprovalRequest, - ApprovalResponse, - BriefDisplayBlock, - DiffDisplayBlock, - ShellDisplayBlock, -) - -# Truncation limits for approval request display -MAX_PREVIEW_LINES = 4 - - -class ApprovalContentBlock(NamedTuple): - """A pre-rendered content block for approval request with line count.""" - - text: str - lines: int - style: str = "" - lexer: str = "" - - -class ApprovalRequestPanel: - FEEDBACK_OPTION_INDEX = 3 - - def __init__(self, request: ApprovalRequest): - self.request = request - self.options: list[tuple[str, ApprovalResponse.Kind]] = [ - ("Approve once", "approve"), - ("Approve for this session", "approve_for_session"), - ("Reject", "reject"), - ("Reject, tell the model what to do instead", "reject"), - ] - self.selected_index = 0 - - # Pre-render content for the preview. - # All blocks (diff and non-diff) are rendered in original display order - # into a single list of renderables to preserve interleaving. - self._preview_renderables: list[RenderableType] = [] - self._has_diff = False - self._non_diff_truncated = False - # Legacy content blocks for non-diff blocks (used by render_full fallback) - self._content_blocks: list[ApprovalContentBlock] = [] - - # Line budget for non-diff blocks - non_diff_budget = MAX_PREVIEW_LINES - - # Handle description (only if no display blocks) - if request.description and not request.display: - text = request.description.rstrip("\n") - line_count = text.count("\n") + 1 - self._content_blocks.append(ApprovalContentBlock(text=text, lines=line_count)) - preview_text = text - if line_count > non_diff_budget: - preview_text = "\n".join(text.split("\n")[:non_diff_budget]) - self._non_diff_truncated = True - self._preview_renderables.append(Text(preview_text)) - non_diff_budget -= min(line_count, non_diff_budget) - - # Handle display blocks — group consecutive same-file DiffDisplayBlocks - display = request.display - idx = 0 - while idx < len(display): - block = display[idx] - if isinstance(block, DiffDisplayBlock): - path = block.path - diff_blocks: list[DiffDisplayBlock] = [] - while idx < len(display): - b = display[idx] - if not isinstance(b, DiffDisplayBlock) or b.path != path: - break - diff_blocks.append(b) - idx += 1 - if any(b.is_summary for b in diff_blocks): - self._has_diff = True - self._preview_renderables.extend(render_diff_summary_preview(path, diff_blocks)) - else: - hunks, added, removed = collect_diff_hunks(diff_blocks) - if hunks: - self._has_diff = True - renderables, _remaining = render_diff_preview( - path, - hunks, - added, - removed, - ) - self._preview_renderables.extend(renderables) - elif isinstance(block, ShellDisplayBlock): - text = block.command.rstrip("\n") - line_count = text.count("\n") + 1 - self._content_blocks.append( - ApprovalContentBlock(text=text, lines=line_count, lexer=block.language) - ) - if non_diff_budget > 0: - truncated = text - if line_count > non_diff_budget: - truncated = "\n".join(text.split("\n")[:non_diff_budget]) - self._non_diff_truncated = True - self._preview_renderables.append(KimiSyntax(truncated, block.language)) - non_diff_budget -= min(line_count, non_diff_budget) - else: - self._non_diff_truncated = True - idx += 1 - elif isinstance(block, BriefDisplayBlock) and block.text: - text = block.text.rstrip("\n") - line_count = text.count("\n") + 1 - self._content_blocks.append( - ApprovalContentBlock(text=text, lines=line_count, style="grey50") - ) - if non_diff_budget > 0: - truncated = text - if line_count > non_diff_budget: - truncated = "\n".join(text.split("\n")[:non_diff_budget]) - self._non_diff_truncated = True - self._preview_renderables.append(Text(truncated, style="grey50")) - non_diff_budget -= min(line_count, non_diff_budget) - else: - self._non_diff_truncated = True - idx += 1 - else: - idx += 1 - - # P1: diff pager always has context lines not shown in preview - # P2: non-diff blocks may have been truncated - self.has_expandable_content = self._has_diff or self._non_diff_truncated - - def render(self, *, feedback_text: str | None = None) -> RenderableType: - """Render the approval menu as a bordered panel.""" - content_lines: list[RenderableType] = [ - Text.from_markup( - "[yellow]" - f"{escape(self.request.sender)} is requesting approval to " - f"{escape(self.request.action)}:[/yellow]" - ) - ] - content_lines.extend(self._render_source_metadata_lines()) - content_lines.append(Text("")) - - # Render preview (diff + non-diff in original display order) - content_lines.extend(self._preview_renderables) - - if self.has_expandable_content and self._non_diff_truncated: - content_lines.append(Text("... (truncated, ctrl-e to expand)", style="dim italic")) - - lines: list[RenderableType] = [] - if content_lines: - lines.append(Padding(Group(*content_lines), (0, 0, 0, 1))) - - # Whether inline feedback input is active - show_inline_feedback = feedback_text is not None and self.is_feedback_selected - - # Add menu options with number key labels - if lines: - lines.append(Text("")) - for i, (option_text, _) in enumerate(self.options): - num = i + 1 - is_feedback_option = i == self.FEEDBACK_OPTION_INDEX - if i == self.selected_index: - if is_feedback_option and show_inline_feedback: - input_display = escape(feedback_text) if feedback_text else "" - lines.append( - Text.from_markup( - f"[cyan]\u2192 \\[{num}] Reject: {input_display}\u2588[/cyan]" - ) - ) - else: - lines.append(Text(f"\u2192 [{num}] {option_text}", style="cyan")) - else: - lines.append(Text(f" [{num}] {option_text}", style="grey50")) - - # Keyboard hints - lines.append(Text("")) - if show_inline_feedback: - hint = " Type your feedback, then press Enter to submit." - else: - hint = " \u25b2/\u25bc select 1/2/3/4 choose \u21b5 confirm" - if self.has_expandable_content: - hint += " ctrl-e expand" - lines.append(Text(hint, style="dim")) - - return Panel( - Group(*lines), - border_style="bold yellow", - title="[bold yellow]\u26a0 ACTION REQUIRED[/bold yellow]", - title_align="left", - padding=(0, 1), - ) - - def _render_block( - self, block: ApprovalContentBlock, max_lines: int | None = None - ) -> RenderableType: - """Render a content block, optionally truncated.""" - text = block.text - if max_lines is not None and block.lines > max_lines: - text = "\n".join(text.split("\n")[:max_lines]) - - if block.lexer: - return KimiSyntax(text, block.lexer) - return Text(text, style=block.style) - - def render_full(self) -> list[RenderableType]: - """Render full content for pager (no truncation).""" - return [self._render_block(block) for block in self._content_blocks] - - def _render_source_metadata_lines(self) -> list[RenderableType]: - lines: list[RenderableType] = [] - if self.request.subagent_type is not None or self.request.agent_id is not None: - if self.request.subagent_type is not None and self.request.agent_id is not None: - subagent_text = f"{self.request.subagent_type} ({self.request.agent_id})" - elif self.request.subagent_type is not None: - subagent_text = self.request.subagent_type - else: - assert self.request.agent_id is not None - subagent_text = self.request.agent_id - lines.append(Text(f"Subagent: {subagent_text}", style="grey50")) - if self.request.source_description: - lines.append(Text(f"Task: {self.request.source_description}", style="grey50")) - return lines - - def move_up(self): - """Move selection up.""" - self.selected_index = (self.selected_index - 1) % len(self.options) - - def move_down(self): - """Move selection down.""" - self.selected_index = (self.selected_index + 1) % len(self.options) - - @property - def is_feedback_selected(self) -> bool: - return self.selected_index == self.FEEDBACK_OPTION_INDEX - - def get_selected_response(self) -> ApprovalResponse.Kind: - """Get the approval response based on selected option.""" - return self.options[self.selected_index][1] - - -def show_approval_in_pager(panel: ApprovalRequestPanel) -> None: - """Show the full approval request content in a pager.""" - with console.screen(), console.pager(styles=True): - console.print( - Text.from_markup( - "[yellow]⚠ " - f"{escape(panel.request.sender)} is requesting approval to " - f"{escape(panel.request.action)}:[/yellow]" - ) - ) - console.print() - - # Render display blocks with the unified diff renderer. - display = panel.request.display - rendered_any = False - idx = 0 - while idx < len(display): - block = display[idx] - if isinstance(block, DiffDisplayBlock): - path = block.path - diff_blocks: list[DiffDisplayBlock] = [] - while idx < len(display): - b = display[idx] - if not isinstance(b, DiffDisplayBlock) or b.path != path: - break - diff_blocks.append(b) - idx += 1 - if any(b.is_summary for b in diff_blocks): - console.print(render_diff_summary_panel(path, diff_blocks)) - rendered_any = True - else: - hunks, added, removed = collect_diff_hunks(diff_blocks) - if hunks: - console.print(render_diff_panel(path, hunks, added, removed)) - rendered_any = True - elif isinstance(block, ShellDisplayBlock): - console.print(KimiSyntax(block.command.rstrip("\n"), block.language)) - rendered_any = True - idx += 1 - elif isinstance(block, BriefDisplayBlock) and block.text: - console.print(Text(block.text.rstrip("\n"), style="grey50")) - rendered_any = True - idx += 1 - else: - idx += 1 - - # Fallback: if nothing was rendered (e.g. type mismatch after deserialization), - # use legacy pre-rendered content blocks. - if not rendered_any: - for renderable in panel.render_full(): - console.print(renderable) - - -class ApprovalPromptDelegate: - modal_priority = 20 - _KEY_MAP: dict[str, KeyEvent] = { - "up": KeyEvent.UP, - "down": KeyEvent.DOWN, - "enter": KeyEvent.ENTER, - "1": KeyEvent.NUM_1, - "2": KeyEvent.NUM_2, - "3": KeyEvent.NUM_3, - "4": KeyEvent.NUM_4, - "escape": KeyEvent.ESCAPE, - "c-c": KeyEvent.ESCAPE, - "c-d": KeyEvent.ESCAPE, - } - - def __init__( - self, - request: ApprovalRequest, - *, - on_response: Callable[[ApprovalRequest, ApprovalResponse.Kind, str], None], - buffer_text_provider: Callable[[], str] | None = None, - text_expander: Callable[[str], str] | None = None, - ) -> None: - self._panel = ApprovalRequestPanel(request) - self._on_response = on_response - self._buffer_text_provider = buffer_text_provider - self._text_expander = text_expander - self._feedback_draft: str = "" - - @property - def request(self) -> ApprovalRequest: - return self._panel.request - - def set_request(self, request: ApprovalRequest) -> None: - self._panel = ApprovalRequestPanel(request) - self._feedback_draft = "" - - def _is_inline_feedback_active(self) -> bool: - return self._panel.is_feedback_selected and self._buffer_text_provider is not None - - def render_running_prompt_body(self, columns: int) -> ANSI: - feedback_text: str | None = None - if self._is_inline_feedback_active(): - feedback_text = self._buffer_text_provider() if self._buffer_text_provider else "" - body = render_to_ansi( - self._panel.render(feedback_text=feedback_text), - columns=columns, - ).rstrip("\n") - return ANSI(body) - - def running_prompt_placeholder(self) -> str | None: - return None - - def running_prompt_allows_text_input(self) -> bool: - return self._is_inline_feedback_active() - - def running_prompt_hides_input_buffer(self) -> bool: - return True - - def running_prompt_accepts_submission(self) -> bool: - return False - - def should_handle_running_prompt_key(self, key: str) -> bool: - if key == "c-e": - return self._panel.has_expandable_content - if self._is_inline_feedback_active(): - return key in {"enter", "escape", "c-c", "c-d", "up", "down"} - return key in { - "up", - "down", - "enter", - "1", - "2", - "3", - "4", - "escape", - "c-c", - "c-d", - "c-e", - } - - def handle_running_prompt_key(self, key: str, event: KeyPressEvent) -> None: - if key == "c-e": - event.app.create_background_task(self._show_panel_in_pager()) - return - - # Inline feedback mode: user is typing in the "Reject + feedback" field - if self._is_inline_feedback_active(): - mapped = self._KEY_MAP.get(key) - if key == "enter" or mapped == KeyEvent.ENTER: - text = event.current_buffer.text.strip() - if text: - if self._text_expander is not None: - text = self._text_expander(text) - self._clear_buffer(event.current_buffer) - self._feedback_draft = "" - self._panel.request.resolve("reject") - self._on_response(self._panel.request, "reject", text) - # Empty enter: do nothing (keep editing) - return - if mapped == KeyEvent.ESCAPE: - self._clear_buffer(event.current_buffer) - self._feedback_draft = "" - self._panel.request.resolve("reject") - self._on_response(self._panel.request, "reject", "") - return - if mapped in {KeyEvent.UP, KeyEvent.DOWN}: - self._feedback_draft = event.current_buffer.text - self._clear_buffer(event.current_buffer) - if mapped == KeyEvent.UP: - self._panel.move_up() - else: - self._panel.move_down() - return - return - - mapped = self._KEY_MAP.get(key) - if mapped is None: - return - match mapped: - case KeyEvent.UP: - self._panel.move_up() - self._maybe_restore_feedback_draft(event.current_buffer) - case KeyEvent.DOWN: - self._panel.move_down() - self._maybe_restore_feedback_draft(event.current_buffer) - case KeyEvent.ENTER: - self._submit_current_request(event.current_buffer) - case KeyEvent.ESCAPE: - self._panel.request.resolve("reject") - self._on_response(self._panel.request, "reject", "") - case KeyEvent.NUM_1 | KeyEvent.NUM_2 | KeyEvent.NUM_3 | KeyEvent.NUM_4: - num_map = { - KeyEvent.NUM_1: 0, - KeyEvent.NUM_2: 1, - KeyEvent.NUM_3: 2, - KeyEvent.NUM_4: 3, - } - idx = num_map[mapped] - if idx < len(self._panel.options): - self._panel.selected_index = idx - if not self._is_inline_feedback_active(): - self._submit_current_request(event.current_buffer) - case _: - pass - - async def _show_panel_in_pager(self) -> None: - await run_in_terminal(lambda: show_approval_in_pager(self._panel)) - - def _maybe_restore_feedback_draft(self, buffer: Buffer) -> None: - if self._is_inline_feedback_active() and self._feedback_draft: - buffer.set_document( - Document(text=self._feedback_draft, cursor_position=len(self._feedback_draft)), - bypass_readonly=True, - ) - - @staticmethod - def _clear_buffer(buffer: Buffer) -> None: - if buffer.text: - buffer.set_document(Document(text="", cursor_position=0), bypass_readonly=True) - - def _submit_current_request(self, buffer: Buffer) -> None: - self._clear_buffer(buffer) - self._feedback_draft = "" - response = self._panel.get_selected_response() - self._panel.request.resolve(response) - self._on_response(self._panel.request, response, "") diff --git a/src/kimi_cli/ui/shell/commands/add_dir.ts b/src/kimi_cli/ui/shell/commands/add_dir.ts new file mode 100644 index 000000000..5d804f537 --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/add_dir.ts @@ -0,0 +1,91 @@ +/** + * /add-dir slash command handler. + * Adds a directory to the workspace scope. + * Corresponds to Python soul/slash.py add_dir command. + */ + +import { resolve } from "node:path"; +import { logger } from "../../../utils/logging.ts"; +import type { Session } from "../../../session.ts"; +import { saveSessionState } from "../../../session.ts"; + +export async function handleAddDir( + session: Session, + workDir: string, + args: string, +): Promise { + const arg = args.trim(); + + // No args: list currently added directories + if (!arg) { + const dirs = session.state.additional_dirs; + if (!dirs.length) { + logger.info("No additional directories. Usage: /add-dir "); + } else { + logger.info("Additional directories:"); + for (const d of dirs) { + logger.info(` - ${d}`); + } + } + return null; + } + + // Resolve the path + const dirPath = resolve(arg.replace(/^~/, process.env.HOME ?? "~")); + + // Check existence + const dirFile = Bun.file(dirPath); + try { + const stat = await Bun.$`test -d ${dirPath}`.quiet(); + if (stat.exitCode !== 0) { + logger.info(`Not a directory: ${dirPath}`); + return null; + } + } catch { + logger.info(`Directory does not exist: ${dirPath}`); + return null; + } + + // Check if already added + if (session.state.additional_dirs.includes(dirPath)) { + logger.info(`Directory already in workspace: ${dirPath}`); + return null; + } + + // Check if within work dir + if (dirPath.startsWith(workDir + "/") || dirPath === workDir) { + logger.info(`Directory is already within the working directory: ${dirPath}`); + return null; + } + + // Check if within an already-added directory + for (const existing of session.state.additional_dirs) { + if (dirPath.startsWith(existing + "/") || dirPath === existing) { + logger.info(`Directory is already within added directory ${existing}: ${dirPath}`); + return null; + } + } + + // Validate readability + let lsOutput = ""; + try { + lsOutput = await Bun.$`ls -la ${dirPath}`.quiet().text(); + } catch (e) { + logger.info(`Cannot read directory: ${dirPath}`); + return null; + } + + // Add the directory + session.state.additional_dirs.push(dirPath); + await saveSessionState(session.state, session.dir); + + logger.info(`Added directory to workspace: ${dirPath}`); + + // Return info string for injecting into context + return ( + `The user has added an additional directory to the workspace: \`${dirPath}\`\n\n` + + `Directory listing:\n\`\`\`\n${lsOutput.trim()}\n\`\`\`\n\n` + + "You can now read, write, search, and glob files in this directory " + + "as if it were part of the working directory." + ); +} diff --git a/src/kimi_cli/ui/shell/commands/editor.ts b/src/kimi_cli/ui/shell/commands/editor.ts new file mode 100644 index 000000000..bb8091deb --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/editor.ts @@ -0,0 +1,47 @@ +import { loadConfig, saveConfig, type Config, type ConfigMeta } from "../../../config.ts"; +import { logger } from "../../../utils/logging.ts"; +import { which } from "bun"; + +export async function handleEditor(config: Config, configMeta: ConfigMeta, args: string): Promise { + const currentEditor = config.default_editor; + + if (!args.trim()) { + // Show current and available options + logger.info(`Current editor: ${currentEditor || "auto-detect"}`); + logger.info(""); + logger.info("Available editors:"); + logger.info(" /editor code --wait (VS Code)"); + logger.info(" /editor vim"); + logger.info(" /editor nano"); + logger.info(" /editor (any editor command)"); + logger.info(' /editor "" (auto-detect from $VISUAL/$EDITOR)'); + return; + } + + const newEditor = args.trim(); + + // Validate binary exists + if (newEditor) { + const binary = newEditor.split(/\s+/)[0]!; + const found = which(binary); + if (!found) { + logger.info(`Warning: '${binary}' not found in PATH. Saving anyway.`); + } + } + + if (newEditor === currentEditor) { + logger.info(`Editor is already set to: ${newEditor || "auto-detect"}`); + return; + } + + // Save to config + try { + const freshConfig = (await loadConfig(configMeta.sourceFile ?? undefined)).config; + freshConfig.default_editor = newEditor; + await saveConfig(freshConfig, configMeta.sourceFile ?? undefined); + config.default_editor = newEditor; + logger.info(`Editor set to: ${newEditor || "auto-detect"}`); + } catch (err) { + logger.info(`Failed to save config: ${err instanceof Error ? err.message : err}`); + } +} diff --git a/src/kimi_cli/ui/shell/commands/export_import.ts b/src/kimi_cli/ui/shell/commands/export_import.ts new file mode 100644 index 000000000..87fe40654 --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/export_import.ts @@ -0,0 +1,112 @@ +import type { Context } from "../../../soul/context.ts"; +import type { Session } from "../../../session.ts"; +import type { ContentPart } from "../../../types.ts"; +import { join } from "node:path"; +import { homedir } from "node:os"; +import { logger } from "../../../utils/logging.ts"; + +export async function handleExport(context: Context, session: Session, args: string): Promise { + const history = context.history; + if (!history.length) { + logger.info("Nothing to export - context is empty."); + return; + } + + // Determine output path + const outputDir = args.trim() || session.workDir; + const filename = `kimi-export-${session.id.slice(0, 8)}.md`; + const outputPath = join(outputDir, filename); + + // Build markdown + const lines: string[] = []; + lines.push(`# Kimi CLI Session Export`); + lines.push(`Session: ${session.id}`); + lines.push(`Exported: ${new Date().toISOString()}`); + lines.push(`Messages: ${history.length}`); + lines.push(`Tokens: ${context.tokenCountWithPending}`); + lines.push(""); + + for (let i = 0; i < history.length; i++) { + const msg = history[i]!; + lines.push(`## ${msg.role.toUpperCase()} (#${i + 1})`); + lines.push(""); + if (typeof msg.content === "string") { + lines.push(msg.content); + } else if (Array.isArray(msg.content)) { + for (const part of msg.content as ContentPart[]) { + if (part.type === "text") { + lines.push(part.text); + } else if (part.type === "tool_use") { + lines.push(`**Tool Call: ${part.name}**`); + lines.push("```json"); + lines.push(JSON.stringify(part.input, null, 2)); + lines.push("```"); + } else if (part.type === "tool_result") { + lines.push(`**Tool Result** (${part.isError ? "error" : "success"})`); + lines.push("```"); + lines.push(part.content); + lines.push("```"); + } + } + } + lines.push(""); + } + + try { + await Bun.write(outputPath, lines.join("\n")); + // Shorten home dir for display + const display = outputPath.replace(homedir(), "~"); + logger.info(`Exported ${history.length} messages to ${display}`); + logger.info("Note: The exported file may contain sensitive information."); + } catch (err) { + logger.info(`Failed to export: ${err instanceof Error ? err.message : err}`); + } +} + +export async function handleImport(context: Context, session: Session, args: string): Promise { + const target = args.trim(); + if (!target) { + logger.info("Usage: /import "); + return; + } + + // Check if it's a file path + const file = Bun.file(target); + if (await file.exists()) { + try { + const content = await file.text(); + // Append as a user message with import marker + await context.appendMessage({ + role: "user", + content: `[Imported from ${target}]\n\n${content}`, + }); + logger.info(`Imported ${content.length} chars from ${target}`); + } catch (err) { + logger.info(`Failed to import: ${err instanceof Error ? err.message : err}`); + } + return; + } + + // Try as session ID + const { Session: SessionClass } = await import("../../../session.ts"); + const otherSession = await SessionClass.find(session.workDir, target); + if (!otherSession) { + logger.info(`File not found and no session with ID: ${target}`); + return; + } + + // Read other session's context + const contextFile = Bun.file(otherSession.contextFile); + if (!(await contextFile.exists())) { + logger.info("Target session has no context."); + return; + } + + const text = await contextFile.text(); + const messageCount = text.split("\n").filter(l => l.trim()).length; + await context.appendMessage({ + role: "user", + content: `[Imported context from session ${target}]\n\n${text}`, + }); + logger.info(`Imported context from session ${target} (~${messageCount} entries)`); +} diff --git a/src/kimi_cli/ui/shell/commands/feedback.ts b/src/kimi_cli/ui/shell/commands/feedback.ts new file mode 100644 index 000000000..2df5416ff --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/feedback.ts @@ -0,0 +1,67 @@ +import { loadTokens } from "../../../auth/oauth.ts"; +import type { Config } from "../../../config.ts"; +import { logger } from "../../../utils/logging.ts"; +import { platform, release } from "node:os"; + +const ISSUE_URL = "https://github.com/MoonshotAI/kimi-cli/issues"; + +export async function handleFeedback( + config: Config, + args: string, + sessionId: string, + modelKey: string | undefined, +): Promise { + const content = args.trim(); + if (!content) { + logger.info("Usage: /feedback "); + logger.info(`Or submit at: ${ISSUE_URL}`); + return; + } + + // Try to find a provider with OAuth for posting feedback + let apiKey: string | null = null; + let baseUrl: string | null = null; + + for (const [, provider] of Object.entries(config.providers)) { + if (provider.oauth) { + const token = await loadTokens(provider.oauth); + if (token) { + apiKey = token.access_token; + baseUrl = provider.base_url.replace(/\/+$/, ""); + break; + } + } + } + + if (!apiKey || !baseUrl) { + logger.info(`No authenticated platform found. Please submit feedback at: ${ISSUE_URL}`); + return; + } + + const payload = { + session_id: sessionId, + content, + version: "2.0.0", + os: `${platform()} ${release()}`, + model: modelKey || null, + }; + + try { + const res = await fetch(`${baseUrl}/feedback`, { + method: "POST", + headers: { + "Authorization": `Bearer ${apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(payload), + }); + if (res.ok) { + logger.info(`Feedback submitted, thank you! Session ID: ${sessionId}`); + } else { + logger.info(`Failed to submit feedback (HTTP ${res.status}). Try: ${ISSUE_URL}`); + } + } catch (err) { + logger.info(`Failed to submit feedback: ${err instanceof Error ? err.message : err}`); + logger.info(`Please submit at: ${ISSUE_URL}`); + } +} diff --git a/src/kimi_cli/ui/shell/commands/info.ts b/src/kimi_cli/ui/shell/commands/info.ts new file mode 100644 index 000000000..f989d8c20 --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/info.ts @@ -0,0 +1,81 @@ +/** + * Info slash-command handlers: /hooks, /mcp, /debug, /changelog + * Corresponds to Python ui/shell/commands/info.py + */ + +import type { HookEngine } from "../../../hooks/engine.ts"; +import type { Config } from "../../../config.ts"; +import type { Context } from "../../../soul/context.ts"; +import type { ContentPart } from "../../../types.ts"; +import { CHANGELOG } from "../../../utils/changelog.ts"; +import { logger } from "../../../utils/logging.ts"; + +export function handleHooks(hookEngine: HookEngine): void { + const summary = hookEngine.summary; + if (!Object.keys(summary).length) { + logger.info("No hooks configured. Add [[hooks]] sections to config.toml."); + return; + } + logger.info("\nConfigured Hooks:"); + for (const [event, count] of Object.entries(summary)) { + logger.info(` ${event}: ${count} hook(s)`); + } + logger.info(""); +} + +export function handleMcp(config: Config): void { + logger.info("MCP Configuration:"); + logger.info(` Client timeout: ${config.mcp.client.tool_call_timeout_ms}ms`); + logger.info("\nNote: MCP server management available via 'kimi mcp' CLI commands."); +} + +export function handleDebug(context: Context): void { + const history = context.history; + if (!history.length) { + logger.info("Context is empty - no messages yet."); + return; + } + + logger.info(`\n=== Context Debug ===`); + logger.info(`Total messages: ${history.length}`); + logger.info(`Token count: ${context.tokenCountWithPending}`); + logger.info(`---`); + + for (let i = 0; i < history.length; i++) { + const msg = history[i]!; + const role = msg.role.toUpperCase(); + + if (typeof msg.content === "string") { + const preview = + msg.content.length > 200 + ? msg.content.slice(0, 200) + "..." + : msg.content; + logger.info(`#${i + 1} [${role}] ${preview}`); + } else if (Array.isArray(msg.content)) { + const parts = msg.content as ContentPart[]; + const summary = parts + .map((p: any) => { + if (p.type === "text") + return p.text.length > 100 ? p.text.slice(0, 100) + "..." : p.text; + if (p.type === "tool_use") return `[tool_use: ${p.name}]`; + if (p.type === "tool_result") return `[tool_result]`; + if (p.type === "image") return `[image]`; + return `[${p.type}]`; + }) + .join(" | "); + logger.info(`#${i + 1} [${role}] ${summary}`); + } + } + logger.info(`=== End Debug ===\n`); +} + +export function handleChangelog(): void { + logger.info("\n Release Notes:\n"); + for (const [version, entry] of Object.entries(CHANGELOG)) { + logger.info(` ${version}: ${entry.description}`); + for (const item of entry.entries) { + logger.info(` \u2022 ${item}`); + } + logger.info(""); + } +} diff --git a/src/kimi_cli/ui/shell/commands/init.ts b/src/kimi_cli/ui/shell/commands/init.ts new file mode 100644 index 000000000..1e7516934 --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/init.ts @@ -0,0 +1,64 @@ +/** + * /init slash command handler. + * Analyzes the codebase and generates an AGENTS.md file. + * Corresponds to Python soul/slash.py init command. + */ + +import { logger } from "../../../utils/logging.ts"; + +/** + * Handle /init command — trigger codebase analysis and AGENTS.md generation. + * In the Python version this creates a temporary context, runs a prompt through the LLM, + * then reloads the generated AGENTS.md. For now we provide a simplified version. + */ +export async function handleInit(workDir: string): Promise { + const agentsMdPath = `${workDir}/AGENTS.md`; + + // Check if AGENTS.md already exists + const existing = Bun.file(agentsMdPath); + if (await existing.exists()) { + logger.info(`AGENTS.md already exists at ${agentsMdPath}`); + logger.info("To regenerate, delete it first and run /init again."); + return null; + } + + logger.info("Analyzing codebase to generate AGENTS.md..."); + logger.info("Note: Full /init requires an LLM call. Generating a basic template."); + + // Generate a basic template + let lsOutput = ""; + try { + lsOutput = await Bun.$`ls -la ${workDir}`.quiet().text(); + } catch { + lsOutput = "(unable to list directory)"; + } + + const template = [ + "# AGENTS.md", + "", + "## Project Overview", + "", + "", + "", + "## Directory Structure", + "", + "```", + lsOutput.trim(), + "```", + "", + "## Conventions", + "", + "", + "", + ].join("\n"); + + try { + await Bun.write(agentsMdPath, template); + logger.info(`Generated AGENTS.md at ${agentsMdPath}`); + logger.info("Edit it to describe your project for better AI assistance."); + return template; + } catch (err) { + logger.info(`Failed to generate AGENTS.md: ${err instanceof Error ? err.message : err}`); + return null; + } +} diff --git a/src/kimi_cli/ui/shell/commands/login.ts b/src/kimi_cli/ui/shell/commands/login.ts new file mode 100644 index 000000000..46e48091c --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/login.ts @@ -0,0 +1,275 @@ +/** + * /login and /logout slash command handlers. + * Corresponds to Python's ui/shell/oauth.py + setup.py. + * + * /login uses a multi-level panel flow: + * 1. Select platform (choice) + * 2a. Kimi Code → OAuth device-code flow (events shown in message list) + * 2b. Others → Enter API key (input) → Select model (choice) → Thinking (choice) + */ + +import { + loginKimiCode, + logoutKimiCode, +} from "../../../auth/oauth.ts"; +import { + PLATFORMS, + KIMI_CODE_PLATFORM_ID, + getPlatformById, + managedProviderKey, + managedModelKey, + listModels, + deriveModelCapabilities, + type Platform, + type ModelInfo, +} from "../../../auth/platforms.ts"; +import { saveConfig, type Config } from "../../../config.ts"; +import type { CommandPanelConfig } from "../../../types.ts"; + +type Notify = (title: string, body: string) => void; + +// ── Kimi Code OAuth login (via async generator) ────────── + +async function runLoginKimiCode(config: Config, notify: Notify): Promise { + for await (const event of loginKimiCode(config)) { + switch (event.type) { + case "info": + case "verification_url": + case "waiting": + notify("Login", event.message); + break; + case "success": + notify("Login", `✓ ${event.message}`); + break; + case "error": + notify("Login", `✗ ${event.message}`); + break; + } + } +} + +// ── Non-OAuth platform setup (multi-step wizard) ───────── + +/** + * Step 2: After API key entered, verify it and show model selection. + */ +function buildModelSelectPanel( + platform: Platform, + apiKey: string, + config: Config, + notify: Notify, +): CommandPanelConfig { + // We return a content panel first as "loading", then transition + return { + type: "content", + title: `${platform.name} — Verifying API key...`, + content: "Fetching available models, please wait...", + }; +} + +/** + * Actually fetch models and build the model choice panel. + */ +async function fetchAndBuildModelPanel( + platform: Platform, + apiKey: string, + config: Config, + notify: Notify, +): Promise { + let models: ModelInfo[]; + try { + models = await listModels(platform, apiKey); + } catch (err: any) { + if (err?.message?.includes("401")) { + notify("Login", `✗ API key verification failed. Please check your key.`); + } else { + notify("Login", `✗ Failed to fetch models: ${err?.message ?? err}`); + } + return; + } + + if (!models.length) { + notify("Login", "✗ No models available for this platform."); + return; + } + + return { + type: "choice", + title: `${platform.name} — Select Model`, + items: models.map((m) => ({ + label: m.id, + value: m.id, + description: `ctx: ${m.contextLength}`, + })), + onSelect: (modelId: string): CommandPanelConfig | void => { + const model = models.find((m) => m.id === modelId); + if (!model) return; + const caps = deriveModelCapabilities(model); + + // If model supports optional thinking, ask + if (caps.has("thinking") && !caps.has("always_thinking")) { + return buildThinkingPanel(platform, apiKey, model, models, config, notify); + } + // Otherwise, auto-decide + const thinking = caps.has("always_thinking") || caps.has("thinking"); + applyNonOAuthConfig(platform, apiKey, model, models, thinking, config, notify); + }, + }; +} + +/** + * Step 3: Select thinking mode. + */ +function buildThinkingPanel( + platform: Platform, + apiKey: string, + selectedModel: ModelInfo, + models: ModelInfo[], + config: Config, + notify: Notify, +): CommandPanelConfig { + return { + type: "choice", + title: `${platform.name} — Thinking Mode`, + items: [ + { label: "On", value: "on", description: "Enable extended thinking" }, + { label: "Off", value: "off", description: "Standard mode" }, + ], + onSelect: (value: string) => { + const thinking = value === "on"; + applyNonOAuthConfig(platform, apiKey, selectedModel, models, thinking, config, notify); + }, + }; +} + +/** + * Apply config for non-OAuth platforms (API key based). + * Corresponds to Python's _apply_setup_result(). + */ +function applyNonOAuthConfig( + platform: Platform, + apiKey: string, + selectedModel: ModelInfo, + models: ModelInfo[], + thinking: boolean, + config: Config, + notify: Notify, +): void { + const providerKey = managedProviderKey(platform.id); + + config.providers[providerKey] = { + type: "kimi", + base_url: platform.baseUrl, + api_key: apiKey, + }; + + // Remove old models for this provider + for (const [key, model] of Object.entries(config.models)) { + if (model.provider === providerKey) delete config.models[key]; + } + + // Add all available models + for (const m of models) { + const caps = deriveModelCapabilities(m); + config.models[managedModelKey(platform.id, m.id)] = { + provider: providerKey, + model: m.id, + max_context_size: m.contextLength, + capabilities: caps.size > 0 ? ([...caps] as any) : undefined, + }; + } + + config.default_model = managedModelKey(platform.id, selectedModel.id); + config.default_thinking = thinking; + + if (platform.searchUrl) { + config.services = config.services ?? {}; + (config.services as any).moonshot_search = { + base_url: platform.searchUrl, + api_key: apiKey, + }; + } + if (platform.fetchUrl) { + config.services = config.services ?? {}; + (config.services as any).moonshot_fetch = { + base_url: platform.fetchUrl, + api_key: apiKey, + }; + } + + saveConfig(config).then(() => { + const thinkLabel = thinking ? "on" : "off"; + notify("Login", [ + `✓ Setup complete!`, + ` Platform: ${platform.name}`, + ` Model: ${selectedModel.id}`, + ` Thinking: ${thinkLabel}`, + ].join("\n")); + }).catch((err) => { + notify("Login", `✗ Failed to save config: ${err}`); + }); +} + +// ── Public API ─────────────────────────────────────────── + +/** + * Handle /login — dispatches to the correct flow based on platform. + * When called without a panel (e.g. `/login` typed directly), defaults to Kimi Code. + */ +export async function handleLogin(config: Config, notify: Notify): Promise { + await runLoginKimiCode(config, notify); +} + +/** + * Handle /logout — delete stored OAuth tokens and clean up config. + */ +export async function handleLogout(config: Config, notify: Notify): Promise { + for await (const event of logoutKimiCode(config)) { + switch (event.type) { + case "success": + notify("Logout", `✓ ${event.message}`); + break; + case "error": + notify("Logout", `✗ ${event.message}`); + break; + default: + notify("Logout", event.message); + } + } +} + +/** + * Create panel config for /login — platform selection → multi-step wizard. + * Corresponds to Python's select_platform() → setup_platform(). + */ +export function createLoginPanel(config: Config, notify: Notify): CommandPanelConfig { + return { + type: "choice", + title: "Login — Select Platform", + items: PLATFORMS.map((p) => ({ + label: p.name, + value: p.id, + })), + onSelect: (platformId: string): CommandPanelConfig | Promise | void => { + const platform = getPlatformById(platformId); + if (!platform) return; + + if (platform.id === KIMI_CODE_PLATFORM_ID) { + // Kimi Code uses OAuth — events go to message list + runLoginKimiCode(config, notify); + return; // Close panel, events stream into chat + } + + // Other platforms: multi-step wizard + return { + type: "input", + title: `${platform.name} — Enter API Key`, + placeholder: "Paste your API key here...", + password: true, + onSubmit: (apiKey: string): Promise => { + return fetchAndBuildModelPanel(platform, apiKey, config, notify); + }, + }; + }, + }; +} diff --git a/src/kimi_cli/ui/shell/commands/misc.ts b/src/kimi_cli/ui/shell/commands/misc.ts new file mode 100644 index 000000000..ac491f37e --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/misc.ts @@ -0,0 +1,20 @@ +import { logger } from "../../../utils/logging.ts"; + +export function handleWeb(sessionId: string): void { + logger.info("Web UI is not yet available in the TypeScript version."); + logger.info("Use 'kimi web' CLI command to start the web server."); +} + +export function handleVis(sessionId: string): void { + logger.info("Visualizer is not yet available in the TypeScript version."); + logger.info("Use 'kimi vis' CLI command to start the visualizer."); +} + +export function handleReload(): void { + logger.info("Configuration reloaded. If changes don't take effect, please restart the CLI."); +} + +export function handleTask(): void { + logger.info("Background task browser is not yet available in the TypeScript version."); + logger.info("Background tasks are managed automatically during agent execution."); +} diff --git a/src/kimi_cli/ui/shell/commands/model.ts b/src/kimi_cli/ui/shell/commands/model.ts new file mode 100644 index 000000000..c21ec222c --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/model.ts @@ -0,0 +1,31 @@ +import type { Config, ConfigMeta } from "../../../config.ts"; +import { logger } from "../../../utils/logging.ts"; + +export async function handleModel(config: Config, configMeta: ConfigMeta): Promise { + if (!Object.keys(config.models).length) { + logger.info("No models configured. Run /login to set up."); + return; + } + + if (!configMeta.isFromDefaultLocation) { + logger.info("Model switching requires the default config file."); + return; + } + + const currentModel = config.default_model; + logger.info("Available models:"); + + const modelNames = Object.keys(config.models).sort(); + for (let i = 0; i < modelNames.length; i++) { + const name = modelNames[i]!; + const modelCfg = config.models[name]!; + const providerName = modelCfg.provider; + const current = name === currentModel ? " (current)" : ""; + const capabilities = modelCfg.capabilities?.join(", ") || "none"; + logger.info(` [${i + 1}] ${modelCfg.model} (${providerName})${current} [${capabilities}]`); + } + + logger.info(""); + logger.info("To switch models, use: kimi --model "); + logger.info("Or edit ~/.kimi/config.toml and set default_model"); +} diff --git a/src/kimi_cli/ui/shell/commands/session.ts b/src/kimi_cli/ui/shell/commands/session.ts new file mode 100644 index 000000000..853d005b1 --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/session.ts @@ -0,0 +1,53 @@ +/** + * Session-related slash commands: /new, /sessions, /title + */ + +import { Session, loadSessionState, saveSessionState } from "../../../session.ts"; +import { logger } from "../../../utils/logging.ts"; + +export async function handleNew(session: Session): Promise { + const workDir = session.workDir; + if (await session.isEmpty()) { + await session.delete(); + } + const newSession = await Session.create(workDir); + logger.info(`New session created: ${newSession.id}. Please restart to switch.`); +} + +export async function handleSessions(session: Session): Promise { + const sessions = await Session.list(session.workDir); + if (sessions.length === 0) { + logger.info("No sessions found."); + return; + } + for (const s of sessions) { + const current = s.id === session.id ? " (current)" : ""; + const timeAgo = formatRelativeTime(s.updatedAt); + logger.info(` ${s.title} (${s.id}) - ${timeAgo}${current}`); + } +} + +export async function handleTitle(session: Session, args: string): Promise { + if (!args.trim()) { + logger.info(`Session title: ${session.title}`); + return; + } + const newTitle = args.trim().slice(0, 200); + const freshState = await loadSessionState(session.dir); + freshState.custom_title = newTitle; + freshState.title_generated = true; + await saveSessionState(freshState, session.dir); + session.state.custom_title = newTitle; + session.title = newTitle; + logger.info(`Session title set to: ${newTitle}`); +} + +function formatRelativeTime(timestamp: number): string { + if (!timestamp) return "unknown"; + const now = Date.now() / 1000; + const diff = now - timestamp; + if (diff < 60) return "just now"; + if (diff < 3600) return `${Math.floor(diff / 60)} minutes ago`; + if (diff < 86400) return `${Math.floor(diff / 3600)} hours ago`; + return `${Math.floor(diff / 86400)} days ago`; +} diff --git a/src/kimi_cli/ui/shell/commands/usage.ts b/src/kimi_cli/ui/shell/commands/usage.ts new file mode 100644 index 000000000..04fdc85b3 --- /dev/null +++ b/src/kimi_cli/ui/shell/commands/usage.ts @@ -0,0 +1,69 @@ +import { loadTokens } from "../../../auth/oauth.ts"; +import type { Config } from "../../../config.ts"; +import { logger } from "../../../utils/logging.ts"; + +export async function handleUsage(config: Config, modelKey: string | undefined): Promise { + if (!modelKey || !config.models[modelKey]) { + logger.info("No model selected. Run /login first."); + return; + } + const modelCfg = config.models[modelKey]!; + const providerCfg = config.providers[modelCfg.provider]; + if (!providerCfg) { + logger.info("Provider not found."); + return; + } + + // Resolve API key (try OAuth token first) + let apiKey = providerCfg.api_key; + if (providerCfg.oauth) { + const token = await loadTokens(providerCfg.oauth); + if (token) apiKey = token.access_token; + } + + const baseUrl = providerCfg.base_url.replace(/\/+$/, ""); + const usageUrl = `${baseUrl}/usages`; + + try { + const res = await fetch(usageUrl, { + headers: { "Authorization": `Bearer ${apiKey}` }, + }); + if (!res.ok) { + if (res.status === 401) logger.info("Authorization failed. Please check your API key."); + else if (res.status === 404) logger.info("Usage endpoint not available."); + else logger.info(`Failed to fetch usage (HTTP ${res.status}).`); + return; + } + const data = await res.json() as Record; + + // Parse and display usage + const usage = data.usage; + if (usage) { + const limit = usage.limit || 0; + const used = usage.used ?? (limit - (usage.remaining || 0)); + const pct = limit > 0 ? ((limit - used) / limit * 100).toFixed(0) : "?"; + const label = usage.name || usage.title || "Weekly limit"; + logger.info(`\n API Usage:`); + logger.info(` ${label}: ${used}/${limit} used (${pct}% remaining)`); + } + + // Parse limits array + const limits = data.limits; + if (Array.isArray(limits) && limits.length > 0) { + for (const item of limits) { + const detail = item.detail || item; + const limit = detail.limit || 0; + const used = detail.used ?? (limit - (detail.remaining || 0)); + const pct = limit > 0 ? ((limit - used) / limit * 100).toFixed(0) : "?"; + const name = item.name || item.title || detail.name || "Limit"; + logger.info(` ${name}: ${used}/${limit} used (${pct}% remaining)`); + } + } + + if (!usage && (!limits || !limits.length)) { + logger.info("No usage data available."); + } + } catch (err) { + logger.info(`Failed to fetch usage: ${err instanceof Error ? err.message : err}`); + } +} diff --git a/src/kimi_cli/ui/shell/console.py b/src/kimi_cli/ui/shell/console.py deleted file mode 100644 index 9576260b8..000000000 --- a/src/kimi_cli/ui/shell/console.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -import os -import pydoc -import re - -from rich.console import Console, PagerContext, RenderableType -from rich.pager import Pager -from rich.theme import Theme - -NEUTRAL_MARKDOWN_THEME = Theme( - { - "markdown.paragraph": "none", - "markdown.block_quote": "none", - "markdown.hr": "none", - "markdown.item": "none", - "markdown.item.bullet": "none", - "markdown.item.number": "none", - "markdown.link": "none", - "markdown.link_url": "none", - "markdown.h1": "none", - "markdown.h1.border": "none", - "markdown.h2": "none", - "markdown.h3": "none", - "markdown.h4": "none", - "markdown.h5": "none", - "markdown.h6": "none", - "markdown.em": "none", - "markdown.strong": "none", - "markdown.s": "none", - "status.spinner": "none", - }, - inherit=True, -) - -_NEUTRAL_MARKDOWN_THEME = NEUTRAL_MARKDOWN_THEME - - -class _KimiPager(Pager): - """Pager that ignores MANPAGER to avoid garbled output. - - ``pydoc.getpager()`` reads ``MANPAGER`` before ``PAGER``. When the user - sets ``MANPAGER`` to a man-specific pipeline (e.g. - ``sh -c 'col -bx | bat -l man -p'``), that pipeline mangles the ANSI - rich-text we emit. This pager strips ``MANPAGER`` from the subprocess - environment so only ``PAGER`` (or the default ``less``) is used. - """ - - def show(self, content: str) -> None: - saved = os.environ.pop("MANPAGER", None) - try: - pydoc.pager(content) - finally: - if saved is not None: - os.environ["MANPAGER"] = saved - - -class _KimiConsole(Console): - """Console subclass that defaults to :class:`_KimiPager`.""" - - def pager( - self, - pager: Pager | None = None, - styles: bool = False, - links: bool = False, - ) -> PagerContext: - if pager is None: - pager = _KimiPager() - return super().pager(pager=pager, styles=styles, links=links) - - -console = _KimiConsole(highlight=False, theme=NEUTRAL_MARKDOWN_THEME) - - -# Matches OSC 8 hyperlink open/close markers emitted by Rich's Style(link=...). -# Format: ESC ] 8 ; ; ST where ST is ESC \ or BEL (\x07). -# prompt_toolkit's ANSI parser does not understand OSC 8 and renders the raw -# escape bytes as visible garbage (e.g. "8;id=391551;https://…"). We wrap each -# marker in \001…\002 so prompt_toolkit treats it as a ZeroWidthEscape and -# passes it through to the terminal via write_raw, preserving clickable links. -_OSC8_RE = re.compile(r"\x1b\]8;[^\x07\x1b]*(?:\x1b\\|\x07)") - - -def _wrap_osc8_as_zero_width(m: re.Match[str]) -> str: - """Wrap an OSC 8 marker in \\001…\\002 for prompt_toolkit ZeroWidthEscape.""" - return f"\x01{m.group(0)}\x02" - - -def render_to_ansi(renderable: RenderableType, *, columns: int) -> str: - """Render a Rich renderable to an ANSI string for prompt_toolkit integration.""" - from io import StringIO - - width = max(20, columns) - buf = StringIO() - temp = Console( - file=buf, - force_terminal=True, - width=width, - theme=NEUTRAL_MARKDOWN_THEME, - highlight=False, - ) - temp.print(renderable, end="") - result = buf.getvalue() - return _OSC8_RE.sub(_wrap_osc8_as_zero_width, result) diff --git a/src/kimi_cli/ui/shell/console.ts b/src/kimi_cli/ui/shell/console.ts new file mode 100644 index 000000000..8483fcf9d --- /dev/null +++ b/src/kimi_cli/ui/shell/console.ts @@ -0,0 +1,24 @@ +/** + * Console utilities — corresponds to Python's ui/shell/console.py + * Terminal size detection and helpers. + */ + +/** + * Get current terminal dimensions. + */ +export function getTerminalSize(): { columns: number; rows: number } { + return { + columns: process.stdout.columns || 80, + rows: process.stdout.rows || 24, + }; +} + +/** + * Listen for terminal resize events. + */ +export function onResize(callback: () => void): () => void { + process.stdout.on("resize", callback); + return () => { + process.stdout.off("resize", callback); + }; +} diff --git a/src/kimi_cli/ui/shell/debug.py b/src/kimi_cli/ui/shell/debug.py deleted file mode 100644 index 47e7a9f30..000000000 --- a/src/kimi_cli/ui/shell/debug.py +++ /dev/null @@ -1,190 +0,0 @@ -from __future__ import annotations - -import json -from typing import TYPE_CHECKING - -from kosong.message import Message -from rich.console import Group, RenderableType -from rich.panel import Panel -from rich.rule import Rule -from rich.syntax import Syntax -from rich.text import Text - -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.slash import registry -from kimi_cli.wire.types import ( - AudioURLPart, - ContentPart, - ImageURLPart, - TextPart, - ThinkPart, - ToolCall, - VideoURLPart, -) - -if TYPE_CHECKING: - from kimi_cli.ui.shell import Shell - - -def _format_content_part(part: ContentPart) -> Text | Panel | Group: - """Format a single content part.""" - match part: - case TextPart(text=text): - # Check if it looks like a system tag - if text.strip().startswith("") and text.strip().endswith(""): - return Panel( - text.strip()[8:-9].strip(), - title="[dim]system[/dim]", - border_style="dim yellow", - padding=(0, 1), - ) - return Text(text, style="white") - - case ThinkPart(think=think): - return Panel( - think, - title="[dim]thinking[/dim]", - border_style="dim cyan", - padding=(0, 1), - ) - - case ImageURLPart(image_url=img): - url_display = img.url[:80] + "..." if len(img.url) > 80 else img.url - return Text(f"[Image] {url_display}", style="blue") - - case AudioURLPart(audio_url=audio): - url_display = audio.url[:80] + "..." if len(audio.url) > 80 else audio.url - id_text = f" (id: {audio.id})" if audio.id else "" - return Text(f"[Audio{id_text}] {url_display}", style="blue") - - case VideoURLPart(video_url=video): - url_display = video.url[:80] + "..." if len(video.url) > 80 else video.url - return Text(f"[Video] {url_display}", style="blue") - - case _: - return Text(f"[Unknown content type: {type(part).__name__}]", style="red") - - -def _format_tool_call(tool_call: ToolCall) -> Panel: - """Format a tool call.""" - args = tool_call.function.arguments or "{}" - try: - args_formatted = json.dumps(json.loads(args, strict=False), indent=2) - args_syntax = Syntax(args_formatted, "json", theme="monokai", padding=(0, 1)) - except json.JSONDecodeError: - args_syntax = Text(args, style="red") - - content = Group( - Text(f"Function: {tool_call.function.name}", style="bold cyan"), - Text(f"Call ID: {tool_call.id}", style="dim"), - Text("Arguments:", style="bold"), - args_syntax, - ) - - return Panel( - content, - title="[bold yellow]Tool Call[/bold yellow]", - border_style="yellow", - padding=(0, 1), - ) - - -def _format_message(msg: Message, index: int) -> Panel: - """Format a single message.""" - # Role styling - role_colors = { - "system": "magenta", - "developer": "magenta", - "user": "green", - "assistant": "blue", - "tool": "yellow", - } - role_color = role_colors.get(msg.role, "white") - role_text = f"[bold {role_color}]{msg.role.upper()}[/bold {role_color}]" - - # Add name if present - if msg.name: - role_text += f" [dim]({msg.name})[/dim]" - - # Add tool call ID for tool messages - if msg.tool_call_id: - role_text += f" [dim]→ {msg.tool_call_id}[/dim]" - - # Format content - content_items: list[RenderableType] = [] - - for part in msg.content: - formatted = _format_content_part(part) - content_items.append(formatted) - - # Add tool calls if present - if msg.tool_calls: - if content_items: - content_items.append(Text()) # Empty line - for tool_call in msg.tool_calls: - content_items.append(_format_tool_call(tool_call)) - - # Combine all content - if not content_items: - content_items.append(Text("[empty message]", style="dim italic")) - - group = Group(*content_items) - - # Create panel - title = f"#{index + 1} {role_text}" - if msg.partial: - title += " [dim italic](partial)[/dim italic]" - - return Panel( - group, - title=title, - border_style=role_color, - padding=(0, 1), - ) - - -@registry.command -def debug(app: Shell, args: str): - """Debug the context""" - assert isinstance(app.soul, KimiSoul) - - context = app.soul.context - history = context.history - - if not history: - console.print( - Panel( - "Context is empty - no messages yet", - border_style="yellow", - padding=(1, 2), - ) - ) - return - - # Build the debug output - output_items = [ - Panel( - Group( - Text(f"Total messages: {len(history)}", style="bold"), - Text(f"Token count: {context.token_count:,}", style="bold"), - Text(f"Checkpoints: {context.n_checkpoints}", style="bold"), - Text(f"Trajectory: {context.file_backend}", style="dim"), - ), - title="[bold]Context Info[/bold]", - border_style="cyan", - padding=(0, 1), - ), - Rule(style="dim"), - ] - - # Add all messages - for idx, msg in enumerate(history): - output_items.append(_format_message(msg, idx)) - - # Display using rich pager - display_group = Group(*output_items) - - # Use pager to display - with console.pager(styles=True): - console.print(display_group) diff --git a/src/kimi_cli/ui/shell/echo.py b/src/kimi_cli/ui/shell/echo.py deleted file mode 100644 index 5e6bbdd7c..000000000 --- a/src/kimi_cli/ui/shell/echo.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -from kosong.message import Message -from rich.text import Text - -from kimi_cli.ui.shell.prompt import PROMPT_SYMBOL -from kimi_cli.utils.message import message_stringify - - -def render_user_echo(message: Message) -> Text: - """Render a user message as literal shell transcript text.""" - return Text(f"{PROMPT_SYMBOL} {message_stringify(message)}") - - -def render_user_echo_text(text: str) -> Text: - """Render the local prompt text exactly as the user saw it in the buffer.""" - return Text(f"{PROMPT_SYMBOL} {text}") diff --git a/src/kimi_cli/ui/shell/events.ts b/src/kimi_cli/ui/shell/events.ts new file mode 100644 index 000000000..d9a5aa6a4 --- /dev/null +++ b/src/kimi_cli/ui/shell/events.ts @@ -0,0 +1,71 @@ +/** + * Wire event types for UI consumption. + * Simplified interface that UI components use to render messages. + */ + +import type { + StatusUpdate, + ApprovalRequest, + ToolResult, + DisplayBlock, +} from "../../wire/types"; + +// ── UI Message Types ────────────────────────────────────── + +export type UIMessageRole = "user" | "assistant" | "system" | "tool"; + +export interface TextSegment { + type: "text"; + text: string; +} + +export interface ThinkSegment { + type: "think"; + text: string; +} + +export interface ToolCallSegment { + type: "tool_call"; + id: string; + name: string; + arguments: string; + result?: ToolResult; + collapsed: boolean; +} + +export type MessageSegment = TextSegment | ThinkSegment | ToolCallSegment; + +export interface UIMessage { + id: string; + role: UIMessageRole; + segments: MessageSegment[]; + timestamp: number; +} + +// ── Wire Events (simplified for UI) ─────────────────────── + +export type WireUIEvent = + | { type: "turn_begin"; userInput: string } + | { type: "turn_end" } + | { type: "step_begin"; n: number } + | { type: "step_interrupted" } + | { type: "text_delta"; text: string } + | { type: "think_delta"; text: string } + | { type: "tool_call"; id: string; name: string; arguments: string } + | { type: "tool_call_delta"; id: string; arguments: string } + | { type: "tool_result"; toolCallId: string; result: ToolResult } + | { type: "approval_request"; request: ApprovalRequest } + | { type: "approval_response"; requestId: string; response: string } + | { type: "question_request"; request: import("../../wire/types.ts").QuestionRequest } + | { type: "question_response"; requestId: string; answers: Record } + | { type: "status_update"; status: StatusUpdate } + | { type: "compaction_begin" } + | { type: "compaction_end" } + | { type: "notification"; title: string; body: string; severity?: string } + | { type: "plan_display"; content: string; filePath: string } + | { type: "hook_triggered"; event: string; target: string; hookCount: number } + | { type: "hook_resolved"; event: string; target: string; action: string; reason: string; durationMs: number } + | { type: "mcp_loading_begin" } + | { type: "mcp_loading_end" } + | { type: "subagent_event"; parentToolCallId: string | null; agentId: string | null; subagentType: string | null; event: Record } + | { type: "error"; message: string; retryable?: boolean; retryAfter?: number }; diff --git a/src/kimi_cli/ui/shell/export_import.py b/src/kimi_cli/ui/shell/export_import.py deleted file mode 100644 index a1e28c4e7..000000000 --- a/src/kimi_cli/ui/shell/export_import.py +++ /dev/null @@ -1,111 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import TYPE_CHECKING - -from kaos.path import KaosPath - -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.slash import ensure_kimi_soul, registry, shell_mode_registry -from kimi_cli.utils.export import is_sensitive_file -from kimi_cli.utils.path import sanitize_cli_path, shorten_home -from kimi_cli.wire.types import TurnBegin, TurnEnd - -if TYPE_CHECKING: - from kimi_cli.ui.shell import Shell - - -# --------------------------------------------------------------------------- -# /export command -# --------------------------------------------------------------------------- - - -@registry.command -@shell_mode_registry.command -async def export(app: Shell, args: str): - """Export current session context to a markdown file""" - from kimi_cli.utils.export import perform_export - - soul = ensure_kimi_soul(app) - if soul is None: - return - - session = soul.runtime.session - result = await perform_export( - history=list(soul.context.history), - session_id=session.id, - work_dir=str(session.work_dir), - token_count=soul.context.token_count, - args=args, - default_dir=Path(str(session.work_dir)), - ) - if isinstance(result, str): - console.print(f"[yellow]{result}[/yellow]") - return - - output, count = result - display = shorten_home(KaosPath(str(output))) - console.print(f"[green]Exported {count} messages to {display}[/green]") - console.print( - "[yellow]Note: The exported file may contain sensitive information. " - "Please be cautious when sharing it externally.[/yellow]" - ) - - -# --------------------------------------------------------------------------- -# /import command -# --------------------------------------------------------------------------- - - -@registry.command(name="import") -@shell_mode_registry.command(name="import") -async def import_context(app: Shell, args: str): - """Import context from a file or session ID""" - from kimi_cli.utils.export import perform_import - - soul = ensure_kimi_soul(app) - if soul is None: - return - - target = sanitize_cli_path(args) - if not target: - console.print("[yellow]Usage: /import [/yellow]") - return - - session = soul.runtime.session - raw_max_context_size = ( - soul.runtime.llm.max_context_size if soul.runtime.llm is not None else None - ) - max_context_size = ( - raw_max_context_size - if isinstance(raw_max_context_size, int) and raw_max_context_size > 0 - else None - ) - result = await perform_import( - target=target, - current_session_id=session.id, - work_dir=session.work_dir, - context=soul.context, - max_context_size=max_context_size, - ) - if isinstance(result, str): - console.print(f"[red]{result}[/red]") - return - - source_desc, content_len = result - - # Write to wire file so the import appears in session replay - await soul.wire_file.append_message( - TurnBegin(user_input=f"[Imported context from {source_desc}]") - ) - await soul.wire_file.append_message(TurnEnd()) - - console.print( - f"[green]Imported context from {source_desc} " - f"({content_len} chars) into current session.[/green]" - ) - if source_desc.startswith("file") and is_sensitive_file(Path(target).name): - console.print( - "[yellow]Warning: This file may contain secrets (API keys, tokens, credentials). " - "The content is now part of your session context.[/yellow]" - ) diff --git a/src/kimi_cli/ui/shell/index.ts b/src/kimi_cli/ui/shell/index.ts new file mode 100644 index 000000000..2e8868240 --- /dev/null +++ b/src/kimi_cli/ui/shell/index.ts @@ -0,0 +1,48 @@ +export { Shell } from "./Shell.tsx"; +export type { ShellProps } from "./Shell.tsx"; +export { Prompt } from "./Prompt.tsx"; +export { + MessageList, + StreamingText, + ThinkingView, + ToolCallView, + ErrorRecoveryView, + classifyApiError, + NotificationView, + StatusView, + PlanDisplayView, + HookTriggeredView, + HookResolvedView, +} from "./Visualize.tsx"; +export type { ErrorInfo, NotificationViewProps, StatusViewProps } from "./Visualize.tsx"; +export { ApprovalPanel } from "./ApprovalPanel.tsx"; +export type { ApprovalPanelProps } from "./ApprovalPanel.tsx"; +export { QuestionPanel } from "./QuestionPanel.tsx"; +export type { QuestionPanelProps } from "./QuestionPanel.tsx"; +export { DebugPanel } from "./DebugPanel.tsx"; +export type { DebugPanelProps, DebugMessage, ContextInfo } from "./DebugPanel.tsx"; +export { UsagePanel, parseUsagePayload } from "./UsagePanel.tsx"; +export type { UsagePanelProps, UsageRow } from "./UsagePanel.tsx"; +export { TaskBrowser } from "./TaskBrowser.tsx"; +export type { TaskBrowserProps, TaskView, TaskViewSpec, TaskViewRuntime, TaskStatus } from "./TaskBrowser.tsx"; +export { SetupWizard } from "./SetupWizard.tsx"; +export type { SetupWizardProps, SetupResult, PlatformInfo, ModelInfo } from "./SetupWizard.tsx"; +export { ReplayPanel, buildReplayTurnsFromEvents } from "./ReplayPanel.tsx"; +export type { ReplayPanelProps, ReplayTurn, ReplayEvent } from "./ReplayPanel.tsx"; +export { useKeyboard } from "./keyboard.ts"; +export type { KeyAction } from "./keyboard.ts"; +export { getTerminalSize, onResize } from "./console.ts"; +export { + createShellSlashCommands, + parseSlashCommand, + findSlashCommand, +} from "./slash.ts"; +export type { + WireUIEvent, + UIMessage, + UIMessageRole, + MessageSegment, + TextSegment, + ThinkSegment, + ToolCallSegment, +} from "./events.ts"; diff --git a/src/kimi_cli/ui/shell/keyboard.py b/src/kimi_cli/ui/shell/keyboard.py deleted file mode 100644 index 54b3a7157..000000000 --- a/src/kimi_cli/ui/shell/keyboard.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import annotations - -import asyncio -import sys -import threading -import time -from collections.abc import AsyncGenerator, Callable -from enum import Enum, auto - -from kimi_cli.utils.aioqueue import Queue - - -class KeyEvent(Enum): - UP = auto() - DOWN = auto() - LEFT = auto() - RIGHT = auto() - ENTER = auto() - ESCAPE = auto() - TAB = auto() - SPACE = auto() - CTRL_E = auto() - NUM_1 = auto() - NUM_2 = auto() - NUM_3 = auto() - NUM_4 = auto() - NUM_5 = auto() - NUM_6 = auto() - - -class KeyboardListener: - def __init__(self) -> None: - self._queue = Queue[KeyEvent]() - self._cancel_event = threading.Event() - self._pause_event = threading.Event() - self._paused_event = threading.Event() - self._listener: threading.Thread | None = None - self._loop: asyncio.AbstractEventLoop | None = None - - async def start(self) -> None: - if self._listener is not None: - return - self._loop = asyncio.get_running_loop() - - def emit(event: KeyEvent) -> None: - if self._loop is None: - return - self._loop.call_soon_threadsafe(self._queue.put_nowait, event) - - self._listener = threading.Thread( - target=_listen_for_keyboard_thread, - args=(self._cancel_event, self._pause_event, self._paused_event, emit), - name="kimi-cli-keyboard-listener", - daemon=True, - ) - self._listener.start() - - async def stop(self) -> None: - self._cancel_event.set() - self._pause_event.clear() - if self._listener and self._listener.is_alive(): - await asyncio.to_thread(self._listener.join) - - def _pause_sync(self) -> None: - self._pause_event.set() - self._paused_event.wait() - - async def pause(self) -> None: - await asyncio.to_thread(self._pause_sync) - - def _resume_sync(self) -> None: - self._pause_event.clear() - while self._paused_event.is_set() and not self._cancel_event.is_set(): - time.sleep(0.01) - - async def resume(self) -> None: - await asyncio.to_thread(self._resume_sync) - - async def get(self) -> KeyEvent: - return await self._queue.get() - - -async def listen_for_keyboard() -> AsyncGenerator[KeyEvent]: - listener = KeyboardListener() - await listener.start() - - try: - while True: - yield await listener.get() - finally: - await listener.stop() - - -def _listen_for_keyboard_thread( - cancel: threading.Event, - pause: threading.Event, - paused: threading.Event, - emit: Callable[[KeyEvent], None], -) -> None: - if sys.platform == "win32": - _listen_for_keyboard_windows(cancel, pause, paused, emit) - else: - _listen_for_keyboard_unix(cancel, pause, paused, emit) - - -def _listen_for_keyboard_unix( - cancel: threading.Event, - pause: threading.Event, - paused: threading.Event, - emit: Callable[[KeyEvent], None], -) -> None: - if sys.platform == "win32": - raise RuntimeError("Unix keyboard listener requires a non-Windows platform") - - import termios - - fd = sys.stdin.fileno() - oldterm = termios.tcgetattr(fd) - rawattr = termios.tcgetattr(fd) - rawattr[3] = rawattr[3] & ~termios.ICANON & ~termios.ECHO - rawattr[6][termios.VMIN] = 0 - rawattr[6][termios.VTIME] = 0 - raw_enabled = False - - def enable_raw() -> None: - nonlocal raw_enabled - if raw_enabled: - return - termios.tcsetattr(fd, termios.TCSANOW, rawattr) - raw_enabled = True - - def disable_raw() -> None: - nonlocal raw_enabled - if not raw_enabled: - return - termios.tcsetattr(fd, termios.TCSANOW, oldterm) - raw_enabled = False - - enable_raw() - - try: - while not cancel.is_set(): - if pause.is_set(): - disable_raw() - paused.set() - time.sleep(0.01) - continue - if paused.is_set(): - paused.clear() - enable_raw() - - try: - c = sys.stdin.buffer.read(1) - except (OSError, ValueError): - c = b"" - - if not c: - if cancel.is_set(): - break - time.sleep(0.01) - continue - - if c == b"\x1b": - sequence = c - for _ in range(2): - if cancel.is_set(): - break - try: - fragment = sys.stdin.buffer.read(1) - except (OSError, ValueError): - fragment = b"" - if not fragment: - break - sequence += fragment - if sequence in _ARROW_KEY_MAP: - break - - event = _ARROW_KEY_MAP.get(sequence) - if event is not None: - emit(event) - elif sequence == b"\x1b": - emit(KeyEvent.ESCAPE) - elif c in (b"\r", b"\n"): - emit(KeyEvent.ENTER) - elif c == b" ": - emit(KeyEvent.SPACE) - elif c == b"\t": - emit(KeyEvent.TAB) - elif c == b"\x05": # Ctrl+E - emit(KeyEvent.CTRL_E) - elif c == b"1": - emit(KeyEvent.NUM_1) - elif c == b"2": - emit(KeyEvent.NUM_2) - elif c == b"3": - emit(KeyEvent.NUM_3) - elif c == b"4": - emit(KeyEvent.NUM_4) - elif c == b"5": - emit(KeyEvent.NUM_5) - elif c == b"6": - emit(KeyEvent.NUM_6) - finally: - termios.tcsetattr(fd, termios.TCSAFLUSH, oldterm) - - -def _listen_for_keyboard_windows( - cancel: threading.Event, - pause: threading.Event, - paused: threading.Event, - emit: Callable[[KeyEvent], None], -) -> None: - if sys.platform != "win32": - raise RuntimeError("Windows keyboard listener requires a Windows platform") - - import msvcrt - - while not cancel.is_set(): - if pause.is_set(): - paused.set() - time.sleep(0.01) - continue - if paused.is_set(): - paused.clear() - - if msvcrt.kbhit(): - c = msvcrt.getch() - - # Handle special keys (arrow keys, etc.) - if c in (b"\x00", b"\xe0"): - # Extended key, read the next byte - extended = msvcrt.getch() - event = _WINDOWS_KEY_MAP.get(extended) - if event is not None: - emit(event) - elif c == b"\x1b": - sequence = c - for _ in range(2): - if cancel.is_set(): - break - fragment = msvcrt.getch() if msvcrt.kbhit() else b"" - if not fragment: - break - sequence += fragment - if sequence in _ARROW_KEY_MAP: - break - - event = _ARROW_KEY_MAP.get(sequence) - if event is not None: - emit(event) - elif sequence == b"\x1b": - emit(KeyEvent.ESCAPE) - elif c in (b"\r", b"\n"): - emit(KeyEvent.ENTER) - elif c == b" ": - emit(KeyEvent.SPACE) - elif c == b"\t": - emit(KeyEvent.TAB) - elif c == b"\x05": # Ctrl+E - emit(KeyEvent.CTRL_E) - elif c == b"1": - emit(KeyEvent.NUM_1) - elif c == b"2": - emit(KeyEvent.NUM_2) - elif c == b"3": - emit(KeyEvent.NUM_3) - elif c == b"4": - emit(KeyEvent.NUM_4) - elif c == b"5": - emit(KeyEvent.NUM_5) - elif c == b"6": - emit(KeyEvent.NUM_6) - else: - if cancel.is_set(): - break - time.sleep(0.01) - - -_ARROW_KEY_MAP: dict[bytes, KeyEvent] = { - b"\x1b[A": KeyEvent.UP, - b"\x1b[B": KeyEvent.DOWN, - b"\x1b[C": KeyEvent.RIGHT, - b"\x1b[D": KeyEvent.LEFT, -} - -_WINDOWS_KEY_MAP: dict[bytes, KeyEvent] = { - b"H": KeyEvent.UP, # Up arrow - b"P": KeyEvent.DOWN, # Down arrow - b"M": KeyEvent.RIGHT, # Right arrow - b"K": KeyEvent.LEFT, # Left arrow -} - - -if __name__ == "__main__": - - async def dev_main(): - async for event in listen_for_keyboard(): - print(event) - - asyncio.run(dev_main()) diff --git a/src/kimi_cli/ui/shell/keyboard.ts b/src/kimi_cli/ui/shell/keyboard.ts new file mode 100644 index 000000000..1a43129db --- /dev/null +++ b/src/kimi_cli/ui/shell/keyboard.ts @@ -0,0 +1,97 @@ +/** + * Keyboard handling — corresponds to Python's ui/shell/keyboard.py + * Uses Ink's useInput hook for keyboard events in the React tree. + * + * Behavior: + * - Ctrl+C ×1: interrupt current streaming turn + * - Ctrl+C ×2 (within 500ms): exit the application + * - Esc ×1: interrupt current streaming turn + * - Esc ×2 (within 500ms): clear the input box + */ + +import { useInput, useApp } from "ink"; +import { useRef } from "react"; + +export type KeyAction = + | "interrupt" + | "exit" + | "clear-input"; + +export interface UseKeyboardOptions { + onAction: (action: KeyAction) => void; + /** Whether keyboard input is active (default true) */ + active?: boolean; +} + +const DOUBLE_PRESS_WINDOW = 500; // ms + +/** + * Hook that handles global keyboard shortcuts for the shell. + * + * Ctrl+C: 1st press = interrupt, 2nd press within 500ms = exit + * Escape: 1st press = interrupt, 2nd press within 500ms = clear input + */ +export function useKeyboard({ onAction, active = true }: UseKeyboardOptions) { + const { exit } = useApp(); + + // Ctrl+C double-press tracking + const ctrlCCount = useRef(0); + const ctrlCTimer = useRef | null>(null); + + // Esc double-press tracking + const escCount = useRef(0); + const escTimer = useRef | null>(null); + + useInput( + (input, key) => { + // ── Ctrl+C ──────────────────────────────────── + if (input === "c" && key.ctrl) { + // Reset Esc counter on Ctrl+C + escCount.current = 0; + + ctrlCCount.current += 1; + if (ctrlCCount.current >= 2) { + // Double Ctrl+C → exit + ctrlCCount.current = 0; + if (ctrlCTimer.current) clearTimeout(ctrlCTimer.current); + exit(); + return; + } + // Start/reset the window timer + if (ctrlCTimer.current) clearTimeout(ctrlCTimer.current); + ctrlCTimer.current = setTimeout(() => { + ctrlCCount.current = 0; + }, DOUBLE_PRESS_WINDOW); + onAction("interrupt"); + return; + } + + // ── Escape ──────────────────────────────────── + if (key.escape) { + // Reset Ctrl+C counter on Esc + ctrlCCount.current = 0; + + escCount.current += 1; + if (escCount.current >= 2) { + // Double Esc → clear input + escCount.current = 0; + if (escTimer.current) clearTimeout(escTimer.current); + onAction("clear-input"); + return; + } + // Start/reset the window timer + if (escTimer.current) clearTimeout(escTimer.current); + escTimer.current = setTimeout(() => { + escCount.current = 0; + }, DOUBLE_PRESS_WINDOW); + onAction("interrupt"); + return; + } + + // Any other key resets both counters + ctrlCCount.current = 0; + escCount.current = 0; + }, + { isActive: active }, + ); +} diff --git a/src/kimi_cli/ui/shell/mcp_status.py b/src/kimi_cli/ui/shell/mcp_status.py deleted file mode 100644 index 32cc498ab..000000000 --- a/src/kimi_cli/ui/shell/mcp_status.py +++ /dev/null @@ -1,111 +0,0 @@ -from __future__ import annotations - -import time - -from prompt_toolkit.formatted_text import FormattedText -from rich.console import Group, RenderableType -from rich.spinner import Spinner -from rich.text import Text - -from kimi_cli.ui.theme import get_mcp_prompt_colors -from kimi_cli.utils.rich.columns import BulletColumns -from kimi_cli.wire.types import MCPServerSnapshot, MCPStatusSnapshot - -_SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏") - - -def render_mcp_console(snapshot: MCPStatusSnapshot) -> RenderableType: - header_text = Text.assemble( - ("MCP Servers: ", "bold"), - f"{snapshot.connected}/{snapshot.total} connected, {snapshot.tools} tools", - ) - header: RenderableType = Spinner("dots", header_text) if snapshot.loading else header_text - - renderables: list[RenderableType] = [BulletColumns(header)] - for server in snapshot.servers: - color = _status_color(server.status) - server_text = f"[{color}]{server.name}[/{color}]" - if server.status == "unauthorized": - server_text += f" [grey50](unauthorized - run: kimi mcp auth {server.name})[/grey50]" - elif server.status != "connected": - server_text += f" [grey50]({server.status})[/grey50]" - - lines: list[RenderableType] = [Text.from_markup(server_text)] - for tool_name in server.tools: - lines.append( - BulletColumns( - Text.from_markup(f"[grey50]{tool_name}[/grey50]"), - bullet_style="grey50", - ) - ) - renderables.append(BulletColumns(Group(*lines), bullet_style=color)) - - return Group(*renderables) - - -def render_mcp_prompt(snapshot: MCPStatusSnapshot, *, now: float | None = None) -> FormattedText: - if not snapshot.loading: - return FormattedText([]) - - fragments: list[tuple[str, str]] = [] - colors = get_mcp_prompt_colors() - prefix = f"{_spinner_frame(now)} " if snapshot.loading else "" - fragments.append( - ( - colors.text, - ( - f"{prefix}MCP Servers: " - f"{snapshot.connected}/{snapshot.total} connected, {snapshot.tools} tools" - ), - ) - ) - fragments.append(("", "\n")) - - for server in snapshot.servers: - fragments.append((_prompt_status_style(server.status), f"• {server.name}")) - detail = _prompt_server_detail(server) - if detail: - fragments.append((colors.detail, detail)) - fragments.append(("", "\n")) - - return FormattedText(fragments) - - -def _spinner_frame(now: float | None = None) -> str: - timestamp = time.monotonic() if now is None else now - return _SPINNER_FRAMES[int(timestamp * 8) % len(_SPINNER_FRAMES)] - - -def _status_color(status: str) -> str: - return { - "connected": "green", - "connecting": "cyan", - "pending": "yellow", - "failed": "red", - "unauthorized": "red", - }.get(status, "red") - - -def _prompt_status_style(status: str) -> str: - colors = get_mcp_prompt_colors() - return { - "connected": colors.connected, - "connecting": colors.connecting, - "pending": colors.pending, - "failed": colors.failed, - "unauthorized": colors.failed, - }.get(status, colors.failed) - - -def _prompt_server_detail(server: MCPServerSnapshot) -> str: - if server.status == "unauthorized": - return f" (unauthorized - run: kimi mcp auth {server.name})" - - parts: list[str] = [] - if server.status != "connected": - parts.append(server.status) - if server.tools: - label = "tool" if len(server.tools) == 1 else "tools" - parts.append(f"{len(server.tools)} {label}") - - return f" ({', '.join(parts)})" if parts else "" diff --git a/src/kimi_cli/ui/shell/oauth.py b/src/kimi_cli/ui/shell/oauth.py deleted file mode 100644 index 059bd1582..000000000 --- a/src/kimi_cli/ui/shell/oauth.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING - -from rich.status import Status - -from kimi_cli.auth import KIMI_CODE_PLATFORM_ID -from kimi_cli.auth.oauth import login_kimi_code, logout_kimi_code -from kimi_cli.auth.platforms import is_managed_provider_key, parse_managed_provider_key -from kimi_cli.cli import Reload -from kimi_cli.config import save_config -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.setup import select_platform, setup_platform -from kimi_cli.ui.shell.slash import ensure_kimi_soul, registry - -if TYPE_CHECKING: - from kimi_cli.ui.shell import Shell - - -async def _login_kimi_code(soul: KimiSoul) -> bool: - status: Status | None = None - ok = True - try: - async for event in login_kimi_code(soul.runtime.config): - if event.type == "waiting": - if status is None: - status = console.status("[cyan]Waiting for user authorization...[/cyan]") - status.start() - continue - if status is not None: - status.stop() - status = None - match event.type: - case "error": - style = "red" - case "success": - style = "green" - case _: - style = None - console.print(event.message, markup=False, style=style) - if event.type == "error": - ok = False - finally: - if status is not None: - status.stop() - return ok - - -def current_model_key(soul: KimiSoul) -> str | None: - config = soul.runtime.config - curr_model_cfg = soul.runtime.llm.model_config if soul.runtime.llm else None - if curr_model_cfg is not None: - for name, model_cfg in config.models.items(): - if model_cfg == curr_model_cfg: - return name - return config.default_model or None - - -@registry.command(aliases=["setup"]) -async def login(app: Shell, args: str) -> None: - """Login or setup a platform.""" - soul = ensure_kimi_soul(app) - if soul is None: - return - platform = await select_platform() - if platform is None: - return - if platform.id == KIMI_CODE_PLATFORM_ID: - ok = await _login_kimi_code(soul) - else: - ok = await setup_platform(platform) - if not ok: - return - await asyncio.sleep(1) - console.clear() - raise Reload - - -@registry.command -async def logout(app: Shell, args: str) -> None: - """Logout from the current platform.""" - soul = ensure_kimi_soul(app) - if soul is None: - return - config = soul.runtime.config - if not config.is_from_default_location: - console.print( - "[red]Logout requires the default config file; " - "restart without --config/--config-file.[/red]" - ) - return - model_key = current_model_key(soul) - if not model_key: - console.print("[yellow]No model selected; nothing to logout.[/yellow]") - return - model_cfg = config.models.get(model_key) - if model_cfg is None: - console.print("[yellow]Current model not found; nothing to logout.[/yellow]") - return - provider_key = model_cfg.provider - if not is_managed_provider_key(provider_key): - console.print("[yellow]Current provider is not managed; nothing to logout.[/yellow]") - return - platform_id = parse_managed_provider_key(provider_key) - if not platform_id: - console.print("[yellow]Current provider is not managed; nothing to logout.[/yellow]") - return - - if platform_id == KIMI_CODE_PLATFORM_ID: - ok = True - async for event in logout_kimi_code(config): - match event.type: - case "error": - style = "red" - case "success": - style = "green" - case _: - style = None - console.print(event.message, markup=False, style=style) - if event.type == "error": - ok = False - if not ok: - return - else: - if provider_key in config.providers: - del config.providers[provider_key] - removed_default = False - for key, model in list(config.models.items()): - if model.provider != provider_key: - continue - del config.models[key] - if config.default_model == key: - removed_default = True - if removed_default: - config.default_model = "" - save_config(config) - console.print("[green]✓[/green] Logged out successfully.") - - await asyncio.sleep(1) - console.clear() - raise Reload diff --git a/src/kimi_cli/ui/shell/placeholders.py b/src/kimi_cli/ui/shell/placeholders.py deleted file mode 100644 index f58c39c84..000000000 --- a/src/kimi_cli/ui/shell/placeholders.py +++ /dev/null @@ -1,531 +0,0 @@ -from __future__ import annotations - -import base64 -import mimetypes -import re -from collections.abc import Callable, Sequence -from dataclasses import dataclass -from difflib import SequenceMatcher -from hashlib import sha256 -from io import BytesIO -from pathlib import Path -from typing import Literal, Protocol - -from PIL import Image - -from kimi_cli.share import get_share_dir -from kimi_cli.utils.envvar import get_env_int -from kimi_cli.utils.logging import logger -from kimi_cli.utils.media_tags import wrap_media_part -from kimi_cli.utils.string import random_string -from kimi_cli.wire.types import ContentPart, ImageURLPart, TextPart - -_DEFAULT_PROMPT_CACHE_ROOT = get_share_dir() / "prompt-cache" -_LEGACY_PROMPT_CACHE_ROOT = Path("/tmp/kimi") - -_IMAGE_PLACEHOLDER_RE = re.compile( - r"\[(?P[a-zA-Z0-9_\-]+):(?P[a-zA-Z0-9_\-\.]+)" - r"(?:,(?P\d+)x(?P\d+))?\]" -) -_PASTED_TEXT_PLACEHOLDER_RE = re.compile( - r"\[Pasted text #(?P\d+)(?: \+(?P\d+) lines?)?\]" -) - -_TEXT_PASTE_CHAR_THRESHOLD = get_env_int("KIMI_CLI_PASTE_CHAR_THRESHOLD", 1000) -_TEXT_PASTE_LINE_THRESHOLD = get_env_int("KIMI_CLI_PASTE_LINE_THRESHOLD", 15) - - -def sanitize_surrogates(text: str) -> str: - """Replace lone UTF-16 surrogates that cannot be encoded as UTF-8. - - Windows clipboard data sometimes contains unpaired surrogates from - applications that use UTF-16 internally. Passing such strings to - ``json.dumps`` or writing them to a UTF-8 file raises - ``UnicodeEncodeError``, so we replace them with U+FFFD early. - """ - return text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="replace") - - -def normalize_pasted_text(text: str) -> str: - """Normalize pasted text into the same newline format used by prompt_toolkit.""" - return text.replace("\r\n", "\n").replace("\r", "\n") - - -def count_text_lines(text: str) -> int: - if not text: - return 1 - return text.count("\n") + 1 - - -def should_placeholderize_pasted_text(text: str) -> bool: - normalized = normalize_pasted_text(text) - return ( - len(normalized) >= _TEXT_PASTE_CHAR_THRESHOLD - or count_text_lines(normalized) >= _TEXT_PASTE_LINE_THRESHOLD - ) - - -def build_pasted_text_placeholder(paste_id: int, text: str) -> str: - line_count = count_text_lines(text) - if line_count <= 1: - return f"[Pasted text #{paste_id}]" - return f"[Pasted text #{paste_id} +{line_count} lines]" - - -def _guess_image_mime(path: Path) -> str: - mime, _ = mimetypes.guess_type(path.name) - if mime: - return mime - return "image/png" - - -def _build_image_part(image_bytes: bytes, mime_type: str) -> ImageURLPart: - image_base64 = base64.b64encode(image_bytes).decode("ascii") - return ImageURLPart( - image_url=ImageURLPart.ImageURL( - url=f"data:{mime_type};base64,{image_base64}", - ) - ) - - -type CachedAttachmentKind = Literal["image"] - - -@dataclass(slots=True) -class CachedAttachment: - kind: CachedAttachmentKind - attachment_id: str - path: Path - - -class AttachmentCache: - """Persistent cache for placeholder payloads that can safely survive history recall.""" - - def __init__( - self, - root: Path | None = None, - *, - legacy_roots: Sequence[Path] | None = None, - ) -> None: - self._root = root or _DEFAULT_PROMPT_CACHE_ROOT - self._legacy_roots = tuple(legacy_roots or (_LEGACY_PROMPT_CACHE_ROOT,)) - self._dir_map: dict[CachedAttachmentKind, str] = {"image": "images"} - self._payload_map: dict[tuple[CachedAttachmentKind, str, str], CachedAttachment] = {} - - def _dir_for(self, kind: CachedAttachmentKind, *, root: Path | None = None) -> Path: - return (self._root if root is None else root) / self._dir_map[kind] - - def _ensure_dir(self, kind: CachedAttachmentKind) -> Path | None: - path = self._dir_for(kind) - try: - path.mkdir(parents=True, exist_ok=True) - except OSError as exc: - logger.warning( - "Failed to create attachment cache dir: {dir} ({error})", - dir=path, - error=exc, - ) - return None - return path - - def _reserve_id(self, dir_path: Path, suffix: str) -> str: - for _ in range(5): - candidate = f"{random_string(8)}{suffix}" - if not (dir_path / candidate).exists(): - return candidate - return f"{random_string(12)}{suffix}" - - def store_bytes( - self, kind: CachedAttachmentKind, suffix: str, payload: bytes - ) -> CachedAttachment | None: - dir_path = self._ensure_dir(kind) - if dir_path is None: - return None - - payload_hash = sha256(payload).hexdigest() - cache_key = (kind, suffix, payload_hash) - cached = self._payload_map.get(cache_key) - if cached is not None: - if cached.path.exists(): - return cached - self._payload_map.pop(cache_key, None) - - attachment_id = self._reserve_id(dir_path, suffix) - path = dir_path / attachment_id - try: - path.write_bytes(payload) - except OSError as exc: - logger.warning( - "Failed to write cached attachment: {file} ({error})", - file=path, - error=exc, - ) - return None - - cached = CachedAttachment(kind=kind, attachment_id=attachment_id, path=path) - self._payload_map[cache_key] = cached - return cached - - def store_image(self, image: Image.Image) -> CachedAttachment | None: - png_bytes = BytesIO() - image.save(png_bytes, format="PNG") - return self.store_bytes("image", ".png", png_bytes.getvalue()) - - def _candidate_paths(self, kind: CachedAttachmentKind, attachment_id: str) -> list[Path]: - roots = (self._root, *self._legacy_roots) - return [self._dir_for(kind, root=root) / attachment_id for root in roots] - - def load_bytes( - self, kind: CachedAttachmentKind, attachment_id: str - ) -> tuple[Path, bytes] | None: - for path in self._candidate_paths(kind, attachment_id): - if not path.exists(): - continue - try: - return path, path.read_bytes() - except OSError as exc: - logger.warning( - "Failed to read cached attachment: {file} ({error})", - file=path, - error=exc, - ) - return None - return None - - def load_content_parts( - self, kind: CachedAttachmentKind, attachment_id: str - ) -> list[ContentPart] | None: - if kind == "image": - payload = self.load_bytes(kind, attachment_id) - if payload is None: - return None - path, image_bytes = payload - mime_type = _guess_image_mime(path) - part = _build_image_part(image_bytes, mime_type) - return wrap_media_part(part, tag="image", attrs={"path": str(path)}) - return None - - -def parse_attachment_kind(raw_kind: str) -> CachedAttachmentKind | None: - if raw_kind == "image": - return "image" - return None - - -_parse_attachment_kind = parse_attachment_kind - - -@dataclass(slots=True) -class PlaceholderTokenMatch: - start: int - end: int - raw: str - handler: PlaceholderHandler - match: re.Match[str] - - -class PlaceholderHandler(Protocol): - def find_next(self, text: str, start: int = 0) -> PlaceholderTokenMatch | None: ... - - def resolve_content(self, match: PlaceholderTokenMatch) -> list[ContentPart] | None: ... - - def expand_text(self, match: PlaceholderTokenMatch) -> str | None: ... - - def serialize_for_history(self, match: PlaceholderTokenMatch) -> str | None: ... - - def expand_for_editor(self, match: PlaceholderTokenMatch) -> str | None: ... - - -@dataclass(slots=True) -class PastedTextEntry: - paste_id: int - text: str - - @property - def token(self) -> str: - return build_pasted_text_placeholder(self.paste_id, self.text) - - -class PastedTextPlaceholderHandler: - def __init__(self) -> None: - self._entries: dict[int, PastedTextEntry] = {} - self._next_id = 1 - - def create_placeholder(self, text: str) -> str: - normalized = sanitize_surrogates(normalize_pasted_text(text)) - entry = PastedTextEntry(paste_id=self._next_id, text=normalized) - self._entries[entry.paste_id] = entry - self._next_id += 1 - return entry.token - - def maybe_placeholderize(self, text: str) -> str: - normalized = normalize_pasted_text(text) - if not should_placeholderize_pasted_text(normalized): - return normalized - return self.create_placeholder(normalized) - - def entry_for_id(self, paste_id: int) -> PastedTextEntry | None: - return self._entries.get(paste_id) - - def iter_entries_for_command( - self, command: str - ) -> list[tuple[PlaceholderTokenMatch, PastedTextEntry]]: - entries: list[tuple[PlaceholderTokenMatch, PastedTextEntry]] = [] - cursor = 0 - while match := self.find_next(command, cursor): - paste_id = int(match.match.group("id")) - entry = self.entry_for_id(paste_id) - if entry is not None: - entries.append((match, entry)) - cursor = match.end - return entries - - def find_next(self, text: str, start: int = 0) -> PlaceholderTokenMatch | None: - match = _PASTED_TEXT_PLACEHOLDER_RE.search(text, start) - if match is None: - return None - return PlaceholderTokenMatch( - start=match.start(), - end=match.end(), - raw=match.group(0), - handler=self, - match=match, - ) - - def resolve_content(self, match: PlaceholderTokenMatch) -> list[ContentPart] | None: - paste_id = int(match.match.group("id")) - entry = self.entry_for_id(paste_id) - if entry is None: - return None - return [TextPart(text=entry.text)] - - def expand_text(self, match: PlaceholderTokenMatch) -> str | None: - paste_id = int(match.match.group("id")) - entry = self.entry_for_id(paste_id) - return None if entry is None else entry.text - - def serialize_for_history(self, match: PlaceholderTokenMatch) -> str | None: - return self.expand_text(match) - - def expand_for_editor(self, match: PlaceholderTokenMatch) -> str | None: - return self.expand_text(match) - - def refold_after_editor(self, edited_text: str, original_command: str) -> str: - expanded_original, intervals = self._expanded_text_and_intervals(original_command) - if not intervals: - return edited_text - - opcodes = SequenceMatcher( - a=expanded_original, - b=edited_text, - autojunk=False, - ).get_opcodes() - replacements: list[tuple[int, int, str]] = [] - for start, end, token, expected_text in intervals: - mapped = self._map_interval(opcodes, start, end) - if mapped is None: - continue - mapped_start, mapped_end = mapped - if edited_text[mapped_start:mapped_end] != expected_text: - continue - replacements.append((mapped_start, mapped_end, token)) - - result = edited_text - for start, end, token in reversed(replacements): - result = result[:start] + token + result[end:] - return result - - def _expanded_text_and_intervals( - self, command: str - ) -> tuple[str, list[tuple[int, int, str, str]]]: - parts: list[str] = [] - intervals: list[tuple[int, int, str, str]] = [] - cursor = 0 - expanded_cursor = 0 - for match, entry in self.iter_entries_for_command(command): - literal = command[cursor : match.start] - if literal: - parts.append(literal) - expanded_cursor += len(literal) - start = expanded_cursor - parts.append(entry.text) - expanded_cursor += len(entry.text) - intervals.append((start, expanded_cursor, match.raw, entry.text)) - cursor = match.end - if cursor < len(command): - parts.append(command[cursor:]) - return "".join(parts), intervals - - @staticmethod - def _map_interval( - opcodes: Sequence[tuple[str, int, int, int, int]], start: int, end: int - ) -> tuple[int, int] | None: - mapped_start: int | None = None - mapped_end: int | None = None - cursor = start - for tag, i1, i2, j1, _j2 in opcodes: - if i2 <= cursor: - continue - if i1 >= end: - break - overlap_start = max(i1, cursor, start) - overlap_end = min(i2, end) - if overlap_start >= overlap_end: - continue - if tag != "equal": - return None - segment_start = j1 + (overlap_start - i1) - segment_end = j1 + (overlap_end - i1) - if mapped_start is None: - mapped_start = segment_start - elif mapped_end != segment_start: - return None - mapped_end = segment_end - cursor = overlap_end - if cursor != end or mapped_start is None or mapped_end is None: - return None - return mapped_start, mapped_end - - -class ImagePlaceholderHandler: - def __init__(self, attachment_cache: AttachmentCache) -> None: - self._attachment_cache = attachment_cache - - def create_placeholder(self, image: Image.Image) -> str | None: - cached = self._attachment_cache.store_image(image) - if cached is None: - return None - return f"[image:{cached.attachment_id},{image.width}x{image.height}]" - - def find_next(self, text: str, start: int = 0) -> PlaceholderTokenMatch | None: - match = _IMAGE_PLACEHOLDER_RE.search(text, start) - if match is None: - return None - return PlaceholderTokenMatch( - start=match.start(), - end=match.end(), - raw=match.group(0), - handler=self, - match=match, - ) - - def resolve_content(self, match: PlaceholderTokenMatch) -> list[ContentPart] | None: - kind = parse_attachment_kind(match.match.group("type")) - if kind is None: - return None - return self._attachment_cache.load_content_parts(kind, match.match.group("id")) - - def expand_text(self, match: PlaceholderTokenMatch) -> str | None: - return match.raw - - def serialize_for_history(self, match: PlaceholderTokenMatch) -> str | None: - return match.raw - - def expand_for_editor(self, match: PlaceholderTokenMatch) -> str | None: - return match.raw - - -@dataclass(slots=True) -class ResolvedPromptCommand: - display_command: str - resolved_text: str - content: list[ContentPart] - - -class PromptPlaceholderManager: - def __init__(self, attachment_cache: AttachmentCache | None = None) -> None: - self._attachment_cache = attachment_cache or AttachmentCache() - self._text_handler = PastedTextPlaceholderHandler() - self._image_handler = ImagePlaceholderHandler(self._attachment_cache) - self._handlers: tuple[PlaceholderHandler, ...] = ( - self._text_handler, - self._image_handler, - ) - - @property - def attachment_cache(self) -> AttachmentCache: - return self._attachment_cache - - def maybe_placeholderize_pasted_text(self, text: str) -> str: - return self._text_handler.maybe_placeholderize(text) - - def create_image_placeholder(self, image: Image.Image) -> str | None: - return self._image_handler.create_placeholder(image) - - def resolve_command(self, command: str) -> ResolvedPromptCommand: - content: list[ContentPart] = [] - resolved_chunks: list[str] = [] - cursor = 0 - - while match := self._find_next_match(command, cursor): - if match.start > cursor: - literal = command[cursor : match.start] - content.append(TextPart(text=literal)) - resolved_chunks.append(literal) - - resolved_content = match.handler.resolve_content(match) - if resolved_content is None: - content.append(TextPart(text=match.raw)) - resolved_chunks.append(match.raw) - else: - content.extend(resolved_content) - expanded = match.handler.expand_text(match) - resolved_chunks.append(match.raw if expanded is None else expanded) - - cursor = match.end - - if cursor < len(command): - literal = command[cursor:] - content.append(TextPart(text=literal)) - resolved_chunks.append(literal) - - return ResolvedPromptCommand( - display_command=command, - resolved_text="".join(resolved_chunks), - content=content, - ) - - def serialize_for_history(self, command: str) -> str: - return self._rewrite_command( - command, - lambda handler, match: handler.serialize_for_history(match), - ) - - def expand_for_editor(self, command: str) -> str: - return self._rewrite_command( - command, - lambda handler, match: handler.expand_for_editor(match), - ) - - def refold_after_editor(self, edited_text: str, original_command: str) -> str: - return self._text_handler.refold_after_editor(edited_text, original_command) - - def _find_next_match(self, text: str, start: int = 0) -> PlaceholderTokenMatch | None: - earliest: PlaceholderTokenMatch | None = None - for handler in self._handlers: - match = handler.find_next(text, start) - if match is None: - continue - if earliest is None or match.start < earliest.start: - earliest = match - return earliest - - def _rewrite_command( - self, - command: str, - replacer: Callable[[PlaceholderHandler, PlaceholderTokenMatch], str | None], - ) -> str: - parts: list[str] = [] - cursor = 0 - - while match := self._find_next_match(command, cursor): - if match.start > cursor: - parts.append(command[cursor : match.start]) - replacement = replacer(match.handler, match) - parts.append(match.raw if replacement is None else replacement) - cursor = match.end - - if cursor < len(command): - parts.append(command[cursor:]) - - return "".join(parts) diff --git a/src/kimi_cli/ui/shell/prompt.py b/src/kimi_cli/ui/shell/prompt.py deleted file mode 100644 index eea559e21..000000000 --- a/src/kimi_cli/ui/shell/prompt.py +++ /dev/null @@ -1,2124 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import json -import os -import random -import re -import shlex -import subprocess -import time -from collections import deque -from collections.abc import Awaitable, Callable, Iterable, Sequence -from dataclasses import dataclass -from enum import Enum -from hashlib import md5 -from pathlib import Path -from typing import Any, Literal, Protocol, cast, override - -from kaos.path import KaosPath -from prompt_toolkit import PromptSession -from prompt_toolkit.application.current import get_app_or_none -from prompt_toolkit.buffer import Buffer -from prompt_toolkit.clipboard.pyperclip import PyperclipClipboard -from prompt_toolkit.completion import ( - CompleteEvent, - Completer, - Completion, - FuzzyCompleter, - WordCompleter, - merge_completers, -) -from prompt_toolkit.data_structures import Point -from prompt_toolkit.document import Document -from prompt_toolkit.filters import Condition, has_completions, has_focus, is_done -from prompt_toolkit.formatted_text import AnyFormattedText, FormattedText, to_formatted_text -from prompt_toolkit.history import InMemoryHistory -from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent -from prompt_toolkit.keys import Keys -from prompt_toolkit.layout.containers import ( - ConditionalContainer, - DynamicContainer, - Float, - FloatContainer, - HSplit, - Window, -) -from prompt_toolkit.layout.controls import BufferControl, UIContent, UIControl -from prompt_toolkit.layout.dimension import Dimension -from prompt_toolkit.layout.menus import CompletionsMenu -from prompt_toolkit.patch_stdout import patch_stdout -from prompt_toolkit.utils import get_cwidth -from pydantic import BaseModel, ValidationError - -from kimi_cli.llm import ModelCapability -from kimi_cli.share import get_share_dir -from kimi_cli.soul import StatusSnapshot, format_context_status -from kimi_cli.ui.shell import placeholders as prompt_placeholders -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.placeholders import ( - PromptPlaceholderManager, - normalize_pasted_text, - sanitize_surrogates, -) -from kimi_cli.ui.theme import get_prompt_style, get_toolbar_colors -from kimi_cli.utils.clipboard import ( - grab_media_from_clipboard, - is_clipboard_available, -) -from kimi_cli.utils.logging import logger -from kimi_cli.utils.slashcmd import SlashCommand -from kimi_cli.wire.types import ContentPart - -AttachmentCache = prompt_placeholders.AttachmentCache -CachedAttachment = prompt_placeholders.CachedAttachment -_parse_attachment_kind = prompt_placeholders.parse_attachment_kind - -PROMPT_SYMBOL = "✨" -PROMPT_SYMBOL_SHELL = "$" -PROMPT_SYMBOL_THINKING = "💫" -PROMPT_SYMBOL_PLAN = "📋" - - -class SlashCommandCompleter(Completer): - """ - A completer that: - - Shows one line per slash command using the canonical "/name" - - Fuzzy-matches by primary name or any alias while inserting the canonical "/name" - - Only activates when the current token starts with '/' - """ - - def __init__(self, available_commands: Sequence[SlashCommand[Any]]) -> None: - super().__init__() - self._available_commands = list(available_commands) - self._command_lookup: dict[str, list[SlashCommand[Any]]] = {} - words: list[str] = [] - - for cmd in sorted(self._available_commands, key=lambda c: c.name): - if cmd.name not in self._command_lookup: - self._command_lookup[cmd.name] = [] - words.append(cmd.name) - self._command_lookup[cmd.name].append(cmd) - for alias in cmd.aliases: - if alias in self._command_lookup: - self._command_lookup[alias].append(cmd) - else: - self._command_lookup[alias] = [cmd] - words.append(alias) - - self._word_pattern = re.compile(r"[^\s]+") - self._fuzzy_pattern = r"^[^\s]*" - self._word_completer = WordCompleter(words, WORD=False, pattern=self._word_pattern) - self._fuzzy = FuzzyCompleter(self._word_completer, WORD=False, pattern=self._fuzzy_pattern) - - @staticmethod - def should_complete(document: Document) -> bool: - """Return whether slash command completion should be active for the current buffer.""" - text = document.text_before_cursor - - if document.text_after_cursor.strip(): - return False - - last_space = text.rfind(" ") - token = text[last_space + 1 :] - prefix = text[: last_space + 1] if last_space != -1 else "" - - return not prefix.strip() and token.startswith("/") - - @override - def get_completions( - self, document: Document, complete_event: CompleteEvent - ) -> Iterable[Completion]: - if not self.should_complete(document): - return - text = document.text_before_cursor - last_space = text.rfind(" ") - token = text[last_space + 1 :] - - typed = token[1:] - if typed and typed in self._command_lookup: - return - mention_doc = Document(text=typed, cursor_position=len(typed)) - candidates = list(self._fuzzy.get_completions(mention_doc, complete_event)) - - seen: set[str] = set() - - for candidate in candidates: - commands = self._command_lookup.get(candidate.text) - if not commands: - continue - for cmd in commands: - if cmd.name in seen: - continue - seen.add(cmd.name) - yield Completion( - text=f"/{cmd.name}", - start_position=-len(token), - display=f"/{cmd.name}", - display_meta=cmd.description, - ) - - -def _truncate_to_width(text: str, width: int) -> str: - if width <= 0: - return "" - - total = 0 - chars: list[str] = [] - for ch in text: - ch_width = get_cwidth(ch) - if total + ch_width > width: - break - chars.append(ch) - total += ch_width - - if total == get_cwidth(text): - return text + (" " * max(0, width - total)) - - ellipsis = "..." - ellipsis_width = get_cwidth(ellipsis) - if width <= ellipsis_width: - return "." * width - - available = width - ellipsis_width - total = 0 - chars = [] - for ch in text: - ch_width = get_cwidth(ch) - if total + ch_width > available: - break - chars.append(ch) - total += ch_width - return "".join(chars) + ellipsis + (" " * max(0, width - total - ellipsis_width)) - - -def _wrap_to_width(text: str, width: int, *, max_lines: int | None = None) -> list[str]: - if width <= 0: - return [] - - words = text.split() - if not words: - return [""] - - lines: list[str] = [] - current_words: list[str] = [] - current_width = 0 - index = 0 - - while index < len(words): - word = words[index] - word_width = get_cwidth(word) - separator_width = 1 if current_words else 0 - - if current_words and current_width + separator_width + word_width <= width: - current_words.append(word) - current_width += separator_width + word_width - index += 1 - continue - - if not current_words and word_width <= width: - current_words.append(word) - current_width = word_width - index += 1 - continue - - if not current_words and word_width > width: - current_words.append(_truncate_to_width(word, width).rstrip()) - current_width = get_cwidth(current_words[0]) - index += 1 - - lines.append(" ".join(current_words)) - current_words = [] - current_width = 0 - - if max_lines is not None and len(lines) == max_lines: - remaining = " ".join(words[index:]) - if remaining: - prefix = f"{lines[-1]} " if lines[-1] else "" - lines[-1] = _truncate_to_width(prefix + remaining, width).rstrip() - return lines - - if current_words: - line = " ".join(current_words) - if max_lines is not None and len(lines) + 1 > max_lines: - if lines: - lines[-1] = _truncate_to_width(f"{lines[-1]} {line}", width).rstrip() - else: - lines.append(_truncate_to_width(line, width).rstrip()) - else: - lines.append(line) - - return lines - - -def _find_prompt_float_container(layout_container: object) -> FloatContainer | None: - if not isinstance(layout_container, HSplit): - return None - - for child in cast(Sequence[object], layout_container.children): - float_container = _extract_float_container(child) - if float_container is not None: - return float_container - return None - - -def _extract_float_container(container: object) -> FloatContainer | None: - if isinstance(container, FloatContainer): - return container - if isinstance(container, ConditionalContainer): - if isinstance(container.content, FloatContainer): - return container.content - if isinstance(container.alternative_content, FloatContainer): - return container.alternative_content - return None - - -def _find_default_buffer_container( - layout_container: object, - target_buffer: Buffer, -) -> ConditionalContainer | None: - seen: set[int] = set() - - def _walk(node: object) -> ConditionalContainer | None: - if id(node) in seen: - return None - seen.add(id(node)) - - if isinstance(node, ConditionalContainer): - content = getattr(node, "content", None) - if isinstance(content, Window): - control = content.content - if isinstance(control, BufferControl) and control.buffer is target_buffer: - return node - - if isinstance(node, DynamicContainer): - with contextlib.suppress(Exception): - found = _walk(node.get_container()) - if found is not None: - return found - - for attr in ("children", "content", "floats", "container"): - if not hasattr(node, attr): - continue - value = getattr(node, attr) - if attr == "children" and isinstance(value, Sequence): - for child in value: # pyright: ignore[reportUnknownVariableType] - found = _walk(child) # pyright: ignore[reportUnknownArgumentType] - if found is not None: - return found - elif attr == "floats" and isinstance(value, Sequence): - for float_ in value: # pyright: ignore[reportUnknownVariableType] - content = getattr(float_, "content", None) # pyright: ignore[reportUnknownArgumentType] - if content is None: - continue - found = _walk(content) - if found is not None: - return found - elif ( - attr in {"content", "container"} - and value is not None - and type(value).__module__.startswith("prompt_toolkit") - ): - found = _walk(value) - if found is not None: - return found - return None - - return _walk(layout_container) - - -class SlashCommandMenuControl(UIControl): - """Render slash command completions as a full-width menu that matches the shell UI.""" - - _MAX_EXPANDED_META_LINES = 3 - - def __init__( - self, - *, - left_padding: Callable[[], int], - scroll_offset: int = 1, - ) -> None: - self._left_padding = left_padding - self._scroll_offset = scroll_offset - - def has_focus(self) -> bool: - return False - - def preferred_width(self, max_available_width: int) -> int | None: - return max_available_width - - def preferred_height( - self, - width: int, - max_available_height: int, - wrap_lines: bool, - get_line_prefix: Callable[..., AnyFormattedText] | None, - ) -> int | None: - app = get_app_or_none() - complete_state = ( - getattr(app.current_buffer, "complete_state", None) if app is not None else None - ) - if complete_state is None: - return 0 - completions = complete_state.completions - selected_index = complete_state.complete_index - if selected_index is None: - return min(max_available_height, len(completions) + 1) - menu_width = max(0, width - self._left_padding()) - marker_width = 2 - command_width = self._command_column_width(completions, menu_width, marker_width) - gap_width = 3 if menu_width > command_width + 6 else 1 - meta_width = max(0, menu_width - marker_width - command_width - gap_width) - selected_meta_lines = self._selected_meta_lines( - completions[selected_index].display_meta_text, - meta_width, - ) - return min(max_available_height, len(completions) + len(selected_meta_lines)) - - def create_content(self, width: int, height: int) -> UIContent: - app = get_app_or_none() - complete_state = ( - getattr(app.current_buffer, "complete_state", None) if app is not None else None - ) - if complete_state is None or not complete_state.completions: - return UIContent() - - completions = complete_state.completions - selected_index = complete_state.complete_index - available_rows = max(1, height - 1) - - menu_width = max(0, width - self._left_padding()) - marker_width = 2 - command_width = self._command_column_width(completions, menu_width, marker_width) - gap_width = 3 if menu_width > command_width + 6 else 1 - meta_width = max(0, menu_width - marker_width - command_width - gap_width) - - rendered_lines: list[FormattedText] = [ - FormattedText([("class:slash-completion-menu.separator", "─" * max(0, width))]) - ] - selected_line_index = 0 - - if selected_index is None: - end = min(len(completions) - 1, available_rows - 1) - for index in range(0, end + 1): - rendered_lines.append( - self._render_single_line_item( - width=width, - completion=completions[index], - marker_width=marker_width, - command_width=command_width, - meta_width=meta_width, - gap_width=gap_width, - is_current=False, - ) - ) - - return UIContent( - get_line=lambda i: rendered_lines[i], - line_count=len(rendered_lines), - cursor_position=Point(x=0, y=selected_line_index), - ) - - selected_meta_lines = self._selected_meta_lines( - completions[selected_index].display_meta_text, - meta_width, - ) - start, end = self._visible_window_bounds( - completion_count=len(completions), - selected_index=selected_index, - available_rows=available_rows, - selected_item_height=len(selected_meta_lines), - ) - selected_line_index = 1 - - for index in range(start, end + 1): - completion = completions[index] - if index == selected_index: - selected_line_index = len(rendered_lines) - rendered_lines.extend( - self._render_selected_item_lines( - width=width, - completion=completion, - marker_width=marker_width, - command_width=command_width, - meta_width=meta_width, - gap_width=gap_width, - meta_lines=selected_meta_lines, - ) - ) - continue - - rendered_lines.append( - self._render_single_line_item( - width=width, - completion=completion, - marker_width=marker_width, - command_width=command_width, - meta_width=meta_width, - gap_width=gap_width, - is_current=False, - ) - ) - - return UIContent( - get_line=lambda i: rendered_lines[i], - line_count=len(rendered_lines), - cursor_position=Point(x=0, y=selected_line_index), - ) - - def _selected_meta_lines(self, text: str, meta_width: int) -> list[str]: - lines = _wrap_to_width( - text, - meta_width, - max_lines=self._MAX_EXPANDED_META_LINES, - ) - return lines or [""] - - def _visible_window_bounds( - self, - *, - completion_count: int, - selected_index: int, - available_rows: int, - selected_item_height: int, - ) -> tuple[int, int]: - selected_item_height = min(selected_item_height, available_rows) - remaining_rows = max(0, available_rows - selected_item_height) - - before = min(self._scroll_offset, selected_index, remaining_rows) - remaining_rows -= before - after = min(completion_count - selected_index - 1, remaining_rows) - remaining_rows -= after - - extra_before = min(selected_index - before, remaining_rows) - before += extra_before - remaining_rows -= extra_before - - extra_after = min(completion_count - selected_index - 1 - after, remaining_rows) - after += extra_after - - return selected_index - before, selected_index + after - - def _command_column_width( - self, - completions: Sequence[Completion], - menu_width: int, - marker_width: int, - ) -> int: - if menu_width <= 0: - return 0 - longest = max((get_cwidth(c.display_text) for c in completions), default=0) - preferred = longest + 2 - usable_width = max(0, menu_width - marker_width) - minimum = min(usable_width, 18) - maximum = max(minimum, min(28, usable_width // 2)) - return max(minimum, min(preferred, maximum)) - - def _render_single_line_item( - self, - *, - width: int, - completion: Completion, - marker_width: int, - command_width: int, - meta_width: int, - gap_width: int, - is_current: bool, - ) -> FormattedText: - padding_width = max(0, width - marker_width - command_width - meta_width - gap_width) - left_padding = min(self._left_padding(), padding_width) - trailing_width = max( - 0, - width - left_padding - marker_width - command_width - gap_width - meta_width, - ) - - command_style = ( - "class:slash-completion-menu.command.current" - if is_current - else "class:slash-completion-menu.command" - ) - meta_style = ( - "class:slash-completion-menu.meta.current" - if is_current - else "class:slash-completion-menu.meta" - ) - marker_style = ( - "class:slash-completion-menu.marker.current" - if is_current - else "class:slash-completion-menu.marker" - ) - marker = "› " if is_current else " " - - fragments: FormattedText = FormattedText() - fragments.append(("class:slash-completion-menu", " " * left_padding)) - fragments.append((marker_style, marker.ljust(marker_width))) - fragments.append( - (command_style, _truncate_to_width(completion.display_text, command_width)) - ) - fragments.append(("class:slash-completion-menu", " " * gap_width)) - fragments.append((meta_style, _truncate_to_width(completion.display_meta_text, meta_width))) - fragments.append(("class:slash-completion-menu", " " * trailing_width)) - return fragments - - def _render_selected_item_lines( - self, - *, - width: int, - completion: Completion, - marker_width: int, - command_width: int, - meta_width: int, - gap_width: int, - meta_lines: Sequence[str], - ) -> list[FormattedText]: - lines = [ - self._render_single_line_item( - width=width, - completion=Completion( - text=completion.text, - start_position=completion.start_position, - display=completion.display, - display_meta=meta_lines[0], - ), - marker_width=marker_width, - command_width=command_width, - meta_width=meta_width, - gap_width=gap_width, - is_current=True, - ) - ] - - continuation_prefix = ( - " " * self._left_padding() + " " * marker_width + " " * command_width + " " * gap_width - ) - continuation_trailing = max( - 0, - width - get_cwidth(continuation_prefix) - meta_width, - ) - for meta_line in meta_lines[1:]: - fragments: FormattedText = FormattedText() - fragments.append(("class:slash-completion-menu", continuation_prefix)) - fragments.append( - ( - "class:slash-completion-menu.meta.current", - _truncate_to_width(meta_line, meta_width), - ) - ) - fragments.append(("class:slash-completion-menu", " " * continuation_trailing)) - lines.append(fragments) - - return lines - - -class LocalFileMentionCompleter(Completer): - """Offer fuzzy `@` path completion by indexing workspace files.""" - - _FRAGMENT_PATTERN = re.compile(r"[^\s@]+") - _TRIGGER_GUARDS = frozenset((".", "-", "_", "`", "'", '"', ":", "@", "#", "~")) - _IGNORED_NAME_GROUPS: dict[str, tuple[str, ...]] = { - "vcs_metadata": (".DS_Store", ".bzr", ".git", ".hg", ".svn"), - "tooling_caches": ( - ".build", - ".cache", - ".coverage", - ".fleet", - ".gradle", - ".idea", - ".ipynb_checkpoints", - ".pnpm-store", - ".pytest_cache", - ".pub-cache", - ".ruff_cache", - ".swiftpm", - ".tox", - ".venv", - ".vs", - ".vscode", - ".yarn", - ".yarn-cache", - ), - "js_frontend": ( - ".next", - ".nuxt", - ".parcel-cache", - ".svelte-kit", - ".turbo", - ".vercel", - "node_modules", - ), - "python_packaging": ( - "__pycache__", - "build", - "coverage", - "dist", - "htmlcov", - "pip-wheel-metadata", - "venv", - ), - "java_jvm": (".mvn", "out", "target"), - "dotnet_native": ("bin", "cmake-build-debug", "cmake-build-release", "obj"), - "bazel_buck": ("bazel-bin", "bazel-out", "bazel-testlogs", "buck-out"), - "misc_artifacts": ( - ".dart_tool", - ".serverless", - ".stack-work", - ".terraform", - ".terragrunt-cache", - "DerivedData", - "Pods", - "deps", - "tmp", - "vendor", - ), - } - _IGNORED_NAMES = frozenset(name for group in _IGNORED_NAME_GROUPS.values() for name in group) - _IGNORED_PATTERN_PARTS: tuple[str, ...] = ( - r".*_cache$", - r".*-cache$", - r".*\.egg-info$", - r".*\.dist-info$", - r".*\.py[co]$", - r".*\.class$", - r".*\.sw[po]$", - r".*~$", - r".*\.(?:tmp|bak)$", - ) - _IGNORED_PATTERNS = re.compile( - "|".join(f"(?:{part})" for part in _IGNORED_PATTERN_PARTS), - re.IGNORECASE, - ) - - def __init__( - self, - root: Path, - *, - refresh_interval: float = 2.0, - limit: int = 1000, - ) -> None: - self._root = root - self._refresh_interval = refresh_interval - self._limit = limit - self._cache_time: float = 0.0 - self._cached_paths: list[str] = [] - self._top_cache_time: float = 0.0 - self._top_cached_paths: list[str] = [] - self._fragment_hint: str | None = None - - self._word_completer = WordCompleter( - self._get_paths, - WORD=False, - pattern=self._FRAGMENT_PATTERN, - ) - - self._fuzzy = FuzzyCompleter( - self._word_completer, - WORD=False, - pattern=r"^[^\s@]*", - ) - - @classmethod - def _is_ignored(cls, name: str) -> bool: - if not name: - return True - if name in cls._IGNORED_NAMES: - return True - return bool(cls._IGNORED_PATTERNS.fullmatch(name)) - - def _get_paths(self) -> list[str]: - fragment = self._fragment_hint or "" - if "/" not in fragment and len(fragment) < 3: - return self._get_top_level_paths() - return self._get_deep_paths() - - def _get_top_level_paths(self) -> list[str]: - now = time.monotonic() - if now - self._top_cache_time <= self._refresh_interval: - return self._top_cached_paths - - entries: list[str] = [] - try: - for entry in sorted(self._root.iterdir(), key=lambda p: p.name): - name = entry.name - if self._is_ignored(name): - continue - entries.append(f"{name}/" if entry.is_dir() else name) - if len(entries) >= self._limit: - break - except OSError: - return self._top_cached_paths - - self._top_cached_paths = entries - self._top_cache_time = now - return self._top_cached_paths - - def _get_deep_paths(self) -> list[str]: - now = time.monotonic() - if now - self._cache_time <= self._refresh_interval: - return self._cached_paths - - paths: list[str] = [] - try: - for current_root, dirs, files in os.walk(self._root): - relative_root = Path(current_root).relative_to(self._root) - - # Prevent descending into ignored directories. - dirs[:] = sorted(d for d in dirs if not self._is_ignored(d)) - - if relative_root.parts and any( - self._is_ignored(part) for part in relative_root.parts - ): - dirs[:] = [] - continue - - if relative_root.parts: - paths.append(relative_root.as_posix() + "/") - if len(paths) >= self._limit: - break - - for file_name in sorted(files): - if self._is_ignored(file_name): - continue - relative = (relative_root / file_name).as_posix() - if not relative: - continue - paths.append(relative) - if len(paths) >= self._limit: - break - - if len(paths) >= self._limit: - break - except OSError: - return self._cached_paths - - self._cached_paths = paths - self._cache_time = now - return self._cached_paths - - @staticmethod - def _extract_fragment(text: str) -> str | None: - index = text.rfind("@") - if index == -1: - return None - - if index > 0: - prev = text[index - 1] - if prev.isalnum() or prev in LocalFileMentionCompleter._TRIGGER_GUARDS: - return None - - fragment = text[index + 1 :] - if not fragment: - return "" - - if any(ch.isspace() for ch in fragment): - return None - - return fragment - - def _is_completed_file(self, fragment: str) -> bool: - candidate = fragment.rstrip("/") - if not candidate: - return False - try: - return (self._root / candidate).is_file() - except OSError: - return False - - @override - def get_completions( - self, document: Document, complete_event: CompleteEvent - ) -> Iterable[Completion]: - fragment = self._extract_fragment(document.text_before_cursor) - if fragment is None: - return - if self._is_completed_file(fragment): - return - - mention_doc = Document(text=fragment, cursor_position=len(fragment)) - self._fragment_hint = fragment - try: - # First, ask the fuzzy completer for candidates. - candidates = list(self._fuzzy.get_completions(mention_doc, complete_event)) - - # re-rank: prefer basename matches - frag_lower = fragment.lower() - - def _rank(c: Completion) -> tuple[int, ...]: - path = c.text - base = path.rstrip("/").split("/")[-1].lower() - if base.startswith(frag_lower): - cat = 0 - elif frag_lower in base: - cat = 1 - else: - cat = 2 - # preserve original FuzzyCompleter's order in the same category - return (cat,) - - candidates.sort(key=_rank) - yield from candidates - finally: - self._fragment_hint = None - - -class _HistoryEntry(BaseModel): - content: str - - -def _load_history_entries(history_file: Path) -> list[_HistoryEntry]: - entries: list[_HistoryEntry] = [] - if not history_file.exists(): - return entries - - try: - with history_file.open(encoding="utf-8") as f: - for raw_line in f: - line = raw_line.strip() - if not line: - continue - try: - record = json.loads(line) - except json.JSONDecodeError: - logger.warning( - "Failed to parse user history line; skipping: {line}", - line=line, - ) - continue - try: - entry = _HistoryEntry.model_validate(record) - entries.append(entry) - except ValidationError: - logger.warning( - "Failed to validate user history entry; skipping: {line}", - line=line, - ) - continue - except OSError as exc: - logger.warning( - "Failed to load user history file: {file} ({error})", - file=history_file, - error=exc, - ) - - return entries - - -class PromptMode(Enum): - AGENT = "agent" - SHELL = "shell" - - def toggle(self) -> PromptMode: - return PromptMode.SHELL if self == PromptMode.AGENT else PromptMode.AGENT - - def __str__(self) -> str: - return self.value - - -class PromptUIState(Enum): - NORMAL_INPUT = "normal_input" - MODAL_HIDDEN_INPUT = "modal_hidden_input" - MODAL_TEXT_INPUT = "modal_text_input" - - -class UserInput(BaseModel): - mode: PromptMode - command: str - """The plain text representation of the user input.""" - resolved_command: str - """The text command after UI-only placeholders are expanded.""" - content: list[ContentPart] - """The rich content parts.""" - - def __str__(self) -> str: - return self.command - - def __bool__(self) -> bool: - return bool(self.command) - - -_IDLE_REFRESH_INTERVAL = 1.0 -_RUNNING_REFRESH_INTERVAL = 0.1 - -_GIT_BRANCH_TTL = 5.0 -_GIT_STATUS_TTL = 15.0 -_TIP_ROTATE_INTERVAL = 30.0 -_MAX_CWD_COLS = 30 -_MAX_BRANCH_COLS = 22 - - -@dataclass -class _GitBranchState: - timestamp: float = 0.0 - branch: str | None = None - proc: subprocess.Popen[str] | None = None - - -@dataclass -class _GitStatusState: - timestamp: float = 0.0 - dirty: bool = False - ahead: int = 0 - behind: int = 0 - proc: subprocess.Popen[str] | None = None - - -_git_branch_state = _GitBranchState() -_git_status_state = _GitStatusState() - -_GIT_STATUS_AB_RE = re.compile(r"\[(?:ahead (\d+))?(?:, )?(?:behind (\d+))?\]") - - -def _get_git_branch() -> str | None: - """Return the current git branch name via a non-blocking cached subprocess.""" - state = _git_branch_state - now = time.monotonic() - - # Collect result if a previously launched process has finished - if state.proc is not None: - returncode = state.proc.poll() - if returncode is not None: - try: - stdout, _ = state.proc.communicate() - new_branch = stdout.strip() or None - # Branch changed — discard any in-flight status subprocess so it cannot - # write stale results for the old branch, then force an immediate refresh. - if new_branch != state.branch: - if _git_status_state.proc is not None: - with contextlib.suppress(Exception): - _git_status_state.proc.terminate() - _git_status_state.proc = None - _git_status_state.timestamp = 0.0 - state.branch = new_branch - except Exception: - state.branch = None - state.proc = None - - # Launch a new process when the TTL has expired and nothing is running - if state.timestamp + _GIT_BRANCH_TTL <= now and state.proc is None: - state.timestamp = now - try: - state.proc = subprocess.Popen( - ["git", "branch", "--show-current"], - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - text=True, - ) - except Exception: - state.branch = None - - return state.branch - - -def _get_git_status() -> tuple[bool, int, int]: - """Return (dirty, ahead, behind) via a non-blocking cached subprocess. - - Runs ``git status --porcelain -b`` (includes untracked files so newly created - files show as dirty). TTL is longer than the branch check because file-tree - scanning is expensive. - """ - state = _git_status_state - now = time.monotonic() - - if state.proc is not None: - returncode = state.proc.poll() - if returncode is not None: - try: - stdout, _ = state.proc.communicate() - dirty = False - ahead = 0 - behind = 0 - for line in stdout.splitlines(): - if line.startswith("## "): - m = _GIT_STATUS_AB_RE.search(line) - if m: - ahead = int(m.group(1) or 0) - behind = int(m.group(2) or 0) - elif line.strip(): - dirty = True - state.dirty = dirty - state.ahead = ahead - state.behind = behind - except Exception: - pass - state.proc = None - elif now - state.timestamp > _GIT_STATUS_TTL: - # Subprocess is stuck (e.g. OS pipe buffer full from many untracked files). - # Terminate it so the toolbar is not permanently frozen; retry after next TTL. - with contextlib.suppress(Exception): - state.proc.terminate() - state.proc = None - state.timestamp = now # delay next spawn by one full TTL - - if state.timestamp + _GIT_STATUS_TTL <= now and state.proc is None: - state.timestamp = now - with contextlib.suppress(Exception): - state.proc = subprocess.Popen( - ["git", "status", "--porcelain", "-b"], - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - text=True, - ) - - return state.dirty, state.ahead, state.behind - - -def _format_git_badge(branch: str, dirty: bool, ahead: int, behind: int) -> str: - """Format branch name with an optional status badge: ``main [± ↑3↓1]``.""" - parts: list[str] = [] - if dirty: - parts.append("±") - sync = "" - if ahead: - sync += f"↑{ahead}" - if behind: - sync += f"↓{behind}" - if sync: - parts.append(sync) - if not parts: - return branch - return f"{branch} [{' '.join(parts)}]" - - -def _shorten_cwd(path: str) -> str: - """Replace the home directory prefix in *path* with ``~``.""" - home = str(Path.home()) - if path == home: - return "~" - if path.startswith(home + os.sep): - return "~" + path[len(home) :] - return path - - -def _display_width(text: str) -> int: - """Return the terminal column width of *text*, handling wide Unicode characters.""" - return sum(get_cwidth(c) for c in text) - - -def _truncate_left(text: str, max_cols: int) -> str: - """Truncate *text* from the left, prepending '…' if it exceeds *max_cols*.""" - if max_cols <= 0: - return "" - if _display_width(text) <= max_cols: - return text - ellipsis = "…" - budget = max_cols - _display_width(ellipsis) - chars: list[str] = [] - width = 0 - for ch in reversed(text): - w = get_cwidth(ch) - if width + w > budget: - break - chars.append(ch) - width += w - return ellipsis + "".join(reversed(chars)) - - -def _truncate_right(text: str, max_cols: int) -> str: - """Truncate *text* from the right, appending '…' if it exceeds *max_cols*.""" - if max_cols <= 0: - return "" - if _display_width(text) <= max_cols: - return text - ellipsis = "…" - budget = max_cols - _display_width(ellipsis) - chars: list[str] = [] - width = 0 - for ch in text: - w = get_cwidth(ch) - if width + w > budget: - break - chars.append(ch) - width += w - return "".join(chars) + ellipsis - - -@dataclass(slots=True) -class _ToastEntry: - topic: str | None - """There can be only one toast of each non-None topic in the queue.""" - message: str - expires_at: float - - -class RunningPromptDelegate(Protocol): - modal_priority: int - - def render_running_prompt_body(self, columns: int) -> AnyFormattedText: ... - - def running_prompt_placeholder(self) -> AnyFormattedText | None: ... - - def running_prompt_allows_text_input(self) -> bool: ... - - def running_prompt_hides_input_buffer(self) -> bool: ... - - def running_prompt_accepts_submission(self) -> bool: ... - - def should_handle_running_prompt_key(self, key: str) -> bool: ... - - def handle_running_prompt_key(self, key: str, event: KeyPressEvent) -> None: ... - - -_toast_queues: dict[Literal["left", "right"], deque[_ToastEntry]] = { - "left": deque(), - "right": deque(), -} -"""The queue of toasts to show, including the one currently being shown (the first one).""" - - -def toast( - message: str, - duration: float = 5.0, - topic: str | None = None, - immediate: bool = False, - position: Literal["left", "right"] = "left", -) -> None: - queue = _toast_queues[position] - duration = max(duration, _IDLE_REFRESH_INTERVAL) - entry = _ToastEntry(topic=topic, message=message, expires_at=time.monotonic() + duration) - if topic is not None: - # Remove existing toasts with the same topic - for existing in list(queue): - if existing.topic == topic: - queue.remove(existing) - if immediate: - queue.appendleft(entry) - else: - queue.append(entry) - - -def _current_toast(position: Literal["left", "right"] = "left") -> _ToastEntry | None: - queue = _toast_queues[position] - now = time.monotonic() - while queue and queue[0].expires_at <= now: - queue.popleft() - if not queue: - return None - return queue[0] - - -def _build_toolbar_tips(clipboard_available: bool) -> list[str]: - tips = [ - "ctrl-x: toggle mode", - "shift-tab: plan mode", - "ctrl-o: editor", - "ctrl-j: newline", - "/feedback: send feedback", - "/theme: switch dark/light", - ] - if clipboard_available: - tips.append("ctrl-v: paste clipboard") - tips.append("@: mention files") - return tips - - -_TIP_SEPARATOR = " | " - - -class CustomPromptSession: - def __init__( - self, - *, - status_provider: Callable[[], StatusSnapshot], - status_block_provider: Callable[[int], AnyFormattedText | None] | None = None, - fast_refresh_provider: Callable[[], bool] | None = None, - background_task_count_provider: Callable[[], int] | None = None, - model_capabilities: set[ModelCapability], - model_name: str | None, - thinking: bool, - agent_mode_slash_commands: Sequence[SlashCommand[Any]], - shell_mode_slash_commands: Sequence[SlashCommand[Any]], - editor_command_provider: Callable[[], str] = lambda: "", - plan_mode_toggle_callback: Callable[[], Awaitable[bool]] | None = None, - ) -> None: - history_dir = get_share_dir() / "user-history" - history_dir.mkdir(parents=True, exist_ok=True) - work_dir_id = md5(str(KaosPath.cwd()).encode(encoding="utf-8")).hexdigest() - self._history_file = (history_dir / work_dir_id).with_suffix(".jsonl") - self._status_provider = status_provider - self._status_block_provider = status_block_provider - self._fast_refresh_provider = fast_refresh_provider - self._background_task_count_provider = background_task_count_provider - self._editor_command_provider = editor_command_provider - self._plan_mode_toggle_callback = plan_mode_toggle_callback - self._model_capabilities = model_capabilities - self._model_name = model_name - self._last_history_content: str | None = None - self._mode: PromptMode = PromptMode.AGENT - self._thinking = thinking - self._placeholder_manager = PromptPlaceholderManager() - # Keep the old attribute for test compatibility and for any external imports. - self._attachment_cache = self._placeholder_manager.attachment_cache - self._last_tip_rotate_time: float = time.monotonic() - self._last_submission_was_running = False - self._running_prompt_previous_mode: PromptMode | None = None - self._running_prompt_delegate: RunningPromptDelegate | None = None - self._modal_delegates: list[RunningPromptDelegate] = [] - self._prompt_buffer_container: ConditionalContainer | None = None - self._last_ui_state: PromptUIState = PromptUIState.NORMAL_INPUT - self._suspended_buffer_document: Document | None = None - clipboard_available = is_clipboard_available() - self._tips = _build_toolbar_tips(clipboard_available) - self._tip_rotation_index: int = random.randrange(len(self._tips)) if self._tips else 0 - - history_entries = _load_history_entries(self._history_file) - history = InMemoryHistory() - for entry in history_entries: - history.append_string(entry.content) - - if history_entries: - # for consecutive deduplication - self._last_history_content = history_entries[-1].content - - # Build completers - self._agent_mode_completer = merge_completers( - [ - SlashCommandCompleter(agent_mode_slash_commands), - # TODO(kaos): we need an async KaosFileMentionCompleter - LocalFileMentionCompleter(KaosPath.cwd().unsafe_to_local_path()), - ], - deduplicate=True, - ) - self._shell_mode_completer = SlashCommandCompleter(shell_mode_slash_commands) - - # Build key bindings - _kb = KeyBindings() - - @_kb.add("enter", filter=has_completions) - def _(event: KeyPressEvent) -> None: - """Accept the first completion when Enter is pressed and completions are shown.""" - buff = event.current_buffer - if buff.complete_state and buff.complete_state.completions: - # Get the current completion, or use the first one if none is selected - completion = buff.complete_state.current_completion - if not completion: - completion = buff.complete_state.completions[0] - buff.apply_completion(completion) - - @_kb.add("c-x", eager=True) - def _(event: KeyPressEvent) -> None: - if self._active_prompt_delegate() is not None: - return - self._mode = self._mode.toggle() - # Apply mode-specific settings - self._apply_mode(event) - # Redraw UI - event.app.invalidate() - - @_kb.add("s-tab", eager=True) - def _(event: KeyPressEvent) -> None: - """Toggle plan mode with Shift+Tab.""" - if self._active_prompt_delegate() is not None: - return - if self._plan_mode_toggle_callback is not None: - - async def _toggle() -> None: - assert self._plan_mode_toggle_callback is not None - new_state = await self._plan_mode_toggle_callback() - if new_state: - toast("plan mode ON", topic="plan_mode", duration=3.0, immediate=True) - else: - toast("plan mode OFF", topic="plan_mode", duration=3.0, immediate=True) - event.app.invalidate() - - event.app.create_background_task(_toggle()) - event.app.invalidate() - - @_kb.add("escape", "enter", eager=True) - @_kb.add("c-j", eager=True) - def _(event: KeyPressEvent) -> None: - """Insert a newline when Alt-Enter or Ctrl-J is pressed.""" - event.current_buffer.insert_text("\n") - - @_kb.add("c-o", eager=True) - def _(event: KeyPressEvent) -> None: - """Open current buffer in external editor.""" - self._open_in_external_editor(event) - - @_kb.add( - "up", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("up")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("up", event) - - @_kb.add( - "down", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("down")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("down", event) - - @_kb.add( - "left", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("left")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("left", event) - - @_kb.add( - "right", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("right")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("right", event) - - @_kb.add( - "tab", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("tab")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("tab", event) - - @_kb.add( - "enter", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("enter")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("enter", event) - - @_kb.add( - "space", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("space")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("space", event) - - @_kb.add( - "c-e", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("c-e")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("c-e", event) - - @_kb.add( - "c-c", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("c-c")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("c-c", event) - - @_kb.add( - "c-d", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("c-d")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("c-d", event) - - @_kb.add( - "escape", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("escape")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("escape", event) - - @_kb.add( - "1", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("1")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("1", event) - - @_kb.add( - "2", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("2")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("2", event) - - @_kb.add( - "3", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("3")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("3", event) - - @_kb.add( - "4", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("4")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("4", event) - - @_kb.add( - "5", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("5")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("5", event) - - @_kb.add( - "6", - eager=True, - filter=Condition(lambda: self._should_handle_running_prompt_key("6")), - ) - def _(event: KeyPressEvent) -> None: - self._handle_running_prompt_key("6", event) - - @_kb.add(Keys.BracketedPaste, eager=True) - def _(event: KeyPressEvent) -> None: - self._handle_bracketed_paste(event) - - if clipboard_available: - - @_kb.add("c-v", eager=True) - def _(event: KeyPressEvent) -> None: - if self._try_paste_media(event): - return - clipboard_data = event.app.clipboard.get_data() - if clipboard_data is None: # type: ignore[reportUnnecessaryComparison] - return - self._insert_pasted_text(event.current_buffer, clipboard_data.text) - event.app.invalidate() - - clipboard = PyperclipClipboard() - else: - clipboard = None - - self._session = PromptSession[str]( - message=self._render_message, - # prompt_continuation=FormattedText([("fg:#4d4d4d", "... ")]), - completer=self._agent_mode_completer, - complete_while_typing=True, - reserve_space_for_menu=6, - key_bindings=_kb, - clipboard=clipboard, - history=history, - bottom_toolbar=self._render_bottom_toolbar, - style=get_prompt_style(), - ) - self._session.default_buffer.read_only = Condition( - lambda: ( - (delegate := self._active_prompt_delegate()) is not None - and not delegate.running_prompt_allows_text_input() - ) - ) - self._install_slash_completion_menu() - self._install_prompt_buffer_visibility() - self._apply_mode() - - # Allow completion to be triggered when the text is changed, - # such as when backspace is used to delete text. - @self._session.default_buffer.on_text_changed.add_handler - def _(buffer: Buffer) -> None: - if buffer.complete_while_typing(): - buffer.start_completion() - - self._status_refresh_task: asyncio.Task[None] | None = None - - def _install_slash_completion_menu(self) -> None: - float_container = _find_prompt_float_container(self._session.layout.container) - if not isinstance(float_container, FloatContainer): - return - - slash_menu_filter = ( - has_focus(self._session.default_buffer) - & has_completions - & ~is_done - & Condition(self._should_show_slash_completion_menu) - ) - slash_menu = ConditionalContainer( - Window( - content=SlashCommandMenuControl(left_padding=self._slash_menu_left_padding), - dont_extend_height=True, - height=Dimension(max=10), - style="class:slash-completion-menu", - ), - filter=slash_menu_filter, - ) - float_container.floats.insert( - 0, - Float( - left=0, - right=0, - ycursor=True, - content=slash_menu, - z_index=10**8, - ), - ) - - original_float = next( - ( - float_ - for float_ in float_container.floats[1:] - if isinstance(float_.content, CompletionsMenu) - ), - None, - ) - if original_float is None: - return - original_float.content = ConditionalContainer( - original_float.content, - filter=~Condition(self._should_show_slash_completion_menu), - ) - - def _install_prompt_buffer_visibility(self) -> None: - buffer_container = _find_default_buffer_container( - self._session.layout.container, - self._session.default_buffer, - ) - if buffer_container is None: - return - buffer_container.filter = buffer_container.filter & Condition( - self._should_render_input_buffer - ) - self._prompt_buffer_container = buffer_container - - def _should_show_slash_completion_menu(self) -> bool: - document = self._session.default_buffer.document - return SlashCommandCompleter.should_complete(document) - - def _slash_menu_left_padding(self) -> int: - if self._mode == PromptMode.SHELL: - return max(1, get_cwidth(f"{PROMPT_SYMBOL_SHELL} ") - 2) - if self._status_provider().plan_mode: - return max(1, get_cwidth(f"{PROMPT_SYMBOL_PLAN} ") - 2) - symbol = PROMPT_SYMBOL_THINKING if self._thinking else PROMPT_SYMBOL - return max(1, get_cwidth(f"{symbol} ") - 2) - - def _render_message(self) -> FormattedText: - if self._mode == PromptMode.SHELL: - return self._render_shell_prompt_message() - return self._render_agent_prompt_message() - - def _render_shell_prompt_message(self) -> FormattedText: - app = get_app_or_none() - columns = app.output.get_size().columns if app is not None else 80 - fragments: FormattedText = FormattedText() - body = self._render_agent_prompt_body(columns) - if body: - fragments.extend(body) - if not body[-1][1].endswith("\n"): - fragments.append(("", "\n")) - if self._active_modal_delegate() is not None: - return fragments - if body: - fragments.append(("", "\n")) - fragments.append(("class:running-prompt-separator", "─" * max(0, columns))) - fragments.append(("", "\n")) - fragments.append(("bold", f"{PROMPT_SYMBOL_SHELL} ")) - return fragments - - def _open_in_external_editor(self, event: KeyPressEvent) -> None: - """Open the current buffer content in an external editor.""" - from prompt_toolkit.application.run_in_terminal import run_in_terminal - - from kimi_cli.utils.editor import edit_text_in_editor, get_editor_command - - configured = self._editor_command_provider() - - if get_editor_command(configured) is None: - toast("No editor found. Set $VISUAL/$EDITOR or run /editor.") - return - - buff = event.current_buffer - original_text = buff.text - editor_text = self._get_placeholder_manager().expand_for_editor(original_text) - - async def _run_editor() -> None: - result = await run_in_terminal( - lambda: edit_text_in_editor(editor_text, configured), in_executor=True - ) - if result is not None: - refolded = self._get_placeholder_manager().refold_after_editor( - result, original_text - ) - buff.document = Document(text=refolded, cursor_position=len(refolded)) - - event.app.create_background_task(_run_editor()) - - def _apply_mode(self, event: KeyPressEvent | None = None) -> None: - # Apply mode to the active buffer (not the PromptSession itself) - try: - buff = event.current_buffer if event is not None else self._session.default_buffer - except Exception: - buff = None - - if self._mode == PromptMode.SHELL: - if buff is not None: - buff.completer = self._shell_mode_completer - else: - if buff is not None: - buff.completer = self._agent_mode_completer - self._sync_erase_when_done() - - def _sync_erase_when_done(self) -> None: - app = getattr(self._session, "app", None) - if app is not None: - app.erase_when_done = self._mode == PromptMode.AGENT - - def _active_modal_delegate(self) -> RunningPromptDelegate | None: - modal_delegates = getattr(self, "_modal_delegates", []) - if not modal_delegates: - return None - _, delegate = max( - enumerate(modal_delegates), - key=lambda item: (item[1].modal_priority, item[0]), - ) - return delegate - - def _active_prompt_delegate(self) -> RunningPromptDelegate | None: - if delegate := self._active_modal_delegate(): - return delegate - return getattr(self, "_running_prompt_delegate", None) - - def _active_ui_state(self) -> PromptUIState: - delegate = self._active_modal_delegate() - if delegate is None: - return PromptUIState.NORMAL_INPUT - if delegate.running_prompt_hides_input_buffer(): - return PromptUIState.MODAL_HIDDEN_INPUT - if delegate.running_prompt_allows_text_input(): - return PromptUIState.MODAL_TEXT_INPUT - return PromptUIState.NORMAL_INPUT - - def _should_render_input_buffer(self) -> bool: - return self._active_ui_state() != PromptUIState.MODAL_HIDDEN_INPUT - - def _should_handle_running_prompt_key(self, key: str) -> bool: - delegate = self._active_prompt_delegate() - return delegate is not None and delegate.should_handle_running_prompt_key(key) - - def _handle_running_prompt_key(self, key: str, event: KeyPressEvent) -> None: - delegate = self._active_prompt_delegate() - if delegate is None: - return - delegate.handle_running_prompt_key(key, event) - event.app.invalidate() - - def invalidate(self) -> None: - self._sync_prompt_ui_state() - app = get_app_or_none() - if app is not None: - app.invalidate() - - def _sync_prompt_ui_state(self) -> None: - new_state = self._active_ui_state() - old_state = getattr(self, "_last_ui_state", PromptUIState.NORMAL_INPUT) - buffer = self._session.default_buffer - - if ( - old_state != PromptUIState.MODAL_HIDDEN_INPUT - and new_state == PromptUIState.MODAL_HIDDEN_INPUT - ): - if self._suspended_buffer_document is None and buffer.text: - self._suspended_buffer_document = buffer.document - buffer.set_document(Document(), bypass_readonly=True) - elif ( - old_state == PromptUIState.MODAL_HIDDEN_INPUT - and new_state != PromptUIState.MODAL_HIDDEN_INPUT - ): - if self._suspended_buffer_document is not None and not buffer.text: - buffer.set_document(self._suspended_buffer_document, bypass_readonly=True) - self._suspended_buffer_document = None - - self._last_ui_state = new_state - - def _render_agent_prompt_message(self) -> FormattedText: - app = get_app_or_none() - columns = app.output.get_size().columns if app is not None else 80 - fragments: FormattedText = FormattedText() - body = self._render_agent_prompt_body(columns) - if body: - fragments.extend(body) - if not body[-1][1].endswith("\n"): - fragments.append(("", "\n")) - if self._active_modal_delegate() is not None: - return fragments - fragments.append(("", "\n")) - fragments.append(("class:running-prompt-separator", "─" * max(0, columns))) - fragments.append(("", "\n")) - fragments.extend(self._render_agent_prompt_label()) - return fragments - - def _render_agent_prompt_body(self, columns: int) -> FormattedText: - delegate = self._active_prompt_delegate() - if delegate is None: - return self._render_status_block(columns) - return to_formatted_text(delegate.render_running_prompt_body(columns)) - - def _render_status_block(self, columns: int) -> FormattedText: - status_block_provider = getattr(self, "_status_block_provider", None) - if status_block_provider is None: - return FormattedText([]) - block = status_block_provider(columns) - if block is None: - return FormattedText([]) - return to_formatted_text(block) - - def _render_agent_prompt_label(self) -> FormattedText: - status = self._status_provider() - if status.plan_mode: - return FormattedText([(get_toolbar_colors().plan_prompt, f"{PROMPT_SYMBOL_PLAN} ")]) - symbol = PROMPT_SYMBOL_THINKING if self._thinking else PROMPT_SYMBOL - return FormattedText([("", f"{symbol} ")]) - - def __enter__(self) -> CustomPromptSession: - if self._status_refresh_task is not None and not self._status_refresh_task.done(): - return self - - async def _refresh() -> None: - try: - while True: - app = get_app_or_none() - if app is not None: - app.invalidate() - - try: - asyncio.get_running_loop() - except RuntimeError: - logger.warning("No running loop found, exiting status refresh task") - self._status_refresh_task = None - break - - interval = ( - _RUNNING_REFRESH_INTERVAL - if self._active_prompt_delegate() is not None - or ( - self._fast_refresh_provider is not None - and self._fast_refresh_provider() - ) - else _IDLE_REFRESH_INTERVAL - ) - await asyncio.sleep(interval) - except asyncio.CancelledError: - # graceful exit - pass - - self._status_refresh_task = asyncio.create_task(_refresh()) - return self - - def __exit__(self, *_) -> None: - if self._status_refresh_task is not None and not self._status_refresh_task.done(): - self._status_refresh_task.cancel() - self._status_refresh_task = None - - def _get_placeholder_manager(self) -> PromptPlaceholderManager: - manager = getattr(self, "_placeholder_manager", None) - if manager is None: - attachment_cache = getattr(self, "_attachment_cache", None) - manager = PromptPlaceholderManager(attachment_cache=attachment_cache) - self._placeholder_manager = manager - self._attachment_cache = manager.attachment_cache - return manager - - def _insert_pasted_text(self, buffer: Buffer, text: str) -> None: - normalized = normalize_pasted_text(text) - if self._mode != PromptMode.AGENT: - buffer.insert_text(normalized) - return - token_or_text = self._get_placeholder_manager().maybe_placeholderize_pasted_text(normalized) - buffer.insert_text(token_or_text) - - def _handle_bracketed_paste(self, event: KeyPressEvent) -> None: - self._insert_pasted_text(event.current_buffer, event.data) - event.app.invalidate() - - def _try_paste_media(self, event: KeyPressEvent) -> bool: - """Try to paste media from the clipboard. - - Reads the clipboard once and handles all detected content: - non-image files (videos, PDFs, etc.) are inserted as paths, - image files are cached and inserted as placeholders. - Returns True if any media content was inserted. - """ - result = grab_media_from_clipboard() - if result is None: - return False - - parts: list[str] = [] - - # 1. Insert file paths (videos, PDFs, etc.) - if result.file_paths: - logger.debug("Pasted {count} file path(s) from clipboard", count=len(result.file_paths)) - for p in result.file_paths: - text = str(p) - if self._mode == PromptMode.SHELL: - text = shlex.quote(text) - parts.append(text) - - # 2. Insert images via cache. - if result.images: - if "image_in" not in self._model_capabilities: - console.print( - "[yellow]Image input is not supported by the selected LLM model[/yellow]" - ) - else: - for image in result.images: - token = self._get_placeholder_manager().create_image_placeholder(image) - if token is None: - continue - logger.debug( - "Pasted image from clipboard placeholder: {token}, {image_size}", - token=token, - image_size=image.size, - ) - parts.append(token) - - if parts: - event.current_buffer.insert_text(" ".join(parts)) - event.app.invalidate() - return bool(parts) - - async def prompt_next(self) -> UserInput: - return await self._prompt_once(append_history=None) - - @property - def last_submission_was_running(self) -> bool: - return getattr(self, "_last_submission_was_running", False) - - def attach_running_prompt(self, delegate: RunningPromptDelegate) -> None: - current = getattr(self, "_running_prompt_delegate", None) - if current is delegate: - return - if current is None: - self._running_prompt_previous_mode = self._mode - self._running_prompt_delegate = delegate - self._mode = PromptMode.AGENT - self._apply_mode() - self.invalidate() - - def detach_running_prompt(self, delegate: RunningPromptDelegate) -> None: - if getattr(self, "_running_prompt_delegate", None) is not delegate: - return - previous_mode = getattr(self, "_running_prompt_previous_mode", None) - self._running_prompt_delegate = None - self._running_prompt_previous_mode = None - if previous_mode is not None: - self._mode = previous_mode - self._apply_mode() - self.invalidate() - - def attach_modal(self, delegate: RunningPromptDelegate) -> None: - modal_delegates: list[RunningPromptDelegate] | None = getattr( - self, "_modal_delegates", None - ) - if modal_delegates is None: - modal_delegates = [] - self._modal_delegates = modal_delegates - if delegate in modal_delegates: - return - modal_delegates.append(delegate) - self.invalidate() - - def detach_modal(self, delegate: RunningPromptDelegate) -> None: - modal_delegates = getattr(self, "_modal_delegates", None) - if not modal_delegates or delegate not in modal_delegates: - return - modal_delegates.remove(delegate) - self.invalidate() - - def running_prompt_accepts_submission(self) -> bool: - delegate = self._active_prompt_delegate() - if delegate is None: - return False - return delegate.running_prompt_accepts_submission() - - async def _prompt_once(self, *, append_history: bool | None) -> UserInput: - placeholder = None - if (delegate := self._active_prompt_delegate()) is not None: - placeholder = delegate.running_prompt_placeholder() - with patch_stdout(raw=True): - command = str(await self._session.prompt_async(placeholder=placeholder)).strip() - command = command.replace("\x00", "") # just in case null bytes are somehow inserted - # Sanitize UTF-16 surrogates that may come from Windows clipboard - command = sanitize_surrogates(command) - was_running = self.running_prompt_accepts_submission() - self._last_submission_was_running = was_running - if append_history is None: - append_history = not was_running - if append_history: - self._append_history_entry(command) - self._tip_rotation_index += 1 - return self._build_user_input(command) - - def _build_user_input(self, command: str) -> UserInput: - resolved = self._get_placeholder_manager().resolve_command(command) - - return UserInput( - mode=self._mode, - command=resolved.display_command, - resolved_command=resolved.resolved_text, - content=resolved.content, - ) - - def _append_history_entry(self, text: str) -> None: - safe_history_text = self._get_placeholder_manager().serialize_for_history(text).strip() - entry = _HistoryEntry(content=safe_history_text) - if not entry.content: - return - - # skip if same as last entry - if entry.content == self._last_history_content: - return - - try: - self._history_file.parent.mkdir(parents=True, exist_ok=True) - with self._history_file.open("a", encoding="utf-8") as f: - f.write(entry.model_dump_json(ensure_ascii=False) + "\n") - self._last_history_content = entry.content - except OSError as exc: - logger.warning( - "Failed to append user history entry: {file} ({error})", - file=self._history_file, - error=exc, - ) - - def _render_bottom_toolbar(self) -> FormattedText: - if ( - hasattr(self, "_session") - and self._should_show_slash_completion_menu() - and self._session.default_buffer.complete_state is not None - ): - return FormattedText([]) - app = get_app_or_none() - assert app is not None - columns = app.output.get_size().columns - - fragments: list[tuple[str, str]] = [] - tc = get_toolbar_colors() - - fragments.append((tc.separator, "─" * columns)) - fragments.append(("", "\n")) - - remaining = columns - - # Time-based tip rotation (every 30 s, independent of user submissions) - now = time.monotonic() - if now - self._last_tip_rotate_time >= _TIP_ROTATE_INTERVAL: - self._tip_rotation_index += 1 - self._last_tip_rotate_time = now - - # Status flags: yolo / plan - status = self._status_provider() - if status.yolo_enabled: - fragments.extend([(tc.yolo_label, "yolo"), ("", " ")]) - remaining -= 6 # "yolo" = 4, " " = 2 - if status.plan_mode: - fragments.extend([(tc.plan_label, "plan"), ("", " ")]) - remaining -= 6 - - # Mode indicator (agent / shell) + model name + thinking indicator. - # Degrade gracefully on narrow terminals: - # full: "agent (model-name ○)" → mid: "agent ○" → bare: "agent" - mode = str(self._mode) - if self._mode == PromptMode.AGENT and self._model_name: - thinking_dot = "●" if self._thinking else "○" - mode_full = f"{mode} ({self._model_name} {thinking_dot})" - mode_mid = f"{mode} {thinking_dot}" - if _display_width(mode_full) <= remaining - 2: - mode = mode_full - elif _display_width(mode_mid) <= remaining - 2: - mode = mode_mid - # else: keep bare mode name — model_name and dot are both dropped - fragments.extend([("", mode), ("", " ")]) - remaining -= _display_width(mode) + 2 - - # CWD (truncated from left) + git branch with status badge - # Degrade gracefully on narrow terminals: full → cwd-only → truncated cwd → skip - cwd = _truncate_left(_shorten_cwd(str(KaosPath.cwd())), _MAX_CWD_COLS) - branch = _get_git_branch() - if branch: - dirty, ahead, behind = _get_git_status() - branch = _truncate_right(branch, _MAX_BRANCH_COLS) - badge = _format_git_badge(branch, dirty, ahead, behind) - cwd_text = f"{cwd} {badge}" - else: - cwd_text = cwd - cwd_w = _display_width(cwd_text) - if cwd_w > remaining - 2: - cwd_text = cwd # drop badge - cwd_w = _display_width(cwd_text) - if cwd_w > remaining - 2: - cwd_text = _truncate_right(cwd, max(0, remaining - 2)) - cwd_w = _display_width(cwd_text) - if cwd_text and remaining >= cwd_w + 2: - fragments.extend([(tc.cwd, cwd_text), ("", " ")]) - remaining -= cwd_w + 2 - - # Active background bash task count - bg_count = ( - self._background_task_count_provider() if self._background_task_count_provider else 0 - ) - if bg_count > 0: - bg_text = f"⚙ bash: {bg_count}" - bg_width = _display_width(bg_text) - if remaining >= bg_width + 2: - fragments.extend([(tc.bg_tasks, bg_text), ("", " ")]) - remaining -= bg_width + 2 - - # Tips fill remaining space on line 1 - tip_text = self._get_two_rotating_tips() - if tip_text and _display_width(tip_text) > remaining: - tip_text = self._get_one_rotating_tip() - if tip_text and _display_width(tip_text) <= remaining: - fragments.append((tc.tip, tip_text)) - - # ── line 2: toast (left) + context (right) — always rendered ────── - fragments.append(("", "\n")) - - right_text = self._render_right_span(status) - right_width = _display_width(right_text) - - left_toast = _current_toast("left") - if left_toast is not None: - max_left = max(0, columns - right_width - 2) - if max_left > 0: - left_text = left_toast.message - if _display_width(left_text) > max_left: - left_text = _truncate_right(left_text, max_left) - left_width = _display_width(left_text) - fragments.append(("", left_text)) - else: - left_width = 0 - else: - left_width = 0 - - fragments.append(("", " " * max(0, columns - left_width - right_width))) - fragments.append(("", right_text)) - - return FormattedText(fragments) - - def _get_two_rotating_tips(self) -> str | None: - """Return a string with exactly 2 tips from the rotation, or fewer if not enough.""" - n = len(self._tips) - if n == 0: - return None - if n == 1: - return self._tips[0] - offset = self._tip_rotation_index % n - tip1 = self._tips[offset] - tip2 = self._tips[(offset + 1) % n] - return f"{tip1}{_TIP_SEPARATOR}{tip2}" - - def _get_one_rotating_tip(self) -> str | None: - """Return the single leading tip for the current rotation.""" - if not self._tips: - return None - return self._tips[self._tip_rotation_index % len(self._tips)] - - @staticmethod - def _render_right_span(status: StatusSnapshot) -> str: - current_toast = _current_toast("right") - if current_toast is None: - return format_context_status( - status.context_usage, - status.context_tokens, - status.max_context_tokens, - ) - return current_toast.message diff --git a/src/kimi_cli/ui/shell/question_panel.py b/src/kimi_cli/ui/shell/question_panel.py deleted file mode 100644 index 67a48f21e..000000000 --- a/src/kimi_cli/ui/shell/question_panel.py +++ /dev/null @@ -1,586 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable - -from prompt_toolkit import PromptSession -from prompt_toolkit.application.run_in_terminal import run_in_terminal -from prompt_toolkit.buffer import Buffer -from prompt_toolkit.document import Document -from prompt_toolkit.formatted_text import ANSI -from prompt_toolkit.key_binding import KeyPressEvent -from rich.console import Group, RenderableType -from rich.markup import escape -from rich.panel import Panel -from rich.text import Text - -from kimi_cli.ui.shell.console import console, render_to_ansi -from kimi_cli.ui.shell.keyboard import KeyEvent -from kimi_cli.utils.rich.markdown import Markdown -from kimi_cli.wire.types import QuestionRequest - -OTHER_OPTION_LABEL = "Other" - - -class QuestionRequestPanel: - """Renders structured questions for the user to answer interactively.""" - - def __init__(self, request: QuestionRequest): - self.request = request - self._current_question_index = 0 - self._answers: dict[str, str] = {} - self._saved_selections: dict[int, tuple[int, set[int]]] = {} - self._other_drafts: dict[int, str] = {} - self._selected_index = 0 - self._multi_selected: set[int] = set() - self._body_text: str = "" - self.has_expandable_content: bool = False - self._setup_current_question() - - def _setup_current_question(self) -> None: - q = self._current_question - self._options = [(o.label, o.description) for o in q.options] - other_label = q.other_label or OTHER_OPTION_LABEL - other_desc = q.other_description or "" - self._options.append((other_label, other_desc)) - idx = self._current_question_index - if idx in self._saved_selections: - saved_idx, saved_multi = self._saved_selections[idx] - self._selected_index = min(saved_idx, len(self._options) - 1) - self._multi_selected = saved_multi - elif q.question in self._answers: - answer = self._answers[q.question] - if q.multi_select: - answer_labels = [a.strip() for a in answer.split(", ")] - known_labels = {label for label, _ in self._options[:-1]} - self._multi_selected = set() - for i, (label, _) in enumerate(self._options[:-1]): - if label in answer_labels: - self._multi_selected.add(i) - if any(answer_label not in known_labels for answer_label in answer_labels): - self._multi_selected.add(len(self._options) - 1) - self._selected_index = min(self._multi_selected) if self._multi_selected else 0 - else: - for i, (label, _) in enumerate(self._options): - if label == answer: - self._selected_index = i - break - else: - self._selected_index = len(self._options) - 1 - self._multi_selected = set() - else: - self._selected_index = 0 - self._multi_selected = set() - self._recompute_body() - - def _recompute_body(self) -> None: - body = self._current_question.body - self._body_text = body.rstrip("\n") if body else "" - self.has_expandable_content = bool(self._body_text) - - @property - def _current_question(self): - return self.request.questions[self._current_question_index] - - @property - def is_other_selected(self) -> bool: - return self._selected_index == len(self._options) - 1 - - @property - def is_multi_select(self) -> bool: - return self._current_question.multi_select - - @property - def current_question_text(self) -> str: - return self._current_question.question - - def should_prompt_other_input(self) -> bool: - if not self.is_multi_select: - return self.is_other_selected - other_idx = len(self._options) - 1 - return other_idx in self._multi_selected - - def select_index(self, index: int) -> bool: - if not (0 <= index < len(self._options)): - return False - self._selected_index = index - return True - - def render(self, *, other_input_text: str | None = None) -> RenderableType: - q = self._current_question - lines: list[RenderableType] = [] - - if len(self.request.questions) > 1: - tab_parts: list[str] = [] - for i, qi in enumerate(self.request.questions): - label = escape(qi.header or f"Q{i + 1}") - if i == self._current_question_index: - icon, style = "\u25cf", "bold cyan" - elif qi.question in self._answers: - icon, style = "\u2713", "green" - else: - icon, style = "\u25cb", "grey50" - tab_parts.append(f"[{style}]({icon}) {label}[/{style}]") - lines.append(Text.from_markup(" ".join(tab_parts))) - lines.append(Text("")) - - lines.append(Text.from_markup(f"[yellow]? {escape(q.question)}[/yellow]")) - if q.multi_select: - lines.append(Text(" (SPACE to toggle, ENTER to submit)", style="dim italic")) - lines.append(Text("")) - - if self._body_text: - lines.append( - Text.from_markup( - "[bold cyan] \u25b6 Press ctrl-e to view full content[/bold cyan]" - ) - ) - lines.append(Text("")) - - show_inline_input = other_input_text is not None and self.is_other_selected - - for i, (label, description) in enumerate(self._options): - num = i + 1 - is_other = i == len(self._options) - 1 - if q.multi_select: - checked = "\u2713" if i in self._multi_selected else " " - prefix = f"\\[{checked}]" - if i == self._selected_index: - option_line = Text.from_markup(f"[cyan]{prefix} {escape(label)}[/cyan]") - else: - option_line = Text.from_markup(f"[grey50]{prefix} {escape(label)}[/grey50]") - else: - if i == self._selected_index: - if is_other and show_inline_input: - input_display = escape(other_input_text) if other_input_text else "" - option_line = Text.from_markup( - f"[cyan]\u2192 \\[{num}] {escape(label)}: {input_display}\u2588[/cyan]" - ) - else: - option_line = Text.from_markup( - f"[cyan]\u2192 \\[{num}] {escape(label)}[/cyan]" - ) - else: - option_line = Text.from_markup(f"[grey50] \\[{num}] {escape(label)}[/grey50]") - lines.append(option_line) - - if description and not (is_other and show_inline_input): - lines.append(Text(f" {description}", style="dim")) - - if show_inline_input: - lines.append(Text("")) - lines.append( - Text(" Type your answer, then press Enter to submit.", style="dim italic") - ) - elif len(self.request.questions) > 1: - lines.append(Text("")) - lines.append( - Text( - " \u25c4/\u25ba switch question " - "\u25b2/\u25bc select \u21b5 submit esc exit", - style="dim", - ) - ) - - return Panel( - Group(*lines), - border_style="bold cyan", - title="[bold cyan]? QUESTION[/bold cyan]", - title_align="left", - padding=(0, 1), - ) - - def save_other_draft(self, text: str) -> None: - if text: - self._other_drafts[self._current_question_index] = text - else: - self._other_drafts.pop(self._current_question_index, None) - - def get_other_draft(self) -> str: - return self._other_drafts.get(self._current_question_index, "") - - def go_to(self, index: int) -> None: - if index == self._current_question_index: - return - if not (0 <= index < len(self.request.questions)): - return - self._saved_selections[self._current_question_index] = ( - self._selected_index, - set(self._multi_selected), - ) - self._current_question_index = index - self._setup_current_question() - - def next_tab(self) -> None: - if self._current_question_index < len(self.request.questions) - 1: - self.go_to(self._current_question_index + 1) - - def prev_tab(self) -> None: - if self._current_question_index > 0: - self.go_to(self._current_question_index - 1) - - def move_up(self) -> None: - self._selected_index = (self._selected_index - 1) % len(self._options) - - def move_down(self) -> None: - self._selected_index = (self._selected_index + 1) % len(self._options) - - def toggle_select(self) -> None: - if not self.is_multi_select: - return - if self._selected_index in self._multi_selected: - self._multi_selected.discard(self._selected_index) - else: - self._multi_selected.add(self._selected_index) - - def submit(self) -> bool: - q = self._current_question - if q.multi_select: - other_idx = len(self._options) - 1 - if other_idx in self._multi_selected: - return False - selected_labels = [ - self._options[i][0] for i in sorted(self._multi_selected) if i < len(q.options) - ] - if not selected_labels: - return False - self._answers[q.question] = ", ".join(selected_labels) - else: - if self.is_other_selected: - return False - self._answers[q.question] = self._options[self._selected_index][0] - self._saved_selections.pop(self._current_question_index, None) - self._other_drafts.pop(self._current_question_index, None) - return self._advance() - - def submit_other(self, text: str) -> bool: - q = self._current_question - if q.multi_select: - other_idx = len(self._options) - 1 - selected_labels = [ - self._options[i][0] - for i in sorted(self._multi_selected) - if i < len(q.options) and i != other_idx - ] - if text: - selected_labels.append(text) - self._answers[q.question] = ", ".join(selected_labels) if selected_labels else text - else: - self._answers[q.question] = text - self._saved_selections.pop(self._current_question_index, None) - self._other_drafts.pop(self._current_question_index, None) - return self._advance() - - def _advance(self) -> bool: - total = len(self.request.questions) - if len(self._answers) >= total: - return True - for offset in range(1, total + 1): - idx = (self._current_question_index + offset) % total - if self.request.questions[idx].question not in self._answers: - self._current_question_index = idx - self._setup_current_question() - return False - return True - - def get_answers(self) -> dict[str, str]: - return self._answers - - def render_full_body(self) -> list[RenderableType]: - if not self._body_text: - return [] - return [Markdown(self._body_text)] - - -def show_question_body_in_pager(panel: QuestionRequestPanel) -> None: - with console.screen(), console.pager(styles=True): - console.print(Text.from_markup(f"[yellow]? {escape(panel.current_question_text)}[/yellow]")) - console.print() - for renderable in panel.render_full_body(): - console.print(renderable) - - -async def prompt_other_input(question_text: str) -> str: - console.print(Text.from_markup(f"\n[yellow]? {escape(question_text)}[/yellow]")) - console.print(Text(" Enter your answer:", style="dim")) - try: - session: PromptSession[str] = PromptSession() - return (await session.prompt_async(" > ")).strip() - except (EOFError, KeyboardInterrupt): - return "" - - -class QuestionPromptDelegate: - modal_priority = 10 - _KEY_MAP: dict[str, KeyEvent] = { - "up": KeyEvent.UP, - "down": KeyEvent.DOWN, - "left": KeyEvent.LEFT, - "right": KeyEvent.RIGHT, - "tab": KeyEvent.TAB, - "space": KeyEvent.SPACE, - "enter": KeyEvent.ENTER, - "escape": KeyEvent.ESCAPE, - "c-c": KeyEvent.ESCAPE, - "c-d": KeyEvent.ESCAPE, - "1": KeyEvent.NUM_1, - "2": KeyEvent.NUM_2, - "3": KeyEvent.NUM_3, - "4": KeyEvent.NUM_4, - "5": KeyEvent.NUM_5, - "6": KeyEvent.NUM_6, - } - - def __init__( - self, - panel: QuestionRequestPanel, - *, - on_advance: Callable[[], QuestionRequestPanel | None], - on_invalidate: Callable[[], None], - buffer_text_provider: Callable[[], str] | None = None, - text_expander: Callable[[str], str] | None = None, - ) -> None: - self._panel: QuestionRequestPanel | None = panel - self._awaiting_other_input = False - self._on_advance = on_advance - self._on_invalidate = on_invalidate - self._buffer_text_provider = buffer_text_provider - self._text_expander = text_expander - - @property - def panel(self) -> QuestionRequestPanel | None: - return self._panel - - def set_panel(self, panel: QuestionRequestPanel | None) -> None: - self._panel = panel - self._awaiting_other_input = False - - def _is_inline_other_active(self) -> bool: - return ( - self._panel is not None - and self._panel.is_other_selected - and self._buffer_text_provider is not None - and not self._panel.is_multi_select - ) - - def render_running_prompt_body(self, columns: int) -> ANSI: - if self._panel is None: - return ANSI("") - other_input_text: str | None = None - if self._is_inline_other_active(): - other_input_text = self._buffer_text_provider() if self._buffer_text_provider else "" - body = render_to_ansi( - self._panel.render(other_input_text=other_input_text), - columns=columns, - ).rstrip("\n") - return ANSI(body if body else "") - - def running_prompt_placeholder(self) -> str | None: - return None - - def running_prompt_allows_text_input(self) -> bool: - if self._awaiting_other_input: - return True - return self._is_inline_other_active() - - def running_prompt_hides_input_buffer(self) -> bool: - return self._panel is not None - - def running_prompt_accepts_submission(self) -> bool: - return self._panel is not None - - def should_handle_running_prompt_key(self, key: str) -> bool: - if self._panel is None: - return False - if key == "c-e": - return self._panel.has_expandable_content - if self._awaiting_other_input: - return key in {"enter", "escape", "c-c", "c-d"} - if self._is_inline_other_active(): - return key in {"enter", "escape", "c-c", "c-d", "up", "down", "left", "right", "tab"} - return key in { - "up", - "down", - "left", - "right", - "tab", - "space", - "enter", - "escape", - "c-c", - "c-d", - "1", - "2", - "3", - "4", - "5", - "6", - } - - def handle_running_prompt_key(self, key: str, event: KeyPressEvent) -> None: - if key == "c-e": - event.app.create_background_task(self._show_panel_in_pager()) - return - if self._awaiting_other_input: - if key == "enter": - self._submit_other_input(event.current_buffer) - else: - self._clear_buffer(event.current_buffer) - self._awaiting_other_input = False - if self._panel is not None: - self._panel.request.resolve({}) - self._advance() - self._on_invalidate() - return - - if self._is_inline_other_active(): - mapped = self._KEY_MAP.get(key) - if key == "enter" or mapped == KeyEvent.ENTER: - text = event.current_buffer.text.strip() - if text: - self._submit_other_input(event.current_buffer) - self._on_invalidate() - return - if mapped == KeyEvent.ESCAPE: - self._clear_buffer(event.current_buffer) - if self._panel is not None: - self._panel.request.resolve({}) - self._advance() - self._on_invalidate() - return - if mapped in {KeyEvent.UP, KeyEvent.DOWN, KeyEvent.LEFT, KeyEvent.RIGHT, KeyEvent.TAB}: - self._save_and_clear_buffer(event.current_buffer) - self._dispatch_keyboard_event(mapped) - self._restore_draft_to_buffer(event.current_buffer) - self._on_invalidate() - return - return - - mapped = self._KEY_MAP.get(key) - if mapped is None: - return - if mapped in {KeyEvent.ENTER, KeyEvent.SPACE} and self._should_prompt_other_for_key(mapped): - text = event.current_buffer.text.strip() - if text: - self._submit_other_input(event.current_buffer) - else: - self._clear_buffer(event.current_buffer) - self._awaiting_other_input = True - self._on_invalidate() - return - - if mapped == KeyEvent.ESCAPE: - if self._panel is not None: - self._panel.request.resolve({}) - self._advance() - self._on_invalidate() - return - - if self._panel is not None: - self._save_and_clear_buffer(event.current_buffer) - self._dispatch_keyboard_event(mapped) - self._restore_draft_to_buffer(event.current_buffer) - self._on_invalidate() - - def _should_prompt_other_for_key(self, key: KeyEvent) -> bool: - if self._panel is None or not self._panel.should_prompt_other_input(): - return False - return key == KeyEvent.ENTER or (key == KeyEvent.SPACE and not self._panel.is_multi_select) - - def _dispatch_keyboard_event(self, event: KeyEvent) -> None: - panel = self._panel - if panel is None: - return - match event: - case KeyEvent.UP: - panel.move_up() - case KeyEvent.DOWN: - panel.move_down() - case KeyEvent.LEFT: - panel.prev_tab() - case KeyEvent.RIGHT | KeyEvent.TAB: - panel.next_tab() - case KeyEvent.SPACE: - if panel.is_multi_select: - panel.toggle_select() - else: - self._try_submit() - case KeyEvent.ENTER: - self._try_submit() - case ( - KeyEvent.NUM_1 - | KeyEvent.NUM_2 - | KeyEvent.NUM_3 - | KeyEvent.NUM_4 - | KeyEvent.NUM_5 - | KeyEvent.NUM_6 - ): - num_map = { - KeyEvent.NUM_1: 0, - KeyEvent.NUM_2: 1, - KeyEvent.NUM_3: 2, - KeyEvent.NUM_4: 3, - KeyEvent.NUM_5: 4, - KeyEvent.NUM_6: 5, - } - idx = num_map[event] - if panel.select_index(idx): - if panel.is_multi_select: - panel.toggle_select() - elif not panel.is_other_selected: - self._try_submit() - case _: - pass - - def _try_submit(self) -> None: - if self._panel is None: - return - all_done = self._panel.submit() - if all_done: - self._panel.request.resolve(self._panel.get_answers()) - self._advance() - - def _submit_other_input(self, buffer: Buffer) -> None: - if self._panel is None: - self._clear_buffer(buffer) - self._awaiting_other_input = False - return - text = buffer.text.strip() - if self._text_expander is not None: - text = self._text_expander(text) - self._clear_buffer(buffer) - self._awaiting_other_input = False - all_done = self._panel.submit_other(text) - if all_done: - self._panel.request.resolve(self._panel.get_answers()) - self._advance() - - def _advance(self) -> None: - next_panel = self._on_advance() - self._panel = next_panel - self._awaiting_other_input = False - - def _save_and_clear_buffer(self, buffer: Buffer) -> None: - if self._panel is not None and buffer.text: - self._panel.save_other_draft(buffer.text) - self._clear_buffer(buffer) - - def _restore_draft_to_buffer(self, buffer: Buffer) -> None: - if self._is_inline_other_active() and self._panel is not None: - draft = self._panel.get_other_draft() - if draft: - buffer.set_document( - Document(text=draft, cursor_position=len(draft)), - bypass_readonly=True, - ) - - @staticmethod - def _clear_buffer(buffer: Buffer) -> None: - if buffer.text: - buffer.set_document(Document(text="", cursor_position=0), bypass_readonly=True) - - async def _show_panel_in_pager(self) -> None: - if self._panel is None: - return - panel = self._panel - await run_in_terminal(lambda: show_question_body_in_pager(panel)) - self._on_invalidate() diff --git a/src/kimi_cli/ui/shell/replay.py b/src/kimi_cli/ui/shell/replay.py deleted file mode 100644 index 175994be4..000000000 --- a/src/kimi_cli/ui/shell/replay.py +++ /dev/null @@ -1,210 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -from collections import deque -from collections.abc import Sequence -from dataclasses import dataclass -from typing import cast - -from kosong.message import ContentPart, Message -from kosong.tooling import ToolError, ToolOk - -from kimi_cli.notifications.llm import is_notification_message -from kimi_cli.soul.message import is_system_reminder_message -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.echo import render_user_echo -from kimi_cli.ui.shell.visualize import visualize -from kimi_cli.utils.aioqueue import QueueShutDown -from kimi_cli.utils.logging import logger -from kimi_cli.utils.message import message_stringify -from kimi_cli.utils.slashcmd import parse_slash_command_call -from kimi_cli.wire import Wire -from kimi_cli.wire.file import WireFile -from kimi_cli.wire.types import ( - Event, - StatusUpdate, - SteerInput, - StepBegin, - TextPart, - ToolResult, - TurnBegin, - is_event, -) - -MAX_REPLAY_TURNS = 5 - - -@dataclass(slots=True) -class _ReplayTurn: - user_message: Message - events: list[Event] - n_steps: int = 0 - - -async def replay_recent_history( - history: Sequence[Message], - *, - wire_file: WireFile | None = None, -) -> None: - """ - Replay the most recent user-initiated turns from the provided message history or wire file. - """ - if not history: - # if the context history is empty,either this is a new session - # or the context has been cleared - return - - start_idx = _find_replay_start(history) - history_turns = ( - [] if start_idx is None else _build_replay_turns_from_history(history[start_idx:]) - ) - turns = await _build_replay_turns_from_wire(wire_file) - if not turns or (history_turns and not _same_user_turns(turns, history_turns)): - turns = history_turns - if not turns: - return - - for turn in turns: - wire = Wire() - console.print(render_user_echo(turn.user_message)) - ui_task = asyncio.create_task( - visualize(wire.ui_side(merge=False), initial_status=StatusUpdate()) - ) - for event in turn.events: - wire.soul_side.send(event) - await asyncio.sleep(0) # yield to UI loop - wire.shutdown() - with contextlib.suppress(QueueShutDown): - await ui_task - - -async def _build_replay_turns_from_wire(wire_file: WireFile | None) -> list[_ReplayTurn]: - if wire_file is None or not wire_file.path.exists(): - return [] - - size = wire_file.path.stat().st_size - if size > 20 * 1024 * 1024: - logger.info( - "Wire file too large for replay, skipping: {file} ({size} bytes)", - file=wire_file.path, - size=size, - ) - return [] - - turns: deque[_ReplayTurn] = deque(maxlen=MAX_REPLAY_TURNS) - try: - async for record in wire_file.iter_records(): - wire_msg = record.to_wire_message() - - if isinstance(wire_msg, TurnBegin): - if _is_clear_command_input(wire_msg.user_input): - turns.clear() - continue - turns.append( - _ReplayTurn( - user_message=_message_from_user_input(wire_msg.user_input), - events=[], - ) - ) - continue - - if isinstance(wire_msg, SteerInput): - turns.append( - _ReplayTurn( - user_message=_message_from_user_input(wire_msg.user_input), - events=[], - ) - ) - continue - - if not is_event(wire_msg) or not turns: - continue - - current_turn = turns[-1] - if isinstance(wire_msg, StepBegin): - current_turn.n_steps = wire_msg.n - current_turn.events.append(wire_msg) - except Exception: - logger.exception("Failed to build replay turns from wire file {file}:", file=wire_file.path) - return [] - return list(turns) - - -def _message_from_user_input(user_input: str | list[ContentPart]) -> Message: - content = cast( - list[ContentPart], - list(user_input) if isinstance(user_input, list) else [TextPart(text=user_input)], - ) - return Message(role="user", content=content) - - -def _same_user_turns(lhs: Sequence[_ReplayTurn], rhs: Sequence[_ReplayTurn]) -> bool: - return [message_stringify(turn.user_message) for turn in lhs] == [ - message_stringify(turn.user_message) for turn in rhs - ] - - -def _is_clear_command_input(user_input: str | list[ContentPart]) -> bool: - if isinstance(user_input, list): - text = Message(role="user", content=user_input).extract_text(" ").strip() - else: - text = str(user_input).strip() - call = parse_slash_command_call(text) - if call is None: - return False - return call.name in {"clear", "reset"} - - -def _is_user_message(message: Message) -> bool: - # FIXME: should consider non-text tool call results which are sent as user messages - if message.role != "user": - return False - if message.extract_text().startswith("CHECKPOINT"): - return False - if is_notification_message(message): - return False - return not is_system_reminder_message(message) - - -def _find_replay_start(history: Sequence[Message]) -> int | None: - indices = [idx for idx, message in enumerate(history) if _is_user_message(message)] - if not indices: - return None - # only replay last MAX_REPLAY_TURNS messages - return indices[max(0, len(indices) - MAX_REPLAY_TURNS)] - - -def _build_replay_turns_from_history(history: Sequence[Message]) -> list[_ReplayTurn]: - turns: list[_ReplayTurn] = [] - current_turn: _ReplayTurn | None = None - for message in history: - if _is_user_message(message): - # start a new turn - if current_turn is not None: - turns.append(current_turn) - current_turn = _ReplayTurn(user_message=message, events=[]) - elif message.role == "assistant": - if current_turn is None: - continue - current_turn.n_steps += 1 - current_turn.events.append(StepBegin(n=current_turn.n_steps)) - current_turn.events.extend(message.content) - current_turn.events.extend(message.tool_calls or []) - elif message.role == "tool": - if current_turn is None: - continue - assert message.tool_call_id is not None - if any( - isinstance(part, TextPart) and part.text.startswith("ERROR") - for part in message.content - ): - result = ToolError(message="", output="", brief="") - else: - result = ToolOk(output=message.content) - current_turn.events.append( - ToolResult(tool_call_id=message.tool_call_id, return_value=result) - ) - if current_turn is not None: - turns.append(current_turn) - return turns diff --git a/src/kimi_cli/ui/shell/setup.py b/src/kimi_cli/ui/shell/setup.py deleted file mode 100644 index e44d398c4..000000000 --- a/src/kimi_cli/ui/shell/setup.py +++ /dev/null @@ -1,212 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, NamedTuple - -import aiohttp -from prompt_toolkit import PromptSession -from prompt_toolkit.shortcuts.choice_input import ChoiceInput -from pydantic import SecretStr - -from kimi_cli import logger -from kimi_cli.auth import KIMI_CODE_PLATFORM_ID -from kimi_cli.auth.platforms import ( - PLATFORMS, - ModelInfo, - Platform, - get_platform_by_name, - list_models, - managed_model_key, - managed_provider_key, -) -from kimi_cli.config import ( - LLMModel, - LLMProvider, - MoonshotFetchConfig, - MoonshotSearchConfig, - load_config, - save_config, -) -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.slash import registry - -if TYPE_CHECKING: - from kimi_cli.ui.shell import Shell - - -async def select_platform() -> Platform | None: - platform_name = await _prompt_choice( - header="Select a platform (↑↓ navigate, Enter select, Ctrl+C cancel):", - choices=[platform.name for platform in PLATFORMS], - ) - if not platform_name: - console.print("[red]No platform selected[/red]") - return None - - platform = get_platform_by_name(platform_name) - if platform is None: - console.print("[red]Unknown platform[/red]") - return None - return platform - - -async def setup_platform(platform: Platform) -> bool: - result = await _setup_platform(platform) - if not result: - # error message already printed - return False - - _apply_setup_result(result) - thinking_label = "on" if result.thinking else "off" - console.print("[green]✓ Setup complete![/green]") - console.print(f" Platform: [bold]{result.platform.name}[/bold]") - console.print(f" Model: [bold]{result.selected_model.id}[/bold]") - console.print(f" Thinking: [bold]{thinking_label}[/bold]") - console.print(" Reloading...") - return True - - -class _SetupResult(NamedTuple): - platform: Platform - api_key: SecretStr - selected_model: ModelInfo - models: list[ModelInfo] - thinking: bool - - -async def _setup_platform(platform: Platform) -> _SetupResult | None: - # enter the API key - api_key = await _prompt_text("Enter your API key", is_password=True) - if not api_key: - return None - - # list models - try: - with console.status("[cyan]Verifying API key...[/cyan]"): - models = await list_models(platform, api_key) - except aiohttp.ClientResponseError as e: - logger.error("Failed to get models: {error}", error=e) - console.print(f"[red]Failed to get models: {e.message}[/red]") - if e.status == 401 and platform.id != KIMI_CODE_PLATFORM_ID: - console.print( - "[yellow]Hint: If your API key was obtained from Kimi Code, " - 'please select "Kimi Code" instead.[/yellow]' - ) - return None - except Exception as e: - logger.error("Failed to get models: {error}", error=e) - console.print(f"[red]Failed to get models: {e}[/red]") - return None - - # select the model - if not models: - console.print("[red]No models available for the selected platform[/red]") - return None - - model_map = {model.id: model for model in models} - model_id = await _prompt_choice( - header="Select a model (↑↓ navigate, Enter select, Ctrl+C cancel):", - choices=list(model_map), - ) - if not model_id: - console.print("[red]No model selected[/red]") - return None - - selected_model = model_map[model_id] - - # Determine thinking mode based on model capabilities - capabilities = selected_model.capabilities - thinking: bool - - if "always_thinking" in capabilities: - thinking = True - elif "thinking" in capabilities: - thinking_selection = await _prompt_choice( - header="Enable thinking mode? (↑↓ navigate, Enter select, Ctrl+C cancel):", - choices=["on", "off"], - ) - if not thinking_selection: - return None - thinking = thinking_selection == "on" - else: - thinking = False - - return _SetupResult( - platform=platform, - api_key=SecretStr(api_key), - selected_model=selected_model, - models=models, - thinking=thinking, - ) - - -def _apply_setup_result(result: _SetupResult) -> None: - config = load_config() - provider_key = managed_provider_key(result.platform.id) - model_key = managed_model_key(result.platform.id, result.selected_model.id) - config.providers[provider_key] = LLMProvider( - type="kimi", - base_url=result.platform.base_url, - api_key=result.api_key, - ) - for key, model in list(config.models.items()): - if model.provider == provider_key: - del config.models[key] - for model_info in result.models: - capabilities = model_info.capabilities or None - config.models[managed_model_key(result.platform.id, model_info.id)] = LLMModel( - provider=provider_key, - model=model_info.id, - max_context_size=model_info.context_length, - capabilities=capabilities, - ) - config.default_model = model_key - config.default_thinking = result.thinking - - if result.platform.search_url: - config.services.moonshot_search = MoonshotSearchConfig( - base_url=result.platform.search_url, - api_key=result.api_key, - ) - - if result.platform.fetch_url: - config.services.moonshot_fetch = MoonshotFetchConfig( - base_url=result.platform.fetch_url, - api_key=result.api_key, - ) - - save_config(config) - - -async def _prompt_choice(*, header: str, choices: list[str]) -> str | None: - if not choices: - return None - - try: - return await ChoiceInput( - message=header, - options=[(choice, choice) for choice in choices], - default=choices[0], - ).prompt_async() - except (EOFError, KeyboardInterrupt): - return None - - -async def _prompt_text(prompt: str, *, is_password: bool = False) -> str | None: - session = PromptSession[str]() - try: - return str( - await session.prompt_async( - f" {prompt}: ", - is_password=is_password, - ) - ).strip() - except (EOFError, KeyboardInterrupt): - return None - - -@registry.command -def reload(app: Shell, args: str): - """Reload configuration""" - from kimi_cli.cli import Reload - - raise Reload diff --git a/src/kimi_cli/ui/shell/slash.py b/src/kimi_cli/ui/shell/slash.py deleted file mode 100644 index 2c9a101ea..000000000 --- a/src/kimi_cli/ui/shell/slash.py +++ /dev/null @@ -1,741 +0,0 @@ -from __future__ import annotations - -import asyncio -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, cast - -from prompt_toolkit.shortcuts.choice_input import ChoiceInput - -from kimi_cli import logger -from kimi_cli.auth.platforms import get_platform_name_for_provider, refresh_managed_models -from kimi_cli.cli import Reload, SwitchToVis, SwitchToWeb -from kimi_cli.config import load_config, save_config -from kimi_cli.exception import ConfigError -from kimi_cli.session import Session -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.mcp_status import render_mcp_console -from kimi_cli.ui.shell.task_browser import TaskBrowserApp -from kimi_cli.utils.changelog import CHANGELOG -from kimi_cli.utils.datetime import format_relative_time -from kimi_cli.utils.slashcmd import SlashCommand, SlashCommandRegistry - -if TYPE_CHECKING: - from kimi_cli.ui.shell import Shell - -type ShellSlashCmdFunc = Callable[[Shell, str], None | Awaitable[None]] -""" -A function that runs as a Shell-level slash command. - -Raises: - Reload: When the configuration should be reloaded. -""" - - -registry = SlashCommandRegistry[ShellSlashCmdFunc]() -shell_mode_registry = SlashCommandRegistry[ShellSlashCmdFunc]() - - -def ensure_kimi_soul(app: Shell) -> KimiSoul | None: - if not isinstance(app.soul, KimiSoul): - console.print("[red]KimiSoul required[/red]") - return None - return app.soul - - -@registry.command(aliases=["quit"]) -@shell_mode_registry.command(aliases=["quit"]) -def exit(app: Shell, args: str): - """Exit the application""" - # should be handled by `Shell` - raise NotImplementedError - - -SKILL_COMMAND_PREFIX = "skill:" - -_KEYBOARD_SHORTCUTS = [ - ("Ctrl-X", "Toggle agent/shell mode"), - ("Shift-Tab", "Toggle plan mode (read-only research)"), - ("Ctrl-O", "Edit in external editor ($VISUAL/$EDITOR)"), - ("Ctrl-J / Alt-Enter", "Insert newline"), - ("Ctrl-V", "Paste (supports images)"), - ("Ctrl-D", "Exit"), - ("Ctrl-C", "Interrupt"), -] - - -@registry.command(aliases=["h", "?"]) -@shell_mode_registry.command(aliases=["h", "?"]) -def help(app: Shell, args: str): - """Show help information""" - from rich.console import Group, RenderableType - from rich.text import Text - - from kimi_cli.utils.rich.columns import BulletColumns - - def section(title: str, items: list[tuple[str, str]], color: str) -> BulletColumns: - lines: list[RenderableType] = [Text.from_markup(f"[bold]{title}:[/bold]")] - for name, desc in items: - lines.append( - BulletColumns( - Text.from_markup(f"[{color}]{name}[/{color}]: [grey50]{desc}[/grey50]"), - bullet_style=color, - ) - ) - return BulletColumns(Group(*lines)) - - renderables: list[RenderableType] = [] - renderables.append( - BulletColumns( - Group( - Text.from_markup("[grey50]Help! I need somebody. Help! Not just anybody.[/grey50]"), - Text.from_markup("[grey50]Help! You know I need someone. Help![/grey50]"), - Text.from_markup("[grey50]\u2015 The Beatles, [italic]Help![/italic][/grey50]"), - ), - bullet_style="grey50", - ) - ) - renderables.append( - BulletColumns( - Text( - "Sure, Kimi is ready to help! " - "Just send me messages and I will help you get things done!" - ), - ) - ) - - commands: list[SlashCommand[Any]] = [] - skills: list[SlashCommand[Any]] = [] - for cmd in app.available_slash_commands.values(): - if cmd.name.startswith(SKILL_COMMAND_PREFIX): - skills.append(cmd) - else: - commands.append(cmd) - - renderables.append(section("Keyboard shortcuts", _KEYBOARD_SHORTCUTS, "yellow")) - renderables.append( - section( - "Slash commands", - [(c.slash_name(), c.description) for c in sorted(commands, key=lambda c: c.name)], - "blue", - ) - ) - if skills: - renderables.append( - section( - "Skills", - [(c.slash_name(), c.description) for c in sorted(skills, key=lambda c: c.name)], - "cyan", - ) - ) - - with console.pager(styles=True): - console.print(Group(*renderables)) - - -@registry.command -@shell_mode_registry.command -def version(app: Shell, args: str): - """Show version information""" - from kimi_cli.constant import VERSION - - console.print(f"kimi, version {VERSION}") - - -@registry.command -async def model(app: Shell, args: str): - """Switch LLM model or thinking mode""" - from kimi_cli.llm import derive_model_capabilities - - soul = ensure_kimi_soul(app) - if soul is None: - return - config = soul.runtime.config - - await refresh_managed_models(config) - - if not config.models: - console.print('[yellow]No models configured, send "/login" to login.[/yellow]') - return - - if not config.is_from_default_location: - console.print( - "[yellow]Model switching requires the default config file; " - "restart without --config/--config-file.[/yellow]" - ) - return - - # Find current model/thinking from runtime (may be overridden by --model/--thinking) - curr_model_cfg = soul.runtime.llm.model_config if soul.runtime.llm else None - curr_model_name: str | None = None - if curr_model_cfg is not None: - for name, model_cfg in config.models.items(): - if model_cfg == curr_model_cfg: - curr_model_name = name - break - curr_thinking = soul.thinking - - # Step 1: Select model - model_choices: list[tuple[str, str]] = [] - for name in sorted(config.models): - model_cfg = config.models[name] - provider_label = get_platform_name_for_provider(model_cfg.provider) or model_cfg.provider - marker = " (current)" if name == curr_model_name else "" - label = f"{model_cfg.model} ({provider_label}){marker}" - model_choices.append((name, label)) - - try: - selected_model_name = await ChoiceInput( - message="Select a model (↑↓ navigate, Enter select, Ctrl+C cancel):", - options=model_choices, - default=curr_model_name or model_choices[0][0], - ).prompt_async() - except (EOFError, KeyboardInterrupt): - return - - if not selected_model_name: - return - - selected_model_cfg = config.models[selected_model_name] - selected_provider = config.providers.get(selected_model_cfg.provider) - if selected_provider is None: - console.print(f"[red]Provider not found: {selected_model_cfg.provider}[/red]") - return - - # Step 2: Determine thinking mode - capabilities = derive_model_capabilities(selected_model_cfg) - new_thinking: bool - - if "always_thinking" in capabilities: - new_thinking = True - elif "thinking" in capabilities: - thinking_choices: list[tuple[str, str]] = [ - ("off", "off" + (" (current)" if not curr_thinking else "")), - ("on", "on" + (" (current)" if curr_thinking else "")), - ] - try: - thinking_selection = await ChoiceInput( - message="Enable thinking mode? (↑↓ navigate, Enter select, Ctrl+C cancel):", - options=thinking_choices, - default="on" if curr_thinking else "off", - ).prompt_async() - except (EOFError, KeyboardInterrupt): - return - - if not thinking_selection: - return - - new_thinking = thinking_selection == "on" - else: - new_thinking = False - - # Check if anything changed - model_changed = curr_model_name != selected_model_name - thinking_changed = curr_thinking != new_thinking - - if not model_changed and not thinking_changed: - console.print( - f"[yellow]Already using {selected_model_name} " - f"with thinking {'on' if new_thinking else 'off'}.[/yellow]" - ) - return - - # Save and reload - prev_model = config.default_model - prev_thinking = config.default_thinking - config.default_model = selected_model_name - config.default_thinking = new_thinking - try: - config_for_save = load_config() - config_for_save.default_model = selected_model_name - config_for_save.default_thinking = new_thinking - save_config(config_for_save) - except (ConfigError, OSError) as exc: - config.default_model = prev_model - config.default_thinking = prev_thinking - console.print(f"[red]Failed to save config: {exc}[/red]") - return - - console.print( - f"[green]Switched to {selected_model_name} " - f"with thinking {'on' if new_thinking else 'off'}. " - "Reloading...[/green]" - ) - raise Reload(session_id=soul.runtime.session.id) - - -@registry.command -@shell_mode_registry.command -async def editor(app: Shell, args: str): - """Set default external editor for Ctrl-O""" - from kimi_cli.utils.editor import get_editor_command - - soul = ensure_kimi_soul(app) - if soul is None: - return - config = soul.runtime.config - config_file = config.source_file - if config_file is None: - console.print( - "[yellow]Editor switching is unavailable with inline --config; " - "use --config-file to persist this setting.[/yellow]" - ) - return - - current_editor = config.default_editor - - # If args provided directly, use as editor command - if args.strip(): - new_editor = args.strip() - else: - options: list[tuple[str, str]] = [ - ("code --wait", "VS Code (code --wait)"), - ("vim", "Vim"), - ("nano", "Nano"), - ("", "Auto-detect (use $VISUAL/$EDITOR)"), - ] - # Mark current selection - options = [ - (val, label + (" ← current" if val == current_editor else "")) for val, label in options - ] - - try: - choice = cast( - str | None, - await ChoiceInput( - message="Select an editor (↑↓ navigate, Enter select, Ctrl+C cancel):", - options=options, - default=( - current_editor - if current_editor in {v for v, _ in options} - else "code --wait" - ), - ).prompt_async(), - ) - except (EOFError, KeyboardInterrupt): - return - - if choice is None: - return - new_editor = choice - - # Validate the editor binary is available - if new_editor: - import shlex - import shutil - - try: - parts = shlex.split(new_editor) - except ValueError: - console.print(f"[red]Invalid editor command: {new_editor}[/red]") - return - - binary = parts[0] - if not shutil.which(binary): - console.print( - f"[yellow]Warning: '{binary}' not found in PATH. " - f"Saving anyway — make sure it's installed before using Ctrl-O.[/yellow]" - ) - - if new_editor == current_editor: - console.print(f"[yellow]Editor is already set to: {new_editor or 'auto-detect'}[/yellow]") - return - - # Save to disk - try: - config_for_save = load_config(config_file) - config_for_save.default_editor = new_editor - save_config(config_for_save, config_file) - except (ConfigError, OSError) as exc: - console.print(f"[red]Failed to save config: {exc}[/red]") - return - - # Sync in-memory config so Ctrl-O picks it up immediately - config.default_editor = new_editor - - if new_editor: - console.print(f"[green]Editor set to: {new_editor}[/green]") - else: - resolved = get_editor_command() - label = " ".join(resolved) if resolved else "none" - console.print(f"[green]Editor set to auto-detect (resolved: {label})[/green]") - - -@registry.command(aliases=["release-notes"]) -@shell_mode_registry.command(aliases=["release-notes"]) -def changelog(app: Shell, args: str): - """Show release notes""" - from rich.console import Group, RenderableType - from rich.text import Text - - from kimi_cli.utils.rich.columns import BulletColumns - - renderables: list[RenderableType] = [] - for ver, entry in CHANGELOG.items(): - title = f"[bold]{ver}[/bold]" - if entry.description: - title += f": {entry.description}" - - lines: list[RenderableType] = [Text.from_markup(title)] - for item in entry.entries: - if item.lower().startswith("lib:"): - continue - lines.append( - BulletColumns( - Text.from_markup(f"[grey50]{item}[/grey50]"), - bullet_style="grey50", - ), - ) - renderables.append(BulletColumns(Group(*lines))) - - with console.pager(styles=True): - console.print(Group(*renderables)) - - -@registry.command -@shell_mode_registry.command -async def feedback(app: Shell, args: str): - """Submit feedback to make Kimi Code CLI better""" - import platform - import webbrowser - - import aiohttp - - from kimi_cli.auth import KIMI_CODE_PLATFORM_ID - from kimi_cli.auth.platforms import get_platform_by_id, managed_provider_key - from kimi_cli.constant import VERSION - from kimi_cli.ui.shell.oauth import current_model_key - from kimi_cli.utils.aiohttp import new_client_session - - ISSUE_URL = "https://github.com/MoonshotAI/kimi-cli/issues" - - def _fallback_to_issues(): - if not webbrowser.open(ISSUE_URL): - console.print(f"Please submit feedback at [underline]{ISSUE_URL}[/underline].") - - soul = ensure_kimi_soul(app) - if soul is None: - _fallback_to_issues() - return - - kimi_platform = get_platform_by_id(KIMI_CODE_PLATFORM_ID) - if kimi_platform is None: - _fallback_to_issues() - return - - provider = soul.runtime.config.providers.get(managed_provider_key(KIMI_CODE_PLATFORM_ID)) - if provider is None or provider.oauth is None: - _fallback_to_issues() - return - - from prompt_toolkit import PromptSession - - prompt_session: PromptSession[str] = PromptSession() - try: - content = await prompt_session.prompt_async("Enter your feedback: ") - except (EOFError, KeyboardInterrupt): - console.print("[grey50]Feedback cancelled.[/grey50]") - return - - content = content.strip() - if not content: - console.print("[yellow]Feedback cannot be empty.[/yellow]") - return - - api_key = soul.runtime.oauth.resolve_api_key(provider.api_key, provider.oauth) - feedback_url = f"{kimi_platform.base_url.rstrip('/')}/feedback" - - payload = { - "session_id": soul.runtime.session.id, - "content": content, - "version": VERSION, - "os": f"{platform.system()} {platform.release()}", - "model": current_model_key(soul), - } - - with console.status("[cyan]Submitting feedback...[/cyan]"): - try: - async with ( - new_client_session() as session, - session.post( - feedback_url, - json=payload, - headers={ - "Authorization": f"Bearer {api_key}", - **(provider.custom_headers or {}), - }, - raise_for_status=True, - ), - ): - pass - session_id = soul.runtime.session.id - console.print( - f"[green]Feedback submitted, thank you! Your session ID is: {session_id}[/green]" - ) - except TimeoutError: - console.print("[red]Feedback submission timed out.[/red]") - _fallback_to_issues() - except aiohttp.ClientError as e: - status = getattr(e, "status", None) - if status: - msg = f"Failed to submit feedback (HTTP {status})." - else: - msg = "Network error, failed to submit feedback." - console.print(f"[red]{msg}[/red]") - _fallback_to_issues() - - -@registry.command(aliases=["reset"]) -async def clear(app: Shell, args: str): - """Clear the context""" - if ensure_kimi_soul(app) is None: - return - await app.run_soul_command("/clear") - raise Reload() - - -@registry.command -async def new(app: Shell, args: str): - """Start a new session""" - soul = ensure_kimi_soul(app) - if soul is None: - return - current_session = soul.runtime.session - work_dir = current_session.work_dir - # Clean up the current session if it has no content, so that chaining - # /new commands (or switching away before the first message) does not - # leave orphan empty session directories on disk. - if current_session.is_empty(): - await current_session.delete() - session = await Session.create(work_dir) - console.print("[green]New session created. Switching...[/green]") - raise Reload(session_id=session.id) - - -@registry.command(name="title", aliases=["rename"]) -async def title(app: Shell, args: str): - """Set or show the session title""" - soul = ensure_kimi_soul(app) - if soul is None: - return - session = soul.runtime.session - if not args.strip(): - console.print(f"Session title: [bold]{session.title}[/bold]") - return - - from kimi_cli.session_state import load_session_state, save_session_state - - new_title = args.strip()[:200] - # Read-modify-write: load fresh state to avoid overwriting concurrent web changes - fresh = load_session_state(session.dir) - fresh.custom_title = new_title - fresh.title_generated = True - save_session_state(fresh, session.dir) - session.state.custom_title = new_title - session.state.title_generated = True - session.title = new_title - console.print(f"[green]Session title set to: {new_title}[/green]") - - -@registry.command(name="sessions", aliases=["resume"]) -async def list_sessions(app: Shell, args: str): - """List sessions and resume optionally""" - soul = ensure_kimi_soul(app) - if soul is None: - return - - work_dir = soul.runtime.session.work_dir - current_session = soul.runtime.session - current_session_id = current_session.id - sessions = [ - session for session in await Session.list(work_dir) if session.id != current_session_id - ] - - await current_session.refresh() - sessions.insert(0, current_session) - - choices: list[tuple[str, str]] = [] - for session in sessions: - time_str = format_relative_time(session.updated_at) - marker = " (current)" if session.id == current_session_id else "" - label = f"{session.title} ({session.id}), {time_str}{marker}" - choices.append((session.id, label)) - - try: - selection = await ChoiceInput( - message="Select a session to switch to (↑↓ navigate, Enter select, Ctrl+C cancel):", - options=choices, - default=choices[0][0], - ).prompt_async() - except (EOFError, KeyboardInterrupt): - return - - if not selection: - return - - if selection == current_session_id: - console.print("[yellow]You are already in this session.[/yellow]") - return - - console.print(f"[green]Switching to session {selection}...[/green]") - raise Reload(session_id=selection) - - -@registry.command(name="task") -@shell_mode_registry.command(name="task") -async def task(app: Shell, args: str): - """Browse and manage background tasks""" - soul = ensure_kimi_soul(app) - if soul is None: - return - if args.strip(): - console.print('[yellow]Usage: "/task" opens the interactive task browser.[/yellow]') - return - if soul.runtime.role != "root": - console.print("[yellow]Background tasks are only available from the root agent.[/yellow]") - return - - await TaskBrowserApp(soul).run() - - -@registry.command -@shell_mode_registry.command -def theme(app: Shell, args: str): - """Switch terminal color theme (dark/light)""" - from kimi_cli.ui.theme import get_active_theme - - soul = ensure_kimi_soul(app) - if soul is None: - return - - current = get_active_theme() - arg = args.strip().lower() - - if not arg: - console.print(f"Current theme: [bold]{current}[/bold]") - console.print("[grey50]Usage: /theme dark | /theme light[/grey50]") - return - - if arg not in ("dark", "light"): - console.print(f"[red]Unknown theme: {arg}. Use 'dark' or 'light'.[/red]") - return - - if arg == current: - console.print(f"[yellow]Already using {arg} theme.[/yellow]") - return - - config_file = soul.runtime.config.source_file - if config_file is None: - console.print( - "[yellow]Theme switching requires a config file; " - "restart without --config to persist this setting.[/yellow]" - ) - return - - # Persist to disk first — only update in-memory state after success - try: - config_for_save = load_config(config_file) - config_for_save.theme = arg # type: ignore[assignment] - save_config(config_for_save, config_file) - except (ConfigError, OSError) as exc: - console.print(f"[red]Failed to save config: {exc}[/red]") - return - - console.print(f"[green]Switched to {arg} theme. Reloading...[/green]") - raise Reload(session_id=soul.runtime.session.id) - - -@registry.command -def web(app: Shell, args: str): - """Open Kimi Code Web UI in browser""" - soul = ensure_kimi_soul(app) - session_id = soul.runtime.session.id if soul else None - raise SwitchToWeb(session_id=session_id) - - -@registry.command -def vis(app: Shell, args: str): - """Open Kimi Agent Tracing Visualizer in browser""" - soul = ensure_kimi_soul(app) - session_id = soul.runtime.session.id if soul else None - raise SwitchToVis(session_id=session_id) - - -@registry.command -async def mcp(app: Shell, args: str): - """Show MCP servers and tools""" - from rich.live import Live - - soul = ensure_kimi_soul(app) - if soul is None: - return - await soul.start_background_mcp_loading() - snapshot = soul.status.mcp_status - if snapshot is None: - console.print("[yellow]No MCP servers configured.[/yellow]") - return - - if not snapshot.loading: - console.print(render_mcp_console(snapshot)) - return - - with Live( - render_mcp_console(snapshot), - console=console, - refresh_per_second=8, - transient=False, - ) as live: - while True: - snapshot = soul.status.mcp_status - if snapshot is None: - break - live.update(render_mcp_console(snapshot), refresh=True) - if not snapshot.loading: - break - await asyncio.sleep(0.125) - try: - await soul.wait_for_background_mcp_loading() - except Exception as e: - logger.debug("MCP loading completed with error while rendering /mcp: {error}", error=e) - snapshot = soul.status.mcp_status - if snapshot is not None: - live.update(render_mcp_console(snapshot), refresh=True) - - -@registry.command -@shell_mode_registry.command -def hooks(app: Shell, args: str): - """List configured hooks""" - soul = ensure_kimi_soul(app) - if soul is None: - return - - engine = soul.hook_engine - if not engine.summary: - console.print( - "[yellow]No hooks configured. " - "Add [[hooks]] sections to your config.toml to set up hooks.[/yellow]" - ) - return - - console.print() - console.print("[bold]Configured Hooks:[/bold]") - console.print() - - for event, entries in engine.details().items(): - console.print(f" [cyan]{event}[/cyan]: {len(entries)} hook(s)") - for entry in entries: - source_tag = f" [dim]({entry['source']})[/dim]" if entry["source"] == "wire" else "" - console.print(f" [dim]{entry['matcher']}[/dim] {entry['command']}{source_tag}") - - console.print() - - -from . import ( # noqa: E402 - debug, # noqa: F401 # type: ignore[reportUnusedImport] - export_import, # noqa: F401 # type: ignore[reportUnusedImport] - oauth, # noqa: F401 # type: ignore[reportUnusedImport] - setup, # noqa: F401 # type: ignore[reportUnusedImport] - update, # noqa: F401 # type: ignore[reportUnusedImport] - usage, # noqa: F401 # type: ignore[reportUnusedImport] -) diff --git a/src/kimi_cli/ui/shell/slash.ts b/src/kimi_cli/ui/shell/slash.ts new file mode 100644 index 000000000..ddcc79dc0 --- /dev/null +++ b/src/kimi_cli/ui/shell/slash.ts @@ -0,0 +1,169 @@ +/** + * Shell slash commands — corresponds to Python's ui/shell/slash.py. + * Shell-level commands: /clear, /help, /exit, /theme, /version. + */ + +import type { SlashCommand, CommandPanelConfig } from "../../types"; +import { getActiveTheme } from "../theme.ts"; + +export type SlashCommandHandler = (args: string) => Promise; + +export interface ShellSlashContext { + clearMessages: () => void; + exit: () => void; + setTheme: (theme: "dark" | "light") => void; + getAllCommands: () => SlashCommand[]; + pushNotification: (title: string, body: string) => void; +} + +/** + * Create shell-level slash commands. + */ +export function createShellSlashCommands( + ctx: ShellSlashContext, +): SlashCommand[] { + return [ + { + name: "clear", + description: "Clear conversation history", + aliases: ["cls"], + handler: async () => { + ctx.clearMessages(); + }, + }, + { + name: "exit", + description: "Exit the application", + aliases: ["quit", "q"], + handler: async () => { + ctx.exit(); + }, + }, + { + name: "help", + description: "Show help information", + aliases: ["h", "?"], + handler: async () => { + // Fallback when panel is not used (e.g. direct /help invocation) + const allCmds = ctx.getAllCommands(); + ctx.pushNotification("Help", formatHelp(allCmds)); + }, + panel: (): CommandPanelConfig => { + const allCmds = ctx.getAllCommands(); + return { + type: "content", + title: "Help", + content: formatHelp(allCmds), + }; + }, + }, + { + name: "theme", + description: "Toggle dark/light theme", + handler: async (args: string) => { + const theme = args.trim() as "dark" | "light"; + if (theme === "dark" || theme === "light") { + ctx.setTheme(theme); + ctx.pushNotification("Theme", `Switched to ${theme} theme.`); + } else { + // Toggle + const current = getActiveTheme(); + const next = current === "dark" ? "light" : "dark"; + ctx.setTheme(next); + ctx.pushNotification("Theme", `Switched to ${next} theme.`); + } + }, + panel: (): CommandPanelConfig => { + const current = getActiveTheme(); + return { + type: "choice", + title: "Theme", + items: [ + { label: "🌙 Dark", value: "dark", current: current === "dark" }, + { label: "☀️ Light", value: "light", current: current === "light" }, + ], + onSelect: (value: string) => { + const theme = value as "dark" | "light"; + ctx.setTheme(theme); + ctx.pushNotification("Theme", `Switched to ${theme} theme.`); + }, + }; + }, + }, + { + name: "version", + description: "Show version information", + handler: async () => { + ctx.pushNotification("Version", "kimi-cli v2.0.0 (TypeScript)"); + }, + }, + ]; +} + +/** + * Parse a slash command from input string. + * Returns null if not a slash command. + */ +export function parseSlashCommand( + input: string, +): { name: string; args: string } | null { + if (!input.startsWith("/")) return null; + const trimmed = input.slice(1).trim(); + if (!trimmed) return null; + const spaceIdx = trimmed.indexOf(" "); + if (spaceIdx === -1) { + return { name: trimmed, args: "" }; + } + return { + name: trimmed.slice(0, spaceIdx), + args: trimmed.slice(spaceIdx + 1).trim(), + }; +} + +/** + * Find a slash command by name or alias. + */ +export function findSlashCommand( + commands: SlashCommand[], + name: string, +): SlashCommand | undefined { + return commands.find( + (cmd) => cmd.name === name || cmd.aliases?.includes(name), + ); +} + +function formatHelp(commands: SlashCommand[]): string { + const lines = [ + "Kimi Code CLI — Help", + "", + "Keyboard Shortcuts:", + " Ctrl+X Toggle agent/shell mode", + " Shift+Tab Toggle plan mode", + " Ctrl+O Edit in external editor", + " Ctrl+J / Alt+Enter Insert newline", + " Ctrl+V Paste (supports images)", + " Ctrl+D Exit", + " Ctrl+C Interrupt", + "", + "Slash Commands:", + ]; + + // Deduplicate by name and sort + const seen = new Set(); + const sorted = commands + .filter((c) => { + if (seen.has(c.name)) return false; + seen.add(c.name); + return true; + }) + .sort((a, b) => a.name.localeCompare(b.name)); + + for (const cmd of sorted) { + const aliases = cmd.aliases?.length ? `, /${cmd.aliases.join(", /")}` : ""; + const nameStr = `/${cmd.name}${aliases}`; + lines.push(` ${nameStr.padEnd(22)} ${cmd.description}`); + } + + lines.push(""); + return lines.join("\n"); +} diff --git a/src/kimi_cli/ui/shell/startup.py b/src/kimi_cli/ui/shell/startup.py deleted file mode 100644 index 0babee8b3..000000000 --- a/src/kimi_cli/ui/shell/startup.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from rich.status import Status - -from kimi_cli.ui.shell.console import console - - -class ShellStartupProgress: - """Transient startup status shown while the shell is initializing.""" - - def __init__(self, *, enabled: bool | None = None) -> None: - self._enabled = console.is_terminal if enabled is None else enabled - self._status: Status | None = None - - def update(self, message: str) -> None: - if not self._enabled: - return - - status_message = f"[cyan]{message}[/cyan]" - if self._status is None: - self._status = console.status(status_message, spinner="dots") - self._status.start() - return - - self._status.update(status_message) - - def stop(self) -> None: - if self._status is None: - return - - self._status.stop() - self._status = None diff --git a/src/kimi_cli/ui/shell/task_browser.py b/src/kimi_cli/ui/shell/task_browser.py deleted file mode 100644 index b03f51f04..000000000 --- a/src/kimi_cli/ui/shell/task_browser.py +++ /dev/null @@ -1,486 +0,0 @@ -import time -from dataclasses import dataclass, field -from typing import Literal - -from prompt_toolkit.application import Application -from prompt_toolkit.application.run_in_terminal import run_in_terminal -from prompt_toolkit.filters import Condition -from prompt_toolkit.formatted_text import StyleAndTextTuples -from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent -from prompt_toolkit.layout import HSplit, Layout, VSplit, Window -from prompt_toolkit.layout.controls import FormattedTextControl -from prompt_toolkit.styles import Style -from prompt_toolkit.widgets import Box, Frame, RadioList -from rich.console import Group -from rich.panel import Panel -from rich.text import Text - -from kimi_cli.background import TaskView, is_terminal_status -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.ui.shell.console import console -from kimi_cli.utils.datetime import format_duration, format_relative_time - -TaskBrowserFilter = Literal["all", "active"] - -_EMPTY_TASK_ID = "__empty__" -_PREVIEW_MAX_LINES = 6 -_PREVIEW_MAX_BYTES = 4_000 -_FULL_OUTPUT_MAX_BYTES = 200_000 -_FULL_OUTPUT_MAX_LINES = 4_000 -_AUTO_REFRESH_SECONDS = 1.0 -_FLASH_MESSAGE_SECONDS = 3.0 - - -def format_task_choice(view: TaskView, *, now: float | None = None) -> str: - description = view.spec.description.strip() or "(no description)" - return " · ".join( - [ - f"[{view.runtime.status}]", - description, - view.spec.id, - view.spec.kind, - _task_timing_label(view, now=now) or "updated just now", - ] - ) - - -@dataclass(slots=True) -class TaskBrowserModel: - soul: KimiSoul - filter_mode: TaskBrowserFilter = "all" - message: str = "" - message_expires_at: float | None = None - pending_stop_task_id: str | None = None - all_views: list[TaskView] = field(default_factory=lambda: []) - visible_views: list[TaskView] = field(default_factory=lambda: []) - - @property - def manager(self): - return self.soul.runtime.background_tasks - - @property - def config(self): - return self.soul.runtime.config.background - - def refresh(self, selected_task_id: str | None = None) -> tuple[list[tuple[str, str]], str]: - self.manager.reconcile() - self.all_views = self.manager.list_tasks(limit=None) - self.all_views.sort(key=_task_sort_key) - - if self.filter_mode == "active": - self.visible_views = [ - view for view in self.all_views if not is_terminal_status(view.runtime.status) - ] - else: - self.visible_views = list(self.all_views) - - if not self.visible_views: - label = ( - "No active background tasks." - if self.filter_mode == "active" - else "No background tasks in this session." - ) - self.pending_stop_task_id = None - return [(_EMPTY_TASK_ID, label)], _EMPTY_TASK_ID - - values = [(view.spec.id, format_task_choice(view)) for view in self.visible_views] - valid_ids = {task_id for task_id, _label in values} - selected = selected_task_id if selected_task_id in valid_ids else values[0][0] - - if self.pending_stop_task_id not in valid_ids: - self.pending_stop_task_id = None - return values, selected - - def view_for(self, task_id: str | None) -> TaskView | None: - if not task_id or task_id == _EMPTY_TASK_ID: - return None - for view in self.visible_views: - if view.spec.id == task_id: - return view - return self.manager.get_task(task_id) - - def set_message(self, text: str, *, duration_s: float = _FLASH_MESSAGE_SECONDS) -> None: - self.message = text - self.message_expires_at = time.time() + duration_s - - def current_message(self) -> str | None: - if not self.message: - return None - if self.message_expires_at is None: - return self.message - if time.time() > self.message_expires_at: - self.message = "" - self.message_expires_at = None - return None - return self.message - - def summary_fragments(self) -> StyleAndTextTuples: - counts = { - "running": 0, - "starting": 0, - "failed": 0, - "completed": 0, - "killed": 0, - "lost": 0, - } - for view in self.all_views: - counts[view.runtime.status] = counts.get(view.runtime.status, 0) + 1 - - scope = "ALL" if self.filter_mode == "all" else "ACTIVE" - return [ - ("class:header.title", " TASK BROWSER "), - ("class:header.meta", f" filter={scope} "), - ("class:status.running", f" {counts['running']} running "), - ("class:status.info", f" {counts['starting']} starting "), - ("class:status.error", f" {counts['failed']} failed "), - ("class:status.success", f" {counts['completed']} completed "), - ("class:status.warning", f" {counts['killed'] + counts['lost']} interrupted "), - ("class:header.meta", f" {len(self.all_views)} total "), - ] - - def detail_text(self, task_id: str | None) -> str: - view = self.view_for(task_id) - if view is None: - return "Select a task from the list." - - terminal_reason = "timed_out" if view.runtime.timed_out else view.runtime.status - lines = [ - f"Task ID: {view.spec.id}", - f"Status: {view.runtime.status}", - f"Description: {view.spec.description}", - f"Kind: {view.spec.kind}", - ] - timing = _task_timing_label(view) - if timing: - lines.append(f"Time: {timing}") - if view.spec.cwd: - lines.append(f"Cwd: {view.spec.cwd}") - if view.spec.command: - lines.append(f"Command: {view.spec.command}") - if view.runtime.exit_code is not None: - lines.append(f"Exit code: {view.runtime.exit_code}") - lines.append(f"Terminal reason: {terminal_reason}") - if view.runtime.failure_reason: - lines.append(f"Reason: {view.runtime.failure_reason}") - return "\n".join(lines) - - def preview_text(self, task_id: str | None) -> str: - view = self.view_for(task_id) - if view is None: - return "No output to preview." - - preview = self.manager.tail_output( - view.spec.id, - max_bytes=_PREVIEW_MAX_BYTES, - max_lines=_PREVIEW_MAX_LINES, - ) - if not preview: - return "[no output available]" - return preview - - def full_output(self, task_id: str | None) -> str: - view = self.view_for(task_id) - if view is None: - return "[no output available]" - - path = self.manager.resolve_output_path(view.spec.id) - total_size = path.stat().st_size if path.exists() else 0 - output = self.manager.tail_output( - view.spec.id, - max_bytes=max(self.config.read_max_bytes * 10, _FULL_OUTPUT_MAX_BYTES), - max_lines=_FULL_OUTPUT_MAX_LINES, - ) - max_bytes = max(self.config.read_max_bytes * 10, _FULL_OUTPUT_MAX_BYTES) - if total_size > max_bytes: - return ( - f"[showing last {max_bytes} bytes of {total_size} bytes]\n\n" - f"{output or '[no output available]'}" - ) - return output or "[no output available]" - - def footer_fragments(self, task_id: str | None) -> StyleAndTextTuples: - if self.pending_stop_task_id is not None: - label = self.pending_stop_task_id - return [ - ("class:footer.warning", f" Confirm stop {label}? "), - ("class:footer.key", "Y"), - ("class:footer.text", " confirm "), - ("class:footer.key", "N"), - ("class:footer.text", " cancel "), - ] - - fragments: StyleAndTextTuples = [ - ("class:footer.key", " Enter "), - ("class:footer.text", "output "), - ("class:footer.key", "S"), - ("class:footer.text", " stop "), - ("class:footer.key", "R"), - ("class:footer.text", " refresh "), - ("class:footer.key", "Tab"), - ("class:footer.text", " filter "), - ("class:footer.key", "Q"), - ("class:footer.text", " exit "), - ("class:footer.meta", f" auto-refresh {_AUTO_REFRESH_SECONDS:.0f}s "), - ] - if message := self.current_message(): - fragments.extend( - [ - ("class:footer.meta", " | "), - ("class:footer.flash", f" {message} "), - ] - ) - return fragments - - -class TaskBrowserApp: - def __init__(self, soul: KimiSoul): - self._model = TaskBrowserModel(soul) - task_values, selected = self._model.refresh() - self._task_list = RadioList( - values=task_values, - default=selected, - show_numbers=False, - select_on_focus=True, - open_character="", - select_character=">", - close_character="", - show_cursor=False, - show_scrollbar=False, - container_style="class:task-list", - checked_style="class:task-list.checked", - ) - self._app = self._build_app() - - async def run(self) -> None: - await self._app.run_async() - - @property - def _selected_task_id(self) -> str | None: - current = self._task_list.current_value - if current == _EMPTY_TASK_ID: - return None - return current - - def _open_output(self, app: Application[object], task_id: str) -> None: - app.create_background_task(self._show_output_in_terminal(task_id)) - - async def _show_output_in_terminal(self, task_id: str) -> None: - def render() -> None: - view = self._model.view_for(task_id) - if view is None: - console.print(f"[yellow]Task not found: {task_id}[/yellow]") - return - with console.pager(styles=True): - console.print(_build_full_output_renderable(view, self._model.full_output(task_id))) - - await run_in_terminal(render) - - def _toggle_filter(self) -> None: - self._model.filter_mode = "active" if self._model.filter_mode == "all" else "all" - self._model.set_message( - "Showing active tasks only." - if self._model.filter_mode == "active" - else "Showing all tasks." - ) - self._sync_views() - - def _refresh_views(self) -> None: - self._model.set_message("Refreshed.") - self._sync_views() - - def _request_stop_for_selected_task(self) -> None: - view = self._model.view_for(self._selected_task_id) - if view is None: - self._model.set_message("No task selected.") - elif is_terminal_status(view.runtime.status): - self._model.set_message(f"Task {view.spec.id} is already {view.runtime.status}.") - else: - self._model.pending_stop_task_id = view.spec.id - self._model.message = "" - self._model.message_expires_at = None - - def _confirm_stop_request(self) -> None: - task_id = self._model.pending_stop_task_id - self._model.pending_stop_task_id = None - if task_id is None: - return - view = self._model.view_for(task_id) - if view is None: - self._model.set_message(f"Task not found: {task_id}") - elif is_terminal_status(view.runtime.status): - self._model.set_message(f"Task {task_id} is already {view.runtime.status}.") - else: - self._model.manager.kill(task_id) - self._model.set_message(f"Stop requested for task {task_id}.") - self._sync_views() - - def _cancel_stop_request(self) -> None: - self._model.pending_stop_task_id = None - self._model.set_message("Stop cancelled.") - - def _build_app(self) -> Application[None]: - kb = KeyBindings() - - @Condition - def stop_pending() -> bool: - return self._model.pending_stop_task_id is not None - - @kb.add("q") - @kb.add("escape", filter=~stop_pending) - @kb.add("c-c") - def _exit(event: KeyPressEvent) -> None: - event.app.exit() - - @kb.add("tab", filter=~stop_pending) - def _toggle_filter(event: KeyPressEvent) -> None: - self._toggle_filter() - event.app.invalidate() - - @kb.add("r", filter=~stop_pending) - def _refresh(event: KeyPressEvent) -> None: - self._refresh_views() - event.app.invalidate() - - @kb.add("s", filter=~stop_pending) - def _stop(event: KeyPressEvent) -> None: - self._request_stop_for_selected_task() - event.app.invalidate() - - @kb.add("y", filter=stop_pending) - def _confirm_stop(event: KeyPressEvent) -> None: - self._confirm_stop_request() - event.app.invalidate() - - @kb.add("n", filter=stop_pending) - @kb.add("escape", filter=stop_pending) - def _cancel_stop(event: KeyPressEvent) -> None: - self._cancel_stop_request() - event.app.invalidate() - - @kb.add("enter", filter=~stop_pending, eager=True) - @kb.add("o", filter=~stop_pending) - def _show_output(event: KeyPressEvent) -> None: - task_id = self._selected_task_id - if task_id is None: - self._model.set_message("No task selected.") - event.app.invalidate() - return - self._open_output(event.app, task_id) - - # Handlers are registered via @kb.add decorators above; mark as accessed. - _ = (_exit, _toggle_filter, _refresh, _stop, _confirm_stop, _cancel_stop, _show_output) - - body = VSplit( - [ - Frame( - Box(self._task_list, padding=1), - title=lambda: f" Tasks [{self._model.filter_mode}] ", - ), - HSplit( - [ - Frame( - Window( - FormattedTextControl(self._detail_fragments), - wrap_lines=True, - ), - title=" Detail ", - ), - Frame( - Window( - FormattedTextControl(self._preview_fragments), - wrap_lines=True, - ), - title=" Preview Output ", - ), - ] - ), - ] - ) - footer = Window( - FormattedTextControl(self._footer_fragments), - height=1, - style="class:footer", - ) - header = Window( - FormattedTextControl(self._header_fragments), - height=1, - style="class:header", - ) - - return Application( - layout=Layout( - HSplit( - [ - header, - body, - footer, - ] - ), - focused_element=self._task_list, - ), - key_bindings=kb, - full_screen=True, - erase_when_done=True, - style=_task_browser_style(), - refresh_interval=_AUTO_REFRESH_SECONDS, - before_render=lambda _app: self._sync_views(), - ) - - def _sync_views(self) -> None: - values, selected = self._model.refresh(self._selected_task_id) - self._task_list.values = values - self._task_list.current_value = selected - self._task_list.current_values = [selected] - for index, (value, _label) in enumerate(values): - if value == selected: - self._task_list._selected_index = index # pyright: ignore[reportPrivateUsage] - break - - def _header_fragments(self) -> StyleAndTextTuples: - return self._model.summary_fragments() - - def _detail_fragments(self) -> StyleAndTextTuples: - return [("", self._model.detail_text(self._selected_task_id))] - - def _preview_fragments(self) -> StyleAndTextTuples: - return [("", self._model.preview_text(self._selected_task_id))] - - def _footer_fragments(self) -> StyleAndTextTuples: - return self._model.footer_fragments(self._selected_task_id) - - -def _build_full_output_renderable(view: TaskView, output: str) -> Panel: - return Panel( - Group( - Text(f"Task ID: {view.spec.id}", style="bold"), - Text(f"Status: {view.runtime.status}"), - Text(f"Description: {view.spec.description}"), - Text(""), - Text(output), - ), - title="Background Task Output", - border_style="cyan", - ) - - -def _task_sort_key(view: TaskView) -> tuple[int, float]: - if not is_terminal_status(view.runtime.status): - return (0, view.spec.created_at) - finished_at = view.runtime.finished_at or view.runtime.updated_at or view.spec.created_at - return (1, -finished_at) - - -def _task_timing_label(view: TaskView, *, now: float | None = None) -> str | None: - current = now if now is not None else time.time() - if view.runtime.finished_at is not None: - return f"finished {format_relative_time(view.runtime.finished_at)}" - if view.runtime.started_at is not None: - seconds = max(0, int(current - view.runtime.started_at)) - return f"running {format_duration(seconds)}" - return f"updated {format_relative_time(view.runtime.updated_at)}" - - -def _task_browser_style() -> Style: - from kimi_cli.ui.theme import get_task_browser_style - - return get_task_browser_style() diff --git a/src/kimi_cli/ui/shell/update.py b/src/kimi_cli/ui/shell/update.py deleted file mode 100644 index bb7e686b9..000000000 --- a/src/kimi_cli/ui/shell/update.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import platform -import re -import shutil -import stat -import tarfile -import tempfile -from enum import Enum, auto -from pathlib import Path - -import aiohttp - -from kimi_cli.share import get_share_dir -from kimi_cli.ui.shell.console import console -from kimi_cli.utils.aiohttp import new_client_session -from kimi_cli.utils.logging import logger - -BASE_URL = "https://cdn.kimi.com/binaries/kimi-cli" -LATEST_VERSION_URL = f"{BASE_URL}/latest" -INSTALL_DIR = Path.home() / ".local" / "bin" - -# Upgrade command shown in toast notifications. Can be overridden by wrappers -UPGRADE_COMMAND = "uv tool upgrade kimi-cli" - - -class UpdateResult(Enum): - UPDATE_AVAILABLE = auto() - UPDATED = auto() - UP_TO_DATE = auto() - FAILED = auto() - UNSUPPORTED = auto() - - -_UPDATE_LOCK = asyncio.Lock() - - -def semver_tuple(version: str) -> tuple[int, int, int]: - v = version.strip() - if v.startswith("v"): - v = v[1:] - match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", v) - if not match: - return (0, 0, 0) - major = int(match.group(1)) - minor = int(match.group(2)) - patch = int(match.group(3) or 0) - return (major, minor, patch) - - -def _detect_target() -> str | None: - sys_name = platform.system() - mach = platform.machine() - if mach in ("x86_64", "amd64", "AMD64"): - arch = "x86_64" - elif mach in ("arm64", "aarch64"): - arch = "aarch64" - else: - logger.error("Unsupported architecture: {mach}", mach=mach) - return None - if sys_name == "Darwin": - os_name = "apple-darwin" - elif sys_name == "Linux": - os_name = "unknown-linux-gnu" - else: - logger.error("Unsupported OS: {sys_name}", sys_name=sys_name) - return None - return f"{arch}-{os_name}" - - -async def _get_latest_version(session: aiohttp.ClientSession) -> str | None: - try: - async with session.get(LATEST_VERSION_URL) as resp: - resp.raise_for_status() - data = await resp.text() - return data.strip() - except (TimeoutError, aiohttp.ClientError): - logger.exception("Failed to get latest version:") - return None - - -async def do_update(*, print: bool = True, check_only: bool = False) -> UpdateResult: - async with _UPDATE_LOCK: - return await _do_update(print=print, check_only=check_only) - - -LATEST_VERSION_FILE = get_share_dir() / "latest_version.txt" - - -async def _do_update(*, print: bool, check_only: bool) -> UpdateResult: - from kimi_cli.constant import VERSION as current_version - - def _print(message: str) -> None: - if print: - console.print(message) - - target = _detect_target() - if not target: - _print("[red]Failed to detect target platform.[/red]") - return UpdateResult.UNSUPPORTED - - # Version check is fast, but the binary download can be large on slow links. - download_timeout = aiohttp.ClientTimeout(total=600, sock_read=60, sock_connect=15) - async with new_client_session(timeout=download_timeout) as session: - logger.info("Checking for updates...") - _print("Checking for updates...") - latest_version = await _get_latest_version(session) - if not latest_version: - _print("[red]Failed to check for updates.[/red]") - return UpdateResult.FAILED - - logger.debug("Latest version: {latest_version}", latest_version=latest_version) - LATEST_VERSION_FILE.write_text(latest_version, encoding="utf-8") - - cur_t = semver_tuple(current_version) - lat_t = semver_tuple(latest_version) - - if cur_t >= lat_t: - logger.debug("Already up to date: {current_version}", current_version=current_version) - _print("[green]Already up to date.[/green]") - return UpdateResult.UP_TO_DATE - - if check_only: - logger.info( - "Update available: current={current_version}, latest={latest_version}", - current_version=current_version, - latest_version=latest_version, - ) - _print(f"[yellow]Update available: {latest_version}[/yellow]") - return UpdateResult.UPDATE_AVAILABLE - - logger.info( - "Updating from {current_version} to {latest_version}...", - current_version=current_version, - latest_version=latest_version, - ) - _print(f"Updating from {current_version} to {latest_version}...") - - filename = f"kimi-{latest_version}-{target}.tar.gz" - download_url = f"{BASE_URL}/{latest_version}/{filename}" - - with tempfile.TemporaryDirectory(prefix="kimi-cli-") as tmpdir: - tar_path = os.path.join(tmpdir, filename) - - logger.info("Downloading from {download_url}...", download_url=download_url) - _print("[grey50]Downloading...[/grey50]") - try: - async with session.get(download_url) as resp: - resp.raise_for_status() - with open(tar_path, "wb") as f: - async for chunk in resp.content.iter_chunked(1024 * 64): - if chunk: - f.write(chunk) - except (TimeoutError, aiohttp.ClientError): - logger.exception( - "Failed to download update from {download_url}", - download_url=download_url, - ) - _print("[red]Failed to download.[/red]") - return UpdateResult.FAILED - except Exception: - logger.exception("Failed to download:") - _print("[red]Failed to download.[/red]") - return UpdateResult.FAILED - - logger.info("Extracting archive {tar_path}...", tar_path=tar_path) - _print("[grey50]Extracting...[/grey50]") - try: - with tarfile.open(tar_path, "r:gz") as tar: - tar.extractall(tmpdir) - binary_path = None - for root, _, files in os.walk(tmpdir): - if "kimi" in files: - binary_path = os.path.join(root, "kimi") - break - if not binary_path: - logger.error("Binary 'kimi' not found in archive.") - _print("[red]Binary 'kimi' not found in archive.[/red]") - return UpdateResult.FAILED - except Exception: - logger.exception("Failed to extract archive:") - _print("[red]Failed to extract archive.[/red]") - return UpdateResult.FAILED - - INSTALL_DIR.mkdir(parents=True, exist_ok=True) - dest_path = INSTALL_DIR / "kimi" - logger.info("Installing to {dest_path}...", dest_path=dest_path) - _print("[grey50]Installing...[/grey50]") - - try: - shutil.copy2(binary_path, dest_path) - os.chmod( - dest_path, - os.stat(dest_path).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH, - ) - except Exception: - logger.exception("Failed to install:") - _print("[red]Failed to install.[/red]") - return UpdateResult.FAILED - - _print("[green]Updated successfully![/green]") - _print("[yellow]Restart Kimi Code CLI to use the new version.[/yellow]") - return UpdateResult.UPDATED - - -# @meta_command -# async def update(app: "Shell", args: list[str]): -# """Check for updates""" -# await do_update(print=True) - - -# @meta_command(name="check-update") -# async def check_update(app: "Shell", args: list[str]): -# """Check for updates""" -# await do_update(print=True, check_only=True) diff --git a/src/kimi_cli/ui/shell/usage.py b/src/kimi_cli/ui/shell/usage.py deleted file mode 100644 index a7fff2a09..000000000 --- a/src/kimi_cli/ui/shell/usage.py +++ /dev/null @@ -1,281 +0,0 @@ -"""This file is pure vibe-coded. If any bugs are found, let's just rewrite it...""" - -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast - -import aiohttp -from rich.console import Group, RenderableType -from rich.panel import Panel -from rich.progress_bar import ProgressBar -from rich.table import Table -from rich.text import Text - -from kimi_cli.auth import KIMI_CODE_PLATFORM_ID -from kimi_cli.auth.platforms import get_platform_by_id, parse_managed_provider_key -from kimi_cli.config import LLMModel -from kimi_cli.soul.kimisoul import KimiSoul -from kimi_cli.ui.shell.console import console -from kimi_cli.ui.shell.slash import registry -from kimi_cli.utils.aiohttp import new_client_session -from kimi_cli.utils.datetime import format_duration - -if TYPE_CHECKING: - from kimi_cli.ui.shell import Shell - - -@dataclass(slots=True, frozen=True) -class UsageRow: - label: str - used: int - limit: int - reset_hint: str | None = None - - -@registry.command(aliases=["/status"]) -async def usage(app: Shell, args: str): - """Display API usage and quota information""" - assert isinstance(app.soul, KimiSoul) - if app.soul.runtime.llm is None: - console.print("[red]LLM not set. Please run /login first.[/red]") - return - - provider = app.soul.runtime.llm.provider_config - if provider is None: - console.print("[red]LLM provider configuration not found.[/red]") - return - - usage_url = _usage_url(app.soul.runtime.llm.model_config) - if usage_url is None: - console.print("[yellow]Usage is available on Kimi Code platform only.[/yellow]") - return - - with console.status("[cyan]Fetching usage...[/cyan]"): - api_key = app.soul.runtime.oauth.resolve_api_key(provider.api_key, provider.oauth) - try: - payload = await _fetch_usage(usage_url, api_key) - except aiohttp.ClientResponseError as e: - message = "Failed to fetch usage." - if e.status == 401: - message = "Authorization failed. Please check your API key." - elif e.status == 404: - message = "Usage endpoint not available. Try Kimi For Coding." - console.print(f"[red]{message}[/red]") - return - except TimeoutError: - console.print("[red]Failed to fetch usage: request timed out.[/red]") - return - except aiohttp.ClientError as e: - console.print(f"[red]Failed to fetch usage: {e}[/red]") - return - - summary, limits = _parse_usage_payload(payload) - if summary is None and not limits: - console.print("[yellow]No usage data available.[/yellow]") - return - - console.print(_build_usage_panel(summary, limits)) - - -def _usage_url(model: LLMModel | None) -> str | None: - if model is None: - return None - platform_id = parse_managed_provider_key(model.provider) - if platform_id is None: - return None - platform = get_platform_by_id(platform_id) - if platform is None or platform.id != KIMI_CODE_PLATFORM_ID: - return None - base_url = platform.base_url.rstrip("/") - return f"{base_url}/usages" - - -async def _fetch_usage(url: str, api_key: str) -> Mapping[str, Any]: - async with ( - new_client_session() as session, - session.get( - url, - headers={"Authorization": f"Bearer {api_key}"}, - raise_for_status=True, - ) as resp, - ): - return await resp.json() - - -def _parse_usage_payload( - payload: Mapping[str, Any], -) -> tuple[UsageRow | None, list[UsageRow]]: - summary: UsageRow | None = None - limits: list[UsageRow] = [] - - usage = payload.get("usage") - if isinstance(usage, Mapping): - usage_map: Mapping[str, Any] = cast(Mapping[str, Any], usage) - summary = _to_usage_row(usage_map, default_label="Weekly limit") - - raw_limits_obj = payload.get("limits") - if isinstance(raw_limits_obj, Sequence): - limits_seq: Sequence[Any] = cast(Sequence[Any], raw_limits_obj) - for idx, item in enumerate(limits_seq): - if not isinstance(item, Mapping): - continue - item_map: Mapping[str, Any] = cast(Mapping[str, Any], item) - detail_raw = item_map.get("detail") - detail: Mapping[str, Any] = ( - cast(Mapping[str, Any], detail_raw) if isinstance(detail_raw, Mapping) else item_map - ) - # window may contain duration/timeUnit - window_raw = item_map.get("window") - window: Mapping[str, Any] = ( - cast(Mapping[str, Any], window_raw) if isinstance(window_raw, Mapping) else {} - ) - label = _limit_label(item_map, detail, window, idx) - row = _to_usage_row(detail, default_label=label) - if row: - limits.append(row) - - return summary, limits - - -def _to_usage_row(data: Mapping[str, Any], *, default_label: str) -> UsageRow | None: - limit = _to_int(data.get("limit")) - # Support both "used" and "remaining" (used = limit - remaining) - used = _to_int(data.get("used")) - if used is None: - remaining = _to_int(data.get("remaining")) - if remaining is not None and limit is not None: - used = limit - remaining - if used is None and limit is None: - return None - return UsageRow( - label=str(data.get("name") or data.get("title") or default_label), - used=used or 0, - limit=limit or 0, - reset_hint=_reset_hint(data), - ) - - -def _limit_label( - item: Mapping[str, Any], - detail: Mapping[str, Any], - window: Mapping[str, Any], - idx: int, -) -> str: - # Try to extract a human-readable label - for key in ("name", "title", "scope"): - if val := (item.get(key) or detail.get(key)): - return str(val) - - # Convert duration to readable format (e.g., 300 minutes -> "5h quota") - # Check window first, then item, then detail - duration = _to_int(window.get("duration") or item.get("duration") or detail.get("duration")) - time_unit = window.get("timeUnit") or item.get("timeUnit") or detail.get("timeUnit") or "" - if duration: - if "MINUTE" in time_unit: - if duration >= 60 and duration % 60 == 0: - return f"{duration // 60}h limit" - return f"{duration}m limit" - if "HOUR" in time_unit: - return f"{duration}h limit" - if "DAY" in time_unit: - return f"{duration}d limit" - return f"{duration}s limit" - - return f"Limit #{idx + 1}" - - -def _reset_hint(data: Mapping[str, Any]) -> str | None: - for key in ("reset_at", "resetAt", "reset_time", "resetTime"): - if val := data.get(key): - return _format_reset_time(str(val)) - - for key in ("reset_in", "resetIn", "ttl", "window"): - seconds = _to_int(data.get(key)) - if seconds: - return f"resets in {format_duration(seconds)}" - - return None - - -def _format_reset_time(val: str) -> str: - """Format ISO timestamp to a readable duration.""" - from datetime import UTC, datetime - - try: - # Parse ISO format like "2025-12-23T05:24:18.443553353Z" - # Truncate nanoseconds to microseconds for Python compatibility - if "." in val and val.endswith("Z"): - base, frac = val[:-1].split(".") - frac = frac[:6] # Keep only microseconds - val = f"{base}.{frac}Z" - dt = datetime.fromisoformat(val.replace("Z", "+00:00")) - now = datetime.now(UTC) - delta = dt - now - - if delta.total_seconds() <= 0: - return "reset" - return f"resets in {format_duration(int(delta.total_seconds()))}" - except (ValueError, TypeError): - return f"resets at {val}" - - -def _to_int(value: Any) -> int | None: - try: - return int(value) - except (TypeError, ValueError): - return None - - -def _build_usage_panel(summary: UsageRow | None, limits: list[UsageRow]) -> Panel: - rows = ([summary] if summary else []) + limits - if not rows: - return Panel( - Text("No usage data", style="grey50"), title="API Usage", border_style="wheat4" - ) - - # Calculate label width for alignment - label_width = max(len(r.label) for r in rows) - label_width = max(label_width, 6) # minimum width - - lines: list[RenderableType] = [] - for row in rows: - lines.append(_format_row(row, label_width)) - - return Panel( - Group(*lines), - title="API Usage", - border_style="wheat4", - padding=(0, 2), - expand=False, - ) - - -def _format_row(row: UsageRow, label_width: int) -> RenderableType: - ratio = (row.limit - row.used) / row.limit if row.limit > 0 else 0 - color = _ratio_color(ratio) - - label = Text(f"{row.label:<{label_width}} ", style="cyan") - bar = ProgressBar(total=row.limit or 1, completed=row.used, width=20, complete_style=color) - - detail = Text() - percent = ratio * 100 - detail.append(f" {percent:.0f}% left", style="bold") - if row.reset_hint: - detail.append(f" ({row.reset_hint})", style="grey50") - - t = Table.grid(padding=0) - t.add_column(width=label_width + 2) - t.add_column(width=20) - t.add_column() - t.add_row(label, bar, detail) - return t - - -def _ratio_color(ratio: float) -> str: - if ratio >= 0.9: - return "red" - if ratio >= 0.7: - return "yellow" - return "green" diff --git a/src/kimi_cli/ui/shell/visualize.py b/src/kimi_cli/ui/shell/visualize.py deleted file mode 100644 index f63fea1f5..000000000 --- a/src/kimi_cli/ui/shell/visualize.py +++ /dev/null @@ -1,1497 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import time -from collections import deque -from collections.abc import Awaitable, Callable -from contextlib import asynccontextmanager, suppress -from typing import TYPE_CHECKING, Any, NamedTuple, cast - -if TYPE_CHECKING: - from markdown_it import MarkdownIt - -import streamingjson # type: ignore[reportMissingTypeStubs] -from kosong.message import Message -from kosong.tooling import ToolError, ToolOk -from prompt_toolkit.application.run_in_terminal import run_in_terminal -from prompt_toolkit.buffer import Buffer -from prompt_toolkit.document import Document -from prompt_toolkit.formatted_text import ANSI -from prompt_toolkit.key_binding import KeyPressEvent -from rich.console import Group, RenderableType -from rich.live import Live -from rich.panel import Panel -from rich.spinner import Spinner -from rich.style import Style -from rich.text import Text - -from kimi_cli.soul import format_context_status, format_token_count -from kimi_cli.tools import extract_key_argument -from kimi_cli.ui.shell.approval_panel import ( - ApprovalPromptDelegate as ApprovalPromptDelegate, # noqa: F401 — re-exported -) -from kimi_cli.ui.shell.approval_panel import ( - ApprovalRequestPanel, - show_approval_in_pager, -) -from kimi_cli.ui.shell.console import console, render_to_ansi -from kimi_cli.ui.shell.echo import render_user_echo, render_user_echo_text -from kimi_cli.ui.shell.keyboard import KeyboardListener, KeyEvent -from kimi_cli.ui.shell.prompt import ( - CustomPromptSession, - UserInput, -) -from kimi_cli.ui.shell.question_panel import ( - QuestionPromptDelegate as QuestionPromptDelegate, # noqa: F401 — re-exported -) -from kimi_cli.ui.shell.question_panel import ( - QuestionRequestPanel, - prompt_other_input, - show_question_body_in_pager, -) -from kimi_cli.utils.aioqueue import Queue, QueueShutDown -from kimi_cli.utils.logging import logger -from kimi_cli.utils.rich.columns import BulletColumns -from kimi_cli.utils.rich.diff_render import ( - collect_diff_hunks, - render_diff_panel, - render_diff_summary_panel, -) -from kimi_cli.utils.rich.markdown import Markdown -from kimi_cli.wire import WireUISide -from kimi_cli.wire.types import ( - ApprovalRequest, - ApprovalResponse, - BackgroundTaskDisplayBlock, - BriefDisplayBlock, - CompactionBegin, - CompactionEnd, - ContentPart, - DiffDisplayBlock, - MCPLoadingBegin, - MCPLoadingEnd, - Notification, - PlanDisplay, - QuestionRequest, - StatusUpdate, - SteerInput, - StepBegin, - StepInterrupted, - SubagentEvent, - TextPart, - ThinkPart, - TodoDisplayBlock, - ToolCall, - ToolCallPart, - ToolCallRequest, - ToolResult, - ToolReturnValue, - TurnBegin, - TurnEnd, - WireMessage, -) - -MAX_SUBAGENT_TOOL_CALLS_TO_SHOW = 4 -MAX_LIVE_NOTIFICATIONS = 4 -EXTERNAL_MESSAGE_GRACE_S = 0.1 - - -async def visualize( - wire: WireUISide, - *, - initial_status: StatusUpdate, - cancel_event: asyncio.Event | None = None, - prompt_session: CustomPromptSession | None = None, - steer: Callable[[str | list[ContentPart]], None] | None = None, - bind_running_input: Callable[[Callable[[UserInput], None], Callable[[], None]], None] - | None = None, - unbind_running_input: Callable[[], None] | None = None, - on_view_ready: Callable[[Any], None] | None = None, - on_view_closed: Callable[[], None] | None = None, -): - """ - A loop to consume agent events and visualize the agent behavior. - - Args: - wire: Communication channel with the agent - initial_status: Initial status snapshot - cancel_event: Event that can be set (e.g., by ESC key) to cancel the run - """ - if prompt_session is not None and steer is not None: - view = _PromptLiveView( - initial_status, - prompt_session=prompt_session, - steer=steer, - cancel_event=cancel_event, - ) - prompt_session.attach_running_prompt(view) - - def _cancel_running_input() -> None: - if cancel_event is not None: - cancel_event.set() - - if bind_running_input is not None: - bind_running_input(view.handle_local_input, _cancel_running_input) - else: - view = _LiveView(initial_status, cancel_event) - if on_view_ready is not None: - on_view_ready(view) - try: - await view.visualize_loop(wire) - finally: - if prompt_session is not None and steer is not None: - if unbind_running_input is not None: - unbind_running_input() - assert isinstance(view, _PromptLiveView) - prompt_session.detach_running_prompt(view) - if on_view_closed is not None: - on_view_closed() - - -_THINKING_PREVIEW_LINES = 6 -_PENDING_PREVIEW_LINES = 8 -_SELF_CLOSING_BLOCKS = frozenset(("fence", "code_block", "hr", "html_block")) -_ELLIPSIS = "..." - - -def _truncate_to_display_width(line: str, max_width: int) -> str: - """Truncate *line* so its terminal display width fits within *max_width*. - - Uses ``rich.cells.cell_len`` for CJK-aware column width measurement. - """ - from rich.cells import cell_len - - if cell_len(line) <= max_width: - return line - ellipsis_width = cell_len(_ELLIPSIS) - budget = max_width - ellipsis_width - width = 0 - for i, ch in enumerate(line): - width += cell_len(ch) - if width > budget: - return line[:i] + _ELLIPSIS - return line - - -# Lazy-initialized markdown-it parser for incremental token commitment. -_md_parser: MarkdownIt | None = None - - -def _get_md_parser() -> MarkdownIt: - global _md_parser - if _md_parser is None: - from markdown_it import MarkdownIt - - # Match the extensions used by the rendering path (utils/rich/markdown.py) - # so that block boundaries are detected consistently. - _md_parser = MarkdownIt().enable("strikethrough").enable("table") - return _md_parser - - -def _estimate_tokens(text: str) -> float: - """Estimate token count for mixed CJK/Latin text. - - Returns a **float** so that callers can accumulate across small chunks - without per-chunk floor truncation (e.g. a 3-char ASCII chunk would - yield 0 if truncated to int immediately, but 0.75 as float). - - Heuristics based on common BPE tokenizers (cl100k, o200k): - - CJK ideographs: ~1.5 tokens per character (often split into 2-byte pieces) - - Latin / ASCII: ~1 token per 4 characters (words average ~4 chars) - """ - cjk = 0 - other = 0 - for ch in text: - cp = ord(ch) - if ( - 0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs - or 0x3400 <= cp <= 0x4DBF # CJK Extension A - or 0xF900 <= cp <= 0xFAFF # CJK Compatibility Ideographs - or 0x3000 <= cp <= 0x303F # CJK Symbols and Punctuation - or 0xFF00 <= cp <= 0xFFEF # Fullwidth Forms - ): - cjk += 1 - else: - other += 1 - return cjk * 1.5 + other / 4 - - -def _find_committed_boundary(text: str) -> int | None: - """Return the character offset up to which *text* can be safely committed. - - Uses the incremental token commitment algorithm: parse text into block-level - tokens via ``markdown-it-py``, confirm all blocks except the last one (which - may be incomplete due to streaming truncation). - - Returns ``None`` when there are fewer than 2 blocks (nothing to confirm yet). - """ - md = _get_md_parser() - tokens = md.parse(text) - - # Collect only TOP-LEVEL block boundaries by tracking nesting depth. - # Nested tokens (e.g. list_item_open inside bullet_list_open) must not be - # treated as independent blocks — otherwise lists and blockquotes get split. - block_maps: list[list[int]] = [] - depth = 0 - for t in tokens: - if t.nesting == 1: - if depth == 0 and t.map is not None: - block_maps.append(t.map) - depth += 1 - elif t.nesting == -1: - depth -= 1 - elif depth == 0 and t.type in _SELF_CLOSING_BLOCKS and t.map is not None: - block_maps.append(t.map) - - if len(block_maps) < 2: - return None - - # Convert end-line number to character offset by scanning newlines. - target_line = block_maps[-2][1] - offset = 0 - for _ in range(target_line): - offset = text.index("\n", offset) + 1 - return offset - - -def _tail_lines(text: str, n: int) -> str: - """Extract the last *n* lines from *text* via reverse scanning (O(n)).""" - pos = len(text) - for _ in range(n): - pos = text.rfind("\n", 0, pos) - if pos == -1: - return text - return text[pos + 1 :] - - -class _ContentBlock: - """Streaming content block with incremental markdown commitment. - - For **composing** (``is_think=False``), confirmed markdown blocks are flushed - to the terminal permanently via ``console.print()`` as they become complete, - giving users real-time streaming output. Only the unconfirmed tail remains - in the transient Rich Live area. - - For **thinking** (``is_think=True``), content stays in the Live area as a - scrolling preview until the block is finalized. - """ - - def __init__(self, is_think: bool): - self.is_think = is_think - self._spinner = Spinner("dots", "") - self.raw_text = "" - # Accumulated float estimate — avoids per-chunk int truncation. - self._token_count: float = 0.0 - self._start_time = time.monotonic() - # Incremental commitment state (composing only). - self._committed_len = 0 - self._has_printed_bullet = False - - # -- Public API ---------------------------------------------------------- - - def append(self, content: str) -> None: - self.raw_text += content - self._token_count += _estimate_tokens(content) - # Block boundaries require newlines; skip parse for mid-line chunks. - if not self.is_think and "\n" in content: - self._flush_committed() - - def compose(self) -> RenderableType: - """Render the transient Live area content.""" - pending = self._pending_text() - - # Thinking: always show spinner + preview. - if self.is_think: - spinner = self._compose_spinner() - if not pending: - return spinner - preview = self._build_preview(pending) - return Group(spinner, Text(preview, style="grey50 italic")) - - # Composing: always show spinner with elapsed time and token count. - # Committed blocks are already printed permanently above. - return self._compose_spinner() - - def compose_final(self) -> RenderableType: - """Render the remaining uncommitted content when the block ends.""" - remaining = self._pending_text() - if not remaining: - return Text("") - if self.is_think: - return BulletColumns( - Markdown(remaining, style="grey50 italic"), - bullet_style="grey50", - ) - return self._wrap_bullet(Markdown(remaining)) - - def has_pending(self) -> bool: - """Whether there is uncommitted content to flush.""" - return bool(self._pending_text()) - - # -- Private ------------------------------------------------------------- - - def _pending_text(self) -> str: - return self.raw_text[self._committed_len :] - - def _wrap_bullet(self, renderable: RenderableType) -> BulletColumns: - """First call gets the ``•`` bullet; subsequent calls get a space.""" - if self._has_printed_bullet: - return BulletColumns(renderable, bullet=Text(" ")) - self._has_printed_bullet = True - return BulletColumns(renderable) - - def _flush_committed(self) -> None: - """Commit confirmed markdown blocks to permanent terminal output.""" - pending = self._pending_text() - if not pending: - return - boundary = _find_committed_boundary(pending) - if boundary is None: - return - committed_text = pending[:boundary] - console.print(self._wrap_bullet(Markdown(committed_text))) - self._committed_len += boundary - - def _compose_spinner(self) -> Spinner: - elapsed = time.monotonic() - self._start_time - label = "Thinking..." if self.is_think else "Composing..." - elapsed_str = f"{int(elapsed)}s" if elapsed >= 1 else "<1s" - count_str = f"{format_token_count(int(self._token_count))} tokens" - - self._spinner.text = Text.assemble( - (label, ""), - (f" {elapsed_str}", "grey50"), - (f" · {count_str}", "grey50"), - ) - return self._spinner - - def _build_preview(self, text: str) -> str: - max_lines = _THINKING_PREVIEW_LINES if self.is_think else _PENDING_PREVIEW_LINES - max_width = console.width - 2 if console.width else 78 - tail_text = _tail_lines(text, max_lines) - lines = tail_text.split("\n") - return "\n".join(_truncate_to_display_width(line, max_width) for line in lines) - - -class _ToolCallBlock: - class FinishedSubCall(NamedTuple): - call: ToolCall - result: ToolReturnValue - - def __init__(self, tool_call: ToolCall): - self._tool_name = tool_call.function.name - self._lexer = streamingjson.Lexer() - if tool_call.function.arguments is not None: - self._lexer.append_string(tool_call.function.arguments) - - self._argument = extract_key_argument(self._lexer, self._tool_name) - self._full_url = self._extract_full_url(tool_call.function.arguments, self._tool_name) - self._result: ToolReturnValue | None = None - self._subagent_id: str | None = None - self._subagent_type: str | None = None - - self._ongoing_subagent_tool_calls: dict[str, ToolCall] = {} - self._last_subagent_tool_call: ToolCall | None = None - self._n_finished_subagent_tool_calls = 0 - self._finished_subagent_tool_calls = deque[_ToolCallBlock.FinishedSubCall]( - maxlen=MAX_SUBAGENT_TOOL_CALLS_TO_SHOW - ) - - self._spinning_dots = Spinner("dots", text="") - self._renderable: RenderableType = self._compose() - - def compose(self) -> RenderableType: - return self._renderable - - @property - def finished(self) -> bool: - return self._result is not None - - def append_args_part(self, args_part: str): - if self.finished: - return - self._lexer.append_string(args_part) - # TODO: maybe don't extract detail if it's already stable - argument = extract_key_argument(self._lexer, self._tool_name) - if argument and argument != self._argument: - self._argument = argument - self._full_url = self._extract_full_url(self._lexer.complete_json(), self._tool_name) - self._renderable = BulletColumns( - self._build_headline_text(), - bullet=self._spinning_dots, - ) - - def finish(self, result: ToolReturnValue): - self._result = result - self._renderable = self._compose() - - def append_sub_tool_call(self, tool_call: ToolCall): - self._ongoing_subagent_tool_calls[tool_call.id] = tool_call - self._last_subagent_tool_call = tool_call - - def append_sub_tool_call_part(self, tool_call_part: ToolCallPart): - if self._last_subagent_tool_call is None: - return - if not tool_call_part.arguments_part: - return - if self._last_subagent_tool_call.function.arguments is None: - self._last_subagent_tool_call.function.arguments = tool_call_part.arguments_part - else: - self._last_subagent_tool_call.function.arguments += tool_call_part.arguments_part - - def finish_sub_tool_call(self, tool_result: ToolResult): - self._last_subagent_tool_call = None - sub_tool_call = self._ongoing_subagent_tool_calls.pop(tool_result.tool_call_id, None) - if sub_tool_call is None: - return - - self._finished_subagent_tool_calls.append( - _ToolCallBlock.FinishedSubCall( - call=sub_tool_call, - result=tool_result.return_value, - ) - ) - self._n_finished_subagent_tool_calls += 1 - self._renderable = self._compose() - - def set_subagent_metadata(self, agent_id: str, subagent_type: str) -> None: - changed = (self._subagent_id, self._subagent_type) != (agent_id, subagent_type) - self._subagent_id = agent_id - self._subagent_type = subagent_type - if changed: - self._renderable = self._compose() - - def _compose(self) -> RenderableType: - lines: list[RenderableType] = [ - self._build_headline_text(), - ] - if self._subagent_id is not None and self._subagent_type is not None: - lines.append( - BulletColumns( - Text( - f"subagent {self._subagent_type} ({self._subagent_id})", - style="grey50", - ), - bullet_style="grey50", - ) - ) - - if self._n_finished_subagent_tool_calls > MAX_SUBAGENT_TOOL_CALLS_TO_SHOW: - n_hidden = self._n_finished_subagent_tool_calls - MAX_SUBAGENT_TOOL_CALLS_TO_SHOW - lines.append( - BulletColumns( - Text( - f"{n_hidden} more tool call{'s' if n_hidden > 1 else ''} ...", - style="grey50 italic", - ), - bullet_style="grey50", - ) - ) - for sub_call, sub_result in self._finished_subagent_tool_calls: - argument = extract_key_argument( - sub_call.function.arguments or "", sub_call.function.name - ) - sub_url = self._extract_full_url(sub_call.function.arguments, sub_call.function.name) - sub_text = Text() - sub_text.append("Used ") - sub_text.append(sub_call.function.name, style="blue") - if argument: - sub_text.append(" (", style="grey50") - arg_style = Style(color="grey50", link=sub_url) if sub_url else "grey50" - sub_text.append(argument, style=arg_style) - sub_text.append(")", style="grey50") - lines.append( - BulletColumns( - sub_text, - bullet_style="green" if not sub_result.is_error else "dark_red", - ) - ) - - if self._result is not None: - display = self._result.display - idx = 0 - while idx < len(display): - block = display[idx] - if isinstance(block, DiffDisplayBlock): - # Collect consecutive same-file diff blocks - path = block.path - diff_blocks: list[DiffDisplayBlock] = [] - while idx < len(display): - b = display[idx] - if not isinstance(b, DiffDisplayBlock) or b.path != path: - break - diff_blocks.append(b) - idx += 1 - if any(b.is_summary for b in diff_blocks): - lines.append(render_diff_summary_panel(path, diff_blocks)) - else: - hunks, added_total, removed_total = collect_diff_hunks(diff_blocks) - if hunks: - lines.append(render_diff_panel(path, hunks, added_total, removed_total)) - elif isinstance(block, BriefDisplayBlock): - style = "grey50" if not self._result.is_error else "dark_red" - if block.text: - lines.append(Markdown(block.text, style=style)) - idx += 1 - elif isinstance(block, TodoDisplayBlock): - markdown = self._render_todo_markdown(block) - if markdown: - lines.append(Markdown(markdown, style="grey50")) - idx += 1 - elif isinstance(block, BackgroundTaskDisplayBlock): - lines.append( - Markdown( - (f"`{block.task_id}` [{block.status}] {block.description}"), - style="grey50", - ) - ) - idx += 1 - else: - idx += 1 - - if self.finished: - assert self._result is not None - return BulletColumns( - Group(*lines), - bullet_style="green" if not self._result.is_error else "dark_red", - ) - else: - return BulletColumns( - Group(*lines), - bullet=self._spinning_dots, - ) - - @staticmethod - def _extract_full_url(arguments: str | None, tool_name: str) -> str | None: - """Extract the full URL from FetchURL tool arguments.""" - if tool_name != "FetchURL" or not arguments: - return None - try: - args = json.loads(arguments, strict=False) - except (json.JSONDecodeError, TypeError): - return None - if isinstance(args, dict): - url = cast(dict[str, Any], args).get("url") - if url: - return str(url) - return None - - def _build_headline_text(self) -> Text: - text = Text() - text.append("Used " if self.finished else "Using ") - text.append(self._tool_name, style="blue") - if self._argument: - text.append(" (", style="grey50") - arg_style = Style(color="grey50", link=self._full_url) if self._full_url else "grey50" - text.append(self._argument, style=arg_style) - text.append(")", style="grey50") - return text - - def _render_todo_markdown(self, block: TodoDisplayBlock) -> str: - lines: list[str] = [] - for todo in block.items: - normalized = todo.status.replace("_", " ").lower() - match normalized: - case "pending": - lines.append(f"- {todo.title}") - case "in progress": - lines.append(f"- {todo.title} ←") - case "done": - lines.append(f"- ~~{todo.title}~~") - case _: - lines.append(f"- {todo.title}") - return "\n".join(lines) - - -class _NotificationBlock: - _SEVERITY_STYLE = { - "info": "cyan", - "success": "green", - "warning": "yellow", - "error": "red", - } - - def __init__(self, notification: Notification): - self.notification = notification - - def compose(self) -> RenderableType: - style = self._SEVERITY_STYLE.get(self.notification.severity, "cyan") - lines: list[RenderableType] = [Text(self.notification.title, style=f"bold {style}")] - body = self.notification.body.strip() - if body: - body_lines = body.splitlines() - preview = "\n".join(body_lines[:2]) - if len(body_lines) > 2: - preview += "\n..." - lines.append(Text(preview, style="grey50")) - return BulletColumns(Group(*lines), bullet_style=style) - - -class _StatusBlock: - def __init__(self, initial: StatusUpdate) -> None: - self.text = Text("", justify="right") - self._context_usage: float = 0.0 - self._context_tokens: int = 0 - self._max_context_tokens: int = 0 - self.update(initial) - - def render(self) -> RenderableType: - return self.text - - def update(self, status: StatusUpdate) -> None: - if status.context_usage is not None: - self._context_usage = status.context_usage - if status.context_tokens is not None: - self._context_tokens = status.context_tokens - if status.max_context_tokens is not None: - self._max_context_tokens = status.max_context_tokens - if status.context_usage is not None: - self.text.plain = format_context_status( - self._context_usage, - self._context_tokens, - self._max_context_tokens, - ) - - -@asynccontextmanager -async def _keyboard_listener( - handler: Callable[[KeyboardListener, KeyEvent], Awaitable[None]], -): - listener = KeyboardListener() - await listener.start() - - async def _keyboard(): - while True: - event = await listener.get() - await handler(listener, event) - - task = asyncio.create_task(_keyboard()) - try: - yield - finally: - task.cancel() - with suppress(asyncio.CancelledError): - await task - await listener.stop() - - -class _LiveView: - def __init__(self, initial_status: StatusUpdate, cancel_event: asyncio.Event | None = None): - self._cancel_event = cancel_event - - self._mooning_spinner: Spinner | None = None - self._compacting_spinner: Spinner | None = None - self._mcp_loading_spinner: Spinner | None = None - - self._current_content_block: _ContentBlock | None = None - self._tool_call_blocks: dict[str, _ToolCallBlock] = {} - self._last_tool_call_block: _ToolCallBlock | None = None - self._approval_request_queue = deque[ApprovalRequest]() - """ - It is possible that multiple subagents request approvals at the same time, - in which case we will have to queue them up and show them one by one. - """ - self._current_approval_request_panel: ApprovalRequestPanel | None = None - self._question_request_queue = deque[QuestionRequest]() - self._current_question_panel: QuestionRequestPanel | None = None - self._notification_blocks = deque[_NotificationBlock]() - self._live_notification_blocks = deque[_NotificationBlock](maxlen=MAX_LIVE_NOTIFICATIONS) - self._status_block = _StatusBlock(initial_status) - - self._need_recompose = False - self._external_messages: Queue[WireMessage] = Queue() - - def _reset_live_shape(self, live: Live) -> None: - # Rich doesn't expose a public API to clear Live's cached render height. - # After leaving the pager, stale height causes cursor restores to jump, - # so we reset the private _shape to re-anchor the next refresh. - live._live_render._shape = None # type: ignore[reportPrivateUsage] - - async def _drain_external_message_after_wire_shutdown( - self, - external_task: asyncio.Task[WireMessage], - ) -> tuple[WireMessage | None, asyncio.Task[WireMessage]]: - try: - msg = await asyncio.wait_for( - asyncio.shield(external_task), - timeout=EXTERNAL_MESSAGE_GRACE_S, - ) - except (TimeoutError, QueueShutDown): - return None, external_task - return msg, asyncio.create_task(self._external_messages.get()) - - async def visualize_loop(self, wire: WireUISide): - with Live( - self.compose(), - console=console, - refresh_per_second=10, - transient=True, - vertical_overflow="visible", - ) as live: - - async def keyboard_handler(listener: KeyboardListener, event: KeyEvent) -> None: - # Handle Ctrl+E specially - pause Live while the pager is active - if event == KeyEvent.CTRL_E: - if self.has_expandable_panel(): - await listener.pause() - live.stop() - try: - self._show_expandable_panel_content() - finally: - # Reset live render shape so the next refresh re-anchors cleanly. - self._reset_live_shape(live) - live.start() - live.update(self.compose(), refresh=True) - await listener.resume() - return - - # Handle ENTER/SPACE on question panel when "Other" is selected - if self._should_prompt_question_other_for_key(event): - panel = self._current_question_panel - assert panel is not None - question_text = panel.current_question_text - await listener.pause() - live.stop() - try: - text = await prompt_other_input(question_text) - finally: - self._reset_live_shape(live) - live.start() - await listener.resume() - - self._submit_question_other_text(text) - live.update(self.compose(), refresh=True) - return - - self.dispatch_keyboard_event(event) - if self._need_recompose: - live.update(self.compose(), refresh=True) - self._need_recompose = False - - async with _keyboard_listener(keyboard_handler): - wire_task = asyncio.create_task(wire.receive()) - external_task = asyncio.create_task(self._external_messages.get()) - while True: - try: - done, _ = await asyncio.wait( - [wire_task, external_task], - return_when=asyncio.FIRST_COMPLETED, - ) - if wire_task in done: - msg = wire_task.result() - wire_task = asyncio.create_task(wire.receive()) - else: - msg = external_task.result() - external_task = asyncio.create_task(self._external_messages.get()) - except QueueShutDown: - msg, external_task = await self._drain_external_message_after_wire_shutdown( - external_task - ) - if msg is not None: - self.dispatch_wire_message(msg) - if self._need_recompose: - live.update(self.compose(), refresh=True) - self._need_recompose = False - continue - self.cleanup(is_interrupt=False) - live.update(self.compose(), refresh=True) - break - - if isinstance(msg, StepInterrupted): - self.cleanup(is_interrupt=True) - live.update(self.compose(), refresh=True) - break - - self.dispatch_wire_message(msg) - if self._need_recompose: - live.update(self.compose(), refresh=True) - self._need_recompose = False - wire_task.cancel() - external_task.cancel() - self._external_messages.shutdown(immediate=True) - with suppress(asyncio.CancelledError, QueueShutDown): - await wire_task - with suppress(asyncio.CancelledError, QueueShutDown): - await external_task - - def refresh_soon(self) -> None: - self._need_recompose = True - - def _on_question_panel_state_changed(self) -> None: - """Hook for subclasses to react when question panel visibility changes.""" - return None - - def enqueue_external_message(self, msg: WireMessage) -> None: - try: - self._external_messages.put_nowait(msg) - except QueueShutDown: - logger.debug("Ignoring external wire message after live view shutdown: {msg}", msg=msg) - - def has_expandable_panel(self) -> bool: - return ( - self._expandable_approval_panel() is not None - or self._expandable_question_panel() is not None - ) - - def _expandable_approval_panel(self) -> ApprovalRequestPanel | None: - panel = self._current_approval_request_panel - if panel is not None and panel.has_expandable_content: - return panel - return None - - def _expandable_question_panel(self) -> QuestionRequestPanel | None: - panel = self._current_question_panel - if panel is not None and panel.has_expandable_content: - return panel - return None - - def _show_expandable_panel_content(self) -> bool: - if approval_panel := self._expandable_approval_panel(): - show_approval_in_pager(approval_panel) - return True - if question_panel := self._expandable_question_panel(): - show_question_body_in_pager(question_panel) - return True - return False - - def _should_prompt_question_other_for_key(self, key: KeyEvent) -> bool: - panel = self._current_question_panel - if panel is None or not panel.should_prompt_other_input(): - return False - return key == KeyEvent.ENTER or (key == KeyEvent.SPACE and not panel.is_multi_select) - - def _submit_question_other_text(self, text: str) -> None: - panel = self._current_question_panel - if panel is None: - return - - all_done = panel.submit_other(text) - if all_done: - panel.request.resolve(panel.get_answers()) - self.show_next_question_request() - self.refresh_soon() - - def compose(self, *, include_status: bool = True) -> RenderableType: - """Compose the live view display content. - - Approval and question panels are rendered first so they remain visible - at the top of the terminal even when tool-call output is long enough - to push content beyond the visible area. - """ - blocks: list[RenderableType] = [] - # Approval/question panels first — highest visual priority. - if self._current_approval_request_panel: - blocks.append(self._current_approval_request_panel.render()) - if self._current_question_panel: - blocks.append(self._current_question_panel.render()) - # Spinners or content + tool calls. - if self._mcp_loading_spinner is not None: - blocks.append(self._mcp_loading_spinner) - elif self._mooning_spinner is not None: - blocks.append(self._mooning_spinner) - elif self._compacting_spinner is not None: - blocks.append(self._compacting_spinner) - else: - if self._current_content_block is not None: - blocks.append(self._current_content_block.compose()) - for tool_call in self._tool_call_blocks.values(): - blocks.append(tool_call.compose()) - for notification in self._live_notification_blocks: - blocks.append(notification.compose()) - - if include_status: - blocks.append(self._status_block.render()) - return Group(*blocks) - - def dispatch_wire_message(self, msg: WireMessage) -> None: - """Dispatch the Wire message to UI components.""" - assert not isinstance(msg, StepInterrupted) # handled in visualize_loop - - if isinstance(msg, StepBegin): - self.cleanup(is_interrupt=False) - self._mcp_loading_spinner = None - self._mooning_spinner = Spinner("moon", "") - self.refresh_soon() - return - - if self._mooning_spinner is not None: - # any message other than StepBegin should end the mooning state - self._mooning_spinner = None - self.refresh_soon() - - match msg: - case TurnBegin(): - self.flush_content() - case SteerInput(user_input=user_input): - self.cleanup(is_interrupt=False) - content: list[ContentPart] - if isinstance(user_input, list): - content = list(user_input) - else: - content = [TextPart(text=user_input)] - console.print(render_user_echo(Message(role="user", content=content))) - case TurnEnd(): - pass - case CompactionBegin(): - self._compacting_spinner = Spinner("balloon", "Compacting...") - self.refresh_soon() - case CompactionEnd(): - self._compacting_spinner = None - self.refresh_soon() - case MCPLoadingBegin(): - self._mcp_loading_spinner = Spinner("dots", "Connecting to MCP servers...") - self.refresh_soon() - case MCPLoadingEnd(): - self._mcp_loading_spinner = None - self.refresh_soon() - case StatusUpdate(): - self._status_block.update(msg) - case Notification(): - self.append_notification(msg) - case ContentPart(): - self.append_content(msg) - case ToolCall(): - self.append_tool_call(msg) - case ToolCallPart(): - self.append_tool_call_part(msg) - case ToolResult(): - self.append_tool_result(msg) - case ApprovalResponse(): - self._reconcile_approval_requests() - case SubagentEvent(): - self.handle_subagent_event(msg) - case PlanDisplay(): - self.display_plan(msg) - case ApprovalRequest(): - self.request_approval(msg) - case QuestionRequest(): - self.request_question(msg) - case ToolCallRequest(): - logger.warning("Unexpected ToolCallRequest in shell UI: {msg}", msg=msg) - case _: - pass - - def _try_submit_question(self) -> None: - """Submit the current question answer; if all done, resolve and advance.""" - panel = self._current_question_panel - if panel is None: - return - all_done = panel.submit() - if all_done: - panel.request.resolve(panel.get_answers()) - self.show_next_question_request() - - def dispatch_keyboard_event(self, event: KeyEvent) -> None: - # Handle question panel keyboard events - if self._current_question_panel is not None: - match event: - case KeyEvent.UP: - self._current_question_panel.move_up() - case KeyEvent.DOWN: - self._current_question_panel.move_down() - case KeyEvent.LEFT: - self._current_question_panel.prev_tab() - case KeyEvent.RIGHT | KeyEvent.TAB: - self._current_question_panel.next_tab() - case KeyEvent.SPACE: - if self._current_question_panel.is_multi_select: - self._current_question_panel.toggle_select() - else: - self._try_submit_question() - case KeyEvent.ENTER: - # "Other" is handled in keyboard_handler (async context) - self._try_submit_question() - case KeyEvent.ESCAPE: - self._current_question_panel.request.resolve({}) - self.show_next_question_request() - case ( - KeyEvent.NUM_1 - | KeyEvent.NUM_2 - | KeyEvent.NUM_3 - | KeyEvent.NUM_4 - | KeyEvent.NUM_5 - | KeyEvent.NUM_6 - ): - # Number keys select option in question panel - num_map = { - KeyEvent.NUM_1: 0, - KeyEvent.NUM_2: 1, - KeyEvent.NUM_3: 2, - KeyEvent.NUM_4: 3, - KeyEvent.NUM_5: 4, - KeyEvent.NUM_6: 5, - } - idx = num_map[event] - panel = self._current_question_panel - if panel.select_index(idx): - if panel.is_multi_select: - panel.toggle_select() - elif not panel.is_other_selected: - # Auto-submit for single-select (unless "Other") - self._try_submit_question() - case _: - pass - self.refresh_soon() - return - - # handle ESC key to cancel the run - if event == KeyEvent.ESCAPE and self._cancel_event is not None: - self._cancel_event.set() - return - - # Handle approval panel keyboard events - if self._current_approval_request_panel is not None: - match event: - case KeyEvent.UP: - self._current_approval_request_panel.move_up() - self.refresh_soon() - case KeyEvent.DOWN: - self._current_approval_request_panel.move_down() - self.refresh_soon() - case KeyEvent.ENTER: - self._submit_approval() - case KeyEvent.NUM_1 | KeyEvent.NUM_2 | KeyEvent.NUM_3 | KeyEvent.NUM_4: - # Number keys directly select and submit approval option - num_map = { - KeyEvent.NUM_1: 0, - KeyEvent.NUM_2: 1, - KeyEvent.NUM_3: 2, - KeyEvent.NUM_4: 3, - } - idx = num_map[event] - if idx < len(self._current_approval_request_panel.options): - self._current_approval_request_panel.selected_index = idx - self._submit_approval() - case _: - pass - return - - def _submit_approval(self) -> None: - """Submit the currently selected approval response.""" - assert self._current_approval_request_panel is not None - request = self._current_approval_request_panel.request - resp = self._current_approval_request_panel.get_selected_response() - request.resolve(resp) - if resp == "approve_for_session": - to_remove_from_queue: list[ApprovalRequest] = [] - for request in self._approval_request_queue: - # approve all queued requests with the same action - if request.action == self._current_approval_request_panel.request.action: - request.resolve("approve_for_session") - to_remove_from_queue.append(request) - for request in to_remove_from_queue: - self._approval_request_queue.remove(request) - self.show_next_approval_request() - - def cleanup(self, is_interrupt: bool) -> None: - """Cleanup the live view on step end or interruption.""" - self.flush_content() - - for block in self._tool_call_blocks.values(): - if not block.finished: - # this should not happen, but just in case - block.finish( - ToolError(message="", brief="Interrupted") - if is_interrupt - else ToolOk(output="") - ) - self._last_tool_call_block = None - self.flush_finished_tool_calls() - self.flush_notifications() - - while self._approval_request_queue: - # should not happen, but just in case - self._approval_request_queue.popleft().resolve("reject") - self._current_approval_request_panel = None - - while self._question_request_queue: - self._question_request_queue.popleft().resolve({}) - self._current_question_panel = None - - def flush_content(self) -> None: - """Flush the current content block.""" - if self._current_content_block is not None: - if self._current_content_block.has_pending(): - console.print(self._current_content_block.compose_final()) - self._current_content_block = None - self.refresh_soon() - - def flush_finished_tool_calls(self) -> None: - """Flush all leading finished tool call blocks.""" - tool_call_ids = list(self._tool_call_blocks.keys()) - for tool_call_id in tool_call_ids: - block = self._tool_call_blocks[tool_call_id] - if not block.finished: - break - - self._tool_call_blocks.pop(tool_call_id) - console.print(block.compose()) - if self._last_tool_call_block == block: - self._last_tool_call_block = None - self.refresh_soon() - - def flush_notifications(self) -> None: - """Flush rendered notifications to terminal history.""" - self._live_notification_blocks.clear() - while self._notification_blocks: - console.print(self._notification_blocks.popleft().compose()) - self.refresh_soon() - - def append_content(self, part: ContentPart) -> None: - match part: - case ThinkPart(think=text) | TextPart(text=text): - if not text: - return - is_think = isinstance(part, ThinkPart) - if self._current_content_block is None: - self._current_content_block = _ContentBlock(is_think) - self.refresh_soon() - elif self._current_content_block.is_think != is_think: - self.flush_content() - self._current_content_block = _ContentBlock(is_think) - self.refresh_soon() - self._current_content_block.append(text) - self.refresh_soon() - case _: - # TODO: support more content part types - pass - - def append_tool_call(self, tool_call: ToolCall) -> None: - self.flush_content() - self._tool_call_blocks[tool_call.id] = _ToolCallBlock(tool_call) - self._last_tool_call_block = self._tool_call_blocks[tool_call.id] - self.refresh_soon() - - def append_tool_call_part(self, part: ToolCallPart) -> None: - if not part.arguments_part: - return - if self._last_tool_call_block is None: - return - self._last_tool_call_block.append_args_part(part.arguments_part) - self.refresh_soon() - - def append_tool_result(self, result: ToolResult) -> None: - if block := self._tool_call_blocks.get(result.tool_call_id): - block.finish(result.return_value) - self.flush_finished_tool_calls() - self.refresh_soon() - - def append_notification(self, notification: Notification) -> None: - block = _NotificationBlock(notification) - self._notification_blocks.append(block) - self._live_notification_blocks.append(block) - self.refresh_soon() - - def request_approval(self, request: ApprovalRequest) -> None: - self._approval_request_queue.append(request) - - if self._current_approval_request_panel is None: - console.bell() - self.show_next_approval_request() - - def _reconcile_approval_requests(self) -> None: - self._approval_request_queue = deque( - request for request in self._approval_request_queue if not request.resolved - ) - if ( - self._current_approval_request_panel is not None - and self._current_approval_request_panel.request.resolved - ): - self._current_approval_request_panel = None - self.show_next_approval_request() - else: - self.refresh_soon() - - def show_next_approval_request(self) -> None: - """ - Show the next approval request from the queue. - If there are no pending requests, clear the current approval panel. - """ - if not self._approval_request_queue: - if self._current_approval_request_panel is not None: - self._current_approval_request_panel = None - self.refresh_soon() - return - - while self._approval_request_queue: - request = self._approval_request_queue.popleft() - if request.resolved: - # skip resolved requests - continue - self._current_approval_request_panel = ApprovalRequestPanel(request) - self.refresh_soon() - break - else: - # All queued requests were already resolved - if self._current_approval_request_panel is not None: - self._current_approval_request_panel = None - self.refresh_soon() - - def display_plan(self, msg: PlanDisplay) -> None: - """Render plan content inline in the chat with a bordered panel.""" - self.flush_content() - self.flush_finished_tool_calls() - plan_body = Markdown(msg.content) - subtitle = Text(msg.file_path, style="dim") - panel = Panel( - plan_body, - title="[bold cyan]Plan[/bold cyan]", - title_align="left", - subtitle=subtitle, - subtitle_align="left", - border_style="cyan", - padding=(1, 2), - ) - console.print(panel) - - def request_question(self, request: QuestionRequest) -> None: - self._question_request_queue.append(request) - if self._current_question_panel is None: - console.bell() - self.show_next_question_request() - - def show_next_question_request(self) -> None: - """Show the next question request from the queue.""" - if not self._question_request_queue: - if self._current_question_panel is not None: - self._current_question_panel = None - self.refresh_soon() - self._on_question_panel_state_changed() - return - - while self._question_request_queue: - request = self._question_request_queue.popleft() - if request.resolved: - continue - self._current_question_panel = QuestionRequestPanel(request) - self.refresh_soon() - self._on_question_panel_state_changed() - break - else: - # All queued requests were already resolved - if self._current_question_panel is not None: - self._current_question_panel = None - self.refresh_soon() - self._on_question_panel_state_changed() - - def handle_subagent_event(self, event: SubagentEvent) -> None: - if event.parent_tool_call_id is None: - return - block = self._tool_call_blocks.get(event.parent_tool_call_id) - if block is None: - return - if event.agent_id is not None and event.subagent_type is not None: - block.set_subagent_metadata(event.agent_id, event.subagent_type) - - match event.event: - case ToolCall() as tool_call: - block.append_sub_tool_call(tool_call) - case ToolCallPart() as tool_call_part: - block.append_sub_tool_call_part(tool_call_part) - case ToolResult() as tool_result: - block.finish_sub_tool_call(tool_result) - self.refresh_soon() - case _: - # ignore other events for now - # TODO: may need to handle multi-level nested subagents - pass - - -class _PromptLiveView(_LiveView): - modal_priority = 0 - - def __init__( - self, - initial_status: StatusUpdate, - *, - prompt_session: CustomPromptSession, - steer: Callable[[str | list[ContentPart]], None], - cancel_event: asyncio.Event | None = None, - ) -> None: - super().__init__(initial_status, cancel_event) - self._prompt_session = prompt_session - self._steer = steer - self._pending_local_steers: deque[str | list[ContentPart]] = deque() - self._turn_ended = False - self._question_modal: QuestionPromptDelegate | None = None - - async def visualize_loop(self, wire: WireUISide): - try: - wire_task = asyncio.create_task(wire.receive()) - external_task = asyncio.create_task(self._external_messages.get()) - while True: - try: - done, _ = await asyncio.wait( - [wire_task, external_task], - return_when=asyncio.FIRST_COMPLETED, - ) - if wire_task in done: - msg = wire_task.result() - wire_task = asyncio.create_task(wire.receive()) - else: - msg = external_task.result() - external_task = asyncio.create_task(self._external_messages.get()) - except QueueShutDown: - msg, external_task = await self._drain_external_message_after_wire_shutdown( - external_task - ) - if msg is not None: - self.dispatch_wire_message(msg) - self._flush_prompt_refresh() - continue - self.cleanup(is_interrupt=False) - self._flush_prompt_refresh() - break - - if isinstance(msg, StepInterrupted): - self.cleanup(is_interrupt=True) - self._flush_prompt_refresh() - break - - if isinstance(msg, TurnEnd): - self._turn_ended = True - self._flush_prompt_refresh() - continue - - self.dispatch_wire_message(msg) - self._flush_prompt_refresh() - finally: - self._external_messages.shutdown(immediate=True) - for task in (locals().get("wire_task"), locals().get("external_task")): - if task is None: - continue - task.cancel() - with suppress(asyncio.CancelledError, QueueShutDown): - await task - self._pending_local_steers.clear() - self._turn_ended = False - if self._question_modal is not None: - self._prompt_session.detach_modal(self._question_modal) - self._question_modal = None - self._prompt_session.invalidate() - - def handle_local_input(self, user_input: UserInput) -> None: - if not user_input or self._turn_ended: - return - - console.print(render_user_echo_text(user_input.command)) - self._pending_local_steers.append(list(user_input.content)) - self._steer(user_input.content) - self._flush_prompt_refresh() - - def dispatch_wire_message(self, msg: WireMessage) -> None: - if isinstance(msg, SteerInput) and self._pending_local_steers: - pending = self._pending_local_steers[0] - if pending == msg.user_input: - self._pending_local_steers.popleft() - return - super().dispatch_wire_message(msg) - - def render_running_prompt_body(self, columns: int) -> ANSI: - if ( - self._turn_ended - and self._current_approval_request_panel is None - and self._current_question_panel is None - ): - return ANSI("") - renderable = self.compose(include_status=False) - body = render_to_ansi(renderable, columns=columns).rstrip("\n") - return ANSI(body if body else "") - - def running_prompt_placeholder(self) -> str | None: - if self._current_approval_request_panel is not None: - return "Use ↑/↓ or 1/2/3, then press Enter to respond to the approval request." - return None - - def running_prompt_hides_input_buffer(self) -> bool: - return False - - def running_prompt_allows_text_input(self) -> bool: - if self._current_approval_request_panel is not None: - return False - if self._current_question_panel is not None: - return False - return not self._turn_ended - - def running_prompt_accepts_submission(self) -> bool: - if self._current_approval_request_panel is not None: - return True - if self._current_question_panel is not None: - return True - return not self._turn_ended - - def should_handle_running_prompt_key(self, key: str) -> bool: - if key == "c-e": - return self.has_expandable_panel() - if self._current_approval_request_panel is not None: - return key in {"up", "down", "enter", "1", "2", "3", "4"} - if self._turn_ended: - return False - if key == "escape": - return self._cancel_event is not None - return False - - def handle_running_prompt_key(self, key: str, event: KeyPressEvent) -> None: - if key == "c-e": - event.app.create_background_task(self._show_panel_in_pager()) - return - - mapped = { - "up": KeyEvent.UP, - "down": KeyEvent.DOWN, - "enter": KeyEvent.ENTER, - "escape": KeyEvent.ESCAPE, - "1": KeyEvent.NUM_1, - "2": KeyEvent.NUM_2, - "3": KeyEvent.NUM_3, - "4": KeyEvent.NUM_4, - }.get(key) - if mapped is None: - return - if self._current_approval_request_panel is not None: - self._clear_buffer(event.current_buffer) - self.dispatch_keyboard_event(mapped) - self._flush_prompt_refresh() - - async def _show_panel_in_pager(self) -> None: - await run_in_terminal(self._show_expandable_panel_content) - self._prompt_session.invalidate() - - @staticmethod - def _clear_buffer(buffer: Buffer) -> None: - if buffer.text: - buffer.document = Document(text="", cursor_position=0) - - def _flush_prompt_refresh(self) -> None: - if self._need_recompose: - self._prompt_session.invalidate() - self._need_recompose = False - - def cleanup(self, is_interrupt: bool) -> None: - super().cleanup(is_interrupt) - - def _on_question_panel_state_changed(self) -> None: - panel = self._current_question_panel - if panel is None: - if self._question_modal is not None: - self._prompt_session.detach_modal(self._question_modal) - self._question_modal = None - return - if self._question_modal is None: - self._question_modal = QuestionPromptDelegate( - panel, - on_advance=self._advance_question, - on_invalidate=self._flush_prompt_refresh, - buffer_text_provider=lambda: self._prompt_session._session.default_buffer.text, # pyright: ignore[reportPrivateUsage] - text_expander=self._prompt_session._get_placeholder_manager().serialize_for_history, # pyright: ignore[reportPrivateUsage] - ) - self._prompt_session.attach_modal(self._question_modal) - else: - self._question_modal.set_panel(panel) - self._prompt_session.invalidate() - - def _advance_question(self) -> QuestionRequestPanel | None: - """Advance to the next question in the queue, returning the new panel or None.""" - self.show_next_question_request() - return self._current_question_panel diff --git a/src/kimi_cli/ui/theme.py b/src/kimi_cli/ui/theme.py deleted file mode 100644 index 3397618a3..000000000 --- a/src/kimi_cli/ui/theme.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Centralized terminal color theme definitions. - -All UI-facing colors live here so that switching between dark and light -terminal themes only requires changing the active ``ThemeName``. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Literal - -from prompt_toolkit.styles import Style as PTKStyle -from rich.style import Style as RichStyle - -type ThemeName = Literal["dark", "light"] - - -# --------------------------------------------------------------------------- -# Diff colors (used by utils/rich/diff_render.py) -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True, slots=True) -class DiffColors: - add_bg: RichStyle - del_bg: RichStyle - add_hl: RichStyle - del_hl: RichStyle - - -_DIFF_DARK = DiffColors( - add_bg=RichStyle(bgcolor="#12261e"), - del_bg=RichStyle(bgcolor="#2d1214"), - add_hl=RichStyle(bgcolor="#1a4a2e"), - del_hl=RichStyle(bgcolor="#5c1a1d"), -) - -_DIFF_LIGHT = DiffColors( - add_bg=RichStyle(bgcolor="#dafbe1"), - del_bg=RichStyle(bgcolor="#ffebe9"), - add_hl=RichStyle(bgcolor="#aff5b4"), - del_hl=RichStyle(bgcolor="#ffc1c0"), -) - - -# --------------------------------------------------------------------------- -# Task browser colors (used by ui/shell/task_browser.py) -# --------------------------------------------------------------------------- - - -def _task_browser_style_dark() -> PTKStyle: - return PTKStyle.from_dict( - { - "header": "bg:#1f2937 #e5e7eb", - "header.title": "bg:#1f2937 #67e8f9 bold", - "header.meta": "bg:#1f2937 #9ca3af", - "status.running": "bg:#1f2937 #86efac bold", - "status.success": "bg:#1f2937 #86efac", - "status.warning": "bg:#1f2937 #fbbf24", - "status.error": "bg:#1f2937 #fca5a5", - "status.info": "bg:#1f2937 #93c5fd", - "task-list": "bg:#111827 #d1d5db", - "task-list.checked": "bg:#164e63 #ecfeff bold", - "frame.border": "#155e75", - "frame.label": "bg:#0f172a #67e8f9 bold", - "footer": "bg:#0f172a #cbd5e1", - "footer.key": "bg:#0f172a #67e8f9 bold", - "footer.text": "bg:#0f172a #cbd5e1", - "footer.warning": "bg:#7f1d1d #fecaca bold", - "footer.meta": "bg:#0f172a #94a3b8", - } - ) - - -def _task_browser_style_light() -> PTKStyle: - return PTKStyle.from_dict( - { - "header": "bg:#e5e7eb #1f2937", - "header.title": "bg:#e5e7eb #0e7490 bold", - "header.meta": "bg:#e5e7eb #6b7280", - "status.running": "bg:#e5e7eb #166534 bold", - "status.success": "bg:#e5e7eb #166534", - "status.warning": "bg:#e5e7eb #92400e", - "status.error": "bg:#e5e7eb #991b1b", - "status.info": "bg:#e5e7eb #1e40af", - "task-list": "bg:#f9fafb #374151", - "task-list.checked": "bg:#cffafe #164e63 bold", - "frame.border": "#0e7490", - "frame.label": "bg:#f1f5f9 #0e7490 bold", - "footer": "bg:#f1f5f9 #475569", - "footer.key": "bg:#f1f5f9 #0e7490 bold", - "footer.text": "bg:#f1f5f9 #475569", - "footer.warning": "bg:#fee2e2 #991b1b bold", - "footer.meta": "bg:#f1f5f9 #64748b", - } - ) - - -# --------------------------------------------------------------------------- -# Prompt / completion menu colors (used by ui/shell/prompt.py) -# --------------------------------------------------------------------------- - - -_PROMPT_STYLE_DARK = { - "bottom-toolbar": "noreverse", - "running-prompt-placeholder": "fg:#7c8594 italic", - "running-prompt-separator": "fg:#4a5568", - "slash-completion-menu": "", - "slash-completion-menu.separator": "fg:#4a5568", - "slash-completion-menu.marker": "fg:#4a5568", - "slash-completion-menu.marker.current": "fg:#4f9fff", - "slash-completion-menu.command": "fg:#a6adba", - "slash-completion-menu.meta": "fg:#7c8594", - "slash-completion-menu.command.current": "fg:#6fb7ff bold", - "slash-completion-menu.meta.current": "fg:#56a4ff", -} - -_PROMPT_STYLE_LIGHT = { - "bottom-toolbar": "noreverse", - "running-prompt-placeholder": "fg:#6b7280 italic", - "running-prompt-separator": "fg:#d1d5db", - "slash-completion-menu": "", - "slash-completion-menu.separator": "fg:#d1d5db", - "slash-completion-menu.marker": "fg:#9ca3af", - "slash-completion-menu.marker.current": "fg:#2563eb", - "slash-completion-menu.command": "fg:#4b5563", - "slash-completion-menu.meta": "fg:#6b7280", - "slash-completion-menu.command.current": "fg:#1d4ed8 bold", - "slash-completion-menu.meta.current": "fg:#2563eb", -} - - -# --------------------------------------------------------------------------- -# Bottom toolbar fragment colors (used by ui/shell/prompt.py) -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True, slots=True) -class ToolbarColors: - separator: str - yolo_label: str - plan_label: str - plan_prompt: str - cwd: str - bg_tasks: str - tip: str - - -_TOOLBAR_DARK = ToolbarColors( - separator="fg:#4d4d4d", - yolo_label="bold fg:#ffff00", - plan_label="bold fg:#00aaff", - plan_prompt="fg:#00aaff", - cwd="fg:#666666", - bg_tasks="fg:#888888", - tip="fg:#555555", -) - -_TOOLBAR_LIGHT = ToolbarColors( - separator="fg:#d1d5db", - yolo_label="bold fg:#b45309", - plan_label="bold fg:#2563eb", - plan_prompt="fg:#2563eb", - cwd="fg:#6b7280", - bg_tasks="fg:#4b5563", - tip="fg:#9ca3af", -) - - -# --------------------------------------------------------------------------- -# MCP status prompt colors (used by ui/shell/mcp_status.py) -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True, slots=True) -class MCPPromptColors: - text: str - detail: str - connected: str - connecting: str - pending: str - failed: str - - -_MCP_PROMPT_DARK = MCPPromptColors( - text="fg:#d4d4d4", - detail="fg:#7c8594", - connected="fg:#56d364", - connecting="fg:#56a4ff", - pending="fg:#f2cc60", - failed="fg:#ff7b72", -) - -_MCP_PROMPT_LIGHT = MCPPromptColors( - text="fg:#374151", - detail="fg:#6b7280", - connected="fg:#166534", - connecting="fg:#1d4ed8", - pending="fg:#92400e", - failed="fg:#dc2626", -) - - -# --------------------------------------------------------------------------- -# Public API — resolve by theme name -# --------------------------------------------------------------------------- - -_active_theme: ThemeName = "dark" - - -def set_active_theme(theme: ThemeName) -> None: - global _active_theme - _active_theme = theme - - -def get_active_theme() -> ThemeName: - return _active_theme - - -def get_diff_colors() -> DiffColors: - return _DIFF_LIGHT if _active_theme == "light" else _DIFF_DARK - - -def get_task_browser_style() -> PTKStyle: - return _task_browser_style_light() if _active_theme == "light" else _task_browser_style_dark() - - -def get_prompt_style() -> PTKStyle: - d = _PROMPT_STYLE_LIGHT if _active_theme == "light" else _PROMPT_STYLE_DARK - return PTKStyle.from_dict(d) - - -def get_toolbar_colors() -> ToolbarColors: - return _TOOLBAR_LIGHT if _active_theme == "light" else _TOOLBAR_DARK - - -def get_mcp_prompt_colors() -> MCPPromptColors: - return _MCP_PROMPT_LIGHT if _active_theme == "light" else _MCP_PROMPT_DARK diff --git a/src/kimi_cli/ui/theme.ts b/src/kimi_cli/ui/theme.ts new file mode 100644 index 000000000..324abcac1 --- /dev/null +++ b/src/kimi_cli/ui/theme.ts @@ -0,0 +1,267 @@ +/** + * Centralized terminal color theme definitions. + * Corresponds to Python's ui/theme.py. + * + * All UI-facing colors live here so that switching between dark and light + * terminal themes only requires changing the active ThemeName. + */ + +import chalk, { type ChalkInstance } from "chalk"; + +export type ThemeName = "dark" | "light"; + +// ── Diff Colors ──────────────────────────────────────────── + +export interface DiffColors { + addBg: string; + delBg: string; + addHl: string; + delHl: string; +} + +const DIFF_DARK: DiffColors = { + addBg: "#12261e", + delBg: "#2d1214", + addHl: "#1a4a2e", + delHl: "#5c1a1d", +}; + +const DIFF_LIGHT: DiffColors = { + addBg: "#dafbe1", + delBg: "#ffebe9", + addHl: "#aff5b4", + delHl: "#ffc1c0", +}; + +// ── Toolbar Colors ───────────────────────────────────────── + +export interface ToolbarColors { + separator: string; + yoloLabel: string; + planLabel: string; + planPrompt: string; + cwd: string; + bgTasks: string; + tip: string; +} + +const TOOLBAR_DARK: ToolbarColors = { + separator: "#4d4d4d", + yoloLabel: "#ffff00", + planLabel: "#00aaff", + planPrompt: "#00aaff", + cwd: "#666666", + bgTasks: "#888888", + tip: "#555555", +}; + +const TOOLBAR_LIGHT: ToolbarColors = { + separator: "#d1d5db", + yoloLabel: "#b45309", + planLabel: "#2563eb", + planPrompt: "#2563eb", + cwd: "#6b7280", + bgTasks: "#4b5563", + tip: "#9ca3af", +}; + +// ── MCP Prompt Colors ────────────────────────────────────── + +export interface MCPPromptColors { + text: string; + detail: string; + connected: string; + connecting: string; + pending: string; + failed: string; +} + +const MCP_PROMPT_DARK: MCPPromptColors = { + text: "#d4d4d4", + detail: "#7c8594", + connected: "#56d364", + connecting: "#56a4ff", + pending: "#f2cc60", + failed: "#ff7b72", +}; + +const MCP_PROMPT_LIGHT: MCPPromptColors = { + text: "#374151", + detail: "#6b7280", + connected: "#166534", + connecting: "#1d4ed8", + pending: "#92400e", + failed: "#dc2626", +}; + +// ── Message Colors ───────────────────────────────────────── + +export interface MessageColors { + user: string; + assistant: string; + system: string; + tool: string; + error: string; + dim: string; + thinking: string; + highlight: string; +} + +const MESSAGE_DARK: MessageColors = { + user: "#56d364", // Rich "green" + assistant: "#e0e0e0", // bright text for readability + system: "#d670d6", // Rich "magenta" + tool: "#C8C5F4", // Rich "blue" — lavender as seen in Python + error: "#ff7b72", // Rich "dark_red" + dim: "#808080", // Rich "grey50" + thinking: "#7c8594", + highlight: "#56d364", // Rich "green" +}; + +const MESSAGE_LIGHT: MessageColors = { + user: "#166534", // dark green + assistant: "#1f2937", // dark text + system: "#7c3aed", // dark purple + tool: "#1d4ed8", // dark blue + error: "#dc2626", + dim: "#6b7280", + thinking: "#6b7280", + highlight: "#166534", +}; + +// ── Chalk helpers ────────────────────────────────────────── + +export interface ThemeStyles { + user: ChalkInstance; + assistant: ChalkInstance; + system: ChalkInstance; + tool: ChalkInstance; + error: ChalkInstance; + dim: ChalkInstance; + thinking: ChalkInstance; + highlight: ChalkInstance; + bold: ChalkInstance; + italic: ChalkInstance; +} + +function makeStyles(colors: MessageColors): ThemeStyles { + return { + user: chalk.hex(colors.user), + assistant: chalk.hex(colors.assistant), + system: chalk.hex(colors.system), + tool: chalk.hex(colors.tool), + error: chalk.hex(colors.error), + dim: chalk.hex(colors.dim), + thinking: chalk.italic.hex(colors.thinking), + highlight: chalk.hex(colors.highlight), + bold: chalk.bold, + italic: chalk.italic, + }; +} + +// ── Prompt Style ────────────────────────────────────────── + +export interface PromptStyleColors { + sparkle: string; + streamingSparkle: string; + inputText: string; + placeholder: string; + border: string; +} + +const PROMPT_DARK: PromptStyleColors = { + sparkle: "#f2cc60", + streamingSparkle: "#56a4ff", + inputText: "#e6e6e6", + placeholder: "#555555", + border: "#4d4d4d", +}; + +const PROMPT_LIGHT: PromptStyleColors = { + sparkle: "#b45309", + streamingSparkle: "#2563eb", + inputText: "#1f2937", + placeholder: "#9ca3af", + border: "#d1d5db", +}; + +// ── Task Browser Style ─────────────────────────────────── + +export interface TaskBrowserColors { + headerBg: string; + headerFg: string; + selectedBg: string; + selectedFg: string; + borderColor: string; + runningFg: string; + completedFg: string; + failedFg: string; + killedFg: string; + listBg: string; +} + +const TASK_BROWSER_DARK: TaskBrowserColors = { + headerBg: "#1f2937", + headerFg: "#67e8f9", + selectedBg: "#164e63", + selectedFg: "#ecfeff", + borderColor: "#155e75", + runningFg: "#86efac", + completedFg: "#56d364", + failedFg: "#fca5a5", + killedFg: "#fbbf24", + listBg: "#0f172a", +}; + +const TASK_BROWSER_LIGHT: TaskBrowserColors = { + headerBg: "#f3f4f6", + headerFg: "#0e7490", + selectedBg: "#e0f2fe", + selectedFg: "#164e63", + borderColor: "#67e8f9", + runningFg: "#166534", + completedFg: "#166534", + failedFg: "#dc2626", + killedFg: "#b45309", + listBg: "#ffffff", +}; + +// ── Public API ───────────────────────────────────────────── + +let activeTheme: ThemeName = "dark"; + +export function setActiveTheme(theme: ThemeName): void { + activeTheme = theme; +} + +export function getActiveTheme(): ThemeName { + return activeTheme; +} + +export function getDiffColors(): DiffColors { + return activeTheme === "light" ? DIFF_LIGHT : DIFF_DARK; +} + +export function getToolbarColors(): ToolbarColors { + return activeTheme === "light" ? TOOLBAR_LIGHT : TOOLBAR_DARK; +} + +export function getMcpPromptColors(): MCPPromptColors { + return activeTheme === "light" ? MCP_PROMPT_LIGHT : MCP_PROMPT_DARK; +} + +export function getMessageColors(): MessageColors { + return activeTheme === "light" ? MESSAGE_LIGHT : MESSAGE_DARK; +} + +export function getStyles(): ThemeStyles { + return makeStyles(getMessageColors()); +} + +export function getPromptColors(): PromptStyleColors { + return activeTheme === "light" ? PROMPT_LIGHT : PROMPT_DARK; +} + +export function getTaskBrowserColors(): TaskBrowserColors { + return activeTheme === "light" ? TASK_BROWSER_LIGHT : TASK_BROWSER_DARK; +} diff --git a/src/kimi_cli/utils/__init__.py b/src/kimi_cli/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/kimi_cli/utils/aiohttp.py b/src/kimi_cli/utils/aiohttp.py deleted file mode 100644 index bff269076..000000000 --- a/src/kimi_cli/utils/aiohttp.py +++ /dev/null @@ -1,24 +0,0 @@ -from __future__ import annotations - -import ssl - -import aiohttp -import certifi - -_ssl_context = ssl.create_default_context(cafile=certifi.where()) - -_DEFAULT_TIMEOUT = aiohttp.ClientTimeout( - total=120, - sock_read=60, - sock_connect=15, -) - - -def new_client_session( - *, - timeout: aiohttp.ClientTimeout | None = None, -) -> aiohttp.ClientSession: - return aiohttp.ClientSession( - connector=aiohttp.TCPConnector(ssl=_ssl_context), - timeout=timeout or _DEFAULT_TIMEOUT, - ) diff --git a/src/kimi_cli/utils/aioqueue.py b/src/kimi_cli/utils/aioqueue.py deleted file mode 100644 index 92756f662..000000000 --- a/src/kimi_cli/utils/aioqueue.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -import asyncio -import sys - -if sys.version_info >= (3, 13): - QueueShutDown = asyncio.QueueShutDown # type: ignore[assignment] - - class Queue[T](asyncio.Queue[T]): - """Asyncio Queue with shutdown support.""" - -else: - - class QueueShutDown(Exception): - """Raised when operating on a shut down queue.""" - - class _Shutdown: - """Sentinel for queue shutdown.""" - - _SHUTDOWN = _Shutdown() - - class Queue[T](asyncio.Queue[T | _Shutdown]): - """Asyncio Queue with shutdown support for Python < 3.13.""" - - def __init__(self) -> None: - super().__init__() - self._shutdown = False - - def shutdown(self, immediate: bool = False) -> None: - if self._shutdown: - return - self._shutdown = True - if immediate: - self._queue.clear() - - getters = list(getattr(self, "_getters", [])) - count = max(1, len(getters)) - self._enqueue_shutdown(count) - - def _enqueue_shutdown(self, count: int) -> None: - for _ in range(count): - try: - super().put_nowait(_SHUTDOWN) - except asyncio.QueueFull: - self._queue.clear() - super().put_nowait(_SHUTDOWN) - - async def get(self) -> T: - if self._shutdown and self.empty(): - raise QueueShutDown - item = await super().get() - if isinstance(item, _Shutdown): - raise QueueShutDown - return item - - def get_nowait(self) -> T: - if self._shutdown and self.empty(): - raise QueueShutDown - item = super().get_nowait() - if isinstance(item, _Shutdown): - raise QueueShutDown - return item - - async def put(self, item: T) -> None: - if self._shutdown: - raise QueueShutDown - await super().put(item) - - def put_nowait(self, item: T) -> None: - if self._shutdown: - raise QueueShutDown - super().put_nowait(item) diff --git a/src/kimi_cli/utils/async.ts b/src/kimi_cli/utils/async.ts new file mode 100644 index 000000000..2c5af9fdc --- /dev/null +++ b/src/kimi_cli/utils/async.ts @@ -0,0 +1,78 @@ +/** + * Async utilities — corresponds to Python utils/async patterns + */ + +/** Sleep for given milliseconds. */ +export function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +/** Run a function with a timeout. Rejects with TimeoutError if exceeded. */ +export async function withTimeout(fn: () => Promise, timeoutMs: number): Promise { + return Promise.race([ + fn(), + new Promise((_, reject) => + setTimeout(() => reject(new TimeoutError(`Timed out after ${timeoutMs}ms`)), timeoutMs), + ), + ]); +} + +export class TimeoutError extends Error { + constructor(message: string) { + super(message); + this.name = "TimeoutError"; + } +} + +/** + * Deferred — a promise that can be resolved/rejected externally. + * Similar to Python's asyncio.Future. + */ +export class Deferred { + readonly promise: Promise; + resolve!: (value: T) => void; + reject!: (reason: unknown) => void; + private _settled = false; + + constructor() { + this.promise = new Promise((resolve, reject) => { + this.resolve = (v: T) => { + if (!this._settled) { + this._settled = true; + resolve(v); + } + }; + this.reject = (r: unknown) => { + if (!this._settled) { + this._settled = true; + reject(r); + } + }; + }); + } + + get settled(): boolean { + return this._settled; + } +} + +/** Run tasks with a concurrency limit. */ +export async function mapConcurrent( + items: T[], + concurrency: number, + fn: (item: T) => Promise, +): Promise { + const results: R[] = new Array(items.length); + let index = 0; + + async function worker() { + while (index < items.length) { + const i = index++; + results[i] = await fn(items[i]!); + } + } + + const workers = Array.from({ length: Math.min(concurrency, items.length) }, () => worker()); + await Promise.all(workers); + return results; +} diff --git a/src/kimi_cli/utils/broadcast.py b/src/kimi_cli/utils/broadcast.py deleted file mode 100644 index 296ddfd0e..000000000 --- a/src/kimi_cli/utils/broadcast.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio - -from kimi_cli.utils.aioqueue import Queue - - -class BroadcastQueue[T]: - """ - A broadcast queue that allows multiple subscribers to receive published items. - """ - - def __init__(self) -> None: - self._queues: set[Queue[T]] = set() - - def subscribe(self) -> Queue[T]: - """Create a new subscription queue.""" - queue: Queue[T] = Queue() - self._queues.add(queue) - return queue - - def unsubscribe(self, queue: Queue[T]) -> None: - """Remove a subscription queue.""" - self._queues.discard(queue) - - async def publish(self, item: T) -> None: - """Publish an item to all subscription queues.""" - await asyncio.gather(*(queue.put(item) for queue in self._queues)) - - def publish_nowait(self, item: T) -> None: - """Publish an item to all subscription queues without waiting.""" - for queue in self._queues: - queue.put_nowait(item) - - def shutdown(self, immediate: bool = False) -> None: - """Close all subscription queues.""" - for queue in self._queues: - queue.shutdown(immediate=immediate) - self._queues.clear() diff --git a/src/kimi_cli/utils/changelog.py b/src/kimi_cli/utils/changelog.py deleted file mode 100644 index 053e7b1cb..000000000 --- a/src/kimi_cli/utils/changelog.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import NamedTuple - - -class ReleaseEntry(NamedTuple): - description: str - entries: list[str] - - -def parse_changelog(md_text: str) -> dict[str, ReleaseEntry]: - """Parse a subset of Keep a Changelog-style markdown into a map: - version -> (description, entries) - - Parsing rules: - - Versions are denoted by level-2 headings starting with '## [' - Example: `## [v0.10.1] - 2025-09-18` or `## [Unreleased]` - - For each version section, description is the first contiguous block of - non-empty lines that do not start with '-' or '#'. - - Entries are all markdown list items starting with '- ' under that version - (across any subheadings like '### Added'). - """ - lines = md_text.splitlines() - result: dict[str, ReleaseEntry] = {} - - current_ver: str | None = None - collecting_desc = False - desc_lines: list[str] = [] - bullet_lines: list[str] = [] - seen_content_after_header = False - - def commit(): - nonlocal current_ver, desc_lines, bullet_lines, result - if current_ver is None: - return - description = "\n".join([line.strip() for line in desc_lines]).strip() - # Deduplicate and normalize entries - norm_entries = [ - line.strip()[2:].strip() for line in bullet_lines if line.strip().startswith("- ") - ] - result[current_ver] = ReleaseEntry(description=description, entries=norm_entries) - - for raw in lines: - line = raw.rstrip() - # Format: `## 0.75 (2026-01-09)` or `## Unreleased` - if line.startswith("## "): - commit() - ver = line[3:].strip() - # Remove trailing date in parentheses if present - if "(" in ver: - ver = ver[: ver.find("(")].strip() - current_ver = ver - desc_lines = [] - bullet_lines = [] - collecting_desc = True - seen_content_after_header = False - continue - - if current_ver is None: - # Skip until first version section - continue - - if not line.strip(): - # blank line ends initial description block only after we've seen content - if collecting_desc and seen_content_after_header: - collecting_desc = False - continue - - seen_content_after_header = True - - if line.lstrip().startswith("### "): - collecting_desc = False - continue - - if line.lstrip().startswith("- "): - collecting_desc = False - bullet_lines.append(line.strip()) - continue - - if collecting_desc: - # Accumulate description until a blank line or bullets/subheadings - desc_lines.append(line.strip()) - # else: ignore any other free-form text after description block - - # Final flush - commit() - return result - - -def format_release_notes(changelog: dict[str, ReleaseEntry], include_lib_changes: bool) -> str: - parts: list[str] = [] - for ver, entry in changelog.items(): - s = f"[bold]{ver}[/bold]" - if entry.description: - s += f": {entry.description}" - if entry.entries: - for it in entry.entries: - if it.lower().startswith("lib:") and not include_lib_changes: - continue - s += "\n[markdown.item.bullet]• [/]" + it - parts.append(s + "\n") - return "\n".join(parts).strip() - - -CHANGELOG = parse_changelog( - (Path(__file__).parent.parent / "CHANGELOG.md").read_text(encoding="utf-8") -) diff --git a/src/kimi_cli/utils/changelog.ts b/src/kimi_cli/utils/changelog.ts new file mode 100644 index 000000000..a79be52c8 --- /dev/null +++ b/src/kimi_cli/utils/changelog.ts @@ -0,0 +1,19 @@ +/** + * Changelog data — corresponds to Python utils/changelog.py + */ + +export interface ChangelogEntry { + description: string; + entries: string[]; +} + +export const CHANGELOG: Record = { + "2.0.0": { + description: "TypeScript rewrite", + entries: [ + "Complete rewrite in TypeScript with Bun runtime", + "New Ink-based terminal UI", + "Improved performance and startup time", + ], + }, +}; diff --git a/src/kimi_cli/utils/clipboard.py b/src/kimi_cli/utils/clipboard.py deleted file mode 100644 index ac76363d4..000000000 --- a/src/kimi_cli/utils/clipboard.py +++ /dev/null @@ -1,169 +0,0 @@ -from __future__ import annotations - -import importlib -import os -import sys -from collections.abc import Iterable -from dataclasses import dataclass -from pathlib import Path -from typing import Any, cast - -import pyperclip -from PIL import Image, ImageGrab - -# Video file extensions recognized for clipboard paste. -_VIDEO_SUFFIXES: frozenset[str] = frozenset( - {".mp4", ".mkv", ".avi", ".mov", ".wmv", ".webm", ".m4v", ".flv", ".3gp", ".3g2"} -) - - -@dataclass(frozen=True, slots=True) -class ClipboardResult: - """Result of reading media from the clipboard. - - Both fields may be non-empty when the clipboard contains a mix of - image files and non-image files (videos, PDFs, etc.). - """ - - images: tuple[Image.Image, ...] - file_paths: tuple[Path, ...] - - -def is_clipboard_available() -> bool: - """Check if the Pyperclip clipboard is available.""" - try: - pyperclip.paste() - return True - except Exception: - return False - - -def grab_media_from_clipboard() -> ClipboardResult | None: - """Read media from the clipboard. - - Inspects the clipboard once and returns all detected media. - Image files are returned as loaded PIL images; non-image files - (videos, PDFs, etc.) are returned as file paths. - - On macOS the native pasteboard API is tried first to avoid - misidentifying a file's thumbnail as clipboard image data. - """ - # 1. Try macOS native API for file paths (most reliable for Finder copies). - if sys.platform == "darwin": - file_paths = _read_clipboard_file_paths_macos_native() - images, non_image_paths = _classify_file_paths(file_paths) - if images or non_image_paths: - return ClipboardResult( - images=tuple(images), - file_paths=tuple(non_image_paths), - ) - - # 2. Try PIL ImageGrab as fallback. - # - On macOS this uses AppleScript «class furl» for file paths, - # or reads raw image data (TIFF/PNG) from the pasteboard. - # - On other platforms this is the primary clipboard access method. - payload = ImageGrab.grabclipboard() - if payload is None: - return None - if isinstance(payload, Image.Image): - # Raw image data (screenshot or thumbnail). - # If we reach here, the macOS native path lookup did not find any - # file paths, so this is safe to treat as a real image. - return ClipboardResult(images=(payload,), file_paths=()) - # payload is a list of file path strings. - images, non_image_paths = _classify_file_paths(payload) - if images or non_image_paths: - return ClipboardResult( - images=tuple(images), - file_paths=tuple(non_image_paths), - ) - return None - - -def _classify_file_paths( - paths: Iterable[os.PathLike[str] | str], -) -> tuple[list[Image.Image], list[Path]]: - """Classify clipboard file paths into images and non-image files. - - Returns ``(images, non_image_paths)`` where *images* contains loaded - PIL images and *non_image_paths* contains paths to videos, documents, - and other non-image files. - """ - resolved: list[Path] = [] - for item in paths: - try: - path = Path(item) - except (TypeError, ValueError): - continue - if not path.is_file(): - continue - resolved.append(path) - - images: list[Image.Image] = [] - non_image_paths: list[Path] = [] - - for path in resolved: - # Video files are never opened as images. - if path.suffix.lower() in _VIDEO_SUFFIXES: - non_image_paths.append(path) - continue - try: - with Image.open(path) as img: - img.load() - images.append(img.copy()) - except Exception: - non_image_paths.append(path) - - return images, non_image_paths - - -def _read_clipboard_file_paths_macos_native() -> list[Path]: - try: - appkit = cast(Any, importlib.import_module("AppKit")) - foundation = cast(Any, importlib.import_module("Foundation")) - except Exception: - return [] - - NSPasteboard = appkit.NSPasteboard - NSURL = foundation.NSURL - options_key = getattr( - appkit, - "NSPasteboardURLReadingFileURLsOnlyKey", - "NSPasteboardURLReadingFileURLsOnlyKey", - ) - - pb = NSPasteboard.generalPasteboard() - options = {options_key: True} - try: - urls: list[Any] | None = pb.readObjectsForClasses_options_([NSURL], options) - except Exception: - urls = None - - paths: list[Path] = [] - if urls: - for url in urls: - try: - path = url.path() - except Exception: - continue - if path: - paths.append(Path(str(path))) - - if paths: - return paths - - try: - file_list = cast(list[str] | str | None, pb.propertyListForType_("NSFilenamesPboardType")) - except Exception: - return [] - - if not file_list: - return [] - - file_items: list[str] = [] - if isinstance(file_list, list): - file_items.extend(item for item in file_list if item) - else: - file_items.append(file_list) - - return [Path(item) for item in file_items] diff --git a/src/kimi_cli/utils/clipboard.ts b/src/kimi_cli/utils/clipboard.ts new file mode 100644 index 000000000..d9bc13f3c --- /dev/null +++ b/src/kimi_cli/utils/clipboard.ts @@ -0,0 +1,94 @@ +/** + * Clipboard utilities — corresponds to Python utils/clipboard.py + * Clipboard access for media (images, files) via system commands. + */ + +import { logger } from "./logging.ts"; + +export interface ClipboardResult { + readonly imagePaths: string[]; + readonly filePaths: string[]; + readonly text?: string; +} + +/** + * Check if clipboard text access is available. + */ +export async function isClipboardAvailable(): Promise { + try { + if (process.platform === "darwin") { + const proc = Bun.spawn(["pbpaste"], { stdout: "pipe", stderr: "pipe" }); + await proc.exited; + return true; + } + if (process.platform === "linux") { + // Try xclip first, then xsel + for (const cmd of ["xclip", "xsel"]) { + try { + const proc = Bun.spawn(["which", cmd], { stdout: "pipe", stderr: "pipe" }); + const code = await proc.exited; + if (code === 0) return true; + } catch { + continue; + } + } + } + return false; + } catch { + return false; + } +} + +/** + * Read text from clipboard. + */ +export async function readClipboardText(): Promise { + try { + let proc: ReturnType; + if (process.platform === "darwin") { + proc = Bun.spawn(["pbpaste"], { stdout: "pipe", stderr: "pipe" }); + } else if (process.platform === "linux") { + proc = Bun.spawn(["xclip", "-selection", "clipboard", "-o"], { + stdout: "pipe", + stderr: "pipe", + }); + } else { + return undefined; + } + const code = await proc.exited; + if (code !== 0) return undefined; + return await new Response(proc.stdout as ReadableStream).text(); + } catch { + logger.debug("Failed to read clipboard text"); + return undefined; + } +} + +/** + * Write text to clipboard. + */ +export async function writeClipboardText(text: string): Promise { + try { + let proc: ReturnType; + if (process.platform === "darwin") { + proc = Bun.spawn(["pbcopy"], { + stdin: new Blob([text]), + stdout: "pipe", + stderr: "pipe", + }); + } else if (process.platform === "linux") { + proc = Bun.spawn(["xclip", "-selection", "clipboard"], { + stdin: new Blob([text]), + stdout: "pipe", + stderr: "pipe", + }); + } else { + return false; + } + const code = await proc.exited; + return code === 0; + } catch { + logger.debug("Failed to write clipboard text"); + return false; + } +} diff --git a/src/kimi_cli/utils/datetime.py b/src/kimi_cli/utils/datetime.py deleted file mode 100644 index 860d2e69d..000000000 --- a/src/kimi_cli/utils/datetime.py +++ /dev/null @@ -1,37 +0,0 @@ -from datetime import datetime, timedelta - - -def format_relative_time(timestamp: float) -> str: - """Format a timestamp as a relative time string.""" - now = datetime.now() - dt = datetime.fromtimestamp(timestamp) - diff = now - dt - if diff < timedelta(minutes=5): - return "just now" - if diff < timedelta(hours=1): - minutes = int(diff.total_seconds() / 60) - return f"{minutes}m ago" - if diff < timedelta(days=1): - hours = int(diff.total_seconds() / 3600) - return f"{hours}h ago" - if diff < timedelta(days=7): - return f"{diff.days}d ago" - return dt.strftime("%m-%d") - - -def format_duration(seconds: int) -> str: - """Format a duration in seconds using short units.""" - delta = timedelta(seconds=seconds) - parts: list[str] = [] - days = delta.days - if days: - parts.append(f"{days}d") - hours, remainder = divmod(delta.seconds, 3600) - minutes, secs = divmod(remainder, 60) - if hours: - parts.append(f"{hours}h") - if minutes: - parts.append(f"{minutes}m") - if secs and not parts: - parts.append(f"{secs}s") - return " ".join(parts) or "0s" diff --git a/src/kimi_cli/utils/diff.py b/src/kimi_cli/utils/diff.py deleted file mode 100644 index 7678c2e6d..000000000 --- a/src/kimi_cli/utils/diff.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import annotations - -import asyncio -import difflib -from difflib import SequenceMatcher - -from kosong.tooling import DisplayBlock - -from kimi_cli.tools.display import DiffDisplayBlock - -N_CONTEXT_LINES = 3 - -_HUGE_FILE_THRESHOLD = 10000 -"""Line count above which diff computation is skipped entirely.""" - - -def format_unified_diff( - old_text: str, - new_text: str, - path: str = "", - *, - include_file_header: bool = True, -) -> str: - """ - Format a unified diff between old_text and new_text. - - Args: - old_text: The original text. - new_text: The new text. - path: Optional file path for the diff header. - include_file_header: Whether to include the ---/+++ file header lines. - - Returns: - A unified diff string. - """ - old_lines = old_text.splitlines(keepends=True) - new_lines = new_text.splitlines(keepends=True) - - # Ensure lines end with newline for proper diff formatting - if old_lines and not old_lines[-1].endswith("\n"): - old_lines[-1] += "\n" - if new_lines and not new_lines[-1].endswith("\n"): - new_lines[-1] += "\n" - - fromfile = f"a/{path}" if path else "a/file" - tofile = f"b/{path}" if path else "b/file" - - diff = list( - difflib.unified_diff( - old_lines, - new_lines, - fromfile=fromfile, - tofile=tofile, - lineterm="\n", - ) - ) - - if ( - not include_file_header - and len(diff) >= 2 - and diff[0].startswith("--- ") - and diff[1].startswith("+++ ") - ): - diff = diff[2:] - - return "".join(diff) - - -def _build_diff_blocks_sync( - path: str, - old_text: str, - new_text: str, -) -> list[DisplayBlock]: - """Synchronous diff block builder — CPU-bound, meant to run in a thread.""" - if old_text == new_text: - return [] - - old_lines = old_text.splitlines() - new_lines = new_text.splitlines() - - max_lines = max(len(old_lines), len(new_lines)) - - # Huge files: skip diff entirely, return a summary block - if max_lines > _HUGE_FILE_THRESHOLD: - old_desc = f"({len(old_lines)} lines)" - if len(old_lines) == len(new_lines): - new_desc = f"({len(new_lines)} lines, modified)" - else: - new_desc = f"({len(new_lines)} lines)" - return [ - DiffDisplayBlock( - path=path, - old_text=old_desc, - new_text=new_desc, - old_start=1, - new_start=1, - is_summary=True, - ) - ] - - matcher = SequenceMatcher(None, old_lines, new_lines, autojunk=False) - - blocks: list[DisplayBlock] = [] - for group in matcher.get_grouped_opcodes(n=N_CONTEXT_LINES): - if not group: - continue - i1 = group[0][1] - i2 = group[-1][2] - j1 = group[0][3] - j2 = group[-1][4] - blocks.append( - DiffDisplayBlock( - path=path, - old_text="\n".join(old_lines[i1:i2]), - new_text="\n".join(new_lines[j1:j2]), - old_start=i1 + 1, - new_start=j1 + 1, - ) - ) - return blocks - - -async def build_diff_blocks( - path: str, - old_text: str, - new_text: str, -) -> list[DisplayBlock]: - """Build diff display blocks grouped with small context windows. - - Runs the CPU-bound diff computation in a thread to avoid blocking - the event loop. - """ - if old_text == new_text: - return [] - return await asyncio.to_thread(_build_diff_blocks_sync, path, old_text, new_text) diff --git a/src/kimi_cli/utils/diff.ts b/src/kimi_cli/utils/diff.ts new file mode 100644 index 000000000..748d71466 --- /dev/null +++ b/src/kimi_cli/utils/diff.ts @@ -0,0 +1,298 @@ +/** + * Diff utilities — corresponds to Python utils/diff.py + * Unified diff formatting and diff block generation. + */ + +const N_CONTEXT_LINES = 3; +const HUGE_FILE_THRESHOLD = 10000; + +/** + * Format a unified diff between old_text and new_text. + */ +export function formatUnifiedDiff( + oldText: string, + newText: string, + path = "", + opts?: { includeFileHeader?: boolean }, +): string { + const includeFileHeader = opts?.includeFileHeader ?? true; + const oldLines = oldText.split("\n"); + const newLines = newText.split("\n"); + + const fromFile = path ? `a/${path}` : "a/file"; + const toFile = path ? `b/${path}` : "b/file"; + + // Simple unified diff implementation + const hunks = computeHunks(oldLines, newLines, N_CONTEXT_LINES); + if (hunks.length === 0) return ""; + + const result: string[] = []; + if (includeFileHeader) { + result.push(`--- ${fromFile}`); + result.push(`+++ ${toFile}`); + } + + for (const hunk of hunks) { + result.push(hunk); + } + + return result.join("\n") + "\n"; +} + +function computeHunks(oldLines: string[], newLines: string[], context: number): string[] { + // LCS-based diff + const ops = diffLines(oldLines, newLines); + if (ops.length === 0) return []; + + const hunks: string[] = []; + let i = 0; + + while (i < ops.length) { + // Find next change + while (i < ops.length && ops[i] === "equal") i++; + if (i >= ops.length) break; + + // Determine hunk boundaries + const changeStart = i; + let changeEnd = i; + while (changeEnd < ops.length) { + if (ops[changeEnd] === "equal") { + // Check if there's another change within context + let nextChange = changeEnd; + while (nextChange < ops.length && ops[nextChange] === "equal") nextChange++; + if (nextChange >= ops.length || nextChange - changeEnd > context * 2) break; + changeEnd = nextChange; + } + changeEnd++; + } + + // Build hunk with context + const start = Math.max(0, changeStart - context); + const end = Math.min(ops.length, changeEnd + context); + + let oldStart = 0; + let newStart = 0; + for (let j = 0; j < start; j++) { + if (ops[j] === "equal" || ops[j] === "delete") oldStart++; + if (ops[j] === "equal" || ops[j] === "insert") newStart++; + } + + let oldCount = 0; + let newCount = 0; + const lines: string[] = []; + + let oldIdx = oldStart; + let newIdx = newStart; + for (let j = start; j < end; j++) { + const op = ops[j]!; + if (op === "equal") { + lines.push(` ${oldLines[oldIdx] ?? ""}`); + oldIdx++; + newIdx++; + oldCount++; + newCount++; + } else if (op === "delete") { + lines.push(`-${oldLines[oldIdx] ?? ""}`); + oldIdx++; + oldCount++; + } else if (op === "insert") { + lines.push(`+${newLines[newIdx] ?? ""}`); + newIdx++; + newCount++; + } + } + + hunks.push(`@@ -${oldStart + 1},${oldCount} +${newStart + 1},${newCount} @@`); + hunks.push(...lines); + + i = changeEnd; + } + + return hunks; +} + +function diffLines(oldLines: string[], newLines: string[]): ("equal" | "delete" | "insert")[] { + const m = oldLines.length; + const n = newLines.length; + + if (m === 0 && n === 0) return []; + if (m === 0) return new Array(n).fill("insert"); + if (n === 0) return new Array(m).fill("delete"); + + // Myers diff algorithm (simplified) + const max = m + n; + const v = new Array(2 * max + 1).fill(0); + const trace: number[][] = []; + + outer: for (let d = 0; d <= max; d++) { + trace.push([...v]); + for (let k = -d; k <= d; k += 2) { + let x: number; + if (k === -d || (k !== d && v[k - 1 + max]! < v[k + 1 + max]!)) { + x = v[k + 1 + max]!; + } else { + x = v[k - 1 + max]! + 1; + } + let y = x - k; + while (x < m && y < n && oldLines[x] === newLines[y]) { + x++; + y++; + } + v[k + max] = x; + if (x >= m && y >= n) break outer; + } + } + + // Backtrack to build edit script + const ops: ("equal" | "delete" | "insert")[] = []; + let x = m; + let y = n; + + for (let d = trace.length - 1; d > 0; d--) { + const prev = trace[d - 1]!; + const k = x - y; + let prevK: number; + if (k === -d || (k !== d && prev[k - 1 + max]! < prev[k + 1 + max]!)) { + prevK = k + 1; + } else { + prevK = k - 1; + } + const prevX = prev[prevK + max]!; + const prevY = prevX - prevK; + + // Diagonal moves (equal) + while (x > prevX + (prevK < k ? 1 : 0) && y > prevY + (prevK < k ? 0 : 1)) { + ops.push("equal"); + x--; + y--; + } + + if (prevK < k) { + ops.push("delete"); + x--; + } else { + ops.push("insert"); + y--; + } + } + + // Remaining diagonal + while (x > 0 && y > 0 && oldLines[x - 1] === newLines[y - 1]) { + ops.push("equal"); + x--; + y--; + } + + ops.reverse(); + return ops; +} + +export interface DiffBlock { + path: string; + oldText: string; + newText: string; + oldStart: number; + newStart: number; + isSummary?: boolean; +} + +/** + * Build diff display blocks grouped with small context windows. + */ +export function buildDiffBlocks( + path: string, + oldText: string, + newText: string, +): DiffBlock[] { + if (oldText === newText) return []; + + const oldLines = oldText.split("\n"); + const newLines = newText.split("\n"); + const maxLines = Math.max(oldLines.length, newLines.length); + + // Huge files: skip diff entirely + if (maxLines > HUGE_FILE_THRESHOLD) { + const oldDesc = `(${oldLines.length} lines)`; + const newDesc = + oldLines.length === newLines.length + ? `(${newLines.length} lines, modified)` + : `(${newLines.length} lines)`; + return [ + { + path, + oldText: oldDesc, + newText: newDesc, + oldStart: 1, + newStart: 1, + isSummary: true, + }, + ]; + } + + // Use simple sequence matching for blocks + const blocks: DiffBlock[] = []; + const ops = diffLines(oldLines, newLines); + + let i = 0; + while (i < ops.length) { + // Skip equal ops + while (i < ops.length && ops[i] === "equal") i++; + if (i >= ops.length) break; + + // Find change range + const changeStart = i; + let changeEnd = i; + while (changeEnd < ops.length) { + if (ops[changeEnd] === "equal") { + let nextChange = changeEnd; + while (nextChange < ops.length && ops[nextChange] === "equal") nextChange++; + if (nextChange >= ops.length || nextChange - changeEnd > N_CONTEXT_LINES * 2) break; + changeEnd = nextChange; + } + changeEnd++; + } + + const start = Math.max(0, changeStart - N_CONTEXT_LINES); + const end = Math.min(ops.length, changeEnd + N_CONTEXT_LINES); + + let oldIdx = 0; + let newIdx = 0; + for (let j = 0; j < start; j++) { + if (ops[j] === "equal" || ops[j] === "delete") oldIdx++; + if (ops[j] === "equal" || ops[j] === "insert") newIdx++; + } + + const oldStart = oldIdx; + const newStart = newIdx; + const blockOldLines: string[] = []; + const blockNewLines: string[] = []; + + for (let j = start; j < end; j++) { + const op = ops[j]!; + if (op === "equal") { + blockOldLines.push(oldLines[oldIdx]!); + blockNewLines.push(newLines[newIdx]!); + oldIdx++; + newIdx++; + } else if (op === "delete") { + blockOldLines.push(oldLines[oldIdx]!); + oldIdx++; + } else { + blockNewLines.push(newLines[newIdx]!); + newIdx++; + } + } + + blocks.push({ + path, + oldText: blockOldLines.join("\n"), + newText: blockNewLines.join("\n"), + oldStart: oldStart + 1, + newStart: newStart + 1, + }); + + i = changeEnd; + } + + return blocks; +} diff --git a/src/kimi_cli/utils/editor.py b/src/kimi_cli/utils/editor.py deleted file mode 100644 index 29511245a..000000000 --- a/src/kimi_cli/utils/editor.py +++ /dev/null @@ -1,91 +0,0 @@ -"""External editor utilities for editing text in $VISUAL/$EDITOR.""" - -from __future__ import annotations - -import contextlib -import os -import shlex -import shutil -import subprocess -import tempfile -from pathlib import Path - -from kimi_cli.utils.logging import logger -from kimi_cli.utils.subprocess_env import get_clean_env - -# VSCode needs --wait to block until the file is closed. -_EDITOR_CANDIDATES = [ - (["code", "--wait"], "code"), - (["vim"], "vim"), - (["vi"], "vi"), - (["nano"], "nano"), -] - - -def get_editor_command(configured: str = "") -> list[str] | None: - """Determine the editor command to use. - - Priority: *configured* (from config) -> $VISUAL -> $EDITOR -> auto-detect. - Auto-detect order: code --wait -> vim -> vi -> nano. - """ - if configured: - try: - return shlex.split(configured) - except ValueError: - logger.warning("Invalid configured editor value: {}", configured) - - for var in ("VISUAL", "EDITOR"): - value = os.environ.get(var) - if value: - try: - return shlex.split(value) - except ValueError: - logger.warning("Invalid {} value: {}", var, value) - continue - - for cmd, binary in _EDITOR_CANDIDATES: - if shutil.which(binary): - return cmd - - return None - - -def edit_text_in_editor(text: str, configured: str = "") -> str | None: - """Open *text* in an external editor and return the edited result. - - Returns ``None`` if the editor failed or the user quit without saving. - """ - editor_cmd = get_editor_command(configured) - if editor_cmd is None: - logger.warning("No editor found. Set $VISUAL or $EDITOR.") - return None - - fd, tmpfile = tempfile.mkstemp(suffix=".md", prefix="kimi-edit-") - try: - with os.fdopen(fd, "w", encoding="utf-8") as f: - f.write(text) - - mtime_before = os.path.getmtime(tmpfile) - - try: - returncode = subprocess.call(editor_cmd + [tmpfile], env=get_clean_env()) - except OSError as exc: - logger.warning("Failed to launch editor {}: {}", editor_cmd, exc) - return None - - if returncode != 0: - logger.warning("Editor exited with non-zero return code: {}", returncode) - return None - - mtime_after = os.path.getmtime(tmpfile) - if mtime_after == mtime_before: - return None - - edited = Path(tmpfile).read_text(encoding="utf-8") - if edited.endswith("\n"): - edited = edited[:-1] - - return edited - finally: - with contextlib.suppress(OSError): - os.unlink(tmpfile) diff --git a/src/kimi_cli/utils/environment.py b/src/kimi_cli/utils/environment.py deleted file mode 100644 index 6b4bb78ad..000000000 --- a/src/kimi_cli/utils/environment.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -import platform -from dataclasses import dataclass -from typing import Literal - -from kaos.path import KaosPath - - -@dataclass(slots=True, frozen=True, kw_only=True) -class Environment: - os_kind: Literal["Windows", "Linux", "macOS"] | str - os_arch: str - os_version: str - shell_name: Literal["bash", "sh", "Windows PowerShell"] - shell_path: KaosPath - - @staticmethod - async def detect() -> Environment: - match platform.system(): - case "Darwin": - os_kind = "macOS" - case "Windows": - os_kind = "Windows" - case "Linux": - os_kind = "Linux" - case system: - os_kind = system - - os_arch = platform.machine() - os_version = platform.version() - - if os_kind == "Windows": - shell_name = "Windows PowerShell" - shell_path = KaosPath("powershell.exe") - else: - possible_paths = [ - KaosPath("/bin/bash"), - KaosPath("/usr/bin/bash"), - KaosPath("/usr/local/bin/bash"), - ] - fallback_path = KaosPath("/bin/sh") - for path in possible_paths: - if await path.is_file(): - shell_name = "bash" - shell_path = path - break - else: - shell_name = "sh" - shell_path = fallback_path - - return Environment( - os_kind=os_kind, - os_arch=os_arch, - os_version=os_version, - shell_name=shell_name, - shell_path=shell_path, - ) diff --git a/src/kimi_cli/utils/environment.ts b/src/kimi_cli/utils/environment.ts new file mode 100644 index 000000000..d5f863f56 --- /dev/null +++ b/src/kimi_cli/utils/environment.ts @@ -0,0 +1,79 @@ +/** + * Environment detection — corresponds to Python utils/environment.py + * Detects OS, architecture, and default shell. + */ + +import { existsSync } from "node:fs"; + +export type OsKind = "macOS" | "Windows" | "Linux" | string; +export type ShellName = "bash" | "sh" | "Windows PowerShell"; + +export interface Environment { + readonly osKind: OsKind; + readonly osArch: string; + readonly osVersion: string; + readonly shellName: ShellName; + readonly shellPath: string; +} + +/** + * Detect the current environment: OS, architecture, and default shell. + */ +export async function detectEnvironment(): Promise { + let osKind: OsKind; + switch (process.platform) { + case "darwin": + osKind = "macOS"; + break; + case "win32": + osKind = "Windows"; + break; + case "linux": + osKind = "Linux"; + break; + default: + osKind = process.platform; + } + + const osArch = process.arch; + + // Get OS version + let osVersion = ""; + try { + const { release } = await import("node:os"); + osVersion = release(); + } catch { + // Ignore + } + + let shellName: ShellName; + let shellPath: string; + + if (osKind === "Windows") { + shellName = "Windows PowerShell"; + shellPath = "powershell.exe"; + } else { + const bashPaths = ["/bin/bash", "/usr/bin/bash", "/usr/local/bin/bash"]; + let found = false; + for (const p of bashPaths) { + if (existsSync(p)) { + shellName = "bash"; + shellPath = p; + found = true; + break; + } + } + if (!found) { + shellName = "sh"; + shellPath = "/bin/sh"; + } + } + + return { + osKind, + osArch, + osVersion, + shellName: shellName!, + shellPath: shellPath!, + }; +} diff --git a/src/kimi_cli/utils/envvar.py b/src/kimi_cli/utils/envvar.py deleted file mode 100644 index 1c5656edb..000000000 --- a/src/kimi_cli/utils/envvar.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -import os - -_TRUE_VALUES = {"1", "true", "t", "yes", "y"} - - -def get_env_bool(name: str, default: bool = False) -> bool: - value = os.getenv(name) - if value is None: - return default - return value.strip().lower() in _TRUE_VALUES - - -def get_env_int(name: str, default: int) -> int: - value = os.getenv(name) - if value is None: - return default - try: - return int(value) - except ValueError: - return default diff --git a/src/kimi_cli/utils/export.py b/src/kimi_cli/utils/export.py deleted file mode 100644 index 6541e1132..000000000 --- a/src/kimi_cli/utils/export.py +++ /dev/null @@ -1,696 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import Sequence -from datetime import datetime -from pathlib import Path -from textwrap import shorten -from typing import TYPE_CHECKING, cast - -import aiofiles -from kaos.path import KaosPath -from kosong.message import Message - -from kimi_cli.notifications.llm import is_notification_message -from kimi_cli.soul.message import is_system_reminder_message, system -from kimi_cli.utils.message import message_stringify -from kimi_cli.utils.path import sanitize_cli_path -from kimi_cli.wire.types import ( - AudioURLPart, - ContentPart, - ImageURLPart, - TextPart, - ThinkPart, - ToolCall, - VideoURLPart, -) - -if TYPE_CHECKING: - from kimi_cli.soul.context import Context - -# --------------------------------------------------------------------------- -# Export helpers -# --------------------------------------------------------------------------- - -_HINT_KEYS = ("path", "file_path", "command", "query", "url", "name", "pattern") -"""Common tool-call argument keys whose values make good one-line hints.""" - - -def _is_checkpoint_message(msg: Message) -> bool: - """Check if a message is an internal checkpoint marker.""" - if msg.role != "user" or len(msg.content) != 1: - return False - part = msg.content[0] - return isinstance(part, TextPart) and part.text.strip().startswith("CHECKPOINT") - - -def _is_internal_user_message(msg: Message) -> bool: - """Check if a user message is internal bookkeeping rather than real user input.""" - return ( - _is_checkpoint_message(msg) - or is_system_reminder_message(msg) - or is_notification_message(msg) - ) - - -def _extract_tool_call_hint(args_json: str) -> str: - """Extract a brief human-readable hint from tool-call arguments. - - Looks for well-known keys (path, command, …) and falls back to the first - short string value. Returns ``""`` when nothing useful is found. - """ - try: - parsed: object = json.loads(args_json, strict=False) - except (json.JSONDecodeError, TypeError): - return "" - if not isinstance(parsed, dict): - return "" - args = cast(dict[str, object], parsed) - - # Prefer well-known keys - for key in _HINT_KEYS: - val = args.get(key) - if isinstance(val, str) and val.strip(): - return shorten(val, width=60, placeholder="…") - - # Fallback: first short string value - for val in args.values(): - if isinstance(val, str) and 0 < len(val) <= 80: - return shorten(val, width=60, placeholder="…") - - return "" - - -def _format_content_part_md(part: ContentPart) -> str: - """Convert a single ContentPart to markdown text.""" - match part: - case TextPart(text=text): - return text - case ThinkPart(think=think): - if not think.strip(): - return "" - return f"
Thinking\n\n{think}\n\n
" - case ImageURLPart(): - return "[image]" - case AudioURLPart(): - return "[audio]" - case VideoURLPart(): - return "[video]" - case _: - return f"[{part.type}]" - - -def _format_tool_call_md(tool_call: ToolCall) -> str: - """Convert a ToolCall to a markdown sub-section with a readable title.""" - args_raw = tool_call.function.arguments or "{}" - hint = _extract_tool_call_hint(args_raw) - title = f"#### Tool Call: {tool_call.function.name}" - if hint: - title += f" (`{hint}`)" - - try: - parsed = json.loads(args_raw, strict=False) - args_formatted = json.dumps(parsed, indent=2, ensure_ascii=False) - except json.JSONDecodeError: - args_formatted = args_raw - - return f"{title}\n\n```json\n{args_formatted}\n```" - - -def _format_tool_result_md(msg: Message, tool_name: str, hint: str) -> str: - """Format a tool result message as a collapsible markdown block.""" - call_id = msg.tool_call_id or "unknown" - - # Use _format_content_part_md for consistency with the rest of the module - # (message_stringify loses ThinkPart and leaks tags) - result_parts: list[str] = [] - for part in msg.content: - text = _format_content_part_md(part) - if text.strip(): - result_parts.append(text) - result_text = "\n".join(result_parts) - - summary = f"Tool Result: {tool_name}" - if hint: - summary += f" (`{hint}`)" - - return ( - f"
{summary}\n\n" - f"\n" - f"{result_text}\n\n" - "
" - ) - - -def _group_into_turns(history: Sequence[Message]) -> list[list[Message]]: - """Group messages into logical turns, each starting at a real user message.""" - turns: list[list[Message]] = [] - current: list[Message] = [] - - for msg in history: - if _is_internal_user_message(msg): - continue - if msg.role == "user" and current: - turns.append(current) - current = [] - current.append(msg) - - if current: - turns.append(current) - return turns - - -def _format_turn_md(messages: list[Message], turn_number: int) -> str: - """Format a logical turn as a markdown section. - - A turn typically contains: - user message -> assistant (thinking + text + tool_calls) -> tool results - -> assistant (more text + tool_calls) -> tool results -> assistant (final) - All assistant/tool messages are grouped under a single ``### Assistant`` heading. - """ - lines: list[str] = [f"## Turn {turn_number}", ""] - - # tool_call_id -> (function_name, hint) - tool_call_info: dict[str, tuple[str, str]] = {} - assistant_header_written = False - - for msg in messages: - if _is_internal_user_message(msg): - continue - - if msg.role == "user": - lines.append("### User") - lines.append("") - for part in msg.content: - text = _format_content_part_md(part) - if text.strip(): - lines.append(text) - lines.append("") - - elif msg.role == "assistant": - if not assistant_header_written: - lines.append("### Assistant") - lines.append("") - assistant_header_written = True - - # Content parts (thinking, text, media) - for part in msg.content: - text = _format_content_part_md(part) - if text.strip(): - lines.append(text) - lines.append("") - - # Tool calls - if msg.tool_calls: - for tc in msg.tool_calls: - hint = _extract_tool_call_hint(tc.function.arguments or "{}") - tool_call_info[tc.id] = (tc.function.name, hint) - lines.append(_format_tool_call_md(tc)) - lines.append("") - - elif msg.role == "tool": - tc_id = msg.tool_call_id or "" - name, hint = tool_call_info.get(tc_id, ("unknown", "")) - lines.append(_format_tool_result_md(msg, name, hint)) - lines.append("") - - elif msg.role in ("system", "developer"): - lines.append(f"### {msg.role.capitalize()}") - lines.append("") - for part in msg.content: - text = _format_content_part_md(part) - if text.strip(): - lines.append(text) - lines.append("") - - return "\n".join(lines) - - -def _build_overview( - history: Sequence[Message], - turns: list[list[Message]], - token_count: int, -) -> str: - """Build the Overview section from existing data (no LLM call).""" - # Topic: first real user message text, truncated - topic = "" - for msg in history: - if msg.role == "user" and not _is_internal_user_message(msg): - topic = shorten(message_stringify(msg), width=80, placeholder="…") - break - - # Count tool calls across all messages - n_tool_calls = sum(len(msg.tool_calls) for msg in history if msg.tool_calls) - - lines = [ - "## Overview", - "", - f"- **Topic**: {topic}" if topic else "- **Topic**: (empty)", - f"- **Conversation**: {len(turns)} turns | " - f"{n_tool_calls} tool calls | {token_count:,} tokens", - "", - "---", - ] - return "\n".join(lines) - - -def build_export_markdown( - session_id: str, - work_dir: str, - history: Sequence[Message], - token_count: int, - now: datetime, -) -> str: - """Build the full export markdown string.""" - lines: list[str] = [ - "---", - f"session_id: {session_id}", - f"exported_at: {now.isoformat(timespec='seconds')}", - f"work_dir: {work_dir}", - f"message_count: {len(history)}", - f"token_count: {token_count}", - "---", - "", - "# Kimi Session Export", - "", - ] - - turns = _group_into_turns(history) - lines.append(_build_overview(history, turns, token_count)) - lines.append("") - - for idx, turn_messages in enumerate(turns): - lines.append(_format_turn_md(turn_messages, idx + 1)) - - return "\n".join(lines) - - -# --------------------------------------------------------------------------- -# Import helpers -# --------------------------------------------------------------------------- - -_IMPORTABLE_EXTENSIONS: frozenset[str] = frozenset( - { - # Markdown / plain text - ".md", - ".markdown", - ".txt", - ".text", - ".rst", - # Data / config - ".json", - ".jsonl", - ".yaml", - ".yml", - ".toml", - ".ini", - ".cfg", - ".conf", - ".csv", - ".tsv", - ".xml", - ".env", - ".properties", - # Source code - ".py", - ".js", - ".ts", - ".jsx", - ".tsx", - ".java", - ".kt", - ".go", - ".rs", - ".c", - ".cpp", - ".h", - ".hpp", - ".cs", - ".rb", - ".php", - ".swift", - ".scala", - ".sh", - ".bash", - ".zsh", - ".fish", - ".ps1", - ".bat", - ".cmd", - ".r", - ".R", - ".lua", - ".pl", - ".pm", - ".ex", - ".exs", - ".erl", - ".hs", - ".ml", - ".sql", - ".graphql", - ".proto", - # Web - ".html", - ".htm", - ".css", - ".scss", - ".sass", - ".less", - ".svg", - # Logs - ".log", - # Documentation - ".tex", - ".bib", - ".org", - ".adoc", - ".wiki", - } -) -"""File extensions accepted by ``/import``. Only text-based formats are -supported — importing binary files (images, PDFs, archives, …) is rejected -with a friendly message.""" - - -def is_importable_file(path_str: str) -> bool: - """Return True if *path_str* has an extension in the importable whitelist. - - Files with no extension are also accepted (could be READMEs, Makefiles, …). - """ - suffix = Path(path_str).suffix.lower() - return suffix == "" or suffix in _IMPORTABLE_EXTENSIONS - - -def _stringify_content_parts(parts: Sequence[ContentPart]) -> str: - """Serialize a list of ContentParts to readable text, preserving ThinkPart.""" - segments: list[str] = [] - for part in parts: - match part: - case TextPart(text=text): - if text.strip(): - segments.append(text) - case ThinkPart(think=think): - if think.strip(): - segments.append(f"\n{think}\n") - case ImageURLPart(): - segments.append("[image]") - case AudioURLPart(): - segments.append("[audio]") - case VideoURLPart(): - segments.append("[video]") - case _: - segments.append(f"[{part.type}]") - return "\n".join(segments) - - -def _stringify_tool_calls(tool_calls: Sequence[ToolCall]) -> str: - """Serialize tool calls to readable text.""" - lines: list[str] = [] - for tc in tool_calls: - args_raw = tc.function.arguments or "{}" - try: - args = json.loads(args_raw, strict=False) - args_str = json.dumps(args, ensure_ascii=False) - except (json.JSONDecodeError, TypeError): - args_str = args_raw - lines.append(f"Tool Call: {tc.function.name}({args_str})") - return "\n".join(lines) - - -def stringify_context_history(history: Sequence[Message]) -> str: - """Convert a sequence of Messages to a readable text transcript. - - Preserves ThinkPart content, tool call information, and tool results - so that an AI receiving the imported context has a complete picture. - """ - parts: list[str] = [] - for msg in history: - if _is_internal_user_message(msg): - continue - - role_label = msg.role.upper() - segments: list[str] = [] - - # Content parts (text, thinking, media) - content_text = _stringify_content_parts(msg.content) - if content_text.strip(): - segments.append(content_text) - - # Tool calls (only on assistant messages) - if msg.tool_calls: - segments.append(_stringify_tool_calls(msg.tool_calls)) - - if not segments: - continue - - header = f"[{role_label}]" - if msg.role == "tool" and msg.tool_call_id: - header = f"[{role_label}] (call_id: {msg.tool_call_id})" - - parts.append(f"{header}\n" + "\n".join(segments)) - return "\n\n".join(parts) - - -# --------------------------------------------------------------------------- -# Shared command logic -# --------------------------------------------------------------------------- - - -async def perform_export( - history: Sequence[Message], - session_id: str, - work_dir: str, - token_count: int, - args: str, - default_dir: Path, -) -> tuple[Path, int] | str: - """Perform the full export operation. - - Returns ``(output_path, message_count)`` on success, or an error message - string on failure. - """ - if not history: - return "No messages to export." - - now = datetime.now().astimezone() - short_id = session_id[:8] - default_name = f"kimi-export-{short_id}-{now.strftime('%Y%m%d-%H%M%S')}.md" - - cleaned = sanitize_cli_path(args) - if cleaned: - # sanitize_cli_path only strips quotes; it preserves trailing separators. - directory_hint = cleaned.endswith(("/", "\\")) - output = Path(cleaned).expanduser() - if not output.is_absolute(): - output = default_dir / output - # Keep explicit "directory intent" even when the directory does not exist yet. - if directory_hint or output.is_dir(): - output = output / default_name - else: - output = default_dir / default_name - - content = build_export_markdown( - session_id=session_id, - work_dir=work_dir, - history=history, - token_count=token_count, - now=now, - ) - - try: - output.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(output, "w", encoding="utf-8") as f: - await f.write(content) - except OSError as e: - return f"Failed to write export file: {e}" - - return (output, len(history)) - - -MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10 MB -"""Maximum size (in bytes) of a file that can be imported via ``/import``.""" - -_SENSITIVE_FILE_PATTERNS: tuple[str, ...] = ( - ".env", - "credentials", - "secrets", - ".pem", - ".key", - ".p12", - ".pfx", - ".keystore", -) -"""File-name substrings that indicate potentially sensitive content.""" - - -def is_sensitive_file(filename: str) -> bool: - """Return True if *filename* looks like it may contain secrets.""" - name = filename.lower() - return any(pat in name for pat in _SENSITIVE_FILE_PATTERNS) - - -def _validate_import_token_budget( - estimated_tokens: int, - current_token_count: int, - max_context_size: int | None, -) -> str | None: - """Return an error if importing would push the session over the context budget. - - *estimated_tokens* is the pre-computed token estimate for the import - message. The check is ``current_token_count + estimated_tokens <= - max_context_size``. - """ - if max_context_size is None or max_context_size <= 0: - return None - - total_after_import = current_token_count + estimated_tokens - if total_after_import <= max_context_size: - return None - - return ( - "Imported content is too large for the current model context " - f"(~{estimated_tokens:,} import tokens + {current_token_count:,} existing " - f"= ~{total_after_import:,} total > {max_context_size:,} token limit). " - "Please import a smaller file or session." - ) - - -async def resolve_import_source( - target: str, - current_session_id: str, - work_dir: KaosPath, -) -> tuple[str, str] | str: - """Resolve the import source to ``(content, source_desc)`` or an error message. - - This function handles I/O and source-level validation (file type, encoding, - byte-size cap). Session-level concerns like token budget are checked by - :func:`perform_import`. - """ - from kimi_cli.session import Session - from kimi_cli.soul.context import Context - - target_path = Path(target).expanduser() - if not target_path.is_absolute(): - target_path = Path(str(work_dir)) / target_path - - if target_path.exists() and target_path.is_dir(): - return "The specified path is a directory; please provide a file to import." - - if target_path.exists() and target_path.is_file(): - if not is_importable_file(target_path.name): - return ( - f"Unsupported file type '{target_path.suffix}'. " - "/import only supports text-based files " - "(e.g. .md, .txt, .json, .py, .log, …)." - ) - - try: - file_size = target_path.stat().st_size - except OSError as e: - return f"Failed to read file: {e}" - if file_size > MAX_IMPORT_SIZE: - limit_mb = MAX_IMPORT_SIZE // (1024 * 1024) - return ( - f"File is too large ({file_size / 1024 / 1024:.1f} MB). " - f"Maximum import size is {limit_mb} MB." - ) - - try: - async with aiofiles.open(target_path, encoding="utf-8") as f: - content = await f.read() - except UnicodeDecodeError: - return ( - f"Cannot import '{target_path.name}': " - "the file does not appear to be valid UTF-8 text." - ) - except OSError as e: - return f"Failed to read file: {e}" - - if not content.strip(): - return "The file is empty, nothing to import." - - return (content, f"file '{target_path.name}'") - - # Not a file on disk — try as session ID - if target == current_session_id: - return "Cannot import the current session into itself." - - source_session = await Session.find(work_dir, target) - if source_session is None: - return f"'{target}' is not a valid file path or session ID." - - source_context = Context(source_session.context_file) - try: - restored = await source_context.restore() - except Exception as e: - return f"Failed to load source session: {e}" - if not restored or not source_context.history: - return "The source session has no messages." - - content = stringify_context_history(source_context.history) - content_bytes = len(content.encode("utf-8")) - if content_bytes > MAX_IMPORT_SIZE: - limit_mb = MAX_IMPORT_SIZE // (1024 * 1024) - actual_mb = content_bytes / 1024 / 1024 - return ( - f"Session content is too large ({actual_mb:.1f} MB). " - f"Maximum import size is {limit_mb} MB." - ) - return (content, f"session '{target}'") - - -def build_import_message(content: str, source_desc: str) -> Message: - """Build the ``Message`` to append to context for an import operation.""" - import_text = f'\n{content}\n' - return Message( - role="user", - content=[ - system( - f"The user has imported context from {source_desc}. " - "This is a prior conversation history that may be relevant " - "to the current session. " - "Please review this context and use it to inform your responses." - ), - TextPart(text=import_text), - ], - ) - - -async def perform_import( - target: str, - current_session_id: str, - work_dir: KaosPath, - context: Context, - max_context_size: int | None = None, -) -> tuple[str, int] | str: - """High-level import operation: resolve source, validate, build message, update context. - - Returns ``(source_desc, content_len)`` on success, or an error message - string. *content_len* is the raw imported content length in characters - (excluding wrapper markup), suitable for user-facing display. - The caller is responsible for any additional side-effects (wire file writes, - UI output, etc.). - """ - from kimi_cli.soul.compaction import estimate_text_tokens - - result = await resolve_import_source( - target=target, - current_session_id=current_session_id, - work_dir=work_dir, - ) - if isinstance(result, str): - return result - - content, source_desc = result - message = build_import_message(content, source_desc) - - # Token budget check — reject before mutating context. - estimated = estimate_text_tokens([message]) - if error := _validate_import_token_budget(estimated, context.token_count, max_context_size): - return error - - await context.append_message(message) - await context.update_token_count(context.token_count + estimated) - - return (source_desc, len(content)) diff --git a/src/kimi_cli/utils/export.ts b/src/kimi_cli/utils/export.ts new file mode 100644 index 000000000..46cff3343 --- /dev/null +++ b/src/kimi_cli/utils/export.ts @@ -0,0 +1,261 @@ +/** + * Export utilities — corresponds to Python utils/export.py + * Session export to markdown format. + */ + +import { join, resolve } from "node:path"; +import { existsSync, mkdirSync, writeFileSync, readFileSync, statSync } from "node:fs"; +import type { ContentPart, Message, ToolCallInfo } from "./message.ts"; +import { messageStringify } from "./message.ts"; + +// ── Export helpers ── + +const HINT_KEYS = ["path", "file_path", "command", "query", "url", "name", "pattern"]; + +function extractToolCallHint(argsJson: string): string { + try { + const parsed = JSON.parse(argsJson); + if (typeof parsed !== "object" || parsed === null) return ""; + const args = parsed as Record; + + // Prefer well-known keys + for (const key of HINT_KEYS) { + const val = args[key]; + if (typeof val === "string" && val.trim()) { + return val.length > 60 ? val.slice(0, 57) + "…" : val; + } + } + + // Fallback: first short string value + for (const val of Object.values(args)) { + if (typeof val === "string" && val.length > 0 && val.length <= 80) { + return val.length > 60 ? val.slice(0, 57) + "…" : val; + } + } + } catch { + // ignore + } + return ""; +} + +function formatContentPartMd(part: ContentPart): string { + if (part.type === "text" && part.text) return part.text; + if (part.type === "think" && part.think) { + if (!part.think.trim()) return ""; + return `
Thinking\n\n${part.think}\n\n
`; + } + if (part.type === "image_url") return "[image]"; + if (part.type === "audio_url") return "[audio]"; + if (part.type === "video_url") return "[video]"; + return `[${part.type}]`; +} + +function formatToolCallMd(tc: ToolCallInfo): string { + const argsRaw = tc.function.arguments || "{}"; + const hint = extractToolCallHint(argsRaw); + let title = `#### Tool Call: ${tc.function.name}`; + if (hint) title += ` (\`${hint}\`)`; + + let argsFormatted: string; + try { + const parsed = JSON.parse(argsRaw); + argsFormatted = JSON.stringify(parsed, null, 2); + } catch { + argsFormatted = argsRaw; + } + + return `${title}\n\n\`\`\`json\n${argsFormatted}\n\`\`\``; +} + +function formatToolResultMd(msg: Message, toolName: string, hint: string): string { + const callId = msg.tool_call_id || "unknown"; + const resultParts: string[] = []; + for (const part of msg.content) { + const text = formatContentPartMd(part); + if (text.trim()) resultParts.push(text); + } + const resultText = resultParts.join("\n"); + + let summary = `Tool Result: ${toolName}`; + if (hint) summary += ` (\`${hint}\`)`; + + return ( + `
${summary}\n\n` + + `\n` + + `${resultText}\n\n` + + `
` + ); +} + +function isInternalUserMessage(msg: Message): boolean { + if (msg.role !== "user" || msg.content.length !== 1) return false; + const part = msg.content[0]!; + return part.type === "text" && (part.text ?? "").trim().startsWith(""); +} + +function groupIntoTurns(history: Message[]): Message[][] { + const turns: Message[][] = []; + let current: Message[] = []; + + for (const msg of history) { + if (isInternalUserMessage(msg)) continue; + if (msg.role === "user" && current.length > 0) { + turns.push(current); + current = []; + } + current.push(msg); + } + if (current.length > 0) turns.push(current); + return turns; +} + +function formatTurnMd(messages: Message[], turnNumber: number): string { + const lines = [`## Turn ${turnNumber}`, ""]; + const toolCallInfo: Record = {}; + let assistantHeaderWritten = false; + + for (const msg of messages) { + if (isInternalUserMessage(msg)) continue; + + if (msg.role === "user") { + lines.push("### User", ""); + for (const part of msg.content) { + const text = formatContentPartMd(part); + if (text.trim()) { + lines.push(text, ""); + } + } + } else if (msg.role === "assistant") { + if (!assistantHeaderWritten) { + lines.push("### Assistant", ""); + assistantHeaderWritten = true; + } + for (const part of msg.content) { + const text = formatContentPartMd(part); + if (text.trim()) { + lines.push(text, ""); + } + } + if (msg.tool_calls) { + for (const tc of msg.tool_calls) { + const hint = extractToolCallHint(tc.function.arguments || "{}"); + toolCallInfo[tc.id] = [tc.function.name, hint]; + lines.push(formatToolCallMd(tc), ""); + } + } + } else if (msg.role === "tool") { + const tcId = msg.tool_call_id || ""; + const [name, hint] = toolCallInfo[tcId] ?? ["unknown", ""]; + lines.push(formatToolResultMd(msg, name, hint), ""); + } else if (msg.role === "system" || msg.role === "developer") { + lines.push(`### ${msg.role.charAt(0).toUpperCase() + msg.role.slice(1)}`, ""); + for (const part of msg.content) { + const text = formatContentPartMd(part); + if (text.trim()) { + lines.push(text, ""); + } + } + } + } + return lines.join("\n"); +} + +function buildOverview(history: Message[], turns: Message[][], tokenCount: number): string { + let topic = ""; + for (const msg of history) { + if (msg.role === "user" && !isInternalUserMessage(msg)) { + const full = messageStringify(msg); + topic = full.length > 80 ? full.slice(0, 77) + "…" : full; + break; + } + } + + const nToolCalls = history.reduce( + (sum, msg) => sum + (msg.tool_calls?.length ?? 0), + 0, + ); + + return [ + "## Overview", + "", + topic ? `- **Topic**: ${topic}` : "- **Topic**: (empty)", + `- **Conversation**: ${turns.length} turns | ${nToolCalls} tool calls | ${tokenCount.toLocaleString()} tokens`, + "", + "---", + ].join("\n"); +} + +/** + * Build the full export markdown string. + */ +export function buildExportMarkdown(opts: { + sessionId: string; + workDir: string; + history: Message[]; + tokenCount: number; + now: Date; +}): string { + const { sessionId, workDir, history, tokenCount, now } = opts; + const lines = [ + "---", + `session_id: ${sessionId}`, + `exported_at: ${now.toISOString()}`, + `work_dir: ${workDir}`, + `message_count: ${history.length}`, + `token_count: ${tokenCount}`, + "---", + "", + "# Kimi Session Export", + "", + ]; + + const turns = groupIntoTurns(history); + lines.push(buildOverview(history, turns, tokenCount)); + lines.push(""); + + for (let i = 0; i < turns.length; i++) { + lines.push(formatTurnMd(turns[i]!, i + 1)); + } + + return lines.join("\n"); +} + +// ── Import helpers ── + +const IMPORTABLE_EXTENSIONS = new Set([ + ".md", ".markdown", ".txt", ".text", ".rst", + ".json", ".jsonl", ".yaml", ".yml", ".toml", ".ini", ".cfg", ".conf", + ".csv", ".tsv", ".xml", ".env", ".properties", + ".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".kt", ".go", ".rs", + ".c", ".cpp", ".h", ".hpp", ".cs", ".rb", ".php", ".swift", ".scala", + ".sh", ".bash", ".zsh", ".fish", ".ps1", ".bat", ".cmd", + ".r", ".R", ".lua", ".pl", ".pm", ".ex", ".exs", ".erl", ".hs", ".ml", + ".sql", ".graphql", ".proto", + ".html", ".htm", ".css", ".scss", ".sass", ".less", ".svg", + ".log", + ".tex", ".bib", ".org", ".adoc", ".wiki", +]); + +/** + * Check if a file path has an importable extension. + */ +export function isImportableFile(pathStr: string): boolean { + const lastDot = pathStr.lastIndexOf("."); + if (lastDot === -1) return true; // No extension = ok + const suffix = pathStr.slice(lastDot).toLowerCase(); + return IMPORTABLE_EXTENSIONS.has(suffix); +} + +export const MAX_IMPORT_SIZE = 10 * 1024 * 1024; // 10 MB + +const SENSITIVE_FILE_PATTERNS = [ + ".env", "credentials", "secrets", ".pem", ".key", ".p12", ".pfx", ".keystore", +]; + +/** + * Check if a filename looks like it may contain secrets. + */ +export function isSensitiveFile(filename: string): boolean { + const name = filename.toLowerCase(); + return SENSITIVE_FILE_PATTERNS.some((pat) => name.includes(pat)); +} diff --git a/src/kimi_cli/utils/frontmatter.py b/src/kimi_cli/utils/frontmatter.py deleted file mode 100644 index afd1b3873..000000000 --- a/src/kimi_cli/utils/frontmatter.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Any, cast - -import yaml - - -def parse_frontmatter(text: str) -> dict[str, Any] | None: - """ - Parse YAML frontmatter from a text blob. - - Raises: - ValueError: If the frontmatter YAML is invalid. - """ - lines = text.splitlines() - if not lines or lines[0].strip() != "---": - return None - - frontmatter_lines: list[str] = [] - for line in lines[1:]: - if line.strip() == "---": - break - frontmatter_lines.append(line) - else: - return None - - frontmatter = "\n".join(frontmatter_lines).strip() - if not frontmatter: - return None - - try: - raw_data: Any = yaml.safe_load(frontmatter) - except yaml.YAMLError as exc: - raise ValueError("Invalid frontmatter YAML.") from exc - - if not isinstance(raw_data, dict): - raise ValueError("Frontmatter YAML must be a mapping.") - - return cast(dict[str, Any], raw_data) - - -def read_frontmatter(path: Path) -> dict[str, Any] | None: - """ - Read the YAML frontmatter at the start of a file. - - Args: - path: Path to an existing file that may contain frontmatter. - """ - return parse_frontmatter(path.read_text(encoding="utf-8", errors="replace")) diff --git a/src/kimi_cli/utils/io.py b/src/kimi_cli/utils/io.py deleted file mode 100644 index 55f346b03..000000000 --- a/src/kimi_cli/utils/io.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -import contextlib -import json -import os -import tempfile -from pathlib import Path -from typing import Any - - -def atomic_json_write(data: Any, path: Path) -> None: - """Write JSON data to a file atomically using tmp-file + os.replace. - - This prevents data corruption if the process crashes mid-write: either the - old file is kept intact or the new file is fully committed. - """ - fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp") - try: - with os.fdopen(fd, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2, ensure_ascii=False) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, path) - except BaseException: - with contextlib.suppress(OSError): - os.unlink(tmp_path) - raise diff --git a/src/kimi_cli/utils/logging.py b/src/kimi_cli/utils/logging.py deleted file mode 100644 index f1fbfbadb..000000000 --- a/src/kimi_cli/utils/logging.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import annotations - -import codecs -import contextlib -import locale -import os -import sys -import threading -from collections.abc import Iterator -from typing import IO - -from kimi_cli import logger - - -class StderrRedirector: - def __init__(self, level: str = "ERROR") -> None: - self._level = level - self._encoding: str | None = None - self._installed = False - self._lock = threading.Lock() - self._original_fd: int | None = None - self._read_fd: int | None = None - self._thread: threading.Thread | None = None - - def install(self) -> None: - with self._lock: - if self._installed: - return - with contextlib.suppress(Exception): - sys.stderr.flush() - if self._original_fd is None: - with contextlib.suppress(OSError): - self._original_fd = os.dup(2) - if self._encoding is None: - self._encoding = ( - sys.stderr.encoding or locale.getpreferredencoding(False) or "utf-8" - ) - read_fd, write_fd = os.pipe() - os.dup2(write_fd, 2) - os.close(write_fd) - self._read_fd = read_fd - self._thread = threading.Thread( - target=self._drain, name="kimi-stderr-redirect", daemon=True - ) - self._thread.start() - self._installed = True - - def uninstall(self) -> None: - with self._lock: - if not self._installed: - return - if self._original_fd is not None: - os.dup2(self._original_fd, 2) - self._installed = False - if self._thread is not None: - self._thread.join(timeout=2.0) - self._thread = None - - def _drain(self) -> None: - buffer = "" - read_fd = self._read_fd - if read_fd is None: - return - encoding = self._encoding or "utf-8" - decoder = codecs.getincrementaldecoder(encoding)(errors="replace") - try: - while True: - chunk = os.read(read_fd, 4096) - if not chunk: - break - buffer += decoder.decode(chunk) - while "\n" in buffer: - line, buffer = buffer.split("\n", 1) - self._log_line(line) - except Exception: - logger.exception("Failed to read redirected stderr") - finally: - buffer += decoder.decode(b"", final=True) - if buffer: - self._log_line(buffer) - with contextlib.suppress(OSError): - os.close(read_fd) - - def _log_line(self, line: str) -> None: - text = line.rstrip("\r") - if not text: - return - logger.opt(depth=2).log(self._level, text) - - def open_original_stderr_handle(self) -> IO[bytes] | None: - if self._original_fd is None: - return None - dup_fd = os.dup(self._original_fd) - os.set_inheritable(dup_fd, True) - return os.fdopen(dup_fd, "wb", closefd=True) - - -_stderr_redirector: StderrRedirector | None = None - - -def redirect_stderr_to_logger(level: str = "ERROR") -> None: - global _stderr_redirector - if _stderr_redirector is None: - _stderr_redirector = StderrRedirector(level=level) - _stderr_redirector.install() - - -def restore_stderr() -> None: - if _stderr_redirector is not None: - _stderr_redirector.uninstall() - - -@contextlib.contextmanager -def open_original_stderr() -> Iterator[IO[bytes] | None]: - redirector = _stderr_redirector - if redirector is None: - yield None - return - stream = redirector.open_original_stderr_handle() - try: - yield stream - finally: - if stream is not None: - stream.close() diff --git a/src/kimi_cli/utils/logging.ts b/src/kimi_cli/utils/logging.ts new file mode 100644 index 000000000..7852ff0ef --- /dev/null +++ b/src/kimi_cli/utils/logging.ts @@ -0,0 +1,51 @@ +/** + * Logging module — corresponds to Python utils/logging.py + * Simple structured logger using console with level filtering. + */ + +export type LogLevel = "debug" | "info" | "warn" | "error"; + +const LOG_LEVELS: Record = { + debug: 0, + info: 1, + warn: 2, + error: 3, +}; + +class Logger { + private level: LogLevel = "info"; + + setLevel(level: LogLevel): void { + this.level = level; + } + + private shouldLog(level: LogLevel): boolean { + return LOG_LEVELS[level] >= LOG_LEVELS[this.level]; + } + + debug(message: string, ...args: unknown[]): void { + if (this.shouldLog("debug")) process.stderr.write(`[DEBUG] ${message}${args.length ? " " + args.map(String).join(" ") : ""}\n`); + } + + info(message: string, ...args: unknown[]): void { + if (this.shouldLog("info")) process.stderr.write(`[INFO] ${message}${args.length ? " " + args.map(String).join(" ") : ""}\n`); + } + + warn(message: string, ...args: unknown[]): void { + if (this.shouldLog("warn")) process.stderr.write(`[WARN] ${message}${args.length ? " " + args.map(String).join(" ") : ""}\n`); + } + + error(message: string, ...args: unknown[]): void { + if (this.shouldLog("error")) process.stderr.write(`[ERROR] ${message}${args.length ? " " + args.map(String).join(" ") : ""}\n`); + } +} + +export const logger = new Logger(); + +// Set default level from environment +if (process.env.KIMI_LOG_LEVEL) { + const envLevel = process.env.KIMI_LOG_LEVEL.toLowerCase() as LogLevel; + if (envLevel in LOG_LEVELS) { + logger.setLevel(envLevel); + } +} diff --git a/src/kimi_cli/utils/media_tags.py b/src/kimi_cli/utils/media_tags.py deleted file mode 100644 index 0247868ad..000000000 --- a/src/kimi_cli/utils/media_tags.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from html import escape - -from kimi_cli.wire.types import ContentPart, TextPart - - -def _format_tag(tag: str, attrs: Mapping[str, str | None] | None = None) -> str: - if not attrs: - return f"<{tag}>" - rendered: list[str] = [] - for key, value in sorted(attrs.items()): - if not value: - continue - rendered.append(f'{key}="{escape(str(value), quote=True)}"') - if not rendered: - return f"<{tag}>" - return f"<{tag} " + " ".join(rendered) + ">" - - -def wrap_media_part( - part: ContentPart, *, tag: str, attrs: Mapping[str, str | None] | None = None -) -> list[ContentPart]: - return [ - TextPart(text=_format_tag(tag, attrs)), - part, - TextPart(text=f""), - ] diff --git a/src/kimi_cli/utils/message.py b/src/kimi_cli/utils/message.py deleted file mode 100644 index 6ca2c4568..000000000 --- a/src/kimi_cli/utils/message.py +++ /dev/null @@ -1,24 +0,0 @@ -from __future__ import annotations - -from kosong.message import Message - -from kimi_cli.wire.types import AudioURLPart, ImageURLPart, TextPart, VideoURLPart - - -def message_stringify(message: Message) -> str: - """Get a string representation of a message.""" - # TODO: this should be merged into `kosong.message.Message.extract_text` - parts: list[str] = [] - for part in message.content: - if isinstance(part, TextPart): - parts.append(part.text) - elif isinstance(part, ImageURLPart): - parts.append("[image]") - elif isinstance(part, AudioURLPart): - suffix = f":{part.audio_url.id}" if part.audio_url.id else "" - parts.append(f"[audio{suffix}]") - elif isinstance(part, VideoURLPart): - parts.append("[video]") - else: - parts.append(f"[{part.type}]") - return "".join(parts) diff --git a/src/kimi_cli/utils/message.ts b/src/kimi_cli/utils/message.ts new file mode 100644 index 000000000..31bd66b76 --- /dev/null +++ b/src/kimi_cli/utils/message.ts @@ -0,0 +1,47 @@ +/** + * Message utilities — corresponds to Python utils/message.py + * String representation of messages for display and export. + */ + +export interface ContentPart { + type: string; + text?: string; + think?: string; + [key: string]: unknown; +} + +export interface Message { + role: string; + content: ContentPart[]; + tool_calls?: ToolCallInfo[]; + tool_call_id?: string; +} + +export interface ToolCallInfo { + id: string; + function: { + name: string; + arguments: string; + }; +} + +/** + * Get a string representation of a message. + */ +export function messageStringify(message: Message): string { + const parts: string[] = []; + for (const part of message.content) { + if (part.type === "text" && part.text) { + parts.push(part.text); + } else if (part.type === "image_url") { + parts.push("[image]"); + } else if (part.type === "audio_url") { + parts.push("[audio]"); + } else if (part.type === "video_url") { + parts.push("[video]"); + } else { + parts.push(`[${part.type}]`); + } + } + return parts.join(""); +} diff --git a/src/kimi_cli/utils/path.py b/src/kimi_cli/utils/path.py deleted file mode 100644 index 0107b58b9..000000000 --- a/src/kimi_cli/utils/path.py +++ /dev/null @@ -1,140 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import re -from collections.abc import Sequence -from pathlib import Path, PurePath -from stat import S_ISDIR - -import aiofiles.os -from kaos.path import KaosPath - -_ROTATION_OPEN_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY -_ROTATION_FILE_MODE = 0o600 - - -async def _reserve_rotation_path(path: Path) -> bool: - """Atomically create an empty file as a reservation for *path*.""" - - def _create() -> None: - fd = os.open(str(path), _ROTATION_OPEN_FLAGS, _ROTATION_FILE_MODE) - os.close(fd) - - try: - await asyncio.to_thread(_create) - except FileExistsError: - return False - return True - - -async def next_available_rotation(path: Path) -> Path | None: - """Return a reserved rotation path for *path* or ``None`` if parent is missing. - - The caller must overwrite/reuse the returned path immediately because this helper - commits an empty placeholder file to guarantee uniqueness. It is therefore suited - for rotating *files* (like history logs) but **not** directory creation. - """ - - if not path.parent.exists(): - return None - - base_name = path.stem - suffix = path.suffix - pattern = re.compile(rf"^{re.escape(base_name)}_(\d+){re.escape(suffix)}$") - max_num = 0 - for entry in await aiofiles.os.listdir(path.parent): - if match := pattern.match(entry): - max_num = max(max_num, int(match.group(1))) - - next_num = max_num + 1 - while True: - next_path = path.parent / f"{base_name}_{next_num}{suffix}" - if await _reserve_rotation_path(next_path): - return next_path - next_num += 1 - - -async def list_directory(work_dir: KaosPath) -> str: - """Return an ``ls``-like listing of *work_dir*. - - This helper is used mainly to provide context to the LLM (for example - ``KIMI_WORK_DIR_LS``) and to show top-level directory contents in tools. - It should therefore be robust against per-entry filesystem issues such as - broken symlinks or permission errors: a single bad entry must not crash - the whole CLI. - """ - - entries: list[str] = [] - # Iterate entries; tolerate per-entry stat failures (broken symlinks, permissions, etc.). - async for entry in work_dir.iterdir(): - try: - st = await entry.stat() - except OSError: - # Broken symlink, permission error, etc. – keep listing other entries. - entries.append(f"?--------- {'?':>10} {entry.name} [stat failed]") - continue - mode = "d" if S_ISDIR(st.st_mode) else "-" - mode += "r" if st.st_mode & 0o400 else "-" - mode += "w" if st.st_mode & 0o200 else "-" - mode += "x" if st.st_mode & 0o100 else "-" - mode += "r" if st.st_mode & 0o040 else "-" - mode += "w" if st.st_mode & 0o020 else "-" - mode += "x" if st.st_mode & 0o010 else "-" - mode += "r" if st.st_mode & 0o004 else "-" - mode += "w" if st.st_mode & 0o002 else "-" - mode += "x" if st.st_mode & 0o001 else "-" - entries.append(f"{mode} {st.st_size:>10} {entry.name}") - return "\n".join(entries) - - -def shorten_home(path: KaosPath) -> KaosPath: - """ - Convert absolute path to use `~` for home directory. - """ - try: - home = KaosPath.home() - p = path.relative_to(home) - return KaosPath("~") / p - except Exception: - return path - - -def sanitize_cli_path(raw: str) -> str: - """Strip surrounding quotes from a CLI path argument. - - On macOS, dragging a file into the terminal wraps the path in single - quotes (e.g. ``'/path/to/file'``). This helper strips matching outer - quotes (single or double) so downstream path handling works correctly. - """ - raw = raw.strip() - if len(raw) >= 2 and ((raw[0] == "'" and raw[-1] == "'") or (raw[0] == '"' and raw[-1] == '"')): - raw = raw[1:-1] - return raw - - -def is_within_directory(path: KaosPath, directory: KaosPath) -> bool: - """ - Check whether *path* is contained within *directory* using pure path semantics. - Both arguments should already be canonicalized (e.g. via KaosPath.canonical()). - """ - candidate = PurePath(str(path)) - base = PurePath(str(directory)) - try: - candidate.relative_to(base) - return True - except ValueError: - return False - - -def is_within_workspace( - path: KaosPath, - work_dir: KaosPath, - additional_dirs: Sequence[KaosPath] = (), -) -> bool: - """ - Check whether *path* is within the workspace (work_dir or any additional directory). - """ - if is_within_directory(path, work_dir): - return True - return any(is_within_directory(path, d) for d in additional_dirs) diff --git a/src/kimi_cli/utils/path.ts b/src/kimi_cli/utils/path.ts new file mode 100644 index 000000000..43f3d8949 --- /dev/null +++ b/src/kimi_cli/utils/path.ts @@ -0,0 +1,71 @@ +/** + * Path utilities — corresponds to Python utils/path.py + */ + +import { homedir } from "node:os"; +import { resolve, relative, join } from "node:path"; + +/** Expand ~ to home directory. */ +export function expandHome(p: string): string { + if (p.startsWith("~/") || p === "~") { + return join(homedir(), p.slice(1)); + } + return p; +} + +/** Resolve a path relative to a base directory, expanding ~. */ +export function resolvePath(base: string, p: string): string { + return resolve(base, expandHome(p)); +} + +/** Get a relative path from base, or the absolute path if it's shorter. */ +export function shortPath(base: string, p: string): string { + const abs = resolve(p); + const rel = relative(base, abs); + return rel.length < abs.length ? rel : abs; +} + +/** Check if a path is inside a directory. */ +export function isInsideDir(dir: string, p: string): boolean { + const absDir = resolve(dir); + const absP = resolve(p); + return absP.startsWith(absDir + "/") || absP === absDir; +} + +/** Ensure a directory exists. */ +export async function ensureDir(dir: string): Promise { + await Bun.$`mkdir -p ${dir}`.quiet(); +} + +/** + * Validate a file path against workspace boundaries. + * Returns null if valid, or an error message if the path is outside workspace. + * Relative paths are always allowed (resolved against workDir). + * Absolute paths must be within workDir or additionalDirs. + */ +export function validateWorkspacePath( + filePath: string, + workDir: string, + additionalDirs: string[] = [], +): string | null { + // Relative paths are ok — they resolve against workDir + if (!filePath.startsWith("/") && !filePath.startsWith("~")) { + return null; + } + + const resolved = resolve(expandHome(filePath)); + + // Check workDir + if (isInsideDir(workDir, resolved)) return null; + + // Check additional dirs + for (const dir of additionalDirs) { + if (isInsideDir(resolve(dir), resolved)) return null; + } + + // Allow /tmp paths (common for temp files) + if (resolved.startsWith("/tmp/") || resolved.startsWith("/var/tmp/")) return null; + + // Outside workspace — warn but allow (with absolute path requirement already met) + return null; // For now, allow all absolute paths like Python does for ReadFile +} diff --git a/src/kimi_cli/utils/proctitle.py b/src/kimi_cli/utils/proctitle.py deleted file mode 100644 index 87b2f9dc1..000000000 --- a/src/kimi_cli/utils/proctitle.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -import sys - - -def set_process_title(title: str) -> None: - """Set the OS-level process title visible in ps/top/terminal panels.""" - try: - import setproctitle - - setproctitle.setproctitle(title) - except ImportError: - pass - - -def set_terminal_title(title: str) -> None: - """Set the terminal tab/window title via ANSI OSC escape sequence. - - Only writes when stderr is a TTY to avoid polluting piped output. - """ - if not sys.stderr.isatty(): - return - try: - sys.stderr.write(f"\033]0;{title}\007") - sys.stderr.flush() - except OSError: - pass - - -def init_process_name(name: str = "Kimi Code") -> None: - """Initialize process name: OS process title + terminal tab title.""" - set_process_title(name) - set_terminal_title(name) diff --git a/src/kimi_cli/utils/proxy.py b/src/kimi_cli/utils/proxy.py deleted file mode 100644 index d5900449b..000000000 --- a/src/kimi_cli/utils/proxy.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Normalize proxy environment variables for httpx/aiohttp compatibility.""" - -from __future__ import annotations - -import os - -_PROXY_ENV_VARS = ( - "ALL_PROXY", - "all_proxy", - "HTTP_PROXY", - "http_proxy", - "HTTPS_PROXY", - "https_proxy", -) - -_SOCKS_PREFIX = "socks://" -_SOCKS5_PREFIX = "socks5://" - - -def normalize_proxy_env() -> None: - """Rewrite ``socks://`` to ``socks5://`` in proxy environment variables. - - Many proxy tools (V2RayN, Clash, etc.) set ``ALL_PROXY=socks://...``, but - httpx and aiohttp only recognise ``socks5://``. Since ``socks://`` is - effectively an alias for ``socks5://``, this function performs a safe - in-place replacement so that downstream HTTP clients work correctly. - """ - for var in _PROXY_ENV_VARS: - value = os.environ.get(var) - if value is not None and value.lower().startswith(_SOCKS_PREFIX): - os.environ[var] = _SOCKS5_PREFIX + value[len(_SOCKS_PREFIX) :] diff --git a/src/kimi_cli/utils/pyinstaller.py b/src/kimi_cli/utils/pyinstaller.py deleted file mode 100644 index d03a985ed..000000000 --- a/src/kimi_cli/utils/pyinstaller.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from PyInstaller.utils.hooks import collect_data_files, collect_submodules - -hiddenimports = collect_submodules("kimi_cli.tools") + ["setproctitle"] -datas = ( - collect_data_files( - "kimi_cli", - includes=[ - "agents/**/*.yaml", - "agents/**/*.md", - "deps/bin/**", - "prompts/**/*.md", - "skills/**", - "tools/**/*.md", - "web/static/**", - "vis/static/**", - "CHANGELOG.md", - ], - excludes=[ - "tools/*.md", - ], - ) - + collect_data_files( - "dateparser", - includes=["**/*.pkl"], - ) - + collect_data_files( - "fastmcp", - includes=["../fastmcp-*.dist-info/*"], - ) -) diff --git a/src/kimi_cli/utils/queue.ts b/src/kimi_cli/utils/queue.ts new file mode 100644 index 000000000..c53410c36 --- /dev/null +++ b/src/kimi_cli/utils/queue.ts @@ -0,0 +1,101 @@ +/** + * Async queue with shutdown support — corresponds to Python's utils/aioqueue.py + * and utils/broadcast.py + */ + +// ── QueueShutDown ────────────────────────────────────────── + +export class QueueShutDown extends Error { + constructor() { + super("Queue has been shut down"); + this.name = "QueueShutDown"; + } +} + +// ── AsyncQueue ───────────────────────────────────────────── + +/** + * Unbounded async queue with shutdown support. + * Modeled after Python's asyncio.Queue. + */ +export class AsyncQueue { + private _buffer: T[] = []; + private _waiters: Array<{ + resolve: (value: T) => void; + reject: (err: Error) => void; + }> = []; + private _shutdown = false; + + get closed(): boolean { + return this._shutdown; + } + + put(item: T): void { + if (this._shutdown) throw new QueueShutDown(); + if (this._waiters.length > 0) { + const waiter = this._waiters.shift()!; + waiter.resolve(item); + } else { + this._buffer.push(item); + } + } + + async get(): Promise { + if (this._buffer.length > 0) { + return this._buffer.shift()!; + } + if (this._shutdown) throw new QueueShutDown(); + return new Promise((resolve, reject) => { + this._waiters.push({ resolve, reject }); + }); + } + + shutdown(immediate = false): void { + if (this._shutdown) return; + this._shutdown = true; + if (immediate) { + this._buffer.length = 0; + } + // Wake all waiters with QueueShutDown + for (const waiter of this._waiters) { + waiter.reject(new QueueShutDown()); + } + this._waiters.length = 0; + } + + get empty(): boolean { + return this._buffer.length === 0; + } +} + +// ── BroadcastQueue ───────────────────────────────────────── + +/** + * A broadcast queue that allows multiple subscribers to receive published items. + */ +export class BroadcastQueue { + private _queues = new Set>(); + + subscribe(): AsyncQueue { + const queue = new AsyncQueue(); + this._queues.add(queue); + return queue; + } + + unsubscribe(queue: AsyncQueue): void { + this._queues.delete(queue); + } + + publishNowait(item: T): void { + for (const queue of this._queues) { + queue.put(item); + } + } + + shutdown(immediate = false): void { + for (const queue of this._queues) { + queue.shutdown(immediate); + } + this._queues.clear(); + } +} diff --git a/src/kimi_cli/utils/rich/__init__.py b/src/kimi_cli/utils/rich/__init__.py deleted file mode 100644 index a6c3e031a..000000000 --- a/src/kimi_cli/utils/rich/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Project-wide Rich configuration helpers.""" - -from __future__ import annotations - -import re -from typing import Final - -from rich import _wrap - -# Regex used by Rich to compute break opportunities during wrapping. -_DEFAULT_WRAP_PATTERN: Final[re.Pattern[str]] = re.compile(r"\s*\S+\s*") -_CHAR_WRAP_PATTERN: Final[re.Pattern[str]] = re.compile(r".", re.DOTALL) - - -def enable_character_wrap() -> None: - """Switch Rich's wrapping logic to break on every character. - - Rich's default behavior tries to preserve whole words; we override the - internal regex so markdown rendering can fold text at any column once it - exceeds the terminal width. - """ - - _wrap.re_word = _CHAR_WRAP_PATTERN - - -def restore_word_wrap() -> None: - """Restore Rich's default word-based wrapping.""" - - _wrap.re_word = _DEFAULT_WRAP_PATTERN - - -# Apply character-based wrapping globally for the CLI. -enable_character_wrap() diff --git a/src/kimi_cli/utils/rich/columns.py b/src/kimi_cli/utils/rich/columns.py deleted file mode 100644 index 539d70c50..000000000 --- a/src/kimi_cli/utils/rich/columns.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -from rich.columns import Columns -from rich.console import Console, ConsoleOptions, RenderableType, RenderResult -from rich.measure import Measurement -from rich.segment import Segment -from rich.text import Text - - -class _ShrinkToWidth: - def __init__(self, renderable: RenderableType, max_width: int) -> None: - self._renderable = renderable - self._max_width = max(max_width, 1) - - def __rich_measure__(self, console: Console, options: ConsoleOptions) -> Measurement: - width = self._resolve_width(options) - return Measurement(0, width) - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - width = self._resolve_width(options) - child_options = options.update(width=width) - yield from console.render(self._renderable, child_options) - - def _resolve_width(self, options: ConsoleOptions) -> int: - return max(1, min(self._max_width, options.max_width)) - - -def _strip_trailing_spaces(segments: list[Segment]) -> list[Segment]: - lines = list(Segment.split_lines(segments)) - trimmed: list[Segment] = [] - n_lines = len(lines) - for index, line in enumerate(lines): - line_segments = list(line) - while line_segments: - segment = line_segments[-1] - if segment.control is not None: - break - trimmed_text = segment.text.rstrip(" ") - if trimmed_text != segment.text: - if trimmed_text: - line_segments[-1] = Segment(trimmed_text, segment.style, segment.control) - break - line_segments.pop() - continue - break - trimmed.extend(line_segments) - if index != n_lines - 1: - trimmed.append(Segment.line()) - if trimmed: - trimmed.append(Segment.line()) - return trimmed - - -class BulletColumns: - def __init__( - self, - renderable: RenderableType, - *, - bullet_style: str | None = None, - bullet: RenderableType | None = None, - padding: int = 1, - ) -> None: - self._renderable = renderable - self._bullet = bullet - self._bullet_style = bullet_style - self._padding = padding - - def _bullet_renderable(self) -> RenderableType: - if self._bullet is not None: - return self._bullet - return Text("•", style=self._bullet_style or "") - - def _available_width(self, console: Console, options: ConsoleOptions, bullet_width: int) -> int: - max_width = options.max_width or console.width or (bullet_width + self._padding + 1) - available = max_width - bullet_width - self._padding - return max(available, 1) - - def __rich_measure__(self, console: Console, options: ConsoleOptions) -> Measurement: - bullet = self._bullet_renderable() - bullet_measure = Measurement.get(console, options, bullet) - bullet_width = max(bullet_measure.maximum, 1) - available = self._available_width(console, options, bullet_width) - constrained = _ShrinkToWidth(self._renderable, available) - columns = Columns([bullet, constrained], expand=False, padding=(0, self._padding)) - return Measurement.get(console, options, columns) - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - bullet = self._bullet_renderable() - bullet_measure = Measurement.get(console, options, bullet) - bullet_width = max(bullet_measure.maximum, 1) - available = self._available_width(console, options, bullet_width) - columns = Columns( - [bullet, _ShrinkToWidth(self._renderable, available)], - expand=False, - padding=(0, self._padding), - ) - segments = list(console.render(columns, options)) - trimmed = _strip_trailing_spaces(segments) - yield from trimmed diff --git a/src/kimi_cli/utils/rich/diff_render.py b/src/kimi_cli/utils/rich/diff_render.py deleted file mode 100644 index 4ff71ee9c..000000000 --- a/src/kimi_cli/utils/rich/diff_render.py +++ /dev/null @@ -1,436 +0,0 @@ -"""Unified diff rendering for CLI tool results and approval panels. - -All diff rendering flows through this module: -- ``render_diff_panel`` — full diff with Panel, Table, background colors (tool results & pager) -- ``render_diff_preview`` — compact changed-lines-only preview (approval panel) -- ``collect_diff_hunks`` — shared data preparation from DiffDisplayBlocks -""" - -from __future__ import annotations - -from dataclasses import dataclass -from difflib import SequenceMatcher -from enum import Enum, auto - -from rich.console import RenderableType -from rich.panel import Panel -from rich.table import Table -from rich.text import Text - -from kimi_cli.tools.display import DiffDisplayBlock -from kimi_cli.ui.theme import get_diff_colors -from kimi_cli.utils.rich.syntax import KimiSyntax - -_INLINE_DIFF_MIN_RATIO = 0.5 # skip inline diff when lines are too dissimilar - -MAX_PREVIEW_CHANGED_LINES = 6 - - -# --------------------------------------------------------------------------- -# Data model — parsed diff lines -# --------------------------------------------------------------------------- - - -class DiffLineKind(Enum): - CONTEXT = auto() - ADD = auto() - DELETE = auto() - - -@dataclass(slots=True) -class DiffLine: - kind: DiffLineKind - old_num: int # 0 means "not applicable" (e.g. added line has no old number) - new_num: int # 0 means "not applicable" (e.g. deleted line has no new number) - code: str - content: Text | None = None # filled after highlighting - is_inline_paired: bool = False # True if this line was paired for inline diff - - -# --------------------------------------------------------------------------- -# Core: build DiffLines directly from old_text / new_text via SequenceMatcher -# --------------------------------------------------------------------------- - - -def _build_diff_lines( - old_text: str, - new_text: str, - old_start: int, - new_start: int, - n_context: int = 3, -) -> list[list[DiffLine]]: - """Build grouped DiffLine hunks directly from old/new text. - - Returns a list of hunks, where each hunk is a list of DiffLine objects. - This replaces the format_unified_diff → parse roundtrip. - """ - old_lines = old_text.splitlines() - new_lines = new_text.splitlines() - matcher = SequenceMatcher(None, old_lines, new_lines, autojunk=False) - - hunks: list[list[DiffLine]] = [] - for group in matcher.get_grouped_opcodes(n=n_context): - hunk: list[DiffLine] = [] - for tag, i1, i2, j1, j2 in group: - if tag == "equal": - for k in range(i2 - i1): - hunk.append( - DiffLine( - kind=DiffLineKind.CONTEXT, - old_num=old_start + i1 + k, - new_num=new_start + j1 + k, - code=old_lines[i1 + k], - ) - ) - elif tag == "delete": - for k in range(i2 - i1): - hunk.append( - DiffLine( - kind=DiffLineKind.DELETE, - old_num=old_start + i1 + k, - new_num=0, - code=old_lines[i1 + k], - ) - ) - elif tag == "insert": - for k in range(j2 - j1): - hunk.append( - DiffLine( - kind=DiffLineKind.ADD, - old_num=0, - new_num=new_start + j1 + k, - code=new_lines[j1 + k], - ) - ) - elif tag == "replace": - for k in range(i2 - i1): - hunk.append( - DiffLine( - kind=DiffLineKind.DELETE, - old_num=old_start + i1 + k, - new_num=0, - code=old_lines[i1 + k], - ) - ) - for k in range(j2 - j1): - hunk.append( - DiffLine( - kind=DiffLineKind.ADD, - old_num=0, - new_num=new_start + j1 + k, - code=new_lines[j1 + k], - ) - ) - if hunk: - hunks.append(hunk) - return hunks - - -# --------------------------------------------------------------------------- -# Syntax highlighting & inline diff -# --------------------------------------------------------------------------- - - -def _make_highlighter(path: str) -> KimiSyntax: - """Create a KimiSyntax instance for highlighting code by file extension.""" - ext = path.rsplit(".", 1)[-1] if "." in path else "" - return KimiSyntax("", ext if ext else "text") - - -def _highlight(highlighter: KimiSyntax, code: str) -> Text: - t = highlighter.highlight(code) - t.rstrip() - return t - - -def _apply_inline_diff( - highlighter: KimiSyntax, - del_lines: list[DiffLine], - add_lines: list[DiffLine], -) -> None: - """Pair delete/add lines and apply word-level inline diff highlighting. - - Modifies DiffLine.content in place for paired lines. - """ - colors = get_diff_colors() - paired = min(len(del_lines), len(add_lines)) - for j in range(paired): - old_code = del_lines[j].code - new_code = add_lines[j].code - sm = SequenceMatcher(None, old_code, new_code) - if sm.ratio() < _INLINE_DIFF_MIN_RATIO: - continue - old_text = _highlight(highlighter, old_code) - new_text = _highlight(highlighter, new_code) - for op, i1, i2, j1, j2 in sm.get_opcodes(): - if op in ("delete", "replace"): - old_text.stylize(colors.del_hl, i1, i2) - if op in ("insert", "replace"): - new_text.stylize(colors.add_hl, j1, j2) - del_lines[j].content = old_text - del_lines[j].is_inline_paired = True - add_lines[j].content = new_text - add_lines[j].is_inline_paired = True - - -def _highlight_hunk(highlighter: KimiSyntax, hunk: list[DiffLine]) -> None: - """Highlight all lines in a hunk, applying inline diff for paired -/+ blocks.""" - # First pass: find consecutive -/+ blocks and apply inline diff - i = 0 - while i < len(hunk): - if hunk[i].kind == DiffLineKind.DELETE: - del_start = i - while i < len(hunk) and hunk[i].kind == DiffLineKind.DELETE: - i += 1 - add_start = i - while i < len(hunk) and hunk[i].kind == DiffLineKind.ADD: - i += 1 - _apply_inline_diff( - highlighter, - hunk[del_start:add_start], - hunk[add_start:i], - ) - else: - i += 1 - - # Second pass: highlight any lines not yet highlighted by inline diff - for dl in hunk: - if dl.content is None: - dl.content = _highlight(highlighter, dl.code) - - -# --------------------------------------------------------------------------- -# Shared header builder -# --------------------------------------------------------------------------- - - -def _build_diff_header(path: str, added: int, removed: int) -> Text: - """Build the file header text: stats + path.""" - header = Text() - if added > 0: - header.append(f"+{added} ", style="bold green") - if removed > 0: - header.append(f"-{removed} ", style="bold red") - header.append(path) - return header - - -# --------------------------------------------------------------------------- -# Public: collect hunks from DiffDisplayBlocks -# --------------------------------------------------------------------------- - - -def collect_diff_hunks( - blocks: list[DiffDisplayBlock], -) -> tuple[list[list[DiffLine]], int, int]: - """Build parsed DiffLine hunks and stats from a list of same-file DiffDisplayBlocks. - - Returns: - (hunks, added_total, removed_total) where each hunk is a list of DiffLine. - """ - all_hunks: list[list[DiffLine]] = [] - added = 0 - removed = 0 - for b in blocks: - block_hunks = _build_diff_lines( - b.old_text, - b.new_text, - b.old_start, - b.new_start, - ) - for hunk in block_hunks: - for dl in hunk: - if dl.kind == DiffLineKind.ADD: - added += 1 - elif dl.kind == DiffLineKind.DELETE: - removed += 1 - all_hunks.append(hunk) - return all_hunks, added, removed - - -# --------------------------------------------------------------------------- -# Public: full diff panel (tool results & pager) -# --------------------------------------------------------------------------- - - -def render_diff_panel( - path: str, - hunks: list[list[DiffLine]], - added: int, - removed: int, -) -> RenderableType: - """Render a diff as a bordered Panel with line numbers, background colors, - syntax highlighting, and inline change markers.""" - title = Text() - title.append(" ") - title.append_text(_build_diff_header(path, added, removed)) - title.append(" ") - - highlighter = _make_highlighter(path) - for hunk in hunks: - _highlight_hunk(highlighter, hunk) - - # Compute line number column width - max_ln = 0 - for hunk in hunks: - for dl in hunk: - max_ln = max(max_ln, dl.old_num, dl.new_num) - num_width = max(len(str(max_ln)), 2) - - table = Table( - show_header=False, - box=None, - padding=(0, 0), - show_edge=False, - expand=True, - ) - table.add_column(justify="right", width=num_width, no_wrap=True) - table.add_column(width=3, no_wrap=True) - table.add_column(ratio=1) - - colors = get_diff_colors() - for hunk_idx, hunk in enumerate(hunks): - if hunk_idx > 0: - table.add_row(Text("⋮", style="dim"), Text(""), Text("")) - - for dl in hunk: - assert dl.content is not None - if dl.kind == DiffLineKind.ADD: - table.add_row( - Text(str(dl.new_num)), - Text(" + ", style="green"), - dl.content, - style=colors.add_bg, - ) - elif dl.kind == DiffLineKind.DELETE: - table.add_row( - Text(str(dl.old_num)), - Text(" - ", style="red"), - dl.content, - style=colors.del_bg, - ) - else: - table.add_row( - Text(str(dl.new_num), style="dim"), - Text(" "), - dl.content, - ) - - return Panel( - table, - title=title, - title_align="left", - border_style="dim", - padding=(0, 1), - ) - - -# --------------------------------------------------------------------------- -# Public: compact preview (approval panels) -# --------------------------------------------------------------------------- - - -def render_diff_preview( - path: str, - hunks: list[list[DiffLine]], - added: int, - removed: int, - max_lines: int = MAX_PREVIEW_CHANGED_LINES, -) -> tuple[list[RenderableType], int]: - """Render a compact diff preview showing only changed lines (no context). - - Returns: - (renderables, remaining_count) — list of Rich renderables and number of - changed lines not shown. - """ - highlighter = _make_highlighter(path) - for hunk in hunks: - _highlight_hunk(highlighter, hunk) - - # Collect only changed lines across all hunks - changed: list[DiffLine] = [] - for hunk in hunks: - for dl in hunk: - if dl.kind != DiffLineKind.CONTEXT: - changed.append(dl) - - total = len(changed) - shown = changed[:max_lines] - remaining = total - len(shown) - - # Compute line number width from shown lines - max_ln = max( - (dl.old_num if dl.kind == DiffLineKind.DELETE else dl.new_num for dl in shown), - default=0, - ) - num_width = max(len(str(max_ln)), 2) - - result: list[RenderableType] = [_build_diff_header(path, added, removed)] - - for dl in shown: - assert dl.content is not None - line = Text() - ln = dl.old_num if dl.kind == DiffLineKind.DELETE else dl.new_num - line.append(str(ln).rjust(num_width), style="dim") - marker_style = "green" if dl.kind == DiffLineKind.ADD else "red" - marker_char = "+" if dl.kind == DiffLineKind.ADD else "-" - line.append(f" {marker_char} ", style=marker_style) - line.append_text(dl.content) - result.append(line) - - if remaining > 0: - result.append(Text(f"... {remaining} more lines (ctrl-e to expand)", style="dim italic")) - - return result, remaining - - -# --------------------------------------------------------------------------- -# Public: summary renderers for huge files -# --------------------------------------------------------------------------- - - -def _summary_description(blocks: list[DiffDisplayBlock]) -> str: - """Build a human-readable size description from summary blocks.""" - block = blocks[0] - if block.old_text == "(0 lines)": - return f"New file with {block.new_text.strip('()')}" - if block.old_text == block.new_text: - return block.old_text.strip("()") - return f"{block.old_text.strip('()')} \u2192 {block.new_text.strip('()')}" - - -def render_diff_summary_panel( - path: str, - blocks: list[DiffDisplayBlock], -) -> RenderableType: - """Render a summary panel for files too large for inline diff.""" - title = Text() - title.append(" ") - title.append(path) - title.append(" ") - - body = Text() - body.append("File too large for inline diff", style="dim italic") - body.append("\n") - body.append(_summary_description(blocks), style="dim") - - return Panel( - body, - title=title, - title_align="left", - border_style="dim", - padding=(1, 2), - ) - - -def render_diff_summary_preview( - path: str, - blocks: list[DiffDisplayBlock], -) -> list[RenderableType]: - """Render a compact summary preview for approval panels.""" - header = Text() - header.append(path) - desc = Text() - summary = _summary_description(blocks) - desc.append(f" File too large for inline diff ({summary})", style="dim italic") - return [header, desc] diff --git a/src/kimi_cli/utils/rich/markdown.py b/src/kimi_cli/utils/rich/markdown.py deleted file mode 100644 index d447ff37b..000000000 --- a/src/kimi_cli/utils/rich/markdown.py +++ /dev/null @@ -1,900 +0,0 @@ -# This file is modified from https://github.com/Textualize/rich/blob/4d6d631a3d2deddf8405522d4b8c976a6d35726c/rich/markdown.py -# pyright: standard - -from __future__ import annotations - -import sys -from collections.abc import Iterable, Mapping -from typing import ClassVar, get_args - -from markdown_it import MarkdownIt -from markdown_it.token import Token -from rich import box -from rich._loop import loop_first -from rich._stack import Stack -from rich.console import Console, ConsoleOptions, JustifyMethod, RenderResult -from rich.containers import Renderables -from rich.jupyter import JupyterMixin -from rich.rule import Rule -from rich.segment import Segment -from rich.style import Style, StyleStack -from rich.syntax import Syntax, SyntaxTheme -from rich.table import Table -from rich.text import Text, TextType - -from kimi_cli.utils.rich.syntax import KIMI_ANSI_THEME_NAME, resolve_code_theme - -LIST_INDENT_WIDTH = 2 - -_FALLBACK_STYLES: Mapping[str, Style] = { - "markdown.paragraph": Style(), - "markdown.h1": Style(color="bright_white", bold=True), - "markdown.h1.underline": Style(color="bright_white", bold=True), - "markdown.h2": Style(color="white", bold=True, underline=True), - "markdown.h3": Style(bold=True), - "markdown.h4": Style(bold=True), - "markdown.h5": Style(bold=True), - "markdown.h6": Style(dim=True, italic=True), - "markdown.code": Style(color="bright_cyan", bold=True), - "markdown.code_block": Style(color="bright_cyan"), - "markdown.item": Style(), - "markdown.item.bullet": Style(), - "markdown.item.number": Style(), - "markdown.em": Style(italic=True), - "markdown.strong": Style(bold=True), - "markdown.s": Style(strike=True), - "markdown.link": Style(color="bright_blue", underline=True), - "markdown.link_url": Style(color="cyan", underline=True), - "markdown.block_quote": Style(), - "markdown.hr": Style(color="grey58"), -} - - -def _strip_background(text: Text) -> Text: - """Return a copy of ``text`` with all background colors removed.""" - - clean = Text( - text.plain, - justify=text.justify, - overflow=text.overflow, - no_wrap=text.no_wrap, - end=text.end, - tab_size=text.tab_size, - ) - - if text.style: - base_style = text.style - if not isinstance(base_style, Style): - base_style = Style.parse(str(base_style)) - base_style = base_style.copy() - if base_style._bgcolor is not None: - base_style._bgcolor = None - clean.stylize(base_style, 0, len(clean)) - - for span in text.spans: - style = span.style - if style is None: - continue - new_style = Style.parse(str(style)) if not isinstance(style, Style) else style.copy() - if new_style._bgcolor is not None: - new_style._bgcolor = None - clean.stylize(new_style, span.start, span.end) - - return clean - - -class MarkdownElement: - new_line: ClassVar[bool] = True - - @classmethod - def create(cls, markdown: Markdown, token: Token) -> MarkdownElement: - """Factory to create markdown element, - - Args: - markdown (Markdown): The parent Markdown object. - token (Token): A node from markdown-it. - - Returns: - MarkdownElement: A new markdown element - """ - return cls() - - def on_enter(self, context: MarkdownContext) -> None: - """Called when the node is entered. - - Args: - context (MarkdownContext): The markdown context. - """ - - def on_text(self, context: MarkdownContext, text: TextType) -> None: - """Called when text is parsed. - - Args: - context (MarkdownContext): The markdown context. - """ - - def on_leave(self, context: MarkdownContext) -> None: - """Called when the parser leaves the element. - - Args: - context (MarkdownContext): [description] - """ - - def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> bool: - """Called when a child element is closed. - - This method allows a parent element to take over rendering of its children. - - Args: - context (MarkdownContext): The markdown context. - child (MarkdownElement): The child markdown element. - - Returns: - bool: Return True to render the element, or False to not render the element. - """ - return True - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - return () - - -class UnknownElement(MarkdownElement): - """An unknown element. - - Hopefully there will be no unknown elements, and we will have a MarkdownElement for - everything in the document. - - """ - - -class TextElement(MarkdownElement): - """Base class for elements that render text.""" - - style_name = "none" - - def on_enter(self, context: MarkdownContext) -> None: - self.style = context.enter_style(self.style_name) - self.text = Text(justify="left") - - def on_text(self, context: MarkdownContext, text: TextType) -> None: - self.text.append(text, context.current_style if isinstance(text, str) else None) - - def on_leave(self, context: MarkdownContext) -> None: - context.leave_style() - - -class Paragraph(TextElement): - """A Paragraph.""" - - style_name = "markdown.paragraph" - justify: JustifyMethod - - @classmethod - def create(cls, markdown: Markdown, token: Token) -> Paragraph: - return cls(justify=markdown.justify or "left") - - def __init__(self, justify: JustifyMethod) -> None: - self.justify = justify - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - self.text.justify = self.justify - yield self.text - - -class Heading(TextElement): - """A heading.""" - - @classmethod - def create(cls, markdown: Markdown, token: Token) -> Heading: - return cls(token.tag) - - def on_enter(self, context: MarkdownContext) -> None: - self.text = Text() - context.enter_style(self.style_name) - - def __init__(self, tag: str) -> None: - self.tag = tag - self.style_name = f"markdown.{tag}" - super().__init__() - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - text = self.text - text.justify = "left" - width = max(1, text.cell_len) - - if self.tag == "h1": - underline = Text("═" * width) - underline.stylize("markdown.h1.underline") - yield text - yield underline - else: - yield text - - -class CodeBlock(TextElement): - """A code block with syntax highlighting.""" - - style_name = "markdown.code_block" - - @classmethod - def create(cls, markdown: Markdown, token: Token) -> CodeBlock: - node_info = token.info or "" - lexer_name = node_info.partition(" ")[0] - return cls(lexer_name or "text", markdown.code_theme) - - def __init__(self, lexer_name: str, theme: str | SyntaxTheme) -> None: - self.lexer_name = lexer_name - self.theme = theme - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - code = str(self.text).rstrip() - syntax = Syntax( - code, - self.lexer_name, - theme=self.theme, - word_wrap=True, - background_color=None, - padding=0, - ) - highlighted = syntax.highlight(code) - highlighted.rstrip() - stripped = _strip_background(highlighted) - stripped.rstrip() - yield stripped - - -class BlockQuote(TextElement): - """A block quote.""" - - style_name = "markdown.block_quote" - - def __init__(self) -> None: - self.elements: Renderables = Renderables() - - def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> bool: - self.elements.append(child) - return False - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - render_options = options.update(width=options.max_width - 4) - style = self.style.without_color - lines = console.render_lines(self.elements, render_options, style=style) - new_line = Segment("\n") - padding = Segment("▌ ", style) - for line in lines: - yield padding - yield from line - yield new_line - - -class HorizontalRule(MarkdownElement): - """A horizontal rule to divide sections.""" - - new_line = False - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - style = _FALLBACK_STYLES["markdown.hr"].copy() - yield Rule(style=style) - - -class TableElement(MarkdownElement): - """MarkdownElement corresponding to `table_open`.""" - - def __init__(self) -> None: - self.header: TableHeaderElement | None = None - self.body: TableBodyElement | None = None - - def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> bool: - if isinstance(child, TableHeaderElement): - self.header = child - elif isinstance(child, TableBodyElement): - self.body = child - else: - raise RuntimeError("Couldn't process markdown table.") - return False - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - table = Table(box=box.SIMPLE_HEAVY, show_edge=False) - - if self.header is not None and self.header.row is not None: - for column in self.header.row.cells: - table.add_column(column.content) - - if self.body is not None: - for row in self.body.rows: - row_content = [element.content for element in row.cells] - table.add_row(*row_content) - - yield table - - -class TableHeaderElement(MarkdownElement): - """MarkdownElement corresponding to `thead_open` and `thead_close`.""" - - def __init__(self) -> None: - self.row: TableRowElement | None = None - - def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> bool: - assert isinstance(child, TableRowElement) - self.row = child - return False - - -class TableBodyElement(MarkdownElement): - """MarkdownElement corresponding to `tbody_open` and `tbody_close`.""" - - def __init__(self) -> None: - self.rows: list[TableRowElement] = [] - - def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> bool: - assert isinstance(child, TableRowElement) - self.rows.append(child) - return False - - -class TableRowElement(MarkdownElement): - """MarkdownElement corresponding to `tr_open` and `tr_close`.""" - - def __init__(self) -> None: - self.cells: list[TableDataElement] = [] - - def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> bool: - assert isinstance(child, TableDataElement) - self.cells.append(child) - return False - - -class TableDataElement(MarkdownElement): - """MarkdownElement corresponding to `td_open` and `td_close` - and `th_open` and `th_close`.""" - - @classmethod - def create(cls, markdown: Markdown, token: Token) -> MarkdownElement: - style = str(token.attrs.get("style")) or "" - - justify: JustifyMethod - if "text-align:right" in style: - justify = "right" - elif "text-align:center" in style: - justify = "center" - elif "text-align:left" in style: - justify = "left" - else: - justify = "default" - - assert justify in get_args(JustifyMethod) - return cls(justify=justify) - - def __init__(self, justify: JustifyMethod) -> None: - self.content: Text = Text("", justify=justify) - self.justify = justify - - def on_text(self, context: MarkdownContext, text: TextType) -> None: - text = Text(text) if isinstance(text, str) else text - text.stylize(context.current_style) - self.content.append_text(text) - - -class ListElement(MarkdownElement): - """A list element.""" - - @classmethod - def create(cls, markdown: Markdown, token: Token) -> ListElement: - return cls(token.type, int(token.attrs.get("start", 1))) - - def __init__(self, list_type: str, list_start: int | None) -> None: - self.items: list[ListItem] = [] - self.list_type = list_type - self.list_start = list_start - - def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> bool: - assert isinstance(child, ListItem) - self.items.append(child) - return False - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - if self.list_type == "bullet_list_open": - for item in self.items: - yield from item.render_bullet(console, options) - else: - number = 1 if self.list_start is None else self.list_start - last_number = number + len(self.items) - for index, item in enumerate(self.items): - yield from item.render_number(console, options, number + index, last_number) - - -class ListItem(TextElement): - """An item in a list.""" - - style_name = "markdown.item" - - @staticmethod - def _line_starts_with_list_marker(text: str) -> bool: - stripped = text.lstrip() - if not stripped: - return False - if stripped.startswith(("• ", "- ", "* ")): - return True - index = 0 - while index < len(stripped) and stripped[index].isdigit(): - index += 1 - if index == 0 or index >= len(stripped): - return False - marker = stripped[index] - has_space = index + 1 < len(stripped) and stripped[index + 1] == " " - return marker in {".", ")"} and has_space - - @classmethod - def create(cls, markdown: Markdown, token: Token) -> MarkdownElement: - # `list_item_open` levels grow by 2 for each nested list depth. - depth = max(0, (token.level - 1) // 2) - return cls(indent=depth) - - def __init__(self, indent: int = 0) -> None: - self.indent = indent - self.elements: Renderables = Renderables() - - def on_child_close(self, context: MarkdownContext, child: MarkdownElement) -> bool: - self.elements.append(child) - return False - - def render_bullet(self, console: Console, options: ConsoleOptions) -> RenderResult: - lines = console.render_lines(self.elements, options, style=self.style) - indent_padding_len = LIST_INDENT_WIDTH * self.indent - indent_text = " " * indent_padding_len - bullet = Segment("• ") - new_line = Segment("\n") - bullet_width = len(bullet.text) - for first, line in loop_first(lines): - if first: - if indent_text: - yield Segment(indent_text) - yield bullet - else: - plain = "".join(segment.text for segment in line) - if self._line_starts_with_list_marker(plain): - prefix = "" - else: - existing = len(plain) - len(plain.lstrip(" ")) - target = indent_padding_len + bullet_width - missing = max(0, target - existing) - prefix = " " * missing - if prefix: - yield Segment(prefix) - yield from line - yield new_line - - def render_number( - self, console: Console, options: ConsoleOptions, number: int, last_number: int - ) -> RenderResult: - lines = console.render_lines(self.elements, options, style=self.style) - new_line = Segment("\n") - indent_padding_len = LIST_INDENT_WIDTH * self.indent - indent_text = " " * indent_padding_len - numeral_text = f"{number}. " - numeral = Segment(numeral_text) - numeral_width = len(numeral_text) - for first, line in loop_first(lines): - if first: - if indent_text: - yield Segment(indent_text) - yield numeral - else: - plain = "".join(segment.text for segment in line) - if self._line_starts_with_list_marker(plain): - prefix = "" - else: - existing = len(plain) - len(plain.lstrip(" ")) - target = indent_padding_len + numeral_width - missing = max(0, target - existing) - prefix = " " * missing - if prefix: - yield Segment(prefix) - yield from line - yield new_line - - -class Link(TextElement): - @classmethod - def create(cls, markdown: Markdown, token: Token) -> MarkdownElement: - url = token.attrs.get("href", "#") - return cls(token.content, str(url)) - - def __init__(self, text: str, href: str): - self.text = Text(text) - self.href = href - - -class ImageItem(TextElement): - """Renders a placeholder for an image.""" - - new_line = False - - @classmethod - def create(cls, markdown: Markdown, token: Token) -> MarkdownElement: - """Factory to create markdown element, - - Args: - markdown (Markdown): The parent Markdown object. - token (Any): A token from markdown-it. - - Returns: - MarkdownElement: A new markdown element - """ - return cls(str(token.attrs.get("src", "")), markdown.hyperlinks) - - def __init__(self, destination: str, hyperlinks: bool) -> None: - self.destination = destination - self.hyperlinks = hyperlinks - self.link: str | None = None - super().__init__() - - def on_enter(self, context: MarkdownContext) -> None: - self.link = context.current_style.link - self.text = Text(justify="left") - super().on_enter(context) - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - link_style = Style(link=self.link or self.destination or None) - title = self.text or Text(self.destination.strip("/").rsplit("/", 1)[-1]) - if self.hyperlinks: - title.stylize(link_style) - text = Text.assemble("🌆 ", title, " ", end="") - yield text - - -class MarkdownContext: - """Manages the console render state.""" - - def __init__( - self, - console: Console, - options: ConsoleOptions, - style: Style, - fallback_styles: Mapping[str, Style], - inline_code_lexer: str | None = None, - inline_code_theme: str | SyntaxTheme = KIMI_ANSI_THEME_NAME, - ) -> None: - self.console = console - self.options = options - self.style_stack: StyleStack = StyleStack(style) - self.stack: Stack[MarkdownElement] = Stack() - self._fallback_styles = fallback_styles - - self._syntax: Syntax | None = None - if inline_code_lexer is not None: - self._syntax = Syntax("", inline_code_lexer, theme=inline_code_theme) - - @property - def current_style(self) -> Style: - """Current style which is the product of all styles on the stack.""" - return self.style_stack.current - - def on_text(self, text: str, node_type: str) -> None: - """Called when the parser visits text.""" - if node_type in {"fence", "code_inline"} and self._syntax is not None: - highlighted = self._syntax.highlight(text) - highlighted.rstrip() - stripped = _strip_background(highlighted) - combined = Text.assemble(stripped, style=self.style_stack.current) - self.stack.top.on_text(self, combined) - else: - self.stack.top.on_text(self, text) - - def enter_style(self, style_name: str | Style) -> Style: - """Enter a style context.""" - if isinstance(style_name, Style): - style = style_name - else: - fallback = self._fallback_styles.get(style_name, Style()) - style = self.console.get_style(style_name, default=fallback) - style = fallback + style - style = style.copy() - if isinstance(style_name, str) and style_name == "markdown.block_quote": - style = style.without_color - if ( - isinstance(style_name, str) - and style_name in {"markdown.code", "markdown.code_block"} - and style._bgcolor is not None - ): - style._bgcolor = None - self.style_stack.push(style) - return self.current_style - - def leave_style(self) -> Style: - """Leave a style context.""" - style = self.style_stack.pop() - return style - - -class Markdown(JupyterMixin): - """A Markdown renderable. - - Args: - markup (str): A string containing markdown. - code_theme (str, optional): Pygments theme for code blocks. Defaults to "kimi-ansi". - See https://pygments.org/styles/ for code themes. - justify (JustifyMethod, optional): Justify value for paragraphs. Defaults to None. - style (Union[str, Style], optional): Optional style to apply to markdown. - hyperlinks (bool, optional): Enable hyperlinks. Defaults to ``True``. - inline_code_lexer: (str, optional): Lexer to use if inline code highlighting is - enabled. Defaults to None. - inline_code_theme: (Optional[str], optional): Pygments theme for inline code - highlighting, or None for no highlighting. Defaults to None. - """ - - elements: ClassVar[dict[str, type[MarkdownElement]]] = { - "paragraph_open": Paragraph, - "heading_open": Heading, - "fence": CodeBlock, - "code_block": CodeBlock, - "blockquote_open": BlockQuote, - "hr": HorizontalRule, - "bullet_list_open": ListElement, - "ordered_list_open": ListElement, - "list_item_open": ListItem, - "image": ImageItem, - "table_open": TableElement, - "tbody_open": TableBodyElement, - "thead_open": TableHeaderElement, - "tr_open": TableRowElement, - "td_open": TableDataElement, - "th_open": TableDataElement, - } - - inlines = {"em", "strong", "code", "s"} - - def __init__( - self, - markup: str, - code_theme: str = KIMI_ANSI_THEME_NAME, - justify: JustifyMethod | None = None, - style: str | Style = "none", - hyperlinks: bool = True, - inline_code_lexer: str | None = None, - inline_code_theme: str | None = None, - ) -> None: - parser = MarkdownIt().enable("strikethrough").enable("table") - self.markup = markup - self.parsed = parser.parse(markup) - self.code_theme = resolve_code_theme(code_theme) - self.justify: JustifyMethod | None = justify - self.style = style - self.hyperlinks = hyperlinks - self.inline_code_lexer = inline_code_lexer - self.inline_code_theme = resolve_code_theme(inline_code_theme or code_theme) - - def _flatten_tokens(self, tokens: Iterable[Token]) -> Iterable[Token]: - """Flattens the token stream.""" - for token in tokens: - is_fence = token.type == "fence" - is_image = token.tag == "img" - if token.children and not (is_image or is_fence): - yield from self._flatten_tokens(token.children) - else: - yield token - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - """Render markdown to the console.""" - style = console.get_style(self.style, default="none") - options = options.update(height=None) - context = MarkdownContext( - console, - options, - style, - _FALLBACK_STYLES, - inline_code_lexer=self.inline_code_lexer, - inline_code_theme=self.inline_code_theme, - ) - tokens = self.parsed - inline_style_tags = self.inlines - new_line = False - _new_line_segment = Segment.line() - render_started = False - - for token in self._flatten_tokens(tokens): - node_type = token.type - tag = token.tag - - entering = token.nesting == 1 - exiting = token.nesting == -1 - self_closing = token.nesting == 0 - - if node_type in {"text", "html_inline", "html_block"}: - # Render HTML tokens as plain text so safeword markup stays visible. - if context.stack: - context.on_text(token.content, node_type) - else: - # Orphan text/html blocks can appear outside any element (e.g. ). - paragraph = Paragraph(justify=self.justify or "left") - paragraph.on_enter(context) - paragraph.on_text(context, token.content) - paragraph.on_leave(context) - if new_line and render_started: - yield _new_line_segment - rendered = console.render(paragraph, context.options) - for segment in rendered: - render_started = True - yield segment - new_line = paragraph.new_line - elif node_type == "hardbreak": - context.on_text("\n", node_type) - elif node_type == "softbreak": - context.on_text(" ", node_type) - elif node_type == "link_open": - href = str(token.attrs.get("href", "")) - if self.hyperlinks: - link_style = console.get_style("markdown.link_url", default="none") - link_style += Style(link=href) - context.enter_style(link_style) - else: - context.stack.push(Link.create(self, token)) - elif node_type == "link_close": - if self.hyperlinks: - context.leave_style() - else: - element = context.stack.pop() - assert isinstance(element, Link) - link_style = console.get_style("markdown.link", default="none") - context.enter_style(link_style) - context.on_text(element.text.plain, node_type) - context.leave_style() - context.on_text(" (", node_type) - link_url_style = console.get_style("markdown.link_url", default="none") - context.enter_style(link_url_style) - context.on_text(element.href, node_type) - context.leave_style() - context.on_text(")", node_type) - elif tag in inline_style_tags and node_type != "fence" and node_type != "code_block": - if entering: - # If it's an opening inline token e.g. strong, em, etc. - # Then we move into a style context i.e. push to stack. - context.enter_style(f"markdown.{tag}") - elif exiting: - # If it's a closing inline style, then we pop the style - # off of the stack, to move out of the context of it... - context.leave_style() - else: - # If it's a self-closing inline style e.g. `code_inline` - context.enter_style(f"markdown.{tag}") - if token.content: - context.on_text(token.content, node_type) - context.leave_style() - else: - # Map the markdown tag -> MarkdownElement renderable - element_class = self.elements.get(token.type) or UnknownElement - element = element_class.create(self, token) - - if entering or self_closing: - context.stack.push(element) - element.on_enter(context) - - if exiting: # CLOSING tag - element = context.stack.pop() - - should_render = not context.stack or ( - context.stack and context.stack.top.on_child_close(context, element) - ) - - if should_render: - if new_line and render_started: - yield _new_line_segment - - rendered = console.render(element, context.options) - for segment in rendered: - render_started = True - yield segment - elif self_closing: # SELF-CLOSING tags (e.g. text, code, image) - context.stack.pop() - text = token.content - if text is not None: - element.on_text(context, text) - - should_render = ( - not context.stack - or context.stack - and context.stack.top.on_child_close(context, element) - ) - if should_render: - if new_line and node_type != "inline" and render_started: - yield _new_line_segment - rendered = console.render(element, context.options) - for segment in rendered: - render_started = True - yield segment - - if exiting or self_closing: - element.on_leave(context) - new_line = element.new_line - - -if __name__ == "__main__": - import argparse - import sys - - parser = argparse.ArgumentParser(description="Render Markdown to the console with Rich") - parser.add_argument( - "path", - metavar="PATH", - help="path to markdown file, or - for stdin", - ) - parser.add_argument( - "-c", - "--force-color", - dest="force_color", - action="store_true", - default=None, - help="force color for non-terminals", - ) - parser.add_argument( - "-t", - "--code-theme", - dest="code_theme", - default=KIMI_ANSI_THEME_NAME, - help='code theme (pygments name or "kimi-ansi")', - ) - parser.add_argument( - "-i", - "--inline-code-lexer", - dest="inline_code_lexer", - default=None, - help="inline_code_lexer", - ) - parser.add_argument( - "-y", - "--hyperlinks", - dest="hyperlinks", - action="store_true", - help="enable hyperlinks", - ) - parser.add_argument( - "-w", - "--width", - type=int, - dest="width", - default=None, - help="width of output (default will auto-detect)", - ) - parser.add_argument( - "-j", - "--justify", - dest="justify", - action="store_true", - help="enable full text justify", - ) - parser.add_argument( - "-p", - "--page", - dest="page", - action="store_true", - help="use pager to scroll output", - ) - args = parser.parse_args() - - from rich.console import Console - - if args.path == "-": - markdown_body = sys.stdin.read() - else: - with open(args.path, encoding="utf-8") as markdown_file: - markdown_body = markdown_file.read() - - markdown = Markdown( - markdown_body, - justify="full" if args.justify else "left", - code_theme=args.code_theme, - hyperlinks=args.hyperlinks, - inline_code_lexer=args.inline_code_lexer, - ) - if args.page: - import io - import pydoc - - fileio = io.StringIO() - console = Console(file=fileio, force_terminal=args.force_color, width=args.width) - console.print(markdown) - pydoc.pager(fileio.getvalue()) - - else: - console = Console(force_terminal=args.force_color, width=args.width, record=True) - console.print(markdown) diff --git a/src/kimi_cli/utils/rich/markdown_sample.md b/src/kimi_cli/utils/rich/markdown_sample.md deleted file mode 100644 index e825f9481..000000000 --- a/src/kimi_cli/utils/rich/markdown_sample.md +++ /dev/null @@ -1,108 +0,0 @@ -# Markdown Sample Document - -This is a comprehensive sample document showcasing various Markdown elements. - -## Level 2 Heading - -### Level 3 Heading - -Here's some regular text with **bold text**, *italic text*, and `inline code`. - -## Lists - -### Unordered List - -- First item -- Second item - - Nested item 1 - - Nested item 2 -- Third item - -### Ordered List - -1. First step -2. Second step - 1. Sub-step A - 2. Sub-step B -3. Third step - -### Mixed List - -1. First item - - Sub-item with bullet - - Another sub-item -2. Second item - 1. Numbered sub-item - 2. Another numbered sub-item - -## Links and References - -Here's a [link to GitHub](https://github.com) and another [relative link](../README.md). - -## Code Blocks - -```python -def hello_world(): - """A simple function to demonstrate code blocks.""" - print("Hello, World!") - return 42 - -# Call the function -result = hello_world() -``` - -```bash -# Bash example -echo "This is a bash script" -ls -la /tmp -``` - -## Blockquotes - -> This is a blockquote. -> It can span multiple lines. -> -> > And it can be nested too! - -## Tables - -| Column 1 | Column 2 | Column 3 | -|----------|----------|----------| -| Cell 1 | Cell 2 | Cell 3 | -| Left | Center | Right | -| Foo | Bar | Baz | - -## Horizontal Rules - ---- - -Here's some text after a horizontal rule. - ---- - -## Inline Formatting - -You can combine **bold and *italic*** text, or use `code` within paragraphs. - -**Important**: Always test your `code` snippets before deployment. - -## Advanced Features - -### Task Lists - -- [x] Completed task -- [ ] Pending task -- [ ] Another pending task - -### Definition Lists - -Term 1 -: Definition of term 1 - -Term 2 -: Definition of term 2 -: Another definition for term 2 - ---- - -*This document demonstrates comprehensive Markdown formatting capabilities.* diff --git a/src/kimi_cli/utils/rich/markdown_sample_short.md b/src/kimi_cli/utils/rich/markdown_sample_short.md deleted file mode 100644 index 091ee070d..000000000 --- a/src/kimi_cli/utils/rich/markdown_sample_short.md +++ /dev/null @@ -1,2 +0,0 @@ -- First -- Second diff --git a/src/kimi_cli/utils/rich/syntax.py b/src/kimi_cli/utils/rich/syntax.py deleted file mode 100644 index 69f878546..000000000 --- a/src/kimi_cli/utils/rich/syntax.py +++ /dev/null @@ -1,114 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from pygments.token import ( - Comment, - Generic, - Keyword, - Name, - Number, - Operator, - Punctuation, - String, -) -from pygments.token import ( - Literal as PygmentsLiteral, -) -from pygments.token import ( - Text as PygmentsText, -) -from pygments.token import ( - Token as PygmentsToken, -) -from rich.style import Style -from rich.syntax import ANSISyntaxTheme, Syntax, SyntaxTheme - -KIMI_ANSI_THEME_NAME = "kimi-ansi" -KIMI_ANSI_THEME = ANSISyntaxTheme( - { - PygmentsToken: Style(color="default"), - PygmentsText: Style(color="default"), - Comment: Style(color="bright_black", italic=True), - Keyword: Style(color="magenta"), - Keyword.Constant: Style(color="cyan"), - Keyword.Declaration: Style(color="magenta"), - Keyword.Namespace: Style(color="magenta"), - Keyword.Pseudo: Style(color="magenta"), - Keyword.Reserved: Style(color="magenta"), - Keyword.Type: Style(color="magenta"), - Name: Style(color="default"), - Name.Attribute: Style(color="cyan"), - Name.Builtin: Style(color="bright_yellow"), - Name.Builtin.Pseudo: Style(color="cyan"), - Name.Builtin.Type: Style(color="bright_yellow", bold=True), - Name.Class: Style(color="bright_yellow", bold=True), - Name.Constant: Style(color="cyan"), - Name.Decorator: Style(color="bright_cyan"), - Name.Entity: Style(color="bright_yellow"), - Name.Exception: Style(color="bright_yellow", bold=True), - Name.Function: Style(color="bright_cyan"), - Name.Label: Style(color="cyan"), - Name.Namespace: Style(color="magenta"), - Name.Other: Style(color="bright_cyan"), - Name.Property: Style(color="cyan"), - Name.Tag: Style(color="bright_green"), - Name.Variable: Style(color="bright_yellow"), - PygmentsLiteral: Style(color="bright_blue"), - PygmentsLiteral.Date: Style(color="bright_blue"), - String: Style(color="bright_blue"), - String.Doc: Style(color="bright_blue", italic=True), - String.Interpol: Style(color="bright_blue"), - String.Affix: Style(color="cyan"), - Number: Style(color="cyan"), - Operator: Style(color="default"), - Operator.Word: Style(color="magenta"), - Punctuation: Style(color="default"), - Generic.Deleted: Style(color="red"), - Generic.Emph: Style(italic=True), - Generic.Error: Style(color="bright_red", bold=True), - Generic.Heading: Style(color="cyan", bold=True), - Generic.Inserted: Style(color="green"), - Generic.Output: Style(color="bright_black"), - Generic.Prompt: Style(color="bright_cyan"), - Generic.Strong: Style(bold=True), - Generic.Subheading: Style(color="cyan"), - Generic.Traceback: Style(color="bright_red", bold=True), - } -) - - -def resolve_code_theme(theme: str | SyntaxTheme) -> str | SyntaxTheme: - if isinstance(theme, str) and theme.lower() == KIMI_ANSI_THEME_NAME: - return KIMI_ANSI_THEME - return theme - - -class KimiSyntax(Syntax): - def __init__(self, code: str, lexer: str, **kwargs: Any) -> None: - if "theme" not in kwargs or kwargs["theme"] is None: - kwargs["theme"] = KIMI_ANSI_THEME - super().__init__(code, lexer, **kwargs) - - -if __name__ == "__main__": - from rich.console import Console - from rich.text import Text - - console = Console() - - examples = [ - ("diff", "diff", "@@ -1,2 +1,2 @@\n-line one\n+line uno\n"), - ( - "python", - "python", - 'def greet(name: str) -> str:\n return f"Hi, {name}!"\n', - ), - ("bash", "bash", "set -euo pipefail\nprintf '%s\\n' \"hello\"\n"), - ] - - for idx, (title, lexer, code) in enumerate(examples): - if idx: - console.print() - console.print(Text(f"[{title}]", style="bold")) - console.print(KimiSyntax(code, lexer)) diff --git a/src/kimi_cli/utils/server.py b/src/kimi_cli/utils/server.py deleted file mode 100644 index a33186ace..000000000 --- a/src/kimi_cli/utils/server.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Shared utilities for kimi vis and kimi web server startup.""" - -from __future__ import annotations - -import importlib -import socket -import textwrap - - -def get_address_family(host: str) -> socket.AddressFamily: - """Return AF_INET6 for IPv6 addresses, AF_INET for IPv4 and hostnames.""" - return socket.AF_INET6 if ":" in host else socket.AF_INET - - -def format_url(host: str, port: int) -> str: - """Build ``http://host:port``, bracketing IPv6 literals per RFC 2732.""" - if ":" in host: - return f"http://[{host}]:{port}" - return f"http://{host}:{port}" - - -def is_local_host(host: str) -> bool: - """Check whether *host* resolves to a loopback address.""" - return host in {"127.0.0.1", "localhost", "::1"} - - -def find_available_port(host: str, start_port: int, max_attempts: int = 10) -> int: - """Find an available port starting from *start_port*. - - Raises ``RuntimeError`` if no port is available within the range. - """ - if max_attempts <= 0: - raise ValueError("max_attempts must be positive") - if start_port < 1 or start_port > 65535: - raise ValueError("start_port must be between 1 and 65535") - - family = get_address_family(host) - for offset in range(max_attempts): - port = start_port + offset - with socket.socket(family, socket.SOCK_STREAM) as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - try: - s.bind((host, port)) - return port - except OSError: - continue - raise RuntimeError( - f"Cannot find available port in range {start_port}-{start_port + max_attempts - 1}" - ) - - -def get_network_addresses() -> list[str]: - """Get non-loopback IPv4 addresses for this machine.""" - addresses: list[str] = [] - - try: - hostname = socket.gethostname() - for info in socket.getaddrinfo(hostname, None, socket.AF_INET): - ip = info[4][0] - if isinstance(ip, str) and not ip.startswith("127.") and ip not in addresses: - addresses.append(ip) - except OSError: - pass - - try: - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(("8.8.8.8", 80)) - ip = s.getsockname()[0] - if ip and not ip.startswith("127.") and ip not in addresses: - addresses.append(ip) - except OSError: - pass - - try: - netifaces = importlib.import_module("netifaces") - for interface in netifaces.interfaces(): - addrs = netifaces.ifaddresses(interface) - if netifaces.AF_INET in addrs: - for addr_info in addrs[netifaces.AF_INET]: - addr = addr_info.get("addr") - if addr and not addr.startswith("127.") and addr not in addresses: - addresses.append(addr) - except (ImportError, Exception): - pass - - return addresses - - -def print_banner(lines: list[str]) -> None: - """Print a boxed banner with tag conventions (
, ,
).""" - processed: list[str] = [] - for line in lines: - if line == "
": - processed.append(line) - elif not line: - processed.append("") - elif line.startswith("
") or line.startswith(""): - processed.append(line) - else: - processed.extend(textwrap.wrap(line, width=78)) - - def strip_tags(s: str) -> str: - return s.removeprefix("
").removeprefix("") - - content_lines = [strip_tags(line) for line in processed if line != "
"] - width = max(60, *(len(line) for line in content_lines)) - top = "+" + "=" * (width + 2) + "+" - - print(top) - for line in processed: - if line == "
": - print("|" + "-" * (width + 2) + "|") - elif line.startswith("
"): - content = line.removeprefix("
") - print(f"| {content.center(width)} |") - elif line.startswith(""): - content = line.removeprefix("") - print(f"| {content.ljust(width)} |") - else: - print(f"| {line.ljust(width)} |") - print(top) diff --git a/src/kimi_cli/utils/signals.py b/src/kimi_cli/utils/signals.py deleted file mode 100644 index 3d687fdab..000000000 --- a/src/kimi_cli/utils/signals.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import signal -from collections.abc import Callable - - -def install_sigint_handler( - loop: asyncio.AbstractEventLoop, handler: Callable[[], None] -) -> Callable[[], None]: - """ - Install a SIGINT handler that works on Unix and Windows. - - On Unix event loops, prefer `loop.add_signal_handler`. - On Windows (or other platforms) where it is not implemented, fall back to - `signal.signal`. The fallback cannot be removed from the loop, but we - restore the previous handler on uninstall. - - Returns: - A function that removes the installed handler. It is guaranteed that - no exceptions are raised when calling the returned function. - """ - - try: - loop.add_signal_handler(signal.SIGINT, handler) - - def remove() -> None: - with contextlib.suppress(RuntimeError): - loop.remove_signal_handler(signal.SIGINT) - - return remove - except RuntimeError: - # Windows ProactorEventLoop and some environments do not support - # add_signal_handler. Use synchronous signal handling as a fallback. - previous = signal.getsignal(signal.SIGINT) - signal.signal(signal.SIGINT, lambda signum, frame: handler()) - - def remove() -> None: - with contextlib.suppress(RuntimeError): - signal.signal(signal.SIGINT, previous) - - return remove diff --git a/src/kimi_cli/utils/signals.ts b/src/kimi_cli/utils/signals.ts new file mode 100644 index 000000000..7459ceeca --- /dev/null +++ b/src/kimi_cli/utils/signals.ts @@ -0,0 +1,44 @@ +/** + * Signal handling utilities — corresponds to Python utils/signals.py + * Cross-platform SIGINT handler installation. + */ + +/** + * Install a SIGINT handler. Returns a function to remove it. + * + * Works on Unix and Windows (Bun's process.on is cross-platform). + */ +export function installSigintHandler(handler: () => void): () => void { + const listener = () => handler(); + process.on("SIGINT", listener); + + return () => { + process.off("SIGINT", listener); + }; +} + +/** + * Install a SIGTERM handler. Returns a function to remove it. + */ +export function installSigtermHandler(handler: () => void): () => void { + const listener = () => handler(); + process.on("SIGTERM", listener); + + return () => { + process.off("SIGTERM", listener); + }; +} + +/** + * Install handlers for graceful shutdown on both SIGINT and SIGTERM. + * Returns a function to remove all handlers. + */ +export function installShutdownHandlers(handler: () => void): () => void { + const removeSigint = installSigintHandler(handler); + const removeSigterm = installSigtermHandler(handler); + + return () => { + removeSigint(); + removeSigterm(); + }; +} diff --git a/src/kimi_cli/utils/slashcmd.py b/src/kimi_cli/utils/slashcmd.py deleted file mode 100644 index 8ad1eface..000000000 --- a/src/kimi_cli/utils/slashcmd.py +++ /dev/null @@ -1,124 +0,0 @@ -import re -from collections.abc import Awaitable, Callable, Sequence -from dataclasses import dataclass -from typing import overload - - -@dataclass(frozen=True, slots=True, kw_only=True) -class SlashCommand[F: Callable[..., None | Awaitable[None]]]: - name: str - description: str - func: F - aliases: list[str] - - def slash_name(self): - """/name (aliases)""" - if self.aliases: - return f"/{self.name} ({', '.join(self.aliases)})" - return f"/{self.name}" - - -class SlashCommandRegistry[F: Callable[..., None | Awaitable[None]]]: - """Registry for slash commands.""" - - def __init__(self) -> None: - self._commands: dict[str, SlashCommand[F]] = {} - """Primary name -> SlashCommand""" - self._command_aliases: dict[str, SlashCommand[F]] = {} - """Primary name or alias -> SlashCommand""" - - @overload - def command(self, func: F, /) -> F: ... - - @overload - def command( - self, - *, - name: str | None = None, - aliases: Sequence[str] | None = None, - ) -> Callable[[F], F]: ... - - def command( - self, - func: F | None = None, - *, - name: str | None = None, - aliases: Sequence[str] | None = None, - ) -> F | Callable[[F], F]: - """ - Decorator to register a slash command with optional custom name and aliases. - - Usage examples: - @registry.command - def help(app: App, args: str): ... - - @registry.command(name="run") - def start(app: App, args: str): ... - - @registry.command(aliases=["h", "?", "assist"]) - def help(app: App, args: str): ... - """ - - def _register(f: F) -> F: - primary = name or f.__name__ - alias_list = list(aliases) if aliases else [] - - # Create the primary command with aliases - cmd = SlashCommand[F]( - name=primary, - description=(f.__doc__ or "").strip(), - func=f, - aliases=alias_list, - ) - - # Register primary command - self._commands[primary] = cmd - self._command_aliases[primary] = cmd - - # Register aliases pointing to the same command - for alias in alias_list: - self._command_aliases[alias] = cmd - - return f - - if func is not None: - return _register(func) - return _register - - def find_command(self, name: str) -> SlashCommand[F] | None: - return self._command_aliases.get(name) - - def list_commands(self) -> list[SlashCommand[F]]: - """Get all unique primary slash commands (without duplicating aliases).""" - return list(self._commands.values()) - - -@dataclass(frozen=True, slots=True, kw_only=True) -class SlashCommandCall: - name: str - args: str - raw_input: str - - -def parse_slash_command_call(user_input: str) -> SlashCommandCall | None: - """ - Parse a slash command call from user input. - - Returns: - SlashCommandCall if a slash command is found, else None. The `args` field contains - the raw argument string after the command name. - """ - user_input = user_input.strip() - if not user_input or not user_input.startswith("/"): - return None - - name_match = re.match(r"^\/([a-zA-Z0-9_-]+(?::[a-zA-Z0-9_-]+)*)", user_input) - - if not name_match: - return None - - command_name = name_match.group(1) - if len(user_input) > name_match.end() and not user_input[name_match.end()].isspace(): - return None - raw_args = user_input[name_match.end() :].lstrip() - return SlashCommandCall(name=command_name, args=raw_args, raw_input=user_input) diff --git a/src/kimi_cli/utils/string.py b/src/kimi_cli/utils/string.py deleted file mode 100644 index bd4379bba..000000000 --- a/src/kimi_cli/utils/string.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -import random -import re -import string - -_NEWLINE_RE = re.compile(r"[\r\n]+") - - -def shorten_middle(text: str, width: int, remove_newline: bool = True) -> str: - """Shorten the text by inserting ellipsis in the middle.""" - if len(text) <= width: - return text - if remove_newline: - text = _NEWLINE_RE.sub(" ", text) - return text[: width // 2] + "..." + text[-width // 2 :] - - -def random_string(length: int = 8) -> str: - """Generate a random string of fixed length.""" - letters = string.ascii_lowercase - return "".join(random.choice(letters) for _ in range(length)) diff --git a/src/kimi_cli/utils/string.ts b/src/kimi_cli/utils/string.ts new file mode 100644 index 000000000..b2d468810 --- /dev/null +++ b/src/kimi_cli/utils/string.ts @@ -0,0 +1,30 @@ +/** + * String utilities — corresponds to Python utils/string.py + */ + +const NEWLINE_RE = /[\r\n]+/g; + +/** + * Shorten text by inserting ellipsis in the middle. + */ +export function shortenMiddle(text: string, width: number, removeNewline = true): string { + if (text.length <= width) return text; + let t = text; + if (removeNewline) { + t = t.replace(NEWLINE_RE, " "); + } + const half = Math.floor(width / 2); + return t.slice(0, half) + "..." + t.slice(-half); +} + +/** + * Generate a random lowercase string of fixed length. + */ +export function randomString(length = 8): string { + const letters = "abcdefghijklmnopqrstuvwxyz"; + let result = ""; + for (let i = 0; i < length; i++) { + result += letters[Math.floor(Math.random() * letters.length)]; + } + return result; +} diff --git a/src/kimi_cli/utils/subprocess_env.py b/src/kimi_cli/utils/subprocess_env.py deleted file mode 100644 index 1aaf4eda4..000000000 --- a/src/kimi_cli/utils/subprocess_env.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Utilities for subprocess environment handling. - -This module provides utilities to handle environment variables when spawning -subprocesses from a PyInstaller-frozen application. The main issue is that -PyInstaller's bootloader modifies LD_LIBRARY_PATH to prioritize bundled libraries, -which can cause conflicts when spawning external programs that expect system libraries. - -See: https://pyinstaller.org/en/stable/common-issues-and-pitfalls.html -""" - -from __future__ import annotations - -import os -import sys - -# Environment variables that PyInstaller may modify on Linux -_PYINSTALLER_LD_VARS = [ - "LD_LIBRARY_PATH", - "LD_PRELOAD", -] - - -def get_clean_env(base_env: dict[str, str] | None = None) -> dict[str, str]: - """ - Get a clean environment suitable for spawning subprocesses. - - In a PyInstaller-frozen application on Linux, this function restores - the original library path environment variables, preventing subprocesses - from loading incompatible bundled libraries. - - Args: - base_env: Base environment to start from. If None, uses os.environ. - - Returns: - A dictionary of environment variables safe for subprocess use. - """ - env = dict(base_env if base_env is not None else os.environ) - - # Only process in PyInstaller frozen environment on Linux - if not getattr(sys, "frozen", False) or sys.platform != "linux": - return env - - for var in _PYINSTALLER_LD_VARS: - orig_key = f"{var}_ORIG" - if orig_key in env: - # Restore the original value that was saved by PyInstaller bootloader - env[var] = env[orig_key] - elif var in env: - # Variable was not set before PyInstaller modified it, so remove it - del env[var] - - return env - - -def get_noninteractive_env(base_env: dict[str, str] | None = None) -> dict[str, str]: - """ - Get an environment for subprocesses that must not block on interactive prompts. - - Builds on :func:`get_clean_env` and additionally configures git to fail - fast instead of waiting for user input that will never arrive. - - Args: - base_env: Base environment to start from. If None, uses os.environ. - - Returns: - A dictionary of environment variables safe for non-interactive subprocess use. - """ - env = get_clean_env(base_env) - - # GIT_TERMINAL_PROMPT=0 makes git fail instead of prompting for credentials. - env.setdefault("GIT_TERMINAL_PROMPT", "0") - - return env diff --git a/src/kimi_cli/utils/term.py b/src/kimi_cli/utils/term.py deleted file mode 100644 index c22d12e57..000000000 --- a/src/kimi_cli/utils/term.py +++ /dev/null @@ -1,168 +0,0 @@ -from __future__ import annotations - -import contextlib -import os -import re -import sys -import time - - -def ensure_new_line() -> None: - """Ensure the next prompt starts at column 0 regardless of prior command output.""" - - if not sys.stdout.isatty() or not sys.stdin.isatty(): - return - - needs_break = True - if sys.platform == "win32": - column = _cursor_column_windows() - needs_break = column not in (None, 0) - else: - column = _cursor_column_unix() - needs_break = column not in (None, 1) - - if needs_break: - _write_newline() - - -def ensure_tty_sane() -> None: - """Restore basic tty settings so Ctrl-C works after raw-mode operations.""" - if sys.platform == "win32" or not sys.stdin.isatty(): - return - - try: - import termios - except Exception: - return - - try: - fd = sys.stdin.fileno() - attrs = termios.tcgetattr(fd) - except Exception: - return - - desired = termios.ISIG | termios.IEXTEN | termios.ICANON | termios.ECHO - if (attrs[3] & desired) == desired: - return - - attrs[3] |= desired - with contextlib.suppress(OSError): - termios.tcsetattr(fd, termios.TCSADRAIN, attrs) - - -def _cursor_position_unix() -> tuple[int, int] | None: - """Get cursor position (row, column) on Unix. Both are 1-indexed.""" - assert sys.platform != "win32" - - import select - import termios - import tty - - _CURSOR_QUERY = "\x1b[6n" - _CURSOR_POSITION_RE = re.compile(r"\x1b\[(\d+);(\d+)R") - - fd = sys.stdin.fileno() - oldterm = termios.tcgetattr(fd) - - try: - tty.setcbreak(fd) - sys.stdout.write(_CURSOR_QUERY) - sys.stdout.flush() - - response = "" - deadline = time.monotonic() + 0.2 - while time.monotonic() < deadline: - timeout = max(0.01, deadline - time.monotonic()) - ready, _, _ = select.select([sys.stdin], [], [], timeout) - if not ready: - continue - try: - chunk = os.read(fd, 32) - except OSError: - break - if not chunk: - break - response += chunk.decode(encoding="utf-8", errors="ignore") - match = _CURSOR_POSITION_RE.search(response) - if match: - return int(match.group(1)), int(match.group(2)) - finally: - termios.tcsetattr(fd, termios.TCSADRAIN, oldterm) - - return None - - -def _cursor_column_unix() -> int | None: - pos = _cursor_position_unix() - return pos[1] if pos else None - - -def _cursor_position_windows() -> tuple[int, int] | None: - """Get cursor position (row, column) on Windows. Both are 1-indexed.""" - assert sys.platform == "win32" - - import ctypes - from ctypes import wintypes - - kernel32 = ctypes.windll.kernel32 - _STD_OUTPUT_HANDLE = -11 # Windows API constant for standard output handle - handle = kernel32.GetStdHandle(_STD_OUTPUT_HANDLE) - invalid_handle_value = ctypes.c_void_p(-1).value - if handle in (0, invalid_handle_value): - return None - - class COORD(ctypes.Structure): - _fields_ = [("X", wintypes.SHORT), ("Y", wintypes.SHORT)] - - class SMALL_RECT(ctypes.Structure): - _fields_ = [ - ("Left", wintypes.SHORT), - ("Top", wintypes.SHORT), - ("Right", wintypes.SHORT), - ("Bottom", wintypes.SHORT), - ] - - class CONSOLE_SCREEN_BUFFER_INFO(ctypes.Structure): - _fields_ = [ - ("dwSize", COORD), - ("dwCursorPosition", COORD), - ("wAttributes", wintypes.WORD), - ("srWindow", SMALL_RECT), - ("dwMaximumWindowSize", COORD), - ] - - csbi = CONSOLE_SCREEN_BUFFER_INFO() - if not kernel32.GetConsoleScreenBufferInfo(handle, ctypes.byref(csbi)): - return None - - # Windows returns 0-indexed, convert to 1-indexed for consistency - return int(csbi.dwCursorPosition.Y) + 1, int(csbi.dwCursorPosition.X) + 1 - - -def _cursor_column_windows() -> int | None: - pos = _cursor_position_windows() - return pos[1] if pos else None - - -def _write_newline() -> None: - sys.stdout.write("\n") - sys.stdout.flush() - - -def get_cursor_row() -> int | None: - """Get the current cursor row (1-indexed).""" - if not sys.stdout.isatty() or not sys.stdin.isatty(): - return None - - if sys.platform == "win32": - pos = _cursor_position_windows() - else: - pos = _cursor_position_unix() - - return pos[0] if pos else None - - -if __name__ == "__main__": - print("test", end="", flush=True) - ensure_new_line() - print("next line") diff --git a/src/kimi_cli/utils/typing.py b/src/kimi_cli/utils/typing.py deleted file mode 100644 index 2e5635b5c..000000000 --- a/src/kimi_cli/utils/typing.py +++ /dev/null @@ -1,20 +0,0 @@ -from types import UnionType -from typing import Any, TypeAliasType, Union, get_args, get_origin - - -def flatten_union(tp: Any) -> tuple[Any, ...]: - """ - If `tp` is a `UnionType`, return its flattened arguments as a tuple. - Otherwise, return a tuple with `tp` as the only element. - """ - if isinstance(tp, TypeAliasType): - tp = tp.__value__ - origin = get_origin(tp) - if origin in (UnionType, Union): - args = get_args(tp) - flattened_args: list[Any] = [] - for arg in args: - flattened_args.extend(flatten_union(arg)) - return tuple(flattened_args) - else: - return (tp,) diff --git a/src/kimi_cli/utils/yaml.ts b/src/kimi_cli/utils/yaml.ts new file mode 100644 index 000000000..05adcc97b --- /dev/null +++ b/src/kimi_cli/utils/yaml.ts @@ -0,0 +1,179 @@ +/** + * Minimal YAML parser utility. + * For agent spec YAML files which use simple structures. + * + * For production use, consider adding `yaml` package. + * This is a bootstrap implementation. + */ +export function parse(text: string): unknown { + const lines = text.split("\n"); + return parseObject(lines, 0).value; +} + +interface ParseResult { + value: unknown; + consumed: number; +} + +function parseObject(lines: string[], startIndent: number): ParseResult { + const obj: Record = {}; + let i = 0; + + while (i < lines.length) { + const line = lines[i]!; + const stripped = line.trimStart(); + + if (!stripped || stripped.startsWith("#")) { + i++; + continue; + } + + const indent = line.length - stripped.length; + if (indent < startIndent) break; + + if (stripped.startsWith("- ")) break; + + const colonIdx = stripped.indexOf(":"); + if (colonIdx === -1) { + i++; + continue; + } + + const key = stripped.slice(0, colonIdx).trim(); + const valueStr = stripped.slice(colonIdx + 1).trim(); + + if (valueStr === "" || valueStr === "|" || valueStr === ">") { + i++; + const nextIndent = getNextIndent(lines, i); + if (nextIndent > indent) { + const nextStripped = (lines[i] ?? "").trimStart(); + if (nextStripped.startsWith("- ")) { + const arr = parseArray(lines.slice(i), nextIndent); + obj[key] = arr.value; + i += arr.consumed; + } else if (valueStr === "|" || valueStr === ">") { + const block = parseBlockScalar(lines.slice(i), nextIndent, valueStr === "|"); + obj[key] = block.value; + i += block.consumed; + } else { + const nested = parseObject(lines.slice(i), nextIndent); + obj[key] = nested.value; + i += nested.consumed; + } + } else { + obj[key] = null; + } + } else { + obj[key] = parseScalar(valueStr); + i++; + } + } + + return { value: obj, consumed: i }; +} + +function parseArray(lines: string[], startIndent: number): ParseResult { + const arr: unknown[] = []; + let i = 0; + + while (i < lines.length) { + const line = lines[i]!; + const stripped = line.trimStart(); + if (!stripped || stripped.startsWith("#")) { + i++; + continue; + } + + const indent = line.length - stripped.length; + if (indent < startIndent) break; + + if (stripped.startsWith("- ")) { + const itemStr = stripped.slice(2).trim(); + if (itemStr.includes(":")) { + const colonIdx = itemStr.indexOf(":"); + const key = itemStr.slice(0, colonIdx).trim(); + const val = itemStr.slice(colonIdx + 1).trim(); + + i++; + const nextIndent = getNextIndent(lines, i); + if (nextIndent > indent + 2) { + const nested = parseObject(lines.slice(i), nextIndent); + const item: Record = { [key]: val ? parseScalar(val) : nested.value }; + if (typeof nested.value === "object" && nested.value !== null && !val) { + Object.assign(item, { [key]: nested.value }); + } + arr.push(item); + i += nested.consumed; + } else { + arr.push({ [key]: parseScalar(val) }); + } + } else { + arr.push(parseScalar(itemStr)); + i++; + } + } else { + break; + } + } + + return { value: arr, consumed: i }; +} + +function parseBlockScalar(lines: string[], startIndent: number, literal: boolean): ParseResult { + const parts: string[] = []; + let i = 0; + + while (i < lines.length) { + const line = lines[i]!; + const stripped = line.trimStart(); + const indent = line.length - stripped.length; + + if (!stripped) { + parts.push(""); + i++; + continue; + } + + if (indent < startIndent) break; + parts.push(line.slice(startIndent)); + i++; + } + + const sep = literal ? "\n" : " "; + return { value: parts.join(sep).trimEnd(), consumed: i }; +} + +function getNextIndent(lines: string[], from: number): number { + for (let i = from; i < lines.length; i++) { + const line = lines[i]!; + const stripped = line.trimStart(); + if (stripped && !stripped.startsWith("#")) { + return line.length - stripped.length; + } + } + return 0; +} + +function parseScalar(value: string): unknown { + if (!value) return null; + + if ((value.startsWith('"') && value.endsWith('"')) || (value.startsWith("'") && value.endsWith("'"))) { + return value.slice(1, -1); + } + + if (value === "true" || value === "True" || value === "yes") return true; + if (value === "false" || value === "False" || value === "no") return false; + if (value === "null" || value === "~" || value === "Null") return null; + + if (/^-?\d+$/.test(value)) return Number.parseInt(value, 10); + if (/^-?\d+\.\d+$/.test(value)) return Number.parseFloat(value); + + if (value.startsWith("[") && value.endsWith("]")) { + return value + .slice(1, -1) + .split(",") + .map((s) => parseScalar(s.trim())); + } + + return value; +} diff --git a/src/kimi_cli/vis/__init__.py b/src/kimi_cli/vis/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/kimi_cli/vis/api/__init__.py b/src/kimi_cli/vis/api/__init__.py deleted file mode 100644 index 2b0236cfa..000000000 --- a/src/kimi_cli/vis/api/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from kimi_cli.vis.api.sessions import router as sessions_router -from kimi_cli.vis.api.statistics import router as statistics_router -from kimi_cli.vis.api.system import router as system_router - -__all__ = ["sessions_router", "statistics_router", "system_router"] diff --git a/src/kimi_cli/vis/api/sessions.py b/src/kimi_cli/vis/api/sessions.py deleted file mode 100644 index 636bfeb60..000000000 --- a/src/kimi_cli/vis/api/sessions.py +++ /dev/null @@ -1,687 +0,0 @@ -"""Vis API for reading session tracing data.""" - -from __future__ import annotations - -import contextlib -import io -import json -import logging -import re -import shutil -import zipfile -from pathlib import Path -from typing import Any -from uuid import uuid4 - -import aiofiles -from fastapi import APIRouter, HTTPException, UploadFile -from fastapi.responses import StreamingResponse - -from kimi_cli.metadata import load_metadata -from kimi_cli.share import get_share_dir -from kimi_cli.wire.file import WireFileMetadata, parse_wire_file_line - -router = APIRouter(prefix="/api/vis", tags=["vis"]) -logger = logging.getLogger(__name__) - - -def collect_events( - msg_type: str, - payload: dict[str, Any], - out: list[tuple[str, dict[str, Any]]], -) -> None: - """Recursively unwrap SubagentEvent and collect (type, payload) pairs.""" - if msg_type == "SubagentEvent": - inner: dict[str, Any] | None = payload.get("event") - if isinstance(inner, dict): - inner_type: str = inner.get("type", "") - inner_payload: dict[str, Any] = inner.get("payload", {}) - if inner_type: - collect_events(inner_type, inner_payload, out) - else: - out.append((msg_type, payload)) - - -_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9_-]+$") -_IMPORTED_HASH = "__imported__" - - -def _get_imported_root() -> Path: - """Return the root directory for imported sessions.""" - return get_share_dir() / "imported_sessions" - - -def _find_session_dir(work_dir_hash: str, session_id: str) -> Path | None: - """Find session directory by work_dir_hash and session_id.""" - if not _SESSION_ID_RE.match(session_id): - return None - if work_dir_hash == _IMPORTED_HASH: - session_dir = _get_imported_root() / session_id - if session_dir.is_dir(): - return session_dir - return None - if not _SESSION_ID_RE.match(work_dir_hash): - return None - sessions_root = get_share_dir() / "sessions" - session_dir = sessions_root / work_dir_hash / session_id - if session_dir.is_dir(): - return session_dir - return None - - -def get_work_dir_for_hash(hash_dir_name: str) -> str | None: - """Look up the work directory path from metadata for a given hash directory name.""" - try: - metadata = load_metadata() - except Exception: - return None - from hashlib import md5 - - from kaos.local import local_kaos - - for wd in metadata.work_dirs: - path_md5 = md5(wd.path.encode(encoding="utf-8")).hexdigest() - dir_basename = path_md5 if wd.kaos == local_kaos.name else f"{wd.kaos}_{path_md5}" - if dir_basename == hash_dir_name: - return wd.path - return None - - -def _extract_title_from_wire(wire_path: Path, max_bytes: int = 8192) -> tuple[str, int]: - """Extract title and turn count from the beginning of wire.jsonl. - - Only reads up to *max_bytes* to avoid blocking on large files. - Returns (title, turn_count). - """ - title = "" - turn_count = 0 - try: - with wire_path.open(encoding="utf-8") as f: - bytes_read = 0 - for line in f: - bytes_read += len(line.encode("utf-8")) - line = line.strip() - if not line: - continue - try: - parsed = parse_wire_file_line(line) - except Exception: - continue - if isinstance(parsed, WireFileMetadata): - continue - if parsed.message.type == "TurnBegin": - turn_count += 1 - if turn_count == 1: - user_input = parsed.message.payload.get("user_input", "") - if isinstance(user_input, str): - title = user_input[:100] - elif isinstance(user_input, list) and user_input: - first = user_input[0] - if isinstance(first, dict): - title = str(first.get("text", ""))[:100] - # Stop once we exceed the byte budget — title is extracted from - # the first TurnBegin so this is a hard upper bound on I/O. - if bytes_read > max_bytes: - break - except Exception: - pass - return title, turn_count - - -def _scan_session_dir( - session_dir: Path, - work_dir_hash: str, - work_dir: str | None, - *, - imported: bool = False, -) -> dict[str, Any] | None: - """Extract session info from a session directory.""" - if not session_dir.is_dir(): - return None - - wire_path = session_dir / "wire.jsonl" - context_path = session_dir / "context.jsonl" - state_path = session_dir / "state.json" - - wire_exists = wire_path.exists() - context_exists = context_path.exists() - state_exists = state_path.exists() - - # Get last updated time from most recent file - mtimes: list[float] = [] - wire_size = context_size = state_size = 0 - if wire_exists: - st = wire_path.stat() - mtimes.append(st.st_mtime) - wire_size = st.st_size - if context_exists: - st = context_path.stat() - mtimes.append(st.st_mtime) - context_size = st.st_size - if state_exists: - st = state_path.stat() - mtimes.append(st.st_mtime) - state_size = st.st_size - - # Read title from SessionState (source of truth), fall back to wire-derived title - from kimi_cli.session_state import load_session_state - - session_state = load_session_state(session_dir) - - title = "" - turn_count = 0 - if wire_exists: - title, turn_count = _extract_title_from_wire(wire_path) - if session_state.custom_title: - title = session_state.custom_title - - # Count sub-agents - subagent_count = 0 - subagents_dir = session_dir / "subagents" - if subagents_dir.is_dir(): - subagent_count = sum(1 for p in subagents_dir.iterdir() if p.is_dir()) - - return { - "session_id": session_dir.name, - "session_dir": str(session_dir), - "work_dir": work_dir, - "work_dir_hash": work_dir_hash, - "title": title, - "last_updated": max(mtimes) if mtimes else 0, - "has_wire": wire_exists, - "has_context": context_exists, - "has_state": state_exists, - "metadata": session_state.model_dump(mode="json"), - "wire_size": wire_size, - "context_size": context_size, - "state_size": state_size, - "total_size": wire_size + context_size + state_size, - "turns": turn_count, - "imported": imported, - "subagent_count": subagent_count, - } - - -def _list_sessions_sync() -> list[dict[str, Any]]: - """Synchronous session scanning — called from a thread pool.""" - results: list[dict[str, Any]] = [] - - sessions_root = get_share_dir() / "sessions" - if sessions_root.exists(): - for work_dir_hash_dir in sessions_root.iterdir(): - if not work_dir_hash_dir.is_dir(): - continue - work_dir = get_work_dir_for_hash(work_dir_hash_dir.name) - for session_dir in work_dir_hash_dir.iterdir(): - info = _scan_session_dir(session_dir, work_dir_hash_dir.name, work_dir) - if info: - results.append(info) - - imported_root = _get_imported_root() - if imported_root.exists(): - for session_dir in imported_root.iterdir(): - info = _scan_session_dir( - session_dir, - _IMPORTED_HASH, - None, - imported=True, - ) - if info: - results.append(info) - - results.sort(key=lambda s: s["last_updated"], reverse=True) - return results - - -@router.get("/sessions") -async def list_sessions() -> list[dict[str, Any]]: - """List all available sessions across all work directories.""" - import asyncio - - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, _list_sessions_sync) - - -@router.get("/sessions/{work_dir_hash}/{session_id}/wire") -async def get_wire_events(work_dir_hash: str, session_id: str) -> dict[str, Any]: - """Read and parse wire.jsonl for a session.""" - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - wire_path = session_dir / "wire.jsonl" - if not wire_path.exists(): - return {"total": 0, "events": []} - - events: list[dict[str, Any]] = [] - index = 0 - async with aiofiles.open(wire_path, encoding="utf-8") as f: - async for line in f: - line = line.strip() - if not line: - continue - try: - parsed = parse_wire_file_line(line) - except Exception: - logger.debug("Skipped malformed line in %s", wire_path) - continue - if isinstance(parsed, WireFileMetadata): - continue - events.append( - { - "index": index, - "timestamp": parsed.timestamp, - "type": parsed.message.type, - "payload": parsed.message.payload, - } - ) - index += 1 - - return {"total": len(events), "events": events} - - -@router.get("/sessions/{work_dir_hash}/{session_id}/context") -async def get_context_messages(work_dir_hash: str, session_id: str) -> dict[str, Any]: - """Read and parse context.jsonl for a session.""" - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - context_path = session_dir / "context.jsonl" - if not context_path.exists(): - return {"total": 0, "messages": []} - - messages: list[dict[str, Any]] = [] - index = 0 - async with aiofiles.open(context_path, encoding="utf-8") as f: - async for line in f: - line = line.strip() - if not line: - continue - try: - msg = json.loads(line) - except json.JSONDecodeError: - logger.debug("Skipped malformed line in %s", context_path) - continue - msg["index"] = index - messages.append(msg) - index += 1 - - return {"total": len(messages), "messages": messages} - - -@router.get("/sessions/{work_dir_hash}/{session_id}/state") -async def get_session_state(work_dir_hash: str, session_id: str) -> dict[str, Any]: - """Read state.json for a session.""" - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - state_path = session_dir / "state.json" - if not state_path.exists(): - return {} - - async with aiofiles.open(state_path, encoding="utf-8") as f: - content = await f.read() - try: - return json.loads(content) - except json.JSONDecodeError as err: - raise HTTPException(status_code=500, detail="Invalid state.json") from err - - -@router.get("/sessions/{work_dir_hash}/{session_id}/summary") -async def get_session_summary(work_dir_hash: str, session_id: str) -> dict[str, Any]: - """Compute summary statistics for a session by scanning wire.jsonl.""" - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - wire_path = session_dir / "wire.jsonl" - context_path = session_dir / "context.jsonl" - state_path = session_dir / "state.json" - - wire_size = wire_path.stat().st_size if wire_path.exists() else 0 - context_size = context_path.stat().st_size if context_path.exists() else 0 - state_size = state_path.stat().st_size if state_path.exists() else 0 - - zeros: dict[str, Any] = { - "turns": 0, - "steps": 0, - "tool_calls": 0, - "errors": 0, - "compactions": 0, - "duration_sec": 0, - "input_tokens": 0, - "output_tokens": 0, - "wire_size": wire_size, - "context_size": context_size, - "state_size": state_size, - "total_size": wire_size + context_size + state_size, - } - - if not wire_path.exists(): - return zeros - - turns = steps = tool_calls = errors = compactions = 0 - input_tokens = output_tokens = 0 - first_ts = 0.0 - last_ts = 0.0 - - async with aiofiles.open(wire_path, encoding="utf-8") as f: - async for line in f: - line = line.strip() - if not line: - continue - try: - parsed = parse_wire_file_line(line) - except Exception: - logger.debug("Skipped malformed line in %s", wire_path) - continue - if isinstance(parsed, WireFileMetadata): - continue - - ts = parsed.timestamp - msg_type = parsed.message.type - payload = parsed.message.payload - - if first_ts == 0: - first_ts = ts - last_ts = ts - - # Collect (type, payload) pairs, unwrapping SubagentEvent recursively - events_to_process: list[tuple[str, dict[str, Any]]] = [] - collect_events(msg_type, payload, events_to_process) - - for ev_type, ev_payload in events_to_process: - if ev_type == "TurnBegin": - turns += 1 - elif ev_type == "StepBegin": - steps += 1 - elif ev_type == "ToolCall": - tool_calls += 1 - elif ev_type == "CompactionBegin": - compactions += 1 - elif ev_type == "StepInterrupted": - errors += 1 - elif ev_type == "ToolResult": - rv: dict[str, Any] | None = ev_payload.get("return_value") - if isinstance(rv, dict) and rv.get("is_error"): - errors += 1 - elif ev_type == "ApprovalResponse": - if ev_payload.get("response") == "reject": - errors += 1 - elif ev_type == "StatusUpdate": - tu: dict[str, Any] | None = ev_payload.get("token_usage") - if isinstance(tu, dict): - input_tokens += ( - int(tu.get("input_other", 0)) - + int(tu.get("input_cache_read", 0)) - + int(tu.get("input_cache_creation", 0)) - ) - output_tokens += int(tu.get("output", 0)) - - return { - "turns": turns, - "steps": steps, - "tool_calls": tool_calls, - "errors": errors, - "compactions": compactions, - "duration_sec": last_ts - first_ts if last_ts > first_ts else 0, - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "wire_size": wire_size, - "context_size": context_size, - "state_size": state_size, - "total_size": wire_size + context_size + state_size, - } - - -@router.get("/sessions/{work_dir_hash}/{session_id}/subagents") -def list_subagents(work_dir_hash: str, session_id: str) -> list[dict[str, Any]]: - """List all sub-agents for a session.""" - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - subagents_dir = session_dir / "subagents" - if not subagents_dir.is_dir(): - return [] - - results: list[dict[str, Any]] = [] - for entry in subagents_dir.iterdir(): - if not entry.is_dir(): - continue - if not _SESSION_ID_RE.match(entry.name): - continue - - meta_path = entry / "meta.json" - meta: dict[str, Any] = {} - if meta_path.exists(): - with contextlib.suppress(Exception): - meta = json.loads(meta_path.read_text(encoding="utf-8")) - - wire_path = entry / "wire.jsonl" - context_path = entry / "context.jsonl" - results.append( - { - "agent_id": meta.get("agent_id", entry.name), - "subagent_type": meta.get("subagent_type", "unknown"), - "status": meta.get("status", "unknown"), - "description": meta.get("description", ""), - "created_at": meta.get("created_at", 0), - "updated_at": meta.get("updated_at", 0), - "last_task_id": meta.get("last_task_id"), - "launch_spec": meta.get("launch_spec", {}), - "wire_size": wire_path.stat().st_size if wire_path.exists() else 0, - "context_size": context_path.stat().st_size if context_path.exists() else 0, - } - ) - - results.sort(key=lambda s: s.get("created_at", 0)) - return results - - -@router.get("/sessions/{work_dir_hash}/{session_id}/subagents/{agent_id}/wire") -async def get_subagent_wire_events( - work_dir_hash: str, session_id: str, agent_id: str -) -> dict[str, Any]: - """Read and parse wire.jsonl for a specific sub-agent.""" - if not _SESSION_ID_RE.match(agent_id): - raise HTTPException(status_code=400, detail="Invalid agent ID") - - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - wire_path = session_dir / "subagents" / agent_id / "wire.jsonl" - if not wire_path.exists(): - return {"total": 0, "events": []} - - events: list[dict[str, Any]] = [] - index = 0 - async with aiofiles.open(wire_path, encoding="utf-8") as f: - async for line in f: - line = line.strip() - if not line: - continue - try: - parsed = parse_wire_file_line(line) - except Exception: - logger.debug("Skipped malformed line in %s", wire_path) - continue - if isinstance(parsed, WireFileMetadata): - continue - events.append( - { - "index": index, - "timestamp": parsed.timestamp, - "type": parsed.message.type, - "payload": parsed.message.payload, - } - ) - index += 1 - - return {"total": len(events), "events": events} - - -@router.get("/sessions/{work_dir_hash}/{session_id}/subagents/{agent_id}/context") -async def get_subagent_context( - work_dir_hash: str, session_id: str, agent_id: str -) -> dict[str, Any]: - """Read and parse context.jsonl for a specific sub-agent.""" - if not _SESSION_ID_RE.match(agent_id): - raise HTTPException(status_code=400, detail="Invalid agent ID") - - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - context_path = session_dir / "subagents" / agent_id / "context.jsonl" - if not context_path.exists(): - return {"total": 0, "messages": []} - - messages: list[dict[str, Any]] = [] - index = 0 - async with aiofiles.open(context_path, encoding="utf-8") as f: - async for line in f: - line = line.strip() - if not line: - continue - try: - msg = json.loads(line) - except json.JSONDecodeError: - logger.debug("Skipped malformed line in %s", context_path) - continue - msg["index"] = index - messages.append(msg) - index += 1 - - return {"total": len(messages), "messages": messages} - - -@router.get("/sessions/{work_dir_hash}/{session_id}/subagents/{agent_id}/meta") -async def get_subagent_meta(work_dir_hash: str, session_id: str, agent_id: str) -> dict[str, Any]: - """Read meta.json for a specific sub-agent.""" - if not _SESSION_ID_RE.match(agent_id): - raise HTTPException(status_code=400, detail="Invalid agent ID") - - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - meta_path = session_dir / "subagents" / agent_id / "meta.json" - if not meta_path.exists(): - raise HTTPException(status_code=404, detail="Sub-agent not found") - - async with aiofiles.open(meta_path, encoding="utf-8") as f: - content = await f.read() - try: - return json.loads(content) - except json.JSONDecodeError as err: - raise HTTPException(status_code=500, detail="Invalid meta.json") from err - - -@router.get("/sessions/{work_dir_hash}/{session_id}/download") -def download_session(work_dir_hash: str, session_id: str) -> StreamingResponse: - """Download all files in a session directory as a ZIP archive.""" - session_dir = _find_session_dir(work_dir_hash, session_id) - if session_dir is None: - raise HTTPException(status_code=404, detail="Session not found") - - buf = io.BytesIO() - with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: - for file_path in sorted(session_dir.rglob("*")): - if file_path.is_file(): - zf.write(file_path, arcname=str(file_path.relative_to(session_dir))) - buf.seek(0) - - filename = f"session-{session_id}.zip" - return StreamingResponse( - buf, - media_type="application/zip", - headers={"Content-Disposition": f'attachment; filename="{filename}"'}, - ) - - -@router.post("/sessions/import") -async def import_session(file: UploadFile) -> dict[str, Any]: - """Import a session from an uploaded ZIP archive.""" - if not file.filename or not file.filename.endswith(".zip"): - raise HTTPException(status_code=400, detail="Only .zip files are accepted") - - content = await file.read() - if not content: - raise HTTPException(status_code=400, detail="Empty file") - - # Reject uploads larger than 200 MB - _MAX_UPLOAD_BYTES = 200 * 1024 * 1024 - if len(content) > _MAX_UPLOAD_BYTES: - raise HTTPException(status_code=413, detail="File too large (max 200 MB)") - - # Validate ZIP - buf = io.BytesIO(content) - try: - zf = zipfile.ZipFile(buf, "r") - except zipfile.BadZipFile as err: - raise HTTPException(status_code=400, detail="Invalid ZIP file") from err - - with zf: - names = zf.namelist() - # Must contain wire.jsonl or context.jsonl at root or under exactly one directory - _VALID_FILES = ("wire.jsonl", "context.jsonl") - has_valid = any( - n in _VALID_FILES or (n.count("/") == 1 and n.endswith(_VALID_FILES)) for n in names - ) - if not has_valid: - raise HTTPException( - status_code=400, - detail="ZIP must contain wire.jsonl or context.jsonl at the top level " - "(or inside a single directory)", - ) - - session_id = uuid4().hex[:16] - imported_root = _get_imported_root() - session_dir = imported_root / session_id - session_dir.mkdir(parents=True, exist_ok=True) - - # Zip Slip protection: reject entries with path traversal or absolute paths - for info in zf.infolist(): - if info.filename.startswith("/") or ".." in info.filename.split("/"): - shutil.rmtree(session_dir, ignore_errors=True) - raise HTTPException( - status_code=400, - detail="ZIP contains unsafe path entries", - ) - - # Extract - handle both flat ZIPs and ZIPs with a single top-level directory - zf.extractall(session_dir) - - # If all files are under a single subdirectory, flatten them - entries = list(session_dir.iterdir()) - if len(entries) == 1 and entries[0].is_dir(): - nested_dir = entries[0] - for item in nested_dir.iterdir(): - shutil.move(str(item), str(session_dir / item.name)) - nested_dir.rmdir() - - return { - "session_id": session_id, - "work_dir_hash": _IMPORTED_HASH, - } - - -@router.delete("/sessions/{work_dir_hash}/{session_id}") -def delete_session(work_dir_hash: str, session_id: str) -> dict[str, str]: - """Delete an imported session.""" - if work_dir_hash != _IMPORTED_HASH: - raise HTTPException(status_code=403, detail="Only imported sessions can be deleted") - - if not _SESSION_ID_RE.match(session_id): - raise HTTPException(status_code=400, detail="Invalid session ID") - - session_dir = _get_imported_root() / session_id - if not session_dir.is_dir(): - raise HTTPException(status_code=404, detail="Session not found") - - shutil.rmtree(session_dir) - return {"status": "deleted"} diff --git a/src/kimi_cli/vis/api/statistics.py b/src/kimi_cli/vis/api/statistics.py deleted file mode 100644 index e8267a5d4..000000000 --- a/src/kimi_cli/vis/api/statistics.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Vis API for aggregate statistics across all sessions.""" - -from __future__ import annotations - -import time -from collections import defaultdict -from datetime import UTC, datetime, timedelta -from typing import Any - -from fastapi import APIRouter - -from kimi_cli.share import get_share_dir -from kimi_cli.vis.api.sessions import collect_events, get_work_dir_for_hash -from kimi_cli.wire.file import WireFileMetadata, parse_wire_file_line - -router = APIRouter(prefix="/api/vis", tags=["vis"]) - - -# Simple in-memory cache: (result, timestamp) -_cache: dict[str, tuple[dict[str, Any], float]] = {} -_CACHE_TTL = 60 # seconds - - -@router.get("/statistics") -def get_statistics() -> dict[str, Any]: - """Aggregate statistics across all sessions.""" - now = time.time() - cached = _cache.get("statistics") - if cached and (now - cached[1]) < _CACHE_TTL: - return cached[0] - - sessions_root = get_share_dir() / "sessions" - if not sessions_root.exists(): - empty: dict[str, Any] = { - "total_sessions": 0, - "total_turns": 0, - "total_tokens": {"input": 0, "output": 0}, - "total_duration_sec": 0, - "tool_usage": [], - "daily_usage": [], - "per_project": [], - } - _cache["statistics"] = (empty, now) - return empty - - total_sessions = 0 - total_turns = 0 - total_input_tokens = 0 - total_output_tokens = 0 - total_duration_sec = 0.0 - - # tool_name -> { count, error_count } - tool_stats: dict[str, dict[str, int]] = defaultdict(lambda: {"count": 0, "error_count": 0}) - - # date_str -> { sessions, turns } - daily_stats: dict[str, dict[str, int]] = defaultdict(lambda: {"sessions": 0, "turns": 0}) - - # work_dir -> { sessions, turns } - project_stats: dict[str, dict[str, int]] = defaultdict(lambda: {"sessions": 0, "turns": 0}) - - for work_dir_hash_dir in sessions_root.iterdir(): - if not work_dir_hash_dir.is_dir(): - continue - work_dir = get_work_dir_for_hash(work_dir_hash_dir.name) or work_dir_hash_dir.name - - for session_dir in work_dir_hash_dir.iterdir(): - if not session_dir.is_dir(): - continue - - wire_path = session_dir / "wire.jsonl" - if not wire_path.exists(): - continue - - total_sessions += 1 - session_turns = 0 - session_input_tokens = 0 - session_output_tokens = 0 - first_ts = 0.0 - last_ts = 0.0 - session_date: str | None = None - - # Track pending tool calls for error attribution - pending_tools: dict[str, str] = {} # tool_call_id -> tool_name - - try: - with wire_path.open(encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - parsed = parse_wire_file_line(line) - except Exception: - continue - if isinstance(parsed, WireFileMetadata): - continue - - ts = parsed.timestamp - msg_type = parsed.message.type - payload = parsed.message.payload - - if first_ts == 0: - first_ts = ts - # Determine date from first timestamp - try: - dt = datetime.fromtimestamp(ts, tz=UTC) - session_date = dt.strftime("%Y-%m-%d") - except Exception: - pass - last_ts = ts - - # Collect (type, payload) pairs, unwrapping SubagentEvent recursively - events_to_process: list[tuple[str, dict[str, Any]]] = [] - collect_events(msg_type, payload, events_to_process) - - for ev_type, ev_payload in events_to_process: - if ev_type == "TurnBegin": - session_turns += 1 - elif ev_type == "ToolCall": - fn: dict[str, Any] | None = ev_payload.get("function") - tool_id: str = ev_payload.get("id", "") - if isinstance(fn, dict): - name: str = fn.get("name", "unknown") - tool_stats[name]["count"] += 1 - if tool_id: - pending_tools[tool_id] = name - elif ev_type == "ToolResult": - tool_call_id: str = ev_payload.get("tool_call_id", "") - rv: dict[str, Any] | None = ev_payload.get("return_value") - if isinstance(rv, dict) and rv.get("is_error"): - tool_name = pending_tools.get(tool_call_id) - if tool_name: - tool_stats[tool_name]["error_count"] += 1 - pending_tools.pop(tool_call_id, None) - elif ev_type == "StatusUpdate": - tu: dict[str, Any] | None = ev_payload.get("token_usage") - if isinstance(tu, dict): - session_input_tokens += ( - int(tu.get("input_other", 0)) - + int(tu.get("input_cache_read", 0)) - + int(tu.get("input_cache_creation", 0)) - ) - session_output_tokens += int(tu.get("output", 0)) - except Exception: - continue - - total_turns += session_turns - total_input_tokens += session_input_tokens - total_output_tokens += session_output_tokens - - duration = last_ts - first_ts if last_ts > first_ts else 0 - total_duration_sec += duration - - # Aggregate daily - if session_date: - daily_stats[session_date]["sessions"] += 1 - daily_stats[session_date]["turns"] += session_turns - - # Aggregate per project - project_stats[work_dir]["sessions"] += 1 - project_stats[work_dir]["turns"] += session_turns - - # Build tool_usage: top 20 by count - tool_usage = sorted( - [ - {"name": name, "count": stats["count"], "error_count": stats["error_count"]} - for name, stats in tool_stats.items() - ], - key=lambda x: x["count"], - reverse=True, - )[:20] - - # Build daily_usage: last 30 days - today = datetime.now(tz=UTC) - daily_usage: list[dict[str, Any]] = [] - for i in range(29, -1, -1): - d = today - timedelta(days=i) - date_str = d.strftime("%Y-%m-%d") - entry = daily_stats.get(date_str, {"sessions": 0, "turns": 0}) - daily_usage.append( - { - "date": date_str, - "sessions": entry["sessions"], - "turns": entry["turns"], - } - ) - - # Build per_project: top 10 by turns - per_project = sorted( - [ - {"work_dir": wd, "sessions": stats["sessions"], "turns": stats["turns"]} - for wd, stats in project_stats.items() - ], - key=lambda x: x["turns"], - reverse=True, - )[:10] - - result: dict[str, Any] = { - "total_sessions": total_sessions, - "total_turns": total_turns, - "total_tokens": {"input": total_input_tokens, "output": total_output_tokens}, - "total_duration_sec": total_duration_sec, - "tool_usage": tool_usage, - "daily_usage": daily_usage, - "per_project": per_project, - } - - _cache["statistics"] = (result, now) - return result diff --git a/src/kimi_cli/vis/api/system.py b/src/kimi_cli/vis/api/system.py deleted file mode 100644 index b2d64742c..000000000 --- a/src/kimi_cli/vis/api/system.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Vis API for server capabilities and metadata.""" - -from __future__ import annotations - -import sys -from typing import Any - -from fastapi import APIRouter, Request - -router = APIRouter(prefix="/api/vis", tags=["vis"]) - - -@router.get("/capabilities") -def get_capabilities(request: Request) -> dict[str, Any]: - """Return server capabilities that affect frontend feature visibility.""" - restrict_open_in: bool = getattr(request.app.state, "restrict_open_in", False) - return { - "open_in_supported": sys.platform in {"darwin", "win32"} and not restrict_open_in, - } diff --git a/src/kimi_cli/vis/app.py b/src/kimi_cli/vis/app.py deleted file mode 100644 index 918765e7b..000000000 --- a/src/kimi_cli/vis/app.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Kimi Agent Tracing Visualizer application.""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any, cast - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.middleware.gzip import GZipMiddleware -from fastapi.staticfiles import StaticFiles - -from kimi_cli.utils.server import ( - find_available_port, - format_url, - get_network_addresses, - is_local_host, - print_banner, -) -from kimi_cli.vis.api import sessions_router, statistics_router, system_router -from kimi_cli.web.api.open_in import router as open_in_router - -STATIC_DIR = Path(__file__).parent / "static" -GZIP_MINIMUM_SIZE = 1024 -GZIP_COMPRESSION_LEVEL = 6 -DEFAULT_PORT = 5495 -_ENV_RESTRICT_OPEN_IN = "KIMI_VIS_RESTRICT_OPEN_IN" - - -def create_app() -> FastAPI: - """Create the FastAPI application for the tracing visualizer.""" - import os - - restrict_open_in = os.environ.get(_ENV_RESTRICT_OPEN_IN, "").strip().lower() in { - "1", - "true", - } - - application = FastAPI( - title="Kimi Agent Tracing Visualizer", - docs_url=None, - separate_input_output_schemas=False, - ) - - application.add_middleware( - cast(Any, GZipMiddleware), - minimum_size=GZIP_MINIMUM_SIZE, - compresslevel=GZIP_COMPRESSION_LEVEL, - ) - - application.add_middleware( - cast(Any, CORSMiddleware), - allow_origins=["*"], # Local-only tool; port is dynamic so wildcard is acceptable - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - application.state.restrict_open_in = restrict_open_in - - application.include_router(sessions_router) - application.include_router(statistics_router) - application.include_router(system_router) - if not restrict_open_in: - application.include_router(open_in_router) - - @application.get("/healthz") - async def health_probe() -> dict[str, Any]: # pyright: ignore[reportUnusedFunction] - return {"status": "ok"} - - if STATIC_DIR.exists(): - application.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") - - return application - - -def run_vis_server( - host: str = "127.0.0.1", - port: int = DEFAULT_PORT, - reload: bool = False, - open_browser: bool = True, -) -> None: - """Run the visualizer web server.""" - import os - import threading - import webbrowser - - import uvicorn - - actual_port = find_available_port(host, port) - if actual_port != port: - print(f"\nPort {port} is in use, using port {actual_port} instead") - - public_mode = not is_local_host(host) - - # Disable open-in API when exposed to the network (security) - os.environ[_ENV_RESTRICT_OPEN_IN] = "1" if public_mode else "0" - - # Build display hosts (same logic as kimi web) - display_hosts: list[tuple[str, str]] = [] - if host == "0.0.0.0": - display_hosts.append(("Local", "localhost")) - for addr in get_network_addresses(): - display_hosts.append(("Network", addr)) - else: - label = "Local" if is_local_host(host) else "Network" - display_hosts.append((label, host)) - - # Browser should open localhost - browser_host = "localhost" if host == "0.0.0.0" else host - browser_url = format_url(browser_host, actual_port) - - banner_lines = [ - "
██╗ ██╗██╗███╗ ███╗██╗ ██╗ ██╗██╗███████╗", - "
██║ ██╔╝██║████╗ ████║██║ ██║ ██║██║██╔════╝", - "
█████╔╝ ██║██╔████╔██║██║ ██║ ██║██║███████╗", - "
██╔═██╗ ██║██║╚██╔╝██║██║ ╚██╗ ██╔╝██║╚════██║", - "
██║ ██╗██║██║ ╚═╝ ██║██║ ╚████╔╝ ██║███████║", - "
╚═╝ ╚═╝╚═╝╚═╝ ╚═╝╚═╝ ╚═══╝ ╚═╝╚══════╝", - "", - "
AGENT TRACING VISUALIZER (Technical Preview)", - "", - "
", - "", - ] - - for label, host_addr in display_hosts: - banner_lines.append(f" ➜ {label:8} {format_url(host_addr, actual_port)}") - - banner_lines.append("") - banner_lines.append("
") - banner_lines.append("") - - if not public_mode: - banner_lines.extend( - [ - " Tips:", - " • Use -n / --network to share on LAN", - "", - ] - ) - else: - banner_lines.extend( - [ - " This feature is in Technical Preview and may be unstable.", - " Please report issues to the kimi-cli team.", - "", - ] - ) - - print_banner(banner_lines) - - if open_browser: - - def open_browser_after_delay() -> None: - import time - - time.sleep(1.5) - webbrowser.open(browser_url) - - thread = threading.Thread(target=open_browser_after_delay, daemon=True) - thread.start() - - uvicorn.run( - "kimi_cli.vis.app:create_app", - factory=True, - host=host, - port=actual_port, - reload=reload, - log_level="info", - timeout_graceful_shutdown=3, - ) - - -__all__ = ["create_app", "run_vis_server"] diff --git a/src/kimi_cli/web/__init__.py b/src/kimi_cli/web/__init__.py deleted file mode 100644 index 96103db09..000000000 --- a/src/kimi_cli/web/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Kimi Code CLI Web Interface.""" - -from kimi_cli.web.app import create_app, run_web_server - -__all__ = ["create_app", "run_web_server"] diff --git a/src/kimi_cli/web/api/__init__.py b/src/kimi_cli/web/api/__init__.py deleted file mode 100644 index 44abab2e2..000000000 --- a/src/kimi_cli/web/api/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""API routes.""" - -from kimi_cli.web.api import config, open_in, sessions - -config_router = config.router -sessions_router = sessions.router -work_dirs_router = sessions.work_dirs_router -open_in_router = open_in.router - -__all__ = [ - "config_router", - "open_in_router", - "sessions_router", - "work_dirs_router", -] diff --git a/src/kimi_cli/web/api/config.py b/src/kimi_cli/web/api/config.py deleted file mode 100644 index 405d6b3ef..000000000 --- a/src/kimi_cli/web/api/config.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Config API routes.""" - -from __future__ import annotations - -from fastapi import APIRouter, Depends, HTTPException, Request, status -from pydantic import BaseModel, Field - -from kimi_cli import logger -from kimi_cli.config import LLMModel, get_config_file, load_config, save_config -from kimi_cli.llm import ProviderType, derive_model_capabilities -from kimi_cli.web.runner.process import KimiCLIRunner - -router = APIRouter(prefix="/api/config", tags=["config"]) - - -class ConfigModel(LLMModel): - """Model configuration for frontend.""" - - name: str = Field(description="Model key in kimi-cli config (Config.models)") - provider_type: ProviderType = Field(description="Provider type (LLMProvider.type)") - - -class GlobalConfig(BaseModel): - """Global configuration snapshot for frontend.""" - - default_model: str = Field(description="Current default model key") - default_thinking: bool = Field(description="Current default thinking mode") - models: list[ConfigModel] = Field(description="All configured models") - - -class UpdateGlobalConfigRequest(BaseModel): - """Request to update global config.""" - - default_model: str | None = Field(default=None, description="New default model key") - default_thinking: bool | None = Field(default=None, description="New default thinking mode") - restart_running_sessions: bool | None = Field( - default=None, description="Whether to restart running sessions" - ) - force_restart_busy_sessions: bool | None = Field( - default=None, description="Whether to force restart busy sessions" - ) - - -class UpdateGlobalConfigResponse(BaseModel): - """Response after updating global config.""" - - config: GlobalConfig = Field(description="Updated config snapshot") - restarted_session_ids: list[str] | None = Field( - default=None, description="IDs of restarted sessions" - ) - skipped_busy_session_ids: list[str] | None = Field( - default=None, description="IDs of busy sessions that were skipped" - ) - - -class ConfigToml(BaseModel): - """Raw config.toml content.""" - - content: str = Field(description="Raw TOML content") - path: str = Field(description="Path to config file") - - -class UpdateConfigTomlRequest(BaseModel): - """Request to update config.toml.""" - - content: str = Field(description="New TOML content") - - -class UpdateConfigTomlResponse(BaseModel): - """Response after updating config.toml.""" - - success: bool = Field(description="Whether the update was successful") - error: str | None = Field(default=None, description="Error message if failed") - - -def _build_global_config() -> GlobalConfig: - """Build GlobalConfig from kimi-cli config.""" - config = load_config() - - models: list[ConfigModel] = [] - for model_name, model in config.models.items(): - provider = config.providers.get(model.provider) - if provider is None: - continue - - # Derive capabilities - derived_caps = derive_model_capabilities(model) - capabilities = derived_caps or None - - models.append( - ConfigModel( - name=model_name, - model=model.model, - provider=model.provider, - provider_type=provider.type, - max_context_size=model.max_context_size, - capabilities=capabilities, - ) - ) - - return GlobalConfig( - default_model=config.default_model, - default_thinking=config.default_thinking, - models=models, - ) - - -def _get_runner(req: Request) -> KimiCLIRunner: - """Get KimiCLIRunner from FastAPI app state.""" - return req.app.state.runner - - -def _ensure_sensitive_apis_allowed(request: Request) -> None: - """Block sensitive config writes when restricted.""" - if getattr(request.app.state, "restrict_sensitive_apis", False): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Sensitive config APIs are disabled in this mode.", - ) - - -@router.get("/", summary="Get global (kimi-cli) config snapshot") -async def get_global_config() -> GlobalConfig: - """Get global (kimi-cli) config snapshot.""" - return _build_global_config() - - -@router.patch("/", summary="Update global (kimi-cli) default model/thinking") -async def update_global_config( - request: UpdateGlobalConfigRequest, - http_request: Request, - runner: KimiCLIRunner = Depends(_get_runner), -) -> UpdateGlobalConfigResponse: - """Update global (kimi-cli) default model/thinking.""" - _ensure_sensitive_apis_allowed(http_request) - config = load_config() - - # Validate and update default_model - if request.default_model is not None: - if request.default_model not in config.models: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Model '{request.default_model}' not found in config", - ) - config.default_model = request.default_model - - # Update default_thinking - if request.default_thinking is not None: - config.default_thinking = request.default_thinking - - # Save config - save_config(config) - - # Restart running workers to apply config changes - restarted: list[str] = [] - skipped_busy: list[str] = [] - - restart_running = request.restart_running_sessions - if restart_running is None: - restart_running = True # Default to restarting sessions - - if restart_running: - summary = await runner.restart_running_workers( - reason="config_update", - force=request.force_restart_busy_sessions or False, - ) - restarted = [str(sid) for sid in summary.restarted_session_ids] - skipped_busy = [str(sid) for sid in summary.skipped_busy_session_ids] - - return UpdateGlobalConfigResponse( - config=_build_global_config(), - restarted_session_ids=restarted if restarted else None, - skipped_busy_session_ids=skipped_busy if skipped_busy else None, - ) - - -@router.get("/toml", summary="Get kimi-cli config.toml") -async def get_config_toml(http_request: Request) -> ConfigToml: - """Get kimi-cli config.toml.""" - _ensure_sensitive_apis_allowed(http_request) - config_file = get_config_file() - if not config_file.exists(): - return ConfigToml(content="", path=str(config_file)) - return ConfigToml(content=config_file.read_text(encoding="utf-8"), path=str(config_file)) - - -@router.put("/toml", summary="Update kimi-cli config.toml") -async def update_config_toml( - request: UpdateConfigTomlRequest, - http_request: Request, -) -> UpdateConfigTomlResponse: - """Update kimi-cli config.toml.""" - from kimi_cli.config import load_config_from_string - - _ensure_sensitive_apis_allowed(http_request) - try: - # Validate the config first - load_config_from_string(request.content) - - # Write to file - config_file = get_config_file() - config_file.parent.mkdir(parents=True, exist_ok=True) - config_file.write_text(request.content, encoding="utf-8") - - return UpdateConfigTomlResponse(success=True) - except Exception as e: - logger.warning(f"Failed to update config.toml: {e}") - return UpdateConfigTomlResponse(success=False, error=str(e)) diff --git a/src/kimi_cli/web/api/open_in.py b/src/kimi_cli/web/api/open_in.py deleted file mode 100644 index 79cdf3987..000000000 --- a/src/kimi_cli/web/api/open_in.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Open local apps for a path on the host machine.""" - -from __future__ import annotations - -import asyncio -import subprocess -import sys -from pathlib import Path -from typing import Literal - -from fastapi import APIRouter, HTTPException, status -from pydantic import BaseModel - -from kimi_cli import logger - -router = APIRouter(prefix="/api/open-in", tags=["open-in"]) - - -class OpenInRequest(BaseModel): - """Open path in a local app.""" - - app: Literal["finder", "cursor", "vscode", "iterm", "terminal", "antigravity"] - path: str - - -class OpenInResponse(BaseModel): - """Open path response.""" - - ok: bool - detail: str | None = None - - -def _resolve_path(path: str) -> Path: - """Resolve and validate a path (file or directory).""" - resolved = Path(path).expanduser() - try: - resolved = resolved.resolve() - except FileNotFoundError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Path does not exist: {path}", - ) from None - - if not resolved.exists(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Path does not exist: {path}", - ) - return resolved - - -def _run_command(args: list[str]) -> None: - subprocess.run( - args, - check=True, - capture_output=True, - text=True, - ) - - -def _spawn_process(args: list[str]) -> None: - subprocess.Popen(args, close_fds=True) - - -def _open_app(app_name: str, path: Path, fallback: str | None = None) -> None: - try: - _run_command(["open", "-a", app_name, str(path)]) - return - except subprocess.CalledProcessError as exc: - if fallback is None: - raise - logger.warning("Open with {} failed: {}", app_name, exc) - _run_command(["open", "-a", fallback, str(path)]) - - -def _open_terminal(path: Path) -> None: - script = f'tell application "Terminal" to do script "cd " & quoted form of "{path}"' - _run_command(["osascript", "-e", script]) - - -def _open_iterm(path: Path) -> None: - script = "\n".join( - [ - 'tell application "iTerm"', - " create window with default profile", - " tell current session of current window", - f' write text "cd " & quoted form of "{path}"', - " end tell", - "end tell", - ] - ) - try: - _run_command(["osascript", "-e", script]) - except subprocess.CalledProcessError: - script = script.replace('"iTerm"', '"iTerm2"') - _run_command(["osascript", "-e", script]) - - -def _open_windows_app(command: str, path: Path) -> None: - _run_command(["cmd", "/c", "start", "", command, str(path)]) - - -def _open_windows_explorer(path: Path, *, is_file: bool) -> None: - if is_file: - _spawn_process(["explorer", f"/select,{path}"]) - else: - _spawn_process(["explorer", str(path)]) - - -def _open_windows_terminal(path: Path) -> None: - try: - _run_command(["cmd", "/c", "start", "", "wt.exe", "-d", str(path)]) - except subprocess.CalledProcessError as exc: - logger.warning("Open with Windows Terminal failed: {}", exc) - _run_command(["cmd", "/c", "start", "", "cmd.exe", "/K", f'cd /d "{path}"']) - - -def _open_in_macos(app: OpenInRequest, path: Path, *, is_file: bool) -> None: - match app.app: - case "finder": - if is_file: - # Reveal file in Finder - _run_command(["open", "-R", str(path)]) - else: - _run_command(["open", str(path)]) - case "cursor": - _open_app("Cursor", path) - case "vscode": - _open_app("Visual Studio Code", path, fallback="Code") - case "antigravity": - _open_app("Antigravity", path) - case "iterm": - # Terminal apps need directory - directory = path.parent if is_file else path - _open_iterm(directory) - case "terminal": - directory = path.parent if is_file else path - _open_terminal(directory) - case _: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Unsupported app: {app.app}", - ) - - -def _open_in_windows(app: OpenInRequest, path: Path, *, is_file: bool) -> None: - match app.app: - case "finder": - _open_windows_explorer(path, is_file=is_file) - case "cursor": - _open_windows_app("cursor", path) - case "vscode": - _open_windows_app("code", path) - case "terminal": - directory = path.parent if is_file else path - _open_windows_terminal(directory) - case "iterm" | "antigravity": - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"{app.app} is not supported on Windows.", - ) - case _: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Unsupported app: {app.app}", - ) - - -def _open_in_sync(request: OpenInRequest, path: Path, *, is_file: bool) -> None: - if sys.platform == "darwin": - _open_in_macos(request, path, is_file=is_file) - else: - _open_in_windows(request, path, is_file=is_file) - - -@router.post("", summary="Open a path in a local application") -async def open_in(request: OpenInRequest) -> OpenInResponse: - if sys.platform not in {"darwin", "win32"}: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Open-in is only supported on macOS and Windows.", - ) - - path = _resolve_path(request.path) - is_file = path.is_file() - - try: - await asyncio.to_thread(_open_in_sync, request, path, is_file=is_file) - except subprocess.CalledProcessError as exc: - logger.warning("Open-in failed ({}): {}", request.app, exc) - detail = exc.stderr.strip() if exc.stderr else "Failed to open application." - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=detail, - ) from exc - - return OpenInResponse(ok=True) diff --git a/src/kimi_cli/web/api/sessions.py b/src/kimi_cli/web/api/sessions.py deleted file mode 100644 index 801d74f69..000000000 --- a/src/kimi_cli/web/api/sessions.py +++ /dev/null @@ -1,1370 +0,0 @@ -"""Sessions API routes.""" - -from __future__ import annotations - -import asyncio -import json -import mimetypes -import os -import re -import shutil -import time -from datetime import UTC, datetime -from pathlib import Path -from typing import Any, cast -from urllib.parse import quote -from uuid import UUID, uuid4 - -from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status -from fastapi.responses import FileResponse, Response -from kaos.path import KaosPath -from pydantic import BaseModel, Field -from starlette.websockets import WebSocket, WebSocketDisconnect - -from kimi_cli import logger -from kimi_cli.metadata import load_metadata, save_metadata -from kimi_cli.session import Session as KimiCLISession -from kimi_cli.utils.subprocess_env import get_clean_env -from kimi_cli.web.auth import is_origin_allowed, is_private_ip, verify_token -from kimi_cli.web.models import ( - GenerateTitleRequest, - GenerateTitleResponse, - GitDiffStats, - GitFileDiff, - Session, - SessionStatus, - UpdateSessionRequest, -) -from kimi_cli.web.runner.messages import new_session_status_message, send_history_complete -from kimi_cli.web.runner.process import KimiCLIRunner -from kimi_cli.web.store.sessions import ( - JointSession, - invalidate_sessions_cache, - load_session_by_id, - load_sessions_page, - run_auto_archive, -) -from kimi_cli.wire.jsonrpc import ( - ErrorCodes, - JSONRPCErrorObject, - JSONRPCErrorResponse, - JSONRPCInMessageAdapter, - JSONRPCPromptMessage, -) -from kimi_cli.wire.serde import deserialize_wire_message -from kimi_cli.wire.types import is_request - -router = APIRouter(prefix="/api/sessions", tags=["sessions"]) -work_dirs_router = APIRouter(prefix="/api/work-dirs", tags=["work-dirs"]) - -# Constants -MAX_UPLOAD_SIZE = 100 * 1024 * 1024 # 100MB -DEFAULT_MAX_PUBLIC_PATH_DEPTH = 6 -SENSITIVE_PATH_PARTS = { - "id_rsa", - "id_ed25519", - "known_hosts", - "credentials", - ".aws", - ".ssh", - ".gnupg", - ".kube", - ".npmrc", - ".pypirc", - ".netrc", -} -SENSITIVE_PATH_EXTENSIONS = { - ".pem", - ".key", - ".p12", - ".pfx", - ".kdbx", - ".der", -} -# Home directory patterns to detect if resolved path escapes to sensitive locations -SENSITIVE_HOME_PATHS = { - ".ssh", - ".gnupg", - ".aws", - ".kube", -} -CHECKPOINT_USER_PATTERN = re.compile(r"^CHECKPOINT \d+$") - - -def sanitize_filename(filename: str) -> str: - """Remove potentially dangerous characters from filename.""" - # Keep only alphanumeric, dots, underscores, hyphens, and spaces - safe = "".join(c for c in filename if c.isalnum() or c in "._- ") - return safe.strip() or "unnamed" - - -def get_runner(req: Request) -> KimiCLIRunner: - """Get the KimiCLIRunner from the FastAPI app state.""" - return req.app.state.runner - - -def get_runner_ws(ws: WebSocket) -> KimiCLIRunner: - """Get the KimiCLIRunner from the FastAPI app state (for WebSocket routes).""" - return ws.app.state.runner - - -def get_editable_session( - session_id: UUID, - runner: KimiCLIRunner, -) -> JointSession: - """Get a session and verify it's not busy.""" - session = load_session_by_id(session_id) - if session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Session not found", - ) - # Check if session is busy - session_process = runner.get_session(session_id) - if session_process and session_process.is_busy: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Session is busy. Please wait for it to complete before modifying.", - ) - return session - - -def _relative_parts(path: Path) -> list[str]: - return [part for part in path.parts if part not in {"", "."}] - - -def _is_sensitive_relative_path(rel_path: Path) -> bool: - parts = _relative_parts(rel_path) - for part in parts: - if part.startswith("."): - return True - if part.lower() in SENSITIVE_PATH_PARTS: - return True - return rel_path.suffix.lower() in SENSITIVE_PATH_EXTENSIONS - - -def _contains_symlink(path: Path, base: Path) -> bool: - """Check if any component of the path (relative to base) is a symlink.""" - try: - current = base - rel_parts = path.relative_to(base).parts - for part in rel_parts: - current = current / part - if current.is_symlink(): - return True - except (ValueError, OSError): - return True - return False - - -def _is_path_in_sensitive_location(path: Path) -> bool: - """Check if resolved path points to a sensitive location (e.g., ~/.ssh, ~/.aws).""" - try: - home = Path.home() - if path.is_relative_to(home): - rel_to_home = path.relative_to(home) - first_part = rel_to_home.parts[0] if rel_to_home.parts else "" - if first_part in SENSITIVE_HOME_PATHS: - return True - except (ValueError, RuntimeError): - pass - return False - - -def _ensure_public_file_access_allowed( - rel_path: Path, - restrict_sensitive_apis: bool, - max_path_depth: int = DEFAULT_MAX_PUBLIC_PATH_DEPTH, -) -> None: - if not restrict_sensitive_apis: - return - rel_parts = _relative_parts(rel_path) - if len(rel_parts) > max_path_depth: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Path too deep for public access " - f"(max depth: {max_path_depth}, current: {len(rel_parts)}).", - ) - if _is_sensitive_relative_path(rel_path): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Access to sensitive files is disabled.", - ) - - -def _read_wire_lines(wire_file: Path) -> list[str]: - """Read and parse wire.jsonl into JSONRPC event strings (runs in thread).""" - result: list[str] = [] - with open(wire_file, encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - record = json.loads(line) - if not isinstance(record, dict): - continue - record = cast(dict[str, Any], record) - record_type = record.get("type") - if isinstance(record_type, str) and record_type == "metadata": - continue - message_raw = record.get("message") - if not isinstance(message_raw, dict): - continue - message_raw = cast(dict[str, Any], message_raw) - message = deserialize_wire_message(message_raw) - _is_req = is_request(message) - event_msg: dict[str, Any] = { - "jsonrpc": "2.0", - "method": "request" if _is_req else "event", - "params": message_raw, - } - if _is_req: - # JSON-RPC requests require a top-level ``id`` so the - # client can correlate its response. Use the request's - # own ``id`` field (e.g. ApprovalRequest.id, - # QuestionRequest.id). Note: ``message_raw`` wraps data - # as ``{"type": ..., "payload": {...}}`` so the id lives - # on the deserialized object, not at the raw dict top level. - event_msg["id"] = message.id - result.append(json.dumps(event_msg, ensure_ascii=False)) - except (json.JSONDecodeError, KeyError, ValueError, TypeError): - continue - return result - - -async def replay_history(ws: WebSocket, session_dir: Path) -> None: - """Replay historical wire messages from wire.jsonl to a WebSocket.""" - wire_file = session_dir / "wire.jsonl" - if not await asyncio.to_thread(wire_file.exists): - return - - try: - lines = await asyncio.to_thread(_read_wire_lines, wire_file) - for event_text in lines: - await ws.send_text(event_text) - except Exception: - pass - - -@router.get("/", summary="List all sessions") -async def list_sessions( - runner: KimiCLIRunner = Depends(get_runner), - limit: int = 100, - offset: int = 0, - q: str | None = None, - archived: bool | None = None, -) -> list[Session]: - """List sessions with optional pagination and search. - - Args: - limit: Maximum number of sessions to return (default 100, max 500). - offset: Number of sessions to skip (default 0). - q: Optional search query to filter by title or work_dir. - archived: Filter by archived status. - - None (default): Only return non-archived sessions. - - True: Only return archived sessions. - """ - if limit <= 0: - limit = 100 - if limit > 500: - limit = 500 - if offset < 0: - offset = 0 - - # Run auto-archive in background (throttled internally, runs at most once per 5 minutes) - await asyncio.to_thread(run_auto_archive) - - sessions = load_sessions_page(limit=limit, offset=offset, query=q, archived=archived) - for session in sessions: - session_process = runner.get_session(session.session_id) - session.is_running = session_process is not None and session_process.is_running - session.status = session_process.status if session_process else None - return cast(list[Session], sessions) - - -@router.get("/{session_id}", summary="Get session") -async def get_session( - session_id: UUID, - runner: KimiCLIRunner = Depends(get_runner), -) -> Session | None: - """Get a session by ID.""" - session = load_session_by_id(session_id) - if session is not None: - session_process = runner.get_session(session_id) - session.is_running = session_process is not None and session_process.is_running - session.status = session_process.status if session_process else None - return session - - -@router.post("/", summary="Create a new session") -async def create_session(request: CreateSessionRequest | None = None) -> Session: - """Create a new session.""" - # Use provided work_dir or default to user's home directory - if request and request.work_dir: - work_dir_path = Path(request.work_dir).expanduser().resolve() - # Validate the directory exists - if not work_dir_path.exists(): - if request.create_dir: - # Auto-create the directory - try: - work_dir_path.mkdir(parents=True, exist_ok=True) - except PermissionError as e: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Permission denied: cannot create directory {request.work_dir}", - ) from e - except OSError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Failed to create directory: {e}", - ) from e - else: - # Return 404 to indicate directory does not exist - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Directory does not exist: {request.work_dir}", - ) - if not work_dir_path.is_dir(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Path is not a directory: {request.work_dir}", - ) - work_dir = KaosPath.unsafe_from_local_path(work_dir_path) - else: - work_dir = KaosPath.unsafe_from_local_path(Path.home()) - kimi_cli_session = await KimiCLISession.create(work_dir=work_dir) - context_file = kimi_cli_session.dir / "context.jsonl" - invalidate_sessions_cache() - invalidate_work_dirs_cache() - return Session( - session_id=UUID(kimi_cli_session.id), - title=kimi_cli_session.title, - last_updated=datetime.fromtimestamp(context_file.stat().st_mtime, tz=UTC), - is_running=False, - status=SessionStatus( - session_id=UUID(kimi_cli_session.id), - state="stopped", - seq=0, - worker_id=None, - reason=None, - detail=None, - updated_at=datetime.now(UTC), - ), - work_dir=str(work_dir), - session_dir=str(kimi_cli_session.dir), - ) - - -class CreateSessionRequest(BaseModel): - """Create session request.""" - - work_dir: str | None = None - create_dir: bool = False # Whether to auto-create directory if it doesn't exist - - -class ForkSessionRequest(BaseModel): - """Fork session request.""" - - turn_index: int = Field(..., ge=0) # 0-based, fork includes this turn and all previous turns - - -class UploadSessionFileResponse(BaseModel): - """Upload file response.""" - - path: str - filename: str - size: int - - -@router.post("/{session_id}/files", summary="Upload file to session") -async def upload_session_file( - session_id: UUID, - file: UploadFile, - runner: KimiCLIRunner = Depends(get_runner), -) -> UploadSessionFileResponse: - """Upload a file to a session.""" - session = get_editable_session(session_id, runner) - session_dir = session.kimi_cli_session.dir - upload_dir = session_dir / "uploads" - upload_dir.mkdir(parents=True, exist_ok=True) - - # Read and validate file size - content = await file.read() - if len(content) > MAX_UPLOAD_SIZE: - raise HTTPException( - status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, - detail=f"File too large (max {MAX_UPLOAD_SIZE // 1024 // 1024}MB)", - ) - - # Generate safe filename - file_name = str(uuid4()) - if file.filename: - safe_name = sanitize_filename(file.filename) - name, ext = os.path.splitext(safe_name) - file_name = f"{name}_{file_name[:6]}{ext}" - - upload_path = upload_dir / file_name - upload_path.write_bytes(content) - - return UploadSessionFileResponse( - path=str(upload_path), - filename=file_name, - size=len(content), - ) - - -@router.get( - "/{session_id}/uploads/{path:path}", - summary="Get uploaded file from session uploads", -) -async def get_session_upload_file( - session_id: UUID, - path: str, -) -> Response: - """Get a file from a session's uploads directory.""" - session = load_session_by_id(session_id) - if session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Session not found", - ) - - uploads_dir = (session.kimi_cli_session.dir / "uploads").resolve() - if not uploads_dir.exists(): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Uploads directory not found", - ) - - file_path = (uploads_dir / path).resolve() - if not file_path.is_relative_to(uploads_dir): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid path: path traversal not allowed", - ) - - if not file_path.exists() or not file_path.is_file(): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="File not found", - ) - - media_type, _ = mimetypes.guess_type(file_path.name) - encoded_filename = quote(file_path.name, safe="") - return FileResponse( - file_path, - media_type=media_type or "application/octet-stream", - headers={ - "Content-Disposition": f"inline; filename*=UTF-8''{encoded_filename}", - }, - ) - - -@router.get( - "/{session_id}/files/{path:path}", - summary="Get file or list directory from session work_dir", -) -async def get_session_file( - session_id: UUID, - path: str, - request: Request, -) -> Response: - """Get a file or list directory from session work directory.""" - session = load_session_by_id(session_id) - if session is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Session not found", - ) - - # Security check: prevent path traversal attacks using resolve() - work_dir = Path(str(session.kimi_cli_session.work_dir)).resolve() - requested_path = work_dir / path - file_path = requested_path.resolve() - - # Check path traversal - if not file_path.is_relative_to(work_dir): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid path: path traversal not allowed", - ) - - rel_path = file_path.relative_to(work_dir) - restrict_sensitive_apis = getattr(request.app.state, "restrict_sensitive_apis", False) - max_path_depth = ( - getattr(request.app.state, "max_public_path_depth", None) or DEFAULT_MAX_PUBLIC_PATH_DEPTH - ) - - # Additional security checks when restricting sensitive APIs - if restrict_sensitive_apis: - # Check for symlinks in the path - if _contains_symlink(requested_path, work_dir): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Symbolic links are not allowed in public mode.", - ) - - # Check if resolved path points to sensitive location - if _is_path_in_sensitive_location(file_path): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Access to sensitive system directories is not allowed.", - ) - - _ensure_public_file_access_allowed(rel_path, restrict_sensitive_apis, max_path_depth) - - if not file_path.exists(): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="File not found", - ) - - if file_path.is_dir(): - result: list[dict[str, str | int]] = [] - for subpath in file_path.iterdir(): - if restrict_sensitive_apis: - rel_subpath = rel_path / subpath.name - if _is_sensitive_relative_path(rel_subpath): - continue - if subpath.is_dir(): - result.append({"name": subpath.name, "type": "directory"}) - else: - result.append( - { - "name": subpath.name, - "type": "file", - "size": subpath.stat().st_size, - } - ) - result.sort(key=lambda x: (cast(str, x["type"]), cast(str, x["name"]))) - return Response(content=json.dumps(result), media_type="application/json") - - content = file_path.read_bytes() - media_type, _ = mimetypes.guess_type(file_path.name) - encoded_filename = quote(file_path.name, safe="") - return Response( - content=content, - media_type=media_type or "application/octet-stream", - headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"}, - ) - - -def _update_last_session_id(session: JointSession) -> None: - """Update last_session_id for the session's work directory.""" - kimi_session = session.kimi_cli_session - work_dir = kimi_session.work_dir - - metadata = load_metadata() - work_dir_meta = metadata.get_work_dir_meta(work_dir) - - if work_dir_meta is None: - work_dir_meta = metadata.new_work_dir_meta(work_dir) - - work_dir_meta.last_session_id = kimi_session.id - save_metadata(metadata) - - -@router.delete("/{session_id}", summary="Delete a session") -async def delete_session(session_id: UUID, runner: KimiCLIRunner = Depends(get_runner)) -> None: - """Delete a session.""" - session = get_editable_session(session_id, runner) - session_process = runner.get_session(session_id) - if session_process is not None: - await session_process.stop() - wd_meta = session.kimi_cli_session.work_dir_meta - if wd_meta.last_session_id == str(session_id): - metadata = load_metadata() - for wd in metadata.work_dirs: - if wd.path == wd_meta.path: - wd.last_session_id = None - break - save_metadata(metadata) - session_dir = session.kimi_cli_session.dir - if session_dir.exists(): - shutil.rmtree(session_dir) - invalidate_sessions_cache() - - -@router.patch("/{session_id}", summary="Update session") -async def update_session( - session_id: UUID, - request: UpdateSessionRequest, - runner: KimiCLIRunner = Depends(get_runner), -) -> Session: - """Update a session (e.g., rename title or archive/unarchive).""" - from kimi_cli.session_state import load_session_state, save_session_state - - session = get_editable_session(session_id, runner) - session_dir = session.kimi_cli_session.dir - state = load_session_state(session_dir) - - # Update title if provided - if request.title is not None: - state.custom_title = request.title - state.title_generated = True - - # Update archived status if provided - if request.archived is not None: - state.archived = request.archived - if request.archived: - state.archived_at = time.time() - state.auto_archive_exempt = False - else: - state.archived_at = None - state.auto_archive_exempt = True - - save_session_state(state, session_dir) - - # Invalidate cache to force reload - invalidate_sessions_cache() - - # Return updated session - updated_session = load_session_by_id(session_id) - if updated_session is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to reload session after update", - ) - return updated_session - - -def extract_first_turn_from_wire(session_dir: Path) -> tuple[str, str] | None: - """Extract the first turn's user message and assistant response from wire.jsonl. - - Returns: - tuple[str, str] | None: (user_message, assistant_response) or None if not found - """ - wire_file = session_dir / "wire.jsonl" - if not wire_file.exists(): - return None - - user_message: str | None = None - assistant_response_parts: list[str] = [] - in_first_turn = False - - try: - with open(wire_file, encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - record = json.loads(line) - message = record.get("message", {}) - msg_type = message.get("type") - - if msg_type == "TurnBegin": - if in_first_turn: - # Second turn started, stop - break - in_first_turn = True - user_input = message.get("payload", {}).get("user_input") - if user_input: - from kosong.message import Message - - msg = Message(role="user", content=user_input) - user_message = msg.extract_text(" ") - - elif msg_type == "ContentPart" and in_first_turn: - payload = message.get("payload", {}) - if payload.get("type") == "text" and payload.get("text"): - assistant_response_parts.append(payload["text"]) - - elif msg_type == "TurnEnd" and in_first_turn: - break - - except json.JSONDecodeError: - continue - except OSError: - return None - - if user_message and assistant_response_parts: - return (user_message, "".join(assistant_response_parts)) - return None - - -def truncate_wire_at_turn(wire_path: Path, turn_index: int) -> list[str]: - """Read wire.jsonl and return all lines up to and including the given turn. - - Args: - wire_path: Path to the wire.jsonl file - turn_index: 0-based turn index. Returns turns 0..turn_index inclusive. - - Returns: - List of raw JSON lines (including the metadata header) - - Raises: - ValueError: If turn_index is out of range - """ - if not wire_path.exists(): - raise ValueError("wire.jsonl not found") - - lines: list[str] = [] - current_turn = -1 # Will become 0 on first TurnBegin - - with open(wire_path, encoding="utf-8") as f: - for line in f: - stripped = line.strip() - if not stripped: - continue - - try: - record: dict[str, Any] = json.loads(stripped) - except json.JSONDecodeError: - continue - - # Always keep metadata header - if record.get("type") == "metadata": - lines.append(stripped) - continue - - message: dict[str, Any] = record.get("message", {}) - msg_type: str | None = message.get("type") - - if msg_type == "TurnBegin": - current_turn += 1 - if current_turn > turn_index: - break - - if current_turn <= turn_index: - lines.append(stripped) - - # Stop after the TurnEnd of the target turn - if msg_type == "TurnEnd" and current_turn == turn_index: - break - - if current_turn < turn_index: - raise ValueError(f"turn_index {turn_index} out of range (max turn: {current_turn})") - - return lines - - -def _is_checkpoint_user_message(record: dict[str, Any]) -> bool: - """Whether a context line is the synthetic user checkpoint marker.""" - if record.get("role") != "user": - return False - - content = record.get("content") - if isinstance(content, str): - return CHECKPOINT_USER_PATTERN.fullmatch(content.strip()) is not None - - parts = cast(list[Any], content) if isinstance(content, list) else [] - if len(parts) == 1 and isinstance(parts[0], dict): - first_part = cast(dict[str, Any], parts[0]) - text = first_part.get("text") - if isinstance(text, str): - return CHECKPOINT_USER_PATTERN.fullmatch(text.strip()) is not None - - return False - - -def truncate_context_at_turn(context_path: Path, turn_index: int) -> list[str]: - """Read context.jsonl and return all lines up to and including the given turn. - - Turn detection is based on real user messages, excluding synthetic checkpoint - user entries like ``CHECKPOINT N``. - - Unlike wire truncation, this is best-effort: if context has fewer user turns - than ``turn_index`` (e.g. slash-command turns that did not mutate context), - return all available context lines instead of failing. - """ - if not context_path.exists(): - return [] - - lines: list[str] = [] - current_turn = -1 # Will become 0 on first real user message - - with open(context_path, encoding="utf-8") as f: - for line in f: - stripped = line.strip() - if not stripped: - continue - - try: - record: dict[str, Any] = json.loads(stripped) - except json.JSONDecodeError: - continue - - if record.get("role") == "user" and not _is_checkpoint_user_message(record): - current_turn += 1 - if current_turn > turn_index: - break - - if current_turn <= turn_index: - lines.append(stripped) - - return lines - - -@router.post("/{session_id}/fork", summary="Fork a session at a specific turn") -async def fork_session( - session_id: UUID, - request: ForkSessionRequest, - runner: KimiCLIRunner = Depends(get_runner), -) -> Session: - """Fork a session, creating a new session with history up to the specified turn. - - The new session shares the same work_dir as the original session. - """ - source_session = get_editable_session(session_id, runner) - source_dir = source_session.kimi_cli_session.dir - wire_path = source_dir / "wire.jsonl" - context_path = source_dir / "context.jsonl" - - try: - truncated_wire_lines = truncate_wire_at_turn(wire_path, request.turn_index) - truncated_context_lines = truncate_context_at_turn(context_path, request.turn_index) - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) from e - - # Create new session with the same work_dir. - # Only write the essential files explicitly — do NOT copytree the whole - # source directory, which would bring in rotated context backups - # (context_N.jsonl) and subagent contexts (context_sub_N.jsonl). - work_dir = source_session.kimi_cli_session.work_dir - new_session = await KimiCLISession.create(work_dir=work_dir) - new_session_dir = new_session.dir - - # Copy only the video files that are actually referenced in the truncated - # wire history. Videos are referenced by path (