diff --git a/internal/runtime/controlplane/completion.go b/internal/runtime/controlplane/completion.go new file mode 100644 index 00000000..99538bef --- /dev/null +++ b/internal/runtime/controlplane/completion.go @@ -0,0 +1,41 @@ +package controlplane + +// CompletionBlockedReason 表示 completion gate 阻塞完成的原因。 +type CompletionBlockedReason string + +const ( + // CompletionBlockedReasonNone 表示当前不存在阻塞原因。 + CompletionBlockedReasonNone CompletionBlockedReason = "" + // CompletionBlockedReasonPendingTodo 表示仍存在未完成 + CompletionBlockedReasonPendingTodo CompletionBlockedReason = "pending_todo" + // CompletionBlockedReasonUnverifiedWrite 表示仍存在未验证写入。 + CompletionBlockedReasonUnverifiedWrite CompletionBlockedReason = "unverified_write" + // CompletionBlockedReasonPostExecuteClosureRequired 表示刚完成执行后仍需闭环。 + CompletionBlockedReasonPostExecuteClosureRequired CompletionBlockedReason = "post_execute_closure_required" +) + +// CompletionState 描述 completion gate 所需的运行事实。 +type CompletionState struct { + HasPendingAgentTodos bool `json:"has_pending_agent_todos"` + HasUnverifiedWrites bool `json:"has_unverified_writes"` + CompletionBlockedReason CompletionBlockedReason `json:"completion_blocked_reason,omitempty"` +} + +// EvaluateCompletion 依据当前事实计算是否允许本轮 completed。 +func EvaluateCompletion(state CompletionState, assistantHasToolCalls bool) (CompletionState, bool) { + state.CompletionBlockedReason = CompletionBlockedReasonNone + + if assistantHasToolCalls { + state.CompletionBlockedReason = CompletionBlockedReasonPostExecuteClosureRequired + return state, false + } + if state.HasPendingAgentTodos { + state.CompletionBlockedReason = CompletionBlockedReasonPendingTodo + return state, false + } + if state.HasUnverifiedWrites { + state.CompletionBlockedReason = CompletionBlockedReasonUnverifiedWrite + return state, false + } + return state, true +} diff --git a/internal/runtime/controlplane/completion_test.go b/internal/runtime/controlplane/completion_test.go new file mode 100644 index 00000000..b609140f --- /dev/null +++ b/internal/runtime/controlplane/completion_test.go @@ -0,0 +1,55 @@ +package controlplane + +import "testing" + +func TestEvaluateCompletionBlockedByPendingTodo(t *testing.T) { + t.Parallel() + + state, completed := EvaluateCompletion(CompletionState{ + HasPendingAgentTodos: true, + }, false) + if completed { + t.Fatalf("expected completion to be blocked") + } + if state.CompletionBlockedReason != CompletionBlockedReasonPendingTodo { + t.Fatalf("blocked reason = %q, want %q", state.CompletionBlockedReason, CompletionBlockedReasonPendingTodo) + } +} + +func TestEvaluateCompletionBlockedByUnverifiedWrite(t *testing.T) { + t.Parallel() + + state, completed := EvaluateCompletion(CompletionState{ + HasUnverifiedWrites: true, + }, false) + if completed { + t.Fatalf("expected completion to be blocked") + } + if state.CompletionBlockedReason != CompletionBlockedReasonUnverifiedWrite { + t.Fatalf("blocked reason = %q, want %q", state.CompletionBlockedReason, CompletionBlockedReasonUnverifiedWrite) + } +} + +func TestEvaluateCompletionBlockedAfterToolCalls(t *testing.T) { + t.Parallel() + + state, completed := EvaluateCompletion(CompletionState{}, true) + if completed { + t.Fatalf("expected completion to be blocked after tool call turn") + } + if state.CompletionBlockedReason != CompletionBlockedReasonPostExecuteClosureRequired { + t.Fatalf("blocked reason = %q, want %q", state.CompletionBlockedReason, CompletionBlockedReasonPostExecuteClosureRequired) + } +} + +func TestEvaluateCompletionAllowsSatisfiedClosure(t *testing.T) { + t.Parallel() + + state, completed := EvaluateCompletion(CompletionState{}, false) + if !completed { + t.Fatalf("expected completion to succeed") + } + if state.CompletionBlockedReason != CompletionBlockedReasonNone { + t.Fatalf("blocked reason = %q, want empty", state.CompletionBlockedReason) + } +} diff --git a/internal/runtime/controlplane/decider.go b/internal/runtime/controlplane/decider.go index 4fbe7a61..644faedf 100644 --- a/internal/runtime/controlplane/decider.go +++ b/internal/runtime/controlplane/decider.go @@ -6,26 +6,26 @@ import ( "strings" ) -// StopInput 汇总停止决议所需的信号(可多信号并存,由 DecideStopReason 按优先级表决)。 +// StopInput 汇总最终 stop 决议所需的信号。 type StopInput struct { - ContextCanceled bool - RunError error - Success bool + UserInterrupted bool + FatalError error + Completed bool } -// DecideStopReason 按固定优先级返回唯一 StopReason:取消 > 错误 > 成功。 +// DecideStopReason 按固定优先级返回唯一的最终 stop 原因。 func DecideStopReason(in StopInput) (StopReason, string) { - if in.ContextCanceled { - return StopReasonCanceled, "" + if in.UserInterrupted { + return StopReasonUserInterrupt, "" } - if in.RunError != nil { - if errors.Is(in.RunError, context.Canceled) { - return StopReasonCanceled, "" + if in.FatalError != nil { + if errors.Is(in.FatalError, context.Canceled) { + return StopReasonUserInterrupt, "" } - return StopReasonError, strings.TrimSpace(in.RunError.Error()) + return StopReasonFatalError, strings.TrimSpace(in.FatalError.Error()) } - if in.Success { - return StopReasonSuccess, "" + if in.Completed { + return StopReasonCompleted, "" } - return StopReasonError, "runtime: stop reason undetermined" + return StopReasonFatalError, "runtime: stop reason undetermined" } diff --git a/internal/runtime/controlplane/decider_test.go b/internal/runtime/controlplane/decider_test.go index 2aab317e..69c2de4a 100644 --- a/internal/runtime/controlplane/decider_test.go +++ b/internal/runtime/controlplane/decider_test.go @@ -11,38 +11,39 @@ func TestDecideStopReasonPriority(t *testing.T) { errSample := errors.New("boom") cases := []struct { - name string - in StopInput - reason StopReason + name string + in StopInput + wantReason StopReason }{ { - name: "canceled_wins_over_error", + name: "user_interrupt_wins_over_fatal", in: StopInput{ - ContextCanceled: true, - RunError: errSample, + UserInterrupted: true, + FatalError: errSample, }, - reason: StopReasonCanceled, + wantReason: StopReasonUserInterrupt, }, { - name: "error", + name: "fatal_error_wins_over_completed", in: StopInput{ - RunError: errSample, + FatalError: errSample, + Completed: true, }, - reason: StopReasonError, + wantReason: StopReasonFatalError, }, { - name: "success", + name: "completed", in: StopInput{ - Success: true, + Completed: true, }, - reason: StopReasonSuccess, + wantReason: StopReasonCompleted, }, { - name: "context_canceled_on_error_field", + name: "context_canceled_maps_to_user_interrupt", in: StopInput{ - RunError: context.Canceled, + FatalError: context.Canceled, }, - reason: StopReasonCanceled, + wantReason: StopReasonUserInterrupt, }, } @@ -50,9 +51,10 @@ func TestDecideStopReasonPriority(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + got, _ := DecideStopReason(tc.in) - if got != tc.reason { - t.Fatalf("DecideStopReason() = %q, want %q", got, tc.reason) + if got != tc.wantReason { + t.Fatalf("DecideStopReason() = %q, want %q", got, tc.wantReason) } }) } @@ -62,8 +64,8 @@ func TestDecideStopReasonDetails(t *testing.T) { t.Parallel() reason, detail := DecideStopReason(StopInput{}) - if reason != StopReasonError { - t.Fatalf("reason = %q, want %q", reason, StopReasonError) + if reason != StopReasonFatalError { + t.Fatalf("reason = %q, want %q", reason, StopReasonFatalError) } if detail != "runtime: stop reason undetermined" { t.Fatalf("detail = %q, want undetermined detail", detail) diff --git a/internal/runtime/controlplane/envelope.go b/internal/runtime/controlplane/envelope.go index ec2006fd..be700ed7 100644 --- a/internal/runtime/controlplane/envelope.go +++ b/internal/runtime/controlplane/envelope.go @@ -1,4 +1,4 @@ package controlplane // PayloadVersion 为 runtime 事件 envelope 的当前协议版本号。 -const PayloadVersion = 1 +const PayloadVersion = 2 diff --git a/internal/runtime/controlplane/phase.go b/internal/runtime/controlplane/phase.go index e75f397c..d1f4e74a 100644 --- a/internal/runtime/controlplane/phase.go +++ b/internal/runtime/controlplane/phase.go @@ -1,15 +1,75 @@ package controlplane -// Phase 表示单轮 ReAct 内的显式阶段(plan -> execute -> dispatch -> verify)。 -type Phase string +import "fmt" + +// RunState 表示单次 Run 生命周期中的显式运行态,统一承载主链 phase 与外围治理态。 +type RunState string const ( - // PhasePlan 规划阶段:构建上下文、调用 provider 直至得到 assistant 消息(含工具调用决策)。 - PhasePlan Phase = "plan" - // PhaseExecute 执行阶段:执行本批次全部工具调用。 - PhaseExecute Phase = "execute" - // PhaseDispatch 调度阶段:执行 Todo 驱动的子代理任务派发。 - PhaseDispatch Phase = "dispatch" - // PhaseVerify 验证阶段:工具结果已回灌,等待下一轮 provider 校验或收尾。 - PhaseVerify Phase = "verify" + // RunStatePlan 表示规划阶段:构建上下文并驱动 provider 产出 assistant 决策。 + RunStatePlan RunState = "plan" + // RunStateExecute 表示执行阶段:执行本轮 assistant 产生的全部工具调用。 + RunStateExecute RunState = "execute" + // RunStateVerify 表示验证阶段:工具结果已回灌,等待下一轮模型收尾或继续推进。 + RunStateVerify RunState = "verify" + // RunStateCompacting 表示当前正在执行 compact 或 reactive compact。 + RunStateCompacting RunState = "compacting" + // RunStateWaitingPermission 表示当前正在等待权限决议,执行流被显式挂起。 + RunStateWaitingPermission RunState = "waiting_permission" + // RunStateStopped 表示本次 Run 已完成终止决议,不再继续推进生命周期。 + RunStateStopped RunState = "stopped" ) + +var allowedRunStateTransitions = map[RunState]map[RunState]struct{}{ + "": { + RunStatePlan: {}, + }, + RunStatePlan: { + RunStatePlan: {}, + RunStateExecute: {}, + RunStateCompacting: {}, + RunStateWaitingPermission: {}, + RunStateStopped: {}, + }, + RunStateExecute: { + RunStateExecute: {}, + RunStateVerify: {}, + RunStateCompacting: {}, + RunStateWaitingPermission: {}, + RunStateStopped: {}, + }, + RunStateVerify: { + RunStateVerify: {}, + RunStatePlan: {}, + RunStateCompacting: {}, + RunStateWaitingPermission: {}, + RunStateStopped: {}, + }, + RunStateCompacting: { + RunStateCompacting: {}, + RunStatePlan: {}, + RunStateWaitingPermission: {}, + RunStateStopped: {}, + }, + RunStateWaitingPermission: { + RunStateWaitingPermission: {}, + RunStatePlan: {}, + RunStateExecute: {}, + RunStateVerify: {}, + RunStateCompacting: {}, + RunStateStopped: {}, + }, + RunStateStopped: { + RunStateStopped: {}, + }, +} + +// ValidateRunStateTransition 校验生命周期迁移是否合法,避免主链 phase 与外围治理态分裂成多套规则。 +func ValidateRunStateTransition(from RunState, to RunState) error { + if nextStates, ok := allowedRunStateTransitions[from]; ok { + if _, allowed := nextStates[to]; allowed { + return nil + } + } + return fmt.Errorf("runtime: invalid run state transition %q -> %q", from, to) +} diff --git a/internal/runtime/controlplane/phase_test.go b/internal/runtime/controlplane/phase_test.go new file mode 100644 index 00000000..e1f44dbc --- /dev/null +++ b/internal/runtime/controlplane/phase_test.go @@ -0,0 +1,40 @@ +package controlplane + +import "testing" + +func TestValidateRunStateTransitionMainlineAndGovernanceStates(t *testing.T) { + t.Parallel() + + validTransitions := []struct { + from RunState + to RunState + }{ + {from: "", to: RunStatePlan}, + {from: RunStatePlan, to: RunStateExecute}, + {from: RunStateExecute, to: RunStateVerify}, + {from: RunStateVerify, to: RunStatePlan}, + {from: RunStatePlan, to: RunStateCompacting}, + {from: RunStateCompacting, to: RunStatePlan}, + {from: RunStateExecute, to: RunStateWaitingPermission}, + {from: RunStateWaitingPermission, to: RunStateExecute}, + {from: RunStateVerify, to: RunStateStopped}, + } + + for _, tc := range validTransitions { + tc := tc + t.Run(string(tc.from)+"->"+string(tc.to), func(t *testing.T) { + t.Parallel() + if err := ValidateRunStateTransition(tc.from, tc.to); err != nil { + t.Fatalf("ValidateRunStateTransition(%q,%q) error = %v", tc.from, tc.to, err) + } + }) + } +} + +func TestValidateRunStateTransitionRejectsInvalidJump(t *testing.T) { + t.Parallel() + + if err := ValidateRunStateTransition(RunStatePlan, RunStateVerify); err == nil { + t.Fatalf("expected invalid transition to return error") + } +} diff --git a/internal/runtime/controlplane/progress.go b/internal/runtime/controlplane/progress.go index 784496ce..7a74438a 100644 --- a/internal/runtime/controlplane/progress.go +++ b/internal/runtime/controlplane/progress.go @@ -1,62 +1,238 @@ package controlplane -// ProgressEvidenceKind 标识工具/适配器产出的证据类型,runtime 仅聚合不做语义推断。 +// ProgressEvidenceKind 标识 runtime 聚合得到的结构化进展证据。 type ProgressEvidenceKind string const ( - // EvidenceNewInfoNonDup 表示本轮引入了非重复的新信息(用于 streak 回归约束)。 - EvidenceNewInfoNonDup ProgressEvidenceKind = "EVIDENCE_NEW_INFO_NON_DUP" + // EvidenceTaskStateChanged 表示任务状态发生合法迁移。 + EvidenceTaskStateChanged ProgressEvidenceKind = "TASK_STATE_CHANGED" + // EvidenceTodoStateChanged 表示 todo 列表发生结构化变化。 + EvidenceTodoStateChanged ProgressEvidenceKind = "TODO_STATE_CHANGED" + // EvidenceWriteApplied 表示本轮产生了有效文件改动。 + EvidenceWriteApplied ProgressEvidenceKind = "WRITE_APPLIED" + // EvidenceVerifyPassed 表示本轮存在明确的验证成功信号(仅与写入证据组合后算业务推进)。 + EvidenceVerifyPassed ProgressEvidenceKind = "VERIFY_PASSED" + // EvidenceNewInfoNonDup 表示本轮引入了去重后的新信息。 + EvidenceNewInfoNonDup ProgressEvidenceKind = "NEW_INFO_NON_DUP" ) -// ProgressEvidenceRecord 描述一条可计分的进展证据。 +// SubgoalRelation 表示当前轮子目标与上一轮的关系。 +type SubgoalRelation string + +const ( + // SubgoalRelationSame 表示子目标可证明相同。 + SubgoalRelationSame SubgoalRelation = "same" + // SubgoalRelationDifferent 表示子目标可证明不同。 + SubgoalRelationDifferent SubgoalRelation = "different" + // SubgoalRelationUnknown 表示当前无法稳定判断子目标关系。 + SubgoalRelationUnknown SubgoalRelation = "unknown" +) + +// StalledProgressState 表示当前进展是否已进入软卡住状态。 +type StalledProgressState string + +const ( + // StalledProgressHealthy 表示当前未进入 stalled。 + StalledProgressHealthy StalledProgressState = "healthy" + // StalledProgressStalled 表示当前已进入 stalled。 + StalledProgressStalled StalledProgressState = "stalled" +) + +// ReminderKind 标识应向模型注入的纠偏提醒类型。 +type ReminderKind string + +const ( + // ReminderKindNone 表示当前轮无需注入提醒。 + ReminderKindNone ReminderKind = "" + // ReminderKindNoProgress 表示应注入无进展提醒。 + ReminderKindNoProgress ReminderKind = "REMINDER_NO_PROGRESS" + // ReminderKindRepeatCycle 表示应注入重复循环提醒。 + ReminderKindRepeatCycle ReminderKind = "REMINDER_REPEAT_CYCLE" + // ReminderKindGenericStalled 表示应注入通用 stalled 提醒。 + ReminderKindGenericStalled ReminderKind = "REMINDER_GENERIC_STALLED" +) + +// ProgressEvidenceRecord 描述一条结构化进展证据。 type ProgressEvidenceRecord struct { Kind ProgressEvidenceKind `json:"kind"` Detail string `json:"detail,omitempty"` } -// ProgressScore 表示一次评估后的分值增量与 streak 快照。 +// ProgressScore 表示一次 progress 评估后的完整快照。 type ProgressScore struct { - ScoreDelta int `json:"score_delta"` - NoProgressStreak int `json:"no_progress_streak"` - RepeatCycleStreak int `json:"repeat_cycle_streak"` + HasBusinessProgress bool `json:"has_business_progress"` + HasExplorationProgress bool `json:"has_exploration_progress"` + StrongEvidenceCount int `json:"strong_evidence_count"` + MediumEvidenceCount int `json:"medium_evidence_count"` + WeakEvidenceCount int `json:"weak_evidence_count"` + ExplorationStreak int `json:"exploration_streak"` + NoProgressStreak int `json:"no_progress_streak"` + RepeatCycleStreak int `json:"repeat_cycle_streak"` + SameToolSignature bool `json:"same_tool_signature"` + SameResultFingerprint bool `json:"same_result_fingerprint"` + SameSubgoal SubgoalRelation `json:"same_subgoal"` + StalledProgressState StalledProgressState `json:"stalled_progress_state"` + ReminderKind ReminderKind `json:"reminder_kind,omitempty"` } -// ProgressState 汇总当前运行期 progress 控制面状态。 +// ProgressState 保存跨轮 progress 判定所需的历史快照。 type ProgressState struct { - LastScore ProgressScore `json:"last_score"` - LastSignature string `json:"last_signature,omitempty"` + LastScore ProgressScore `json:"last_score"` + LastToolSignature string `json:"last_tool_signature,omitempty"` + LastResultFingerprint string `json:"last_result_fingerprint,omitempty"` + LastSubgoalFingerprint string `json:"last_subgoal_fingerprint,omitempty"` } -// ApplyProgressEvidence 根据证据更新分值与 streak。 -func ApplyProgressEvidence(state ProgressState, records []ProgressEvidenceRecord, currentSignature string) ProgressState { - next := state.LastScore - hasToolAttempt := currentSignature != "" - isRepeated := hasToolAttempt && state.LastSignature != "" && currentSignature == state.LastSignature +// ProgressInput 描述一次 progress 评估所需的事实输入。 +type ProgressInput struct { + RunState RunState + Evidence []ProgressEvidenceRecord + CurrentToolSignature string + ResultFingerprint string + SubgoalFingerprint string + NoProgressLimit int + RepeatCycleLimit int +} + +// EvaluateProgress 基于上一轮状态和本轮事实生成新的 progress 快照。 +func EvaluateProgress(state ProgressState, input ProgressInput) ProgressState { + next := ProgressScore{} + flags := summarizeEvidence(input.Evidence) + + next.StrongEvidenceCount = flags.strongCount + next.MediumEvidenceCount = flags.mediumCount + next.WeakEvidenceCount = flags.weakCount + next.HasBusinessProgress = flags.strongCount > 0 || (flags.hasWrite && flags.hasVerify) + next.HasExplorationProgress = !next.HasBusinessProgress && isExplorationProgress(input.RunState, flags) + next.SameToolSignature = input.CurrentToolSignature != "" && + state.LastToolSignature != "" && + input.CurrentToolSignature == state.LastToolSignature + next.SameResultFingerprint = input.ResultFingerprint != "" && + state.LastResultFingerprint != "" && + input.ResultFingerprint == state.LastResultFingerprint + next.SameSubgoal = compareSubgoalFingerprint(state.LastSubgoalFingerprint, input.SubgoalFingerprint) - if hasToolAttempt { - if isRepeated { - next.RepeatCycleStreak++ - } else { - next.RepeatCycleStreak = 1 + if next.HasBusinessProgress { + next.ExplorationStreak = 0 + next.NoProgressStreak = 0 + } else if next.HasExplorationProgress { + next.ExplorationStreak = state.LastScore.ExplorationStreak + 1 + next.NoProgressStreak = state.LastScore.NoProgressStreak + if next.ExplorationStreak > explorationWindowForPhase(input.RunState) { + next.NoProgressStreak++ } } else { - next.RepeatCycleStreak = 0 + next.ExplorationStreak = 0 + next.NoProgressStreak = state.LastScore.NoProgressStreak + 1 } - nextSignature := "" - if hasToolAttempt { - nextSignature = currentSignature + if next.HasBusinessProgress { + next.RepeatCycleStreak = 0 + } else if next.SameToolSignature && next.SameResultFingerprint && next.SameSubgoal == SubgoalRelationSame { + next.RepeatCycleStreak = state.LastScore.RepeatCycleStreak + 1 + } else { + next.RepeatCycleStreak = 0 } - if len(records) > 0 && !isRepeated { - next.NoProgressStreak = 0 - next.ScoreDelta++ + if shouldStall(next, input.NoProgressLimit, input.RepeatCycleLimit) { + next.StalledProgressState = StalledProgressStalled + next.ReminderKind = selectReminderKind(next) } else { - next.NoProgressStreak++ + next.StalledProgressState = StalledProgressHealthy + next.ReminderKind = ReminderKindNone } return ProgressState{ - LastScore: next, - LastSignature: nextSignature, + LastScore: next, + LastToolSignature: input.CurrentToolSignature, + LastResultFingerprint: input.ResultFingerprint, + LastSubgoalFingerprint: input.SubgoalFingerprint, + } +} + +type evidenceFlags struct { + strongCount int + mediumCount int + weakCount int + hasWrite bool + hasVerify bool +} + +// summarizeEvidence 汇总本轮 evidence 的强中弱计数与关键标记。 +func summarizeEvidence(records []ProgressEvidenceRecord) evidenceFlags { + var flags evidenceFlags + for _, record := range records { + switch record.Kind { + case EvidenceTaskStateChanged, EvidenceTodoStateChanged: + flags.strongCount++ + case EvidenceWriteApplied, EvidenceVerifyPassed: + flags.mediumCount++ + case EvidenceNewInfoNonDup: + flags.weakCount++ + } + + switch record.Kind { + case EvidenceWriteApplied: + flags.hasWrite = true + case EvidenceVerifyPassed: + flags.hasVerify = true + } + } + return flags +} + +// isExplorationProgress 判断本轮是否属于可被宽容窗口吸收的探索型推进。 +func isExplorationProgress(runState RunState, flags evidenceFlags) bool { + if runState != RunStatePlan && runState != RunStateExecute { + return false + } + return flags.weakCount > 0 +} + +// explorationWindowForPhase 返回不同阶段允许的 exploration 宽容窗口。 +func explorationWindowForPhase(runState RunState) int { + switch runState { + case RunStatePlan: + return 4 + case RunStateExecute: + return 2 + default: + return 0 + } +} + +// compareSubgoalFingerprint 判断当前轮与上一轮的子目标关系。 +func compareSubgoalFingerprint(previous string, current string) SubgoalRelation { + if previous == "" && current == "" { + return SubgoalRelationUnknown + } + if previous == "" || current == "" { + return SubgoalRelationUnknown + } + if previous == current { + return SubgoalRelationSame + } + return SubgoalRelationDifferent +} + +// shouldStall 判断当前快照是否应进入 stalled。 +func shouldStall(score ProgressScore, noProgressLimit int, repeatLimit int) bool { + if repeatLimit > 0 && score.RepeatCycleStreak >= repeatLimit { + return true + } + if noProgressLimit > 0 && score.NoProgressStreak >= noProgressLimit { + return true + } + return false +} + +// selectReminderKind 选择 stalled 场景下应注入的提醒类型。 +func selectReminderKind(score ProgressScore) ReminderKind { + if score.RepeatCycleStreak > 0 && score.SameToolSignature && score.SameResultFingerprint { + return ReminderKindRepeatCycle + } + if score.NoProgressStreak > 0 { + return ReminderKindNoProgress } + return ReminderKindGenericStalled } diff --git a/internal/runtime/controlplane/progress_test.go b/internal/runtime/controlplane/progress_test.go index f457a0be..fe450eda 100644 --- a/internal/runtime/controlplane/progress_test.go +++ b/internal/runtime/controlplane/progress_test.go @@ -2,92 +2,166 @@ package controlplane import "testing" -func TestApplyProgressEvidenceNoEvidenceIncrementsNoProgress(t *testing.T) { +func TestEvaluateProgressBusinessProgressResetsStreaks(t *testing.T) { t.Parallel() - got := ApplyProgressEvidence(ProgressState{}, nil, "") - want := ProgressState{ + + state := ProgressState{ LastScore: ProgressScore{ - NoProgressStreak: 1, - RepeatCycleStreak: 0, + ExplorationStreak: 2, + NoProgressStreak: 3, + RepeatCycleStreak: 1, + }, + } + + got := EvaluateProgress(state, ProgressInput{ + RunState: RunStateExecute, + Evidence: []ProgressEvidenceRecord{ + {Kind: EvidenceTodoStateChanged}, }, + NoProgressLimit: 3, + RepeatCycleLimit: 3, + }) + + if !got.LastScore.HasBusinessProgress { + t.Fatalf("expected business progress") } - if got != want { - t.Fatalf("expected %+v, got %+v", want, got) + if got.LastScore.NoProgressStreak != 0 { + t.Fatalf("no-progress streak = %d, want 0", got.LastScore.NoProgressStreak) + } + if got.LastScore.RepeatCycleStreak != 0 { + t.Fatalf("repeat streak = %d, want 0", got.LastScore.RepeatCycleStreak) } } -func TestApplyProgressEvidenceOnlyNonDupResetsNoProgressStreak(t *testing.T) { +func TestEvaluateProgressExplorationUsesWindow(t *testing.T) { t.Parallel() + state := ProgressState{ - LastScore: ProgressScore{NoProgressStreak: 3}, - } - got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ - {Kind: EvidenceNewInfoNonDup}, - }, "sig1") - want := ProgressState{ LastScore: ProgressScore{ - ScoreDelta: 1, - NoProgressStreak: 0, - RepeatCycleStreak: 1, + ExplorationStreak: 3, + NoProgressStreak: 1, }, - LastSignature: "sig1", } - if got != want { - t.Fatalf("expected %+v, got %+v", want, got) + + got := EvaluateProgress(state, ProgressInput{ + RunState: RunStatePlan, + Evidence: []ProgressEvidenceRecord{ + {Kind: EvidenceNewInfoNonDup}, + }, + NoProgressLimit: 3, + RepeatCycleLimit: 3, + }) + + if !got.LastScore.HasExplorationProgress { + t.Fatalf("expected exploration progress") + } + if got.LastScore.ExplorationStreak != 4 { + t.Fatalf("exploration streak = %d, want 4", got.LastScore.ExplorationStreak) + } + if got.LastScore.NoProgressStreak != 1 { + t.Fatalf("no-progress streak = %d, want unchanged 1", got.LastScore.NoProgressStreak) } } -func TestApplyProgressEvidenceMixedResetsNoProgress(t *testing.T) { +func TestEvaluateProgressExplorationExhaustionStartsNoProgress(t *testing.T) { t.Parallel() + state := ProgressState{ - LastScore: ProgressScore{NoProgressStreak: 2}, + LastScore: ProgressScore{ + ExplorationStreak: 4, + NoProgressStreak: 1, + }, } - got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ - {Kind: EvidenceNewInfoNonDup}, - {Kind: ProgressEvidenceKind("other_evidence")}, - }, "sig1") - if got.LastScore.NoProgressStreak != 0 { - t.Fatalf("expected streak reset, got %d", got.LastScore.NoProgressStreak) + + got := EvaluateProgress(state, ProgressInput{ + RunState: RunStatePlan, + Evidence: []ProgressEvidenceRecord{ + {Kind: EvidenceNewInfoNonDup}, + }, + NoProgressLimit: 3, + RepeatCycleLimit: 3, + }) + + if got.LastScore.NoProgressStreak != 2 { + t.Fatalf("no-progress streak = %d, want 2", got.LastScore.NoProgressStreak) } } -func TestApplyProgressEvidenceRepeatCycle(t *testing.T) { +func TestEvaluateProgressRepeatCycleRequiresSameResultAndSubgoal(t *testing.T) { t.Parallel() + state := ProgressState{ - LastScore: ProgressScore{NoProgressStreak: 1, RepeatCycleStreak: 2}, - LastSignature: "sig1", + LastScore: ProgressScore{RepeatCycleStreak: 2}, + LastToolSignature: "sig", + LastResultFingerprint: "result", + LastSubgoalFingerprint: "subgoal", } - got := ApplyProgressEvidence(state, []ProgressEvidenceRecord{ - {Kind: EvidenceNewInfoNonDup}, - }, "sig1") - want := ProgressState{ - LastScore: ProgressScore{ - NoProgressStreak: 2, - RepeatCycleStreak: 3, - }, - LastSignature: "sig1", + + got := EvaluateProgress(state, ProgressInput{ + RunState: RunStateExecute, + CurrentToolSignature: "sig", + ResultFingerprint: "result", + SubgoalFingerprint: "subgoal", + NoProgressLimit: 3, + RepeatCycleLimit: 3, + }) + + if got.LastScore.RepeatCycleStreak != 3 { + t.Fatalf("repeat streak = %d, want 3", got.LastScore.RepeatCycleStreak) + } + if got.LastScore.StalledProgressState != StalledProgressStalled { + t.Fatalf("stalled state = %q, want %q", got.LastScore.StalledProgressState, StalledProgressStalled) } - if got != want { - t.Fatalf("expected %+v, got %+v", want, got) + if got.LastScore.ReminderKind != ReminderKindRepeatCycle { + t.Fatalf("reminder = %q, want %q", got.LastScore.ReminderKind, ReminderKindRepeatCycle) } } -func TestApplyProgressEvidenceRepeatCycleOnFailureKeepsSignatureTracking(t *testing.T) { +func TestEvaluateProgressUnknownSubgoalDoesNotAdvanceRepeat(t *testing.T) { t.Parallel() + state := ProgressState{ - LastScore: ProgressScore{NoProgressStreak: 2, RepeatCycleStreak: 1}, - LastSignature: "sig1", + LastScore: ProgressScore{RepeatCycleStreak: 1}, + LastToolSignature: "sig", + LastResultFingerprint: "result", + LastSubgoalFingerprint: "subgoal", } - got := ApplyProgressEvidence(state, nil, "sig1") - want := ProgressState{ - LastScore: ProgressScore{ - NoProgressStreak: 3, - RepeatCycleStreak: 2, + got := EvaluateProgress(state, ProgressInput{ + RunState: RunStateExecute, + CurrentToolSignature: "sig", + ResultFingerprint: "result", + SubgoalFingerprint: "", + NoProgressLimit: 3, + RepeatCycleLimit: 3, + }) + + if got.LastScore.SameSubgoal != SubgoalRelationUnknown { + t.Fatalf("same subgoal = %q, want %q", got.LastScore.SameSubgoal, SubgoalRelationUnknown) + } + if got.LastScore.RepeatCycleStreak != 0 { + t.Fatalf("repeat streak = %d, want 0", got.LastScore.RepeatCycleStreak) + } +} + +func TestEvaluateProgressVerifyPassedAloneIsNotBusinessProgress(t *testing.T) { + t.Parallel() + + got := EvaluateProgress(ProgressState{}, ProgressInput{ + RunState: RunStateVerify, + Evidence: []ProgressEvidenceRecord{ + {Kind: EvidenceVerifyPassed}, }, - LastSignature: "sig1", + NoProgressLimit: 3, + RepeatCycleLimit: 3, + }) + if got.LastScore.HasBusinessProgress { + t.Fatalf("expected verify-passed alone to not count as business progress") + } + if got.LastScore.StrongEvidenceCount != 0 { + t.Fatalf("strong evidence = %d, want 0", got.LastScore.StrongEvidenceCount) } - if got != want { - t.Fatalf("expected %+v, got %+v", want, got) + if got.LastScore.MediumEvidenceCount != 1 { + t.Fatalf("medium evidence = %d, want 1", got.LastScore.MediumEvidenceCount) } } diff --git a/internal/runtime/controlplane/stop_reason.go b/internal/runtime/controlplane/stop_reason.go index ff51454b..3b8b0c2f 100644 --- a/internal/runtime/controlplane/stop_reason.go +++ b/internal/runtime/controlplane/stop_reason.go @@ -1,13 +1,13 @@ package controlplane -// StopReason 表示一次 Run 的最终停止原因,互斥且由决议器唯一确定。 +// StopReason 表示一次 Run 的最终硬停止原因。 type StopReason string const ( - // StopReasonSuccess 表示助手正常结束(无待执行工具调用)。 - StopReasonSuccess StopReason = "success" - // StopReasonError 表示不可恢复的运行时或 provider 错误。 - StopReasonError StopReason = "error" - // StopReasonCanceled 表示运行上下文被取消(含用户中断)。 - StopReasonCanceled StopReason = "canceled" + // StopReasonFatalError 表示出现不可恢复错误。 + StopReasonFatalError StopReason = "STOP_FATAL_ERROR" + // StopReasonCompleted 表示运行满足完成条件。 + StopReasonCompleted StopReason = "STOP_COMPLETED" + // StopReasonUserInterrupt 表示运行被用户或上层上下文中断。 + StopReasonUserInterrupt StopReason = "STOP_USER_INTERRUPT" ) diff --git a/internal/runtime/event_emitter.go b/internal/runtime/event_emitter.go index 43080dbb..67e06860 100644 --- a/internal/runtime/event_emitter.go +++ b/internal/runtime/event_emitter.go @@ -28,8 +28,8 @@ func (s *Service) emitRunScoped(ctx context.Context, kind EventType, state *runS return s.emit(ctx, kind, "", "", payload) } phase := "" - if state.phase != "" { - phase = string(state.phase) + if state.lifecycle != "" { + phase = string(state.lifecycle) } return s.emitWithEnvelope(ctx, RuntimeEvent{ Type: kind, diff --git a/internal/runtime/permission.go b/internal/runtime/permission.go index 96d7f504..79efdc04 100644 --- a/internal/runtime/permission.go +++ b/internal/runtime/permission.go @@ -11,6 +11,7 @@ import ( providertypes "neo-code/internal/provider/types" approvalflow "neo-code/internal/runtime/approval" + "neo-code/internal/runtime/controlplane" "neo-code/internal/security" "neo-code/internal/tools" ) @@ -128,10 +129,20 @@ func (s *Service) executeToolCallWithPermission(ctx context.Context, input permi // 审批等待属于用户交互阶段,不应受工具执行超时约束; // 否则用户未及时响应会被误判为工具失败并进入调度重试/失败链路。 - decision, requestID, err := s.awaitPermissionDecision(ctx, input, permissionErr) - if err != nil { + var decision approvalflow.Decision + var requestID string + if err := s.enterTemporaryRunState(ctx, input.State, controlplane.RunStateWaitingPermission); err != nil { return result, err } + defer func() { + _ = s.leaveTemporaryRunState(ctx, input.State, controlplane.RunStateWaitingPermission) + }() + resolvedDecision, resolvedRequestID, waitErr := s.awaitPermissionDecision(ctx, input, permissionErr) + if waitErr != nil { + return result, waitErr + } + decision = resolvedDecision + requestID = resolvedRequestID scope, err := rememberScopeFromDecision(decision) if err != nil { diff --git a/internal/runtime/permission_test.go b/internal/runtime/permission_test.go index 51d1c9ff..005b294a 100644 --- a/internal/runtime/permission_test.go +++ b/internal/runtime/permission_test.go @@ -452,7 +452,14 @@ func TestServiceRunMCPPermissionAllowFlow(t *testing.T) { tools: []mcp.ToolDescriptor{ {Name: "create_issue", Description: "create issue", InputSchema: map[string]any{"type": "object"}}, }, - callResult: mcp.CallResult{Content: "mcp create ok"}, + callResult: mcp.CallResult{ + Content: "mcp create ok", + Metadata: map[string]any{ + "verification_performed": true, + "verification_passed": true, + "verification_scope": "workspace", + }, + }, } if err := mcpRegistry.RegisterServer("github", "stdio", "v1", mcpClient); err != nil { t.Fatalf("register mcp server: %v", err) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index fb538c0e..4d163fdc 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -122,7 +122,9 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { for turn := 0; ; turn++ { state.turn = turn - s.transitionRunPhase(ctx, &state, controlplane.PhasePlan) + if err := s.setBaseRunState(ctx, &state, controlplane.RunStatePlan); err != nil { + return s.handleRunError(ctx, state.runID, state.session.ID, err) + } for { if err := ctx.Err(); err != nil { @@ -167,37 +169,76 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } s.emitTokenUsage(ctx, &state, turnResult) + state.mu.Lock() + state.completion = collectCompletionState( + &state, + turnResult.assistant, + len(turnResult.assistant.ToolCalls) > 0, + ) + completionState, completed := controlplane.EvaluateCompletion( + state.completion, + len(turnResult.assistant.ToolCalls) > 0, + ) + state.completion = completionState + state.mu.Unlock() + if len(turnResult.assistant.ToolCalls) == 0 { - s.emitRunScoped(ctx, EventAgentDone, &state, turnResult.assistant) - s.triggerMemoExtraction(state.session.ID, state.session.Messages, state.rememberedThisRun) - return nil + if completed { + s.emitRunScoped(ctx, EventAgentDone, &state, turnResult.assistant) + s.triggerMemoExtraction(state.session.ID, state.session.Messages, state.rememberedThisRun) + return nil + } + state.mu.Lock() + progressInput := collectProgressInput( + controlplane.RunStatePlan, + state.session.TaskState.Clone(), + state.session.TaskState.Clone(), + cloneTodosForPersistence(state.session.Todos), + cloneTodosForPersistence(state.session.Todos), + toolExecutionSummary{}, + snapshot.noProgressStreakLimit, + snapshot.repeatCycleStreakLimit, + ) + state.progress = controlplane.EvaluateProgress(state.progress, progressInput) + currentScore := state.progress.LastScore + state.mu.Unlock() + + s.emitRunScoped(ctx, EventProgressEvaluated, &state, ProgressEvaluatedPayload{Score: currentScore}) + break } - s.transitionRunPhase(ctx, &state, controlplane.PhaseExecute) - if err := s.executeAssistantToolCalls(ctx, &state, snapshot, turnResult.assistant); err != nil { + + beforeTask := state.session.TaskState.Clone() + beforeTodos := cloneTodosForPersistence(state.session.Todos) + if err := s.setBaseRunState(ctx, &state, controlplane.RunStateExecute); err != nil { return s.handleRunError(ctx, state.runID, state.session.ID, err) } - s.transitionRunPhase(ctx, &state, controlplane.PhaseVerify) - - var evidence []controlplane.ProgressEvidenceRecord - toolCallCount := len(turnResult.assistant.ToolCalls) - currentSignature := computeToolSignature(turnResult.assistant.ToolCalls) - - state.mu.Lock() - if len(state.session.Messages) >= toolCallCount { - for i := len(state.session.Messages) - toolCallCount; i < len(state.session.Messages); i++ { - if msg := state.session.Messages[i]; msg.Role == providertypes.RoleTool && !msg.IsError { - evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceNewInfoNonDup}) - break - } - } + summary, err := s.executeAssistantToolCalls(ctx, &state, snapshot, turnResult.assistant) + if err != nil { + return s.handleRunError(ctx, state.runID, state.session.ID, err) } - state.progress = controlplane.ApplyProgressEvidence(state.progress, evidence, currentSignature) + state.mu.Lock() + state.completion = applyToolExecutionCompletion(state.completion, summary) + afterTask := state.session.TaskState.Clone() + afterTodos := cloneTodosForPersistence(state.session.Todos) + progressInput := collectProgressInput( + controlplane.RunStateExecute, + beforeTask, + afterTask, + beforeTodos, + afterTodos, + summary, + snapshot.noProgressStreakLimit, + snapshot.repeatCycleStreakLimit, + ) + state.progress = controlplane.EvaluateProgress(state.progress, progressInput) currentScore := state.progress.LastScore state.mu.Unlock() s.emitRunScoped(ctx, EventProgressEvaluated, &state, ProgressEvaluatedPayload{Score: currentScore}) - + if err := s.setBaseRunState(ctx, &state, controlplane.RunStateVerify); err != nil { + return s.handleRunError(ctx, state.runID, state.session.ID, err) + } break } } @@ -266,25 +307,22 @@ func (s *Service) prepareTurnSnapshot(ctx context.Context, state *runState) (tur } state.mu.Lock() - streak := state.progress.LastScore.NoProgressStreak - repeatStreak := state.progress.LastScore.RepeatCycleStreak + score := state.progress.LastScore state.mu.Unlock() limit := resolveNoProgressStreakLimit(cfg.Runtime) repeatLimit := resolveRepeatCycleStreakLimit(cfg.Runtime) - systemPrompt, repeatInjected := withSelfHealingRepeatReminder(builtContext.SystemPrompt, repeatStreak, repeatLimit) - if !repeatInjected { - systemPrompt = withSelfHealingReminder(systemPrompt, streak, limit) - } + systemPrompt := withProgressReminder(builtContext.SystemPrompt, score) model := strings.TrimSpace(cfg.CurrentModel) return turnSnapshot{ - config: cfg, - providerConfig: providerRuntimeCfg, - model: model, - workdir: activeWorkdir, - toolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, - noProgressStreakLimit: limit, + config: cfg, + providerConfig: providerRuntimeCfg, + model: model, + workdir: activeWorkdir, + toolTimeout: time.Duration(cfg.ToolTimeoutSec) * time.Second, + noProgressStreakLimit: limit, + repeatCycleStreakLimit: repeatLimit, request: providertypes.GenerateRequest{ Model: model, SystemPrompt: systemPrompt, @@ -391,17 +429,31 @@ func (s *Service) applyCompactForState( mode contextcompact.Mode, errorPolicy compactErrorPolicy, ) (bool, error) { - session, result, err := s.runCompactForSession(ctx, state.runID, state.session, cfg, mode, errorPolicy) - if err != nil { + applied := false + if err := s.enterTemporaryRunState(ctx, state, controlplane.RunStateCompacting); err != nil { return false, err } - state.session = session - if result.Applied { - state.resetTokenTotals() - state.compactApplied = true - return true, nil + defer func() { + _ = s.leaveTemporaryRunState(ctx, state, controlplane.RunStateCompacting) + }() + + err := func() error { + session, result, compactErr := s.runCompactForSession(ctx, state.runID, state.session, cfg, mode, errorPolicy) + if compactErr != nil { + return compactErr + } + state.session = session + if result.Applied { + state.resetTokenTotals() + state.compactApplied = true + applied = true + } + return nil + }() + if err != nil { + return false, err } - return false, nil + return applied, nil } // autoCompactThreshold 返回当前配置下的自动 compact 触发阈值。 @@ -517,28 +569,23 @@ func (s *Service) bindSessionLock(sessionID string) func() { } } -// withSelfHealingReminder 在无进展临界轮次注入自愈提醒,保持提示词拼接规则集中。 -func withSelfHealingReminder(systemPrompt string, streak int, limit int) string { - if streak != limit-1 { +// withProgressReminder 根据当前 progress 快照选择并注入唯一的自愈提醒。 +func withProgressReminder(systemPrompt string, score controlplane.ProgressScore) string { + var reminder string + switch score.ReminderKind { + case controlplane.ReminderKindRepeatCycle: + reminder = selfHealingRepeatReminder + case controlplane.ReminderKindNoProgress, controlplane.ReminderKindGenericStalled: + reminder = selfHealingReminder + default: return systemPrompt } - trimmed := strings.TrimSpace(systemPrompt) - if trimmed == "" { - return selfHealingReminder - } - return trimmed + "\n\n" + selfHealingReminder -} -// withSelfHealingRepeatReminder 在重复循环临界轮次注入循环自愈提醒,避免模型继续相同工具调用。 -func withSelfHealingRepeatReminder(systemPrompt string, repeatStreak int, repeatLimit int) (string, bool) { - if repeatStreak != repeatLimit-1 { - return systemPrompt, false - } trimmed := strings.TrimSpace(systemPrompt) if trimmed == "" { - return selfHealingRepeatReminder, true + return reminder } - return trimmed + "\n\n" + selfHealingRepeatReminder, true + return trimmed + "\n\n" + reminder } // autoCompactCacheKeyFromConfig 提取会影响自动压缩阈值解析的配置维度,用于 run 内缓存命中判断。 diff --git a/internal/runtime/run_lifecycle.go b/internal/runtime/run_lifecycle.go index be293c8c..28e52ea3 100644 --- a/internal/runtime/run_lifecycle.go +++ b/internal/runtime/run_lifecycle.go @@ -12,20 +12,121 @@ import ( "neo-code/internal/runtime/controlplane" ) -// transitionRunPhase 在阶段变化时发出 phase_changed 并更新 runState。 -func (s *Service) transitionRunPhase(ctx context.Context, state *runState, next controlplane.Phase) { - if state == nil || state.phase == next { - return +// setBaseRunState 更新主链生命周期状态,并触发有效运行态重计算。 +func (s *Service) setBaseRunState(ctx context.Context, state *runState, next controlplane.RunState) error { + if state == nil { + return nil + } + if !isBaseLifecycleState(next) { + return errors.New("runtime: invalid base lifecycle state") + } + state.mu.Lock() + state.baseLifecycle = next + state.mu.Unlock() + return s.refreshEffectiveRunState(ctx, state) +} + +// enterTemporaryRunState 增加临时治理态计数,并触发有效运行态重计算。 +func (s *Service) enterTemporaryRunState(ctx context.Context, state *runState, temporary controlplane.RunState) error { + if state == nil { + return nil + } + state.mu.Lock() + switch temporary { + case controlplane.RunStateWaitingPermission: + state.waitingPermissionCount++ + case controlplane.RunStateCompacting: + state.compactingCount++ + default: + state.mu.Unlock() + return errors.New("runtime: unsupported temporary lifecycle state") + } + state.mu.Unlock() + return s.refreshEffectiveRunState(ctx, state) +} + +// leaveTemporaryRunState 释放临时治理态计数,并触发有效运行态重计算。 +func (s *Service) leaveTemporaryRunState(ctx context.Context, state *runState, temporary controlplane.RunState) error { + if state == nil { + return nil + } + state.mu.Lock() + switch temporary { + case controlplane.RunStateWaitingPermission: + if state.waitingPermissionCount > 0 { + state.waitingPermissionCount-- + } + case controlplane.RunStateCompacting: + if state.compactingCount > 0 { + state.compactingCount-- + } + default: + state.mu.Unlock() + return errors.New("runtime: unsupported temporary lifecycle state") + } + state.mu.Unlock() + return s.refreshEffectiveRunState(ctx, state) +} + +// refreshEffectiveRunState 根据 base + 临时态覆盖层计算并发出统一 phase_changed 事件。 +func (s *Service) refreshEffectiveRunState(ctx context.Context, state *runState) error { + if state == nil { + return nil } - from := state.phase - state.phase = next + state.mu.Lock() + next := deriveEffectiveRunState(state) + from := state.lifecycle + if next == from { + state.mu.Unlock() + return nil + } + if err := controlplane.ValidateRunStateTransition(from, next); err != nil { + state.mu.Unlock() + return err + } + state.lifecycle = next + state.mu.Unlock() + _ = s.emitRunScoped(ctx, EventPhaseChanged, state, PhaseChangedPayload{ From: string(from), To: string(next), }) + return nil +} + +// deriveEffectiveRunState 统一推导当前有效运行态,临时治理态优先级高于 base 主链态。 +func deriveEffectiveRunState(state *runState) controlplane.RunState { + if state == nil { + return "" + } + if state.waitingPermissionCount > 0 { + return controlplane.RunStateWaitingPermission + } + if state.compactingCount > 0 { + return controlplane.RunStateCompacting + } + if state.baseLifecycle != "" { + return state.baseLifecycle + } + return state.lifecycle +} + +// isBaseLifecycleState 判断状态是否属于主链 base lifecycle 集合。 +func isBaseLifecycleState(state controlplane.RunState) bool { + switch state { + case controlplane.RunStatePlan, controlplane.RunStateExecute, controlplane.RunStateVerify, controlplane.RunStateStopped: + return true + default: + return false + } } -// emitRunTermination 在 Run 退出时决议并发出唯一 stop_reason_decided 终止事实事件。 +// transitionRunState 兼容旧调用入口,内部统一转为 base lifecycle 更新。 +func (s *Service) transitionRunState(ctx context.Context, state *runState, next controlplane.RunState) error { + return s.setBaseRunState(ctx, state, next) +} + +// emitRunTermination 在 Run 退出时决议并发出唯一的 stop_reason_decided 事件。 func (s *Service) emitRunTermination(ctx context.Context, input UserInput, state *runState, err error) { runID := strings.TrimSpace(input.RunID) sessionID := strings.TrimSpace(input.SessionID) @@ -40,17 +141,22 @@ func (s *Service) emitRunTermination(ctx context.Context, input UserInput, state return } state.stopEmitted = true + state.baseLifecycle = controlplane.RunStateStopped + state.lifecycle = controlplane.RunStateStopped + state.waitingPermissionCount = 0 + state.compactingCount = 0 } - in := controlplane.StopInput{Success: err == nil} + in := controlplane.StopInput{} if err != nil { - in.Success = false switch { case errors.Is(err, context.Canceled): - in.ContextCanceled = true + in.UserInterrupted = true default: - in.RunError = err + in.FatalError = err } + } else { + in.Completed = true } reason, detail := controlplane.DecideStopReason(in) @@ -58,10 +164,11 @@ func (s *Service) emitRunTermination(ctx context.Context, input UserInput, state phase := "" if state != nil { turn = state.turn - if state.phase != "" { - phase = string(state.phase) + if state.lifecycle != "" { + phase = string(state.lifecycle) } } + emitCtx, cancel := stopReasonEmitContext(ctx) defer cancel() _ = s.emitWithEnvelope(emitCtx, RuntimeEvent{ @@ -76,7 +183,7 @@ func (s *Service) emitRunTermination(ctx context.Context, input UserInput, state }) } -// stopReasonEmitContext 为终止事件提供可用发送窗口,避免继承已取消上下文导致事实事件丢失。 +// stopReasonEmitContext 为终止事件提供可用发送窗口,避免继承已取消上下文导致事件丢失。 func stopReasonEmitContext(ctx context.Context) (context.Context, context.CancelFunc) { if ctx != nil && ctx.Err() == nil { return context.WithTimeout(ctx, terminationEventEmitTimeout) @@ -84,7 +191,7 @@ func stopReasonEmitContext(ctx context.Context) (context.Context, context.Cancel return context.WithTimeout(context.Background(), terminationEventEmitTimeout) } -// handleRunError 负责记录 provider 错误日志并原样返回错误;终止类事件由 Run 出口统一发出。 +// handleRunError 统一转换 runtime 终止错误,保证取消语义收敛到同一路径。 func (s *Service) handleRunError(ctx context.Context, runID string, sessionID string, err error) error { _ = ctx _ = runID @@ -92,7 +199,6 @@ func (s *Service) handleRunError(ctx context.Context, runID string, sessionID st if errors.Is(err, context.Canceled) { return context.Canceled } - return err } @@ -105,7 +211,7 @@ func isRetryableProviderError(err error) bool { return providerErr.Retryable } -// providerRetryBackoff 计算 runtime 级 provider 重试等待时间。 +// providerRetryBackoff 计算 runtime 级 provider 重试等待时长。 func providerRetryBackoff(attempt int) time.Duration { wait := providerRetryBaseWait << (attempt - 1) jitter := float64(wait) * (0.5 + rand.Float64()) diff --git a/internal/runtime/run_lifecycle_test.go b/internal/runtime/run_lifecycle_test.go new file mode 100644 index 00000000..916c7598 --- /dev/null +++ b/internal/runtime/run_lifecycle_test.go @@ -0,0 +1,113 @@ +package runtime + +import ( + "context" + "testing" + + "neo-code/internal/runtime/controlplane" +) + +func TestTemporaryRunStateCountersKeepEffectiveStateStable(t *testing.T) { + t.Parallel() + + service := &Service{events: make(chan RuntimeEvent, 16)} + state := newRunState("run-temp-counter", newRuntimeSession("session-temp-counter")) + if err := service.setBaseRunState(context.Background(), &state, controlplane.RunStatePlan); err != nil { + t.Fatalf("set base run state: %v", err) + } + if err := service.setBaseRunState(context.Background(), &state, controlplane.RunStateExecute); err != nil { + t.Fatalf("set base run state: %v", err) + } + + if err := service.enterTemporaryRunState(context.Background(), &state, controlplane.RunStateWaitingPermission); err != nil { + t.Fatalf("enter waiting #1: %v", err) + } + if err := service.enterTemporaryRunState(context.Background(), &state, controlplane.RunStateWaitingPermission); err != nil { + t.Fatalf("enter waiting #2: %v", err) + } + if state.lifecycle != controlplane.RunStateWaitingPermission { + t.Fatalf("lifecycle = %q, want waiting_permission", state.lifecycle) + } + + if err := service.leaveTemporaryRunState(context.Background(), &state, controlplane.RunStateWaitingPermission); err != nil { + t.Fatalf("leave waiting #1: %v", err) + } + if state.lifecycle != controlplane.RunStateWaitingPermission { + t.Fatalf("lifecycle after first leave = %q, want waiting_permission", state.lifecycle) + } + + if err := service.leaveTemporaryRunState(context.Background(), &state, controlplane.RunStateWaitingPermission); err != nil { + t.Fatalf("leave waiting #2: %v", err) + } + if state.lifecycle != controlplane.RunStateExecute { + t.Fatalf("lifecycle after second leave = %q, want execute", state.lifecycle) + } + + events := collectRuntimeEvents(service.Events()) + assertPhaseTransitions(t, events, [][2]string{ + {"", "plan"}, + {"plan", "execute"}, + {"execute", "waiting_permission"}, + {"waiting_permission", "execute"}, + }) +} + +func TestTemporaryRunStatePriorityWaitingOverCompacting(t *testing.T) { + t.Parallel() + + service := &Service{events: make(chan RuntimeEvent, 16)} + state := newRunState("run-temp-priority", newRuntimeSession("session-temp-priority")) + if err := service.setBaseRunState(context.Background(), &state, controlplane.RunStatePlan); err != nil { + t.Fatalf("set base run state: %v", err) + } + + if err := service.enterTemporaryRunState(context.Background(), &state, controlplane.RunStateCompacting); err != nil { + t.Fatalf("enter compacting: %v", err) + } + if state.lifecycle != controlplane.RunStateCompacting { + t.Fatalf("lifecycle = %q, want compacting", state.lifecycle) + } + if err := service.enterTemporaryRunState(context.Background(), &state, controlplane.RunStateWaitingPermission); err != nil { + t.Fatalf("enter waiting: %v", err) + } + if state.lifecycle != controlplane.RunStateWaitingPermission { + t.Fatalf("lifecycle = %q, want waiting_permission", state.lifecycle) + } + + if err := service.leaveTemporaryRunState(context.Background(), &state, controlplane.RunStateWaitingPermission); err != nil { + t.Fatalf("leave waiting: %v", err) + } + if state.lifecycle != controlplane.RunStateCompacting { + t.Fatalf("lifecycle = %q, want compacting after waiting leaves", state.lifecycle) + } + if err := service.leaveTemporaryRunState(context.Background(), &state, controlplane.RunStateCompacting); err != nil { + t.Fatalf("leave compacting: %v", err) + } + if state.lifecycle != controlplane.RunStatePlan { + t.Fatalf("lifecycle = %q, want plan", state.lifecycle) + } +} + +func assertPhaseTransitions(t *testing.T, events []RuntimeEvent, expected [][2]string) { + t.Helper() + + var phases [][2]string + for _, event := range events { + if event.Type != EventPhaseChanged { + continue + } + payload, ok := event.Payload.(PhaseChangedPayload) + if !ok { + t.Fatalf("expected phase payload, got %#v", event.Payload) + } + phases = append(phases, [2]string{payload.From, payload.To}) + } + if len(phases) != len(expected) { + t.Fatalf("phase transition count = %d, want %d, got %+v", len(phases), len(expected), phases) + } + for i := range expected { + if phases[i] != expected[i] { + t.Fatalf("phase transition[%d] = %+v, want %+v", i, phases[i], expected[i]) + } + } +} diff --git a/internal/runtime/run_termination_test.go b/internal/runtime/run_termination_test.go index 1247cd9c..0cdf077b 100644 --- a/internal/runtime/run_termination_test.go +++ b/internal/runtime/run_termination_test.go @@ -32,8 +32,8 @@ func TestEmitRunTerminationEmitsStopReasonOnce(t *testing.T) { if !ok { t.Fatalf("expected StopReasonDecidedPayload, got %#v", e.Payload) } - if p.Reason != controlplane.StopReasonError { - t.Fatalf("reason = %q, want error", p.Reason) + if p.Reason != controlplane.StopReasonFatalError { + t.Fatalf("reason = %q, want fatal error", p.Reason) } } } diff --git a/internal/runtime/runtime_branch_coverage_test.go b/internal/runtime/runtime_branch_coverage_test.go index eb4a85a5..4503b738 100644 --- a/internal/runtime/runtime_branch_coverage_test.go +++ b/internal/runtime/runtime_branch_coverage_test.go @@ -16,7 +16,7 @@ func TestExecuteAssistantToolCallsReturnsNilForEmptyCalls(t *testing.T) { service := &Service{} state := &runState{} - err := service.executeAssistantToolCalls(context.Background(), state, turnSnapshot{}, providertypes.Message{}) + _, err := service.executeAssistantToolCalls(context.Background(), state, turnSnapshot{}, providertypes.Message{}) if err != nil { t.Fatalf("executeAssistantToolCalls() error = %v", err) } @@ -29,17 +29,16 @@ func TestExecuteOneToolCallStopsWhenContextCheckReturnsTrue(t *testing.T) { state := newRunState("run-stop", newRuntimeSession("session-stop")) called := false - service.executeOneToolCall( + _, _, _ = service.executeOneToolCall( context.Background(), &state, turnSnapshot{}, providertypes.ToolCall{ID: "call-1", Name: "noop"}, &sync.Mutex{}, func() bool { return true }, - func(error) { called = true }, ) if called { - t.Fatalf("rememberError should not be called when execution is short-circuited") + t.Fatalf("expected short-circuit to bypass legacy error callback path") } } @@ -91,11 +90,11 @@ func TestTransitionRunPhaseNoopBranches(t *testing.T) { t.Parallel() service := &Service{events: make(chan RuntimeEvent, 4)} - service.transitionRunPhase(context.Background(), nil, controlplane.PhasePlan) + service.transitionRunState(context.Background(), nil, controlplane.RunStatePlan) state := newRunState("run-phase", newRuntimeSession("session-phase")) - state.phase = controlplane.PhasePlan - service.transitionRunPhase(context.Background(), &state, controlplane.PhasePlan) + state.lifecycle = controlplane.RunStatePlan + service.transitionRunState(context.Background(), &state, controlplane.RunStatePlan) events := collectRuntimeEvents(service.Events()) if len(events) != 0 { diff --git a/internal/runtime/runtime_internal_helpers_test.go b/internal/runtime/runtime_internal_helpers_test.go index 87e5e52c..169828be 100644 --- a/internal/runtime/runtime_internal_helpers_test.go +++ b/internal/runtime/runtime_internal_helpers_test.go @@ -469,7 +469,7 @@ func TestExecuteAssistantToolCallsFillsErrorContent(t *testing.T) { } snapshot := turnSnapshot{workdir: t.TempDir(), toolTimeout: time.Second} - if err := service.executeAssistantToolCalls(context.Background(), &state, snapshot, assistant); err != nil { + if _, err := service.executeAssistantToolCalls(context.Background(), &state, snapshot, assistant); err != nil { t.Fatalf("executeAssistantToolCalls() error = %v", err) } if len(state.session.Messages) != 1 { @@ -509,7 +509,7 @@ func TestExecuteAssistantToolCallsCanceledSaveStillEmitsResultWhenExecErr(t *tes } snapshot := turnSnapshot{workdir: t.TempDir(), toolTimeout: time.Second} - err := service.executeAssistantToolCalls(context.Background(), &state, snapshot, assistant) + _, err := service.executeAssistantToolCalls(context.Background(), &state, snapshot, assistant) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled from save failure, got %v", err) } diff --git a/internal/runtime/runtime_progress_test.go b/internal/runtime/runtime_progress_test.go index f32d6acb..b5a7eea0 100644 --- a/internal/runtime/runtime_progress_test.go +++ b/internal/runtime/runtime_progress_test.go @@ -13,6 +13,7 @@ import ( "neo-code/internal/runtime/controlplane" agentsession "neo-code/internal/session" "neo-code/internal/tools" + todotool "neo-code/internal/tools/todo" ) func TestProgressStreakNoLongerStopsRun(t *testing.T) { @@ -83,7 +84,7 @@ func TestProgressStreakNoLongerStopsRun(t *testing.T) { } events := collectRuntimeEvents(service.Events()) - assertStopReasonDecided(t, events, controlplane.StopReasonSuccess, "") + assertStopReasonDecided(t, events, controlplane.StopReasonCompleted, "") if !promptInjected { t.Error("expected self-healing prompt to be injected before repetitive no-progress turns") @@ -165,7 +166,7 @@ func TestProgressEvidenceResetsNoProgressStreak(t *testing.T) { } events := collectRuntimeEvents(service.Events()) - assertStopReasonDecided(t, events, controlplane.StopReasonSuccess, "") + assertStopReasonDecided(t, events, controlplane.StopReasonCompleted, "") } func TestRepeatCycleStreakNoLongerStopsRunAndInjectsReminder(t *testing.T) { @@ -232,7 +233,7 @@ func TestRepeatCycleStreakNoLongerStopsRunAndInjectsReminder(t *testing.T) { } events := collectRuntimeEvents(service.Events()) - assertStopReasonDecided(t, events, controlplane.StopReasonSuccess, "") + assertStopReasonDecided(t, events, controlplane.StopReasonCompleted, "") if !promptInjected { t.Fatal("expected repeat self-healing prompt to be injected before repeat limit is reached") @@ -357,6 +358,8 @@ func TestPrepareTurnSnapshotInjectRepeatReminderWithEmptyPrompt(t *testing.T) { } state := newRunState("run-repeat-reminder-empty", newRuntimeSession("session-repeat-reminder-empty")) state.progress.LastScore.RepeatCycleStreak = 2 + state.progress.LastScore.StalledProgressState = controlplane.StalledProgressStalled + state.progress.LastScore.ReminderKind = controlplane.ReminderKindRepeatCycle snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) if err != nil { @@ -392,6 +395,8 @@ func TestPrepareTurnSnapshotRepeatReminderTakesPriority(t *testing.T) { state := newRunState("run-reminder-priority", newRuntimeSession("session-reminder-priority")) state.progress.LastScore.NoProgressStreak = 2 state.progress.LastScore.RepeatCycleStreak = 2 + state.progress.LastScore.StalledProgressState = controlplane.StalledProgressStalled + state.progress.LastScore.ReminderKind = controlplane.ReminderKindRepeatCycle snapshot, rebuilt, err := service.prepareTurnSnapshot(context.Background(), &state) if err != nil { @@ -471,6 +476,93 @@ func TestComputeTodoStateSignature(t *testing.T) { } } +func TestNoToolIncompleteTurnStillEvaluatesProgressAndInjectsReminder(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Runtime.MaxNoProgressStreak = 1 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + session := newRuntimeSession("session-no-tool-reminder") + session.Todos = []agentsession.TodoItem{ + { + ID: "todo-1", + Content: "close me", + Status: agentsession.TodoStatusPending, + Executor: agentsession.TodoExecutorAgent, + Revision: 1, + }, + } + store.sessions[session.ID] = cloneSession(session) + + registry := tools.NewRegistry() + registry.Register(todotool.New()) + + providerImpl := &scriptedProvider{ + responses: []scriptedResponse{ + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("done")}, + }, + FinishReason: "stop", + }, + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + { + ID: "todo-close", + Name: tools.ToolNameTodoWrite, + Arguments: `{"action":"set_status","id":"todo-1","status":"canceled","expected_revision":1}`, + }, + }, + }, + FinishReason: "tool_calls", + }, + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("done")}, + }, + FinishReason: "stop", + }, + }, + } + + service := NewWithFactory( + manager, + registry, + store, + &scriptedProviderFactory{provider: providerImpl}, + &stubContextBuilder{}, + ) + + if err := service.Run(context.Background(), UserInput{ + RunID: "run-no-tool-reminder", + SessionID: session.ID, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + + if len(providerImpl.requests) < 2 { + t.Fatalf("expected at least 2 provider requests, got %d", len(providerImpl.requests)) + } + if !strings.Contains(providerImpl.requests[1].SystemPrompt, selfHealingReminder) { + t.Fatalf("expected stalled reminder in second provider request, got %q", providerImpl.requests[1].SystemPrompt) + } + + events := collectRuntimeEvents(service.Events()) + assertEventContains(t, events, EventProgressEvaluated) + assertStopReasonDecided(t, events, controlplane.StopReasonCompleted, "") +} + func assertStopReasonDecided(t *testing.T, events []RuntimeEvent, wantReason controlplane.StopReason, wantDetail string) { t.Helper() assertEventContains(t, events, EventStopReasonDecided) diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index a279b1d1..d1116d2c 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -481,7 +481,7 @@ func TestExecuteAssistantToolCallsRemainingBranches(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() state := newRunState("run", newRuntimeSession("session-top-cancel")) - err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) + _, err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -499,7 +499,7 @@ func TestExecuteAssistantToolCallsRemainingBranches(t *testing.T) { store.sessions[session.ID] = cloneSession(session) service := &Service{events: make(chan RuntimeEvent, 8), approvalBroker: approvalflow.NewBroker(), toolManager: manager, sessionStore: store} state := newRunState("run", session) - err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) + _, err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -518,7 +518,7 @@ func TestExecuteAssistantToolCallsRemainingBranches(t *testing.T) { service := &Service{events: make(chan RuntimeEvent, 8), approvalBroker: approvalflow.NewBroker(), toolManager: manager, sessionStore: store} state := newRunState("run", session) - err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) + _, err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } @@ -537,7 +537,7 @@ func TestExecuteAssistantToolCallsRemainingBranches(t *testing.T) { service := &Service{events: make(chan RuntimeEvent, 8), approvalBroker: approvalflow.NewBroker(), toolManager: manager, sessionStore: store} state := newRunState("run", session) - err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) + _, err := service.executeAssistantToolCalls(ctx, &state, turnSnapshot{toolTimeout: time.Second}, providertypes.Message{ToolCalls: []providertypes.ToolCall{{ID: "c", Name: "filesystem_read_file"}}}) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 61de9d8d..03bfa546 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -821,7 +821,7 @@ func TestServiceRun(t *testing.T) { // 第二轮:普通文本回复 providerStreams: [][]providertypes.StreamEvent{ { - providertypes.NewToolCallStartStreamEvent(0, "call-1", "filesystem_edit"), + providertypes.NewToolCallStartStreamEvent(0, "call-1", "filesystem_read_file"), providertypes.NewToolCallDeltaStreamEvent(0, "call-1", `{"path":"main.go"}`), }, { @@ -829,7 +829,7 @@ func TestServiceRun(t *testing.T) { }, }, registerTool: &stubTool{ - name: "filesystem_edit", + name: "filesystem_read_file", content: "tool output", }, contextBuilder: &stubContextBuilder{ @@ -864,7 +864,7 @@ func TestServiceRun(t *testing.T) { if message.Role == "tool" && message.ToolCallID == "call-1" && strings.Contains(renderPartsForTest(message.Parts), "tool result") && - strings.Contains(renderPartsForTest(message.Parts), "tool: filesystem_edit") && + strings.Contains(renderPartsForTest(message.Parts), "tool: filesystem_read_file") && strings.Contains(renderPartsForTest(message.Parts), "status: ok") && strings.Contains(renderPartsForTest(message.Parts), "content:\ntool output") { foundToolResult = true @@ -879,7 +879,7 @@ func TestServiceRun(t *testing.T) { if session.Messages[2].Role != providertypes.RoleTool || renderPartsForTest(session.Messages[2].Parts) != "tool output" { t.Fatalf("expected persisted tool message to keep raw content, got %+v", session.Messages[2]) } - if session.Messages[2].ToolMetadata["tool_name"] != "filesystem_edit" { + if session.Messages[2].ToolMetadata["tool_name"] != "filesystem_read_file" { t.Fatalf("expected persisted tool metadata to keep tool name, got %+v", session.Messages[2].ToolMetadata) } }, @@ -1125,12 +1125,12 @@ func TestServiceRunSchedulesMemoExtractionOnlyAfterFinalCompletion(t *testing.T) manager := newRuntimeConfigManager(t) store := newMemoryStore() registry := tools.NewRegistry() - registry.Register(&stubTool{name: "filesystem_edit", content: "tool output"}) + registry.Register(&stubTool{name: "filesystem_read_file", content: "tool output"}) scripted := &scriptedProvider{ streams: [][]providertypes.StreamEvent{ { - providertypes.NewToolCallStartStreamEvent(0, "call-1", "filesystem_edit"), + providertypes.NewToolCallStartStreamEvent(0, "call-1", "filesystem_read_file"), providertypes.NewToolCallDeltaStreamEvent(0, "call-1", `{"path":"main.go"}`), providertypes.NewMessageDoneStreamEvent("tool_calls", nil), }, @@ -1161,7 +1161,7 @@ func TestServiceRunMergesLateToolCallMetadata(t *testing.T) { manager := newRuntimeConfigManager(t) store := newMemoryStore() - tool := &stubTool{name: "filesystem_edit", content: "tool output"} + tool := &stubTool{name: "filesystem_read_file", content: "tool output"} registry := tools.NewRegistry() registry.Register(tool) @@ -1169,7 +1169,7 @@ func TestServiceRunMergesLateToolCallMetadata(t *testing.T) { streams: [][]providertypes.StreamEvent{ { providertypes.NewToolCallDeltaStreamEvent(0, "", `{"path":"main.go"`), - providertypes.NewToolCallStartStreamEvent(0, "call-late", "filesystem_edit"), + providertypes.NewToolCallStartStreamEvent(0, "call-late", "filesystem_read_file"), providertypes.NewToolCallDeltaStreamEvent(0, "call-late", `}`), }, {providertypes.NewTextDeltaStreamEvent("done")}, @@ -1187,8 +1187,8 @@ func TestServiceRunMergesLateToolCallMetadata(t *testing.T) { if tool.lastInput.ID != "call-late" { t.Fatalf("expected merged tool call id %q, got %q", "call-late", tool.lastInput.ID) } - if tool.lastInput.Name != "filesystem_edit" { - t.Fatalf("expected merged tool name %q, got %q", "filesystem_edit", tool.lastInput.Name) + if tool.lastInput.Name != "filesystem_read_file" { + t.Fatalf("expected merged tool name %q, got %q", "filesystem_read_file", tool.lastInput.Name) } if got := string(tool.lastInput.Arguments); got != `{"path":"main.go"}` { t.Fatalf("expected merged tool arguments %q, got %q", `{"path":"main.go"}`, got) @@ -1201,7 +1201,7 @@ func TestServiceRunMergesLateToolCallMetadata(t *testing.T) { if len(session.Messages[1].ToolCalls) != 1 { t.Fatalf("expected persisted assistant tool call, got %+v", session.Messages[1]) } - if session.Messages[1].ToolCalls[0].ID != "call-late" || session.Messages[1].ToolCalls[0].Name != "filesystem_edit" { + if session.Messages[1].ToolCalls[0].ID != "call-late" || session.Messages[1].ToolCalls[0].Name != "filesystem_read_file" { t.Fatalf("expected merged assistant tool call metadata, got %+v", session.Messages[1].ToolCalls[0]) } if session.Messages[2].ToolCallID != "call-late" { @@ -1682,17 +1682,17 @@ func TestServiceRunUsesToolManager(t *testing.T) { AgentID: "agent-run-tool-manager", IssuedAt: now.Add(-time.Minute), ExpiresAt: now.Add(time.Hour), - AllowedTools: []string{"filesystem_edit"}, + AllowedTools: []string{"filesystem_read_file"}, AllowedPaths: []string{t.TempDir()}, NetworkPolicy: security.NetworkPolicy{Mode: security.NetworkPermissionDenyAll}, WritePermission: security.WritePermissionWorkspace, } toolManager := &stubToolManager{ specs: []providertypes.ToolSpec{ - {Name: "filesystem_edit", Description: "stub", Schema: map[string]any{"type": "object"}}, + {Name: "filesystem_read_file", Description: "stub", Schema: map[string]any{"type": "object"}}, }, result: tools.ToolResult{ - Name: "filesystem_edit", + Name: "filesystem_read_file", Content: "tool manager output", Metadata: map[string]any{ "path": "main.go", @@ -1703,7 +1703,7 @@ func TestServiceRunUsesToolManager(t *testing.T) { scripted := &scriptedProvider{ streams: [][]providertypes.StreamEvent{ { - providertypes.NewToolCallStartStreamEvent(0, "call-manager", "filesystem_edit"), + providertypes.NewToolCallStartStreamEvent(0, "call-manager", "filesystem_read_file"), providertypes.NewToolCallDeltaStreamEvent(0, "call-manager", `{"path":"main.go"}`), }, {providertypes.NewTextDeltaStreamEvent("done")}, @@ -1739,7 +1739,7 @@ func TestServiceRunUsesToolManager(t *testing.T) { if toolManager.lastInput.CapabilityToken == nil || toolManager.lastInput.CapabilityToken.ID != capability.ID { t.Fatalf("expected forwarded capability token id %q, got %+v", capability.ID, toolManager.lastInput.CapabilityToken) } - if len(scripted.requests) == 0 || len(scripted.requests[0].Tools) != 1 || scripted.requests[0].Tools[0].Name != "filesystem_edit" { + if len(scripted.requests) == 0 || len(scripted.requests[0].Tools) != 1 || scripted.requests[0].Tools[0].Name != "filesystem_read_file" { t.Fatalf("expected tool specs from tool manager, got %+v", scripted.requests) } @@ -1748,7 +1748,7 @@ func TestServiceRunUsesToolManager(t *testing.T) { for _, message := range session.Messages { if message.Role == providertypes.RoleTool && renderPartsForTest(message.Parts) == "tool manager output" && - message.ToolMetadata["tool_name"] == "filesystem_edit" && + message.ToolMetadata["tool_name"] == "filesystem_read_file" && message.ToolMetadata["path"] == "main.go" { foundToolMessage = true break @@ -2122,7 +2122,7 @@ func TestServiceRunErrorPaths(t *testing.T) { ToolCalls: []providertypes.ToolCall{ { ID: fmt.Sprintf("loop-call-%d", i), - Name: "filesystem_edit", + Name: "filesystem_read_file", Arguments: fmt.Sprintf(`{"path":"x", "iteration": %d}`, i), }, }, @@ -2136,7 +2136,7 @@ func TestServiceRunErrorPaths(t *testing.T) { }) return &scriptedProvider{responses: responses} }(), - registerTool: &stubTool{name: "filesystem_edit", content: "loop tool output"}, + registerTool: &stubTool{name: "filesystem_read_file", content: "loop tool output"}, expectEvents: []EventType{EventUserMessage, EventToolStart, EventToolChunk, EventToolResult, EventAgentDone}, assert: func(t *testing.T, store *memoryStore, scripted *scriptedProvider, tool *stubTool) { t.Helper() @@ -3175,7 +3175,7 @@ func TestServiceRunUsesSessionWorkdirForContextAndTools(t *testing.T) { session := agentsession.NewWithWorkdir("Session Workdir", sessionWorkdir) store.sessions[session.ID] = cloneSession(session) - tool := &stubTool{name: "filesystem_edit", content: "ok"} + tool := &stubTool{name: "filesystem_read_file", content: "ok"} registry := tools.NewRegistry() registry.Register(tool) @@ -3183,7 +3183,7 @@ func TestServiceRunUsesSessionWorkdirForContextAndTools(t *testing.T) { scripted := &scriptedProvider{ streams: [][]providertypes.StreamEvent{ { - providertypes.NewToolCallStartStreamEvent(0, "call-session-workdir", "filesystem_edit"), + providertypes.NewToolCallStartStreamEvent(0, "call-session-workdir", "filesystem_read_file"), providertypes.NewToolCallDeltaStreamEvent(0, "call-session-workdir", `{"path":"main.go"}`), }, {providertypes.NewTextDeltaStreamEvent("done")}, @@ -5277,7 +5277,7 @@ func TestAgentDoneEventCarriesRunScopedEnvelope(t *testing.T) { if doneEvent.Turn == turnUnspecified { t.Fatalf("expected run-scoped turn, got %d", doneEvent.Turn) } - if doneEvent.Phase != string(controlplane.PhasePlan) { - t.Fatalf("expected phase=%q, got %q", controlplane.PhasePlan, doneEvent.Phase) + if doneEvent.Phase != string(controlplane.RunStatePlan) { + t.Fatalf("expected phase=%q, got %q", controlplane.RunStatePlan, doneEvent.Phase) } } diff --git a/internal/runtime/state.go b/internal/runtime/state.go index 64a03c24..5cc0e7ca 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -28,8 +28,12 @@ type runState struct { agentID string capabilityToken *security.CapabilityToken turn int - phase controlplane.Phase + baseLifecycle controlplane.RunState + lifecycle controlplane.RunState + waitingPermissionCount int + compactingCount int stopEmitted bool + completion controlplane.CompletionState progress controlplane.ProgressState reportedMissingSkills map[string]struct{} } @@ -91,13 +95,14 @@ func (s *runState) markSkillMissingReported(skillID string) bool { // noProgressStreakLimit 由 prepareTurnSnapshot 一次性解析并存储,确保同一轮的 // 提示词纠偏阈值来自同一配置快照,避免并发 reload 导致注入行为不一致。 type turnSnapshot struct { - config config.Config - providerConfig provider.RuntimeConfig - model string - workdir string - toolTimeout time.Duration - noProgressStreakLimit int - request providertypes.GenerateRequest + config config.Config + providerConfig provider.RuntimeConfig + model string + workdir string + toolTimeout time.Duration + noProgressStreakLimit int + repeatCycleStreakLimit int + request providertypes.GenerateRequest } // providerTurnResult 表示单轮 provider 调用成功后的结构化结果。 diff --git a/internal/runtime/subagent_tool_executor.go b/internal/runtime/subagent_tool_executor.go index 62abe53d..d5cdcfa3 100644 --- a/internal/runtime/subagent_tool_executor.go +++ b/internal/runtime/subagent_tool_executor.go @@ -66,10 +66,21 @@ func (e *subAgentRuntimeToolExecutor) ExecuteTool( agentID := strings.TrimSpace(input.AgentID) workdir := strings.TrimSpace(input.Workdir) callName := strings.TrimSpace(input.Call.Name) + capabilityToken := e.bindCapabilityTokenToExecution(e.resolveCapabilityToken(input), taskID, agentID) + effectiveTaskID := taskID + effectiveAgentID := agentID + if capabilityToken != nil { + if trimmedTaskID := strings.TrimSpace(capabilityToken.TaskID); trimmedTaskID != "" { + effectiveTaskID = trimmedTaskID + } + if trimmedAgentID := strings.TrimSpace(capabilityToken.AgentID); trimmedAgentID != "" { + effectiveAgentID = trimmedAgentID + } + } payload := SubAgentToolCallEventPayload{ Role: input.Role, - TaskID: taskID, + TaskID: effectiveTaskID, ToolName: callName, Decision: subAgentToolDecisionPending, ElapsedMS: 0, @@ -79,9 +90,9 @@ func (e *subAgentRuntimeToolExecutor) ExecuteTool( result, execErr := e.service.executeToolCallWithPermission(ctx, permissionExecutionInput{ RunID: runID, SessionID: sessionID, - TaskID: taskID, - AgentID: agentID, - Capability: e.resolveCapabilityToken(input), + TaskID: effectiveTaskID, + AgentID: effectiveAgentID, + Capability: capabilityToken, Call: input.Call, Workdir: workdir, ToolTimeout: timeout, @@ -113,7 +124,7 @@ func (e *subAgentRuntimeToolExecutor) ExecuteTool( eventPayload := SubAgentToolCallEventPayload{ Role: input.Role, - TaskID: taskID, + TaskID: effectiveTaskID, ToolName: output.Name, Decision: decision, ElapsedMS: elapsedMilliseconds(startedAt), @@ -186,6 +197,51 @@ func (e *subAgentRuntimeToolExecutor) resolveCapabilityToken(input subagent.Tool return &signed } +// bindCapabilityTokenToExecution 在真正执行前把 capability token 重新绑定到当前 task/agent,避免回退 parent token 时破坏权限校验。 +func (e *subAgentRuntimeToolExecutor) bindCapabilityTokenToExecution( + token *security.CapabilityToken, + taskID string, + agentID string, +) *security.CapabilityToken { + if token == nil { + return nil + } + normalized := token.Normalize() + boundTaskID := strings.TrimSpace(taskID) + boundAgentID := strings.TrimSpace(agentID) + if (boundTaskID == "" || normalized.TaskID == boundTaskID) && + (boundAgentID == "" || normalized.AgentID == boundAgentID) { + return &normalized + } + if e == nil || e.service == nil { + return &normalized + } + + signerProvider, ok := e.service.toolManager.(capabilitySignerProvider) + if !ok { + return &normalized + } + signer := signerProvider.CapabilitySigner() + if signer == nil { + return &normalized + } + + rebound := normalized + rebound.ID = fmt.Sprintf("subagent-bind-%d-%s", time.Now().UTC().UnixNano(), boundTaskID) + if boundTaskID != "" { + rebound.TaskID = boundTaskID + } + if boundAgentID != "" { + rebound.AgentID = boundAgentID + } + rebound.Signature = "" + signed, err := signer.Sign(rebound) + if err != nil { + return &normalized + } + return &signed +} + // tightenToolAllowlist 以 parent 为上界收敛工具白名单;未请求时继承 parent。 func tightenToolAllowlist(parent []string, requested []string) []string { parent = normalizeAllowlistToList(parent) diff --git a/internal/runtime/todo_runtime_integration_test.go b/internal/runtime/todo_runtime_integration_test.go index 8eefc8f2..2eacbd57 100644 --- a/internal/runtime/todo_runtime_integration_test.go +++ b/internal/runtime/todo_runtime_integration_test.go @@ -33,6 +33,19 @@ func TestServiceRunTodoWriteToolCall(t *testing.T) { }, FinishReason: "tool_calls", }, + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + { + ID: "todo-call-2", + Name: tools.ToolNameTodoWrite, + Arguments: `{"action":"set_status","id":"todo-1","status":"canceled","expected_revision":1}`, + }, + }, + }, + FinishReason: "tool_calls", + }, { Message: providertypes.Message{ Role: providertypes.RoleAssistant, @@ -79,6 +92,9 @@ func TestServiceRunTodoWriteToolCall(t *testing.T) { if session.Todos[0].ID != "todo-1" || session.Todos[0].Content != "implement feature" { t.Fatalf("unexpected todo item: %+v", session.Todos[0]) } + if session.Todos[0].Status != "canceled" { + t.Fatalf("expected todo to be closed before completion, got %+v", session.Todos[0]) + } events := collectRuntimeEvents(service.Events()) foundTodoUpdated := false diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index 90e8eafc..686e3aa4 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -10,24 +10,31 @@ import ( "neo-code/internal/tools" ) -// executeAssistantToolCalls 并发执行 assistant 返回的全部工具调用并回写结果。 +type indexedToolCall struct { + index int + call providertypes.ToolCall +} + +// executeAssistantToolCalls 并发执行 assistant 返回的全部工具调用并返回结构化执行摘要。 func (s *Service) executeAssistantToolCalls( ctx context.Context, state *runState, snapshot turnSnapshot, assistant providertypes.Message, -) error { +) (toolExecutionSummary, error) { if len(assistant.ToolCalls) == 0 { - return nil + return toolExecutionSummary{}, nil } execCtx, cancelExec := context.WithCancel(ctx) defer cancelExec() parallelism := resolveToolParallelism(len(assistant.ToolCalls)) - orderedCalls := reorderToolCallsByNameRoundRobin(assistant.ToolCalls) toolLocks := buildToolExecutionLocks(assistant.ToolCalls) - taskCh := make(chan providertypes.ToolCall) + taskCh := make(chan indexedToolCall) + results := make([]tools.ToolResult, len(assistant.ToolCalls)) + completed := make([]bool, len(assistant.ToolCalls)) + writes := make([]bool, len(assistant.ToolCalls)) var mu sync.Mutex var firstErr error var workerWG sync.WaitGroup @@ -40,32 +47,51 @@ func (s *Service) executeAssistantToolCalls( workerWG.Add(1) go func() { defer workerWG.Done() - for call := range taskCh { - s.executeOneToolCall( + for task := range taskCh { + result, wrote, err := s.executeOneToolCall( execCtx, state, snapshot, - call, - toolLocks[normalizeToolLockKey(call.Name)], + task.call, + toolLocks[normalizeToolLockKey(task.call.Name)], checkContext, - func(err error) { - recordAndCancelOnFirstError(&mu, &firstErr, err, cancelExec) - }, ) + mu.Lock() + results[task.index] = result + completed[task.index] = true + writes[task.index] = wrote + mu.Unlock() + if err != nil { + recordAndCancelOnFirstError(&mu, &firstErr, err, cancelExec) + } } }() } - for _, call := range orderedCalls { + for index, call := range assistant.ToolCalls { if checkContext() { break } - taskCh <- call + taskCh <- indexedToolCall{index: index, call: call} } close(taskCh) workerWG.Wait() - return firstErr + + summary := toolExecutionSummary{ + Calls: append([]providertypes.ToolCall(nil), assistant.ToolCalls...), + } + for index, ok := range completed { + if !ok { + continue + } + summary.Results = append(summary.Results, results[index]) + if writes[index] { + summary.HasSuccessfulWorkspaceWrite = true + } + } + summary.HasSuccessfulVerification = hasSuccessfulVerificationResult(summary.Results) + return summary, firstErr } // executeOneToolCall 在单个 worker 中执行一次工具调用并处理结果回写与事件发射。 @@ -76,10 +102,9 @@ func (s *Service) executeOneToolCall( call providertypes.ToolCall, toolLock *sync.Mutex, checkContext func() bool, - rememberError func(error), -) { +) (tools.ToolResult, bool, error) { if checkContext() { - return + return tools.ToolResult{}, false, ctx.Err() } toolLock.Lock() @@ -100,13 +125,8 @@ func (s *Service) executeOneToolCall( }) if errors.Is(execErr, context.Canceled) { - rememberError(execErr) - return + return result, false, execErr } - if execErr == nil && checkContext() { - return - } - if execErr != nil && strings.TrimSpace(result.Content) == "" { result.Content = execErr.Error() } @@ -115,12 +135,7 @@ func (s *Service) executeOneToolCall( if execErr != nil && errors.Is(err, context.Canceled) { s.emitRunScoped(ctx, EventToolResult, state, result) } - rememberError(err) - return - } - - if execErr == nil && checkContext() { - return + return result, false, err } s.emitRunScoped(ctx, EventToolResult, state, result) @@ -132,9 +147,13 @@ func (s *Service) executeOneToolCall( state.mu.Unlock() } - if execErr != nil && checkContext() { - return + if checkContext() { + return result, hasSuccessfulWorkspaceWriteFact(result, execErr), ctx.Err() } + if execErr != nil { + return result, false, nil + } + return result, hasSuccessfulWorkspaceWriteFact(result, execErr), nil } // resolveToolParallelism 计算本轮工具执行的并发上限,避免无界 goroutine 扩散。 @@ -148,40 +167,6 @@ func resolveToolParallelism(toolCallCount int) int { return defaultToolParallelism } -// reorderToolCallsByNameRoundRobin 按工具名分组后轮询展开,降低同名批量调用导致的队头阻塞。 -func reorderToolCallsByNameRoundRobin(calls []providertypes.ToolCall) []providertypes.ToolCall { - if len(calls) <= 1 { - return append([]providertypes.ToolCall(nil), calls...) - } - grouped := make(map[string][]providertypes.ToolCall, len(calls)) - order := make([]string, 0, len(calls)) - for _, call := range calls { - key := normalizeToolLockKey(call.Name) - if _, ok := grouped[key]; !ok { - order = append(order, key) - } - grouped[key] = append(grouped[key], call) - } - - ordered := make([]providertypes.ToolCall, 0, len(calls)) - for { - progressed := false - for _, key := range order { - queue := grouped[key] - if len(queue) == 0 { - continue - } - ordered = append(ordered, queue[0]) - grouped[key] = queue[1:] - progressed = true - } - if !progressed { - break - } - } - return ordered -} - // buildToolExecutionLocks 按工具名构造互斥锁,确保同名工具调用在单轮内串行执行。 func buildToolExecutionLocks(calls []providertypes.ToolCall) map[string]*sync.Mutex { locks := make(map[string]*sync.Mutex, len(calls)) @@ -253,3 +238,11 @@ func (s *Service) emitTodoToolEvent( s.emitRunScoped(ctx, EventTodoConflict, state, TodoEventPayload{Action: action, Reason: reason}) } } + +// hasSuccessfulWorkspaceWriteFact 判断工具结果是否产出了成功写入事实。 +func hasSuccessfulWorkspaceWriteFact(result tools.ToolResult, execErr error) bool { + if execErr != nil || result.IsError { + return false + } + return result.Facts.WorkspaceWrite +} diff --git a/internal/runtime/turn_control.go b/internal/runtime/turn_control.go new file mode 100644 index 00000000..768f71ff --- /dev/null +++ b/internal/runtime/turn_control.go @@ -0,0 +1,252 @@ +package runtime + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "strings" + + providertypes "neo-code/internal/provider/types" + "neo-code/internal/runtime/controlplane" + agentsession "neo-code/internal/session" + "neo-code/internal/tools" +) + +type toolExecutionSummary struct { + Calls []providertypes.ToolCall + Results []tools.ToolResult + HasSuccessfulWorkspaceWrite bool + HasSuccessfulVerification bool +} + +// collectCompletionState 基于当前运行态与本轮 assistant 行为生成 completion 输入。 +func collectCompletionState( + state *runState, + _ providertypes.Message, + _ bool, +) controlplane.CompletionState { + current := state.completion + current.HasPendingAgentTodos = hasPendingAgentTodos(state.session.Todos) + return current +} + +// applyToolExecutionCompletion 更新一轮工具执行后的 completion 事实。 +func applyToolExecutionCompletion(current controlplane.CompletionState, summary toolExecutionSummary) controlplane.CompletionState { + if len(summary.Results) == 0 { + if summary.HasSuccessfulWorkspaceWrite { + current.HasUnverifiedWrites = true + } + if summary.HasSuccessfulVerification { + current.HasUnverifiedWrites = false + } + return current + } + for _, result := range summary.Results { + if result.IsError { + continue + } + if result.Facts.WorkspaceWrite { + current.HasUnverifiedWrites = true + } + if result.Facts.VerificationPerformed && result.Facts.VerificationPassed { + current.HasUnverifiedWrites = false + } + } + return current +} + +// collectProgressInput 基于执行前后事实组装 progress 评估输入。 +func collectProgressInput( + runState controlplane.RunState, + beforeTask agentsession.TaskState, + afterTask agentsession.TaskState, + beforeTodos []agentsession.TodoItem, + afterTodos []agentsession.TodoItem, + summary toolExecutionSummary, + noProgressLimit int, + repeatLimit int, +) controlplane.ProgressInput { + evidence := deriveProgressEvidence(beforeTask, afterTask, beforeTodos, afterTodos, summary) + return controlplane.ProgressInput{ + RunState: runState, + Evidence: evidence, + CurrentToolSignature: computeToolSignature(summary.Calls), + ResultFingerprint: computeToolResultFingerprint(summary.Results), + SubgoalFingerprint: computeSubgoalFingerprint(afterTask, afterTodos, summary.Calls), + NoProgressLimit: noProgressLimit, + RepeatCycleLimit: repeatLimit, + } +} + +// deriveProgressEvidence 从本轮前后快照和工具摘要中提取结构化 evidence。 +func deriveProgressEvidence( + beforeTask agentsession.TaskState, + afterTask agentsession.TaskState, + beforeTodos []agentsession.TodoItem, + afterTodos []agentsession.TodoItem, + summary toolExecutionSummary, +) []controlplane.ProgressEvidenceRecord { + var evidence []controlplane.ProgressEvidenceRecord + + if computeTaskStateSignature(beforeTask) != computeTaskStateSignature(afterTask) { + evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceTaskStateChanged}) + } + if computeTodoStateSignature(beforeTodos) != computeTodoStateSignature(afterTodos) { + evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceTodoStateChanged}) + } + if summary.HasSuccessfulWorkspaceWrite { + evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceWriteApplied}) + } + if summary.HasSuccessfulVerification { + evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceVerifyPassed}) + } + if hasSuccessfulInformationalResult(summary.Results) { + evidence = append(evidence, controlplane.ProgressEvidenceRecord{Kind: controlplane.EvidenceNewInfoNonDup}) + } + return evidence +} + +// computeTaskStateSignature 计算 task_state 的结构化签名。 +func computeTaskStateSignature(task agentsession.TaskState) string { + encoded, err := json.Marshal(task.Clone()) + if err != nil { + return "" + } + hash := sha256.Sum256(encoded) + return hex.EncodeToString(hash[:]) +} + +// computeToolResultFingerprint 计算本轮工具结果的聚合指纹。 +func computeToolResultFingerprint(results []tools.ToolResult) string { + if len(results) == 0 { + return "" + } + type normalizedResult struct { + Name string `json:"name"` + IsError bool `json:"is_error"` + Content string `json:"content"` + ErrorClass string `json:"error_class,omitempty"` + } + + normalized := make([]normalizedResult, 0, len(results)) + for _, result := range results { + if strings.TrimSpace(result.Name) == "" { + return "" + } + entry := normalizedResult{ + Name: strings.TrimSpace(result.Name), + IsError: result.IsError, + Content: normalizeToolResultContent(result.Content), + } + if result.IsError { + entry.ErrorClass = classifyToolError(result) + } + normalized = append(normalized, entry) + } + + encoded, err := json.Marshal(normalized) + if err != nil { + return "" + } + hash := sha256.Sum256(encoded) + return hex.EncodeToString(hash[:]) +} + +// computeSubgoalFingerprint 生成当前轮子目标的轻量指纹。 +func computeSubgoalFingerprint( + task agentsession.TaskState, + todos []agentsession.TodoItem, + calls []providertypes.ToolCall, +) string { + type subgoalSnapshot struct { + NextStep string `json:"next_step,omitempty"` + OpenItems []string `json:"open_items,omitempty"` + Todos []string `json:"todos,omitempty"` + } + + snapshot := subgoalSnapshot{ + NextStep: strings.TrimSpace(task.NextStep), + OpenItems: append([]string(nil), task.OpenItems...), + } + for _, item := range todos { + if item.Status.IsTerminal() { + continue + } + snapshot.Todos = append(snapshot.Todos, strings.TrimSpace(item.Content)) + } + if snapshot.NextStep == "" && len(snapshot.OpenItems) == 0 && len(snapshot.Todos) == 0 { + return computeToolSignature(calls) + } + + encoded, err := json.Marshal(snapshot) + if err != nil { + return "" + } + hash := sha256.Sum256(encoded) + return hex.EncodeToString(hash[:]) +} + +// hasPendingAgentTodos 判断当前 session 中是否仍存在未闭合 todo。 +func hasPendingAgentTodos(items []agentsession.TodoItem) bool { + for _, item := range items { + if item.Status.IsTerminal() { + continue + } + return true + } + return false +} + +// hasSuccessfulInformationalResult 判断本轮是否至少获得一个成功的非写入工具结果。 +func hasSuccessfulInformationalResult(results []tools.ToolResult) bool { + for _, result := range results { + if result.IsError { + continue + } + switch strings.TrimSpace(result.Name) { + case tools.ToolNameFilesystemWriteFile, tools.ToolNameFilesystemEdit: + continue + default: + return true + } + } + return false +} + +// hasSuccessfulVerificationResult 判断本轮是否存在显式验证成功的结构化事实。 +func hasSuccessfulVerificationResult(results []tools.ToolResult) bool { + if len(results) == 0 { + return false + } + for _, result := range results { + if result.IsError || !result.Facts.VerificationPerformed || !result.Facts.VerificationPassed { + continue + } + return true + } + return false +} + +// normalizeToolResultContent 对工具结果文本做稳定化裁剪,避免无关差异放大指纹抖动。 +func normalizeToolResultContent(content string) string { + trimmed := strings.TrimSpace(content) + if len(trimmed) <= 256 { + return trimmed + } + return trimmed[:256] +} + +// classifyToolError 为错误结果生成轻量分类,避免直接依赖完整错误文案。 +func classifyToolError(result tools.ToolResult) string { + trimmed := strings.ToLower(strings.TrimSpace(result.Content)) + switch { + case strings.Contains(trimmed, "timeout"): + return "timeout" + case strings.Contains(trimmed, "denied"): + return "permission_denied" + case strings.Contains(trimmed, "not found"): + return "not_found" + default: + return "generic_error" + } +} diff --git a/internal/runtime/turn_control_test.go b/internal/runtime/turn_control_test.go new file mode 100644 index 00000000..93a1d7cf --- /dev/null +++ b/internal/runtime/turn_control_test.go @@ -0,0 +1,138 @@ +package runtime + +import ( + "context" + "testing" + + providertypes "neo-code/internal/provider/types" + "neo-code/internal/runtime/controlplane" + agentsession "neo-code/internal/session" + "neo-code/internal/tools" +) + +func TestCollectCompletionStateKeepsUnverifiedWrites(t *testing.T) { + t.Parallel() + + state := newRunState("run-verify-silent", newRuntimeSession("session-verify-silent")) + state.completion = controlplane.CompletionState{ + HasUnverifiedWrites: true, + } + + got := collectCompletionState(&state, providertypes.Message{Role: providertypes.RoleAssistant}, false) + if got.HasUnverifiedWrites != true { + t.Fatalf("expected unverified writes to remain blocked, got %+v", got) + } +} + +func TestApplyToolExecutionCompletionTracksWriteAndVerification(t *testing.T) { + t.Parallel() + + written := applyToolExecutionCompletion(controlplane.CompletionState{}, toolExecutionSummary{ + Results: []tools.ToolResult{ + {Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}}, + }, + }) + if !written.HasUnverifiedWrites { + t.Fatalf("expected successful write to require verification, got %+v", written) + } + + verified := applyToolExecutionCompletion(written, toolExecutionSummary{ + Results: []tools.ToolResult{ + {Facts: tools.ToolExecutionFacts{VerificationPerformed: true, VerificationPassed: true}}, + }, + }) + if verified.HasUnverifiedWrites { + t.Fatalf("expected explicit verification to clear pending write, got %+v", verified) + } +} + +func TestApplyToolExecutionCompletionKeepsUnverifiedWhenVerifyBeforeWrite(t *testing.T) { + t.Parallel() + + got := applyToolExecutionCompletion(controlplane.CompletionState{}, toolExecutionSummary{ + Results: []tools.ToolResult{ + {Facts: tools.ToolExecutionFacts{VerificationPerformed: true, VerificationPassed: true}}, + {Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}}, + }, + }) + if !got.HasUnverifiedWrites { + t.Fatalf("expected write after verify to remain unverified, got %+v", got) + } +} + +func TestApplyToolExecutionCompletionClearsWhenVerifyAfterWrite(t *testing.T) { + t.Parallel() + + got := applyToolExecutionCompletion(controlplane.CompletionState{}, toolExecutionSummary{ + Results: []tools.ToolResult{ + {Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}}, + {Facts: tools.ToolExecutionFacts{VerificationPerformed: true, VerificationPassed: true}}, + }, + }) + if got.HasUnverifiedWrites { + t.Fatalf("expected verify after write to clear unverified flag, got %+v", got) + } +} + +func TestHasPendingAgentTodosBlocksOnAnyNonTerminalTodo(t *testing.T) { + t.Parallel() + + todos := []agentsession.TodoItem{ + { + ID: "subagent-1", + Content: "delegate", + Status: agentsession.TodoStatusPending, + Executor: agentsession.TodoExecutorSubAgent, + }, + } + if !hasPendingAgentTodos(todos) { + t.Fatalf("expected pending subagent todo to block completion") + } + + completed := []agentsession.TodoItem{ + { + ID: "subagent-2", + Content: "done", + Status: agentsession.TodoStatusCompleted, + Executor: agentsession.TodoExecutorSubAgent, + }, + } + if hasPendingAgentTodos(completed) { + t.Fatalf("expected terminal todo to not block completion") + } +} + +func TestTransitionRunPhaseInvalidTransitionReturnsError(t *testing.T) { + t.Parallel() + + service := &Service{events: make(chan RuntimeEvent, 4)} + state := newRunState("run-invalid-phase", newRuntimeSession("session-invalid-phase")) + state.lifecycle = controlplane.RunStatePlan + + err := service.transitionRunState(context.Background(), &state, controlplane.RunStateVerify) + if err == nil { + t.Fatalf("expected invalid transition to return error") + } + if state.lifecycle != controlplane.RunStatePlan { + t.Fatalf("expected lifecycle to remain unchanged, got %q", state.lifecycle) + } + if events := collectRuntimeEvents(service.Events()); len(events) != 0 { + t.Fatalf("expected no phase events on invalid transition, got %+v", events) + } +} + +func TestHasSuccessfulVerificationResultRequiresStructuredFacts(t *testing.T) { + t.Parallel() + + if !hasSuccessfulVerificationResult([]tools.ToolResult{ + {Facts: tools.ToolExecutionFacts{VerificationPerformed: true, VerificationPassed: true}}, + }) { + t.Fatalf("expected verification facts to count as verify passed") + } + if hasSuccessfulVerificationResult([]tools.ToolResult{ + {Facts: tools.ToolExecutionFacts{VerificationPerformed: true, VerificationPassed: false}}, + {Facts: tools.ToolExecutionFacts{VerificationPerformed: false, VerificationPassed: true}}, + }) { + t.Fatalf("expected incomplete verification facts to be ignored") + } +} diff --git a/internal/security/workspace.go b/internal/security/workspace.go index 4746ccb9..3e4da06a 100644 --- a/internal/security/workspace.go +++ b/internal/security/workspace.go @@ -12,6 +12,8 @@ import ( "sync" ) +var evalSymlinks = filepath.EvalSymlinks + // WorkspaceSandbox enforces workspace-relative path boundaries for tool actions. type WorkspaceSandbox struct { canonicalRoots sync.Map @@ -256,9 +258,19 @@ func resolveCanonicalWorkspaceRoot(absoluteRoot string) (string, bool, error) { return "", false, fmt.Errorf("security: workspace root %q is not a directory", absoluteRoot) } - canonicalRoot, err := filepath.EvalSymlinks(absoluteRoot) + canonicalRoot, err := evalSymlinks(absoluteRoot) if err != nil { - return "", false, fmt.Errorf("security: resolve workspace root: %w", err) + if !errors.Is(err, os.ErrPermission) { + return "", false, fmt.Errorf("security: resolve workspace root: %w", err) + } + allowed, inspectErr := canFallbackToCandidateOnPermission(absoluteRoot, absoluteRoot) + if inspectErr != nil { + return "", false, inspectErr + } + if !allowed { + return "", false, fmt.Errorf("security: resolve workspace root %q: %w", absoluteRoot, err) + } + canonicalRoot = absoluteRoot } cleanedCanonical := cleanedPathKey(canonicalRoot) @@ -317,9 +329,22 @@ func ensureNoSymlinkEscape(root string, target string, original string) (string, } func ensureResolvedPathWithinWorkspace(root string, candidate string, original string) error { - resolved, err := filepath.EvalSymlinks(candidate) + if samePathKey(root, candidate) { + return nil + } + resolved, err := evalSymlinks(candidate) if err != nil { - return fmt.Errorf("security: resolve symlink %q: %w", candidate, err) + if !errors.Is(err, os.ErrPermission) { + return fmt.Errorf("security: resolve symlink %q: %w", candidate, err) + } + fallbackAllowed, inspectErr := canFallbackToCandidateOnPermission(root, candidate) + if inspectErr != nil { + return inspectErr + } + if !fallbackAllowed { + return fmt.Errorf("security: resolve symlink %q: %w", candidate, err) + } + resolved = candidate } resolved, err = filepath.Abs(resolved) if err != nil { @@ -331,6 +356,38 @@ func ensureResolvedPathWithinWorkspace(root string, candidate string, original s return nil } +// canFallbackToCandidateOnPermission 在 EvalSymlinks 遇到权限错误时,逐段确认 root 到 candidate 的现存路径不含符号链接。 +func canFallbackToCandidateOnPermission(root string, candidate string) (bool, error) { + rootInfo, err := os.Lstat(filepath.Clean(root)) + if err != nil { + return false, fmt.Errorf("security: inspect path %q: %w", root, err) + } + if rootInfo.Mode()&os.ModeSymlink != 0 { + return false, nil + } + + relativePath, err := filepath.Rel(root, candidate) + if err != nil { + return false, fmt.Errorf("security: compare workspace target %q: %w", candidate, err) + } + if relativePath == "." { + return true, nil + } + + current := cleanedPathKey(root) + for _, segment := range splitRelativePath(relativePath) { + current = cleanedPathKey(filepath.Join(current, segment)) + info, statErr := os.Lstat(current) + if statErr != nil { + return false, fmt.Errorf("security: inspect path %q: %w", current, statErr) + } + if info.Mode()&os.ModeSymlink != 0 { + return false, nil + } + } + return true, nil +} + func capturePathSnapshot(path string) (pathSnapshot, error) { info, err := os.Lstat(path) if err != nil { diff --git a/internal/security/workspace_test.go b/internal/security/workspace_test.go index 032e33d6..c5095b54 100644 --- a/internal/security/workspace_test.go +++ b/internal/security/workspace_test.go @@ -530,6 +530,54 @@ func TestCanonicalWorkspaceRoot(t *testing.T) { } } +func TestCanonicalWorkspaceRootPermissionErrorFallsBackToAbsoluteRoot(t *testing.T) { + originalEvalSymlinks := evalSymlinks + evalSymlinks = func(path string) (string, error) { + return "", os.ErrPermission + } + defer func() { + evalSymlinks = originalEvalSymlinks + }() + + root := t.TempDir() + got, err := NewWorkspaceSandbox().canonicalWorkspaceRoot(root) + if err != nil { + t.Fatalf("expected permission fallback for workspace root, got %v", err) + } + want, err := filepath.Abs(root) + if err != nil { + t.Fatalf("filepath.Abs(root): %v", err) + } + if !samePathKey(got, want) { + t.Fatalf("canonicalWorkspaceRoot() = %q, want %q", got, want) + } +} + +func TestCanonicalWorkspaceRootPermissionErrorRejectsSymlinkRoot(t *testing.T) { + base := t.TempDir() + realRoot := filepath.Join(base, "real") + if err := os.MkdirAll(realRoot, 0o755); err != nil { + t.Fatalf("mkdir real root: %v", err) + } + linkRoot := filepath.Join(base, "root-link") + if err := os.Symlink(realRoot, linkRoot); err != nil { + t.Skipf("symlink not supported in this environment: %v", err) + } + + originalEvalSymlinks := evalSymlinks + evalSymlinks = func(path string) (string, error) { + return "", os.ErrPermission + } + defer func() { + evalSymlinks = originalEvalSymlinks + }() + + _, err := NewWorkspaceSandbox().canonicalWorkspaceRoot(linkRoot) + if err == nil || !strings.Contains(err.Error(), "permission denied") { + t.Fatalf("expected symlink root to reject permission fallback, got %v", err) + } +} + func TestAbsoluteWorkspaceTarget(t *testing.T) { t.Parallel() @@ -570,7 +618,7 @@ func TestAbsoluteWorkspaceTarget(t *testing.T) { if err != nil { t.Fatalf("filepath.Abs(%q): %v", tt.want, err) } - if got != filepath.Clean(wantAbs) { + if !samePathKey(got, wantAbs) { t.Fatalf("absoluteWorkspaceTarget() = %q, want %q", got, filepath.Clean(wantAbs)) } }) @@ -702,6 +750,72 @@ func TestEnsureNoSymlinkEscape(t *testing.T) { } } +func TestEnsureResolvedPathWithinWorkspacePermissionErrorFallsBackForPlainPath(t *testing.T) { + root := t.TempDir() + candidate := filepath.Join(root, "notes.txt") + mustWriteWorkspaceFile(t, candidate, "hello") + + originalEvalSymlinks := evalSymlinks + evalSymlinks = func(path string) (string, error) { + return "", os.ErrPermission + } + defer func() { + evalSymlinks = originalEvalSymlinks + }() + + err := ensureResolvedPathWithinWorkspace(root, candidate, "notes.txt") + if err != nil { + t.Fatalf("expected plain path permission fallback, got %v", err) + } +} + +func TestEnsureResolvedPathWithinWorkspacePermissionErrorRejectsSymlinkedPath(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + target := filepath.Join(outside, "secret.txt") + mustWriteWorkspaceFile(t, target, "secret") + + link := filepath.Join(root, "linked.txt") + mustSymlinkOrSkip(t, target, link) + + originalEvalSymlinks := evalSymlinks + evalSymlinks = func(path string) (string, error) { + return "", os.ErrPermission + } + defer func() { + evalSymlinks = originalEvalSymlinks + }() + + err := ensureResolvedPathWithinWorkspace(root, link, "linked.txt") + if err == nil || !strings.Contains(err.Error(), "resolve symlink") { + t.Fatalf("expected symlink permission error, got %v", err) + } +} + +func TestCanFallbackToCandidateOnPermissionRejectsSymlinkRoot(t *testing.T) { + base := t.TempDir() + realRoot := filepath.Join(base, "real") + if err := os.MkdirAll(realRoot, 0o755); err != nil { + t.Fatalf("mkdir real root: %v", err) + } + + symlinkRoot := filepath.Join(base, "root-link") + if err := os.Symlink(realRoot, symlinkRoot); err != nil { + t.Skipf("symlink not supported in this environment: %v", err) + } + + candidate := filepath.Join(symlinkRoot, "notes.txt") + mustWriteWorkspaceFile(t, filepath.Join(realRoot, "notes.txt"), "hello") + + allowed, err := canFallbackToCandidateOnPermission(symlinkRoot, candidate) + if err != nil { + t.Fatalf("canFallbackToCandidateOnPermission() error: %v", err) + } + if allowed { + t.Fatalf("expected symlink workspace root to reject permission fallback") + } +} + func TestWorkspaceExecutionPlanValidateForExecution(t *testing.T) { t.Parallel() diff --git a/internal/session/storage_helpers.go b/internal/session/storage_helpers.go index 637ff6d3..545f7e4d 100644 --- a/internal/session/storage_helpers.go +++ b/internal/session/storage_helpers.go @@ -40,15 +40,21 @@ func resolvePathForContainment(path string) (string, error) { if err == nil { return resolved, nil } + if errors.Is(err, os.ErrPermission) { + return "", fmt.Errorf("eval symlinks: %w", err) + } if !errors.Is(err, os.ErrNotExist) { return "", fmt.Errorf("eval symlinks: %w", err) } parent := filepath.Dir(absPath) resolvedParent, parentErr := filepath.EvalSymlinks(parent) - if parentErr != nil { + if parentErr == nil { + return filepath.Join(resolvedParent, filepath.Base(absPath)), nil + } + if errors.Is(parentErr, os.ErrPermission) { return "", fmt.Errorf("eval parent symlinks: %w", parentErr) } - return filepath.Join(resolvedParent, filepath.Base(absPath)), nil + return "", fmt.Errorf("eval parent symlinks: %w", parentErr) } // createTempFile 在目标目录中创建唯一临时文件。 diff --git a/internal/tools/bash/tool.go b/internal/tools/bash/tool.go index e02bce21..92cf5c0c 100644 --- a/internal/tools/bash/tool.go +++ b/internal/tools/bash/tool.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "strings" "time" "neo-code/internal/tools" @@ -17,8 +18,10 @@ type Tool struct { } type input struct { - Command string `json:"command"` - Workdir string `json:"workdir,omitempty"` + Command string `json:"command"` + Workdir string `json:"workdir,omitempty"` + Verification bool `json:"verification,omitempty"` + VerificationScope string `json:"verification_scope,omitempty"` } func New(root string, shell string, timeout time.Duration) *Tool { @@ -64,6 +67,14 @@ func (t *Tool) Schema() map[string]any { "type": "string", "description": "Optional working directory relative to the workspace root.", }, + "verification": map[string]any{ + "type": "boolean", + "description": "Set true when this command is explicitly used for verification.", + }, + "verification_scope": map[string]any{ + "type": "string", + "description": "Optional verification scope. Defaults to workspace when verification=true.", + }, }, "required": []string{"command"}, } @@ -84,5 +95,41 @@ func (t *Tool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.Too return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } - return t.executor.Execute(ctx, call, in.Command, in.Workdir) + result, err := t.executor.Execute(ctx, call, in.Command, in.Workdir) + result.Metadata = withVerificationMetadata(result.Metadata, in, err == nil && !result.IsError) + result.Facts = withVerificationFacts(result.Facts, in, err == nil && !result.IsError) + return result, err +} + +// withVerificationMetadata 在 bash 调用显式声明验证意图时写入结构化验证元数据。 +func withVerificationMetadata(metadata map[string]any, in input, succeeded bool) map[string]any { + scope := in.VerificationScope + if !in.Verification && scope == "" { + return metadata + } + if metadata == nil { + metadata = make(map[string]any, 3) + } + metadata["verification_performed"] = true + metadata["verification_passed"] = succeeded + if scope == "" { + scope = "workspace" + } + metadata["verification_scope"] = scope + return metadata +} + +// withVerificationFacts 在 bash 调用显式声明验证意图时写入受信的结构化事实。 +func withVerificationFacts(facts tools.ToolExecutionFacts, in input, succeeded bool) tools.ToolExecutionFacts { + scope := strings.TrimSpace(in.VerificationScope) + if !in.Verification && scope == "" { + return facts + } + facts.VerificationPerformed = true + facts.VerificationPassed = succeeded + if scope == "" { + scope = "workspace" + } + facts.VerificationScope = scope + return facts } diff --git a/internal/tools/bash/tool_test.go b/internal/tools/bash/tool_test.go index e2202ca4..8c70dd12 100644 --- a/internal/tools/bash/tool_test.go +++ b/internal/tools/bash/tool_test.go @@ -171,6 +171,40 @@ func TestToolExecuteErrorFormattingAndTruncation(t *testing.T) { } } +func TestToolExecuteEmitsVerificationMetadataWhenExplicitlyRequested(t *testing.T) { + workspace := t.TempDir() + tool := New(workspace, defaultShell(), 3*time.Second) + + args := mustMarshalArgs(t, map[string]any{ + "command": safeEchoCommand(), + "verification": true, + "verification_scope": "workspace", + }) + result, err := tool.Execute(context.Background(), tools.ToolCallInput{ + Name: tool.Name(), + Arguments: args, + Workdir: workspace, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if performed, _ := result.Metadata["verification_performed"].(bool); !performed { + t.Fatalf("expected verification_performed=true, got %#v", result.Metadata["verification_performed"]) + } + if passed, _ := result.Metadata["verification_passed"].(bool); !passed { + t.Fatalf("expected verification_passed=true, got %#v", result.Metadata["verification_passed"]) + } + if scope, _ := result.Metadata["verification_scope"].(string); scope != "workspace" { + t.Fatalf("expected verification_scope=workspace, got %#v", result.Metadata["verification_scope"]) + } + if !result.Facts.VerificationPerformed || !result.Facts.VerificationPassed { + t.Fatalf("expected verification facts to be populated, got %+v", result.Facts) + } + if result.Facts.VerificationScope != "workspace" { + t.Fatalf("expected verification fact scope workspace, got %q", result.Facts.VerificationScope) + } +} + func mustMarshalArgs(t *testing.T, value any) []byte { t.Helper() diff --git a/internal/tools/facts.go b/internal/tools/facts.go new file mode 100644 index 00000000..060ef564 --- /dev/null +++ b/internal/tools/facts.go @@ -0,0 +1,39 @@ +package tools + +import ( + "strings" + + "neo-code/internal/security" +) + +// EnrichToolResultFacts 基于权限动作与工具本地事实补齐结构化执行事实。 +// 注意:此处不信任外部工具 metadata 中的 workspace/verification 字段,避免越过信任边界。 +func EnrichToolResultFacts(action security.Action, result ToolResult) ToolResult { + facts := result.Facts + if !facts.WorkspaceWrite { + facts.WorkspaceWrite = defaultWorkspaceWriteFromAction(action) + } + if facts.VerificationPassed { + facts.VerificationPerformed = true + } + facts.VerificationScope = strings.TrimSpace(facts.VerificationScope) + if !facts.VerificationPerformed { + facts.VerificationPassed = false + facts.VerificationScope = "" + } + + result.Facts = facts + return result +} + +// defaultWorkspaceWriteFromAction 按权限动作类型推导默认写入事实,仅明确写能力才标记为写入。 +func defaultWorkspaceWriteFromAction(action security.Action) bool { + switch action.Type { + case security.ActionTypeRead: + return false + case security.ActionTypeWrite: + return true + default: + return false + } +} diff --git a/internal/tools/facts_test.go b/internal/tools/facts_test.go new file mode 100644 index 00000000..057df4a9 --- /dev/null +++ b/internal/tools/facts_test.go @@ -0,0 +1,76 @@ +package tools + +import ( + "testing" + + "neo-code/internal/security" +) + +func TestEnrichToolResultFactsDefaultsFromAction(t *testing.T) { + t.Parallel() + + read := EnrichToolResultFacts(security.Action{Type: security.ActionTypeRead}, ToolResult{}) + if read.Facts.WorkspaceWrite { + t.Fatalf("expected read action to default workspace_write=false") + } + + bash := EnrichToolResultFacts(security.Action{Type: security.ActionTypeBash}, ToolResult{}) + if bash.Facts.WorkspaceWrite { + t.Fatalf("expected bash action to default workspace_write=false") + } + + mcp := EnrichToolResultFacts(security.Action{Type: security.ActionTypeMCP}, ToolResult{}) + if mcp.Facts.WorkspaceWrite { + t.Fatalf("expected mcp action to default workspace_write=false") + } +} + +func TestEnrichToolResultFactsIgnoresUntrustedMetadata(t *testing.T) { + t.Parallel() + + result := EnrichToolResultFacts( + security.Action{Type: security.ActionTypeMCP}, + ToolResult{ + Metadata: map[string]any{ + "workspace_write": false, + "verification_performed": true, + "verification_passed": true, + "verification_scope": "workspace", + }, + }, + ) + if result.Facts.WorkspaceWrite { + t.Fatalf("expected metadata workspace_write to be ignored") + } + if result.Facts.VerificationPerformed || result.Facts.VerificationPassed { + t.Fatalf("expected metadata verification facts to be ignored, got %+v", result.Facts) + } + if result.Facts.VerificationScope != "" { + t.Fatalf("expected empty verification scope, got %q", result.Facts.VerificationScope) + } +} + +func TestEnrichToolResultFactsRespectsTrustedFacts(t *testing.T) { + t.Parallel() + + result := EnrichToolResultFacts( + security.Action{Type: security.ActionTypeBash}, + ToolResult{ + Facts: ToolExecutionFacts{ + WorkspaceWrite: true, + VerificationPerformed: true, + VerificationPassed: true, + VerificationScope: " workspace ", + }, + }, + ) + if !result.Facts.WorkspaceWrite { + t.Fatalf("expected trusted workspace write fact to be preserved") + } + if !result.Facts.VerificationPerformed || !result.Facts.VerificationPassed { + t.Fatalf("expected trusted verification facts to be preserved, got %+v", result.Facts) + } + if result.Facts.VerificationScope != "workspace" { + t.Fatalf("verification scope = %q, want workspace", result.Facts.VerificationScope) + } +} diff --git a/internal/tools/manager.go b/internal/tools/manager.go index d27b16a3..4fc7e5dd 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "log" + "path/filepath" + "runtime" "strings" "sync" "time" @@ -45,6 +47,55 @@ type microCompactSummarizerExecutor interface { MicroCompactSummarizer(name string) ContentSummarizer } +// factsEnrichingExecutor 包装底层执行器,在不信任外部 metadata 的前提下补齐受信结构化事实。 +type factsEnrichingExecutor struct { + inner Executor +} + +// newFactsEnrichingExecutor 创建带结构化事实补齐能力的执行器包装层。 +func newFactsEnrichingExecutor(inner Executor) Executor { + if inner == nil { + return nil + } + return &factsEnrichingExecutor{inner: inner} +} + +// ListAvailableSpecs 透传工具规格查询能力,不改变可见工具集。 +func (e *factsEnrichingExecutor) ListAvailableSpecs(ctx context.Context, input SpecListInput) ([]providertypes.ToolSpec, error) { + return e.inner.ListAvailableSpecs(ctx, input) +} + +// Supports 透传工具支持性判断,保证原有执行路由不受包装层影响。 +func (e *factsEnrichingExecutor) Supports(name string) bool { + return e.inner.Supports(name) +} + +// MicroCompactPolicy 透传被包装执行器的压缩策略,确保 UI/Runtime 行为与原实现一致。 +func (e *factsEnrichingExecutor) MicroCompactPolicy(name string) MicroCompactPolicy { + if source, ok := e.inner.(microCompactPolicyExecutor); ok { + return source.MicroCompactPolicy(name) + } + return MicroCompactPolicyCompact +} + +// MicroCompactSummarizer 透传被包装执行器的摘要器实现,避免包装层吞掉摘要能力。 +func (e *factsEnrichingExecutor) MicroCompactSummarizer(name string) ContentSummarizer { + if source, ok := e.inner.(microCompactSummarizerExecutor); ok { + return source.MicroCompactSummarizer(name) + } + return nil +} + +// Execute 在执行后按本地权限动作补齐可信 facts,避免运行时依赖远端 metadata。 +func (e *factsEnrichingExecutor) Execute(ctx context.Context, input ToolCallInput) (ToolResult, error) { + result, err := e.inner.Execute(ctx, input) + action, actionErr := buildPermissionAction(input) + if actionErr == nil { + result = EnrichToolResultFacts(action, result) + } + return result, err +} + // WorkspaceSandbox enforces workspace-oriented constraints before execution. type WorkspaceSandbox interface { Check(ctx context.Context, action security.Action) (*security.WorkspaceExecutionPlan, error) @@ -68,6 +119,13 @@ var ( ErrCapabilityDenied = errors.New("tools: capability denied") ) +const ( + // sandboxExternalWriteApprovalRuleID 是工作区外低风险写入的审批规则标识。 + sandboxExternalWriteApprovalRuleID = "workspace-sandbox:external-write-ask" + // sandboxExternalWriteApprovalReason 是工作区外低风险写入需要审批时的统一提示。 + sandboxExternalWriteApprovalReason = "workspace write outside workdir requires approval" +) + // PermissionDecisionError reports a non-allow permission decision. type PermissionDecisionError struct { decision security.Decision @@ -190,7 +248,7 @@ func NewManager(executor Executor, engine security.PermissionEngine, sandbox Wor } return &DefaultManager{ - executor: executor, + executor: newFactsEnrichingExecutor(executor), engine: engine, sandbox: sandbox, sessionDecisions: newSessionPermissionMemory(), @@ -322,19 +380,297 @@ func (m *DefaultManager) Execute(ctx context.Context, input ToolCallInput) (Tool return result, permissionErrorFromDecision(decision) } - plan, err := m.sandbox.Check(ctx, action) - if err != nil { - result := NewErrorResult(input.Name, "workspace sandbox rejected action", err.Error(), actionMetadata(action)) - result.ToolCallID = input.ID - return result, err - } + plan, err := m.sandbox.Check(ctx, action) + if err != nil { + if decision, decisionMatched := resolveSandboxOutsideWriteDecision(input, action, err, m.sessionDecisions); decisionMatched { + if decision.Decision != security.DecisionAllow { + result := blockedToolResult(input, decision) + return result, permissionErrorFromDecision(decision) + } + m.auditCapabilityDecision(action, string(security.DecisionAllow), decision.Reason) + return m.executor.Execute(ctx, input) + } else { + result := NewErrorResult(input.Name, "workspace sandbox rejected action", sandboxErrorDetails(action, err), actionMetadata(action)) + result.ToolCallID = input.ID + return result, err + } + } else if plan != nil { + input.WorkspacePlan = plan + } m.auditCapabilityDecision(action, string(security.DecisionAllow), "") - if plan != nil { - input.WorkspacePlan = plan + return m.executor.Execute(ctx, input) +} + +// resolveSandboxOutsideWriteDecision 将“工作区外低风险写入”沙箱拒绝收敛为 ask/remembered allow/remembered deny。 +func resolveSandboxOutsideWriteDecision( + input ToolCallInput, + action security.Action, + sandboxErr error, + sessionMemory *sessionPermissionMemory, +) (security.CheckResult, bool) { + if !isSandboxOutsideWriteApprovalCandidate(action, sandboxErr) { + return security.CheckResult{}, false } - return m.executor.Execute(ctx, input) + decision := security.CheckResult{ + Decision: security.DecisionAsk, + Action: action, + Rule: &security.Rule{ + ID: sandboxExternalWriteApprovalRuleID, + Type: action.Type, + Resource: action.Payload.Resource, + Decision: security.DecisionAsk, + Reason: sandboxExternalWriteApprovalReason, + }, + Reason: sandboxExternalWriteApprovalReason, + } + + if sessionMemory != nil { + if rememberedDecision, rememberedScope, ok := sessionMemory.resolve(input.SessionID, action); ok { + decision = security.CheckResult{ + Decision: rememberedDecision, + Action: action, + Rule: &security.Rule{ + ID: "session-memory:" + string(rememberedScope), + Type: action.Type, + Resource: action.Payload.Resource, + Decision: rememberedDecision, + Reason: sessionDecisionReason(rememberedScope), + }, + Reason: sessionDecisionReason(rememberedScope), + } + } + } + + return decision, true +} + +// isSandboxOutsideWriteApprovalCandidate 判断当前沙箱错误是否可升级为“工作区外低风险写入审批”。 +func isSandboxOutsideWriteApprovalCandidate(action security.Action, sandboxErr error) bool { + if isWorkspaceSymlinkViolationError(sandboxErr) { + return false + } + if !isWorkspaceBoundaryViolationError(sandboxErr) { + return false + } + if action.Type != security.ActionTypeWrite { + return false + } + resource := strings.TrimSpace(strings.ToLower(action.Payload.Resource)) + toolName := strings.TrimSpace(strings.ToLower(action.Payload.ToolName)) + if resource != ToolNameFilesystemWriteFile && toolName != ToolNameFilesystemWriteFile { + return false + } + + targetPath := resolveActionSandboxTargetPath(action) + if targetPath == "" { + return false + } + return isLowRiskExternalWritePath(targetPath) +} + +// isWorkspaceBoundaryViolationError 判断错误是否由工作区边界校验触发。 +func isWorkspaceBoundaryViolationError(err error) bool { + message := strings.ToLower(strings.TrimSpace(errorMessage(err))) + if message == "" { + return false + } + return strings.Contains(message, "escapes workspace root") || + strings.Contains(message, "different volume than workspace root") +} + +// isWorkspaceSymlinkViolationError 判断沙箱拒绝是否来自符号链接越界逃逸。 +func isWorkspaceSymlinkViolationError(err error) bool { + message := strings.ToLower(strings.TrimSpace(errorMessage(err))) + if message == "" { + return false + } + return strings.Contains(message, "escapes workspace root via symlink") +} + +// resolveActionSandboxTargetPath 将 action 的 sandbox target 解析为可判定风险的绝对路径。 +func resolveActionSandboxTargetPath(action security.Action) string { + target := strings.TrimSpace(action.Payload.SandboxTarget) + if target == "" { + target = strings.TrimSpace(action.Payload.Target) + } + if target == "" { + return "" + } + if !filepath.IsAbs(target) && strings.TrimSpace(action.Payload.Workdir) != "" { + target = filepath.Join(strings.TrimSpace(action.Payload.Workdir), target) + } + if absoluteTarget, err := filepath.Abs(target); err == nil { + target = absoluteTarget + } + return filepath.Clean(target) +} + +// isLowRiskExternalWritePath 判断工作区外写入目标是否属于可审批放行的低风险路径。 +func isLowRiskExternalWritePath(targetPath string) bool { + cleaned := strings.TrimSpace(filepath.Clean(targetPath)) + if cleaned == "" || cleaned == "." { + return false + } + if isSystemProtectedPath(cleaned) { + return false + } + if isUserStartupProfilePath(cleaned) { + return false + } + if isHighRiskExecutableExtension(filepath.Ext(cleaned)) { + return false + } + return true +} + +// isUserStartupProfilePath 判断路径是否命中用户级 shell/profile 启动文件,命中后必须保持硬拒绝。 +func isUserStartupProfilePath(path string) bool { + return isUserStartupProfilePathForOS(path, runtime.GOOS) +} + +// isUserStartupProfilePathForOS 按指定操作系统判定路径是否命中用户级 shell/profile 启动文件。 +func isUserStartupProfilePathForOS(path string, goos string) bool { + cleaned := strings.ToLower(strings.TrimSpace(filepath.Clean(path))) + if cleaned == "" || cleaned == "." { + return false + } + + base := filepath.Base(cleaned) + switch base { + case ".bashrc", ".bash_profile", ".bash_login", ".profile", + ".zshrc", ".zprofile", ".zlogin", ".zshenv", ".cshrc", ".tcshrc", + "profile.ps1", "microsoft.powershell_profile.ps1", + "microsoft.vscode_profile.ps1", "profile": + return true + } + + segments := splitPathSegments(cleaned) + if len(segments) == 0 { + return false + } + if strings.EqualFold(strings.TrimSpace(goos), "windows") { + for i := 0; i+2 < len(segments); i++ { + if segments[i] == "documents" && segments[i+1] == "windowspowershell" && strings.HasSuffix(base, ".ps1") { + return true + } + if segments[i] == "documents" && segments[i+1] == "powershell" && strings.HasSuffix(base, ".ps1") { + return true + } + } + return false + } + for i := 0; i+2 < len(segments); i++ { + if segments[i] == ".config" && segments[i+1] == "fish" && base == "config.fish" { + return true + } + } + return false +} + +// isSystemProtectedPath 判定路径是否命中系统受保护目录,命中后必须保持硬拒绝。 +func isSystemProtectedPath(path string) bool { + return isSystemProtectedPathForOS(path, runtime.GOOS) +} + +// isSystemProtectedPathForOS 按指定操作系统判定路径是否命中系统受保护目录。 +func isSystemProtectedPathForOS(path string, goos string) bool { + normalized := strings.ToLower(filepath.Clean(path)) + if strings.EqualFold(strings.TrimSpace(goos), "windows") { + volume := strings.ToLower(filepath.VolumeName(normalized)) + if volume == "" && len(normalized) >= 2 && normalized[1] == ':' { + volume = normalized[:2] + } + rest := strings.TrimPrefix(normalized, volume) + rest = strings.TrimLeft(rest, `\/`) + if rest == "" { + return true + } + segments := splitPathSegments(rest) + switch segments[0] { + case "windows", "program files", "program files (x86)", "programdata", + "$recycle.bin", "system volume information", "recovery", "boot": + return true + } + if len(segments) >= 3 && segments[0] == "users" && segments[2] == "appdata" { + return true + } + } else { + trimmed := strings.TrimLeft(normalized, "/") + segments := splitPathSegments(trimmed) + if len(segments) == 0 { + return true + } + switch segments[0] { + case "etc", "bin", "sbin", "usr", "var", "lib", "lib64", "boot", "proc", "sys", "dev", "run", "root": + return true + } + } + + for _, segment := range splitPathSegments(normalized) { + if segment == ".ssh" { + return true + } + } + return false +} + +// isHighRiskExecutableExtension 识别高风险可执行文件后缀,命中后不走审批放行链路。 +func isHighRiskExecutableExtension(extension string) bool { + switch strings.ToLower(strings.TrimSpace(extension)) { + case ".exe", ".dll", ".sys", ".bat", ".cmd", ".com", ".scr", ".msi", ".reg": + return true + default: + return false + } +} + +// splitPathSegments 把路径按目录分隔符拆成稳定片段,忽略空片段。 +func splitPathSegments(path string) []string { + normalized := strings.ReplaceAll(path, "\\", "/") + rawSegments := strings.Split(normalized, "/") + segments := make([]string, 0, len(rawSegments)) + for _, segment := range rawSegments { + trimmed := strings.TrimSpace(segment) + if trimmed == "" { + continue + } + segments = append(segments, trimmed) + } + return segments +} + +// sandboxErrorDetails 生成可回灌给模型的沙箱拒绝详情,便于模型正确感知失败原因。 +func sandboxErrorDetails(action security.Action, sandboxErr error) string { + securityMessage := strings.TrimSpace(errorMessage(sandboxErr)) + if securityMessage == "" { + securityMessage = "sandbox rejected action" + } + if !strings.HasPrefix(strings.ToLower(securityMessage), "security:") { + securityMessage = "security: " + securityMessage + } + parts := []string{ + securityMessage, + } + if workdir := strings.TrimSpace(action.Payload.Workdir); workdir != "" { + parts = append(parts, "workdir: "+workdir) + } + if target := strings.TrimSpace(action.Payload.Target); target != "" { + parts = append(parts, "target: "+target) + } + if sandboxTarget := strings.TrimSpace(action.Payload.SandboxTarget); sandboxTarget != "" { + parts = append(parts, "sandbox_target: "+sandboxTarget) + } + return strings.Join(parts, "\n") +} + +// errorMessage 提取错误文本,统一处理 nil 输入避免重复分支。 +func errorMessage(err error) string { + if err == nil { + return "" + } + return err.Error() } // verifyCapabilityToken 校验 capability token 的签名、绑定关系与时效性。 diff --git a/internal/tools/manager_test.go b/internal/tools/manager_test.go index 2656a98d..35103ceb 100644 --- a/internal/tools/manager_test.go +++ b/internal/tools/manager_test.go @@ -3,8 +3,10 @@ package tools import ( "context" "errors" + "fmt" "os" "path/filepath" + "runtime" "strings" "testing" "time" @@ -71,6 +73,10 @@ func (s *stubSandbox) Check(ctx context.Context, action security.Action) (*secur return s.plan, s.err } +func isWindowsRuntime() bool { + return runtime.GOOS == "windows" +} + func mustAllowEngine(t *testing.T) security.PermissionEngine { t.Helper() engine, err := security.NewStaticGateway(security.DecisionAllow, nil) @@ -234,6 +240,15 @@ func TestDefaultManagerListAvailableSpecsBoundaries(t *testing.T) { func TestDefaultManagerExecute(t *testing.T) { t.Parallel() + lowRiskOutsidePath := filepath.Join(string(filepath.Separator), "tmp", "snake_game.py") + workspaceRoot := filepath.Join(string(filepath.Separator), "workspace", "project") + protectedOutsidePath := filepath.Join(string(filepath.Separator), "etc", "hosts") + if isWindowsRuntime() { + lowRiskOutsidePath = `C:\Users\tester\Desktop\SnakeGame\snake_game.py` + workspaceRoot = `C:\workspace\project` + protectedOutsidePath = `C:\Windows\System32\drivers\etc\hosts` + } + tests := []struct { name string rules []security.Rule @@ -301,6 +316,36 @@ func TestDefaultManagerExecute(t *testing.T) { expectCalls: 0, expectSandboxRuns: 1, }, + { + name: "low risk outside workspace write becomes ask", + input: ToolCallInput{ + ID: "call-6", + Name: "filesystem_write_file", + Arguments: []byte(fmt.Sprintf(`{"path":%q,"content":"hi"}`, lowRiskOutsidePath)), + Workdir: workspaceRoot, + SessionID: "session-low-risk-outside", + }, + sandboxErr: fmt.Errorf("security: path %q escapes workspace root", lowRiskOutsidePath), + expectErr: sandboxExternalWriteApprovalReason, + expectContent: []string{"tool error", "reason: " + sandboxExternalWriteApprovalReason}, + expectDecision: "ask", + expectCalls: 0, + expectSandboxRuns: 1, + }, + { + name: "protected outside path keeps hard sandbox reject", + input: ToolCallInput{ + ID: "call-7", + Name: "filesystem_write_file", + Arguments: []byte(fmt.Sprintf(`{"path":%q,"content":"hi"}`, protectedOutsidePath)), + Workdir: workspaceRoot, + }, + sandboxErr: fmt.Errorf("security: path %q escapes workspace root", protectedOutsidePath), + expectErr: "escapes workspace root", + expectContent: []string{"tool error", "reason: workspace sandbox rejected action", "target: " + protectedOutsidePath}, + expectCalls: 0, + expectSandboxRuns: 1, + }, { name: "unknown tool uses executor error", input: ToolCallInput{ @@ -367,6 +412,319 @@ func TestDefaultManagerExecute(t *testing.T) { } } +func TestDefaultManagerSandboxOutsideWriteSessionMemory(t *testing.T) { + t.Parallel() + + outsidePath := filepath.Join(string(filepath.Separator), "tmp", "snake_game.py") + workspaceRoot := filepath.Join(string(filepath.Separator), "workspace", "project") + if isWindowsRuntime() { + outsidePath = `C:\Users\tester\Desktop\SnakeGame\snake_game.py` + workspaceRoot = `C:\workspace\project` + } + + registry := NewRegistry() + writeTool := &managerStubTool{name: "filesystem_write_file", content: "ok"} + registry.Register(writeTool) + + manager, err := NewManager(registry, mustAllowEngine(t), &stubSandbox{ + err: fmt.Errorf("security: path %q escapes workspace root", outsidePath), + }) + if err != nil { + t.Fatalf("new manager: %v", err) + } + + input := ToolCallInput{ + ID: "call-outside-ask", + Name: "filesystem_write_file", + Arguments: []byte(fmt.Sprintf(`{"path":%q,"content":"hi"}`, outsidePath)), + Workdir: workspaceRoot, + SessionID: "session-outside-ask", + } + + _, execErr := manager.Execute(context.Background(), input) + var permissionErr *PermissionDecisionError + if !errors.As(execErr, &permissionErr) || permissionErr.Decision() != "ask" { + t.Fatalf("expected initial ask decision, got %v", execErr) + } + + if rememberErr := manager.RememberSessionDecision(input.SessionID, permissionErr.Action(), SessionPermissionScopeAlways); rememberErr != nil { + t.Fatalf("remember outside write allow: %v", rememberErr) + } + + _, err = manager.Execute(context.Background(), input) + if err != nil { + t.Fatalf("expected remembered allow retry to execute, got %v", err) + } + if writeTool.callCount != 1 { + t.Fatalf("expected write tool to execute after remembered allow, got %d", writeTool.callCount) + } +} + +func TestSandboxOutsideWriteApprovalCandidate(t *testing.T) { + t.Parallel() + + workspaceRoot := filepath.Join(string(filepath.Separator), "workspace", "project") + lowRiskPath := filepath.Join(string(filepath.Separator), "tmp", "sample.py") + protectedPath := filepath.Join(string(filepath.Separator), "etc", "hosts") + highRiskExecutable := filepath.Join(string(filepath.Separator), "tmp", "sample.exe") + startupProfilePath := filepath.Join(string(filepath.Separator), "home", "tester", ".bashrc") + if isWindowsRuntime() { + workspaceRoot = `C:\workspace\project` + lowRiskPath = `C:\Users\tester\Desktop\sample.py` + protectedPath = `C:\Windows\System32\drivers\etc\hosts` + highRiskExecutable = `C:\Users\tester\Desktop\sample.exe` + startupProfilePath = `C:\Users\tester\Documents\PowerShell\Microsoft.PowerShell_profile.ps1` + } + + buildAction := func(target string, toolName string) security.Action { + return security.Action{ + Type: security.ActionTypeWrite, + Payload: security.ActionPayload{ + ToolName: toolName, + Resource: toolName, + Operation: "write_file", + Workdir: workspaceRoot, + TargetType: security.TargetTypePath, + Target: target, + SandboxTarget: target, + }, + } + } + + tests := []struct { + name string + action security.Action + sandboxErr error + want bool + }{ + { + name: "boundary violation low risk file asks approval", + action: buildAction(lowRiskPath, "filesystem_write_file"), + sandboxErr: fmt.Errorf("security: path %q escapes workspace root", lowRiskPath), + want: true, + }, + { + name: "non-boundary sandbox error keeps hard reject", + action: buildAction(lowRiskPath, "filesystem_write_file"), + sandboxErr: errors.New("workspace denied"), + want: false, + }, + { + name: "protected system path keeps hard reject", + action: buildAction(protectedPath, "filesystem_write_file"), + sandboxErr: fmt.Errorf("security: path %q escapes workspace root", protectedPath), + want: false, + }, + { + name: "high risk executable extension keeps hard reject", + action: buildAction(highRiskExecutable, "filesystem_write_file"), + sandboxErr: fmt.Errorf("security: path %q escapes workspace root", highRiskExecutable), + want: false, + }, + { + name: "write tool not in allowlist keeps hard reject", + action: buildAction(lowRiskPath, "filesystem_edit"), + sandboxErr: fmt.Errorf("security: path %q escapes workspace root", lowRiskPath), + want: false, + }, + { + name: "symlink workspace escape keeps hard reject", + action: buildAction(lowRiskPath, "filesystem_write_file"), + sandboxErr: fmt.Errorf("security: path %q escapes workspace root via symlink", filepath.Join("link", "sample.py")), + want: false, + }, + { + name: "startup profile path keeps hard reject", + action: buildAction(startupProfilePath, "filesystem_write_file"), + sandboxErr: fmt.Errorf("security: path %q escapes workspace root", startupProfilePath), + want: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := isSandboxOutsideWriteApprovalCandidate(tt.action, tt.sandboxErr) + if got != tt.want { + t.Fatalf("expected %v, got %v", tt.want, got) + } + }) + } +} + +func TestSandboxOutsideWriteUtilityHelpers(t *testing.T) { + t.Parallel() + + t.Run("candidate requires write action", func(t *testing.T) { + t.Parallel() + action := security.Action{ + Type: security.ActionTypeRead, + Payload: security.ActionPayload{ + ToolName: ToolNameFilesystemWriteFile, + Resource: ToolNameFilesystemWriteFile, + Workdir: "/workspace/project", + Target: "/tmp/note.txt", + SandboxTarget: "/tmp/note.txt", + }, + } + if got := isSandboxOutsideWriteApprovalCandidate(action, errors.New(`security: path "/tmp/note.txt" escapes workspace root`)); got { + t.Fatalf("expected non-write action not to be candidate") + } + }) + + t.Run("candidate requires resolvable target path", func(t *testing.T) { + t.Parallel() + action := security.Action{ + Type: security.ActionTypeWrite, + Payload: security.ActionPayload{ + ToolName: ToolNameFilesystemWriteFile, + Resource: ToolNameFilesystemWriteFile, + Workdir: "/workspace/project", + }, + } + if got := isSandboxOutsideWriteApprovalCandidate(action, errors.New(`security: path "/tmp/note.txt" escapes workspace root`)); got { + t.Fatalf("expected empty target not to be candidate") + } + }) + + t.Run("workspace error recognizers handle nil", func(t *testing.T) { + t.Parallel() + if isWorkspaceBoundaryViolationError(nil) { + t.Fatalf("expected nil error not to be workspace boundary violation") + } + if isWorkspaceSymlinkViolationError(nil) { + t.Fatalf("expected nil error not to be workspace symlink violation") + } + }) + + t.Run("resolve action sandbox target path branches", func(t *testing.T) { + t.Parallel() + if got := resolveActionSandboxTargetPath(security.Action{}); got != "" { + t.Fatalf("expected empty target path, got %q", got) + } + + actionWithTarget := security.Action{ + Payload: security.ActionPayload{ + Target: "logs/app.log", + Workdir: "/workspace/project", + }, + } + resolved := resolveActionSandboxTargetPath(actionWithTarget) + if !strings.HasSuffix(filepath.ToSlash(resolved), "/workspace/project/logs/app.log") { + t.Fatalf("expected target fallback with workdir join, got %q", resolved) + } + + actionWithSandboxTarget := security.Action{ + Payload: security.ActionPayload{ + Target: "/tmp/ignored.txt", + SandboxTarget: "/tmp/final.txt", + }, + } + if got := resolveActionSandboxTargetPath(actionWithSandboxTarget); filepath.Clean(got) != filepath.Clean("/tmp/final.txt") { + t.Fatalf("expected sandbox target to win, got %q", got) + } + }) + + t.Run("low risk path rejects empty path", func(t *testing.T) { + t.Parallel() + if isLowRiskExternalWritePath(" . ") { + t.Fatalf("expected dot path to be rejected") + } + }) + + t.Run("startup profile detector os branches", func(t *testing.T) { + t.Parallel() + if isUserStartupProfilePathForOS(".", "linux") { + t.Fatalf("expected dot path not to be startup profile") + } + if isUserStartupProfilePathForOS(" / ", "linux") { + t.Fatalf("expected root path not to be startup profile") + } + if !isUserStartupProfilePathForOS(`/Users/tester/Documents/WindowsPowerShell/custom_profile.ps1`, "windows") { + t.Fatalf("expected windows powershell profile directory to be recognized") + } + if !isUserStartupProfilePathForOS(`/Users/tester/Documents/PowerShell/custom_profile.ps1`, "windows") { + t.Fatalf("expected powershell profile directory to be recognized") + } + if isUserStartupProfilePathForOS(`/Users/tester/Documents/PowerShell/readme.txt`, "windows") { + t.Fatalf("expected non-ps1 path not to be startup profile") + } + if !isUserStartupProfilePathForOS(`/home/tester/.config/fish/config.fish`, "linux") { + t.Fatalf("expected fish config path to be startup profile") + } + }) + + t.Run("system protected path detector os branches", func(t *testing.T) { + t.Parallel() + if !isSystemProtectedPathForOS("/", "linux") { + t.Fatalf("expected linux root to be protected") + } + if !isSystemProtectedPathForOS("/home/tester/.ssh/config", "linux") { + t.Fatalf("expected .ssh path to be protected") + } + if isSystemProtectedPathForOS("/home/tester/Documents/notes.txt", "linux") { + t.Fatalf("expected regular linux user path not to be protected") + } + if !isSystemProtectedPathForOS(`C:\Windows\System32\drivers\etc\hosts`, "windows") { + t.Fatalf("expected windows system path to be protected") + } + if !isSystemProtectedPathForOS(`C:\Users\tester\AppData\Roaming\config`, "windows") { + t.Fatalf("expected appdata path to be protected") + } + if !isSystemProtectedPathForOS(`C:`, "windows") { + t.Fatalf("expected windows drive root to be protected") + } + if isSystemProtectedPathForOS(`C:\Users\tester\Desktop\note.txt`, "windows") { + t.Fatalf("expected regular windows user path not to be protected") + } + }) + + t.Run("error message handles nil", func(t *testing.T) { + t.Parallel() + if got := errorMessage(nil); got != "" { + t.Fatalf("expected empty error message for nil error, got %q", got) + } + }) +} + +func TestSandboxErrorDetailsIncludesWorkspaceContext(t *testing.T) { + t.Parallel() + + action := security.Action{ + Type: security.ActionTypeWrite, + Payload: security.ActionPayload{ + ToolName: "filesystem_write_file", + Resource: "filesystem_write_file", + Workdir: `C:\workspace\project`, + Target: `C:\Users\tester\Desktop\SnakeGame\snake_game.py`, + SandboxTarget: `C:\Users\tester\Desktop\SnakeGame\snake_game.py`, + }, + } + if !isWindowsRuntime() { + action.Payload.Workdir = "/workspace/project" + action.Payload.Target = "/tmp/snake_game.py" + action.Payload.SandboxTarget = "/tmp/snake_game.py" + } + + details := sandboxErrorDetails(action, errors.New("security: path escapes workspace root")) + for _, fragment := range []string{ + "security: path escapes workspace root", + "workdir: " + action.Payload.Workdir, + "target: " + action.Payload.Target, + "sandbox_target: " + action.Payload.SandboxTarget, + } { + if !strings.Contains(details, fragment) { + t.Fatalf("expected details containing %q, got %q", fragment, details) + } + } + + withoutPrefix := sandboxErrorDetails(action, errors.New("path escapes workspace root")) + if !strings.Contains(withoutPrefix, "security: path escapes workspace root") { + t.Fatalf("expected details to normalize security prefix, got %q", withoutPrefix) + } +} + func TestDefaultManagerExecuteBoundaries(t *testing.T) { t.Parallel() @@ -1480,6 +1838,58 @@ func TestDefaultManagerExecuteMCPRememberDoesNotBroadenAcrossTools(t *testing.T) } } +func TestDefaultManagerExecuteMCPMetadataCannotDriveTrustedFacts(t *testing.T) { + t.Parallel() + + registry := NewRegistry() + mcpRegistry := mcp.NewRegistry() + if err := mcpRegistry.RegisterServer("github", "stdio", "v1", &stubMCPClient{ + tools: []mcp.ToolDescriptor{ + {Name: "create_issue", Description: "create"}, + }, + callResult: mcp.CallResult{ + Content: "ok", + Metadata: map[string]any{ + "workspace_write": true, + "verification_performed": true, + "verification_passed": true, + "verification_scope": "workspace", + }, + }, + }); err != nil { + t.Fatalf("register mcp server: %v", err) + } + if err := mcpRegistry.RefreshServerTools(context.Background(), "github"); err != nil { + t.Fatalf("refresh mcp tools: %v", err) + } + registry.SetMCPRegistry(mcpRegistry) + + engine, err := security.NewStaticGateway(security.DecisionAllow, nil) + if err != nil { + t.Fatalf("new engine: %v", err) + } + manager, err := NewManager(registry, engine, nil) + if err != nil { + t.Fatalf("new manager: %v", err) + } + + result, execErr := manager.Execute(context.Background(), ToolCallInput{ + ID: "call-mcp-facts", + Name: "mcp.github.create_issue", + Arguments: []byte(`{"title":"hello"}`), + SessionID: "session-mcp-facts", + }) + if execErr != nil { + t.Fatalf("execute mcp: %v", execErr) + } + if result.Facts.WorkspaceWrite { + t.Fatalf("expected untrusted metadata to not mark workspace write, got %+v", result.Facts) + } + if result.Facts.VerificationPerformed || result.Facts.VerificationPassed || result.Facts.VerificationScope != "" { + t.Fatalf("expected untrusted metadata to not mark verification facts, got %+v", result.Facts) + } +} + func TestDefaultManagerExecuteMCPServerDenyUsesTraceableRule(t *testing.T) { t.Parallel() @@ -1826,6 +2236,26 @@ func TestDefaultManagerExecuteCapabilityTokenValidation(t *testing.T) { }, expectErr: "requires non-empty action agent_id", }, + { + name: "deny agent mismatch", + buildInput: func(t *testing.T, manager *DefaultManager) ToolCallInput { + t.Helper() + signed, err := manager.CapabilitySigner().Sign(baseToken) + if err != nil { + t.Fatalf("sign token: %v", err) + } + return ToolCallInput{ + ID: "call-agent-mismatch", + Name: "filesystem_read_file", + Arguments: []byte(`{"path":"README.md"}`), + Workdir: workdir, + TaskID: baseToken.TaskID, + AgentID: "agent-other", + CapabilityToken: &signed, + } + }, + expectErr: "agent_id does not match action", + }, } for _, tt := range testCases { diff --git a/internal/tools/permission_mapper.go b/internal/tools/permission_mapper.go index 8537b1ba..9626ea3e 100644 --- a/internal/tools/permission_mapper.go +++ b/internal/tools/permission_mapper.go @@ -136,7 +136,7 @@ func extractStringArgument(raw []byte, key string) string { var payload map[string]any if err := json.Unmarshal(raw, &payload); err != nil { - return "" + return extractStringArgumentFallback(string(raw), key) } value, ok := payload[key].(string) @@ -146,6 +146,30 @@ func extractStringArgument(raw []byte, key string) string { return strings.TrimSpace(value) } +// extractStringArgumentFallback 在参数不是严格合法 JSON 时做最小字符串提取,兼容未转义的 Windows 路径。 +func extractStringArgumentFallback(raw string, key string) string { + quotedKey := `"` + strings.TrimSpace(key) + `"` + start := strings.Index(raw, quotedKey) + if start < 0 { + return "" + } + rest := raw[start+len(quotedKey):] + colon := strings.Index(rest, ":") + if colon < 0 { + return "" + } + rest = strings.TrimSpace(rest[colon+1:]) + if !strings.HasPrefix(rest, `"`) { + return "" + } + rest = rest[1:] + end := strings.Index(rest, `"`) + if end < 0 { + return "" + } + return strings.TrimSpace(rest[:end]) +} + // extractSpawnSubAgentTarget 提取 spawn_subagent 的稳定权限目标,优先 items[].id,再回退 id/prompt。 func extractSpawnSubAgentTarget(raw []byte) string { if len(raw) == 0 { diff --git a/internal/tools/registry.go b/internal/tools/registry.go index 45a1e3fd..24abbb8e 100644 --- a/internal/tools/registry.go +++ b/internal/tools/registry.go @@ -218,6 +218,9 @@ func (r *Registry) Execute(ctx context.Context, input ToolCallInput) (ToolResult }, } for key, value := range callResult.Metadata { + if shouldSkipMCPMetadataKey(key, result.Metadata) { + continue + } result.Metadata[key] = value } if callErr != nil { @@ -413,3 +416,20 @@ func parseMCPToolFullName(fullName string) (string, string, bool) { } return parts[1], parts[2], true } + +// shouldSkipMCPMetadataKey 过滤 MCP 远端透传 metadata 中会影响本地安全语义或覆盖保留键的字段。 +func shouldSkipMCPMetadataKey(key string, existing map[string]any) bool { + normalized := strings.ToLower(strings.TrimSpace(key)) + if normalized == "" { + return true + } + if _, reserved := existing[normalized]; reserved { + return true + } + switch normalized { + case "workspace_write", "verification_performed", "verification_passed", "verification_scope": + return true + default: + return false + } +} diff --git a/internal/tools/registry_test.go b/internal/tools/registry_test.go index 59ddb78f..1191317c 100644 --- a/internal/tools/registry_test.go +++ b/internal/tools/registry_test.go @@ -337,7 +337,11 @@ func TestRegistryExecuteDispatchesToMCPAdapter(t *testing.T) { callResult: mcp.CallResult{ Content: "mcp ok", Metadata: map[string]any{ - "latency_ms": 12, + "latency_ms": 12, + "verification_passed": true, + "workspace_write": true, + "mcp_server_id": "override", + "verification_performed": true, }, }, }); err != nil { @@ -368,6 +372,15 @@ func TestRegistryExecuteDispatchesToMCPAdapter(t *testing.T) { if result.Metadata["mcp_server_id"] != "docs" || result.Metadata["mcp_tool_name"] != "search" { t.Fatalf("unexpected mcp metadata: %+v", result.Metadata) } + if result.Metadata["latency_ms"] != 12 { + t.Fatalf("expected safe metadata passthrough, got %+v", result.Metadata) + } + if _, exists := result.Metadata["workspace_write"]; exists { + t.Fatalf("expected workspace_write metadata to be filtered, got %+v", result.Metadata) + } + if _, exists := result.Metadata["verification_passed"]; exists { + t.Fatalf("expected verification metadata to be filtered, got %+v", result.Metadata) + } } func TestRegistryExecuteRejectsPolicyDeniedMCPTool(t *testing.T) { diff --git a/internal/tools/types.go b/internal/tools/types.go index bcb71607..68b30038 100644 --- a/internal/tools/types.go +++ b/internal/tools/types.go @@ -95,6 +95,15 @@ type ToolResult struct { Content string IsError bool Metadata map[string]any + Facts ToolExecutionFacts +} + +// ToolExecutionFacts 描述工具执行产出的结构化运行事实,供 runtime 做写入/验证控制。 +type ToolExecutionFacts struct { + WorkspaceWrite bool + VerificationPerformed bool + VerificationPassed bool + VerificationScope string } // ToolSpec 对齐 provider 层 tool schema 结构。 diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index fc6c50d6..7cb01e8d 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -24,6 +24,7 @@ import ( providertypes "neo-code/internal/provider/types" agentruntime "neo-code/internal/runtime" approvalflow "neo-code/internal/runtime/approval" + "neo-code/internal/runtime/controlplane" agentsession "neo-code/internal/session" "neo-code/internal/tools" tuistatus "neo-code/internal/tui/core/status" @@ -1076,6 +1077,12 @@ func runtimeEventPhaseChangedHandler(a *App, event agentruntime.RuntimeEvent) bo a.setRunProgress(0.6, "Running tools") case "verify": a.setRunProgress(0.82, "Verifying") + case "compacting": + a.setRunProgress(0.9, "Compacting context") + case "waiting_permission": + a.setRunProgress(0.88, "Awaiting permission") + case "stopped": + a.setRunProgress(1, "Stopped") } return false } @@ -1093,16 +1100,23 @@ func runtimeEventStopReasonDecidedHandler(a *App, event agentruntime.RuntimeEven a.pendingPermission = nil a.clearRunProgress() - reason := strings.ToLower(strings.TrimSpace(string(payload.Reason))) + reason := controlplane.StopReason(strings.ToUpper(strings.TrimSpace(string(payload.Reason)))) switch reason { - case "success": - if strings.TrimSpace(a.state.ExecutionError) == "" { - a.state.StatusText = statusReady - } - case "canceled": + case controlplane.StopReasonCompleted: + a.state.ExecutionError = "" + a.state.StatusText = statusReady + case controlplane.StopReasonUserInterrupt: a.state.ExecutionError = "" a.state.StatusText = statusCanceled a.appendActivity("run", "Canceled current run", "", false) + case controlplane.StopReasonFatalError: + detail := strings.TrimSpace(payload.Detail) + if detail == "" { + detail = "runtime stopped" + } + a.state.ExecutionError = detail + a.state.StatusText = detail + a.appendActivity("run", "Runtime stopped", detail, true) default: detail := strings.TrimSpace(payload.Detail) if detail == "" { diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index d6f7725d..d0736497 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -26,6 +26,9 @@ func TestRuntimeEventPhaseChangedHandlerBranches(t *testing.T) { {to: " plan ", wantValue: 0.3, wantLabel: "Planning"}, {to: "execute", wantValue: 0.6, wantLabel: "Running tools"}, {to: "VERIFY", wantValue: 0.82, wantLabel: "Verifying"}, + {to: "compacting", wantValue: 0.9, wantLabel: "Compacting context"}, + {to: " waiting_permission ", wantValue: 0.88, wantLabel: "Awaiting permission"}, + {to: "stopped", wantValue: 1, wantLabel: "Stopped"}, } for _, tc := range cases { app.clearRunProgress() @@ -60,7 +63,7 @@ func TestRuntimeEventStopReasonDecidedHandlerBranches(t *testing.T) { } handled := runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason(" success ")}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason(" stop_completed ")}, }) if handled { t.Fatalf("expected handler to return false") @@ -81,7 +84,7 @@ func TestRuntimeEventStopReasonDecidedHandlerBranches(t *testing.T) { app.state.ExecutionError = "" app.state.StatusText = "not-ready" runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("success")}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("STOP_COMPLETED")}, }) if app.state.StatusText != statusReady { t.Fatalf("expected success with empty execution error to set ready status") @@ -90,28 +93,28 @@ func TestRuntimeEventStopReasonDecidedHandlerBranches(t *testing.T) { app.state.ExecutionError = "boom" app.state.StatusText = "" runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("success")}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("STOP_COMPLETED")}, }) - if app.state.StatusText == statusReady { - t.Fatalf("expected success branch to keep status unchanged when execution error exists") + if app.state.StatusText != statusReady || app.state.ExecutionError != "" { + t.Fatalf("expected completed state to clear error and set ready status, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) } runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("canceled")}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("STOP_USER_INTERRUPT")}, }) if app.state.ExecutionError != "" || app.state.StatusText != statusCanceled { t.Fatalf("expected canceled state to clear error and set canceled status") } runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("error"), Detail: " "}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("STOP_FATAL_ERROR"), Detail: " "}, }) if app.state.StatusText != "runtime stopped" || app.state.ExecutionError != "runtime stopped" { t.Fatalf("expected default stop detail, got status=%q err=%q", app.state.StatusText, app.state.ExecutionError) } runtimeEventStopReasonDecidedHandler(&app, agentruntime.RuntimeEvent{ - Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("error"), Detail: "explicit failure"}, + Payload: agentruntime.StopReasonDecidedPayload{Reason: controlplane.StopReason("STOP_FATAL_ERROR"), Detail: "explicit failure"}, }) if app.state.StatusText != "explicit failure" || app.state.ExecutionError != "explicit failure" { t.Fatalf("expected explicit stop detail to be surfaced")