From 6548e14b69f6ed567febf7e21b23f9d850c0b7ac Mon Sep 17 00:00:00 2001 From: phantom5099 <1011668688@qq.com> Date: Wed, 22 Apr 2026 23:27:02 +0800 Subject: [PATCH 1/9] =?UTF-8?q?refactor(runtime):=20=E5=BC=95=E5=85=A5?= =?UTF-8?q?=E4=B8=8A=E4=B8=8B=E6=96=87=E9=A2=84=E7=AE=97=E6=8E=A7=E5=88=B6?= =?UTF-8?q?=E9=9D=A2=E5=B9=B6=E6=94=B6=E6=95=9B=E8=BF=90=E8=A1=8C=E6=97=B6?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6=E5=8D=8F=E8=AE=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CLAUDE.md | 179 ++-- docs/context-compact.md | 145 ++- docs/guides/configuration.md | 202 ++-- docs/runtime-provider-event-flow.md | 213 ++-- internal/app/bootstrap.go | 25 +- internal/app/bootstrap_test.go | 52 +- internal/cli/gateway_commands.go | 7 +- internal/cli/migrate_command.go | 67 ++ internal/cli/migrate_command_test.go | 96 ++ internal/cli/root.go | 19 +- internal/cli/root_test.go | 110 ++- internal/config/config_test.go | 278 +++--- internal/config/context.go | 89 +- internal/config/context_budget_migration.go | 147 +++ .../config/context_budget_migration_test.go | 95 ++ internal/config/context_test.go | 55 +- internal/config/loader.go | 45 +- internal/config/loader_test.go | 164 ++- internal/config/manager.go | 4 - internal/config/provider.go | 69 +- internal/config/provider_test.go | 52 +- .../config/state/auto_compact_threshold.go | 90 -- .../state/auto_compact_threshold_test.go | 158 --- internal/config/state/budget.go | 85 ++ internal/config/state/budget_test.go | 143 +++ .../config/state/model_additional_test.go | 2 +- internal/config/state/model_test.go | 15 +- .../config/state/service_provider_create.go | 4 +- internal/config/state/service_test.go | 36 +- internal/context/builder.go | 15 +- internal/context/builder_test.go | 80 -- internal/context/compact/runner.go | 6 +- internal/context/compact/runner_test.go | 12 +- internal/context/types.go | 6 +- internal/memo/store.go | 166 +--- internal/memo/store_test.go | 125 --- internal/provider/anthropic/driver.go | 14 +- internal/provider/anthropic/driver_test.go | 10 +- internal/provider/anthropic/provider.go | 41 +- internal/provider/anthropic/provider_test.go | 12 +- internal/provider/catalog/service_test.go | 17 + .../provider/conformance/conformance_test.go | 27 +- internal/provider/contracts.go | 51 +- internal/provider/estimate.go | 29 + internal/provider/gemini/driver.go | 6 +- internal/provider/gemini/driver_test.go | 10 +- internal/provider/gemini/provider.go | 50 +- internal/provider/gemini/provider_test.go | 30 +- internal/provider/generate_test.go | 20 + internal/provider/openaicompat/common_test.go | 16 +- .../provider/openaicompat/discovery_http.go | 7 +- .../openaicompat/discovery_http_test.go | 9 +- .../openaicompat/driver_internal_test.go | 23 +- .../provider/openaicompat/generate_sdk.go | 25 +- .../openaicompat/openaicompat_test.go | 41 +- internal/provider/openaicompat/provider.go | 45 +- internal/provider/registry_test.go | 22 +- internal/provider/types/usage.go | 7 + internal/runtime/budget_models.go | 116 +++ internal/runtime/compact.go | 3 + internal/runtime/compact_generator.go | 4 +- internal/runtime/controlplane/budget.go | 61 ++ internal/runtime/controlplane/decider.go | 4 + internal/runtime/controlplane/stop_reason.go | 2 + internal/runtime/events.go | 60 +- internal/runtime/events_subagent.go | 6 - internal/runtime/input_prepare_test.go | 2 +- internal/runtime/provider_stream.go | 2 + internal/runtime/run.go | 332 ++++--- internal/runtime/run_lifecycle.go | 9 +- internal/runtime/runtime.go | 36 +- .../runtime/runtime_branch_coverage_test.go | 8 +- internal/runtime/runtime_gap_coverage_test.go | 16 +- .../runtime/runtime_internal_helpers_test.go | 28 +- internal/runtime/runtime_progress_test.go | 22 +- .../runtime_remaining_branches_test.go | 28 +- internal/runtime/runtime_test.go | 930 +++++++----------- internal/runtime/session_mutation.go | 76 +- internal/runtime/skills_test.go | 111 +-- internal/runtime/state.go | 53 +- internal/runtime/todo_mutator_test.go | 36 +- internal/runtime/toolexec.go | 8 +- internal/runtime/turn_control_test.go | 2 +- internal/session/asset_store_test.go | 2 +- internal/session/id.go | 5 - internal/session/id_test.go | 12 - internal/session/input_preparer.go | 17 +- internal/session/input_preparer_test.go | 65 +- internal/session/skill_activation_test.go | 10 +- internal/session/sqlite_store.go | 138 ++- .../session/sqlite_store_additional_test.go | 31 +- internal/session/store.go | 128 ++- internal/session/store_test.go | 271 ++++- internal/session/test_helpers_test.go | 2 +- internal/session/todo.go | 30 +- internal/session/todo_test.go | 32 +- internal/tui/core/app/update.go | 29 +- .../core/app/update_runtime_events_test.go | 34 +- internal/tui/core/app/update_test.go | 32 + .../gateway_rpc_client_additional_test.go | 17 +- .../tui/services/gateway_rpc_client_test.go | 27 +- .../tui/services/gateway_stream_client.go | 2 + .../gateway_stream_client_additional_test.go | 26 +- internal/tui/services/runtime_bridge.go | 11 + internal/tui/services/runtime_contract.go | 22 + internal/tui/tui.go | 29 +- internal/tui/tui_test.go | 35 +- scripts/migrate_context_budget/main.go | 54 + scripts/migrate_context_budget/main_test.go | 13 + 109 files changed, 3783 insertions(+), 2814 deletions(-) create mode 100644 internal/cli/migrate_command.go create mode 100644 internal/cli/migrate_command_test.go create mode 100644 internal/config/context_budget_migration.go create mode 100644 internal/config/context_budget_migration_test.go delete mode 100644 internal/config/state/auto_compact_threshold.go delete mode 100644 internal/config/state/auto_compact_threshold_test.go create mode 100644 internal/config/state/budget.go create mode 100644 internal/config/state/budget_test.go create mode 100644 internal/provider/estimate.go create mode 100644 internal/runtime/budget_models.go create mode 100644 internal/runtime/controlplane/budget.go create mode 100644 scripts/migrate_context_budget/main.go create mode 100644 scripts/migrate_context_budget/main_test.go diff --git a/CLAUDE.md b/CLAUDE.md index 8821a29a..3cca0fd9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,69 +1,120 @@ -# CLAUDE.md — Claude Code 项目规则 - -本文件是 Claude Code 在此仓库中工作时的行为指引。完整的项目协作规则请参见 `AGENTS.md`。 - -## 项目概览 - -NeoCode Coding Agent MVP — 一个 Go 实现的 AI 编码助手,主链路为: -`用户输入 -> Agent 推理 -> 调用工具 -> 获取结果 -> 继续推理 -> UI 展示` - -## 关键目录 - -| 目录 | 职责 | -|------|------| -| `cmd/neocode` | CLI 入口 | -| `internal/app` | 应用装配与 bootstrap | -| `internal/config` | 配置模型、YAML 加载、校验 | -| `internal/provider` | 模型厂商适配器(差异收敛在此层) | -| `internal/runtime` | ReAct 主循环、事件流、Prompt 编排 | -| `internal/session` | 会话领域模型与 JSON 持久化 | -| `internal/tools` | 工具契约、注册表与具体实现 | -| `internal/tui` | Bubble Tea TUI 状态机与渲染 | -| `internal/context` | 上下文构建、压缩决策 | -| `docs` | 架构与设计文档 | - -## 必须遵守的规则 - -### 架构边界 -- **不跨层直连**:遵循 `TUI -> Runtime -> Provider / Tool Manager` 主链路 -- **不泄漏厂商差异**:模型协议差异收敛在 `internal/provider` 内 -- **不内嵌工具逻辑**:所有可被模型调用的能力必须进入 `internal/tools` -- **不散落状态**:会话状态、消息历史、工具调用记录由 `runtime` 管理 - -### 编码规范 -- Go 惯用风格,制表符缩进,单行约 120 字符 -- `PascalCase` 导出,`camelCase` 未导出 -- 新增函数必须附带中文注释,说明职责与关键行为 -- 不硬编码路径、URL、模型名、超时等,通过配置或常量注入 -- 不硬编码业务语义字符串,收敛到共享常量或类型定义 - -### 安全 -- 不写入明文 API Key -- 配置只保存环境变量名 -- `filesystem` 工具限制在工作目录内 -- `bash` 工具限制超时与输出长度 -- 本地运行数据、会话数据不入库 - -### 测试 -- 整体测试覆盖率以 **100%** 为硬性目标 -- 改动必须同步补齐测试:正常路径 + 边界条件 + 异常分支 + 回归场景 -- 优先覆盖:配置校验、provider 转换、tool 参数校验、runtime 停止条件、事件派发 - -### 文档 -- 沿用目标文档已有语言(中文为主则续用中文) -- 实现与文档冲突时必须修正至少一个 - -### 提交前检查 -- 确认职责分工未被破坏 -- 确认新增能力接到正确层级 -- `go build ./...` && `go test ./...` && `gofmt -w ./cmd ./internal` -- 检查 `git status`,确保无密钥或临时文件混入 +# CLAUDE.md — Claude Code 项目指引 + +本文件是 Claude Code 在本仓库中工作的快速入口。完整且最高优先级的协作规则以 +`AGENTS.md` 为准;如果两份文档出现冲突,先遵守 `AGENTS.md`,再同步修正本文。 + +## 项目定位 + +NeoCode 是一个 Go 实现的编码 Agent。当前架构已经完成控制面与数据面拆分,主链路必须保持为: + +```text +用户输入(TUI) -> 网关中继(Gateway) -> Agent 推理(Runtime) -> 工具调用(Tools) -> 结果回传 -> UI 展示 +``` + +做任何修改时,优先确认主链路仍可运行、职责边界仍清晰、行为可通过测试验证。 + +## 当前实现重点 + +- `internal/tui` 只负责 Bubble Tea 状态机、渲染和事件消费,不直接调用 provider,不执行工具。 +- `internal/gateway` 负责协议路由、JSON-RPC 归一化、权限与事件中继,是 TUI 与 Runtime 的边界。 +- `internal/runtime` 负责 ReAct 循环、事件派发、工具调用编排、token ledger、compact 触发和停止条件。 +- `internal/context` 负责 system prompt、AGENTS.md、Task State、Todo State、Skills、Memo、消息裁剪和 micro compact。 +- `internal/provider` 只收敛模型厂商差异,包括请求组装、流式解析、usage 读取和输入 token 估算。 +- `internal/tools` 负责工具 schema、注册表、参数校验、执行和结果格式。 +- `internal/session` 负责会话领域模型以及 JSON / SQLite 持久化。 +- `internal/config` 负责配置加载、校验、provider/model 选择、环境变量名和 context budget 配置。 +- `internal/app` 只做应用装配与依赖注入,不承载业务规则。 + +## 架构边界 + +- 不要跨层直连;新增能力默认沿 `TUI -> Gateway -> Runtime -> Provider / Tool Manager` 接入。 +- 不要把 provider 厂商字段、错误格式或协议细节泄漏到 runtime、tui 或上层调用方。 +- 不要在 TUI 或 Runtime 中内嵌可被模型调用的具体工具逻辑;工具能力必须进入 `internal/tools`。 +- 不要把会话状态、消息历史、工具调用记录散落到 UI;这些状态由 runtime/session 管理。 +- 新设计已确定时,不要为了“可能兼容旧版本”保留旧分支、旧协议兜底或碎片化实现。 + +## Budget 与 Compact + +当前上下文预算统一使用 `context.budget`,不再使用旧的 `context.auto_compact` 运行时语义。 + +- `context.compact` 只描述 compact 策略和 read-time micro compact 行为。 +- `context.budget` 描述 prompt budget、reserve tokens、fallback budget 和 reactive compact 次数。 +- runtime 在发送 provider 请求前冻结 turn snapshot,并调用 provider 的 `EstimateInputTokens`。 +- budget 控制面的闭环是: + +```text +BuildRequest -> FreezeSnapshot -> EstimateInput -> DecideBudget -> allow | compact | stop +``` + +- 首次超预算时执行 `proactive` compact;compact 后仍超预算时停止本次 run,并发出 + `STOP_BUDGET_EXCEEDED`。 +- provider 返回 `context_too_long` 时进入 `reactive` compact,再重新进入预算闭环。 +- context builder 不负责预算判断,也不返回旧的 auto compact 建议。 + +## Runtime 事件协议 + +TUI 只消费当前 runtime/gateway 事件协议。停止原因使用: + +- `STOP_COMPLETED` +- `STOP_USER_INTERRUPT` +- `STOP_FATAL_ERROR` +- `STOP_BUDGET_EXCEEDED` + +预算和 token 相关事件包括: + +- `budget_checked` +- `ledger_reconciled` +- `token_usage` + +`token_usage` 需要携带本轮 usage、来源标签、会话累计值和 `has_unknown_usage`。不要复用旧的 +`usage` 协议,也不要在 TUI 中猜测 provider usage 语义。 + +## 配置与安全 + +- 主配置路径为 `~/.neocode/config.yaml`。 +- API Key 只通过环境变量读取,配置文件只保存环境变量名,不保存明文密钥。 +- `workdir` 通过启动参数或运行时上下文传入,不写入主配置。 +- custom provider 使用 `~/.neocode/providers//provider.yaml`。 +- `context.auto_compact` 只允许被迁移到 `context.budget`;主解析器仍只接受当前结构。 +- `filesystem` 工具必须限制在工作目录内。 +- `bash` 工具必须限制超时、输出长度,并避免交互式阻塞命令。 +- `webfetch` 工具必须限制协议范围、响应大小和内容类型。 + +## 修改代码时 + +- 先判断改动属于 config、context、runtime、provider、tools、session、gateway、tui 还是 app。 +- 优先做最小闭环改动,不做无关重构。 +- 涉及 provider 协议差异时,只改 provider 层或其明确契约。 +- 涉及模型可调用能力时,先补 tools 契约,再由 runtime 接入。 +- 涉及配置结构、事件协议、目录职责或命令时,同步更新 `docs/`、`README.md` 或本文。 +- 非测试文件新增函数时,紧邻函数定义写中文注释,说明职责和关键行为。 +- 所有中文文件按 UTF-8 无 BOM 读写,发现乱码先确认编码再修改。 + +## 测试重点 + +改动应同步补测试,优先覆盖: + +- 配置默认值、校验、迁移和错误包装。 +- provider 请求组装、stream 解析、tool call 解析、usage 与错误映射。 +- runtime 最大轮数、停止原因、tool result 回灌、compact、budget、token ledger 和事件派发。 +- context build 输入输出契约、AGENTS.md 加载、micro compact、消息裁剪边界。 +- tools 参数校验、权限、超时、输出裁剪和错误格式。 +- session JSON / SQLite 持久化、schema migration、token totals 和 `HasUnknownUsage`。 +- TUI 对 gateway/runtime 当前事件协议的映射和展示状态。 ## 常用命令 ```bash -go run ./cmd/neocode # 启动应用 -go build ./... # 编译 -go test ./... # 运行测试 -gofmt -w ./cmd ./internal # 格式化 +go run ./cmd/neocode +go build ./... +go test ./... +gofmt -w ./cmd ./internal ``` + +## 提交前检查 + +- 主链路仍是 `TUI -> Gateway -> Runtime -> Tools -> UI`。 +- 新增能力没有跨层接线,也没有把厂商差异泄漏到上层。 +- 配置、事件、schema 和文档使用当前结构,不保留旧协议兜底。 +- 测试覆盖本次改动的正常路径、边界条件、异常分支和回归场景。 +- `git status` 中没有密钥、本地配置、临时目录或无关文件混入。 diff --git a/docs/context-compact.md b/docs/context-compact.md index 5799073e..c5380c6a 100644 --- a/docs/context-compact.md +++ b/docs/context-compact.md @@ -1,25 +1,19 @@ # Context Compact -## Auto Compact Failure Fallback +本文说明 NeoCode 当前的上下文压缩策略、预算触发链路和 compact 协议。 -- 当 `context.auto_compact.input_token_threshold <= 0` 时,系统会尝试基于当前模型的 `ContextWindow` 自动推导阈值。 -- 若当前 provider 选择无效、catalog snapshot 查询失败,或模型窗口元数据缺失,系统会直接回退到 `fallback_input_token_threshold`。 -- 自动推导失败不会静默关闭 auto compact;runtime 仍会拿到一个可用的保底阈值。 +## 总览 -本文档说明 NeoCode 中 context compact 的配置、执行链路和摘要约定。 +当前 compact 只承担“压缩 transcript”的职责,不再负责预算判断。预算控制已独立为 `context.budget`: -## 概览 +- `manual`:用户通过 `/compact` 主动触发 +- `proactive`:发送前输入预算超限时触发 +- `reactive`:provider 返回 `context_too_long` 时触发 -- runtime 已接入手动 compact、基于 token 阈值的自动 compact,以及 provider 上下文过长后的 `reactive` compact 自动恢复。 -- `internal/context/compact` 支持 `manual`、`auto` 与 `reactive` 三种 mode。 -- 用户通过 `/compact` 对当前会话执行一次上下文压缩。 -- compact 前会先写入完整 transcript,随后生成并校验新的 durable `TaskState` 与 display summary,再回写会话消息。 -- compact 的 system prompt 静态说明模板由 `internal/promptasset` 通过 `go:embed` 提供,但 compact user prompt 的元数据块、消息边界和 transcript 渲染仍由代码拼装。 +三种模式共用同一条 compact 执行管线,但触发源不同。 ## 配置 -compact 相关配置位于: - ```yaml context: compact: @@ -29,69 +23,80 @@ context: read_time_max_message_spans: 24 max_summary_chars: 1200 micro_compact_disabled: false - auto_compact: - enabled: false - input_token_threshold: 0 + budget: + prompt_budget: 0 reserve_tokens: 13000 - fallback_input_token_threshold: 100000 + fallback_prompt_budget: 100000 + max_reactive_compacts: 3 ``` +### `context.compact` + - `manual_strategy` 控制手动 compact 的策略,支持 `keep_recent` 和 `full_replace`。 - `manual_keep_recent_messages` - 在 `keep_recent` 模式下保留最近消息数量,并按 tool call 与 tool result 的原子块整体保留。 + 在 `keep_recent` 模式下保留的最近消息数,并按 tool call / tool result 的原子块整体保留。 - `read_time_max_message_spans` - 控制 `context.Builder` 读时 trim 可保留的 message span 上限;该值越大,普通“继续”续跑时越不容易在未触发 compact 前丢掉较早的文件读取结果。 + 控制 `context.Builder` 读时 trim 可保留的 message span 上限。 - `micro_compact_retained_tool_spans` - 控制 read-time micro compact 默认保留原始内容的最近可压缩工具块数量;默认值为 `6`,显式配置为更小值时可更积极地回收旧工具结果。 + 控制 read-time micro compact 默认保留原始内容的最近可压缩工具块数量。 - `max_summary_chars` 控制 compact summary 的最大字符数。 - `micro_compact_disabled` - 控制是否关闭默认启用的读时 micro compact;设为 `true` 时会回退为仅 trim、不清理旧 tool result。 -- `auto_compact.enabled` - 控制是否启用基于 token 阈值的自动压缩;默认关闭。 -- `auto_compact.input_token_threshold` - 当会话累计输入 token 数达到此阈值时触发自动压缩;默认 `0`(自动推导),推导失败时回退到 `fallback_input_token_threshold`(默认 `100000`)。 + 控制是否关闭默认启用的 read-time micro compact。 -## 自动压缩 +### `context.budget` -当 `auto_compact.enabled` 为 `true` 时,runtime 在每次调用 `context.Builder.Build()` 时将当前 token 累计值传入 Metadata,context 模块通过比较累计值与阈值在 `BuildResult.AutoCompactSuggested` 中返回压缩建议。runtime 读取建议后调用现有 compact 管线执行压缩;token 计数的重置与持久化语义统一见 [Session 持久化设计](./session-persistence-design.md)。 +- `prompt_budget` + 显式输入预算;`> 0` 时直接使用,`0` 表示自动推导。 +- `reserve_tokens` + 自动推导预算时,为输出、tool call、system prompt 预留的缓冲。 +- `fallback_prompt_budget` + 模型窗口不可用时的保底输入预算。 +- `max_reactive_compacts` + 单次 run 内 reactive compact 的最大尝试次数。 -设计原则: -- **context 拥有压缩决策权**,runtime 只做编排执行。 -- 每次 `Run()` 调用最多触发一次自动压缩,避免无限循环。 -- 压缩成功后 token 计数器重置为零,下一轮不会立即重复触发。 +## 预算闭环 -新增工具时,micro compact 策略不再由 `context` 层静态白名单维护,而是由 `internal/tools` 中的工具实现声明。 -默认情况下,已注册工具都会参与 micro compact;只有显式声明保留历史结果的工具才会跳过旧结果清理。 -默认 pin 仅对 `filesystem_write_file` 与 `filesystem_edit` 这类文件内容修改工具生效,用于保留 `README`、spec/schema、`go.mod`、`package.json` 等关键产物的最近结果;`.env*` 不参与默认 pin,避免敏感内容在上下文中滞留更久。 -但 micro compact 只有在当前会话已经建立非空 `TaskState` 时才会生效;没有 durable task state 时,context 仅做 trim,不清理旧 tool result。 +当前发送链路固定为: -## 执行链路 +```text +BuildRequest -> FreezeSnapshot -> EstimateInput -> DecideBudget -> (allow | compact | stop) +``` -1. TUI 识别 `/compact` 并调用 `runtime.Compact(...)`。 -2. runtime 发出 `compact_start` 事件。 -3. compact runner 将原始消息写入 transcript(JSONL)。 -4. compact runner 根据策略构造归档消息与保留消息,并过滤旧的 `[compact_summary]` 展示摘要,避免“摘要的摘要”。 -5. runtime 选择用于生成 summary 的 provider 和 model: - 优先复用会话记录的 `provider` / `model`,缺失时回退到当前配置。 -6. summary generator 调用模型生成完整 `task_state` 与 display summary。 -7. runner 校验 display summary 结构与长度,必要时截断,并写入 `task_state.last_updated_at`。 -8. compact 成功时回写 `session.TaskState` 与会话消息并发出 `compact_applied`;失败时发出 `compact_error`。 +关键规则: + +- `context.Builder` 只构建 provider-facing request,不再返回旧的 builder 压缩建议布尔值。 +- provider 发送前一定先做输入 token estimate。 +- estimate 首次超预算时,runtime 执行一次 `proactive` compact,然后重建 request 并重新估算。 +- compact 后仍超预算时,runtime 直接停止本次 run,并返回 `STOP_BUDGET_EXCEEDED`。 +- provider 返回 `context_too_long` 时,runtime 触发 `reactive` compact,并重新进入同一预算闭环。 + +## compact 如何压缩 + +compact runner 会先写入完整 transcript,再生成 durable `TaskState` 与面向人类阅读的 `display_summary`。 -其中 `reactive` mode 在 context 包内与 `manual` 复用同一条压缩管线: +自动链路下的保留规则固定为: -1. 先写 transcript。 -2. 默认按 `keep_recent` 裁剪可归档历史。 -3. 生成并校验 display summary,同时更新 durable `TaskState`。 -4. 返回压缩后的消息、`TaskState` 与 transcript 元信息。 +- 最近一条显式用户消息所在 span 永远保留原文 +- 最近尾部消息原样保留 +- 更早历史归档为一条 `[compact_summary]` -当 provider 返回“上下文过长”错误时,runtime 会: +这意味着: -1. 识别 provider 归一化后的 typed error,必要时回退到错误文本匹配。 -2. 触发 `compact.Run(mode=reactive)`,并在仍然命中“上下文过长”时继续做逐步降级恢复。 -3. 继续复用 `compact_start`、`compact_applied`、`compact_error` 事件,并通过 `trigger_mode=reactive` 区分来源。 -4. 每次 `Run()` 最多执行 3 次 reactive compact 降级尝试;每次尝试都会进一步收缩 `manual_keep_recent_messages`,超过上限后返回最后一次 provider 错误。 +- 当前轮用户刚输入的问题不会被摘要替换 +- 被压缩的是更早的历史消息,而不是当前交互的最近尾部 + +## 执行链路 + +1. TUI 识别 `/compact` 并调用 `runtime.Compact(...)`。 +2. runtime 发出 `compact_start` 事件。 +3. compact runner 写入原始 transcript。 +4. compact runner 根据策略构造归档消息与保留消息,并过滤旧 `[compact_summary]`,避免“摘要的摘要”。 +5. summary generator 调用模型生成完整 `task_state` 与 `display_summary`。 +6. runner 校验 summary 结构与长度,必要时截断,并更新 `task_state.last_updated_at`。 +7. compact 成功后,runtime 回写会话消息与 `TaskState`,重置 token totals 和 `HasUnknownUsage`,并发出 `compact_applied`。 +8. compact 失败时发出 `compact_error`。 ## 生成协议 @@ -113,13 +118,13 @@ compact generator 必须只返回一个 JSON 对象,顶层固定包含: } ``` -- `task_state` 表示 compact 之后的完整 durable task state,而不是增量 patch。 -- `task_state` 只允许包含固定字段,不允许混入模型自定义键。 -- `display_summary` 仍然必须使用 `[compact_summary]` 协议,供人类阅读和后续轮次参考。 +约束: -上述 JSON 契约与 `[compact_summary]` 格式模板仍由代码注入到 compact system prompt 中,避免在模板文件里复制一份会随实现演进的协议定义。 +- `task_state` 表示 compact 后的完整 durable task state,不是增量 patch +- `task_state` 只允许固定字段 +- `display_summary` 必须使用 `[compact_summary]` 协议 -`display_summary` 必须以如下结构返回: +`display_summary` 结构如下: ```text [compact_summary] @@ -140,17 +145,6 @@ constraints: - ... ``` -- 必须包含固定起始标记 `[compact_summary]`。 -- 必须包含 `done`、`in_progress`、`decisions`、`code_changes`、`constraints` 五个 section。 -- 每个 section 至少包含一条非空 bullet。 - -## 保留原则 - -- durable truth 优先进入 `TaskState`,而不是散落在聊天消息里。 -- `TaskState` 重点保留目标、已完成进展、未完成事项、下一步、阻塞点、关键工件、决策、用户约束。 -- `display_summary` 只保留继续工作最少需要的人类可读信息。 -- 默认忽略工具详细输出、重复背景、已解决错误的排查细节。 - ## 事件 compact 相关 runtime 事件包括: @@ -169,12 +163,3 @@ compact 相关 runtime 事件包括: - `trigger_mode` - `transcript_id` - `transcript_path` - -## Auto Compact 阈值解析 - -- `context.auto_compact.input_token_threshold > 0` 时,直接使用显式手动阈值。 -- `context.auto_compact.input_token_threshold <= 0` 时,系统会对当前选中的 provider/model 做自动推导。 -- 自动推导公式为 `resolved_threshold = context_window - reserve_tokens`。 -- `reserve_tokens` 默认 `13000`,用于给输出、tool call 和 system prompt 预留缓冲。 -- 如果当前模型没有可用的 `ContextWindow`,或窗口值小于等于 `reserve_tokens`,则回退到 `fallback_input_token_threshold`。 -- `fallback_input_token_threshold` 默认 `100000`,用于保证主链路在缺少模型窗口元数据时仍可稳定运行。 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 4385cb2c..b9a18ec8 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -1,42 +1,29 @@ # 配置指南 -本文说明 NeoCode 当前真实生效的配置规则。 - -## 总原则 - -- `config.yaml` 只保存最小运行时状态 -- provider 元数据来自代码内置定义或 custom provider 文件 -- API Key 只从环境变量读取 -- YAML 采用严格解析,未知字段直接报错 - -这意味着 NeoCode 当前不会: - -- 自动清理旧版 `providers` / `provider_overrides` -- 自动兼容 `workdir`、`default_workdir` 等旧字段 +本文说明 NeoCode 当前真实生效的配置结构与约束。 ## 配置文件位置 -主配置文件路径: +主配置文件: ```text ~/.neocode/config.yaml ``` -custom provider 目录: +自定义 provider 目录: ```text ~/.neocode/providers//provider.yaml ``` -## `config.yaml` 可写字段 - -当前支持的主配置示例: +## `config.yaml` 示例 ```yaml selected_provider: openai current_model: gpt-5.4 shell: bash tool_timeout_sec: 20 + runtime: max_no_progress_streak: 3 max_repeat_cycle_streak: 3 @@ -60,77 +47,92 @@ context: read_time_max_message_spans: 24 max_summary_chars: 1200 micro_compact_disabled: false - auto_compact: - enabled: false - input_token_threshold: 0 + budget: + prompt_budget: 0 reserve_tokens: 13000 - fallback_input_token_threshold: 100000 + fallback_prompt_budget: 100000 + max_reactive_compacts: 3 ``` -### 基础字段 +## 基础字段 | 字段 | 说明 | |------|------| | `selected_provider` | 当前选中的 provider 名称 | | `current_model` | 当前选中的模型 ID | -| `shell` | 默认 shell,Windows 默认 `powershell`,其他平台默认 `bash` | -| `tool_timeout_sec` | 工具执行超时(秒) | +| `shell` | 默认 shell;Windows 默认 `powershell`,其他平台默认 `bash` | +| `tool_timeout_sec` | 工具执行超时秒数 | + +## `context` 字段 -### `context` 字段 +### `context.compact` | 字段 | 说明 | |------|------| | `context.compact.manual_strategy` | `/compact` 手动压缩策略,支持 `keep_recent` / `full_replace` | -| `context.compact.manual_keep_recent_messages` | `keep_recent` 策略下保留的最近消息数 | -| `context.compact.micro_compact_retained_tool_spans` | read-time micro compact 默认保留原始内容的最近可压缩工具块数量,默认 `6` | -| `context.compact.read_time_max_message_spans` | context 读时保留的 message span 上限,用于降低“继续”时较早文件读取结果被过早裁掉的风险 | +| `context.compact.manual_keep_recent_messages` | `keep_recent` 下保留的最近消息数 | +| `context.compact.micro_compact_retained_tool_spans` | read-time micro compact 默认保留原始内容的最近工具块数量 | +| `context.compact.read_time_max_message_spans` | context 构建时保留的 message span 上限 | | `context.compact.max_summary_chars` | compact summary 最大字符数 | | `context.compact.micro_compact_disabled` | 是否关闭默认启用的 micro compact | -| `context.auto_compact.enabled` | 是否启用自动压缩 | -| `context.auto_compact.input_token_threshold` | 自动压缩输入 token 阈值 | -| `context.auto_compact.reserve_tokens` | 自动阈值推导时预留 token 缓冲(`resolved_threshold = context_window - reserve_tokens`) | -| `context.auto_compact.fallback_input_token_threshold` | 自动推导失败时使用的保底阈值 | -默认 pin 仅对 `filesystem_write_file` 与 `filesystem_edit` 这类文件修改工具生效,用于保留关键产物文件的最近结果;`.env*` 不参与默认 pin,避免敏感内容在上下文中保留更久。 - -### `runtime` 字段 +### `context.budget` | 字段 | 说明 | |------|------| -| `runtime.max_no_progress_streak` | 连续”无进展”轮次熔断阈值,默认 `3`;streak 达到 `limit-1`(默认第 2 轮)时向模型注入一次系统级纠偏提示,达到 `limit`(默认第 3 轮)时终止运行 | -| `runtime.max_repeat_cycle_streak` | 连续“重复调用同一工具参数”轮次熔断阈值,默认 `3`;达到阈值后终止运行 | -| `runtime.assets.max_session_asset_bytes` | 单个 `session_asset` 最大原始字节数,默认 `20971520`(20 MiB);`0` 或未配置时回退默认值 | -| `runtime.assets.max_session_assets_total_bytes` | 单次请求可携带的 `session_asset` 原始总字节上限,默认 `20971520`(20 MiB);`0` 或未配置时回退默认值 | +| `context.budget.prompt_budget` | 显式输入预算;`> 0` 时直接使用,`0` 表示自动推导 | +| `context.budget.reserve_tokens` | 自动推导输入预算时,从模型窗口中预留给输出、tool call、system prompt 的缓冲 | +| `context.budget.fallback_prompt_budget` | 模型窗口不可用或推导失败时使用的保底输入预算 | +| `context.budget.max_reactive_compacts` | 单次 `Run()` 内允许的 reactive compact 最大次数 | -### `tools` 字段 +## Budget 解析规则 -| 字段 | 说明 | -|------|------| -| `tools.webfetch.max_response_bytes` | WebFetch 最大响应字节数 | -| `tools.webfetch.supported_content_types` | WebFetch 允许的内容类型 | -| `tools.mcp.servers` | MCP server 列表 | +NeoCode 已不再使用旧的 `auto_compact` 阈值语义,当前统一使用 `context.budget`: -## 不写入 `config.yaml` 的字段 +1. `context.budget.prompt_budget > 0` 时,直接使用显式预算。 +2. `context.budget.prompt_budget <= 0` 时,系统尝试基于当前 provider/model 的 `ContextWindow` 自动推导。 +3. 自动推导公式为: -以下内容不允许写入主配置文件: +```text +prompt_budget = context_window - reserve_tokens +``` -- `providers` -- `provider_overrides` -- `workdir` -- `default_workdir` -- `base_url` -- `api_key_env` -- `models` +4. 如果当前 provider 选择无效、catalog snapshot 查询失败、模型缺少可用 `ContextWindow`,或 `ContextWindow <= reserve_tokens`,则回退到 `fallback_prompt_budget`。 + +## 配置结构升级 + +启动时会在严格解析 `config.yaml` 前执行一次结构升级: + +- 仅当检测到 `context.auto_compact` 时,自动迁移为 `context.budget`。 +- 迁移前会写入 `config.yaml.bak`,原配置内容保留在备份中。 +- 如果 `context.auto_compact` 与 `context.budget` 同时存在,程序会直接报错,避免猜测覆盖用户配置。 +- 主解析器仍只接受当前结构;迁移完成后不会在运行时兼容旧字段。 + +打包用户不需要额外执行迁移命令。`neocode migrate context-budget` 仅用于提前检查或手动修复配置文件。 -如果这些字段出现在 `config.yaml` 中,加载会直接失败,而不是被“自动迁移”或“悄悄清理”。 +## 预算闭环 + +当前发送链路采用固定闭环: + +```text +BuildRequest -> FreezeSnapshot -> EstimateInput -> DecideBudget -> (allow | compact | stop) +``` + +规则如下: + +- provider 发送前一定先做输入 token estimate。 +- 如果 estimate 没超过 `prompt_budget`,本轮允许发送。 +- 如果 estimate 首次超预算,先执行一次 `proactive` compact,然后重建请求并重新估算。 +- 如果 compact 后仍超预算,直接停止当前 run,并产出 `STOP_BUDGET_EXCEEDED`。 +- 如果 provider 返回 `context_too_long`,runtime 会进入 `reactive` compact 恢复链路,并重新进入同一预算闭环。 ## provider 策略 -NeoCode 采用“builtin provider + custom provider”双来源模型。 +NeoCode 采用 “builtin provider + custom provider” 双来源模型。 ### builtin provider -builtin provider 由代码内置,集中定义在: +内置 provider 定义于: ```text internal/config/builtin_providers.go @@ -145,7 +147,7 @@ internal/config/builtin_providers.go ### custom provider -custom provider 通过单独文件声明,而不是写进 `config.yaml`: +自定义 provider 通过单独文件声明,而不是写入 `config.yaml`: ```yaml name: company-gateway @@ -158,51 +160,19 @@ chat_endpoint_path: /chat/completions discovery_endpoint_path: /models ``` -`model_source` 语义如下: - -- `discover`(默认):通过 discovery(如 `/models`)拉取模型列表。 -- `manual`:不触发 discovery,优先使用 `models` 中声明的模型列表。 - -`chat_api_mode`(仅 `openaicompat` 生效)语义如下: - -- `chat_completions`:按 Chat Completions 协议发送请求。 -- `responses`:按 Responses 协议发送请求。 -- 省略时按默认 `chat_completions` 处理;`chat_endpoint_path` 仅负责路由,不再决定协议模式。 - -`manual` 模式示例: - -```yaml -name: company-gateway-manual -driver: openaicompat -api_key_env: COMPANY_GATEWAY_API_KEY -model_source: manual -base_url: https://llm.example.com/v1 -chat_endpoint_path: /chat/completions -models: - - id: gpt-4o-mini - name: GPT-4o Mini - context_window: 128000 -``` - -迁移与兼容性说明: - -- 老配置未声明 `model_source` 时,默认按 `discover` 处理。 -- `manual` 模式下必须提供 `models`,否则会在加载/创建阶段报错。 -- `manual` 模式会忽略 discovery 相关字段(如 `discovery_endpoint_path`)。 -- `provider.yaml` 仅支持平铺字段:`name/driver/base_url/api_key_env/model_source/chat_api_mode/chat_endpoint_path/discovery_endpoint_path/models`。 - -## Auto Compact 失败与校验补充 +## 不写入 `config.yaml` 的字段 -- 当 `context.auto_compact.input_token_threshold <= 0` 时,如果当前 provider 选择无效、catalog snapshot 查询失败、模型缺少可用的 `ContextWindow`,或 `ContextWindow <= reserve_tokens`,系统会回退到 `fallback_input_token_threshold`,不会静默关闭 auto compact。 -- `~/.neocode/providers//provider.yaml` 中的 `models[].id` 必须非空。 -- `models[].context_window` 和 `models[].max_output_tokens` 如果显式配置,必须大于 `0`。 -- `models` 中重复的模型 `id` 会在加载 `provider.yaml` 时直接报错。 +以下内容不允许写入主配置文件: -文件路径: +- `providers` +- `provider_overrides` +- `workdir` +- `default_workdir` +- `base_url` +- `api_key_env` +- `models` -```text -~/.neocode/providers/company-gateway/provider.yaml -``` +如果这些字段出现在 `config.yaml` 中,加载会直接失败,不会自动迁移或清理。 ## 环境变量 @@ -231,19 +201,6 @@ $env:OPENAI_API_KEY = "sk-..." $env:GEMINI_API_KEY = "AI..." ``` -## 启动时的选择修正 - -`config.yaml` 里的 `selected_provider/current_model` 表达的是“用户上次保存的选择状态”。 - -启动时系统还会进行选择校验与必要修正;若 driver 不受支持会报错并中止。因此需要区分两件事: - -- 配置快照结构合法 -- 当前选择已经可直接运行 - -前者由 `config.ValidateSnapshot()` 保证,后者由 `internal/config/state.Service.EnsureSelection()` 保证。 - -不要把这两层职责混在一起理解。 - ## CLI 运行参数覆盖 工作目录不写入 `config.yaml`,只通过启动参数覆盖: @@ -254,12 +211,9 @@ go run ./cmd/neocode --workdir /path/to/workspace 说明: -- `--workdir` 只影响本次进程 +- `--workdir` 只影响当前进程 - 不会回写到 `config.yaml` - 工具根目录与 session 隔离都会使用该工作区 -- TUI 默认通过本地 Gateway(优先 IPC)转发 runtime 请求 -- 启动时会先探测本地网关;若未运行会自动后台拉起并等待就绪 -- 若自动拉起后仍连接或握手失败会直接退出(Fail Fast) ## 常见错误 @@ -272,7 +226,9 @@ go run ./cmd/neocode --workdir /path/to/workspace - `providers` - `provider_overrides` -当前版本会直接报未知字段错误。处理方式是手动删除这些字段,而不是等待程序自动迁移。 +当前版本会直接报未知字段或结构不匹配错误。处理方式是手动删除旧字段,而不是等待程序自动兼容。 + +`context.auto_compact` 是例外:如果配置中只存在旧预算块,启动时会自动迁移为 `context.budget`;如果新旧预算块同时存在,则需要手动合并后再启动。 ### API Key 未设置 @@ -282,10 +238,4 @@ go run ./cmd/neocode --workdir /path/to/workspace config: environment variable OPENAI_API_KEY is empty ``` -处理方式:先在当前 shell 中设置对应环境变量,再启动 NeoCode。 - -## 相关文档 - -- [添加 Provider](./adding-providers.md) -- [配置管理详细设计](../config-management-detail-design.md) -- [Context Compact](../context-compact.md) +先在当前 shell 中设置对应环境变量,再启动 NeoCode。 diff --git a/docs/runtime-provider-event-flow.md b/docs/runtime-provider-event-flow.md index d08da197..de876971 100644 --- a/docs/runtime-provider-event-flow.md +++ b/docs/runtime-provider-event-flow.md @@ -2,7 +2,7 @@ ## Runtime 事件类型 -当前 runtime 对外暴露一组小而稳定的事件(1A 硬切后不再保留旧事件镜像): +当前 runtime 对外暴露一组稳定事件: - `user_message` - `agent_chunk` @@ -17,6 +17,8 @@ - `provider_retry` - `permission_requested` - `permission_resolved` +- `budget_checked` +- `ledger_reconciled` - `token_usage` - `skill_activated` - `skill_deactivated` @@ -25,124 +27,121 @@ - `compact_applied` - `compact_error` -这三类 compact 事件同时用于 `manual`、`auto` 与 `reactive` 三种来源,调用方可通过 payload 中的 `trigger_mode` 区分。 - ## ReAct 主循环 -1. 加载目标会话或创建新会话。 -2. 追加最新的用户消息。 -3. 读取最新配置快照。 -4. 调用 `context.Builder` 生成本轮请求使用的 `system prompt` 和消息上下文。 -5. 如命中 token 阈值自动压缩建议,则先执行一次 compact,再在同一轮内重建请求。 -6. 冻结当前 turn 的 `provider / model / tools / workdir / request` 快照。 -7. 调用 `Provider.Generate`,并把流式事件桥接给 TUI。 -8. 如 provider 返回“上下文过长”错误,则触发 `reactive` compact,并在同一 run 内最多做 3 次逐步降级的恢复尝试。 -9. 保存 assistant 完整回复。 -10. 执行返回的工具调用,并保存每一个工具结果。 -11. 如果最终 assistant 回复后没有后续工具调用,则在 runtime 收口处安排一次后台 memo 自动提取。 -12. 如果仍需继续推理,则进入下一轮;否则结束。 - -补充说明: -- runtime 不再设置内部 `max_loops` 停止条件;单次 run 仅在拿到最终 assistant 回复、遇到错误或收到外部取消时结束。 -- 由于 session 锁覆盖整个 run 生命周期,同一会话如果持续陷入工具调用循环,会一直占用该会话直到模型自行收口、报错或被取消。 - -### Memo 自动提取调度 - -- 自动提取只在最终 assistant 回复完成且当前轮没有后续工具调用时调度。 -- 如果本次 `Run` 已成功调用 `memo_remember`,则不再安排自动提取,避免与显式写入重复。 -- runtime 只负责在结束点调度,不直接执行提取逻辑;实际 debounce、尾随执行与持久化去重由 `internal/memo` 内部处理。 -- 调度时会绑定当次 provider/model 快照,后台任务不会重新读取全局当前配置,避免把历史会话消息发送到后续切换后的 provider。 -- 自动提取失败只记日志,不额外发出 TUI 事件,也不影响主链路完成。 -- `memo` 的最近消息窗口会复用 `internal/context` 的只读投影规则,只保留 provider-safe 的消息序列。 -- assistant 含 `tool_calls` 时,只有在窗口内能同时保留对应 `tool` 响应时才会注入;缺响应、空内容或已被 micro compact 清空的 assistant/tool 片段会整组丢弃,保留项会先投影为模型可消费的结构化文本。 -- recent window 的总消息数有硬预算:`min(limit*2, 24)`;超过预算的整段 tool span 会被跳过,避免窗口体积失控。 -- 进入 `memo` 提取前,tool 文本会二次收敛为 `content_excerpt`,并按 `600` rune 上限截断。 - -补充约束: -- 同一 turn 内的 provider retry 只重放冻结后的 turn 快照,不会重新读取配置。 -- `auto compact` 与 `reactive compact` 都不额外消耗 reasoning turn。 -- 权限审批等待由 `internal/runtime/approval` 负责 request 生命周期,runtime 自己负责事件发射与 tool 重试编排。 - -### Context Builder 输入与职责 - -- `runtime` 只向 `context.Builder` 传递本轮所需元数据: - - 历史消息 - - `workdir` - - `shell` - - 当前 `provider` - - 当前 `model` - - 会话累计输入 token 数(`SessionInputTokens`) - - 会话累计输出 token 数(`SessionOutputTokens`) - - 自动压缩阈值(`AutoCompactThreshold`) -- `context.Builder` 负责统一组装: - - 固定核心 system prompt sections(静态模板由 `internal/promptasset` 通过 `go:embed` 提供) - - 从 `workdir` 向上发现的 `AGENTS.md` - - `Task State` - - `Todo State` - - `Skills` - - 可选 `Memo` - - 系统状态摘要(`workdir` / `shell` / `provider` / `model` / git branch / git dirty) - - 裁剪后的历史消息 - - 自动压缩决策(`BuildResult.AutoCompactSuggested`) -- `runtime` 不直接读取规则文件,也不直接查询 git 状态。 -- `provider` 只消费最终生成的 `SystemPrompt`、消息列表和工具 schema,不感知上下文来源。 - -### 静态模板与动态拼装边界 - -- `internal/promptasset` 负责承载受版本管理的静态 prompt 模板资产,并通过 `go:embed` 编译进程序。 -- `context` 继续负责主会话 system prompt 的 section 顺序、动态 section 注入与最终渲染。 -- `runtime` 继续负责在特定条件下注入 reminder,但静态 reminder 文案本身来自模板资产。 -- `subagent` 继续负责角色策略、工具约束与输出契约,只有角色基础 prompt 抽离为模板资产。 - -### System Prompt 注入顺序 - -当前 `system prompt` 按以下顺序拼装: - -1. 固定核心 sections -2. `Project Rules` section -3. `Task State` section -4. `Todo State` section -5. `Skills` section -6. 可选 `Memo` section -7. `System State` section - -其中: - -- 规则文件只支持大写文件名 `AGENTS.md` -- 多份命中结果按“从全局到局部”的顺序注入 -- git 只注入摘要,不注入完整 `git status` -- 各 section 统一由 `internal/context` 内部的 `renderPromptSection` 和 `composeSystemPrompt` 渲染,`runtime` 仍只消费最终字符串 +单次 run 的主链路为: + +1. 加载或创建会话 +2. 追加最新用户消息 +3. 读取当前配置快照 +4. 调用 `context.Builder` 构建 provider-facing request +5. 冻结当前 turn 的 `provider / model / tools / workdir / request / prompt_budget` +6. 调用 provider 的 `EstimateInputTokens` +7. 由 budget control plane 输出 `allow | compact | stop` +8. 若为 `compact`,先执行 `proactive` compact,再在同一 run 内重建 request +9. 若为 `allow`,调用 `Provider.Generate` +10. 若 provider 返回 `context_too_long`,触发 `reactive` compact,并重新进入预算闭环 +11. 正常返回后,对 usage 做 reconcile +12. 追加 assistant 消息 +13. 执行工具调用并写回 tool result +14. 如仍需继续推理,进入下一轮;否则结束 + +## Budget 控制面 + +runtime 不再消费旧的 builder 压缩建议,而是使用冻结快照上的显式预算决策。 + +### `budget_checked` + +每次发送前预算判定都会发出 `budget_checked`: + +- `attempt_seq` +- `request_hash` +- `action` +- `reason` +- `estimated_input_tokens` +- `prompt_budget` +- `estimate_source` + +语义: + +- `allow`:本轮请求在预算内 +- `compact`:首次超预算,需要先压缩 +- `stop`:压缩后仍超预算,停止当前 run + +## Context Builder 职责 + +`runtime` 向 `context.Builder` 传入: + +- 历史消息 +- `workdir` +- `shell` +- 当前 `provider` +- 当前 `model` +- 会话累计输入 token +- 会话累计输出 token + +`context.Builder` 只负责: + +- 组装 `system prompt` +- 读取 `AGENTS.md` +- 注入 `Task State` / `Todo State` / `Skills` / `Memo` +- 执行 read-time trim 和 micro compact +- 输出最终 `SystemPrompt` 与消息列表 + +`context.Builder` 不再负责: + +- token budget 判断 +- proactive compact 触发 +- 旧的 builder 压缩建议布尔值 ## 流式桥接 -- Provider 发出 `StreamEvent` -- `internal/provider` 根包仅保留最小事件发送 helper;协议流解析留在各自 driver 子包 +- provider 发出 `StreamEvent` +- `internal/provider` 只处理协议差异 - `internal/runtime/streaming` 统一累积文本、tool call 增量和 `message_done` -- runtime 将累积过程映射成 `RuntimeEvent` -- TUI 使用 Bubble Tea `Cmd` 监听事件,并在处理完成后继续订阅 -- `provider.GenerateText` 只在上游 `Generate` 成功返回时,才把缺失 `message_done` 视为流式中断。 -- 如果 provider 在真正开始流式输出前直接返回 HTTP/ProviderError,则优先保留原始错误,不再额外包装成 `message_done` 缺失。 +- runtime 将结果映射成 `RuntimeEvent` +- TUI 通过 Bubble Tea `Cmd` 监听这些事件 + +## Usage 对账 + +provider 返回后,runtime 会执行显式的账本调和。 + +### `ledger_reconciled` + +每轮 provider 调用完成后都会发出: + +- `attempt_seq` +- `request_hash` +- `input_tokens` +- `input_source` +- `output_tokens` +- `output_source` +- `has_unknown_usage` -同一套流式累积逻辑同时复用于: -- 普通 `Run()` 的 assistant 回复收敛 -- compact summary 生成阶段的 provider 输出消费 +规则: -## Token 计量 +- provider 返回 usage 时,`input_source=observed`,`output_source=observed` +- provider usage 缺失时,输入侧回退到发送前 estimate,因此 `input_source=estimated` +- provider usage 缺失时,输出侧不伪装成观测值,因此 `output_source=unknown` +- 只要出现过未知 output usage,会话级 `HasUnknownUsage` 会被置为 `true` -runtime 在转发 provider 流式事件时,从 `MessageDone` 事件中提取 `Usage`(`InputTokens`、`OutputTokens`),累积到会话级计数器,并发出 `token_usage` 事件供 TUI 消费。 +### `token_usage` -`token_usage` payload 包含: +`token_usage` 继续面向 TUI 提供单轮和会话累计数据,并新增来源标签: -- `input_tokens`:本次调用输入 token -- `output_tokens`:本次调用输出 token -- `session_input_tokens`:会话累计输入 token -- `session_output_tokens`:会话累计输出 token +- `input_tokens` +- `output_tokens` +- `input_source` +- `output_source` +- `has_unknown_usage` +- `session_input_tokens` +- `session_output_tokens` ## 持久化时机 -- 用户消息提交后保存 -- assistant 完整回复后保存 -- 每个工具结果完成后保存 -- 避免在高频 UI 刷新路径中做磁盘 I/O +- 用户消息提交后立即持久化 +- assistant 完整回复后立即持久化 +- 每个工具结果完成后立即持久化 +- compact 成功后通过 `ReplaceTranscript` 原子重写 transcript -会话 JSON 结构、工作区分桶以及 token 计数持久化约束统一见 [Session 持久化设计](./session-persistence-design.md)。 +会话级 token totals 和 `HasUnknownUsage` 由 `runtime` 统一维护,并在持久化层落盘。 diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 325bc8c9..9e2496ce 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -29,7 +29,7 @@ import ( "neo-code/internal/tools/spawnsubagent" "neo-code/internal/tools/todo" "neo-code/internal/tools/webfetch" - "neo-code/internal/tui" + tuiapp "neo-code/internal/tui/core/app" "neo-code/internal/tui/services" ) @@ -40,7 +40,7 @@ var ( setConsoleInputCodePage = platformSetConsoleInputCodePage buildToolManagerFunc = buildToolManager newRemoteRuntimeAdapter = defaultNewRemoteRuntimeAdapter - newTUIWithMemo = tui.NewWithMemo + newTUIWithMemo = tuiapp.NewWithMemo cleanupExpiredSessions = func( ctx context.Context, store agentsession.Store, @@ -159,7 +159,7 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime // Session Store 绑定到启动时的 workdir 哈希分桶,整个应用生命周期内不可变。 // 这意味着所有会话都归属到启动时指定的项目目录下,运行时不会因配置变更而迁移存储位置。 - sessionStore = agentsession.NewStore(sharedDeps.ConfigManager.BaseDir(), cfg.Workdir) + sessionStore = agentsession.NewSQLiteStore(sharedDeps.ConfigManager.BaseDir(), cfg.Workdir) // 启动时自动清理过期会话,避免数据库无限膨胀。 if _, err := cleanupExpiredSessions(ctx, sessionStore, agentsession.DefaultSessionMaxAge); err != nil { @@ -196,13 +196,13 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime runtimeSvc.SetSessionAssetStore(sessionStore) runtimeSvc.SetUserInputPreparer(agentruntime.NewSessionInputPreparer(sessionStore, sessionStore)) runtimeSvc.SetSkillsRegistry(buildSkillsRegistry(ctx, sharedDeps.ConfigManager.BaseDir())) - runtimeSvc.SetAutoCompactThresholdResolver(runtimeAutoCompactThresholdResolverFunc( - func(ctx context.Context, cfg config.Config) (int, error) { - resolution, err := configstate.ResolveAutoCompactThreshold(ctx, cfg, modelCatalogs) + runtimeSvc.SetBudgetResolver(runtimeBudgetResolverFunc( + func(ctx context.Context, cfg config.Config) (int, string, error) { + resolution, err := configstate.ResolvePromptBudget(ctx, cfg, modelCatalogs) if err != nil { - return 0, err + return 0, "", err } - return resolution.Threshold, nil + return resolution.PromptBudget, string(resolution.Source), nil }, )) @@ -232,11 +232,6 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime }, nil } -// BuildRuntime 兼容旧入口,内部转发到 BuildGatewayServerDeps。 -func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, error) { - return BuildGatewayServerDeps(ctx, opts) -} - // NewProgram 基于共享运行时依赖构建并返回 TUI 程序,同时返回退出时应调用的资源清理函数。 func NewProgram(ctx context.Context, opts BootstrapOptions) (*tea.Program, func() error, error) { bundle, err := BuildTUIClientDeps(ctx, opts) @@ -423,9 +418,9 @@ func (f textGenAdapter) Generate(ctx context.Context, prompt string, msgs []prov return f(ctx, prompt, msgs) } -type runtimeAutoCompactThresholdResolverFunc func(ctx context.Context, cfg config.Config) (int, error) +type runtimeBudgetResolverFunc func(ctx context.Context, cfg config.Config) (int, string, error) -func (f runtimeAutoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { +func (f runtimeBudgetResolverFunc) ResolvePromptBudget(ctx context.Context, cfg config.Config) (int, string, error) { return f(ctx, cfg) } diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index 6e56f1dd..04fa0a9d 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -134,7 +134,7 @@ func TestBuildRuntimeRejectsUnsupportedSelectedProviderDriverOnStartup(t *testin t.Fatalf("write provider config: %v", err) } - _, err := BuildRuntime(context.Background(), BootstrapOptions{}) + _, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{}) if !errors.Is(err, configstate.ErrDriverUnsupported) { t.Fatalf("expected ErrDriverUnsupported, got %v", err) } @@ -762,9 +762,9 @@ func TestBuildRuntimeUsesWorkdirOverride(t *testing.T) { t.Fatalf("mkdir override workdir: %v", err) } - bundle, err := BuildRuntime(context.Background(), BootstrapOptions{Workdir: override}) + bundle, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{Workdir: override}) if err != nil { - t.Fatalf("BuildRuntime() error = %v", err) + t.Fatalf("BuildGatewayServerDeps() error = %v", err) } if bundle.Config.Workdir != filepath.Clean(override) { t.Fatalf("expected workdir %q, got %q", filepath.Clean(override), bundle.Config.Workdir) @@ -786,9 +786,9 @@ func TestBuildRuntimeSucceedsWhenSkillsRootMissing(t *testing.T) { t.Setenv("HOME", home) t.Setenv("USERPROFILE", home) - bundle, err := BuildRuntime(context.Background(), BootstrapOptions{}) + bundle, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{}) if err != nil { - t.Fatalf("BuildRuntime() error = %v", err) + t.Fatalf("BuildGatewayServerDeps() error = %v", err) } if bundle.Close != nil { t.Cleanup(func() { @@ -808,7 +808,7 @@ func TestBuildRuntimeSucceedsWhenSkillsRootMissing(t *testing.T) { t.Fatalf("expected runtime to expose ActivateSessionSkill") } - store := agentsession.NewStore(bundle.ConfigManager.BaseDir(), bundle.Config.Workdir) + store := agentsession.NewSQLiteStore(bundle.ConfigManager.BaseDir(), bundle.Config.Workdir) t.Cleanup(func() { if err := store.Close(); err != nil { t.Fatalf("store.Close() error = %v", err) @@ -820,7 +820,7 @@ func TestBuildRuntimeSucceedsWhenSkillsRootMissing(t *testing.T) { Title: session.Title, CreatedAt: session.CreatedAt, UpdatedAt: session.UpdatedAt, - Workdir: session.Workdir, + Head: session.HeadSnapshot(), }) if err != nil { t.Fatalf("save session: %v", err) @@ -856,9 +856,9 @@ func TestBuildRuntimeInjectsSkillsRegistryWhenRootExists(t *testing.T) { t.Fatalf("write SKILL.md: %v", err) } - bundle, err := BuildRuntime(context.Background(), BootstrapOptions{}) + bundle, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{}) if err != nil { - t.Fatalf("BuildRuntime() error = %v", err) + t.Fatalf("BuildGatewayServerDeps() error = %v", err) } if bundle.Close != nil { t.Cleanup(func() { @@ -868,7 +868,7 @@ func TestBuildRuntimeInjectsSkillsRegistryWhenRootExists(t *testing.T) { }) } - store := agentsession.NewStore(bundle.ConfigManager.BaseDir(), bundle.Config.Workdir) + store := agentsession.NewSQLiteStore(bundle.ConfigManager.BaseDir(), bundle.Config.Workdir) t.Cleanup(func() { if err := store.Close(); err != nil { t.Fatalf("store.Close() error = %v", err) @@ -880,7 +880,7 @@ func TestBuildRuntimeInjectsSkillsRegistryWhenRootExists(t *testing.T) { Title: session.Title, CreatedAt: session.CreatedAt, UpdatedAt: session.UpdatedAt, - Workdir: session.Workdir, + Head: session.HeadSnapshot(), }) if err != nil { t.Fatalf("save session: %v", err) @@ -930,7 +930,7 @@ func TestBuildRuntimeRejectsInvalidWorkdirOverride(t *testing.T) { t.Setenv("USERPROFILE", home) invalid := filepath.Join(t.TempDir(), "missing", "中文") - _, err := BuildRuntime(context.Background(), BootstrapOptions{Workdir: invalid}) + _, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{Workdir: invalid}) if err == nil || !strings.Contains(strings.ToLower(err.Error()), "resolve workdir") { t.Fatalf("expected resolve workdir error, got %v", err) } @@ -952,7 +952,7 @@ func TestBuildRuntimeRejectsInvalidConfigFile(t *testing.T) { t.Fatalf("write config: %v", err) } - _, err := BuildRuntime(context.Background(), BootstrapOptions{}) + _, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{}) if err == nil || !strings.Contains(err.Error(), "workdir not found") { t.Fatalf("expected legacy config error, got %v", err) } @@ -985,7 +985,7 @@ func TestBuildRuntimeRejectsUnsupportedMCPSource(t *testing.T) { t.Fatalf("write config: %v", err) } - _, err := BuildRuntime(context.Background(), BootstrapOptions{}) + _, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{}) if err == nil || !strings.Contains(strings.ToLower(err.Error()), "not supported") { t.Fatalf("expected unsupported mcp source validation error, got %v", err) } @@ -1034,7 +1034,7 @@ func TestBuildRuntimeCleansResourcesWhenToolManagerBuildFails(t *testing.T) { return nil, errors.New("build tool manager failed") } - _, err := BuildRuntime(context.Background(), BootstrapOptions{}) + _, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{}) if err == nil || !strings.Contains(err.Error(), "build tool manager failed") { t.Fatalf("expected tool manager build error, got %v", err) } @@ -1065,9 +1065,9 @@ func TestBuildRuntimeLogsSessionCleanupWarningAndContinues(t *testing.T) { log.SetOutput(&logBuffer) t.Cleanup(func() { log.SetOutput(originalLogWriter) }) - bundle, err := BuildRuntime(context.Background(), BootstrapOptions{}) + bundle, err := BuildGatewayServerDeps(context.Background(), BootstrapOptions{}) if err != nil { - t.Fatalf("BuildRuntime() error = %v", err) + t.Fatalf("BuildGatewayServerDeps() error = %v", err) } if bundle.Close != nil { defer bundle.Close() @@ -1531,6 +1531,9 @@ func TestBuildTUIClientDepsSkipsLocalRuntimeStack(t *testing.T) { func TestNewProgramUsesRemoteRuntimeAdapter(t *testing.T) { disableBuiltinProviderAPIKeys(t) + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) originalFactory := newRemoteRuntimeAdapter t.Cleanup(func() { newRemoteRuntimeAdapter = originalFactory }) @@ -1562,6 +1565,9 @@ func TestNewProgramUsesRemoteRuntimeAdapter(t *testing.T) { func TestNewProgramFailsFastWhenRemoteAdapterInitFails(t *testing.T) { disableBuiltinProviderAPIKeys(t) + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) originalFactory := newRemoteRuntimeAdapter t.Cleanup(func() { newRemoteRuntimeAdapter = originalFactory }) @@ -1823,6 +1829,18 @@ type stubMemoProvider struct { generate func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error } +func (s *stubMemoProvider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + _ = ctx + return providertypes.BudgetEstimate{ + EstimatedInputTokens: provider.EstimateTextTokens(req.SystemPrompt), + EstimateSource: provider.EstimateSourceLocal, + Accurate: false, + }, nil +} + func (s *stubMemoProvider) Generate( ctx context.Context, req providertypes.GenerateRequest, diff --git a/internal/cli/gateway_commands.go b/internal/cli/gateway_commands.go index e7e92bfb..a359ce9c 100644 --- a/internal/cli/gateway_commands.go +++ b/internal/cli/gateway_commands.go @@ -34,7 +34,7 @@ var ( newGatewayServer = defaultNewGatewayServer newGatewayNetwork = defaultNewGatewayNetworkServer dispatchURLThroughIPC = urlscheme.Dispatch - newAuthManager = gatewayauth.NewManager + newAuthManager = defaultNewAuthManager loadAuthToken = loadGatewayAuthToken exitProcess = os.Exit writeDispatchError = writeURLDispatchErrorOutput @@ -97,6 +97,11 @@ type gatewayNetworkServer interface { Close(ctx context.Context) error } +// defaultNewAuthManager 创建默认网关认证器,并把具体持久化实现收敛在 CLI 装配层内部。 +func defaultNewAuthManager(path string) (gateway.TokenAuthenticator, error) { + return gatewayauth.NewManager(path) +} + // newGatewayCommand 创建并返回网关子命令,负责启动本地 Gateway 进程。 func newGatewayCommand() *cobra.Command { options := &gatewayCommandOptions{} diff --git a/internal/cli/migrate_command.go b/internal/cli/migrate_command.go new file mode 100644 index 00000000..544ae758 --- /dev/null +++ b/internal/cli/migrate_command.go @@ -0,0 +1,67 @@ +package cli + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + + "neo-code/internal/config" +) + +type migrateContextBudgetOptions struct { + ConfigPath string + DryRun bool +} + +// newMigrateCommand 构建一次性迁移命令集合,迁移逻辑不接入主配置加载路径。 +func newMigrateCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "migrate", + Short: "Run one-time local data migrations", + SilenceUsage: true, + Args: cobra.NoArgs, + } + cmd.AddCommand(newMigrateContextBudgetCommand()) + return cmd +} + +// newMigrateContextBudgetCommand 构建 context.auto_compact 到 context.budget 的显式迁移命令。 +func newMigrateContextBudgetCommand() *cobra.Command { + options := &migrateContextBudgetOptions{} + cmd := &cobra.Command{ + Use: "context-budget", + Short: "Migrate context.auto_compact to context.budget", + SilenceUsage: true, + Args: cobra.NoArgs, + Annotations: map[string]string{ + commandAnnotationSkipGlobalPreload: "true", + commandAnnotationSkipSilentUpdateCheck: "true", + }, + RunE: func(cmd *cobra.Command, args []string) error { + result, err := config.MigrateContextBudgetConfigFile(strings.TrimSpace(options.ConfigPath), options.DryRun) + if err != nil { + return err + } + printContextBudgetMigrationResult(cmd, result, options.DryRun) + return nil + }, + } + cmd.Flags().StringVar(&options.ConfigPath, "config", "", "config.yaml path (default ~/.neocode/config.yaml)") + cmd.Flags().BoolVar(&options.DryRun, "dry-run", false, "check migration without writing files") + return cmd +} + +// printContextBudgetMigrationResult 输出迁移结果,确保 dry-run 和真实写入提示保持一致。 +func printContextBudgetMigrationResult(cmd *cobra.Command, result config.ContextBudgetMigrationResult, dryRun bool) { + writer := cmd.OutOrStdout() + if !result.Changed { + _, _ = fmt.Fprintf(writer, "跳过: %s (%s)\n", result.Path, result.Reason) + return + } + if dryRun { + _, _ = fmt.Fprintf(writer, "[DRY-RUN] 将迁移 %s\n", result.Path) + return + } + _, _ = fmt.Fprintf(writer, "已迁移 %s (备份: %s)\n", result.Path, result.Backup) +} diff --git a/internal/cli/migrate_command_test.go b/internal/cli/migrate_command_test.go new file mode 100644 index 00000000..75880fa1 --- /dev/null +++ b/internal/cli/migrate_command_test.go @@ -0,0 +1,96 @@ +package cli + +import ( + "bytes" + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestMigrateContextBudgetCommandDryRunSkipsGlobalHooks(t *testing.T) { + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { + runGlobalPreload = originalPreload + runSilentUpdateCheck = originalSilentCheck + }) + + runGlobalPreload = func(context.Context) error { + t.Fatal("migrate command must not run global preload") + return nil + } + runSilentUpdateCheck = func(context.Context) { + t.Fatal("migrate command must not run silent update check") + } + + dir := t.TempDir() + target := filepath.Join(dir, "config.yaml") + original := "context:\n auto_compact:\n input_token_threshold: 120000\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + var stdout bytes.Buffer + cmd := NewRootCommand() + cmd.SetOut(&stdout) + cmd.SetArgs([]string{"migrate", "context-budget", "--config", target, "--dry-run"}) + if err := cmd.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if !strings.Contains(stdout.String(), "[DRY-RUN]") { + t.Fatalf("expected dry-run output, got %q", stdout.String()) + } + content, err := os.ReadFile(target) + if err != nil { + t.Fatalf("read config: %v", err) + } + if string(content) != original { + t.Fatalf("dry-run mutated config:\n%s", content) + } +} + +func TestMigrateContextBudgetCommandWritesBackup(t *testing.T) { + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { + runGlobalPreload = originalPreload + runSilentUpdateCheck = originalSilentCheck + }) + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + + dir := t.TempDir() + target := filepath.Join(dir, "config.yaml") + original := "context:\n auto_compact:\n reserve_tokens: 13000\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + var stdout bytes.Buffer + cmd := NewRootCommand() + cmd.SetOut(&stdout) + cmd.SetArgs([]string{"migrate", "context-budget", "--config", target}) + if err := cmd.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if !strings.Contains(stdout.String(), "已迁移") { + t.Fatalf("expected migrated output, got %q", stdout.String()) + } + + backup, err := os.ReadFile(target + ".bak") + if err != nil { + t.Fatalf("read backup: %v", err) + } + if string(backup) != original { + t.Fatalf("backup content mismatch:\n%s", backup) + } + migrated, err := os.ReadFile(target) + if err != nil { + t.Fatalf("read migrated config: %v", err) + } + if strings.Contains(string(migrated), "auto_compact") || !strings.Contains(string(migrated), "budget:") { + t.Fatalf("unexpected migrated config:\n%s", migrated) + } +} diff --git a/internal/cli/root.go b/internal/cli/root.go index e52b60b8..fc8f291a 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -26,8 +26,10 @@ var checkLatestRelease = updater.CheckLatest const silentUpdateCheckTimeout = 3 * time.Second const silentUpdateCheckDrainTimeout = 300 * time.Millisecond +const commandAnnotationSkipGlobalPreload = "neocode.skip_global_preload" +const commandAnnotationSkipSilentUpdateCheck = "neocode.skip_silent_update_check" -var ansiEscapeSequencePattern = regexp.MustCompile(`\x1b(?:\[[0-?]*[ -/]*[@-~]|\][^\x07]*(?:\x07|\x1b\\)|[@-Z\\-_])`) +var ansiEscapeSequencePattern = regexp.MustCompile(`\x1b(?:\[[0-?]*[ -/]*[@-~]|][^\x07]*(?:\x07|\x1b\\)|[@-Z\\-_])`) var ( silentUpdateCheckMu sync.Mutex @@ -83,6 +85,7 @@ func NewRootCommand() *cobra.Command { _ = settings.BindPFlag("workdir", cmd.PersistentFlags().Lookup("workdir")) cmd.AddCommand( newGatewayCommand(), + newMigrateCommand(), newURLDispatchCommand(), newUpdateCommand(), ) @@ -156,7 +159,7 @@ func defaultSilentUpdateCheck(ctx context.Context) { // shouldSkipGlobalPreload 判断当前子命令是否跳过全局预加载。 func shouldSkipGlobalPreload(cmd *cobra.Command) bool { - return normalizedCommandName(cmd) == "url-dispatch" + return normalizedCommandName(cmd) == "url-dispatch" || commandAnnotationEnabled(cmd, commandAnnotationSkipGlobalPreload) } // shouldSkipSilentUpdateCheck 判断当前子命令是否跳过静默更新检查。 @@ -165,7 +168,7 @@ func shouldSkipSilentUpdateCheck(cmd *cobra.Command) bool { case "url-dispatch", "update": return true default: - return false + return commandAnnotationEnabled(cmd, commandAnnotationSkipSilentUpdateCheck) } } @@ -190,6 +193,16 @@ func normalizedCommandName(cmd *cobra.Command) string { return strings.ToLower(strings.TrimSpace(cmd.Name())) } +// commandAnnotationEnabled 沿当前命令链查找布尔注解,供轻量命令跳过全局启动副作用。 +func commandAnnotationEnabled(cmd *cobra.Command, key string) bool { + for current := cmd; current != nil; current = current.Parent() { + if strings.EqualFold(strings.TrimSpace(current.Annotations[key]), "true") { + return true + } + } + return false +} + // setSilentUpdateCheckDone 原子地更新静默检查完成信号通道。 func setSilentUpdateCheckDone(done <-chan struct{}) { silentUpdateCheckMu.Lock() diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 540ab369..69c3e01c 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -19,7 +19,6 @@ import ( "neo-code/internal/config" "neo-code/internal/gateway" "neo-code/internal/gateway/adapters/urlscheme" - gatewayauth "neo-code/internal/gateway/auth" "neo-code/internal/updater" ) @@ -295,8 +294,15 @@ func TestMustReadInheritedWorkdirBranches(t *testing.T) { func TestDefaultGatewayCommandRunnerSuccess(t *testing.T) { originalNewGatewayServer := newGatewayServer originalNewGatewayNetwork := newGatewayNetwork + originalBuildGatewayRuntimePort := buildGatewayRuntimePort + originalNewAuthManager := newAuthManager t.Cleanup(func() { newGatewayServer = originalNewGatewayServer }) t.Cleanup(func() { newGatewayNetwork = originalNewGatewayNetwork }) + t.Cleanup(func() { buildGatewayRuntimePort = originalBuildGatewayRuntimePort }) + t.Cleanup(func() { newAuthManager = originalNewAuthManager }) + prepareGatewayCommandRunnerTestEnv(t) + buildGatewayRuntimePort = stubGatewayRuntimePortBuilder() + newAuthManager = stubGatewayAuthManagerBuilder() server := &stubGatewayServer{listenAddress: "stub://gateway"} newGatewayServer = func(options gateway.ServerOptions) (gatewayServer, error) { @@ -328,7 +334,11 @@ func TestDefaultGatewayCommandRunnerSuccess(t *testing.T) { func TestDefaultGatewayCommandRunnerReturnsBuildRuntimePortError(t *testing.T) { originalBuildGatewayRuntimePort := buildGatewayRuntimePort + originalNewAuthManager := newAuthManager t.Cleanup(func() { buildGatewayRuntimePort = originalBuildGatewayRuntimePort }) + t.Cleanup(func() { newAuthManager = originalNewAuthManager }) + prepareGatewayCommandRunnerTestEnv(t) + newAuthManager = stubGatewayAuthManagerBuilder() buildGatewayRuntimePort = func(context.Context, string) (gateway.RuntimePort, func() error, error) { return nil, nil, errors.New("build runtime port failed") @@ -350,8 +360,15 @@ func TestDefaultGatewayCommandRunnerReturnsBuildRuntimePortError(t *testing.T) { func TestDefaultGatewayCommandRunnerReturnsConstructorError(t *testing.T) { originalNewGatewayServer := newGatewayServer originalNewGatewayNetwork := newGatewayNetwork + originalBuildGatewayRuntimePort := buildGatewayRuntimePort + originalNewAuthManager := newAuthManager t.Cleanup(func() { newGatewayServer = originalNewGatewayServer }) t.Cleanup(func() { newGatewayNetwork = originalNewGatewayNetwork }) + t.Cleanup(func() { buildGatewayRuntimePort = originalBuildGatewayRuntimePort }) + t.Cleanup(func() { newAuthManager = originalNewAuthManager }) + prepareGatewayCommandRunnerTestEnv(t) + buildGatewayRuntimePort = stubGatewayRuntimePortBuilder() + newAuthManager = stubGatewayAuthManagerBuilder() expected := errors.New("new gateway server failed") newGatewayServer = func(options gateway.ServerOptions) (gatewayServer, error) { @@ -372,6 +389,7 @@ func TestDefaultGatewayCommandRunnerReturnsConstructorError(t *testing.T) { } func TestDefaultGatewayCommandRunnerReturnsLoadConfigError(t *testing.T) { + prepareGatewayCommandRunnerTestEnv(t) ctx, cancel := context.WithCancel(context.Background()) cancel() err := defaultGatewayCommandRunner(ctx, gatewayCommandOptions{ @@ -388,11 +406,15 @@ func TestDefaultGatewayCommandRunnerReturnsAuthManagerError(t *testing.T) { originalNewGatewayServer := newGatewayServer originalNewGatewayNetwork := newGatewayNetwork originalNewAuthManager := newAuthManager + originalBuildGatewayRuntimePort := buildGatewayRuntimePort t.Cleanup(func() { newGatewayServer = originalNewGatewayServer }) t.Cleanup(func() { newGatewayNetwork = originalNewGatewayNetwork }) t.Cleanup(func() { newAuthManager = originalNewAuthManager }) + t.Cleanup(func() { buildGatewayRuntimePort = originalBuildGatewayRuntimePort }) + prepareGatewayCommandRunnerTestEnv(t) + buildGatewayRuntimePort = stubGatewayRuntimePortBuilder() - newAuthManager = func(string) (*gatewayauth.Manager, error) { + newAuthManager = func(string) (gateway.TokenAuthenticator, error) { return nil, errors.New("auth manager failed") } newGatewayServer = func(options gateway.ServerOptions) (gatewayServer, error) { @@ -415,8 +437,15 @@ func TestDefaultGatewayCommandRunnerReturnsAuthManagerError(t *testing.T) { func TestDefaultGatewayCommandRunnerReturnsServeError(t *testing.T) { originalNewGatewayServer := newGatewayServer originalNewGatewayNetwork := newGatewayNetwork + originalBuildGatewayRuntimePort := buildGatewayRuntimePort + originalNewAuthManager := newAuthManager t.Cleanup(func() { newGatewayServer = originalNewGatewayServer }) t.Cleanup(func() { newGatewayNetwork = originalNewGatewayNetwork }) + t.Cleanup(func() { buildGatewayRuntimePort = originalBuildGatewayRuntimePort }) + t.Cleanup(func() { newAuthManager = originalNewAuthManager }) + prepareGatewayCommandRunnerTestEnv(t) + buildGatewayRuntimePort = stubGatewayRuntimePortBuilder() + newAuthManager = stubGatewayAuthManagerBuilder() expected := errors.New("serve failed") server := &stubGatewayServer{ @@ -450,8 +479,15 @@ func TestDefaultGatewayCommandRunnerReturnsServeError(t *testing.T) { func TestDefaultGatewayCommandRunnerDegradesWhenNetworkServeFails(t *testing.T) { originalNewGatewayServer := newGatewayServer originalNewGatewayNetwork := newGatewayNetwork + originalBuildGatewayRuntimePort := buildGatewayRuntimePort + originalNewAuthManager := newAuthManager t.Cleanup(func() { newGatewayServer = originalNewGatewayServer }) t.Cleanup(func() { newGatewayNetwork = originalNewGatewayNetwork }) + t.Cleanup(func() { buildGatewayRuntimePort = originalBuildGatewayRuntimePort }) + t.Cleanup(func() { newAuthManager = originalNewAuthManager }) + prepareGatewayCommandRunnerTestEnv(t) + buildGatewayRuntimePort = stubGatewayRuntimePortBuilder() + newAuthManager = stubGatewayAuthManagerBuilder() ipcServer := &stubGatewayServer{listenAddress: "stub://gateway"} newGatewayServer = func(options gateway.ServerOptions) (gatewayServer, error) { @@ -487,8 +523,15 @@ func TestDefaultGatewayCommandRunnerDegradesWhenNetworkServeFails(t *testing.T) func TestDefaultGatewayCommandRunnerReturnsNetworkConstructorError(t *testing.T) { originalNewGatewayServer := newGatewayServer originalNewGatewayNetwork := newGatewayNetwork + originalBuildGatewayRuntimePort := buildGatewayRuntimePort + originalNewAuthManager := newAuthManager t.Cleanup(func() { newGatewayServer = originalNewGatewayServer }) t.Cleanup(func() { newGatewayNetwork = originalNewGatewayNetwork }) + t.Cleanup(func() { buildGatewayRuntimePort = originalBuildGatewayRuntimePort }) + t.Cleanup(func() { newAuthManager = originalNewAuthManager }) + prepareGatewayCommandRunnerTestEnv(t) + buildGatewayRuntimePort = stubGatewayRuntimePortBuilder() + newAuthManager = stubGatewayAuthManagerBuilder() networkErr := errors.New("new network server failed") ipcServer := &stubGatewayServer{listenAddress: "stub://gateway"} @@ -513,6 +556,10 @@ func TestDefaultGatewayCommandRunnerReturnsNetworkConstructorError(t *testing.T) } func TestDefaultGatewayCommandRunnerRejectsInvalidACLMode(t *testing.T) { + originalNewAuthManager := newAuthManager + t.Cleanup(func() { newAuthManager = originalNewAuthManager }) + prepareGatewayCommandRunnerTestEnv(t) + newAuthManager = stubGatewayAuthManagerBuilder() err := defaultGatewayCommandRunner(context.Background(), gatewayCommandOptions{ ListenAddress: "stub://gateway", HTTPAddress: "127.0.0.1:8080", @@ -1555,6 +1602,45 @@ type stubGatewayServer struct { closeCalled bool } +type stubRuntimePort struct{} + +type stubGatewayAuthenticator struct{} + +func (stubGatewayAuthenticator) ValidateToken(token string) bool { + return strings.TrimSpace(token) == "test-token" +} + +func (stubGatewayAuthenticator) ResolveSubjectID(token string) (string, bool) { + if strings.TrimSpace(token) != "test-token" { + return "", false + } + return "local_admin", true +} + +func (stubRuntimePort) Run(context.Context, gateway.RunInput) error { return nil } + +func (stubRuntimePort) Compact(context.Context, gateway.CompactInput) (gateway.CompactResult, error) { + return gateway.CompactResult{}, nil +} + +func (stubRuntimePort) ResolvePermission(context.Context, gateway.PermissionResolutionInput) error { + return nil +} + +func (stubRuntimePort) CancelRun(context.Context, gateway.CancelInput) (bool, error) { + return false, nil +} + +func (stubRuntimePort) Events() <-chan gateway.RuntimeEvent { return nil } + +func (stubRuntimePort) ListSessions(context.Context) ([]gateway.SessionSummary, error) { + return nil, nil +} + +func (stubRuntimePort) LoadSession(context.Context, gateway.LoadSessionInput) (gateway.Session, error) { + return gateway.Session{}, nil +} + func (s *stubGatewayServer) ListenAddress() string { return s.listenAddress } @@ -1592,3 +1678,23 @@ func captureEnvForRootTest(t *testing.T, key string) func() { _ = os.Unsetenv(key) } } + +func prepareGatewayCommandRunnerTestEnv(t *testing.T) { + t.Helper() + homeDir := t.TempDir() + t.Setenv("HOME", homeDir) + t.Setenv("USERPROFILE", homeDir) + t.Setenv("XDG_CONFIG_HOME", homeDir) +} + +func stubGatewayRuntimePortBuilder() func(context.Context, string) (gateway.RuntimePort, func() error, error) { + return func(context.Context, string) (gateway.RuntimePort, func() error, error) { + return stubRuntimePort{}, func() error { return nil }, nil + } +} + +func stubGatewayAuthManagerBuilder() func(string) (gateway.TokenAuthenticator, error) { + return func(string) (gateway.TokenAuthenticator, error) { + return stubGatewayAuthenticator{}, nil + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 0bf581d1..9d421a1c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -264,13 +264,21 @@ func TestConfigMethodErrorPaths(t *testing.T) { restoreEnv(t, "MISSING_PROVIDER_KEY") _ = os.Unsetenv("MISSING_PROVIDER_KEY") - _, err := (ProviderConfig{ + resolved, err := (ProviderConfig{ Name: "custom", Driver: "custom", BaseURL: "https://example.com", Model: "custom-model", APIKeyEnv: "MISSING_PROVIDER_KEY", }).Resolve() + if err != nil { + t.Fatalf("Resolve() error = %v", err) + } + runtimeConfig, err := resolved.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + _, err = runtimeConfig.ResolveAPIKeyValue() if err == nil || !strings.Contains(err.Error(), "MISSING_PROVIDER_KEY") { t.Fatalf("expected missing env resolve error, got %v", err) } @@ -739,8 +747,16 @@ func TestProviderLookupAndResolveSelectedProvider(t *testing.T) { if err != nil { t.Fatalf("Resolve() error = %v", err) } - if resolved.APIKey != "lookup-key" { - t.Fatalf("expected resolved key %q, got %q", "lookup-key", resolved.APIKey) + runtimeConfig, err := resolved.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + apiKey, err := runtimeConfig.ResolveAPIKeyValue() + if err != nil { + t.Fatalf("ResolveAPIKeyValue() error = %v", err) + } + if apiKey != "lookup-key" { + t.Fatalf("expected resolved key %q, got %q", "lookup-key", apiKey) } } @@ -1029,8 +1045,8 @@ func TestManagerHelperMethodsAndReloads(t *testing.T) { if err := manager.Save(context.Background()); err != nil { t.Fatalf("Save() error = %v", err) } - if _, err := manager.Reload(context.Background()); err != nil { - t.Fatalf("Reload() error = %v", err) + if _, err := manager.Load(context.Background()); err != nil { + t.Fatalf("Load() error = %v", err) } if got := manager.ConfigPath(); got != filepath.Join(tempDir, configName) { t.Fatalf("expected config path %q, got %q", filepath.Join(tempDir, configName), got) @@ -1141,7 +1157,7 @@ func TestCompactConfigDefaultsAndRoundTrip(t *testing.T) { reloaded, err := loader.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.Context.Compact.ManualStrategy != CompactManualStrategyFullReplace { t.Fatalf("expected manual strategy to persist, got %q", reloaded.Context.Compact.ManualStrategy) @@ -1243,93 +1259,101 @@ func restoreEnv(t *testing.T, key string) { }) } -func TestAutoCompactConfigDefaults(t *testing.T) { +func TestBudgetConfigDefaults(t *testing.T) { t.Parallel() cfg := StaticDefaults() - if cfg.Context.AutoCompact.InputTokenThreshold != DefaultAutoCompactInputTokenThreshold { - t.Fatalf("expected input_token_threshold=%d, got %d", - DefaultAutoCompactInputTokenThreshold, cfg.Context.AutoCompact.InputTokenThreshold) + if cfg.Context.Budget.PromptBudget != DefaultBudgetPromptBudget { + t.Fatalf("expected prompt_budget=%d, got %d", DefaultBudgetPromptBudget, cfg.Context.Budget.PromptBudget) } - if cfg.Context.AutoCompact.ReserveTokens != DefaultAutoCompactReserveTokens { - t.Fatalf("expected reserve_tokens=%d, got %d", - DefaultAutoCompactReserveTokens, cfg.Context.AutoCompact.ReserveTokens) + if cfg.Context.Budget.ReserveTokens != DefaultBudgetReserveTokens { + t.Fatalf("expected reserve_tokens=%d, got %d", DefaultBudgetReserveTokens, cfg.Context.Budget.ReserveTokens) } - if cfg.Context.AutoCompact.FallbackInputTokenThreshold != DefaultAutoCompactFallbackInputTokenThreshold { - t.Fatalf("expected fallback_input_token_threshold=%d, got %d", - DefaultAutoCompactFallbackInputTokenThreshold, cfg.Context.AutoCompact.FallbackInputTokenThreshold) + if cfg.Context.Budget.FallbackPromptBudget != DefaultBudgetFallbackPromptBudget { + t.Fatalf("expected fallback_prompt_budget=%d, got %d", + DefaultBudgetFallbackPromptBudget, cfg.Context.Budget.FallbackPromptBudget) } - - if cfg.Context.AutoCompact.Enabled != false { - t.Fatalf("expected enabled=false, got %v", cfg.Context.AutoCompact.Enabled) + if cfg.Context.Budget.MaxReactiveCompacts != DefaultBudgetMaxReactiveCompacts { + t.Fatalf("expected max_reactive_compacts=%d, got %d", + DefaultBudgetMaxReactiveCompacts, cfg.Context.Budget.MaxReactiveCompacts) } } -func TestAutoCompactConfigApplyDefaults(t *testing.T) { +func TestBudgetConfigApplyDefaults(t *testing.T) { t.Parallel() - cfg := AutoCompactConfig{} - defaults := AutoCompactConfig{ - ReserveTokens: 13000, - FallbackInputTokenThreshold: 100000, + cfg := BudgetConfig{} + defaults := BudgetConfig{ + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, } cfg.ApplyDefaults(defaults) - if cfg.InputTokenThreshold != 0 { - t.Fatalf("expected threshold to remain implicit 0, got %d", cfg.InputTokenThreshold) + if cfg.PromptBudget != 0 { + t.Fatalf("expected prompt budget to remain implicit 0, got %d", cfg.PromptBudget) } if cfg.ReserveTokens != 13000 { t.Fatalf("expected reserve_tokens=13000, got %d", cfg.ReserveTokens) } - if cfg.FallbackInputTokenThreshold != 100000 { - t.Fatalf("expected fallback_input_token_threshold=100000, got %d", cfg.FallbackInputTokenThreshold) + if cfg.FallbackPromptBudget != 100000 { + t.Fatalf("expected fallback_prompt_budget=100000, got %d", cfg.FallbackPromptBudget) + } + if cfg.MaxReactiveCompacts != 3 { + t.Fatalf("expected max_reactive_compacts=3, got %d", cfg.MaxReactiveCompacts) } } -func TestAutoCompactConfigApplyDefaultsPreservesExplicit(t *testing.T) { +func TestBudgetConfigApplyDefaultsPreservesExplicit(t *testing.T) { t.Parallel() - cfg := AutoCompactConfig{ - InputTokenThreshold: 200000, - ReserveTokens: 5000, - FallbackInputTokenThreshold: 80000, + cfg := BudgetConfig{ + PromptBudget: 200000, + ReserveTokens: 5000, + FallbackPromptBudget: 80000, + MaxReactiveCompacts: 5, } - defaults := AutoCompactConfig{ - InputTokenThreshold: 50000, - ReserveTokens: 13000, - FallbackInputTokenThreshold: 100000, + defaults := BudgetConfig{ + PromptBudget: 50000, + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, } cfg.ApplyDefaults(defaults) - if cfg.InputTokenThreshold != 200000 { - t.Fatalf("expected explicit threshold=200000 to be preserved, got %d", cfg.InputTokenThreshold) + if cfg.PromptBudget != 200000 { + t.Fatalf("expected explicit prompt_budget=200000 to be preserved, got %d", cfg.PromptBudget) } if cfg.ReserveTokens != 5000 { t.Fatalf("expected explicit reserve_tokens=5000 to be preserved, got %d", cfg.ReserveTokens) } - if cfg.FallbackInputTokenThreshold != 80000 { - t.Fatalf("expected explicit fallback_input_token_threshold=80000 to be preserved, got %d", cfg.FallbackInputTokenThreshold) + if cfg.FallbackPromptBudget != 80000 { + t.Fatalf("expected explicit fallback_prompt_budget=80000 to be preserved, got %d", cfg.FallbackPromptBudget) + } + if cfg.MaxReactiveCompacts != 5 { + t.Fatalf("expected explicit max_reactive_compacts=5 to be preserved, got %d", cfg.MaxReactiveCompacts) } } -func TestAutoCompactConfigApplyDefaultsNilReceiver(t *testing.T) { +func TestBudgetConfigApplyDefaultsNilReceiver(t *testing.T) { t.Parallel() - var cfg *AutoCompactConfig - cfg.ApplyDefaults(AutoCompactConfig{ReserveTokens: 13000, FallbackInputTokenThreshold: 100000}) + var cfg *BudgetConfig + cfg.ApplyDefaults(BudgetConfig{ReserveTokens: 13000, FallbackPromptBudget: 100000, MaxReactiveCompacts: 3}) } -func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { +func TestContextConfigApplyDefaultsPropagatesBudgetDefaults(t *testing.T) { t.Parallel() cfg := ContextConfig{} cfg.ApplyDefaults(ContextConfig{ - AutoCompact: AutoCompactConfig{ - ReserveTokens: 13000, - FallbackInputTokenThreshold: 100000, + Budget: BudgetConfig{ + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, }, Compact: CompactConfig{ ManualStrategy: CompactManualStrategyKeepRecent, @@ -1339,73 +1363,76 @@ func TestContextConfigApplyDefaultsPropagatesAutoCompactDefaults(t *testing.T) { }, }) - if cfg.AutoCompact.InputTokenThreshold != 0 { - t.Fatalf("expected auto compact threshold to remain implicit 0, got %d", cfg.AutoCompact.InputTokenThreshold) + if cfg.Budget.PromptBudget != 0 { + t.Fatalf("expected prompt budget to remain implicit 0, got %d", cfg.Budget.PromptBudget) + } + if cfg.Budget.ReserveTokens != 13000 { + t.Fatalf("expected reserve_tokens=13000, got %d", cfg.Budget.ReserveTokens) } - if cfg.AutoCompact.ReserveTokens != 13000 { - t.Fatalf("expected reserve_tokens=13000, got %d", cfg.AutoCompact.ReserveTokens) + if cfg.Budget.FallbackPromptBudget != 100000 { + t.Fatalf("expected fallback_prompt_budget=100000, got %d", cfg.Budget.FallbackPromptBudget) } - if cfg.AutoCompact.FallbackInputTokenThreshold != 100000 { - t.Fatalf("expected fallback_input_token_threshold=100000, got %d", cfg.AutoCompact.FallbackInputTokenThreshold) + if cfg.Budget.MaxReactiveCompacts != 3 { + t.Fatalf("expected max_reactive_compacts=3, got %d", cfg.Budget.MaxReactiveCompacts) } } -func TestAutoCompactConfigValidateEnabledWithoutThreshold(t *testing.T) { +func TestBudgetConfigValidateAllowsImplicitPromptBudget(t *testing.T) { t.Parallel() - cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - ReserveTokens: 13000, - FallbackInputTokenThreshold: 100000, + cfg := BudgetConfig{ + PromptBudget: 0, + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, } err := cfg.Validate() if err != nil { - t.Fatalf("expected validation to allow implicit threshold, got %v", err) + t.Fatalf("expected validation to allow implicit prompt budget, got %v", err) } } -func TestAutoCompactConfigValidateDisabledWithoutThreshold(t *testing.T) { +func TestBudgetConfigValidateRejectsNegativePromptBudget(t *testing.T) { t.Parallel() - cfg := AutoCompactConfig{ - Enabled: false, - InputTokenThreshold: 0, - ReserveTokens: 0, - FallbackInputTokenThreshold: 0, + cfg := BudgetConfig{ + PromptBudget: -1, + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, } err := cfg.Validate() - if err != nil { - t.Fatalf("expected no error for disabled auto compact, got %v", err) + if err == nil || !strings.Contains(err.Error(), "prompt_budget") { + t.Fatalf("expected prompt_budget validation error, got %v", err) } } -func TestAutoCompactConfigValidateEnabledWithThreshold(t *testing.T) { +func TestBudgetConfigValidateWithExplicitPromptBudget(t *testing.T) { t.Parallel() - cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 50000, - ReserveTokens: 13000, - FallbackInputTokenThreshold: 100000, + cfg := BudgetConfig{ + PromptBudget: 50000, + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, } err := cfg.Validate() if err != nil { - t.Fatalf("expected validation to pass, got %v", err) + t.Fatalf("expected budget validation to pass, got %v", err) } } -func TestAutoCompactConfigValidateRejectsNonPositiveReserveTokens(t *testing.T) { +func TestBudgetConfigValidateRejectsNonPositiveReserveTokens(t *testing.T) { t.Parallel() - cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - ReserveTokens: 0, - FallbackInputTokenThreshold: 100000, + cfg := BudgetConfig{ + PromptBudget: 0, + ReserveTokens: 0, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, } err := cfg.Validate() @@ -1414,76 +1441,49 @@ func TestAutoCompactConfigValidateRejectsNonPositiveReserveTokens(t *testing.T) } } -func TestAutoCompactConfigValidateRejectsNonPositiveFallbackThreshold(t *testing.T) { +func TestBudgetConfigValidateRejectsNonPositiveFallbackPromptBudget(t *testing.T) { t.Parallel() - cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - ReserveTokens: 13000, - FallbackInputTokenThreshold: 0, + cfg := BudgetConfig{ + PromptBudget: 0, + ReserveTokens: 13000, + FallbackPromptBudget: 0, + MaxReactiveCompacts: 3, } err := cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "fallback_input_token_threshold") { - t.Fatalf("expected fallback_input_token_threshold validation error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "fallback_prompt_budget") { + t.Fatalf("expected fallback_prompt_budget validation error, got %v", err) } } -func TestAutoCompactConfigClone(t *testing.T) { +func TestBudgetConfigValidateRejectsNonPositiveMaxReactiveCompacts(t *testing.T) { t.Parallel() - cfg := AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 75000, - ReserveTokens: 13000, - FallbackInputTokenThreshold: 100000, + cfg := BudgetConfig{ + PromptBudget: 0, + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 0, } - - cloned := cfg.Clone() - - if cfg.Enabled != cloned.Enabled { - t.Fatalf("expected enabled=%v to be cloned, got %v", cfg.Enabled, cloned.Enabled) - } - if cfg.InputTokenThreshold != cloned.InputTokenThreshold { - t.Fatalf("expected threshold=%d to be cloned, got %d", - cfg.InputTokenThreshold, cloned.InputTokenThreshold) - } - if cfg.ReserveTokens != cloned.ReserveTokens { - t.Fatalf("expected reserve_tokens=%d to be cloned, got %d", cfg.ReserveTokens, cloned.ReserveTokens) - } - if cfg.FallbackInputTokenThreshold != cloned.FallbackInputTokenThreshold { - t.Fatalf("expected fallback_input_token_threshold=%d to be cloned, got %d", - cfg.FallbackInputTokenThreshold, cloned.FallbackInputTokenThreshold) - } - - cloned.InputTokenThreshold = 100000 - if cfg.InputTokenThreshold == cloned.InputTokenThreshold { - t.Fatalf("clone should be independent from source") + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "max_reactive_compacts") { + t.Fatalf("expected max_reactive_compacts validation error, got %v", err) } } -func TestAutoCompactConfigContextConfigValidate(t *testing.T) { +func TestBudgetConfigClone(t *testing.T) { t.Parallel() - ctx := ContextConfig{ - AutoCompact: AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - ReserveTokens: 13000, - FallbackInputTokenThreshold: 100000, - }, - Compact: CompactConfig{ - ManualStrategy: CompactManualStrategyKeepRecent, - ManualKeepRecentMessages: 10, - MaxSummaryChars: 1200, - ReadTimeMaxMessageSpans: 24, - }, + cfg := BudgetConfig{ + PromptBudget: 75000, + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, } - - err := ctx.Validate() - if err != nil { - t.Fatalf("expected context validation to allow implicit threshold, got %v", err) + cloned := cfg.Clone() + if cfg != cloned { + t.Fatalf("expected equal config clone, got %+v vs %+v", cfg, cloned) } } @@ -1918,7 +1918,6 @@ func TestToRuntimeConfigMapsAllFields(t *testing.T) { Model: "gemini-2.5-flash", APIKeyEnv: "TEST_ENV_KEY", }, - APIKey: "resolved-secret-key", } got, err := resolved.ToRuntimeConfig() @@ -1934,8 +1933,11 @@ func TestToRuntimeConfigMapsAllFields(t *testing.T) { if got.DefaultModel != "gemini-2.5-flash" { t.Fatalf("expected DefaultModel=gemini-2.5-flash, got %q", got.DefaultModel) } - if got.APIKey != "resolved-secret-key" { - t.Fatalf("expected APIKey=resolved-secret-key, got %q", got.APIKey) + if got.APIKeyEnv != "TEST_ENV_KEY" { + t.Fatalf("expected APIKeyEnv=TEST_ENV_KEY, got %q", got.APIKeyEnv) + } + if got.APIKeyResolver == nil { + t.Fatal("expected APIKeyResolver to be set") } } diff --git a/internal/config/context.go b/internal/config/context.go index 32cb3e10..8263a036 100644 --- a/internal/config/context.go +++ b/internal/config/context.go @@ -7,21 +7,22 @@ import ( ) const ( - DefaultCompactManualKeepRecentMessages = 10 - DefaultCompactMaxSummaryChars = 1200 - DefaultAutoCompactInputTokenThreshold = 0 - DefaultAutoCompactReserveTokens = 13000 - DefaultAutoCompactFallbackInputTokenThreshold = 100000 - DefaultMicroCompactRetainedToolSpans = 6 - DefaultCompactReadTimeMaxMessageSpans = 24 + DefaultCompactManualKeepRecentMessages = 10 + DefaultCompactMaxSummaryChars = 1200 + DefaultBudgetPromptBudget = 0 + DefaultBudgetReserveTokens = 13000 + DefaultBudgetFallbackPromptBudget = 100000 + DefaultBudgetMaxReactiveCompacts = 3 + DefaultMicroCompactRetainedToolSpans = 6 + DefaultCompactReadTimeMaxMessageSpans = 24 CompactManualStrategyKeepRecent = "keep_recent" CompactManualStrategyFullReplace = "full_replace" ) type ContextConfig struct { - Compact CompactConfig `yaml:"compact,omitempty"` - AutoCompact AutoCompactConfig `yaml:"auto_compact,omitempty"` + Compact CompactConfig `yaml:"compact,omitempty"` + Budget BudgetConfig `yaml:"budget,omitempty"` } type CompactConfig struct { @@ -34,27 +35,29 @@ type CompactConfig struct { MaxArchivedPromptChars int `yaml:"max_archived_prompt_chars,omitempty"` } -// AutoCompactConfig controls automatic context compression triggered by token thresholds. -type AutoCompactConfig struct { - Enabled bool `yaml:"enabled"` - InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` - ReserveTokens int `yaml:"reserve_tokens,omitempty"` - FallbackInputTokenThreshold int `yaml:"fallback_input_token_threshold,omitempty"` +// BudgetConfig 定义上下文预算控制面的配置。 +type BudgetConfig struct { + PromptBudget int `yaml:"prompt_budget,omitempty"` + ReserveTokens int `yaml:"reserve_tokens,omitempty"` + FallbackPromptBudget int `yaml:"fallback_prompt_budget,omitempty"` + MaxReactiveCompacts int `yaml:"max_reactive_compacts,omitempty"` } // defaultContextConfig 返回上下文压缩相关配置的默认值。 func defaultContextConfig() ContextConfig { return ContextConfig{ - Compact: defaultCompactConfig(), - AutoCompact: defaultAutoCompactConfig(), + Compact: defaultCompactConfig(), + Budget: defaultBudgetConfig(), } } -func defaultAutoCompactConfig() AutoCompactConfig { - return AutoCompactConfig{ - InputTokenThreshold: DefaultAutoCompactInputTokenThreshold, - ReserveTokens: DefaultAutoCompactReserveTokens, - FallbackInputTokenThreshold: DefaultAutoCompactFallbackInputTokenThreshold, +// defaultBudgetConfig 返回预算控制面的默认配置。 +func defaultBudgetConfig() BudgetConfig { + return BudgetConfig{ + PromptBudget: DefaultBudgetPromptBudget, + ReserveTokens: DefaultBudgetReserveTokens, + FallbackPromptBudget: DefaultBudgetFallbackPromptBudget, + MaxReactiveCompacts: DefaultBudgetMaxReactiveCompacts, } } @@ -72,8 +75,8 @@ func defaultCompactConfig() CompactConfig { // Clone 返回上下文配置的独立副本,避免后续修改污染原值。 func (c ContextConfig) Clone() ContextConfig { return ContextConfig{ - Compact: c.Compact.Clone(), - AutoCompact: c.AutoCompact.Clone(), + Compact: c.Compact.Clone(), + Budget: c.Budget.Clone(), } } @@ -82,8 +85,8 @@ func (c CompactConfig) Clone() CompactConfig { return c } -// Clone 返回 auto_compact 配置的值副本。 -func (c AutoCompactConfig) Clone() AutoCompactConfig { +// Clone 返回 budget 配置的值副本。 +func (c BudgetConfig) Clone() BudgetConfig { return c } @@ -94,7 +97,7 @@ func (c *ContextConfig) ApplyDefaults(defaults ContextConfig) { } c.Compact.ApplyDefaults(defaults.Compact) - c.AutoCompact.ApplyDefaults(defaults.AutoCompact) + c.Budget.ApplyDefaults(defaults.Budget) } // ApplyDefaults 为 compact 配置填充缺省策略和阈值。 @@ -120,16 +123,19 @@ func (c *CompactConfig) ApplyDefaults(defaults CompactConfig) { } } -// ApplyDefaults 为 auto_compact 配置填充缺省阈值。 -func (c *AutoCompactConfig) ApplyDefaults(defaults AutoCompactConfig) { +// ApplyDefaults 为 budget 配置填充缺省值。 +func (c *BudgetConfig) ApplyDefaults(defaults BudgetConfig) { if c == nil { return } if c.ReserveTokens <= 0 { c.ReserveTokens = defaults.ReserveTokens } - if c.FallbackInputTokenThreshold <= 0 { - c.FallbackInputTokenThreshold = defaults.FallbackInputTokenThreshold + if c.FallbackPromptBudget <= 0 { + c.FallbackPromptBudget = defaults.FallbackPromptBudget + } + if c.MaxReactiveCompacts <= 0 { + c.MaxReactiveCompacts = defaults.MaxReactiveCompacts } } @@ -138,8 +144,8 @@ func (c ContextConfig) Validate() error { if err := c.Compact.Validate(); err != nil { return fmt.Errorf("compact: %w", err) } - if err := c.AutoCompact.Validate(); err != nil { - return fmt.Errorf("auto_compact: %w", err) + if err := c.Budget.Validate(); err != nil { + return fmt.Errorf("budget: %w", err) } return nil } @@ -164,16 +170,19 @@ func (c CompactConfig) Validate() error { } } -// Validate 校验 auto_compact 配置是否合法。 -func (c AutoCompactConfig) Validate() error { - if !c.Enabled { - return nil +// Validate 校验 budget 配置是否合法。 +func (c BudgetConfig) Validate() error { + if c.PromptBudget < 0 { + return errors.New("prompt_budget must be greater than or equal to 0") } if c.ReserveTokens <= 0 { - return errors.New("reserve_tokens must be greater than 0 when enabled") + return errors.New("reserve_tokens must be greater than 0") + } + if c.FallbackPromptBudget <= 0 { + return errors.New("fallback_prompt_budget must be greater than 0") } - if c.FallbackInputTokenThreshold <= 0 { - return errors.New("fallback_input_token_threshold must be greater than 0 when enabled") + if c.MaxReactiveCompacts <= 0 { + return errors.New("max_reactive_compacts must be greater than 0") } return nil } diff --git a/internal/config/context_budget_migration.go b/internal/config/context_budget_migration.go new file mode 100644 index 00000000..cb0f2c2d --- /dev/null +++ b/internal/config/context_budget_migration.go @@ -0,0 +1,147 @@ +package config + +import ( + "bytes" + "errors" + "fmt" + "os" + "path/filepath" + + "gopkg.in/yaml.v3" +) + +// ContextBudgetMigrationResult 汇总 config.yaml 预算配置迁移的执行结果。 +type ContextBudgetMigrationResult struct { + Path string + Changed bool + Backup string + Reason string +} + +// DefaultConfigPath 返回当前用户环境下的默认主配置文件路径。 +func DefaultConfigPath() string { + return filepath.Join(defaultBaseDir(), configName) +} + +// UpgradeConfigSchemaBeforeLoad 在严格解析配置前执行一次磁盘结构升级。 +func UpgradeConfigSchemaBeforeLoad(path string) error { + _, err := MigrateContextBudgetConfigFile(path, false) + return err +} + +// MigrateContextBudgetConfigFile 将 config.yaml 中的 context.auto_compact 迁移到 context.budget。 +func MigrateContextBudgetConfigFile(path string, dryRun bool) (ContextBudgetMigrationResult, error) { + if path == "" { + path = DefaultConfigPath() + } + if filepath.Base(path) != configName { + return ContextBudgetMigrationResult{}, fmt.Errorf("config: migration target must be %s", configName) + } + + result := ContextBudgetMigrationResult{Path: path} + raw, err := os.ReadFile(path) + if err != nil { + return result, fmt.Errorf("config: read migration target %s: %w", path, err) + } + + migrated, changed, err := MigrateContextBudgetConfigContent(raw) + if err != nil { + return result, fmt.Errorf("config: migrate %s: %w", path, err) + } + if !changed { + result.Reason = "未检测到 context.auto_compact" + return result, nil + } + + result.Changed = true + if dryRun { + return result, nil + } + + backup := path + ".bak" + if err := os.WriteFile(backup, raw, 0o644); err != nil { + return result, fmt.Errorf("config: write migration backup %s: %w", backup, err) + } + if err := os.WriteFile(path, migrated, 0o644); err != nil { + return result, fmt.Errorf("config: write migrated config %s: %w", path, err) + } + result.Backup = backup + return result, nil +} + +// MigrateContextBudgetConfigContent 将旧预算 YAML 块替换为当前预算 YAML 块。 +func MigrateContextBudgetConfigContent(raw []byte) ([]byte, bool, error) { + if len(bytes.TrimSpace(raw)) == 0 { + return raw, false, nil + } + if !bytes.Contains(raw, []byte("auto_compact")) { + return raw, false, nil + } + + var doc map[string]any + if err := yaml.Unmarshal(raw, &doc); err != nil { + return nil, false, err + } + contextValue, ok := doc["context"] + if !ok { + return raw, false, nil + } + contextMap, ok := migrationStringMap(contextValue) + if !ok { + return nil, false, errors.New("context must be a mapping") + } + + autoValue, hasAutoCompact := contextMap["auto_compact"] + if !hasAutoCompact { + return raw, false, nil + } + if _, hasBudget := contextMap["budget"]; hasBudget { + return nil, false, errors.New("context.auto_compact and context.budget cannot both exist") + } + + autoMap, ok := migrationStringMap(autoValue) + if !ok { + return nil, false, errors.New("context.auto_compact must be a mapping") + } + budgetMap := make(map[string]any) + migrationMoveField(autoMap, budgetMap, "input_token_threshold", "prompt_budget") + migrationMoveField(autoMap, budgetMap, "reserve_tokens", "reserve_tokens") + migrationMoveField(autoMap, budgetMap, "fallback_input_token_threshold", "fallback_prompt_budget") + + delete(contextMap, "auto_compact") + contextMap["budget"] = budgetMap + doc["context"] = contextMap + + out, err := yaml.Marshal(doc) + if err != nil { + return nil, false, err + } + return out, true, nil +} + +// migrationMoveField 在两个 YAML map 之间迁移字段名,不修改字段值。 +func migrationMoveField(src map[string]any, dst map[string]any, oldName string, newName string) { + if value, ok := src[oldName]; ok { + dst[newName] = value + } +} + +// migrationStringMap 将 YAML map 统一转为 map[string]any。 +func migrationStringMap(value any) (map[string]any, bool) { + switch typed := value.(type) { + case map[string]any: + return typed, true + case map[any]any: + result := make(map[string]any, len(typed)) + for key, value := range typed { + keyString, ok := key.(string) + if !ok { + return nil, false + } + result[keyString] = value + } + return result, true + default: + return nil, false + } +} diff --git a/internal/config/context_budget_migration_test.go b/internal/config/context_budget_migration_test.go new file mode 100644 index 00000000..153a498b --- /dev/null +++ b/internal/config/context_budget_migration_test.go @@ -0,0 +1,95 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestMigrateContextBudgetConfigContentMovesAutoCompactToBudget(t *testing.T) { + t.Parallel() + + input := []byte(strings.TrimSpace(` +selected_provider: openai +context: + compact: + manual_strategy: keep_recent + auto_compact: + input_token_threshold: 120000 + reserve_tokens: 13000 + fallback_input_token_threshold: 100000 +`) + "\n") + + out, changed, err := MigrateContextBudgetConfigContent(input) + if err != nil { + t.Fatalf("MigrateContextBudgetConfigContent() error = %v", err) + } + if !changed { + t.Fatal("expected migration change") + } + text := string(out) + if strings.Contains(text, "auto_compact:") { + t.Fatalf("expected auto_compact removed, got:\n%s", text) + } + for _, want := range []string{ + "budget:", + "prompt_budget: 120000", + "reserve_tokens: 13000", + "fallback_prompt_budget: 100000", + } { + if !strings.Contains(text, want) { + t.Fatalf("expected migrated YAML to contain %q, got:\n%s", want, text) + } + } +} + +func TestMigrateContextBudgetConfigContentRejectsMixedBudgetBlocks(t *testing.T) { + t.Parallel() + + input := []byte(strings.TrimSpace(` +context: + budget: + prompt_budget: 100000 + auto_compact: + input_token_threshold: 120000 +`) + "\n") + + _, _, err := MigrateContextBudgetConfigContent(input) + if err == nil || !strings.Contains(err.Error(), "cannot both exist") { + t.Fatalf("expected mixed block error, got %v", err) + } +} + +func TestMigrateContextBudgetConfigFileCreatesBackup(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + target := filepath.Join(dir, configName) + original := strings.TrimSpace(` +context: + auto_compact: + input_token_threshold: 120000 +`) + "\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write target: %v", err) + } + + result, err := MigrateContextBudgetConfigFile(target, false) + if err != nil { + t.Fatalf("MigrateContextBudgetConfigFile() error = %v", err) + } + if !result.Changed { + t.Fatal("expected changed result") + } + if result.Backup == "" { + t.Fatal("expected backup path") + } + backup, err := os.ReadFile(result.Backup) + if err != nil { + t.Fatalf("read backup: %v", err) + } + if string(backup) != original { + t.Fatalf("expected backup to keep original content, got:\n%s", backup) + } +} diff --git a/internal/config/context_test.go b/internal/config/context_test.go index 8dd9186e..6ba82ef7 100644 --- a/internal/config/context_test.go +++ b/internal/config/context_test.go @@ -15,41 +15,42 @@ func TestContextConfigCloneIndependence(t *testing.T) { MaxSummaryChars: 1200, ReadTimeMaxMessageSpans: 24, }, - AutoCompact: AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 50000, + Budget: BudgetConfig{ + PromptBudget: 50000, + ReserveTokens: 9000, + FallbackPromptBudget: 88000, + MaxReactiveCompacts: 2, }, } cloned := original.Clone() cloned.Compact.ManualStrategy = CompactManualStrategyFullReplace cloned.Compact.ManualKeepRecentMessages = 5 - cloned.AutoCompact.Enabled = false - cloned.AutoCompact.InputTokenThreshold = 100000 + cloned.Budget.PromptBudget = 100000 + cloned.Budget.MaxReactiveCompacts = 4 if original.Compact.ManualStrategy == cloned.Compact.ManualStrategy { - t.Fatal("expected Compact Clone to be independent") + t.Fatal("expected Compact clone to be independent") } if original.Compact.ManualKeepRecentMessages == cloned.Compact.ManualKeepRecentMessages { - t.Fatal("expected ManualKeepRecentMessages Clone to be independent") + t.Fatal("expected ManualKeepRecentMessages clone to be independent") } - if original.AutoCompact.Enabled == cloned.AutoCompact.Enabled { - t.Fatal("expected AutoCompact Enabled clone to be independent") + if original.Budget.PromptBudget == cloned.Budget.PromptBudget { + t.Fatal("expected Budget PromptBudget clone to be independent") } - if original.AutoCompact.InputTokenThreshold == cloned.AutoCompact.InputTokenThreshold { - t.Fatal("expected AutoCompact InputTokenThreshold clone to be independent") + if original.Budget.MaxReactiveCompacts == cloned.Budget.MaxReactiveCompacts { + t.Fatal("expected Budget MaxReactiveCompacts clone to be independent") } } -func TestCompactConfigCloneValueSemantics(t *testing.T) { +func TestBudgetConfigCloneValueSemantics(t *testing.T) { t.Parallel() - original := CompactConfig{ - ManualStrategy: CompactManualStrategyFullReplace, - ManualKeepRecentMessages: 5, - MaxSummaryChars: 800, - MicroCompactDisabled: true, - ReadTimeMaxMessageSpans: 24, + original := BudgetConfig{ + PromptBudget: 75000, + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, } cloned := original.Clone() if original != cloned { @@ -57,16 +58,6 @@ func TestCompactConfigCloneValueSemantics(t *testing.T) { } } -func TestAutoCompactConfigCloneValueSemantics(t *testing.T) { - t.Parallel() - - original := AutoCompactConfig{Enabled: true, InputTokenThreshold: 75000} - cloned := original.Clone() - if original != cloned { - t.Fatalf("expected equal configs, got %+v vs %+v", original, cloned) - } -} - func TestContextConfigValidatePropagatesCompactError(t *testing.T) { t.Parallel() @@ -87,8 +78,12 @@ func TestContextConfigApplyDefaultsNilReceiver(t *testing.T) { var ctxCfg *ContextConfig ctxCfg.ApplyDefaults(ContextConfig{ - Compact: CompactConfig{ManualStrategy: CompactManualStrategyFullReplace}, - AutoCompact: AutoCompactConfig{InputTokenThreshold: 50000}, + Compact: CompactConfig{ManualStrategy: CompactManualStrategyFullReplace}, + Budget: BudgetConfig{ + ReserveTokens: 13000, + FallbackPromptBudget: 100000, + MaxReactiveCompacts: 3, + }, }) } diff --git a/internal/config/loader.go b/internal/config/loader.go index e0eaa68c..fc2e0262 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -34,8 +34,8 @@ type persistedConfig struct { } type persistedContextConfig struct { - Compact persistedCompactConfig `yaml:"compact,omitempty"` - AutoCompact persistedAutoCompactConfig `yaml:"auto_compact,omitempty"` + Compact persistedCompactConfig `yaml:"compact,omitempty"` + Budget persistedBudgetConfig `yaml:"budget,omitempty"` } type persistedCompactConfig struct { @@ -48,11 +48,11 @@ type persistedCompactConfig struct { MaxArchivedPromptChars int `yaml:"max_archived_prompt_chars,omitempty"` } -type persistedAutoCompactConfig struct { - Enabled bool `yaml:"enabled"` - InputTokenThreshold int `yaml:"input_token_threshold,omitempty"` - ReserveTokens int `yaml:"reserve_tokens,omitempty"` - FallbackInputTokenThreshold int `yaml:"fallback_input_token_threshold,omitempty"` +type persistedBudgetConfig struct { + PromptBudget int `yaml:"prompt_budget,omitempty"` + ReserveTokens int `yaml:"reserve_tokens,omitempty"` + FallbackPromptBudget int `yaml:"fallback_prompt_budget,omitempty"` + MaxReactiveCompacts int `yaml:"max_reactive_compacts,omitempty"` } type persistedMemoConfig struct { @@ -62,7 +62,6 @@ type persistedMemoConfig struct { MaxIndexBytes *int `yaml:"max_index_bytes,omitempty"` ExtractTimeoutSec *int `yaml:"extract_timeout_sec,omitempty"` ExtractRecentMessages *int `yaml:"extract_recent_messages,omitempty"` - MaxIndexLines *int `yaml:"max_index_lines,omitempty"` } func NewLoader(baseDir string, defaults *Config) *Loader { @@ -115,6 +114,12 @@ func (l *Loader) Load(ctx context.Context) (*Config, error) { return nil, err } } + if err := ctx.Err(); err != nil { + return nil, err + } + if err := UpgradeConfigSchemaBeforeLoad(l.ConfigPath()); err != nil { + return nil, err + } data, err := os.ReadFile(l.ConfigPath()) if err != nil { @@ -257,11 +262,11 @@ func newPersistedContextConfig(cfg ContextConfig) persistedContextConfig { ReadTimeMaxMessageSpans: cfg.Compact.ReadTimeMaxMessageSpans, MaxArchivedPromptChars: cfg.Compact.MaxArchivedPromptChars, }, - AutoCompact: persistedAutoCompactConfig{ - Enabled: cfg.AutoCompact.Enabled, - InputTokenThreshold: cfg.AutoCompact.InputTokenThreshold, - ReserveTokens: cfg.AutoCompact.ReserveTokens, - FallbackInputTokenThreshold: cfg.AutoCompact.FallbackInputTokenThreshold, + Budget: persistedBudgetConfig{ + PromptBudget: cfg.Budget.PromptBudget, + ReserveTokens: cfg.Budget.ReserveTokens, + FallbackPromptBudget: cfg.Budget.FallbackPromptBudget, + MaxReactiveCompacts: cfg.Budget.MaxReactiveCompacts, }, } } @@ -278,15 +283,15 @@ func fromPersistedContextConfig(file persistedContextConfig, defaults ContextCon ReadTimeMaxMessageSpans: file.Compact.ReadTimeMaxMessageSpans, MaxArchivedPromptChars: file.Compact.MaxArchivedPromptChars, }, - AutoCompact: AutoCompactConfig{ - Enabled: file.AutoCompact.Enabled, - InputTokenThreshold: file.AutoCompact.InputTokenThreshold, - ReserveTokens: file.AutoCompact.ReserveTokens, - FallbackInputTokenThreshold: file.AutoCompact.FallbackInputTokenThreshold, + Budget: BudgetConfig{ + PromptBudget: file.Budget.PromptBudget, + ReserveTokens: file.Budget.ReserveTokens, + FallbackPromptBudget: file.Budget.FallbackPromptBudget, + MaxReactiveCompacts: file.Budget.MaxReactiveCompacts, }, } out.Compact.ApplyDefaults(defaults.Compact) - out.AutoCompact.ApplyDefaults(defaults.AutoCompact) + out.Budget.ApplyDefaults(defaults.Budget) return out } @@ -362,8 +367,6 @@ func fromPersistedMemoConfig(file persistedMemoConfig, defaults MemoConfig) Memo } if file.MaxEntries != nil { out.MaxEntries = *file.MaxEntries - } else if file.MaxIndexLines != nil { - out.MaxEntries = *file.MaxIndexLines } if file.MaxIndexBytes != nil { out.MaxIndexBytes = *file.MaxIndexBytes diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 2ecaef0f..121a975f 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -117,6 +117,79 @@ shell: powershell } } +func TestLoaderUpgradesContextBudgetBeforeStrictParse(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-4.1 +shell: powershell +context: + auto_compact: + input_token_threshold: 120000 + reserve_tokens: 13000 + fallback_input_token_threshold: 100000 +` + writeLoaderConfig(t, loader, raw) + + cfg, err := loader.Load(context.Background()) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if cfg.Context.Budget.PromptBudget != 120000 { + t.Fatalf("expected prompt_budget migrated, got %d", cfg.Context.Budget.PromptBudget) + } + if cfg.Context.Budget.ReserveTokens != 13000 { + t.Fatalf("expected reserve_tokens migrated, got %d", cfg.Context.Budget.ReserveTokens) + } + if cfg.Context.Budget.FallbackPromptBudget != 100000 { + t.Fatalf("expected fallback_prompt_budget migrated, got %d", cfg.Context.Budget.FallbackPromptBudget) + } + + data, err := os.ReadFile(loader.ConfigPath()) + if err != nil { + t.Fatalf("read migrated config: %v", err) + } + text := string(data) + if strings.Contains(text, "auto_compact:") { + t.Fatalf("expected loader migration to remove auto_compact, got:\n%s", text) + } + if !strings.Contains(text, "budget:") { + t.Fatalf("expected loader migration to persist budget block, got:\n%s", text) + } + + backup, err := os.ReadFile(loader.ConfigPath() + ".bak") + if err != nil { + t.Fatalf("read migration backup: %v", err) + } + if !strings.Contains(string(backup), "auto_compact:") { + t.Fatalf("expected backup to preserve original config, got:\n%s", backup) + } +} + +func TestLoaderRejectsAmbiguousContextBudgetMigration(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-4.1 +shell: powershell +context: + budget: + prompt_budget: 110000 + auto_compact: + input_token_threshold: 120000 +` + writeLoaderConfig(t, loader, raw) + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "context.auto_compact and context.budget cannot both exist") { + t.Fatalf("expected ambiguous migration error, got %v", err) + } +} + func TestLoaderLoadInvalidBaseDir(t *testing.T) { t.Parallel() @@ -755,7 +828,7 @@ models: } } -func TestLoaderParsesAutoCompactDerivedFields(t *testing.T) { +func TestLoaderParsesBudgetFields(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) @@ -764,38 +837,41 @@ selected_provider: openai current_model: gpt-5.4 shell: powershell context: - auto_compact: - enabled: true - input_token_threshold: 0 - reserve_tokens: 9000 - fallback_input_token_threshold: 88000 -` + budget: + prompt_budget: 0 + reserve_tokens: 9000 + fallback_prompt_budget: 88000 + max_reactive_compacts: 4 + ` writeLoaderConfig(t, loader, raw) cfg, err := loader.Load(context.Background()) if err != nil { t.Fatalf("Load() error = %v", err) } - if cfg.Context.AutoCompact.InputTokenThreshold != 0 { - t.Fatalf("expected implicit threshold 0, got %d", cfg.Context.AutoCompact.InputTokenThreshold) + if cfg.Context.Budget.PromptBudget != 0 { + t.Fatalf("expected implicit prompt budget 0, got %d", cfg.Context.Budget.PromptBudget) } - if cfg.Context.AutoCompact.ReserveTokens != 9000 { - t.Fatalf("expected reserve_tokens=9000, got %d", cfg.Context.AutoCompact.ReserveTokens) + if cfg.Context.Budget.ReserveTokens != 9000 { + t.Fatalf("expected reserve_tokens=9000, got %d", cfg.Context.Budget.ReserveTokens) } - if cfg.Context.AutoCompact.FallbackInputTokenThreshold != 88000 { - t.Fatalf("expected fallback_input_token_threshold=88000, got %d", cfg.Context.AutoCompact.FallbackInputTokenThreshold) + if cfg.Context.Budget.FallbackPromptBudget != 88000 { + t.Fatalf("expected fallback_prompt_budget=88000, got %d", cfg.Context.Budget.FallbackPromptBudget) + } + if cfg.Context.Budget.MaxReactiveCompacts != 4 { + t.Fatalf("expected max_reactive_compacts=4, got %d", cfg.Context.Budget.MaxReactiveCompacts) } } -func TestLoaderSavePersistsAutoCompactDerivedFields(t *testing.T) { +func TestLoaderSavePersistsBudgetFields(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) cfg := testDefaultConfig().Clone() - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 0 - cfg.Context.AutoCompact.ReserveTokens = 9000 - cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 + cfg.Context.Budget.PromptBudget = 0 + cfg.Context.Budget.ReserveTokens = 9000 + cfg.Context.Budget.FallbackPromptBudget = 88000 + cfg.Context.Budget.MaxReactiveCompacts = 4 if err := loader.Save(context.Background(), &cfg); err != nil { t.Fatalf("Save() error = %v", err) @@ -806,14 +882,17 @@ func TestLoaderSavePersistsAutoCompactDerivedFields(t *testing.T) { t.Fatalf("read config: %v", err) } text := string(data) - if strings.Contains(text, "input_token_threshold: 100000") { - t.Fatalf("expected implicit threshold to avoid legacy default, got:\n%s", text) + if strings.Contains(text, "prompt_budget: 100000") { + t.Fatalf("expected implicit prompt budget to avoid default serialization, got:\n%s", text) } if !strings.Contains(text, "reserve_tokens: 9000") { t.Fatalf("expected reserve_tokens to persist, got:\n%s", text) } - if !strings.Contains(text, "fallback_input_token_threshold: 88000") { - t.Fatalf("expected fallback_input_token_threshold to persist, got:\n%s", text) + if !strings.Contains(text, "fallback_prompt_budget: 88000") { + t.Fatalf("expected fallback_prompt_budget to persist, got:\n%s", text) + } + if !strings.Contains(text, "max_reactive_compacts: 4") { + t.Fatalf("expected max_reactive_compacts to persist, got:\n%s", text) } } @@ -1677,28 +1756,6 @@ shell: powershell } } -func TestLoaderSupportsLegacyMemoMaxIndexLinesField(t *testing.T) { - t.Parallel() - - loader := NewLoader(t.TempDir(), testDefaultConfig()) - raw := ` -selected_provider: openai -current_model: gpt-4.1 -shell: powershell -memo: - max_index_lines: 123 -` - writeLoaderConfig(t, loader, raw) - - cfg, err := loader.Load(context.Background()) - if err != nil { - t.Fatalf("expected legacy memo field to be accepted, got %v", err) - } - if cfg.Memo.MaxEntries != 123 { - t.Fatalf("expected legacy max_index_lines mapped to memo.max_entries=123, got %d", cfg.Memo.MaxEntries) - } -} - func TestLoaderRejectsExplicitInvalidMemoNumbers(t *testing.T) { t.Parallel() @@ -1752,6 +1809,25 @@ memo: } } +func TestLoaderRejectsLegacyMemoMaxIndexLinesField(t *testing.T) { + t.Parallel() + + loader := NewLoader(t.TempDir(), testDefaultConfig()) + raw := ` +selected_provider: openai +current_model: gpt-4.1 +shell: powershell +memo: + max_index_lines: 123 +` + writeLoaderConfig(t, loader, raw) + + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "field max_index_lines not found") { + t.Fatalf("expected legacy memo field rejection, got %v", err) + } +} + func TestLoaderLoadsCompactExtendedFields(t *testing.T) { t.Parallel() diff --git a/internal/config/manager.go b/internal/config/manager.go index 22cbdb6f..2c825b73 100644 --- a/internal/config/manager.go +++ b/internal/config/manager.go @@ -42,10 +42,6 @@ func (m *Manager) Load(ctx context.Context) (Config, error) { return snapshot, nil } -func (m *Manager) Reload(ctx context.Context) (Config, error) { - return m.Load(ctx) -} - func (m *Manager) Get() Config { m.mu.RLock() defer m.mu.RUnlock() diff --git a/internal/config/provider.go b/internal/config/provider.go index f5a47129..2bf428aa 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -35,12 +35,11 @@ type ProviderConfig struct { type ResolvedProviderConfig struct { ProviderConfig - APIKey string `yaml:"-"` SessionAssetPolicy session.AssetPolicy `yaml:"-"` RequestAssetBudget provider.RequestAssetBudget `yaml:"-"` } -// ResolveSelectedProvider 解析当前配置中选中的 provider,并补全运行时所需的密钥信息。 +// ResolveSelectedProvider 解析当前配置中选中的 provider,并补全运行时所需的运行时策略。 func ResolveSelectedProvider(cfg Config) (ResolvedProviderConfig, error) { providerName := strings.TrimSpace(cfg.SelectedProvider) if providerName == "" { @@ -51,10 +50,7 @@ func ResolveSelectedProvider(cfg Config) (ResolvedProviderConfig, error) { if err != nil { return ResolvedProviderConfig{}, err } - resolved, err := providerCfg.Resolve() - if err != nil { - return ResolvedProviderConfig{}, err - } + resolved := ResolvedProviderConfig{ProviderConfig: providerCfg} resolved.SessionAssetPolicy = cfg.Runtime.ResolveSessionAssetPolicy() resolved.RequestAssetBudget = cfg.Runtime.ResolveRequestAssetBudget() return resolved, nil @@ -117,42 +113,15 @@ func (p ProviderConfig) Identity() (provider.ProviderIdentity, error) { } func (p ProviderConfig) ResolveAPIKey() (string, error) { - envName := strings.TrimSpace(p.APIKeyEnv) - if envName == "" { + if strings.TrimSpace(p.APIKeyEnv) == "" { return "", fmt.Errorf("config: provider %q api_key_env is empty", p.Name) } - - value := strings.TrimSpace(os.Getenv(envName)) - if value != "" { - return value, nil - } - - // 进程环境未命中时回退读取用户级环境变量(Windows 为注册表持久化), - // 并回填到当前进程环境,避免后续链路重复出现“变量为空”的假阴性。 - userValue, exists, err := LookupUserEnvVar(envName) - if err != nil { - return "", fmt.Errorf("config: lookup user environment variable %s: %w", envName, err) - } - if exists { - trimmedUserValue := strings.TrimSpace(userValue) - if trimmedUserValue != "" { - _ = os.Setenv(envName, trimmedUserValue) - return trimmedUserValue, nil - } - } - - return "", fmt.Errorf("config: environment variable %s is empty", envName) + return resolveRuntimeAPIKey(p.APIKeyEnv) } func (p ProviderConfig) Resolve() (ResolvedProviderConfig, error) { - apiKey, err := p.ResolveAPIKey() - if err != nil { - return ResolvedProviderConfig{}, err - } - return ResolvedProviderConfig{ ProviderConfig: p, - APIKey: apiKey, }, nil } @@ -263,7 +232,8 @@ func (p ResolvedProviderConfig) ToRuntimeConfig() (provider.RuntimeConfig, error Driver: p.Driver, BaseURL: baseURL, DefaultModel: p.Model, - APIKey: p.APIKey, + APIKeyEnv: p.APIKeyEnv, + APIKeyResolver: resolveRuntimeAPIKey, SessionAssetPolicy: p.SessionAssetPolicy, RequestAssetBudget: p.RequestAssetBudget, ChatAPIMode: chatAPIMode, @@ -272,6 +242,33 @@ func (p ResolvedProviderConfig) ToRuntimeConfig() (provider.RuntimeConfig, error }, nil } +// resolveRuntimeAPIKey 在 provider 真正发起请求前解析 API Key,并在需要时补回当前进程环境。 +func resolveRuntimeAPIKey(envName string) (string, error) { + envName = strings.TrimSpace(envName) + if envName == "" { + return "", errors.New("config: provider api_key_env is empty") + } + + value := strings.TrimSpace(os.Getenv(envName)) + if value != "" { + return value, nil + } + + userValue, exists, err := LookupUserEnvVar(envName) + if err != nil { + return "", fmt.Errorf("config: lookup user environment variable %s: %w", envName, err) + } + if exists { + trimmedUserValue := strings.TrimSpace(userValue) + if trimmedUserValue != "" { + _ = os.Setenv(envName, trimmedUserValue) + return trimmedUserValue, nil + } + } + + return "", fmt.Errorf("config: environment variable %s is empty", envName) +} + // normalizeProviderDiscoverySettingsFromConfig 归一化 discovery 所需的最小路径配置。 func normalizeProviderDiscoverySettingsFromConfig(cfg ProviderConfig) (string, error) { return provider.NormalizeProviderDiscoverySettings(cfg.Driver, cfg.DiscoveryEndpointPath) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 21b74a50..c1c7f9a6 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -293,7 +293,15 @@ func TestProviderConfigResolveWrapsAPIKeyError(t *testing.T) { BaseURL: "https://example.com", APIKeyEnv: "UNRESOLVABLE_API_KEY_FOR_TEST", } - _, err := cfg.Resolve() + resolved, err := cfg.Resolve() + if err != nil { + t.Fatalf("Resolve() error = %v", err) + } + runtimeConfig, err := resolved.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + _, err = runtimeConfig.ResolveAPIKeyValue() if err == nil || !strings.Contains(err.Error(), "UNRESOLVABLE_API_KEY_FOR_TEST") { t.Fatalf("expected unresolved API key error, got %v", err) } @@ -652,12 +660,12 @@ func TestResolvedProviderConfigToRuntimeConfig(t *testing.T) { resolved := ResolvedProviderConfig{ ProviderConfig: ProviderConfig{ - Name: "company-gateway", - Driver: "openaicompat", - BaseURL: "https://llm.example.com/v1", - Model: "server-default", + Name: "company-gateway", + Driver: "openaicompat", + BaseURL: "https://llm.example.com/v1", + Model: "server-default", + APIKeyEnv: "COMPANY_GATEWAY_KEY", }, - APIKey: "secret-key", SessionAssetPolicy: session.AssetPolicy{ MaxSessionAssetBytes: 1024, }, @@ -675,7 +683,7 @@ func TestResolvedProviderConfigToRuntimeConfig(t *testing.T) { Driver: "openaicompat", BaseURL: "https://llm.example.com/v1", DefaultModel: "server-default", - APIKey: "secret-key", + APIKeyEnv: "COMPANY_GATEWAY_KEY", SessionAssetPolicy: session.AssetPolicy{ MaxSessionAssetBytes: 1024, }, @@ -687,7 +695,20 @@ func TestResolvedProviderConfigToRuntimeConfig(t *testing.T) { DiscoveryEndpointPath: providerpkg.DiscoveryEndpointPathModels, } - if got != want { + if got.APIKeyResolver == nil { + t.Fatal("expected APIKeyResolver to be set") + } + got.APIKeyResolver = nil + if got.Name != want.Name || + got.Driver != want.Driver || + got.BaseURL != want.BaseURL || + got.DefaultModel != want.DefaultModel || + got.APIKeyEnv != want.APIKeyEnv || + got.SessionAssetPolicy != want.SessionAssetPolicy || + got.RequestAssetBudget != want.RequestAssetBudget || + got.ChatAPIMode != want.ChatAPIMode || + got.ChatEndpointPath != want.ChatEndpointPath || + got.DiscoveryEndpointPath != want.DiscoveryEndpointPath { t.Fatalf("ToRuntimeConfig() = %+v, want %+v", got, want) } } @@ -697,12 +718,12 @@ func TestResolvedProviderConfigToRuntimeConfigStripsBaseURLUserinfo(t *testing.T resolved := ResolvedProviderConfig{ ProviderConfig: ProviderConfig{ - Name: "company-gateway", - Driver: "openaicompat", - BaseURL: "https://token@llm.example.com/v1", - Model: "server-default", + Name: "company-gateway", + Driver: "openaicompat", + BaseURL: "https://token@llm.example.com/v1", + Model: "server-default", + APIKeyEnv: "COMPANY_GATEWAY_KEY", }, - APIKey: "secret-key", } got, err := resolved.ToRuntimeConfig() @@ -728,7 +749,6 @@ func TestResolvedProviderConfigToRuntimeConfigReturnsProtocolNormalizationError( APIKeyEnv: "TEST_KEY", ChatEndpointPath: "https://llm.example.com/chat/completions", }, - APIKey: "secret-key", } _, err := resolved.ToRuntimeConfig() @@ -749,10 +769,10 @@ func TestResolvedProviderConfigToRuntimeConfigUsesNormalizedOpenAICompatPaths(t Driver: "openaicompat", BaseURL: "https://llm.example.com/v1", Model: "gpt-5.4", + APIKeyEnv: "RESPONSES_GATEWAY_KEY", ChatAPIMode: providerpkg.ChatAPIModeResponses, ChatEndpointPath: "/responses", }, - APIKey: "secret-key", } got, err := resolved.ToRuntimeConfig() @@ -776,10 +796,10 @@ func TestResolvedProviderConfigToRuntimeConfigStripsSDKChatEndpointPath(t *testi Driver: providerpkg.DriverGemini, BaseURL: GeminiDefaultBaseURL, Model: GeminiDefaultModel, + APIKeyEnv: "GEMINI_KEY", ChatEndpointPath: "/models", DiscoveryEndpointPath: providerpkg.DiscoveryEndpointPathModels, }, - APIKey: "secret-key", } got, err := resolved.ToRuntimeConfig() diff --git a/internal/config/state/auto_compact_threshold.go b/internal/config/state/auto_compact_threshold.go deleted file mode 100644 index 5196d445..00000000 --- a/internal/config/state/auto_compact_threshold.go +++ /dev/null @@ -1,90 +0,0 @@ -package state - -import ( - "context" - "strings" - - "neo-code/internal/config" - "neo-code/internal/provider" -) - -// AutoCompactThresholdSource 标识自动压缩阈值最终采用的来源。 -type AutoCompactThresholdSource string - -const ( - AutoCompactThresholdSourceDisabled AutoCompactThresholdSource = "disabled" - AutoCompactThresholdSourceExplicit AutoCompactThresholdSource = "explicit" - AutoCompactThresholdSourceDerived AutoCompactThresholdSource = "derived" - AutoCompactThresholdSourceFallback AutoCompactThresholdSource = "fallback" -) - -// AutoCompactThresholdResolution 描述自动压缩阈值的解析结果,供 runtime 直接消费。 -type AutoCompactThresholdResolution struct { - Threshold int - Source AutoCompactThresholdSource - ContextWindow int - ModelID string -} - -// fallbackAutoCompactThresholdResolution 构造自动推导失败时使用的保底阈值结果。 -func fallbackAutoCompactThresholdResolution(cfg config.Config) AutoCompactThresholdResolution { - return AutoCompactThresholdResolution{ - Threshold: cfg.Context.AutoCompact.FallbackInputTokenThreshold, - Source: AutoCompactThresholdSourceFallback, - ModelID: strings.TrimSpace(cfg.CurrentModel), - } -} - -// ResolveAutoCompactThreshold 基于当前选择的 provider/model 和模型目录快照解析最终阈值。 -func ResolveAutoCompactThreshold( - ctx context.Context, - cfg config.Config, - catalogs ModelCatalog, -) (AutoCompactThresholdResolution, error) { - autoCompact := cfg.Context.AutoCompact - if !autoCompact.Enabled { - return AutoCompactThresholdResolution{Source: AutoCompactThresholdSourceDisabled}, nil - } - - if autoCompact.InputTokenThreshold > 0 { - return AutoCompactThresholdResolution{ - Threshold: autoCompact.InputTokenThreshold, - Source: AutoCompactThresholdSourceExplicit, - ModelID: strings.TrimSpace(cfg.CurrentModel), - }, nil - } - - resolution := fallbackAutoCompactThresholdResolution(cfg) - providerCfg, err := selectedProviderConfig(cfg) - if err != nil { - return resolution, nil - } - if catalogs == nil { - return resolution, nil - } - - input, err := catalogInputFromProvider(providerCfg) - if err != nil { - return resolution, nil - } - - models, err := catalogs.ListProviderModelsSnapshot(ctx, input) - if err != nil { - return resolution, err - } - - modelID := provider.NormalizeKey(cfg.CurrentModel) - for _, model := range models { - if provider.NormalizeKey(model.ID) != modelID { - continue - } - resolution.ContextWindow = model.ContextWindow - if model.ContextWindow > autoCompact.ReserveTokens { - resolution.Threshold = model.ContextWindow - autoCompact.ReserveTokens - resolution.Source = AutoCompactThresholdSourceDerived - } - return resolution, nil - } - - return resolution, nil -} diff --git a/internal/config/state/auto_compact_threshold_test.go b/internal/config/state/auto_compact_threshold_test.go deleted file mode 100644 index 1d318c48..00000000 --- a/internal/config/state/auto_compact_threshold_test.go +++ /dev/null @@ -1,158 +0,0 @@ -package state - -import ( - "context" - "errors" - "testing" - - configpkg "neo-code/internal/config" - providertypes "neo-code/internal/provider/types" -) - -func assertAutoCompactResolution(t *testing.T, got AutoCompactThresholdResolution, wantThreshold int, wantSource AutoCompactThresholdSource) { - t.Helper() - - if got.Threshold != wantThreshold || got.Source != wantSource { - t.Fatalf("expected threshold=%d source=%s, got %+v", wantThreshold, wantSource, got) - } -} - -func TestResolveAutoCompactThresholdDisabled(t *testing.T) { - t.Parallel() - - cfg := configpkg.StaticDefaults().Clone() - cfg.Context.AutoCompact.Enabled = false - - resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, nil) - if err != nil { - t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) - } - assertAutoCompactResolution(t, resolution, 0, AutoCompactThresholdSourceDisabled) -} - -func TestResolveAutoCompactThresholdExplicitWins(t *testing.T) { - t.Parallel() - - cfg := configpkg.StaticDefaults().Clone() - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 42000 - - resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, nil) - if err != nil { - t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) - } - assertAutoCompactResolution(t, resolution, 42000, AutoCompactThresholdSourceExplicit) -} - -func TestResolveAutoCompactThresholdDerivedFromContextWindow(t *testing.T) { - t.Parallel() - - cfg := testDefaultConfig().Clone() - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 0 - cfg.Context.AutoCompact.ReserveTokens = 13000 - cfg.CurrentModel = "deepseek-coder" - cfg.Providers[0].Model = "deepseek-coder" - cfg.Providers[0].Models = []providertypes.ModelDescriptor{{ - ID: "deepseek-coder", - ContextWindow: 131072, - }} - - resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ - snapshotModels: cfg.Providers[0].Models, - }) - if err != nil { - t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) - } - assertAutoCompactResolution(t, resolution, 118072, AutoCompactThresholdSourceDerived) -} - -func TestResolveAutoCompactThresholdFallsBackWhenWindowTooSmall(t *testing.T) { - t.Parallel() - - cfg := testDefaultConfig().Clone() - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 0 - cfg.Context.AutoCompact.ReserveTokens = 13000 - cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 - cfg.CurrentModel = "small-model" - cfg.Providers[0].Model = "small-model" - - resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ - snapshotModels: []providertypes.ModelDescriptor{{ - ID: "small-model", - ContextWindow: 8000, - }}, - }) - if err != nil { - t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) - } - assertAutoCompactResolution(t, resolution, 88000, AutoCompactThresholdSourceFallback) -} - -func TestResolveAutoCompactThresholdFallsBackWhenModelMissing(t *testing.T) { - t.Parallel() - - cfg := testDefaultConfig().Clone() - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 0 - cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 - cfg.CurrentModel = "missing-model" - - resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ - snapshotModels: []providertypes.ModelDescriptor{{ID: "other-model", ContextWindow: 131072}}, - }) - if err != nil { - t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) - } - assertAutoCompactResolution(t, resolution, 88000, AutoCompactThresholdSourceFallback) -} - -func TestResolveAutoCompactThresholdFallsBackWhenSelectedProviderInvalid(t *testing.T) { - t.Parallel() - - cfg := testDefaultConfig().Clone() - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 0 - cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 - cfg.SelectedProvider = "missing-provider" - - resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{}) - if err != nil { - t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) - } - assertAutoCompactResolution(t, resolution, 88000, AutoCompactThresholdSourceFallback) -} - -func TestResolveAutoCompactThresholdFallsBackWhenCatalogInputResolutionFails(t *testing.T) { - t.Parallel() - - cfg := testDefaultConfig().Clone() - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 0 - cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 - cfg.Providers[0].BaseURL = "" - - resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{}) - if err != nil { - t.Fatalf("ResolveAutoCompactThreshold() error = %v", err) - } - assertAutoCompactResolution(t, resolution, 88000, AutoCompactThresholdSourceFallback) -} - -func TestResolveAutoCompactThresholdFallsBackWhenSnapshotLookupFails(t *testing.T) { - t.Parallel() - - cfg := testDefaultConfig().Clone() - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 0 - cfg.Context.AutoCompact.FallbackInputTokenThreshold = 88000 - - resolution, err := ResolveAutoCompactThreshold(context.Background(), cfg, catalogMethodsStub{ - snapshotErr: errors.New("snapshot failed"), - }) - if err == nil { - t.Fatalf("ResolveAutoCompactThreshold() error = nil, want non-nil") - } - assertAutoCompactResolution(t, resolution, 88000, AutoCompactThresholdSourceFallback) -} diff --git a/internal/config/state/budget.go b/internal/config/state/budget.go new file mode 100644 index 00000000..afdf8f32 --- /dev/null +++ b/internal/config/state/budget.go @@ -0,0 +1,85 @@ +package state + +import ( + "context" + "strings" + + "neo-code/internal/config" + "neo-code/internal/provider" +) + +// PromptBudgetSource 标识 prompt budget 最终采用的来源。 +type PromptBudgetSource string + +const ( + PromptBudgetSourceExplicit PromptBudgetSource = "explicit" + PromptBudgetSourceDerived PromptBudgetSource = "derived" + PromptBudgetSourceFallback PromptBudgetSource = "fallback" +) + +// PromptBudgetResolution 描述 prompt budget 的解析结果,供 runtime 直接消费。 +type PromptBudgetResolution struct { + PromptBudget int + Source PromptBudgetSource + ContextWindow int + ModelID string +} + +// fallbackPromptBudgetResolution 构造自动推导失败时使用的保底预算结果。 +func fallbackPromptBudgetResolution(cfg config.Config) PromptBudgetResolution { + return PromptBudgetResolution{ + PromptBudget: cfg.Context.Budget.FallbackPromptBudget, + Source: PromptBudgetSourceFallback, + ModelID: strings.TrimSpace(cfg.CurrentModel), + } +} + +// ResolvePromptBudget 基于当前选择的 provider/model 和模型目录快照解析最终输入预算。 +func ResolvePromptBudget( + ctx context.Context, + cfg config.Config, + catalogs ModelCatalog, +) (PromptBudgetResolution, error) { + budget := cfg.Context.Budget + if budget.PromptBudget > 0 { + return PromptBudgetResolution{ + PromptBudget: budget.PromptBudget, + Source: PromptBudgetSourceExplicit, + ModelID: strings.TrimSpace(cfg.CurrentModel), + }, nil + } + + resolution := fallbackPromptBudgetResolution(cfg) + providerCfg, err := selectedProviderConfig(cfg) + if err != nil { + return resolution, nil + } + if catalogs == nil { + return resolution, nil + } + + input, err := catalogInputFromProvider(providerCfg) + if err != nil { + return resolution, nil + } + + models, err := catalogs.ListProviderModelsSnapshot(ctx, input) + if err != nil { + return resolution, err + } + + modelID := provider.NormalizeKey(cfg.CurrentModel) + for _, model := range models { + if provider.NormalizeKey(model.ID) != modelID { + continue + } + resolution.ContextWindow = model.ContextWindow + if model.ContextWindow > budget.ReserveTokens { + resolution.PromptBudget = model.ContextWindow - budget.ReserveTokens + resolution.Source = PromptBudgetSourceDerived + } + return resolution, nil + } + + return resolution, nil +} diff --git a/internal/config/state/budget_test.go b/internal/config/state/budget_test.go new file mode 100644 index 00000000..e0fa1eae --- /dev/null +++ b/internal/config/state/budget_test.go @@ -0,0 +1,143 @@ +package state + +import ( + "context" + "errors" + "testing" + + configpkg "neo-code/internal/config" + providertypes "neo-code/internal/provider/types" +) + +func assertPromptBudgetResolution( + t *testing.T, + got PromptBudgetResolution, + wantBudget int, + wantSource PromptBudgetSource, +) { + t.Helper() + + if got.PromptBudget != wantBudget || got.Source != wantSource { + t.Fatalf("expected budget=%d source=%s, got %+v", wantBudget, wantSource, got) + } +} + +func TestResolvePromptBudgetExplicitWins(t *testing.T) { + t.Parallel() + + cfg := configpkg.StaticDefaults().Clone() + cfg.Context.Budget.PromptBudget = 42000 + + resolution, err := ResolvePromptBudget(context.Background(), cfg, nil) + if err != nil { + t.Fatalf("ResolvePromptBudget() error = %v", err) + } + assertPromptBudgetResolution(t, resolution, 42000, PromptBudgetSourceExplicit) +} + +func TestResolvePromptBudgetDerivedFromContextWindow(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.Budget.PromptBudget = 0 + cfg.Context.Budget.ReserveTokens = 13000 + cfg.CurrentModel = "deepseek-coder" + cfg.Providers[0].Model = "deepseek-coder" + cfg.Providers[0].Models = []providertypes.ModelDescriptor{{ + ID: "deepseek-coder", + ContextWindow: 131072, + }} + + resolution, err := ResolvePromptBudget(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: cfg.Providers[0].Models, + }) + if err != nil { + t.Fatalf("ResolvePromptBudget() error = %v", err) + } + assertPromptBudgetResolution(t, resolution, 118072, PromptBudgetSourceDerived) +} + +func TestResolvePromptBudgetFallsBackWhenWindowTooSmall(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.Budget.PromptBudget = 0 + cfg.Context.Budget.ReserveTokens = 13000 + cfg.Context.Budget.FallbackPromptBudget = 88000 + cfg.CurrentModel = "small-model" + cfg.Providers[0].Model = "small-model" + + resolution, err := ResolvePromptBudget(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: []providertypes.ModelDescriptor{{ + ID: "small-model", + ContextWindow: 8000, + }}, + }) + if err != nil { + t.Fatalf("ResolvePromptBudget() error = %v", err) + } + assertPromptBudgetResolution(t, resolution, 88000, PromptBudgetSourceFallback) +} + +func TestResolvePromptBudgetFallsBackWhenModelMissing(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.Budget.PromptBudget = 0 + cfg.Context.Budget.FallbackPromptBudget = 88000 + cfg.CurrentModel = "missing-model" + + resolution, err := ResolvePromptBudget(context.Background(), cfg, catalogMethodsStub{ + snapshotModels: []providertypes.ModelDescriptor{{ID: "other-model", ContextWindow: 131072}}, + }) + if err != nil { + t.Fatalf("ResolvePromptBudget() error = %v", err) + } + assertPromptBudgetResolution(t, resolution, 88000, PromptBudgetSourceFallback) +} + +func TestResolvePromptBudgetFallsBackWhenSelectedProviderInvalid(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.Budget.PromptBudget = 0 + cfg.Context.Budget.FallbackPromptBudget = 88000 + cfg.SelectedProvider = "missing-provider" + + resolution, err := ResolvePromptBudget(context.Background(), cfg, catalogMethodsStub{}) + if err != nil { + t.Fatalf("ResolvePromptBudget() error = %v", err) + } + assertPromptBudgetResolution(t, resolution, 88000, PromptBudgetSourceFallback) +} + +func TestResolvePromptBudgetFallsBackWhenCatalogInputResolutionFails(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.Budget.PromptBudget = 0 + cfg.Context.Budget.FallbackPromptBudget = 88000 + cfg.Providers[0].BaseURL = "" + + resolution, err := ResolvePromptBudget(context.Background(), cfg, catalogMethodsStub{}) + if err != nil { + t.Fatalf("ResolvePromptBudget() error = %v", err) + } + assertPromptBudgetResolution(t, resolution, 88000, PromptBudgetSourceFallback) +} + +func TestResolvePromptBudgetFallsBackWhenSnapshotLookupFails(t *testing.T) { + t.Parallel() + + cfg := testDefaultConfig().Clone() + cfg.Context.Budget.PromptBudget = 0 + cfg.Context.Budget.FallbackPromptBudget = 88000 + + resolution, err := ResolvePromptBudget(context.Background(), cfg, catalogMethodsStub{ + snapshotErr: errors.New("snapshot failed"), + }) + if err == nil { + t.Fatalf("ResolvePromptBudget() error = nil, want non-nil") + } + assertPromptBudgetResolution(t, resolution, 88000, PromptBudgetSourceFallback) +} diff --git a/internal/config/state/model_additional_test.go b/internal/config/state/model_additional_test.go index 89772ca1..95d4ab1e 100644 --- a/internal/config/state/model_additional_test.go +++ b/internal/config/state/model_additional_test.go @@ -325,7 +325,7 @@ func TestSelectionServiceEnsureSelectionBootstrapInitialWhenNoSelection(t *testi t.Fatal("expected non-empty ProviderID after bootstrap") } - reloaded, _ := manager.Reload(context.Background()) + reloaded, _ := manager.Load(context.Background()) if reloaded.SelectedProvider == "" { t.Fatal("expected persisted selection after bootstrap") } diff --git a/internal/config/state/model_test.go b/internal/config/state/model_test.go index fc1a6b90..de9e7e22 100644 --- a/internal/config/state/model_test.go +++ b/internal/config/state/model_test.go @@ -57,9 +57,16 @@ func TestCatalogInputFromProviderBuiltinIncludesDefaultsAndLazyDiscovery(t *test if err != nil { t.Fatalf("ResolveDiscoveryConfig() error = %v", err) } - if runtimeConfig.DefaultModel != "server-default" || runtimeConfig.APIKey != "secret-key" { + if runtimeConfig.DefaultModel != "server-default" { t.Fatalf("expected runtime config to resolve model and api key, got %+v", runtimeConfig) } + apiKey, err := runtimeConfig.ResolveAPIKeyValue() + if err != nil { + t.Fatalf("ResolveAPIKeyValue() error = %v", err) + } + if apiKey != "secret-key" { + t.Fatalf("expected resolved api key secret-key, got %q", apiKey) + } } func TestCatalogInputFromProviderDefaultsOpenAICompatibleIdentityPaths(t *testing.T) { @@ -137,7 +144,11 @@ func TestCatalogInputFromProviderResolveDiscoveryConfigPropagatesResolveError(t t.Fatalf("catalogInputFromProvider() error = %v", err) } - _, err = input.ResolveDiscoveryConfig() + runtimeConfig, err := input.ResolveDiscoveryConfig() + if err != nil { + t.Fatalf("ResolveDiscoveryConfig() error = %v", err) + } + _, err = runtimeConfig.ResolveAPIKeyValue() if err == nil || !strings.Contains(err.Error(), "environment variable MISSING_PROVIDER_API_KEY is empty") { t.Fatalf("expected resolve api key error, got %v", err) } diff --git a/internal/config/state/service_provider_create.go b/internal/config/state/service_provider_create.go index ad66342a..158b94ab 100644 --- a/internal/config/state/service_provider_create.go +++ b/internal/config/state/service_provider_create.go @@ -123,7 +123,7 @@ func (s *Service) CreateCustomProvider(ctx context.Context, input CreateCustomPr if providerSaved { reloadCtx, cancel := context.WithTimeout(context.Background(), providerCreateRollbackReloadTimeout) defer cancel() - if _, reloadErr := s.manager.Reload(reloadCtx); reloadErr != nil { + if _, reloadErr := s.manager.Load(reloadCtx); reloadErr != nil { return fmt.Errorf("%w (post-rollback reload failed: %v)", rolledErr, reloadErr) } } @@ -156,7 +156,7 @@ func (s *Service) CreateCustomProvider(ctx context.Context, input CreateCustomPr } processEnvApplied = true - if _, err := s.manager.Reload(ctx); err != nil { + if _, err := s.manager.Load(ctx); err != nil { return Selection{}, rollback(fmt.Errorf("selection: reload config snapshot: %w", err)) } diff --git a/internal/config/state/service_test.go b/internal/config/state/service_test.go index 9e30b338..8559c907 100644 --- a/internal/config/state/service_test.go +++ b/internal/config/state/service_test.go @@ -104,9 +104,9 @@ func TestSelectionServiceBuiltinUnsupportedAPIStyleNoLongerFailsAcrossSnapshotPa t.Fatalf("expected EnsureSelection() to remain available, got %v", err) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.SelectedProvider != providerCfg.Name { t.Fatalf("expected selected provider to stay on %q, got %q", providerCfg.Name, reloaded.SelectedProvider) @@ -328,9 +328,9 @@ func TestSelectionServiceEnsureSelectionRepairsInvalidCurrentModel(t *testing.T) t.Fatalf("unexpected normalized selection: %+v", selection) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.CurrentModel != OpenAIDefaultModel { t.Fatalf("expected rewritten current model %q, got %q", OpenAIDefaultModel, reloaded.CurrentModel) @@ -423,9 +423,9 @@ func TestSelectionServiceEnsureSelectionFallsBackToBuiltinDefaultModelWhenSnapsh t.Fatalf("expected builtin ensure to use snapshot catalog only, got %+v", *tracker) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.CurrentModel != OpenAIDefaultModel { t.Fatalf("expected builtin fallback to persist %q, got %q", OpenAIDefaultModel, reloaded.CurrentModel) @@ -462,9 +462,9 @@ func TestSelectionServiceEnsureSelectionKeepsCustomSelectionWhenSnapshotMissing( t.Fatalf("expected custom ensure to use snapshot catalog only, got %+v", *tracker) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.CurrentModel != "unknown-model" { t.Fatalf("expected custom current model to stay unchanged, got %q", reloaded.CurrentModel) @@ -507,9 +507,9 @@ func TestSelectionServiceEnsureSelectionBackfillsEmptyCustomModelFromSynchronous t.Fatalf("expected custom ensure to try snapshot first and then sync discovery, got %+v", *tracker) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.CurrentModel != "server-coder" { t.Fatalf("expected discovered current model to persist, got %q", reloaded.CurrentModel) @@ -549,9 +549,9 @@ func TestSelectionServiceEnsureSelectionKeepsEmptyCustomModelWhenSynchronousDisc t.Fatalf("expected custom ensure to attempt sync discovery once, got %+v", *tracker) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.CurrentModel != "" { t.Fatalf("expected empty current model to stay unchanged after failed discovery, got %q", reloaded.CurrentModel) @@ -593,9 +593,9 @@ func TestSelectionServiceEnsureSelectionReturnsBootstrappedSelectionWhenCustomDi t.Fatalf("expected custom ensure to attempt snapshot and one sync discovery, got %+v", *tracker) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.SelectedProvider != "company-gateway" { t.Fatalf("expected selected provider to persist as company-gateway, got %q", reloaded.SelectedProvider) @@ -626,9 +626,9 @@ func TestSelectionServiceEnsureSelectionRetriesWhenProviderDriftsDuringUpdate(t t.Fatalf("expected retried selection to use drifted provider snapshot, got %+v", selection) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.SelectedProvider != QiniuName { t.Fatalf("expected selected provider to persist as %q, got %q", QiniuName, reloaded.SelectedProvider) @@ -659,9 +659,9 @@ func TestSelectionServiceEnsureSelectionRetriesWhenProviderDriftsBeforeEarlyRetu t.Fatalf("expected drifted provider selection after retry, got %+v", selection) } - reloaded, err := manager.Reload(context.Background()) + reloaded, err := manager.Load(context.Background()) if err != nil { - t.Fatalf("Reload() error = %v", err) + t.Fatalf("Load() error = %v", err) } if reloaded.SelectedProvider != QiniuName { t.Fatalf("expected selected provider to persist as %q, got %q", QiniuName, reloaded.SelectedProvider) diff --git a/internal/context/builder.go b/internal/context/builder.go index b9c7f269..8d0415b1 100644 --- a/internal/context/builder.go +++ b/internal/context/builder.go @@ -96,13 +96,16 @@ func (b *DefaultBuilder) Build(ctx context.Context, input BuildInput) (BuildResu pinChecker = NewDefaultPinChecker() } - shouldAutoCompact := input.Compact.AutoCompactThreshold > 0 && - input.Metadata.SessionInputTokens >= input.Compact.AutoCompactThreshold - return BuildResult{ - SystemPrompt: composeSystemPrompt(sections...), - Messages: applyReadTimeContextProjection(trimPolicy.Trim(input.Messages, input.Compact), input.TaskState, input.Compact, b.microCompactPolicies, b.microCompactSummarizers, pinChecker), - AutoCompactSuggested: shouldAutoCompact, + SystemPrompt: composeSystemPrompt(sections...), + Messages: applyReadTimeContextProjection( + trimPolicy.Trim(input.Messages, input.Compact), + input.TaskState, + input.Compact, + b.microCompactPolicies, + b.microCompactSummarizers, + pinChecker, + ), }, nil } diff --git a/internal/context/builder_test.go b/internal/context/builder_test.go index bb3cdc04..e53cb976 100644 --- a/internal/context/builder_test.go +++ b/internal/context/builder_test.go @@ -839,86 +839,6 @@ func TestTrimMessagesBoundaries(t *testing.T) { } } -func TestBuildAutoCompactSuggestedDisabled(t *testing.T) { - t.Parallel() - - builder := NewBuilder() - input := BuildInput{ - Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, - Metadata: testMetadata(t.TempDir()), - Compact: CompactOptions{AutoCompactThreshold: 0}, - } - input.Metadata.SessionInputTokens = 100 - - result, err := builder.Build(stdcontext.Background(), input) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if result.AutoCompactSuggested { - t.Fatalf("expected AutoCompactSuggested false when threshold is 0") - } -} - -func TestBuildAutoCompactSuggestedBelowThreshold(t *testing.T) { - t.Parallel() - - builder := NewBuilder() - input := BuildInput{ - Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, - Metadata: testMetadata(t.TempDir()), - Compact: CompactOptions{AutoCompactThreshold: 100}, - } - input.Metadata.SessionInputTokens = 99 - - result, err := builder.Build(stdcontext.Background(), input) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if result.AutoCompactSuggested { - t.Fatalf("expected AutoCompactSuggested false when tokens below threshold") - } -} - -func TestBuildAutoCompactSuggestedAtThreshold(t *testing.T) { - t.Parallel() - - builder := NewBuilder() - input := BuildInput{ - Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, - Metadata: testMetadata(t.TempDir()), - Compact: CompactOptions{AutoCompactThreshold: 100}, - } - input.Metadata.SessionInputTokens = 100 - - result, err := builder.Build(stdcontext.Background(), input) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if !result.AutoCompactSuggested { - t.Fatalf("expected AutoCompactSuggested true when tokens equal threshold") - } -} - -func TestBuildAutoCompactSuggestedAboveThreshold(t *testing.T) { - t.Parallel() - - builder := NewBuilder() - input := BuildInput{ - Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, - Metadata: testMetadata(t.TempDir()), - Compact: CompactOptions{AutoCompactThreshold: 100}, - } - input.Metadata.SessionInputTokens = 200 - - result, err := builder.Build(stdcontext.Background(), input) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if !result.AutoCompactSuggested { - t.Fatalf("expected AutoCompactSuggested true when tokens above threshold") - } -} - func TestNewBuilderWithMemo(t *testing.T) { t.Parallel() diff --git a/internal/context/compact/runner.go b/internal/context/compact/runner.go index bfa6aafd..9d680ec4 100644 --- a/internal/context/compact/runner.go +++ b/internal/context/compact/runner.go @@ -19,8 +19,8 @@ type Mode string const ( // ModeManual runs the explicit user-triggered compact flow. ModeManual Mode = "manual" - // ModeAuto runs the token-threshold-triggered compact flow. - ModeAuto Mode = "auto" + // ModeProactive runs the budget-triggered compact flow before provider send. + ModeProactive Mode = "proactive" // ModeReactive runs the provider-error-triggered compact flow. ModeReactive Mode = "reactive" ) @@ -127,7 +127,7 @@ func (s *Service) Run(ctx context.Context, input Input) (Result, error) { return Result{}, err } - if input.Mode != ModeManual && input.Mode != ModeAuto && input.Mode != ModeReactive { + if input.Mode != ModeManual && input.Mode != ModeProactive && input.Mode != ModeReactive { return Result{}, fmt.Errorf("compact: unsupported mode %q", input.Mode) } diff --git a/internal/context/compact/runner_test.go b/internal/context/compact/runner_test.go index 8e56764d..5aad2712 100644 --- a/internal/context/compact/runner_test.go +++ b/internal/context/compact/runner_test.go @@ -258,7 +258,7 @@ func TestReactiveCompactUsesKeepRecentAndReportsReactiveMode(t *testing.T) { } } -func TestAutoCompactUsesManualStrategyAndReportsAutoMode(t *testing.T) { +func TestProactiveCompactUsesManualStrategyAndReportsProactiveMode(t *testing.T) { t.Parallel() generator := &stubSummaryGenerator{output: validSummaryOutput()} @@ -276,7 +276,7 @@ func TestAutoCompactUsesManualStrategyAndReportsAutoMode(t *testing.T) { } result, err := runner.Run(context.Background(), Input{ - Mode: ModeAuto, + Mode: ModeProactive, SessionID: "session-auto", Workdir: t.TempDir(), Messages: messages, @@ -292,14 +292,14 @@ func TestAutoCompactUsesManualStrategyAndReportsAutoMode(t *testing.T) { if !result.Applied { t.Fatalf("expected auto compact applied") } - if result.Metrics.TriggerMode != string(ModeAuto) { - t.Fatalf("expected trigger mode %q, got %q", ModeAuto, result.Metrics.TriggerMode) + if result.Metrics.TriggerMode != string(ModeProactive) { + t.Fatalf("expected trigger mode %q, got %q", ModeProactive, result.Metrics.TriggerMode) } if len(generator.calls) != 1 { t.Fatalf("expected generator to run once, got %d", len(generator.calls)) } - if generator.calls[0].Mode != ModeAuto { - t.Fatalf("expected summary input mode %q, got %q", ModeAuto, generator.calls[0].Mode) + if generator.calls[0].Mode != ModeProactive { + t.Fatalf("expected summary input mode %q, got %q", ModeProactive, generator.calls[0].Mode) } if generator.calls[0].Config.ManualStrategy != config.CompactManualStrategyKeepRecent { t.Fatalf("expected auto compact to retain manual strategy, got %q", generator.calls[0].Config.ManualStrategy) diff --git a/internal/context/types.go b/internal/context/types.go index 14ab09e1..0e8a6dec 100644 --- a/internal/context/types.go +++ b/internal/context/types.go @@ -26,9 +26,8 @@ type BuildInput struct { // BuildResult is the provider-facing context produced for a single round. type BuildResult struct { - SystemPrompt string - Messages []providertypes.Message - AutoCompactSuggested bool + SystemPrompt string + Messages []providertypes.Message } // MicroCompactPolicySource 定义 context 读取工具 micro compact 策略的最小依赖。 @@ -49,7 +48,6 @@ type MicroCompactPinChecker interface { // CompactOptions controls read-time compact behavior inside the context builder. type CompactOptions struct { DisableMicroCompact bool - AutoCompactThreshold int MicroCompactRetainedToolSpans int ReadTimeMaxMessageSpans int } diff --git a/internal/memo/store.go b/internal/memo/store.go index d7ee2036..ca5f4c63 100644 --- a/internal/memo/store.go +++ b/internal/memo/store.go @@ -69,7 +69,7 @@ func (s *FileStore) LoadIndex(ctx context.Context, scope Scope) (*Index, error) s.mu.RLock() defer s.mu.RUnlock() - data, err := readFirstExistingFile(s.indexPaths(scope)) + data, err := os.ReadFile(indexPath) if errors.Is(err, os.ErrNotExist) { return &Index{}, nil } @@ -101,7 +101,7 @@ func (s *FileStore) LoadIndex(ctx context.Context, scope Scope) (*Index, error) return cloneIndex(cachedContent), nil } -// SaveIndex 将索引写入指定分层下的 MEMO.md 文件,采用临时文件 + 原子替换策略。 +// SaveIndex 将索引写入指定分层下的 MEMO.md 文件,采用临时文件加原子替换策略。 func (s *FileStore) SaveIndex(ctx context.Context, scope Scope, index *Index) error { if err := ctx.Err(); err != nil { return err @@ -116,10 +116,6 @@ func (s *FileStore) SaveIndex(ctx context.Context, scope Scope, index *Index) er s.mu.Lock() defer s.mu.Unlock() - if err := s.migrateLegacyProjectScopeLocked(scope); err != nil { - return err - } - dir := s.scopeDir(scope) if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("memo: create memo dir: %w", err) @@ -169,14 +165,14 @@ func (s *FileStore) LoadTopic(ctx context.Context, scope Scope, filename string) s.mu.RLock() defer s.mu.RUnlock() - data, err := readFirstExistingFile(s.topicPaths(scope, filename)) + data, err := os.ReadFile(s.topicPath(scope, filename)) if err != nil { return "", fmt.Errorf("memo: read topic %s: %w", filename, err) } return string(data), nil } -// SaveTopic 将内容写入指定分层下的 topic 文件,采用临时文件 + 原子替换策略。 +// SaveTopic 将内容写入指定分层下的 topic 文件,采用临时文件加原子替换策略。 func (s *FileStore) SaveTopic(ctx context.Context, scope Scope, filename string, content string) error { if err := ctx.Err(); err != nil { return err @@ -188,10 +184,6 @@ func (s *FileStore) SaveTopic(ctx context.Context, scope Scope, filename string, s.mu.Lock() defer s.mu.Unlock() - if err := s.migrateLegacyProjectScopeLocked(scope); err != nil { - return err - } - dir := s.topicsDir(scope) if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("memo: create topics dir: %w", err) @@ -223,10 +215,6 @@ func (s *FileStore) DeleteTopic(ctx context.Context, scope Scope, filename strin s.mu.Lock() defer s.mu.Unlock() - if err := s.migrateLegacyProjectScopeLocked(scope); err != nil { - return err - } - path := s.topicPath(scope, filename) if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("memo: delete topic %s: %w", filename, err) @@ -234,7 +222,7 @@ func (s *FileStore) DeleteTopic(ctx context.Context, scope Scope, filename strin return nil } -// ListTopics 列出指定分层下 topics 目录中的所有 .md 文件名。 +// ListTopics 列出指定分层 topics 目录中的全部 .md 文件名。 func (s *FileStore) ListTopics(ctx context.Context, scope Scope) ([]string, error) { if err := ctx.Err(); err != nil { return nil, err @@ -246,22 +234,22 @@ func (s *FileStore) ListTopics(ctx context.Context, scope Scope) ([]string, erro s.mu.RLock() defer s.mu.RUnlock() - seen := make(map[string]struct{}) - for _, dir := range s.topicsDirs(scope) { - entries, err := os.ReadDir(dir) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - continue - } - return nil, fmt.Errorf("memo: list topics: %w", err) - } - for _, name := range collectTopicNames(entries) { - seen[name] = struct{}{} + entries, err := os.ReadDir(s.topicsDir(scope)) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil } + return nil, fmt.Errorf("memo: list topics: %w", err) + } + + seen := make(map[string]struct{}) + for _, name := range collectTopicNames(entries) { + seen[name] = struct{}{} } if len(seen) == 0 { return nil, nil } + names := make([]string, 0, len(seen)) for name := range seen { names = append(names, name) @@ -282,20 +270,6 @@ func collectTopicNames(entries []os.DirEntry) []string { return names } -// readFirstExistingFile 按顺序读取候选路径,返回首个存在文件内容;若均不存在则返回 os.ErrNotExist。 -func readFirstExistingFile(paths []string) ([]byte, error) { - for _, path := range paths { - data, err := os.ReadFile(path) - if err == nil { - return data, nil - } - if !errors.Is(err, os.ErrNotExist) { - return nil, err - } - } - return nil, os.ErrNotExist -} - // scopeDir 返回指定 memo 分层的根目录。 func (s *FileStore) scopeDir(scope Scope) string { if scope == ScopeUser { @@ -304,120 +278,16 @@ func (s *FileStore) scopeDir(scope Scope) string { return filepath.Join(projectMemoDirectory(s.baseDir, s.workspaceRoot), string(scope)) } -// scopeDirLegacy 返回旧版本 project scope 的根目录,仅用于兼容迁移。 -func (s *FileStore) scopeDirLegacy(scope Scope) string { - if scope == ScopeProject { - return projectMemoDirectory(s.baseDir, s.workspaceRoot) - } - return "" -} - -// indexPaths 返回读取索引时的候选路径,顺序为新路径优先、旧路径兜底。 -func (s *FileStore) indexPaths(scope Scope) []string { - paths := []string{filepath.Join(s.scopeDir(scope), memoFileName)} - if legacy := s.scopeDirLegacy(scope); legacy != "" { - paths = append(paths, filepath.Join(legacy, memoFileName)) - } - return paths -} - // topicsDir 返回指定 memo 分层的 topics 目录。 func (s *FileStore) topicsDir(scope Scope) string { return filepath.Join(s.scopeDir(scope), topicsDirName) } -// topicsDirs 返回读取 topics 时的候选目录,顺序为新路径优先、旧路径兜底。 -func (s *FileStore) topicsDirs(scope Scope) []string { - dirs := []string{s.topicsDir(scope)} - if legacy := s.scopeDirLegacy(scope); legacy != "" { - dirs = append(dirs, filepath.Join(legacy, topicsDirName)) - } - return dirs -} - -// topicPath 生成指定分层下 topic 文件的安全路径,防止目录穿越。 +// topicPath 生成指定分层中 topic 文件的安全路径,防止目录穿越。 func (s *FileStore) topicPath(scope Scope, filename string) string { return filepath.Join(s.topicsDir(scope), filepath.Base(filename)) } -// topicPaths 返回读取 topic 时的候选路径,顺序为新路径优先、旧路径兜底。 -func (s *FileStore) topicPaths(scope Scope, filename string) []string { - base := filepath.Base(filename) - paths := []string{filepath.Join(s.topicsDir(scope), base)} - if legacy := s.scopeDirLegacy(scope); legacy != "" { - paths = append(paths, filepath.Join(legacy, topicsDirName, base)) - } - return paths -} - -// migrateLegacyProjectScopeLocked 在首次写入前把旧版 project 目录迁移到新目录,避免历史数据不可见。 -func (s *FileStore) migrateLegacyProjectScopeLocked(scope Scope) error { - if scope != ScopeProject { - return nil - } - - legacyDir := s.scopeDirLegacy(scope) - if legacyDir == "" { - return nil - } - - if err := os.MkdirAll(s.scopeDir(scope), 0o755); err != nil { - return fmt.Errorf("memo: create scoped dir for migration: %w", err) - } - - legacyMemo := filepath.Join(legacyDir, memoFileName) - targetMemo := filepath.Join(s.scopeDir(scope), memoFileName) - if err := moveFileIfDstMissing(legacyMemo, targetMemo); err != nil { - return fmt.Errorf("memo: migrate legacy index: %w", err) - } - - legacyTopics := filepath.Join(legacyDir, topicsDirName) - legacyEntries, err := os.ReadDir(legacyTopics) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return fmt.Errorf("memo: list legacy topics: %w", err) - } - if len(legacyEntries) == 0 { - return nil - } - - newTopics := s.topicsDir(scope) - if err := os.MkdirAll(newTopics, 0o755); err != nil { - return fmt.Errorf("memo: create scoped topics dir for migration: %w", err) - } - for _, entry := range legacyEntries { - if entry.IsDir() || filepath.Ext(entry.Name()) != ".md" { - continue - } - oldPath := filepath.Join(legacyTopics, entry.Name()) - newPath := filepath.Join(newTopics, entry.Name()) - if err := moveFileIfDstMissing(oldPath, newPath); err != nil { - return fmt.Errorf("memo: migrate legacy topic %s: %w", entry.Name(), err) - } - } - - return nil -} - -// moveFileIfDstMissing 在源文件存在且目标文件不存在时执行迁移重命名。 -func moveFileIfDstMissing(src string, dst string) error { - if _, err := os.Stat(src); err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return err - } - if _, err := os.Stat(dst); err == nil { - return nil - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - return os.Rename(src, dst) -} - -// memoDirectory 根据工作区根目录计算记忆分桶目录,复用 session 包的工作区哈希。 // globalMemoDirectory 返回全局 memo 根目录,用于存放 user 层记忆。 func globalMemoDirectory(baseDir string) string { return filepath.Join(baseDir, memoDirName) @@ -428,7 +298,7 @@ func projectMemoDirectory(baseDir string, workspaceRoot string) string { return filepath.Join(baseDir, "projects", agentsession.HashWorkspaceRoot(workspaceRoot), memoDirName) } -// validateStorageScope 校验当前 scope 是否是允许落盘的 memo 分层。 +// validateStorageScope 校验当前 scope 是否允许落盘。 func validateStorageScope(scope Scope) error { switch scope { case ScopeUser, ScopeProject: diff --git a/internal/memo/store_test.go b/internal/memo/store_test.go index de58dbf5..61015d9f 100644 --- a/internal/memo/store_test.go +++ b/internal/memo/store_test.go @@ -2,7 +2,6 @@ package memo import ( "context" - "errors" "os" "path/filepath" "strings" @@ -307,127 +306,3 @@ func TestFileStoreWritesScopesToExpectedDirectories(t *testing.T) { t.Fatalf("expected project memo to exist: %v", err) } } - -func TestFileStoreLoadIndexFallsBackToLegacyProjectPath(t *testing.T) { - store, legacyDir := newLegacyProjectStore(t) - if err := os.MkdirAll(legacyDir, 0o755); err != nil { - t.Fatalf("MkdirAll(legacy) error = %v", err) - } - index := &Index{Entries: []Entry{{Type: TypeProject, Title: "legacy"}}} - if err := os.WriteFile(filepath.Join(legacyDir, memoFileName), []byte(RenderIndex(index)), 0o644); err != nil { - t.Fatalf("WriteFile(legacy index) error = %v", err) - } - - loaded, err := store.LoadIndex(context.Background(), ScopeProject) - if err != nil { - t.Fatalf("LoadIndex() error = %v", err) - } - if len(loaded.Entries) != 1 || loaded.Entries[0].Title != "legacy" { - t.Fatalf("loaded entries = %#v", loaded.Entries) - } -} - -func TestFileStoreLoadTopicAndListTopicsFallbackToLegacyProjectPath(t *testing.T) { - store, legacyDir := newLegacyProjectStore(t) - legacyTopicsDir := filepath.Join(legacyDir, topicsDirName) - if err := os.MkdirAll(legacyTopicsDir, 0o755); err != nil { - t.Fatalf("MkdirAll(legacy topics) error = %v", err) - } - if err := os.WriteFile(filepath.Join(legacyTopicsDir, "legacy.md"), []byte("legacy"), 0o644); err != nil { - t.Fatalf("WriteFile(legacy topic) error = %v", err) - } - - content, err := store.LoadTopic(context.Background(), ScopeProject, "legacy.md") - if err != nil { - t.Fatalf("LoadTopic() error = %v", err) - } - if content != "legacy" { - t.Fatalf("LoadTopic() = %q, want %q", content, "legacy") - } - - topics, err := store.ListTopics(context.Background(), ScopeProject) - if err != nil { - t.Fatalf("ListTopics() error = %v", err) - } - if len(topics) != 1 || topics[0] != "legacy.md" { - t.Fatalf("ListTopics() = %#v, want [legacy.md]", topics) - } -} - -func TestFileStoreListTopicsMergesScopedAndLegacyProjectTopics(t *testing.T) { - store, legacyDir := newLegacyProjectStore(t) - scopedTopicsDir := store.topicsDir(ScopeProject) - legacyTopicsDir := filepath.Join(legacyDir, topicsDirName) - if err := os.MkdirAll(scopedTopicsDir, 0o755); err != nil { - t.Fatalf("MkdirAll(scoped topics) error = %v", err) - } - if err := os.MkdirAll(legacyTopicsDir, 0o755); err != nil { - t.Fatalf("MkdirAll(legacy topics) error = %v", err) - } - if err := os.WriteFile(filepath.Join(scopedTopicsDir, "scoped.md"), []byte("scoped"), 0o644); err != nil { - t.Fatalf("WriteFile(scoped topic) error = %v", err) - } - if err := os.WriteFile(filepath.Join(legacyTopicsDir, "legacy.md"), []byte("legacy"), 0o644); err != nil { - t.Fatalf("WriteFile(legacy topic) error = %v", err) - } - if err := os.WriteFile(filepath.Join(legacyTopicsDir, "scoped.md"), []byte("legacy dup"), 0o644); err != nil { - t.Fatalf("WriteFile(legacy duplicate topic) error = %v", err) - } - - topics, err := store.ListTopics(context.Background(), ScopeProject) - if err != nil { - t.Fatalf("ListTopics() error = %v", err) - } - want := []string{"legacy.md", "scoped.md"} - if len(topics) != len(want) { - t.Fatalf("len(topics) = %d, want %d, topics = %#v", len(topics), len(want), topics) - } - for i := range want { - if topics[i] != want[i] { - t.Fatalf("topics[%d] = %q, want %q (topics=%#v)", i, topics[i], want[i], topics) - } - } -} - -func TestFileStoreSaveIndexMigratesLegacyProjectData(t *testing.T) { - store, legacyDir := newLegacyProjectStore(t) - legacyTopicsDir := filepath.Join(legacyDir, topicsDirName) - if err := os.MkdirAll(legacyTopicsDir, 0o755); err != nil { - t.Fatalf("MkdirAll(legacy topics) error = %v", err) - } - if err := os.WriteFile(filepath.Join(legacyDir, memoFileName), []byte(RenderIndex(&Index{ - Entries: []Entry{{Type: TypeProject, Title: "legacy index"}}, - })), 0o644); err != nil { - t.Fatalf("WriteFile(legacy index) error = %v", err) - } - if err := os.WriteFile(filepath.Join(legacyTopicsDir, "legacy.md"), []byte("legacy topic"), 0o644); err != nil { - t.Fatalf("WriteFile(legacy topic) error = %v", err) - } - - if err := store.SaveIndex(context.Background(), ScopeProject, &Index{ - Entries: []Entry{{Type: TypeProject, Title: "new index"}}, - }); err != nil { - t.Fatalf("SaveIndex() error = %v", err) - } - - newScopeDir := store.scopeDir(ScopeProject) - if _, err := os.Stat(filepath.Join(newScopeDir, memoFileName)); err != nil { - t.Fatalf("expected scoped index after migration: %v", err) - } - if _, err := os.Stat(filepath.Join(newScopeDir, topicsDirName, "legacy.md")); err != nil { - t.Fatalf("expected scoped topic after migration: %v", err) - } - if _, err := os.Stat(filepath.Join(legacyDir, memoFileName)); !errors.Is(err, os.ErrNotExist) { - t.Fatalf("expected legacy index to be migrated, stat err = %v", err) - } - if _, err := os.Stat(filepath.Join(legacyTopicsDir, "legacy.md")); !errors.Is(err, os.ErrNotExist) { - t.Fatalf("expected legacy topic to be migrated, stat err = %v", err) - } -} - -func newLegacyProjectStore(t *testing.T) (*FileStore, string) { - t.Helper() - baseDir := t.TempDir() - workspaceRoot := "/workspace/project" - return NewFileStore(baseDir, workspaceRoot), projectMemoDirectory(baseDir, workspaceRoot) -} diff --git a/internal/provider/anthropic/driver.go b/internal/provider/anthropic/driver.go index b04e1a3c..a5ec554e 100644 --- a/internal/provider/anthropic/driver.go +++ b/internal/provider/anthropic/driver.go @@ -25,7 +25,10 @@ func Driver() provider.DriverDefinition { return New(cfg) }, Discover: func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { - client := newSDKClient(cfg) + client, err := newSDKClient(cfg) + if err != nil { + return nil, err + } descriptors := make([]providertypes.ModelDescriptor, 0, 64) pager := client.Models.ListAutoPaging(ctx, anthropic.ModelListParams{}) @@ -56,8 +59,11 @@ func Driver() provider.DriverDefinition { } // newSDKClient 构造 Anthropic SDK 客户端,供生成与模型发现链路共享连接配置。 -func newSDKClient(cfg provider.RuntimeConfig) anthropic.Client { - apiKey := strings.TrimSpace(cfg.APIKey) +func newSDKClient(cfg provider.RuntimeConfig) (anthropic.Client, error) { + apiKey, err := cfg.ResolveAPIKeyValue() + if err != nil { + return anthropic.Client{}, err + } httpClient := &http.Client{ Timeout: 90 * time.Second, @@ -69,7 +75,7 @@ func newSDKClient(cfg provider.RuntimeConfig) anthropic.Client { if strings.TrimSpace(cfg.BaseURL) != "" { options = append(options, anthroption.WithBaseURL(strings.TrimSpace(cfg.BaseURL))) } - return anthropic.NewClient(options...) + return anthropic.NewClient(options...), nil } // validateCatalogIdentity 在 SDK 模式下不再限制 endpoint 相关字段。 diff --git a/internal/provider/anthropic/driver_test.go b/internal/provider/anthropic/driver_test.go index 3c5d866c..467336b6 100644 --- a/internal/provider/anthropic/driver_test.go +++ b/internal/provider/anthropic/driver_test.go @@ -15,9 +15,10 @@ func TestDriverBuild(t *testing.T) { driver := Driver() p, err := driver.Build(context.Background(), provider.RuntimeConfig{ - Driver: DriverName, - BaseURL: "https://api.anthropic.com/v1", - APIKey: "test-key", + Driver: DriverName, + BaseURL: "https://api.anthropic.com/v1", + APIKeyEnv: "TEST_ANTHROPIC_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }) if err != nil { t.Fatalf("Build() error = %v", err) @@ -56,7 +57,8 @@ func TestDriverDiscover(t *testing.T) { models, err := driver.Discover(context.Background(), provider.RuntimeConfig{ Driver: DriverName, BaseURL: server.URL, - APIKey: "test-key", + APIKeyEnv: "TEST_ANTHROPIC_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", }) if err != nil { diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index 1d3a4763..ffbb548a 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -24,22 +24,35 @@ type toolCallState struct { // Provider 封装 Anthropic messages 协议的请求发送与流式解析。 type Provider struct { - cfg provider.RuntimeConfig - client anthropic.Client + cfg provider.RuntimeConfig +} + +// EstimateInputTokens 基于 Anthropic 最终请求结构做本地输入 token 估算。 +func (p *Provider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + params, err := BuildRequest(ctx, p.cfg, req) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + tokens, err := provider.EstimateSerializedPayloadTokens(params) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + Accurate: false, + }, nil } // New 创建 Anthropic provider 实例,并初始化官方 SDK 客户端。 func New(cfg provider.RuntimeConfig) (*Provider, error) { - if strings.TrimSpace(cfg.APIKey) == "" { - return nil, errors.New(errorPrefix + "api key is empty") + if strings.TrimSpace(cfg.APIKeyEnv) == "" { + return nil, errors.New(errorPrefix + "api_key_env is empty") } - - client := newSDKClient(cfg) - - return &Provider{ - cfg: cfg, - client: client, - }, nil + return &Provider{cfg: cfg}, nil } // Generate 发起 Anthropic 流式请求,并将 typed stream 转为统一事件。 @@ -49,7 +62,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque return err } - streamReader := p.client.Messages.NewStreaming(ctx, params) + client, err := newSDKClient(p.cfg) + if err != nil { + return err + } + streamReader := client.Messages.NewStreaming(ctx, params) defer func() { _ = streamReader.Close() }() var ( diff --git a/internal/provider/anthropic/provider_test.go b/internal/provider/anthropic/provider_test.go index 6fb7b06b..49c36ac9 100644 --- a/internal/provider/anthropic/provider_test.go +++ b/internal/provider/anthropic/provider_test.go @@ -38,10 +38,11 @@ func TestProviderGenerate(t *testing.T) { defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverAnthropic, - BaseURL: server.URL, - DefaultModel: "claude-3-7-sonnet", - APIKey: "test-key", + Driver: provider.DriverAnthropic, + BaseURL: server.URL, + DefaultModel: "claude-3-7-sonnet", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }) if err != nil { t.Fatalf("New() error = %v", err) @@ -105,7 +106,8 @@ func TestNewAcceptsCustomChatEndpointPath(t *testing.T) { Driver: provider.DriverAnthropic, BaseURL: "https://api.anthropic.com/v1", DefaultModel: "claude-3-7-sonnet", - APIKey: "test-key", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), ChatEndpointPath: "/custom/messages", }) if err != nil { diff --git a/internal/provider/catalog/service_test.go b/internal/provider/catalog/service_test.go index 07ebe084..d8a89351 100644 --- a/internal/provider/catalog/service_test.go +++ b/internal/provider/catalog/service_test.go @@ -186,6 +186,9 @@ func TestListProviderModelsReturnsDiscoveryErrorOnCacheMiss(t *testing.T) { t.Setenv(testAPIKeyEnv, "") service := NewService("", newRegistry(t, "openaicompat", func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + if _, err := cfg.ResolveAPIKeyValue(); err != nil { + return nil, err + } return nil, nil }), newMemoryStore()) @@ -433,6 +436,9 @@ func TestDiscoverAndPersistFailurePaths(t *testing.T) { t.Run("resolve provider config failure", func(t *testing.T) { service := NewService("", newRegistry(t, openaicompat.DriverName, func(ctx context.Context, cfg provider.RuntimeConfig) ([]providertypes.ModelDescriptor, error) { + if _, err := cfg.ResolveAPIKeyValue(); err != nil { + return nil, err + } return nil, nil }), newMemoryStore()) @@ -701,6 +707,17 @@ func containsModelDescriptorID(models []providertypes.ModelDescriptor, modelID s type catalogTestProvider struct{} +func (catalogTestProvider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{ + EstimateSource: provider.EstimateSourceLocal, + }, nil +} + func (catalogTestProvider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { return nil } diff --git a/internal/provider/conformance/conformance_test.go b/internal/provider/conformance/conformance_test.go index 4c065525..bf4235d9 100644 --- a/internal/provider/conformance/conformance_test.go +++ b/internal/provider/conformance/conformance_test.go @@ -37,7 +37,8 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { Driver: provider.DriverOpenAICompat, BaseURL: baseURL, DefaultModel: "gpt-4.1", - APIKey: "test-key", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), ChatEndpointPath: "/chat/completions", } }, @@ -59,7 +60,8 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { Driver: provider.DriverGemini, BaseURL: baseURL, DefaultModel: "gemini-2.5-flash", - APIKey: "test-key", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), ChatEndpointPath: "/models", } }, @@ -79,7 +81,8 @@ func TestGenerateContractAcrossDrivers(t *testing.T) { Driver: provider.DriverAnthropic, BaseURL: baseURL, DefaultModel: "claude-3-7-sonnet", - APIKey: "test-key", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), ChatEndpointPath: "/messages", } }, @@ -174,7 +177,8 @@ func TestDiscoverContractAcrossDrivers(t *testing.T) { Name: "openai", Driver: provider.DriverOpenAICompat, BaseURL: baseURL, - APIKey: "test-key", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", } }, @@ -190,7 +194,8 @@ func TestDiscoverContractAcrossDrivers(t *testing.T) { Name: "gemini", Driver: provider.DriverGemini, BaseURL: baseURL, - APIKey: "test-key", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", } }, @@ -206,7 +211,8 @@ func TestDiscoverContractAcrossDrivers(t *testing.T) { Name: "anthropic", Driver: provider.DriverAnthropic, BaseURL: baseURL, - APIKey: "test-key", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", } }, @@ -258,7 +264,8 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { Driver: provider.DriverOpenAICompat, BaseURL: baseURL, DefaultModel: "gpt-4.1", - APIKey: "test-key", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), ChatEndpointPath: "/chat/completions", } }, @@ -273,7 +280,8 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { Driver: provider.DriverGemini, BaseURL: baseURL, DefaultModel: "gemini-2.5-flash", - APIKey: "test-key", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), ChatEndpointPath: "/models", } }, @@ -288,7 +296,8 @@ func TestGenerateErrorClassificationAcrossDrivers(t *testing.T) { Driver: provider.DriverAnthropic, BaseURL: baseURL, DefaultModel: "claude-3-7-sonnet", - APIKey: "test-key", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), ChatEndpointPath: "/messages", } }, diff --git a/internal/provider/contracts.go b/internal/provider/contracts.go index e5d9ef1f..790bc2b3 100644 --- a/internal/provider/contracts.go +++ b/internal/provider/contracts.go @@ -2,18 +2,26 @@ package provider import ( "context" + "errors" + "fmt" + "os" + "strings" providertypes "neo-code/internal/provider/types" "neo-code/internal/session" ) +// APIKeyResolver 定义 provider 在真正发请求前解析 API Key 的能力。 +type APIKeyResolver func(envName string) (string, error) + // RuntimeConfig 表示 provider 构建与模型发现使用的最小运行时输入。 type RuntimeConfig struct { Name string Driver string BaseURL string DefaultModel string - APIKey string + APIKeyEnv string + APIKeyResolver APIKeyResolver SessionAssetPolicy session.AssetPolicy RequestAssetBudget RequestAssetBudget ChatAPIMode string @@ -21,8 +29,49 @@ type RuntimeConfig struct { DiscoveryEndpointPath string } +// ResolveAPIKeyValue 在 provider 即将发起请求前解析当前配置引用的 API Key。 +func (c RuntimeConfig) ResolveAPIKeyValue() (string, error) { + envName := strings.TrimSpace(c.APIKeyEnv) + if envName == "" { + if strings.TrimSpace(c.Name) == "" { + return "", errors.New("provider runtime config: api_key_env is empty") + } + return "", fmt.Errorf("provider runtime config: provider %q api_key_env is empty", strings.TrimSpace(c.Name)) + } + + if c.APIKeyResolver != nil { + value, err := c.APIKeyResolver(envName) + if err != nil { + return "", err + } + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "", fmt.Errorf("provider runtime config: environment variable %s is empty", envName) + } + return trimmed, nil + } + + value := strings.TrimSpace(os.Getenv(envName)) + if value == "" { + return "", fmt.Errorf("provider runtime config: environment variable %s is empty", envName) + } + return value, nil +} + +// StaticAPIKeyResolver 返回一个仅供测试和受控注入场景使用的固定密钥解析器。 +func StaticAPIKeyResolver(apiKey string) APIKeyResolver { + trimmed := strings.TrimSpace(apiKey) + return func(_ string) (string, error) { + if trimmed == "" { + return "", errors.New("provider runtime config: static api key is empty") + } + return trimmed, nil + } +} + // Provider 定义模型生成能力,通过 channel 推送流式事件给上层消费。 type Provider interface { + EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error } diff --git a/internal/provider/estimate.go b/internal/provider/estimate.go new file mode 100644 index 00000000..07e0c9d8 --- /dev/null +++ b/internal/provider/estimate.go @@ -0,0 +1,29 @@ +package provider + +import ( + "encoding/json" + "math" +) + +const ( + EstimateSourceNative = "native" + EstimateSourceLocal = "local" + localEstimateSlack = 1.15 +) + +// EstimateSerializedPayloadTokens 基于最终协议载荷的序列化结果估算输入 token 数。 +func EstimateSerializedPayloadTokens(payload any) (int, error) { + encoded, err := json.Marshal(payload) + if err != nil { + return 0, err + } + return EstimateTextTokens(string(encoded)), nil +} + +// EstimateTextTokens 对文本做保守放大的本地 token 估算,供 provider 预算预检复用。 +func EstimateTextTokens(text string) int { + if text == "" { + return 0 + } + return int(math.Ceil(float64(len([]byte(text))) / 4.0 * localEstimateSlack)) +} diff --git a/internal/provider/gemini/driver.go b/internal/provider/gemini/driver.go index bd0272f0..0c627970 100644 --- a/internal/provider/gemini/driver.go +++ b/internal/provider/gemini/driver.go @@ -66,11 +66,15 @@ func Driver() provider.DriverDefinition { // newSDKClient 构造 Gemini SDK 客户端,供生成与模型发现链路共享连接配置。 func newSDKClient(ctx context.Context, cfg provider.RuntimeConfig) (*genai.Client, error) { + apiKey, err := cfg.ResolveAPIKeyValue() + if err != nil { + return nil, err + } httpClient := &http.Client{ Timeout: 90 * time.Second, } clientConfig := &genai.ClientConfig{ - APIKey: strings.TrimSpace(cfg.APIKey), + APIKey: apiKey, Backend: genai.BackendGeminiAPI, HTTPClient: httpClient, } diff --git a/internal/provider/gemini/driver_test.go b/internal/provider/gemini/driver_test.go index bfe59747..f0ebee9b 100644 --- a/internal/provider/gemini/driver_test.go +++ b/internal/provider/gemini/driver_test.go @@ -30,7 +30,8 @@ func TestDriverDiscover(t *testing.T) { models, err := driver.Discover(context.Background(), provider.RuntimeConfig{ Driver: DriverName, BaseURL: server.URL, - APIKey: "test-key", + APIKeyEnv: "TEST_GEMINI_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", }) if err != nil { @@ -46,9 +47,10 @@ func TestDriverBuild(t *testing.T) { driver := Driver() p, err := driver.Build(context.Background(), provider.RuntimeConfig{ - Driver: DriverName, - BaseURL: "https://generativelanguage.googleapis.com/v1beta", - APIKey: "test-key", + Driver: DriverName, + BaseURL: "https://generativelanguage.googleapis.com/v1beta", + APIKeyEnv: "TEST_GEMINI_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }) if err != nil { t.Fatalf("Build() error = %v", err) diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index c06d08a1..789fd8ca 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -18,26 +18,46 @@ const errorPrefix = "gemini provider: " // Provider 封装 Gemini native 协议的请求发送与流式响应解析。 type Provider struct { - cfg provider.RuntimeConfig - client *genai.Client + cfg provider.RuntimeConfig } -// New 创建 Gemini native provider 实例,并初始化官方 SDK 客户端。 -func New(cfg provider.RuntimeConfig) (*Provider, error) { - if strings.TrimSpace(cfg.APIKey) == "" { - return nil, errors.New(errorPrefix + "api key is empty") +// EstimateInputTokens 基于 Gemini 最终请求结构做本地输入 token 估算。 +func (p *Provider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + model, contents, genConfig, err := BuildRequest(ctx, p.cfg, req) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + payload := struct { + Model string `json:"model"` + Contents []*genai.Content `json:"contents"` + Config *genai.GenerateContentConfig `json:"config,omitempty"` + }{ + Model: model, + Contents: contents, + Config: genConfig, } - client, err := newSDKClient(context.Background(), cfg) + tokens, err := provider.EstimateSerializedPayloadTokens(payload) if err != nil { - return nil, err + return providertypes.BudgetEstimate{}, err } - - return &Provider{ - cfg: cfg, - client: client, + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + Accurate: false, }, nil } +// New 创建 Gemini native provider 实例,并初始化官方 SDK 客户端。 +func New(cfg provider.RuntimeConfig) (*Provider, error) { + if strings.TrimSpace(cfg.APIKeyEnv) == "" { + return nil, errors.New(errorPrefix + "api_key_env is empty") + } + return &Provider{cfg: cfg}, nil +} + // Generate 发起 Gemini 流式请求,并将 SDK chunk 转为统一流式事件。 func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { model, contents, config, err := BuildRequest(ctx, p.cfg, req) @@ -48,6 +68,10 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque if normalizedModel == "" { return errors.New(errorPrefix + "model is empty") } + client, err := newSDKClient(ctx, p.cfg) + if err != nil { + return err + } var ( finishReason string @@ -55,7 +79,7 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque hasPayload bool callSeq int ) - for chunk, streamErr := range p.client.Models.GenerateContentStream(ctx, normalizedModel, contents, config) { + for chunk, streamErr := range client.Models.GenerateContentStream(ctx, normalizedModel, contents, config) { if streamErr != nil { if ctxErr := ctx.Err(); ctxErr != nil { return ctxErr diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index b41c4057..37eaedd9 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -29,10 +29,11 @@ func TestProviderGenerate(t *testing.T) { defer server.Close() p, err := New(provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: server.URL, - DefaultModel: "gemini-2.5-flash", - APIKey: "test-key", + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }) if err != nil { t.Fatalf("New() error = %v", err) @@ -96,7 +97,8 @@ func TestNewAcceptsCustomChatEndpointPath(t *testing.T) { Driver: provider.DriverGemini, BaseURL: "https://generativelanguage.googleapis.com/v1beta", DefaultModel: "gemini-2.5-flash", - APIKey: "test-key", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), ChatEndpointPath: "/custom/models", }) if err != nil { @@ -111,10 +113,11 @@ func TestBuildRequestSupportsImageParts(t *testing.T) { t.Parallel() cfg := provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: "https://generativelanguage.googleapis.com/v1beta", - DefaultModel: "gemini-2.5-flash", - APIKey: "test-key", + Driver: provider.DriverGemini, + BaseURL: "https://generativelanguage.googleapis.com/v1beta", + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), } model, contents, requestConfig, err := BuildRequest(context.Background(), cfg, providertypes.GenerateRequest{ Messages: []providertypes.Message{ @@ -164,10 +167,11 @@ func TestBuildRequestRejectsSessionAssetWithoutReader(t *testing.T) { t.Parallel() cfg := provider.RuntimeConfig{ - Driver: provider.DriverGemini, - BaseURL: "https://generativelanguage.googleapis.com/v1beta", - DefaultModel: "gemini-2.5-flash", - APIKey: "test-key", + Driver: provider.DriverGemini, + BaseURL: "https://generativelanguage.googleapis.com/v1beta", + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), } _, _, _, err := BuildRequest(context.Background(), cfg, providertypes.GenerateRequest{ Messages: []providertypes.Message{ diff --git a/internal/provider/generate_test.go b/internal/provider/generate_test.go index 73649d12..0658e187 100644 --- a/internal/provider/generate_test.go +++ b/internal/provider/generate_test.go @@ -15,6 +15,18 @@ type stubTextGenProvider struct { generate func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error } +func (s *stubTextGenProvider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + _ = ctx + return providertypes.BudgetEstimate{ + EstimatedInputTokens: provider.EstimateTextTokens(req.SystemPrompt + renderEstimateMessages(req.Messages)), + EstimateSource: provider.EstimateSourceLocal, + Accurate: false, + }, nil +} + func (s *stubTextGenProvider) Generate( ctx context.Context, req providertypes.GenerateRequest, @@ -27,6 +39,14 @@ func (s *stubTextGenProvider) Generate( return nil } +func renderEstimateMessages(messages []providertypes.Message) string { + var builder strings.Builder + for _, message := range messages { + builder.WriteString(provider.RenderMessageText(message.Parts)) + } + return builder.String() +} + func TestGenerateTextSuccess(t *testing.T) { providerStub := &stubTextGenProvider{ generate: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { diff --git a/internal/provider/openaicompat/common_test.go b/internal/provider/openaicompat/common_test.go index 71a1704a..350e726a 100644 --- a/internal/provider/openaicompat/common_test.go +++ b/internal/provider/openaicompat/common_test.go @@ -12,8 +12,9 @@ func TestValidateRuntimeConfig(t *testing.T) { t.Run("empty base url", func(t *testing.T) { t.Parallel() err := validateRuntimeConfig(provider.RuntimeConfig{ - BaseURL: "", - APIKey: "test-key", + BaseURL: "", + APIKeyEnv: "TEST_OPENAI_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }) if err == nil || err.Error() != errorPrefix+"base url is empty" { t.Fatalf("expected base url error, got %v", err) @@ -23,10 +24,10 @@ func TestValidateRuntimeConfig(t *testing.T) { t.Run("empty api key", func(t *testing.T) { t.Parallel() err := validateRuntimeConfig(provider.RuntimeConfig{ - BaseURL: "https://api.example.com/v1", - APIKey: " ", + BaseURL: "https://api.example.com/v1", + APIKeyEnv: " ", }) - if err == nil || err.Error() != errorPrefix+"api key is empty" { + if err == nil || err.Error() != errorPrefix+"api_key_env is empty" { t.Fatalf("expected api key error, got %v", err) } }) @@ -34,8 +35,9 @@ func TestValidateRuntimeConfig(t *testing.T) { t.Run("valid config", func(t *testing.T) { t.Parallel() err := validateRuntimeConfig(provider.RuntimeConfig{ - BaseURL: " https://api.example.com/v1 ", - APIKey: " test-key ", + BaseURL: " https://api.example.com/v1 ", + APIKeyEnv: "TEST_OPENAI_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver(" test-key "), }) if err != nil { t.Fatalf("expected valid config, got %v", err) diff --git a/internal/provider/openaicompat/discovery_http.go b/internal/provider/openaicompat/discovery_http.go index 295c024b..191f236f 100644 --- a/internal/provider/openaicompat/discovery_http.go +++ b/internal/provider/openaicompat/discovery_http.go @@ -46,11 +46,16 @@ func RequestConfigFromRuntime(cfg provider.RuntimeConfig) (RequestConfig, error) return RequestConfig{}, provider.NewDiscoveryConfigError(err.Error()) } + apiKey, err := cfg.ResolveAPIKeyValue() + if err != nil { + return RequestConfig{}, err + } + return RequestConfig{ Driver: cfg.Driver, BaseURL: cfg.BaseURL, EndpointPath: discoveryEndpointPath, - APIKey: cfg.APIKey, + APIKey: apiKey, }, nil } diff --git a/internal/provider/openaicompat/discovery_http_test.go b/internal/provider/openaicompat/discovery_http_test.go index 4dc985c2..510fbd90 100644 --- a/internal/provider/openaicompat/discovery_http_test.go +++ b/internal/provider/openaicompat/discovery_http_test.go @@ -362,7 +362,8 @@ func TestRequestConfigFromRuntime(t *testing.T) { cfg, err := RequestConfigFromRuntime(provider.RuntimeConfig{ Driver: provider.DriverOpenAICompat, BaseURL: "https://api.openai.com/v1", - APIKey: "test-key", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", }) if err != nil { @@ -399,8 +400,10 @@ func TestRequestConfigFromRuntimeDefaultsEmptyDiscoveryPath(t *testing.T) { t.Parallel() cfg, err := RequestConfigFromRuntime(provider.RuntimeConfig{ - Driver: provider.DriverOpenAICompat, - BaseURL: "https://api.example.com/v1", + Driver: provider.DriverOpenAICompat, + BaseURL: "https://api.example.com/v1", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }) if err != nil { t.Fatalf("RequestConfigFromRuntime() error = %v", err) diff --git a/internal/provider/openaicompat/driver_internal_test.go b/internal/provider/openaicompat/driver_internal_test.go index 8084192a..da6ab742 100644 --- a/internal/provider/openaicompat/driver_internal_test.go +++ b/internal/provider/openaicompat/driver_internal_test.go @@ -31,7 +31,8 @@ func TestDriverClosuresAndSupportedProtocol(t *testing.T) { Driver: DriverName, BaseURL: server.URL, DefaultModel: "gpt-4.1", - APIKey: "test-key", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", } driver := Driver() @@ -101,7 +102,8 @@ func TestFetchModelsAndGenerateExtraBranches(t *testing.T) { Name: DriverName, Driver: DriverName, BaseURL: "://bad", - APIKey: "test-key", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", }, client: &http.Client{}, @@ -116,7 +118,8 @@ func TestFetchModelsAndGenerateExtraBranches(t *testing.T) { Name: DriverName, Driver: DriverName, BaseURL: "https://api.example.com/v1", - APIKey: "test-key", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "https://api.example.com/models", }, client: &http.Client{}, @@ -139,7 +142,8 @@ func TestFetchModelsAndGenerateExtraBranches(t *testing.T) { Name: DriverName, Driver: DriverName, BaseURL: server.URL, - APIKey: "test-key", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), DiscoveryEndpointPath: "/models", }, client: server.Client(), @@ -154,11 +158,12 @@ func TestFetchModelsAndGenerateExtraBranches(t *testing.T) { } p2, err := New(provider.RuntimeConfig{ - Name: DriverName, - Driver: provider.DriverAnthropic, - BaseURL: "https://api.example.com/v1", - DefaultModel: "gpt-4.1", - APIKey: "test-key", + Name: DriverName, + Driver: provider.DriverAnthropic, + BaseURL: "https://api.example.com/v1", + DefaultModel: "gpt-4.1", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }) if err != nil { t.Fatalf("New() error = %v", err) diff --git a/internal/provider/openaicompat/generate_sdk.go b/internal/provider/openaicompat/generate_sdk.go index d094e85e..890997d6 100644 --- a/internal/provider/openaicompat/generate_sdk.go +++ b/internal/provider/openaicompat/generate_sdk.go @@ -29,7 +29,10 @@ func (p *Provider) generateSDKChatCompletions( return err } - client := p.newSDKClient() + client, err := p.newSDKClient() + if err != nil { + return err + } params := convertToChatCompletionParams(payload) stream := client.Chat.Completions.NewStreaming(ctx, params) @@ -249,7 +252,10 @@ func (p *Provider) generateChatCompletionsWithCompatibleStream( return fmt.Errorf("%sinvalid chat endpoint configuration: %w", errorPrefix, err) } - client := p.newSDKClient() + client, err := p.newSDKClient() + if err != nil { + return err + } var resp *http.Response err = client.Post( ctx, @@ -286,7 +292,10 @@ func (p *Provider) generateSDKResponses( return err } - client := p.newSDKClient() + client, err := p.newSDKClient() + if err != nil { + return err + } var resp *http.Response err = client.Post( ctx, @@ -316,12 +325,16 @@ func wrapSDKRequestError(err error, action string) error { return fmt.Errorf("%s%s: %w", errorPrefix, strings.TrimSpace(action), err) } -func (p *Provider) newSDKClient() openai.Client { +func (p *Provider) newSDKClient() (openai.Client, error) { + apiKey, err := p.cfg.ResolveAPIKeyValue() + if err != nil { + return openai.Client{}, err + } return openai.NewClient( option.WithHTTPClient(p.client), - option.WithAPIKey(strings.TrimSpace(p.cfg.APIKey)), + option.WithAPIKey(apiKey), option.WithBaseURL(strings.TrimRight(strings.TrimSpace(p.cfg.BaseURL), "/")), - ) + ), nil } func resolveChatEndpoint(cfg provider.RuntimeConfig) (string, error) { diff --git a/internal/provider/openaicompat/openaicompat_test.go b/internal/provider/openaicompat/openaicompat_test.go index cb839d51..34a00001 100644 --- a/internal/provider/openaicompat/openaicompat_test.go +++ b/internal/provider/openaicompat/openaicompat_test.go @@ -53,12 +53,13 @@ func TestNewValidationErrors(t *testing.T) { t.Run("empty api key returns error", func(t *testing.T) { t.Parallel() cfg := resolvedConfig("", "") - cfg.APIKey = "" + cfg.APIKeyEnv = "" + cfg.APIKeyResolver = nil _, err := New(cfg) if err == nil { t.Fatal("expected error for empty api key") } - if !strings.Contains(err.Error(), "api key is empty") { + if !strings.Contains(err.Error(), "api_key_env is empty") { t.Fatalf("expected api key error, got: %v", err) } }) @@ -66,7 +67,8 @@ func TestNewValidationErrors(t *testing.T) { t.Run("whitespace-only api key returns error", func(t *testing.T) { t.Parallel() cfg := resolvedConfig("", "") - cfg.APIKey = " " + cfg.APIKeyEnv = " " + cfg.APIKeyResolver = nil _, err := New(cfg) if err == nil { t.Fatal("expected error for whitespace-only api key") @@ -76,11 +78,12 @@ func TestNewValidationErrors(t *testing.T) { t.Run("invalid config validate fails", func(t *testing.T) { t.Parallel() cfg := provider.RuntimeConfig{ - Name: DriverName, - Driver: DriverName, - BaseURL: "", - DefaultModel: config.OpenAIDefaultModel, - APIKey: "test-key", + Name: DriverName, + Driver: DriverName, + BaseURL: "", + DefaultModel: config.OpenAIDefaultModel, + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), } _, err := New(cfg) if err == nil { @@ -391,11 +394,12 @@ func TestBuildRequest_EmptyModelReturnsError(t *testing.T) { // to test BuildRequest's own empty-model check. p := &Provider{ cfg: provider.RuntimeConfig{ - Name: DriverName, - Driver: DriverName, - BaseURL: config.OpenAIDefaultBaseURL, - DefaultModel: "", - APIKey: "test-key", + Name: DriverName, + Driver: DriverName, + BaseURL: config.OpenAIDefaultBaseURL, + DefaultModel: "", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }, client: &http.Client{}, } @@ -619,11 +623,12 @@ func resolvedConfig(baseURL, model string) provider.RuntimeConfig { model = config.OpenAIDefaultModel } return provider.RuntimeConfig{ - Name: DriverName, - Driver: DriverName, - BaseURL: baseURL, - DefaultModel: model, - APIKey: "test-key", + Name: DriverName, + Driver: DriverName, + BaseURL: baseURL, + DefaultModel: model, + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), } } diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index 172fbbb9..2dea2883 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -9,6 +9,8 @@ import ( "time" "neo-code/internal/provider" + "neo-code/internal/provider/openaicompat/chatcompletions" + "neo-code/internal/provider/openaicompat/responses" providertypes "neo-code/internal/provider/types" ) @@ -26,8 +28,8 @@ func validateRuntimeConfig(cfg provider.RuntimeConfig) error { if strings.TrimSpace(cfg.BaseURL) == "" { return errors.New(errorPrefix + "base url is empty") } - if strings.TrimSpace(cfg.APIKey) == "" { - return errors.New(errorPrefix + "api key is empty") + if strings.TrimSpace(cfg.APIKeyEnv) == "" { + return errors.New(errorPrefix + "api_key_env is empty") } return nil } @@ -38,6 +40,45 @@ type Provider struct { client *http.Client } +// EstimateInputTokens 基于 OpenAI-compatible 最终请求结构做本地输入 token 估算。 +func (p *Provider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + mode, err := resolveExecutionMode(p.cfg) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + + var tokens int + switch mode { + case executionModeCompletions: + payload, buildErr := chatcompletions.BuildRequest(ctx, p.cfg, req) + if buildErr != nil { + return providertypes.BudgetEstimate{}, buildErr + } + tokens, err = provider.EstimateSerializedPayloadTokens(payload) + case executionModeResponses: + payload, buildErr := responses.BuildRequest(ctx, p.cfg, req) + if buildErr != nil { + return providertypes.BudgetEstimate{}, buildErr + } + tokens, err = provider.EstimateSerializedPayloadTokens(payload) + default: + return providertypes.BudgetEstimate{}, provider.NewDiscoveryConfigError( + fmt.Sprintf("openaicompat provider: driver %q resolved unsupported execution mode %q", p.cfg.Driver, mode), + ) + } + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + Accurate: false, + }, nil +} + // buildOptions 控制 provider 构建时的可选注入项。 type buildOptions struct { transport http.RoundTripper diff --git a/internal/provider/registry_test.go b/internal/provider/registry_test.go index 08d83bed..446b03eb 100644 --- a/internal/provider/registry_test.go +++ b/internal/provider/registry_test.go @@ -13,6 +13,17 @@ import ( type stubProvider struct{} +func (stubProvider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{ + EstimateSource: provider.EstimateSourceLocal, + }, nil +} + func (stubProvider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { return nil } @@ -44,11 +55,12 @@ func TestRegistryBuildsRegisteredDriverCaseInsensitively(t *testing.T) { registry := newTestRegistry(t) got, err := registry.Build(context.Background(), provider.RuntimeConfig{ - Name: "openai-main", - Driver: "OPENAICOMPAT", - BaseURL: config.OpenAIDefaultBaseURL, - DefaultModel: config.OpenAIDefaultModel, - APIKey: "test-key", + Name: "openai-main", + Driver: "OPENAICOMPAT", + BaseURL: config.OpenAIDefaultBaseURL, + DefaultModel: config.OpenAIDefaultModel, + APIKeyEnv: "TEST_OPENAI_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), }) if err != nil { t.Fatalf("Build() error = %v", err) diff --git a/internal/provider/types/usage.go b/internal/provider/types/usage.go index 3a433baf..4250c6d1 100644 --- a/internal/provider/types/usage.go +++ b/internal/provider/types/usage.go @@ -6,3 +6,10 @@ type Usage struct { OutputTokens int `json:"output_tokens"` TotalTokens int `json:"total_tokens"` } + +// BudgetEstimate 描述 provider 对冻结请求输入 token 的估算结果。 +type BudgetEstimate struct { + EstimatedInputTokens int `json:"estimated_input_tokens"` + EstimateSource string `json:"estimate_source"` + Accurate bool `json:"accurate"` +} diff --git a/internal/runtime/budget_models.go b/internal/runtime/budget_models.go new file mode 100644 index 00000000..58fd03ac --- /dev/null +++ b/internal/runtime/budget_models.go @@ -0,0 +1,116 @@ +package runtime + +import ( + "time" + + "neo-code/internal/config" + "neo-code/internal/provider" + providertypes "neo-code/internal/provider/types" + "neo-code/internal/runtime/controlplane" +) + +// TurnBudgetSnapshot 冻结单次预算尝试需要的 request、provider 配置与预算事实上下文。 +type TurnBudgetSnapshot struct { + ID controlplane.TurnBudgetID + Config config.Config + ProviderConfig provider.RuntimeConfig + Model string + Workdir string + ToolTimeout time.Duration + PromptBudget int + BudgetSource string + CompactCount int + NoProgressStreakLimit int + RepeatCycleStreakLimit int + Request providertypes.GenerateRequest +} + +// TurnBudgetUsageObservation 描述 provider 调用完成后可被证明的 usage 观察结果。 +type TurnBudgetUsageObservation struct { + ID controlplane.TurnBudgetID + InputTokens int + OutputTokens int + InputObserved bool + OutputObserved bool +} + +// turnProviderOutput 汇总一次 provider 调用返回的 assistant 消息与预算 usage observation。 +type turnProviderOutput struct { + assistant providertypes.Message + usageObservation TurnBudgetUsageObservation +} + +// ledgerReconcileResult 描述单轮 usage 调和后写入账本与事件的结果。 +type ledgerReconcileResult struct { + inputTokens int + inputSource string + outputTokens int + outputSource string + hasUnknownUsage bool +} + +// newTurnBudgetSnapshot 构造本次发送尝试的冻结预算快照。 +func newTurnBudgetSnapshot( + attemptSeq int, + cfg config.Config, + providerConfig provider.RuntimeConfig, + model string, + workdir string, + toolTimeout time.Duration, + promptBudget int, + budgetSource string, + compactCount int, + noProgressStreakLimit int, + repeatCycleStreakLimit int, + request providertypes.GenerateRequest, +) TurnBudgetSnapshot { + if attemptSeq <= 0 { + attemptSeq = 1 + } + return TurnBudgetSnapshot{ + ID: controlplane.TurnBudgetID{ + AttemptSeq: attemptSeq, + RequestHash: computeRequestHash(request), + }, + Config: cfg, + ProviderConfig: providerConfig, + Model: model, + Workdir: workdir, + ToolTimeout: toolTimeout, + PromptBudget: promptBudget, + BudgetSource: budgetSource, + CompactCount: compactCount, + NoProgressStreakLimit: noProgressStreakLimit, + RepeatCycleStreakLimit: repeatCycleStreakLimit, + Request: request, + } +} + +// newTurnBudgetEstimate 将 provider signal 包装为 runtime 预算主干估算对象。 +func newTurnBudgetEstimate( + id controlplane.TurnBudgetID, + estimate providertypes.BudgetEstimate, +) controlplane.TurnBudgetEstimate { + return controlplane.TurnBudgetEstimate{ + ID: id, + EstimatedInputTokens: estimate.EstimatedInputTokens, + EstimateSource: estimate.EstimateSource, + Accurate: estimate.Accurate, + } +} + +// newTurnBudgetUsageObservation 构造单次 provider 调用对应的 usage observation。 +func newTurnBudgetUsageObservation( + id controlplane.TurnBudgetID, + inputTokens int, + outputTokens int, + observed bool, +) TurnBudgetUsageObservation { + return TurnBudgetUsageObservation{ + ID: id, + InputTokens: inputTokens, + OutputTokens: outputTokens, + InputObserved: observed, + OutputObserved: observed, + } +} diff --git a/internal/runtime/compact.go b/internal/runtime/compact.go index 2de89682..afb3bc36 100644 --- a/internal/runtime/compact.go +++ b/internal/runtime/compact.go @@ -122,6 +122,7 @@ func (s *Service) runCompactForSession( originalTaskState := session.TaskState.Clone() originalTokenInputTotal := session.TokenInputTotal originalTokenOutputTotal := session.TokenOutputTotal + originalHasUnknownUsage := session.HasUnknownUsage originalUpdatedAt := session.UpdatedAt s.emit(ctx, EventCompactStart, runID, session.ID, string(mode)) @@ -143,12 +144,14 @@ func (s *Service) runCompactForSession( session.TaskState = result.TaskState.Clone() session.TokenInputTotal = 0 session.TokenOutputTotal = 0 + session.HasUnknownUsage = false session.UpdatedAt = time.Now() if err := s.sessionStore.ReplaceTranscript(ctx, replaceTranscriptInputFromSession(session)); err != nil { session.Messages = originalMessages session.TaskState = originalTaskState session.TokenInputTotal = originalTokenInputTotal session.TokenOutputTotal = originalTokenOutputTotal + session.HasUnknownUsage = originalHasUnknownUsage session.UpdatedAt = originalUpdatedAt return failCompact(err) } diff --git a/internal/runtime/compact_generator.go b/internal/runtime/compact_generator.go index 2695a99a..9ef19984 100644 --- a/internal/runtime/compact_generator.go +++ b/internal/runtime/compact_generator.go @@ -60,9 +60,7 @@ func (g *compactSummaryGenerator) Generate( if g.providerFactory == nil { return contextcompact.SummaryOutput{}, errors.New("runtime: compact summary generator provider factory is nil") } - if strings.TrimSpace(g.providerConfig.Driver) == "" || - strings.TrimSpace(g.providerConfig.BaseURL) == "" || - strings.TrimSpace(g.providerConfig.APIKey) == "" { + if strings.TrimSpace(g.providerConfig.Driver) == "" { return contextcompact.SummaryOutput{}, errors.New("runtime: compact summary generator provider config is incomplete") } diff --git a/internal/runtime/controlplane/budget.go b/internal/runtime/controlplane/budget.go new file mode 100644 index 00000000..d777b929 --- /dev/null +++ b/internal/runtime/controlplane/budget.go @@ -0,0 +1,61 @@ +package controlplane + +// TurnBudgetAction 表示预算控制面对单次发送尝试做出的唯一动作。 +type TurnBudgetAction string + +const ( + TurnBudgetActionAllow TurnBudgetAction = "allow" + TurnBudgetActionCompact TurnBudgetAction = "compact" + TurnBudgetActionStop TurnBudgetAction = "stop" +) + +// TurnBudgetID 标识一次冻结预算尝试,避免 estimate、decision 与 usage observation 串用。 +type TurnBudgetID struct { + AttemptSeq int `json:"attempt_seq"` + RequestHash string `json:"request_hash"` +} + +// TurnBudgetEstimate 描述 runtime 对冻结请求输入 token 的主干估算事实。 +type TurnBudgetEstimate struct { + ID TurnBudgetID `json:"id"` + EstimatedInputTokens int `json:"estimated_input_tokens"` + EstimateSource string `json:"estimate_source,omitempty"` + Accurate bool `json:"accurate"` +} + +// TurnBudgetDecision 描述冻结请求在当前预算事实下的决策结果。 +type TurnBudgetDecision struct { + ID TurnBudgetID `json:"id"` + Action TurnBudgetAction `json:"action"` + Reason string `json:"reason,omitempty"` + EstimatedInputTokens int `json:"estimated_input_tokens"` + PromptBudget int `json:"prompt_budget"` + EstimateSource string `json:"estimate_source,omitempty"` +} + +// DecideTurnBudget 根据输入预算事实输出 allow、compact 或 stop 三种动作。 +func DecideTurnBudget( + estimate TurnBudgetEstimate, + promptBudget int, + compactCount int, +) TurnBudgetDecision { + decision := TurnBudgetDecision{ + ID: estimate.ID, + EstimatedInputTokens: estimate.EstimatedInputTokens, + PromptBudget: promptBudget, + EstimateSource: estimate.EstimateSource, + } + if estimate.EstimatedInputTokens <= promptBudget { + decision.Action = TurnBudgetActionAllow + decision.Reason = "within_budget" + return decision + } + if compactCount == 0 { + decision.Action = TurnBudgetActionCompact + decision.Reason = "exceeds_budget_first_time" + return decision + } + decision.Action = TurnBudgetActionStop + decision.Reason = "exceeds_budget_after_compact" + return decision +} diff --git a/internal/runtime/controlplane/decider.go b/internal/runtime/controlplane/decider.go index 644faedf..cd72bec8 100644 --- a/internal/runtime/controlplane/decider.go +++ b/internal/runtime/controlplane/decider.go @@ -9,6 +9,7 @@ import ( // StopInput 汇总最终 stop 决议所需的信号。 type StopInput struct { UserInterrupted bool + BudgetExceeded bool FatalError error Completed bool } @@ -18,6 +19,9 @@ func DecideStopReason(in StopInput) (StopReason, string) { if in.UserInterrupted { return StopReasonUserInterrupt, "" } + if in.BudgetExceeded { + return StopReasonBudgetExceeded, "" + } if in.FatalError != nil { if errors.Is(in.FatalError, context.Canceled) { return StopReasonUserInterrupt, "" diff --git a/internal/runtime/controlplane/stop_reason.go b/internal/runtime/controlplane/stop_reason.go index 3b8b0c2f..2aa6d9e5 100644 --- a/internal/runtime/controlplane/stop_reason.go +++ b/internal/runtime/controlplane/stop_reason.go @@ -6,6 +6,8 @@ type StopReason string const ( // StopReasonFatalError 表示出现不可恢复错误。 StopReasonFatalError StopReason = "STOP_FATAL_ERROR" + // StopReasonBudgetExceeded 表示预算闭环判定本轮请求无法继续发送。 + StopReasonBudgetExceeded StopReason = "STOP_BUDGET_EXCEEDED" // StopReasonCompleted 表示运行满足完成条件。 StopReasonCompleted StopReason = "STOP_COMPLETED" // StopReasonUserInterrupt 表示运行被用户或上层上下文中断。 diff --git a/internal/runtime/events.go b/internal/runtime/events.go index eb41ad79..befceb31 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -29,7 +29,13 @@ type PhaseChangedPayload struct { // BudgetCheckedPayload 为预算检查预留负载。 type BudgetCheckedPayload struct { - Note string `json:"note,omitempty"` + AttemptSeq int `json:"attempt_seq"` + RequestHash string `json:"request_hash"` + Action string `json:"action"` + Reason string `json:"reason,omitempty"` + EstimatedInputTokens int `json:"estimated_input_tokens"` + PromptBudget int `json:"prompt_budget"` + EstimateSource string `json:"estimate_source,omitempty"` } // ProgressEvaluatedPayload 汇总 progress 控制面的评估结果。 @@ -45,7 +51,42 @@ type StopReasonDecidedPayload struct { // LedgerReconciledPayload 为账本对账预留负载。 type LedgerReconciledPayload struct { - Note string `json:"note,omitempty"` + AttemptSeq int `json:"attempt_seq"` + RequestHash string `json:"request_hash"` + InputTokens int `json:"input_tokens"` + InputSource string `json:"input_source"` + OutputTokens int `json:"output_tokens"` + OutputSource string `json:"output_source"` + HasUnknownUsage bool `json:"has_unknown_usage"` +} + +// newBudgetCheckedPayload 将预算决策对象展开为对外事件 payload,保持可观测字段稳定。 +func newBudgetCheckedPayload(decision controlplane.TurnBudgetDecision) BudgetCheckedPayload { + return BudgetCheckedPayload{ + AttemptSeq: decision.ID.AttemptSeq, + RequestHash: decision.ID.RequestHash, + Action: string(decision.Action), + Reason: decision.Reason, + EstimatedInputTokens: decision.EstimatedInputTokens, + PromptBudget: decision.PromptBudget, + EstimateSource: decision.EstimateSource, + } +} + +// newLedgerReconciledPayload 将 usage observation 与调和结果拼装为对外事件 payload。 +func newLedgerReconciledPayload( + observation TurnBudgetUsageObservation, + result ledgerReconcileResult, +) LedgerReconciledPayload { + return LedgerReconciledPayload{ + AttemptSeq: observation.ID.AttemptSeq, + RequestHash: observation.ID.RequestHash, + InputTokens: result.inputTokens, + InputSource: result.inputSource, + OutputTokens: result.outputTokens, + OutputSource: result.outputSource, + HasUnknownUsage: result.hasUnknownUsage, + } } // PermissionRequestPayload 描述一次权限请求。 @@ -155,10 +196,14 @@ const ( EventSkillMissing EventType = "skill_missing" // EventPhaseChanged 表示运行 phase 迁移。 EventPhaseChanged EventType = "phase_changed" + // EventBudgetChecked 表示预算控制面对冻结请求完成一次预算决策。 + EventBudgetChecked EventType = "budget_checked" // EventProgressEvaluated 表示 progress 评估完成。 EventProgressEvaluated EventType = "progress_evaluated" // EventStopReasonDecided 表示 stop reason 已决议。 EventStopReasonDecided EventType = "stop_reason_decided" + // EventLedgerReconciled 表示本轮 usage 已按新账本语义完成调和。 + EventLedgerReconciled EventType = "ledger_reconciled" // EventTodoUpdated 表示 todo_write 成功更新。 EventTodoUpdated EventType = "todo_updated" // EventTodoConflict 表示 todo_write 触发冲突类错误。 @@ -175,8 +220,11 @@ const ( // TokenUsagePayload 承载单轮 token 用量统计。 type TokenUsagePayload struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - SessionInputTokens int `json:"session_input_tokens"` - SessionOutputTokens int `json:"session_output_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputSource string `json:"input_source,omitempty"` + OutputSource string `json:"output_source,omitempty"` + HasUnknownUsage bool `json:"has_unknown_usage,omitempty"` + SessionInputTokens int `json:"session_input_tokens"` + SessionOutputTokens int `json:"session_output_tokens"` } diff --git a/internal/runtime/events_subagent.go b/internal/runtime/events_subagent.go index c25c9021..c2cb8c61 100644 --- a/internal/runtime/events_subagent.go +++ b/internal/runtime/events_subagent.go @@ -2,12 +2,6 @@ package runtime import "neo-code/internal/subagent" -// EventPermissionRequest 为兼容旧事件名保留,语义等同 EventPermissionRequested。 -const EventPermissionRequest EventType = EventPermissionRequested - -// EventCompactDone 为兼容旧事件名保留,语义等同 EventCompactApplied。 -const EventCompactDone EventType = EventCompactApplied - // SubAgentEventPayload 描述子代理执行生命周期的事件载荷。 type SubAgentEventPayload struct { Role subagent.Role `json:"role"` diff --git a/internal/runtime/input_prepare_test.go b/internal/runtime/input_prepare_test.go index f12a1bdc..46478a9e 100644 --- a/internal/runtime/input_prepare_test.go +++ b/internal/runtime/input_prepare_test.go @@ -193,7 +193,7 @@ func newPrepareTestServiceWithRuntimeConfig( t.Fatalf("load config: %v", err) } - store := agentsession.NewStore(t.TempDir(), workdir) + store := agentsession.NewSQLiteStore(t.TempDir(), workdir) t.Cleanup(func() { _ = store.Close() }) diff --git a/internal/runtime/provider_stream.go b/internal/runtime/provider_stream.go index 1ad5ecb4..b120e4a4 100644 --- a/internal/runtime/provider_stream.go +++ b/internal/runtime/provider_stream.go @@ -14,6 +14,7 @@ type streamGenerateResult struct { message providertypes.Message inputTokens int outputTokens int + usagePresent bool err error } @@ -40,6 +41,7 @@ func generateStreamingMessage( if payload.Usage != nil { outcome.inputTokens = payload.Usage.InputTokens outcome.outputTokens = payload.Usage.OutputTokens + outcome.usagePresent = true } if userOnMessageDone != nil { userOnMessageDone(payload) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index be90441b..681d3d23 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -26,6 +26,12 @@ var selfHealingReminder = promptasset.NoProgressReminder() var selfHealingRepeatReminder = promptasset.RepeatCycleReminder() +const ( + usageSourceObserved = "observed" + usageSourceEstimated = "estimated" + usageSourceUnknown = "unknown" +) + // computeToolSignature 计算单轮执行的工具签名,用于循环检测。 func computeToolSignature(calls []providertypes.ToolCall) string { if len(calls) == 0 { @@ -122,6 +128,8 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { for turn := 0; ; turn++ { state.turn = turn + state.compactCount = 0 + state.nextAttemptSeq = 1 if err := s.setBaseRunState(ctx, &state, controlplane.RunStatePlan); err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } @@ -131,7 +139,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return s.handleRunError(ctx, state.runID, state.session.ID, err) } - snapshot, rebuilt, err := s.prepareTurnSnapshot(ctx, &state) + snapshot, rebuilt, err := s.prepareTurnBudgetSnapshot(ctx, &state) if err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } @@ -139,13 +147,35 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { continue } - turnResult, err := s.callProviderWithRetry(ctx, &state, snapshot) + decision, err := s.evaluateTurnBudget(ctx, &state, snapshot) if err != nil { - if provider.IsContextTooLong(err) && state.reactiveCompactAttempts < maxReactiveCompactAttempts { + return s.handleRunError(ctx, state.runID, state.session.ID, err) + } + switch decision.Action { + case controlplane.TurnBudgetActionCompact: + if _, err := s.applyCompactForState( + ctx, + &state, + snapshot.Config, + contextcompact.ModeProactive, + compactErrorBestEffort, + ); err != nil { + return s.handleRunError(ctx, state.runID, state.session.ID, err) + } + continue + case controlplane.TurnBudgetActionStop: + state.budgetExceeded = true + return nil + } + + turnOutput, err := s.callProviderWithRetry(ctx, &state, snapshot) + if err != nil { + if provider.IsContextTooLong(err) && + state.reactiveCompactAttempts < snapshot.Config.Context.Budget.MaxReactiveCompacts { state.reactiveCompactAttempts++ - degradedCfg := snapshot.config + degradedCfg := snapshot.Config degradedCfg.Context.Compact.ManualKeepRecentMessages = degradeKeepRecentMessages( - snapshot.config.Context.Compact.ManualKeepRecentMessages, + snapshot.Config.Context.Compact.ManualKeepRecentMessages, state.reactiveCompactAttempts, ) _, _ = s.applyCompactForState(ctx, &state, degradedCfg, contextcompact.ModeReactive, compactErrorBestEffort) @@ -154,37 +184,42 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return s.handleRunError(ctx, state.runID, state.session.ID, err) } - if strings.TrimSpace(turnResult.assistant.Role) == "" { - turnResult.assistant.Role = providertypes.RoleAssistant + if strings.TrimSpace(turnOutput.assistant.Role) == "" { + turnOutput.assistant.Role = providertypes.RoleAssistant + } + reconciled, err := s.reconcileLedger(&state, decision, turnOutput.usageObservation) + if err != nil { + return s.handleRunError(ctx, state.runID, state.session.ID, err) } if err := s.appendAssistantMessageAndSave( ctx, &state, snapshot, - turnResult.assistant, - turnResult.inputTokens, - turnResult.outputTokens, + turnOutput.assistant, + reconciled.inputTokens, + reconciled.outputTokens, ); err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } - s.emitTokenUsage(ctx, &state, turnResult) + s.emitLedgerReconciled(ctx, &state, turnOutput.usageObservation, reconciled) + s.emitTokenUsage(ctx, &state, reconciled) state.mu.Lock() state.completion = collectCompletionState( &state, - turnResult.assistant, - len(turnResult.assistant.ToolCalls) > 0, + turnOutput.assistant, + len(turnOutput.assistant.ToolCalls) > 0, ) completionState, completed := controlplane.EvaluateCompletion( state.completion, - len(turnResult.assistant.ToolCalls) > 0, + len(turnOutput.assistant.ToolCalls) > 0, ) state.completion = completionState state.mu.Unlock() - if len(turnResult.assistant.ToolCalls) == 0 { + if len(turnOutput.assistant.ToolCalls) == 0 { if completed { - s.emitRunScoped(ctx, EventAgentDone, &state, turnResult.assistant) + s.emitRunScoped(ctx, EventAgentDone, &state, turnOutput.assistant) s.triggerMemoExtraction(state.session.ID, state.session.Messages, state.rememberedThisRun) return nil } @@ -196,8 +231,8 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { cloneTodosForPersistence(state.session.Todos), cloneTodosForPersistence(state.session.Todos), toolExecutionSummary{}, - snapshot.noProgressStreakLimit, - snapshot.repeatCycleStreakLimit, + snapshot.NoProgressStreakLimit, + snapshot.RepeatCycleStreakLimit, ) state.progress = controlplane.EvaluateProgress(state.progress, progressInput) currentScore := state.progress.LastScore @@ -212,7 +247,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.setBaseRunState(ctx, &state, controlplane.RunStateExecute); err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } - summary, err := s.executeAssistantToolCalls(ctx, &state, snapshot, turnResult.assistant) + summary, err := s.executeAssistantToolCalls(ctx, &state, snapshot, turnOutput.assistant) if err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } @@ -228,8 +263,8 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { beforeTodos, afterTodos, summary, - snapshot.noProgressStreakLimit, - snapshot.repeatCycleStreakLimit, + snapshot.NoProgressStreakLimit, + snapshot.RepeatCycleStreakLimit, ) state.progress = controlplane.EvaluateProgress(state.progress, progressInput) currentScore := state.progress.LastScore @@ -244,13 +279,13 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } } -// prepareTurnSnapshot 基于当前会话状态冻结一轮推理所需的请求快照。 -func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (turnSnapshot, bool, error) { +// prepareTurnBudgetSnapshot 基于当前会话状态冻结一次预算尝试所需的 request 与预算事实。 +func (s *Service) prepareTurnBudgetSnapshot(ctx context.Context, state *runState) (TurnBudgetSnapshot, bool, error) { cfg := s.configManager.Get() activeWorkdir := agentsession.EffectiveWorkdir(state.session.Workdir, cfg.Workdir) activeSkills, err := s.resolveActiveSkills(ctx, state) if err != nil { - return turnSnapshot{}, false, err + return TurnBudgetSnapshot{}, false, err } builtContext, err := s.contextBuilder.Build(ctx, agentcontext.BuildInput{ @@ -268,43 +303,32 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur }, Compact: agentcontext.CompactOptions{ DisableMicroCompact: cfg.Context.Compact.MicroCompactDisabled, - AutoCompactThreshold: s.autoCompactThresholdForState(ctx, cfg, state), MicroCompactRetainedToolSpans: cfg.Context.Compact.MicroCompactRetainedToolSpans, ReadTimeMaxMessageSpans: cfg.Context.Compact.ReadTimeMaxMessageSpans, }, }) if err != nil { - return turnSnapshot{}, false, err + return TurnBudgetSnapshot{}, false, err } if strings.Contains(builtContext.SystemPrompt, "## Todo State") { s.emitRunScoped(ctx, EventTodoSummaryInjected, state, TodoEventPayload{}) } - if builtContext.AutoCompactSuggested && !state.compactApplied { - applied, err := s.applyCompactForState(ctx, state, cfg, contextcompact.ModeAuto, compactErrorBestEffort) - if err != nil { - return turnSnapshot{}, false, err - } - if applied { - return turnSnapshot{}, true, nil - } - } - toolSpecs, err := s.toolManager.ListAvailableSpecs(ctx, tools.SpecListInput{ SessionID: state.session.ID, }) if err != nil { - return turnSnapshot{}, false, err + return TurnBudgetSnapshot{}, false, err } toolSpecs = prioritizeToolSpecsBySkillHints(toolSpecs, activeSkills) resolvedProvider, err := config.ResolveSelectedProvider(cfg) if err != nil { - return turnSnapshot{}, false, err + return TurnBudgetSnapshot{}, false, err } providerRuntimeCfg, err := resolvedProvider.ToRuntimeConfig() if err != nil { - return turnSnapshot{}, false, err + return TurnBudgetSnapshot{}, false, err } state.mu.Lock() @@ -314,24 +338,33 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur limit := resolveNoProgressStreakLimit(cfg.Runtime) repeatLimit := resolveRepeatCycleStreakLimit(cfg.Runtime) systemPrompt := withProgressReminder(builtContext.SystemPrompt, score) - + promptBudget, budgetSource := s.resolvePromptBudget(ctx, cfg) model := strings.TrimSpace(cfg.CurrentModel) - return turnSnapshot{ - config: cfg, - providerConfig: providerRuntimeCfg, - model: model, - workdir: activeWorkdir, - toolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, - noProgressStreakLimit: limit, - repeatCycleStreakLimit: repeatLimit, - request: providertypes.GenerateRequest{ - Model: model, - SystemPrompt: systemPrompt, - Messages: builtContext.Messages, - Tools: toolSpecs, - SessionAssetReader: s.buildSessionAssetReader(ctx, state.session.ID), - }, - }, false, nil + request := providertypes.GenerateRequest{ + Model: model, + SystemPrompt: systemPrompt, + Messages: builtContext.Messages, + Tools: toolSpecs, + SessionAssetReader: s.buildSessionAssetReader(ctx, state.session.ID), + } + attemptSeq := state.nextAttemptSeq + if attemptSeq <= 0 { + attemptSeq = 1 + } + return newTurnBudgetSnapshot( + attemptSeq, + cfg, + providerRuntimeCfg, + model, + activeWorkdir, + time.Duration(cfg.ToolTimeoutSec)*time.Second, + promptBudget, + budgetSource, + state.compactCount, + limit, + repeatLimit, + request, + ), false, nil } // resolveNoProgressStreakLimit 统一解析熔断阈值,避免运行期出现无效值导致分支行为不一致。 @@ -350,12 +383,12 @@ func resolveRepeatCycleStreakLimit(rc config.RuntimeConfig) int { return rc.MaxRepeatCycleStreak } -// callProviderWithRetry 使用冻结后的 turnSnapshot 执行 provider 调用与必要重试。 +// callProviderWithRetry 使用冻结后的 TurnBudgetSnapshot 执行 provider 调用与必要重试。 func (s *Service) callProviderWithRetry( ctx context.Context, state *runState, - snapshot turnSnapshot, -) (providerTurnResult, error) { + snapshot TurnBudgetSnapshot, +) (turnProviderOutput, error) { var lastErr error for retryAttempt := 0; retryAttempt <= defaultProviderRetryMax; retryAttempt++ { @@ -367,17 +400,17 @@ func (s *Service) callProviderWithRetry( select { case <-ctx.Done(): - return providerTurnResult{}, ctx.Err() + return turnProviderOutput{}, ctx.Err() case <-time.After(wait): } } - modelProvider, err := s.providerFactory.Build(ctx, snapshot.providerConfig) + modelProvider, err := s.providerFactory.Build(ctx, snapshot.ProviderConfig) if err != nil { - return providerTurnResult{}, err + return turnProviderOutput{}, err } - streamOutcome := generateStreamingMessage(ctx, modelProvider, snapshot.request, streaming.Hooks{ + streamOutcome := generateStreamingMessage(ctx, modelProvider, snapshot.Request, streaming.Hooks{ OnTextDelta: func(text string) { s.emitRunScoped(ctx, EventAgentChunk, state, text) }, @@ -388,35 +421,42 @@ func (s *Service) callProviderWithRetry( if streamOutcome.err != nil { lastErr = streamOutcome.err if !isRetryableProviderError(lastErr) { - return providerTurnResult{}, lastErr + return turnProviderOutput{}, lastErr } if ctx.Err() != nil { - return providerTurnResult{}, ctx.Err() + return turnProviderOutput{}, ctx.Err() } continue } - return providerTurnResult{ - assistant: streamOutcome.message, - inputTokens: streamOutcome.inputTokens, - outputTokens: streamOutcome.outputTokens, + return turnProviderOutput{ + assistant: streamOutcome.message, + usageObservation: newTurnBudgetUsageObservation( + snapshot.ID, + streamOutcome.inputTokens, + streamOutcome.outputTokens, + streamOutcome.usagePresent, + ), }, nil } if lastErr == nil { lastErr = errors.New("max retries exceeded") } - return providerTurnResult{}, fmt.Errorf("runtime: max retries exhausted, last error: %w", lastErr) + return turnProviderOutput{}, fmt.Errorf("runtime: max retries exhausted, last error: %w", lastErr) } // emitTokenUsage 在单轮 provider 调用成功后发出 token_usage 事件。 -func (s *Service) emitTokenUsage(ctx context.Context, state *runState, result providerTurnResult) { - if result.inputTokens == 0 && result.outputTokens == 0 { +func (s *Service) emitTokenUsage(ctx context.Context, state *runState, result ledgerReconcileResult) { + if result.inputTokens == 0 && result.outputTokens == 0 && !result.hasUnknownUsage { return } s.emitRunScoped(ctx, EventTokenUsage, state, TokenUsagePayload{ InputTokens: result.inputTokens, OutputTokens: result.outputTokens, + InputSource: result.inputSource, + OutputSource: result.outputSource, + HasUnknownUsage: result.hasUnknownUsage, SessionInputTokens: state.session.TokenInputTotal, SessionOutputTokens: state.session.TokenOutputTotal, }) @@ -443,10 +483,13 @@ func (s *Service) applyCompactForState( if compactErr != nil { return compactErr } + if mode == contextcompact.ModeProactive || mode == contextcompact.ModeReactive { + state.compactCount++ + } state.session = session if result.Applied { state.resetTokenTotals() - state.compactApplied = true + state.nextAttemptSeq++ applied = true } return nil @@ -457,43 +500,87 @@ func (s *Service) applyCompactForState( return applied, nil } -// autoCompactThreshold 返回当前配置下的自动 compact 触发阈值。 -func (s *Service) autoCompactThreshold(ctx context.Context, cfg config.Config) int { - return s.autoCompactThresholdForState(ctx, cfg, nil) +// resolvePromptBudget 解析当前请求链路使用的 prompt budget 与来源标签。 +func (s *Service) resolvePromptBudget(ctx context.Context, cfg config.Config) (int, string) { + if cfg.Context.Budget.PromptBudget > 0 { + return cfg.Context.Budget.PromptBudget, "explicit" + } + promptBudget := cfg.Context.Budget.FallbackPromptBudget + source := "fallback" + if s != nil && s.budgetResolver != nil { + resolvedBudget, resolvedSource, err := s.budgetResolver.ResolvePromptBudget(ctx, cfg) + if err == nil && resolvedBudget > 0 { + promptBudget = resolvedBudget + if strings.TrimSpace(resolvedSource) != "" { + source = resolvedSource + } + } + } + return promptBudget, source } -// autoCompactThresholdForState 返回当前配置下的自动 compact 触发阈值,并在单次 run 内按关键输入缓存结果。 -func (s *Service) autoCompactThresholdForState(ctx context.Context, cfg config.Config, state *runState) int { - if !cfg.Context.AutoCompact.Enabled { - return 0 - } - if cfg.Context.AutoCompact.InputTokenThreshold > 0 { - return cfg.Context.AutoCompact.InputTokenThreshold +// evaluateTurnBudget 对冻结请求执行发送前输入 token 估算,并产出唯一预算动作。 +func (s *Service) evaluateTurnBudget( + ctx context.Context, + state *runState, + snapshot TurnBudgetSnapshot, +) (controlplane.TurnBudgetDecision, error) { + modelProvider, err := s.providerFactory.Build(ctx, snapshot.ProviderConfig) + if err != nil { + return controlplane.TurnBudgetDecision{}, err } + providerEstimate, err := modelProvider.EstimateInputTokens(ctx, snapshot.Request) + if err != nil { + return controlplane.TurnBudgetDecision{}, err + } + estimate := newTurnBudgetEstimate(snapshot.ID, providerEstimate) + decision := controlplane.DecideTurnBudget( + estimate, + snapshot.PromptBudget, + snapshot.CompactCount, + ) + s.emitRunScoped(ctx, EventBudgetChecked, state, newBudgetCheckedPayload(decision)) + return decision, nil +} - key := autoCompactCacheKeyFromConfig(cfg) - if state != nil && state.autoCompactCache.valid && state.autoCompactCache.key == key { - return state.autoCompactCache.threshold - } +// reconcileLedger 根据 observed usage 或发送前 estimate 生成本轮账本写入结果。 +func (s *Service) reconcileLedger( + state *runState, + decision controlplane.TurnBudgetDecision, + observation TurnBudgetUsageObservation, +) (ledgerReconcileResult, error) { + if decision.ID != observation.ID { + return ledgerReconcileResult{}, fmt.Errorf("runtime: turn budget id mismatch between decision and usage observation") + } + reconciled := ledgerReconcileResult{ + inputTokens: observation.InputTokens, + inputSource: usageSourceObserved, + outputTokens: observation.OutputTokens, + outputSource: usageSourceObserved, + } + if observation.InputObserved && observation.OutputObserved { + return reconciled, nil + } + reconciled.inputTokens = decision.EstimatedInputTokens + reconciled.inputSource = usageSourceEstimated + reconciled.outputTokens = 0 + reconciled.outputSource = usageSourceUnknown + reconciled.hasUnknownUsage = true + if state != nil { + state.session.HasUnknownUsage = true + state.hasUnknownUsage = true + } + return reconciled, nil +} - threshold := fallbackAutoCompactThreshold(cfg) - cacheable := true - if s != nil && s.autoCompactThresholdResolver != nil { - resolvedThreshold, err := s.autoCompactThresholdResolver.ResolveAutoCompactThreshold(ctx, cfg) - if err != nil { - cacheable = false - } else if resolvedThreshold > 0 { - threshold = resolvedThreshold - } - } - if state != nil && cacheable { - state.autoCompactCache = autoCompactThresholdCache{ - key: key, - threshold: threshold, - valid: true, - } - } - return threshold +// emitLedgerReconciled 发出本轮 usage 调和结果,便于区分 observed 与估算值。 +func (s *Service) emitLedgerReconciled( + ctx context.Context, + state *runState, + observation TurnBudgetUsageObservation, + result ledgerReconcileResult, +) { + s.emitRunScoped(ctx, EventLedgerReconciled, state, newLedgerReconciledPayload(observation, result)) } // degradeKeepRecentMessages 根据 reactive compact 尝试次数逐步减少保留消息数。 @@ -548,14 +635,6 @@ func sessionTitleFromParts(parts []providertypes.ContentPart) string { return "Image Message" } -// fallbackAutoCompactThreshold 返回自动推导失败时仍可继续使用的保底阈值。 -func fallbackAutoCompactThreshold(cfg config.Config) int { - if cfg.Context.AutoCompact.FallbackInputTokenThreshold > 0 { - return cfg.Context.AutoCompact.FallbackInputTokenThreshold - } - return 0 -} - // bindSessionLock 获取并持有指定会话锁,返回对应的释放函数。 func (s *Service) bindSessionLock(sessionID string) func() { id := strings.TrimSpace(sessionID) @@ -589,14 +668,23 @@ func withProgressReminder(systemPrompt string, score controlplane.ProgressScore) return trimmed + "\n\n" + reminder } -// autoCompactCacheKeyFromConfig 提取会影响自动压缩阈值解析的配置维度,用于 run 内缓存命中判断。 -func autoCompactCacheKeyFromConfig(cfg config.Config) autoCompactThresholdCacheKey { - return autoCompactThresholdCacheKey{ - provider: strings.TrimSpace(cfg.SelectedProvider), - model: strings.TrimSpace(cfg.CurrentModel), - autoCompactEnabled: cfg.Context.AutoCompact.Enabled, - autoCompactInputThreshold: cfg.Context.AutoCompact.InputTokenThreshold, - autoCompactReserveTokens: cfg.Context.AutoCompact.ReserveTokens, - autoCompactFallback: cfg.Context.AutoCompact.FallbackInputTokenThreshold, +// computeRequestHash 计算冻结请求的稳定指纹,避免 compact 前后的估算结果串用。 +func computeRequestHash(req providertypes.GenerateRequest) string { + hashInput := struct { + Model string `json:"model"` + SystemPrompt string `json:"system_prompt"` + Messages []providertypes.Message `json:"messages"` + Tools []tools.ToolSpec `json:"tools"` + }{ + Model: req.Model, + SystemPrompt: req.SystemPrompt, + Messages: cloneMessages(req.Messages), + Tools: append([]tools.ToolSpec(nil), req.Tools...), + } + encoded, err := json.Marshal(hashInput) + if err != nil { + return "" } + sum := sha256.Sum256(encoded) + return hex.EncodeToString(sum[:]) } diff --git a/internal/runtime/run_lifecycle.go b/internal/runtime/run_lifecycle.go index 28e52ea3..9bf04024 100644 --- a/internal/runtime/run_lifecycle.go +++ b/internal/runtime/run_lifecycle.go @@ -121,11 +121,6 @@ func isBaseLifecycleState(state controlplane.RunState) bool { } } -// transitionRunState 兼容旧调用入口,内部统一转为 base lifecycle 更新。 -func (s *Service) transitionRunState(ctx context.Context, state *runState, next controlplane.RunState) error { - return s.setBaseRunState(ctx, state, next) -} - // emitRunTermination 在 Run 退出时决议并发出唯一的 stop_reason_decided 事件。 func (s *Service) emitRunTermination(ctx context.Context, input UserInput, state *runState, err error) { runID := strings.TrimSpace(input.RunID) @@ -148,7 +143,9 @@ func (s *Service) emitRunTermination(ctx context.Context, input UserInput, state } in := controlplane.StopInput{} - if err != nil { + if state != nil && state.budgetExceeded { + in.BudgetExceeded = true + } else if err != nil { switch { case errors.Is(err, context.Canceled): in.UserInterrupted = true diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index b1384eee..b1a58cea 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -107,24 +107,24 @@ type MemoExtractor interface { Schedule(sessionID string, messages []providertypes.Message) } -// Service 是 runtime 的默认实现,负责组织一次完整的 agent 运行闭环。 -type AutoCompactThresholdResolver interface { - ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) +// BudgetResolver 定义 prompt budget 解析能力,避免 runtime 直接处理模型目录细节。 +type BudgetResolver interface { + ResolvePromptBudget(ctx context.Context, cfg config.Config) (int, string, error) } type Service struct { - configManager *config.Manager - sessionStore agentsession.Store - sessionAssetStore agentsession.AssetStore - userInputPreparer UserInputPreparer - toolManager tools.Manager - providerFactory ProviderFactory - contextBuilder agentcontext.Builder - compactRunner contextcompact.Runner - approvalBroker *approval.Broker - memoExtractor MemoExtractor - skillsRegistry skills.Registry - autoCompactThresholdResolver AutoCompactThresholdResolver + configManager *config.Manager + sessionStore agentsession.Store + sessionAssetStore agentsession.AssetStore + userInputPreparer UserInputPreparer + toolManager tools.Manager + providerFactory ProviderFactory + contextBuilder agentcontext.Builder + compactRunner contextcompact.Runner + approvalBroker *approval.Broker + memoExtractor MemoExtractor + skillsRegistry skills.Registry + budgetResolver BudgetResolver events chan RuntimeEvent sessionMu sync.Mutex @@ -321,7 +321,7 @@ func isRuntimeSessionAlreadyExistsError(err error) bool { return errors.Is(err, agentsession.ErrSessionAlreadyExists) || errors.Is(err, os.ErrExist) } -// SetAutoCompactThresholdResolver 注入自动压缩阈值解析能力,避免 runtime 直接处理模型目录细节。 -func (s *Service) SetAutoCompactThresholdResolver(resolver AutoCompactThresholdResolver) { - s.autoCompactThresholdResolver = resolver +// SetBudgetResolver 注入 prompt budget 解析能力,避免 runtime 直接感知模型目录细节。 +func (s *Service) SetBudgetResolver(resolver BudgetResolver) { + s.budgetResolver = resolver } diff --git a/internal/runtime/runtime_branch_coverage_test.go b/internal/runtime/runtime_branch_coverage_test.go index 4503b738..1a0eb69f 100644 --- a/internal/runtime/runtime_branch_coverage_test.go +++ b/internal/runtime/runtime_branch_coverage_test.go @@ -16,7 +16,7 @@ func TestExecuteAssistantToolCallsReturnsNilForEmptyCalls(t *testing.T) { service := &Service{} state := &runState{} - _, err := service.executeAssistantToolCalls(context.Background(), state, turnSnapshot{}, providertypes.Message{}) + _, err := service.executeAssistantToolCalls(context.Background(), state, TurnBudgetSnapshot{}, providertypes.Message{}) if err != nil { t.Fatalf("executeAssistantToolCalls() error = %v", err) } @@ -32,7 +32,7 @@ func TestExecuteOneToolCallStopsWhenContextCheckReturnsTrue(t *testing.T) { _, _, _ = service.executeOneToolCall( context.Background(), &state, - turnSnapshot{}, + TurnBudgetSnapshot{}, providertypes.ToolCall{ID: "call-1", Name: "noop"}, &sync.Mutex{}, func() bool { return true }, @@ -90,11 +90,11 @@ func TestTransitionRunPhaseNoopBranches(t *testing.T) { t.Parallel() service := &Service{events: make(chan RuntimeEvent, 4)} - service.transitionRunState(context.Background(), nil, controlplane.RunStatePlan) + service.setBaseRunState(context.Background(), nil, controlplane.RunStatePlan) state := newRunState("run-phase", newRuntimeSession("session-phase")) state.lifecycle = controlplane.RunStatePlan - service.transitionRunState(context.Background(), &state, controlplane.RunStatePlan) + service.setBaseRunState(context.Background(), &state, controlplane.RunStatePlan) events := collectRuntimeEvents(service.Events()) if len(events) != 0 { diff --git a/internal/runtime/runtime_gap_coverage_test.go b/internal/runtime/runtime_gap_coverage_test.go index d99bf560..6a156682 100644 --- a/internal/runtime/runtime_gap_coverage_test.go +++ b/internal/runtime/runtime_gap_coverage_test.go @@ -180,7 +180,13 @@ func TestCompactSummaryGeneratorErrorBranches(t *testing.T) { g = &compactSummaryGenerator{ providerFactory: &scriptedProviderFactory{err: errors.New("build failed")}, - providerConfig: provider.RuntimeConfig{Name: "openai", Driver: "openai", BaseURL: "https://example.com", APIKey: "k"}, + providerConfig: provider.RuntimeConfig{ + Name: "openai", + Driver: "openai", + BaseURL: "https://example.com", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("k"), + }, } if _, err := g.Generate(context.Background(), contextcompact.SummaryInput{}); err == nil { t.Fatalf("expected provider build error") @@ -188,7 +194,13 @@ func TestCompactSummaryGeneratorErrorBranches(t *testing.T) { g = &compactSummaryGenerator{ providerFactory: &scriptedProviderFactory{provider: &scriptedProvider{streams: [][]providertypes.StreamEvent{{providertypes.NewTextDeltaStreamEvent(" ")}}}}, - providerConfig: provider.RuntimeConfig{Name: "openai", Driver: "openai", BaseURL: "https://example.com", APIKey: "k"}, + providerConfig: provider.RuntimeConfig{ + Name: "openai", + Driver: "openai", + BaseURL: "https://example.com", + APIKeyEnv: "OPENAI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("k"), + }, } if _, err := g.Generate(context.Background(), contextcompact.SummaryInput{}); err == nil { t.Fatalf("expected empty summary error") diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index dd0dbbdd..4a31e7c8 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -140,6 +140,8 @@ func TestRunStateMutationsAndSync(t *testing.T) { state := newRunState("run-1", session) state.recordUsage(10, 20) + state.session.HasUnknownUsage = true + state.hasUnknownUsage = true if state.session.TokenInputTotal != 11 || state.session.TokenOutputTotal != 22 { t.Fatalf("unexpected token totals: in=%d out=%d", state.session.TokenInputTotal, state.session.TokenOutputTotal) } @@ -148,6 +150,9 @@ func TestRunStateMutationsAndSync(t *testing.T) { if state.session.TokenInputTotal != 0 || state.session.TokenOutputTotal != 0 { t.Fatalf("expected reset totals to be zero, got in=%d out=%d", state.session.TokenInputTotal, state.session.TokenOutputTotal) } + if state.session.HasUnknownUsage || state.hasUnknownUsage { + t.Fatalf("expected resetTokenTotals to clear unknown usage flags") + } before := state.session.UpdatedAt state.recordUsage(1, 2) @@ -191,9 +196,9 @@ func TestAppendAssistantMessageAndSaveMetadataBranches(t *testing.T) { service := &Service{sessionStore: store} state := newRunState("run-append-assistant", session) - snapshot := turnSnapshot{ - providerConfig: providerRuntimeConfigForTest("openai"), - model: "gpt-4.1", + snapshot := TurnBudgetSnapshot{ + ProviderConfig: providerRuntimeConfigForTest("openai"), + Model: "gpt-4.1", } if err := service.appendAssistantMessageAndSave( @@ -211,8 +216,8 @@ func TestAppendAssistantMessageAndSaveMetadataBranches(t *testing.T) { } store.saves = 0 - state.session.Provider = snapshot.providerConfig.Name - state.session.Model = snapshot.model + state.session.Provider = snapshot.ProviderConfig.Name + state.session.Model = snapshot.Model if err := service.appendAssistantMessageAndSave( context.Background(), &state, @@ -601,14 +606,19 @@ func TestEmitTokenUsageSkipsZeroUsage(t *testing.T) { service := &Service{events: make(chan RuntimeEvent, 8)} state := &runState{runID: "run-token", session: newRuntimeSession("session-token")} - service.emitTokenUsage(context.Background(), state, providerTurnResult{}) + service.emitTokenUsage(context.Background(), state, ledgerReconcileResult{}) events := collectRuntimeEvents(service.Events()) if len(events) != 0 { t.Fatalf("expected no token event for zero usage, got %+v", events) } state.recordUsage(5, 7) - service.emitTokenUsage(context.Background(), state, providerTurnResult{inputTokens: 5, outputTokens: 7}) + service.emitTokenUsage(context.Background(), state, ledgerReconcileResult{ + inputTokens: 5, + inputSource: usageSourceObserved, + outputTokens: 7, + outputSource: usageSourceObserved, + }) events = collectRuntimeEvents(service.Events()) if len(events) != 1 || events[0].Type != EventTokenUsage { t.Fatalf("expected one token usage event, got %+v", events) @@ -637,7 +647,7 @@ func TestExecuteAssistantToolCallsFillsErrorContent(t *testing.T) { {ID: "call-err", Name: "filesystem_read_file", Arguments: `{"path":"a.txt"}`}, }, } - snapshot := turnSnapshot{workdir: t.TempDir(), toolTimeout: time.Second} + snapshot := TurnBudgetSnapshot{Workdir: t.TempDir(), ToolTimeout: time.Second} if _, err := service.executeAssistantToolCalls(context.Background(), &state, snapshot, assistant); err != nil { t.Fatalf("executeAssistantToolCalls() error = %v", err) @@ -677,7 +687,7 @@ func TestExecuteAssistantToolCallsCanceledSaveStillEmitsResultWhenExecErr(t *tes {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"a.txt"}`}, }, } - snapshot := turnSnapshot{workdir: t.TempDir(), toolTimeout: time.Second} + snapshot := TurnBudgetSnapshot{Workdir: t.TempDir(), ToolTimeout: time.Second} _, err := service.executeAssistantToolCalls(context.Background(), &state, snapshot, assistant) if !errors.Is(err, context.Canceled) { diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index b5a7eea0..e81fbd1d 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -361,19 +361,19 @@ func TestPrepareTurnSnapshotInjectRepeatReminderWithEmptyPrompt(t *testing.T) { state.progress.LastScore.StalledProgressState = controlplane.StalledProgressStalled state.progress.LastScore.ReminderKind = controlplane.ReminderKindRepeatCycle - snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) + snapshot, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state) if err != nil { - t.Fatalf("prepareTurnSnapshot() error = %v", err) + t.Fatalf("prepareTurnBudgetSnapshot() error = %v", err) } if rebuilt { t.Fatal("expected rebuilt=false") } - if snapshot.request.SystemPrompt != selfHealingRepeatReminder { - t.Fatalf("expected repeat reminder only, got %q", snapshot.request.SystemPrompt) + if snapshot.Request.SystemPrompt != selfHealingRepeatReminder { + t.Fatalf("expected repeat reminder only, got %q", snapshot.Request.SystemPrompt) } } -func TestPrepareTurnSnapshotRepeatReminderTakesPriority(t *testing.T) { +func TestPrepareTurnBudgetSnapshotRepeatReminderTakesPriority(t *testing.T) { manager := newRuntimeConfigManager(t) if err := manager.Update(context.Background(), func(cfg *config.Config) error { cfg.Runtime.MaxNoProgressStreak = 3 @@ -398,18 +398,18 @@ func TestPrepareTurnSnapshotRepeatReminderTakesPriority(t *testing.T) { state.progress.LastScore.StalledProgressState = controlplane.StalledProgressStalled state.progress.LastScore.ReminderKind = controlplane.ReminderKindRepeatCycle - snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) + snapshot, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state) if err != nil { - t.Fatalf("prepareTurnSnapshot() error = %v", err) + t.Fatalf("prepareTurnBudgetSnapshot() error = %v", err) } if rebuilt { t.Fatal("expected rebuilt=false") } - if !strings.Contains(snapshot.request.SystemPrompt, selfHealingRepeatReminder) { - t.Fatalf("expected prompt to contain repeat reminder, got %q", snapshot.request.SystemPrompt) + if !strings.Contains(snapshot.Request.SystemPrompt, selfHealingRepeatReminder) { + t.Fatalf("expected prompt to contain repeat reminder, got %q", snapshot.Request.SystemPrompt) } - if strings.Contains(snapshot.request.SystemPrompt, selfHealingReminder) { - t.Fatalf("expected no-progress reminder to be skipped when repeat reminder is injected, got %q", snapshot.request.SystemPrompt) + if strings.Contains(snapshot.Request.SystemPrompt, selfHealingReminder) { + t.Fatalf("expected no-progress reminder to be skipped when repeat reminder is injected, got %q", snapshot.Request.SystemPrompt) } } diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index d1116d2c..dd1c3aab 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -170,7 +170,15 @@ func TestResolveCompactProviderSelectionResolveErrorBranch(t *testing.T) { _ = os.Unsetenv(apiEnv) session := agentsession.Session{Provider: cfg.SelectedProvider, Model: "m1"} - if _, _, err := resolveCompactProviderSelection(session, cfg); err == nil { + resolved, _, err := resolveCompactProviderSelection(session, cfg) + if err != nil { + t.Fatalf("resolveCompactProviderSelection() error = %v", err) + } + runtimeConfig, err := resolved.ToRuntimeConfig() + if err != nil { + t.Fatalf("ToRuntimeConfig() error = %v", err) + } + if _, err := runtimeConfig.ResolveAPIKeyValue(); err == nil { t.Fatalf("expected resolve API key error") } } @@ -281,7 +289,7 @@ func TestGenerateStreamingMessageDrainEventsAfterContextCanceled(t *testing.T) { } } -func TestPrepareTurnSnapshotErrorBranches(t *testing.T) { +func TestPrepareTurnBudgetSnapshotErrorBranches(t *testing.T) { t.Parallel() manager := newRuntimeConfigManager(t) @@ -293,7 +301,7 @@ func TestPrepareTurnSnapshotErrorBranches(t *testing.T) { toolManager: &stubToolManager{}, } state := newRunState("run-snapshot", newRuntimeSession("session-snapshot")) - if _, _, err := service.prepareTurnSnapshot(context.Background(), &state); err == nil { + if _, _, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); err == nil { t.Fatalf("expected build error") } @@ -306,7 +314,7 @@ func TestPrepareTurnSnapshotErrorBranches(t *testing.T) { service.contextBuilder = &stubContextBuilder{buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { return agentcontext.BuildResult{Messages: input.Messages}, nil }} - if _, _, err := service.prepareTurnSnapshot(context.Background(), &state); err == nil { + if _, _, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); err == nil { t.Fatalf("expected resolve selected provider error") } } @@ -481,7 +489,7 @@ func TestExecuteAssistantToolCallsRemainingBranches(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() state := newRunState("run", newRuntimeSession("session-top-cancel")) - _, err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) + _, err := service.executeAssistantToolCalls(ctx, &state, TurnBudgetSnapshot{}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -499,7 +507,7 @@ func TestExecuteAssistantToolCallsRemainingBranches(t *testing.T) { store.sessions[session.ID] = cloneSession(session) service := &Service{events: make(chan RuntimeEvent, 8), approvalBroker: approvalflow.NewBroker(), toolManager: manager, sessionStore: store} state := newRunState("run", session) - _, err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) + _, err := service.executeAssistantToolCalls(ctx, &state, TurnBudgetSnapshot{ToolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -518,7 +526,7 @@ func TestExecuteAssistantToolCallsRemainingBranches(t *testing.T) { service := &Service{events: make(chan RuntimeEvent, 8), approvalBroker: approvalflow.NewBroker(), toolManager: manager, sessionStore: store} state := newRunState("run", session) - _, err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) + _, err := service.executeAssistantToolCalls(ctx, &state, TurnBudgetSnapshot{ToolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -537,7 +545,7 @@ func TestExecuteAssistantToolCallsRemainingBranches(t *testing.T) { service := &Service{events: make(chan RuntimeEvent, 8), approvalBroker: approvalflow.NewBroker(), toolManager: manager, sessionStore: store} state := newRunState("run", session) - _, err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) + _, err := service.executeAssistantToolCalls(ctx, &state, TurnBudgetSnapshot{ToolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -598,7 +606,7 @@ func TestRunAndProviderRetryRemainingBranches(t *testing.T) { }() service := &Service{providerFactory: &scriptedProviderFactory{provider: providerRetry}, events: make(chan RuntimeEvent, 8)} state := newRunState("run-retry-backoff", newRuntimeSession("session-retry-backoff")) - _, err := service.callProviderWithRetry(ctx, &state, turnSnapshot{providerConfig: provider.RuntimeConfig{Name: "x"}}) + _, err := service.callProviderWithRetry(ctx, &state, TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -613,7 +621,7 @@ func TestRunAndProviderRetryRemainingBranches(t *testing.T) { }} service := &Service{providerFactory: &scriptedProviderFactory{provider: providerRetry}, events: make(chan RuntimeEvent, 8)} state := newRunState("run-retry-ctx-check", newRuntimeSession("session-retry-ctx-check")) - _, err := service.callProviderWithRetry(ctx, &state, turnSnapshot{providerConfig: provider.RuntimeConfig{Name: "x"}}) + _, err := service.callProviderWithRetry(ctx, &state, TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 03bfa546..50c1a025 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -40,9 +40,9 @@ type failingStore struct { ignoreContextErr bool } -type autoCompactThresholdResolverFunc func(ctx context.Context, cfg config.Config) (int, error) +type budgetResolverFunc func(ctx context.Context, cfg config.Config) (int, string, error) -func (f autoCompactThresholdResolverFunc) ResolveAutoCompactThreshold(ctx context.Context, cfg config.Config) (int, error) { +func (f budgetResolverFunc) ResolvePromptBudget(ctx context.Context, cfg config.Config) (int, string, error) { return f(ctx, cfg) } @@ -70,7 +70,7 @@ func (s *memoryStore) CreateSession(ctx context.Context, input agentsession.Crea if err := ctx.Err(); err != nil { return agentsession.Session{}, err } - session := agentsession.NewWithWorkdir(input.Title, input.Workdir) + session := agentsession.NewWithWorkdir(input.Title, input.Head.Workdir) if strings.TrimSpace(input.ID) != "" { session.ID = input.ID } @@ -80,13 +80,15 @@ func (s *memoryStore) CreateSession(ctx context.Context, input agentsession.Crea if !input.UpdatedAt.IsZero() { session.UpdatedAt = input.UpdatedAt } - session.Provider = input.Provider - session.Model = input.Model - session.TaskState = input.TaskState.Clone() - session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) - session.Todos = cloneTodosForPersistence(input.Todos) - session.TokenInputTotal = input.TokenInputTotal - session.TokenOutputTotal = input.TokenOutputTotal + head := input.Head + session.Provider = head.Provider + session.Model = head.Model + session.TaskState = head.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(head.ActivatedSkills) + session.Todos = cloneTodosForPersistence(head.Todos) + session.TokenInputTotal = head.TokenInputTotal + session.TokenOutputTotal = head.TokenOutputTotal + session.HasUnknownUsage = head.HasUnknownUsage session.Messages = []providertypes.Message{} s.mu.Lock() @@ -155,6 +157,7 @@ func (s *memoryStore) AppendMessages(ctx context.Context, input agentsession.App session.Workdir = input.Workdir session.TokenInputTotal += input.TokenInputDelta session.TokenOutputTotal += input.TokenOutputDelta + session.HasUnknownUsage = session.HasUnknownUsage || input.HasUnknownUsage s.saves++ s.sessions[input.SessionID] = cloneSession(session) return nil @@ -197,14 +200,16 @@ func (s *memoryStore) UpdateSessionState(ctx context.Context, input agentsession if !input.UpdatedAt.IsZero() { session.UpdatedAt = input.UpdatedAt } - session.Provider = input.Provider - session.Model = input.Model - session.Workdir = input.Workdir - session.TaskState = input.TaskState.Clone() - session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) - session.Todos = cloneTodosForPersistence(input.Todos) - session.TokenInputTotal = input.TokenInputTotal - session.TokenOutputTotal = input.TokenOutputTotal + head := input.Head + session.Provider = head.Provider + session.Model = head.Model + session.Workdir = head.Workdir + session.TaskState = head.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(head.ActivatedSkills) + session.Todos = cloneTodosForPersistence(head.Todos) + session.TokenInputTotal = head.TokenInputTotal + session.TokenOutputTotal = head.TokenOutputTotal + session.HasUnknownUsage = head.HasUnknownUsage s.saves++ s.sessions[input.SessionID] = cloneSession(session) return nil @@ -226,14 +231,16 @@ func (s *memoryStore) ReplaceTranscript(ctx context.Context, input agentsession. if !input.UpdatedAt.IsZero() { session.UpdatedAt = input.UpdatedAt } - session.Provider = input.Provider - session.Model = input.Model - session.Workdir = input.Workdir - session.TaskState = input.TaskState.Clone() - session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) - session.Todos = cloneTodosForPersistence(input.Todos) - session.TokenInputTotal = input.TokenInputTotal - session.TokenOutputTotal = input.TokenOutputTotal + head := input.Head + session.Provider = head.Provider + session.Model = head.Model + session.Workdir = head.Workdir + session.TaskState = head.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(head.ActivatedSkills) + session.Todos = cloneTodosForPersistence(head.Todos) + session.TokenInputTotal = head.TokenInputTotal + session.TokenOutputTotal = head.TokenOutputTotal + session.HasUnknownUsage = head.HasUnknownUsage s.saves++ s.sessions[input.SessionID] = cloneSession(session) return nil @@ -243,7 +250,7 @@ func (s *memoryStore) CleanupExpiredSessions(ctx context.Context, maxAge time.Du return 0, nil } -// CreateSession 转发到底层 Store,并按旧 save 计数规则注入失败。 +// CreateSession 转发到底层 Store,并按当前 save 计数规则注入失败。 func (s *failingStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { if err := s.nextSaveError(ctx); err != nil { return agentsession.Session{}, err @@ -337,7 +344,7 @@ func (s *blockingLoadStore) CreateSession(ctx context.Context, input agentsessio if err := ctx.Err(); err != nil { return agentsession.Session{}, err } - session := agentsession.NewWithWorkdir(input.Title, input.Workdir) + session := agentsession.NewWithWorkdir(input.Title, input.Head.Workdir) if strings.TrimSpace(input.ID) != "" { session.ID = input.ID } @@ -347,13 +354,15 @@ func (s *blockingLoadStore) CreateSession(ctx context.Context, input agentsessio if !input.UpdatedAt.IsZero() { session.UpdatedAt = input.UpdatedAt } - session.Provider = input.Provider - session.Model = input.Model - session.TaskState = input.TaskState.Clone() - session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) - session.Todos = cloneTodosForPersistence(input.Todos) - session.TokenInputTotal = input.TokenInputTotal - session.TokenOutputTotal = input.TokenOutputTotal + head := input.Head + session.Provider = head.Provider + session.Model = head.Model + session.TaskState = head.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(head.ActivatedSkills) + session.Todos = cloneTodosForPersistence(head.Todos) + session.TokenInputTotal = head.TokenInputTotal + session.TokenOutputTotal = head.TokenOutputTotal + session.HasUnknownUsage = head.HasUnknownUsage s.mu.Lock() s.sessions[session.ID] = cloneSession(session) s.mu.Unlock() @@ -407,6 +416,7 @@ func (s *blockingLoadStore) AppendMessages(ctx context.Context, input agentsessi session.Workdir = input.Workdir session.TokenInputTotal += input.TokenInputDelta session.TokenOutputTotal += input.TokenOutputDelta + session.HasUnknownUsage = session.HasUnknownUsage || input.HasUnknownUsage s.sessions[input.SessionID] = cloneSession(session) return nil } @@ -445,14 +455,16 @@ func (s *blockingLoadStore) UpdateSessionState(ctx context.Context, input agents if !input.UpdatedAt.IsZero() { session.UpdatedAt = input.UpdatedAt } - session.Provider = input.Provider - session.Model = input.Model - session.Workdir = input.Workdir - session.TaskState = input.TaskState.Clone() - session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) - session.Todos = cloneTodosForPersistence(input.Todos) - session.TokenInputTotal = input.TokenInputTotal - session.TokenOutputTotal = input.TokenOutputTotal + head := input.Head + session.Provider = head.Provider + session.Model = head.Model + session.Workdir = head.Workdir + session.TaskState = head.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(head.ActivatedSkills) + session.Todos = cloneTodosForPersistence(head.Todos) + session.TokenInputTotal = head.TokenInputTotal + session.TokenOutputTotal = head.TokenOutputTotal + session.HasUnknownUsage = head.HasUnknownUsage s.sessions[input.SessionID] = cloneSession(session) return nil } @@ -472,14 +484,16 @@ func (s *blockingLoadStore) ReplaceTranscript(ctx context.Context, input agentse if !input.UpdatedAt.IsZero() { session.UpdatedAt = input.UpdatedAt } - session.Provider = input.Provider - session.Model = input.Model - session.Workdir = input.Workdir - session.TaskState = input.TaskState.Clone() - session.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) - session.Todos = cloneTodosForPersistence(input.Todos) - session.TokenInputTotal = input.TokenInputTotal - session.TokenOutputTotal = input.TokenOutputTotal + head := input.Head + session.Provider = head.Provider + session.Model = head.Model + session.Workdir = head.Workdir + session.TaskState = head.TaskState.Clone() + session.ActivatedSkills = agentsessionCloneSkillActivations(head.ActivatedSkills) + session.Todos = cloneTodosForPersistence(head.Todos) + session.TokenInputTotal = head.TokenInputTotal + session.TokenOutputTotal = head.TokenOutputTotal + session.HasUnknownUsage = head.HasUnknownUsage s.sessions[input.SessionID] = cloneSession(session) return nil } @@ -507,12 +521,28 @@ func (s *blockingLoadStore) CleanupExpiredSessions(ctx context.Context, maxAge t } type scriptedProvider struct { - name string - streams [][]providertypes.StreamEvent - responses []scriptedResponse - requests []providertypes.GenerateRequest - callCount int - chatFn func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error + name string + streams [][]providertypes.StreamEvent + responses []scriptedResponse + requests []providertypes.GenerateRequest + callCount int + estimateFn func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) + chatFn func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error +} + +func (p *scriptedProvider) EstimateInputTokens( + ctx context.Context, + req providertypes.GenerateRequest, +) (providertypes.BudgetEstimate, error) { + if p.estimateFn != nil { + return p.estimateFn(ctx, req) + } + _ = ctx + return providertypes.BudgetEstimate{ + EstimatedInputTokens: provider.EstimateTextTokens(req.SystemPrompt + renderMessagesForEstimate(req.Messages)), + EstimateSource: provider.EstimateSourceLocal, + Accurate: false, + }, nil } type scriptedResponse struct { @@ -587,6 +617,14 @@ func streamContainsMessageDone(events []providertypes.StreamEvent) bool { return false } +func renderMessagesForEstimate(messages []providertypes.Message) string { + var builder strings.Builder + for _, message := range messages { + builder.WriteString(provider.RenderMessageText(message.Parts)) + } + return builder.String() +} + type scriptedProviderFactory struct { provider provider.Provider calls int @@ -984,8 +1022,9 @@ func TestServiceRun(t *testing.T) { t.Fatalf("Run() error = %v", err) } - if factory.calls != tt.expectProviderCalls { - t.Fatalf("expected %d provider builds, got %d", tt.expectProviderCalls, factory.calls) + expectedProviderBuilds := tt.expectProviderCalls * 2 + if factory.calls != expectedProviderBuilds { + t.Fatalf("expected %d provider builds, got %d", expectedProviderBuilds, factory.calls) } if registeredTool != nil && registeredTool.callCount != tt.expectToolCalls { t.Fatalf("expected %d tool executes, got %d", tt.expectToolCalls, registeredTool.callCount) @@ -3154,7 +3193,7 @@ func TestServiceConstructorsAndDelegates(t *testing.T) { t.Fatalf("expected loaded session %q, got %q", session.ID, loaded.ID) } - sessionStore := agentsession.NewStore(t.TempDir(), t.TempDir()) + sessionStore := agentsession.NewSQLiteStore(t.TempDir(), t.TempDir()) if sessionStore == nil { t.Fatalf("expected JSON session store") } @@ -3385,7 +3424,7 @@ func collectRuntimeEvents(events <-chan RuntimeEvent) []RuntimeEvent { } } -// isPermissionRequestEvent 判断是否为权限请求类事件(含 1A 主事件与兼容旧名)。 +// isPermissionRequestEvent 判断是否为权限请求类事件。 func isPermissionRequestEvent(typ EventType) bool { return typ == EventPermissionRequested } @@ -3745,9 +3784,9 @@ func TestCallProviderWithRetryReturnsCombinedForwardError(t *testing.T) { service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, nil) state := newRunState("run-forward-error", agentsession.Session{ID: "session-forward-error"}) - snapshot := turnSnapshot{ - providerConfig: provider.RuntimeConfig{}, - request: providertypes.GenerateRequest{ + snapshot := TurnBudgetSnapshot{ + ProviderConfig: provider.RuntimeConfig{}, + Request: providertypes.GenerateRequest{ Model: "test-model", SystemPrompt: "prompt", Messages: []providertypes.Message{{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, @@ -3900,286 +3939,6 @@ func TestServiceRunPersistsAndRestoresTokenUsage(t *testing.T) { } } -func TestServiceRunAutoCompactsAndResetsSessionTokens(t *testing.T) { - t.Parallel() - - manager := newRuntimeConfigManager(t) - if err := manager.Update(context.Background(), func(cfg *config.Config) error { - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 100 - return nil - }); err != nil { - t.Fatalf("update config: %v", err) - } - - store := newMemoryStore() - session := agentsession.New("auto-compact") - session.ID = "session-auto-compact" - session.TokenInputTotal = 100 - session.TokenOutputTotal = 40 - session.Messages = []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older request")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older answer")}}, - } - store.sessions[session.ID] = cloneSession(session) - - registry := tools.NewRegistry() - tool := &stubTool{name: "filesystem_read_file", content: "file content"} - registry.Register(tool) - - builder := &stubContextBuilder{ - buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { - return agentcontext.BuildResult{ - SystemPrompt: "auto compact prompt", - Messages: append([]providertypes.Message(nil), input.Messages...), - AutoCompactSuggested: input.Metadata.SessionInputTokens >= input.Compact.AutoCompactThreshold, - }, nil - }, - } - scripted := &scriptedProvider{ - responses: []scriptedResponse{ - { - Message: providertypes.Message{ - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"main.go"}`}, - }, - }, - FinishReason: "tool_calls", - }, - { - Message: providertypes.Message{Parts: []providertypes.ContentPart{providertypes.NewTextPart("done")}}, - FinishReason: "stop", - }, - }, - } - - service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, builder) - compactRunner := &stubCompactRunner{ - result: contextcompact.Result{ - Messages: []providertypes.Message{ - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("[compact_summary]\ndone:\n- archived\n\nin_progress:\n- continue")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest answer")}}, - }, - Applied: true, - Metrics: contextcompact.Metrics{ - BeforeChars: 60, - AfterChars: 24, - SavedRatio: 0.6, - TriggerMode: string(contextcompact.ModeAuto), - }, - TranscriptID: "transcript_auto", - TranscriptPath: "/tmp/auto.jsonl", - }, - } - service.compactRunner = compactRunner - - if err := service.Run(context.Background(), UserInput{ - SessionID: session.ID, - RunID: "run-auto-compact", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, - }); err != nil { - t.Fatalf("Run() error = %v", err) - } - - if len(compactRunner.calls) != 1 { - t.Fatalf("expected auto compact to run once, got %d", len(compactRunner.calls)) - } - if compactRunner.calls[0].Mode != contextcompact.ModeAuto { - t.Fatalf("expected compact mode %q, got %q", contextcompact.ModeAuto, compactRunner.calls[0].Mode) - } - if len(builder.builds) != 3 { - t.Fatalf("expected 3 build attempts, got %d", len(builder.builds)) - } - if builder.builds[0].Metadata.SessionInputTokens != 100 { - t.Fatalf("expected first build to see pre-compact tokens, got %d", builder.builds[0].Metadata.SessionInputTokens) - } - if builder.builds[0].Metadata.SessionOutputTokens != 40 { - t.Fatalf("expected first build to see pre-compact output tokens, got %d", builder.builds[0].Metadata.SessionOutputTokens) - } - if builder.builds[0].Compact.AutoCompactThreshold != 100 { - t.Fatalf("expected auto compact threshold 100, got %d", builder.builds[0].Compact.AutoCompactThreshold) - } - if builder.builds[1].Metadata.SessionInputTokens != 0 { - t.Fatalf("expected second build to see reset input tokens, got %d", builder.builds[1].Metadata.SessionInputTokens) - } - if builder.builds[1].Metadata.SessionOutputTokens != 0 { - t.Fatalf("expected second build to see reset output tokens, got %d", builder.builds[1].Metadata.SessionOutputTokens) - } - if len(scripted.requests) != 2 { - t.Fatalf("expected 2 provider requests after tool follow-up, got %d", len(scripted.requests)) - } - if len(scripted.requests[0].Messages) != 2 { - t.Fatalf("expected rebuilt compacted context to be sent, got %+v", scripted.requests[0].Messages) - } - if renderPartsForTest(scripted.requests[0].Messages[0].Parts) != "[compact_summary]\ndone:\n- archived\n\nin_progress:\n- continue" { - t.Fatalf("expected first provider request to use compact summary, got %+v", scripted.requests[0].Messages) - } - if renderPartsForTest(scripted.requests[0].Messages[1].Parts) != "latest answer" { - t.Fatalf("expected first provider request to use compacted latest answer, got %+v", scripted.requests[0].Messages) - } - - saved, err := store.Load(context.Background(), session.ID) - if err != nil { - t.Fatalf("load compacted session: %v", err) - } - if saved.TokenInputTotal != 0 { - t.Fatalf("expected persisted input tokens to reset, got %d", saved.TokenInputTotal) - } - if saved.TokenOutputTotal != 0 { - t.Fatalf("expected persisted output tokens to reset, got %d", saved.TokenOutputTotal) - } - if tool.callCount != 1 { - t.Fatalf("expected tool to execute once, got %d", tool.callCount) - } - - events := collectRuntimeEvents(service.Events()) - assertEventSequence(t, events, []EventType{ - EventUserMessage, - EventCompactStart, - EventCompactApplied, - EventToolStart, - EventToolResult, - EventAgentDone, - }) - assertNoEventType(t, events, EventCompactError) - - foundAutoDone := false - for _, event := range events { - if event.Type != EventCompactApplied { - continue - } - payload, ok := event.Payload.(CompactResult) - if !ok { - t.Fatalf("expected CompactResult, got %T", event.Payload) - } - if payload.TriggerMode != string(contextcompact.ModeAuto) { - t.Fatalf("expected trigger mode %q, got %q", contextcompact.ModeAuto, payload.TriggerMode) - } - foundAutoDone = true - } - if !foundAutoDone { - t.Fatalf("expected auto compact_done event in %+v", events) - } -} - -func TestServiceRunAutoCompactNoopDoesNotDisableReactiveRetry(t *testing.T) { - t.Parallel() - - manager := newRuntimeConfigManager(t) - if err := manager.Update(context.Background(), func(cfg *config.Config) error { - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 100 - return nil - }); err != nil { - t.Fatalf("update config: %v", err) - } - - store := newMemoryStore() - session := agentsession.New("auto-noop-reactive") - session.ID = "session-auto-noop-reactive" - session.TokenInputTotal = 100 - session.Messages = []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older request")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older answer")}}, - } - store.sessions[session.ID] = cloneSession(session) - - registry := tools.NewRegistry() - registry.Register(&stubTool{name: "filesystem_read_file", content: "default"}) - - builder := &stubContextBuilder{ - buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { - return agentcontext.BuildResult{ - SystemPrompt: "auto compact prompt", - Messages: append([]providertypes.Message(nil), input.Messages...), - AutoCompactSuggested: input.Metadata.SessionInputTokens >= input.Compact.AutoCompactThreshold, - }, nil - }, - } - - callCount := 0 - scripted := &scriptedProvider{ - chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - callCount++ - if callCount == 1 { - return &provider.ProviderError{ - StatusCode: 400, - Code: provider.ErrorCodeContextTooLong, - Message: "maximum context length exceeded", - } - } - select { - case events <- providertypes.NewTextDeltaStreamEvent("recovered after reactive compact"): - case <-ctx.Done(): - return ctx.Err() - } - select { - case events <- providertypes.NewMessageDoneStreamEvent("stop", nil): - case <-ctx.Done(): - return ctx.Err() - } - return nil - }, - } - - service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, builder) - compactRunner := &stubCompactRunner{ - runFn: func(ctx context.Context, input contextcompact.Input) (contextcompact.Result, error) { - switch input.Mode { - case contextcompact.ModeAuto: - return contextcompact.Result{ - Messages: append([]providertypes.Message(nil), input.Messages...), - Applied: false, - Metrics: contextcompact.Metrics{ - BeforeChars: 40, - AfterChars: 40, - TriggerMode: string(contextcompact.ModeAuto), - }, - }, nil - case contextcompact.ModeReactive: - return contextcompact.Result{ - Messages: []providertypes.Message{ - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("[compact_summary]\ndone:\n- archived\n\nin_progress:\n- continue")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}}, - }, - Applied: true, - Metrics: contextcompact.Metrics{ - BeforeChars: 80, - AfterChars: 30, - SavedRatio: 0.625, - TriggerMode: string(contextcompact.ModeReactive), - }, - }, nil - default: - t.Fatalf("unexpected compact mode %q", input.Mode) - return contextcompact.Result{}, nil - } - }, - } - service.compactRunner = compactRunner - - if err := service.Run(context.Background(), UserInput{ - SessionID: session.ID, - RunID: "run-auto-noop-reactive", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, - }); err != nil { - t.Fatalf("Run() error = %v", err) - } - - if len(compactRunner.calls) != 2 { - t.Fatalf("expected auto noop then reactive compact, got %d calls", len(compactRunner.calls)) - } - if compactRunner.calls[0].Mode != contextcompact.ModeAuto { - t.Fatalf("expected first compact mode %q, got %q", contextcompact.ModeAuto, compactRunner.calls[0].Mode) - } - if compactRunner.calls[1].Mode != contextcompact.ModeReactive { - t.Fatalf("expected second compact mode %q, got %q", contextcompact.ModeReactive, compactRunner.calls[1].Mode) - } - if scripted.callCount != 2 { - t.Fatalf("expected provider to be called twice, got %d", scripted.callCount) - } -} - func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) { t.Parallel() @@ -4284,8 +4043,11 @@ func TestServiceRunReactivelyCompactsOnContextTooLong(t *testing.T) { if err != nil { t.Fatalf("load compacted session: %v", err) } - if saved.TokenInputTotal != 0 || saved.TokenOutputTotal != 0 { - t.Fatalf("expected persisted token totals to reset, got input=%d output=%d", saved.TokenInputTotal, saved.TokenOutputTotal) + if saved.TokenInputTotal == 0 || saved.TokenOutputTotal != 0 { + t.Fatalf("expected post-compact run to persist estimated input and zero output, got input=%d output=%d", saved.TokenInputTotal, saved.TokenOutputTotal) + } + if !saved.HasUnknownUsage { + t.Fatalf("expected missing post-compact usage to mark HasUnknownUsage") } if len(saved.Messages) != 3 { t.Fatalf("expected compacted transcript plus final assistant reply, got %+v", saved.Messages) @@ -4402,6 +4164,95 @@ func TestServiceRunReactiveCompactRetriesWithinSameRun(t *testing.T) { assertNoEventType(t, events, EventError) } +func TestServiceRunReactiveCompactLimitAppliesAcrossTurns(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.MaxReactiveCompacts = 1 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + session := agentsession.New("reactive-run-limit") + session.ID = "session-reactive-run-limit" + store.sessions[session.ID] = cloneSession(session) + + registry := tools.NewRegistry() + registry.Register(&stubTool{name: "filesystem_read_file", content: "tool output"}) + + callCount := 0 + contextTooLongErr := &provider.ProviderError{ + StatusCode: 400, + Code: provider.ErrorCodeContextTooLong, + Message: "maximum context length exceeded", + } + scripted := &scriptedProvider{ + chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + callCount++ + switch callCount { + case 1: + return contextTooLongErr + case 2: + toolCall := providertypes.ToolCall{ + ID: "call-read", + Name: "filesystem_read_file", + Arguments: `{}`, + } + select { + case events <- providertypes.NewToolCallStartStreamEvent(0, toolCall.ID, toolCall.Name): + case <-ctx.Done(): + return ctx.Err() + } + select { + case events <- providertypes.NewToolCallDeltaStreamEvent(0, toolCall.ID, toolCall.Arguments): + case <-ctx.Done(): + return ctx.Err() + } + select { + case events <- providertypes.NewMessageDoneStreamEvent("tool_calls", nil): + case <-ctx.Done(): + return ctx.Err() + } + return nil + default: + return contextTooLongErr + } + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + service.compactRunner = &stubCompactRunner{ + result: contextcompact.Result{ + Applied: true, + Messages: []providertypes.Message{ + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("[compact_summary]\ndone:\n- archived\n\nin_progress:\n- continue")}}, + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}}, + }, + Metrics: contextcompact.Metrics{TriggerMode: string(contextcompact.ModeReactive)}, + }, + } + + err := service.Run(context.Background(), UserInput{ + SessionID: session.ID, + RunID: "run-reactive-run-limit", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }) + if !provider.IsContextTooLong(err) { + t.Fatalf("expected second turn context-too-long error after run-level limit, got %v", err) + } + + compactRunner := service.compactRunner.(*stubCompactRunner) + if len(compactRunner.calls) != 1 { + t.Fatalf("expected reactive compact limit to allow one compact for the whole run, got %d", len(compactRunner.calls)) + } + if callCount != 3 { + t.Fatalf("expected provider to stop on second context-too-long turn, got %d calls", callCount) + } +} + func TestServiceRunReactiveCompactDegradesUpToMaxAttempts(t *testing.T) { t.Parallel() @@ -4550,302 +4401,255 @@ func TestRestoreSessionTokensNewSession(t *testing.T) { } } -func TestAutoCompactThresholdEnabled(t *testing.T) { +func TestResolvePromptBudgetUsesExplicitConfig(t *testing.T) { t.Parallel() service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 50000, + Budget: config.BudgetConfig{ + PromptBudget: 50000, + ReserveTokens: 13000, + FallbackPromptBudget: 88000, + MaxReactiveCompacts: 3, }, }, } - threshold := service.autoCompactThreshold(context.Background(), cfg) - if threshold != 50000 { - t.Fatalf("expected threshold == 50000, got %d", threshold) + promptBudget, source := service.resolvePromptBudget(context.Background(), cfg) + if promptBudget != 50000 || source != "explicit" { + t.Fatalf("expected prompt budget 50000/explicit, got %d/%s", promptBudget, source) } } -func TestAutoCompactThresholdDisabled(t *testing.T) { +func TestResolvePromptBudgetUsesResolver(t *testing.T) { t.Parallel() service := &Service{} - cfg := config.Config{ - Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: false, - InputTokenThreshold: 50000, - }, + service.SetBudgetResolver(budgetResolverFunc( + func(ctx context.Context, cfg config.Config) (int, string, error) { + return 88000, "derived", nil }, - } - - threshold := service.autoCompactThreshold(context.Background(), cfg) - if threshold != 0 { - t.Fatalf("expected threshold == 0, got %d", threshold) - } -} - -func TestAutoCompactThresholdZeroValue(t *testing.T) { - t.Parallel() + )) - service := &Service{} cfg := config.Config{ Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, + Budget: config.BudgetConfig{ + PromptBudget: 0, + ReserveTokens: 13000, + FallbackPromptBudget: 76000, + MaxReactiveCompacts: 3, }, }, } - threshold := service.autoCompactThreshold(context.Background(), cfg) - if threshold != 0 { - t.Fatalf("expected threshold == 0, got %d", threshold) + promptBudget, source := service.resolvePromptBudget(context.Background(), cfg) + if promptBudget != 88000 || source != "derived" { + t.Fatalf("expected prompt budget 88000/derived, got %d/%s", promptBudget, source) } } -func TestAutoCompactThresholdUsesResolver(t *testing.T) { +func TestResolvePromptBudgetFallsBackWhenResolverErrors(t *testing.T) { t.Parallel() service := &Service{} - service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( - func(ctx context.Context, cfg config.Config) (int, error) { - return 88000, nil + service.SetBudgetResolver(budgetResolverFunc( + func(ctx context.Context, cfg config.Config) (int, string, error) { + return 0, "", errors.New("resolver failed") }, )) cfg := config.Config{ Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, + Budget: config.BudgetConfig{ + PromptBudget: 0, + ReserveTokens: 13000, + FallbackPromptBudget: 88000, + MaxReactiveCompacts: 3, }, }, } - threshold := service.autoCompactThreshold(context.Background(), cfg) - if threshold != 88000 { - t.Fatalf("expected resolver threshold == 88000, got %d", threshold) + promptBudget, source := service.resolvePromptBudget(context.Background(), cfg) + if promptBudget != 88000 || source != "fallback" { + t.Fatalf("expected prompt budget 88000/fallback, got %d/%s", promptBudget, source) } } -func TestAutoCompactThresholdFallsBackWhenResolverErrors(t *testing.T) { +func TestServiceRunStopsWhenBudgetStillExceededAfterProactiveCompact(t *testing.T) { t.Parallel() - service := &Service{} - service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( - func(ctx context.Context, cfg config.Config) (int, error) { - return 0, errors.New("resolver failed") - }, - )) - - cfg := config.Config{ - Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - FallbackInputTokenThreshold: 88000, - }, - }, - } - - threshold := service.autoCompactThreshold(context.Background(), cfg) - if threshold != 88000 { - t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 10 + cfg.Context.Budget.FallbackPromptBudget = 10 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) } -} -func TestAutoCompactThresholdFallsBackWhenResolverReturnsZeroWithoutError(t *testing.T) { - t.Parallel() - - service := &Service{} - service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( - func(ctx context.Context, cfg config.Config) (int, error) { - return 0, nil + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{ + EstimatedInputTokens: 99, + EstimateSource: provider.EstimateSourceLocal, + Accurate: false, + }, nil }, - )) - - cfg := config.Config{ - Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - FallbackInputTokenThreshold: 88000, - }, + chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { + t.Fatalf("Generate should not be called when budget decision stops before send") + return nil }, } - threshold := service.autoCompactThreshold(context.Background(), cfg) - if threshold != 88000 { - t.Fatalf("expected fallback threshold == 88000, got %d", threshold) - } -} - -func TestAutoCompactThresholdFallsBackWhenResolverReturnsNegativeWithoutError(t *testing.T) { - t.Parallel() - - service := &Service{} - service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( - func(ctx context.Context, cfg config.Config) (int, error) { - return -1, nil - }, - )) - - cfg := config.Config{ - Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - FallbackInputTokenThreshold: 88000, + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + service.compactRunner = &stubCompactRunner{ + result: contextcompact.Result{ + Applied: true, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("[compact_summary]\ndone:\n- archived\n\nin_progress:\n- continue")}, + }, + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }, + }, + Metrics: contextcompact.Metrics{ + TriggerMode: string(contextcompact.ModeProactive), }, }, } - threshold := service.autoCompactThreshold(context.Background(), cfg) - if threshold != 88000 { - t.Fatalf("expected fallback threshold == 88000, got %d", threshold) + if err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-stop", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) } -} - -func TestAutoCompactThresholdImplicitModeWithoutResolverUsesFallback(t *testing.T) { - t.Parallel() - service := &Service{} - cfg := config.Config{ - Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - FallbackInputTokenThreshold: 88000, - }, - }, + compactRunner := service.compactRunner.(*stubCompactRunner) + if len(compactRunner.calls) != 1 { + t.Fatalf("expected one proactive compact, got %d", len(compactRunner.calls)) } - - threshold := service.autoCompactThreshold(context.Background(), cfg) - if threshold != 88000 { - t.Fatalf("expected implicit mode fallback threshold == 88000, got %d", threshold) + if compactRunner.calls[0].Mode != contextcompact.ModeProactive { + t.Fatalf("expected compact mode %q, got %q", contextcompact.ModeProactive, compactRunner.calls[0].Mode) } -} - -func TestAutoCompactThresholdForStateCachesResolverResultWithinRun(t *testing.T) { - t.Parallel() - - service := &Service{} - resolveCalls := 0 - service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( - func(ctx context.Context, cfg config.Config) (int, error) { - resolveCalls++ - return 88000, nil - }, - )) - - cfg := config.Config{ - SelectedProvider: "openai", - CurrentModel: "gpt-5", - Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - ReserveTokens: 10000, - FallbackInputTokenThreshold: 76000, - }, - }, + if scripted.callCount != 0 { + t.Fatalf("expected provider Generate to be skipped, got %d calls", scripted.callCount) } - state := newRunState("run-cache-hit", newRuntimeSession("session-cache-hit")) - threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) - threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + events := collectRuntimeEvents(service.Events()) + var budgetActions []string + var stopPayload StopReasonDecidedPayload + for _, event := range events { + switch event.Type { + case EventBudgetChecked: + payload, ok := event.Payload.(BudgetCheckedPayload) + if !ok { + t.Fatalf("expected BudgetCheckedPayload, got %T", event.Payload) + } + budgetActions = append(budgetActions, payload.Action) + case EventStopReasonDecided: + payload, ok := event.Payload.(StopReasonDecidedPayload) + if !ok { + t.Fatalf("expected StopReasonDecidedPayload, got %T", event.Payload) + } + stopPayload = payload + } + } - if threshold1 != 88000 || threshold2 != 88000 { - t.Fatalf("expected cached resolver threshold == 88000, got %d and %d", threshold1, threshold2) + if len(budgetActions) != 2 || budgetActions[0] != "compact" || budgetActions[1] != "stop" { + t.Fatalf("expected budget actions [compact stop], got %v", budgetActions) } - if resolveCalls != 1 { - t.Fatalf("expected resolver to be called once, got %d", resolveCalls) + if stopPayload.Reason != controlplane.StopReasonBudgetExceeded { + t.Fatalf("expected stop reason %q, got %q", controlplane.StopReasonBudgetExceeded, stopPayload.Reason) } } -func TestAutoCompactThresholdForStateRecomputesWhenCacheKeyChanges(t *testing.T) { +func TestServiceRunReconcilesUnknownOutputUsage(t *testing.T) { t.Parallel() - service := &Service{} - resolveCalls := 0 - service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( - func(ctx context.Context, cfg config.Config) (int, error) { - resolveCalls++ - if strings.TrimSpace(cfg.CurrentModel) == "gpt-5.1" { - return 99000, nil - } - return 88000, nil + manager := newRuntimeConfigManager(t) + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{ + EstimatedInputTokens: 17, + EstimateSource: provider.EstimateSourceLocal, + Accurate: false, + }, nil }, - )) - - cfg := config.Config{ - SelectedProvider: "openai", - CurrentModel: "gpt-5", - Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - ReserveTokens: 10000, - FallbackInputTokenThreshold: 76000, + responses: []scriptedResponse{ + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("done")}, + }, + FinishReason: "stop", }, }, } - state := newRunState("run-cache-miss", newRuntimeSession("session-cache-miss")) - threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) - cfg.CurrentModel = "gpt-5.1" - threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + if err := service.Run(context.Background(), UserInput{ + RunID: "run-unknown-usage", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } - if threshold1 != 88000 || threshold2 != 99000 { - t.Fatalf("expected thresholds [88000, 99000], got [%d, %d]", threshold1, threshold2) + saved := onlySession(t, store) + if saved.TokenInputTotal != 17 || saved.TokenOutputTotal != 0 { + t.Fatalf("expected estimated input / zero persisted output, got in=%d out=%d", saved.TokenInputTotal, saved.TokenOutputTotal) } - if resolveCalls != 2 { - t.Fatalf("expected resolver to be called twice after key change, got %d", resolveCalls) + if !saved.HasUnknownUsage { + t.Fatalf("expected session HasUnknownUsage to be persisted") } -} - -func TestAutoCompactThresholdForStateDoesNotCacheResolverErrorFallback(t *testing.T) { - t.Parallel() - service := &Service{} - resolveCalls := 0 - service.SetAutoCompactThresholdResolver(autoCompactThresholdResolverFunc( - func(ctx context.Context, cfg config.Config) (int, error) { - resolveCalls++ - if resolveCalls == 1 { - return 0, errors.New("snapshot unavailable") + events := collectRuntimeEvents(service.Events()) + var ledgerPayload LedgerReconciledPayload + var tokenPayload TokenUsagePayload + foundLedger := false + foundUsage := false + for _, event := range events { + switch event.Type { + case EventLedgerReconciled: + payload, ok := event.Payload.(LedgerReconciledPayload) + if !ok { + t.Fatalf("expected LedgerReconciledPayload, got %T", event.Payload) } - return 91000, nil - }, - )) - - cfg := config.Config{ - SelectedProvider: "openai", - CurrentModel: "gpt-5", - Context: config.ContextConfig{ - AutoCompact: config.AutoCompactConfig{ - Enabled: true, - InputTokenThreshold: 0, - ReserveTokens: 10000, - FallbackInputTokenThreshold: 76000, - }, - }, + ledgerPayload = payload + foundLedger = true + case EventTokenUsage: + payload, ok := event.Payload.(TokenUsagePayload) + if !ok { + t.Fatalf("expected TokenUsagePayload, got %T", event.Payload) + } + tokenPayload = payload + foundUsage = true + } } - state := newRunState("run-cache-error", newRuntimeSession("session-cache-error")) - threshold1 := service.autoCompactThresholdForState(context.Background(), cfg, &state) - threshold2 := service.autoCompactThresholdForState(context.Background(), cfg, &state) - threshold3 := service.autoCompactThresholdForState(context.Background(), cfg, &state) - - if threshold1 != 76000 || threshold2 != 91000 || threshold3 != 91000 { - t.Fatalf("expected thresholds [76000, 91000, 91000], got [%d, %d, %d]", threshold1, threshold2, threshold3) + if !foundLedger { + t.Fatalf("expected ledger_reconciled event") + } + if ledgerPayload.InputSource != usageSourceEstimated || ledgerPayload.OutputSource != usageSourceUnknown || !ledgerPayload.HasUnknownUsage { + t.Fatalf("unexpected ledger payload: %+v", ledgerPayload) + } + if !foundUsage { + t.Fatalf("expected token_usage event") } - if resolveCalls != 2 { - t.Fatalf("expected resolver to be called twice, got %d", resolveCalls) + if tokenPayload.InputSource != usageSourceEstimated || tokenPayload.OutputSource != usageSourceUnknown || !tokenPayload.HasUnknownUsage { + t.Fatalf("unexpected token payload: %+v", tokenPayload) } } diff --git a/internal/runtime/session_mutation.go b/internal/runtime/session_mutation.go index 6673bf60..78044495 100644 --- a/internal/runtime/session_mutation.go +++ b/internal/runtime/session_mutation.go @@ -39,15 +39,20 @@ func (s *Service) appendUserMessageAndSave(ctx context.Context, state *runState, func (s *Service) appendAssistantMessageAndSave( ctx context.Context, state *runState, - snapshot turnSnapshot, + snapshot TurnBudgetSnapshot, assistant providertypes.Message, inputTokens int, outputTokens int, ) error { - metadataChanged := state.session.Provider != snapshot.providerConfig.Name || state.session.Model != snapshot.model - state.session.Provider = snapshot.providerConfig.Name - state.session.Model = snapshot.model + metadataChanged := state.session.Provider != snapshot.ProviderConfig.Name || state.session.Model != snapshot.Model + unknownUsageChanged := false + state.session.Provider = snapshot.ProviderConfig.Name + state.session.Model = snapshot.Model + previousUnknownUsage := state.session.HasUnknownUsage state.recordUsage(inputTokens, outputTokens) + if state.session.HasUnknownUsage != previousUnknownUsage { + unknownUsageChanged = true + } if !assistant.IsEmpty() { state.session.Messages = append(state.session.Messages, assistant) @@ -61,10 +66,11 @@ func (s *Service) appendAssistantMessageAndSave( Workdir: state.session.Workdir, TokenInputDelta: inputTokens, TokenOutputDelta: outputTokens, + HasUnknownUsage: state.session.HasUnknownUsage, }) } - if metadataChanged || inputTokens != 0 || outputTokens != 0 { + if metadataChanged || unknownUsageChanged || inputTokens != 0 || outputTokens != 0 { state.touchSession() return s.sessionStore.UpdateSessionState(ctx, sessionStateInputFromSession(state.session)) } @@ -83,12 +89,13 @@ func (s *Service) appendToolMessageAndSave( state.session.Messages = append(state.session.Messages, toolMessage) state.touchSession() input := agentsession.AppendMessagesInput{ - SessionID: state.session.ID, - Messages: []providertypes.Message{toolMessage}, - UpdatedAt: state.session.UpdatedAt, - Provider: state.session.Provider, - Model: state.session.Model, - Workdir: state.session.Workdir, + SessionID: state.session.ID, + Messages: []providertypes.Message{toolMessage}, + UpdatedAt: state.session.UpdatedAt, + Provider: state.session.Provider, + Model: state.session.Model, + Workdir: state.session.Workdir, + HasUnknownUsage: state.session.HasUnknownUsage, } state.mu.Unlock() return s.sessionStore.AppendMessages(ctx, input) @@ -255,52 +262,31 @@ func parseFloatExitCode(value float64) (int, bool) { // createSessionInputFromSession 将运行态 session 转为建库时使用的会话头输入。 func createSessionInputFromSession(session agentsession.Session) agentsession.CreateSessionInput { return agentsession.CreateSessionInput{ - ID: session.ID, - Title: session.Title, - CreatedAt: session.CreatedAt, - UpdatedAt: session.UpdatedAt, - Provider: session.Provider, - Model: session.Model, - Workdir: session.Workdir, - TaskState: session.TaskState.Clone(), - ActivatedSkills: agentsessionCloneSkillActivations(session.ActivatedSkills), - Todos: cloneTodosForPersistence(session.Todos), - TokenInputTotal: session.TokenInputTotal, - TokenOutputTotal: session.TokenOutputTotal, + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Head: session.HeadSnapshot(), } } // sessionStateInputFromSession 将运行态 session 映射为只更新会话头的持久化输入。 func sessionStateInputFromSession(session agentsession.Session) agentsession.UpdateSessionStateInput { return agentsession.UpdateSessionStateInput{ - SessionID: session.ID, - Title: session.Title, - UpdatedAt: session.UpdatedAt, - Provider: session.Provider, - Model: session.Model, - Workdir: session.Workdir, - TaskState: session.TaskState.Clone(), - ActivatedSkills: agentsessionCloneSkillActivations(session.ActivatedSkills), - Todos: cloneTodosForPersistence(session.Todos), - TokenInputTotal: session.TokenInputTotal, - TokenOutputTotal: session.TokenOutputTotal, + SessionID: session.ID, + Title: session.Title, + UpdatedAt: session.UpdatedAt, + Head: session.HeadSnapshot(), } } // replaceTranscriptInputFromSession 将完整 session 映射为 transcript 原子替换输入。 func replaceTranscriptInputFromSession(session agentsession.Session) agentsession.ReplaceTranscriptInput { return agentsession.ReplaceTranscriptInput{ - SessionID: session.ID, - Messages: cloneMessagesForPersistence(session.Messages), - UpdatedAt: session.UpdatedAt, - Provider: session.Provider, - Model: session.Model, - Workdir: session.Workdir, - TaskState: session.TaskState.Clone(), - ActivatedSkills: agentsessionCloneSkillActivations(session.ActivatedSkills), - Todos: cloneTodosForPersistence(session.Todos), - TokenInputTotal: session.TokenInputTotal, - TokenOutputTotal: session.TokenOutputTotal, + SessionID: session.ID, + Messages: cloneMessagesForPersistence(session.Messages), + UpdatedAt: session.UpdatedAt, + Head: session.HeadSnapshot(), } } diff --git a/internal/runtime/skills_test.go b/internal/runtime/skills_test.go index d28d8662..d181ce26 100644 --- a/internal/runtime/skills_test.go +++ b/internal/runtime/skills_test.go @@ -9,13 +9,9 @@ import ( "strings" "testing" - "neo-code/internal/config" - agentcontext "neo-code/internal/context" - contextcompact "neo-code/internal/context/compact" providertypes "neo-code/internal/provider/types" agentsession "neo-code/internal/session" "neo-code/internal/skills" - "neo-code/internal/tools" ) type stubSkillsRegistry struct { @@ -215,8 +211,8 @@ func TestPrepareTurnSnapshotPassesResolvedSkillsToContextBuilder(t *testing.T) { }) state := newRunState("run-build-skill", session) - if _, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state); err != nil { - t.Fatalf("prepareTurnSnapshot() error = %v", err) + if _, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); err != nil { + t.Fatalf("prepareTurnBudgetSnapshot() error = %v", err) } else if rebuilt { t.Fatalf("did not expect snapshot rebuild") } @@ -225,7 +221,7 @@ func TestPrepareTurnSnapshotPassesResolvedSkillsToContextBuilder(t *testing.T) { } } -func TestPrepareTurnSnapshotEmitsSkillMissingAndContinues(t *testing.T) { +func TestPrepareTurnBudgetSnapshotEmitsSkillMissingAndContinues(t *testing.T) { t.Parallel() manager := newRuntimeConfigManager(t) @@ -238,8 +234,8 @@ func TestPrepareTurnSnapshotEmitsSkillMissingAndContinues(t *testing.T) { service := NewWithFactory(manager, &stubToolManager{}, store, &scriptedProviderFactory{provider: &scriptedProvider{}}, builder) state := newRunState("run-missing-skill", session) - if _, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state); err != nil { - t.Fatalf("prepareTurnSnapshot() error = %v", err) + if _, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); err != nil { + t.Fatalf("prepareTurnBudgetSnapshot() error = %v", err) } else if rebuilt { t.Fatalf("did not expect snapshot rebuild") } @@ -253,7 +249,7 @@ func TestPrepareTurnSnapshotEmitsSkillMissingAndContinues(t *testing.T) { } } -func TestPrepareTurnSnapshotDeduplicatesSkillMissingPerRun(t *testing.T) { +func TestPrepareTurnBudgetSnapshotDeduplicatesSkillMissingPerRun(t *testing.T) { t.Parallel() manager := newRuntimeConfigManager(t) @@ -266,13 +262,13 @@ func TestPrepareTurnSnapshotDeduplicatesSkillMissingPerRun(t *testing.T) { service := NewWithFactory(manager, &stubToolManager{}, store, &scriptedProviderFactory{provider: &scriptedProvider{}}, builder) state := newRunState("run-missing-skill-dedupe", session) - if _, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state); err != nil { - t.Fatalf("first prepareTurnSnapshot() error = %v", err) + if _, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); err != nil { + t.Fatalf("first prepareTurnBudgetSnapshot() error = %v", err) } else if rebuilt { t.Fatalf("did not expect first snapshot rebuild") } - if _, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state); err != nil { - t.Fatalf("second prepareTurnSnapshot() error = %v", err) + if _, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); err != nil { + t.Fatalf("second prepareTurnBudgetSnapshot() error = %v", err) } else if rebuilt { t.Fatalf("did not expect second snapshot rebuild") } @@ -287,7 +283,7 @@ func TestPrepareTurnSnapshotDeduplicatesSkillMissingPerRun(t *testing.T) { } } -func TestPrepareTurnSnapshotPropagatesRegistryFailure(t *testing.T) { +func TestPrepareTurnBudgetSnapshotPropagatesRegistryFailure(t *testing.T) { t.Parallel() manager := newRuntimeConfigManager(t) @@ -301,7 +297,7 @@ func TestPrepareTurnSnapshotPropagatesRegistryFailure(t *testing.T) { service.SetSkillsRegistry(&stubSkillsRegistry{getErr: os.ErrPermission}) state := newRunState("run-skill-registry-failure", session) - if _, _, err := service.prepareTurnSnapshot(context.Background(), &state); !errors.Is(err, os.ErrPermission) { + if _, _, err := service.prepareTurnBudgetSnapshot(context.Background(), &state); !errors.Is(err, os.ErrPermission) { t.Fatalf("expected registry failure to propagate, got %v", err) } if len(collectRuntimeEvents(service.Events())) != 0 { @@ -568,18 +564,18 @@ func TestPrepareTurnSnapshotPrioritizesToolsByActiveSkillHints(t *testing.T) { }) state := newRunState("run-skill-tool-priority", session) - snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) + snapshot, rebuilt, err := service.prepareTurnBudgetSnapshot(context.Background(), &state) if err != nil { - t.Fatalf("prepareTurnSnapshot() error = %v", err) + t.Fatalf("prepareTurnBudgetSnapshot() error = %v", err) } if rebuilt { t.Fatalf("did not expect snapshot rebuild") } - if len(snapshot.request.Tools) != 2 { - t.Fatalf("expected 2 tools, got %d", len(snapshot.request.Tools)) + if len(snapshot.Request.Tools) != 2 { + t.Fatalf("expected 2 tools, got %d", len(snapshot.Request.Tools)) } - if snapshot.request.Tools[0].Name != "bash" { - t.Fatalf("expected hinted tool first, got %q", snapshot.request.Tools[0].Name) + if snapshot.Request.Tools[0].Name != "bash" { + t.Fatalf("expected hinted tool first, got %q", snapshot.Request.Tools[0].Name) } } @@ -752,74 +748,3 @@ func TestSkillHelperFunctionsBranches(t *testing.T) { t.Fatalf("expected nil for empty active skills") } } - -func TestServiceRunReinjectsSkillsAfterAutoCompact(t *testing.T) { - t.Parallel() - - manager := newRuntimeConfigManager(t) - if err := manager.Update(context.Background(), func(cfg *config.Config) error { - cfg.Context.AutoCompact.Enabled = true - cfg.Context.AutoCompact.InputTokenThreshold = 1 - return nil - }); err != nil { - t.Fatalf("update config: %v", err) - } - - store := newMemoryStore() - session := newRuntimeSession("session-auto-compact-skills") - session.ActivateSkill("go-review") - session.TokenInputTotal = 3 - store.sessions[session.ID] = cloneSession(session) - - builder := &stubContextBuilder{ - buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { - return agentcontext.BuildResult{ - SystemPrompt: "prompt", - Messages: append([]providertypes.Message(nil), input.Messages...), - AutoCompactSuggested: input.Metadata.SessionInputTokens >= 1, - }, nil - }, - } - compactRunner := &stubCompactRunner{ - runFn: func(ctx context.Context, input contextcompact.Input) (contextcompact.Result, error) { - return contextcompact.Result{ - Applied: true, - Messages: append([]providertypes.Message(nil), input.Messages...), - TaskState: input.TaskState.Clone(), - Metrics: contextcompact.Metrics{ - TriggerMode: string(contextcompact.ModeAuto), - }, - }, nil - }, - } - scripted := &scriptedProvider{ - streams: [][]providertypes.StreamEvent{ - {providertypes.NewTextDeltaStreamEvent("done")}, - }, - } - registry := tools.NewRegistry() - registry.Register(&stubTool{name: "filesystem_read_file", content: "default"}) - - service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, builder) - service.compactRunner = compactRunner - service.SetSkillsRegistry(&stubSkillsRegistry{ - skills: map[string]skills.Skill{ - "go-review": { - Descriptor: skills.Descriptor{ID: "go-review", Name: "Go Review"}, - Content: skills.Content{Instruction: "review code"}, - }, - }, - }) - - if err := service.Run(context.Background(), UserInput{SessionID: session.ID, RunID: "run-auto-compact-skills", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}); err != nil { - t.Fatalf("Run() error = %v", err) - } - if len(builder.builds) < 2 { - t.Fatalf("expected context builder to run before and after compact, got %d", len(builder.builds)) - } - for idx, build := range builder.builds[:2] { - if len(build.ActiveSkills) != 1 || build.ActiveSkills[0].Descriptor.ID != "go-review" { - t.Fatalf("expected active skill on build %d, got %+v", idx, build.ActiveSkills) - } - } -} diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 5cc0e7ca..59259f70 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -4,35 +4,31 @@ import ( "sync" "time" - "neo-code/internal/config" - "neo-code/internal/provider" - providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/controlplane" "neo-code/internal/security" agentsession "neo-code/internal/session" ) -// maxReactiveCompactAttempts 限制 reactive compact 最大尝试次数,超出后放弃降级并返回错误。 -const maxReactiveCompactAttempts = 3 - // runState 汇总单次 Run 生命周期内会变化的会话与计量状态。 type runState struct { mu sync.Mutex runID string session agentsession.Session - compactApplied bool + compactCount int reactiveCompactAttempts int - autoCompactCache autoCompactThresholdCache rememberedThisRun bool taskID string agentID string capabilityToken *security.CapabilityToken + nextAttemptSeq int turn int baseLifecycle controlplane.RunState lifecycle controlplane.RunState waitingPermissionCount int compactingCount int stopEmitted bool + budgetExceeded bool + hasUnknownUsage bool completion controlplane.CompletionState progress controlplane.ProgressState reportedMissingSkills map[string]struct{} @@ -43,6 +39,7 @@ func newRunState(runID string, session agentsession.Session) runState { return runState{ runID: runID, session: session, + nextAttemptSeq: 1, reportedMissingSkills: make(map[string]struct{}), } } @@ -63,6 +60,8 @@ func (s *runState) resetTokenTotals() { } s.session.TokenInputTotal = 0 s.session.TokenOutputTotal = 0 + s.session.HasUnknownUsage = false + s.hasUnknownUsage = false } // touchSession 更新会话修改时间。 @@ -90,41 +89,3 @@ func (s *runState) markSkillMissingReported(skillID string) bool { s.reportedMissingSkills[normalized] = struct{}{} return true } - -// turnSnapshot 冻结单轮推理所需的配置、上下文与 provider 请求。 -// noProgressStreakLimit 由 prepareTurnSnapshot 一次性解析并存储,确保同一轮的 -// 提示词纠偏阈值来自同一配置快照,避免并发 reload 导致注入行为不一致。 -type turnSnapshot struct { - config config.Config - providerConfig provider.RuntimeConfig - model string - workdir string - toolTimeout time.Duration - noProgressStreakLimit int - repeatCycleStreakLimit int - request providertypes.GenerateRequest -} - -// providerTurnResult 表示单轮 provider 调用成功后的结构化结果。 -type providerTurnResult struct { - assistant providertypes.Message - inputTokens int - outputTokens int -} - -// autoCompactThresholdCache 保存当前 run 已解析过的自动压缩阈值,避免热路径重复解析。 -type autoCompactThresholdCache struct { - key autoCompactThresholdCacheKey - threshold int - valid bool -} - -// autoCompactThresholdCacheKey 描述自动压缩阈值解析输入的关键维度。 -type autoCompactThresholdCacheKey struct { - provider string - model string - autoCompactEnabled bool - autoCompactInputThreshold int - autoCompactReserveTokens int - autoCompactFallback int -} diff --git a/internal/runtime/todo_mutator_test.go b/internal/runtime/todo_mutator_test.go index 9377af46..a2fe949e 100644 --- a/internal/runtime/todo_mutator_test.go +++ b/internal/runtime/todo_mutator_test.go @@ -22,7 +22,7 @@ func (s *mutatorStore) CreateSession(ctx context.Context, input agentsession.Cre if s.err != nil { return agentsession.Session{}, s.err } - session := agentsession.NewWithWorkdir(input.Title, input.Workdir) + session := agentsession.NewWithWorkdir(input.Title, input.Head.Workdir) if input.ID != "" { session.ID = input.ID } @@ -97,14 +97,15 @@ func (s *mutatorStore) UpdateSessionState(ctx context.Context, input agentsessio s.last.ID = input.SessionID s.last.Title = input.Title s.last.UpdatedAt = input.UpdatedAt - s.last.Provider = input.Provider - s.last.Model = input.Model - s.last.Workdir = input.Workdir - s.last.TaskState = input.TaskState.Clone() - s.last.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) - s.last.Todos = cloneTodosForPersistence(input.Todos) - s.last.TokenInputTotal = input.TokenInputTotal - s.last.TokenOutputTotal = input.TokenOutputTotal + head := input.Head + s.last.Provider = head.Provider + s.last.Model = head.Model + s.last.Workdir = head.Workdir + s.last.TaskState = head.TaskState.Clone() + s.last.ActivatedSkills = agentsessionCloneSkillActivations(head.ActivatedSkills) + s.last.Todos = cloneTodosForPersistence(head.Todos) + s.last.TokenInputTotal = head.TokenInputTotal + s.last.TokenOutputTotal = head.TokenOutputTotal return nil } @@ -121,14 +122,15 @@ func (s *mutatorStore) ReplaceTranscript(ctx context.Context, input agentsession s.last.ID = input.SessionID s.last.Messages = cloneMessagesForPersistence(input.Messages) s.last.UpdatedAt = input.UpdatedAt - s.last.Provider = input.Provider - s.last.Model = input.Model - s.last.Workdir = input.Workdir - s.last.TaskState = input.TaskState.Clone() - s.last.ActivatedSkills = agentsessionCloneSkillActivations(input.ActivatedSkills) - s.last.Todos = cloneTodosForPersistence(input.Todos) - s.last.TokenInputTotal = input.TokenInputTotal - s.last.TokenOutputTotal = input.TokenOutputTotal + head := input.Head + s.last.Provider = head.Provider + s.last.Model = head.Model + s.last.Workdir = head.Workdir + s.last.TaskState = head.TaskState.Clone() + s.last.ActivatedSkills = agentsessionCloneSkillActivations(head.ActivatedSkills) + s.last.Todos = cloneTodosForPersistence(head.Todos) + s.last.TokenInputTotal = head.TokenInputTotal + s.last.TokenOutputTotal = head.TokenOutputTotal return nil } diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index 686e3aa4..757d4af7 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -19,7 +19,7 @@ type indexedToolCall struct { func (s *Service) executeAssistantToolCalls( ctx context.Context, state *runState, - snapshot turnSnapshot, + snapshot TurnBudgetSnapshot, assistant providertypes.Message, ) (toolExecutionSummary, error) { if len(assistant.ToolCalls) == 0 { @@ -98,7 +98,7 @@ func (s *Service) executeAssistantToolCalls( func (s *Service) executeOneToolCall( ctx context.Context, state *runState, - snapshot turnSnapshot, + snapshot TurnBudgetSnapshot, call providertypes.ToolCall, toolLock *sync.Mutex, checkContext func() bool, @@ -120,8 +120,8 @@ func (s *Service) executeOneToolCall( Capability: state.capabilityToken, State: state, Call: call, - Workdir: snapshot.workdir, - ToolTimeout: snapshot.toolTimeout, + Workdir: snapshot.Workdir, + ToolTimeout: snapshot.ToolTimeout, }) if errors.Is(execErr, context.Canceled) { diff --git a/internal/runtime/turn_control_test.go b/internal/runtime/turn_control_test.go index 93a1d7cf..c3347231 100644 --- a/internal/runtime/turn_control_test.go +++ b/internal/runtime/turn_control_test.go @@ -109,7 +109,7 @@ func TestTransitionRunPhaseInvalidTransitionReturnsError(t *testing.T) { state := newRunState("run-invalid-phase", newRuntimeSession("session-invalid-phase")) state.lifecycle = controlplane.RunStatePlan - err := service.transitionRunState(context.Background(), &state, controlplane.RunStateVerify) + err := service.setBaseRunState(context.Background(), &state, controlplane.RunStateVerify) if err == nil { t.Fatalf("expected invalid transition to return error") } diff --git a/internal/session/asset_store_test.go b/internal/session/asset_store_test.go index 27ed93bc..b964e65b 100644 --- a/internal/session/asset_store_test.go +++ b/internal/session/asset_store_test.go @@ -106,7 +106,7 @@ func TestSQLiteStoreOpenReturnsFileErrorWhenPayloadMissing(t *testing.T) { if err != nil { t.Fatalf("MkdirTemp() workspaceRoot error = %v", err) } - store := NewStore(baseDir, workspaceRoot) + store := NewSQLiteStore(baseDir, workspaceRoot) t.Cleanup(func() { _ = store.Close() _ = os.RemoveAll(baseDir) diff --git a/internal/session/id.go b/internal/session/id.go index dc2e84e9..0108a6b7 100644 --- a/internal/session/id.go +++ b/internal/session/id.go @@ -11,8 +11,3 @@ func NewID(prefix string) string { _, _ = rand.Read(buf) return prefix + "_" + hex.EncodeToString(buf) } - -// newID 保留为内部兼容入口,后续代码请优先使用 NewID。 -func newID(prefix string) string { - return NewID(prefix) -} diff --git a/internal/session/id_test.go b/internal/session/id_test.go index 396779f8..8ef1bb51 100644 --- a/internal/session/id_test.go +++ b/internal/session/id_test.go @@ -39,15 +39,3 @@ func TestNewIDAllowsEmptyPrefix(t *testing.T) { t.Fatalf("expected format _<16hex>, got %q", id) } } - -func TestNewIDCompatibilityWrapper(t *testing.T) { - t.Parallel() - - id := newID("session") - if !strings.HasPrefix(id, "session_") { - t.Fatalf("expected compatibility wrapper to preserve prefix, got %q", id) - } - if len(strings.TrimPrefix(id, "session_")) != 16 { - t.Fatalf("expected compatibility wrapper to return 16 hex chars, got %q", id) - } -} diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index 68451347..e2cc3ca6 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -338,18 +338,11 @@ func (p *InputPreparer) loadOrCreateSession( } session := NewWithWorkdir(title, sessionWorkdir) created, err := p.store.CreateSession(ctx, CreateSessionInput{ - ID: session.ID, - Title: session.Title, - CreatedAt: session.CreatedAt, - UpdatedAt: session.UpdatedAt, - Provider: session.Provider, - Model: session.Model, - Workdir: session.Workdir, - TaskState: session.TaskState, - ActivatedSkills: session.ActivatedSkills, - Todos: session.Todos, - TokenInputTotal: session.TokenInputTotal, - TokenOutputTotal: session.TokenOutputTotal, + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Head: session.HeadSnapshot(), }) if err != nil { return Session{}, false, sessionWorkdirUpdate{}, err diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index e7b96c95..ffaebe2d 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -533,22 +533,24 @@ func TestInputPreparerPrepareWorkdirUpdatePreservesConcurrentSessionHeadChanges( SessionID: session.ID, Title: session.Title, UpdatedAt: session.UpdatedAt.Add(time.Minute), - Provider: "provider-after", - Model: "model-after", - Workdir: currentWorkdir, - TaskState: TaskState{ - Goal: "newer task state", - NextStep: "must survive workdir update", + Head: SessionHead{ + Provider: "provider-after", + Model: "model-after", + Workdir: currentWorkdir, + TaskState: TaskState{ + Goal: "newer task state", + NextStep: "must survive workdir update", + }, + Todos: []TodoItem{{ + ID: "todo-newer", + Content: "written by concurrent run", + Status: TodoStatusCompleted, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt.Add(time.Minute), + }}, + TokenInputTotal: 55, + TokenOutputTotal: 89, }, - Todos: []TodoItem{{ - ID: "todo-newer", - Content: "written by concurrent run", - Status: TodoStatusCompleted, - CreatedAt: session.CreatedAt, - UpdatedAt: session.UpdatedAt.Add(time.Minute), - }}, - TokenInputTotal: 55, - TokenOutputTotal: 89, } preparerStore := &workdirRaceStore{ @@ -579,40 +581,33 @@ func TestInputPreparerPrepareWorkdirUpdatePreservesConcurrentSessionHeadChanges( if loaded.Workdir != targetWorkdir { t.Fatalf("expected persisted workdir %q, got %q", targetWorkdir, loaded.Workdir) } - if loaded.Provider != concurrentState.Provider || loaded.Model != concurrentState.Model { + if loaded.Provider != concurrentState.Head.Provider || loaded.Model != concurrentState.Head.Model { t.Fatalf("expected provider/model %q/%q, got %q/%q", - concurrentState.Provider, concurrentState.Model, loaded.Provider, loaded.Model) + concurrentState.Head.Provider, concurrentState.Head.Model, loaded.Provider, loaded.Model) } - if loaded.TokenInputTotal != concurrentState.TokenInputTotal || loaded.TokenOutputTotal != concurrentState.TokenOutputTotal { + if loaded.TokenInputTotal != concurrentState.Head.TokenInputTotal || loaded.TokenOutputTotal != concurrentState.Head.TokenOutputTotal { t.Fatalf("expected token totals %d/%d, got %d/%d", - concurrentState.TokenInputTotal, - concurrentState.TokenOutputTotal, + concurrentState.Head.TokenInputTotal, + concurrentState.Head.TokenOutputTotal, loaded.TokenInputTotal, loaded.TokenOutputTotal, ) } - if loaded.TaskState.Goal != concurrentState.TaskState.Goal || loaded.TaskState.NextStep != concurrentState.TaskState.NextStep { + if loaded.TaskState.Goal != concurrentState.Head.TaskState.Goal || loaded.TaskState.NextStep != concurrentState.Head.TaskState.NextStep { t.Fatalf("expected newer task state to survive, got %+v", loaded.TaskState) } - if len(loaded.Todos) != 1 || loaded.Todos[0].ID != concurrentState.Todos[0].ID || loaded.Todos[0].Status != concurrentState.Todos[0].Status { + if len(loaded.Todos) != 1 || loaded.Todos[0].ID != concurrentState.Head.Todos[0].ID || loaded.Todos[0].Status != concurrentState.Head.Todos[0].Status { t.Fatalf("expected newer todos to survive, got %+v", loaded.Todos) } } func createSessionForPreparerTest(ctx context.Context, store *SQLiteStore, session Session) error { _, err := store.CreateSession(ctx, CreateSessionInput{ - ID: session.ID, - Title: session.Title, - CreatedAt: session.CreatedAt, - UpdatedAt: session.UpdatedAt, - Provider: session.Provider, - Model: session.Model, - Workdir: session.Workdir, - TaskState: session.TaskState, - ActivatedSkills: session.ActivatedSkills, - Todos: session.Todos, - TokenInputTotal: session.TokenInputTotal, - TokenOutputTotal: session.TokenOutputTotal, + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Head: session.HeadSnapshot(), }) return err } @@ -620,7 +615,7 @@ func createSessionForPreparerTest(ctx context.Context, store *SQLiteStore, sessi func newInputPreparerTestStore(t *testing.T, workdir string) *SQLiteStore { t.Helper() - store := NewStore(t.TempDir(), workdir) + store := NewSQLiteStore(t.TempDir(), workdir) t.Cleanup(func() { _ = store.Close() }) diff --git a/internal/session/skill_activation_test.go b/internal/session/skill_activation_test.go index f34d106c..3780812b 100644 --- a/internal/session/skill_activation_test.go +++ b/internal/session/skill_activation_test.go @@ -41,10 +41,12 @@ func TestSQLiteStoreRoundTripActivatedSkills(t *testing.T) { Title: "Skills Round Trip", CreatedAt: time.Now().Add(-time.Minute), UpdatedAt: time.Now(), - ActivatedSkills: []SkillActivation{ - {SkillID: " zeta "}, - {SkillID: "go_review"}, - {SkillID: "go-review"}, + Head: SessionHead{ + ActivatedSkills: []SkillActivation{ + {SkillID: " zeta "}, + {SkillID: "go_review"}, + {SkillID: "go-review"}, + }, }, }) if err != nil { diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index 88f8b7c8..0e34dbdb 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -32,6 +32,7 @@ type sqliteSessionRow struct { TodosJSON string TokenInputTotal int TokenOutputTotal int + HasUnknownUsage bool } type sqliteMessageRow struct { @@ -160,8 +161,8 @@ func (s *SQLiteStore) CreateSession(ctx context.Context, input CreateSessionInpu INSERT INTO sessions ( id, title, created_at_ms, updated_at_ms, provider, model, workdir, task_state_json, todos_json, activated_skills_json, - token_input_total, token_output_total, last_seq, message_count -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 0, 0) + token_input_total, token_output_total, has_unknown_usage, last_seq, message_count +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 0, 0) `, session.ID, session.Title, @@ -175,6 +176,7 @@ INSERT INTO sessions ( mustJSONString(session.ActivatedSkills), session.TokenInputTotal, session.TokenOutputTotal, + session.HasUnknownUsage, ) if err != nil { if isSQLiteSessionUniqueConstraintError(err) { @@ -318,6 +320,7 @@ SET updated_at_ms = ?, workdir = ?, token_input_total = token_input_total + ?, token_output_total = token_output_total + ?, + has_unknown_usage = CASE WHEN ? THEN 1 ELSE has_unknown_usage END, last_seq = ?, message_count = message_count + ? WHERE id = ? @@ -328,6 +331,7 @@ WHERE id = ? stringsTrimSpace(input.Workdir), input.TokenInputDelta, input.TokenOutputDelta, + input.HasUnknownUsage, lastSeq, len(normalizedMessages), input.SessionID, @@ -398,7 +402,8 @@ SET title = ?, todos_json = ?, activated_skills_json = ?, token_input_total = ?, - token_output_total = ? + token_output_total = ?, + has_unknown_usage = ? WHERE id = ? `, row.Title, @@ -411,6 +416,7 @@ WHERE id = ? row.ActivatedJSON, row.TokenInputTotal, row.TokenOutputTotal, + row.HasUnknownUsage, row.ID, ) if err != nil { @@ -465,6 +471,7 @@ SET updated_at_ms = ?, activated_skills_json = ?, token_input_total = ?, token_output_total = ?, + has_unknown_usage = ?, last_seq = ?, message_count = ? WHERE id = ? @@ -478,6 +485,7 @@ WHERE id = ? row.ActivatedJSON, row.TokenInputTotal, row.TokenOutputTotal, + row.HasUnknownUsage, lastSeq, len(messages), row.ID, @@ -795,7 +803,13 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { if err := db.QueryRowContext(ctx, `PRAGMA user_version`).Scan(&userVersion); err != nil { return fmt.Errorf("session: read sqlite user_version: %w", err) } - if userVersion != 0 && userVersion != sqliteSchemaVersion { + switch userVersion { + case 0, sqliteSchemaVersion: + case 1: + if err := migrateSQLiteSchemaV1ToV2(ctx, db); err != nil { + return err + } + default: return fmt.Errorf("session: unsupported sqlite schema version %d", userVersion) } @@ -819,6 +833,7 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { activated_skills_json TEXT NOT NULL, token_input_total INTEGER NOT NULL DEFAULT 0, token_output_total INTEGER NOT NULL DEFAULT 0, + has_unknown_usage INTEGER NOT NULL DEFAULT 0, last_seq INTEGER NOT NULL DEFAULT 0, message_count INTEGER NOT NULL DEFAULT 0 )`, @@ -860,20 +875,72 @@ func initializeSQLiteSchema(ctx context.Context, db *sql.DB) error { return nil } +// migrateSQLiteSchemaV1ToV2 将 v1 会话库升级到当前 v2 schema,仅补齐当前版本新增字段。 +func migrateSQLiteSchemaV1ToV2(ctx context.Context, db *sql.DB) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("session: begin schema migration tx: %w", err) + } + defer rollbackTx(tx) + + hasColumn, err := sqliteTableHasColumn(ctx, tx, "sessions", "has_unknown_usage") + if err != nil { + return err + } + if !hasColumn { + if _, err := tx.ExecContext( + ctx, + `ALTER TABLE sessions ADD COLUMN has_unknown_usage INTEGER NOT NULL DEFAULT 0`, + ); err != nil { + return fmt.Errorf("session: migrate sqlite schema v1 to v2: %w", err) + } + } + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`PRAGMA user_version=%d`, sqliteSchemaVersion)); err != nil { + return fmt.Errorf("session: set migrated sqlite schema version: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("session: commit schema migration tx: %w", err) + } + return nil +} + +// sqliteTableHasColumn 检查指定表是否包含字段,供明确版本迁移保持幂等。 +func sqliteTableHasColumn(ctx context.Context, tx *sql.Tx, table string, column string) (bool, error) { + rows, err := tx.QueryContext(ctx, `PRAGMA table_info(`+table+`)`) + if err != nil { + return false, fmt.Errorf("session: inspect sqlite table %s: %w", table, err) + } + defer rows.Close() + + for rows.Next() { + var ( + cid int + name string + columnType string + notNull int + defaultVal sql.NullString + primaryKey int + ) + if err := rows.Scan(&cid, &name, &columnType, ¬Null, &defaultVal, &primaryKey); err != nil { + return false, fmt.Errorf("session: scan sqlite table %s info: %w", table, err) + } + if stringsTrimSpace(name) == column { + return true, nil + } + } + if err := rows.Err(); err != nil { + return false, fmt.Errorf("session: iterate sqlite table %s info: %w", table, err) + } + return false, nil +} + // normalizeCreateSessionInput 规范化创建会话输入并生成最终会话头。 func normalizeCreateSessionInput(input CreateSessionInput) (Session, error) { session := Session{ - ID: stringsTrimSpace(input.ID), - Title: sanitizeTitle(input.Title), - Provider: stringsTrimSpace(input.Provider), - Model: stringsTrimSpace(input.Model), - CreatedAt: input.CreatedAt, - UpdatedAt: input.UpdatedAt, - Workdir: stringsTrimSpace(input.Workdir), - TaskState: normalizeAndClampTaskState(input.TaskState), - ActivatedSkills: normalizeSkillActivations(input.ActivatedSkills), - TokenInputTotal: input.TokenInputTotal, - TokenOutputTotal: input.TokenOutputTotal, + ID: stringsTrimSpace(input.ID), + Title: sanitizeTitle(input.Title), + CreatedAt: input.CreatedAt, + UpdatedAt: input.UpdatedAt, } if session.ID == "" { session.ID = NewID("session") @@ -888,7 +955,9 @@ func normalizeCreateSessionInput(input CreateSessionInput) (Session, error) { if session.UpdatedAt.IsZero() { session.UpdatedAt = session.CreatedAt } - todos, err := normalizeAndValidateTodos(input.Todos) + head := input.Head.clone() + head.applyToSession(&session) + todos, err := normalizeAndValidateTodos(head.Todos) if err != nil { return Session{}, err } @@ -904,39 +973,34 @@ func normalizeUpdateSessionStateInput(input UpdateSessionStateInput) (sqliteSess if err := validateStorageID("session id", input.SessionID); err != nil { return sqliteSessionRow{}, fmt.Errorf("session: %w", err) } - todos, err := normalizeAndValidateTodos(input.Todos) + head := input.Head.clone() + todos, err := normalizeAndValidateTodos(head.Todos) if err != nil { return sqliteSessionRow{}, err } return sqliteSessionRow{ ID: stringsTrimSpace(input.SessionID), Title: sanitizeTitle(input.Title), - Provider: stringsTrimSpace(input.Provider), - Model: stringsTrimSpace(input.Model), + Provider: stringsTrimSpace(head.Provider), + Model: stringsTrimSpace(head.Model), UpdatedAtMS: toUnixMillis(resolveUpdatedAt(input.UpdatedAt)), - Workdir: stringsTrimSpace(input.Workdir), - TaskStateJSON: mustJSONString(normalizeAndClampTaskState(input.TaskState)), + Workdir: stringsTrimSpace(head.Workdir), + TaskStateJSON: mustJSONString(normalizeAndClampTaskState(head.TaskState)), TodosJSON: mustJSONString(todos), - ActivatedJSON: mustJSONString(normalizeSkillActivations(input.ActivatedSkills)), - TokenInputTotal: input.TokenInputTotal, - TokenOutputTotal: input.TokenOutputTotal, + ActivatedJSON: mustJSONString(normalizeSkillActivations(head.ActivatedSkills)), + TokenInputTotal: head.TokenInputTotal, + TokenOutputTotal: head.TokenOutputTotal, + HasUnknownUsage: head.HasUnknownUsage, }, nil } // normalizeReplaceTranscriptInput 规范化 compact 后的 transcript 替换输入。 func normalizeReplaceTranscriptInput(input ReplaceTranscriptInput) (sqliteSessionRow, []providertypes.Message, error) { row, err := normalizeUpdateSessionStateInput(UpdateSessionStateInput{ - SessionID: input.SessionID, - Title: "", - UpdatedAt: input.UpdatedAt, - Provider: input.Provider, - Model: input.Model, - Workdir: input.Workdir, - TaskState: input.TaskState, - ActivatedSkills: input.ActivatedSkills, - Todos: input.Todos, - TokenInputTotal: input.TokenInputTotal, - TokenOutputTotal: input.TokenOutputTotal, + SessionID: input.SessionID, + Title: "", + UpdatedAt: input.UpdatedAt, + Head: input.Head, }) if err != nil { return sqliteSessionRow{}, nil, err @@ -968,7 +1032,7 @@ func loadSessionRow(ctx context.Context, tx *sql.Tx, sessionID string) (sqliteSe var row sqliteSessionRow err := tx.QueryRowContext(ctx, ` SELECT id, title, provider, model, created_at_ms, updated_at_ms, workdir, - task_state_json, activated_skills_json, todos_json, token_input_total, token_output_total + task_state_json, activated_skills_json, todos_json, token_input_total, token_output_total, has_unknown_usage FROM sessions WHERE id = ? `, @@ -986,6 +1050,7 @@ WHERE id = ? &row.TodosJSON, &row.TokenInputTotal, &row.TokenOutputTotal, + &row.HasUnknownUsage, ) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -1064,6 +1129,7 @@ func buildSessionFromRow(row sqliteSessionRow, messages []sqliteMessageRow) (Ses Todos: normalizedTodos, TokenInputTotal: row.TokenInputTotal, TokenOutputTotal: row.TokenOutputTotal, + HasUnknownUsage: row.HasUnknownUsage, } if len(result.Todos) > 0 { result.TodoVersion = CurrentTodoVersion diff --git a/internal/session/sqlite_store_additional_test.go b/internal/session/sqlite_store_additional_test.go index e747cb82..27b204fe 100644 --- a/internal/session/sqlite_store_additional_test.go +++ b/internal/session/sqlite_store_additional_test.go @@ -213,7 +213,9 @@ func TestNormalizeCreateSessionInputDefaultsGeneratedID(t *testing.T) { session, err := normalizeCreateSessionInput(CreateSessionInput{ Title: " test ", - Todos: []TodoItem{{ID: "todo-1", Content: "a"}}, + Head: SessionHead{ + Todos: []TodoItem{{ID: "todo-1", Content: "a"}}, + }, }) if err != nil { t.Fatalf("normalizeCreateSessionInput() error = %v", err) @@ -473,7 +475,7 @@ func TestSQLiteStoreInitializeTightensExistingDirectoryPermissions(t *testing.T) baseDir := t.TempDir() workspaceRoot := t.TempDir() - store := NewStore(baseDir, workspaceRoot) + store := NewSQLiteStore(baseDir, workspaceRoot) t.Cleanup(func() { _ = store.Close() }) for _, dir := range []string{store.projectDir, store.assetsDir} { @@ -739,7 +741,7 @@ func TestCleanupExpiredSessionAssetsStopsOnCanceledContext(t *testing.T) { } } -func TestBuildSessionFromRowInfersLegacySubAgentExecutor(t *testing.T) { +func TestBuildSessionFromRowDefaultsTodoMissingExecutorToAgent(t *testing.T) { t.Parallel() nowMS := toUnixMillis(time.Now().UTC()) @@ -755,20 +757,14 @@ func TestBuildSessionFromRowInfersLegacySubAgentExecutor(t *testing.T) { session, err := buildSessionFromRow(row, nil) if err != nil { - t.Fatalf("buildSessionFromRow() error = %v", err) - } - if len(session.Todos) != 1 { - t.Fatalf("todos len = %d, want 1", len(session.Todos)) + t.Fatalf("expected missing executor to default, got session=%+v err=%v", session, err) } - if session.Todos[0].Executor != TodoExecutorSubAgent { - t.Fatalf("legacy todo executor = %q, want %q", session.Todos[0].Executor, TodoExecutorSubAgent) - } - if session.TodoVersion != CurrentTodoVersion { - t.Fatalf("todo_version = %d, want %d", session.TodoVersion, CurrentTodoVersion) + if len(session.Todos) != 1 || session.Todos[0].Executor != TodoExecutorAgent { + t.Fatalf("expected default executor %q, got %+v", TodoExecutorAgent, session.Todos) } } -func TestBuildSessionFromRowInfersLegacySubAgentExecutorByRetrySignals(t *testing.T) { +func TestBuildSessionFromRowDefaultsRetryTodoMissingExecutorToAgent(t *testing.T) { t.Parallel() now := time.Now().UTC() @@ -788,12 +784,9 @@ func TestBuildSessionFromRowInfersLegacySubAgentExecutorByRetrySignals(t *testin session, err := buildSessionFromRow(row, nil) if err != nil { - t.Fatalf("buildSessionFromRow() error = %v", err) - } - if len(session.Todos) != 1 { - t.Fatalf("todos len = %d, want 1", len(session.Todos)) + t.Fatalf("expected retry todo missing executor to default, got session=%+v err=%v", session, err) } - if session.Todos[0].Executor != TodoExecutorSubAgent { - t.Fatalf("legacy retry todo executor = %q, want %q", session.Todos[0].Executor, TodoExecutorSubAgent) + if len(session.Todos) != 1 || session.Todos[0].Executor != TodoExecutorAgent { + t.Fatalf("expected default executor %q, got %+v", TodoExecutorAgent, session.Todos) } } diff --git a/internal/session/store.go b/internal/session/store.go index c21cc8c7..95c74232 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -14,7 +14,7 @@ import ( const ( sessionDatabaseFileName = "session.db" assetsDirName = "assets" - sqliteSchemaVersion = 1 + sqliteSchemaVersion = 2 // MaxSessionMessages 定义单个会话允许持久化的最大消息数,超出时自动裁剪最旧消息。 MaxSessionMessages = 8192 @@ -47,6 +47,20 @@ type Session struct { Messages []providertypes.Message TokenInputTotal int TokenOutputTotal int + HasUnknownUsage bool +} + +// SessionHead 表示可独立持久化、可整体替换的会话头状态快照。 +type SessionHead struct { + Provider string + Model string + Workdir string + TaskState TaskState + ActivatedSkills []SkillActivation + Todos []TodoItem + TokenInputTotal int + TokenOutputTotal int + HasUnknownUsage bool } // Summary 表示会话列表视图需要的轻量摘要。 @@ -59,18 +73,11 @@ type Summary struct { // CreateSessionInput 描述新建空会话头时需要写入的字段。 type CreateSessionInput struct { - ID string - Title string - CreatedAt time.Time - UpdatedAt time.Time - Provider string - Model string - Workdir string - TaskState TaskState - ActivatedSkills []SkillActivation - Todos []TodoItem - TokenInputTotal int - TokenOutputTotal int + ID string + Title string + CreatedAt time.Time + UpdatedAt time.Time + Head SessionHead } // AppendMessagesInput 描述一次原子追加消息及会话头增量更新。 @@ -83,21 +90,15 @@ type AppendMessagesInput struct { Workdir string TokenInputDelta int TokenOutputDelta int + HasUnknownUsage bool } // UpdateSessionStateInput 描述一次只更新会话头状态的写入。 type UpdateSessionStateInput struct { - SessionID string - Title string - UpdatedAt time.Time - Provider string - Model string - Workdir string - TaskState TaskState - ActivatedSkills []SkillActivation - Todos []TodoItem - TokenInputTotal int - TokenOutputTotal int + SessionID string + Title string + UpdatedAt time.Time + Head SessionHead } // UpdateSessionWorkdirInput 描述一次仅更新会话 workdir 的最小粒度写入。 @@ -109,17 +110,10 @@ type UpdateSessionWorkdirInput struct { // ReplaceTranscriptInput 描述 compact 后整段 transcript 的原子替换。 type ReplaceTranscriptInput struct { - SessionID string - Messages []providertypes.Message - UpdatedAt time.Time - Provider string - Model string - Workdir string - TaskState TaskState - ActivatedSkills []SkillActivation - Todos []TodoItem - TokenInputTotal int - TokenOutputTotal int + SessionID string + Messages []providertypes.Message + UpdatedAt time.Time + Head SessionHead } // Store 定义会话持久化的意图型接口。 @@ -145,11 +139,6 @@ func NewSQLiteStore(baseDir string, workspaceRoot string) *SQLiteStore { } } -// NewStore 返回默认会话存储实现。 -func NewStore(baseDir string, workspaceRoot string) *SQLiteStore { - return NewSQLiteStore(baseDir, workspaceRoot) -} - // New 创建一个默认标题策略的新会话对象。 func New(title string) Session { return NewWithWorkdir(title, "") @@ -171,6 +160,65 @@ func NewWithWorkdir(title string, workdir string) Session { } } +// HeadSnapshot 返回当前会话头状态的深拷贝,用于持久化输入与跨层传递。 +func (s Session) HeadSnapshot() SessionHead { + return SessionHead{ + Provider: strings.TrimSpace(s.Provider), + Model: strings.TrimSpace(s.Model), + Workdir: strings.TrimSpace(s.Workdir), + TaskState: s.TaskState.Clone(), + ActivatedSkills: cloneSkillActivations(s.ActivatedSkills), + Todos: cloneTodoItems(s.Todos), + TokenInputTotal: s.TokenInputTotal, + TokenOutputTotal: s.TokenOutputTotal, + HasUnknownUsage: s.HasUnknownUsage, + } +} + +// clone 返回会话头状态的深拷贝,避免跨层共享底层切片。 +func (h SessionHead) clone() SessionHead { + return SessionHead{ + Provider: strings.TrimSpace(h.Provider), + Model: strings.TrimSpace(h.Model), + Workdir: strings.TrimSpace(h.Workdir), + TaskState: h.TaskState.Clone(), + ActivatedSkills: cloneSkillActivations(h.ActivatedSkills), + Todos: cloneTodoItems(h.Todos), + TokenInputTotal: h.TokenInputTotal, + TokenOutputTotal: h.TokenOutputTotal, + HasUnknownUsage: h.HasUnknownUsage, + } +} + +// applyToSession 将会话头状态整体写回会话对象,避免调用方逐字段手动拼装。 +func (h SessionHead) applyToSession(session *Session) { + if session == nil { + return + } + cloned := h.clone() + session.Provider = cloned.Provider + session.Model = cloned.Model + session.Workdir = cloned.Workdir + session.TaskState = cloned.TaskState + session.ActivatedSkills = cloned.ActivatedSkills + session.Todos = cloned.Todos + session.TokenInputTotal = cloned.TokenInputTotal + session.TokenOutputTotal = cloned.TokenOutputTotal + session.HasUnknownUsage = cloned.HasUnknownUsage +} + +// cloneTodoItems 深拷贝 Todo 列表,避免会话头快照共享底层切片。 +func cloneTodoItems(items []TodoItem) []TodoItem { + if len(items) == 0 { + return nil + } + cloned := make([]TodoItem, len(items)) + for idx, item := range items { + cloned[idx] = item.Clone() + } + return cloned +} + // sanitizeTitle 规范化会话标题,保证空标题和超长标题都有稳定表现。 func sanitizeTitle(title string) string { title = strings.TrimSpace(title) diff --git a/internal/session/store_test.go b/internal/session/store_test.go index e9a76862..f08411cd 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -24,19 +24,21 @@ func TestSQLiteStoreLifecycleRoundTrip(t *testing.T) { Title: " Session Roundtrip ", CreatedAt: createdAt, UpdatedAt: updatedAt, - Provider: "openai", - Model: "gpt-5", - Workdir: "/repo", - TaskState: TaskState{ - Goal: "ship sqlite migration", - Progress: []string{"draft plan"}, - }, - ActivatedSkills: []SkillActivation{{SkillID: "go_review"}, {SkillID: "go-review"}}, - Todos: []TodoItem{ - {ID: "todo-1", Content: "implement store"}, + Head: SessionHead{ + Provider: "openai", + Model: "gpt-5", + Workdir: "/repo", + TaskState: TaskState{ + Goal: "ship sqlite migration", + Progress: []string{"draft plan"}, + }, + ActivatedSkills: []SkillActivation{{SkillID: "go_review"}, {SkillID: "go-review"}}, + Todos: []TodoItem{ + {ID: "todo-1", Content: "implement store"}, + }, + TokenInputTotal: 11, + TokenOutputTotal: 7, }, - TokenInputTotal: 11, - TokenOutputTotal: 7, }) if err != nil { t.Fatalf("CreateSession() error = %v", err) @@ -74,20 +76,23 @@ func TestSQLiteStoreLifecycleRoundTrip(t *testing.T) { SessionID: session.ID, Title: "SQLite Ready", UpdatedAt: updatedAt.Add(2 * time.Minute), - Provider: "openai", - Model: "gpt-5.1", - Workdir: "/repo/final", - TaskState: TaskState{ - Goal: "ship sqlite migration", - Progress: []string{"draft plan", "replace store"}, - UserConstraints: []string{"no legacy compatibility"}, - }, - ActivatedSkills: []SkillActivation{{SkillID: "go-review"}}, - Todos: []TodoItem{ - {ID: "todo-1", Content: "implement store", Status: TodoStatusInProgress}, + Head: SessionHead{ + Provider: "openai", + Model: "gpt-5.1", + Workdir: "/repo/final", + TaskState: TaskState{ + Goal: "ship sqlite migration", + Progress: []string{"draft plan", "replace store"}, + UserConstraints: []string{"no legacy compatibility"}, + }, + ActivatedSkills: []SkillActivation{{SkillID: "go-review"}}, + Todos: []TodoItem{ + {ID: "todo-1", Content: "implement store", Status: TodoStatusInProgress}, + }, + TokenInputTotal: 99, + TokenOutputTotal: 42, + HasUnknownUsage: true, }, - TokenInputTotal: 99, - TokenOutputTotal: 42, }); err != nil { t.Fatalf("UpdateSessionState() error = %v", err) } @@ -105,6 +110,9 @@ func TestSQLiteStoreLifecycleRoundTrip(t *testing.T) { if loaded.TokenInputTotal != 99 || loaded.TokenOutputTotal != 42 { t.Fatalf("unexpected token totals: in=%d out=%d", loaded.TokenInputTotal, loaded.TokenOutputTotal) } + if !loaded.HasUnknownUsage { + t.Fatalf("expected HasUnknownUsage to round-trip") + } if got := loaded.ActiveSkillIDs(); len(got) != 1 || got[0] != "go-review" { t.Fatalf("unexpected active skills: %+v", got) } @@ -154,7 +162,7 @@ func TestSQLiteStoreListSummariesSortedAndLegacyJSONIgnored(t *testing.T) { if err != nil { t.Fatalf("MkdirTemp() workspaceRoot error = %v", err) } - store := NewStore(baseDir, workspaceRoot) + store := NewSQLiteStore(baseDir, workspaceRoot) t.Cleanup(func() { _ = store.Close() _ = os.RemoveAll(baseDir) @@ -210,18 +218,21 @@ func TestSQLiteStoreReplaceTranscriptAndPragmas(t *testing.T) { if err := store.ReplaceTranscript(ctx, ReplaceTranscriptInput{ SessionID: session.ID, UpdatedAt: time.Now().UTC(), - Provider: "openai", - Model: "gpt-5.2", - Workdir: "/repo", - TaskState: TaskState{Goal: "after compact"}, - Todos: []TodoItem{ - {ID: "todo-1", Content: "after compact"}, - }, Messages: []providertypes.Message{ {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("after")}}, }, - TokenInputTotal: 0, - TokenOutputTotal: 0, + Head: SessionHead{ + Provider: "openai", + Model: "gpt-5.2", + Workdir: "/repo", + TaskState: TaskState{Goal: "after compact"}, + Todos: []TodoItem{ + {ID: "todo-1", Content: "after compact"}, + }, + TokenInputTotal: 0, + TokenOutputTotal: 0, + HasUnknownUsage: false, + }, }); err != nil { t.Fatalf("ReplaceTranscript() error = %v", err) } @@ -239,6 +250,9 @@ func TestSQLiteStoreReplaceTranscriptAndPragmas(t *testing.T) { if loaded.TaskState.Goal != "after compact" { t.Fatalf("unexpected task state after replace: %+v", loaded.TaskState) } + if loaded.HasUnknownUsage { + t.Fatalf("expected replace transcript to clear HasUnknownUsage") + } db, err := store.ensureDB(ctx) if err != nil { @@ -414,9 +428,11 @@ func TestSQLiteStoreAppendReplaceAndSchemaErrors(t *testing.T) { if err := store.UpdateSessionState(ctx, UpdateSessionStateInput{ SessionID: session.ID, Title: "x", - Todos: []TodoItem{ - {ID: "dup", Content: "a"}, - {ID: "dup", Content: "b"}, + Head: SessionHead{ + Todos: []TodoItem{ + {ID: "dup", Content: "a"}, + {ID: "dup", Content: "b"}, + }, }, }); err == nil { t.Fatalf("expected invalid todos error") @@ -449,7 +465,7 @@ func TestSQLiteStoreInitializationRejectsUnsupportedSchemaVersion(t *testing.T) if err != nil { t.Fatalf("MkdirTemp() workspaceRoot error = %v", err) } - store := NewStore(baseDir, workspaceRoot) + store := NewSQLiteStore(baseDir, workspaceRoot) t.Cleanup(func() { _ = store.Close() _ = os.RemoveAll(baseDir) @@ -476,6 +492,51 @@ func TestSQLiteStoreInitializationRejectsUnsupportedSchemaVersion(t *testing.T) } } +func TestSQLiteStoreMigratesSchemaV1ToV2(t *testing.T) { + ctx := context.Background() + baseDir, workspaceRoot, store := newMigrationTestStore(t) + + createLegacyV1SessionDB(t, ctx, baseDir, workspaceRoot, false) + loaded, err := store.LoadSession(ctx, "session_v1") + if err != nil { + t.Fatalf("LoadSession() after migration error = %v", err) + } + if loaded.ID != "session_v1" || loaded.Title != "Legacy V1" { + t.Fatalf("unexpected migrated session: %+v", loaded) + } + if loaded.HasUnknownUsage { + t.Fatalf("expected migrated HasUnknownUsage to default false") + } + + db, err := store.ensureDB(ctx) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + assertPragmaInt(t, db, "user_version", sqliteSchemaVersion) + assertSQLiteColumnExists(t, db, "sessions", "has_unknown_usage") +} + +func TestSQLiteStoreMigratesSchemaV1ToV2WhenColumnAlreadyExists(t *testing.T) { + ctx := context.Background() + baseDir, workspaceRoot, store := newMigrationTestStore(t) + + createLegacyV1SessionDB(t, ctx, baseDir, workspaceRoot, true) + summaries, err := store.ListSummaries(ctx) + if err != nil { + t.Fatalf("ListSummaries() after migration error = %v", err) + } + if len(summaries) != 1 || summaries[0].ID != "session_v1" { + t.Fatalf("unexpected summaries after migration: %+v", summaries) + } + + db, err := store.ensureDB(ctx) + if err != nil { + t.Fatalf("ensureDB() error = %v", err) + } + assertPragmaInt(t, db, "user_version", sqliteSchemaVersion) + assertSQLiteColumnExists(t, db, "sessions", "has_unknown_usage") +} + func assertPragmaString(t *testing.T, db *sql.DB, name string, want string) { t.Helper() var got string @@ -498,6 +559,138 @@ func assertPragmaInt(t *testing.T, db *sql.DB, name string, want int) { } } +func assertSQLiteColumnExists(t *testing.T, db *sql.DB, table string, column string) { + t.Helper() + rows, err := db.Query(`PRAGMA table_info(` + table + `)`) + if err != nil { + t.Fatalf("PRAGMA table_info(%s) error = %v", table, err) + } + defer rows.Close() + + for rows.Next() { + var ( + cid int + name string + columnType string + notNull int + defaultVal sql.NullString + primaryKey int + ) + if err := rows.Scan(&cid, &name, &columnType, ¬Null, &defaultVal, &primaryKey); err != nil { + t.Fatalf("scan table info: %v", err) + } + if name == column { + return + } + } + if err := rows.Err(); err != nil { + t.Fatalf("iterate table info: %v", err) + } + t.Fatalf("expected column %s.%s to exist", table, column) +} + +func newMigrationTestStore(t *testing.T) (string, string, *SQLiteStore) { + t.Helper() + baseDir, err := os.MkdirTemp("", "session-base-") + if err != nil { + t.Fatalf("MkdirTemp() baseDir error = %v", err) + } + workspaceRoot, err := os.MkdirTemp("", "session-workspace-") + if err != nil { + t.Fatalf("MkdirTemp() workspaceRoot error = %v", err) + } + store := NewSQLiteStore(baseDir, workspaceRoot) + t.Cleanup(func() { + _ = store.Close() + _ = os.RemoveAll(baseDir) + _ = os.RemoveAll(workspaceRoot) + }) + return baseDir, workspaceRoot, store +} + +func createLegacyV1SessionDB( + t *testing.T, + ctx context.Context, + baseDir string, + workspaceRoot string, + includeUnknownUsageColumn bool, +) { + t.Helper() + projectDir := projectDirectory(baseDir, workspaceRoot) + if err := os.MkdirAll(projectDir, 0o755); err != nil { + t.Fatalf("MkdirAll(projectDir) error = %v", err) + } + db, err := sql.Open("sqlite", databasePath(baseDir, workspaceRoot)) + if err != nil { + t.Fatalf("sql.Open() error = %v", err) + } + defer db.Close() + + unknownUsageColumn := "" + unknownUsageInsertColumn := "" + unknownUsageInsertValue := "" + if includeUnknownUsageColumn { + unknownUsageColumn = ", has_unknown_usage INTEGER NOT NULL DEFAULT 0" + unknownUsageInsertColumn = ", has_unknown_usage" + unknownUsageInsertValue = ", 0" + } + statements := []string{ + `CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + created_at_ms INTEGER NOT NULL, + updated_at_ms INTEGER NOT NULL, + provider TEXT NOT NULL DEFAULT '', + model TEXT NOT NULL DEFAULT '', + workdir TEXT NOT NULL DEFAULT '', + task_state_json TEXT NOT NULL, + todos_json TEXT NOT NULL, + activated_skills_json TEXT NOT NULL, + token_input_total INTEGER NOT NULL DEFAULT 0, + token_output_total INTEGER NOT NULL DEFAULT 0` + unknownUsageColumn + `, + last_seq INTEGER NOT NULL DEFAULT 0, + message_count INTEGER NOT NULL DEFAULT 0 + )`, + `CREATE TABLE messages ( + session_id TEXT NOT NULL, + seq INTEGER NOT NULL, + role TEXT NOT NULL, + parts_json TEXT NOT NULL, + tool_calls_json TEXT NOT NULL DEFAULT '', + tool_call_id TEXT NOT NULL DEFAULT '', + is_error INTEGER NOT NULL DEFAULT 0, + tool_metadata_json TEXT NOT NULL DEFAULT '', + created_at_ms INTEGER NOT NULL, + PRIMARY KEY(session_id, seq), + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `CREATE TABLE session_assets ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + mime_type TEXT NOT NULL, + size_bytes INTEGER NOT NULL, + relative_path TEXT NOT NULL, + created_at_ms INTEGER NOT NULL, + FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE + )`, + `INSERT INTO sessions ( + id, title, created_at_ms, updated_at_ms, provider, model, workdir, + task_state_json, todos_json, activated_skills_json, + token_input_total, token_output_total` + unknownUsageInsertColumn + `, + last_seq, message_count + ) VALUES ( + 'session_v1', 'Legacy V1', 1000, 1000, 'openai', 'gpt-5', '/repo', + '{}', '[]', '[]', 11, 7` + unknownUsageInsertValue + `, 0, 0 + )`, + `PRAGMA user_version=1`, + } + for _, statement := range statements { + if _, err := db.ExecContext(ctx, statement); err != nil { + t.Fatalf("exec legacy schema statement: %v\n%s", err, statement) + } + } +} + func renderSessionMessageParts(message providertypes.Message) string { if len(message.Parts) == 0 { return "" diff --git a/internal/session/test_helpers_test.go b/internal/session/test_helpers_test.go index 1b78de3b..c8581112 100644 --- a/internal/session/test_helpers_test.go +++ b/internal/session/test_helpers_test.go @@ -20,7 +20,7 @@ func newTestStore(t *testing.T) *SQLiteStore { if err != nil { t.Fatalf("MkdirTemp() error = %v", err) } - store := NewStore(baseDir, workspaceRoot) + store := NewSQLiteStore(baseDir, workspaceRoot) t.Cleanup(func() { _ = store.Close() _ = os.RemoveAll(baseDir) diff --git a/internal/session/todo.go b/internal/session/todo.go index 88c51572..b3ac4f3d 100644 --- a/internal/session/todo.go +++ b/internal/session/todo.go @@ -207,11 +207,6 @@ func (s *Session) AddTodo(item TodoItem) error { return nil } -// UpdateTodoStatus 按 ID 更新 Todo 状态(兼容旧调用,无 revision 约束)。 -func (s *Session) UpdateTodoStatus(id string, status TodoStatus) error { - return s.SetTodoStatus(id, status, 0) -} - // SetTodoStatus 按 ID 更新 Todo 状态并执行 revision 检查。 func (s *Session) SetTodoStatus(id string, status TodoStatus, expectedRevision int64) error { patch := TodoPatch{ @@ -412,9 +407,6 @@ func normalizeTodoItem(item TodoItem) (TodoItem, error) { item.Content = strings.TrimSpace(item.Content) item.Dependencies = normalizeTodoDependencies(item.Dependencies) item.Executor = normalizeTodoExecutor(item.Executor) - if item.Executor == "" { - item.Executor = inferLegacyTodoExecutor(item) - } item.OwnerType = normalizeTodoOwnerType(item.OwnerType) item.OwnerID = strings.TrimSpace(item.OwnerID) item.Acceptance = normalizeTodoTextList(item.Acceptance) @@ -458,22 +450,6 @@ func normalizeTodoItem(item TodoItem) (TodoItem, error) { return item, nil } -// inferLegacyTodoExecutor 基于旧字段推断缺失 executor 的历史任务执行归属,避免升级后改变既有调度行为。 -func inferLegacyTodoExecutor(item TodoItem) string { - if normalizeTodoOwnerType(item.OwnerType) == TodoOwnerTypeSubAgent { - return TodoExecutorSubAgent - } - if item.RetryCount > 0 || item.RetryLimit > 0 { - return TodoExecutorSubAgent - } - if item.Status == TodoStatusBlocked || item.Status == TodoStatusInProgress || item.Status == TodoStatusFailed { - if strings.TrimSpace(item.FailureReason) != "" || !item.NextRetryAt.IsZero() { - return TodoExecutorSubAgent - } - } - return TodoExecutorAgent -} - // normalizeTodoDependencies 对依赖列表做去空白、去重并保持顺序。 func normalizeTodoDependencies(dependencies []string) []string { return normalizeTodoTextList(dependencies) @@ -617,7 +593,11 @@ func normalizeTodoOwnerType(ownerType string) string { // normalizeTodoExecutor 规范化 executor 字段。 func normalizeTodoExecutor(executor string) string { - return strings.ToLower(strings.TrimSpace(executor)) + normalized := strings.ToLower(strings.TrimSpace(executor)) + if normalized == "" { + return TodoExecutorAgent + } + return normalized } // isValidTodoExecutor 判断 executor 是否受支持。 diff --git a/internal/session/todo_test.go b/internal/session/todo_test.go index e7a242cb..2106875f 100644 --- a/internal/session/todo_test.go +++ b/internal/session/todo_test.go @@ -301,8 +301,8 @@ func TestSessionReplaceTodosAndUpdateTodoStatusCompatibility(t *testing.T) { t.Fatalf("unexpected session after replace: %+v", session) } - if err := session.UpdateTodoStatus("b", TodoStatusInProgress); err != nil { - t.Fatalf("UpdateTodoStatus(b,in_progress) error = %v", err) + if err := session.SetTodoStatus("b", TodoStatusInProgress, 0); err != nil { + t.Fatalf("SetTodoStatus(b,in_progress,0) error = %v", err) } b, _ := session.FindTodo("b") if b.Status != TodoStatusInProgress { @@ -393,32 +393,16 @@ func TestTodoInternalHelpers(t *testing.T) { t.Fatalf("negative retry fields should be normalized to 0, got count=%d limit=%d", normalized.RetryCount, normalized.RetryLimit) } - legacySubAgent, err := normalizeTodoItem(TodoItem{ - ID: "legacy-subagent", - Content: "legacy", + normalizedDefaultExecutor, err := normalizeTodoItem(TodoItem{ + ID: "missing-executor", + Content: "legacy payload", OwnerType: TodoOwnerTypeSubAgent, }) if err != nil { - t.Fatalf("normalizeTodoItem(legacy-subagent) error = %v", err) + t.Fatalf("expected missing executor to default to agent, got %v", err) } - if legacySubAgent.Executor != TodoExecutorSubAgent { - t.Fatalf("legacy executor = %q, want %q", legacySubAgent.Executor, TodoExecutorSubAgent) - } - - legacyRetrySubAgent, err := normalizeTodoItem(TodoItem{ - ID: "legacy-retry-subagent", - Content: "legacy retry", - Status: TodoStatusBlocked, - RetryCount: 1, - OwnerType: "", - OwnerID: "", - NextRetryAt: time.Now().UTC().Add(time.Minute), - }) - if err != nil { - t.Fatalf("normalizeTodoItem(legacy-retry-subagent) error = %v", err) - } - if legacyRetrySubAgent.Executor != TodoExecutorSubAgent { - t.Fatalf("legacy retry executor = %q, want %q", legacyRetrySubAgent.Executor, TodoExecutorSubAgent) + if normalizedDefaultExecutor.Executor != TodoExecutorAgent { + t.Fatalf("default executor = %q, want %q", normalizedDefaultExecutor.Executor, TodoExecutorAgent) } } diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index caddc057..54ed35d6 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -1086,6 +1086,7 @@ var runtimeEventHandlerRegistry = map[tuiservices.EventType]func(*App, tuiservic tuiservices.EventPermissionResolved: runtimeEventPermissionResolvedHandler, tuiservices.EventCompactApplied: runtimeEventCompactDoneHandler, tuiservices.EventCompactError: runtimeEventCompactErrorHandler, + tuiservices.EventTokenUsage: runtimeEventTokenUsageHandler, tuiservices.EventPhaseChanged: runtimeEventPhaseChangedHandler, tuiservices.EventStopReasonDecided: runtimeEventStopReasonDecidedHandler, tuiservices.EventTodoUpdated: runtimeEventTodoUpdatedHandler, @@ -1126,15 +1127,23 @@ func runtimeEventStopReasonDecidedHandler(a *App, event tuiservices.RuntimeEvent reason := strings.ToLower(strings.TrimSpace(string(payload.Reason))) switch reason { - case "success": + case strings.ToLower(string(tuiservices.StopReasonCompleted)): if strings.TrimSpace(a.state.ExecutionError) == "" { a.state.StatusText = statusReady } - case "canceled": + case strings.ToLower(string(tuiservices.StopReasonUserInterrupt)): a.state.ExecutionError = "" a.state.StatusText = statusCanceled a.appendActivity("run", "Canceled current run", "", false) - default: + case strings.ToLower(string(tuiservices.StopReasonBudgetExceeded)): + detail := strings.TrimSpace(payload.Detail) + if detail == "" { + detail = "Context budget exceeded" + } + a.state.ExecutionError = "" + a.state.StatusText = detail + a.appendActivity("run", "Context budget exceeded", detail, false) + case strings.ToLower(string(tuiservices.StopReasonFatalError)): detail := strings.TrimSpace(payload.Detail) if detail == "" { detail = "runtime stopped" @@ -1142,6 +1151,11 @@ func runtimeEventStopReasonDecidedHandler(a *App, event tuiservices.RuntimeEvent a.state.ExecutionError = detail a.state.StatusText = detail a.appendActivity("run", "Runtime stopped", detail, true) + default: + detail := "unknown stop reason: " + strings.TrimSpace(string(payload.Reason)) + a.state.ExecutionError = detail + a.state.StatusText = detail + a.appendActivity("run", "Runtime stopped", detail, true) } return false } @@ -1444,6 +1458,15 @@ func runtimeEventUsageHandler(a *App, event tuiservices.RuntimeEvent) bool { return false } +func runtimeEventTokenUsageHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.TokenUsagePayload) + if !ok { + return false + } + a.state.TokenUsage = tuiservices.MapTokenUsagePayload(payload, a.state.TokenUsage) + return false +} + // runtimeEventToolCallThinkingHandler 在工具调用进入思考阶段时同步当前工具与进度提示。 func runtimeEventToolCallThinkingHandler(a *App, event tuiservices.RuntimeEvent) bool { if payload, ok := event.Payload.(string); ok && strings.TrimSpace(payload) != "" { diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index 2c3d8c4e..f97d9e59 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -66,7 +66,7 @@ func TestRuntimeEventStopReasonDecidedHandlerBranches(t *testing.T) { } handled := runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReason(" success ")}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReason(" STOP_COMPLETED ")}, }) if handled { t.Fatalf("expected handler to return false") @@ -87,40 +87,54 @@ func TestRuntimeEventStopReasonDecidedHandlerBranches(t *testing.T) { app.state.ExecutionError = "" app.state.StatusText = "not-ready" runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReason("success")}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReasonCompleted}, }) if app.state.StatusText != statusReady { - t.Fatalf("expected success with empty execution error to set ready status") + t.Fatalf("expected completed with empty execution error to set ready status") } app.state.ExecutionError = "boom" app.state.StatusText = "" runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReason("success")}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReasonCompleted}, }) if app.state.StatusText == statusReady { - t.Fatalf("expected success branch to keep status unchanged when execution error exists") + t.Fatalf("expected completed branch to keep status unchanged when execution error exists") } runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReason("canceled")}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReasonUserInterrupt}, }) if app.state.ExecutionError != "" || app.state.StatusText != statusCanceled { t.Fatalf("expected canceled state to clear error and set canceled status") } runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReason("error"), Detail: " "}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReasonBudgetExceeded}, + }) + if app.state.ExecutionError != "" || app.state.StatusText != "Context budget exceeded" { + t.Fatalf("expected budget stop without execution error, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) + } + + runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReasonFatalError, Detail: " "}, }) if app.state.StatusText != "runtime stopped" || app.state.ExecutionError != "runtime stopped" { - t.Fatalf("expected default stop detail, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) + t.Fatalf("expected fatal stop default detail, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) } runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReason("error"), Detail: "explicit failure"}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReasonFatalError, Detail: "explicit failure"}, }) if app.state.StatusText != "explicit failure" || app.state.ExecutionError != "explicit failure" { - t.Fatalf("expected explicit stop detail to be surfaced") + t.Fatalf("expected explicit fatal stop detail to be surfaced") + } + + runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.StopReasonDecidedPayload{Reason: agentruntime.StopReason("STOP_UNKNOWN")}, + }) + if !strings.Contains(app.state.ExecutionError, "unknown stop reason") { + t.Fatalf("expected unknown stop reason error, got %q", app.state.ExecutionError) } } diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index 19caf9e8..715550d8 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -1670,6 +1670,38 @@ func TestRuntimeEventUsageHandler(t *testing.T) { } } +func TestRuntimeEventTokenUsageHandler(t *testing.T) { + app, _ := newTestApp(t) + app.state.TokenUsage.RunInputTokens = 2 + app.state.TokenUsage.RunOutputTokens = 3 + app.state.TokenUsage.RunTotalTokens = 5 + + payload := agentruntime.TokenUsagePayload{ + InputTokens: 7, + OutputTokens: 11, + SessionInputTokens: 17, + SessionOutputTokens: 19, + HasUnknownUsage: true, + } + handled := runtimeEventTokenUsageHandler(&app, agentruntime.RuntimeEvent{Payload: payload}) + if handled { + t.Fatalf("expected false") + } + if app.state.TokenUsage.RunInputTokens != 9 || + app.state.TokenUsage.RunOutputTokens != 14 || + app.state.TokenUsage.RunTotalTokens != 23 { + t.Fatalf("unexpected run token usage: %+v", app.state.TokenUsage) + } + if app.state.TokenUsage.SessionInputTokens != 17 || + app.state.TokenUsage.SessionOutputTokens != 19 || + app.state.TokenUsage.SessionTotalTokens != 36 { + t.Fatalf("unexpected session token usage: %+v", app.state.TokenUsage) + } + if runtimeEventTokenUsageHandler(&app, agentruntime.RuntimeEvent{Payload: "invalid"}) { + t.Fatalf("invalid token usage payload should return false") + } +} + func TestRuntimeEventToolCallThinkingHandler(t *testing.T) { app, _ := newTestApp(t) handled := runtimeEventToolCallThinkingHandler(&app, agentruntime.RuntimeEvent{Payload: "bash"}) diff --git a/internal/tui/services/gateway_rpc_client_additional_test.go b/internal/tui/services/gateway_rpc_client_additional_test.go index a9baf1c5..678cf5e2 100644 --- a/internal/tui/services/gateway_rpc_client_additional_test.go +++ b/internal/tui/services/gateway_rpc_client_additional_test.go @@ -19,7 +19,6 @@ import ( "time" "neo-code/internal/gateway" - gatewayauth "neo-code/internal/gateway/auth" "neo-code/internal/gateway/protocol" ) @@ -913,7 +912,7 @@ func TestGatewayAutoSpawnOutputFallbackAndPath(t *testing.T) { if err != nil { t.Fatalf("resolveGatewayAutoSpawnLogPath() error = %v", err) } - if !strings.HasSuffix(path, defaultGatewayAutoSpawnLogRelativePath) { + if !strings.HasSuffix(filepath.Clean(path), filepath.Clean(filepath.FromSlash(defaultGatewayAutoSpawnLogRelativePath))) { t.Fatalf("log path = %q", path) } }) @@ -1083,13 +1082,7 @@ func TestGatewayRPCClientAuthenticateLoadsTokenAfterGatewayAutoSpawn(t *testing. ListenAddress: "test://gateway", TokenFile: tokenFile, AutoSpawnGateway: func(_ context.Context, _ string, _ func(address string) (net.Conn, error)) (*exec.Cmd, error) { - manager, createErr := gatewayauth.NewManager(tokenFile) - if createErr != nil { - return nil, createErr - } - if strings.TrimSpace(manager.Token()) == "" { - return nil, errors.New("created token is empty") - } + writeTestAuthTokenFile(t, tokenFile, "auto-spawn-token") return nil, nil }, Dial: func(_ string) (net.Conn, error) { @@ -1203,6 +1196,9 @@ func TestGatewayAutoSpawnLogErrorBranches(t *testing.T) { }) t.Run("open log file returns open error", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("directory permission assertions are not reliable on Windows") + } base := t.TempDir() readonlyDir := filepath.Join(base, "ro") if err := os.MkdirAll(readonlyDir, 0o700); err != nil { @@ -1220,6 +1216,9 @@ func TestGatewayAutoSpawnLogErrorBranches(t *testing.T) { }) t.Run("rotate stat error", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("directory permission assertions are not reliable on Windows") + } base := t.TempDir() locked := filepath.Join(base, "locked") if err := os.MkdirAll(locked, 0o700); err != nil { diff --git a/internal/tui/services/gateway_rpc_client_test.go b/internal/tui/services/gateway_rpc_client_test.go index a747b2f7..97238a1c 100644 --- a/internal/tui/services/gateway_rpc_client_test.go +++ b/internal/tui/services/gateway_rpc_client_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net" + "os" "path/filepath" "strings" "sync/atomic" @@ -13,7 +14,6 @@ import ( "time" "neo-code/internal/gateway" - gatewayauth "neo-code/internal/gateway/auth" "neo-code/internal/gateway/protocol" ) @@ -396,11 +396,30 @@ func TestGatewayRPCClientReadLoopSustainsBackpressureWhenNotificationsAreConsume func createTestAuthTokenFile(t *testing.T) (string, string) { t.Helper() path := filepath.Join(t.TempDir(), "auth.json") - manager, err := gatewayauth.NewManager(path) + token := "test-token" + writeTestAuthTokenFile(t, path, token) + return path, token +} + +func writeTestAuthTokenFile(t *testing.T, path string, token string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("create auth dir: %v", err) + } + payload := map[string]any{ + "version": 1, + "token": token, + "created_at": time.Now().UTC(), + "updated_at": time.Now().UTC(), + } + data, err := json.MarshalIndent(payload, "", " ") if err != nil { - t.Fatalf("gatewayauth.NewManager() error = %v", err) + t.Fatalf("marshal auth token: %v", err) + } + data = append(data, '\n') + if err := os.WriteFile(path, data, 0o600); err != nil { + t.Fatalf("write auth token: %v", err) } - return path, manager.Token() } func readRPCRequestOrFail(t *testing.T, decoder *json.Decoder) protocol.JSONRPCRequest { diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index ff7b5512..492dc46a 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -189,6 +189,8 @@ func restoreRuntimePayload(eventType EventType, payload any) (any, error) { return decodeRuntimePayload[CompactResult](payload) case EventCompactError: return decodeRuntimePayload[CompactErrorPayload](payload) + case EventTokenUsage: + return decodeRuntimePayload[TokenUsagePayload](payload) case EventPhaseChanged: return decodeRuntimePayload[PhaseChangedPayload](payload) case EventStopReasonDecided: diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go index 94316eed..02def539 100644 --- a/internal/tui/services/gateway_stream_client_additional_test.go +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -123,14 +123,14 @@ func TestRestoreRuntimePayloadCoversSpecializedTypes(t *testing.T) { { name: "stop reason", eventType: EventStopReasonDecided, - payload: map[string]any{"reason": " max_rounds "}, + payload: map[string]any{"reason": " STOP_COMPLETED "}, assertFn: func(t *testing.T, got any) { t.Helper() value, ok := got.(StopReasonDecidedPayload) if !ok { t.Fatalf("payload type = %T", got) } - if value.Reason != StopReason("max_rounds") { + if value.Reason != StopReasonCompleted { t.Fatalf("reason = %q", value.Reason) } }, @@ -146,6 +146,28 @@ func TestRestoreRuntimePayloadCoversSpecializedTypes(t *testing.T) { } }, }, + { + name: "token usage payload", + eventType: EventTokenUsage, + payload: map[string]any{ + "input_tokens": 3, + "output_tokens": 5, + "session_input_tokens": 13, + "session_output_tokens": 21, + "has_unknown_usage": true, + }, + assertFn: func(t *testing.T, got any) { + t.Helper() + value, ok := got.(TokenUsagePayload) + if !ok { + t.Fatalf("payload type = %T", got) + } + if value.InputTokens != 3 || value.OutputTokens != 5 || + value.SessionInputTokens != 13 || value.SessionOutputTokens != 21 || !value.HasUnknownUsage { + t.Fatalf("payload = %#v", value) + } + }, + }, { name: "string payload", eventType: EventToolChunk, diff --git a/internal/tui/services/runtime_bridge.go b/internal/tui/services/runtime_bridge.go index 95b09034..88e41b13 100644 --- a/internal/tui/services/runtime_bridge.go +++ b/internal/tui/services/runtime_bridge.go @@ -330,6 +330,17 @@ func MapUsagePayload(payload RuntimeUsagePayload) TokenUsageVM { } } +// MapTokenUsagePayload 将当前 token_usage 事件累计进 TUI token 视图。 +func MapTokenUsagePayload(payload TokenUsagePayload, current TokenUsageVM) TokenUsageVM { + current.RunInputTokens += payload.InputTokens + current.RunOutputTokens += payload.OutputTokens + current.RunTotalTokens = current.RunInputTokens + current.RunOutputTokens + current.SessionInputTokens = payload.SessionInputTokens + current.SessionOutputTokens = payload.SessionOutputTokens + current.SessionTotalTokens = payload.SessionInputTokens + payload.SessionOutputTokens + return current +} + // MapUsageSnapshot 将 usage 快照映射为 TokenUsageVM(保留当前 run 统计不变)。 func MapUsageSnapshot(snapshot RuntimeUsageSnapshot, current TokenUsageVM) TokenUsageVM { current.SessionInputTokens = snapshot.InputTokens diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go index 003f968e..520b19ad 100644 --- a/internal/tui/services/runtime_contract.go +++ b/internal/tui/services/runtime_contract.go @@ -184,12 +184,34 @@ type PhaseChangedPayload struct { // StopReason 表示运行终止原因。 type StopReason string +const ( + // StopReasonCompleted 表示 runtime 当前协议中的正常完成原因。 + StopReasonCompleted StopReason = "STOP_COMPLETED" + // StopReasonUserInterrupt 表示 runtime 当前协议中的用户中断原因。 + StopReasonUserInterrupt StopReason = "STOP_USER_INTERRUPT" + // StopReasonFatalError 表示 runtime 当前协议中的不可恢复错误原因。 + StopReasonFatalError StopReason = "STOP_FATAL_ERROR" + // StopReasonBudgetExceeded 表示 runtime 当前协议中的预算超限停止原因。 + StopReasonBudgetExceeded StopReason = "STOP_BUDGET_EXCEEDED" +) + // StopReasonDecidedPayload 描述停止原因决策结果。 type StopReasonDecidedPayload struct { Reason StopReason `json:"reason"` Detail string `json:"detail,omitempty"` } +// TokenUsagePayload 描述 runtime 当前 token_usage 事件载荷。 +type TokenUsagePayload struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputSource string `json:"input_source,omitempty"` + OutputSource string `json:"output_source,omitempty"` + HasUnknownUsage bool `json:"has_unknown_usage,omitempty"` + SessionInputTokens int `json:"session_input_tokens"` + SessionOutputTokens int `json:"session_output_tokens"` +} + // TodoEventPayload 描述 todo 相关事件载荷。 type TodoEventPayload struct { Action string `json:"action"` diff --git a/internal/tui/tui.go b/internal/tui/tui.go index d7f18092..c508fcf6 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -1,33 +1,6 @@ package tui -import ( - "neo-code/internal/config" - "neo-code/internal/memo" - tuibootstrap "neo-code/internal/tui/bootstrap" - tuiapp "neo-code/internal/tui/core/app" - tuiservices "neo-code/internal/tui/services" -) +import tuiapp "neo-code/internal/tui/core/app" type App = tuiapp.App type ProviderController = tuiapp.ProviderController - -// New 保留 internal/tui 对外入口,内部实现转发到分层后的 core/app。 -func New(cfg *config.Config, configManager *config.Manager, runtime tuiservices.Runtime, providerSvc ProviderController) (App, error) { - return tuiapp.New(cfg, configManager, runtime, providerSvc) -} - -// NewWithMemo 创建带 memo 服务的 TUI App。 -func NewWithMemo( - cfg *config.Config, - configManager *config.Manager, - runtime tuiservices.Runtime, - providerSvc ProviderController, - memoSvc *memo.Service, -) (App, error) { - return tuiapp.NewWithMemo(cfg, configManager, runtime, providerSvc, memoSvc) -} - -// NewWithBootstrap 保留对外注入入口,内部转发到 core/app。 -func NewWithBootstrap(options tuibootstrap.Options) (App, error) { - return tuiapp.NewWithBootstrap(options) -} diff --git a/internal/tui/tui_test.go b/internal/tui/tui_test.go index bb04f7bc..2c80f4f6 100644 --- a/internal/tui/tui_test.go +++ b/internal/tui/tui_test.go @@ -1,12 +1,6 @@ package tui -import ( - "testing" - - "neo-code/internal/config" - "neo-code/internal/memo" - tuibootstrap "neo-code/internal/tui/bootstrap" -) +import "testing" func TestAppTypeAlias(t *testing.T) { var _ App = App{} @@ -15,30 +9,3 @@ func TestAppTypeAlias(t *testing.T) { func TestProviderControllerTypeAlias(t *testing.T) { var _ ProviderController = ProviderController(nil) } - -func TestNewForwardsToCore(t *testing.T) { - t.Run("nil config", func(t *testing.T) { - _, err := New(nil, &config.Manager{}, nil, nil) - if err == nil { - t.Error("expected error for nil runtime") - } - }) -} - -func TestNewWithBootstrapForwardsToCore(t *testing.T) { - t.Run("empty options", func(t *testing.T) { - _, err := NewWithBootstrap(tuibootstrap.Options{}) - if err == nil { - t.Error("expected error for empty options") - } - }) -} - -func TestNewWithMemoForwardsToCore(t *testing.T) { - t.Run("nil runtime", func(t *testing.T) { - _, err := NewWithMemo(nil, &config.Manager{}, nil, nil, &memo.Service{}) - if err == nil { - t.Error("expected error for nil runtime") - } - }) -} diff --git a/scripts/migrate_context_budget/main.go b/scripts/migrate_context_budget/main.go new file mode 100644 index 00000000..97c9cfc4 --- /dev/null +++ b/scripts/migrate_context_budget/main.go @@ -0,0 +1,54 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "strings" + + "neo-code/internal/config" +) + +const defaultNeoCodeDirName = ".neocode" + +// main 解析命令行参数并调用正式配置迁移实现。 +func main() { + baseDir := flag.String("base-dir", defaultBaseDir(), "NeoCode 配置根目录,默认为 ~/.neocode") + target := flag.String("target", "", "指定要迁移的 config.yaml;为空时使用 /config.yaml") + dryRun := flag.Bool("dry-run", false, "只检查是否需要迁移,不写入文件") + flag.Parse() + + path := strings.TrimSpace(*target) + if path == "" { + path = filepath.Join(strings.TrimSpace(*baseDir), "config.yaml") + } + result, err := config.MigrateContextBudgetConfigFile(path, *dryRun) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "迁移失败: %v\n", err) + os.Exit(1) + } + printMigrationResult(result, *dryRun) +} + +// printMigrationResult 输出迁移结果,保持脚本与打包 CLI 的用户提示一致。 +func printMigrationResult(result config.ContextBudgetMigrationResult, dryRun bool) { + if !result.Changed { + fmt.Printf("跳过: %s (%s)\n", result.Path, result.Reason) + return + } + if dryRun { + fmt.Printf("[DRY-RUN] 将迁移 %s\n", result.Path) + return + } + fmt.Printf("已迁移 %s (备份: %s)\n", result.Path, result.Backup) +} + +// defaultBaseDir 返回当前用户目录下的默认 NeoCode 配置目录。 +func defaultBaseDir() string { + home, err := os.UserHomeDir() + if err != nil || strings.TrimSpace(home) == "" { + return filepath.Join("~", defaultNeoCodeDirName) + } + return filepath.Join(home, defaultNeoCodeDirName) +} diff --git a/scripts/migrate_context_budget/main_test.go b/scripts/migrate_context_budget/main_test.go new file mode 100644 index 00000000..dff132e4 --- /dev/null +++ b/scripts/migrate_context_budget/main_test.go @@ -0,0 +1,13 @@ +package main + +import ( + "testing" +) + +func TestDefaultBaseDirReturnsPath(t *testing.T) { + t.Parallel() + + if got := defaultBaseDir(); got == "" { + t.Fatal("expected non-empty default base dir") + } +} From 6b291649e8d49112e720450d01858649196a0ab9 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 22 Apr 2026 16:09:53 +0000 Subject: [PATCH 2/9] refactor(runtime): complete budget preflight and payload v3 contract - move config migration side effects from Loader to bootstrap preflight - add migration notes for deprecated auto_compact.enabled:false - wire budget decision to estimate accuracy with reason constants - add estimate_accurate to budget_checked payload - enforce payload_version=3 in TUI stream decode - update docs and tests for new behavior Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- docs/context-compact.md | 3 +- docs/guides/configuration.md | 9 +- docs/runtime-provider-event-flow.md | 6 +- internal/app/bootstrap.go | 49 ++++++++-- internal/app/bootstrap_test.go | 76 ++++++++++++++++ internal/cli/migrate_command.go | 5 +- internal/cli/migrate_command_test.go | 32 +++++++ internal/config/context_budget_migration.go | 62 +++++++++---- .../config/context_budget_migration_test.go | 90 ++++++++++++++++++- internal/config/loader.go | 3 - internal/config/loader_test.go | 47 +++------- internal/runtime/controlplane/budget.go | 32 ++++++- internal/runtime/controlplane/budget_test.go | 86 ++++++++++++++++++ internal/runtime/controlplane/envelope.go | 2 +- internal/runtime/events.go | 2 + internal/runtime/runtime_test.go | 39 +++++--- .../tui/services/gateway_stream_client.go | 9 ++ .../gateway_stream_client_additional_test.go | 35 ++++++++ .../services/gateway_stream_client_test.go | 4 +- scripts/migrate_context_budget/main.go | 3 + 20 files changed, 510 insertions(+), 84 deletions(-) create mode 100644 internal/runtime/controlplane/budget_test.go diff --git a/docs/context-compact.md b/docs/context-compact.md index c5380c6a..0da33c6c 100644 --- a/docs/context-compact.md +++ b/docs/context-compact.md @@ -69,7 +69,8 @@ BuildRequest -> FreezeSnapshot -> EstimateInput -> DecideBudget -> (allow | comp - `context.Builder` 只构建 provider-facing request,不再返回旧的 builder 压缩建议布尔值。 - provider 发送前一定先做输入 token estimate。 - estimate 首次超预算时,runtime 执行一次 `proactive` compact,然后重建 request 并重新估算。 -- compact 后仍超预算时,runtime 直接停止本次 run,并返回 `STOP_BUDGET_EXCEEDED`。 +- compact 后仍超预算且估算高置信(`accurate=true`)时,runtime 停止本次 run,并返回 `STOP_BUDGET_EXCEEDED`。 +- compact 后仍超预算但估算低置信(`accurate=false`)时,runtime 继续发送请求,不因低置信估算直接硬停。 - provider 返回 `context_too_long` 时,runtime 触发 `reactive` compact,并重新进入同一预算闭环。 ## compact 如何压缩 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index b9a18ec8..cb9ed7f0 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -101,10 +101,12 @@ prompt_budget = context_window - reserve_tokens ## 配置结构升级 -启动时会在严格解析 `config.yaml` 前执行一次结构升级: +启动装配阶段会在严格解析 `config.yaml` 前执行一次 preflight 结构升级: - 仅当检测到 `context.auto_compact` 时,自动迁移为 `context.budget`。 - 迁移前会写入 `config.yaml.bak`,原配置内容保留在备份中。 +- 如果旧配置显式 `context.auto_compact.enabled: false`,迁移仍会执行,并记录说明: + `旧 context.auto_compact.enabled 已废弃,新预算门禁不可关闭`。 - 如果 `context.auto_compact` 与 `context.budget` 同时存在,程序会直接报错,避免猜测覆盖用户配置。 - 主解析器仍只接受当前结构;迁移完成后不会在运行时兼容旧字段。 @@ -123,7 +125,8 @@ BuildRequest -> FreezeSnapshot -> EstimateInput -> DecideBudget -> (allow | comp - provider 发送前一定先做输入 token estimate。 - 如果 estimate 没超过 `prompt_budget`,本轮允许发送。 - 如果 estimate 首次超预算,先执行一次 `proactive` compact,然后重建请求并重新估算。 -- 如果 compact 后仍超预算,直接停止当前 run,并产出 `STOP_BUDGET_EXCEEDED`。 +- 如果 compact 后仍超预算且估算为高置信(`accurate=true`),停止当前 run,并产出 `STOP_BUDGET_EXCEEDED`。 +- 如果 compact 后仍超预算但估算为低置信(`accurate=false`),不直接硬停,继续发送请求。 - 如果 provider 返回 `context_too_long`,runtime 会进入 `reactive` compact 恢复链路,并重新进入同一预算闭环。 ## provider 策略 @@ -228,7 +231,7 @@ go run ./cmd/neocode --workdir /path/to/workspace 当前版本会直接报未知字段或结构不匹配错误。处理方式是手动删除旧字段,而不是等待程序自动兼容。 -`context.auto_compact` 是例外:如果配置中只存在旧预算块,启动时会自动迁移为 `context.budget`;如果新旧预算块同时存在,则需要手动合并后再启动。 +`context.auto_compact` 是例外:如果配置中只存在旧预算块,启动 preflight 会自动迁移为 `context.budget`;如果新旧预算块同时存在,则需要手动合并后再启动。 ### API Key 未设置 diff --git a/docs/runtime-provider-event-flow.md b/docs/runtime-provider-event-flow.md index de876971..bcea98b4 100644 --- a/docs/runtime-provider-event-flow.md +++ b/docs/runtime-provider-event-flow.md @@ -27,6 +27,8 @@ - `compact_applied` - `compact_error` +当前事件 envelope 的唯一有效 `payload_version` 为 `3`。 + ## ReAct 主循环 单次 run 的主链路为: @@ -61,12 +63,14 @@ runtime 不再消费旧的 builder 压缩建议,而是使用冻结快照上的 - `estimated_input_tokens` - `prompt_budget` - `estimate_source` +- `estimate_accurate` 语义: - `allow`:本轮请求在预算内 - `compact`:首次超预算,需要先压缩 -- `stop`:压缩后仍超预算,停止当前 run +- `stop`:压缩后仍超预算且估算高置信,停止当前 run +- `allow` + `reason=exceeds_budget_inaccurate_after_compact_allow`:压缩后仍超预算但估算低置信,继续放行 ## Context Builder 职责 diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 9e2496ce..31f8f232 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -3,6 +3,7 @@ package app import ( "context" "log" + "os" "path/filepath" "strings" "time" @@ -36,12 +37,14 @@ import ( const utf8CodePage = 65001 var ( - setConsoleOutputCodePage = platformSetConsoleOutputCodePage - setConsoleInputCodePage = platformSetConsoleInputCodePage - buildToolManagerFunc = buildToolManager - newRemoteRuntimeAdapter = defaultNewRemoteRuntimeAdapter - newTUIWithMemo = tuiapp.NewWithMemo - cleanupExpiredSessions = func( + setConsoleOutputCodePage = platformSetConsoleOutputCodePage + setConsoleInputCodePage = platformSetConsoleInputCodePage + buildToolManagerFunc = buildToolManager + newRemoteRuntimeAdapter = defaultNewRemoteRuntimeAdapter + newTUIWithMemo = tuiapp.NewWithMemo + runConfigMigrationPreflight = defaultRunConfigMigrationPreflight + bootstrapLogf = log.Printf + cleanupExpiredSessions = func( ctx context.Context, store agentsession.Store, maxAge time.Duration, @@ -274,6 +277,9 @@ func BuildSharedConfigDeps( loader := config.NewLoader("", defaultCfg) manager := config.NewManager(loader) + if err := runConfigMigrationPreflight(ctx, manager.ConfigPath()); err != nil { + return bootstrapSharedBundle{}, nil, nil, err + } if _, err := manager.Load(ctx); err != nil { return bootstrapSharedBundle{}, nil, nil, err } @@ -295,6 +301,37 @@ func BuildSharedConfigDeps( }, providerRegistry, modelCatalogs, nil } +// defaultRunConfigMigrationPreflight 在启动装配阶段执行 schema 迁移,并记录一次迁移结果。 +func defaultRunConfigMigrationPreflight(ctx context.Context, configPath string) error { + if err := ctx.Err(); err != nil { + return err + } + if _, err := os.Stat(configPath); err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + result, err := config.UpgradeConfigSchema(configPath) + if err != nil { + return err + } + if !result.Changed && len(result.Notes) == 0 { + return nil + } + if result.Changed { + if result.Backup != "" { + bootstrapLogf("config migration: migrated %s (backup: %s)", result.Path, result.Backup) + } else { + bootstrapLogf("config migration: migrated %s", result.Path) + } + } + for _, note := range result.Notes { + bootstrapLogf("config migration: note: %s", strings.TrimSpace(note)) + } + return nil +} + // BuildTUIClientDeps 构建 TUI 客户端依赖,仅保留配置与 Provider 选择,不创建本地 runtime/tool 栈。 func BuildTUIClientDeps(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, error) { sharedDeps, _, _, err := BuildSharedConfigDeps(ctx, opts) diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index 04fa0a9d..878f88aa 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -60,6 +60,82 @@ func TestNewProgram(t *testing.T) { } } +func TestBuildSharedConfigDepsRunsConfigMigrationPreflight(t *testing.T) { + disableBuiltinProviderAPIKeys(t) + + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + t.Setenv("OPENAI_API_KEY", "test-key") + + configDir := filepath.Join(home, ".neocode") + if err := os.MkdirAll(configDir, 0o755); err != nil { + t.Fatalf("mkdir config dir: %v", err) + } + configPath := filepath.Join(configDir, "config.yaml") + raw := strings.TrimSpace(` +selected_provider: openai +current_model: gpt-5.4 +shell: powershell +context: + auto_compact: + enabled: false + reserve_tokens: 14000 +`) + "\n" + if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + originalLogf := bootstrapLogf + t.Cleanup(func() { bootstrapLogf = originalLogf }) + var logs []string + bootstrapLogf = func(format string, args ...any) { + logs = append(logs, fmt.Sprintf(format, args...)) + } + + shared, _, _, err := BuildSharedConfigDeps(context.Background(), BootstrapOptions{}) + if err != nil { + t.Fatalf("BuildSharedConfigDeps() error = %v", err) + } + if shared.Config.Context.Budget.ReserveTokens != 14000 { + t.Fatalf("expected migrated reserve_tokens=14000, got %d", shared.Config.Context.Budget.ReserveTokens) + } + + migrated, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("read migrated config: %v", err) + } + text := string(migrated) + if strings.Contains(text, "auto_compact:") || !strings.Contains(text, "budget:") { + t.Fatalf("expected migrated budget block, got:\n%s", text) + } + if _, err := os.Stat(configPath + ".bak"); err != nil { + t.Fatalf("expected migration backup file: %v", err) + } + if len(logs) == 0 { + t.Fatalf("expected preflight migration logs") + } + joined := strings.Join(logs, "\n") + if !strings.Contains(joined, config.ContextBudgetMigrationNoteEnabledDeprecated) { + t.Fatalf("expected migration note log, got:\n%s", joined) + } +} + +func TestBuildSharedConfigDepsReturnsPreflightError(t *testing.T) { + disableBuiltinProviderAPIKeys(t) + + originalPreflight := runConfigMigrationPreflight + t.Cleanup(func() { runConfigMigrationPreflight = originalPreflight }) + runConfigMigrationPreflight = func(context.Context, string) error { + return errors.New("preflight failed") + } + + _, _, _, err := BuildSharedConfigDeps(context.Background(), BootstrapOptions{}) + if err == nil || !strings.Contains(err.Error(), "preflight failed") { + t.Fatalf("expected preflight error, got %v", err) + } +} + func TestNewProgramNormalizesInvalidCurrentModelOnStartup(t *testing.T) { disableBuiltinProviderAPIKeys(t) originalFactory := newRemoteRuntimeAdapter diff --git a/internal/cli/migrate_command.go b/internal/cli/migrate_command.go index 544ae758..f9f5c026 100644 --- a/internal/cli/migrate_command.go +++ b/internal/cli/migrate_command.go @@ -14,7 +14,7 @@ type migrateContextBudgetOptions struct { DryRun bool } -// newMigrateCommand 构建一次性迁移命令集合,迁移逻辑不接入主配置加载路径。 +// newMigrateCommand 构建一次性迁移命令集合,命令可手动触发,启动 preflight 也会自动执行迁移。 func newMigrateCommand() *cobra.Command { cmd := &cobra.Command{ Use: "migrate", @@ -55,6 +55,9 @@ func newMigrateContextBudgetCommand() *cobra.Command { // printContextBudgetMigrationResult 输出迁移结果,确保 dry-run 和真实写入提示保持一致。 func printContextBudgetMigrationResult(cmd *cobra.Command, result config.ContextBudgetMigrationResult, dryRun bool) { writer := cmd.OutOrStdout() + for _, note := range result.Notes { + _, _ = fmt.Fprintf(writer, "说明: %s\n", strings.TrimSpace(note)) + } if !result.Changed { _, _ = fmt.Fprintf(writer, "跳过: %s (%s)\n", result.Path, result.Reason) return diff --git a/internal/cli/migrate_command_test.go b/internal/cli/migrate_command_test.go index 75880fa1..acdcad72 100644 --- a/internal/cli/migrate_command_test.go +++ b/internal/cli/migrate_command_test.go @@ -7,6 +7,8 @@ import ( "path/filepath" "strings" "testing" + + "neo-code/internal/config" ) func TestMigrateContextBudgetCommandDryRunSkipsGlobalHooks(t *testing.T) { @@ -94,3 +96,33 @@ func TestMigrateContextBudgetCommandWritesBackup(t *testing.T) { t.Fatalf("unexpected migrated config:\n%s", migrated) } } + +func TestMigrateContextBudgetCommandPrintsMigrationNotes(t *testing.T) { + originalPreload := runGlobalPreload + originalSilentCheck := runSilentUpdateCheck + t.Cleanup(func() { + runGlobalPreload = originalPreload + runSilentUpdateCheck = originalSilentCheck + }) + runGlobalPreload = func(context.Context) error { return nil } + runSilentUpdateCheck = func(context.Context) {} + + dir := t.TempDir() + target := filepath.Join(dir, "config.yaml") + original := "context:\n auto_compact:\n enabled: false\n reserve_tokens: 13000\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + var stdout bytes.Buffer + cmd := NewRootCommand() + cmd.SetOut(&stdout) + cmd.SetArgs([]string{"migrate", "context-budget", "--config", target}) + if err := cmd.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + out := stdout.String() + if !strings.Contains(out, "说明: "+config.ContextBudgetMigrationNoteEnabledDeprecated) { + t.Fatalf("expected migration note in output, got %q", out) + } +} diff --git a/internal/config/context_budget_migration.go b/internal/config/context_budget_migration.go index cb0f2c2d..e053c252 100644 --- a/internal/config/context_budget_migration.go +++ b/internal/config/context_budget_migration.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "gopkg.in/yaml.v3" ) @@ -16,17 +17,22 @@ type ContextBudgetMigrationResult struct { Changed bool Backup string Reason string + Notes []string } +const ( + // ContextBudgetMigrationNoteEnabledDeprecated 标记旧开关被废弃且预算门禁不可关闭。 + ContextBudgetMigrationNoteEnabledDeprecated = "旧 context.auto_compact.enabled 已废弃,新预算门禁不可关闭" +) + // DefaultConfigPath 返回当前用户环境下的默认主配置文件路径。 func DefaultConfigPath() string { return filepath.Join(defaultBaseDir(), configName) } -// UpgradeConfigSchemaBeforeLoad 在严格解析配置前执行一次磁盘结构升级。 -func UpgradeConfigSchemaBeforeLoad(path string) error { - _, err := MigrateContextBudgetConfigFile(path, false) - return err +// UpgradeConfigSchema 执行配置 schema 升级并返回迁移结果。 +func UpgradeConfigSchema(path string) (ContextBudgetMigrationResult, error) { + return MigrateContextBudgetConfigFile(path, false) } // MigrateContextBudgetConfigFile 将 config.yaml 中的 context.auto_compact 迁移到 context.budget。 @@ -44,10 +50,11 @@ func MigrateContextBudgetConfigFile(path string, dryRun bool) (ContextBudgetMigr return result, fmt.Errorf("config: read migration target %s: %w", path, err) } - migrated, changed, err := MigrateContextBudgetConfigContent(raw) + migrated, changed, notes, err := MigrateContextBudgetConfigContent(raw) if err != nil { return result, fmt.Errorf("config: migrate %s: %w", path, err) } + result.Notes = append(result.Notes, notes...) if !changed { result.Reason = "未检测到 context.auto_compact" return result, nil @@ -69,44 +76,45 @@ func MigrateContextBudgetConfigFile(path string, dryRun bool) (ContextBudgetMigr return result, nil } -// MigrateContextBudgetConfigContent 将旧预算 YAML 块替换为当前预算 YAML 块。 -func MigrateContextBudgetConfigContent(raw []byte) ([]byte, bool, error) { +// MigrateContextBudgetConfigContent 将旧预算 YAML 块替换为当前预算 YAML 块,并返回迁移说明。 +func MigrateContextBudgetConfigContent(raw []byte) ([]byte, bool, []string, error) { if len(bytes.TrimSpace(raw)) == 0 { - return raw, false, nil + return raw, false, nil, nil } if !bytes.Contains(raw, []byte("auto_compact")) { - return raw, false, nil + return raw, false, nil, nil } var doc map[string]any if err := yaml.Unmarshal(raw, &doc); err != nil { - return nil, false, err + return nil, false, nil, err } contextValue, ok := doc["context"] if !ok { - return raw, false, nil + return raw, false, nil, nil } contextMap, ok := migrationStringMap(contextValue) if !ok { - return nil, false, errors.New("context must be a mapping") + return nil, false, nil, errors.New("context must be a mapping") } autoValue, hasAutoCompact := contextMap["auto_compact"] if !hasAutoCompact { - return raw, false, nil + return raw, false, nil, nil } if _, hasBudget := contextMap["budget"]; hasBudget { - return nil, false, errors.New("context.auto_compact and context.budget cannot both exist") + return nil, false, nil, errors.New("context.auto_compact and context.budget cannot both exist") } autoMap, ok := migrationStringMap(autoValue) if !ok { - return nil, false, errors.New("context.auto_compact must be a mapping") + return nil, false, nil, errors.New("context.auto_compact must be a mapping") } budgetMap := make(map[string]any) migrationMoveField(autoMap, budgetMap, "input_token_threshold", "prompt_budget") migrationMoveField(autoMap, budgetMap, "reserve_tokens", "reserve_tokens") migrationMoveField(autoMap, budgetMap, "fallback_input_token_threshold", "fallback_prompt_budget") + notes := collectContextBudgetMigrationNotes(autoMap) delete(contextMap, "auto_compact") contextMap["budget"] = budgetMap @@ -114,9 +122,29 @@ func MigrateContextBudgetConfigContent(raw []byte) ([]byte, bool, error) { out, err := yaml.Marshal(doc) if err != nil { - return nil, false, err + return nil, false, nil, err + } + return out, true, notes, nil +} + +// collectContextBudgetMigrationNotes 汇总迁移过程中需要提示给用户的行为变化说明。 +func collectContextBudgetMigrationNotes(autoCompact map[string]any) []string { + if value, ok := autoCompact["enabled"]; ok && migrationExplicitFalse(value) { + return []string{ContextBudgetMigrationNoteEnabledDeprecated} + } + return nil +} + +// migrationExplicitFalse 判断迁移字段是否显式配置为 false。 +func migrationExplicitFalse(value any) bool { + switch typed := value.(type) { + case bool: + return !typed + case string: + return strings.EqualFold(strings.TrimSpace(typed), "false") + default: + return false } - return out, true, nil } // migrationMoveField 在两个 YAML map 之间迁移字段名,不修改字段值。 diff --git a/internal/config/context_budget_migration_test.go b/internal/config/context_budget_migration_test.go index 153a498b..763463a2 100644 --- a/internal/config/context_budget_migration_test.go +++ b/internal/config/context_budget_migration_test.go @@ -21,13 +21,16 @@ context: fallback_input_token_threshold: 100000 `) + "\n") - out, changed, err := MigrateContextBudgetConfigContent(input) + out, changed, notes, err := MigrateContextBudgetConfigContent(input) if err != nil { t.Fatalf("MigrateContextBudgetConfigContent() error = %v", err) } if !changed { t.Fatal("expected migration change") } + if len(notes) != 0 { + t.Fatalf("expected no migration notes, got %v", notes) + } text := string(out) if strings.Contains(text, "auto_compact:") { t.Fatalf("expected auto_compact removed, got:\n%s", text) @@ -55,12 +58,65 @@ context: input_token_threshold: 120000 `) + "\n") - _, _, err := MigrateContextBudgetConfigContent(input) + _, _, _, err := MigrateContextBudgetConfigContent(input) if err == nil || !strings.Contains(err.Error(), "cannot both exist") { t.Fatalf("expected mixed block error, got %v", err) } } +func TestMigrateContextBudgetConfigContentAddsNoteWhenEnabledExplicitlyFalse(t *testing.T) { + t.Parallel() + + input := []byte(strings.TrimSpace(` +context: + auto_compact: + enabled: false + input_token_threshold: 120000 +`) + "\n") + + _, changed, notes, err := MigrateContextBudgetConfigContent(input) + if err != nil { + t.Fatalf("MigrateContextBudgetConfigContent() error = %v", err) + } + if !changed { + t.Fatal("expected migration change") + } + if len(notes) != 1 || notes[0] != ContextBudgetMigrationNoteEnabledDeprecated { + t.Fatalf("expected notes [%q], got %v", ContextBudgetMigrationNoteEnabledDeprecated, notes) + } +} + +func TestMigrateContextBudgetConfigContentNoNoteWhenEnabledTrueOrMissing(t *testing.T) { + t.Parallel() + + cases := []string{ + strings.TrimSpace(` +context: + auto_compact: + enabled: true + reserve_tokens: 13000 +`) + "\n", + strings.TrimSpace(` +context: + auto_compact: + reserve_tokens: 13000 +`) + "\n", + } + + for _, input := range cases { + _, changed, notes, err := MigrateContextBudgetConfigContent([]byte(input)) + if err != nil { + t.Fatalf("MigrateContextBudgetConfigContent() error = %v", err) + } + if !changed { + t.Fatal("expected migration change") + } + if len(notes) != 0 { + t.Fatalf("expected no notes, got %v", notes) + } + } +} + func TestMigrateContextBudgetConfigFileCreatesBackup(t *testing.T) { t.Parallel() @@ -82,6 +138,9 @@ context: if !result.Changed { t.Fatal("expected changed result") } + if len(result.Notes) != 0 { + t.Fatalf("expected no notes, got %v", result.Notes) + } if result.Backup == "" { t.Fatal("expected backup path") } @@ -93,3 +152,30 @@ context: t.Fatalf("expected backup to keep original content, got:\n%s", backup) } } + +func TestUpgradeConfigSchemaReturnsNotes(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + target := filepath.Join(dir, configName) + original := strings.TrimSpace(` +context: + auto_compact: + enabled: false + reserve_tokens: 13000 +`) + "\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write target: %v", err) + } + + result, err := UpgradeConfigSchema(target) + if err != nil { + t.Fatalf("UpgradeConfigSchema() error = %v", err) + } + if !result.Changed { + t.Fatal("expected changed result") + } + if len(result.Notes) != 1 || result.Notes[0] != ContextBudgetMigrationNoteEnabledDeprecated { + t.Fatalf("expected note %q, got %v", ContextBudgetMigrationNoteEnabledDeprecated, result.Notes) + } +} diff --git a/internal/config/loader.go b/internal/config/loader.go index fc2e0262..f559d9fb 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -117,9 +117,6 @@ func (l *Loader) Load(ctx context.Context) (*Config, error) { if err := ctx.Err(); err != nil { return nil, err } - if err := UpgradeConfigSchemaBeforeLoad(l.ConfigPath()); err != nil { - return nil, err - } data, err := os.ReadFile(l.ConfigPath()) if err != nil { diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 121a975f..ad987c94 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -117,7 +117,7 @@ shell: powershell } } -func TestLoaderUpgradesContextBudgetBeforeStrictParse(t *testing.T) { +func TestLoaderDoesNotMigrateLegacyContextBudgetOnLoad(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) @@ -133,42 +133,23 @@ context: ` writeLoaderConfig(t, loader, raw) - cfg, err := loader.Load(context.Background()) - if err != nil { - t.Fatalf("Load() error = %v", err) - } - if cfg.Context.Budget.PromptBudget != 120000 { - t.Fatalf("expected prompt_budget migrated, got %d", cfg.Context.Budget.PromptBudget) - } - if cfg.Context.Budget.ReserveTokens != 13000 { - t.Fatalf("expected reserve_tokens migrated, got %d", cfg.Context.Budget.ReserveTokens) - } - if cfg.Context.Budget.FallbackPromptBudget != 100000 { - t.Fatalf("expected fallback_prompt_budget migrated, got %d", cfg.Context.Budget.FallbackPromptBudget) - } - - data, err := os.ReadFile(loader.ConfigPath()) - if err != nil { - t.Fatalf("read migrated config: %v", err) - } - text := string(data) - if strings.Contains(text, "auto_compact:") { - t.Fatalf("expected loader migration to remove auto_compact, got:\n%s", text) + _, err := loader.Load(context.Background()) + if err == nil || !strings.Contains(err.Error(), "field auto_compact not found") { + t.Fatalf("expected legacy auto_compact parse error, got %v", err) } - if !strings.Contains(text, "budget:") { - t.Fatalf("expected loader migration to persist budget block, got:\n%s", text) + if _, statErr := os.Stat(loader.ConfigPath() + ".bak"); !os.IsNotExist(statErr) { + t.Fatalf("expected no backup file written by loader, got %v", statErr) } - - backup, err := os.ReadFile(loader.ConfigPath() + ".bak") - if err != nil { - t.Fatalf("read migration backup: %v", err) + data, readErr := os.ReadFile(loader.ConfigPath()) + if readErr != nil { + t.Fatalf("read config: %v", readErr) } - if !strings.Contains(string(backup), "auto_compact:") { - t.Fatalf("expected backup to preserve original config, got:\n%s", backup) + if string(data) != strings.TrimSpace(raw)+"\n" { + t.Fatalf("loader should not rewrite config, got:\n%s", data) } } -func TestLoaderRejectsAmbiguousContextBudgetMigration(t *testing.T) { +func TestLoaderRejectsLegacyAndCurrentContextBudgetMixWithoutPreflight(t *testing.T) { t.Parallel() loader := NewLoader(t.TempDir(), testDefaultConfig()) @@ -185,8 +166,8 @@ context: writeLoaderConfig(t, loader, raw) _, err := loader.Load(context.Background()) - if err == nil || !strings.Contains(err.Error(), "context.auto_compact and context.budget cannot both exist") { - t.Fatalf("expected ambiguous migration error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "field auto_compact not found") { + t.Fatalf("expected legacy auto_compact parse error, got %v", err) } } diff --git a/internal/runtime/controlplane/budget.go b/internal/runtime/controlplane/budget.go index d777b929..16a1915a 100644 --- a/internal/runtime/controlplane/budget.go +++ b/internal/runtime/controlplane/budget.go @@ -9,6 +9,19 @@ const ( TurnBudgetActionStop TurnBudgetAction = "stop" ) +const ( + // BudgetDecisionReasonWithinBudget 表示估算在预算范围内。 + BudgetDecisionReasonWithinBudget = "within_budget" + // BudgetDecisionReasonExceedsBudgetFirstTime 表示首次超预算,需要先 compact。 + BudgetDecisionReasonExceedsBudgetFirstTime = "exceeds_budget_first_time" + // BudgetDecisionReasonExceedsBudgetAfterCompact 表示高置信估算在 compact 后仍超预算,需要停止。 + BudgetDecisionReasonExceedsBudgetAfterCompact = "exceeds_budget_after_compact" + // BudgetDecisionReasonExceedsBudgetInaccurateFirstTime 表示低置信估算首次超预算,先 compact 再验证。 + BudgetDecisionReasonExceedsBudgetInaccurateFirstTime = "exceeds_budget_inaccurate_first_time" + // BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow 表示低置信估算 compact 后仍超预算但允许放行。 + BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow = "exceeds_budget_inaccurate_after_compact_allow" +) + // TurnBudgetID 标识一次冻结预算尝试,避免 estimate、decision 与 usage observation 串用。 type TurnBudgetID struct { AttemptSeq int `json:"attempt_seq"` @@ -31,6 +44,7 @@ type TurnBudgetDecision struct { EstimatedInputTokens int `json:"estimated_input_tokens"` PromptBudget int `json:"prompt_budget"` EstimateSource string `json:"estimate_source,omitempty"` + EstimateAccurate bool `json:"estimate_accurate"` } // DecideTurnBudget 根据输入预算事实输出 allow、compact 或 stop 三种动作。 @@ -44,18 +58,28 @@ func DecideTurnBudget( EstimatedInputTokens: estimate.EstimatedInputTokens, PromptBudget: promptBudget, EstimateSource: estimate.EstimateSource, + EstimateAccurate: estimate.Accurate, } if estimate.EstimatedInputTokens <= promptBudget { decision.Action = TurnBudgetActionAllow - decision.Reason = "within_budget" + decision.Reason = BudgetDecisionReasonWithinBudget return decision } if compactCount == 0 { decision.Action = TurnBudgetActionCompact - decision.Reason = "exceeds_budget_first_time" + if estimate.Accurate { + decision.Reason = BudgetDecisionReasonExceedsBudgetFirstTime + } else { + decision.Reason = BudgetDecisionReasonExceedsBudgetInaccurateFirstTime + } + return decision + } + if estimate.Accurate { + decision.Action = TurnBudgetActionStop + decision.Reason = BudgetDecisionReasonExceedsBudgetAfterCompact return decision } - decision.Action = TurnBudgetActionStop - decision.Reason = "exceeds_budget_after_compact" + decision.Action = TurnBudgetActionAllow + decision.Reason = BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow return decision } diff --git a/internal/runtime/controlplane/budget_test.go b/internal/runtime/controlplane/budget_test.go new file mode 100644 index 00000000..64b5380f --- /dev/null +++ b/internal/runtime/controlplane/budget_test.go @@ -0,0 +1,86 @@ +package controlplane + +import "testing" + +func TestDecideTurnBudgetAccurateBranches(t *testing.T) { + t.Parallel() + + baseEstimate := TurnBudgetEstimate{ + ID: TurnBudgetID{ + AttemptSeq: 1, + RequestHash: "hash-1", + }, + EstimatedInputTokens: 120, + EstimateSource: "provider", + Accurate: true, + } + + within := DecideTurnBudget(baseEstimate, 120, 0) + if within.Action != TurnBudgetActionAllow { + t.Fatalf("within.Action = %q", within.Action) + } + if within.Reason != BudgetDecisionReasonWithinBudget { + t.Fatalf("within.Reason = %q", within.Reason) + } + if !within.EstimateAccurate { + t.Fatalf("within.EstimateAccurate = false, want true") + } + + firstExceed := DecideTurnBudget(baseEstimate, 100, 0) + if firstExceed.Action != TurnBudgetActionCompact { + t.Fatalf("firstExceed.Action = %q", firstExceed.Action) + } + if firstExceed.Reason != BudgetDecisionReasonExceedsBudgetFirstTime { + t.Fatalf("firstExceed.Reason = %q", firstExceed.Reason) + } + if !firstExceed.EstimateAccurate { + t.Fatalf("firstExceed.EstimateAccurate = false, want true") + } + + afterCompact := DecideTurnBudget(baseEstimate, 100, 1) + if afterCompact.Action != TurnBudgetActionStop { + t.Fatalf("afterCompact.Action = %q", afterCompact.Action) + } + if afterCompact.Reason != BudgetDecisionReasonExceedsBudgetAfterCompact { + t.Fatalf("afterCompact.Reason = %q", afterCompact.Reason) + } + if !afterCompact.EstimateAccurate { + t.Fatalf("afterCompact.EstimateAccurate = false, want true") + } +} + +func TestDecideTurnBudgetInaccurateBranches(t *testing.T) { + t.Parallel() + + estimate := TurnBudgetEstimate{ + ID: TurnBudgetID{ + AttemptSeq: 2, + RequestHash: "hash-2", + }, + EstimatedInputTokens: 200, + EstimateSource: "local", + Accurate: false, + } + + firstExceed := DecideTurnBudget(estimate, 100, 0) + if firstExceed.Action != TurnBudgetActionCompact { + t.Fatalf("firstExceed.Action = %q", firstExceed.Action) + } + if firstExceed.Reason != BudgetDecisionReasonExceedsBudgetInaccurateFirstTime { + t.Fatalf("firstExceed.Reason = %q", firstExceed.Reason) + } + if firstExceed.EstimateAccurate { + t.Fatalf("firstExceed.EstimateAccurate = true, want false") + } + + afterCompact := DecideTurnBudget(estimate, 100, 1) + if afterCompact.Action != TurnBudgetActionAllow { + t.Fatalf("afterCompact.Action = %q", afterCompact.Action) + } + if afterCompact.Reason != BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow { + t.Fatalf("afterCompact.Reason = %q", afterCompact.Reason) + } + if afterCompact.EstimateAccurate { + t.Fatalf("afterCompact.EstimateAccurate = true, want false") + } +} diff --git a/internal/runtime/controlplane/envelope.go b/internal/runtime/controlplane/envelope.go index be700ed7..a1c65626 100644 --- a/internal/runtime/controlplane/envelope.go +++ b/internal/runtime/controlplane/envelope.go @@ -1,4 +1,4 @@ package controlplane // PayloadVersion 为 runtime 事件 envelope 的当前协议版本号。 -const PayloadVersion = 2 +const PayloadVersion = 3 diff --git a/internal/runtime/events.go b/internal/runtime/events.go index befceb31..0ff7adf8 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -36,6 +36,7 @@ type BudgetCheckedPayload struct { EstimatedInputTokens int `json:"estimated_input_tokens"` PromptBudget int `json:"prompt_budget"` EstimateSource string `json:"estimate_source,omitempty"` + EstimateAccurate bool `json:"estimate_accurate"` } // ProgressEvaluatedPayload 汇总 progress 控制面的评估结果。 @@ -70,6 +71,7 @@ func newBudgetCheckedPayload(decision controlplane.TurnBudgetDecision) BudgetChe EstimatedInputTokens: decision.EstimatedInputTokens, PromptBudget: decision.PromptBudget, EstimateSource: decision.EstimateSource, + EstimateAccurate: decision.EstimateAccurate, } } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 50c1a025..bbf48ac9 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4476,7 +4476,7 @@ func TestResolvePromptBudgetFallsBackWhenResolverErrors(t *testing.T) { } } -func TestServiceRunStopsWhenBudgetStillExceededAfterProactiveCompact(t *testing.T) { +func TestServiceRunAllowsAfterProactiveCompactWhenEstimateInaccurate(t *testing.T) { t.Parallel() manager := newRuntimeConfigManager(t) @@ -4500,9 +4500,14 @@ func TestServiceRunStopsWhenBudgetStillExceededAfterProactiveCompact(t *testing. Accurate: false, }, nil }, - chatFn: func(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - t.Fatalf("Generate should not be called when budget decision stops before send") - return nil + responses: []scriptedResponse{ + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("继续执行")}, + }, + FinishReason: "stop", + }, }, } @@ -4527,7 +4532,7 @@ func TestServiceRunStopsWhenBudgetStillExceededAfterProactiveCompact(t *testing. } if err := service.Run(context.Background(), UserInput{ - RunID: "run-budget-stop", + RunID: "run-budget-inaccurate-allow", Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, }); err != nil { t.Fatalf("Run() error = %v", err) @@ -4540,12 +4545,14 @@ func TestServiceRunStopsWhenBudgetStillExceededAfterProactiveCompact(t *testing. if compactRunner.calls[0].Mode != contextcompact.ModeProactive { t.Fatalf("expected compact mode %q, got %q", contextcompact.ModeProactive, compactRunner.calls[0].Mode) } - if scripted.callCount != 0 { - t.Fatalf("expected provider Generate to be skipped, got %d calls", scripted.callCount) + if scripted.callCount != 1 { + t.Fatalf("expected provider Generate to be called once, got %d calls", scripted.callCount) } events := collectRuntimeEvents(service.Events()) var budgetActions []string + var budgetReasons []string + var budgetAccuracies []bool var stopPayload StopReasonDecidedPayload for _, event := range events { switch event.Type { @@ -4555,6 +4562,8 @@ func TestServiceRunStopsWhenBudgetStillExceededAfterProactiveCompact(t *testing. t.Fatalf("expected BudgetCheckedPayload, got %T", event.Payload) } budgetActions = append(budgetActions, payload.Action) + budgetReasons = append(budgetReasons, payload.Reason) + budgetAccuracies = append(budgetAccuracies, payload.EstimateAccurate) case EventStopReasonDecided: payload, ok := event.Payload.(StopReasonDecidedPayload) if !ok { @@ -4564,11 +4573,19 @@ func TestServiceRunStopsWhenBudgetStillExceededAfterProactiveCompact(t *testing. } } - if len(budgetActions) != 2 || budgetActions[0] != "compact" || budgetActions[1] != "stop" { - t.Fatalf("expected budget actions [compact stop], got %v", budgetActions) + if len(budgetActions) != 2 || budgetActions[0] != "compact" || budgetActions[1] != "allow" { + t.Fatalf("expected budget actions [compact allow], got %v", budgetActions) + } + if len(budgetReasons) != 2 || + budgetReasons[0] != controlplane.BudgetDecisionReasonExceedsBudgetInaccurateFirstTime || + budgetReasons[1] != controlplane.BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow { + t.Fatalf("unexpected budget reasons %v", budgetReasons) + } + if len(budgetAccuracies) != 2 || budgetAccuracies[0] || budgetAccuracies[1] { + t.Fatalf("expected inaccurate estimates, got %v", budgetAccuracies) } - if stopPayload.Reason != controlplane.StopReasonBudgetExceeded { - t.Fatalf("expected stop reason %q, got %q", controlplane.StopReasonBudgetExceeded, stopPayload.Reason) + if stopPayload.Reason != controlplane.StopReasonCompleted { + t.Fatalf("expected stop reason %q, got %q", controlplane.StopReasonCompleted, stopPayload.Reason) } } diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index 492dc46a..afdf7c57 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -14,6 +14,8 @@ import ( "neo-code/internal/tools" ) +const runtimeEventPayloadVersion = 3 + // GatewayStreamClient 负责消费 gateway.event 并恢复为 TUI 事件。 type GatewayStreamClient struct { source <-chan gatewayRPCNotification @@ -122,6 +124,13 @@ func decodeRuntimeEventFromGatewayNotification(notification gatewayRPCNotificati if event.Timestamp.IsZero() { event.Timestamp = time.Now().UTC() } + if event.PayloadVersion != runtimeEventPayloadVersion { + return RuntimeEvent{}, fmt.Errorf( + "unsupported runtime payload_version: got %d want %d", + event.PayloadVersion, + runtimeEventPayloadVersion, + ) + } rawPayload, _ := streamReadMapValue(envelope, "payload") restoredPayload, err := restoreRuntimePayload(event.Type, rawPayload) diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go index 02def539..25a6a660 100644 --- a/internal/tui/services/gateway_stream_client_additional_test.go +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -3,6 +3,7 @@ package services import ( "encoding/json" "reflect" + "strings" "testing" "time" @@ -53,6 +54,7 @@ func TestDecodeRuntimeEventFromGatewayNotificationUsesCurrentTimeWhenTimestampMi Action: gateway.FrameActionRun, Payload: map[string]any{ "runtime_event_type": string(EventError), + "payload_version": runtimeEventPayloadVersion, "payload": "boom", }, }) @@ -67,6 +69,28 @@ func TestDecodeRuntimeEventFromGatewayNotificationUsesCurrentTimeWhenTimestampMi } } +func TestDecodeRuntimeEventFromGatewayNotificationRejectsPayloadVersionMismatch(t *testing.T) { + t.Parallel() + + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(EventError), + "payload_version": runtimeEventPayloadVersion - 1, + "payload": "boom", + }, + }) + + _, err := decodeRuntimeEventFromGatewayNotification(notification) + if err == nil { + t.Fatalf("expected payload version mismatch error") + } + if got := err.Error(); got == "" || !containsAll(got, "payload_version", "want") { + t.Fatalf("unexpected error: %v", err) + } +} + func TestExtractRuntimeEnvelopeFallbackMarshalling(t *testing.T) { t.Parallel() @@ -271,6 +295,7 @@ func TestGatewayStreamClientRunSkipsNonGatewayEventsAndStopsOnClose(t *testing.T Action: gateway.FrameActionRun, Payload: map[string]any{ "runtime_event_type": string(EventAgentChunk), + "payload_version": runtimeEventPayloadVersion, "payload": "ok", }, }) @@ -292,6 +317,15 @@ func TestGatewayStreamClientRunSkipsNonGatewayEventsAndStopsOnClose(t *testing.T } } +func containsAll(input string, subs ...string) bool { + for _, sub := range subs { + if !strings.Contains(input, sub) { + return false + } + } + return true +} + func TestGatewayStreamClientRunStopsWhenSourceClosed(t *testing.T) { t.Parallel() @@ -436,6 +470,7 @@ func TestGatewayStreamDecodeAndEnvelopeExtraBranches(t *testing.T) { Action: gateway.FrameActionRun, Payload: map[string]any{ "runtime_event_type": string(EventToolResult), + "payload_version": runtimeEventPayloadVersion, "payload": "not-an-object", }, }) diff --git a/internal/tui/services/gateway_stream_client_test.go b/internal/tui/services/gateway_stream_client_test.go index 9656ccfa..bb37697f 100644 --- a/internal/tui/services/gateway_stream_client_test.go +++ b/internal/tui/services/gateway_stream_client_test.go @@ -22,7 +22,7 @@ func TestDecodeRuntimeEventFromGatewayNotificationRestoresStringPayload(t *testi "turn": 2, "phase": "thinking", "timestamp": timestamp.Format(time.RFC3339Nano), - "payload_version": 1, + "payload_version": runtimeEventPayloadVersion, "payload": "hello", }, }) @@ -57,6 +57,7 @@ func TestDecodeRuntimeEventFromGatewayNotificationRestoresToolResultPayload(t *t RunID: "run-2", Payload: map[string]any{ "runtime_event_type": string(EventToolResult), + "payload_version": runtimeEventPayloadVersion, "payload": map[string]any{ "ToolCallID": "call-1", "Name": "bash", @@ -89,6 +90,7 @@ func TestDecodeRuntimeEventFromGatewayNotificationSupportsNestedEnvelope(t *test "type": "run_progress", "payload": map[string]any{ "runtime_event_type": string(EventError), + "payload_version": runtimeEventPayloadVersion, "payload": "boom", }, }, diff --git a/scripts/migrate_context_budget/main.go b/scripts/migrate_context_budget/main.go index 97c9cfc4..55a9ae03 100644 --- a/scripts/migrate_context_budget/main.go +++ b/scripts/migrate_context_budget/main.go @@ -33,6 +33,9 @@ func main() { // printMigrationResult 输出迁移结果,保持脚本与打包 CLI 的用户提示一致。 func printMigrationResult(result config.ContextBudgetMigrationResult, dryRun bool) { + for _, note := range result.Notes { + fmt.Printf("说明: %s\n", strings.TrimSpace(note)) + } if !result.Changed { fmt.Printf("跳过: %s (%s)\n", result.Path, result.Reason) return From 82ac9222857025da627ff78d2b6d430033d1cece Mon Sep 17 00:00:00 2001 From: xgopilot Date: Wed, 22 Apr 2026 17:39:08 +0000 Subject: [PATCH 3/9] refactor(runtime): gate budget by policy and bump payload v4 - replace Accurate with GatePolicy in budget estimate and decisions - set builtin provider estimates to local+gateable - switch budget_checked payload to estimate_gate_policy - bump runtime/tui payload_version to 4 - add atomic config write helper and use it in migration/loader - add migration failure safety tests and provider/runtime regressions Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- docs/context-compact.md | 4 +- docs/guides/configuration.md | 4 +- docs/runtime-provider-event-flow.md | 8 +- internal/app/bootstrap_test.go | 2 +- internal/config/atomic_write.go | 78 ++++++++++ internal/config/context_budget_migration.go | 4 +- .../config/context_budget_migration_test.go | 128 ++++++++++++++++ internal/config/loader.go | 2 +- internal/provider/anthropic/provider.go | 2 +- internal/provider/anthropic/provider_test.go | 34 +++++ internal/provider/estimate.go | 2 + internal/provider/gemini/provider.go | 2 +- internal/provider/gemini/provider_test.go | 34 +++++ internal/provider/generate_test.go | 2 +- .../openaicompat/openaicompat_test.go | 27 ++++ internal/provider/openaicompat/provider.go | 2 +- internal/provider/types/usage.go | 2 +- internal/runtime/budget_models.go | 6 +- internal/runtime/controlplane/budget.go | 35 ++--- internal/runtime/controlplane/budget_test.go | 32 ++-- internal/runtime/controlplane/envelope.go | 2 +- internal/runtime/events.go | 4 +- internal/runtime/runtime_test.go | 139 ++++++++++++++++-- .../tui/services/gateway_stream_client.go | 2 +- 24 files changed, 491 insertions(+), 66 deletions(-) create mode 100644 internal/config/atomic_write.go diff --git a/docs/context-compact.md b/docs/context-compact.md index 0da33c6c..c93d7c57 100644 --- a/docs/context-compact.md +++ b/docs/context-compact.md @@ -69,8 +69,8 @@ BuildRequest -> FreezeSnapshot -> EstimateInput -> DecideBudget -> (allow | comp - `context.Builder` 只构建 provider-facing request,不再返回旧的 builder 压缩建议布尔值。 - provider 发送前一定先做输入 token estimate。 - estimate 首次超预算时,runtime 执行一次 `proactive` compact,然后重建 request 并重新估算。 -- compact 后仍超预算且估算高置信(`accurate=true`)时,runtime 停止本次 run,并返回 `STOP_BUDGET_EXCEEDED`。 -- compact 后仍超预算但估算低置信(`accurate=false`)时,runtime 继续发送请求,不因低置信估算直接硬停。 +- compact 后仍超预算且 `gate_policy=gateable` 时,runtime 停止本次 run,并返回 `STOP_BUDGET_EXCEEDED`。 +- compact 后仍超预算但 `gate_policy=advisory` 时,runtime 继续发送请求,不直接硬停。 - provider 返回 `context_too_long` 时,runtime 触发 `reactive` compact,并重新进入同一预算闭环。 ## compact 如何压缩 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index cb9ed7f0..6a4241f3 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -125,8 +125,8 @@ BuildRequest -> FreezeSnapshot -> EstimateInput -> DecideBudget -> (allow | comp - provider 发送前一定先做输入 token estimate。 - 如果 estimate 没超过 `prompt_budget`,本轮允许发送。 - 如果 estimate 首次超预算,先执行一次 `proactive` compact,然后重建请求并重新估算。 -- 如果 compact 后仍超预算且估算为高置信(`accurate=true`),停止当前 run,并产出 `STOP_BUDGET_EXCEEDED`。 -- 如果 compact 后仍超预算但估算为低置信(`accurate=false`),不直接硬停,继续发送请求。 +- 如果 compact 后仍超预算且 `gate_policy=gateable`,停止当前 run,并产出 `STOP_BUDGET_EXCEEDED`。 +- 如果 compact 后仍超预算但 `gate_policy=advisory`,不直接硬停,继续发送请求。 - 如果 provider 返回 `context_too_long`,runtime 会进入 `reactive` compact 恢复链路,并重新进入同一预算闭环。 ## provider 策略 diff --git a/docs/runtime-provider-event-flow.md b/docs/runtime-provider-event-flow.md index bcea98b4..d0966bd1 100644 --- a/docs/runtime-provider-event-flow.md +++ b/docs/runtime-provider-event-flow.md @@ -27,7 +27,7 @@ - `compact_applied` - `compact_error` -当前事件 envelope 的唯一有效 `payload_version` 为 `3`。 +当前事件 envelope 的唯一有效 `payload_version` 为 `4`。 ## ReAct 主循环 @@ -63,14 +63,14 @@ runtime 不再消费旧的 builder 压缩建议,而是使用冻结快照上的 - `estimated_input_tokens` - `prompt_budget` - `estimate_source` -- `estimate_accurate` +- `estimate_gate_policy` 语义: - `allow`:本轮请求在预算内 - `compact`:首次超预算,需要先压缩 -- `stop`:压缩后仍超预算且估算高置信,停止当前 run -- `allow` + `reason=exceeds_budget_inaccurate_after_compact_allow`:压缩后仍超预算但估算低置信,继续放行 +- `stop` + `reason=exceeds_budget_after_compact_stop`:压缩后仍超预算且估算可门禁(`gateable`),停止当前 run +- `allow` + `reason=exceeds_budget_after_compact_allow_advisory`:压缩后仍超预算但估算仅 advisory,继续放行 ## Context Builder 职责 diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index 878f88aa..adc9a182 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -1913,7 +1913,7 @@ func (s *stubMemoProvider) EstimateInputTokens( return providertypes.BudgetEstimate{ EstimatedInputTokens: provider.EstimateTextTokens(req.SystemPrompt), EstimateSource: provider.EstimateSourceLocal, - Accurate: false, + GatePolicy: provider.EstimateGateGateable, }, nil } diff --git a/internal/config/atomic_write.go b/internal/config/atomic_write.go new file mode 100644 index 00000000..5a1af5b6 --- /dev/null +++ b/internal/config/atomic_write.go @@ -0,0 +1,78 @@ +package config + +import ( + "bytes" + "errors" + "fmt" + "os" + "path/filepath" + "syscall" +) + +var ( + atomicCreateTemp = os.CreateTemp + atomicReadFile = os.ReadFile + atomicRename = os.Rename +) + +// writeFileAtomically 通过同目录临时文件与原子替换写入目标文件,并在写后做回读校验。 +func writeFileAtomically(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + pattern := "." + filepath.Base(path) + ".tmp-*" + tempFile, err := atomicCreateTemp(dir, pattern) + if err != nil { + return fmt.Errorf("create temp file: %w", err) + } + + tempPath := tempFile.Name() + cleanupTemp := true + defer func() { + if cleanupTemp { + _ = os.Remove(tempPath) + } + }() + + if _, err := tempFile.Write(data); err != nil { + _ = tempFile.Close() + return fmt.Errorf("write temp file: %w", err) + } + if err := tempFile.Sync(); err != nil { + _ = tempFile.Close() + return fmt.Errorf("sync temp file: %w", err) + } + if err := tempFile.Close(); err != nil { + return fmt.Errorf("close temp file: %w", err) + } + if err := os.Chmod(tempPath, perm); err != nil { + return fmt.Errorf("chmod temp file: %w", err) + } + if err := atomicRename(tempPath, path); err != nil { + return fmt.Errorf("rename temp file: %w", err) + } + cleanupTemp = false + + written, err := atomicReadFile(path) + if err != nil { + return fmt.Errorf("read back written file: %w", err) + } + if !bytes.Equal(written, data) { + return errors.New("read back mismatch") + } + if err := fsyncDirectory(dir); err != nil { + return fmt.Errorf("sync target directory: %w", err) + } + return nil +} + +// fsyncDirectory 尝试同步目录元数据,确保 rename 后的目录项在支持的平台尽快落盘。 +func fsyncDirectory(dir string) error { + handle, err := os.Open(dir) + if err != nil { + return err + } + defer handle.Close() + if err := handle.Sync(); err != nil && !errors.Is(err, syscall.EINVAL) && !errors.Is(err, os.ErrInvalid) { + return err + } + return nil +} diff --git a/internal/config/context_budget_migration.go b/internal/config/context_budget_migration.go index e053c252..2e5139c5 100644 --- a/internal/config/context_budget_migration.go +++ b/internal/config/context_budget_migration.go @@ -66,10 +66,10 @@ func MigrateContextBudgetConfigFile(path string, dryRun bool) (ContextBudgetMigr } backup := path + ".bak" - if err := os.WriteFile(backup, raw, 0o644); err != nil { + if err := writeFileAtomically(backup, raw, 0o644); err != nil { return result, fmt.Errorf("config: write migration backup %s: %w", backup, err) } - if err := os.WriteFile(path, migrated, 0o644); err != nil { + if err := writeFileAtomically(path, migrated, 0o644); err != nil { return result, fmt.Errorf("config: write migrated config %s: %w", path, err) } result.Backup = backup diff --git a/internal/config/context_budget_migration_test.go b/internal/config/context_budget_migration_test.go index 763463a2..2bbd38ef 100644 --- a/internal/config/context_budget_migration_test.go +++ b/internal/config/context_budget_migration_test.go @@ -1,6 +1,7 @@ package config import ( + "errors" "os" "path/filepath" "strings" @@ -179,3 +180,130 @@ context: t.Fatalf("expected note %q, got %v", ContextBudgetMigrationNoteEnabledDeprecated, result.Notes) } } + +func TestMigrateContextBudgetConfigFileKeepsOriginalWhenBackupWriteFails(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, configName) + original := strings.TrimSpace(` +context: + auto_compact: + input_token_threshold: 120000 +`) + "\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write target: %v", err) + } + + restore := stubAtomicWriteOps(t) + defer restore() + atomicCreateTemp = func(dir string, pattern string) (*os.File, error) { + return nil, errors.New("create temp failed") + } + + _, err := MigrateContextBudgetConfigFile(target, false) + if err == nil || !strings.Contains(err.Error(), "write migration backup") { + t.Fatalf("expected backup write error, got %v", err) + } + raw, readErr := os.ReadFile(target) + if readErr != nil { + t.Fatalf("read target: %v", readErr) + } + if string(raw) != original { + t.Fatalf("expected original config to stay unchanged, got:\n%s", raw) + } +} + +func TestMigrateContextBudgetConfigFileKeepsOriginalWhenTargetReplaceFails(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, configName) + original := strings.TrimSpace(` +context: + auto_compact: + input_token_threshold: 120000 +`) + "\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write target: %v", err) + } + + restore := stubAtomicWriteOps(t) + defer restore() + renameCount := 0 + atomicRename = func(oldpath string, newpath string) error { + renameCount++ + if renameCount == 2 { + return errors.New("rename target failed") + } + return os.Rename(oldpath, newpath) + } + + _, err := MigrateContextBudgetConfigFile(target, false) + if err == nil || !strings.Contains(err.Error(), "write migrated config") { + t.Fatalf("expected migrated config write error, got %v", err) + } + if renameCount < 2 { + t.Fatalf("expected second rename to fail, got renameCount=%d", renameCount) + } + + raw, readErr := os.ReadFile(target) + if readErr != nil { + t.Fatalf("read target: %v", readErr) + } + if string(raw) != original { + t.Fatalf("expected original config to stay unchanged, got:\n%s", raw) + } + + backupRaw, backupErr := os.ReadFile(target + ".bak") + if backupErr != nil { + t.Fatalf("read backup: %v", backupErr) + } + if string(backupRaw) != original { + t.Fatalf("expected backup to keep original content, got:\n%s", backupRaw) + } +} + +func TestMigrateContextBudgetConfigFileKeepsOriginalWhenBackupVerificationFails(t *testing.T) { + dir := t.TempDir() + target := filepath.Join(dir, configName) + original := strings.TrimSpace(` +context: + auto_compact: + input_token_threshold: 120000 +`) + "\n" + if err := os.WriteFile(target, []byte(original), 0o644); err != nil { + t.Fatalf("write target: %v", err) + } + + restore := stubAtomicWriteOps(t) + defer restore() + readCount := 0 + atomicReadFile = func(path string) ([]byte, error) { + readCount++ + if readCount == 1 { + return []byte("corrupted"), nil + } + return os.ReadFile(path) + } + + _, err := MigrateContextBudgetConfigFile(target, false) + if err == nil || !strings.Contains(err.Error(), "read back mismatch") { + t.Fatalf("expected read back mismatch error, got %v", err) + } + raw, readErr := os.ReadFile(target) + if readErr != nil { + t.Fatalf("read target: %v", readErr) + } + if string(raw) != original { + t.Fatalf("expected original config to stay unchanged, got:\n%s", raw) + } +} + +func stubAtomicWriteOps(t *testing.T) func() { + t.Helper() + prevCreateTemp := atomicCreateTemp + prevReadFile := atomicReadFile + prevRename := atomicRename + return func() { + atomicCreateTemp = prevCreateTemp + atomicReadFile = prevReadFile + atomicRename = prevRename + } +} diff --git a/internal/config/loader.go b/internal/config/loader.go index f559d9fb..ead5f0f0 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -165,7 +165,7 @@ func (l *Loader) Save(ctx context.Context, cfg *Config) error { return err } - if err := os.WriteFile(l.ConfigPath(), data, 0o644); err != nil { + if err := writeFileAtomically(l.ConfigPath(), data, 0o644); err != nil { return fmt.Errorf("config: write config file: %w", err) } diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index ffbb548a..f7db2c17 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -43,7 +43,7 @@ func (p *Provider) EstimateInputTokens( return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - Accurate: false, + GatePolicy: provider.EstimateGateGateable, }, nil } diff --git a/internal/provider/anthropic/provider_test.go b/internal/provider/anthropic/provider_test.go index 49c36ac9..32b3e43f 100644 --- a/internal/provider/anthropic/provider_test.go +++ b/internal/provider/anthropic/provider_test.go @@ -199,6 +199,40 @@ func TestBuildRequestRejectsSessionAssetWithoutReader(t *testing.T) { } } +func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { + t.Parallel() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverAnthropic, + BaseURL: "https://api.anthropic.com/v1", + DefaultModel: "claude-3-7-sonnet", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + estimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if estimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) + } + if estimate.GatePolicy != provider.EstimateGateGateable { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + } + if estimate.EstimatedInputTokens <= 0 { + t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) + } +} + func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { var drained []providertypes.StreamEvent for { diff --git a/internal/provider/estimate.go b/internal/provider/estimate.go index 07e0c9d8..8467a62b 100644 --- a/internal/provider/estimate.go +++ b/internal/provider/estimate.go @@ -8,6 +8,8 @@ import ( const ( EstimateSourceNative = "native" EstimateSourceLocal = "local" + EstimateGateAdvisory = "advisory" + EstimateGateGateable = "gateable" localEstimateSlack = 1.15 ) diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index 789fd8ca..af1a9b5d 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -46,7 +46,7 @@ func (p *Provider) EstimateInputTokens( return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - Accurate: false, + GatePolicy: provider.EstimateGateGateable, }, nil } diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index 37eaedd9..e9e866fc 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -186,6 +186,40 @@ func TestBuildRequestRejectsSessionAssetWithoutReader(t *testing.T) { } } +func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { + t.Parallel() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: "https://generativelanguage.googleapis.com/v1beta", + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + estimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if estimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) + } + if estimate.GatePolicy != provider.EstimateGateGateable { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + } + if estimate.EstimatedInputTokens <= 0 { + t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) + } +} + func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { var drained []providertypes.StreamEvent for { diff --git a/internal/provider/generate_test.go b/internal/provider/generate_test.go index 0658e187..b35e2717 100644 --- a/internal/provider/generate_test.go +++ b/internal/provider/generate_test.go @@ -23,7 +23,7 @@ func (s *stubTextGenProvider) EstimateInputTokens( return providertypes.BudgetEstimate{ EstimatedInputTokens: provider.EstimateTextTokens(req.SystemPrompt + renderEstimateMessages(req.Messages)), EstimateSource: provider.EstimateSourceLocal, - Accurate: false, + GatePolicy: provider.EstimateGateGateable, }, nil } diff --git a/internal/provider/openaicompat/openaicompat_test.go b/internal/provider/openaicompat/openaicompat_test.go index 34a00001..3cb176be 100644 --- a/internal/provider/openaicompat/openaicompat_test.go +++ b/internal/provider/openaicompat/openaicompat_test.go @@ -245,6 +245,33 @@ func TestDiscoverModelsParsesNestedContainerAndAliasFields(t *testing.T) { } } +func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { + t.Parallel() + + p, err := New(resolvedConfig("", "")) + if err != nil { + t.Fatalf("New() error = %v", err) + } + estimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if estimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) + } + if estimate.GatePolicy != provider.EstimateGateGateable { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + } + if estimate.EstimatedInputTokens <= 0 { + t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) + } +} + func TestDiscoverModelsOpenAIProfileFallsBackToGenericListKeys(t *testing.T) { t.Parallel() diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index 2dea2883..6227f9a2 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -75,7 +75,7 @@ func (p *Provider) EstimateInputTokens( return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - Accurate: false, + GatePolicy: provider.EstimateGateGateable, }, nil } diff --git a/internal/provider/types/usage.go b/internal/provider/types/usage.go index 4250c6d1..a605919c 100644 --- a/internal/provider/types/usage.go +++ b/internal/provider/types/usage.go @@ -11,5 +11,5 @@ type Usage struct { type BudgetEstimate struct { EstimatedInputTokens int `json:"estimated_input_tokens"` EstimateSource string `json:"estimate_source"` - Accurate bool `json:"accurate"` + GatePolicy string `json:"gate_policy"` } diff --git a/internal/runtime/budget_models.go b/internal/runtime/budget_models.go index 58fd03ac..6283931e 100644 --- a/internal/runtime/budget_models.go +++ b/internal/runtime/budget_models.go @@ -91,11 +91,15 @@ func newTurnBudgetEstimate( id controlplane.TurnBudgetID, estimate providertypes.BudgetEstimate, ) controlplane.TurnBudgetEstimate { + gatePolicy := controlplane.TurnBudgetGatePolicyAdvisory + if estimate.GatePolicy == provider.EstimateGateGateable { + gatePolicy = controlplane.TurnBudgetGatePolicyGateable + } return controlplane.TurnBudgetEstimate{ ID: id, EstimatedInputTokens: estimate.EstimatedInputTokens, EstimateSource: estimate.EstimateSource, - Accurate: estimate.Accurate, + GatePolicy: gatePolicy, } } diff --git a/internal/runtime/controlplane/budget.go b/internal/runtime/controlplane/budget.go index 16a1915a..872496e5 100644 --- a/internal/runtime/controlplane/budget.go +++ b/internal/runtime/controlplane/budget.go @@ -9,17 +9,22 @@ const ( TurnBudgetActionStop TurnBudgetAction = "stop" ) +const ( + // TurnBudgetGatePolicyGateable 表示估算可作为预算硬停门禁依据。 + TurnBudgetGatePolicyGateable = "gateable" + // TurnBudgetGatePolicyAdvisory 表示估算仅用于提示或触发 compact,不能硬停。 + TurnBudgetGatePolicyAdvisory = "advisory" +) + const ( // BudgetDecisionReasonWithinBudget 表示估算在预算范围内。 BudgetDecisionReasonWithinBudget = "within_budget" // BudgetDecisionReasonExceedsBudgetFirstTime 表示首次超预算,需要先 compact。 BudgetDecisionReasonExceedsBudgetFirstTime = "exceeds_budget_first_time" - // BudgetDecisionReasonExceedsBudgetAfterCompact 表示高置信估算在 compact 后仍超预算,需要停止。 - BudgetDecisionReasonExceedsBudgetAfterCompact = "exceeds_budget_after_compact" - // BudgetDecisionReasonExceedsBudgetInaccurateFirstTime 表示低置信估算首次超预算,先 compact 再验证。 - BudgetDecisionReasonExceedsBudgetInaccurateFirstTime = "exceeds_budget_inaccurate_first_time" - // BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow 表示低置信估算 compact 后仍超预算但允许放行。 - BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow = "exceeds_budget_inaccurate_after_compact_allow" + // BudgetDecisionReasonExceedsBudgetAfterCompactStop 表示 compact 后仍超预算且可门禁,必须停止。 + BudgetDecisionReasonExceedsBudgetAfterCompactStop = "exceeds_budget_after_compact_stop" + // BudgetDecisionReasonExceedsBudgetAfterCompactAllowAdvisory 表示 compact 后仍超预算但仅 advisory,允许放行。 + BudgetDecisionReasonExceedsBudgetAfterCompactAllowAdvisory = "exceeds_budget_after_compact_allow_advisory" ) // TurnBudgetID 标识一次冻结预算尝试,避免 estimate、decision 与 usage observation 串用。 @@ -33,7 +38,7 @@ type TurnBudgetEstimate struct { ID TurnBudgetID `json:"id"` EstimatedInputTokens int `json:"estimated_input_tokens"` EstimateSource string `json:"estimate_source,omitempty"` - Accurate bool `json:"accurate"` + GatePolicy string `json:"gate_policy,omitempty"` } // TurnBudgetDecision 描述冻结请求在当前预算事实下的决策结果。 @@ -44,7 +49,7 @@ type TurnBudgetDecision struct { EstimatedInputTokens int `json:"estimated_input_tokens"` PromptBudget int `json:"prompt_budget"` EstimateSource string `json:"estimate_source,omitempty"` - EstimateAccurate bool `json:"estimate_accurate"` + EstimateGatePolicy string `json:"estimate_gate_policy,omitempty"` } // DecideTurnBudget 根据输入预算事实输出 allow、compact 或 stop 三种动作。 @@ -58,7 +63,7 @@ func DecideTurnBudget( EstimatedInputTokens: estimate.EstimatedInputTokens, PromptBudget: promptBudget, EstimateSource: estimate.EstimateSource, - EstimateAccurate: estimate.Accurate, + EstimateGatePolicy: estimate.GatePolicy, } if estimate.EstimatedInputTokens <= promptBudget { decision.Action = TurnBudgetActionAllow @@ -67,19 +72,15 @@ func DecideTurnBudget( } if compactCount == 0 { decision.Action = TurnBudgetActionCompact - if estimate.Accurate { - decision.Reason = BudgetDecisionReasonExceedsBudgetFirstTime - } else { - decision.Reason = BudgetDecisionReasonExceedsBudgetInaccurateFirstTime - } + decision.Reason = BudgetDecisionReasonExceedsBudgetFirstTime return decision } - if estimate.Accurate { + if estimate.GatePolicy == TurnBudgetGatePolicyGateable { decision.Action = TurnBudgetActionStop - decision.Reason = BudgetDecisionReasonExceedsBudgetAfterCompact + decision.Reason = BudgetDecisionReasonExceedsBudgetAfterCompactStop return decision } decision.Action = TurnBudgetActionAllow - decision.Reason = BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow + decision.Reason = BudgetDecisionReasonExceedsBudgetAfterCompactAllowAdvisory return decision } diff --git a/internal/runtime/controlplane/budget_test.go b/internal/runtime/controlplane/budget_test.go index 64b5380f..680f8856 100644 --- a/internal/runtime/controlplane/budget_test.go +++ b/internal/runtime/controlplane/budget_test.go @@ -12,7 +12,7 @@ func TestDecideTurnBudgetAccurateBranches(t *testing.T) { }, EstimatedInputTokens: 120, EstimateSource: "provider", - Accurate: true, + GatePolicy: TurnBudgetGatePolicyGateable, } within := DecideTurnBudget(baseEstimate, 120, 0) @@ -22,8 +22,8 @@ func TestDecideTurnBudgetAccurateBranches(t *testing.T) { if within.Reason != BudgetDecisionReasonWithinBudget { t.Fatalf("within.Reason = %q", within.Reason) } - if !within.EstimateAccurate { - t.Fatalf("within.EstimateAccurate = false, want true") + if within.EstimateGatePolicy != TurnBudgetGatePolicyGateable { + t.Fatalf("within.EstimateGatePolicy = %q, want %q", within.EstimateGatePolicy, TurnBudgetGatePolicyGateable) } firstExceed := DecideTurnBudget(baseEstimate, 100, 0) @@ -33,23 +33,23 @@ func TestDecideTurnBudgetAccurateBranches(t *testing.T) { if firstExceed.Reason != BudgetDecisionReasonExceedsBudgetFirstTime { t.Fatalf("firstExceed.Reason = %q", firstExceed.Reason) } - if !firstExceed.EstimateAccurate { - t.Fatalf("firstExceed.EstimateAccurate = false, want true") + if firstExceed.EstimateGatePolicy != TurnBudgetGatePolicyGateable { + t.Fatalf("firstExceed.EstimateGatePolicy = %q, want %q", firstExceed.EstimateGatePolicy, TurnBudgetGatePolicyGateable) } afterCompact := DecideTurnBudget(baseEstimate, 100, 1) if afterCompact.Action != TurnBudgetActionStop { t.Fatalf("afterCompact.Action = %q", afterCompact.Action) } - if afterCompact.Reason != BudgetDecisionReasonExceedsBudgetAfterCompact { + if afterCompact.Reason != BudgetDecisionReasonExceedsBudgetAfterCompactStop { t.Fatalf("afterCompact.Reason = %q", afterCompact.Reason) } - if !afterCompact.EstimateAccurate { - t.Fatalf("afterCompact.EstimateAccurate = false, want true") + if afterCompact.EstimateGatePolicy != TurnBudgetGatePolicyGateable { + t.Fatalf("afterCompact.EstimateGatePolicy = %q, want %q", afterCompact.EstimateGatePolicy, TurnBudgetGatePolicyGateable) } } -func TestDecideTurnBudgetInaccurateBranches(t *testing.T) { +func TestDecideTurnBudgetAdvisoryBranches(t *testing.T) { t.Parallel() estimate := TurnBudgetEstimate{ @@ -59,28 +59,28 @@ func TestDecideTurnBudgetInaccurateBranches(t *testing.T) { }, EstimatedInputTokens: 200, EstimateSource: "local", - Accurate: false, + GatePolicy: TurnBudgetGatePolicyAdvisory, } firstExceed := DecideTurnBudget(estimate, 100, 0) if firstExceed.Action != TurnBudgetActionCompact { t.Fatalf("firstExceed.Action = %q", firstExceed.Action) } - if firstExceed.Reason != BudgetDecisionReasonExceedsBudgetInaccurateFirstTime { + if firstExceed.Reason != BudgetDecisionReasonExceedsBudgetFirstTime { t.Fatalf("firstExceed.Reason = %q", firstExceed.Reason) } - if firstExceed.EstimateAccurate { - t.Fatalf("firstExceed.EstimateAccurate = true, want false") + if firstExceed.EstimateGatePolicy != TurnBudgetGatePolicyAdvisory { + t.Fatalf("firstExceed.EstimateGatePolicy = %q, want %q", firstExceed.EstimateGatePolicy, TurnBudgetGatePolicyAdvisory) } afterCompact := DecideTurnBudget(estimate, 100, 1) if afterCompact.Action != TurnBudgetActionAllow { t.Fatalf("afterCompact.Action = %q", afterCompact.Action) } - if afterCompact.Reason != BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow { + if afterCompact.Reason != BudgetDecisionReasonExceedsBudgetAfterCompactAllowAdvisory { t.Fatalf("afterCompact.Reason = %q", afterCompact.Reason) } - if afterCompact.EstimateAccurate { - t.Fatalf("afterCompact.EstimateAccurate = true, want false") + if afterCompact.EstimateGatePolicy != TurnBudgetGatePolicyAdvisory { + t.Fatalf("afterCompact.EstimateGatePolicy = %q, want %q", afterCompact.EstimateGatePolicy, TurnBudgetGatePolicyAdvisory) } } diff --git a/internal/runtime/controlplane/envelope.go b/internal/runtime/controlplane/envelope.go index a1c65626..f0d08da4 100644 --- a/internal/runtime/controlplane/envelope.go +++ b/internal/runtime/controlplane/envelope.go @@ -1,4 +1,4 @@ package controlplane // PayloadVersion 为 runtime 事件 envelope 的当前协议版本号。 -const PayloadVersion = 3 +const PayloadVersion = 4 diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 0ff7adf8..03416ba2 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -36,7 +36,7 @@ type BudgetCheckedPayload struct { EstimatedInputTokens int `json:"estimated_input_tokens"` PromptBudget int `json:"prompt_budget"` EstimateSource string `json:"estimate_source,omitempty"` - EstimateAccurate bool `json:"estimate_accurate"` + EstimateGatePolicy string `json:"estimate_gate_policy,omitempty"` } // ProgressEvaluatedPayload 汇总 progress 控制面的评估结果。 @@ -71,7 +71,7 @@ func newBudgetCheckedPayload(decision controlplane.TurnBudgetDecision) BudgetChe EstimatedInputTokens: decision.EstimatedInputTokens, PromptBudget: decision.PromptBudget, EstimateSource: decision.EstimateSource, - EstimateAccurate: decision.EstimateAccurate, + EstimateGatePolicy: decision.EstimateGatePolicy, } } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index bbf48ac9..c06059b1 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -541,7 +541,7 @@ func (p *scriptedProvider) EstimateInputTokens( return providertypes.BudgetEstimate{ EstimatedInputTokens: provider.EstimateTextTokens(req.SystemPrompt + renderMessagesForEstimate(req.Messages)), EstimateSource: provider.EstimateSourceLocal, - Accurate: false, + GatePolicy: provider.EstimateGateGateable, }, nil } @@ -4476,7 +4476,7 @@ func TestResolvePromptBudgetFallsBackWhenResolverErrors(t *testing.T) { } } -func TestServiceRunAllowsAfterProactiveCompactWhenEstimateInaccurate(t *testing.T) { +func TestServiceRunStopsAfterProactiveCompactWhenEstimateGateable(t *testing.T) { t.Parallel() manager := newRuntimeConfigManager(t) @@ -4497,7 +4497,7 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateInaccurate(t *testing. return providertypes.BudgetEstimate{ EstimatedInputTokens: 99, EstimateSource: provider.EstimateSourceLocal, - Accurate: false, + GatePolicy: provider.EstimateGateGateable, }, nil }, responses: []scriptedResponse{ @@ -4532,7 +4532,122 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateInaccurate(t *testing. } if err := service.Run(context.Background(), UserInput{ - RunID: "run-budget-inaccurate-allow", + RunID: "run-budget-gateable-stop", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + + compactRunner := service.compactRunner.(*stubCompactRunner) + if len(compactRunner.calls) != 1 { + t.Fatalf("expected one proactive compact, got %d", len(compactRunner.calls)) + } + if compactRunner.calls[0].Mode != contextcompact.ModeProactive { + t.Fatalf("expected compact mode %q, got %q", contextcompact.ModeProactive, compactRunner.calls[0].Mode) + } + if scripted.callCount != 0 { + t.Fatalf("expected provider Generate to be skipped after budget stop, got %d calls", scripted.callCount) + } + + events := collectRuntimeEvents(service.Events()) + var budgetActions []string + var budgetReasons []string + var budgetGatePolicies []string + var stopPayload StopReasonDecidedPayload + for _, event := range events { + switch event.Type { + case EventBudgetChecked: + payload, ok := event.Payload.(BudgetCheckedPayload) + if !ok { + t.Fatalf("expected BudgetCheckedPayload, got %T", event.Payload) + } + budgetActions = append(budgetActions, payload.Action) + budgetReasons = append(budgetReasons, payload.Reason) + budgetGatePolicies = append(budgetGatePolicies, payload.EstimateGatePolicy) + case EventStopReasonDecided: + payload, ok := event.Payload.(StopReasonDecidedPayload) + if !ok { + t.Fatalf("expected StopReasonDecidedPayload, got %T", event.Payload) + } + stopPayload = payload + } + } + + if len(budgetActions) != 2 || budgetActions[0] != "compact" || budgetActions[1] != "stop" { + t.Fatalf("expected budget actions [compact stop], got %v", budgetActions) + } + if len(budgetReasons) != 2 || + budgetReasons[0] != controlplane.BudgetDecisionReasonExceedsBudgetFirstTime || + budgetReasons[1] != controlplane.BudgetDecisionReasonExceedsBudgetAfterCompactStop { + t.Fatalf("unexpected budget reasons %v", budgetReasons) + } + if len(budgetGatePolicies) != 2 || + budgetGatePolicies[0] != provider.EstimateGateGateable || + budgetGatePolicies[1] != provider.EstimateGateGateable { + t.Fatalf("expected gateable estimates, got %v", budgetGatePolicies) + } + if stopPayload.Reason != controlplane.StopReasonBudgetExceeded { + t.Fatalf("expected stop reason %q, got %q", controlplane.StopReasonBudgetExceeded, stopPayload.Reason) + } +} + +func TestServiceRunAllowsAfterProactiveCompactWhenEstimateAdvisory(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 10 + cfg.Context.Budget.FallbackPromptBudget = 10 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{ + EstimatedInputTokens: 99, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + }, + responses: []scriptedResponse{ + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("继续执行")}, + }, + FinishReason: "stop", + }, + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + service.compactRunner = &stubCompactRunner{ + result: contextcompact.Result{ + Applied: true, + Messages: []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("[compact_summary]\ndone:\n- archived\n\nin_progress:\n- continue")}, + }, + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }, + }, + Metrics: contextcompact.Metrics{ + TriggerMode: string(contextcompact.ModeProactive), + }, + }, + } + + if err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-advisory-allow", Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, }); err != nil { t.Fatalf("Run() error = %v", err) @@ -4552,7 +4667,7 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateInaccurate(t *testing. events := collectRuntimeEvents(service.Events()) var budgetActions []string var budgetReasons []string - var budgetAccuracies []bool + var budgetGatePolicies []string var stopPayload StopReasonDecidedPayload for _, event := range events { switch event.Type { @@ -4563,7 +4678,7 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateInaccurate(t *testing. } budgetActions = append(budgetActions, payload.Action) budgetReasons = append(budgetReasons, payload.Reason) - budgetAccuracies = append(budgetAccuracies, payload.EstimateAccurate) + budgetGatePolicies = append(budgetGatePolicies, payload.EstimateGatePolicy) case EventStopReasonDecided: payload, ok := event.Payload.(StopReasonDecidedPayload) if !ok { @@ -4577,12 +4692,14 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateInaccurate(t *testing. t.Fatalf("expected budget actions [compact allow], got %v", budgetActions) } if len(budgetReasons) != 2 || - budgetReasons[0] != controlplane.BudgetDecisionReasonExceedsBudgetInaccurateFirstTime || - budgetReasons[1] != controlplane.BudgetDecisionReasonExceedsBudgetInaccurateAfterCompactAllow { + budgetReasons[0] != controlplane.BudgetDecisionReasonExceedsBudgetFirstTime || + budgetReasons[1] != controlplane.BudgetDecisionReasonExceedsBudgetAfterCompactAllowAdvisory { t.Fatalf("unexpected budget reasons %v", budgetReasons) } - if len(budgetAccuracies) != 2 || budgetAccuracies[0] || budgetAccuracies[1] { - t.Fatalf("expected inaccurate estimates, got %v", budgetAccuracies) + if len(budgetGatePolicies) != 2 || + budgetGatePolicies[0] != provider.EstimateGateAdvisory || + budgetGatePolicies[1] != provider.EstimateGateAdvisory { + t.Fatalf("expected advisory estimates, got %v", budgetGatePolicies) } if stopPayload.Reason != controlplane.StopReasonCompleted { t.Fatalf("expected stop reason %q, got %q", controlplane.StopReasonCompleted, stopPayload.Reason) @@ -4602,7 +4719,7 @@ func TestServiceRunReconcilesUnknownOutputUsage(t *testing.T) { return providertypes.BudgetEstimate{ EstimatedInputTokens: 17, EstimateSource: provider.EstimateSourceLocal, - Accurate: false, + GatePolicy: provider.EstimateGateGateable, }, nil }, responses: []scriptedResponse{ diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index afdf7c57..34219c1d 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -14,7 +14,7 @@ import ( "neo-code/internal/tools" ) -const runtimeEventPayloadVersion = 3 +const runtimeEventPayloadVersion = 4 // GatewayStreamClient 负责消费 gateway.event 并恢复为 TUI 事件。 type GatewayStreamClient struct { From a9b0f9ae93c52323a2573adeeed316d1a3ad855f Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 23 Apr 2026 03:33:17 +0000 Subject: [PATCH 4/9] fix(runtime/provider/tui): address budget estimate gating, request reuse and stream fail-fast Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/anthropic/provider.go | 47 +++++++++- internal/provider/anthropic/provider_test.go | 71 ++++++++++++-- internal/provider/estimate.go | 14 +++ internal/provider/gemini/provider.go | 56 ++++++++++- internal/provider/gemini/provider_test.go | 65 +++++++++++-- .../provider/openaicompat/generate_sdk.go | 13 +-- .../openaicompat/openaicompat_test.go | 81 +++++++++++++++- internal/provider/openaicompat/provider.go | 93 ++++++++++++++++++- internal/runtime/run.go | 25 +++-- .../runtime_remaining_branches_test.go | 14 ++- internal/runtime/runtime_test.go | 3 +- .../tui/services/gateway_stream_client.go | 13 ++- .../gateway_stream_client_additional_test.go | 51 ++++++++++ 13 files changed, 496 insertions(+), 50 deletions(-) diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index f7db2c17..1e176a16 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -6,6 +6,7 @@ import ( "fmt" "strconv" "strings" + "sync" anthropic "github.com/anthropics/anthropic-sdk-go" @@ -25,6 +26,14 @@ type toolCallState struct { // Provider 封装 Anthropic messages 协议的请求发送与流式解析。 type Provider struct { cfg provider.RuntimeConfig + + mu sync.Mutex + prepared *preparedRequest +} + +type preparedRequest struct { + signature string + params anthropic.MessageNewParams } // EstimateInputTokens 基于 Anthropic 最终请求结构做本地输入 token 估算。 @@ -40,10 +49,11 @@ func (p *Provider) EstimateInputTokens( if err != nil { return providertypes.BudgetEstimate{}, err } + p.storePreparedRequest(provider.BuildGenerateRequestSignature(req), params) return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - GatePolicy: provider.EstimateGateGateable, + GatePolicy: provider.EstimateGateAdvisory, }, nil } @@ -57,9 +67,13 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { // Generate 发起 Anthropic 流式请求,并将 typed stream 转为统一事件。 func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - params, err := BuildRequest(ctx, p.cfg, req) - if err != nil { - return err + params, ok := p.takePreparedRequest(provider.BuildGenerateRequestSignature(req)) + if !ok { + var err error + params, err = BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } } client, err := newSDKClient(p.cfg) @@ -185,6 +199,31 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque return provider.EmitMessageDone(ctx, events, finishReason, &usage) } +// storePreparedRequest 缓存估算阶段已构建的 Anthropic 请求,供同轮发送复用。 +func (p *Provider) storePreparedRequest(signature string, params anthropic.MessageNewParams) { + p.mu.Lock() + defer p.mu.Unlock() + p.prepared = &preparedRequest{ + signature: strings.TrimSpace(signature), + params: params, + } +} + +// takePreparedRequest 读取并消费匹配签名的预构建请求,避免跨请求误复用。 +func (p *Provider) takePreparedRequest(signature string) (anthropic.MessageNewParams, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if p.prepared == nil { + return anthropic.MessageNewParams{}, false + } + current := p.prepared + p.prepared = nil + if strings.TrimSpace(signature) == "" || current.signature != strings.TrimSpace(signature) { + return anthropic.MessageNewParams{}, false + } + return current.params, true +} + // mapAnthropicSDKError 统一映射 SDK 错误为 provider 领域错误。 func mapAnthropicSDKError(err error) error { var apiErr *anthropic.Error diff --git a/internal/provider/anthropic/provider_test.go b/internal/provider/anthropic/provider_test.go index 32b3e43f..63a2ecdd 100644 --- a/internal/provider/anthropic/provider_test.go +++ b/internal/provider/anthropic/provider_test.go @@ -154,7 +154,7 @@ func TestBuildRequestSupportsImageParts(t *testing.T) { }, }, }, - SessionAssetReader: stubSessionAssetReader{ + SessionAssetReader: &stubSessionAssetReader{ assets: map[string]stubSessionAsset{ "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, }, @@ -199,7 +199,7 @@ func TestBuildRequestRejectsSessionAssetWithoutReader(t *testing.T) { } } -func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { +func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) { t.Parallel() p, err := New(provider.RuntimeConfig{ @@ -225,14 +225,67 @@ func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { if estimate.EstimateSource != provider.EstimateSourceLocal { t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) } - if estimate.GatePolicy != provider.EstimateGateGateable { - t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + if estimate.GatePolicy != provider.EstimateGateAdvisory { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory) } if estimate.EstimatedInputTokens <= 0 { t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) } } +func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "event: message_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":4}}}\n\n") + _, _ = fmt.Fprint(w, "event: content_block_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"ok\"}}\n\n") + _, _ = fmt.Fprint(w, "event: message_delta\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":1}}\n\n") + _, _ = fmt.Fprint(w, "event: message_stop\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverAnthropic, + BaseURL: server.URL, + DefaultModel: "claude-3-7-sonnet", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + reader := &stubSessionAssetReader{ + maxOpen: 1, + assets: map[string]stubSessionAsset{ + "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, + }, + } + request := providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + SessionAssetReader: reader, + } + if _, err := p.EstimateInputTokens(context.Background(), request); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), request, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if reader.openCount != 1 { + t.Fatalf("expected session asset to be opened once, got %d", reader.openCount) + } +} + func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { var drained []providertypes.StreamEvent for { @@ -252,10 +305,16 @@ type stubSessionAsset struct { } type stubSessionAssetReader struct { - assets map[string]stubSessionAsset + assets map[string]stubSessionAsset + openCount int + maxOpen int } -func (r stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { +func (r *stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { + if r.maxOpen > 0 && r.openCount >= r.maxOpen { + return nil, "", fmt.Errorf("open limit exceeded for asset: %s", assetID) + } + r.openCount++ asset, ok := r.assets[assetID] if !ok { return nil, "", fmt.Errorf("asset not found: %s", assetID) diff --git a/internal/provider/estimate.go b/internal/provider/estimate.go index 8467a62b..2f0499e8 100644 --- a/internal/provider/estimate.go +++ b/internal/provider/estimate.go @@ -1,8 +1,12 @@ package provider import ( + "crypto/sha256" + "encoding/hex" "encoding/json" "math" + + providertypes "neo-code/internal/provider/types" ) const ( @@ -29,3 +33,13 @@ func EstimateTextTokens(text string) int { } return int(math.Ceil(float64(len([]byte(text))) / 4.0 * localEstimateSlack)) } + +// BuildGenerateRequestSignature 生成 GenerateRequest 的稳定签名,用于估算与发送阶段的请求复用匹配。 +func BuildGenerateRequestSignature(req providertypes.GenerateRequest) string { + encoded, err := json.Marshal(req) + if err != nil { + return "" + } + hash := sha256.Sum256(encoded) + return hex.EncodeToString(hash[:]) +} diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index af1a9b5d..5a6ea83e 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "strings" + "sync" "google.golang.org/genai" @@ -19,6 +20,16 @@ const errorPrefix = "gemini provider: " // Provider 封装 Gemini native 协议的请求发送与流式响应解析。 type Provider struct { cfg provider.RuntimeConfig + + mu sync.Mutex + prepared *preparedRequest +} + +type preparedRequest struct { + signature string + model string + contents []*genai.Content + config *genai.GenerateContentConfig } // EstimateInputTokens 基于 Gemini 最终请求结构做本地输入 token 估算。 @@ -43,10 +54,11 @@ func (p *Provider) EstimateInputTokens( if err != nil { return providertypes.BudgetEstimate{}, err } + p.storePreparedRequest(provider.BuildGenerateRequestSignature(req), model, contents, genConfig) return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - GatePolicy: provider.EstimateGateGateable, + GatePolicy: provider.EstimateGateAdvisory, }, nil } @@ -60,9 +72,13 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { // Generate 发起 Gemini 流式请求,并将 SDK chunk 转为统一流式事件。 func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateRequest, events chan<- providertypes.StreamEvent) error { - model, contents, config, err := BuildRequest(ctx, p.cfg, req) - if err != nil { - return err + model, contents, config, ok := p.takePreparedRequest(provider.BuildGenerateRequestSignature(req)) + if !ok { + var err error + model, contents, config, err = BuildRequest(ctx, p.cfg, req) + if err != nil { + return err + } } normalizedModel := normalizeGeminiModelName(model) if normalizedModel == "" { @@ -144,6 +160,38 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque return provider.EmitMessageDone(ctx, events, finishReason, &usage) } +// storePreparedRequest 缓存估算阶段的 Gemini 构建结果,供同轮发送直接复用。 +func (p *Provider) storePreparedRequest( + signature string, + model string, + contents []*genai.Content, + config *genai.GenerateContentConfig, +) { + p.mu.Lock() + defer p.mu.Unlock() + p.prepared = &preparedRequest{ + signature: strings.TrimSpace(signature), + model: model, + contents: contents, + config: config, + } +} + +// takePreparedRequest 读取并消费签名匹配的预构建请求,避免跨请求误复用。 +func (p *Provider) takePreparedRequest(signature string) (string, []*genai.Content, *genai.GenerateContentConfig, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if p.prepared == nil { + return "", nil, nil, false + } + current := p.prepared + p.prepared = nil + if strings.TrimSpace(signature) == "" || current.signature != strings.TrimSpace(signature) { + return "", nil, nil, false + } + return current.model, current.contents, current.config, true +} + // normalizeGeminiModelName 统一清洗 Gemini 模型名,兼容 discover 返回的 "models/{id}" 形式。 func normalizeGeminiModelName(model string) string { trimmed := strings.TrimSpace(model) diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index e9e866fc..3e8296bb 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -135,7 +135,7 @@ func TestBuildRequestSupportsImageParts(t *testing.T) { }, }, }, - SessionAssetReader: stubSessionAssetReader{ + SessionAssetReader: &stubSessionAssetReader{ assets: map[string]stubSessionAsset{ "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, }, @@ -186,7 +186,7 @@ func TestBuildRequestRejectsSessionAssetWithoutReader(t *testing.T) { } } -func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { +func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) { t.Parallel() p, err := New(provider.RuntimeConfig{ @@ -212,14 +212,61 @@ func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { if estimate.EstimateSource != provider.EstimateSourceLocal { t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) } - if estimate.GatePolicy != provider.EstimateGateGateable { - t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + if estimate.GatePolicy != provider.EstimateGateAdvisory { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory) } if estimate.EstimatedInputTokens <= 0 { t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) } } +func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"content\":{\"parts\":[{\"text\":\"ok\"}]}}],\"usageMetadata\":{\"promptTokenCount\":5,\"candidatesTokenCount\":2,\"totalTokenCount\":7}}\n\n") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[]}}]}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + reader := &stubSessionAssetReader{ + maxOpen: 1, + assets: map[string]stubSessionAsset{ + "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, + }, + } + request := providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + SessionAssetReader: reader, + } + if _, err := p.EstimateInputTokens(context.Background(), request); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), request, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if reader.openCount != 1 { + t.Fatalf("expected session asset to be opened once, got %d", reader.openCount) + } +} + func drainEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent { var drained []providertypes.StreamEvent for { @@ -239,10 +286,16 @@ type stubSessionAsset struct { } type stubSessionAssetReader struct { - assets map[string]stubSessionAsset + assets map[string]stubSessionAsset + openCount int + maxOpen int } -func (r stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { +func (r *stubSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { + if r.maxOpen > 0 && r.openCount >= r.maxOpen { + return nil, "", fmt.Errorf("open limit exceeded for asset: %s", assetID) + } + r.openCount++ asset, ok := r.assets[assetID] if !ok { return nil, "", fmt.Errorf("asset not found: %s", assetID) diff --git a/internal/provider/openaicompat/generate_sdk.go b/internal/provider/openaicompat/generate_sdk.go index 890997d6..4bcc78de 100644 --- a/internal/provider/openaicompat/generate_sdk.go +++ b/internal/provider/openaicompat/generate_sdk.go @@ -21,14 +21,9 @@ import ( // generateSDKChatCompletions 走 SDK chat/completions 发送请求 func (p *Provider) generateSDKChatCompletions( ctx context.Context, - req providertypes.GenerateRequest, + payload chatcompletions.Request, events chan<- providertypes.StreamEvent, ) error { - payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) - if err != nil { - return err - } - client, err := p.newSDKClient() if err != nil { return err @@ -280,13 +275,9 @@ func (p *Provider) generateChatCompletionsWithCompatibleStream( // generateSDKResponses 走 SDK responses 发送请求,复用本地流事件映射。 func (p *Provider) generateSDKResponses( ctx context.Context, - req providertypes.GenerateRequest, + payload responses.Request, events chan<- providertypes.StreamEvent, ) error { - payload, err := responses.BuildRequest(ctx, p.cfg, req) - if err != nil { - return err - } endpoint, err := resolveChatEndpoint(p.cfg) if err != nil { return err diff --git a/internal/provider/openaicompat/openaicompat_test.go b/internal/provider/openaicompat/openaicompat_test.go index 3cb176be..a45196c9 100644 --- a/internal/provider/openaicompat/openaicompat_test.go +++ b/internal/provider/openaicompat/openaicompat_test.go @@ -3,6 +3,7 @@ package openaicompat import ( "context" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -245,7 +246,7 @@ func TestDiscoverModelsParsesNestedContainerAndAliasFields(t *testing.T) { } } -func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { +func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) { t.Parallel() p, err := New(resolvedConfig("", "")) @@ -264,14 +265,61 @@ func TestEstimateInputTokensReturnsGateableLocalEstimate(t *testing.T) { if estimate.EstimateSource != provider.EstimateSourceLocal { t.Fatalf("estimate source = %q, want %q", estimate.EstimateSource, provider.EstimateSourceLocal) } - if estimate.GatePolicy != provider.EstimateGateGateable { - t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateGateable) + if estimate.GatePolicy != provider.EstimateGateAdvisory { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory) } if estimate.EstimatedInputTokens <= 0 { t.Fatalf("expected positive estimate tokens, got %d", estimate.EstimatedInputTokens) } } +func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { + t.Setenv(config.OpenAIDefaultAPIKeyEnv, "test-key") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(`data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}} +data: [DONE] + +`)) + })) + defer server.Close() + + p, err := New(resolvedConfig(server.URL, "gpt-4.1")) + if err != nil { + t.Fatalf("New() error = %v", err) + } + p.client = server.Client() + + reader := &singleUseSessionAssetReader{ + maxOpen: 1, + assets: map[string]sessionAsset{ + "asset-1": {data: []byte("image-bytes"), mime: "image/png"}, + }, + } + request := providertypes.GenerateRequest{ + Model: "gpt-4.1", + Messages: []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }, + }, + SessionAssetReader: reader, + } + if _, err := p.EstimateInputTokens(context.Background(), request); err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), request, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + if reader.openCount != 1 { + t.Fatalf("expected session asset to be opened once, got %d", reader.openCount) + } +} + func TestDiscoverModelsOpenAIProfileFallsBackToGenericListKeys(t *testing.T) { t.Parallel() @@ -733,3 +781,30 @@ func (r *cancelAfterDoneReader) Read(p []byte) (int, error) { r.cancel() return 0, r.err } + +type sessionAsset struct { + data []byte + mime string + err error +} + +type singleUseSessionAssetReader struct { + assets map[string]sessionAsset + openCount int + maxOpen int +} + +func (r *singleUseSessionAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) { + if r.maxOpen > 0 && r.openCount >= r.maxOpen { + return nil, "", fmt.Errorf("open limit exceeded for asset: %s", assetID) + } + r.openCount++ + asset, ok := r.assets[assetID] + if !ok { + return nil, "", fmt.Errorf("asset not found: %s", assetID) + } + if asset.err != nil { + return nil, "", asset.err + } + return io.NopCloser(strings.NewReader(string(asset.data))), asset.mime, nil +} diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index 6227f9a2..db841b81 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "strings" + "sync" "time" "neo-code/internal/provider" @@ -38,6 +39,15 @@ func validateRuntimeConfig(cfg provider.RuntimeConfig) error { type Provider struct { cfg provider.RuntimeConfig client *http.Client + + mu sync.Mutex + prepared *preparedRequest +} + +type preparedRequest struct { + mode string + signature string + payload any } // EstimateInputTokens 基于 OpenAI-compatible 最终请求结构做本地输入 token 估算。 @@ -58,12 +68,18 @@ func (p *Provider) EstimateInputTokens( return providertypes.BudgetEstimate{}, buildErr } tokens, err = provider.EstimateSerializedPayloadTokens(payload) + if err == nil { + p.storePreparedRequest(mode, provider.BuildGenerateRequestSignature(req), payload) + } case executionModeResponses: payload, buildErr := responses.BuildRequest(ctx, p.cfg, req) if buildErr != nil { return providertypes.BudgetEstimate{}, buildErr } tokens, err = provider.EstimateSerializedPayloadTokens(payload) + if err == nil { + p.storePreparedRequest(mode, provider.BuildGenerateRequestSignature(req), payload) + } default: return providertypes.BudgetEstimate{}, provider.NewDiscoveryConfigError( fmt.Sprintf("openaicompat provider: driver %q resolved unsupported execution mode %q", p.cfg.Driver, mode), @@ -75,7 +91,7 @@ func (p *Provider) EstimateInputTokens( return providertypes.BudgetEstimate{ EstimatedInputTokens: tokens, EstimateSource: provider.EstimateSourceLocal, - GatePolicy: provider.EstimateGateGateable, + GatePolicy: provider.EstimateGateAdvisory, }, nil } @@ -139,9 +155,25 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque switch mode { case executionModeCompletions: - return p.generateSDKChatCompletions(ctx, req, events) + signature := provider.BuildGenerateRequestSignature(req) + if payload, ok := p.takePreparedChatCompletionsRequest(mode, signature); ok { + return p.generateSDKChatCompletions(ctx, payload, events) + } + payload, buildErr := chatcompletions.BuildRequest(ctx, p.cfg, req) + if buildErr != nil { + return buildErr + } + return p.generateSDKChatCompletions(ctx, payload, events) case executionModeResponses: - return p.generateSDKResponses(ctx, req, events) + signature := provider.BuildGenerateRequestSignature(req) + if payload, ok := p.takePreparedResponsesRequest(mode, signature); ok { + return p.generateSDKResponses(ctx, payload, events) + } + payload, buildErr := responses.BuildRequest(ctx, p.cfg, req) + if buildErr != nil { + return buildErr + } + return p.generateSDKResponses(ctx, payload, events) default: return provider.NewDiscoveryConfigError( fmt.Sprintf("openaicompat provider: driver %q resolved unsupported execution mode %q", p.cfg.Driver, mode), @@ -149,6 +181,61 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque } } +// storePreparedRequest 缓存估算阶段已构建请求,供同轮发送复用以避免重复构建。 +func (p *Provider) storePreparedRequest(mode string, signature string, payload any) { + p.mu.Lock() + defer p.mu.Unlock() + p.prepared = &preparedRequest{ + mode: mode, + signature: strings.TrimSpace(signature), + payload: payload, + } +} + +// takePreparedChatCompletionsRequest 读取并消费 chat/completions 预构建请求,仅在签名匹配时命中。 +func (p *Provider) takePreparedChatCompletionsRequest(mode string, signature string) (chatcompletions.Request, bool) { + raw, ok := p.takePreparedRequest(mode, signature) + if !ok { + return chatcompletions.Request{}, false + } + payload, ok := raw.(chatcompletions.Request) + if !ok { + return chatcompletions.Request{}, false + } + return payload, true +} + +// takePreparedResponsesRequest 读取并消费 responses 预构建请求,仅在签名匹配时命中。 +func (p *Provider) takePreparedResponsesRequest(mode string, signature string) (responses.Request, bool) { + raw, ok := p.takePreparedRequest(mode, signature) + if !ok { + return responses.Request{}, false + } + payload, ok := raw.(responses.Request) + if !ok { + return responses.Request{}, false + } + return payload, true +} + +// takePreparedRequest 读取并消费缓存请求,避免跨请求误复用。 +func (p *Provider) takePreparedRequest(mode string, signature string) (any, bool) { + p.mu.Lock() + defer p.mu.Unlock() + if p.prepared == nil { + return nil, false + } + current := p.prepared + p.prepared = nil + if current.mode != mode { + return nil, false + } + if strings.TrimSpace(signature) == "" || current.signature != strings.TrimSpace(signature) { + return nil, false + } + return current.payload, true +} + // resolveExecutionMode 解析当前配置对应的 OpenAI-compatible 执行模式。 func resolveExecutionMode(cfg provider.RuntimeConfig) (string, error) { if provider.NormalizeProviderDriver(cfg.Driver) != DriverName { diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 681d3d23..76316f24 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -147,7 +147,12 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { continue } - decision, err := s.evaluateTurnBudget(ctx, &state, snapshot) + modelProvider, err := s.providerFactory.Build(ctx, snapshot.ProviderConfig) + if err != nil { + return s.handleRunError(ctx, state.runID, state.session.ID, err) + } + + decision, err := s.evaluateTurnBudget(ctx, &state, snapshot, modelProvider) if err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } @@ -168,7 +173,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { return nil } - turnOutput, err := s.callProviderWithRetry(ctx, &state, snapshot) + turnOutput, err := s.callProviderWithRetry(ctx, &state, snapshot, modelProvider) if err != nil { if provider.IsContextTooLong(err) && state.reactiveCompactAttempts < snapshot.Config.Context.Budget.MaxReactiveCompacts { @@ -388,6 +393,7 @@ func (s *Service) callProviderWithRetry( ctx context.Context, state *runState, snapshot TurnBudgetSnapshot, + initialProvider provider.Provider, ) (turnProviderOutput, error) { var lastErr error @@ -405,9 +411,13 @@ func (s *Service) callProviderWithRetry( } } - modelProvider, err := s.providerFactory.Build(ctx, snapshot.ProviderConfig) - if err != nil { - return turnProviderOutput{}, err + modelProvider := initialProvider + if retryAttempt > 0 { + var err error + modelProvider, err = s.providerFactory.Build(ctx, snapshot.ProviderConfig) + if err != nil { + return turnProviderOutput{}, err + } } streamOutcome := generateStreamingMessage(ctx, modelProvider, snapshot.Request, streaming.Hooks{ @@ -524,11 +534,8 @@ func (s *Service) evaluateTurnBudget( ctx context.Context, state *runState, snapshot TurnBudgetSnapshot, + modelProvider provider.Provider, ) (controlplane.TurnBudgetDecision, error) { - modelProvider, err := s.providerFactory.Build(ctx, snapshot.ProviderConfig) - if err != nil { - return controlplane.TurnBudgetDecision{}, err - } providerEstimate, err := modelProvider.EstimateInputTokens(ctx, snapshot.Request) if err != nil { return controlplane.TurnBudgetDecision{}, err diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index dd1c3aab..2eeabc2c 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -606,7 +606,12 @@ func TestRunAndProviderRetryRemainingBranches(t *testing.T) { }() service := &Service{providerFactory: &scriptedProviderFactory{provider: providerRetry}, events: make(chan RuntimeEvent, 8)} state := newRunState("run-retry-backoff", newRuntimeSession("session-retry-backoff")) - _, err := service.callProviderWithRetry(ctx, &state, TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}) + _, err := service.callProviderWithRetry( + ctx, + &state, + TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}, + providerRetry, + ) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -621,7 +626,12 @@ func TestRunAndProviderRetryRemainingBranches(t *testing.T) { }} service := &Service{providerFactory: &scriptedProviderFactory{provider: providerRetry}, events: make(chan RuntimeEvent, 8)} state := newRunState("run-retry-ctx-check", newRuntimeSession("session-retry-ctx-check")) - _, err := service.callProviderWithRetry(ctx, &state, TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}) + _, err := service.callProviderWithRetry( + ctx, + &state, + TurnBudgetSnapshot{ProviderConfig: provider.RuntimeConfig{Name: "x"}}, + providerRetry, + ) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index c06059b1..091ed67a 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -1022,7 +1022,7 @@ func TestServiceRun(t *testing.T) { t.Fatalf("Run() error = %v", err) } - expectedProviderBuilds := tt.expectProviderCalls * 2 + expectedProviderBuilds := tt.expectProviderCalls if factory.calls != expectedProviderBuilds { t.Fatalf("expected %d provider builds, got %d", expectedProviderBuilds, factory.calls) } @@ -3797,6 +3797,7 @@ func TestCallProviderWithRetryReturnsCombinedForwardError(t *testing.T) { context.Background(), &state, snapshot, + scripted, ) if err == nil || !containsError(err, "provider stream handling failed after provider error") { t.Fatalf("expected combined forward/provider error, got %v", err) diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index 34219c1d..72d137ba 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -71,15 +71,19 @@ func (c *GatewayStreamClient) run() { event, err := decodeRuntimeEventFromGatewayNotification(notification) if err != nil { + errMessage := fmt.Sprintf("gateway stream decode error: %v", err) select { case <-c.closeCh: return case c.events <- RuntimeEvent{ Type: EventError, Timestamp: time.Now().UTC(), - Payload: fmt.Sprintf("gateway stream decode error: %v", err), + Payload: errMessage, }: } + if isRuntimePayloadVersionMismatch(errMessage) { + return + } continue } @@ -92,6 +96,13 @@ func (c *GatewayStreamClient) run() { } } +// isRuntimePayloadVersionMismatch 判断错误是否由 runtime 事件版本不匹配触发,用于快速停止消费避免噪声洪泛。 +func isRuntimePayloadVersionMismatch(errMessage string) bool { + normalized := strings.ToLower(strings.TrimSpace(errMessage)) + return strings.Contains(normalized, "payload_version") && + strings.Contains(normalized, "unsupported") +} + // decodeRuntimeEventFromGatewayNotification 将 gateway.event 通知还原为事件。 func decodeRuntimeEventFromGatewayNotification(notification gatewayRPCNotification) (RuntimeEvent, error) { var frame gateway.MessageFrame diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go index 25a6a660..da167c07 100644 --- a/internal/tui/services/gateway_stream_client_additional_test.go +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -317,6 +317,57 @@ func TestGatewayStreamClientRunSkipsNonGatewayEventsAndStopsOnClose(t *testing.T } } +func TestGatewayStreamClientRunStopsOnPayloadVersionMismatch(t *testing.T) { + t.Parallel() + + source := make(chan gatewayRPCNotification, 3) + client := NewGatewayStreamClient(source) + + source <- buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(EventAgentChunk), + "payload_version": runtimeEventPayloadVersion - 1, + "payload": "legacy", + }, + }) + source <- buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(EventAgentChunk), + "payload_version": runtimeEventPayloadVersion, + "payload": "ok", + }, + }) + + select { + case event, ok := <-client.Events(): + if !ok { + t.Fatalf("events channel closed before decode error event") + } + if event.Type != EventError { + t.Fatalf("event.Type = %q, want %q", event.Type, EventError) + } + payload, payloadOK := event.Payload.(string) + if !payloadOK || !containsAll(payload, "payload_version", "want") { + t.Fatalf("event.Payload = %#v", event.Payload) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for decode error event") + } + + select { + case _, ok := <-client.Events(): + if ok { + t.Fatalf("expected stream to stop after payload version mismatch") + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for events channel close") + } +} + func containsAll(input string, subs ...string) bool { for _, sub := range subs { if !strings.Contains(input, sub) { From ebfd4b266e020c52a374283f266da1b6b4a5d9c5 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 23 Apr 2026 04:16:56 +0000 Subject: [PATCH 5/9] fix(runtime): degrade when estimate tokens fails Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/runtime/controlplane/budget.go | 2 + internal/runtime/events.go | 21 +++++ internal/runtime/run.go | 14 ++- internal/runtime/runtime_test.go | 117 ++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 1 deletion(-) diff --git a/internal/runtime/controlplane/budget.go b/internal/runtime/controlplane/budget.go index 872496e5..6ac5a291 100644 --- a/internal/runtime/controlplane/budget.go +++ b/internal/runtime/controlplane/budget.go @@ -19,6 +19,8 @@ const ( const ( // BudgetDecisionReasonWithinBudget 表示估算在预算范围内。 BudgetDecisionReasonWithinBudget = "within_budget" + // BudgetDecisionReasonEstimateFailedBypass 表示估算失败后跳过预算门禁并放行。 + BudgetDecisionReasonEstimateFailedBypass = "estimate_failed_bypass" // BudgetDecisionReasonExceedsBudgetFirstTime 表示首次超预算,需要先 compact。 BudgetDecisionReasonExceedsBudgetFirstTime = "exceeds_budget_first_time" // BudgetDecisionReasonExceedsBudgetAfterCompactStop 表示 compact 后仍超预算且可门禁,必须停止。 diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 03416ba2..2b42cbd9 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -39,6 +39,13 @@ type BudgetCheckedPayload struct { EstimateGatePolicy string `json:"estimate_gate_policy,omitempty"` } +// BudgetEstimateFailedPayload 描述预算估算失败时的降级诊断信息。 +type BudgetEstimateFailedPayload struct { + AttemptSeq int `json:"attempt_seq"` + RequestHash string `json:"request_hash"` + Message string `json:"message"` +} + // ProgressEvaluatedPayload 汇总 progress 控制面的评估结果。 type ProgressEvaluatedPayload struct { Score controlplane.ProgressScore `json:"score"` @@ -75,6 +82,18 @@ func newBudgetCheckedPayload(decision controlplane.TurnBudgetDecision) BudgetChe } } +// newBudgetEstimateFailedPayload 将估算失败错误转换为 runtime 诊断事件 payload。 +func newBudgetEstimateFailedPayload(id controlplane.TurnBudgetID, err error) BudgetEstimateFailedPayload { + payload := BudgetEstimateFailedPayload{ + AttemptSeq: id.AttemptSeq, + RequestHash: id.RequestHash, + } + if err != nil { + payload.Message = err.Error() + } + return payload +} + // newLedgerReconciledPayload 将 usage observation 与调和结果拼装为对外事件 payload。 func newLedgerReconciledPayload( observation TurnBudgetUsageObservation, @@ -200,6 +219,8 @@ const ( EventPhaseChanged EventType = "phase_changed" // EventBudgetChecked 表示预算控制面对冻结请求完成一次预算决策。 EventBudgetChecked EventType = "budget_checked" + // EventBudgetEstimateFailed 表示预算估算失败并进入降级放行。 + EventBudgetEstimateFailed EventType = "budget_estimate_failed" // EventProgressEvaluated 表示 progress 评估完成。 EventProgressEvaluated EventType = "progress_evaluated" // EventStopReasonDecided 表示 stop reason 已决议。 diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 76316f24..17197798 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -538,7 +538,19 @@ func (s *Service) evaluateTurnBudget( ) (controlplane.TurnBudgetDecision, error) { providerEstimate, err := modelProvider.EstimateInputTokens(ctx, snapshot.Request) if err != nil { - return controlplane.TurnBudgetDecision{}, err + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return controlplane.TurnBudgetDecision{}, err + } + s.emitRunScoped(ctx, EventBudgetEstimateFailed, state, newBudgetEstimateFailedPayload(snapshot.ID, err)) + decision := controlplane.TurnBudgetDecision{ + ID: snapshot.ID, + Action: controlplane.TurnBudgetActionAllow, + Reason: controlplane.BudgetDecisionReasonEstimateFailedBypass, + PromptBudget: snapshot.PromptBudget, + EstimateGatePolicy: provider.EstimateGateAdvisory, + } + s.emitRunScoped(ctx, EventBudgetChecked, state, newBudgetCheckedPayload(decision)) + return decision, nil } estimate := newTurnBudgetEstimate(snapshot.ID, providerEstimate) decision := controlplane.DecideTurnBudget( diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 091ed67a..3d6fea52 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4707,6 +4707,123 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateAdvisory(t *testing.T) } } +func TestServiceRunBypassesBudgetGateWhenEstimateFails(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 10 + cfg.Context.Budget.FallbackPromptBudget = 10 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{}, errors.New("estimate unavailable") + }, + responses: []scriptedResponse{ + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("继续执行")}, + }, + FinishReason: "stop", + }, + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + + if err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-estimate-failed-bypass", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + if scripted.callCount != 1 { + t.Fatalf("expected provider Generate to be called once, got %d calls", scripted.callCount) + } + + events := collectRuntimeEvents(service.Events()) + foundDiagnostic := false + foundBudgetChecked := false + var stopPayload StopReasonDecidedPayload + for _, event := range events { + switch event.Type { + case EventBudgetEstimateFailed: + payload, ok := event.Payload.(BudgetEstimateFailedPayload) + if !ok { + t.Fatalf("expected BudgetEstimateFailedPayload, got %T", event.Payload) + } + if payload.Message == "" { + t.Fatalf("expected non-empty estimate failure message") + } + foundDiagnostic = true + case EventBudgetChecked: + payload, ok := event.Payload.(BudgetCheckedPayload) + if !ok { + t.Fatalf("expected BudgetCheckedPayload, got %T", event.Payload) + } + if payload.Action != string(controlplane.TurnBudgetActionAllow) { + t.Fatalf("expected budget action allow, got %q", payload.Action) + } + if payload.Reason != controlplane.BudgetDecisionReasonEstimateFailedBypass { + t.Fatalf("expected reason %q, got %q", controlplane.BudgetDecisionReasonEstimateFailedBypass, payload.Reason) + } + foundBudgetChecked = true + case EventStopReasonDecided: + payload, ok := event.Payload.(StopReasonDecidedPayload) + if !ok { + t.Fatalf("expected StopReasonDecidedPayload, got %T", event.Payload) + } + stopPayload = payload + } + } + if !foundDiagnostic { + t.Fatalf("expected budget_estimate_failed event") + } + if !foundBudgetChecked { + t.Fatalf("expected budget_checked event") + } + if stopPayload.Reason != controlplane.StopReasonCompleted { + t.Fatalf("expected stop reason %q, got %q", controlplane.StopReasonCompleted, stopPayload.Reason) + } + assertNoEventType(t, events, EventError) +} + +func TestServiceRunFailsWhenEstimateContextCanceled(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{}, context.Canceled + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-estimate-canceled", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + + events := collectRuntimeEvents(service.Events()) + assertNoEventType(t, events, EventBudgetEstimateFailed) +} + func TestServiceRunReconcilesUnknownOutputUsage(t *testing.T) { t.Parallel() From eb360d310ad91c9e256f7f289a6c4d0fafdeccba Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 23 Apr 2026 04:38:20 +0000 Subject: [PATCH 6/9] fix(runtime): reconcile token ledger with partial usage observation Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/anthropic/provider.go | 9 ++- internal/provider/anthropic/provider_test.go | 59 ++++++++++++++++ internal/provider/gemini/provider.go | 5 ++ internal/provider/gemini/provider_test.go | 55 +++++++++++++++ .../openaicompat/chatcompletions/adapter.go | 40 +++++++---- .../chatcompletions/adapter_test.go | 33 +++++++++ .../openaicompat/responses/adapter.go | 19 +++-- .../openaicompat/responses/adapter_test.go | 33 +++++++++ internal/provider/types/usage.go | 8 ++- internal/runtime/budget_models.go | 7 +- internal/runtime/provider_stream.go | 14 ++-- internal/runtime/run.go | 24 ++++--- .../runtime/runtime_internal_helpers_test.go | 69 +++++++++++++++++++ internal/runtime/runtime_test.go | 8 ++- 14 files changed, 343 insertions(+), 40 deletions(-) diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index 1e176a16..4c594286 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -98,9 +98,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque case anthropic.MessageStartEvent: if variant.Message.Usage.InputTokens > 0 { usage.InputTokens = int(variant.Message.Usage.InputTokens) + usage.InputObserved = true } if variant.Message.Usage.OutputTokens > 0 { usage.OutputTokens = int(variant.Message.Usage.OutputTokens) + usage.OutputObserved = true } case anthropic.ContentBlockStartEvent: switch block := variant.ContentBlock.AsAny().(type) { @@ -167,9 +169,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque } if variant.Usage.OutputTokens > 0 { usage.OutputTokens = int(variant.Usage.OutputTokens) + usage.OutputObserved = true } if variant.Usage.InputTokens > 0 { usage.InputTokens = int(variant.Usage.InputTokens) + usage.InputObserved = true } } } @@ -193,9 +197,12 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque return fmt.Errorf("%sinvalid tool_use stream at index %d: missing tool name", errorPrefix, index) } } - if usage.TotalTokens <= 0 { + if usage.TotalTokens <= 0 && (usage.InputObserved || usage.OutputObserved) { usage.TotalTokens = usage.InputTokens + usage.OutputTokens } + if !usage.InputObserved && !usage.OutputObserved { + return provider.EmitMessageDone(ctx, events, finishReason, nil) + } return provider.EmitMessageDone(ctx, events, finishReason, &usage) } diff --git a/internal/provider/anthropic/provider_test.go b/internal/provider/anthropic/provider_test.go index 63a2ecdd..a7219d20 100644 --- a/internal/provider/anthropic/provider_test.go +++ b/internal/provider/anthropic/provider_test.go @@ -92,6 +92,9 @@ func TestProviderGenerate(t *testing.T) { if payload.Usage == nil || payload.Usage.TotalTokens != 10 { t.Fatalf("expected usage total tokens 10, got %+v", payload.Usage) } + if !payload.Usage.InputObserved || !payload.Usage.OutputObserved { + t.Fatalf("expected usage observed flags true, got %+v", payload.Usage) + } } } if !foundText || !foundToolStart || !foundToolDelta || !foundDone { @@ -99,6 +102,62 @@ func TestProviderGenerate(t *testing.T) { } } +func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "event: content_block_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"Hello\"}}\n\n") + _, _ = fmt.Fprint(w, "event: message_delta\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"}}\n\n") + _, _ = fmt.Fprint(w, "event: message_stop\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverAnthropic, + BaseURL: server.URL, + DefaultModel: "claude-3-7-sonnet", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + + drained := drainEvents(events) + var done *providertypes.MessageDonePayload + for i := range drained { + if drained[i].Type != providertypes.StreamEventMessageDone { + continue + } + payload, payloadErr := drained[i].MessageDoneValue() + if payloadErr != nil { + t.Fatalf("MessageDoneValue() error = %v", payloadErr) + } + done = &payload + break + } + if done == nil { + t.Fatalf("expected message_done event, got %+v", drained) + } + if done.Usage != nil { + t.Fatalf("expected nil usage when provider does not report usage, got %+v", done.Usage) + } +} + func TestNewAcceptsCustomChatEndpointPath(t *testing.T) { t.Parallel() diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index 5a6ea83e..017ce282 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -157,6 +157,9 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque if !hasPayload { return fmt.Errorf("%w: empty gemini stream payload", provider.ErrStreamInterrupted) } + if !usage.InputObserved && !usage.OutputObserved { + return provider.EmitMessageDone(ctx, events, finishReason, nil) + } return provider.EmitMessageDone(ctx, events, finishReason, &usage) } @@ -209,6 +212,8 @@ func extractUsage(usage *providertypes.Usage, raw *genai.GenerateContentResponse usage.InputTokens = int(raw.PromptTokenCount) usage.OutputTokens = int(raw.CandidatesTokenCount) usage.TotalTokens = int(raw.TotalTokenCount) + usage.InputObserved = true + usage.OutputObserved = true } // encodeArguments 将函数参数对象编码为 JSON 字符串,供统一 tool_call_delta 事件复用。 diff --git a/internal/provider/gemini/provider_test.go b/internal/provider/gemini/provider_test.go index 3e8296bb..8109fe8c 100644 --- a/internal/provider/gemini/provider_test.go +++ b/internal/provider/gemini/provider_test.go @@ -80,6 +80,9 @@ func TestProviderGenerate(t *testing.T) { if payload.Usage == nil || payload.Usage.TotalTokens != 7 { t.Fatalf("expected usage total tokens 7, got %+v", payload.Usage) } + if !payload.Usage.InputObserved || !payload.Usage.OutputObserved { + t.Fatalf("expected usage observed flags true, got %+v", payload.Usage) + } if payload.FinishReason != "stop" { t.Fatalf("expected finish reason stop, got %q", payload.FinishReason) } @@ -90,6 +93,58 @@ func TestProviderGenerate(t *testing.T) { } } +func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"content\":{\"parts\":[{\"text\":\"Hello \"}]}}]}\n\n") + _, _ = fmt.Fprint(w, "data: {\"candidates\":[{\"index\":0,\"finishReason\":\"STOP\",\"content\":{\"parts\":[{\"text\":\"done\"}]}}]}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverGemini, + BaseURL: server.URL, + DefaultModel: "gemini-2.5-flash", + APIKeyEnv: "GEMINI_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + + drained := drainEvents(events) + var done *providertypes.MessageDonePayload + for i := range drained { + if drained[i].Type != providertypes.StreamEventMessageDone { + continue + } + payload, payloadErr := drained[i].MessageDoneValue() + if payloadErr != nil { + t.Fatalf("MessageDoneValue() error = %v", payloadErr) + } + done = &payload + break + } + if done == nil { + t.Fatalf("expected message_done event, got %+v", drained) + } + if done.Usage != nil { + t.Fatalf("expected nil usage when provider does not report usage, got %+v", done.Usage) + } +} + func TestNewAcceptsCustomChatEndpointPath(t *testing.T) { t.Parallel() diff --git a/internal/provider/openaicompat/chatcompletions/adapter.go b/internal/provider/openaicompat/chatcompletions/adapter.go index e1fe0fe7..ba3ffc5c 100644 --- a/internal/provider/openaicompat/chatcompletions/adapter.go +++ b/internal/provider/openaicompat/chatcompletions/adapter.go @@ -72,6 +72,9 @@ func EmitFromSDKStream( return fmt.Errorf("SDK stream error: %w", err) } + if !usage.InputObserved && !usage.OutputObserved { + return provider.EmitMessageDone(ctx, events, finishReason, nil) + } return provider.EmitMessageDone(ctx, events, finishReason, &usage) } @@ -183,13 +186,13 @@ func ConsumeStream( if flushErr := flushDataLines(); flushErr != nil { return flushErr } - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } if flushErr := flushDataLines(); flushErr != nil { return flushErr } if strings.TrimSpace(finishReason) != "" { - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } return fmt.Errorf("%w: %w", provider.ErrStreamInterrupted, err) } @@ -206,7 +209,7 @@ func ConsumeStream( return flushErr } done = true - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } else { dataLines = append(dataLines, data) } @@ -215,7 +218,7 @@ func ConsumeStream( return flushErr } if done { - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } default: if len(dataLines) == 0 { @@ -225,7 +228,7 @@ func ConsumeStream( return flushErr } if done { - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } } @@ -234,7 +237,7 @@ func ConsumeStream( return flushErr } if done || strings.TrimSpace(finishReason) != "" { - return provider.EmitMessageDone(ctx, events, finishReason, &usage) + return provider.EmitMessageDone(ctx, events, finishReason, doneUsagePtr(usage)) } return fmt.Errorf("%w: missing [DONE] marker before EOF", provider.ErrStreamInterrupted) } @@ -247,19 +250,32 @@ func extractLegacyStreamUsage(usage *providertypes.Usage, raw *streamUsage) { return } *usage = providertypes.Usage{ - InputTokens: raw.PromptTokens, - OutputTokens: raw.CompletionTokens, - TotalTokens: raw.TotalTokens, + InputTokens: raw.PromptTokens, + OutputTokens: raw.CompletionTokens, + TotalTokens: raw.TotalTokens, + InputObserved: true, + OutputObserved: true, } } // extractStreamUsage 将 OpenAI usage 覆盖到统一 token 统计。 func extractStreamUsage(usage *providertypes.Usage, raw openai.CompletionUsage) { *usage = providertypes.Usage{ - InputTokens: int(raw.PromptTokens), - OutputTokens: int(raw.CompletionTokens), - TotalTokens: int(raw.TotalTokens), + InputTokens: int(raw.PromptTokens), + OutputTokens: int(raw.CompletionTokens), + TotalTokens: int(raw.TotalTokens), + InputObserved: true, + OutputObserved: true, + } +} + +// doneUsagePtr 在 message_done 事件中按 usage 观测状态返回 payload,未观测时返回 nil。 +func doneUsagePtr(usage providertypes.Usage) *providertypes.Usage { + if !usage.InputObserved && !usage.OutputObserved { + return nil } + copy := usage + return © } // mergeToolCallDeltaFromSDK 将单个 SDK tool call 增量合并到累积状态,并在必要时发出起始/增量事件。 diff --git a/internal/provider/openaicompat/chatcompletions/adapter_test.go b/internal/provider/openaicompat/chatcompletions/adapter_test.go index 5d9c2da8..d1eb195c 100644 --- a/internal/provider/openaicompat/chatcompletions/adapter_test.go +++ b/internal/provider/openaicompat/chatcompletions/adapter_test.go @@ -43,6 +43,36 @@ func TestConsumeStreamSupportsWeakSSEFormat(t *testing.T) { if done.Usage == nil || done.Usage.TotalTokens != 3 { t.Fatalf("unexpected usage: %+v", done.Usage) } + if !done.Usage.InputObserved || !done.Usage.OutputObserved { + t.Fatalf("expected usage observed flags to be true, got %+v", done.Usage) + } +} + +func TestConsumeStreamEmitsNilUsageWhenProviderDidNotReturnUsage(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}]}`, + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 4) + if err := ConsumeStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("ConsumeStream() error = %v", err) + } + + drained := drainEvents(events) + if len(drained) != 2 { + t.Fatalf("expected 2 events, got %d", len(drained)) + } + done, err := drained[1].MessageDoneValue() + if err != nil { + t.Fatalf("expected message done, got err=%v", err) + } + if done.Usage != nil { + t.Fatalf("expected nil usage when stream carries no usage, got %+v", done.Usage) + } } func TestConsumeStreamParsesMultilineDataEvent(t *testing.T) { @@ -197,6 +227,9 @@ func TestEmitFromSDKStream(t *testing.T) { if done.Usage == nil || done.Usage.TotalTokens != 3 { t.Fatalf("expected usage total tokens 3, got %+v", done.Usage) } + if !done.Usage.InputObserved || !done.Usage.OutputObserved { + t.Fatalf("expected usage observed flags to be true, got %+v", done.Usage) + } } func TestEmitFromSDKStreamErrors(t *testing.T) { diff --git a/internal/provider/openaicompat/responses/adapter.go b/internal/provider/openaicompat/responses/adapter.go index f146207d..6a2c0cf5 100644 --- a/internal/provider/openaicompat/responses/adapter.go +++ b/internal/provider/openaicompat/responses/adapter.go @@ -36,7 +36,7 @@ func EmitFromStream( if reason == "" { reason = "stop" } - return provider.EmitMessageDone(ctx, events, reason, &usage) + return provider.EmitMessageDone(ctx, events, reason, doneUsagePtr(usage)) } processPayload := func(payload string) error { if strings.TrimSpace(payload) == "[DONE]" { @@ -392,9 +392,11 @@ func extractUsage(usage *providertypes.Usage, response *streamResponse) { return } *usage = providertypes.Usage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.TotalTokens, + InputTokens: response.Usage.InputTokens, + OutputTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + InputObserved: true, + OutputObserved: true, } } @@ -435,3 +437,12 @@ func resolveFinishReason(eventType string, response *streamResponse) string { return "" } } + +// doneUsagePtr 在 message_done 事件中按 usage 观测状态返回 payload,未观测时返回 nil。 +func doneUsagePtr(usage providertypes.Usage) *providertypes.Usage { + if !usage.InputObserved && !usage.OutputObserved { + return nil + } + copy := usage + return © +} diff --git a/internal/provider/openaicompat/responses/adapter_test.go b/internal/provider/openaicompat/responses/adapter_test.go index a75cd0b3..f922b752 100644 --- a/internal/provider/openaicompat/responses/adapter_test.go +++ b/internal/provider/openaicompat/responses/adapter_test.go @@ -47,6 +47,39 @@ func TestEmitFromStreamSupportsMultilineSSEData(t *testing.T) { if done.Usage == nil || done.Usage.TotalTokens != 3 { t.Fatalf("unexpected usage in done event: %+v", done.Usage) } + if !done.Usage.InputObserved || !done.Usage.OutputObserved { + t.Fatalf("expected usage observed flags to be true, got %+v", done.Usage) + } +} + +func TestEmitFromStreamEmitsNilUsageWhenProviderDidNotReturnUsage(t *testing.T) { + t.Parallel() + + body := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"hello"}`, + "", + `data: {"type":"response.completed","response":{"status":"completed"}}`, + "", + `data: [DONE]`, + "", + }, "\n") + + events := make(chan providertypes.StreamEvent, 4) + if err := EmitFromStream(context.Background(), strings.NewReader(body), events); err != nil { + t.Fatalf("EmitFromStream() error = %v", err) + } + + drained := drainResponseEvents(events) + if len(drained) != 2 { + t.Fatalf("expected 2 events, got %d (%+v)", len(drained), drained) + } + done, err := drained[1].MessageDoneValue() + if err != nil { + t.Fatalf("expected message done event, got err=%v", err) + } + if done.Usage != nil { + t.Fatalf("expected nil usage when stream carries no usage, got %+v", done.Usage) + } } func TestEmitFromStreamSupportsLongDataLine(t *testing.T) { diff --git a/internal/provider/types/usage.go b/internal/provider/types/usage.go index a605919c..c8c76e26 100644 --- a/internal/provider/types/usage.go +++ b/internal/provider/types/usage.go @@ -2,9 +2,11 @@ package types // Usage 记录本次请求的 token 使用统计。 type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + InputObserved bool `json:"input_observed"` + OutputObserved bool `json:"output_observed"` } // BudgetEstimate 描述 provider 对冻结请求输入 token 的估算结果。 diff --git a/internal/runtime/budget_models.go b/internal/runtime/budget_models.go index 6283931e..b573eeb2 100644 --- a/internal/runtime/budget_models.go +++ b/internal/runtime/budget_models.go @@ -108,13 +108,14 @@ func newTurnBudgetUsageObservation( id controlplane.TurnBudgetID, inputTokens int, outputTokens int, - observed bool, + inputObserved bool, + outputObserved bool, ) TurnBudgetUsageObservation { return TurnBudgetUsageObservation{ ID: id, InputTokens: inputTokens, OutputTokens: outputTokens, - InputObserved: observed, - OutputObserved: observed, + InputObserved: inputObserved, + OutputObserved: outputObserved, } } diff --git a/internal/runtime/provider_stream.go b/internal/runtime/provider_stream.go index b120e4a4..75d4d7ab 100644 --- a/internal/runtime/provider_stream.go +++ b/internal/runtime/provider_stream.go @@ -11,11 +11,12 @@ import ( // streamGenerateResult 统一承载一次流式生成的消息、用量与消费错误。 type streamGenerateResult struct { - message providertypes.Message - inputTokens int - outputTokens int - usagePresent bool - err error + message providertypes.Message + inputTokens int + outputTokens int + inputObserved bool + outputObserved bool + err error } // generateStreamingMessage 负责执行一次基于流式事件的生成调用,并收敛最终 assistant 消息与 usage。 @@ -41,7 +42,8 @@ func generateStreamingMessage( if payload.Usage != nil { outcome.inputTokens = payload.Usage.InputTokens outcome.outputTokens = payload.Usage.OutputTokens - outcome.usagePresent = true + outcome.inputObserved = payload.Usage.InputObserved + outcome.outputObserved = payload.Usage.OutputObserved } if userOnMessageDone != nil { userOnMessageDone(payload) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 17197798..55a86ab5 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -445,7 +445,8 @@ func (s *Service) callProviderWithRetry( snapshot.ID, streamOutcome.inputTokens, streamOutcome.outputTokens, - streamOutcome.usagePresent, + streamOutcome.inputObserved, + streamOutcome.outputObserved, ), }, nil } @@ -572,18 +573,23 @@ func (s *Service) reconcileLedger( return ledgerReconcileResult{}, fmt.Errorf("runtime: turn budget id mismatch between decision and usage observation") } reconciled := ledgerReconcileResult{ - inputTokens: observation.InputTokens, - inputSource: usageSourceObserved, - outputTokens: observation.OutputTokens, - outputSource: usageSourceObserved, + inputSource: usageSourceUnknown, + outputSource: usageSourceUnknown, + } + if observation.InputObserved { + reconciled.inputTokens = observation.InputTokens + reconciled.inputSource = usageSourceObserved + } else { + reconciled.inputTokens = decision.EstimatedInputTokens + reconciled.inputSource = usageSourceEstimated + } + if observation.OutputObserved { + reconciled.outputTokens = observation.OutputTokens + reconciled.outputSource = usageSourceObserved } if observation.InputObserved && observation.OutputObserved { return reconciled, nil } - reconciled.inputTokens = decision.EstimatedInputTokens - reconciled.inputSource = usageSourceEstimated - reconciled.outputTokens = 0 - reconciled.outputSource = usageSourceUnknown reconciled.hasUnknownUsage = true if state != nil { state.session.HasUnknownUsage = true diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 4a31e7c8..375a79d1 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -11,6 +11,7 @@ import ( "neo-code/internal/provider" providertypes "neo-code/internal/provider/types" "neo-code/internal/runtime/approval" + "neo-code/internal/runtime/controlplane" agentsession "neo-code/internal/session" "neo-code/internal/tools" ) @@ -625,6 +626,74 @@ func TestEmitTokenUsageSkipsZeroUsage(t *testing.T) { } } +func TestReconcileLedgerSupportsPartialObservation(t *testing.T) { + t.Parallel() + + service := &Service{} + state := &runState{session: newRuntimeSession("session-partial-observed")} + id := controlplane.TurnBudgetID{AttemptSeq: 2, RequestHash: "hash-partial-observed"} + decision := controlplane.TurnBudgetDecision{ + ID: id, + EstimatedInputTokens: 37, + } + observation := TurnBudgetUsageObservation{ + ID: id, + InputTokens: 13, + OutputTokens: 0, + InputObserved: true, + OutputObserved: false, + } + + result, err := service.reconcileLedger(state, decision, observation) + if err != nil { + t.Fatalf("reconcileLedger() error = %v", err) + } + if result.inputTokens != 13 || result.inputSource != usageSourceObserved { + t.Fatalf("expected observed input reconciliation, got %+v", result) + } + if result.outputTokens != 0 || result.outputSource != usageSourceUnknown { + t.Fatalf("expected unknown output reconciliation, got %+v", result) + } + if !result.hasUnknownUsage { + t.Fatalf("expected hasUnknownUsage=true for partial observation") + } + if !state.session.HasUnknownUsage || !state.hasUnknownUsage { + t.Fatalf("expected unknown usage flag to propagate to run state") + } +} + +func TestReconcileLedgerUsesEstimateWhenInputNotObserved(t *testing.T) { + t.Parallel() + + service := &Service{} + id := controlplane.TurnBudgetID{AttemptSeq: 3, RequestHash: "hash-no-input-observed"} + decision := controlplane.TurnBudgetDecision{ + ID: id, + EstimatedInputTokens: 41, + } + observation := TurnBudgetUsageObservation{ + ID: id, + InputTokens: 0, + OutputTokens: 7, + InputObserved: false, + OutputObserved: true, + } + + result, err := service.reconcileLedger(nil, decision, observation) + if err != nil { + t.Fatalf("reconcileLedger() error = %v", err) + } + if result.inputTokens != 41 || result.inputSource != usageSourceEstimated { + t.Fatalf("expected estimated input reconciliation, got %+v", result) + } + if result.outputTokens != 7 || result.outputSource != usageSourceObserved { + t.Fatalf("expected observed output reconciliation, got %+v", result) + } + if !result.hasUnknownUsage { + t.Fatalf("expected hasUnknownUsage=true when any side is unobserved") + } +} + func TestExecuteAssistantToolCallsFillsErrorContent(t *testing.T) { t.Parallel() diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 3d6fea52..df63dda9 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -3823,6 +3823,8 @@ func TestServiceRunPersistsAndRestoresTokenUsage(t *testing.T) { usage.InputTokens = 25 usage.OutputTokens = 10 } + usage.InputObserved = true + usage.OutputObserved = true select { case events <- providertypes.NewTextDeltaStreamEvent("assistant reply"): @@ -4917,8 +4919,10 @@ func TestTokenUsageRecordedOnMessageDone(t *testing.T) { // Create a MessageDone stream event with token usage messageDoneEvent := providertypes.NewMessageDoneStreamEvent("stop", &providertypes.Usage{ - InputTokens: 100, - OutputTokens: 50, + InputTokens: 100, + OutputTokens: 50, + InputObserved: true, + OutputObserved: true, }) // 使用与运行时相同的流式事件处理器验证 usage 累积行为。 From 6ab4ba79cd4cec927b616c1547c31812408f0522 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 23 Apr 2026 04:51:08 +0000 Subject: [PATCH 7/9] fix(runtime): increment compact counter only when compact applied Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/runtime/run.go | 6 ++-- .../runtime_remaining_branches_test.go | 30 +++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 55a86ab5..7c05dee5 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -494,11 +494,11 @@ func (s *Service) applyCompactForState( if compactErr != nil { return compactErr } - if mode == contextcompact.ModeProactive || mode == contextcompact.ModeReactive { - state.compactCount++ - } state.session = session if result.Applied { + if mode == contextcompact.ModeProactive || mode == contextcompact.ModeReactive { + state.compactCount++ + } state.resetTokenTotals() state.nextAttemptSeq++ applied = true diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index 2eeabc2c..53400e1f 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -16,6 +16,7 @@ import ( providertypes "neo-code/internal/provider/types" approvalflow "neo-code/internal/runtime/approval" + "neo-code/internal/runtime/controlplane" "neo-code/internal/runtime/streaming" "neo-code/internal/security" agentsession "neo-code/internal/session" @@ -329,6 +330,35 @@ func TestApplyCompactForStateStrictErrorBranch(t *testing.T) { } } +func TestApplyCompactForStateDoesNotIncreaseCompactCountWhenNotApplied(t *testing.T) { + t.Parallel() + + service := &Service{ + events: make(chan RuntimeEvent, 8), + compactRunner: &stubCompactRunner{ + result: contextcompact.Result{ + Applied: false, + }, + }, + } + state := newRunState("run-apply-compact-not-applied", newRuntimeSession("session-apply-compact-not-applied")) + state.compactCount = 1 + if err := service.setBaseRunState(context.Background(), &state, controlplane.RunStatePlan); err != nil { + t.Fatalf("set base run state: %v", err) + } + + applied, err := service.applyCompactForState(context.Background(), &state, config.Config{}, contextcompact.ModeProactive, compactErrorStrict) + if err != nil { + t.Fatalf("applyCompactForState() error = %v", err) + } + if applied { + t.Fatalf("expected applied=false when compact runner result is not applied") + } + if state.compactCount != 1 { + t.Fatalf("expected compactCount to stay 1 when compact not applied, got %d", state.compactCount) + } +} + func TestExecuteToolCallWithPermissionRemainingBranches(t *testing.T) { t.Parallel() From 0f259c18dd75e59fa0fb243498eddd81532a9bc0 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 23 Apr 2026 05:00:14 +0000 Subject: [PATCH 8/9] test(provider,scripts): improve coverage for estimator and migration output Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/estimate_test.go | 71 ++++++++++++++++ scripts/migrate_context_budget/main_test.go | 90 +++++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 internal/provider/estimate_test.go diff --git a/internal/provider/estimate_test.go b/internal/provider/estimate_test.go new file mode 100644 index 00000000..5fecb135 --- /dev/null +++ b/internal/provider/estimate_test.go @@ -0,0 +1,71 @@ +package provider + +import ( + "testing" + + providertypes "neo-code/internal/provider/types" +) + +func TestEstimateSerializedPayloadTokens(t *testing.T) { + t.Parallel() + + tokens, err := EstimateSerializedPayloadTokens(map[string]any{ + "model": "x", + "input": "hello", + }) + if err != nil { + t.Fatalf("EstimateSerializedPayloadTokens() error = %v", err) + } + if tokens <= 0 { + t.Fatalf("EstimateSerializedPayloadTokens() = %d, want > 0", tokens) + } +} + +func TestEstimateSerializedPayloadTokensMarshalError(t *testing.T) { + t.Parallel() + + if _, err := EstimateSerializedPayloadTokens(make(chan int)); err == nil { + t.Fatal("EstimateSerializedPayloadTokens() expected marshal error, got nil") + } +} + +func TestEstimateTextTokens(t *testing.T) { + t.Parallel() + + if got := EstimateTextTokens(""); got != 0 { + t.Fatalf("EstimateTextTokens(\"\") = %d, want 0", got) + } + if got := EstimateTextTokens("1234"); got != 2 { + t.Fatalf("EstimateTextTokens(\"1234\") = %d, want 2", got) + } +} + +func TestBuildGenerateRequestSignature(t *testing.T) { + t.Parallel() + + reqA := providertypes.GenerateRequest{ + Model: "gpt", + Messages: []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }, + }, + } + reqB := reqA + reqC := reqA + reqC.Model = "gpt-2" + + sigA := BuildGenerateRequestSignature(reqA) + sigB := BuildGenerateRequestSignature(reqB) + sigC := BuildGenerateRequestSignature(reqC) + if sigA == "" { + t.Fatal("BuildGenerateRequestSignature(reqA) returned empty signature") + } + if sigA != sigB { + t.Fatalf("same request should have same signature: %q != %q", sigA, sigB) + } + if sigA == sigC { + t.Fatalf("different requests should have different signatures: %q == %q", sigA, sigC) + } +} diff --git a/scripts/migrate_context_budget/main_test.go b/scripts/migrate_context_budget/main_test.go index dff132e4..7c2a37ef 100644 --- a/scripts/migrate_context_budget/main_test.go +++ b/scripts/migrate_context_budget/main_test.go @@ -1,7 +1,14 @@ package main import ( + "bytes" + "io" + "os" + "path/filepath" + "strings" "testing" + + "neo-code/internal/config" ) func TestDefaultBaseDirReturnsPath(t *testing.T) { @@ -11,3 +18,86 @@ func TestDefaultBaseDirReturnsPath(t *testing.T) { t.Fatal("expected non-empty default base dir") } } + +func TestPrintMigrationResultChangedDryRun(t *testing.T) { + output := captureStdout(t, func() { + printMigrationResult(config.ContextBudgetMigrationResult{ + Path: "/tmp/config.yaml", + Changed: true, + }, true) + }) + + if !strings.Contains(output, "[DRY-RUN] 将迁移 /tmp/config.yaml") { + t.Fatalf("unexpected output: %q", output) + } +} + +func TestPrintMigrationResultChangedWithBackup(t *testing.T) { + output := captureStdout(t, func() { + printMigrationResult(config.ContextBudgetMigrationResult{ + Path: "/tmp/config.yaml", + Changed: true, + Backup: "/tmp/config.yaml.bak", + }, false) + }) + + if !strings.Contains(output, "已迁移 /tmp/config.yaml (备份: /tmp/config.yaml.bak)") { + t.Fatalf("unexpected output: %q", output) + } +} + +func TestPrintMigrationResultNotChangedWithNotes(t *testing.T) { + output := captureStdout(t, func() { + printMigrationResult(config.ContextBudgetMigrationResult{ + Path: "/tmp/config.yaml", + Reason: "未检测到 context.auto_compact", + Notes: []string{" note-a ", "note-b"}, + }, false) + }) + + if !strings.Contains(output, "说明: note-a") { + t.Fatalf("missing note-a in output: %q", output) + } + if !strings.Contains(output, "说明: note-b") { + t.Fatalf("missing note-b in output: %q", output) + } + if !strings.Contains(output, "跳过: /tmp/config.yaml (未检测到 context.auto_compact)") { + t.Fatalf("missing skip line in output: %q", output) + } +} + +func TestDefaultBaseDirUsesHome(t *testing.T) { + tempHome := t.TempDir() + t.Setenv("HOME", tempHome) + want := filepath.Join(tempHome, ".neocode") + if got := defaultBaseDir(); got != want { + t.Fatalf("defaultBaseDir() = %q, want %q", got, want) + } +} + +func captureStdout(t *testing.T, fn func()) string { + t.Helper() + + originalStdout := os.Stdout + reader, writer, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe() error = %v", err) + } + os.Stdout = writer + defer func() { + os.Stdout = originalStdout + }() + + done := make(chan string, 1) + go func() { + var buf bytes.Buffer + _, _ = io.Copy(&buf, reader) + done <- buf.String() + }() + + fn() + _ = writer.Close() + output := <-done + _ = reader.Close() + return output +} From 62e569bcccd37545547313d7b97899b377ce581d Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 23 Apr 2026 05:08:55 +0000 Subject: [PATCH 9/9] fix(runtime,anthropic): tighten estimate failure gate and preserve zero usage observation Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com> --- internal/provider/anthropic/provider.go | 8 +-- internal/provider/anthropic/provider_test.go | 64 ++++++++++++++++++++ internal/runtime/run.go | 9 +++ internal/runtime/runtime_test.go | 46 +++++++++++++- 4 files changed, 122 insertions(+), 5 deletions(-) diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index 4c594286..777fcba6 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -96,11 +96,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque event := streamReader.Current() switch variant := event.AsAny().(type) { case anthropic.MessageStartEvent: - if variant.Message.Usage.InputTokens > 0 { + if variant.Message.Usage.JSON.InputTokens.Valid() { usage.InputTokens = int(variant.Message.Usage.InputTokens) usage.InputObserved = true } - if variant.Message.Usage.OutputTokens > 0 { + if variant.Message.Usage.JSON.OutputTokens.Valid() { usage.OutputTokens = int(variant.Message.Usage.OutputTokens) usage.OutputObserved = true } @@ -167,11 +167,11 @@ func (p *Provider) Generate(ctx context.Context, req providertypes.GenerateReque if reason := strings.TrimSpace(string(variant.Delta.StopReason)); reason != "" { finishReason = reason } - if variant.Usage.OutputTokens > 0 { + if variant.Usage.JSON.OutputTokens.Valid() { usage.OutputTokens = int(variant.Usage.OutputTokens) usage.OutputObserved = true } - if variant.Usage.InputTokens > 0 { + if variant.Usage.JSON.InputTokens.Valid() { usage.InputTokens = int(variant.Usage.InputTokens) usage.InputObserved = true } diff --git a/internal/provider/anthropic/provider_test.go b/internal/provider/anthropic/provider_test.go index a7219d20..5e8ed159 100644 --- a/internal/provider/anthropic/provider_test.go +++ b/internal/provider/anthropic/provider_test.go @@ -102,6 +102,70 @@ func TestProviderGenerate(t *testing.T) { } } +func TestProviderGenerateMarksZeroUsageAsObserved(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "event: message_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n") + _, _ = fmt.Fprint(w, "event: content_block_start\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"ok\"}}\n\n") + _, _ = fmt.Fprint(w, "event: message_delta\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}\n\n") + _, _ = fmt.Fprint(w, "event: message_stop\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"message_stop\"}\n\n") + })) + defer server.Close() + + p, err := New(provider.RuntimeConfig{ + Driver: provider.DriverAnthropic, + BaseURL: server.URL, + DefaultModel: "claude-3-7-sonnet", + APIKeyEnv: "ANTHROPIC_TEST_KEY", + APIKeyResolver: provider.StaticAPIKeyResolver("test-key"), + }) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + events := make(chan providertypes.StreamEvent, 8) + if err := p.Generate(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hi")}, + }}, + }, events); err != nil { + t.Fatalf("Generate() error = %v", err) + } + + drained := drainEvents(events) + var done *providertypes.MessageDonePayload + for i := range drained { + if drained[i].Type != providertypes.StreamEventMessageDone { + continue + } + payload, payloadErr := drained[i].MessageDoneValue() + if payloadErr != nil { + t.Fatalf("MessageDoneValue() error = %v", payloadErr) + } + done = &payload + break + } + if done == nil { + t.Fatalf("expected message_done event, got %+v", drained) + } + if done.Usage == nil { + t.Fatalf("expected usage to be present when zero usage is observed") + } + if !done.Usage.InputObserved || !done.Usage.OutputObserved { + t.Fatalf("expected observed flags true, got %+v", done.Usage) + } + if done.Usage.InputTokens != 0 || done.Usage.OutputTokens != 0 || done.Usage.TotalTokens != 0 { + t.Fatalf("expected zero usage, got %+v", done.Usage) + } +} + func TestProviderGenerateOmitsUsageWhenProviderDidNotReturnUsage(t *testing.T) { t.Parallel() diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 7c05dee5..55bfe45d 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -542,6 +542,9 @@ func (s *Service) evaluateTurnBudget( if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return controlplane.TurnBudgetDecision{}, err } + if !shouldBypassEstimateFailure(err) { + return controlplane.TurnBudgetDecision{}, fmt.Errorf("runtime: estimate input tokens: %w", err) + } s.emitRunScoped(ctx, EventBudgetEstimateFailed, state, newBudgetEstimateFailedPayload(snapshot.ID, err)) decision := controlplane.TurnBudgetDecision{ ID: snapshot.ID, @@ -563,6 +566,12 @@ func (s *Service) evaluateTurnBudget( return decision, nil } +// shouldBypassEstimateFailure 判断估算失败是否允许降级放行,仅对可恢复 provider 错误放行。 +func shouldBypassEstimateFailure(err error) bool { + var providerErr *provider.ProviderError + return errors.As(err, &providerErr) && providerErr.Retryable +} + // reconcileLedger 根据 observed usage 或发送前 estimate 生成本轮账本写入结果。 func (s *Service) reconcileLedger( state *runState, diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index df63dda9..2c6b8902 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -4727,7 +4727,12 @@ func TestServiceRunBypassesBudgetGateWhenEstimateFails(t *testing.T) { estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { _ = ctx _ = req - return providertypes.BudgetEstimate{}, errors.New("estimate unavailable") + return providertypes.BudgetEstimate{}, &provider.ProviderError{ + StatusCode: 503, + Code: provider.ErrorCodeServer, + Message: "estimate unavailable", + Retryable: true, + } }, responses: []scriptedResponse{ { @@ -4799,6 +4804,45 @@ func TestServiceRunBypassesBudgetGateWhenEstimateFails(t *testing.T) { assertNoEventType(t, events, EventError) } +func TestServiceRunFailsWhenEstimateFailsWithDeterministicError(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 10 + cfg.Context.Budget.FallbackPromptBudget = 10 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + _ = req + return providertypes.BudgetEstimate{}, errors.New("invalid provider config") + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-estimate-failed-hard-stop", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }) + if err == nil || !containsError(err, "estimate input tokens") { + t.Fatalf("expected estimate input tokens error, got %v", err) + } + if scripted.callCount != 0 { + t.Fatalf("expected provider Generate not to be called, got %d calls", scripted.callCount) + } + + events := collectRuntimeEvents(service.Events()) + assertNoEventType(t, events, EventBudgetEstimateFailed) + assertNoEventType(t, events, EventBudgetChecked) +} + func TestServiceRunFailsWhenEstimateContextCanceled(t *testing.T) { t.Parallel()