diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index 3ff76a91..b5f86d8e 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -159,6 +159,25 @@ func (f runtimeCheckpointFixture) captureFile(t *testing.T, relPath string, cont return abs } +func readCheckpointRestoredPayload(t *testing.T, events <-chan RuntimeEvent) CheckpointRestoredPayload { + t.Helper() + for { + select { + case evt := <-events: + if evt.Type != EventCheckpointRestored { + continue + } + payload, ok := evt.Payload.(CheckpointRestoredPayload) + if !ok { + t.Fatalf("checkpoint restored payload type = %T", evt.Payload) + } + return payload + default: + t.Fatal("expected checkpoint restored event") + } + } +} + func TestCreateStartOfTurnCheckpoint_PendingWrite(t *testing.T) { fixture := newRuntimeCheckpointFixture(t) fixture.captureFile(t, "main.go", []byte("package main\nconst v = 1\n")) @@ -392,6 +411,134 @@ func TestRestoreCheckpoint_RecoversCapturedFile(t *testing.T) { if string(got) != "version one" { t.Fatalf("restored content = %q, want %q", string(got), "version one") } + + payload := readCheckpointRestoredPayload(t, fixture.service.events) + if payload.Mode != "exact" { + t.Fatalf("restore payload mode = %q, want exact", payload.Mode) + } + if len(payload.Paths) != 0 { + t.Fatalf("restore payload paths = %#v, want empty", payload.Paths) + } + if payload.GuardCheckpointID == "" { + t.Fatal("restore payload guard checkpoint id is empty") + } +} + +func TestRestoreCheckpointBaselineEmitsModeAndPaths(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + target := filepath.Join(fixture.workdir, "baseline.txt") + if err := os.WriteFile(target, []byte("before baseline"), 0o644); err != nil { + t.Fatalf("WriteFile(before baseline) error = %v", err) + } + if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite() error = %v", err) + } + + state := newRunState("run-baseline", fixture.session) + if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createStartOfTurnCheckpoint() error = %v", err) + } + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("records = %#v, want 1", records) + } + cpRecord := records[0] + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), cpRecord.CheckpointID, agentsession.CheckpointStatusAvailable); err != nil { + t.Fatalf("UpdateCheckpointStatus() error = %v", err) + } + if err := os.WriteFile(target, []byte("after baseline"), 0o644); err != nil { + t.Fatalf("WriteFile(after baseline) error = %v", err) + } + + if _, err := fixture.service.RestoreCheckpoint(context.Background(), GatewayRestoreInput{ + SessionID: fixture.session.ID, + CheckpointID: cpRecord.CheckpointID, + Mode: "baseline", + Paths: []string{"./baseline.txt", "baseline.txt"}, + }); err != nil { + t.Fatalf("RestoreCheckpoint(baseline) error = %v", err) + } + if got := string(mustReadRuntimeFile(t, target)); got != "before baseline" { + t.Fatalf("baseline restored content = %q, want before baseline", got) + } + + payload := readCheckpointRestoredPayload(t, fixture.service.events) + if payload.Mode != "baseline" { + t.Fatalf("restore payload mode = %q, want baseline", payload.Mode) + } + if len(payload.Paths) != 1 || payload.Paths[0] != "baseline.txt" { + t.Fatalf("restore payload paths = %#v, want [baseline.txt]", payload.Paths) + } + if payload.GuardCheckpointID != "" { + t.Fatalf("baseline restore guard checkpoint id = %q, want empty", payload.GuardCheckpointID) + } +} + +func TestRestoreCheckpointBaselineRejectsPathsThatNormalizeEmpty(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + target := filepath.Join(fixture.workdir, "baseline.txt") + if err := os.WriteFile(target, []byte("before baseline"), 0o644); err != nil { + t.Fatalf("WriteFile(before baseline) error = %v", err) + } + if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite() error = %v", err) + } + + state := newRunState("run-baseline-empty-paths", fixture.session) + if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createStartOfTurnCheckpoint() error = %v", err) + } + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("records = %#v, want 1", records) + } + cpRecord := records[0] + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), cpRecord.CheckpointID, agentsession.CheckpointStatusAvailable); err != nil { + t.Fatalf("UpdateCheckpointStatus() error = %v", err) + } + + _, err = fixture.service.restoreCheckpointBaseline(context.Background(), fixture.session.ID, cpRecord.CheckpointID, []string{"./", " . "}) + if err == nil || !strings.Contains(err.Error(), "baseline restore paths required") { + t.Fatalf("restoreCheckpointBaseline() error = %v, want baseline restore paths required", err) + } +} + +func TestRestoreCheckpointBaselineWrapsRestoreBaselineError(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + target := filepath.Join(fixture.workdir, "baseline.txt") + if err := os.WriteFile(target, []byte("before baseline"), 0o644); err != nil { + t.Fatalf("WriteFile(before baseline) error = %v", err) + } + if _, err := fixture.perEditStore.CapturePreWrite(target); err != nil { + t.Fatalf("CapturePreWrite() error = %v", err) + } + + state := newRunState("run-baseline-missing-path", fixture.session) + if err := fixture.service.createStartOfTurnCheckpoint(context.Background(), &state); err != nil { + t.Fatalf("createStartOfTurnCheckpoint() error = %v", err) + } + records, err := fixture.checkpointStore.ListCheckpoints(context.Background(), fixture.session.ID, checkpoint.ListCheckpointOpts{}) + if err != nil { + t.Fatalf("ListCheckpoints() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("records = %#v, want 1", records) + } + cpRecord := records[0] + if err := fixture.checkpointStore.UpdateCheckpointStatus(context.Background(), cpRecord.CheckpointID, agentsession.CheckpointStatusAvailable); err != nil { + t.Fatalf("UpdateCheckpointStatus() error = %v", err) + } + + _, err = fixture.service.restoreCheckpointBaseline(context.Background(), fixture.session.ID, cpRecord.CheckpointID, []string{"missing.txt"}) + if err == nil || !strings.Contains(err.Error(), "baseline restore code") || !strings.Contains(err.Error(), "missing.txt") { + t.Fatalf("restoreCheckpointBaseline() error = %v, want wrapped missing baseline path error", err) + } } func TestUndoRestoreCheckpoint_RestoresGuardState(t *testing.T) { diff --git a/internal/runtime/checkpoint_restore.go b/internal/runtime/checkpoint_restore.go index 358227aa..17038559 100644 --- a/internal/runtime/checkpoint_restore.go +++ b/internal/runtime/checkpoint_restore.go @@ -191,11 +191,13 @@ func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInp CheckpointID: result.CheckpointID, SessionID: result.SessionID, GuardCheckpointID: guardRecord.CheckpointID, + Mode: "exact", }) return result, nil } if mode == "baseline" { - result, err := s.restoreCheckpointBaseline(ctx, input.SessionID, input.CheckpointID, input.Paths) + paths := normalizeBaselineRestorePaths(input.Paths) + result, err := s.restoreCheckpointBaseline(ctx, input.SessionID, input.CheckpointID, paths) if err != nil { return RestoreResult{}, err } @@ -203,6 +205,8 @@ func (s *Service) RestoreCheckpoint(ctx context.Context, input GatewayRestoreInp CheckpointID: result.CheckpointID, SessionID: result.SessionID, GuardCheckpointID: "", + Mode: "baseline", + Paths: paths, }) return result, nil } @@ -351,7 +355,20 @@ func (s *Service) restoreCheckpointBaseline( if perEditID == "" { return RestoreResult{}, fmt.Errorf("checkpoint: %s has no code snapshot", checkpointID) } + relPaths := normalizeBaselineRestorePaths(paths) + if len(relPaths) == 0 { + return RestoreResult{}, fmt.Errorf("checkpoint: baseline restore paths required") + } + if err := s.perEditStore.RestoreBaseline(ctx, perEditID, relPaths); err != nil { + return RestoreResult{}, fmt.Errorf("checkpoint: baseline restore code: %w", err) + } + return RestoreResult{CheckpointID: checkpointID, SessionID: sessionID}, nil +} + +// normalizeBaselineRestorePaths 统一 baseline 文件回退路径,保证恢复执行与事件通知使用同一组相对路径。 +func normalizeBaselineRestorePaths(paths []string) []string { relPaths := make([]string, 0, len(paths)) + seen := make(map[string]struct{}, len(paths)) for _, path := range paths { clean := filepath.ToSlash(strings.TrimSpace(path)) for strings.HasPrefix(clean, "./") { @@ -360,15 +377,13 @@ func (s *Service) restoreCheckpointBaseline( if clean == "" || clean == "." { continue } + if _, ok := seen[clean]; ok { + continue + } + seen[clean] = struct{}{} relPaths = append(relPaths, clean) } - if len(relPaths) == 0 { - return RestoreResult{}, fmt.Errorf("checkpoint: baseline restore paths required") - } - if err := s.perEditStore.RestoreBaseline(ctx, perEditID, relPaths); err != nil { - return RestoreResult{}, fmt.Errorf("checkpoint: baseline restore code: %w", err) - } - return RestoreResult{CheckpointID: checkpointID, SessionID: sessionID}, nil + return relPaths } // ListCheckpoints 查询指定会话的 checkpoint 列表。 diff --git a/internal/runtime/events.go b/internal/runtime/events.go index d4b3e1ee..8b6c0a63 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -470,9 +470,11 @@ type CheckpointWarningPayload struct { // CheckpointRestoredPayload 描述 checkpoint 恢复成功事件。 type CheckpointRestoredPayload struct { - CheckpointID string `json:"checkpoint_id"` - SessionID string `json:"session_id"` - GuardCheckpointID string `json:"guard_checkpoint_id"` + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + GuardCheckpointID string `json:"guard_checkpoint_id"` + Mode string `json:"mode,omitempty"` + Paths []string `json:"paths,omitempty"` } // CheckpointUndoRestorePayload 描述 restore 撤销事件。 diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go index 97012e73..c049d1aa 100644 --- a/internal/tui/services/gateway_stream_client_additional_test.go +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -448,6 +448,82 @@ func TestRestoreRuntimePayloadAdditionalBranches(t *testing.T) { } } +func TestRestoreRuntimePayloadCheckpointRestoredKeepsBaselineFields(t *testing.T) { + t.Parallel() + + payload, err := restoreRuntimePayload(EventCheckpointRestored, map[string]any{ + "checkpoint_id": "cp-baseline", + "session_id": "session-1", + "guard_checkpoint_id": "", + "mode": "baseline", + "paths": []any{"a.txt", "nested/b.txt"}, + }) + if err != nil { + t.Fatalf("restoreRuntimePayload(CheckpointRestored) error = %v", err) + } + restored, ok := payload.(CheckpointRestoredPayload) + if !ok { + t.Fatalf("payload type = %T, want CheckpointRestoredPayload", payload) + } + if restored.CheckpointID != "cp-baseline" || restored.SessionID != "session-1" { + t.Fatalf("unexpected restored identity payload: %+v", restored) + } + if restored.Mode != "baseline" { + t.Fatalf("restored.Mode = %q, want baseline", restored.Mode) + } + if !reflect.DeepEqual(restored.Paths, []string{"a.txt", "nested/b.txt"}) { + t.Fatalf("restored.Paths = %#v, want [a.txt nested/b.txt]", restored.Paths) + } + if restored.GuardCheckpointID != "" { + t.Fatalf("restored.GuardCheckpointID = %q, want empty for baseline restore", restored.GuardCheckpointID) + } +} + +func TestDecodeRuntimeEventCheckpointRestoredKeepsModeAndPaths(t *testing.T) { + t.Parallel() + + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + RunID: "run-1", + SessionID: "session-1", + Payload: map[string]any{ + "event_type": "run_progress", + "payload": map[string]any{ + "runtime_event_type": string(EventCheckpointRestored), + "payload_version": runtimeEventPayloadVersion, + "turn": 2, + "phase": "restore", + "payload": map[string]any{ + "checkpoint_id": "cp-baseline", + "session_id": "session-1", + "guard_checkpoint_id": "", + "mode": "baseline", + "paths": []string{"a.txt", "nested/b.txt"}, + }, + }, + }, + }) + + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err) + } + if event.Type != EventCheckpointRestored || event.RunID != "run-1" || event.SessionID != "session-1" || event.Turn != 2 || event.Phase != "restore" { + t.Fatalf("unexpected event metadata: %+v", event) + } + restored, ok := event.Payload.(CheckpointRestoredPayload) + if !ok { + t.Fatalf("event.Payload type = %T, want CheckpointRestoredPayload", event.Payload) + } + if restored.CheckpointID != "cp-baseline" || restored.Mode != "baseline" { + t.Fatalf("unexpected checkpoint restored payload: %+v", restored) + } + if !reflect.DeepEqual(restored.Paths, []string{"a.txt", "nested/b.txt"}) { + t.Fatalf("restored.Paths = %#v, want [a.txt nested/b.txt]", restored.Paths) + } +} + func TestStreamHelperBranches(t *testing.T) { t.Parallel() diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go index 6947c0b8..8ab9b21a 100644 --- a/internal/tui/services/runtime_contract.go +++ b/internal/tui/services/runtime_contract.go @@ -618,9 +618,11 @@ type CheckpointWarningPayload struct { // CheckpointRestoredPayload 描述 checkpoint 恢复成功事件。 type CheckpointRestoredPayload struct { - CheckpointID string `json:"checkpoint_id"` - SessionID string `json:"session_id"` - GuardCheckpointID string `json:"guard_checkpoint_id"` + CheckpointID string `json:"checkpoint_id"` + SessionID string `json:"session_id"` + GuardCheckpointID string `json:"guard_checkpoint_id"` + Mode string `json:"mode,omitempty"` + Paths []string `json:"paths,omitempty"` } // CheckpointUndoRestorePayload 描述 restore 撤销事件。 diff --git a/web/src/api/protocol.ts b/web/src/api/protocol.ts index db1836b2..c34b08b8 100644 --- a/web/src/api/protocol.ts +++ b/web/src/api/protocol.ts @@ -521,6 +521,8 @@ export interface CheckpointRestoredPayload { checkpoint_id: string; session_id: string; guard_checkpoint_id: string; + mode?: "exact" | "baseline" | string; + paths?: string[]; } export interface CheckpointUndoRestorePayload { diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index bf39b1b2..599af2f3 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -4,6 +4,7 @@ import ChatInput from './ChatInput' import { useChatStore } from '@/stores/useChatStore' import { useComposerStore } from '@/stores/useComposerStore' import { useSessionStore } from '@/stores/useSessionStore' +import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' const mockGatewayAPI = { listAvailableSkills: vi.fn(), @@ -31,6 +32,21 @@ async function submitSlashCommand(command: string) { fireEvent.keyDown(textarea, { key: 'Enter' }) } +function renderWithBudget(input: { + action: string + estimated_input_tokens: number + prompt_budget: number + context_window?: number +}) { + useRuntimeInsightStore.getState().setBudgetChecked({ + attempt_seq: 1, + request_hash: 'budget-test', + ...input, + }) + render() + return screen.getByTestId('budget-token-ring') +} + describe('ChatInput', () => { beforeEach(() => { vi.clearAllMocks() @@ -54,12 +70,17 @@ describe('ChatInput', () => { useComposerStore.setState({ composerText: '' }) useSessionStore.setState({ currentSessionId: '' } as never) + useRuntimeInsightStore.getState().reset() useChatStore.setState({ isGenerating: false, + isCompacting: false, + compactMode: '', + compactMessage: '', messages: [], permissionRequests: [], agentMode: 'build', permissionMode: 'default', + tokenUsage: null, } as never) }) @@ -140,6 +161,60 @@ describe('ChatInput', () => { expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) + it('shows inline compact status while compaction is running', () => { + useChatStore.getState().startCompacting('manual', 'Compacting context...') + + render() + + expect(screen.getByRole('status')).toHaveTextContent('Compacting context...') + }) + + it('blocks normal sends while compaction is running', async () => { + useChatStore.getState().startCompacting('manual', 'Compacting context...') + render() + + const textarea = screen.getByRole('textbox') + fireEvent.change(textarea, { target: { value: 'hello' } }) + fireEvent.keyDown(textarea, { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.run).not.toHaveBeenCalled() + }) + expect(useChatStore.getState().messages).toHaveLength(0) + }) + + it('blocks duplicate compact commands while compaction is running', async () => { + useSessionStore.setState({ currentSessionId: 'session-1' } as never) + useChatStore.getState().startCompacting('manual', 'Compacting context...') + render() + + await submitSlashCommand('/compact') + + await waitFor(() => { + expect(mockGatewayAPI.compact).not.toHaveBeenCalled() + }) + }) + + it('sets compact state immediately when running /compact', async () => { + useSessionStore.setState({ currentSessionId: 'session-1' } as never) + let resolveCompact: (value: unknown) => void = () => {} + mockGatewayAPI.compact.mockReturnValueOnce(new Promise((resolve) => { + resolveCompact = resolve + })) + render() + + await submitSlashCommand('/compact') + + await waitFor(() => { + expect(useChatStore.getState().isCompacting).toBe(true) + }) + + resolveCompact({}) + await waitFor(() => { + expect(useChatStore.getState().isCompacting).toBe(false) + }) + }) + it('executes /memo without session id and shows payload.Content', async () => { mockGatewayAPI.executeSystemTool.mockResolvedValueOnce({ payload: { @@ -203,4 +278,69 @@ describe('ChatInput', () => { expect(mockGatewayAPI.executeSystemTool).not.toHaveBeenCalled() }) }) + + it('shows a green budget ring below the warning threshold', () => { + const ring = renderWithBudget({ + action: 'allow', + estimated_input_tokens: 80, + prompt_budget: 100, + context_window: 200, + }) + + expect(ring).toHaveAttribute('stroke', 'var(--success)') + }) + + it('shows a yellow budget ring near the automatic compact threshold', () => { + const ring = renderWithBudget({ + action: 'allow', + estimated_input_tokens: 90, + prompt_budget: 100, + context_window: 200, + }) + + expect(ring).toHaveAttribute('stroke', 'var(--warning)') + }) + + it('shows a red budget ring near the context window limit', () => { + const ring = renderWithBudget({ + action: 'allow', + estimated_input_tokens: 190, + prompt_budget: 100, + context_window: 200, + }) + + expect(ring).toHaveAttribute('stroke', 'var(--error)') + }) + + it('falls back to prompt budget as the limit when context window is missing', () => { + const ring = renderWithBudget({ + action: 'allow', + estimated_input_tokens: 100, + prompt_budget: 100, + }) + + expect(ring).toHaveAttribute('stroke', 'var(--error)') + }) + + it('honors compact budget action as a yellow color override', () => { + const ring = renderWithBudget({ + action: 'compact', + estimated_input_tokens: 20, + prompt_budget: 100, + context_window: 200, + }) + + expect(ring).toHaveAttribute('stroke', 'var(--warning)') + }) + + it('honors stop budget action as a red color override', () => { + const ring = renderWithBudget({ + action: 'stop', + estimated_input_tokens: 20, + prompt_budget: 100, + context_window: 200, + }) + + expect(ring).toHaveAttribute('stroke', 'var(--error)') + }) }) diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index d5a4bb2e..5669029f 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -19,7 +19,7 @@ import { import SlashCommandMenu from './SlashCommandMenu' import SkillPicker from './SkillPicker' import ModelSelector from './ModelSelector' -import { Send, Square } from 'lucide-react' +import { LoaderCircle, Send, Square } from 'lucide-react' const slashMenuAnchorStyle: React.CSSProperties = { position: 'absolute', @@ -28,6 +28,9 @@ const slashMenuAnchorStyle: React.CSSProperties = { zIndex: 100, } +const budgetWarningThresholdRatio = 0.9 +const budgetDangerThresholdRatio = 0.95 + /** 将网关返回的技能列表转换成输入框使用的 slash 命令结构。 */ function buildSkillSlashCommands( skills: Array<{ descriptor: { id: string; description?: string }; active?: boolean }>, @@ -64,6 +67,59 @@ function extractSystemToolContent(result: unknown, fallback: string): string { return content || fallback } +/** 将预算事件转换为输入框圆环的语义状态,保持阈值和颜色判断集中。 */ +function resolveBudgetRingState( + budgetChecked: ReturnType['budgetChecked'], + budgetEstimateFailed: ReturnType['budgetEstimateFailed'], +) { + if (budgetEstimateFailed) { + return { + color: 'var(--error)', + label: '预算估算失败', + ratio: 0, + } + } + if (!budgetChecked) { + return { + color: 'var(--text-tertiary)', + label: '暂无预算数据', + ratio: 0, + } + } + + const estimatedTokens = Math.max(0, budgetChecked.estimated_input_tokens) + const promptBudget = Math.max(0, budgetChecked.prompt_budget) + const contextLimit = Math.max(0, budgetChecked.context_window || promptBudget) + const ringRatio = contextLimit > 0 ? Math.min(estimatedTokens / contextLimit, 1) : 0 + + if ( + budgetChecked.action === 'stop' || + (contextLimit > 0 && estimatedTokens >= contextLimit * budgetDangerThresholdRatio) || + (!budgetChecked.context_window && promptBudget > 0 && estimatedTokens >= promptBudget) + ) { + return { + color: 'var(--error)', + label: '接近上下文上限', + ratio: ringRatio, + } + } + if ( + budgetChecked.action === 'compact' || + (promptBudget > 0 && estimatedTokens >= promptBudget * budgetWarningThresholdRatio) + ) { + return { + color: 'var(--warning)', + label: '接近自动压缩阈值', + ratio: ringRatio, + } + } + return { + color: 'var(--success)', + label: '正常', + ratio: ringRatio, + } +} + export default function ChatInput() { const gatewayAPI = useGatewayAPI() const text = useComposerStore((state) => state.composerText) @@ -73,6 +129,8 @@ export default function ChatInput() { const runCancelledRef = useRef(false) const composingRef = useRef(false) const isGenerating = useChatStore((state) => state.isGenerating) + const isCompacting = useChatStore((state) => state.isCompacting) + const compactMessage = useChatStore((state) => state.compactMessage) const addMessage = useChatStore((state) => state.addMessage) const addSystemMessage = useChatStore((state) => state.addSystemMessage) const setGenerating = useChatStore((state) => state.setGenerating) @@ -132,6 +190,10 @@ export default function ChatInput() { const { command, argument } = parsed const currentSessionId = sessionId const api = gatewayAPI + if (isCompacting) { + useUIStore.getState().showToast('Context compaction is still running', 'info') + return true + } if (!api) { useUIStore.getState().showToast('Gateway not connected', 'error') return true @@ -147,11 +209,17 @@ export default function ChatInput() { useUIStore.getState().showToast('Send a message first to start a session', 'error') return true } + useChatStore.getState().startCompacting('manual', 'Compacting context...') try { await api.compact(currentSessionId, '') } catch (err) { console.error('Compact failed:', err) - useUIStore.getState().showToast('Compaction failed', 'error') + if (useChatStore.getState().isCompacting) { + useChatStore.getState().finishCompacting() + useUIStore.getState().showToast('Compaction failed', 'error') + } + } finally { + useChatStore.getState().finishCompacting() } return true } @@ -205,8 +273,8 @@ export default function ChatInput() { return true } default: { - if (isGenerating) { - useUIStore.getState().showToast('Cannot toggle skill while generating', 'info') + if (isGenerating || isCompacting) { + useUIStore.getState().showToast(isCompacting ? 'Context compaction is still running' : 'Cannot toggle skill while generating', 'info') return true } const skillCommand = availableSkillCommands.find((skill) => skill.usage === command) @@ -231,12 +299,17 @@ export default function ChatInput() { return false } } - }, [gatewayAPI, sessionId, addSystemMessage, availableSkillCommands, isGenerating, allSlashCommands]) + }, [gatewayAPI, sessionId, addSystemMessage, availableSkillCommands, isGenerating, isCompacting, allSlashCommands]) async function handleSubmit() { const input = text.trim() if (!input) return + if (isCompacting) { + useUIStore.getState().showToast('Context compaction is still running', 'info') + return + } + if (isGenerating) { if (isSlashCommand(input)) useUIStore.getState().showToast('Cannot run commands while generating', 'info') return @@ -363,6 +436,7 @@ export default function ChatInput() { } const isEmpty = !text.trim() + const controlsLocked = isGenerating || isCompacting return ( <> @@ -382,7 +456,13 @@ export default function ChatInput() { /> )} -
+
+ {isCompacting && ( +
+ + {compactMessage || 'Compacting context...'} +
+ )}