diff --git a/chat/src/components/chat-provider.tsx b/chat/src/components/chat-provider.tsx index 21a2ee3f..55db1e2b 100644 --- a/chat/src/components/chat-provider.tsx +++ b/chat/src/components/chat-provider.tsx @@ -303,10 +303,9 @@ export function ChatProvider({ children }: PropsWithChildren) { description: message, }); } finally { + // Remove optimistic draft message if still present (may have been replaced by server response via SSE). + setMessages((prev) => prev.filter((m) => !isDraftMessage(m))); if (type === "user") { - setMessages((prevMessages) => - prevMessages.filter((m) => !isDraftMessage(m)) - ); setLoading(false); } } diff --git a/cmd/attach/attach.go b/cmd/attach/attach.go index 17516398..747b0a5c 100644 --- a/cmd/attach/attach.go +++ b/cmd/attach/attach.go @@ -129,7 +129,42 @@ func WriteRawInputOverHTTP(ctx context.Context, url string, msg string) error { return nil } +// statusResponse is used to parse the /status endpoint response. +type statusResponse struct { + Status string `json:"status"` + AgentType string `json:"agent_type"` + Backend string `json:"backend"` +} + +func checkACPMode(remoteUrl string) error { + resp, err := http.Get(remoteUrl + "/status") + if err != nil { + return xerrors.Errorf("failed to check server status: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return xerrors.Errorf("unexpected %d response from server: %s", resp.StatusCode, resp.Status) + } + + var status statusResponse + if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { + return xerrors.Errorf("failed to decode server status: %w", err) + } + + if status.Backend == "acp" { + return xerrors.New("attach is not supported in ACP mode. The server is running with --experimental-acp which uses JSON-RPC instead of terminal emulation.") + } + + return nil +} + func runAttach(remoteUrl string) error { + // Check if server is running in ACP mode (attach not supported) + if err := checkACPMode(remoteUrl); err != nil { + return err + } + ctx, cancel := context.WithCancel(context.Background()) defer cancel() stdin := int(os.Stdin.Fd()) diff --git a/cmd/server/server.go b/cmd/server/server.go index 6d5cdec3..15c14af7 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -19,6 +19,7 @@ import ( "github.com/coder/agentapi/lib/httpapi" "github.com/coder/agentapi/lib/logctx" "github.com/coder/agentapi/lib/msgfmt" + st "github.com/coder/agentapi/lib/screentracker" "github.com/coder/agentapi/lib/termexec" ) @@ -104,11 +105,33 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } printOpenAPI := viper.GetBool(FlagPrintOpenAPI) + experimentalACP := viper.GetBool(FlagExperimentalACP) + + if printOpenAPI && experimentalACP { + return xerrors.Errorf("flags --%s and --%s are mutually exclusive", FlagPrintOpenAPI, FlagExperimentalACP) + } + + var agentIO st.AgentIO + var transport = "pty" var process *termexec.Process + var acpResult *httpapi.SetupACPResult + if printOpenAPI { - process = nil + agentIO = nil + } else if experimentalACP { + var err error + acpResult, err = httpapi.SetupACP(ctx, httpapi.SetupACPConfig{ + Program: agent, + ProgramArgs: argsToPass[1:], + }) + if err != nil { + return xerrors.Errorf("failed to setup ACP: %w", err) + } + acpIO := acpResult.AgentIO + agentIO = acpIO + transport = "acp" } else { - process, err = httpapi.SetupProcess(ctx, httpapi.SetupProcessConfig{ + proc, err := httpapi.SetupProcess(ctx, httpapi.SetupProcessConfig{ Program: agent, ProgramArgs: argsToPass[1:], TerminalWidth: termWidth, @@ -118,11 +141,14 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er if err != nil { return xerrors.Errorf("failed to setup process: %w", err) } + process = proc + agentIO = proc } port := viper.GetInt(FlagPort) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: agentType, - Process: process, + AgentIO: agentIO, + Transport: transport, Port: port, ChatBasePath: viper.GetString(FlagChatBasePath), AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), @@ -138,19 +164,35 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) - go func() { - defer close(processExitCh) - if err := process.Wait(); err != nil { - if errors.Is(err, termexec.ErrNonZeroExitCode) { - processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err) - } else { - processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) + // Wait for process exit in PTY mode + if process != nil { + go func() { + defer close(processExitCh) + if err := process.Wait(); err != nil { + if errors.Is(err, termexec.ErrNonZeroExitCode) { + processExitCh <- xerrors.Errorf("========\n%s\n========\n: %w", strings.TrimSpace(process.ReadScreen()), err) + } else { + processExitCh <- xerrors.Errorf("failed to wait for process: %w", err) + } } - } - if err := srv.Stop(ctx); err != nil { - logger.Error("Failed to stop server", "error", err) - } - }() + if err := srv.Stop(ctx); err != nil { + logger.Error("Failed to stop server", "error", err) + } + }() + } + // Wait for process exit in ACP mode + if acpResult != nil { + go func() { + defer close(processExitCh) + defer close(acpResult.Done) // Signal cleanup goroutine to exit + if err := acpResult.Wait(); err != nil { + processExitCh <- xerrors.Errorf("ACP process exited: %w", err) + } + if err := srv.Stop(ctx); err != nil { + logger.Error("Failed to stop server", "error", err) + } + }() + } if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { return xerrors.Errorf("failed to start server: %w", err) } @@ -180,16 +222,17 @@ type flagSpec struct { } const ( - FlagType = "type" - FlagPort = "port" - FlagPrintOpenAPI = "print-openapi" - FlagChatBasePath = "chat-base-path" - FlagTermWidth = "term-width" - FlagTermHeight = "term-height" - FlagAllowedHosts = "allowed-hosts" - FlagAllowedOrigins = "allowed-origins" - FlagExit = "exit" - FlagInitialPrompt = "initial-prompt" + FlagType = "type" + FlagPort = "port" + FlagPrintOpenAPI = "print-openapi" + FlagChatBasePath = "chat-base-path" + FlagTermWidth = "term-width" + FlagTermHeight = "term-height" + FlagAllowedHosts = "allowed-hosts" + FlagAllowedOrigins = "allowed-origins" + FlagExit = "exit" + FlagInitialPrompt = "initial-prompt" + FlagExperimentalACP = "experimental-acp" ) func CreateServerCmd() *cobra.Command { @@ -228,6 +271,7 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, + {FlagExperimentalACP, "", false, "Use experimental ACP transport instead of PTY", "bool"}, } for _, spec := range flagSpecs { diff --git a/lib/httpapi/models.go b/lib/httpapi/models.go index 7ed52c43..76ba824d 100644 --- a/lib/httpapi/models.go +++ b/lib/httpapi/models.go @@ -38,6 +38,7 @@ type StatusResponse struct { Body struct { Status AgentStatus `json:"status" doc:"Current agent status. 'running' means that the agent is processing a message, 'stable' means that the agent is idle and waiting for input."` AgentType mf.AgentType `json:"agent_type" doc:"Type of the agent being used by the server."` + Backend string `json:"backend" doc:"Backend transport being used ('acp' or 'pty')."` } } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 956cfb8a..1f255101 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -24,6 +24,7 @@ import ( mf "github.com/coder/agentapi/lib/msgfmt" st "github.com/coder/agentapi/lib/screentracker" "github.com/coder/agentapi/lib/termexec" + "github.com/coder/agentapi/x/acpio" "github.com/coder/quartz" "github.com/danielgtaylor/huma/v2" "github.com/danielgtaylor/huma/v2/adapters/humachi" @@ -42,12 +43,13 @@ type Server struct { mu sync.RWMutex logger *slog.Logger conversation st.Conversation - agentio *termexec.Process + agentio st.AgentIO agentType mf.AgentType emitter *EventEmitter chatBasePath string tempDir string clock quartz.Clock + transport string } func (s *Server) NormalizeSchema(schema any) any { @@ -98,7 +100,8 @@ const snapshotInterval = 25 * time.Millisecond type ServerConfig struct { AgentType mf.AgentType - Process *termexec.Process + AgentIO st.AgentIO + Transport string Port int ChatBasePath string AllowedHosts []string @@ -252,18 +255,34 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { initialPrompt = FormatMessage(config.AgentType, config.InitialPrompt) } - conversation := st.NewPTY(ctx, st.PTYConversationConfig{ - AgentType: config.AgentType, - AgentIO: config.Process, - Clock: config.Clock, - SnapshotInterval: snapshotInterval, - ScreenStabilityLength: 2 * time.Second, - FormatMessage: formatMessage, - ReadyForInitialPrompt: isAgentReadyForInitialPrompt, - FormatToolCall: formatToolCall, - InitialPrompt: initialPrompt, - Logger: logger, - }, emitter) + // Create appropriate conversation based on transport type + var conversation st.Conversation + if config.Transport == "acp" { + // For ACP, cast AgentIO to *acpio.ACPAgentIO + acpIO, ok := config.AgentIO.(*acpio.ACPAgentIO) + if !ok { + return nil, fmt.Errorf("ACP transport requires ACPAgentIO") + } + conversation = acpio.NewACPConversation(ctx, acpIO, logger, initialPrompt, emitter, config.Clock) + } else { + // Default to PTY transport + proc, ok := config.AgentIO.(*termexec.Process) + if !ok && config.AgentIO != nil { + return nil, fmt.Errorf("PTY transport requires termexec.Process") + } + conversation = st.NewPTY(ctx, st.PTYConversationConfig{ + AgentType: config.AgentType, + AgentIO: proc, + Clock: config.Clock, + SnapshotInterval: snapshotInterval, + ScreenStabilityLength: 2 * time.Second, + FormatMessage: formatMessage, + ReadyForInitialPrompt: isAgentReadyForInitialPrompt, + FormatToolCall: formatToolCall, + InitialPrompt: initialPrompt, + Logger: logger, + }, emitter) + } // Create temporary directory for uploads tempDir, err := os.MkdirTemp("", "agentapi-uploads-") @@ -278,24 +297,25 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { port: config.Port, conversation: conversation, logger: logger, - agentio: config.Process, + agentio: config.AgentIO, agentType: config.AgentType, emitter: emitter, chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), tempDir: tempDir, clock: config.Clock, + transport: config.Transport, } // Register API routes s.registerRoutes() - // Start the conversation polling loop if we have a process. - // Process is nil only when --print-openapi is used (no agent runs). - // The process is already running at this point - termexec.StartProcess() - // blocks until the PTY is created and the process is active. Agent - // readiness (waiting for the prompt) is handled asynchronously inside - // conversation.Start() via ReadyForInitialPrompt. - if config.Process != nil { + // Start the conversation polling loop if we have an agent IO. + // AgentIO is nil only when --print-openapi is used (no agent runs). + // For PTY transport, the process is already running at this point - + // termexec.StartProcess() blocks until the PTY is created and the process + // is active. Agent readiness (waiting for the prompt) is handled + // asynchronously inside conversation.Start() via ReadyForInitialPrompt. + if config.AgentIO != nil { s.conversation.Start(ctx) } @@ -417,6 +437,7 @@ func (s *Server) getStatus(ctx context.Context, input *struct{}) (*StatusRespons resp := &StatusResponse{} resp.Body.Status = agentStatus resp.Body.AgentType = s.agentType + resp.Body.Backend = s.transport return resp, nil } diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index c8e8b23c..712df95e 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -29,7 +29,7 @@ func TestOpenAPISchema(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, @@ -78,7 +78,7 @@ func TestServer_redirectToChat(t *testing.T) { tCtx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(tCtx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: tc.chatBasePath, AllowedHosts: []string{"*"}, @@ -242,7 +242,7 @@ func TestServer_AllowedHosts(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: tc.allowedHosts, @@ -325,7 +325,7 @@ func TestServer_CORSPreflightWithHosts(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: tc.allowedHosts, @@ -484,7 +484,7 @@ func TestServer_CORSOrigins(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing @@ -564,7 +564,7 @@ func TestServer_CORSPreflightOrigins(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing @@ -615,7 +615,7 @@ func TestServer_SSEMiddleware_Events(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, @@ -662,7 +662,7 @@ func TestServer_UploadFiles(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, @@ -817,7 +817,7 @@ func TestServer_UploadFiles_Errors(t *testing.T) { ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ AgentType: msgfmt.AgentTypeClaude, - Process: nil, + AgentIO: nil, Port: 0, ChatBasePath: "/chat", AllowedHosts: []string{"*"}, diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 16203041..447b10b7 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "os/exec" "os/signal" "strings" "syscall" @@ -12,6 +13,8 @@ import ( "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" "github.com/coder/agentapi/lib/termexec" + "github.com/coder/agentapi/x/acpio" + "github.com/coder/quartz" ) type SetupProcessConfig struct { @@ -58,3 +61,76 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro return process, nil } + +type SetupACPConfig struct { + Program string + ProgramArgs []string + Clock quartz.Clock +} + +// SetupACPResult contains the result of setting up an ACP process. +type SetupACPResult struct { + AgentIO *acpio.ACPAgentIO + Wait func() error // Calls cmd.Wait() and returns exit error + Done chan struct{} // Close this when Wait() returns to clean up goroutine +} + +func SetupACP(ctx context.Context, config SetupACPConfig) (*SetupACPResult, error) { + logger := logctx.From(ctx) + + if config.Clock == nil { + config.Clock = quartz.NewReal() + } + + args := config.ProgramArgs + logger.Info(fmt.Sprintf("Running (ACP): %s %s", config.Program, strings.Join(args, " "))) + + cmd := exec.CommandContext(ctx, config.Program, args...) + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdin pipe: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start process: %w", err) + } + + agentIO, err := acpio.NewWithPipes(ctx, stdin, stdout, logger, os.Getwd) + if err != nil { + _ = cmd.Process.Kill() + return nil, fmt.Errorf("failed to initialize ACP connection: %w", err) + } + + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + logger.Info("Context done, closing ACP agent") + // Try graceful shutdown first + _ = cmd.Process.Signal(syscall.SIGTERM) + // Then close pipes + _ = stdin.Close() + _ = stdout.Close() + // Force kill after timeout + config.Clock.AfterFunc(5*time.Second, func() { + _ = cmd.Process.Kill() + }) + return + case <-done: + // Process exited normally, nothing to clean up + return + } + }() + + return &SetupACPResult{ + AgentIO: agentIO, + Wait: cmd.Wait, + Done: done, + }, nil +} + diff --git a/openapi.json b/openapi.json index dda817cc..12c4590b 100644 --- a/openapi.json +++ b/openapi.json @@ -269,6 +269,10 @@ "description": "Type of the agent being used by the server.", "type": "string" }, + "backend": { + "description": "Backend transport being used ('acp' or 'pty').", + "type": "string" + }, "status": { "$ref": "#/components/schemas/AgentStatus", "description": "Current agent status. 'running' means that the agent is processing a message, 'stable' means that the agent is idle and waiting for input." @@ -276,6 +280,7 @@ }, "required": [ "agent_type", + "backend", "status" ], "type": "object"