diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 8dc24f54..37187dbc 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -52,6 +52,18 @@ archives: {{- else if eq .Arch "386" }}i386 {{- else }}{{ .Arch }}{{ end }} {{- if .Arm }}v{{ .Arm }}{{ end }} + files: + - web/package.json + - web/package-lock.json + - web/index.html + - web/components.json + - web/tsconfig.json + - web/tsconfig.app.json + - web/tsconfig.node.json + - web/vite.config.ts + - web/src/**/* + - web/scripts/**/* + - web/vite-plugins/**/* - id: neocode-gateway ids: diff --git a/README.md b/README.md index 56e6d347..2a32173c 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,13 @@ $env:OPENAI_API_KEY = "your_key_here" neocode --workdir /path/to/your/project ``` +如果你希望使用浏览器 Web UI,可以直接运行: +```bash +neocode web +``` + +标签发布版会在缺少 `web/dist` 时自动使用发布包内的 `web/` 源码执行 `npm install` 和 `npm run build`。这要求用户机器已安装 Node.js 和 npm;如果你使用源码仓库运行,也保留相同的自动构建行为。 + ### 4. 常用命令 ```text diff --git a/internal/checkpoint/per_edit_snapshot.go b/internal/checkpoint/per_edit_snapshot.go index 3270ab79..9efa78d2 100644 --- a/internal/checkpoint/per_edit_snapshot.go +++ b/internal/checkpoint/per_edit_snapshot.go @@ -47,9 +47,10 @@ type FileVersionMeta struct { // CheckpointMeta 是 cp_.json 的内容。 type CheckpointMeta struct { - CheckpointID string `json:"checkpoint_id"` - CreatedAt time.Time `json:"created_at"` - FileVersions map[string]int `json:"file_versions"` + CheckpointID string `json:"checkpoint_id"` + CreatedAt time.Time `json:"created_at"` + FileVersions map[string]int `json:"file_versions"` + ExactFileVersions map[string]int `json:"exact_file_versions,omitempty"` } // perEditIndexEntry 是 index.jsonl 的单行结构,进程重启时用于重建内存索引。 @@ -227,6 +228,16 @@ func (s *PerEditSnapshotStore) CapturePostDelete(absPaths []string) error { // pathToVersions 为空时返回 (false, nil) 表示目前无文件被追踪过,无需写入。 // 调用方在 Finalize 后应调用 Reset 清空 pending。 func (s *PerEditSnapshotStore) Finalize(checkpointID string) (bool, error) { + return s.finalizeCheckpoint(checkpointID, false) +} + +// FinalizeWithExactState 在写入 pre-write 基线的同时,固化 checkpoint 结束时的精确文件版本。 +func (s *PerEditSnapshotStore) FinalizeWithExactState(checkpointID string) (bool, error) { + return s.finalizeCheckpoint(checkpointID, true) +} + +// finalizeCheckpoint 根据需要把 pending 基线与 checkpoint 末状态一并落盘。 +func (s *PerEditSnapshotStore) finalizeCheckpoint(checkpointID string, captureExactState bool) (bool, error) { if checkpointID == "" { return false, fmt.Errorf("per-edit: empty checkpointID") } @@ -265,10 +276,20 @@ func (s *PerEditSnapshotStore) Finalize(checkpointID string) (bool, error) { } } + var exactSnapshot map[string]int + if captureExactState { + var err error + exactSnapshot, err = s.captureExactStateSnapshot(snapshot) + if err != nil { + return false, err + } + } + meta := CheckpointMeta{ - CheckpointID: checkpointID, - CreatedAt: time.Now().UTC(), - FileVersions: snapshot, + CheckpointID: checkpointID, + CreatedAt: time.Now().UTC(), + FileVersions: snapshot, + ExactFileVersions: exactSnapshot, } if err := s.writeCheckpointMeta(meta); err != nil { return false, err @@ -305,6 +326,36 @@ func (s *PerEditSnapshotStore) FinalizePending(checkpointID string) (bool, error return true, nil } +// captureExactStateSnapshot 为当前 pending 里的每个文件追加一个“checkpoint 结束态”精确版本。 +func (s *PerEditSnapshotStore) captureExactStateSnapshot(baseVersions map[string]int) (map[string]int, error) { + s.indexMu.Lock() + defer s.indexMu.Unlock() + + hashes := make([]string, 0, len(baseVersions)) + for hash := range baseVersions { + hashes = append(hashes, hash) + } + sort.Strings(hashes) + + exactVersions := make(map[string]int, len(hashes)) + for _, hash := range hashes { + display := s.resolveDisplayPathLocked(hash, "") + if display == "" { + meta, err := s.readVersionMeta(hash, baseVersions[hash]) + if err != nil { + return nil, fmt.Errorf("per-edit: read baseline meta for %s: %w", hash, err) + } + display = meta.DisplayPath + } + version, err := s.captureExactCurrentVersionLocked(hash, display) + if err != nil { + return nil, err + } + exactVersions[hash] = version + } + return exactVersions, nil +} + // Reset 清空 pending 映射,每轮 turn 开始时调用,避免跨轮残留。 func (s *PerEditSnapshotStore) Reset() { s.pendingMu.Lock() @@ -781,6 +832,149 @@ const ( // FileChangeEntry 是 repository.FileChangeEntry 的别名,保留以维持向后兼容。 type FileChangeEntry = repository.FileChangeEntry +// DiffCheckpointsToWorkdir 按多个 checkpoint 的首次触碰版本作为基线,对比当前工作区最终状态。 +func (s *PerEditSnapshotStore) DiffCheckpointsToWorkdir(ctx context.Context, checkpointIDs []string) (string, []FileChangeEntry, error) { + if s == nil { + return "", nil, fmt.Errorf("per-edit: store not available") + } + baseVersions := make(map[string]int) + for _, checkpointID := range checkpointIDs { + cp, err := s.readCheckpointMeta(checkpointID) + if err != nil { + return "", nil, err + } + for hash, version := range cp.FileVersions { + if _, exists := baseVersions[hash]; !exists { + baseVersions[hash] = version + } + } + } + if len(baseVersions) == 0 { + return "", nil, nil + } + + s.indexMu.Lock() + defer s.indexMu.Unlock() + + hashes := make([]string, 0, len(baseVersions)) + for hash := range baseVersions { + hashes = append(hashes, hash) + } + sort.Strings(hashes) + + var patch bytes.Buffer + changes := make([]FileChangeEntry, 0, len(hashes)) + for _, hash := range hashes { + if err := ctx.Err(); err != nil { + return "", nil, err + } + fromContent, fromIsDir, fromExists, fromDisplay, err := s.contentAtExactVersionLocked(hash, baseVersions[hash]) + if err != nil { + return "", nil, err + } + display := s.resolveDisplayPathLocked(hash, fromDisplay) + toContent, toIsDir, toExists := readWorkdirContent(display) + if fromIsDir && toIsDir { + continue + } + if bytes.Equal(fromContent, toContent) && fromExists == toExists && fromIsDir == toIsDir { + continue + } + rel := filepath.ToSlash(s.relativeDisplay(display)) + kind := classifyFileChange(fromContent, fromIsDir, fromExists, toContent, toIsDir, toExists) + if kind != "" { + changes = append(changes, FileChangeEntry{Path: rel, Kind: kind}) + } + out, err := unifiedDiffForContents(rel, fromContent, toContent) + if err != nil { + return "", nil, err + } + patch.WriteString(out) + } + return strings.TrimRight(patch.String(), "\n"), changes, nil +} + +// DiffCheckpointsToCheckpoint 汇总多个 checkpoint 的首触碰基线,并对比目标 checkpoint 的精确结束态。 +func (s *PerEditSnapshotStore) DiffCheckpointsToCheckpoint( + ctx context.Context, + checkpointIDs []string, + targetCheckpointID string, +) (string, []FileChangeEntry, error) { + if s == nil { + return "", nil, fmt.Errorf("per-edit: store not available") + } + if strings.TrimSpace(targetCheckpointID) == "" { + return "", nil, fmt.Errorf("per-edit: target checkpoint id required") + } + + baseVersions := make(map[string]int) + for _, checkpointID := range checkpointIDs { + cp, err := s.readCheckpointMeta(checkpointID) + if err != nil { + return "", nil, err + } + for hash, version := range cp.FileVersions { + if _, exists := baseVersions[hash]; !exists { + baseVersions[hash] = version + } + } + } + if len(baseVersions) == 0 { + return "", nil, nil + } + + targetMeta, err := s.readCheckpointMeta(targetCheckpointID) + if err != nil { + return "", nil, err + } + + s.indexMu.Lock() + defer s.indexMu.Unlock() + + hashes := make([]string, 0, len(baseVersions)) + for hash := range baseVersions { + hashes = append(hashes, hash) + } + sort.Strings(hashes) + + var patch bytes.Buffer + changes := make([]FileChangeEntry, 0, len(hashes)) + for _, hash := range hashes { + if err := ctx.Err(); err != nil { + return "", nil, err + } + fromContent, fromIsDir, fromExists, fromDisplay, err := s.contentAtExactVersionLocked(hash, baseVersions[hash]) + if err != nil { + return "", nil, err + } + toContent, toIsDir, toExists, toDisplay, err := s.contentAtCheckpointTargetLocked(hash, targetMeta, fromDisplay) + if err != nil { + return "", nil, err + } + if fromIsDir && toIsDir { + continue + } + if bytes.Equal(fromContent, toContent) && fromExists == toExists && fromIsDir == toIsDir { + continue + } + display := toDisplay + if display == "" { + display = fromDisplay + } + rel := filepath.ToSlash(s.relativeDisplay(display)) + kind := classifyFileChange(fromContent, fromIsDir, fromExists, toContent, toIsDir, toExists) + if kind != "" { + changes = append(changes, FileChangeEntry{Path: rel, Kind: kind}) + } + out, err := unifiedDiffForContents(rel, fromContent, toContent) + if err != nil { + return "", nil, err + } + patch.WriteString(out) + } + return strings.TrimRight(patch.String(), "\n"), changes, nil +} + // ChangedFiles 端到端比较两个 checkpoint,返回 path → 变更类别的列表(按 path 字典序)。 // 不返回内容差异,仅用于 UI 分组(添加/删除/修改)。完整 patch 仍由 Diff 生成。 func (s *PerEditSnapshotStore) ChangedFiles(ctx context.Context, fromID, toID string) ([]FileChangeEntry, error) { @@ -1023,6 +1217,9 @@ func (s *PerEditSnapshotStore) readCheckpointMeta(checkpointID string) (Checkpoi if meta.FileVersions == nil { meta.FileVersions = map[string]int{} } + if meta.ExactFileVersions == nil { + meta.ExactFileVersions = map[string]int{} + } return meta, nil } @@ -1040,6 +1237,49 @@ func (s *PerEditSnapshotStore) readVersionBin(hash string, version int) ([]byte, return os.ReadFile(s.versionBinPath(hash, version)) } +// captureExactCurrentVersionLocked 读取当前工作区状态,并为同一路径追加一个精确版本。 +func (s *PerEditSnapshotStore) captureExactCurrentVersionLocked(hash, displayPath string) (int, error) { + cleanPath := filepath.Clean(displayPath) + if cleanPath == "" || cleanPath == "." { + return 0, fmt.Errorf("per-edit: missing display path for exact snapshot") + } + + versions := s.pathToVersions[hash] + nextVersion := 1 + if len(versions) > 0 { + nextVersion = versions[len(versions)-1] + 1 + } + + content, existed, isDir, mode, err := readFileForCapture(cleanPath) + if err != nil { + return 0, fmt.Errorf("per-edit: read exact state %s: %w", cleanPath, err) + } + + meta := FileVersionMeta{ + PathHash: hash, + DisplayPath: cleanPath, + Version: nextVersion, + Existed: existed, + IsDir: isDir, + Mode: mode, + CreatedAt: time.Now().UTC(), + } + if err := s.writeVersionFiles(meta, content); err != nil { + return 0, err + } + if err := s.appendIndex(perEditIndexEntry{ + PathHash: hash, + DisplayPath: cleanPath, + Version: nextVersion, + }); err != nil { + return 0, fmt.Errorf("per-edit: append exact index: %w", err) + } + + s.pathToVersions[hash] = append(versions, nextVersion) + s.displayPaths[hash] = cleanPath + return nextVersion, nil +} + // findNextVersionLocked 返回 hash 下大于 vA 的最小版本号,没有则返回 0。indexMu 必须被持有。 func (s *PerEditSnapshotStore) findNextVersionLocked(hash string, vA int) int { versions := s.pathToVersions[hash] @@ -1110,6 +1350,77 @@ func (s *PerEditSnapshotStore) contentAtCheckpointLocked(hash string, cpVersions return content, false, true, nextMeta.Mode, display, nil } +// contentAtExactVersionLocked 读取指定 hash/version 保存的精确内容,调用方必须持有 indexMu。 +func (s *PerEditSnapshotStore) contentAtExactVersionLocked(hash string, version int) ([]byte, bool, bool, string, error) { + meta, err := s.readVersionMeta(hash, version) + if err != nil { + return nil, false, false, "", fmt.Errorf("per-edit: read exact meta v%d for %s: %w", version, hash, err) + } + display := s.resolveDisplayPathLocked(hash, meta.DisplayPath) + if !meta.Existed { + return nil, false, false, display, nil + } + if meta.IsDir { + return nil, true, true, display, nil + } + content, err := s.readVersionBin(hash, version) + if err != nil { + return nil, false, false, display, fmt.Errorf("per-edit: read exact bin v%d for %s: %w", version, hash, err) + } + return content, false, true, display, nil +} + +// contentAtCheckpointTargetLocked 优先读取 checkpoint 记录的精确结束态,缺失时兼容回退到当前工作区。 +func (s *PerEditSnapshotStore) contentAtCheckpointTargetLocked( + hash string, + cp CheckpointMeta, + fallbackDisplay string, +) ([]byte, bool, bool, string, error) { + if version, ok := cp.ExactFileVersions[hash]; ok { + return s.contentAtExactVersionLocked(hash, version) + } + display := s.resolveDisplayPathLocked(hash, fallbackDisplay) + content, isDir, exists := readWorkdirContent(display) + return content, isDir, exists, display, nil +} + +// classifyFileChange 将端点状态归类为 UI 可展示的 added/deleted/modified。 +func classifyFileChange( + fromContent []byte, + fromIsDir bool, + fromExists bool, + toContent []byte, + toIsDir bool, + toExists bool, +) FileChangeKind { + switch { + case !fromExists && toExists: + return FileChangeAdded + case fromExists && !toExists: + return FileChangeDeleted + case fromIsDir != toIsDir || !bytes.Equal(fromContent, toContent): + return FileChangeModified + default: + return "" + } +} + +// unifiedDiffForContents 生成单个文件的 unified diff 片段,路径已按工作区相对路径传入。 +func unifiedDiffForContents(rel string, fromContent, toContent []byte) (string, error) { + diff := difflib.UnifiedDiff{ + A: difflib.SplitLines(string(fromContent)), + B: difflib.SplitLines(string(toContent)), + FromFile: "a/" + filepath.ToSlash(rel), + ToFile: "b/" + filepath.ToSlash(rel), + Context: 3, + } + out, err := difflib.GetUnifiedDiffString(diff) + if err != nil { + return "", fmt.Errorf("per-edit: diff %s: %w", rel, err) + } + return out, nil +} + func readWorkdirContent(absPath string) ([]byte, bool, bool) { if absPath == "" { return nil, false, false diff --git a/internal/checkpoint/per_edit_snapshot_test.go b/internal/checkpoint/per_edit_snapshot_test.go index 47b5e93a..27b7d94b 100644 --- a/internal/checkpoint/per_edit_snapshot_test.go +++ b/internal/checkpoint/per_edit_snapshot_test.go @@ -1124,6 +1124,252 @@ func TestChangedFiles_NewFileDetectedAsAdded(t *testing.T) { } } +// ──────────────────────── DiffCheckpointsToWorkdir tests ──────────────────────── + +func TestDiffCheckpointsToWorkdir_AggregatesRepeatedEdits(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "a.txt", "A\n") + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture cp1: %v", err) + } + if err := os.WriteFile(abs, []byte("B\n"), 0o644); err != nil { + t.Fatalf("write B: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture cp2: %v", err) + } + if err := os.WriteFile(abs, []byte("C\n"), 0o644); err != nil { + t.Fatalf("write C: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + patch, changes, err := store.DiffCheckpointsToWorkdir(context.Background(), []string{"cp1", "cp2"}) + if err != nil { + t.Fatalf("DiffCheckpointsToWorkdir() error = %v", err) + } + if len(changes) != 1 || changes[0].Path != "a.txt" || changes[0].Kind != FileChangeModified { + t.Fatalf("changes = %+v, want a.txt modified", changes) + } + if !strings.Contains(patch, "-A") || !strings.Contains(patch, "+C") || strings.Contains(patch, "-B") { + t.Fatalf("patch should compare A to C only, got:\n%s", patch) + } +} + +func TestDiffCheckpointsToWorkdir_ElidesRevertedAndAddDelete(t *testing.T) { + store, workdir := newTestStore(t) + reverted := writeWorkdirFile(t, workdir, "reverted.txt", "A\n") + transient := filepath.Join(workdir, "transient.txt") + + if _, err := store.CapturePreWrite(reverted); err != nil { + t.Fatalf("capture reverted cp1: %v", err) + } + if err := os.WriteFile(reverted, []byte("B\n"), 0o644); err != nil { + t.Fatalf("write reverted B: %v", err) + } + if _, err := store.CapturePreWrite(transient); err != nil { + t.Fatalf("capture transient cp1: %v", err) + } + if err := os.WriteFile(transient, []byte("created\n"), 0o644); err != nil { + t.Fatalf("write transient: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize cp1: %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(reverted); err != nil { + t.Fatalf("capture reverted cp2: %v", err) + } + if err := os.WriteFile(reverted, []byte("A\n"), 0o644); err != nil { + t.Fatalf("restore reverted A: %v", err) + } + if _, err := store.CapturePreWrite(transient); err != nil { + t.Fatalf("capture transient cp2: %v", err) + } + if err := os.Remove(transient); err != nil { + t.Fatalf("remove transient: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("finalize cp2: %v", err) + } + store.Reset() + + patch, changes, err := store.DiffCheckpointsToWorkdir(context.Background(), []string{"cp1", "cp2"}) + if err != nil { + t.Fatalf("DiffCheckpointsToWorkdir() error = %v", err) + } + if patch != "" || len(changes) != 0 { + t.Fatalf("expected empty aggregate diff, patch=%q changes=%+v", patch, changes) + } +} + +func TestDiffCheckpointsToWorkdir_DeletedExistingFile(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "gone.txt", "old\n") + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture: %v", err) + } + if err := os.Remove(abs); err != nil { + t.Fatalf("remove: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("finalize: %v", err) + } + store.Reset() + + patch, changes, err := store.DiffCheckpointsToWorkdir(context.Background(), []string{"cp1"}) + if err != nil { + t.Fatalf("DiffCheckpointsToWorkdir() error = %v", err) + } + if len(changes) != 1 || changes[0].Path != "gone.txt" || changes[0].Kind != FileChangeDeleted { + t.Fatalf("changes = %+v, want gone.txt deleted", changes) + } + if !strings.Contains(patch, "-old") { + t.Fatalf("patch should contain deleted content, got:\n%s", patch) + } +} + +func TestDiffCheckpointsToCheckpoint_UsesExactTargetState(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "tracked.txt", "A\n") + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture cp1: %v", err) + } + if err := os.WriteFile(abs, []byte("B\n"), 0o644); err != nil { + t.Fatalf("write B: %v", err) + } + if _, err := store.FinalizeWithExactState("cp1"); err != nil { + t.Fatalf("FinalizeWithExactState(cp1): %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture cp2: %v", err) + } + if err := os.WriteFile(abs, []byte("C\n"), 0o644); err != nil { + t.Fatalf("write C: %v", err) + } + if _, err := store.FinalizeWithExactState("cp2"); err != nil { + t.Fatalf("FinalizeWithExactState(cp2): %v", err) + } + store.Reset() + + if err := os.WriteFile(abs, []byte("D\n"), 0o644); err != nil { + t.Fatalf("write D drift: %v", err) + } + + patch, changes, err := store.DiffCheckpointsToCheckpoint(context.Background(), []string{"cp1", "cp2"}, "cp2") + if err != nil { + t.Fatalf("DiffCheckpointsToCheckpoint() error = %v", err) + } + if len(changes) != 1 || changes[0].Path != "tracked.txt" || changes[0].Kind != FileChangeModified { + t.Fatalf("changes = %+v, want tracked.txt modified", changes) + } + if !strings.Contains(patch, "-A") || !strings.Contains(patch, "+C") { + t.Fatalf("patch should compare A to C, got:\n%s", patch) + } + if strings.Contains(patch, "+D") || strings.Contains(patch, "-B") { + t.Fatalf("patch should ignore later workdir drift, got:\n%s", patch) + } +} + +func TestDiffCheckpointsToCheckpoint_ElidesRevertedAndTransientFiles(t *testing.T) { + store, workdir := newTestStore(t) + reverted := writeWorkdirFile(t, workdir, "reverted.txt", "A\n") + transient := filepath.Join(workdir, "transient.txt") + + if _, err := store.CapturePreWrite(reverted); err != nil { + t.Fatalf("capture reverted cp1: %v", err) + } + if err := os.WriteFile(reverted, []byte("B\n"), 0o644); err != nil { + t.Fatalf("write reverted B: %v", err) + } + if _, err := store.CapturePreWrite(transient); err != nil { + t.Fatalf("capture transient cp1: %v", err) + } + if err := os.WriteFile(transient, []byte("created\n"), 0o644); err != nil { + t.Fatalf("write transient: %v", err) + } + if _, err := store.FinalizeWithExactState("cp1"); err != nil { + t.Fatalf("FinalizeWithExactState(cp1): %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(reverted); err != nil { + t.Fatalf("capture reverted cp2: %v", err) + } + if err := os.WriteFile(reverted, []byte("A\n"), 0o644); err != nil { + t.Fatalf("restore reverted A: %v", err) + } + if _, err := store.CapturePreWrite(transient); err != nil { + t.Fatalf("capture transient cp2: %v", err) + } + if err := os.Remove(transient); err != nil { + t.Fatalf("remove transient: %v", err) + } + if _, err := store.FinalizeWithExactState("cp2"); err != nil { + t.Fatalf("FinalizeWithExactState(cp2): %v", err) + } + store.Reset() + + patch, changes, err := store.DiffCheckpointsToCheckpoint(context.Background(), []string{"cp1", "cp2"}, "cp2") + if err != nil { + t.Fatalf("DiffCheckpointsToCheckpoint() error = %v", err) + } + if patch != "" || len(changes) != 0 { + t.Fatalf("expected empty aggregate diff, patch=%q changes=%+v", patch, changes) + } +} + +func TestDiffCheckpointsToCheckpoint_FallsBackWhenExactStateMissing(t *testing.T) { + store, workdir := newTestStore(t) + abs := writeWorkdirFile(t, workdir, "tracked.txt", "A\n") + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture cp1: %v", err) + } + if err := os.WriteFile(abs, []byte("B\n"), 0o644); err != nil { + t.Fatalf("write B: %v", err) + } + if _, err := store.Finalize("cp1"); err != nil { + t.Fatalf("Finalize(cp1): %v", err) + } + store.Reset() + + if _, err := store.CapturePreWrite(abs); err != nil { + t.Fatalf("capture cp2: %v", err) + } + if err := os.WriteFile(abs, []byte("C\n"), 0o644); err != nil { + t.Fatalf("write C: %v", err) + } + if _, err := store.Finalize("cp2"); err != nil { + t.Fatalf("Finalize(cp2): %v", err) + } + store.Reset() + + patch, changes, err := store.DiffCheckpointsToCheckpoint(context.Background(), []string{"cp1", "cp2"}, "cp2") + if err != nil { + t.Fatalf("DiffCheckpointsToCheckpoint() error = %v", err) + } + if len(changes) != 1 || changes[0].Path != "tracked.txt" || changes[0].Kind != FileChangeModified { + t.Fatalf("changes = %+v, want tracked.txt modified", changes) + } + if !strings.Contains(patch, "-A") || !strings.Contains(patch, "+C") { + t.Fatalf("patch should fall back to workdir and compare A to C, got:\n%s", patch) + } +} + // ──────────────────────── RunAggregateDiff tests ──────────────────────── func TestRunAggregateDiff_ModifiedFileAcrossCheckpoints(t *testing.T) { diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index 3cbef499..46455254 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -605,24 +605,27 @@ func (b *gatewayRuntimePortBridge) ListFiles(ctx context.Context, input gateway. return result, nil } -// ListModels 列出可用模型(仅返回当前选中 provider 的模型)。 +// ListModels 列出可用模型;有会话时按会话有效 provider 返回,无会话时按全局默认 provider 返回。 func (b *gatewayRuntimePortBridge) ListModels(ctx context.Context, input gateway.ListModelsInput) ([]gateway.ModelEntry, error) { if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { return nil, err } - - // 复用 ListProviders 的 Selected 标记逻辑,避免与 cfg.SelectedProvider 单独比较产生不一致 - providers, err := b.ListProviders(ctx, gateway.ListProvidersInput{SubjectID: input.SubjectID}) + options, err := b.listProviderOptions(ctx) + if err != nil { + return nil, err + } + providerID, _, err := b.resolveEffectiveProviderModel(ctx, strings.TrimSpace(input.SessionID), options) if err != nil { return nil, err } models := make([]gateway.ModelEntry, 0) - for _, p := range providers { - if !p.Selected { + for _, option := range options { + optionID := strings.TrimSpace(option.ID) + if providerID != "" && !strings.EqualFold(providerID, optionID) { continue } - for _, model := range p.Models { + for _, model := range option.Models { id := strings.TrimSpace(model.ID) if id == "" { continue @@ -634,7 +637,7 @@ func (b *gatewayRuntimePortBridge) ListModels(ctx context.Context, input gateway models = append(models, gateway.ModelEntry{ ID: id, Name: name, - Provider: strings.TrimSpace(p.ID), + Provider: optionID, CapabilityHints: model.CapabilityHints, }) } @@ -674,27 +677,18 @@ func (b *gatewayRuntimePortBridge) GetSessionModel(ctx context.Context, input ga if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { return gateway.SessionModelResult{}, err } - if b.sessionStore == nil { - return gateway.SessionModelResult{}, fmt.Errorf("gateway runtime bridge: session store is unavailable") - } - session, err := b.loadStoredSession(ctx, strings.TrimSpace(input.SessionID)) + options, err := b.listProviderOptions(ctx) if err != nil { return gateway.SessionModelResult{}, err } - providerID, modelID := strings.TrimSpace(session.Provider), strings.TrimSpace(session.Model) - if providerID == "" || modelID == "" { - cfg := b.currentConfig() - if providerID == "" { - providerID = strings.TrimSpace(cfg.SelectedProvider) - } - if modelID == "" { - modelID = strings.TrimSpace(cfg.CurrentModel) - } + providerID, modelID, err := b.resolveEffectiveProviderModel(ctx, strings.TrimSpace(input.SessionID), options) + if err != nil { + return gateway.SessionModelResult{}, err } return gateway.SessionModelResult{ ProviderID: providerID, ModelID: modelID, - ModelName: b.modelDisplayName(ctx, providerID, modelID), + ModelName: b.modelDisplayNameFromOptions(providerID, modelID, options), Provider: providerID, }, nil } @@ -1412,6 +1406,132 @@ func (b *gatewayRuntimePortBridge) resolveProviderModelForSession( return "", "", fmt.Errorf("gateway runtime bridge: model %q not found", modelID) } +// listProviderOptions 读取当前 provider 选项快照,供模型选择相关逻辑复用。 +func (b *gatewayRuntimePortBridge) listProviderOptions(ctx context.Context) ([]configstate.ProviderOption, error) { + if b.providerSelection == nil { + return nil, fmt.Errorf("gateway runtime bridge: provider selection is unavailable") + } + return b.providerSelection.ListProviderOptions(ctx) +} + +// resolveEffectiveProviderModel 解析当前会话或全局默认的有效 provider/model,不会回写会话状态。 +func (b *gatewayRuntimePortBridge) resolveEffectiveProviderModel( + ctx context.Context, + sessionID string, + options []configstate.ProviderOption, +) (string, string, error) { + sessionID = strings.TrimSpace(sessionID) + sessionProviderID := "" + sessionModelID := "" + if sessionID != "" { + if b.sessionStore == nil { + return "", "", fmt.Errorf("gateway runtime bridge: session store is unavailable") + } + session, err := b.loadStoredSession(ctx, sessionID) + if err != nil { + return "", "", err + } + sessionProviderID = strings.TrimSpace(session.Provider) + sessionModelID = strings.TrimSpace(session.Model) + } + + cfg := b.currentConfig() + defaultProviderID := strings.TrimSpace(cfg.SelectedProvider) + defaultModelID := strings.TrimSpace(cfg.CurrentModel) + + selection, ok := resolveEffectiveProviderModelSelection( + options, + sessionProviderID, + sessionModelID, + defaultProviderID, + defaultModelID, + ) + if !ok { + return "", "", fmt.Errorf("gateway runtime bridge: no available provider/model selection") + } + return selection.ProviderID, selection.ModelID, nil +} + +type effectiveProviderModelSelection struct { + ProviderID string + ModelID string +} + +// resolveEffectiveProviderModelSelection 按“会话优先、全局兜底”规则解析有效 provider/model。 +func resolveEffectiveProviderModelSelection( + options []configstate.ProviderOption, + sessionProviderID string, + sessionModelID string, + defaultProviderID string, + defaultModelID string, +) (effectiveProviderModelSelection, bool) { + findProvider := func(providerID string) *configstate.ProviderOption { + providerID = strings.TrimSpace(providerID) + if providerID == "" { + return nil + } + for i := range options { + if strings.EqualFold(strings.TrimSpace(options[i].ID), providerID) { + return &options[i] + } + } + return nil + } + firstModelID := func(option *configstate.ProviderOption) string { + if option == nil { + return "" + } + for _, model := range option.Models { + if id := strings.TrimSpace(model.ID); id != "" { + return id + } + } + return "" + } + resolveModelID := func(option *configstate.ProviderOption, preferredModelID string) string { + preferredModelID = strings.TrimSpace(preferredModelID) + if option == nil { + return "" + } + if preferredModelID != "" { + for _, model := range option.Models { + if strings.EqualFold(strings.TrimSpace(model.ID), preferredModelID) { + return strings.TrimSpace(model.ID) + } + } + } + return firstModelID(option) + } + firstAvailable := func() (effectiveProviderModelSelection, bool) { + for _, option := range options { + providerID := strings.TrimSpace(option.ID) + modelID := firstModelID(&option) + if providerID != "" && modelID != "" { + return effectiveProviderModelSelection{ProviderID: providerID, ModelID: modelID}, true + } + } + return effectiveProviderModelSelection{}, false + } + + if option := findProvider(sessionProviderID); option != nil { + if modelID := resolveModelID(option, sessionModelID); modelID != "" { + return effectiveProviderModelSelection{ + ProviderID: strings.TrimSpace(option.ID), + ModelID: modelID, + }, true + } + } + if option := findProvider(defaultProviderID); option != nil { + if modelID := resolveModelID(option, defaultModelID); modelID != "" { + return effectiveProviderModelSelection{ + ProviderID: strings.TrimSpace(option.ID), + ModelID: modelID, + }, true + } + } + return firstAvailable() +} + // modelDisplayName 从 provider 候选中查找模型展示名,找不到时回退模型 ID。 func (b *gatewayRuntimePortBridge) modelDisplayName(ctx context.Context, providerID string, modelID string) string { modelID = strings.TrimSpace(modelID) @@ -1438,6 +1558,32 @@ func (b *gatewayRuntimePortBridge) modelDisplayName(ctx context.Context, provide return modelID } +// modelDisplayNameFromOptions 基于 provider 选项快照查找模型展示名,避免重复读取 provider 列表。 +func (b *gatewayRuntimePortBridge) modelDisplayNameFromOptions( + providerID string, + modelID string, + options []configstate.ProviderOption, +) string { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return modelID + } + for _, option := range options { + if providerID != "" && !strings.EqualFold(strings.TrimSpace(option.ID), strings.TrimSpace(providerID)) { + continue + } + for _, model := range option.Models { + if strings.EqualFold(strings.TrimSpace(model.ID), modelID) { + if name := strings.TrimSpace(model.Name); name != "" { + return name + } + return strings.TrimSpace(model.ID) + } + } + } + return modelID +} + // ReloadConfig 从磁盘重新加载内存配置快照,使管理端口的写入对其他工作区可见。 func (b *gatewayRuntimePortBridge) ReloadConfig(ctx context.Context) error { if b == nil || b.configManager == nil { diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index 8a4b4ef8..94da5148 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -1883,7 +1883,12 @@ func TestGatewayRuntimePortBridgeSetSessionModelNotFound(t *testing.T) { func TestGatewayRuntimePortBridgeGetSessionModelStoreNil(t *testing.T) { stub := &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)} - bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, nil) + ps := &providerSelectionStub{ + listOptions: []configstate.ProviderOption{ + {ID: "openai", Models: []providertypes.ModelDescriptor{{ID: "gpt-4"}}}, + }, + } + bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, nil, nil, ps) defer bridge.Close() _, err := bridge.GetSessionModel(context.Background(), gateway.GetSessionModelInput{ @@ -1898,7 +1903,12 @@ func TestGatewayRuntimePortBridgeGetSessionModelStoreNil(t *testing.T) { func TestGatewayRuntimePortBridgeGetSessionModelLoadFail(t *testing.T) { store := &bridgeSessionStoreWithLoader{loadErr: errors.New("load failed")} stub := &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)} - bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, store) + ps := &providerSelectionStub{ + listOptions: []configstate.ProviderOption{ + {ID: "openai", Models: []providertypes.ModelDescriptor{{ID: "gpt-4"}}}, + }, + } + bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, store, nil, ps) defer bridge.Close() _, err := bridge.GetSessionModel(context.Background(), gateway.GetSessionModelInput{ @@ -1942,7 +1952,11 @@ func TestGatewayRuntimePortBridgeGetSessionModelDisplayNameNotFound(t *testing.T store := &bridgeSessionStoreWithLoader{ session: agentsession.Session{ID: "s-1", Provider: "openai", Model: "gpt-4"}, } - ps := &providerSelectionStub{listOptions: []configstate.ProviderOption{}} + ps := &providerSelectionStub{ + listOptions: []configstate.ProviderOption{ + {ID: "openai", Models: []providertypes.ModelDescriptor{{ID: "gpt-4", Name: ""}}}, + }, + } stub := &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)} bridge, _ := newGatewayRuntimePortBridge(context.Background(), stub, store, nil, ps) defer bridge.Close() @@ -2187,6 +2201,124 @@ func TestResolveProviderModelForSession(t *testing.T) { }) } +func TestResolveEffectiveProviderModelSelection(t *testing.T) { + t.Run("session provider wins and falls back within provider", func(t *testing.T) { + selection, ok := resolveEffectiveProviderModelSelection( + []configstate.ProviderOption{ + {ID: "openai", Models: []providertypes.ModelDescriptor{{ID: "gpt-4.1"}, {ID: "gpt-4o"}}}, + {ID: "gemini", Models: []providertypes.ModelDescriptor{{ID: "gemini-2.5-pro"}}}, + }, + "openai", + "missing-model", + "gemini", + "gemini-2.5-pro", + ) + if !ok { + t.Fatal("expected effective selection") + } + if selection.ProviderID != "openai" || selection.ModelID != "gpt-4.1" { + t.Fatalf("selection = %+v, want openai/gpt-4.1", selection) + } + }) + + t.Run("invalid session provider falls back to global default", func(t *testing.T) { + selection, ok := resolveEffectiveProviderModelSelection( + []configstate.ProviderOption{ + {ID: "openai", Models: []providertypes.ModelDescriptor{{ID: "gpt-4.1"}}}, + {ID: "gemini", Models: []providertypes.ModelDescriptor{{ID: "gemini-2.5-pro"}}}, + }, + "unknown", + "unknown-model", + "gemini", + "gemini-2.5-pro", + ) + if !ok { + t.Fatal("expected effective selection") + } + if selection.ProviderID != "gemini" || selection.ModelID != "gemini-2.5-pro" { + t.Fatalf("selection = %+v, want gemini/gemini-2.5-pro", selection) + } + }) +} + +func TestGatewayRuntimePortBridgeListModelsUsesSessionProvider(t *testing.T) { + store := &bridgeSessionStoreWithLoader{ + bridgeSessionStoreStub: bridgeSessionStoreStub{}, + session: agentsession.Session{ + ID: "session-1", + Provider: "openai", + Model: "gpt-4.1", + }, + } + cfgMgr := &configManagerStub{ + cfg: config.Config{ + SelectedProvider: "gemini", + CurrentModel: "gemini-2.5-pro", + }, + } + ps := &providerSelectionStub{ + listOptions: []configstate.ProviderOption{ + {ID: "openai", Models: []providertypes.ModelDescriptor{{ID: "gpt-4.1", Name: "GPT-4.1"}}}, + {ID: "gemini", Models: []providertypes.ModelDescriptor{{ID: "gemini-2.5-pro", Name: "Gemini 2.5 Pro"}}}, + }, + } + bridge, _ := newGatewayRuntimePortBridge(context.Background(), &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, store, cfgMgr, ps) + defer bridge.Close() + + models, err := bridge.ListModels(context.Background(), gateway.ListModelsInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + }) + if err != nil { + t.Fatalf("ListModels() error = %v", err) + } + if len(models) != 1 { + t.Fatalf("models len = %d, want 1", len(models)) + } + if models[0].Provider != "openai" || models[0].ID != "gpt-4.1" { + t.Fatalf("models = %+v, want openai/gpt-4.1 only", models) + } +} + +func TestGatewayRuntimePortBridgeGetSessionModelFallsBackToEffectiveSelection(t *testing.T) { + store := &bridgeSessionStoreWithLoader{ + bridgeSessionStoreStub: bridgeSessionStoreStub{}, + session: agentsession.Session{ + ID: "session-1", + Provider: "openai", + Model: "missing-model", + }, + } + cfgMgr := &configManagerStub{ + cfg: config.Config{ + SelectedProvider: "gemini", + CurrentModel: "gemini-2.5-pro", + }, + } + ps := &providerSelectionStub{ + listOptions: []configstate.ProviderOption{ + {ID: "openai", Models: []providertypes.ModelDescriptor{{ID: "gpt-4.1", Name: "GPT-4.1"}}}, + {ID: "gemini", Models: []providertypes.ModelDescriptor{{ID: "gemini-2.5-pro", Name: "Gemini 2.5 Pro"}}}, + }, + } + bridge, _ := newGatewayRuntimePortBridge(context.Background(), &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, store, cfgMgr, ps) + defer bridge.Close() + + result, err := bridge.GetSessionModel(context.Background(), gateway.GetSessionModelInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + }) + if err != nil { + t.Fatalf("GetSessionModel() error = %v", err) + } + if result.ProviderID != "openai" || result.ModelID != "gpt-4.1" { + t.Fatalf("result = %+v, want openai/gpt-4.1", result) + } + if result.ModelName != "GPT-4.1" { + t.Fatalf("model name = %q, want %q", result.ModelName, "GPT-4.1") + } +} + func TestModelDisplayName(t *testing.T) { t.Run("provider filter", func(t *testing.T) { ps := &providerSelectionStub{ diff --git a/internal/cli/web_command.go b/internal/cli/web_command.go index 4164eac9..39fcee87 100644 --- a/internal/cli/web_command.go +++ b/internal/cli/web_command.go @@ -21,14 +21,20 @@ import ( "neo-code/internal/webassets" ) +var ( + webCommandStartGatewayServer = startGatewayServer + webCommandBuildFrontend = buildFrontend + webCommandLookPath = exec.LookPath +) + type webCommandOptions struct { - HTTPAddress string - LogLevel string - StaticDir string - OpenBrowser bool - SkipBuild bool - Workdir string - TokenFile string + HTTPAddress string + LogLevel string + StaticDir string + OpenBrowser bool + SkipBuild bool + Workdir string + TokenFile string } // newWebCommand 创建并返回根命令下的 web 子命令,负责构建前端并启动带 Web UI 的 Gateway。 @@ -36,9 +42,9 @@ func newWebCommand() *cobra.Command { options := &webCommandOptions{} cmd := &cobra.Command{ - Use: "web", - Short: "Start NeoCode with Web UI", - Long: "Build frontend assets (if needed) and start the gateway with an integrated web UI.\nOpen http://127.0.0.1:8080 in your browser to use the interactive coding agent.", + Use: "web", + Short: "Start NeoCode with Web UI", + Long: "Build frontend assets (if needed) and start the gateway with an integrated web UI.\nOpen http://127.0.0.1:8080 in your browser to use the interactive coding agent.", SilenceUsage: true, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { @@ -88,12 +94,17 @@ func runWebCommand(ctx context.Context, options webCommandOptions) error { webDir := findWebSourceDir() if webDir == "" { if err != nil { - return fmt.Errorf("frontend not found: %w", err) + return fmt.Errorf( + "frontend assets unavailable: %w; release packages must include the web/ source directory, or source builds must run from the project root or use --static-dir", + err, + ) } - return fmt.Errorf("web source directory not found; run from project root or set --static-dir") + return fmt.Errorf( + "web source directory not found; release packages must include web/, or source builds must run from project root or set --static-dir", + ) } - if buildErr := buildFrontend(webDir, logger); buildErr != nil { - return fmt.Errorf("frontend build failed: %w", buildErr) + if buildErr := webCommandBuildFrontend(webDir, logger); buildErr != nil { + return fmt.Errorf("frontend build failed on this machine after detecting bundled web source: %w", buildErr) } // 构建后重新解析 staticDir, err = resolveWebStaticDir(options.StaticDir) @@ -131,7 +142,7 @@ func runWebCommand(ctx context.Context, options webCommandOptions) error { } } - return startGatewayServer(ctx, gatewayOpts, staticDir, staticFileFS, onNetworkReady) + return webCommandStartGatewayServer(ctx, gatewayOpts, staticDir, staticFileFS, onNetworkReady) } // resolveWebStaticDir 按 --static-dir → /web/dist → /web/dist 顺序查找前端静态文件。 @@ -146,7 +157,7 @@ func resolveWebStaticDir(override string) (string, error) { } // 相对于可执行文件(适用于安装的二进制) - if exe, err := os.Executable(); err == nil { + if exe, err := resolveExecutablePath(); err == nil { exeDir := filepath.Dir(exe) if dir, err := validateStaticDir(filepath.Join(exeDir, "web", "dist")); err == nil { return dir, nil @@ -180,7 +191,7 @@ func findWebSourceDir() string { candidates := []string{ filepath.Join(".", "web"), } - if exe, err := os.Executable(); err == nil { + if exe, err := resolveExecutablePath(); err == nil { exeDir := filepath.Dir(exe) candidates = append(candidates, filepath.Join(exeDir, "web"), @@ -270,9 +281,11 @@ func findNPMBinary() (string, error) { if runtime.GOOS == "windows" { name = "npm.cmd" } - path, err := exec.LookPath(name) + path, err := webCommandLookPath(name) if err != nil { - return "", fmt.Errorf("npm not found on PATH; install Node.js to build the frontend, or use --static-dir to specify pre-built assets") + return "", fmt.Errorf( + "npm not found on PATH; install Node.js and npm on this machine so `neocode web` can build the bundled frontend automatically, or use --static-dir to specify pre-built assets", + ) } return path, nil } diff --git a/internal/cli/web_command_test.go b/internal/cli/web_command_test.go new file mode 100644 index 00000000..0f2cde78 --- /dev/null +++ b/internal/cli/web_command_test.go @@ -0,0 +1,188 @@ +package cli + +import ( + "context" + "errors" + "io/fs" + "log" + "os" + "path/filepath" + "strings" + "testing" +) + +// writeWebCommandTestFile 写入 web 命令测试所需的最小文件内容,避免各测试重复拼装目录。 +func writeWebCommandTestFile(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", path, err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write %s: %v", path, err) + } +} + +// chdirForWebCommandTest 切换当前工作目录,并在测试结束后恢复。 +func chdirForWebCommandTest(t *testing.T, dir string) { + t.Helper() + original, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("chdir %s: %v", dir, err) + } + t.Cleanup(func() { + if err := os.Chdir(original); err != nil { + t.Fatalf("restore cwd: %v", err) + } + }) +} + +// stubResolveExecutablePath 替换可执行文件路径解析,便于覆盖发布包布局分支。 +func stubResolveExecutablePath(t *testing.T, fn func() (string, error)) { + t.Helper() + original := resolveExecutablePath + resolveExecutablePath = fn + t.Cleanup(func() { + resolveExecutablePath = original + }) +} + +// stubWebCommandHooks 替换 web 命令测试中的可注入执行点,并在结束后恢复。 +func stubWebCommandHooks( + t *testing.T, + startGateway func(context.Context, gatewayCommandOptions, string, fs.FS, func(string)) error, + build func(string, *log.Logger) error, + lookPath func(string) (string, error), +) { + t.Helper() + originalStart := webCommandStartGatewayServer + originalBuild := webCommandBuildFrontend + originalLookPath := webCommandLookPath + if startGateway != nil { + webCommandStartGatewayServer = startGateway + } + if build != nil { + webCommandBuildFrontend = build + } + if lookPath != nil { + webCommandLookPath = lookPath + } + t.Cleanup(func() { + webCommandStartGatewayServer = originalStart + webCommandBuildFrontend = originalBuild + webCommandLookPath = originalLookPath + }) +} + +func TestFindWebSourceDirUsesCurrentWorkdir(t *testing.T) { + tempDir := t.TempDir() + chdirForWebCommandTest(t, tempDir) + stubResolveExecutablePath(t, func() (string, error) { + return "", errors.New("skip executable lookup") + }) + + writeWebCommandTestFile(t, filepath.Join(tempDir, "web", "package.json"), "{}") + + got := findWebSourceDir() + want := filepath.Join(tempDir, "web") + if got != want { + t.Fatalf("findWebSourceDir() = %q, want %q", got, want) + } +} + +func TestFindWebSourceDirFallsBackToExecutableDir(t *testing.T) { + tempDir := t.TempDir() + chdirForWebCommandTest(t, tempDir) + + releaseDir := filepath.Join(tempDir, "release") + writeWebCommandTestFile(t, filepath.Join(releaseDir, "web", "package.json"), "{}") + stubResolveExecutablePath(t, func() (string, error) { + return filepath.Join(releaseDir, "neocode.exe"), nil + }) + + got := findWebSourceDir() + want := filepath.Join(releaseDir, "web") + if got != want { + t.Fatalf("findWebSourceDir() = %q, want %q", got, want) + } +} + +func TestResolveWebStaticDirFallsBackToExecutableDir(t *testing.T) { + tempDir := t.TempDir() + chdirForWebCommandTest(t, tempDir) + + releaseDir := filepath.Join(tempDir, "release") + writeWebCommandTestFile(t, filepath.Join(releaseDir, "web", "dist", "index.html"), "") + stubResolveExecutablePath(t, func() (string, error) { + return filepath.Join(releaseDir, "neocode.exe"), nil + }) + + got, err := resolveWebStaticDir("") + if err != nil { + t.Fatalf("resolveWebStaticDir returned error: %v", err) + } + want := filepath.Join(releaseDir, "web", "dist") + if got != want { + t.Fatalf("resolveWebStaticDir() = %q, want %q", got, want) + } +} + +func TestFindNPMBinaryMissingMessage(t *testing.T) { + stubWebCommandHooks(t, nil, nil, func(string) (string, error) { + return "", errors.New("not found") + }) + + _, err := findNPMBinary() + if err == nil { + t.Fatal("findNPMBinary() error = nil, want error") + } + message := err.Error() + if !strings.Contains(message, "Node.js and npm") { + t.Fatalf("findNPMBinary() error = %q, want Node.js/npm guidance", message) + } + if !strings.Contains(message, "`neocode web`") { + t.Fatalf("findNPMBinary() error = %q, want neocode web guidance", message) + } +} + +func TestRunWebCommandBuildsFrontendWhenDistMissing(t *testing.T) { + tempDir := t.TempDir() + chdirForWebCommandTest(t, tempDir) + writeWebCommandTestFile(t, filepath.Join(tempDir, "web", "package.json"), "{}") + + buildCalled := false + var capturedStaticDir string + sentinelErr := errors.New("stop after start") + stubWebCommandHooks( + t, + func(_ context.Context, _ gatewayCommandOptions, staticDir string, _ fs.FS, _ func(string)) error { + capturedStaticDir = staticDir + return sentinelErr + }, + func(webDir string, _ *log.Logger) error { + buildCalled = true + writeWebCommandTestFile(t, filepath.Join(webDir, "dist", "index.html"), "") + return nil + }, + nil, + ) + + err := runWebCommand(context.Background(), webCommandOptions{ + HTTPAddress: "127.0.0.1:8080", + LogLevel: "info", + OpenBrowser: false, + Workdir: tempDir, + }) + if !errors.Is(err, sentinelErr) { + t.Fatalf("runWebCommand() error = %v, want sentinel error %v", err, sentinelErr) + } + if !buildCalled { + t.Fatal("runWebCommand() did not invoke frontend build when dist was missing") + } + wantStaticDir := filepath.Join(tempDir, "web", "dist") + if capturedStaticDir != wantStaticDir { + t.Fatalf("startGatewayServer staticDir = %q, want %q", capturedStaticDir, wantStaticDir) + } +} diff --git a/internal/gateway/bootstrap.go b/internal/gateway/bootstrap.go index 780cca73..d9538d84 100644 --- a/internal/gateway/bootstrap.go +++ b/internal/gateway/bootstrap.go @@ -905,8 +905,9 @@ func handleListModelsFrame(ctx context.Context, frame MessageFrame, runtimePort RequestID: frame.RequestID, SessionID: strings.TrimSpace(frame.SessionID), Payload: map[string]any{ - "models": models, - "selected_model_id": sessionModel.ModelID, + "models": models, + "selected_provider_id": sessionModel.ProviderID, + "selected_model_id": sessionModel.ModelID, }, } } @@ -2122,8 +2123,8 @@ func decodeCheckpointDiffPayload(payload any) CheckpointDiffInput { var decoded struct { SessionID string `json:"session_id"` CheckpointID string `json:"checkpoint_id"` - Scope string `json:"scope"` RunID string `json:"run_id"` + Scope string `json:"scope"` } _ = json.Unmarshal(raw, &decoded) return CheckpointDiffInput{ diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index 67d71ddb..4312ee4b 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -623,16 +623,20 @@ func TestDecodeCheckpointDiffPayloadBranches(t *testing.T) { params := decodeCheckpointDiffPayload(map[string]any{ "session_id": " session-1 ", "checkpoint_id": " cp-1 ", + "run_id": " run-1 ", + "scope": " run ", }) - if params.SessionID != "session-1" || params.CheckpointID != "cp-1" { + if params.SessionID != "session-1" || params.CheckpointID != "cp-1" || params.RunID != "run-1" || params.Scope != "run" { t.Fatalf("decode map payload = %#v", params) } params = decodeCheckpointDiffPayload(CheckpointDiffInput{ SessionID: "session-2", CheckpointID: "cp-2", + RunID: "run-2", + Scope: "run", }) - if params.SessionID != "session-2" || params.CheckpointID != "cp-2" { + if params.SessionID != "session-2" || params.CheckpointID != "cp-2" || params.RunID != "run-2" || params.Scope != "run" { t.Fatalf("decode struct payload = %#v", params) } diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go index 7bfa80b6..67f155f2 100644 --- a/internal/gateway/protocol/jsonrpc_test.go +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -273,7 +273,7 @@ func TestNormalizeJSONRPCRequestCheckpointMethods(t *testing.T) { JSONRPC: JSONRPCVersion, ID: json.RawMessage(`"checkpoint-diff-1"`), Method: MethodGatewayCheckpointDiff, - Params: json.RawMessage(`{"session_id":" s-1 ","checkpoint_id":" cp-1 "}`), + Params: json.RawMessage(`{"session_id":" s-1 ","checkpoint_id":" cp-1 ","run_id":" run-1 ","scope":" run "}`), }) if rpcErr != nil { t.Fatalf("normalize checkpoint.diff request: %v", rpcErr) @@ -285,7 +285,7 @@ func TestNormalizeJSONRPCRequestCheckpointMethods(t *testing.T) { if !ok { t.Fatalf("payload type = %T, want CheckpointDiffParams", normalized.Payload) } - if params.SessionID != "s-1" || params.CheckpointID != "cp-1" { + if params.SessionID != "s-1" || params.CheckpointID != "cp-1" || params.RunID != "run-1" || params.Scope != "run" { t.Fatalf("checkpoint.diff params = %#v, want trimmed params", params) } }) diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index 8af6718c..3075ac75 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -681,8 +681,8 @@ func TestCheckpointDiffSelectsLatestCodeCheckpointAndRejectsSessionOnlyTarget(t if err := os.WriteFile(target, []byte("two\n"), 0o644); err != nil { t.Fatalf("WriteFile(cp1 next) error = %v", err) } - if _, err := perEditStore.Finalize("cp-1"); err != nil { - t.Fatalf("Finalize(cp-1) error = %v", err) + if _, err := perEditStore.FinalizeWithExactState("cp-1"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-1) error = %v", err) } perEditStore.Reset() @@ -692,11 +692,15 @@ func TestCheckpointDiffSelectsLatestCodeCheckpointAndRejectsSessionOnlyTarget(t if err := os.WriteFile(target, []byte("three\n"), 0o644); err != nil { t.Fatalf("WriteFile(cp2 next) error = %v", err) } - if _, err := perEditStore.Finalize("cp-2"); err != nil { - t.Fatalf("Finalize(cp-2) error = %v", err) + if _, err := perEditStore.FinalizeWithExactState("cp-2"); err != nil { + t.Fatalf("FinalizeWithExactState(cp-2) error = %v", err) } perEditStore.Reset() + if err := os.WriteFile(target, []byte("four\n"), 0o644); err != nil { + t.Fatalf("WriteFile(four) error = %v", err) + } + spy := &checkpointStoreSpy{ listRecords: []agentsession.CheckpointRecord{ { @@ -847,6 +851,87 @@ func TestFindPreviousEndOfTurnCheckpoint_SkipsNonPerEditRef(t *testing.T) { } } +func TestCheckpointDiffRunScopeAggregatesCurrentRun(t *testing.T) { + now := time.Now().UTC() + workdir := t.TempDir() + projectDir := t.TempDir() + perEditStore := checkpoint.NewPerEditSnapshotStore(projectDir, workdir) + target := filepath.Join(workdir, "tracked.txt") + if err := os.WriteFile(target, []byte("one\n"), 0o644); err != nil { + t.Fatalf("WriteFile(base) error = %v", err) + } + if _, err := perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite(cp1) error = %v", err) + } + if err := os.WriteFile(target, []byte("two\n"), 0o644); err != nil { + t.Fatalf("WriteFile(two) error = %v", err) + } + if _, err := perEditStore.Finalize("cp-1"); err != nil { + t.Fatalf("Finalize(cp-1) error = %v", err) + } + perEditStore.Reset() + + if _, err := perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite(cp2) error = %v", err) + } + if err := os.WriteFile(target, []byte("three\n"), 0o644); err != nil { + t.Fatalf("WriteFile(three) error = %v", err) + } + if _, err := perEditStore.Finalize("cp-2"); err != nil { + t.Fatalf("Finalize(cp-2) error = %v", err) + } + perEditStore.Reset() + + spy := &checkpointStoreSpy{ + listRecords: []agentsession.CheckpointRecord{ + { + CheckpointID: "cp-other", + SessionID: "session-1", + RunID: "run-other", + CreatedAt: now.Add(3 * time.Second), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-other"), + }, + { + CheckpointID: "cp-2", + SessionID: "session-1", + RunID: "run-1", + CreatedAt: now.Add(2 * time.Second), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-2"), + }, + { + CheckpointID: "cp-1", + SessionID: "session-1", + RunID: "run-1", + CreatedAt: now.Add(time.Second), + CodeCheckpointRef: checkpoint.RefForPerEditCheckpoint("cp-1"), + }, + }, + } + service := &Service{ + checkpointStore: spy, + perEditStore: perEditStore, + } + + result, err := service.CheckpointDiff(context.Background(), CheckpointDiffInput{ + SessionID: "session-1", + CheckpointID: "cp-2", + RunID: "run-1", + Scope: "run", + }) + if err != nil { + t.Fatalf("CheckpointDiff(run) error = %v", err) + } + if result.CheckpointID != "cp-2" { + t.Fatalf("CheckpointID = %q, want cp-2", result.CheckpointID) + } + if len(result.Files.Modified) != 1 || result.Files.Modified[0] != "tracked.txt" { + t.Fatalf("modified files = %+v, want tracked.txt", result.Files.Modified) + } + if !strings.Contains(result.Patch, "-one") || !strings.Contains(result.Patch, "+three") || strings.Contains(result.Patch, "-two") || strings.Contains(result.Patch, "+four") { + t.Fatalf("run patch should compare one to three only, got:\n%s", result.Patch) + } +} + func mustReadRuntimeFile(t *testing.T, path string) []byte { t.Helper() data, err := os.ReadFile(path) @@ -941,9 +1026,9 @@ func TestCheckpointDiff_ScopeRun_ReturnsAggregateDiff(t *testing.T) { if len(modifiedPaths) != 1 || modifiedPaths[0] != "a.txt" { t.Fatalf("expected a.txt modified, got added=%v modified=%v", addedPaths, modifiedPaths) } - // CheckpointID should be set to the first checkpoint in the run - if result.CheckpointID != "cp-1" { - t.Fatalf("CheckpointID = %q, want cp-1", result.CheckpointID) + // 当前 run-scope diff 默认返回目标 checkpoint(未显式指定时为最新 checkpoint)。 + if result.CheckpointID != "cp-2" { + t.Fatalf("CheckpointID = %q, want cp-2", result.CheckpointID) } } @@ -987,8 +1072,8 @@ func TestCheckpointDiff_ScopeRun_NoCheckpointsForRunID(t *testing.T) { if err == nil { t.Fatal("expected error for run_id with no code checkpoints") } - if !strings.Contains(err.Error(), "no code checkpoints found") { - t.Fatalf("error = %v, want 'no code checkpoints found'", err) + if !strings.Contains(err.Error(), "no code checkpoint found") { + t.Fatalf("error = %v, want 'no code checkpoint found'", err) } } diff --git a/internal/runtime/checkpoint_gate.go b/internal/runtime/checkpoint_gate.go index 10050a40..bae5f42f 100644 --- a/internal/runtime/checkpoint_gate.go +++ b/internal/runtime/checkpoint_gate.go @@ -55,7 +55,7 @@ func (s *Service) createEndOfTurnCheckpoint(ctx context.Context, state *runState state.mu.Unlock() checkpointID := agentsession.NewID("checkpoint") - written, err := s.perEditStore.Finalize(checkpointID) + written, err := s.perEditStore.FinalizeWithExactState(checkpointID) if err != nil { log.Printf("checkpoint: end-of-turn finalize: %v", err) return diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go index 37e39723..abdc7217 100644 --- a/internal/runtime/checkpoint_restore.go +++ b/internal/runtime/checkpoint_restore.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "sort" "strings" "time" @@ -70,8 +71,7 @@ func (s *Service) restoreCheckpointCore(ctx context.Context, sessionID, checkpoi records, listErr := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{Limit: 5}) if listErr == nil { for _, r := range records { - if r.Reason == agentsession.CheckpointReasonEndOfTurn && - checkpoint.IsPerEditRef(r.CodeCheckpointRef) { + if r.Reason == agentsession.CheckpointReasonEndOfTurn && checkpoint.IsPerEditRef(r.CodeCheckpointRef) { fallbackRef = r.CodeCheckpointRef break } @@ -338,11 +338,11 @@ func (s *Service) CheckpointDiff(ctx context.Context, input CheckpointDiffInput) return CheckpointDiffResult{}, fmt.Errorf("checkpoint: session_id required") } - if strings.TrimSpace(input.Scope) == "run" { - return s.runDiff(ctx, sessionID, strings.TrimSpace(input.RunID)) + if strings.EqualFold(strings.TrimSpace(input.Scope), "run") { + return s.checkpointDiffForRun(ctx, input, sessionID) } - records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{Limit: 50}) + records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{Limit: 20}) if err != nil { return CheckpointDiffResult{}, fmt.Errorf("checkpoint: list for diff: %w", err) } @@ -425,68 +425,74 @@ func (s *Service) CheckpointDiff(ctx context.Context, input CheckpointDiffInput) return result, nil } -// runDiff 按 run_id 收集该 run 内所有 per-edit checkpoint, -// 以每个文件首次被触碰前的精确版本(v1.bin)作为 baseline, -// 与当前 workdir 状态作端到端对比。 -func (s *Service) runDiff(ctx context.Context, sessionID, runID string) (CheckpointDiffResult, error) { - if runID == "" { - return CheckpointDiffResult{}, fmt.Errorf("checkpoint: run_id required for scope=run") - } - - records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{ - RunID: runID, - }) +// checkpointDiffForRun 汇总指定 run 内的代码 checkpoint,返回本次请求初始状态到当前工作区的净变更。 +func (s *Service) checkpointDiffForRun(ctx context.Context, input CheckpointDiffInput, sessionID string) (CheckpointDiffResult, error) { + records, err := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{}) if err != nil { return CheckpointDiffResult{}, fmt.Errorf("checkpoint: list for run diff: %w", err) } - var ( - firstPerEditCheckpointID string - perEditIDs []string - ) - // ListCheckpoints 返回 DESC 顺序(最新在前),因此倒序遍历以获取最早的 per-edit checkpoint。 - for i := len(records) - 1; i >= 0; i-- { - r := records[i] - if r.RunID != runID { + targetID := strings.TrimSpace(input.CheckpointID) + runID := strings.TrimSpace(input.RunID) + var targetRecord *agentsession.CheckpointRecord + if targetID != "" { + for i := range records { + if records[i].CheckpointID == targetID { + targetRecord = &records[i] + break + } + } + if targetRecord == nil || !checkpoint.IsPerEditRef(targetRecord.CodeCheckpointRef) { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: %s not found or has no code snapshot", targetID) + } + if runID == "" { + runID = strings.TrimSpace(targetRecord.RunID) + } + } + if runID == "" { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: run_id required for run scope diff") + } + + codeRecords := make([]agentsession.CheckpointRecord, 0) + for _, record := range records { + if strings.TrimSpace(record.RunID) != runID { continue } - if !checkpoint.IsPerEditRef(r.CodeCheckpointRef) { + if !checkpoint.IsPerEditRef(record.CodeCheckpointRef) { continue } - perEditID := checkpoint.PerEditCheckpointIDFromRef(r.CodeCheckpointRef) - perEditIDs = append(perEditIDs, perEditID) - if firstPerEditCheckpointID == "" { - firstPerEditCheckpointID = r.CheckpointID + if record.Reason == agentsession.CheckpointReasonGuard { + continue } + if targetRecord != nil && record.CreatedAt.After(targetRecord.CreatedAt) { + continue + } + codeRecords = append(codeRecords, record) } - - if len(perEditIDs) == 0 { - return CheckpointDiffResult{}, fmt.Errorf("checkpoint: no code checkpoints found for run_id %s", runID) + if len(codeRecords) == 0 { + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: no code checkpoint found for run %s", runID) + } + sort.Slice(codeRecords, func(i, j int) bool { + return codeRecords[i].CreatedAt.Before(codeRecords[j].CreatedAt) + }) + if targetRecord == nil { + targetRecord = &codeRecords[len(codeRecords)-1] } - // 查找上一个 run 最后一个 checkpoint 的 FileVersions,用于版本号比较过滤历史文件。 - var prevFileVersions map[string]int - if allRecords, listErr := s.checkpointStore.ListCheckpoints(ctx, sessionID, checkpoint.ListCheckpointOpts{}); listErr == nil { - for _, r := range allRecords { - if r.RunID != runID && checkpoint.IsPerEditRef(r.CodeCheckpointRef) { - prevPerEditID := checkpoint.PerEditCheckpointIDFromRef(r.CodeCheckpointRef) - if fv, fvErr := s.perEditStore.GetCheckpointFileVersions(prevPerEditID); fvErr == nil { - prevFileVersions = fv - } - break - } + perEditIDs := make([]string, 0, len(codeRecords)) + for _, record := range codeRecords { + perEditID := checkpoint.PerEditCheckpointIDFromRef(record.CodeCheckpointRef) + if perEditID != "" { + perEditIDs = append(perEditIDs, perEditID) } } - - patch, changes, err := s.perEditStore.RunAggregateDiff(ctx, perEditIDs, prevFileVersions) + targetPerEditID := checkpoint.PerEditCheckpointIDFromRef(targetRecord.CodeCheckpointRef) + patch, changes, err := s.perEditStore.DiffCheckpointsToCheckpoint(ctx, perEditIDs, targetPerEditID) if err != nil { - return CheckpointDiffResult{}, fmt.Errorf("checkpoint: run aggregate diff: %w", err) + return CheckpointDiffResult{}, fmt.Errorf("checkpoint: per-edit run diff: %w", err) } - result := CheckpointDiffResult{ - CheckpointID: firstPerEditCheckpointID, - Patch: patch, - } + result := CheckpointDiffResult{CheckpointID: targetRecord.CheckpointID, Patch: patch} for _, c := range changes { switch c.Kind { case checkpoint.FileChangeAdded: @@ -497,6 +503,5 @@ func (s *Service) runDiff(ctx context.Context, sessionID, runID string) (Checkpo result.Files.Modified = append(result.Files.Modified, c.Path) } } - return result, nil } diff --git a/internal/runtime/compact.go b/internal/runtime/compact.go index 55b001c3..141c88b7 100644 --- a/internal/runtime/compact.go +++ b/internal/runtime/compact.go @@ -286,7 +286,7 @@ func (s *Service) createCompactCheckpoint(ctx context.Context, runID string, ses // Per-edit snapshot if pending writes exist this turn. if s.perEditStore != nil { - if written, err := s.perEditStore.Finalize(checkpointID); err == nil && written { + if written, err := s.perEditStore.FinalizeWithExactState(checkpointID); err == nil && written { record.CodeCheckpointRef = checkpoint.RefForPerEditCheckpoint(checkpointID) s.perEditStore.Reset() } diff --git a/internal/runtime/run.go b/internal/runtime/run.go index 901245a6..2294faf4 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -104,10 +104,13 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { } if statePtr != nil && s.perEditStore != nil && statePtr.baselineCheckpointID != "" && statePtr.lastEndOfTurnCheckpointID != "" { runEndCtx := context.Background() - records, listErr := s.checkpointStore.ListCheckpoints(runEndCtx, statePtr.session.ID, checkpoint.ListCheckpointOpts{RunID: statePtr.runID}) + records, listErr := s.checkpointStore.ListCheckpoints(runEndCtx, statePtr.session.ID, checkpoint.ListCheckpointOpts{}) if listErr == nil { var perEditIDs []string for _, r := range records { + if strings.TrimSpace(r.RunID) != statePtr.runID { + continue + } if checkpoint.IsPerEditRef(r.CodeCheckpointRef) { perEditIDs = append(perEditIDs, checkpoint.PerEditCheckpointIDFromRef(r.CodeCheckpointRef)) } @@ -564,6 +567,10 @@ func (s *Service) prepareTurnBudgetSnapshot(ctx context.Context, state *runState stage := resolvePlanningStageForState(state) readOnly := isReadOnlyPlanningStage(stage) injectFullPlan := planningNeedsFullPlan(state) + resolvedProvider, model, err := resolveCompactProviderSelection(state.session, cfg) + if err != nil { + return TurnBudgetSnapshot{}, false, err + } builtContext, err := s.contextBuilder.Build(ctx, agentcontext.BuildInput{ Messages: state.session.Messages, @@ -580,8 +587,8 @@ func (s *Service) prepareTurnBudgetSnapshot(ctx context.Context, state *runState ProjectRoot: cfg.Workdir, Workdir: activeWorkdir, Shell: cfg.Shell, - Provider: cfg.SelectedProvider, - Model: cfg.CurrentModel, + Provider: resolvedProvider.Name, + Model: model, SessionInputTokens: state.session.TokenInputTotal, SessionOutputTokens: state.session.TokenOutputTotal, }, @@ -608,10 +615,6 @@ func (s *Service) prepareTurnBudgetSnapshot(ctx context.Context, state *runState } toolSpecs = prioritizeToolSpecsBySkillHints(toolSpecs, activeSkills) - resolvedProvider, model, err := resolveCompactProviderSelection(state.session, cfg) - if err != nil { - return TurnBudgetSnapshot{}, false, err - } providerRuntimeCfg, err := resolvedProvider.ToRuntimeConfig() if err != nil { return TurnBudgetSnapshot{}, false, err @@ -627,7 +630,10 @@ func (s *Service) prepareTurnBudgetSnapshot(ctx context.Context, state *runState if notificationHint := strings.TrimSpace(s.drainHookNotificationsForTurn(state)); notificationHint != "" { systemPrompt = mergeEphemeralHookNotificationIntoSystemPrompt(systemPrompt, notificationHint) } - promptBudget, budgetSource, contextWindow := s.resolvePromptBudget(ctx, cfg) + budgetCfg := cfg + budgetCfg.SelectedProvider = resolvedProvider.Name + budgetCfg.CurrentModel = model + promptBudget, budgetSource, contextWindow := s.resolvePromptBudget(ctx, budgetCfg) requestMessages := append([]providertypes.Message(nil), builtContext.Messages...) thinkingCfg, thinkingErr := resolveThinkingConfig( modelCapabilityHintsForRequest(model, resolvedProvider.Models), diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 4dd5b06e..a5608684 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -1555,6 +1555,80 @@ func TestServiceRunDelegatesToContextBuilder(t *testing.T) { } } +func TestServiceRunUsesSessionSelectionForMetadataAndBudget(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.SelectedProvider = config.GeminiName + cfg.CurrentModel = "gemini-current-model" + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + session := agentsession.New("session selection") + session.ID = "session-selection" + session.Provider = config.OpenAIName + session.Model = "openai-session-model" + store.sessions[session.ID] = cloneSession(session) + + registry := tools.NewRegistry() + registry.Register(&stubTool{name: "filesystem_read_file", content: "default"}) + + builder := &stubContextBuilder{ + buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { + return agentcontext.BuildResult{ + SystemPrompt: "delegated prompt", + Messages: []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("delegated message")}}, + }, + }, nil + }, + } + scripted := &scriptedProvider{ + streams: [][]providertypes.StreamEvent{ + {providertypes.NewTextDeltaStreamEvent("done")}, + }, + } + factory := &scriptedProviderFactory{provider: scripted} + service := NewWithFactory(manager, registry, store, factory, builder) + + var resolvedBudgetCfg config.Config + service.SetBudgetResolver(budgetResolverFunc(func(ctx context.Context, cfg config.Config) (BudgetResolution, error) { + resolvedBudgetCfg = cfg + return BudgetResolution{PromptBudget: 12345, Source: "derived", ContextWindow: 200000}, nil + })) + + if err := service.Run(context.Background(), UserInput{ + SessionID: session.ID, + RunID: "run-session-selection", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + + if builder.lastInput.Metadata.Provider != config.OpenAIName { + t.Fatalf("builder provider = %q, want %q", builder.lastInput.Metadata.Provider, config.OpenAIName) + } + if builder.lastInput.Metadata.Model != "openai-session-model" { + t.Fatalf("builder model = %q, want %q", builder.lastInput.Metadata.Model, "openai-session-model") + } + if resolvedBudgetCfg.SelectedProvider != config.OpenAIName { + t.Fatalf("budget provider = %q, want %q", resolvedBudgetCfg.SelectedProvider, config.OpenAIName) + } + if resolvedBudgetCfg.CurrentModel != "openai-session-model" { + t.Fatalf("budget model = %q, want %q", resolvedBudgetCfg.CurrentModel, "openai-session-model") + } + if len(factory.configs) != 1 || factory.configs[0].Name != config.OpenAIName { + t.Fatalf("factory configs = %+v, want one openai config", factory.configs) + } + if len(scripted.requests) != 1 || scripted.requests[0].Model != "openai-session-model" { + t.Fatalf("requests = %+v, want one openai-session-model request", scripted.requests) + } +} + func TestServiceRunCanDisableMicroCompactViaConfig(t *testing.T) { t.Parallel() diff --git a/web/src/api/protocol.ts b/web/src/api/protocol.ts index 150a556f..d485e7ed 100644 --- a/web/src/api/protocol.ts +++ b/web/src/api/protocol.ts @@ -248,6 +248,8 @@ export interface UndoRestoreParams { export interface CheckpointDiffParams { session_id: string checkpoint_id?: string + run_id?: string + scope?: 'run' | string } /** gateway.resolvePermission 参数 */ @@ -515,7 +517,7 @@ export interface ModelEntry { } /** gateway.listModels 响应 */ -export type ListModelsResult = RPCResult<{ models: ModelEntry[]; selected_model_id?: string }> +export type ListModelsResult = RPCResult<{ models: ModelEntry[]; selected_provider_id?: string; selected_model_id?: string }> /** gateway.setSessionModel 参数 */ export interface SetSessionModelParams { diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx new file mode 100644 index 00000000..f9de008e --- /dev/null +++ b/web/src/components/chat/ChatInput.test.tsx @@ -0,0 +1,49 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { fireEvent, render, screen } from '@testing-library/react' +import ChatInput from './ChatInput' +import { useChatStore } from '@/stores/useChatStore' +import { useComposerStore } from '@/stores/useComposerStore' +import { useSessionStore } from '@/stores/useSessionStore' + +vi.mock('@/context/RuntimeProvider', () => ({ + useGatewayAPI: () => null, +})) + +describe('ChatInput', () => { + beforeEach(() => { + useComposerStore.setState({ composerText: '' }) + useSessionStore.setState({ currentSessionId: '' } as any) + useChatStore.setState({ + isGenerating: false, + messages: [], + permissionRequests: [], + agentMode: 'build', + permissionMode: 'default', + } as any) + }) + + it('shows the default/bypass selector in build mode', () => { + render() + + expect(screen.getByRole('button', { name: 'Build' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'default' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'bypass' })).toBeInTheDocument() + }) + + it('hides the permission selector after switching to plan mode', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'Build' })) + + expect(screen.getByRole('button', { name: 'Plan' })).toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'default' })).not.toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'bypass' })).not.toBeInTheDocument() + }) + + it('does not render the unimplemented attachment and mention buttons', () => { + render() + + expect(screen.queryByTitle('附加文件')).not.toBeInTheDocument() + expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() + }) +}) diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index 5e62d9ff..4b9df9d2 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -17,7 +17,7 @@ import { } from '@/utils/slashCommands' import SlashCommandMenu from './SlashCommandMenu' import SkillPicker from './SkillPicker' -import { Send, Square, Paperclip, AtSign } from 'lucide-react' +import { Send, Square } from 'lucide-react' /** 聊天输入框 */ export default function ChatInput() { @@ -35,6 +35,8 @@ export default function ChatInput() { const sessionId = useSessionStore((s) => s.currentSessionId) const agentMode = useChatStore((s) => s.agentMode) const setAgentMode = useChatStore((s) => s.setAgentMode) + const permissionMode = useChatStore((s) => s.permissionMode) + const setPermissionMode = useChatStore((s) => s.setPermissionMode) // Slash command 菜单状态 const [showSlashMenu, setShowSlashMenu] = useState(false) @@ -409,12 +411,6 @@ export default function ChatInput() { />
- - + {agentMode === 'build' && ( +
+ + +
+ )}
-
-
- {change.diff?.map((line, index) => ( - - ))} +
+ {displayHunks.length === 0 ? ( +
当前文件没有可展示的 diff
+ ) : ( + displayHunks.map((hunk, index) => ( +
+ {hunk.lines.map((line, lineIndex) => ( + + ))} +
+ )) + )}
)} @@ -182,8 +156,8 @@ export default function FileChangePanel() {
- 文件更改 -
@@ -198,7 +172,7 @@ export default function FileChangePanel() {
{fileChanges.length === 0 ? ( -
当前会话暂无文件更改
+
当前会话暂无文件变更
) : ( fileChanges.map((change) => ( = { +const styles: Record = { container: { display: 'flex', flexDirection: 'column', @@ -365,11 +339,18 @@ const styles: Record = { fontSize: 11, fontFamily: 'var(--font-ui)', }, - diffBlock: { - maxHeight: 260, - overflow: 'auto', + diffScroller: { + overflowX: 'auto', + overflowY: 'visible', background: 'var(--code-bg)', - padding: '6px 0', + padding: 8, + }, + hunkBlock: { + border: '1px solid rgba(148, 163, 184, 0.18)', + borderRadius: 'var(--radius-sm)', + overflow: 'hidden', + background: 'rgba(15, 23, 42, 0.22)', + marginBottom: 10, }, diffLine: { display: 'flex', @@ -389,4 +370,9 @@ const styles: Record = { diffText: { whiteSpace: 'pre', }, + emptyDiff: { + padding: '10px 12px', + fontSize: 11, + color: 'var(--text-tertiary)', + }, } diff --git a/web/src/stores/useChatStore.test.ts b/web/src/stores/useChatStore.test.ts index 35ca432b..6325e9d3 100644 --- a/web/src/stores/useChatStore.test.ts +++ b/web/src/stores/useChatStore.test.ts @@ -6,10 +6,14 @@ beforeEach(() => { messages: [], isGenerating: false, streamingMessageId: '', + streamingThinkingMessageId: '', permissionRequests: [], tokenUsage: null, phase: '', stopReason: '', + isTransitioning: false, + agentMode: 'build', + permissionMode: 'default', } as any) }) @@ -118,4 +122,20 @@ describe('useChatStore', () => { useChatStore.getState().setGenerating(false) expect(useChatStore.getState().isGenerating).toBe(false) }) + + it('starts with default permission mode', () => { + expect(useChatStore.getState().permissionMode).toBe('default') + }) + + it('setPermissionMode updates the permission mode', () => { + useChatStore.getState().setPermissionMode('bypass') + expect(useChatStore.getState().permissionMode).toBe('bypass') + }) + + it('clearMessages resets permission mode to default', () => { + const store = useChatStore.getState() + store.setPermissionMode('bypass') + store.clearMessages() + expect(useChatStore.getState().permissionMode).toBe('default') + }) }) diff --git a/web/src/stores/useChatStore.ts b/web/src/stores/useChatStore.ts index 16c5834d..62644582 100644 --- a/web/src/stores/useChatStore.ts +++ b/web/src/stores/useChatStore.ts @@ -60,6 +60,8 @@ interface ChatState { isTransitioning: boolean /** 当前 Agent 工作模式 */ agentMode: 'build' | 'plan' + /** Build 模式下的权限审批策略 */ + permissionMode: 'default' | 'bypass' // Actions addMessage: (msg: ChatMessage) => void @@ -101,6 +103,7 @@ interface ChatState { clearMessages: () => void addSystemMessage: (content: string) => void setAgentMode: (mode: 'build' | 'plan') => void + setPermissionMode: (mode: 'default' | 'bypass') => void } let msgIdCounter = 0 @@ -181,6 +184,7 @@ export const useChatStore = create((set) => ({ stopReason: '', isTransitioning: false, agentMode: 'build', + permissionMode: 'default', addMessage: (msg) => set((s) => ({ messages: [...s.messages, msg] })), removeMessage: (id) => set((s) => ({ messages: s.messages.filter((m) => m.id !== id) })), @@ -367,6 +371,7 @@ export const useChatStore = create((set) => ({ set((s) => ({ messages: [...s.messages, createSystemMessage(content)] })), setAgentMode: (agentMode) => set({ agentMode }), + setPermissionMode: (permissionMode) => set({ permissionMode }), /** 清理全部聊天状态,包括权限请求、token用量等。同时重置 eventBridge 模块级游标,避免跨会话泄漏。 */ clearMessages: () => { @@ -382,6 +387,7 @@ export const useChatStore = create((set) => ({ stopReason: '', isTransitioning: false, agentMode: 'build', + permissionMode: 'default', }) }, })) diff --git a/web/src/stores/useUIStore.ts b/web/src/stores/useUIStore.ts index 1cf0fa87..f1c1131f 100644 --- a/web/src/stores/useUIStore.ts +++ b/web/src/stores/useUIStore.ts @@ -1,4 +1,5 @@ import { create } from 'zustand' +import type { DiffHunk, DiffLine } from '@/utils/patchParser' /** Toast 通知 */ export interface Toast { @@ -14,7 +15,8 @@ export interface FileChange { status: 'added' | 'modified' | 'deleted' | 'accepted' | 'rejected' additions: number deletions: number - diff?: { type: 'add' | 'del' | 'header'; content: string }[] + diff?: DiffLine[] + hunks?: DiffHunk[] checkpoint_id?: string } @@ -54,6 +56,7 @@ interface UIState { setTheme: (theme: 'light' | 'dark') => void setSearchQuery: (q: string) => void addFileChange: (change: FileChange) => void + replaceFileChanges: (changes: FileChange[]) => void acceptFileChange: (id: string) => void rejectFileChange: (id: string) => void clearFileChanges: () => void @@ -90,6 +93,7 @@ export const useUIStore = create((set) => ({ set((s) => ({ fileChanges: [...s.fileChanges, change], })), + replaceFileChanges: (fileChanges) => set({ fileChanges }), acceptFileChange: (id) => set((s) => ({ fileChanges: s.fileChanges.map((c) => (c.id === id ? { ...c, status: 'accepted' as const } : c)), diff --git a/web/src/utils/eventBridge.test.ts b/web/src/utils/eventBridge.test.ts index ebe9fb40..1038071f 100644 --- a/web/src/utils/eventBridge.test.ts +++ b/web/src/utils/eventBridge.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, beforeEach } from 'vitest' +import { describe, it, expect, beforeEach, vi } from 'vitest' import { useChatStore } from '@/stores/useChatStore' import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore } from '@/stores/useSessionStore' @@ -7,11 +7,13 @@ import { useUIStore } from '@/stores/useUIStore' import { handleGatewayEvent, resetEventBridgeCursors } from './eventBridge' import { EventType } from '@/api/protocol' -function createMockGatewayAPI() { +function createMockGatewayAPI(overrides: Record = {}) { return { listSessions: async () => ({ payload: { sessions: [] } }), loadSession: async () => ({ payload: { messages: [] } }), bindStream: async () => ({}), + checkpointDiff: async () => ({ payload: { checkpoint_id: 'cp', files: {}, patch: '' } }), + ...overrides, } as any } @@ -440,4 +442,191 @@ describe('eventBridge', () => { expect(newChange).toBeDefined() expect(newChange?.checkpoint_id).toBeUndefined() }) + + it('replaces transient tool diffs with run-scoped checkpoint diff on end-of-turn checkpoint', async () => { + const checkpointDiff = vi.fn(async () => ({ + payload: { + checkpoint_id: 'cp2', + files: { modified: ['a.txt'] }, + patch: '--- a/a.txt\n+++ b/a.txt\n@@ -1,3 +1,3 @@\n line 1\n-A\n+C\n line 3\n@@ -10,3 +10,3 @@\n line 10\n-B\n+D\n line 12\n', + }, + })) + const api = createMockGatewayAPI({ checkpointDiff }) + + handleGatewayEvent({ + type: EventType.InputNormalized, + payload: { payload: { runtime_event_type: EventType.InputNormalized, payload: { session_id: 'sess-1', run_id: 'run-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"a.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolDiff, + payload: { payload: { runtime_event_type: EventType.ToolDiff, payload: { tool_call_id: 'tc1', tool_name: 'filesystem_write_file', file_path: 'a.txt', diff: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-A\n+B\n' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolDiff, + payload: { payload: { runtime_event_type: EventType.ToolDiff, payload: { tool_call_id: 'tc2', tool_name: 'filesystem_write_file', file_path: 'a.txt', diff: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-B\n+C\n' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + expect(useUIStore.getState().fileChanges[0]?.hunks?.[0]?.lines.map((line) => line.content)).toEqual([ + '@@ -1 +1 @@', + 'B', + 'C', + ]) + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp2', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'end_of_turn' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + expect(checkpointDiff).toHaveBeenCalledWith({ + session_id: 'sess-1', + run_id: 'run-1', + checkpoint_id: 'cp2', + scope: 'run', + }) + const changes = useUIStore.getState().fileChanges + expect(changes).toHaveLength(1) + expect(changes[0]).toMatchObject({ path: 'a.txt', status: 'modified', additions: 2, deletions: 2 }) + expect(changes[0].hunks).toHaveLength(2) + expect(changes[0].hunks?.[0]?.lines.map((line) => line.content)).toEqual([ + '@@ -1,3 +1,3 @@', + 'line 1', + 'A', + 'C', + 'line 3', + ]) + expect(changes[0].hunks?.[1]?.lines.map((line) => line.content)).toEqual([ + '@@ -10,3 +10,3 @@', + 'line 10', + 'B', + 'D', + 'line 12', + ]) + }) + + it('stores hunk structure for transient tool diffs before aggregate checkpoint diff arrives', () => { + const api = createMockGatewayAPI() + + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"a.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolDiff, + payload: { + payload: { + runtime_event_type: EventType.ToolDiff, + payload: { + tool_call_id: 'tc1', + tool_name: 'filesystem_write_file', + file_path: 'a.txt', + diff: '--- a/a.txt\n+++ b/a.txt\n@@ -1,3 +1,3 @@\n line 1\n-old\n+new\n line 3\n@@ -10,2 +10,3 @@\n line 10\n+line 11\n line 12\n', + }, + }, + }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'a.txt') + expect(change?.hunks).toHaveLength(2) + expect(change?.hunks?.[0]?.lines.map((line) => line.type)).toEqual(['header', 'context', 'del', 'add', 'context']) + expect(change?.hunks?.[1]?.lines.map((line) => line.content)).toEqual([ + '@@ -10,2 +10,3 @@', + 'line 10', + 'line 11', + 'line 12', + ]) + }) + + it('keeps transient tool diffs visible when backend sends simplified diff without @@ header', () => { + const api = createMockGatewayAPI() + + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"a.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + handleGatewayEvent({ + type: EventType.ToolDiff, + payload: { + payload: { + runtime_event_type: EventType.ToolDiff, + payload: { + tool_call_id: 'tc1', + tool_name: 'filesystem_write_file', + file_path: 'a.txt', + diff: '--- a/a.txt\n+++ b/a.txt\n-old\n+new\n', + }, + }, + }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + const change = useUIStore.getState().fileChanges.find((entry) => entry.path === 'a.txt') + expect(change).toMatchObject({ additions: 1, deletions: 1 }) + expect(change?.hunks).toHaveLength(1) + expect(change?.hunks?.[0]?.lines.map((line) => line.content)).toEqual(['old', 'new']) + }) + + it('filters final run-scoped modified entries that have no renderable patch', async () => { + const checkpointDiff = vi.fn(async () => ({ + payload: { + checkpoint_id: 'cp2', + files: { modified: ['a.txt', 'b.txt'] }, + patch: '--- a/a.txt\n+++ b/a.txt\n@@ -1 +1 @@\n-old\n+new\n', + }, + })) + const api = createMockGatewayAPI({ checkpointDiff }) + + handleGatewayEvent({ + type: EventType.InputNormalized, + payload: { payload: { runtime_event_type: EventType.InputNormalized, payload: { session_id: 'sess-1', run_id: 'run-1' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + handleGatewayEvent({ + type: EventType.ToolStart, + payload: { payload: { runtime_event_type: EventType.ToolStart, payload: { name: 'filesystem_write_file', id: 'tc1', arguments: '{"path":"b.txt"}' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + + expect(useUIStore.getState().fileChanges.find((entry) => entry.path === 'b.txt')).toBeDefined() + + handleGatewayEvent({ + type: EventType.CheckpointCreated, + payload: { payload: { runtime_event_type: EventType.CheckpointCreated, payload: { checkpoint_id: 'cp2', code_checkpoint_ref: 'c', session_checkpoint_ref: 's', commit_hash: '', reason: 'end_of_turn' } } }, + session_id: 'sess-1', + run_id: 'run-1', + }, api) + await Promise.resolve() + await Promise.resolve() + + const changes = useUIStore.getState().fileChanges + expect(changes).toHaveLength(1) + expect(changes[0]).toMatchObject({ path: 'a.txt', additions: 1, deletions: 1 }) + expect(changes.find((entry) => entry.path === 'b.txt')).toBeUndefined() + }) }) diff --git a/web/src/utils/eventBridge.ts b/web/src/utils/eventBridge.ts index e02f7c2f..95b8cf96 100644 --- a/web/src/utils/eventBridge.ts +++ b/web/src/utils/eventBridge.ts @@ -4,6 +4,7 @@ import { type BudgetCheckedPayload, type BudgetEstimateFailedPayload, type CheckpointCreatedPayload, + type CheckpointDiffResultPayload, type CheckpointRestoredPayload, type CheckpointUndoRestorePayload, type CheckpointWarningPayload, @@ -26,7 +27,7 @@ import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore } from '@/stores/useSessionStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' import { useWorkspaceStore } from '@/stores/useWorkspaceStore' -import { parseSingleFileDiff, type ParsedFileDiff } from '@/utils/patchParser' +import { parseSingleFileDiff, parseUnifiedPatch, type ParsedFileDiff } from '@/utils/patchParser' type PayloadRecord = Record | undefined @@ -35,14 +36,16 @@ type PayloadRecord = Record | undefined let _latestVerificationMsgId: string | undefined let _latestDoneToolCallId: string | undefined -// 模块级缓存最新的 checkpoint_id,供 reject 回退使用 +// 模块级缓存最新的 checkpoint_id,用于工具占位条目关联后续端到端 diff。 let _latestCheckpointId: string | undefined +let _latestRunDiffRequestId = 0 /** 重置模块级游标 —— 在截断聊天历史 / 切换会话等场景调用,避免后续事件挂到已被移除的消息上 */ export function resetEventBridgeCursors() { _latestVerificationMsgId = undefined _latestDoneToolCallId = undefined _latestCheckpointId = undefined + _latestRunDiffRequestId += 1 } /** @@ -85,6 +88,7 @@ function _upsertFileChange( additions: parsed?.additions ?? c.additions, deletions: parsed?.deletions ?? c.deletions, diff: parsed?.lines ?? c.diff, + hunks: parsed?.hunks ?? c.hunks, checkpoint_id: _latestCheckpointId ?? c.checkpoint_id, } : c, @@ -98,6 +102,7 @@ function _upsertFileChange( additions: parsed?.additions ?? 0, deletions: parsed?.deletions ?? 0, diff: parsed?.lines, + hunks: parsed?.hunks, checkpoint_id: _latestCheckpointId, }) } @@ -164,6 +169,73 @@ function _applyToolDiff(payload: ToolDiffPayload) { } } +function _fileChangesFromCheckpointDiff(diff: CheckpointDiffResultPayload) { + const parsed = diff.patch ? parseUnifiedPatch(diff.patch) : {} + const parsedByPath = new Map() + for (const [path, parsedDiff] of Object.entries(parsed)) { + const normalized = normalizeFilePath(path) + if (normalized) parsedByPath.set(normalized, parsedDiff) + } + const byPath = new Map() + for (const path of diff.files?.added ?? []) byPath.set(normalizeFilePath(path), 'added') + for (const path of diff.files?.modified ?? []) byPath.set(normalizeFilePath(path), 'modified') + for (const path of diff.files?.deleted ?? []) byPath.set(normalizeFilePath(path), 'deleted') + + for (const path of parsedByPath.keys()) { + if (!byPath.has(path)) byPath.set(path, 'modified') + } + + return Array.from(byPath.entries()) + .filter(([path]) => path) + .filter(([path]) => { + const parsedDiff = parsedByPath.get(path) + return Boolean( + parsedDiff && + (parsedDiff.additions > 0 || + parsedDiff.deletions > 0 || + (parsedDiff.hunks?.length ?? 0) > 0 || + (parsedDiff.lines?.length ?? 0) > 0), + ) + }) + .sort(([a], [b]) => a.localeCompare(b)) + .map(([path, status]) => { + const parsedDiff = parsedByPath.get(path) + return { + id: `fc_${path}`, + path, + status, + additions: parsedDiff?.additions ?? 0, + deletions: parsedDiff?.deletions ?? 0, + diff: parsedDiff?.lines, + hunks: parsedDiff?.hunks, + checkpoint_id: diff.checkpoint_id, + } + }) +} + +function _refreshRunFileChanges( + gatewayAPI: GatewayAPI, + sessionId: string, + runId: string, + checkpointId: string, +) { + const requestId = ++_latestRunDiffRequestId + gatewayAPI.checkpointDiff({ + session_id: sessionId, + run_id: runId, + checkpoint_id: checkpointId, + scope: 'run', + }).then((result) => { + if (requestId !== _latestRunDiffRequestId) return + if (runId !== useGatewayStore.getState().currentRunId) return + if (sessionId !== useSessionStore.getState().currentSessionId) return + if (!result?.payload) return + useUIStore.getState().replaceFileChanges(_fileChangesFromCheckpointDiff(result.payload)) + }).catch((error) => { + console.warn('[eventBridge] checkpoint.diff run scope failed:', error) + }) +} + function normalizePermissionPayload(raw: unknown): PermissionRequestPayload | null { const r = raw as Record | undefined if (!r || typeof r !== 'object') return null @@ -317,6 +389,7 @@ export function handleGatewayEvent(frame: MessageFrame, gatewayAPI: GatewayAPI) const sessionId = strField(eventPayload, 'session_id') || frameSessionId || '' const runId = strField(eventPayload, 'run_id') || frameRunId || '' if (runId) gwStore.setCurrentRunId(runId) + useUIStore.getState().clearFileChanges() if (sessionId && sessionId !== useSessionStore.getState().currentSessionId) { useSessionStore.getState().setCurrentSessionId(sessionId) } @@ -565,6 +638,9 @@ export function handleGatewayEvent(frame: MessageFrame, gatewayAPI: GatewayAPI) chatStore.attachCheckpointToToolCall(_latestDoneToolCallId, payload.checkpoint_id) } _latestCheckpointId = payload.checkpoint_id + if (payload.reason === 'end_of_turn' && gatewayAPI && frameSessionId && frameRunId) { + _refreshRunFileChanges(gatewayAPI, frameSessionId, frameRunId, payload.checkpoint_id) + } } break } diff --git a/web/src/utils/patchParser.test.ts b/web/src/utils/patchParser.test.ts new file mode 100644 index 00000000..86fc6e56 --- /dev/null +++ b/web/src/utils/patchParser.test.ts @@ -0,0 +1,107 @@ +import { describe, expect, it } from 'vitest' +import { parseSingleFileDiff, parseUnifiedPatch } from './patchParser' + +describe('patchParser', () => { + it('parses multiple hunks with context lines for a single file diff', () => { + const parsed = parseSingleFileDiff( + [ + '--- a/src/a.txt', + '+++ b/src/a.txt', + '@@ -1,3 +1,3 @@', + ' line 1', + '-line 2 old', + '+line 2 new', + ' line 3', + '@@ -10,3 +10,4 @@', + ' line 10', + '-line 11 old', + '+line 11 new', + '+line 12 added', + ].join('\n'), + ) + + expect(parsed.additions).toBe(3) + expect(parsed.deletions).toBe(2) + expect(parsed.hunks).toHaveLength(2) + expect(parsed.hunks[0]?.header).toBe('@@ -1,3 +1,3 @@') + expect(parsed.hunks[0]?.lines.map((line) => line.type)).toEqual(['header', 'context', 'del', 'add', 'context']) + expect(parsed.hunks[1]?.lines.map((line) => line.content)).toEqual([ + '@@ -10,3 +10,4 @@', + 'line 10', + 'line 11 old', + 'line 11 new', + 'line 12 added', + ]) + }) + + it('parses unified patch for modified, added, and deleted files', () => { + const parsed = parseUnifiedPatch( + [ + 'diff --git a/src/a.txt b/src/a.txt', + '--- a/src/a.txt', + '+++ b/src/a.txt', + '@@ -1 +1 @@', + '-before', + '+after', + 'diff --git a/src/new.txt b/src/new.txt', + '--- /dev/null', + '+++ b/src/new.txt', + '@@ -0,0 +1,2 @@', + '+new 1', + '+new 2', + 'diff --git a/src/old.txt b/src/old.txt', + '--- a/src/old.txt', + '+++ /dev/null', + '@@ -1,2 +0,0 @@', + '-old 1', + '-old 2', + ].join('\n'), + ) + + expect(Object.keys(parsed)).toEqual(['src/a.txt', 'src/new.txt', 'src/old.txt']) + expect(parsed['src/a.txt']?.hunks).toHaveLength(1) + expect(parsed['src/a.txt']?.additions).toBe(1) + expect(parsed['src/a.txt']?.deletions).toBe(1) + expect(parsed['src/new.txt']?.hunks[0]?.lines.map((line) => line.type)).toEqual(['header', 'add', 'add']) + expect(parsed['src/old.txt']?.hunks[0]?.lines.map((line) => line.type)).toEqual(['header', 'del', 'del']) + }) + + it('falls back to an implicit hunk when diff has no @@ header', () => { + const parsed = parseSingleFileDiff( + [ + '--- a/src/a.txt', + '+++ b/src/a.txt', + '-before', + '+after', + ].join('\n'), + ) + + expect(parsed.additions).toBe(1) + expect(parsed.deletions).toBe(1) + expect(parsed.hunks).toHaveLength(1) + expect(parsed.hunks[0]?.header).toBe('') + expect(parsed.hunks[0]?.lines.map((line) => line.type)).toEqual(['del', 'add']) + expect(parsed.lines.map((line) => line.content)).toEqual(['before', 'after']) + }) + + it('parses unified patch without @@ headers by creating implicit hunks per file', () => { + const parsed = parseUnifiedPatch( + [ + 'diff --git a/src/a.txt b/src/a.txt', + '--- a/src/a.txt', + '+++ b/src/a.txt', + '-before', + '+after', + 'diff --git a/src/b.txt b/src/b.txt', + '--- /dev/null', + '+++ b/src/b.txt', + '+new file', + ].join('\n'), + ) + + expect(parsed['src/a.txt']?.hunks).toHaveLength(1) + expect(parsed['src/a.txt']?.hunks[0]?.lines.map((line) => line.content)).toEqual(['before', 'after']) + expect(parsed['src/b.txt']?.hunks).toHaveLength(1) + expect(parsed['src/b.txt']?.hunks[0]?.lines.map((line) => line.type)).toEqual(['add']) + }) +}) diff --git a/web/src/utils/patchParser.ts b/web/src/utils/patchParser.ts index f2bf0f27..623f59e8 100644 --- a/web/src/utils/patchParser.ts +++ b/web/src/utils/patchParser.ts @@ -1,120 +1,162 @@ /** - * 解析 unified diff patch 字符串,按文件路径分组返回 diff 行。 + * 解析 unified diff patch,保留文件级统计与 hunk 结构。 */ +export type DiffLineType = 'add' | 'del' | 'header' | 'context' + +export interface DiffLine { + type: DiffLineType + content: string +} + +export interface DiffHunk { + header: string + lines: DiffLine[] + additions: number + deletions: number +} + export interface ParsedFileDiff { additions: number deletions: number - lines: { type: 'add' | 'del' | 'header'; content: string }[] + lines: DiffLine[] + hunks: DiffHunk[] +} + +function createParsedFileDiff(): ParsedFileDiff { + return { additions: 0, deletions: 0, lines: [], hunks: [] } +} + +function pushLine(target: ParsedFileDiff, hunk: DiffHunk | null, line: DiffLine) { + target.lines.push(line) + if (hunk) hunk.lines.push(line) +} + +function startHunk(target: ParsedFileDiff, header: string): DiffHunk { + const hunk: DiffHunk = { + header, + lines: [{ type: 'header', content: header }], + additions: 0, + deletions: 0, + } + target.hunks.push(hunk) + target.lines.push(hunk.lines[0]) + return hunk +} + +function startImplicitHunk(target: ParsedFileDiff): DiffHunk { + const hunk: DiffHunk = { + header: '', + lines: [], + additions: 0, + deletions: 0, + } + target.hunks.push(hunk) + return hunk +} + +function parseDiffLine(target: ParsedFileDiff, currentHunk: DiffHunk | null, line: string): DiffHunk | null { + if (line.startsWith('@@')) { + return startHunk(target, line) + } + const hunk = currentHunk ?? startImplicitHunk(target) + if (line.startsWith('+')) { + const nextLine: DiffLine = { type: 'add', content: line.slice(1) } + target.additions += 1 + hunk.additions += 1 + pushLine(target, hunk, nextLine) + return hunk + } + if (line.startsWith('-')) { + const nextLine: DiffLine = { type: 'del', content: line.slice(1) } + target.deletions += 1 + hunk.deletions += 1 + pushLine(target, hunk, nextLine) + return hunk + } + if (line.startsWith(' ')) { + pushLine(target, hunk, { type: 'context', content: line.slice(1) }) + } + return hunk } /** - * parseSingleFileDiff 解析单文件 diff 内容,返回 additions/deletions/lines。 - * 跳过 --- / +++ 文件头行,只处理 @@ hunk 头和 +/- 内容行。 + * parseSingleFileDiff 解析单文件 diff 内容,返回 additions/deletions/lines/hunks。 + * 跳过 `---` / `+++` 文件头,仅保留 hunk 头、上下文行和增删行。 */ export function parseSingleFileDiff(diff: string): ParsedFileDiff { - const result: ParsedFileDiff = { additions: 0, deletions: 0, lines: [] } + const result = createParsedFileDiff() if (!diff) return result + let currentHunk: DiffHunk | null = null for (const rawLine of diff.split('\n')) { const line = rawLine.endsWith('\r') ? rawLine.slice(0, -1) : rawLine - - // 跳过文件头行 - if (line.startsWith('--- ') || line.startsWith('+++ ')) continue - - // hunk header - if (line.startsWith('@@')) { - result.lines.push({ type: 'header', content: line }) - continue - } - - // 上下文行 — 跳过 - if (line.startsWith(' ')) continue - - // 新增行 - if (line.startsWith('+')) { - result.additions++ - result.lines.push({ type: 'add', content: line.slice(1) }) - continue - } - - // 删除行 - if (line.startsWith('-')) { - result.deletions++ - result.lines.push({ type: 'del', content: line.slice(1) }) - } + if (line.startsWith('--- ') || line.startsWith('+++ ') || line.startsWith('\\ ')) continue + currentHunk = parseDiffLine(result, currentHunk, line) } return result } +function ensureParsedFile( + result: Record, + currentPath: string, +): ParsedFileDiff | null { + if (!currentPath) return null + const existing = result[currentPath] + if (existing) return existing + const created = createParsedFileDiff() + result[currentPath] = created + return created +} + /** * parseUnifiedPatch 将标准 unified diff 拆为按文件索引的结构。 - * 支持 `--- a/path` / `+++ b/path` 或 `diff --git a/path b/path` 两种分隔方式。 + * 支持 `--- a/path` / `+++ b/path` 与 `diff --git a/path b/path` 两种文件边界。 */ export function parseUnifiedPatch(patch: string): Record { const result: Record = {} if (!patch) return result - const lines = patch.split('\n') let currentPath = '' let current: ParsedFileDiff | null = null + let currentHunk: DiffHunk | null = null - for (const rawLine of lines) { + for (const rawLine of patch.split('\n')) { const line = rawLine.endsWith('\r') ? rawLine.slice(0, -1) : rawLine - // 文件边界:--- a/path 或 --- path (go-difflib 不带 a/ 前缀) - const fromMatch = line.match(/^--- (?:a\/)?(.+)$/) - if (fromMatch && fromMatch[1] !== '/dev/null') { - currentPath = fromMatch[1] - current = { additions: 0, deletions: 0, lines: [] } - result[currentPath] = current - continue - } - - // 文件边界:+++ b/path 或 +++ path - const toMatch = line.match(/^\+\+\+ (?:b\/)?(.+)$/) - if (toMatch && toMatch[1] !== '/dev/null') { - if (!current) { - currentPath = toMatch[1] - current = { additions: 0, deletions: 0, lines: [] } - result[currentPath] = current - } - continue - } - - // 文件边界:diff --git a/path b/path const gitMatch = line.match(/^diff --git a\/\S+ b\/(.+)$/) if (gitMatch) { currentPath = gitMatch[1] - current = result[currentPath] ?? { additions: 0, deletions: 0, lines: [] } - result[currentPath] = current + current = ensureParsedFile(result, currentPath) + currentHunk = null continue } - if (!current) continue - - // hunk header - if (line.startsWith('@@')) { - current.lines.push({ type: 'header', content: line }) + const fromMatch = line.match(/^--- (?:a\/)?(.+)$/) + if (fromMatch) { + if (fromMatch[1] !== '/dev/null') { + currentPath = fromMatch[1] + current = ensureParsedFile(result, currentPath) + } + currentHunk = null continue } - // 纯上下文行 — 跳过 - if (line.startsWith(' ')) continue - - // 新增行 - if (line.startsWith('+')) { - current.additions++ - current.lines.push({ type: 'add', content: line.slice(1) }) + const toMatch = line.match(/^\+\+\+ (?:b\/)?(.+)$/) + if (toMatch) { + if (toMatch[1] !== '/dev/null') { + currentPath = toMatch[1] + current = ensureParsedFile(result, currentPath) + } + currentHunk = null continue } - // 删除行 - if (line.startsWith('-')) { - current.deletions++ - current.lines.push({ type: 'del', content: line.slice(1) }) - } + if (line.startsWith('\\ ')) continue + if (!current) continue + + currentHunk = parseDiffLine(current, currentHunk, line) } return result diff --git a/www/guide/install.md b/www/guide/install.md index e6e7666c..6c4f1563 100644 --- a/www/guide/install.md +++ b/www/guide/install.md @@ -73,6 +73,8 @@ neocode --workdir /path/to/your/project neocode web ``` +标签发布版执行 `neocode web` 时,如果本地还没有 `web/dist`,会自动使用发布包内的 `web/` 源码执行 `npm install` 和 `npm run build`,然后启动 Web UI。该流程要求当前机器已安装 Node.js 和 npm。 + ## 5. 第一次对话 可以先让 NeoCode 读项目结构: @@ -116,6 +118,8 @@ go build ./... go run ./cmd/neocode ``` +如果你希望从源码仓库直接验证 Web UI,也可以运行 `go run ./cmd/neocode web`。当 `web/dist` 缺失时,命令会自动尝试构建前端;若构建机没有 Node.js/npm,会直接报出依赖缺失提示。 + 如果你只想稳定使用,优先使用一键安装方式。源码构建更适合阅读代码、调试功能或参与开发。 ## 下一步 diff --git a/www/guide/web-ui.md b/www/guide/web-ui.md index ca6cd34f..d4d83ec1 100644 --- a/www/guide/web-ui.md +++ b/www/guide/web-ui.md @@ -14,6 +14,7 @@ neocode web ``` 启动后浏览器会自动打开 `http://127.0.0.1:8080`。如果端口被占用,会自动尝试 8081 ~ 8090。 +如果当前目录或发布包内存在 `web/` 源码但还没有 `web/dist`,命令会自动执行 `npm install` 和 `npm run build`。标签发布版使用该能力时,用户机器必须预先安装 Node.js 和 npm。 ### 常用参数 @@ -21,7 +22,7 @@ neocode web |------|--------|------| | `--http-listen` | `127.0.0.1:8080` | 监听地址(仅允许回环地址) | | `--open-browser` | `true` | 启动后自动打开浏览器 | -| `--skip-build` | `false` | 跳过前端构建(dist/ 缺失时会报错) | +| `--skip-build` | `false` | 跳过前端构建(dist/ 缺失时会报错;仅在你已准备好预构建资源时使用) | | `--static-dir` | — | 指定前端静态文件目录 | | `--log-level` | `info` | 日志级别:debug / info / warn / error | | `--token-file` | — | 自定义认证 token 文件路径 | @@ -56,6 +57,7 @@ Web UI 支持两种运行模式,根据启动方式自动选择: ### 浏览器模式 通过 `neocode web` 或直接在浏览器中访问 Gateway 地址时使用。首次连接需要输入 Gateway URL 和 token,配置保存在 sessionStorage 中。 +如果你使用标签发布版,首次运行 `neocode web` 可能先看到前端依赖安装与构建日志;构建完成后会继续启动 Web UI。若机器缺少 Node.js/npm,命令会直接提示安装依赖。 ## 核心功能