From b22ccafdb451823627b023d7776845a58a5a1cc8 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 19 Mar 2026 20:29:36 +0100 Subject: [PATCH 1/5] Add WebSocket transport for OpenAI Responses API streaming Introduce an optional WebSocket transport as an alternative to SSE for the OpenAI Responses API. Users can enable it via provider_opts: provider_opts: transport: websocket Key changes: - Add responseEventStream interface to abstract SSE and WebSocket transports - Refactor ResponseStreamAdapter to accept any responseEventStream - Implement wsStream (WebSocket transport) and wsPool (connection pool with 55-min TTL, auto-reconnect, and lastResponseID tracking) - Integrate WebSocket path in CreateResponseStream with automatic SSE fallback on connection failure - No new dependencies (reuses existing gorilla/websocket) The existing ResponseStreamAdapter.Recv() logic is fully reused since WebSocket events use the same JSON schema as SSE events. Assisted-By: docker-agent --- agent-schema.json | 2 +- docs/providers/openai/index.md | 30 ++ examples/websocket_transport.yaml | 42 ++ pkg/model/provider/openai/client.go | 96 +++- pkg/model/provider/openai/event_stream.go | 23 + pkg/model/provider/openai/response_stream.go | 12 +- pkg/model/provider/openai/ws_pool.go | 204 +++++++++ pkg/model/provider/openai/ws_stream.go | 203 +++++++++ pkg/model/provider/openai/ws_stream_test.go | 442 +++++++++++++++++++ 9 files changed, 1040 insertions(+), 14 deletions(-) create mode 100644 examples/websocket_transport.yaml create mode 100644 pkg/model/provider/openai/event_stream.go create mode 100644 pkg/model/provider/openai/ws_pool.go create mode 100644 pkg/model/provider/openai/ws_stream.go create mode 100644 pkg/model/provider/openai/ws_stream_test.go diff --git a/agent-schema.json b/agent-schema.json index 99140e597..abf934761 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -547,7 +547,7 @@ }, "provider_opts": { "type": "object", - "description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).", + "description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).", "additionalProperties": true }, "track_usage": { diff --git a/docs/providers/openai/index.md b/docs/providers/openai/index.md index 4dd20e6af..8be075196 100644 --- a/docs/providers/openai/index.md +++ b/docs/providers/openai/index.md @@ -77,3 +77,33 @@ models: model: gpt-4o base_url: https://your-proxy.example.com/v1 ``` + +## WebSocket Transport + +For OpenAI Responses API models (gpt-4.1+, o-series, gpt-5), you can use WebSocket streaming instead of the default SSE (Server-Sent Events): + +```yaml +models: + fast-gpt: + provider: openai + model: gpt-4.1 + provider_opts: + transport: websocket # Use WebSocket instead of SSE +``` + +### Benefits + +- **~40% faster** for workflows with 20+ tool calls +- **Persistent connection** reduces per-turn overhead +- **Server-side caching** of connection state +- **Automatic fallback** to SSE if WebSocket fails + +### Requirements + +- Only works with Responses API models: `gpt-4.1+`, `o1`, `o3`, `o4`, `gpt-5` +- NOT compatible with `--gateway` flag (automatically falls back to SSE) +- Requires `OPENAI_API_KEY` environment variable + +### Example + +See [`examples/websocket_transport.yaml`]({{ '/examples/websocket_transport/' | relative_url }}) for a complete example. diff --git a/examples/websocket_transport.yaml b/examples/websocket_transport.yaml new file mode 100644 index 000000000..1ab4069d2 --- /dev/null +++ b/examples/websocket_transport.yaml @@ -0,0 +1,42 @@ +#!/usr/bin/env docker agent run + +# Example: WebSocket Transport for OpenAI Responses API +# +# This example demonstrates how to use WebSocket streaming instead of +# Server-Sent Events (SSE) for the OpenAI Responses API. +# +# WebSocket transport maintains a persistent connection across tool-call +# rounds, reducing per-turn overhead and improving end-to-end latency +# for agentic workflows with many tool calls. +# +# Benefits of WebSocket over SSE: +# - ~40% faster end-to-end execution for workflows with 20+ tool calls +# - Persistent connection reduces per-turn continuation overhead +# - Connection-local state caching on the server +# - Falls back to SSE automatically if WebSocket connection fails +# +# Requirements: +# - Works only with OpenAI Responses API models (gpt-4.1+, o-series, gpt-5) +# - Requires OPENAI_API_KEY environment variable (or use token_key) +# - NOT compatible with --gateway flag (automatically falls back to SSE) +# +# Run with: +# docker agent run websocket_transport.yaml + +models: + gpt-ws: + provider: openai + model: gpt-4.1 + provider_opts: + transport: websocket # Use WebSocket instead of SSE + +agents: + root: + model: gpt-ws + description: Assistant using WebSocket streaming + instruction: | + You are a helpful assistant. Answer questions concisely. + toolsets: + - type: shell # Real toolset for demonstrating multi-turn tool calls + commands: + demo: "List the files in the current directory, then count how many are YAML files" diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index e9956f08c..fafd73581 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "log/slog" + "net/http" "net/url" + "os" "strings" "github.com/openai/openai-go/v3" @@ -29,12 +31,16 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) -// Client represents an OpenAI client wrapper -// It implements the provider.Provider interface +// Client represents an OpenAI client wrapper. +// It implements the provider.Provider interface. type Client struct { base.Config clientFn func(context.Context) (*openai.Client, error) + + // wsPool is lazily initialized when transport=websocket is configured. + // It maintains a persistent WebSocket connection across requests. + wsPool *wsPool } // NewClient creates a new OpenAI client from the provided configuration @@ -307,12 +313,6 @@ func (c *Client) CreateResponseStream( return nil, errors.New("at least one message is required") } - client, err := c.clientFn(ctx) - if err != nil { - slog.Error("Failed to create OpenAI client", "error", err) - return nil, err - } - input := convertMessagesToResponseInput(messages) params := responses.ResponseNewParams{ @@ -398,10 +398,88 @@ func (c *Client) CreateResponseStream( slog.Error("Failed to marshal OpenAI responses request to JSON", "error", err) } + // Choose transport: WebSocket or SSE (default). + // WebSocket is disabled when using a Gateway since most gateways don't support it. + transport := getTransport(&c.ModelConfig) + trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage + + if transport == "websocket" && c.ModelOptions.Gateway() == "" { + stream, err := c.createWebSocketStream(ctx, params) + if err != nil { + slog.Error("WebSocket stream failed, falling back to SSE", "error", err) + // Fall through to SSE below. + } else { + slog.Debug("OpenAI responses WebSocket stream created successfully", "model", c.ModelConfig.Model) + return newResponseStreamAdapter(stream, trackUsage), nil + } + } else if transport == "websocket" { + slog.Debug("WebSocket transport requested but Gateway is configured, using SSE", + "model", c.ModelConfig.Model, + "gateway", c.ModelOptions.Gateway()) + } + + client, err := c.clientFn(ctx) + if err != nil { + slog.Error("Failed to create OpenAI client", "error", err) + return nil, err + } stream := client.Responses.NewStreaming(ctx, params) slog.Debug("OpenAI responses stream created successfully", "model", c.ModelConfig.Model) - return newResponseStreamAdapter(stream, c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage), nil + return newResponseStreamAdapter(stream, trackUsage), nil +} + +// createWebSocketStream initializes (or reuses) a WebSocket connection and +// sends the response.create message, returning a responseEventStream. +func (c *Client) createWebSocketStream( + ctx context.Context, + params responses.ResponseNewParams, +) (responseEventStream, error) { + if c.wsPool == nil { + // Lazy-init the pool on first WebSocket call. + baseURL := cmp.Or(c.ModelConfig.BaseURL, "https://api.openai.com/v1") + wsURL := httpToWSURL(baseURL) + + headerFn := c.buildWSHeaderFn() + c.wsPool = newWSPool(wsURL, headerFn) + } + + return c.wsPool.Stream(ctx, params) +} + +// buildWSHeaderFn returns a function that produces the HTTP headers needed +// for the WebSocket handshake, including the Authorization header. +func (c *Client) buildWSHeaderFn() func(ctx context.Context) (http.Header, error) { + return func(ctx context.Context) (http.Header, error) { + h := http.Header{} + + // Resolve the API key using the same logic as the HTTP client. + var apiKey string + if c.ModelConfig.TokenKey != "" { + apiKey, _ = c.Env.Get(ctx, c.ModelConfig.TokenKey) + } + if apiKey == "" { + // Fall back to the standard OPENAI_API_KEY env var. + apiKey = os.Getenv("OPENAI_API_KEY") + } + if apiKey != "" { + h.Set("Authorization", "Bearer "+apiKey) + } + + return h, nil + } +} + +// getTransport returns the streaming transport preference from ProviderOpts. +// Valid values are "sse" (default) and "websocket". +func getTransport(cfg *latest.ModelConfig) string { + if cfg == nil || cfg.ProviderOpts == nil { + return "sse" + } + if t, ok := cfg.ProviderOpts["transport"].(string); ok { + return strings.ToLower(t) + } + return "sse" } func convertMessagesToResponseInput(messages []chat.Message) []responses.ResponseInputItemUnionParam { diff --git a/pkg/model/provider/openai/event_stream.go b/pkg/model/provider/openai/event_stream.go new file mode 100644 index 000000000..16b533cae --- /dev/null +++ b/pkg/model/provider/openai/event_stream.go @@ -0,0 +1,23 @@ +package openai + +import "github.com/openai/openai-go/v3/responses" + +// responseEventStream abstracts over SSE and WebSocket transports for +// streaming Responses API events. +// +// The ssestream.Stream[responses.ResponseStreamEventUnion] type already +// satisfies this interface, so it can be used directly. +type responseEventStream interface { + // Next advances the stream to the next event. + // Returns false when the stream is exhausted or an error occurred. + Next() bool + + // Current returns the most recently decoded event. + Current() responses.ResponseStreamEventUnion + + // Err returns the first non-EOF error encountered by the stream. + Err() error + + // Close releases resources held by the stream. + Close() error +} diff --git a/pkg/model/provider/openai/response_stream.go b/pkg/model/provider/openai/response_stream.go index 065cba32c..260e5b55f 100644 --- a/pkg/model/provider/openai/response_stream.go +++ b/pkg/model/provider/openai/response_stream.go @@ -13,15 +13,19 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) -// ResponseStreamAdapter adapts the OpenAI responses stream to our interface +// Compile-time check: ssestream.Stream satisfies responseEventStream. +var _ responseEventStream = (*ssestream.Stream[responses.ResponseStreamEventUnion])(nil) + +// ResponseStreamAdapter adapts the OpenAI responses stream to our interface. +// It works with any responseEventStream implementation (SSE or WebSocket). type ResponseStreamAdapter struct { - stream *ssestream.Stream[responses.ResponseStreamEventUnion] + stream responseEventStream trackUsage bool itemCallIDMap map[string]string itemHasContent map[string]bool } -func newResponseStreamAdapter(stream *ssestream.Stream[responses.ResponseStreamEventUnion], trackUsage bool) *ResponseStreamAdapter { +func newResponseStreamAdapter(stream responseEventStream, trackUsage bool) *ResponseStreamAdapter { return &ResponseStreamAdapter{ stream: stream, trackUsage: trackUsage, @@ -254,5 +258,5 @@ func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) { // Close closes the stream func (a *ResponseStreamAdapter) Close() { - a.stream.Close() + _ = a.stream.Close() } diff --git a/pkg/model/provider/openai/ws_pool.go b/pkg/model/provider/openai/ws_pool.go new file mode 100644 index 000000000..9323fc30f --- /dev/null +++ b/pkg/model/provider/openai/ws_pool.go @@ -0,0 +1,204 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/openai/openai-go/v3/responses" +) + +const ( + // wsMaxConnectionAge is the maximum lifetime of a WebSocket connection. + // OpenAI enforces a 60-minute limit; we reconnect slightly earlier. + wsMaxConnectionAge = 55 * time.Minute +) + +// wsConnection holds a WebSocket connection together with bookkeeping +// metadata for the connection pool. +type wsConnection struct { + conn *websocket.Conn + createdAt time.Time + + // lastResponseID is the ID of the most recent response completed on + // this connection. It can be passed as previous_response_id in subsequent + // requests to enable server-side context caching. + lastResponseID string +} + +// isExpired returns true when the connection has been open longer than +// wsMaxConnectionAge. +func (c *wsConnection) isExpired() bool { + return time.Since(c.createdAt) >= wsMaxConnectionAge +} + +// wsPool manages a single reusable WebSocket connection to the OpenAI +// Responses API. It is safe for concurrent use; however, because the +// OpenAI WebSocket protocol is sequential (one response at a time), +// callers must not overlap requests on the same pool. +type wsPool struct { + mu sync.Mutex + conn *wsConnection + + // wsURL is the WebSocket endpoint (e.g. wss://api.openai.com/v1/responses). + wsURL string + + // headerFn returns the HTTP headers (including Authorization) for + // the WebSocket handshake. It is called each time a new connection + // is established so that short-lived tokens are refreshed. + headerFn func(ctx context.Context) (http.Header, error) +} + +// newWSPool creates a pool for the given WebSocket URL. +func newWSPool(wsURL string, headerFn func(ctx context.Context) (http.Header, error)) *wsPool { + return &wsPool{ + wsURL: wsURL, + headerFn: headerFn, + } +} + +// Stream opens (or reuses) a WebSocket connection, sends a response.create +// message, and returns a responseEventStream that yields server events. +func (p *wsPool) Stream( + ctx context.Context, + params responses.ResponseNewParams, +) (responseEventStream, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Close stale connections. + if p.conn != nil && p.conn.isExpired() { + slog.Debug("Closing expired WebSocket connection", + "age", time.Since(p.conn.createdAt)) + _ = p.conn.conn.Close() + p.conn = nil + } + + // Establish a new connection if needed. + if p.conn == nil { + headers, err := p.headerFn(ctx) + if err != nil { + return nil, fmt.Errorf("websocket pool: headers: %w", err) + } + + stream, err := dialWebSocket(ctx, p.wsURL, headers, params) + if err != nil { + return nil, err + } + + p.conn = &wsConnection{ + conn: stream.conn, + createdAt: time.Now(), + } + + return &pooledStream{pool: p, inner: stream}, nil + } + + // Reuse existing connection: send a new response.create. + stream, err := sendOnExisting(p.conn.conn, params) + if err != nil { + // Connection is broken; tear down and retry with a fresh one. + slog.Warn("Existing WebSocket connection failed, reconnecting", "error", err) + _ = p.conn.conn.Close() + p.conn = nil + + headers, err2 := p.headerFn(ctx) + if err2 != nil { + return nil, fmt.Errorf("websocket pool: headers on reconnect: %w", err2) + } + stream, err2 = dialWebSocket(ctx, p.wsURL, headers, params) + if err2 != nil { + return nil, fmt.Errorf("websocket pool: reconnect: %w", err2) + } + p.conn = &wsConnection{ + conn: stream.conn, + createdAt: time.Now(), + } + return &pooledStream{pool: p, inner: stream}, nil + } + + return &pooledStream{pool: p, inner: stream}, nil +} + +// Close shuts down the pooled connection. +func (p *wsPool) Close() { + p.mu.Lock() + defer p.mu.Unlock() + + if p.conn != nil { + _ = p.conn.conn.Close() + p.conn = nil + } +} + +// sendOnExisting sends a response.create on an already-open connection and +// returns a wsStream that reads events from it. +func sendOnExisting(conn *websocket.Conn, params responses.ResponseNewParams) (*wsStream, error) { + paramsJSON, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("websocket: marshal params: %w", err) + } + + msg := wsCreateMessage{ + Type: "response.create", + Params: paramsJSON, + } + + if err := conn.WriteJSON(msg); err != nil { + return nil, fmt.Errorf("websocket: write response.create: %w", err) + } + + slog.Debug("WebSocket response.create sent (reused connection)") + + return &wsStream{conn: conn}, nil +} + +// pooledStream wraps a wsStream and updates pool state when the response +// finishes. Its Close does NOT close the underlying WebSocket connection +// (which is owned by the pool). +type pooledStream struct { + pool *wsPool + inner *wsStream +} + +var _ responseEventStream = (*pooledStream)(nil) + +func (s *pooledStream) Next() bool { + ok := s.inner.Next() + if !ok { + return false + } + + // Track response ID from terminal events for future continuation. + event := s.inner.Current() + if isTerminalEvent(event.Type) && event.Response.ID != "" { + s.pool.mu.Lock() + if s.pool.conn != nil { + s.pool.conn.lastResponseID = event.Response.ID + } + s.pool.mu.Unlock() + } + + return true +} + +func (s *pooledStream) Current() responses.ResponseStreamEventUnion { + return s.inner.Current() +} + +func (s *pooledStream) Err() error { + return s.inner.Err() +} + +// Close releases the stream but keeps the underlying connection alive in +// the pool for reuse. +func (s *pooledStream) Close() error { + s.inner.done = true + // Do NOT close the WebSocket connection—it stays in the pool. + return nil +} diff --git a/pkg/model/provider/openai/ws_stream.go b/pkg/model/provider/openai/ws_stream.go new file mode 100644 index 000000000..2bf7a8be0 --- /dev/null +++ b/pkg/model/provider/openai/ws_stream.go @@ -0,0 +1,203 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/openai/openai-go/v3/responses" +) + +const ( + // wsHandshakeTimeout is the maximum time allowed for the WebSocket handshake. + wsHandshakeTimeout = 45 * time.Second +) + +// wsCreateMessage is the envelope sent over WebSocket to start a new response. +// It wraps ResponseNewParams with the required "type" discriminator. +type wsCreateMessage struct { + Type string `json:"type"` + + // Embed the params as a raw message so that its MarshalJSON is used + // and we simply add the "type" key on top. + Params json.RawMessage `json:"-"` +} + +func (m wsCreateMessage) MarshalJSON() ([]byte, error) { + // Marshal the params first, then inject "type". + raw := m.Params + if raw == nil { + raw = []byte("{}") + } + + // Merge: start with the params object, add "type". + var obj map[string]json.RawMessage + if err := json.Unmarshal(raw, &obj); err != nil { + return nil, fmt.Errorf("wsCreateMessage: unmarshal params: %w", err) + } + typeVal, _ := json.Marshal(m.Type) + obj["type"] = typeVal + return json.Marshal(obj) +} + +// wsStream implements responseEventStream for a single request/response +// exchange over a WebSocket connection. +// +// After the terminal event (response.completed, response.failed, etc.) is +// delivered via Current(), the next call to Next() returns false. +type wsStream struct { + conn *websocket.Conn + current responses.ResponseStreamEventUnion + err error + done bool +} + +// Compile-time check: wsStream satisfies responseEventStream. +var _ responseEventStream = (*wsStream)(nil) + +// dialWebSocket opens a WebSocket connection, sends the response.create +// message, and returns a stream that yields server events. +func dialWebSocket( + ctx context.Context, + wsURL string, + headers http.Header, + params responses.ResponseNewParams, +) (*wsStream, error) { + dialer := websocket.Dialer{ + HandshakeTimeout: wsHandshakeTimeout, + } + + slog.Debug("Opening WebSocket connection", "url", wsURL) + + conn, resp, err := dialer.DialContext(ctx, wsURL, headers) + if err != nil { + if resp != nil { + slog.Error("WebSocket handshake failed", + "status", resp.StatusCode, + "error", err) + } + return nil, fmt.Errorf("websocket dial %s: %w", wsURL, err) + } + + // Marshal the params using the SDK's MarshalJSON so all field + // encodings (omitzero, unions, etc.) are handled correctly. + paramsJSON, err := json.Marshal(params) + if err != nil { + conn.Close() + return nil, fmt.Errorf("websocket: marshal params: %w", err) + } + + msg := wsCreateMessage{ + Type: "response.create", + Params: paramsJSON, + } + + if err := conn.WriteJSON(msg); err != nil { + conn.Close() + return nil, fmt.Errorf("websocket: write response.create: %w", err) + } + + slog.Debug("WebSocket response.create sent", "url", wsURL) + + return &wsStream{conn: conn}, nil +} + +// Next reads the next event from the WebSocket. Returns false when the +// response is complete or an error occurred. +func (s *wsStream) Next() bool { + if s.done { + return false + } + + _, data, err := s.conn.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + ) { + s.done = true + return false + } + s.err = fmt.Errorf("websocket read: %w", err) + s.done = true + return false + } + + var event responses.ResponseStreamEventUnion + if err := json.Unmarshal(data, &event); err != nil { + s.err = fmt.Errorf("websocket unmarshal event: %w", err) + s.done = true + return false + } + + s.current = event + + slog.Debug("WebSocket event received", "type", event.Type) + + // Check for server-side error events. + if event.Type == "error" { + s.err = fmt.Errorf("openai websocket error: %s (param: %s)", event.Message, event.Param) + s.done = true + // Still return true so the caller can inspect the event. + return true + } + + // Terminal events: deliver this event then stop on next call. + if isTerminalEvent(event.Type) { + s.done = true + // Return true so the adapter receives usage/finish data. + return true + } + + return true +} + +// Current returns the most recently decoded event. +func (s *wsStream) Current() responses.ResponseStreamEventUnion { + return s.current +} + +// Err returns the first non-EOF error encountered by the stream. +func (s *wsStream) Err() error { + return s.err +} + +// Close sends a close frame and releases the connection. +func (s *wsStream) Close() error { + s.done = true + // Best-effort close handshake. + _ = s.conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ) + return s.conn.Close() +} + +// isTerminalEvent returns true for event types that signal the end of a +// response on the WebSocket stream. +func isTerminalEvent(eventType string) bool { + switch eventType { + case "response.completed", "response.done", + "response.failed", "response.incomplete": + return true + default: + return false + } +} + +// httpToWSURL converts an HTTP(S) base URL to its WebSocket equivalent. +// "https://api.openai.com/v1" → "wss://api.openai.com/v1/responses" +func httpToWSURL(baseURL string) string { + u := strings.TrimRight(baseURL, "/") + u = strings.Replace(u, "https://", "wss://", 1) + u = strings.Replace(u, "http://", "ws://", 1) + if !strings.HasSuffix(u, "/responses") { + u += "/responses" + } + return u +} diff --git a/pkg/model/provider/openai/ws_stream_test.go b/pkg/model/provider/openai/ws_stream_test.go new file mode 100644 index 000000000..90f6e5427 --- /dev/null +++ b/pkg/model/provider/openai/ws_stream_test.go @@ -0,0 +1,442 @@ +package openai + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/openai/openai-go/v3/responses" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" +) + +// testWSServer starts an httptest.Server that upgrades to WebSocket, +// reads the response.create message, and sends back the given events +// as JSON text frames. +func testWSServer(t *testing.T, events []map[string]any) *httptest.Server { + t.Helper() + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("WebSocket upgrade failed: %v", err) + return + } + defer conn.Close() + + // Read the response.create message. + _, data, err := conn.ReadMessage() + if err != nil { + t.Errorf("Failed to read response.create: %v", err) + return + } + + var createMsg map[string]any + if err := json.Unmarshal(data, &createMsg); err != nil { + t.Errorf("Failed to unmarshal response.create: %v", err) + return + } + assert.Equal(t, "response.create", createMsg["type"]) + + // Send events. + for _, event := range events { + eventData, _ := json.Marshal(event) + if err := conn.WriteMessage(websocket.TextMessage, eventData); err != nil { + return + } + } + + // Close the connection after sending all events. + _ = conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ) + })) +} + +func TestWSStream_TextDelta(t *testing.T) { + t.Parallel() + + events := []map[string]any{ + { + "type": "response.output_text.delta", + "delta": "Hello ", + "item_id": "item_1", + }, + { + "type": "response.output_text.delta", + "delta": "World!", + "item_id": "item_1", + }, + { + "type": "response.completed", + "response": map[string]any{ + "id": "resp_123", + "output": []any{}, + "usage": map[string]any{ + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + "output_tokens_details": map[string]any{ + "reasoning_tokens": 0, + }, + }, + }, + }, + } + + srv := testWSServer(t, events) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + stream, err := dialWebSocket(t.Context(), wsURL, http.Header{}, defaultTestParams()) + require.NoError(t, err) + defer stream.Close() + + adapter := newResponseStreamAdapter(stream, true) + + // First delta + resp, err := adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + assert.Equal(t, "Hello ", resp.Choices[0].Delta.Content) + + // Second delta + resp, err = adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + assert.Equal(t, "World!", resp.Choices[0].Delta.Content) + + // response.completed → finish reason + usage + resp, err = adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + assert.Equal(t, chat.FinishReasonStop, resp.Choices[0].FinishReason) + require.NotNil(t, resp.Usage) + assert.Equal(t, int64(10), resp.Usage.InputTokens) + assert.Equal(t, int64(5), resp.Usage.OutputTokens) + + // Stream is exhausted. + _, err = adapter.Recv() + assert.ErrorIs(t, err, io.EOF) +} + +func TestWSStream_ToolCall(t *testing.T) { + t.Parallel() + + events := []map[string]any{ + { + "type": "response.output_item.added", + "item_id": "item_2", + "item": map[string]any{ + "type": "function_call", + "id": "item_2", + "call_id": "call_abc", + "name": "get_weather", + }, + }, + { + "type": "response.function_call_arguments.delta", + "item_id": "item_2", + "delta": `{"city":`, + }, + { + "type": "response.function_call_arguments.delta", + "item_id": "item_2", + "delta": `"Paris"}`, + }, + { + "type": "response.function_call_arguments.done", + "item_id": "item_2", + }, + { + "type": "response.completed", + "response": map[string]any{ + "id": "resp_456", + "output": []any{ + map[string]any{"type": "function_call"}, + }, + "usage": map[string]any{ + "input_tokens": 8, + "output_tokens": 12, + "total_tokens": 20, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + "output_tokens_details": map[string]any{ + "reasoning_tokens": 0, + }, + }, + }, + }, + } + + srv := testWSServer(t, events) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + stream, err := dialWebSocket(t.Context(), wsURL, http.Header{}, defaultTestParams()) + require.NoError(t, err) + defer stream.Close() + + adapter := newResponseStreamAdapter(stream, true) + + // output_item.added → tool call with name + resp, err := adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + require.Len(t, resp.Choices[0].Delta.ToolCalls, 1) + assert.Equal(t, "get_weather", resp.Choices[0].Delta.ToolCalls[0].Function.Name) + assert.Equal(t, "call_abc", resp.Choices[0].Delta.ToolCalls[0].ID) + + // arguments delta 1 + resp, err = adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + require.Len(t, resp.Choices[0].Delta.ToolCalls, 1) + assert.Equal(t, `{"city":`, resp.Choices[0].Delta.ToolCalls[0].Function.Arguments) + + // arguments delta 2 + resp, err = adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + require.Len(t, resp.Choices[0].Delta.ToolCalls, 1) + assert.Equal(t, `"Paris"}`, resp.Choices[0].Delta.ToolCalls[0].Function.Arguments) + + // arguments done → empty response (no choices) + resp, err = adapter.Recv() + require.NoError(t, err) + assert.Empty(t, resp.Choices) + + // response.completed → finish reason tool_calls + resp, err = adapter.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + assert.Equal(t, chat.FinishReasonToolCalls, resp.Choices[0].FinishReason) + + // Stream is exhausted. + _, err = adapter.Recv() + assert.ErrorIs(t, err, io.EOF) +} + +func TestWSStream_ErrorEvent(t *testing.T) { + t.Parallel() + + events := []map[string]any{ + { + "type": "error", + "message": "rate_limit_exceeded", + "param": "", + }, + } + + srv := testWSServer(t, events) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + stream, err := dialWebSocket(t.Context(), wsURL, http.Header{}, defaultTestParams()) + require.NoError(t, err) + defer stream.Close() + + // The error event is still yielded to Recv, then the stream errors. + ok := stream.Next() + assert.True(t, ok) + assert.Equal(t, "error", stream.Current().Type) + require.Error(t, stream.Err()) + assert.Contains(t, stream.Err().Error(), "rate_limit_exceeded") + + // Further calls return false. + assert.False(t, stream.Next()) +} + +func TestHTTPToWSURL(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + expected string + }{ + {"https://api.openai.com/v1", "wss://api.openai.com/v1/responses"}, + {"https://api.openai.com/v1/", "wss://api.openai.com/v1/responses"}, + {"http://localhost:8080/v1", "ws://localhost:8080/v1/responses"}, + {"https://api.openai.com/v1/responses", "wss://api.openai.com/v1/responses"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.expected, httpToWSURL(tt.input)) + }) + } +} + +func TestGetTransport(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *latest.ModelConfig + expected string + }{ + { + name: "nil config", + config: nil, + expected: "sse", + }, + { + name: "no provider opts", + config: &latest.ModelConfig{}, + expected: "sse", + }, + { + name: "transport=websocket", + config: &latest.ModelConfig{ + ProviderOpts: map[string]any{"transport": "websocket"}, + }, + expected: "websocket", + }, + { + name: "transport=WebSocket (case insensitive)", + config: &latest.ModelConfig{ + ProviderOpts: map[string]any{"transport": "WebSocket"}, + }, + expected: "websocket", + }, + { + name: "transport=sse", + config: &latest.ModelConfig{ + ProviderOpts: map[string]any{"transport": "sse"}, + }, + expected: "sse", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.expected, getTransport(tt.config)) + }) + } +} + +func TestWSStream_EndToEnd_WithClient(t *testing.T) { + t.Parallel() + + events := []map[string]any{ + { + "type": "response.output_text.delta", + "delta": "Hi!", + "item_id": "item_1", + }, + { + "type": "response.completed", + "response": map[string]any{ + "id": "resp_e2e", + "output": []any{}, + "usage": map[string]any{ + "input_tokens": 5, + "output_tokens": 1, + "total_tokens": 6, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + "output_tokens_details": map[string]any{ + "reasoning_tokens": 0, + }, + }, + }, + }, + } + + srv := testWSServer(t, events) + defer srv.Close() + + baseURL := srv.URL + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4.1", + BaseURL: baseURL, + ProviderOpts: map[string]any{ + "api_type": "openai_responses", + "transport": "websocket", + }, + } + + env := environment.NewMapEnvProvider(map[string]string{}) + + client, err := NewClient(t.Context(), cfg, env) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream( + t.Context(), + []chat.Message{{Role: chat.MessageRoleUser, Content: "hello"}}, + nil, + ) + require.NoError(t, err) + defer stream.Close() + + // First event: text delta + resp, err := stream.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + assert.Equal(t, "Hi!", resp.Choices[0].Delta.Content) + + // Second event: completed + resp, err = stream.Recv() + require.NoError(t, err) + require.Len(t, resp.Choices, 1) + assert.Equal(t, chat.FinishReasonStop, resp.Choices[0].FinishReason) + + // Done + _, err = stream.Recv() + assert.ErrorIs(t, err, io.EOF) +} + +func defaultTestParams() responses.ResponseNewParams { + return responses.ResponseNewParams{ + Model: "gpt-4.1", + } +} + +func TestWebSocketDisabledWithGateway(t *testing.T) { + t.Parallel() + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4.1", + ProviderOpts: map[string]any{ + "transport": "websocket", + }, + } + + // Test 1: No gateway - WebSocket should be allowed + transport := getTransport(cfg) + assert.Equal(t, "websocket", transport) + + // Test 2: With gateway - the condition in CreateResponseStream + // checks c.ModelOptions.Gateway() == "" before allowing WebSocket + // We can't easily test the full flow without mocking the Gateway auth, + // but we can verify the getTransport function works correctly + assert.Equal(t, "websocket", getTransport(cfg)) + + // Test 3: SSE is default + cfgNoTransport := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4.1", + } + assert.Equal(t, "sse", getTransport(cfgNoTransport)) +} From 60131885e21e4992d5b0124643e401ba24737224 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 19 Mar 2026 22:50:24 +0100 Subject: [PATCH 2/5] Address PR #2186 review feedback for WebSocket pool - Add Client.Close() to release pooled WebSocket connections - Invalidate broken connections in pooledStream.Close() instead of returning dead sockets to the pool - Preserve lastResponseID across reconnections (expired + broken) so server-side context caching survives connection resets - Add wsMaxReconnectAttempts constant with bounded retry loop to prevent unbounded reconnection attempts - Replace os.Getenv("OPENAI_API_KEY") with c.Env.Get() for consistent secret resolution via the environment provider - Treat websocket.CloseNoStatusReceived as a normal close condition Assisted-By: docker-agent --- pkg/model/provider/openai/client.go | 15 ++++-- pkg/model/provider/openai/ws_pool.go | 67 +++++++++++++++++++------- pkg/model/provider/openai/ws_stream.go | 1 + 3 files changed, 62 insertions(+), 21 deletions(-) diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index fafd73581..ec418dda5 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -9,7 +9,6 @@ import ( "log/slog" "net/http" "net/url" - "os" "strings" "github.com/openai/openai-go/v3" @@ -156,6 +155,14 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro }, nil } +// Close releases resources held by the client, including any pooled WebSocket +// connections. It is safe to call Close multiple times. +func (c *Client) Close() { + if c.wsPool != nil { + c.wsPool.Close() + } +} + // convertMessages converts chat.Message to openai.ChatCompletionMessageParamUnion // using the shared oaistream implementation. func convertMessages(messages []chat.Message) []openai.ChatCompletionMessageParamUnion { @@ -459,8 +466,10 @@ func (c *Client) buildWSHeaderFn() func(ctx context.Context) (http.Header, error apiKey, _ = c.Env.Get(ctx, c.ModelConfig.TokenKey) } if apiKey == "" { - // Fall back to the standard OPENAI_API_KEY env var. - apiKey = os.Getenv("OPENAI_API_KEY") + // Fall back to the standard OPENAI_API_KEY env var via the + // environment provider so that secret resolution is + // consistent with the HTTP client path. + apiKey, _ = c.Env.Get(ctx, "OPENAI_API_KEY") } if apiKey != "" { h.Set("Authorization", "Bearer "+apiKey) diff --git a/pkg/model/provider/openai/ws_pool.go b/pkg/model/provider/openai/ws_pool.go index 9323fc30f..efc4bbbde 100644 --- a/pkg/model/provider/openai/ws_pool.go +++ b/pkg/model/provider/openai/ws_pool.go @@ -17,6 +17,11 @@ const ( // wsMaxConnectionAge is the maximum lifetime of a WebSocket connection. // OpenAI enforces a 60-minute limit; we reconnect slightly earlier. wsMaxConnectionAge = 55 * time.Minute + + // wsMaxReconnectAttempts is the maximum number of times a broken + // connection will be replaced with a fresh one within a single + // Stream call before the error is propagated to the caller. + wsMaxReconnectAttempts = 1 ) // wsConnection holds a WebSocket connection together with bookkeeping @@ -71,10 +76,12 @@ func (p *wsPool) Stream( p.mu.Lock() defer p.mu.Unlock() - // Close stale connections. + // Close stale connections, preserving the last response ID. + var prevResponseID string if p.conn != nil && p.conn.isExpired() { slog.Debug("Closing expired WebSocket connection", "age", time.Since(p.conn.createdAt)) + prevResponseID = p.conn.lastResponseID _ = p.conn.conn.Close() p.conn = nil } @@ -92,8 +99,9 @@ func (p *wsPool) Stream( } p.conn = &wsConnection{ - conn: stream.conn, - createdAt: time.Now(), + conn: stream.conn, + createdAt: time.Now(), + lastResponseID: prevResponseID, } return &pooledStream{pool: p, inner: stream}, nil @@ -103,23 +111,33 @@ func (p *wsPool) Stream( stream, err := sendOnExisting(p.conn.conn, params) if err != nil { // Connection is broken; tear down and retry with a fresh one. + // We only attempt wsMaxReconnectAttempts reconnections to avoid + // unbounded loops if the server keeps rejecting connections. slog.Warn("Existing WebSocket connection failed, reconnecting", "error", err) + prevResponseID := p.conn.lastResponseID _ = p.conn.conn.Close() p.conn = nil - headers, err2 := p.headerFn(ctx) - if err2 != nil { - return nil, fmt.Errorf("websocket pool: headers on reconnect: %w", err2) - } - stream, err2 = dialWebSocket(ctx, p.wsURL, headers, params) - if err2 != nil { - return nil, fmt.Errorf("websocket pool: reconnect: %w", err2) - } - p.conn = &wsConnection{ - conn: stream.conn, - createdAt: time.Now(), + var lastErr error + for attempt := range wsMaxReconnectAttempts { + headers, err2 := p.headerFn(ctx) + if err2 != nil { + lastErr = fmt.Errorf("websocket pool: headers on reconnect (attempt %d/%d): %w", attempt+1, wsMaxReconnectAttempts, err2) + continue + } + stream, err2 = dialWebSocket(ctx, p.wsURL, headers, params) + if err2 != nil { + lastErr = fmt.Errorf("websocket pool: reconnect (attempt %d/%d): %w", attempt+1, wsMaxReconnectAttempts, err2) + continue + } + p.conn = &wsConnection{ + conn: stream.conn, + createdAt: time.Now(), + lastResponseID: prevResponseID, + } + return &pooledStream{pool: p, inner: stream}, nil } - return &pooledStream{pool: p, inner: stream}, nil + return nil, lastErr } return &pooledStream{pool: p, inner: stream}, nil @@ -195,10 +213,23 @@ func (s *pooledStream) Err() error { return s.inner.Err() } -// Close releases the stream but keeps the underlying connection alive in -// the pool for reuse. +// Close releases the stream. If the stream encountered an error, the +// underlying connection is invalidated so that the pool opens a fresh one +// on the next request. Otherwise the connection stays in the pool for reuse. func (s *pooledStream) Close() error { s.inner.done = true - // Do NOT close the WebSocket connection—it stays in the pool. + + if s.inner.Err() != nil { + // Connection is likely broken; tear it down so the pool + // doesn't hand out a dead socket. + s.pool.mu.Lock() + if s.pool.conn != nil && s.pool.conn.conn == s.inner.conn { + _ = s.pool.conn.conn.Close() + s.pool.conn = nil + } + s.pool.mu.Unlock() + } + + // Do NOT close the WebSocket connection when healthy—it stays in the pool. return nil } diff --git a/pkg/model/provider/openai/ws_stream.go b/pkg/model/provider/openai/ws_stream.go index 2bf7a8be0..dc9bf3b86 100644 --- a/pkg/model/provider/openai/ws_stream.go +++ b/pkg/model/provider/openai/ws_stream.go @@ -119,6 +119,7 @@ func (s *wsStream) Next() bool { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, ) { s.done = true return false From e4f454c4d7cf9e40de7b29be059b4b410b2554b3 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 19 Mar 2026 22:56:52 +0100 Subject: [PATCH 3/5] Simplify WebSocket pool code structure - Promote lastResponseID from wsConnection to wsPool so it naturally survives all connection transitions without manual threading - Extract closeLocked(), dialLocked(), invalidateConn() helpers to eliminate duplicated connection lifecycle logic in Stream() - Replace loop-of-one reconnect with a single dialLocked() call - Extract sendResponseCreate() to deduplicate marshal+send between dialWebSocket() and sendOnExisting() - Remove wsMaxReconnectAttempts constant (was always 1) - Simplify wsConnection struct to just conn + createdAt Net result: -18 lines, fewer code paths, same behavior. Assisted-By: docker-agent --- pkg/model/provider/openai/ws_pool.go | 151 +++++++++++-------------- pkg/model/provider/openai/ws_stream.go | 37 +++--- 2 files changed, 85 insertions(+), 103 deletions(-) diff --git a/pkg/model/provider/openai/ws_pool.go b/pkg/model/provider/openai/ws_pool.go index efc4bbbde..348fc3421 100644 --- a/pkg/model/provider/openai/ws_pool.go +++ b/pkg/model/provider/openai/ws_pool.go @@ -2,7 +2,6 @@ package openai import ( "context" - "encoding/json" "fmt" "log/slog" "net/http" @@ -17,11 +16,6 @@ const ( // wsMaxConnectionAge is the maximum lifetime of a WebSocket connection. // OpenAI enforces a 60-minute limit; we reconnect slightly earlier. wsMaxConnectionAge = 55 * time.Minute - - // wsMaxReconnectAttempts is the maximum number of times a broken - // connection will be replaced with a fresh one within a single - // Stream call before the error is propagated to the caller. - wsMaxReconnectAttempts = 1 ) // wsConnection holds a WebSocket connection together with bookkeeping @@ -29,11 +23,6 @@ const ( type wsConnection struct { conn *websocket.Conn createdAt time.Time - - // lastResponseID is the ID of the most recent response completed on - // this connection. It can be passed as previous_response_id in subsequent - // requests to enable server-side context caching. - lastResponseID string } // isExpired returns true when the connection has been open longer than @@ -50,6 +39,12 @@ type wsPool struct { mu sync.Mutex conn *wsConnection + // lastResponseID is the ID of the most recent response completed on + // this pool. It can be passed as previous_response_id in subsequent + // requests to enable server-side context caching. + // It lives on the pool (not wsConnection) so it survives reconnections. + lastResponseID string + // wsURL is the WebSocket endpoint (e.g. wss://api.openai.com/v1/responses). wsURL string @@ -76,99 +71,89 @@ func (p *wsPool) Stream( p.mu.Lock() defer p.mu.Unlock() - // Close stale connections, preserving the last response ID. - var prevResponseID string + // Close stale connections. if p.conn != nil && p.conn.isExpired() { slog.Debug("Closing expired WebSocket connection", "age", time.Since(p.conn.createdAt)) - prevResponseID = p.conn.lastResponseID - _ = p.conn.conn.Close() - p.conn = nil + p.closeLocked() } // Establish a new connection if needed. if p.conn == nil { - headers, err := p.headerFn(ctx) - if err != nil { - return nil, fmt.Errorf("websocket pool: headers: %w", err) - } - - stream, err := dialWebSocket(ctx, p.wsURL, headers, params) - if err != nil { - return nil, err - } - - p.conn = &wsConnection{ - conn: stream.conn, - createdAt: time.Now(), - lastResponseID: prevResponseID, - } - - return &pooledStream{pool: p, inner: stream}, nil + return p.dialLocked(ctx, params) } // Reuse existing connection: send a new response.create. stream, err := sendOnExisting(p.conn.conn, params) if err != nil { - // Connection is broken; tear down and retry with a fresh one. - // We only attempt wsMaxReconnectAttempts reconnections to avoid - // unbounded loops if the server keeps rejecting connections. + // Connection is broken; tear down and reconnect once. slog.Warn("Existing WebSocket connection failed, reconnecting", "error", err) - prevResponseID := p.conn.lastResponseID - _ = p.conn.conn.Close() - p.conn = nil - - var lastErr error - for attempt := range wsMaxReconnectAttempts { - headers, err2 := p.headerFn(ctx) - if err2 != nil { - lastErr = fmt.Errorf("websocket pool: headers on reconnect (attempt %d/%d): %w", attempt+1, wsMaxReconnectAttempts, err2) - continue - } - stream, err2 = dialWebSocket(ctx, p.wsURL, headers, params) - if err2 != nil { - lastErr = fmt.Errorf("websocket pool: reconnect (attempt %d/%d): %w", attempt+1, wsMaxReconnectAttempts, err2) - continue - } - p.conn = &wsConnection{ - conn: stream.conn, - createdAt: time.Now(), - lastResponseID: prevResponseID, - } - return &pooledStream{pool: p, inner: stream}, nil - } - return nil, lastErr + p.closeLocked() + return p.dialLocked(ctx, params) + } + + return &pooledStream{pool: p, inner: stream}, nil +} + +// dialLocked opens a fresh WebSocket connection and stores it in the pool. +// Caller must hold p.mu. +func (p *wsPool) dialLocked( + ctx context.Context, + params responses.ResponseNewParams, +) (*pooledStream, error) { + headers, err := p.headerFn(ctx) + if err != nil { + return nil, fmt.Errorf("websocket pool: headers: %w", err) + } + + stream, err := dialWebSocket(ctx, p.wsURL, headers, params) + if err != nil { + return nil, err + } + + p.conn = &wsConnection{ + conn: stream.conn, + createdAt: time.Now(), } return &pooledStream{pool: p, inner: stream}, nil } +// closeLocked closes the current connection. lastResponseID is preserved +// on the pool so it survives reconnections. Caller must hold p.mu. +func (p *wsPool) closeLocked() { + if p.conn == nil { + return + } + _ = p.conn.conn.Close() + p.conn = nil +} + +// invalidateConn tears down the pooled connection if it matches conn. +// Called by pooledStream.Close when the stream encountered an error, +// so the pool does not hand out a broken connection. +func (p *wsPool) invalidateConn(conn *websocket.Conn) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.conn != nil && p.conn.conn == conn { + p.closeLocked() + } +} + // Close shuts down the pooled connection. func (p *wsPool) Close() { p.mu.Lock() defer p.mu.Unlock() - if p.conn != nil { - _ = p.conn.conn.Close() - p.conn = nil - } + p.closeLocked() } // sendOnExisting sends a response.create on an already-open connection and // returns a wsStream that reads events from it. func sendOnExisting(conn *websocket.Conn, params responses.ResponseNewParams) (*wsStream, error) { - paramsJSON, err := json.Marshal(params) - if err != nil { - return nil, fmt.Errorf("websocket: marshal params: %w", err) - } - - msg := wsCreateMessage{ - Type: "response.create", - Params: paramsJSON, - } - - if err := conn.WriteJSON(msg); err != nil { - return nil, fmt.Errorf("websocket: write response.create: %w", err) + if err := sendResponseCreate(conn, params); err != nil { + return nil, err } slog.Debug("WebSocket response.create sent (reused connection)") @@ -196,9 +181,7 @@ func (s *pooledStream) Next() bool { event := s.inner.Current() if isTerminalEvent(event.Type) && event.Response.ID != "" { s.pool.mu.Lock() - if s.pool.conn != nil { - s.pool.conn.lastResponseID = event.Response.ID - } + s.pool.lastResponseID = event.Response.ID s.pool.mu.Unlock() } @@ -220,16 +203,8 @@ func (s *pooledStream) Close() error { s.inner.done = true if s.inner.Err() != nil { - // Connection is likely broken; tear it down so the pool - // doesn't hand out a dead socket. - s.pool.mu.Lock() - if s.pool.conn != nil && s.pool.conn.conn == s.inner.conn { - _ = s.pool.conn.conn.Close() - s.pool.conn = nil - } - s.pool.mu.Unlock() + s.pool.invalidateConn(s.inner.conn) } - // Do NOT close the WebSocket connection when healthy—it stays in the pool. return nil } diff --git a/pkg/model/provider/openai/ws_stream.go b/pkg/model/provider/openai/ws_stream.go index dc9bf3b86..e16e954df 100644 --- a/pkg/model/provider/openai/ws_stream.go +++ b/pkg/model/provider/openai/ws_stream.go @@ -60,6 +60,26 @@ type wsStream struct { // Compile-time check: wsStream satisfies responseEventStream. var _ responseEventStream = (*wsStream)(nil) +// sendResponseCreate marshals params and writes a response.create message +// on the given WebSocket connection. +func sendResponseCreate(conn *websocket.Conn, params responses.ResponseNewParams) error { + paramsJSON, err := json.Marshal(params) + if err != nil { + return fmt.Errorf("websocket: marshal params: %w", err) + } + + msg := wsCreateMessage{ + Type: "response.create", + Params: paramsJSON, + } + + if err := conn.WriteJSON(msg); err != nil { + return fmt.Errorf("websocket: write response.create: %w", err) + } + + return nil +} + // dialWebSocket opens a WebSocket connection, sends the response.create // message, and returns a stream that yields server events. func dialWebSocket( @@ -84,22 +104,9 @@ func dialWebSocket( return nil, fmt.Errorf("websocket dial %s: %w", wsURL, err) } - // Marshal the params using the SDK's MarshalJSON so all field - // encodings (omitzero, unions, etc.) are handled correctly. - paramsJSON, err := json.Marshal(params) - if err != nil { - conn.Close() - return nil, fmt.Errorf("websocket: marshal params: %w", err) - } - - msg := wsCreateMessage{ - Type: "response.create", - Params: paramsJSON, - } - - if err := conn.WriteJSON(msg); err != nil { + if err := sendResponseCreate(conn, params); err != nil { conn.Close() - return nil, fmt.Errorf("websocket: write response.create: %w", err) + return nil, err } slog.Debug("WebSocket response.create sent", "url", wsURL) From 75b96c24dcb341440316671368f7c774d77fc4ba Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 19 Mar 2026 23:07:47 +0100 Subject: [PATCH 4/5] Fix data race on wsPool lazy init and minor issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Initialize wsPool eagerly in NewClient instead of lazily in createWebSocketStream to eliminate a potential data race when concurrent goroutines both see wsPool==nil - Downgrade WebSocket→SSE fallback log from Error to Warn since this is an intentional graceful degradation, not an unexpected error - Close HTTP response body defensively in dialWebSocket on handshake failure to prevent a potential resource leak Assisted-By: docker-agent --- pkg/model/provider/openai/client.go | 29 +++++++++++++++----------- pkg/model/provider/openai/ws_stream.go | 3 +++ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index ec418dda5..8042a295b 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -37,7 +37,7 @@ type Client struct { clientFn func(context.Context) (*openai.Client, error) - // wsPool is lazily initialized when transport=websocket is configured. + // wsPool is initialized in NewClient when transport=websocket is configured. // It maintains a persistent WebSocket connection across requests. wsPool *wsPool } @@ -145,14 +145,24 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro slog.Debug("OpenAI client created successfully", "model", cfg.Model) - return &Client{ + client := &Client{ Config: base.Config{ ModelConfig: *cfg, ModelOptions: globalOptions, Env: env, }, clientFn: clientFn, - }, nil + } + + // Pre-create the WebSocket pool when the transport is configured. + // The pool is cheap (no connections opened until the first Stream call) + // and eager init avoids a data race on the lazy path. + if getTransport(cfg) == "websocket" && globalOptions.Gateway() == "" { + baseURL := cmp.Or(cfg.BaseURL, "https://api.openai.com/v1") + client.wsPool = newWSPool(httpToWSURL(baseURL), client.buildWSHeaderFn()) + } + + return client, nil } // Close releases resources held by the client, including any pooled WebSocket @@ -413,7 +423,7 @@ func (c *Client) CreateResponseStream( if transport == "websocket" && c.ModelOptions.Gateway() == "" { stream, err := c.createWebSocketStream(ctx, params) if err != nil { - slog.Error("WebSocket stream failed, falling back to SSE", "error", err) + slog.Warn("WebSocket stream failed, falling back to SSE", "error", err) // Fall through to SSE below. } else { slog.Debug("OpenAI responses WebSocket stream created successfully", "model", c.ModelConfig.Model) @@ -436,19 +446,14 @@ func (c *Client) CreateResponseStream( return newResponseStreamAdapter(stream, trackUsage), nil } -// createWebSocketStream initializes (or reuses) a WebSocket connection and -// sends the response.create message, returning a responseEventStream. +// createWebSocketStream sends a request over the pre-initialized WebSocket +// pool, returning a responseEventStream. func (c *Client) createWebSocketStream( ctx context.Context, params responses.ResponseNewParams, ) (responseEventStream, error) { if c.wsPool == nil { - // Lazy-init the pool on first WebSocket call. - baseURL := cmp.Or(c.ModelConfig.BaseURL, "https://api.openai.com/v1") - wsURL := httpToWSURL(baseURL) - - headerFn := c.buildWSHeaderFn() - c.wsPool = newWSPool(wsURL, headerFn) + return nil, errors.New("websocket pool not initialized") } return c.wsPool.Stream(ctx, params) diff --git a/pkg/model/provider/openai/ws_stream.go b/pkg/model/provider/openai/ws_stream.go index e16e954df..4b2fc63bc 100644 --- a/pkg/model/provider/openai/ws_stream.go +++ b/pkg/model/provider/openai/ws_stream.go @@ -97,6 +97,9 @@ func dialWebSocket( conn, resp, err := dialer.DialContext(ctx, wsURL, headers) if err != nil { if resp != nil { + if resp.Body != nil { + _ = resp.Body.Close() + } slog.Error("WebSocket handshake failed", "status", resp.StatusCode, "error", err) From 556f27e8479dec11f0002b4b10987cdeadefeb2c Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 20 Mar 2026 12:33:16 +0100 Subject: [PATCH 5/5] Inject lastResponseID as previous_response_id in WebSocket requests The wsPool already tracked lastResponseID from completed responses but never forwarded it to subsequent requests. Now, wsPool.Stream() injects it as previous_response_id when the caller hasn't already set one, enabling server-side context caching across multi-turn exchanges. Add tests covering automatic injection, caller override preservation, and survival across reconnections. Assisted-By: docker-agent --- pkg/model/provider/openai/ws_pool.go | 8 + pkg/model/provider/openai/ws_pool_test.go | 234 ++++++++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 pkg/model/provider/openai/ws_pool_test.go diff --git a/pkg/model/provider/openai/ws_pool.go b/pkg/model/provider/openai/ws_pool.go index 348fc3421..edf16d4db 100644 --- a/pkg/model/provider/openai/ws_pool.go +++ b/pkg/model/provider/openai/ws_pool.go @@ -9,6 +9,7 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/openai/openai-go/v3/packages/param" "github.com/openai/openai-go/v3/responses" ) @@ -71,6 +72,13 @@ func (p *wsPool) Stream( p.mu.Lock() defer p.mu.Unlock() + // Inject previous_response_id for server-side context caching when the + // caller hasn't already set one and we have a response from an earlier + // exchange on this pool. + if p.lastResponseID != "" && !params.PreviousResponseID.Valid() { + params.PreviousResponseID = param.NewOpt(p.lastResponseID) + } + // Close stale connections. if p.conn != nil && p.conn.isExpired() { slog.Debug("Closing expired WebSocket connection", diff --git a/pkg/model/provider/openai/ws_pool_test.go b/pkg/model/provider/openai/ws_pool_test.go new file mode 100644 index 000000000..6c5370519 --- /dev/null +++ b/pkg/model/provider/openai/ws_pool_test.go @@ -0,0 +1,234 @@ +package openai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" + "github.com/openai/openai-go/v3/packages/param" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testWSServerCapture starts a test WebSocket server that captures each +// response.create message into the returned slice and replies with the +// given canned events. +func testWSServerCapture(t *testing.T, events []map[string]any) (*httptest.Server, *[]map[string]json.RawMessage) { + t.Helper() + + var captured []map[string]json.RawMessage + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("WebSocket upgrade failed: %v", err) + return + } + defer conn.Close() + + for { + // Read a response.create message. + _, data, err := conn.ReadMessage() + if err != nil { + return + } + + var createMsg map[string]json.RawMessage + if err := json.Unmarshal(data, &createMsg); err != nil { + t.Errorf("Failed to unmarshal response.create: %v", err) + return + } + captured = append(captured, createMsg) + + // Send events. + for _, event := range events { + eventData, _ := json.Marshal(event) + if err := conn.WriteMessage(websocket.TextMessage, eventData); err != nil { + return + } + } + } + })) + + return srv, &captured +} + +func completedEvent(responseID string) map[string]any { + return map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "output": []any{}, + "usage": map[string]any{ + "input_tokens": 5, + "output_tokens": 1, + "total_tokens": 6, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + "output_tokens_details": map[string]any{ + "reasoning_tokens": 0, + }, + }, + }, + } +} + +func TestWSPool_InjectsPreviousResponseID(t *testing.T) { + t.Parallel() + + events := []map[string]any{ + completedEvent("resp_first"), + } + + srv, captured := testWSServerCapture(t, events) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + pool := newWSPool(wsURL, func(_ context.Context) (http.Header, error) { + return http.Header{}, nil + }) + defer pool.Close() + + ctx := t.Context() + + // --- First request: no previous_response_id should be set. + stream1, err := pool.Stream(ctx, defaultTestParams()) + require.NoError(t, err) + drainStream(t, stream1) + + // After draining, the pool should have captured the response ID. + assert.Equal(t, "resp_first", pool.lastResponseID) + + // --- Second request: the pool should inject previous_response_id automatically. + // Change events for the second request to return a different ID. + // (The server always sends the same events we initialized, so we verify + // the injection from the captured request.) + stream2, err := pool.Stream(ctx, defaultTestParams()) + require.NoError(t, err) + drainStream(t, stream2) + + // Verify captured messages. + require.Len(t, *captured, 2) + + // First request: no previous_response_id. + assertPreviousResponseID(t, (*captured)[0], "") + + // Second request: pool injects the ID from the first response. + assertPreviousResponseID(t, (*captured)[1], "resp_first") +} + +func TestWSPool_CallerPreviousResponseIDNotOverwritten(t *testing.T) { + t.Parallel() + + events := []map[string]any{ + completedEvent("resp_pool"), + } + + srv, captured := testWSServerCapture(t, events) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + pool := newWSPool(wsURL, func(_ context.Context) (http.Header, error) { + return http.Header{}, nil + }) + defer pool.Close() + + ctx := t.Context() + + // First request — populate lastResponseID. + stream1, err := pool.Stream(ctx, defaultTestParams()) + require.NoError(t, err) + drainStream(t, stream1) + + assert.Equal(t, "resp_pool", pool.lastResponseID) + + // Second request with caller-provided previous_response_id. + params := defaultTestParams() + params.PreviousResponseID = param.NewOpt("caller_resp_999") + + stream2, err := pool.Stream(ctx, params) + require.NoError(t, err) + drainStream(t, stream2) + + require.Len(t, *captured, 2) + + // The caller's ID must NOT be overwritten by the pool. + assertPreviousResponseID(t, (*captured)[1], "caller_resp_999") +} + +func TestWSPool_LastResponseIDSurvivesReconnect(t *testing.T) { + t.Parallel() + + events := []map[string]any{ + completedEvent("resp_survive"), + } + + srv, captured := testWSServerCapture(t, events) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + pool := newWSPool(wsURL, func(_ context.Context) (http.Header, error) { + return http.Header{}, nil + }) + defer pool.Close() + + ctx := t.Context() + + // First request. + stream1, err := pool.Stream(ctx, defaultTestParams()) + require.NoError(t, err) + drainStream(t, stream1) + + assert.Equal(t, "resp_survive", pool.lastResponseID) + + // Force a reconnect by closing the pooled connection. + pool.Close() + + // Second request after reconnection. + stream2, err := pool.Stream(ctx, defaultTestParams()) + require.NoError(t, err) + drainStream(t, stream2) + + require.Len(t, *captured, 2) + + // The lastResponseID should survive the reconnect. + assertPreviousResponseID(t, (*captured)[1], "resp_survive") +} + +// drainStream reads all events from a responseEventStream until exhausted. +func drainStream(t *testing.T, stream responseEventStream) { + t.Helper() + for stream.Next() { + // consume + } + require.NoError(t, stream.Err()) + require.NoError(t, stream.Close()) +} + +// assertPreviousResponseID checks that the captured response.create message +// contains (or omits) the expected previous_response_id. +func assertPreviousResponseID(t *testing.T, msg map[string]json.RawMessage, expected string) { + t.Helper() + + raw, ok := msg["previous_response_id"] + if expected == "" { + // Either absent or null. + if ok { + assert.JSONEq(t, "null", string(raw), + "expected previous_response_id to be absent or null") + } + return + } + + require.True(t, ok, "expected previous_response_id in request") + var got string + require.NoError(t, json.Unmarshal(raw, &got)) + assert.Equal(t, expected, got) +}