From e3bd9369d5638fcf0a262c95d1938708ceadee5f Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 22 Jan 2026 13:56:25 +0000 Subject: [PATCH 01/27] chore(lib): extract Conversation interface --- lib/httpapi/server.go | 12 +- lib/screentracker/conversation.go | 458 ++---------------- lib/screentracker/diff.go | 56 +++ lib/screentracker/diff_internal_test.go | 39 ++ lib/screentracker/pty_conversation.go | 371 ++++++++++++++ ...ation_test.go => pty_conversation_test.go} | 166 ++----- 6 files changed, 554 insertions(+), 548 deletions(-) create mode 100644 lib/screentracker/diff.go create mode 100644 lib/screentracker/diff_internal_test.go create mode 100644 lib/screentracker/pty_conversation.go rename lib/screentracker/{conversation_test.go => pty_conversation_test.go} (75%) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 59497873..fd0a90c5 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -40,7 +40,7 @@ type Server struct { srv *http.Server mu sync.RWMutex logger *slog.Logger - conversation *st.Conversation + conversation *st.PTYConversation agentio *termexec.Process agentType mf.AgentType emitter *EventEmitter @@ -237,7 +237,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } - conversation := st.NewConversation(ctx, st.ConversationConfig{ + conversation := st.NewPTY(ctx, st.PTYConversationConfig{ AgentType: config.AgentType, AgentIO: config.Process, GetTime: func() time.Time { @@ -331,7 +331,7 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) { } func (s *Server) StartSnapshotLoop(ctx context.Context) { - s.conversation.StartSnapshotLoop(ctx) + s.conversation.Start(ctx) go func() { for { currentStatus := s.conversation.Status() @@ -339,7 +339,7 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { // Send initial prompt when agent becomes stable for the first time if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { - if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { + if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { s.logger.Error("Failed to send initial prompt", "error", err) } else { s.conversation.InitialPromptSent = true @@ -350,7 +350,7 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { } s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) - s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen()) + s.emitter.UpdateScreenAndEmitChanges(s.conversation.String()) time.Sleep(snapshotInterval) } }() @@ -449,7 +449,7 @@ func (s *Server) createMessage(ctx context.Context, input *MessageRequest) (*Mes switch input.Body.Type { case MessageTypeUser: - if err := s.conversation.SendMessage(FormatMessage(s.agentType, input.Body.Content)...); err != nil { + if err := s.conversation.Send(FormatMessage(s.agentType, input.Body.Content)...); err != nil { return nil, xerrors.Errorf("failed to send message: %w", err) } case MessageTypeRaw: diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 97a74722..db8d82d1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -2,55 +2,27 @@ package screentracker import ( "context" - "fmt" - "log/slog" - "strings" - "sync" "time" - "github.com/coder/agentapi/lib/msgfmt" "github.com/coder/agentapi/lib/util" "github.com/danielgtaylor/huma/v2" "golang.org/x/xerrors" ) -type screenSnapshot struct { - timestamp time.Time - screen string -} - -type AgentIO interface { - Write(data []byte) (int, error) - ReadScreen() string -} +type ConversationStatus string -type ConversationConfig struct { - AgentType msgfmt.AgentType - AgentIO AgentIO - // GetTime returns the current time - GetTime func() time.Time - // How often to take a snapshot for the stability check - SnapshotInterval time.Duration - // How long the screen should not change to be considered stable - ScreenStabilityLength time.Duration - // Function to format the messages received from the agent - // userInput is the last user message - FormatMessage func(message string, userInput string) string - // SkipWritingMessage skips the writing of a message to the agent. - // This is used in tests - SkipWritingMessage bool - // SkipSendMessageStatusCheck skips the check for whether the message can be sent. - // This is used in tests - SkipSendMessageStatusCheck bool - // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt - ReadyForInitialPrompt func(message string) bool - // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls - FormatToolCall func(message string) (string, []string) - Logger *slog.Logger -} +const ( + ConversationStatusChanging ConversationStatus = "changing" + ConversationStatusStable ConversationStatus = "stable" + ConversationStatusInitializing ConversationStatus = "initializing" +) type ConversationRole string +func (c ConversationRole) Schema(r huma.Registry) *huma.Schema { + return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues) +} + const ( ConversationRoleUser ConversationRole = "user" ConversationRoleAgent ConversationRole = "agent" @@ -61,207 +33,15 @@ var ConversationRoleValues = []ConversationRole{ ConversationRoleAgent, } -func (c ConversationRole) Schema(r huma.Registry) *huma.Schema { - return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues) -} - -type ConversationMessage struct { - Id int - Message string - Role ConversationRole - Time time.Time -} - -type Conversation struct { - cfg ConversationConfig - // How many stable snapshots are required to consider the screen stable - stableSnapshotsThreshold int - snapshotBuffer *RingBuffer[screenSnapshot] - messages []ConversationMessage - screenBeforeLastUserMessage string - lock sync.Mutex - // InitialPrompt is the initial prompt passed to the agent - InitialPrompt string - // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents - InitialPromptSent bool - // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt - ReadyForInitialPrompt bool - // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message - toolCallMessageSet map[string]bool -} - -type ConversationStatus string - -const ( - ConversationStatusChanging ConversationStatus = "changing" - ConversationStatusStable ConversationStatus = "stable" - ConversationStatusInitializing ConversationStatus = "initializing" +var ( + MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") + MessageValidationErrorEmpty = xerrors.New("message must not be empty") + MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input") ) -func getStableSnapshotsThreshold(cfg ConversationConfig) int { - length := cfg.ScreenStabilityLength.Milliseconds() - interval := cfg.SnapshotInterval.Milliseconds() - threshold := int(length / interval) - if length%interval != 0 { - threshold++ - } - return threshold + 1 -} - -func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation { - threshold := getStableSnapshotsThreshold(cfg) - c := &Conversation{ - cfg: cfg, - stableSnapshotsThreshold: threshold, - snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), - messages: []ConversationMessage{ - { - Message: "", - Role: ConversationRoleAgent, - Time: cfg.GetTime(), - }, - }, - InitialPrompt: initialPrompt, - InitialPromptSent: len(initialPrompt) == 0, - toolCallMessageSet: make(map[string]bool), - } - return c -} - -func (c *Conversation) StartSnapshotLoop(ctx context.Context) { - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(c.cfg.SnapshotInterval): - // It's important that we hold the lock while reading the screen. - // There's a race condition that occurs without it: - // 1. The screen is read - // 2. Independently, SendMessage is called and takes the lock. - // 3. AddSnapshot is called and waits on the lock. - // 4. SendMessage modifies the terminal state, releases the lock - // 5. AddSnapshot adds a snapshot from a stale screen - c.lock.Lock() - screen := c.cfg.AgentIO.ReadScreen() - c.addSnapshotInner(screen) - c.lock.Unlock() - } - } - }() -} - -func FindNewMessage(oldScreen, newScreen string, agentType msgfmt.AgentType) string { - oldLines := strings.Split(oldScreen, "\n") - newLines := strings.Split(newScreen, "\n") - oldLinesMap := make(map[string]bool) - - // -1 indicates no header - dynamicHeaderEnd := -1 - - // Skip header lines for Opencode agent type to avoid false positives - // The header contains dynamic content (token count, context percentage, cost) - // that changes between screens, causing line comparison mismatches: - // - // ┃ # Getting Started with Claude CLI ┃ - // ┃ /share to create a shareable link 12.6K/6% ($0.05) ┃ - if len(newLines) >= 2 && agentType == msgfmt.AgentTypeOpencode { - dynamicHeaderEnd = 2 - } - - for _, line := range oldLines { - oldLinesMap[line] = true - } - firstNonMatchingLine := len(newLines) - for i, line := range newLines[dynamicHeaderEnd+1:] { - if !oldLinesMap[line] { - firstNonMatchingLine = i - break - } - } - newSectionLines := newLines[firstNonMatchingLine:] - - // remove leading and trailing lines which are empty or have only whitespace - startLine := 0 - endLine := len(newSectionLines) - 1 - for i := 0; i < len(newSectionLines); i++ { - if strings.TrimSpace(newSectionLines[i]) != "" { - startLine = i - break - } - } - for i := len(newSectionLines) - 1; i >= 0; i-- { - if strings.TrimSpace(newSectionLines[i]) != "" { - endLine = i - break - } - } - return strings.Join(newSectionLines[startLine:endLine+1], "\n") -} - -func (c *Conversation) lastMessage(role ConversationRole) ConversationMessage { - for i := len(c.messages) - 1; i >= 0; i-- { - if c.messages[i].Role == role { - return c.messages[i] - } - } - return ConversationMessage{} -} - -// This function assumes that the caller holds the lock -func (c *Conversation) updateLastAgentMessage(screen string, timestamp time.Time) { - agentMessage := FindNewMessage(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType) - lastUserMessage := c.lastMessage(ConversationRoleUser) - var toolCalls []string - if c.cfg.FormatMessage != nil { - agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) - } - if c.cfg.FormatToolCall != nil { - agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) - } - for _, toolCall := range toolCalls { - if c.toolCallMessageSet[toolCall] == false { - c.toolCallMessageSet[toolCall] = true - c.cfg.Logger.Info("Tool call detected", "toolCall", toolCall) - } - } - shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser - lastAgentMessage := c.lastMessage(ConversationRoleAgent) - if lastAgentMessage.Message == agentMessage { - return - } - conversationMessage := ConversationMessage{ - Message: agentMessage, - Role: ConversationRoleAgent, - Time: timestamp, - } - if shouldCreateNewMessage { - c.messages = append(c.messages, conversationMessage) - - // Cleanup - c.toolCallMessageSet = make(map[string]bool) - - } else { - c.messages[len(c.messages)-1] = conversationMessage - } - c.messages[len(c.messages)-1].Id = len(c.messages) - 1 -} - -// assumes the caller holds the lock -func (c *Conversation) addSnapshotInner(screen string) { - snapshot := screenSnapshot{ - timestamp: c.cfg.GetTime(), - screen: screen, - } - c.snapshotBuffer.Add(snapshot) - c.updateLastAgentMessage(screen, snapshot.timestamp) -} - -func (c *Conversation) AddSnapshot(screen string) { - c.lock.Lock() - defer c.lock.Unlock() - - c.addSnapshotInner(screen) +type AgentIO interface { + Write(data []byte) (int, error) + ReadScreen() string } type MessagePart interface { @@ -269,198 +49,18 @@ type MessagePart interface { String() string } -type MessagePartText struct { - Content string - Alias string - Hidden bool -} - -func (p MessagePartText) Do(writer AgentIO) error { - _, err := writer.Write([]byte(p.Content)) - return err -} - -func (p MessagePartText) String() string { - if p.Hidden { - return "" - } - if p.Alias != "" { - return p.Alias - } - return p.Content -} - -func PartsToString(parts ...MessagePart) string { - var sb strings.Builder - for _, part := range parts { - sb.WriteString(part.String()) - } - return sb.String() -} - -func ExecuteParts(writer AgentIO, parts ...MessagePart) error { - for _, part := range parts { - if err := part.Do(writer); err != nil { - return xerrors.Errorf("failed to write message part: %w", err) - } - } - return nil -} - -func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, messageParts ...MessagePart) error { - if c.cfg.SkipWritingMessage { - return nil - } - screenBeforeMessage := c.cfg.AgentIO.ReadScreen() - if err := ExecuteParts(c.cfg.AgentIO, messageParts...); err != nil { - return xerrors.Errorf("failed to write message: %w", err) - } - // wait for the screen to stabilize after the message is written - if err := util.WaitFor(ctx, util.WaitTimeout{ - Timeout: 15 * time.Second, - MinInterval: 50 * time.Millisecond, - InitialWait: true, - }, func() (bool, error) { - screen := c.cfg.AgentIO.ReadScreen() - if screen != screenBeforeMessage { - time.Sleep(1 * time.Second) - newScreen := c.cfg.AgentIO.ReadScreen() - return newScreen == screen, nil - } - return false, nil - }); err != nil { - return xerrors.Errorf("failed to wait for screen to stabilize: %w", err) - } - - // wait for the screen to change after the carriage return is written - screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen() - lastCarriageReturnTime := time.Time{} - if err := util.WaitFor(ctx, util.WaitTimeout{ - Timeout: 15 * time.Second, - MinInterval: 25 * time.Millisecond, - }, func() (bool, error) { - // we don't want to spam additional carriage returns because the agent may process them - // (aider does this), but we do want to retry sending one if nothing's - // happening for a while - if time.Since(lastCarriageReturnTime) >= 3*time.Second { - lastCarriageReturnTime = time.Now() - if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil { - return false, xerrors.Errorf("failed to write carriage return: %w", err) - } - } - time.Sleep(25 * time.Millisecond) - screen := c.cfg.AgentIO.ReadScreen() - - return screen != screenBeforeCarriageReturn, nil - }); err != nil { - return xerrors.Errorf("failed to wait for processing to start: %w", err) - } - - return nil -} - -var MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") -var MessageValidationErrorEmpty = xerrors.New("message must not be empty") -var MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input") - -func (c *Conversation) SendMessage(messageParts ...MessagePart) error { - c.lock.Lock() - defer c.lock.Unlock() - - if !c.cfg.SkipSendMessageStatusCheck && c.statusInner() != ConversationStatusStable { - return MessageValidationErrorChanging - } - - message := PartsToString(messageParts...) - if message != msgfmt.TrimWhitespace(message) { - // msgfmt formatting functions assume this - return MessageValidationErrorWhitespace - } - if message == "" { - // writeMessageWithConfirmation requires a non-empty message - return MessageValidationErrorEmpty - } - - screenBeforeMessage := c.cfg.AgentIO.ReadScreen() - now := c.cfg.GetTime() - c.updateLastAgentMessage(screenBeforeMessage, now) - - if err := c.writeMessageWithConfirmation(context.Background(), messageParts...); err != nil { - return xerrors.Errorf("failed to send message: %w", err) - } - - c.screenBeforeLastUserMessage = screenBeforeMessage - c.messages = append(c.messages, ConversationMessage{ - Id: len(c.messages), - Message: message, - Role: ConversationRoleUser, - Time: now, - }) - return nil -} - -// Assumes that the caller holds the lock -func (c *Conversation) statusInner() ConversationStatus { - // sanity checks - if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold { - panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold)) - } - if c.stableSnapshotsThreshold == 0 { - panic("stable snapshots threshold is 0. can't check stability") - } - - snapshots := c.snapshotBuffer.GetAll() - if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser { - // if the last message is a user message then the snapshot loop hasn't - // been triggered since the last user message, and we should assume - // the screen is changing - return ConversationStatusChanging - } - - if len(snapshots) != c.stableSnapshotsThreshold { - return ConversationStatusInitializing - } - - for i := 1; i < len(snapshots); i++ { - if snapshots[0].screen != snapshots[i].screen { - return ConversationStatusChanging - } - } - - if !c.InitialPromptSent && !c.ReadyForInitialPrompt { - if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { - c.ReadyForInitialPrompt = true - return ConversationStatusStable - } - return ConversationStatusChanging - } - - return ConversationStatusStable -} - -func (c *Conversation) Status() ConversationStatus { - c.lock.Lock() - defer c.lock.Unlock() - - return c.statusInner() -} - -func (c *Conversation) Messages() []ConversationMessage { - c.lock.Lock() - defer c.lock.Unlock() - - result := make([]ConversationMessage, len(c.messages)) - copy(result, c.messages) - return result +// Conversation allows tracking of a conversation between a user and an agent. +type Conversation interface { + Messages() []ConversationMessage + Snapshot(string) + Start(context.Context) + Status() ConversationStatus + String() string } -func (c *Conversation) Screen() string { - c.lock.Lock() - defer c.lock.Unlock() - - snapshots := c.snapshotBuffer.GetAll() - if len(snapshots) == 0 { - return "" - } - return snapshots[len(snapshots)-1].screen +type ConversationMessage struct { + Id int + Message string + Role ConversationRole + Time time.Time } diff --git a/lib/screentracker/diff.go b/lib/screentracker/diff.go new file mode 100644 index 00000000..47c5b78c --- /dev/null +++ b/lib/screentracker/diff.go @@ -0,0 +1,56 @@ +package screentracker + +import ( + "strings" + + "github.com/coder/agentapi/lib/msgfmt" +) + +// screenDiff compares two screen states and attempts to find latest message of the given agent type. +func screenDiff(oldScreen, newScreen string, agentType msgfmt.AgentType) string { + oldLines := strings.Split(oldScreen, "\n") + newLines := strings.Split(newScreen, "\n") + oldLinesMap := make(map[string]bool) + + // -1 indicates no header + dynamicHeaderEnd := -1 + + // Skip header lines for Opencode agent type to avoid false positives + // The header contains dynamic content (token count, context percentage, cost) + // that changes between screens, causing line comparison mismatches: + // + // ┃ # Getting Started with Claude CLI ┃ + // ┃ /share to create a shareable link 12.6K/6% ($0.05) ┃ + if len(newLines) >= 2 && agentType == msgfmt.AgentTypeOpencode { + dynamicHeaderEnd = 2 + } + + for _, line := range oldLines { + oldLinesMap[line] = true + } + firstNonMatchingLine := len(newLines) + for i, line := range newLines[dynamicHeaderEnd+1:] { + if !oldLinesMap[line] { + firstNonMatchingLine = i + break + } + } + newSectionLines := newLines[firstNonMatchingLine:] + + // remove leading and trailing lines which are empty or have only whitespace + startLine := 0 + endLine := len(newSectionLines) - 1 + for i := range newSectionLines { + if strings.TrimSpace(newSectionLines[i]) != "" { + startLine = i + break + } + } + for i := len(newSectionLines) - 1; i >= 0; i-- { + if strings.TrimSpace(newSectionLines[i]) != "" { + endLine = i + break + } + } + return strings.Join(newSectionLines[startLine:endLine+1], "\n") +} diff --git a/lib/screentracker/diff_internal_test.go b/lib/screentracker/diff_internal_test.go new file mode 100644 index 00000000..d68bc36c --- /dev/null +++ b/lib/screentracker/diff_internal_test.go @@ -0,0 +1,39 @@ +package screentracker + +import ( + "embed" + "path" + "testing" + + "github.com/coder/agentapi/lib/msgfmt" + "github.com/stretchr/testify/assert" +) + +//go:embed testdata +var testdataDir embed.FS + +func TestScreenDiff(t *testing.T) { + t.Run("simple", func(t *testing.T) { + assert.Equal(t, "", screenDiff("123456", "123456", msgfmt.AgentTypeCustom)) + assert.Equal(t, "1234567", screenDiff("123456", "1234567", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("123", "123\n \n \n \n42", msgfmt.AgentTypeCustom)) + assert.Equal(t, "12342", screenDiff("123", "12342\n \n \n \n", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("123", "123\n \n \n \n42\n \n \n \n", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("89", "42", msgfmt.AgentTypeCustom)) + }) + + dir := "testdata/diff" + cases, err := testdataDir.ReadDir(dir) + assert.NoError(t, err) + for _, c := range cases { + t.Run(c.Name(), func(t *testing.T) { + before, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "before.txt")) + assert.NoError(t, err) + after, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "after.txt")) + assert.NoError(t, err) + expected, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "expected.txt")) + assert.NoError(t, err) + assert.Equal(t, string(expected), screenDiff(string(before), string(after), msgfmt.AgentTypeCustom)) + }) + } +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go new file mode 100644 index 00000000..91b956a1 --- /dev/null +++ b/lib/screentracker/pty_conversation.go @@ -0,0 +1,371 @@ +package screentracker + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "github.com/coder/agentapi/lib/msgfmt" + "github.com/coder/agentapi/lib/util" + "golang.org/x/xerrors" +) + +// A screenSnapshot represents a snapshot of the PTY at a specific time. +type screenSnapshot struct { + timestamp time.Time + screen string +} + +type MessagePartText struct { + Content string + Alias string + Hidden bool +} + +var _ MessagePart = &MessagePartText{} + +func (p MessagePartText) Do(writer AgentIO) error { + _, err := writer.Write([]byte(p.Content)) + return err +} + +func (p MessagePartText) String() string { + if p.Hidden { + return "" + } + if p.Alias != "" { + return p.Alias + } + return p.Content +} + +// PTYConversationConfig is the configuration for a PTYConversation. +type PTYConversationConfig struct { + AgentType msgfmt.AgentType + AgentIO AgentIO + // GetTime returns the current time + GetTime func() time.Time + // How often to take a snapshot for the stability check + SnapshotInterval time.Duration + // How long the screen should not change to be considered stable + ScreenStabilityLength time.Duration + // Function to format the messages received from the agent + // userInput is the last user message + FormatMessage func(message string, userInput string) string + // SkipWritingMessage skips the writing of a message to the agent. + // This is used in tests + SkipWritingMessage bool + // SkipSendMessageStatusCheck skips the check for whether the message can be sent. + // This is used in tests + SkipSendMessageStatusCheck bool + // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt + ReadyForInitialPrompt func(message string) bool + // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls + FormatToolCall func(message string) (string, []string) + Logger *slog.Logger +} + +func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { + length := cfg.ScreenStabilityLength.Milliseconds() + interval := cfg.SnapshotInterval.Milliseconds() + threshold := int(length / interval) + if length%interval != 0 { + threshold++ + } + return threshold + 1 +} + +// PTYConversation is a conversation that uses a pseudo-terminal (PTY) for communication. +// It uses a combination of polling and diffs to detect changes in the screen. +type PTYConversation struct { + cfg PTYConversationConfig + // How many stable snapshots are required to consider the screen stable + stableSnapshotsThreshold int + snapshotBuffer *RingBuffer[screenSnapshot] + messages []ConversationMessage + screenBeforeLastUserMessage string + lock sync.Mutex + + // InitialPrompt is the initial prompt passed to the agent + InitialPrompt string + // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents + InitialPromptSent bool + // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt + ReadyForInitialPrompt bool + // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message + toolCallMessageSet map[string]bool +} + +var _ Conversation = &PTYConversation{} + +func NewPTY(ctx context.Context, cfg PTYConversationConfig, initialPrompt string) *PTYConversation { + threshold := cfg.getStableSnapshotsThreshold() + c := &PTYConversation{ + cfg: cfg, + stableSnapshotsThreshold: threshold, + snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), + messages: []ConversationMessage{ + { + Message: "", + Role: ConversationRoleAgent, + Time: cfg.GetTime(), + }, + }, + InitialPrompt: initialPrompt, + InitialPromptSent: len(initialPrompt) == 0, + toolCallMessageSet: make(map[string]bool), + } + return c +} + +func (c *PTYConversation) Start(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(c.cfg.SnapshotInterval): + // It's important that we hold the lock while reading the screen. + // There's a race condition that occurs without it: + // 1. The screen is read + // 2. Independently, SendMessage is called and takes the lock. + // 3. AddSnapshot is called and waits on the lock. + // 4. SendMessage modifies the terminal state, releases the lock + // 5. AddSnapshot adds a snapshot from a stale screen + c.lock.Lock() + screen := c.cfg.AgentIO.ReadScreen() + c.snapshotLocked(screen) + c.lock.Unlock() + } + } + }() +} + +func (c *PTYConversation) lastMessage(role ConversationRole) ConversationMessage { + for i := len(c.messages) - 1; i >= 0; i-- { + if c.messages[i].Role == role { + return c.messages[i] + } + } + return ConversationMessage{} +} + +// caller MUST hold c.lock +func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp time.Time) { + agentMessage := screenDiff(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType) + lastUserMessage := c.lastMessage(ConversationRoleUser) + var toolCalls []string + if c.cfg.FormatMessage != nil { + agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) + } + if c.cfg.FormatToolCall != nil { + agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) + } + for _, toolCall := range toolCalls { + if c.toolCallMessageSet[toolCall] == false { + c.toolCallMessageSet[toolCall] = true + c.cfg.Logger.Info("Tool call detected", "toolCall", toolCall) + } + } + shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser + lastAgentMessage := c.lastMessage(ConversationRoleAgent) + if lastAgentMessage.Message == agentMessage { + return + } + conversationMessage := ConversationMessage{ + Message: agentMessage, + Role: ConversationRoleAgent, + Time: timestamp, + } + if shouldCreateNewMessage { + c.messages = append(c.messages, conversationMessage) + + // Cleanup + c.toolCallMessageSet = make(map[string]bool) + + } else { + c.messages[len(c.messages)-1] = conversationMessage + } + c.messages[len(c.messages)-1].Id = len(c.messages) - 1 +} + +func (c *PTYConversation) Snapshot(screen string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.snapshotLocked(screen) +} + +// caller MUST hold c.lock +func (c *PTYConversation) snapshotLocked(screen string) { + snapshot := screenSnapshot{ + timestamp: c.cfg.GetTime(), + screen: screen, + } + c.snapshotBuffer.Add(snapshot) + c.updateLastAgentMessageLocked(screen, snapshot.timestamp) +} + +func (c *PTYConversation) Send(messageParts ...MessagePart) error { + c.lock.Lock() + defer c.lock.Unlock() + + if !c.cfg.SkipSendMessageStatusCheck && c.statusLocked() != ConversationStatusStable { + return MessageValidationErrorChanging + } + + var sb strings.Builder + for _, part := range messageParts { + sb.WriteString(part.String()) + } + message := sb.String() + if message != msgfmt.TrimWhitespace(message) { + // msgfmt formatting functions assume this + return MessageValidationErrorWhitespace + } + if message == "" { + // writeMessageWithConfirmation requires a non-empty message + return MessageValidationErrorEmpty + } + + screenBeforeMessage := c.cfg.AgentIO.ReadScreen() + now := c.cfg.GetTime() + c.updateLastAgentMessageLocked(screenBeforeMessage, now) + + if err := c.writeStabilize(context.Background(), messageParts...); err != nil { + return xerrors.Errorf("failed to send message: %w", err) + } + + c.screenBeforeLastUserMessage = screenBeforeMessage + c.messages = append(c.messages, ConversationMessage{ + Id: len(c.messages), + Message: message, + Role: ConversationRoleUser, + Time: now, + }) + return nil +} + +// writeStabilize writes messageParts to the screen and waits for the screen to stabilize after the message is written. +func (c *PTYConversation) writeStabilize(ctx context.Context, messageParts ...MessagePart) error { + if c.cfg.SkipWritingMessage { + return nil + } + screenBeforeMessage := c.cfg.AgentIO.ReadScreen() + for _, part := range messageParts { + if err := part.Do(c.cfg.AgentIO); err != nil { + return xerrors.Errorf("failed to write message part: %w", err) + } + } + // wait for the screen to stabilize after the message is written + if err := util.WaitFor(ctx, util.WaitTimeout{ + Timeout: 15 * time.Second, + MinInterval: 50 * time.Millisecond, + InitialWait: true, + }, func() (bool, error) { + screen := c.cfg.AgentIO.ReadScreen() + if screen != screenBeforeMessage { + time.Sleep(1 * time.Second) + newScreen := c.cfg.AgentIO.ReadScreen() + return newScreen == screen, nil + } + return false, nil + }); err != nil { + return xerrors.Errorf("failed to wait for screen to stabilize: %w", err) + } + + // wait for the screen to change after the carriage return is written + screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen() + lastCarriageReturnTime := time.Time{} + if err := util.WaitFor(ctx, util.WaitTimeout{ + Timeout: 15 * time.Second, + MinInterval: 25 * time.Millisecond, + }, func() (bool, error) { + // we don't want to spam additional carriage returns because the agent may process them + // (aider does this), but we do want to retry sending one if nothing's + // happening for a while + if time.Since(lastCarriageReturnTime) >= 3*time.Second { + lastCarriageReturnTime = time.Now() + if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil { + return false, xerrors.Errorf("failed to write carriage return: %w", err) + } + } + time.Sleep(25 * time.Millisecond) + screen := c.cfg.AgentIO.ReadScreen() + + return screen != screenBeforeCarriageReturn, nil + }); err != nil { + return xerrors.Errorf("failed to wait for processing to start: %w", err) + } + + return nil +} + +func (c *PTYConversation) Status() ConversationStatus { + c.lock.Lock() + defer c.lock.Unlock() + + return c.statusLocked() +} + +// caller MUST hold c.lock +func (c *PTYConversation) statusLocked() ConversationStatus { + // sanity checks + if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold { + panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold)) + } + if c.stableSnapshotsThreshold == 0 { + panic("stable snapshots threshold is 0. can't check stability") + } + + snapshots := c.snapshotBuffer.GetAll() + if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser { + // if the last message is a user message then the snapshot loop hasn't + // been triggered since the last user message, and we should assume + // the screen is changing + return ConversationStatusChanging + } + + if len(snapshots) != c.stableSnapshotsThreshold { + return ConversationStatusInitializing + } + + for i := 1; i < len(snapshots); i++ { + if snapshots[0].screen != snapshots[i].screen { + return ConversationStatusChanging + } + } + + if !c.InitialPromptSent && !c.ReadyForInitialPrompt { + if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { + c.ReadyForInitialPrompt = true + return ConversationStatusStable + } + return ConversationStatusChanging + } + + return ConversationStatusStable +} + +func (c *PTYConversation) Messages() []ConversationMessage { + c.lock.Lock() + defer c.lock.Unlock() + + result := make([]ConversationMessage, len(c.messages)) + copy(result, c.messages) + return result +} + +func (c *PTYConversation) String() string { + c.lock.Lock() + defer c.lock.Unlock() + + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) == 0 { + return "" + } + return snapshots[len(snapshots)-1].screen +} diff --git a/lib/screentracker/conversation_test.go b/lib/screentracker/pty_conversation_test.go similarity index 75% rename from lib/screentracker/conversation_test.go rename to lib/screentracker/pty_conversation_test.go index 9b888813..6798de4d 100644 --- a/lib/screentracker/conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,13 +2,10 @@ package screentracker_test import ( "context" - "embed" "fmt" - "path" "testing" "time" - "github.com/coder/agentapi/lib/msgfmt" "github.com/stretchr/testify/assert" st "github.com/coder/agentapi/lib/screentracker" @@ -19,7 +16,7 @@ type statusTestStep struct { status st.ConversationStatus } type statusTestParams struct { - cfg st.ConversationConfig + cfg st.PTYConversationConfig steps []statusTestStep } @@ -42,11 +39,11 @@ func statusTest(t *testing.T, params statusTestParams) { if params.cfg.GetTime == nil { params.cfg.GetTime = func() time.Time { return time.Now() } } - c := st.NewConversation(ctx, params.cfg, "") + c := st.NewPTY(ctx, params.cfg, "") assert.Equal(t, st.ConversationStatusInitializing, c.Status()) for i, step := range params.steps { - c.AddSnapshot(step.snapshot) + c.Snapshot(step.snapshot) assert.Equal(t, step.status, c.Status(), "step %d", i) } }) @@ -58,7 +55,7 @@ func TestConversation(t *testing.T) { initializing := st.ConversationStatusInitializing statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 2 * time.Second, // stability threshold: 3 @@ -76,7 +73,7 @@ func TestConversation(t *testing.T) { }) statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 2 * time.Second, ScreenStabilityLength: 3 * time.Second, // stability threshold: 3 @@ -95,7 +92,7 @@ func TestConversation(t *testing.T) { }) statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 6 * time.Second, ScreenStabilityLength: 14 * time.Second, // stability threshold: 4 @@ -133,11 +130,11 @@ func TestMessages(t *testing.T) { Time: now, } } - sendMsg := func(c *st.Conversation, msg string) error { - return c.SendMessage(st.MessagePartText{Content: msg}) + sendMsg := func(c *st.PTYConversation, msg string) error { + return c.Send(st.MessagePartText{Content: msg}) } - newConversation := func(opts ...func(*st.ConversationConfig)) *st.Conversation { - cfg := st.ConversationConfig{ + newConversation := func(opts ...func(*st.PTYConversationConfig)) *st.PTYConversation { + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 2 * time.Second, @@ -147,7 +144,7 @@ func TestMessages(t *testing.T) { for _, opt := range opts { opt(&cfg) } - return st.NewConversation(context.Background(), cfg, "") + return st.NewPTY(context.Background(), cfg, "") } t.Run("messages are copied", func(t *testing.T) { @@ -167,7 +164,7 @@ func TestMessages(t *testing.T) { t.Run("whitespace-padding", func(t *testing.T) { c := newConversation() for _, msg := range []string{"123 ", " 123", "123\t\t", "\n123", "123\n\t", " \t123\n\t"} { - err := c.SendMessage(st.MessagePartText{Content: msg}) + err := c.Send(st.MessagePartText{Content: msg}) assert.Error(t, err, st.MessageValidationErrorWhitespace) } }) @@ -178,33 +175,33 @@ func TestMessages(t *testing.T) { }{ Time: now, } - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.GetTime = func() time.Time { return nowWrapper.Time } }) - c.AddSnapshot("1") + c.Snapshot("1") msgs := c.Messages() assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), }, msgs) nowWrapper.Time = nowWrapper.Add(1 * time.Second) - c.AddSnapshot("1") + c.Snapshot("1") assert.Equal(t, msgs, c.Messages()) }) t.Run("tracking messages", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent }) // agent message is recorded when the first snapshot is added - c.AddSnapshot("1") + c.Snapshot("1") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), }, c.Messages()) // agent message is updated when the screen changes - c.AddSnapshot("2") + c.Snapshot("2") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "2"), }, c.Messages()) @@ -218,7 +215,7 @@ func TestMessages(t *testing.T) { }, c.Messages()) // agent message is added after a user message - c.AddSnapshot("4") + c.Snapshot("4") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "2"), userMsg(1, "3"), @@ -236,9 +233,9 @@ func TestMessages(t *testing.T) { }, c.Messages()) // conversation status is changing right after a user message - c.AddSnapshot("7") - c.AddSnapshot("7") - c.AddSnapshot("7") + c.Snapshot("7") + c.Snapshot("7") + c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) agent.screen = "7" assert.NoError(t, sendMsg(c, "8")) @@ -254,21 +251,21 @@ func TestMessages(t *testing.T) { // conversation status is back to stable after a snapshot that // doesn't change the screen - c.AddSnapshot("7") + c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) }) t.Run("tracking messages overlap", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent }) // common overlap between screens is removed after a user message - c.AddSnapshot("1") + c.Snapshot("1") agent.screen = "1" assert.NoError(t, sendMsg(c, "2")) - c.AddSnapshot("1\n3") + c.Snapshot("1\n3") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), userMsg(1, "2"), @@ -277,7 +274,7 @@ func TestMessages(t *testing.T) { agent.screen = "1\n3x" assert.NoError(t, sendMsg(c, "4")) - c.AddSnapshot("1\n3x\n5") + c.Snapshot("1\n3x\n5") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), userMsg(1, "2"), @@ -289,7 +286,7 @@ func TestMessages(t *testing.T) { t.Run("format-message", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent cfg.FormatMessage = func(message string, userInput string) string { return message + " " + userInput @@ -302,7 +299,7 @@ func TestMessages(t *testing.T) { userMsg(1, "2"), }, c.Messages()) agent.screen = "x" - c.AddSnapshot("x") + c.Snapshot("x") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1 "), userMsg(1, "2"), @@ -312,7 +309,7 @@ func TestMessages(t *testing.T) { t.Run("format-message", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent cfg.FormatMessage = func(message string, userInput string) string { return "formatted" @@ -329,7 +326,7 @@ func TestMessages(t *testing.T) { }) t.Run("send-message-status-check", func(t *testing.T) { - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.SkipSendMessageStatusCheck = false cfg.SnapshotInterval = 1 * time.Second cfg.ScreenStabilityLength = 2 * time.Second @@ -337,10 +334,10 @@ func TestMessages(t *testing.T) { }) assert.Error(t, sendMsg(c, "1"), st.MessageValidationErrorChanging) for range 3 { - c.AddSnapshot("1") + c.Snapshot("1") } assert.NoError(t, sendMsg(c, "4")) - c.AddSnapshot("2") + c.Snapshot("2") assert.Error(t, sendMsg(c, "5"), st.MessageValidationErrorChanging) }) @@ -350,68 +347,11 @@ func TestMessages(t *testing.T) { }) } -//go:embed testdata -var testdataDir embed.FS - -func TestFindNewMessage(t *testing.T) { - assert.Equal(t, "", st.FindNewMessage("123456", "123456", msgfmt.AgentTypeCustom)) - assert.Equal(t, "1234567", st.FindNewMessage("123456", "1234567", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("123", "123\n \n \n \n42", msgfmt.AgentTypeCustom)) - assert.Equal(t, "12342", st.FindNewMessage("123", "12342\n \n \n \n", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("123", "123\n \n \n \n42\n \n \n \n", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("89", "42", msgfmt.AgentTypeCustom)) - - dir := "testdata/diff" - cases, err := testdataDir.ReadDir(dir) - assert.NoError(t, err) - for _, c := range cases { - t.Run(c.Name(), func(t *testing.T) { - before, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "before.txt")) - assert.NoError(t, err) - after, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "after.txt")) - assert.NoError(t, err) - expected, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "expected.txt")) - assert.NoError(t, err) - assert.Equal(t, string(expected), st.FindNewMessage(string(before), string(after), msgfmt.AgentTypeCustom)) - }) - } -} - -func TestPartsToString(t *testing.T) { - assert.Equal(t, "123", st.PartsToString(st.MessagePartText{Content: "123"})) - assert.Equal(t, - "123", - st.PartsToString( - st.MessagePartText{Content: "1"}, - st.MessagePartText{Content: "2"}, - st.MessagePartText{Content: "3"}, - ), - ) - assert.Equal(t, - "123", - st.PartsToString( - st.MessagePartText{Content: "1"}, - st.MessagePartText{Content: "x", Hidden: true}, - st.MessagePartText{Content: "2"}, - st.MessagePartText{Content: "3"}, - st.MessagePartText{Content: "y", Hidden: true}, - ), - ) - assert.Equal(t, - "ab", - st.PartsToString( - st.MessagePartText{Content: "1", Alias: "a"}, - st.MessagePartText{Content: "2", Alias: "b"}, - st.MessagePartText{Content: "3", Alias: "c", Hidden: true}, - ), - ) -} - func TestInitialPromptReadiness(t *testing.T) { now := time.Now() t.Run("agent not ready - status remains changing", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -420,10 +360,10 @@ func TestInitialPromptReadiness(t *testing.T) { return message == "ready" }, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Fill buffer with stable snapshots, but agent is not ready - c.AddSnapshot("loading...") + c.Snapshot("loading...") // Even though screen is stable, status should be changing because agent is not ready assert.Equal(t, st.ConversationStatusChanging, c.Status()) @@ -432,7 +372,7 @@ func TestInitialPromptReadiness(t *testing.T) { }) t.Run("agent becomes ready - status changes to stable", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -441,14 +381,14 @@ func TestInitialPromptReadiness(t *testing.T) { return message == "ready" }, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Agent not ready initially - c.AddSnapshot("loading...") + c.Snapshot("loading...") assert.Equal(t, st.ConversationStatusChanging, c.Status()) // Agent becomes ready - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt) assert.False(t, c.InitialPromptSent) @@ -456,7 +396,7 @@ func TestInitialPromptReadiness(t *testing.T) { t.Run("ready for initial prompt lifecycle: false -> true -> false", func(t *testing.T) { agent := &testAgent{screen: "loading..."} - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -467,23 +407,23 @@ func TestInitialPromptReadiness(t *testing.T) { SkipWritingMessage: true, SkipSendMessageStatusCheck: true, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Initial state: ReadyForInitialPrompt should be false - c.AddSnapshot("loading...") + c.Snapshot("loading...") assert.False(t, c.ReadyForInitialPrompt, "should start as false") assert.False(t, c.InitialPromptSent) assert.Equal(t, st.ConversationStatusChanging, c.Status()) // Agent becomes ready: ReadyForInitialPrompt should become true agent.screen = "ready" - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt, "should become true when ready") assert.False(t, c.InitialPromptSent) // Send the initial prompt - assert.NoError(t, c.SendMessage(st.MessagePartText{Content: "initial prompt here"})) + assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) // After sending initial prompt: ReadyForInitialPrompt should be set back to false // (simulating what happens in the actual server code) @@ -496,7 +436,7 @@ func TestInitialPromptReadiness(t *testing.T) { }) t.Run("no initial prompt - normal status logic applies", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -506,9 +446,9 @@ func TestInitialPromptReadiness(t *testing.T) { }, } // Empty initial prompt means no need to wait for readiness - c := st.NewConversation(context.Background(), cfg, "") + c := st.NewPTY(context.Background(), cfg, "") - c.AddSnapshot("loading...") + c.Snapshot("loading...") // Status should be stable because no initial prompt to wait for assert.Equal(t, st.ConversationStatusStable, c.Status()) @@ -518,7 +458,7 @@ func TestInitialPromptReadiness(t *testing.T) { t.Run("initial prompt sent - normal status logic applies", func(t *testing.T) { agent := &testAgent{screen: "ready"} - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -529,24 +469,24 @@ func TestInitialPromptReadiness(t *testing.T) { SkipWritingMessage: true, SkipSendMessageStatusCheck: true, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // First, agent becomes ready - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt) assert.False(t, c.InitialPromptSent) // Send the initial prompt agent.screen = "processing..." - assert.NoError(t, c.SendMessage(st.MessagePartText{Content: "initial prompt here"})) + assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) // Mark initial prompt as sent (simulating what the server does) c.InitialPromptSent = true c.ReadyForInitialPrompt = false // Now test that status logic works normally after initial prompt is sent - c.AddSnapshot("processing...") + c.Snapshot("processing...") // Status should be stable because initial prompt was already sent // and the readiness check is bypassed From a0f8bb563bd4cde3c40752c1f2fdecb454f995b7 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Sat, 31 Jan 2026 20:52:52 +0530 Subject: [PATCH 02/27] feat: implement state persistence --- cmd/server/server.go | 38 +++++- lib/httpapi/events.go | 2 +- lib/httpapi/server.go | 162 ++++++++++++++++++++------ lib/httpapi/setup.go | 14 --- lib/screentracker/conversation.go | 9 ++ lib/screentracker/pty_conversation.go | 116 ++++++++++++++++++ 6 files changed, 287 insertions(+), 54 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 6a7fa7f0..5a125af4 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -103,6 +103,26 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } } + // Get the variables related to state management + stateFile := viper.GetString(StateFile) + loadState := true + saveState := true + if stateFile != "" { + if !viper.IsSet(LoadState) { + loadState = true + } else { + loadState = viper.GetBool(LoadState) + } + + if !viper.IsSet(SaveState) { + saveState = true + } else { + saveState = viper.GetBool(SaveState) + } + } + + pidFile := viper.GetString(PidFile) + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -128,7 +148,14 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, + StatePersistenceCfg: httpapi.StatePersistenceCfg{ + StateFile: stateFile, + LoadState: loadState, + SaveState: saveState, + PidFile: pidFile, + }, }) + if err != nil { return xerrors.Errorf("failed to create server: %w", err) } @@ -137,6 +164,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er return nil } srv.StartSnapshotLoop(ctx) + srv.HandleSignals(ctx, process) logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) go func() { @@ -152,7 +180,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er logger.Error("Failed to stop server", "error", err) } }() - if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { return xerrors.Errorf("failed to start server: %w", err) } select { @@ -191,6 +219,10 @@ const ( FlagAllowedOrigins = "allowed-origins" FlagExit = "exit" FlagInitialPrompt = "initial-prompt" + StateFile = "state-file" + LoadState = "load-state" + SaveState = "save-state" + PidFile = "pid-file" ) func CreateServerCmd() *cobra.Command { @@ -229,6 +261,10 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, + {StateFile, "s", "", "Path to file for saving/loading server state", "string"}, + {LoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, + {SaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, + {PidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, } for _, spec := range flagSpecs { diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 73eff07b..e8cabab6 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -120,7 +120,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { } } -// Assumes that only the last message can change or new messages can be added. +// UpdateMessagesAndEmitChanges assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.ConversationMessage) { e.mu.Lock() diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index cc330c6b..965fa28f 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -11,11 +11,13 @@ import ( "net/http" "net/url" "os" + "os/signal" "path/filepath" "slices" "sort" "strings" "sync" + "syscall" "time" "unicode" @@ -34,18 +36,20 @@ import ( // Server represents the HTTP server type Server struct { - router chi.Router - api huma.API - port int - srv *http.Server - mu sync.RWMutex - logger *slog.Logger - conversation *st.PTYConversation - agentio *termexec.Process - agentType mf.AgentType - emitter *EventEmitter - chatBasePath string - tempDir string + router chi.Router + api huma.API + port int + srv *http.Server + mu sync.RWMutex + logger *slog.Logger + conversation *st.PTYConversation + agentio *termexec.Process + agentType mf.AgentType + emitter *EventEmitter + chatBasePath string + tempDir string + statePersistenceCfg StatePersistenceCfg + stateLoadComplete bool } func (s *Server) NormalizeSchema(schema any) any { @@ -94,14 +98,22 @@ func (s *Server) GetOpenAPI() string { // because the action of taking a snapshot takes time too. const snapshotInterval = 25 * time.Millisecond +type StatePersistenceCfg struct { + StateFile string + LoadState bool + SaveState bool + PidFile string +} + type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + StatePersistenceCfg StatePersistenceCfg } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -260,16 +272,18 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { logger.Info("Created temporary directory for uploads", "tempDir", tempDir) s := &Server{ - router: router, - api: api, - port: config.Port, - conversation: conversation, - logger: logger, - agentio: config.Process, - agentType: config.AgentType, - emitter: emitter, - chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), - tempDir: tempDir, + router: router, + api: api, + port: config.Port, + conversation: conversation, + logger: logger, + agentio: config.Process, + agentType: config.AgentType, + emitter: emitter, + chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), + tempDir: tempDir, + statePersistenceCfg: config.StatePersistenceCfg, + stateLoadComplete: false, } // Register API routes @@ -337,15 +351,26 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { currentStatus := s.conversation.Status() // Send initial prompt when agent becomes stable for the first time - if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { - - if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { - s.logger.Error("Failed to send initial prompt", "error", err) - } else { - s.conversation.InitialPromptSent = true - s.conversation.ReadyForInitialPrompt = false - currentStatus = st.ConversationStatusChanging - s.logger.Info("Initial prompt sent successfully") + if convertStatus(currentStatus) == AgentStatusStable { + + if !s.stateLoadComplete && s.statePersistenceCfg.LoadState { + _, err := s.conversation.LoadState(s.statePersistenceCfg.StateFile) + if err != nil { + s.logger.Warn("Failed to load state file", "path", s.statePersistenceCfg.StateFile, "err", err) + } else { + s.logger.Info("Successfully loaded state", "path", s.statePersistenceCfg.StateFile) + } + s.stateLoadComplete = true + } + if !s.conversation.InitialPromptSent { + if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { + s.logger.Error("Failed to send initial prompt", "error", err) + } else { + s.conversation.InitialPromptSent = true + s.conversation.ReadyForInitialPrompt = false + currentStatus = st.ConversationStatusChanging + s.logger.Info("Initial prompt sent successfully") + } } } s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) @@ -592,6 +617,15 @@ func (s *Server) Start() error { // Stop gracefully stops the HTTP server func (s *Server) Stop(ctx context.Context) error { + // Save conversation state if configured + if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + s.logger.Error("Failed to save conversation state", "error", err) + } else { + s.logger.Info("Saved conversation state", "stateFile", s.statePersistenceCfg.StateFile) + } + } + // Clean up temporary directory s.cleanupTempDir() @@ -610,6 +644,58 @@ func (s *Server) cleanupTempDir() { } } +// HandleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process +// - SIGUSR1: save conversation state without exiting +func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + go func() { + sig := <-shutdownCh + s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) + + // Save conversation state if configured (synchronously before closing process) + if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + s.logger.Error("Failed to save conversation state on signal", "signal", sig, "error", err) + } else { + s.logger.Info("Saved conversation state on signal", "signal", sig, "stateFile", s.statePersistenceCfg.StateFile) + } + } + + // Now close the process + if err := process.Close(s.logger, 5*time.Second); err != nil { + s.logger.Error("Error closing process", "signal", sig, "error", err) + } + }() + + // Handle SIGUSR1 for save without exit + saveOnlyCh := make(chan os.Signal, 1) + signal.Notify(saveOnlyCh, syscall.SIGUSR1) + go func() { + for { + select { + case <-saveOnlyCh: + s.logger.Info("Received SIGUSR1, saving state without exiting") + + // Save conversation state if configured + if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + s.logger.Error("Failed to save conversation state on SIGUSR1", "error", err) + } else { + s.logger.Info("Saved conversation state on SIGUSR1", "stateFile", s.statePersistenceCfg.StateFile) + } + } else { + s.logger.Warn("SIGUSR1 received but state saving is not configured") + } + case <-ctx.Done(): + return + } + } + }() +} + // registerStaticFileRoutes sets up routes for serving static files func (s *Server) registerStaticFileRoutes() { chatHandler := FileServerWithIndexFallback(s.chatBasePath) diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 16203041..c8d95b6e 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -4,10 +4,7 @@ import ( "context" "fmt" "os" - "os/signal" "strings" - "syscall" - "time" "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" @@ -45,16 +42,5 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro return nil, err } } - - // Handle SIGINT (Ctrl+C) and send it to the process - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - go func() { - <-signalCh - if err := process.Close(logger, 5*time.Second); err != nil { - logger.Error("Error closing process", "error", err) - } - }() - return process, nil } diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index db8d82d1..daf129a1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -52,6 +52,8 @@ type MessagePart interface { // Conversation allows tracking of a conversation between a user and an agent. type Conversation interface { Messages() []ConversationMessage + SaveState([]ConversationMessage, string) error + LoadState(string) ([]ConversationMessage, error) Snapshot(string) Start(context.Context) Status() ConversationStatus @@ -64,3 +66,10 @@ type ConversationMessage struct { Role ConversationRole Time time.Time } + +type AgentState struct { + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` + InitialPromptSent bool `json:"initial_prompt_sent"` +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 91b956a1..2f8e804f 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -2,8 +2,11 @@ package screentracker import ( "context" + "encoding/json" "fmt" "log/slog" + "os" + "path/filepath" "strings" "sync" "time" @@ -97,6 +100,12 @@ type PTYConversation struct { ReadyForInitialPrompt bool // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message toolCallMessageSet map[string]bool + // dirty tracks whether the conversation state has changed since the last save + dirty bool + // firstStableSnapshot is the conversation history rolled out by the agent in case of a resume (given that the agent supports it) + firstStableSnapshot string + // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state + userSentMessageAfterLoadState bool } var _ Conversation = &PTYConversation{} @@ -161,6 +170,7 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } + agentMessage = c.skipInitialSnapshot(agentMessage) if c.cfg.FormatToolCall != nil { agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) } @@ -190,6 +200,7 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp c.messages[len(c.messages)-1] = conversationMessage } c.messages[len(c.messages)-1].Id = len(c.messages) - 1 + c.dirty = true } func (c *PTYConversation) Snapshot(screen string) { @@ -246,6 +257,9 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error { Role: ConversationRoleUser, Time: now, }) + c.dirty = true + c.userSentMessageAfterLoadState = true + return nil } @@ -369,3 +383,105 @@ func (c *PTYConversation) String() string { } return snapshots[len(snapshots)-1].screen } + +func (c *PTYConversation) SaveState(conversation []ConversationMessage, stateFile string) error { + c.lock.Lock() + defer c.lock.Unlock() + + // Skip if state file is not configured + if stateFile == "" { + return nil + } + + // Skip if not dirty + if !c.dirty { + return nil + } + + // Use atomic write: write to temp file, then rename to target path + data, err := json.MarshalIndent(AgentState{ + Version: 1, + Messages: conversation, + InitialPrompt: c.InitialPrompt, + InitialPromptSent: c.InitialPromptSent, + }, "", " ") + if err != nil { + return xerrors.Errorf("failed to marshal state: %w", err) + } + + // Create directory if it doesn't exist + dir := filepath.Dir(stateFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create state directory: %w", err) + } + + // Write to temp file + tempFile := stateFile + ".tmp" + if err := os.WriteFile(tempFile, data, 0o644); err != nil { + return xerrors.Errorf("failed to write temp state file: %w", err) + } + + // Atomic rename + if err := os.Rename(tempFile, stateFile); err != nil { + return xerrors.Errorf("failed to rename state file: %w", err) + } + + // Clear dirty flag after successful save + c.dirty = false + return nil +} + +func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, error) { + c.lock.Lock() + defer c.lock.Unlock() + + // Skip if state file is not configured + if stateFile == "" { + return nil, nil + } + + // Check if file exists + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + return nil, nil + } + + // Read state file + data, err := os.ReadFile(stateFile) + if err != nil { + return nil, xerrors.Errorf("failed to read state file: %w", err) + } + + if len(data) == 0 { + return nil, xerrors.Errorf("failed to read state file: empty state file") + } + + var agentState AgentState + if err := json.Unmarshal(data, &agentState); err != nil { + return nil, xerrors.Errorf("failed to unmarshal state: %w", err) + } + + c.InitialPromptSent = agentState.InitialPromptSent + c.InitialPrompt = agentState.InitialPrompt + c.messages = agentState.Messages + + // Store the first stable snapshot for filtering later + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) > 0 { + c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") + } + + return c.messages, nil +} + +func (c *PTYConversation) skipInitialSnapshot(screen string) string { + newScreen := strings.ReplaceAll(screen, c.firstStableSnapshot, "") + + // Before the first user message after loading state, return the last message from the loaded state. + // This prevents computing incorrect diffs from the restored screen, as the agent's message should + // remain stable until the user continues the conversation. + if c.userSentMessageAfterLoadState == false { + newScreen = "\n" + c.messages[len(c.messages)-1].Message + } + + return newScreen +} From ca3cdff8c019238b02f5f38644ae96d074b0f2dc Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Sat, 31 Jan 2026 21:56:44 +0530 Subject: [PATCH 03/27] feat: pid file writing and clearing and improved error handling for load state --- lib/httpapi/server.go | 57 +++++++++++++++++++++++---- lib/screentracker/pty_conversation.go | 9 ++++- 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 965fa28f..e895c4ec 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -350,16 +350,11 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { for { currentStatus := s.conversation.Status() - // Send initial prompt when agent becomes stable for the first time + // Send initial prompt & load state when agent becomes stable for the first time if convertStatus(currentStatus) == AgentStatusStable { if !s.stateLoadComplete && s.statePersistenceCfg.LoadState { - _, err := s.conversation.LoadState(s.statePersistenceCfg.StateFile) - if err != nil { - s.logger.Warn("Failed to load state file", "path", s.statePersistenceCfg.StateFile, "err", err) - } else { - s.logger.Info("Successfully loaded state", "path", s.statePersistenceCfg.StateFile) - } + _, _ = s.conversation.LoadState(s.statePersistenceCfg.StateFile) s.stateLoadComplete = true } if !s.conversation.InitialPromptSent { @@ -612,6 +607,11 @@ func (s *Server) Start() error { Handler: s.router, } + // Write PID file if configured + if err := s.writePIDFile(); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + return s.srv.ListenAndServe() } @@ -626,6 +626,9 @@ func (s *Server) Stop(ctx context.Context) error { } } + // Clean up PID file + s.cleanupPIDFile() + // Clean up temporary directory s.cleanupTempDir() @@ -644,6 +647,43 @@ func (s *Server) cleanupTempDir() { } } +// writePIDFile writes the current process ID to the configured PID file +func (s *Server) writePIDFile() error { + if s.statePersistenceCfg.PidFile == "" { + return nil + } + + pid := os.Getpid() + pidContent := fmt.Sprintf("%d\n", pid) + + // Create directory if it doesn't exist + dir := filepath.Dir(s.statePersistenceCfg.PidFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create PID file directory: %w", err) + } + + // Write PID file + if err := os.WriteFile(s.statePersistenceCfg.PidFile, []byte(pidContent), 0o644); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + + s.logger.Info("Wrote PID file", "pidFile", s.statePersistenceCfg.PidFile, "pid", pid) + return nil +} + +// cleanupPIDFile removes the PID file if it exists +func (s *Server) cleanupPIDFile() { + if s.statePersistenceCfg.PidFile == "" { + return + } + + if err := os.Remove(s.statePersistenceCfg.PidFile); err != nil && !os.IsNotExist(err) { + s.logger.Error("Failed to remove PID file", "pidFile", s.statePersistenceCfg.PidFile, "error", err) + } else if err == nil { + s.logger.Info("Removed PID file", "pidFile", s.statePersistenceCfg.PidFile) + } +} + // HandleSignals sets up signal handlers for: // - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process // - SIGUSR1: save conversation state without exiting @@ -664,6 +704,9 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { } } + // Clean up PID file + s.cleanupPIDFile() + // Now close the process if err := process.Close(s.logger, 5*time.Second); err != nil { s.logger.Error("Error closing process", "signal", sig, "error", err) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 2f8e804f..abfc886e 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -442,22 +442,26 @@ func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, er // Check if file exists if _, err := os.Stat(stateFile); os.IsNotExist(err) { + c.cfg.Logger.Info("No previous state to load (file does not exist)", "path", stateFile) return nil, nil } // Read state file data, err := os.ReadFile(stateFile) if err != nil { + c.cfg.Logger.Warn("Failed to load state file", "path", stateFile, "err", err) return nil, xerrors.Errorf("failed to read state file: %w", err) } if len(data) == 0 { - return nil, xerrors.Errorf("failed to read state file: empty state file") + c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) + return nil, nil } var agentState AgentState if err := json.Unmarshal(data, &agentState); err != nil { - return nil, xerrors.Errorf("failed to unmarshal state: %w", err) + c.cfg.Logger.Warn("Failed to load state file (corrupted or invalid JSON)", "path", stateFile, "err", err) + return nil, xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) } c.InitialPromptSent = agentState.InitialPromptSent @@ -470,6 +474,7 @@ func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, er c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") } + c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) return c.messages, nil } From 1c224e9235a7187f74d1d3a09a36f46423cc53cc Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Sat, 31 Jan 2026 22:19:39 +0530 Subject: [PATCH 04/27] refactor: remove redundant save logic --- lib/httpapi/server.go | 12 ------------ lib/screentracker/pty_conversation.go | 1 - 2 files changed, 13 deletions(-) diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index e895c4ec..68be4b4a 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -617,18 +617,6 @@ func (s *Server) Start() error { // Stop gracefully stops the HTTP server func (s *Server) Stop(ctx context.Context) error { - // Save conversation state if configured - if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { - s.logger.Error("Failed to save conversation state", "error", err) - } else { - s.logger.Info("Saved conversation state", "stateFile", s.statePersistenceCfg.StateFile) - } - } - - // Clean up PID file - s.cleanupPIDFile() - // Clean up temporary directory s.cleanupTempDir() diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index abfc886e..9c603b04 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -257,7 +257,6 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error { Role: ConversationRoleUser, Time: now, }) - c.dirty = true c.userSentMessageAfterLoadState = true return nil From 30f82d7c4d7080a3d6b657d5bdc8b3f3e4fc01e7 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Mon, 2 Feb 2026 16:18:40 +0530 Subject: [PATCH 05/27] feat: improve logic for first run with empty state file --- lib/screentracker/pty_conversation.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 9c603b04..a8f73c0a 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -106,6 +106,8 @@ type PTYConversation struct { firstStableSnapshot string // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state userSentMessageAfterLoadState bool + // loadStateSuccessful indicates whether conversation state was successfully restored from file. + loadStateSuccessful bool } var _ Conversation = &PTYConversation{} @@ -123,9 +125,13 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, initialPrompt string Time: cfg.GetTime(), }, }, - InitialPrompt: initialPrompt, - InitialPromptSent: len(initialPrompt) == 0, - toolCallMessageSet: make(map[string]bool), + InitialPrompt: initialPrompt, + InitialPromptSent: len(initialPrompt) == 0, + toolCallMessageSet: make(map[string]bool), + dirty: false, + firstStableSnapshot: "", + userSentMessageAfterLoadState: false, + loadStateSuccessful: false, } return c } @@ -170,7 +176,9 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } - agentMessage = c.skipInitialSnapshot(agentMessage) + if c.loadStateSuccessful { + agentMessage = c.adjustScreenAfterStateLoad(agentMessage) + } if c.cfg.FormatToolCall != nil { agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) } @@ -473,12 +481,13 @@ func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, er c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") } + c.loadStateSuccessful = true c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) return c.messages, nil } -func (c *PTYConversation) skipInitialSnapshot(screen string) string { - newScreen := strings.ReplaceAll(screen, c.firstStableSnapshot, "") +func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { + newScreen := strings.Replace(screen, c.firstStableSnapshot, "", 1) // Before the first user message after loading state, return the last message from the loaded state. // This prevents computing incorrect diffs from the restored screen, as the agent's message should From 12bed1c23355eb6a53b57cb412e9ed7458dcf29b Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 3 Feb 2026 16:27:53 +0530 Subject: [PATCH 06/27] feat: implement platform-specific signal handling --- lib/httpapi/server.go | 72 ++++++++------------------- lib/httpapi/server_signals_unix.go | 42 ++++++++++++++++ lib/httpapi/server_signals_windows.go | 26 ++++++++++ 3 files changed, 89 insertions(+), 51 deletions(-) create mode 100644 lib/httpapi/server_signals_unix.go create mode 100644 lib/httpapi/server_signals_windows.go diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 68be4b4a..cb761189 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -11,13 +11,11 @@ import ( "net/http" "net/url" "os" - "os/signal" "path/filepath" "slices" "sort" "strings" "sync" - "syscall" "time" "unicode" @@ -672,59 +670,31 @@ func (s *Server) cleanupPIDFile() { } } -// HandleSignals sets up signal handlers for: -// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process -// - SIGUSR1: save conversation state without exiting -func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { - // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) - shutdownCh := make(chan os.Signal, 1) - signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) - go func() { - sig := <-shutdownCh - s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) - - // Save conversation state if configured (synchronously before closing process) - if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { - s.logger.Error("Failed to save conversation state on signal", "signal", sig, "error", err) - } else { - s.logger.Info("Saved conversation state on signal", "signal", sig, "stateFile", s.statePersistenceCfg.StateFile) - } - } +// saveAndCleanup saves the conversation state and cleans up before shutdown +func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { + // Save conversation state if configured (synchronously before closing process) + s.saveStateIfConfigured(sig.String()) - // Clean up PID file - s.cleanupPIDFile() + // Clean up PID file + s.cleanupPIDFile() - // Now close the process - if err := process.Close(s.logger, 5*time.Second); err != nil { - s.logger.Error("Error closing process", "signal", sig, "error", err) - } - }() + // Now close the process + if err := process.Close(s.logger, 5*time.Second); err != nil { + s.logger.Error("Error closing process", "signal", sig, "error", err) + } +} - // Handle SIGUSR1 for save without exit - saveOnlyCh := make(chan os.Signal, 1) - signal.Notify(saveOnlyCh, syscall.SIGUSR1) - go func() { - for { - select { - case <-saveOnlyCh: - s.logger.Info("Received SIGUSR1, saving state without exiting") - - // Save conversation state if configured - if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { - s.logger.Error("Failed to save conversation state on SIGUSR1", "error", err) - } else { - s.logger.Info("Saved conversation state on SIGUSR1", "stateFile", s.statePersistenceCfg.StateFile) - } - } else { - s.logger.Warn("SIGUSR1 received but state saving is not configured") - } - case <-ctx.Done(): - return - } +// saveStateIfConfigured saves the conversation state if configured +func (s *Server) saveStateIfConfigured(source string) { + if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + s.logger.Error("Failed to save conversation state", "source", source, "error", err) + } else { + s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceCfg.StateFile) } - }() + } else { + s.logger.Warn("Save requested but state saving is not configured", "source", source) + } } // registerStaticFileRoutes sets up routes for serving static files diff --git a/lib/httpapi/server_signals_unix.go b/lib/httpapi/server_signals_unix.go new file mode 100644 index 00000000..bfc6eaa9 --- /dev/null +++ b/lib/httpapi/server_signals_unix.go @@ -0,0 +1,42 @@ +//go:build unix + +package httpapi + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/termexec" +) + +// HandleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process +// - SIGUSR1: save conversation state without exiting +func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + go func() { + sig := <-shutdownCh + s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) + + s.saveAndCleanup(sig, process) + }() + + // Handle SIGUSR1 for save without exit + saveOnlyCh := make(chan os.Signal, 1) + signal.Notify(saveOnlyCh, syscall.SIGUSR1) + go func() { + for { + select { + case <-saveOnlyCh: + s.logger.Info("Received SIGUSR1, saving state without exiting") + s.saveStateIfConfigured("SIGUSR1") + case <-ctx.Done(): + return + } + } + }() +} diff --git a/lib/httpapi/server_signals_windows.go b/lib/httpapi/server_signals_windows.go new file mode 100644 index 00000000..ea07c6ad --- /dev/null +++ b/lib/httpapi/server_signals_windows.go @@ -0,0 +1,26 @@ +//go:build windows + +package httpapi + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/termexec" +) + +// HandleSignals sets up signal handlers for Windows. +// Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). +func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT only on Windows) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) + go func() { + sig := <-shutdownCh + s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) + + s.saveAndCleanup(sig, process) + }() +} From e366e8b8f512a5ffde415063aa04261f4b4f3dd0 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 5 Feb 2026 18:11:03 +0530 Subject: [PATCH 07/27] feat: refactor cfg -> Config and move pid ops to server --- cmd/server/server.go | 42 +++++++++++++- lib/httpapi/server.go | 126 ++++++++++++++---------------------------- 2 files changed, 80 insertions(+), 88 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 5a125af4..561877e8 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -8,6 +8,7 @@ import ( "log/slog" "net/http" "os" + "path/filepath" "sort" "strings" @@ -123,6 +124,15 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er pidFile := viper.GetString(PidFile) + // Write PID file if configured + if pidFile != "" { + if err := writePIDFile(pidFile, logger); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + // Ensure PID file is cleaned up on exit + defer cleanupPIDFile(pidFile, logger) + } + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -148,11 +158,10 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, - StatePersistenceCfg: httpapi.StatePersistenceCfg{ + StatePersistenceConfig: httpapi.StatePersistenceConfig{ StateFile: stateFile, LoadState: loadState, SaveState: saveState, - PidFile: pidFile, }, }) @@ -200,6 +209,35 @@ var agentNames = (func() []string { return names })() +// writePIDFile writes the current process ID to the specified file +func writePIDFile(pidFile string, logger *slog.Logger) error { + pid := os.Getpid() + pidContent := fmt.Sprintf("%d\n", pid) + + // Create directory if it doesn't exist + dir := filepath.Dir(pidFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create PID file directory: %w", err) + } + + // Write PID file + if err := os.WriteFile(pidFile, []byte(pidContent), 0o644); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + + logger.Info("Wrote PID file", "pidFile", pidFile, "pid", pid) + return nil +} + +// cleanupPIDFile removes the PID file if it exists +func cleanupPIDFile(pidFile string, logger *slog.Logger) { + if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { + logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err) + } else if err == nil { + logger.Info("Removed PID file", "pidFile", pidFile) + } +} + type flagSpec struct { name string shorthand string diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index cb761189..d234bad0 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -34,20 +34,20 @@ import ( // Server represents the HTTP server type Server struct { - router chi.Router - api huma.API - port int - srv *http.Server - mu sync.RWMutex - logger *slog.Logger - conversation *st.PTYConversation - agentio *termexec.Process - agentType mf.AgentType - emitter *EventEmitter - chatBasePath string - tempDir string - statePersistenceCfg StatePersistenceCfg - stateLoadComplete bool + router chi.Router + api huma.API + port int + srv *http.Server + mu sync.RWMutex + logger *slog.Logger + conversation *st.PTYConversation + agentio *termexec.Process + agentType mf.AgentType + emitter *EventEmitter + chatBasePath string + tempDir string + statePersistenceConfig StatePersistenceConfig + stateLoadComplete bool } func (s *Server) NormalizeSchema(schema any) any { @@ -96,22 +96,21 @@ func (s *Server) GetOpenAPI() string { // because the action of taking a snapshot takes time too. const snapshotInterval = 25 * time.Millisecond -type StatePersistenceCfg struct { +type StatePersistenceConfig struct { StateFile string LoadState bool SaveState bool - PidFile string } type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string - StatePersistenceCfg StatePersistenceCfg + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + StatePersistenceConfig StatePersistenceConfig } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -270,18 +269,18 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { logger.Info("Created temporary directory for uploads", "tempDir", tempDir) s := &Server{ - router: router, - api: api, - port: config.Port, - conversation: conversation, - logger: logger, - agentio: config.Process, - agentType: config.AgentType, - emitter: emitter, - chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), - tempDir: tempDir, - statePersistenceCfg: config.StatePersistenceCfg, - stateLoadComplete: false, + router: router, + api: api, + port: config.Port, + conversation: conversation, + logger: logger, + agentio: config.Process, + agentType: config.AgentType, + emitter: emitter, + chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), + tempDir: tempDir, + statePersistenceConfig: config.StatePersistenceConfig, + stateLoadComplete: false, } // Register API routes @@ -351,8 +350,8 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) { // Send initial prompt & load state when agent becomes stable for the first time if convertStatus(currentStatus) == AgentStatusStable { - if !s.stateLoadComplete && s.statePersistenceCfg.LoadState { - _, _ = s.conversation.LoadState(s.statePersistenceCfg.StateFile) + if !s.stateLoadComplete && s.statePersistenceConfig.LoadState { + _, _ = s.conversation.LoadState(s.statePersistenceConfig.StateFile) s.stateLoadComplete = true } if !s.conversation.InitialPromptSent { @@ -605,11 +604,6 @@ func (s *Server) Start() error { Handler: s.router, } - // Write PID file if configured - if err := s.writePIDFile(); err != nil { - return xerrors.Errorf("failed to write PID file: %w", err) - } - return s.srv.ListenAndServe() } @@ -633,51 +627,11 @@ func (s *Server) cleanupTempDir() { } } -// writePIDFile writes the current process ID to the configured PID file -func (s *Server) writePIDFile() error { - if s.statePersistenceCfg.PidFile == "" { - return nil - } - - pid := os.Getpid() - pidContent := fmt.Sprintf("%d\n", pid) - - // Create directory if it doesn't exist - dir := filepath.Dir(s.statePersistenceCfg.PidFile) - if err := os.MkdirAll(dir, 0o755); err != nil { - return xerrors.Errorf("failed to create PID file directory: %w", err) - } - - // Write PID file - if err := os.WriteFile(s.statePersistenceCfg.PidFile, []byte(pidContent), 0o644); err != nil { - return xerrors.Errorf("failed to write PID file: %w", err) - } - - s.logger.Info("Wrote PID file", "pidFile", s.statePersistenceCfg.PidFile, "pid", pid) - return nil -} - -// cleanupPIDFile removes the PID file if it exists -func (s *Server) cleanupPIDFile() { - if s.statePersistenceCfg.PidFile == "" { - return - } - - if err := os.Remove(s.statePersistenceCfg.PidFile); err != nil && !os.IsNotExist(err) { - s.logger.Error("Failed to remove PID file", "pidFile", s.statePersistenceCfg.PidFile, "error", err) - } else if err == nil { - s.logger.Info("Removed PID file", "pidFile", s.statePersistenceCfg.PidFile) - } -} - // saveAndCleanup saves the conversation state and cleans up before shutdown func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { // Save conversation state if configured (synchronously before closing process) s.saveStateIfConfigured(sig.String()) - // Clean up PID file - s.cleanupPIDFile() - // Now close the process if err := process.Close(s.logger, 5*time.Second); err != nil { s.logger.Error("Error closing process", "signal", sig, "error", err) @@ -686,11 +640,11 @@ func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { // saveStateIfConfigured saves the conversation state if configured func (s *Server) saveStateIfConfigured(source string) { - if s.statePersistenceCfg.SaveState && s.statePersistenceCfg.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceCfg.StateFile); err != nil { + if s.statePersistenceConfig.SaveState && s.statePersistenceConfig.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceConfig.StateFile); err != nil { s.logger.Error("Failed to save conversation state", "source", source, "error", err) } else { - s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceCfg.StateFile) + s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceConfig.StateFile) } } else { s.logger.Warn("Save requested but state saving is not configured", "source", source) From 26fdf818438c6837debb644c153829a64cb76f81 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 5 Feb 2026 18:15:25 +0530 Subject: [PATCH 08/27] feat: unregister the signal handlers on teardown --- lib/httpapi/server_signals_unix.go | 2 ++ lib/httpapi/server_signals_windows.go | 1 + 2 files changed, 3 insertions(+) diff --git a/lib/httpapi/server_signals_unix.go b/lib/httpapi/server_signals_unix.go index bfc6eaa9..837db86c 100644 --- a/lib/httpapi/server_signals_unix.go +++ b/lib/httpapi/server_signals_unix.go @@ -19,6 +19,7 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) go func() { + defer signal.Stop(shutdownCh) sig := <-shutdownCh s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) @@ -29,6 +30,7 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { saveOnlyCh := make(chan os.Signal, 1) signal.Notify(saveOnlyCh, syscall.SIGUSR1) go func() { + defer signal.Stop(saveOnlyCh) for { select { case <-saveOnlyCh: diff --git a/lib/httpapi/server_signals_windows.go b/lib/httpapi/server_signals_windows.go index ea07c6ad..503e56a9 100644 --- a/lib/httpapi/server_signals_windows.go +++ b/lib/httpapi/server_signals_windows.go @@ -18,6 +18,7 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) go func() { + defer signal.Stop(shutdownCh) sig := <-shutdownCh s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) From 5795db7235a1436afc8a4ecd343ad072303dab23 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 14:27:51 +0530 Subject: [PATCH 09/27] feat: resolve conflicts and improve shutdown sequence --- cmd/server/server.go | 26 ++++- cmd/server/signals.go | 36 ++++++ .../server/signals_unix.go | 22 ++-- .../server/signals_windows.go | 12 +- lib/httpapi/server.go | 104 +++++------------- lib/screentracker/conversation.go | 7 ++ lib/screentracker/pty_conversation.go | 97 +++++++++++----- 7 files changed, 179 insertions(+), 125 deletions(-) create mode 100644 cmd/server/signals.go rename lib/httpapi/server_signals_unix.go => cmd/server/signals_unix.go (52%) rename lib/httpapi/server_signals_windows.go => cmd/server/signals_windows.go (60%) diff --git a/cmd/server/server.go b/cmd/server/server.go index b42c232a..46b21e26 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -11,7 +11,9 @@ import ( "path/filepath" "sort" "strings" + "time" + "github.com/coder/agentapi/lib/screentracker" "github.com/mattn/go-isatty" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -106,8 +108,10 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er // Get the variables related to state management stateFile := viper.GetString(StateFile) - loadState := true - saveState := true + loadState := false + saveState := false + + // Validate state file configuration if stateFile != "" { if !viper.IsSet(LoadState) { loadState = true @@ -120,6 +124,14 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } else { saveState = viper.GetBool(SaveState) } + } else { + // No state file provided - ensure load/save flags are not explicitly set to true + if viper.IsSet(LoadState) && viper.GetBool(LoadState) { + return xerrors.Errorf("--load-state requires --state-file to be set") + } + if viper.IsSet(SaveState) && viper.GetBool(SaveState) { + return xerrors.Errorf("--save-state requires --state-file to be set") + } } pidFile := viper.GetString(PidFile) @@ -158,7 +170,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, - StatePersistenceConfig: httpapi.StatePersistenceConfig{ + StatePersistenceConfig: screentracker.StatePersistenceConfig{ StateFile: stateFile, LoadState: loadState, SaveState: saveState, @@ -172,7 +184,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er fmt.Println(srv.GetOpenAPI()) return nil } - srv.HandleSignals(ctx, process) + handleSignals(ctx, logger, srv, process) logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) go func() { @@ -184,8 +196,10 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) } } - if err := srv.Stop(ctx); err != nil { - logger.Error("Failed to stop server", "error", err) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop server after process exit", "error", err) } }() if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { diff --git a/cmd/server/signals.go b/cmd/server/signals.go new file mode 100644 index 00000000..e66554e7 --- /dev/null +++ b/cmd/server/signals.go @@ -0,0 +1,36 @@ +package server + +import ( + "context" + "log/slog" + "os" + "time" + + "github.com/coder/agentapi/lib/httpapi" + "github.com/coder/agentapi/lib/termexec" +) + +// performGracefulShutdown handles the common shutdown logic for all platforms. +// It saves state, stops the HTTP server, closes the process, and exits. +func performGracefulShutdown(sig os.Signal, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { + logger.Info("Received shutdown signal, initiating graceful shutdown", "signal", sig) + + // Save state + if err := srv.SaveState(sig.String()); err != nil { + logger.Error("Failed to save state during shutdown", "signal", sig, "error", err) + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop HTTP server", "signal", sig, "error", err) + } + + // Close the process + if err := process.Close(logger, 5*time.Second); err != nil { + logger.Error("Failed to close process cleanly", "signal", sig, "error", err) + } + + // Exit cleanly + os.Exit(0) +} diff --git a/lib/httpapi/server_signals_unix.go b/cmd/server/signals_unix.go similarity index 52% rename from lib/httpapi/server_signals_unix.go rename to cmd/server/signals_unix.go index 837db86c..fe6b4693 100644 --- a/lib/httpapi/server_signals_unix.go +++ b/cmd/server/signals_unix.go @@ -1,29 +1,29 @@ //go:build unix -package httpapi +package server import ( "context" + "log/slog" "os" "os/signal" "syscall" + "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/termexec" ) -// HandleSignals sets up signal handlers for: -// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process +// handleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: save conversation state, stop server, then close the process // - SIGUSR1: save conversation state without exiting -func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { +func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) shutdownCh := make(chan os.Signal, 1) - signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh - s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) - - s.saveAndCleanup(sig, process) + performGracefulShutdown(sig, logger, srv, process) }() // Handle SIGUSR1 for save without exit @@ -34,8 +34,10 @@ func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { for { select { case <-saveOnlyCh: - s.logger.Info("Received SIGUSR1, saving state without exiting") - s.saveStateIfConfigured("SIGUSR1") + logger.Info("Received SIGUSR1, saving state without exiting") + if err := srv.SaveState("SIGUSR1"); err != nil { + logger.Error("Failed to save state on SIGUSR1", "error", err) + } case <-ctx.Done(): return } diff --git a/lib/httpapi/server_signals_windows.go b/cmd/server/signals_windows.go similarity index 60% rename from lib/httpapi/server_signals_windows.go rename to cmd/server/signals_windows.go index 503e56a9..52d90616 100644 --- a/lib/httpapi/server_signals_windows.go +++ b/cmd/server/signals_windows.go @@ -1,27 +1,27 @@ //go:build windows -package httpapi +package server import ( "context" + "log/slog" "os" "os/signal" "syscall" + "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/termexec" ) -// HandleSignals sets up signal handlers for Windows. +// handleSignals sets up signal handlers for Windows. // Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). -func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { +func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { // Handle shutdown signals (SIGTERM, SIGINT only on Windows) shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh - s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) - - s.saveAndCleanup(sig, process) + performGracefulShutdown(sig, logger, srv, process) }() } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 4deb4e38..32cd64cf 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -40,6 +40,7 @@ type Server struct { port int srv *http.Server mu sync.RWMutex + stopOnce sync.Once logger *slog.Logger conversation st.Conversation agentio *termexec.Process @@ -48,8 +49,6 @@ type Server struct { chatBasePath string tempDir string clock quartz.Clock - statePersistenceConfig StatePersistenceConfig - stateLoadComplete bool } func (s *Server) NormalizeSchema(schema any) any { @@ -98,22 +97,16 @@ func (s *Server) GetOpenAPI() string { // because the action of taking a snapshot takes time too. const snapshotInterval = 25 * time.Millisecond -type StatePersistenceConfig struct { - StateFile string - LoadState bool - SaveState bool -} - type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string - Clock quartz.Clock - StatePersistenceConfig StatePersistenceConfig + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + Clock quartz.Clock + StatePersistenceConfig st.StatePersistenceConfig } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -279,7 +272,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { emitter.UpdateMessagesAndEmitChanges(messages) emitter.UpdateScreenAndEmitChanges(screen) }, - Logger: logger, + Logger: logger, + StatePersistenceConfig: config.StatePersistenceConfig, }) // Create temporary directory for uploads @@ -301,8 +295,6 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), tempDir: tempDir, clock: config.Clock, - statePersistenceConfig: config.StatePersistenceConfig, - stateLoadComplete: false, } // Register API routes @@ -373,32 +365,6 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) { next(ctx) } -func (s *Server) StartSnapshotLoop(ctx context.Context) { - s.conversation.StartSnapshotLoop(ctx) - go func() { - for { - currentStatus := s.conversation.Status() - - // Send initial prompt when agent becomes stable for the first time - if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { - - if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { - s.logger.Error("Failed to send initial prompt", "error", err) - } else { - s.conversation.InitialPromptSent = true - s.conversation.ReadyForInitialPrompt = false - currentStatus = st.ConversationStatusChanging - s.logger.Info("Initial prompt sent successfully") - } - } - s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) - s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) - s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen()) - time.Sleep(snapshotInterval) - } - }() -} - // registerRoutes sets up all API endpoints func (s *Server) registerRoutes() { // GET /status endpoint @@ -633,15 +599,19 @@ func (s *Server) Start() error { return s.srv.ListenAndServe() } -// Stop gracefully stops the HTTP server +// Stop gracefully stops the HTTP server. It is safe to call multiple times; +// only the first call will perform the shutdown, subsequent calls are no-ops. func (s *Server) Stop(ctx context.Context) error { - // Clean up temporary directory - s.cleanupTempDir() + var err error + s.stopOnce.Do(func() { + // Clean up temporary directory + s.cleanupTempDir() - if s.srv != nil { - return s.srv.Shutdown(ctx) - } - return nil + if s.srv != nil { + err = s.srv.Shutdown(ctx) + } + }) + return err } // cleanupTempDir removes the temporary directory and all its contents @@ -653,28 +623,14 @@ func (s *Server) cleanupTempDir() { } } -// saveAndCleanup saves the conversation state and cleans up before shutdown -func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { - // Save conversation state if configured (synchronously before closing process) - s.saveStateIfConfigured(sig.String()) - - // Now close the process - if err := process.Close(s.logger, 5*time.Second); err != nil { - s.logger.Error("Error closing process", "signal", sig, "error", err) - } -} - -// saveStateIfConfigured saves the conversation state if configured -func (s *Server) saveStateIfConfigured(source string) { - if s.statePersistenceConfig.SaveState && s.statePersistenceConfig.StateFile != "" { - if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceConfig.StateFile); err != nil { - s.logger.Error("Failed to save conversation state", "source", source, "error", err) - } else { - s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceConfig.StateFile) - } - } else { - s.logger.Warn("Save requested but state saving is not configured", "source", source) +// SaveState saves the conversation state if configured. This can be called from signal handlers. +// The source parameter indicates what triggered the save (e.g., "SIGTERM", "SIGUSR1"). +func (s *Server) SaveState(source string) error { + if err := s.conversation.SaveState(); err != nil { + s.logger.Error("Failed to save conversation state", "source", source, "error", err) + return err } + return nil } // registerStaticFileRoutes sets up routes for serving static files diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 9e6b856f..44e303f1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -63,6 +63,7 @@ type Conversation interface { Start(context.Context) Status() ConversationStatus Text() string + SaveState() error } type ConversationMessage struct { @@ -71,3 +72,9 @@ type ConversationMessage struct { Role ConversationRole Time time.Time } + +type StatePersistenceConfig struct { + StateFile string + LoadState bool + SaveState bool +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 91ed7cba..fd3dedcf 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -29,6 +29,13 @@ type MessagePartText struct { Hidden bool } +type AgentState struct { + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` + //InitialPromptSent bool `json:"initial_prompt_sent"` +} + var _ MessagePart = &MessagePartText{} func (p MessagePartText) Do(writer AgentIO) error { @@ -72,8 +79,9 @@ type PTYConversationConfig struct { // InitialPrompt is the initial prompt to send to the agent once ready InitialPrompt []MessagePart // OnSnapshot is called after each snapshot with current status, messages, and screen content - OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) - Logger *slog.Logger + OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) + Logger *slog.Logger + StatePersistenceConfig StatePersistenceConfig } func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { @@ -142,9 +150,9 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation { Time: cfg.Clock.Now(), }, }, - outboundQueue: make(chan outboundMessage, 1), - stableSignal: make(chan struct{}, 1), - toolCallMessageSet: make(map[string]bool), + outboundQueue: make(chan outboundMessage, 1), + stableSignal: make(chan struct{}, 1), + toolCallMessageSet: make(map[string]bool), dirty: false, firstStableSnapshot: "", userSentMessageAfterLoadState: false, @@ -178,6 +186,12 @@ func (c *PTYConversation) Start(ctx context.Context) { if !c.initialPromptReady && c.cfg.ReadyForInitialPrompt(screen) { c.initialPromptReady = true } + + if !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { + _ = c.loadState() + c.loadStateSuccessful = true + } + if c.initialPromptReady && len(c.outboundQueue) > 0 && c.isScreenStableLocked() { select { case c.stableSignal <- struct{}{}: @@ -284,6 +298,8 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp c.messages[len(c.messages)-1] = conversationMessage } c.messages[len(c.messages)-1].Id = len(c.messages) - 1 + + c.dirty = true } // caller MUST hold c.lock @@ -297,10 +313,6 @@ func (c *PTYConversation) snapshotLocked(screen string) { } func (c *PTYConversation) Send(messageParts ...MessagePart) error { - if !c.cfg.SkipSendMessageStatusCheck && c.statusLocked() != ConversationStatusStable { - return MessageValidationErrorChanging - } - // Validate message content before enqueueing var sb strings.Builder for _, part := range messageParts { @@ -514,26 +526,41 @@ func (c *PTYConversation) Text() string { return snapshots[len(snapshots)-1].screen } -func (c *PTYConversation) SaveState(conversation []ConversationMessage, stateFile string) error { +func (c *PTYConversation) SaveState() error { + conversation := c.Messages() + c.lock.Lock() defer c.lock.Unlock() - // Skip if state file is not configured - if stateFile == "" { + stateFile := c.cfg.StatePersistenceConfig.StateFile + saveState := c.cfg.StatePersistenceConfig.SaveState + + if !saveState { + c.cfg.Logger.Info("") return nil } // Skip if not dirty if !c.dirty { + c.cfg.Logger.Info("Skipping state save: no changes since last save") return nil } + // Serialize initial prompt from message parts + var initialPromptStr string + if len(c.cfg.InitialPrompt) > 0 { + var sb strings.Builder + for _, part := range c.cfg.InitialPrompt { + sb.WriteString(part.String()) + } + initialPromptStr = sb.String() + } + // Use atomic write: write to temp file, then rename to target path data, err := json.MarshalIndent(AgentState{ - Version: 1, - Messages: conversation, - InitialPrompt: c.InitialPrompt, - InitialPromptSent: c.InitialPromptSent, + Version: 1, + Messages: conversation, + InitialPrompt: initialPromptStr, }, "", " ") if err != nil { return xerrors.Errorf("failed to marshal state: %w", err) @@ -558,44 +585,51 @@ func (c *PTYConversation) SaveState(conversation []ConversationMessage, stateFil // Clear dirty flag after successful save c.dirty = false + + c.cfg.Logger.Info("State saved successfully to: %s", stateFile) + return nil } -func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, error) { - c.lock.Lock() - defer c.lock.Unlock() +// LoadState loads the state, this method assumes that caller holds the Lock +func (c *PTYConversation) loadState() error { + stateFile := c.cfg.StatePersistenceConfig.StateFile + loadState := c.cfg.StatePersistenceConfig.LoadState - // Skip if state file is not configured - if stateFile == "" { - return nil, nil + if !loadState { + return nil } // Check if file exists if _, err := os.Stat(stateFile); os.IsNotExist(err) { c.cfg.Logger.Info("No previous state to load (file does not exist)", "path", stateFile) - return nil, nil + return nil } // Read state file data, err := os.ReadFile(stateFile) if err != nil { c.cfg.Logger.Warn("Failed to load state file", "path", stateFile, "err", err) - return nil, xerrors.Errorf("failed to read state file: %w", err) + return xerrors.Errorf("failed to read state file: %w", err) } if len(data) == 0 { c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) - return nil, nil + return nil } var agentState AgentState if err := json.Unmarshal(data, &agentState); err != nil { c.cfg.Logger.Warn("Failed to load state file (corrupted or invalid JSON)", "path", stateFile, "err", err) - return nil, xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) + return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) } - c.InitialPromptSent = agentState.InitialPromptSent - c.InitialPrompt = agentState.InitialPrompt + //c.cfg.initialPromptSent = agentState.InitialPromptSent + c.cfg.InitialPrompt = []MessagePart{MessagePartText{ + Content: agentState.InitialPrompt, + Alias: "", + Hidden: false, + }} c.messages = agentState.Messages // Store the first stable snapshot for filtering later @@ -606,10 +640,15 @@ func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, er c.loadStateSuccessful = true c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) - return c.messages, nil + return nil } func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { + + if c.firstStableSnapshot == "" { + return screen + } + newScreen := strings.Replace(screen, c.firstStableSnapshot, "", 1) // Before the first user message after loading state, return the last message from the loaded state. From 9deab88258e92521ca36cafac48e898e2d7ab651 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 15:08:43 +0530 Subject: [PATCH 10/27] feat: resolve conflicts --- lib/screentracker/pty_conversation.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 946a07fd..d40337c5 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -195,7 +195,7 @@ func (c *PTYConversation) Start(ctx context.Context) { c.initialPromptReady = true } - if !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { + if c.initialPromptReady && !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { _ = c.loadState() c.loadStateSuccessful = true } @@ -596,7 +596,7 @@ func (c *PTYConversation) SaveState() error { // Clear dirty flag after successful save c.dirty = false - c.cfg.Logger.Info("State saved successfully to: %s", stateFile) + c.cfg.Logger.Info(fmt.Sprintf("State saved successfully to: %s", stateFile)) return nil } From 18fb1e4bdf1dcc332461376f518740a0ebd9d5d1 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 15:13:33 +0530 Subject: [PATCH 11/27] chore: not dirty after load state --- lib/screentracker/pty_conversation.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index d40337c5..a8ddb5ea 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -649,6 +649,8 @@ func (c *PTYConversation) loadState() error { } c.loadStateSuccessful = true + c.dirty = false + c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) return nil } From b719dac58d869d4a9b82cac051afb058945f185d Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 15:55:21 +0530 Subject: [PATCH 12/27] feat: add tests --- cmd/server/server_test.go | 214 +++++++++++++ lib/httpapi/server_test.go | 34 ++ lib/screentracker/pty_conversation_test.go | 353 +++++++++++++++++++++ 3 files changed, 601 insertions(+) diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index bd07fc63..4affad0d 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -2,6 +2,8 @@ package server import ( "fmt" + "io" + "log/slog" "os" "strings" "testing" @@ -477,6 +479,218 @@ func TestServerCmd_AllowedHosts(t *testing.T) { } } +func TestServerCmd_StatePersistenceFlags(t *testing.T) { + // NOTE: These tests use --exit flag to test flag parsing and defaults. + // Runtime validation that happens in runServer (e.g., "--load-state requires --state-file") + // would call os.Exit(1) which terminates the test process, so those validations + // are tested through integration/E2E tests instead. + + t.Run("state-file with defaults", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + // load-state and save-state default to true when state-file is set (validated in runServer) + }) + + t.Run("state-file with explicit load-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--load-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, false, viper.GetBool(LoadState)) + }) + + t.Run("state-file with explicit save-state=false", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--state-file", "/tmp/state.json", "--save-state=false", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, false, viper.GetBool(SaveState)) + }) + + t.Run("state-file with explicit load-state=true and save-state=true", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--load-state=true", + "--save-state=true", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, true, viper.GetBool(LoadState)) + assert.Equal(t, true, viper.GetBool(SaveState)) + }) + + t.Run("load-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--load-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(LoadState)) + }) + + t.Run("save-state flag can be parsed", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--save-state", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + // Flag is parsed correctly (validation happens in runServer) + assert.Equal(t, true, viper.GetBool(SaveState)) + }) + + t.Run("pid-file can be set independently", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--pid-file", "/tmp/server.pid", "--exit", "dummy-command"}) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile)) + }) + + t.Run("state-file and pid-file can be set together", func(t *testing.T) { + isolateViper(t) + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{ + "--state-file", "/tmp/state.json", + "--pid-file", "/tmp/server.pid", + "--exit", "dummy-command", + }) + err := serverCmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile)) + }) +} + +func TestPIDFileOperations(t *testing.T) { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + t.Run("writePIDFile creates file with process ID", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify content contains current PID + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("writePIDFile creates directory if not exists", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nested/deep/test.pid" + + err := writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(pidFile) + require.NoError(t, err) + + // Verify directory was created + _, err = os.Stat(tmpDir + "/nested/deep") + require.NoError(t, err) + }) + + t.Run("writePIDFile overwrites existing file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Write initial PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Overwrite with current PID + err = writePIDFile(pidFile, discardLogger) + require.NoError(t, err) + + // Verify content is updated + data, err := os.ReadFile(pidFile) + require.NoError(t, err) + + expectedPID := fmt.Sprintf("%d\n", os.Getpid()) + assert.Equal(t, expectedPID, string(data)) + }) + + t.Run("cleanupPIDFile removes file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/test.pid" + + // Create PID file + err := os.WriteFile(pidFile, []byte("12345\n"), 0o644) + require.NoError(t, err) + + // Cleanup + cleanupPIDFile(pidFile, discardLogger) + + // Verify file is removed + _, err = os.Stat(pidFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("cleanupPIDFile handles non-existent file", func(t *testing.T) { + tmpDir := t.TempDir() + pidFile := tmpDir + "/nonexistent.pid" + + // Should not panic or error + cleanupPIDFile(pidFile, discardLogger) + }) + + t.Run("cleanupPIDFile handles directory removal error gracefully", func(t *testing.T) { + // Create a file in a protected directory (this is system-dependent) + // Just verify it doesn't panic when it can't remove the file + pidFile := "/this/should/not/exist/test.pid" + + // Should not panic + cleanupPIDFile(pidFile, discardLogger) + }) +} + func TestServerCmd_AllowedOrigins(t *testing.T) { tests := []struct { name string diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index c8e8b23c..82fc6713 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/logctx" @@ -956,3 +957,36 @@ func TestServer_UploadFiles_Errors(t *testing.T) { require.Contains(t, string(body), "file size exceeds 10MB limit") }) } + +func TestServer_Stop_Idempotency(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, + }) + require.NoError(t, err) + + // First call to Stop should succeed + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = srv.Stop(stopCtx) + require.NoError(t, err) + + // Second call to Stop should also succeed (no-op) + stopCtx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + err = srv.Stop(stopCtx2) + require.NoError(t, err) + + // Third call to Stop should also succeed (no-op) + stopCtx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel3() + err = srv.Stop(stopCtx3) + require.NoError(t, err) +} diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index 19b4511b..67ff1395 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,9 +2,11 @@ package screentracker_test import ( "context" + "encoding/json" "fmt" "io" "log/slog" + "os" "sync" "testing" "time" @@ -446,6 +448,357 @@ func TestMessages(t *testing.T) { }) } +func TestStatePersistence(t *testing.T) { + t.Run("SaveState creates file with correct structure", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + // Create temp directory for state file + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate some conversation + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Save state + err := c.SaveState() + require.NoError(t, err) + + // Read and verify the saved file + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.Equal(t, 1, agentState.Version) + assert.Equal(t, "test prompt", agentState.InitialPrompt) + assert.NotEmpty(t, agentState.Messages) + }) + + t.Run("SaveState skips when not configured", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + err := c.SaveState() + require.NoError(t, err) + + // File should not be created + _, err = os.Stat(stateFile) + assert.True(t, os.IsNotExist(err)) + }) + + t.Run("SaveState honors dirty flag", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Generate conversation and save + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + err := c.SaveState() + require.NoError(t, err) + + // Get file modification time + info1, err := os.Stat(stateFile) + require.NoError(t, err) + modTime1 := info1.ModTime() + + // Save again without changes - file should not be modified + err = c.SaveState() + require.NoError(t, err) + + info2, err := os.Stat(stateFile) + require.NoError(t, err) + modTime2 := info2.ModTime() + + // File modification time should be the same (dirty flag prevents save) + assert.Equal(t, modTime1, modTime2) + }) + + t.Run("SaveState creates directory if not exists", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nested/deep/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "initial"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + agent.setScreen("hello") + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + err := c.SaveState() + require.NoError(t, err) + + // Verify file and directory were created + _, err = os.Stat(stateFile) + assert.NoError(t, err) + }) + + t.Run("LoadState restores conversation from file", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with test data + testState := st.AgentState{ + Version: 1, + InitialPrompt: "restored prompt", + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message 1", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "user message 1", Role: st.ConversationRoleUser, Time: time.Now()}, + {Id: 2, Message: "agent message 2", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with LoadState enabled + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until agent is ready and state is loaded + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Verify messages were restored + messages := c.Messages() + assert.Len(t, messages, 3) + assert.Equal(t, "agent message 1", messages[0].Message) + assert.Equal(t, "user message 1", messages[1].Message) + // The last agent message may have adjustments from adjustScreenAfterStateLoad + assert.Contains(t, messages[2].Message, "agent message 2") + }) + + t.Run("LoadState handles missing file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/nonexistent.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles empty file gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/empty.json" + + // Create empty file + err := os.WriteFile(stateFile, []byte(""), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic or error + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) + + t.Run("LoadState handles corrupted JSON gracefully", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/corrupted.json" + + // Create corrupted JSON file + err := os.WriteFile(stateFile, []byte("{invalid json}"), 0o644) + require.NoError(t, err) + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + FormatMessage: func(message string, userInput string) string { + return message + }, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + // Should not panic - logs warning and continues + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + // Should have default initial message + messages := c.Messages() + assert.Len(t, messages, 1) + }) +} + func TestInitialPromptReadiness(t *testing.T) { discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) From 39590021d1e48d059dc87bd4b96d82d533544bc1 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 16:24:39 +0530 Subject: [PATCH 13/27] feat: remove comment --- lib/screentracker/pty_conversation.go | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index a8ddb5ea..a4b44124 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -33,7 +33,6 @@ type AgentState struct { Version int `json:"version"` Messages []ConversationMessage `json:"messages"` InitialPrompt string `json:"initial_prompt"` - //InitialPromptSent bool `json:"initial_prompt_sent"` } var _ MessagePart = &MessagePartText{} From 7e389d29234de38d963a70bb5dd544756a032021 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Tue, 17 Feb 2026 16:28:11 +0530 Subject: [PATCH 14/27] feat: remove comments --- cmd/server/server.go | 1 - lib/httpapi/server.go | 23 ++++++++++------------- lib/screentracker/pty_conversation.go | 4 +--- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 46b21e26..c20a833c 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -125,7 +125,6 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er saveState = viper.GetBool(SaveState) } } else { - // No state file provided - ensure load/save flags are not explicitly set to true if viper.IsSet(LoadState) && viper.GetBool(LoadState) { return xerrors.Errorf("--load-state requires --state-file to be set") } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 0243038e..29f0dce1 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -255,15 +255,15 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } conversation := st.NewPTY(ctx, st.PTYConversationConfig{ - AgentType: config.AgentType, - AgentIO: config.Process, - Clock: config.Clock, - SnapshotInterval: snapshotInterval, - ScreenStabilityLength: 2 * time.Second, - FormatMessage: formatMessage, - ReadyForInitialPrompt: isAgentReadyForInitialPrompt, - FormatToolCall: formatToolCall, - InitialPrompt: initialPrompt, + AgentType: config.AgentType, + AgentIO: config.Process, + Clock: config.Clock, + SnapshotInterval: snapshotInterval, + ScreenStabilityLength: 2 * time.Second, + FormatMessage: formatMessage, + ReadyForInitialPrompt: isAgentReadyForInitialPrompt, + FormatToolCall: formatToolCall, + InitialPrompt: initialPrompt, Logger: logger, StatePersistenceConfig: config.StatePersistenceConfig, }, emitter) @@ -591,8 +591,7 @@ func (s *Server) Start() error { return s.srv.ListenAndServe() } -// Stop gracefully stops the HTTP server. It is safe to call multiple times; -// only the first call will perform the shutdown, subsequent calls are no-ops. +// Stop gracefully stops the HTTP server. It is safe to call multiple times. func (s *Server) Stop(ctx context.Context) error { var err error s.stopOnce.Do(func() { @@ -615,8 +614,6 @@ func (s *Server) cleanupTempDir() { } } -// SaveState saves the conversation state if configured. This can be called from signal handlers. -// The source parameter indicates what triggered the save (e.g., "SIGTERM", "SIGUSR1"). func (s *Server) SaveState(source string) error { if err := s.conversation.SaveState(); err != nil { s.logger.Error("Failed to save conversation state", "source", source, "error", err) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index a4b44124..e5b6feb4 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -76,9 +76,7 @@ type PTYConversationConfig struct { // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls FormatToolCall func(message string) (string, []string) // InitialPrompt is the initial prompt to send to the agent once ready - InitialPrompt []MessagePart - // OnSnapshot is called after each snapshot with current status, messages, and screen content - OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string) + InitialPrompt []MessagePart Logger *slog.Logger StatePersistenceConfig StatePersistenceConfig } From 1d7aaedc01e4723e2e01511e5b4e9942ce84fd6e Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Wed, 18 Feb 2026 22:29:57 +0530 Subject: [PATCH 15/27] wip: address comments --- cmd/server/server.go | 4 +-- cmd/server/signals.go | 7 +++- cmd/server/signals_unix.go | 4 +-- cmd/server/signals_windows.go | 4 +-- lib/httpapi/events.go | 2 +- lib/screentracker/pty_conversation.go | 50 ++++++++++++++++++--------- 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index c20a833c..39141335 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -140,8 +140,6 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er if err := writePIDFile(pidFile, logger); err != nil { return xerrors.Errorf("failed to write PID file: %w", err) } - // Ensure PID file is cleaned up on exit - defer cleanupPIDFile(pidFile, logger) } printOpenAPI := viper.GetBool(FlagPrintOpenAPI) @@ -183,7 +181,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er fmt.Println(srv.GetOpenAPI()) return nil } - handleSignals(ctx, logger, srv, process) + handleSignals(ctx, logger, srv, process, pidFile) logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) go func() { diff --git a/cmd/server/signals.go b/cmd/server/signals.go index e66554e7..7a669faa 100644 --- a/cmd/server/signals.go +++ b/cmd/server/signals.go @@ -12,7 +12,7 @@ import ( // performGracefulShutdown handles the common shutdown logic for all platforms. // It saves state, stops the HTTP server, closes the process, and exits. -func performGracefulShutdown(sig os.Signal, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { +func performGracefulShutdown(sig os.Signal, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process, pidFile string) { logger.Info("Received shutdown signal, initiating graceful shutdown", "signal", sig) // Save state @@ -31,6 +31,11 @@ func performGracefulShutdown(sig os.Signal, logger *slog.Logger, srv *httpapi.Se logger.Error("Failed to close process cleanly", "signal", sig, "error", err) } + // Clean up PID file before exit + if pidFile != "" { + cleanupPIDFile(pidFile, logger) + } + // Exit cleanly os.Exit(0) } diff --git a/cmd/server/signals_unix.go b/cmd/server/signals_unix.go index fe6b4693..4b46a240 100644 --- a/cmd/server/signals_unix.go +++ b/cmd/server/signals_unix.go @@ -16,14 +16,14 @@ import ( // handleSignals sets up signal handlers for: // - SIGTERM, SIGINT, SIGHUP: save conversation state, stop server, then close the process // - SIGUSR1: save conversation state without exiting -func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { +func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process, pidFile string) { // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh - performGracefulShutdown(sig, logger, srv, process) + performGracefulShutdown(sig, logger, srv, process, pidFile) }() // Handle SIGUSR1 for save without exit diff --git a/cmd/server/signals_windows.go b/cmd/server/signals_windows.go index 52d90616..45dd414b 100644 --- a/cmd/server/signals_windows.go +++ b/cmd/server/signals_windows.go @@ -15,13 +15,13 @@ import ( // handleSignals sets up signal handlers for Windows. // Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). -func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process) { +func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process, pidFile string) { // Handle shutdown signals (SIGTERM, SIGINT only on Windows) shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh - performGracefulShutdown(sig, logger, srv, process) + performGracefulShutdown(sig, logger, srv, process, pidFile) }() } diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index dac1549f..30aba036 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -137,7 +137,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { } } -// UpdateMessagesAndEmitChanges assumes that only the last message can change or new messages can be added. +// EmitMessages assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) { e.mu.Lock() diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index e5b6feb4..32454a00 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "log/slog" "os" "path/filepath" @@ -193,7 +194,17 @@ func (c *PTYConversation) Start(ctx context.Context) { } if c.initialPromptReady && !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { - _ = c.loadState() + if err := c.loadStateLocked(); err != nil { + // Add error message to conversation so user is aware + errorMsg := fmt.Sprintf("AgentAPI state restoration failed, the conversation history may be missing: %v", err) + c.messages = append(c.messages, ConversationMessage{ + Id: len(c.messages), + Message: errorMsg, + Role: ConversationRoleAgent, + Time: c.cfg.Clock.Now(), + }) + c.cfg.Logger.Error("Failed to load state", "error", err) + } c.loadStateSuccessful = true } @@ -534,8 +545,6 @@ func (c *PTYConversation) Text() string { } func (c *PTYConversation) SaveState() error { - conversation := c.Messages() - c.lock.Lock() defer c.lock.Unlock() @@ -543,7 +552,7 @@ func (c *PTYConversation) SaveState() error { saveState := c.cfg.StatePersistenceConfig.SaveState if !saveState { - c.cfg.Logger.Info("") + c.cfg.Logger.Info("State persistence is disabled") return nil } @@ -553,6 +562,8 @@ func (c *PTYConversation) SaveState() error { return nil } + conversation := c.messagesLocked() + // Serialize initial prompt from message parts var initialPromptStr string if len(c.cfg.InitialPrompt) > 0 { @@ -598,8 +609,8 @@ func (c *PTYConversation) SaveState() error { return nil } -// LoadState loads the state, this method assumes that caller holds the Lock -func (c *PTYConversation) loadState() error { +// loadStateLocked loads the state, this method assumes that caller holds the Lock +func (c *PTYConversation) loadStateLocked() error { stateFile := c.cfg.StatePersistenceConfig.StateFile loadState := c.cfg.StatePersistenceConfig.LoadState @@ -613,20 +624,25 @@ func (c *PTYConversation) loadState() error { return nil } - // Read state file - data, err := os.ReadFile(stateFile) + // Open state file + f, err := os.Open(stateFile) if err != nil { - c.cfg.Logger.Warn("Failed to load state file", "path", stateFile, "err", err) - return xerrors.Errorf("failed to read state file: %w", err) - } - - if len(data) == 0 { - c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) - return nil + c.cfg.Logger.Warn("Failed to open state file", "path", stateFile, "err", err) + return xerrors.Errorf("failed to open state file: %w", err) } + defer func() { + if closeErr := f.Close(); closeErr != nil { + c.cfg.Logger.Warn("Failed to close state file", "path", stateFile, "err", closeErr) + } + }() var agentState AgentState - if err := json.Unmarshal(data, &agentState); err != nil { + decoder := json.NewDecoder(f) + if err := decoder.Decode(&agentState); err != nil { + if err == io.EOF { + c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) + return nil + } c.cfg.Logger.Warn("Failed to load state file (corrupted or invalid JSON)", "path", stateFile, "err", err) return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) } @@ -663,7 +679,7 @@ func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { // Before the first user message after loading state, return the last message from the loaded state. // This prevents computing incorrect diffs from the restored screen, as the agent's message should // remain stable until the user continues the conversation. - if c.userSentMessageAfterLoadState == false { + if !c.userSentMessageAfterLoadState { newScreen = "\n" + c.messages[len(c.messages)-1].Message } From 058b18f6df5c68b85a5b4d79294241d44f8407eb Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 00:18:17 +0530 Subject: [PATCH 16/27] feat: remove anti-pattern for graceful shutdown --- cmd/server/server.go | 64 ++++++++++++++++++++++++++++++----- cmd/server/signals.go | 41 ---------------------- cmd/server/signals_unix.go | 8 ++--- cmd/server/signals_windows.go | 6 ++-- lib/httpapi/server.go | 15 ++++++++ 5 files changed, 78 insertions(+), 56 deletions(-) delete mode 100644 cmd/server/signals.go diff --git a/cmd/server/server.go b/cmd/server/server.go index 39141335..8ad37219 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -181,8 +181,22 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er fmt.Println(srv.GetOpenAPI()) return nil } - handleSignals(ctx, logger, srv, process, pidFile) + + // Create a context for graceful shutdown + gracefulCtx, gracefulCancel := context.WithCancel(ctx) + defer gracefulCancel() + + // Setup signal handlers (they will call gracefulCancel) + handleSignals(gracefulCtx, gracefulCancel, logger, srv) + + // Setup PID file cleanup + if pidFile != "" { + defer cleanupPIDFile(pidFile, logger) + } + logger.Info("Starting server on port", "port", port) + + // Monitor process exit processExitCh := make(chan error, 1) go func() { defer close(processExitCh) @@ -193,18 +207,52 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) } } - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := srv.Stop(shutdownCtx); err != nil { - logger.Error("Failed to stop server after process exit", "error", err) + + select { + case <-gracefulCtx.Done(): + default: + gracefulCancel() + } + }() + + // Start the server + serverErrCh := make(chan error, 1) + go func() { + defer close(serverErrCh) + if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { + serverErrCh <- err } }() - if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { - return xerrors.Errorf("failed to start server: %w", err) + + select { + case err := <-serverErrCh: + if err != nil { + return xerrors.Errorf("failed to start server: %w", err) + } + case <-gracefulCtx.Done(): } + + if err := srv.SaveState("shutdown"); err != nil { + logger.Error("Failed to save state during shutdown", "error", err) + } + + // Stop the HTTP server + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Stop(shutdownCtx); err != nil { + logger.Error("Failed to stop HTTP server", "error", err) + } + + // Close the process + if err := process.Close(logger, 5*time.Second); err != nil { + logger.Error("Failed to close process cleanly", "error", err) + } + select { case err := <-processExitCh: - return xerrors.Errorf("agent exited with error: %w", err) + if err != nil { + return xerrors.Errorf("agent exited with error: %w", err) + } default: } return nil diff --git a/cmd/server/signals.go b/cmd/server/signals.go deleted file mode 100644 index 7a669faa..00000000 --- a/cmd/server/signals.go +++ /dev/null @@ -1,41 +0,0 @@ -package server - -import ( - "context" - "log/slog" - "os" - "time" - - "github.com/coder/agentapi/lib/httpapi" - "github.com/coder/agentapi/lib/termexec" -) - -// performGracefulShutdown handles the common shutdown logic for all platforms. -// It saves state, stops the HTTP server, closes the process, and exits. -func performGracefulShutdown(sig os.Signal, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process, pidFile string) { - logger.Info("Received shutdown signal, initiating graceful shutdown", "signal", sig) - - // Save state - if err := srv.SaveState(sig.String()); err != nil { - logger.Error("Failed to save state during shutdown", "signal", sig, "error", err) - } - - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := srv.Stop(shutdownCtx); err != nil { - logger.Error("Failed to stop HTTP server", "signal", sig, "error", err) - } - - // Close the process - if err := process.Close(logger, 5*time.Second); err != nil { - logger.Error("Failed to close process cleanly", "signal", sig, "error", err) - } - - // Clean up PID file before exit - if pidFile != "" { - cleanupPIDFile(pidFile, logger) - } - - // Exit cleanly - os.Exit(0) -} diff --git a/cmd/server/signals_unix.go b/cmd/server/signals_unix.go index 4b46a240..b15b5b2b 100644 --- a/cmd/server/signals_unix.go +++ b/cmd/server/signals_unix.go @@ -10,20 +10,20 @@ import ( "syscall" "github.com/coder/agentapi/lib/httpapi" - "github.com/coder/agentapi/lib/termexec" ) // handleSignals sets up signal handlers for: -// - SIGTERM, SIGINT, SIGHUP: save conversation state, stop server, then close the process +// - SIGTERM, SIGINT, SIGHUP: trigger graceful shutdown by canceling the context // - SIGUSR1: save conversation state without exiting -func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process, pidFile string) { +func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) { // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh - performGracefulShutdown(sig, logger, srv, process, pidFile) + logger.Info("Received shutdown signal", "signal", sig) + cancel() }() // Handle SIGUSR1 for save without exit diff --git a/cmd/server/signals_windows.go b/cmd/server/signals_windows.go index 45dd414b..b8a109c9 100644 --- a/cmd/server/signals_windows.go +++ b/cmd/server/signals_windows.go @@ -10,18 +10,18 @@ import ( "syscall" "github.com/coder/agentapi/lib/httpapi" - "github.com/coder/agentapi/lib/termexec" ) // handleSignals sets up signal handlers for Windows. // Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). -func handleSignals(ctx context.Context, logger *slog.Logger, srv *httpapi.Server, process *termexec.Process, pidFile string) { +func handleSignals(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, srv *httpapi.Server) { // Handle shutdown signals (SIGTERM, SIGINT only on Windows) shutdownCh := make(chan os.Signal, 1) signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) go func() { defer signal.Stop(shutdownCh) sig := <-shutdownCh - performGracefulShutdown(sig, logger, srv, process, pidFile) + logger.Info("Received shutdown signal", "signal", sig) + cancel() }() } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 29f0dce1..a53b5074 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -49,6 +49,8 @@ type Server struct { chatBasePath string tempDir string clock quartz.Clock + shutdownCtx context.Context + shutdown context.CancelFunc } func (s *Server) NormalizeSchema(schema any) any { @@ -275,6 +277,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } logger.Info("Created temporary directory for uploads", "tempDir", tempDir) + ctx, cancel := context.WithCancel(context.Background()) + s := &Server{ router: router, api: api, @@ -287,6 +291,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), tempDir: tempDir, clock: config.Clock, + shutdownCtx: ctx, + shutdown: cancel, } // Register API routes @@ -514,6 +520,7 @@ func (s *Server) uploadFiles(ctx context.Context, input *struct { func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.Sender) { subscriberId, ch, stateEvents := s.emitter.Subscribe() defer s.emitter.Unsubscribe(subscriberId) + s.logger.Info("New subscriber", "subscriberId", subscriberId) for _, event := range stateEvents { if event.Type == EventTypeScreenUpdate { @@ -539,6 +546,9 @@ func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse. s.logger.Error("Failed to send event", "subscriberId", subscriberId, "error", err) return } + case <-s.shutdownCtx.Done(): + s.logger.Info("Server stop initiated, unsubscribing.", "subscriberId", subscriberId) + return case <-ctx.Done(): s.logger.Info("Context done", "subscriberId", subscriberId) return @@ -573,6 +583,9 @@ func (s *Server) subscribeScreen(ctx context.Context, input *struct{}, send sse. s.logger.Error("Failed to send screen event", "subscriberId", subscriberId, "error", err) return } + case <-s.shutdownCtx.Done(): + s.logger.Info("Server stop initiated, unsubscribing.", "subscriberId", subscriberId) + return case <-ctx.Done(): s.logger.Info("Screen context done", "subscriberId", subscriberId) return @@ -595,6 +608,8 @@ func (s *Server) Start() error { func (s *Server) Stop(ctx context.Context) error { var err error s.stopOnce.Do(func() { + s.shutdown() + // Clean up temporary directory s.cleanupTempDir() From 2565a3c36a0d55456bd518dd9acb23504c236fe6 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 00:26:01 +0530 Subject: [PATCH 17/27] feat: remove additional message upon load state fail --- lib/screentracker/pty_conversation.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 32454a00..7aaefbf5 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -195,14 +195,6 @@ func (c *PTYConversation) Start(ctx context.Context) { if c.initialPromptReady && !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { if err := c.loadStateLocked(); err != nil { - // Add error message to conversation so user is aware - errorMsg := fmt.Sprintf("AgentAPI state restoration failed, the conversation history may be missing: %v", err) - c.messages = append(c.messages, ConversationMessage{ - Id: len(c.messages), - Message: errorMsg, - Role: ConversationRoleAgent, - Time: c.cfg.Clock.Now(), - }) c.cfg.Logger.Error("Failed to load state", "error", err) } c.loadStateSuccessful = true @@ -627,7 +619,6 @@ func (c *PTYConversation) loadStateLocked() error { // Open state file f, err := os.Open(stateFile) if err != nil { - c.cfg.Logger.Warn("Failed to open state file", "path", stateFile, "err", err) return xerrors.Errorf("failed to open state file: %w", err) } defer func() { @@ -643,7 +634,6 @@ func (c *PTYConversation) loadStateLocked() error { c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) return nil } - c.cfg.Logger.Warn("Failed to load state file (corrupted or invalid JSON)", "path", stateFile, "err", err) return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) } From 1033cd72284fb1e29122a60624afc6da545a0133 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 00:46:19 +0530 Subject: [PATCH 18/27] wip: apply suggestions from cian --- cmd/server/server.go | 32 +++++++++++++-------------- lib/screentracker/pty_conversation.go | 8 +++---- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 8ad37219..45985d99 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -107,33 +107,33 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } // Get the variables related to state management - stateFile := viper.GetString(StateFile) + stateFile := viper.GetString(FlagStateFile) loadState := false saveState := false // Validate state file configuration if stateFile != "" { - if !viper.IsSet(LoadState) { + if !viper.IsSet(FlagLoadState) { loadState = true } else { - loadState = viper.GetBool(LoadState) + loadState = viper.GetBool(FlagLoadState) } - if !viper.IsSet(SaveState) { + if !viper.IsSet(FlagSaveState) { saveState = true } else { - saveState = viper.GetBool(SaveState) + saveState = viper.GetBool(FlagSaveState) } } else { - if viper.IsSet(LoadState) && viper.GetBool(LoadState) { + if viper.IsSet(FlagLoadState) && viper.GetBool(FlagLoadState) { return xerrors.Errorf("--load-state requires --state-file to be set") } - if viper.IsSet(SaveState) && viper.GetBool(SaveState) { + if viper.IsSet(FlagSaveState) && viper.GetBool(FlagSaveState) { return xerrors.Errorf("--save-state requires --state-file to be set") } } - pidFile := viper.GetString(PidFile) + pidFile := viper.GetString(FlagPidFile) // Write PID file if configured if pidFile != "" { @@ -315,10 +315,10 @@ const ( FlagAllowedOrigins = "allowed-origins" FlagExit = "exit" FlagInitialPrompt = "initial-prompt" - StateFile = "state-file" - LoadState = "load-state" - SaveState = "save-state" - PidFile = "pid-file" + FlagStateFile = "state-file" + FlagLoadState = "load-state" + FlagSaveState = "save-state" + FlagPidFile = "pid-file" ) func CreateServerCmd() *cobra.Command { @@ -357,10 +357,10 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, - {StateFile, "s", "", "Path to file for saving/loading server state", "string"}, - {LoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, - {SaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, - {PidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, + {FlagStateFile, "s", "", "Path to file for saving/loading server state", "string"}, + {FlagLoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, + {FlagSaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, + {FlagPidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, } for _, spec := range flagSpecs { diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 7aaefbf5..c6b13b68 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -596,7 +596,7 @@ func (c *PTYConversation) SaveState() error { // Clear dirty flag after successful save c.dirty = false - c.cfg.Logger.Info(fmt.Sprintf("State saved successfully to: %s", stateFile)) + c.cfg.Logger.Info("State saved successfully", "path", stateFile) return nil } @@ -606,7 +606,7 @@ func (c *PTYConversation) loadStateLocked() error { stateFile := c.cfg.StatePersistenceConfig.StateFile loadState := c.cfg.StatePersistenceConfig.LoadState - if !loadState { + if !loadState || c.loadStateSuccessful { return nil } @@ -647,7 +647,7 @@ func (c *PTYConversation) loadStateLocked() error { // Store the first stable snapshot for filtering later snapshots := c.snapshotBuffer.GetAll() - if len(snapshots) > 0 { + if len(snapshots) > 0 && c.cfg.FormatMessage != nil { c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") } @@ -669,7 +669,7 @@ func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { // Before the first user message after loading state, return the last message from the loaded state. // This prevents computing incorrect diffs from the restored screen, as the agent's message should // remain stable until the user continues the conversation. - if !c.userSentMessageAfterLoadState { + if !c.userSentMessageAfterLoadState && len(c.messages) > 0 { newScreen = "\n" + c.messages[len(c.messages)-1].Message } From cfb76012a32e6323c7a15b5054f22bb17c5acb24 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 00:52:22 +0530 Subject: [PATCH 19/27] wip: apply suggestions from cian --- cmd/server/server.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 45985d99..cfb9640d 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -274,12 +274,12 @@ func writePIDFile(pidFile string, logger *slog.Logger) error { // Create directory if it doesn't exist dir := filepath.Dir(pidFile) - if err := os.MkdirAll(dir, 0o755); err != nil { + if err := os.MkdirAll(dir, 0o700); err != nil { return xerrors.Errorf("failed to create PID file directory: %w", err) } // Write PID file - if err := os.WriteFile(pidFile, []byte(pidContent), 0o644); err != nil { + if err := os.WriteFile(pidFile, []byte(pidContent), 0o600); err != nil { return xerrors.Errorf("failed to write PID file: %w", err) } From 9d7eb5af8714804c9ac8606011fc4368cd787a4c Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 00:54:19 +0530 Subject: [PATCH 20/27] feat: update tests --- cmd/server/server_test.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index 4affad0d..7b9372c1 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -494,7 +494,7 @@ func TestServerCmd_StatePersistenceFlags(t *testing.T) { err := serverCmd.Execute() require.NoError(t, err) - assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) // load-state and save-state default to true when state-file is set (validated in runServer) }) @@ -507,8 +507,8 @@ func TestServerCmd_StatePersistenceFlags(t *testing.T) { err := serverCmd.Execute() require.NoError(t, err) - assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) - assert.Equal(t, false, viper.GetBool(LoadState)) + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, false, viper.GetBool(FlagLoadState)) }) t.Run("state-file with explicit save-state=false", func(t *testing.T) { @@ -520,8 +520,8 @@ func TestServerCmd_StatePersistenceFlags(t *testing.T) { err := serverCmd.Execute() require.NoError(t, err) - assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) - assert.Equal(t, false, viper.GetBool(SaveState)) + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, false, viper.GetBool(FlagSaveState)) }) t.Run("state-file with explicit load-state=true and save-state=true", func(t *testing.T) { @@ -538,9 +538,9 @@ func TestServerCmd_StatePersistenceFlags(t *testing.T) { err := serverCmd.Execute() require.NoError(t, err) - assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) - assert.Equal(t, true, viper.GetBool(LoadState)) - assert.Equal(t, true, viper.GetBool(SaveState)) + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, true, viper.GetBool(FlagLoadState)) + assert.Equal(t, true, viper.GetBool(FlagSaveState)) }) t.Run("load-state flag can be parsed", func(t *testing.T) { @@ -553,7 +553,7 @@ func TestServerCmd_StatePersistenceFlags(t *testing.T) { require.NoError(t, err) // Flag is parsed correctly (validation happens in runServer) - assert.Equal(t, true, viper.GetBool(LoadState)) + assert.Equal(t, true, viper.GetBool(FlagLoadState)) }) t.Run("save-state flag can be parsed", func(t *testing.T) { @@ -566,7 +566,7 @@ func TestServerCmd_StatePersistenceFlags(t *testing.T) { require.NoError(t, err) // Flag is parsed correctly (validation happens in runServer) - assert.Equal(t, true, viper.GetBool(SaveState)) + assert.Equal(t, true, viper.GetBool(FlagSaveState)) }) t.Run("pid-file can be set independently", func(t *testing.T) { @@ -578,7 +578,7 @@ func TestServerCmd_StatePersistenceFlags(t *testing.T) { err := serverCmd.Execute() require.NoError(t, err) - assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile)) + assert.Equal(t, "/tmp/server.pid", viper.GetString(FlagPidFile)) }) t.Run("state-file and pid-file can be set together", func(t *testing.T) { @@ -594,8 +594,8 @@ func TestServerCmd_StatePersistenceFlags(t *testing.T) { err := serverCmd.Execute() require.NoError(t, err) - assert.Equal(t, "/tmp/state.json", viper.GetString(StateFile)) - assert.Equal(t, "/tmp/server.pid", viper.GetString(PidFile)) + assert.Equal(t, "/tmp/state.json", viper.GetString(FlagStateFile)) + assert.Equal(t, "/tmp/server.pid", viper.GetString(FlagPidFile)) }) } From 759ec531abb47f7f4c19e3f0b2a4fd551394e08c Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 16:16:00 +0530 Subject: [PATCH 21/27] feat: improved initial prompt handling --- lib/screentracker/conversation.go | 9 + lib/screentracker/pty_conversation.go | 66 ++-- lib/screentracker/pty_conversation_test.go | 348 ++++++++++++++++++++- 3 files changed, 379 insertions(+), 44 deletions(-) diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index f921d16b..358f56d3 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -2,6 +2,7 @@ package screentracker import ( "context" + "strings" "time" "github.com/coder/agentapi/lib/util" @@ -49,6 +50,14 @@ type MessagePart interface { String() string } +func buildStringFromMessageParts(parts []MessagePart) string { + var sb strings.Builder + for _, part := range parts { + sb.WriteString(part.String()) + } + return sb.String() +} + // Conversation represents a conversation between a user and an agent. // It is intended as the primary interface for interacting with a session. // Implementations must support the following capabilities: diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index c6b13b68..37881d97 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -31,9 +31,10 @@ type MessagePartText struct { } type AgentState struct { - Version int `json:"version"` - Messages []ConversationMessage `json:"messages"` - InitialPrompt string `json:"initial_prompt"` + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` + InitialPromptSent bool `json:"initial_prompt_sent"` } var _ MessagePart = &MessagePartText{} @@ -129,6 +130,7 @@ type PTYConversation struct { // initialPromptReady is set to true when ReadyForInitialPrompt returns true. // Checked inline in the snapshot loop on each tick. initialPromptReady bool + initialPromptSent bool } var _ Conversation = &PTYConversation{} @@ -167,10 +169,6 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PT userSentMessageAfterLoadState: false, loadStateSuccessful: false, } - // If we have an initial prompt, enqueue it - if len(cfg.InitialPrompt) > 0 { - c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil} - } if c.cfg.ReadyForInitialPrompt == nil { c.cfg.ReadyForInitialPrompt = func(string) bool { return true } } @@ -200,6 +198,13 @@ func (c *PTYConversation) Start(ctx context.Context) { c.loadStateSuccessful = true } + // Enqueue initial prompt once after agent is ready (and after state is potentially loaded) + if c.initialPromptReady && len(c.cfg.InitialPrompt) > 0 && !c.initialPromptSent { + c.outboundQueue <- outboundMessage{parts: c.cfg.InitialPrompt, errCh: nil} + c.initialPromptSent = true + c.dirty = true + } + if c.initialPromptReady && len(c.outboundQueue) > 0 && c.isScreenStableLocked() { select { case c.stableSignal <- struct{}{}: @@ -324,11 +329,7 @@ func (c *PTYConversation) snapshotLocked(screen string) { func (c *PTYConversation) Send(messageParts ...MessagePart) error { // Validate message content before enqueueing - var sb strings.Builder - for _, part := range messageParts { - sb.WriteString(part.String()) - } - message := sb.String() + message := buildStringFromMessageParts(messageParts) if message != msgfmt.TrimWhitespace(message) { return ErrMessageValidationWhitespace } @@ -352,11 +353,7 @@ func (c *PTYConversation) Send(messageParts ...MessagePart) error { // around the parts that access shared state, but releases it during // writeStabilize to avoid blocking the snapshot loop. func (c *PTYConversation) sendMessage(ctx context.Context, messageParts ...MessagePart) error { - var sb strings.Builder - for _, part := range messageParts { - sb.WriteString(part.String()) - } - message := sb.String() + message := buildStringFromMessageParts(messageParts) c.lock.Lock() screenBeforeMessage := c.cfg.AgentIO.ReadScreen() @@ -559,18 +556,15 @@ func (c *PTYConversation) SaveState() error { // Serialize initial prompt from message parts var initialPromptStr string if len(c.cfg.InitialPrompt) > 0 { - var sb strings.Builder - for _, part := range c.cfg.InitialPrompt { - sb.WriteString(part.String()) - } - initialPromptStr = sb.String() + initialPromptStr = buildStringFromMessageParts(c.cfg.InitialPrompt) } // Use atomic write: write to temp file, then rename to target path data, err := json.MarshalIndent(AgentState{ - Version: 1, - Messages: conversation, - InitialPrompt: initialPromptStr, + Version: 1, + Messages: conversation, + InitialPrompt: initialPromptStr, + InitialPromptSent: c.initialPromptSent, }, "", " ") if err != nil { return xerrors.Errorf("failed to marshal state: %w", err) @@ -637,12 +631,22 @@ func (c *PTYConversation) loadStateLocked() error { return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) } - //c.cfg.initialPromptSent = agentState.InitialPromptSent - c.cfg.InitialPrompt = []MessagePart{MessagePartText{ - Content: agentState.InitialPrompt, - Alias: "", - Hidden: false, - }} + // Handle initial prompt restoration: + // - If a new initial prompt was provided via flags, check if it differs from the saved one. + // If different, mark as not sent (will be sent). If same, preserve sent status. + // - If no new prompt provided, restore the saved prompt and its sent status. + c.initialPromptSent = agentState.InitialPromptSent + if len(c.cfg.InitialPrompt) > 0 { + isDifferent := buildStringFromMessageParts(c.cfg.InitialPrompt) != agentState.InitialPrompt + c.initialPromptSent = !isDifferent + } else { + c.cfg.InitialPrompt = []MessagePart{MessagePartText{ + Content: agentState.InitialPrompt, + Alias: "", + Hidden: false, + }} + } + c.messages = agentState.Messages // Store the first stable snapshot for filtering later diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index 67ff1395..f201edc5 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -802,7 +802,7 @@ func TestStatePersistence(t *testing.T) { func TestInitialPromptReadiness(t *testing.T) { discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - t.Run("agent not ready - status remains changing", func(t *testing.T) { + t.Run("agent not ready - status is stable until agent becomes ready", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) t.Cleanup(cancel) mClock := quartz.NewMock(t) @@ -825,12 +825,12 @@ func TestInitialPromptReadiness(t *testing.T) { // Take a snapshot with "loading...". Threshold is 1 (stability 0 / interval 1s = 0 + 1 = 1). advanceFor(ctx, t, mClock, 1*time.Second) - // Even though screen is stable, status should be changing because - // the initial prompt is still in the outbound queue. - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + // Screen is stable and agent is not ready, so initial prompt hasn't been enqueued yet. + // Status should be stable. + assert.Equal(t, st.ConversationStatusStable, c.Status()) }) - t.Run("agent becomes ready - status stays changing until initial prompt sent", func(t *testing.T) { + t.Run("agent becomes ready - prompt enqueued and status changes to changing", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) t.Cleanup(cancel) mClock := quartz.NewMock(t) @@ -850,12 +850,11 @@ func TestInitialPromptReadiness(t *testing.T) { c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) - // Agent not ready initially. + // Agent not ready initially, status should be stable advanceFor(ctx, t, mClock, 1*time.Second) - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + assert.Equal(t, st.ConversationStatusStable, c.Status()) - // Agent becomes ready, but status stays "changing" because the - // initial prompt is still in the outbound queue. + // Agent becomes ready, prompt gets enqueued, status becomes "changing" agent.setScreen("ready") advanceFor(ctx, t, mClock, 1*time.Second) assert.Equal(t, st.ConversationStatusChanging, c.Status()) @@ -886,12 +885,12 @@ func TestInitialPromptReadiness(t *testing.T) { c := st.NewPTY(ctx, cfg, &testEmitter{}) c.Start(ctx) - // Status is "changing" while waiting for readiness. + // Status is "stable" while waiting for readiness (prompt not yet enqueued). advanceFor(ctx, t, mClock, 1*time.Second) - assert.Equal(t, st.ConversationStatusChanging, c.Status()) + assert.Equal(t, st.ConversationStatusStable, c.Status()) - // Agent becomes ready. The readiness loop detects this, the snapshot - // loop sees queue + stable + ready and signals the send loop. + // Agent becomes ready. The snapshot loop detects this, enqueues the prompt, + // then sees queue + stable + ready and signals the send loop. // writeStabilize runs with onWrite changing the screen, so it completes. agent.setScreen("ready") // Drive clock until the initial prompt is sent (queue drains). @@ -964,3 +963,326 @@ func TestInitialPromptReadiness(t *testing.T) { assert.Equal(t, st.ConversationStatusStable, c.Status()) }) } + +func TestInitialPromptSent(t *testing.T) { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + t.Run("initialPromptSent is set when initial prompt is sent", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready and initial prompt is sent + agent.setScreen("ready") + advanceUntil(ctx, t, mClock, func() bool { + return len(c.Messages()) >= 2 + }) + + // Save state and verify initialPromptSent is persisted + agent.setScreen("response") + advanceFor(ctx, t, mClock, 2*time.Second) + + err := c.SaveState() + require.NoError(t, err) + + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.True(t, agentState.InitialPromptSent, "initialPromptSent should be true after initial prompt is sent") + assert.Equal(t, "test prompt", agentState.InitialPrompt) + }) + + t.Run("initialPromptSent prevents re-sending prompt after state load", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with initialPromptSent=true + testState := st.AgentState{ + Version: 1, + InitialPrompt: "test prompt", + InitialPromptSent: true, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + {Id: 1, Message: "test prompt", Role: st.ConversationRoleUser, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with same initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + writeCount := 0 + agent.onWrite = func(data []byte) { + writeCount++ + agent.screen = "after_write" + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "test prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Advance until ready and state is loaded + advanceFor(ctx, t, mClock, 500*time.Millisecond) + + // Verify the prompt was NOT re-sent (no writes occurred) + assert.Equal(t, 0, writeCount, "initial prompt should not be re-sent when already sent") + + // Messages should be restored from state (at minimum, the original 2) + messages := c.Messages() + assert.GreaterOrEqual(t, len(messages), 2, "messages should be restored from state") + // Verify the first two messages match what we saved + assert.Equal(t, "agent message", messages[0].Message) + assert.Equal(t, st.ConversationRoleAgent, messages[0].Role) + assert.Equal(t, "test prompt", messages[1].Message) + assert.Equal(t, st.ConversationRoleUser, messages[1].Role) + }) + + t.Run("new initial prompt is sent if different from saved prompt", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with old prompt + testState := st.AgentState{ + Version: 1, + InitialPrompt: "old prompt", + InitialPromptSent: true, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation with different initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + InitialPrompt: []st.MessagePart{st.MessagePartText{Content: "new prompt"}}, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready + agent.setScreen("ready") + + // Advance until the new prompt is sent + advanceUntil(ctx, t, mClock, func() bool { + msgs := c.Messages() + // Look for the new prompt in messages + for _, msg := range msgs { + if msg.Role == st.ConversationRoleUser && msg.Message == "new prompt" { + return true + } + } + return false + }) + + // Verify the new prompt was sent + messages := c.Messages() + found := false + for _, msg := range messages { + if msg.Role == st.ConversationRoleUser && msg.Message == "new prompt" { + found = true + break + } + } + assert.True(t, found, "new prompt should be sent when different from saved prompt") + }) + + t.Run("initialPromptSent not set when no initial prompt configured", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "ready"} + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 100 * time.Millisecond, + ScreenStabilityLength: 200 * time.Millisecond, + AgentIO: agent, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: false, + SaveState: true, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + advanceFor(ctx, t, mClock, 300*time.Millisecond) + + err := c.SaveState() + require.NoError(t, err) + + data, err := os.ReadFile(stateFile) + require.NoError(t, err) + + var agentState st.AgentState + err = json.Unmarshal(data, &agentState) + require.NoError(t, err) + + assert.False(t, agentState.InitialPromptSent, "initialPromptSent should be false when no initial prompt configured") + }) + + t.Run("restored prompt used when no new prompt provided", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + t.Cleanup(cancel) + + tmpDir := t.TempDir() + stateFile := tmpDir + "/state.json" + + // Create a state file with a prompt + testState := st.AgentState{ + Version: 1, + InitialPrompt: "saved prompt", + InitialPromptSent: false, + Messages: []st.ConversationMessage{ + {Id: 0, Message: "agent message", Role: st.ConversationRoleAgent, Time: time.Now()}, + }, + } + data, err := json.MarshalIndent(testState, "", " ") + require.NoError(t, err) + err = os.WriteFile(stateFile, data, 0o644) + require.NoError(t, err) + + // Create conversation without providing an initial prompt + mClock := quartz.NewMock(t) + agent := &testAgent{screen: "loading..."} + writeCounter := 0 + agent.onWrite = func(data []byte) { + writeCounter++ + agent.screen = fmt.Sprintf("__write_%d", writeCounter) + } + + cfg := st.PTYConversationConfig{ + Clock: mClock, + SnapshotInterval: 1 * time.Second, + ScreenStabilityLength: 0, + AgentIO: agent, + ReadyForInitialPrompt: func(message string) bool { + return message == "ready" + }, + Logger: discardLogger, + StatePersistenceConfig: st.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: true, + SaveState: false, + }, + } + + c := st.NewPTY(ctx, cfg, &testEmitter{}) + c.Start(ctx) + + // Agent becomes ready + agent.setScreen("ready") + + // Advance until the saved prompt is sent + advanceUntil(ctx, t, mClock, func() bool { + msgs := c.Messages() + for _, msg := range msgs { + if msg.Role == st.ConversationRoleUser && msg.Message == "saved prompt" { + return true + } + } + return false + }) + + // Verify the saved prompt was sent + messages := c.Messages() + found := false + for _, msg := range messages { + if msg.Role == st.ConversationRoleUser && msg.Message == "saved prompt" { + found = true + break + } + } + assert.True(t, found, "saved prompt should be sent when no new prompt provided") + }) +} From 03c6f1646ca0c60be1a31e0f8346ce4601232f75 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 16:18:54 +0530 Subject: [PATCH 22/27] chore: comments --- lib/screentracker/pty_conversation.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 37881d97..c779e632 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -130,7 +130,8 @@ type PTYConversation struct { // initialPromptReady is set to true when ReadyForInitialPrompt returns true. // Checked inline in the snapshot loop on each tick. initialPromptReady bool - initialPromptSent bool + // initialPromptSent is set to true when the initial prompt has been enqueued to the outbound queue. + initialPromptSent bool } var _ Conversation = &PTYConversation{} From bd75240a0b1c94c0b6a2d9d05c874063a46de2da Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 16:44:58 +0530 Subject: [PATCH 23/27] chore: address cian's file permission comments --- lib/screentracker/pty_conversation.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index c779e632..c976f23f 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -573,13 +573,13 @@ func (c *PTYConversation) SaveState() error { // Create directory if it doesn't exist dir := filepath.Dir(stateFile) - if err := os.MkdirAll(dir, 0o755); err != nil { + if err := os.MkdirAll(dir, 0o700); err != nil { return xerrors.Errorf("failed to create state directory: %w", err) } // Write to temp file tempFile := stateFile + ".tmp" - if err := os.WriteFile(tempFile, data, 0o644); err != nil { + if err := os.WriteFile(tempFile, data, 0o600); err != nil { return xerrors.Errorf("failed to write temp state file: %w", err) } From b1ab61545bd4640a6b0c08c10b1d67cbae4cff47 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 19:09:20 +0530 Subject: [PATCH 24/27] feat: implement error handling for agent events --- chat/src/app/layout.tsx | 2 +- chat/src/components/chat-provider.tsx | 25 +++++++++++ lib/httpapi/events.go | 33 ++++++++++++++ lib/httpapi/server.go | 1 + lib/screentracker/conversation.go | 1 + lib/screentracker/pty_conversation.go | 22 ++++++---- lib/screentracker/pty_conversation_test.go | 1 + openapi.json | 50 ++++++++++++++++++++++ 8 files changed, 126 insertions(+), 9 deletions(-) diff --git a/chat/src/app/layout.tsx b/chat/src/app/layout.tsx index 830124ed..7c44c440 100644 --- a/chat/src/app/layout.tsx +++ b/chat/src/app/layout.tsx @@ -29,7 +29,7 @@ export default function RootLayout({ disableTransitionOnChange > {children} - + diff --git a/chat/src/components/chat-provider.tsx b/chat/src/components/chat-provider.tsx index 21a2ee3f..e8789ac0 100644 --- a/chat/src/components/chat-provider.tsx +++ b/chat/src/components/chat-provider.tsx @@ -36,6 +36,12 @@ interface StatusChangeEvent { agent_type: string; } +interface ErrorEventData { + message: string; + level: string; + time: string; +} + interface APIErrorDetail { location: string; message: string; @@ -215,6 +221,25 @@ export function ChatProvider({ children }: PropsWithChildren) { setAgentType(data.agent_type === "" ? "unknown" : data.agent_type as AgentType); }); + // Handle agent error events + eventSource.addEventListener("agent_error", (event) => { + const messageEvent = event as MessageEvent; + try { + const data: ErrorEventData = JSON.parse(messageEvent.data); + + // Display error as toast notification that persists until manually dismissed + if (data.level === "error") { + toast.error(data.message, { duration: Infinity }); + } else if (data.level === "warning") { + toast.warning(data.message, { duration: Infinity }); + } else { + toast.info(data.message, { duration: Infinity }); + } + } catch (e) { + console.error("Failed to parse agent_error event data:", e); + } + }); + // Handle connection open (server is online) eventSource.onopen = () => { // Connection is established, but we'll wait for status_change event diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 30aba036..d2ba4d3f 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -18,6 +18,7 @@ const ( EventTypeMessageUpdate EventType = "message_update" EventTypeStatusChange EventType = "status_change" EventTypeScreenUpdate EventType = "screen_update" + EventTypeError EventType = "agent_error" ) type AgentStatus string @@ -52,6 +53,12 @@ type ScreenUpdateBody struct { Screen string `json:"screen"` } +type ErrorBody struct { + Message string `json:"message" doc:"Error message"` + Level string `json:"level" doc:"Error level: 'warning' or 'error'"` + Time time.Time `json:"time" doc:"Timestamp when the error occurred"` +} + type Event struct { Type EventType Payload any @@ -66,6 +73,7 @@ type EventEmitter struct { chanIdx int subscriptionBufSize uint screen string + errors []ErrorBody } func convertStatus(status st.ConversationStatus) AgentStatus { @@ -194,6 +202,22 @@ func (e *EventEmitter) EmitScreen(newScreen string) { e.screen = newScreen } +func (e *EventEmitter) EmitError(message string, level string) { + e.mu.Lock() + defer e.mu.Unlock() + + errorBody := ErrorBody{ + Message: message, + Level: level, + Time: time.Now(), + } + + // Store the error so new subscribers can receive all errors + e.errors = append(e.errors, errorBody) + + e.notifyChannels(EventTypeError, errorBody) +} + // Assumes the caller holds the lock. func (e *EventEmitter) currentStateAsEvents() []Event { events := make([]Event, 0, len(e.messages)+2) @@ -211,6 +235,15 @@ func (e *EventEmitter) currentStateAsEvents() []Event { Type: EventTypeScreenUpdate, Payload: ScreenUpdateBody{Screen: strings.TrimRight(e.screen, mf.WhiteSpaceChars)}, }) + + // Include all error events + for _, err := range e.errors { + events = append(events, Event{ + Type: EventTypeError, + Payload: err, + }) + } + return events } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index a53b5074..f18ce679 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -396,6 +396,7 @@ func (s *Server) registerRoutes() { // Mapping of event type name to Go struct for that event. "message_update": MessageUpdateBody{}, "status_change": StatusChangeBody{}, + "agent_error": ErrorBody{}, }, s.subscribeEvents) sse.Register(s.api, huma.Operation{ diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 358f56d3..8555e424 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -80,6 +80,7 @@ type Emitter interface { EmitMessages([]ConversationMessage) EmitStatus(ConversationStatus) EmitScreen(string) + EmitError(message string, level string) } type ConversationMessage struct { diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index c976f23f..e1ae169d 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -125,8 +125,8 @@ type PTYConversation struct { firstStableSnapshot string // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state userSentMessageAfterLoadState bool - // loadStateSuccessful indicates whether conversation state was successfully restored from file. - loadStateSuccessful bool + // loadStateAttempted indicates whether we have attempted to load conversation state from file (regardless of success). + loadStateAttempted bool // initialPromptReady is set to true when ReadyForInitialPrompt returns true. // Checked inline in the snapshot loop on each tick. initialPromptReady bool @@ -141,6 +141,7 @@ type noopEmitter struct{} func (noopEmitter) EmitMessages([]ConversationMessage) {} func (noopEmitter) EmitStatus(ConversationStatus) {} func (noopEmitter) EmitScreen(string) {} +func (noopEmitter) EmitError(_ string, _ string) {} func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation { if cfg.Clock == nil { @@ -168,7 +169,7 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PT dirty: false, firstStableSnapshot: "", userSentMessageAfterLoadState: false, - loadStateSuccessful: false, + loadStateAttempted: false, } if c.cfg.ReadyForInitialPrompt == nil { c.cfg.ReadyForInitialPrompt = func(string) bool { return true } @@ -192,11 +193,13 @@ func (c *PTYConversation) Start(ctx context.Context) { c.initialPromptReady = true } - if c.initialPromptReady && !c.loadStateSuccessful && c.cfg.StatePersistenceConfig.LoadState { + var loadStateErr error + if c.initialPromptReady && !c.loadStateAttempted && c.cfg.StatePersistenceConfig.LoadState { if err := c.loadStateLocked(); err != nil { c.cfg.Logger.Error("Failed to load state", "error", err) + loadStateErr = err } - c.loadStateSuccessful = true + c.loadStateAttempted = true } // Enqueue initial prompt once after agent is ready (and after state is potentially loaded) @@ -219,6 +222,9 @@ func (c *PTYConversation) Start(ctx context.Context) { c.emitter.EmitStatus(status) c.emitter.EmitMessages(messages) c.emitter.EmitScreen(screen) + if loadStateErr != nil { + c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", loadStateErr), "warning") + } return nil }, "snapshot") @@ -282,7 +288,7 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } - if c.loadStateSuccessful { + if c.loadStateAttempted { agentMessage = c.adjustScreenAfterStateLoad(agentMessage) } if c.cfg.FormatToolCall != nil { @@ -601,7 +607,7 @@ func (c *PTYConversation) loadStateLocked() error { stateFile := c.cfg.StatePersistenceConfig.StateFile loadState := c.cfg.StatePersistenceConfig.LoadState - if !loadState || c.loadStateSuccessful { + if !loadState || c.loadStateAttempted { return nil } @@ -656,7 +662,7 @@ func (c *PTYConversation) loadStateLocked() error { c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") } - c.loadStateSuccessful = true + c.loadStateAttempted = true c.dirty = false c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) diff --git a/lib/screentracker/pty_conversation_test.go b/lib/screentracker/pty_conversation_test.go index f201edc5..c8a49c7e 100644 --- a/lib/screentracker/pty_conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -56,6 +56,7 @@ type testEmitter struct{} func (testEmitter) EmitMessages([]st.ConversationMessage) {} func (testEmitter) EmitStatus(st.ConversationStatus) {} func (testEmitter) EmitScreen(string) {} +func (testEmitter) EmitError(_ string, _ string) {} // advanceFor is a shorthand for advanceUntil with a time-based condition. func advanceFor(ctx context.Context, t *testing.T, mClock *quartz.Mock, total time.Duration) { diff --git a/openapi.json b/openapi.json index dda817cc..790338a6 100644 --- a/openapi.json +++ b/openapi.json @@ -19,6 +19,30 @@ "title": "ConversationRole", "type": "string" }, + "ErrorBody": { + "additionalProperties": false, + "properties": { + "level": { + "description": "Error level: 'warning' or 'error'", + "type": "string" + }, + "message": { + "description": "Error message", + "type": "string" + }, + "time": { + "description": "Timestamp when the error occurred", + "format": "date-time", + "type": "string" + } + }, + "required": [ + "level", + "message", + "time" + ], + "type": "object" + }, "ErrorDetail": { "additionalProperties": false, "properties": { @@ -326,6 +350,32 @@ "description": "Each oneOf object in the array represents one possible Server Sent Events (SSE) message, serialized as UTF-8 text according to the SSE specification.", "items": { "oneOf": [ + { + "properties": { + "data": { + "$ref": "#/components/schemas/ErrorBody" + }, + "event": { + "const": "agent_error", + "description": "The event name.", + "type": "string" + }, + "id": { + "description": "The event ID.", + "type": "integer" + }, + "retry": { + "description": "The retry time in milliseconds.", + "type": "integer" + } + }, + "required": [ + "data", + "event" + ], + "title": "Event agent_error", + "type": "object" + }, { "properties": { "data": { From 31d27a78e0f7ec4c15fed5d636619581fac5eb2f Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Thu, 19 Feb 2026 19:56:24 +0530 Subject: [PATCH 25/27] fix: no screen adjustment in case of loadState failure --- lib/screentracker/pty_conversation.go | 35 +++++++++++++++++---------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index e1ae169d..407bb10b 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -37,6 +37,18 @@ type AgentState struct { InitialPromptSent bool `json:"initial_prompt_sent"` } +// LoadStateStatus represents the state of loading persisted conversation state. +type LoadStateStatus int + +const ( + // LoadStatePending indicates state loading has not been attempted yet. + LoadStatePending LoadStateStatus = iota + // LoadStateSucceeded indicates state was successfully loaded. + LoadStateSucceeded + // LoadStateFailed indicates state loading was attempted but failed. + LoadStateFailed +) + var _ MessagePart = &MessagePartText{} func (p MessagePartText) Do(writer AgentIO) error { @@ -125,8 +137,8 @@ type PTYConversation struct { firstStableSnapshot string // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state userSentMessageAfterLoadState bool - // loadStateAttempted indicates whether we have attempted to load conversation state from file (regardless of success). - loadStateAttempted bool + // loadStateStatus tracks the status of loading conversation state from file. + loadStateStatus LoadStateStatus // initialPromptReady is set to true when ReadyForInitialPrompt returns true. // Checked inline in the snapshot loop on each tick. initialPromptReady bool @@ -169,7 +181,7 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PT dirty: false, firstStableSnapshot: "", userSentMessageAfterLoadState: false, - loadStateAttempted: false, + loadStateStatus: LoadStatePending, } if c.cfg.ReadyForInitialPrompt == nil { c.cfg.ReadyForInitialPrompt = func(string) bool { return true } @@ -193,13 +205,14 @@ func (c *PTYConversation) Start(ctx context.Context) { c.initialPromptReady = true } - var loadStateErr error - if c.initialPromptReady && !c.loadStateAttempted && c.cfg.StatePersistenceConfig.LoadState { + if c.initialPromptReady && c.loadStateStatus == LoadStatePending && c.cfg.StatePersistenceConfig.LoadState { if err := c.loadStateLocked(); err != nil { c.cfg.Logger.Error("Failed to load state", "error", err) - loadStateErr = err + c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", err), "warning") + c.loadStateStatus = LoadStateFailed + } else { + c.loadStateStatus = LoadStateSucceeded } - c.loadStateAttempted = true } // Enqueue initial prompt once after agent is ready (and after state is potentially loaded) @@ -222,9 +235,6 @@ func (c *PTYConversation) Start(ctx context.Context) { c.emitter.EmitStatus(status) c.emitter.EmitMessages(messages) c.emitter.EmitScreen(screen) - if loadStateErr != nil { - c.emitter.EmitError(fmt.Sprintf("Failed to restore previous session: %v", loadStateErr), "warning") - } return nil }, "snapshot") @@ -288,7 +298,7 @@ func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp if c.cfg.FormatMessage != nil { agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) } - if c.loadStateAttempted { + if c.loadStateStatus == LoadStateSucceeded { agentMessage = c.adjustScreenAfterStateLoad(agentMessage) } if c.cfg.FormatToolCall != nil { @@ -607,7 +617,7 @@ func (c *PTYConversation) loadStateLocked() error { stateFile := c.cfg.StatePersistenceConfig.StateFile loadState := c.cfg.StatePersistenceConfig.LoadState - if !loadState || c.loadStateAttempted { + if !loadState || c.loadStateStatus != LoadStatePending { return nil } @@ -662,7 +672,6 @@ func (c *PTYConversation) loadStateLocked() error { c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") } - c.loadStateAttempted = true c.dirty = false c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) From 220d360e2345771d803bfcce06ac60e61c73f711 Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Fri, 20 Feb 2026 12:57:16 +0530 Subject: [PATCH 26/27] feat: add three e2e tests for statePersistence --- e2e/echo_test.go | 251 ++++++++++++++++-- e2e/testdata/state_persistence.json | 18 ++ ..._persistence_different_initial_prompt.json | 10 + ...tence_different_initial_prompt_phase1.json | 10 + .../state_persistence_initial_prompt.json | 14 + 5 files changed, 288 insertions(+), 15 deletions(-) create mode 100644 e2e/testdata/state_persistence.json create mode 100644 e2e/testdata/state_persistence_different_initial_prompt.json create mode 100644 e2e/testdata/state_persistence_different_initial_prompt_phase1.json create mode 100644 e2e/testdata/state_persistence_initial_prompt.json diff --git a/e2e/echo_test.go b/e2e/echo_test.go index 765521cf..ac7d58e3 100644 --- a/e2e/echo_test.go +++ b/e2e/echo_test.go @@ -40,7 +40,8 @@ func TestE2E(t *testing.T) { t.Run("basic", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - script, apiClient := setup(ctx, t, nil) + script, apiClient, cleanup := setup(ctx, t, nil, true) + defer cleanup() messageReq := agentapisdk.PostMessageParams{ Content: "This is a test message.", Type: agentapisdk.MessageTypeUser, @@ -60,7 +61,8 @@ func TestE2E(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - script, apiClient := setup(ctx, t, nil) + script, apiClient, cleanup := setup(ctx, t, nil, true) + defer cleanup() messageReq := agentapisdk.PostMessageParams{ Content: "What is the answer to life, the universe, and everything?", Type: agentapisdk.MessageTypeUser, @@ -86,13 +88,14 @@ func TestE2E(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - script, apiClient := setup(ctx, t, ¶ms{ + script, apiClient, cleanup := setup(ctx, t, ¶ms{ cmdFn: func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { defCmd, defArgs := defaultCmdFn(ctx, t, serverPort, binaryPath, cwd, scriptFilePath) script := fmt.Sprintf(`echo "hello agent" | %s %s`, defCmd, strings.Join(defArgs, " ")) return "/bin/sh", []string{"-c", script} }, - }) + }, true) + defer cleanup() require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, 5*time.Second, "stdin")) msgResp, err := apiClient.GetMessages(ctx) require.NoError(t, err, "Failed to get messages via SDK") @@ -100,27 +103,230 @@ func TestE2E(t *testing.T) { require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content)) require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content)) }) + + t.Run("state_persistence", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Create a temporary state file + stateFile := filepath.Join(t.TempDir(), "state.json") + scriptFilePath := filepath.Join("testdata", "state_persistence.json") + + // Step 1: Start server with state persistence enabled and send first message + script, apiClient, cleanup := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + }, true) + + // Send first message + messageReq := agentapisdk.PostMessageParams{ + Content: "First message before state save.", + Type: agentapisdk.MessageTypeUser, + } + _, err := apiClient.PostMessage(ctx, messageReq) + require.NoError(t, err, "Failed to send first message") + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "first message")) + + // Verify messages before shutdown + msgResp, err := apiClient.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages before shutdown") + require.Len(t, msgResp.Messages, 3, "Expected 3 messages before shutdown") + require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[0].Content)) + require.Equal(t, script[1].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content)) + require.Equal(t, script[1].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content)) + + // Step 2: Stop server (triggers state save) + cleanup() + + // Give filesystem a moment to sync + time.Sleep(100 * time.Millisecond) + + // Verify state file was created + require.FileExists(t, stateFile, "State file should exist after shutdown") + + // Step 3: Start new server instance and load state + // Note: We don't wait for stable here because the echo agent will try to replay + // from the beginning, which conflicts with restored state. We just verify the + // state was loaded and messages are present. + _, apiClient2, cleanup2 := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + }, false) + defer cleanup2() + + // Give the server a moment to load state + time.Sleep(500 * time.Millisecond) + + // Step 4: Verify state was restored by checking messages via API + msgResp2, err := apiClient2.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages after state restore") + require.Len(t, msgResp2.Messages, 3, "Expected 3 messages after state restore") + + // Verify all messages match the state before shutdown + require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp2.Messages[0].Content)) + require.Equal(t, script[1].ExpectMessage, strings.TrimSpace(msgResp2.Messages[1].Content)) + require.Equal(t, script[1].ResponseMessage, strings.TrimSpace(msgResp2.Messages[2].Content)) + }) + + t.Run("state_persistence_initial_prompt", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Create a temporary state file + stateFile := filepath.Join(t.TempDir(), "state.json") + scriptFilePath := filepath.Join("testdata", "state_persistence_initial_prompt.json") + + // Step 1: Start server with initial prompt + initialPrompt1 := "Test initial prompt" + _, apiClient, cleanup := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + initialPrompt: initialPrompt1, + }, true) + + // Verify initial prompt was sent (should have 3 messages: agent greeting + initial prompt + response) + msgResp, err := apiClient.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages after initial prompt") + require.Len(t, msgResp.Messages, 3, "Expected 3 messages after initial prompt") + require.Equal(t, "Hello! I'm ready to help you.", strings.TrimSpace(msgResp.Messages[0].Content)) + require.Equal(t, initialPrompt1, strings.TrimSpace(msgResp.Messages[1].Content)) + require.Equal(t, "Echo: Test initial prompt", strings.TrimSpace(msgResp.Messages[2].Content)) + + // Step 2: Close server + cleanup() + time.Sleep(100 * time.Millisecond) + require.FileExists(t, stateFile, "State file should exist after shutdown") + + // Step 3: Restart WITHOUT an initial prompt + _, apiClient2, cleanup2 := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + }, false) + defer cleanup2() + time.Sleep(500 * time.Millisecond) + + // Step 4: Verify initial prompt was NOT sent again (should still have 3 messages) + msgResp2, err := apiClient2.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages after restart without initial prompt") + require.Len(t, msgResp2.Messages, 3, "Expected 3 messages (initial prompt should not be sent again)") + require.Equal(t, initialPrompt1, strings.TrimSpace(msgResp2.Messages[1].Content)) + + // Step 5: Close server + cleanup2() + time.Sleep(100 * time.Millisecond) + + // Step 6: Restart with same initial prompt + _, apiClient3, cleanup3 := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: scriptFilePath, + initialPrompt: initialPrompt1, + }, false) + defer cleanup3() + time.Sleep(500 * time.Millisecond) + + // Step 7: Verify same initial prompt was NOT sent again (should still have 3 messages) + msgResp3, err := apiClient3.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages after restart with same initial prompt") + require.Len(t, msgResp3.Messages, 3, "Expected 3 messages (same initial prompt should not be sent again)") + + }) + + t.Run("state_persistence_different_initial_prompt", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Create a temporary state file + stateFile := filepath.Join(t.TempDir(), "state.json") + + // Step 1: Start server with initial prompt "Test initial prompt" using phase1 script + initialPrompt1 := "Test initial prompt" + _, apiClient, cleanup := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: filepath.Join("testdata", "state_persistence_different_initial_prompt_phase1.json"), + initialPrompt: initialPrompt1, + }, true) + + // Verify initial prompt was sent (3 messages: greeting + prompt + response) + msgResp, err := apiClient.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages after initial prompt") + require.Len(t, msgResp.Messages, 3, "Expected 3 messages after initial prompt") + require.Equal(t, "Hello! I'm ready to help you.", strings.TrimSpace(msgResp.Messages[0].Content)) + require.Equal(t, initialPrompt1, strings.TrimSpace(msgResp.Messages[1].Content)) + require.Equal(t, "Echo: Test initial prompt", strings.TrimSpace(msgResp.Messages[2].Content)) + + // Step 2: Close server + cleanup() + time.Sleep(100 * time.Millisecond) + require.FileExists(t, stateFile, "State file should exist after shutdown") + + // Step 3: Restart with DIFFERENT initial prompt using a different script + initialPrompt2 := "Different initial prompt" + _, apiClient2, cleanup2 := setup(ctx, t, ¶ms{ + stateFile: stateFile, + scriptFilePath: filepath.Join("testdata", "state_persistence_different_initial_prompt.json"), + initialPrompt: initialPrompt2, + }, false) + defer cleanup2() + + // Wait for initial prompt to be processed + time.Sleep(1 * time.Second) + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient2, operationTimeout, "after different initial prompt")) + + // Step 4: Verify new initial prompt WAS sent (5 messages: 3 previous + 2 new) + msgResp2, err := apiClient2.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages after different initial prompt") + require.Len(t, msgResp2.Messages, 5, "Expected 5 messages after different initial prompt (3 previous + 2 new)") + // Verify the new initial prompt and response were added + require.Equal(t, initialPrompt2, strings.TrimSpace(msgResp2.Messages[3].Content)) + require.Equal(t, "Echo: Different initial prompt", strings.TrimSpace(msgResp2.Messages[4].Content)) + + }) } type params struct { - cmdFn func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) + cmdFn func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) + stateFile string + scriptFilePath string + initialPrompt string } func defaultCmdFn(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { return binaryPath, []string{"server", fmt.Sprintf("--port=%d", serverPort), "--", "go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath} } -func setup(ctx context.Context, t testing.TB, p *params) ([]ScriptEntry, *agentapisdk.Client) { +func stateCmdFn(stateFile, initialPrompt string) func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { + return func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { + args := []string{ + "server", + fmt.Sprintf("--port=%d", serverPort), + fmt.Sprintf("--state-file=%s", stateFile), + } + if initialPrompt != "" { + args = append(args, fmt.Sprintf("--initial-prompt=%s", initialPrompt)) + } + args = append(args, "--", "go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath) + return binaryPath, args + } +} + +func setup(ctx context.Context, t testing.TB, p *params, waitForStable bool) ([]ScriptEntry, *agentapisdk.Client, func()) { t.Helper() if p == nil { p = ¶ms{} } if p.cmdFn == nil { - p.cmdFn = defaultCmdFn + if p.stateFile != "" { + p.cmdFn = stateCmdFn(p.stateFile, p.initialPrompt) + } else { + p.cmdFn = defaultCmdFn + } } - scriptFilePath := filepath.Join("testdata", filepath.Base(t.Name())+".json") + scriptFilePath := p.scriptFilePath + if scriptFilePath == "" { + scriptFilePath = filepath.Join("testdata", filepath.Base(t.Name())+".json") + } data, err := os.ReadFile(scriptFilePath) require.NoError(t, err, "Failed to read test script file: %s", scriptFilePath) @@ -175,22 +381,37 @@ func setup(ctx context.Context, t testing.TB, p *params) ([]ScriptEntry, *agenta logOutput(t, "SERVER-STDERR", stderr) }() - // Clean up process - t.Cleanup(func() { + // Create cleanup function + cleanup := func() { if cmd.Process != nil { - _ = cmd.Process.Kill() - _ = cmd.Wait() + // Send SIGTERM to allow graceful shutdown and state save + _ = cmd.Process.Signal(os.Interrupt) + // Wait for process to exit gracefully (with timeout) + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + select { + case <-done: + // Process exited gracefully + case <-time.After(5 * time.Second): + // Timeout, force kill + _ = cmd.Process.Kill() + <-done + } } wg.Wait() - }) + } serverURL := fmt.Sprintf("http://localhost:%d", serverPort) require.NoError(t, waitForServer(ctx, t, serverURL, healthCheckTimeout), "Server not ready") apiClient, err := agentapisdk.NewClient(serverURL) require.NoError(t, err, "Failed to create agentapi SDK client") - require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "setup")) - return script, apiClient + if waitForStable { + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "setup")) + } + return script, apiClient, cleanup } // logOutput logs process output with prefix diff --git a/e2e/testdata/state_persistence.json b/e2e/testdata/state_persistence.json new file mode 100644 index 00000000..b7fe071d --- /dev/null +++ b/e2e/testdata/state_persistence.json @@ -0,0 +1,18 @@ +[ + { + "expectMessage": "", + "responseMessage": "Hello! I'm ready to help you." + }, + { + "expectMessage": "First message before state save.", + "responseMessage": "Echo: First message before state save." + }, + { + "expectMessage": "Test initial prompt", + "responseMessage": "Echo: Test initial prompt" + }, + { + "expectMessage": "Different initial prompt", + "responseMessage": "Echo: Different initial prompt" + } +] diff --git a/e2e/testdata/state_persistence_different_initial_prompt.json b/e2e/testdata/state_persistence_different_initial_prompt.json new file mode 100644 index 00000000..60c610fd --- /dev/null +++ b/e2e/testdata/state_persistence_different_initial_prompt.json @@ -0,0 +1,10 @@ +[ + { + "expectMessage": "", + "responseMessage": "Hello! I'm ready to help you." + }, + { + "expectMessage": "Different initial prompt", + "responseMessage": "Echo: Different initial prompt" + } +] diff --git a/e2e/testdata/state_persistence_different_initial_prompt_phase1.json b/e2e/testdata/state_persistence_different_initial_prompt_phase1.json new file mode 100644 index 00000000..685ac187 --- /dev/null +++ b/e2e/testdata/state_persistence_different_initial_prompt_phase1.json @@ -0,0 +1,10 @@ +[ + { + "expectMessage": "", + "responseMessage": "Hello! I'm ready to help you." + }, + { + "expectMessage": "Test initial prompt", + "responseMessage": "Echo: Test initial prompt" + } +] diff --git a/e2e/testdata/state_persistence_initial_prompt.json b/e2e/testdata/state_persistence_initial_prompt.json new file mode 100644 index 00000000..cdd8d767 --- /dev/null +++ b/e2e/testdata/state_persistence_initial_prompt.json @@ -0,0 +1,14 @@ +[ + { + "expectMessage": "", + "responseMessage": "Hello! I'm ready to help you." + }, + { + "expectMessage": "Test initial prompt", + "responseMessage": "Echo: Test initial prompt" + }, + { + "expectMessage": "Different initial prompt", + "responseMessage": "Echo: Different initial prompt" + } +] From eef927dd4dd2933974822b618ebb6e7c8294818c Mon Sep 17 00:00:00 2001 From: 35C4n0r Date: Fri, 20 Feb 2026 21:46:25 +0530 Subject: [PATCH 27/27] feat: address maf's review --- cmd/server/server.go | 22 ++++---------- e2e/echo_test.go | 10 ++----- lib/httpapi/server.go | 5 +++- lib/screentracker/pty_conversation.go | 41 ++++++++++++++------------- lib/termexec/termexec.go | 2 +- 5 files changed, 34 insertions(+), 46 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index cfb9640d..4ddc8f98 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -140,6 +140,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er if err := writePIDFile(pidFile, logger); err != nil { return xerrors.Errorf("failed to write PID file: %w", err) } + defer cleanupPIDFile(pidFile, logger) } printOpenAPI := viper.GetBool(FlagPrintOpenAPI) @@ -189,17 +190,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er // Setup signal handlers (they will call gracefulCancel) handleSignals(gracefulCtx, gracefulCancel, logger, srv) - // Setup PID file cleanup - if pidFile != "" { - defer cleanupPIDFile(pidFile, logger) - } - logger.Info("Starting server on port", "port", port) // Monitor process exit processExitCh := make(chan error, 1) go func() { defer close(processExitCh) + defer gracefulCancel() if err := process.Wait(); err != nil { if errors.Is(err, termexec.ErrNonZeroExitCode) { processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err) @@ -207,12 +204,6 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) } } - - select { - case <-gracefulCtx.Done(): - default: - gracefulCancel() - } }() // Start the server @@ -243,17 +234,16 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er logger.Error("Failed to stop HTTP server", "error", err) } - // Close the process - if err := process.Close(logger, 5*time.Second); err != nil { - logger.Error("Failed to close process cleanly", "error", err) - } - select { case err := <-processExitCh: if err != nil { return xerrors.Errorf("agent exited with error: %w", err) } default: + // Close the process + if err := process.Close(logger, 5*time.Second); err != nil { + logger.Error("Failed to close process cleanly", "error", err) + } } return nil } diff --git a/e2e/echo_test.go b/e2e/echo_test.go index ac7d58e3..1568d71f 100644 --- a/e2e/echo_test.go +++ b/e2e/echo_test.go @@ -138,9 +138,6 @@ func TestE2E(t *testing.T) { // Step 2: Stop server (triggers state save) cleanup() - // Give filesystem a moment to sync - time.Sleep(100 * time.Millisecond) - // Verify state file was created require.FileExists(t, stateFile, "State file should exist after shutdown") @@ -194,7 +191,6 @@ func TestE2E(t *testing.T) { // Step 2: Close server cleanup() - time.Sleep(100 * time.Millisecond) require.FileExists(t, stateFile, "State file should exist after shutdown") // Step 3: Restart WITHOUT an initial prompt @@ -213,7 +209,6 @@ func TestE2E(t *testing.T) { // Step 5: Close server cleanup2() - time.Sleep(100 * time.Millisecond) // Step 6: Restart with same initial prompt _, apiClient3, cleanup3 := setup(ctx, t, ¶ms{ @@ -256,7 +251,6 @@ func TestE2E(t *testing.T) { // Step 2: Close server cleanup() - time.Sleep(100 * time.Millisecond) require.FileExists(t, stateFile, "State file should exist after shutdown") // Step 3: Restart with DIFFERENT initial prompt using a different script @@ -384,7 +378,7 @@ func setup(ctx context.Context, t testing.TB, p *params, waitForStable bool) ([] // Create cleanup function cleanup := func() { if cmd.Process != nil { - // Send SIGTERM to allow graceful shutdown and state save + // Send SIGINT to allow graceful shutdown and state save _ = cmd.Process.Signal(os.Interrupt) // Wait for process to exit gracefully (with timeout) done := make(chan error, 1) @@ -394,7 +388,7 @@ func setup(ctx context.Context, t testing.TB, p *params, waitForStable bool) ([] select { case <-done: // Process exited gracefully - case <-time.After(5 * time.Second): + case <-time.After(10 * time.Second): // Timeout, force kill _ = cmd.Process.Kill() <-done diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index f18ce679..00685f04 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -615,7 +616,9 @@ func (s *Server) Stop(ctx context.Context) error { s.cleanupTempDir() if s.srv != nil { - err = s.srv.Shutdown(ctx) + if err = s.srv.Shutdown(ctx); errors.Is(err, http.ErrServerClosed) { + err = nil + } } }) return err diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go index 407bb10b..38cb689f 100644 --- a/lib/screentracker/pty_conversation.go +++ b/lib/screentracker/pty_conversation.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "io" "log/slog" "os" "path/filepath" @@ -215,7 +214,6 @@ func (c *PTYConversation) Start(ctx context.Context) { } } - // Enqueue initial prompt once after agent is ready (and after state is potentially loaded) if c.initialPromptReady && len(c.cfg.InitialPrompt) > 0 && !c.initialPromptSent { c.outboundQueue <- outboundMessage{parts: c.cfg.InitialPrompt, errCh: nil} c.initialPromptSent = true @@ -576,27 +574,34 @@ func (c *PTYConversation) SaveState() error { initialPromptStr = buildStringFromMessageParts(c.cfg.InitialPrompt) } - // Use atomic write: write to temp file, then rename to target path - data, err := json.MarshalIndent(AgentState{ - Version: 1, - Messages: conversation, - InitialPrompt: initialPromptStr, - InitialPromptSent: c.initialPromptSent, - }, "", " ") - if err != nil { - return xerrors.Errorf("failed to marshal state: %w", err) - } - // Create directory if it doesn't exist dir := filepath.Dir(stateFile) if err := os.MkdirAll(dir, 0o700); err != nil { return xerrors.Errorf("failed to create state directory: %w", err) } - // Write to temp file + // Use atomic write: write to temp file, then rename to target path tempFile := stateFile + ".tmp" - if err := os.WriteFile(tempFile, data, 0o600); err != nil { - return xerrors.Errorf("failed to write temp state file: %w", err) + f, err := os.OpenFile(tempFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return xerrors.Errorf("failed to create temp state file: %w", err) + } + + // Encode directly to file to avoid loading entire JSON into memory + encoder := json.NewEncoder(f) + encoder.SetIndent("", " ") + if err := encoder.Encode(AgentState{ + Version: 1, + Messages: conversation, + InitialPrompt: initialPromptStr, + InitialPromptSent: c.initialPromptSent, + }); err != nil { + return xerrors.Errorf("failed to encode state: %w", err) + } + + // Close file before rename + if err := f.Close(); err != nil { + return xerrors.Errorf("failed to close temp state file: %w", err) } // Atomic rename @@ -641,10 +646,6 @@ func (c *PTYConversation) loadStateLocked() error { var agentState AgentState decoder := json.NewDecoder(f) if err := decoder.Decode(&agentState); err != nil { - if err == io.EOF { - c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) - return nil - } return xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) } diff --git a/lib/termexec/termexec.go b/lib/termexec/termexec.go index edad9b13..05403690 100644 --- a/lib/termexec/termexec.go +++ b/lib/termexec/termexec.go @@ -163,7 +163,7 @@ func (p *Process) Close(logger *slog.Logger, timeout time.Duration) error { case err := <-exited: var pathErr *os.SyscallError // ECHILD is expected if the process has already exited - if err != nil && !(errors.As(err, &pathErr) && pathErr.Err == syscall.ECHILD) { + if err != nil && !(errors.As(err, &pathErr) && errors.Is(pathErr.Err, syscall.ECHILD)) { exitErr = xerrors.Errorf("process exited with error: %w", err) } }