diff --git a/README.md b/README.md index dffe3e98..76afb457 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,8 @@ Key settings: | `ENABLE_PASSTHROUGH_ROUTES` | `true` | Enable provider-native passthrough routes under `/p/{provider}/...` | | `ALLOW_PASSTHROUGH_V1_ALIAS` | `true` | Allow `/p/{provider}/v1/...` aliases while keeping `/p/{provider}/...` canonical | | `ENABLED_PASSTHROUGH_PROVIDERS` | `openai,anthropic` | Comma-separated list of enabled passthrough providers | +| `EXPERIMENTAL_FORWARD_PROXY_ENABLED` | `false` | Enable the experimental HTTP forward proxy wrapper for client traffic inspection | +| `EXPERIMENTAL_FORWARD_PROXY_MITM_HOSTS` | `api.anthropic.com` | Comma-separated HTTPS hosts to inspect; other CONNECT targets are tunneled blindly | | `CACHE_TYPE` | `local` | Cache backend (`local` or `redis`) | | `STORAGE_TYPE` | `sqlite` | Storage backend (`sqlite`, `postgresql`, `mongodb`) | | `METRICS_ENABLED` | `false` | Enable Prometheus metrics | @@ -180,6 +182,8 @@ Key settings: **Quick Start - Authentication:** By default `GOMODEL_MASTER_KEY` is unset. Without this key, API endpoints are unprotected and anyone can call them. This is insecure for production. **Strongly recommend** setting a strong secret before exposing the service. Add `GOMODEL_MASTER_KEY` to your `.env` or environment for production deployments. +**Experimental forward proxy:** When `EXPERIMENTAL_FORWARD_PROXY_ENABLED=true`, GoModel can act as a local HTTP proxy for client traffic. To inspect HTTPS bodies, provide `EXPERIMENTAL_FORWARD_PROXY_CA_CERT_FILE` and `EXPERIMENTAL_FORWARD_PROXY_CA_KEY_FILE`, trust that CA in the client environment, and point the client at GoModel with `HTTP_PROXY` or `HTTPS_PROXY`. This mode is intended for local experiments and is not hardened as a general-purpose enterprise proxy. For Claude Code setup examples, see [`docs/guides/claude-code.mdx`](docs/guides/claude-code.mdx). + --- ## Response Caching diff --git a/config/config.example.yaml b/config/config.example.yaml index c068c498..adc72b61 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -11,6 +11,10 @@ server: enable_passthrough_routes: true # expose /p/{provider}/{endpoint} passthrough routes allow_passthrough_v1_alias: true # allow /p/{provider}/v1/... while keeping /p/{provider}/... canonical enabled_passthrough_providers: ["openai", "anthropic"] # providers enabled on /p/{provider}/... + experimental_forward_proxy_enabled: false # HTTP forward proxy for client traffic inspection (experimental) + experimental_forward_proxy_mitm_hosts: ["api.anthropic.com"] # HTTPS hosts to inspect; all others tunnel blindly + experimental_forward_proxy_ca_cert_file: "" # PEM CA cert used for MITM leaf cert generation + experimental_forward_proxy_ca_key_file: "" # PEM CA private key used for MITM leaf cert generation cache: model: diff --git a/config/config.go b/config/config.go index 983822a3..f25e39ec 100644 --- a/config/config.go +++ b/config/config.go @@ -434,6 +434,18 @@ type ServerConfig struct { // EnabledPassthroughProviders lists the provider types enabled on // /p/{provider}/... passthrough routes. Default: ["openai", "anthropic"]. EnabledPassthroughProviders []string `yaml:"enabled_passthrough_providers" env:"ENABLED_PASSTHROUGH_PROVIDERS"` + // ExperimentalForwardProxyEnabled enables an HTTP forward proxy entrypoint that can + // optionally MITM selected HTTPS hosts for traffic inspection. Default: false. + ExperimentalForwardProxyEnabled bool `yaml:"experimental_forward_proxy_enabled" env:"EXPERIMENTAL_FORWARD_PROXY_ENABLED"` + // ExperimentalForwardProxyMITMHosts lists the hosts whose HTTPS CONNECT traffic + // should be terminated and inspected. Other hosts are tunneled blindly. + ExperimentalForwardProxyMITMHosts []string `yaml:"experimental_forward_proxy_mitm_hosts" env:"EXPERIMENTAL_FORWARD_PROXY_MITM_HOSTS"` + // ExperimentalForwardProxyCACertFile points at the PEM-encoded CA certificate used + // to mint leaf certificates for inspected HTTPS hosts. + ExperimentalForwardProxyCACertFile string `yaml:"experimental_forward_proxy_ca_cert_file" env:"EXPERIMENTAL_FORWARD_PROXY_CA_CERT_FILE"` + // ExperimentalForwardProxyCAKeyFile points at the PEM-encoded CA private key used + // to mint leaf certificates for inspected HTTPS hosts. + ExperimentalForwardProxyCAKeyFile string `yaml:"experimental_forward_proxy_ca_key_file" env:"EXPERIMENTAL_FORWARD_PROXY_CA_KEY_FILE"` } // MetricsConfig holds observability configuration for Prometheus metrics @@ -504,6 +516,7 @@ func buildDefaultConfig() *Config { "openai", "anthropic", }, + ExperimentalForwardProxyMITMHosts: []string{"api.anthropic.com"}, }, Cache: CacheConfig{ Model: ModelCacheConfig{ diff --git a/docs/guides/claude-code.mdx b/docs/guides/claude-code.mdx index d2b517bb..f2501074 100644 --- a/docs/guides/claude-code.mdx +++ b/docs/guides/claude-code.mdx @@ -1,45 +1,34 @@ --- title: "Using GoModel with Claude Code" -description: "Step-by-step guide for routing Claude Code through GoModel with Anthropic passthrough." +description: "Step-by-step guide for using Claude Code with GoModel in gateway mode or subscription proxy mode." --- -GoModel can sit in front of Claude Code so every request goes through your own -gateway first. +GoModel supports two different Claude Code setups: -Flow: - -`Claude Code -> GoModel -> Anthropic` +| Goal | Mode | Upstream auth | +| --- | --- | --- | +| Use Claude Code through an Anthropic-compatible gateway | Gateway mode | GoModel's `ANTHROPIC_API_KEY` | +| Keep Claude.ai subscription auth and still capture traffic in GoModel | Subscription proxy mode | Claude Code's own `claude.ai` or enterprise session | -## Before you start +If you want company-wide Claude Code observability without moving users onto API +keys, use **subscription proxy mode**. -- Install Claude Code on your machine. -- Choose a GoModel master key, for example `change-me`. -- Make sure GoModel has an Anthropic upstream credential. +## Mode 1: Gateway mode - - Claude Code can be routed through GoModel whether or not you personally use a - Claude Code subscription. For gateway mode, Claude Code talks to GoModel with - `ANTHROPIC_BASE_URL` and `ANTHROPIC_AUTH_TOKEN`. GoModel still needs its own - `ANTHROPIC_API_KEY` to reach Anthropic upstream. - - -## How to get `ANTHROPIC_API_KEY` +Flow: -1. Open the Claude Console and sign in to your API account. -2. Go to account settings in Console, then create an API key. -3. Copy the key once and set it for GoModel as `ANTHROPIC_API_KEY`. +`Claude Code -> GoModel /p/anthropic -> Anthropic` -Anthropic's API docs state that API keys are created in Console account -settings. +This is the standard Anthropic-compatible gateway setup. Claude Code talks to +GoModel with `ANTHROPIC_BASE_URL` and `ANTHROPIC_AUTH_TOKEN`, and GoModel calls +Anthropic with its own `ANTHROPIC_API_KEY`. - Claude paid plans (Pro, Max, Team, Enterprise) and Claude API Console billing - are separate. API key usage is billed as API usage. + Gateway mode does not use a user's Claude subscription allowance. It uses + GoModel's upstream Anthropic API credentials. -## 1. Run GoModel - -Start GoModel with a master key and an Anthropic provider key: +### 1. Run GoModel ```bash docker run --rm -p 8080:8080 \ @@ -48,17 +37,13 @@ docker run --rm -p 8080:8080 \ enterpilot/gomodel ``` -## 2. Confirm Anthropic passthrough with curl - -Check that the Anthropic passthrough models endpoint responds: +### 2. Confirm Anthropic passthrough ```bash curl -s http://localhost:8080/p/anthropic/v1/models \ -H "Authorization: Bearer change-me" ``` -Then send one small test message: - ```bash curl -s http://localhost:8080/p/anthropic/v1/messages \ -H "Authorization: Bearer change-me" \ @@ -76,11 +61,7 @@ curl -s http://localhost:8080/p/anthropic/v1/messages \ }' ``` -If the gateway is wired correctly, the response will contain `ok`. - -## 3. Configure Claude Code to use GoModel - -Point Claude Code at GoModel's Anthropic passthrough: +### 3. Point Claude Code at GoModel ```bash export ANTHROPIC_BASE_URL=http://localhost:8080/p/anthropic @@ -88,55 +69,142 @@ export ANTHROPIC_AUTH_TOKEN=change-me export CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 ``` -Short Claude Code doc summary: for gateway mode, set `ANTHROPIC_BASE_URL` to -your gateway URL and `ANTHROPIC_AUTH_TOKEN` to your gateway token, then run -Claude Code normally. See the official guide: -[Claude Code LLM gateway docs](https://code.claude.com/docs/en/llm-gateway). +```bash +claude -p --output-format text --model claude-3-haiku-20240307 \ + 'Reply with exactly ok and no punctuation.' +``` + +## Mode 2: Subscription proxy mode + +Flow: -If GoModel is not running on your local machine, replace `localhost:8080` with -the correct host and port. +`Claude Code -> GoModel forward proxy -> Anthropic` -## 4. Run a Claude Code test prompt +This mode keeps Claude Code signed in with `claude.ai` or enterprise auth and +places GoModel in the middle as an HTTP(S) proxy. It is the mode to use when +you want GoModel `audit_logs` and `usage` for real Claude Code subscription +traffic. + + + Validated on March 23, 2026 with Claude Code `2.1.81`, `authMethod: + claude.ai`, and GoModel's experimental forward proxy. + + + + The forward proxy is experimental and intended for local or tightly controlled + internal use. Do not expose it directly to the internet in its current form. + + +### 1. Generate a local MITM CA ```bash -claude -p --output-format text --model claude-3-haiku-20240307 \ - 'Reply with exactly ok and no punctuation.' +mkdir -p proxy-certs + +openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout proxy-certs/ca-key.pem \ + -out proxy-certs/ca-cert.pem \ + -days 30 \ + -subj "/CN=GoModel Claude Proxy CA" +``` + +### 2. Run GoModel with the forward proxy enabled + +```bash +docker run --rm -p 8080:8080 \ + -v "$PWD/proxy-certs:/certs" \ + -e OPENAI_API_KEY="dummy" \ + -e LOGGING_ENABLED=true \ + -e LOGGING_LOG_BODIES=true \ + -e LOGGING_LOG_HEADERS=true \ + -e USAGE_ENABLED=true \ + -e EXPERIMENTAL_FORWARD_PROXY_ENABLED=true \ + -e EXPERIMENTAL_FORWARD_PROXY_MITM_HOSTS="api.anthropic.com" \ + -e EXPERIMENTAL_FORWARD_PROXY_CA_CERT_FILE="/certs/ca-cert.pem" \ + -e EXPERIMENTAL_FORWARD_PROXY_CA_KEY_FILE="/certs/ca-key.pem" \ + enterpilot/gomodel +``` + + + The forward proxy path itself does not use `OPENAI_API_KEY`. The example sets + `OPENAI_API_KEY=dummy` only because GoModel currently expects at least one + configured provider at startup. If you already run GoModel with a real + provider credential, keep using that instead. + + +### 3. Point Claude Code at GoModel as an HTTP proxy + +Keep Claude Code logged in normally, then export: + +```bash +export HTTPS_PROXY=http://localhost:8080 +export HTTP_PROXY=http://localhost:8080 +export NODE_EXTRA_CA_CERTS="$PWD/proxy-certs/ca-cert.pem" +export NO_PROXY=127.0.0.1,localhost ``` -The validated result was: +You can confirm the client is still using subscription auth with: -```text -ok +```bash +claude auth status ``` -## 5. Check the traffic in GoModel +### 4. Run a real Claude Code prompt -Open the GoModel dashboard audit logs: +```bash +claude -p --output-format json \ + --setting-sources user \ + --strict-mcp-config --mcp-config '{"mcpServers":{}}' \ + --disable-slash-commands \ + --no-chrome \ + --no-session-persistence \ + 'Respond with exactly the word THROUGHPROXY.' +``` + +A successful run returns JSON with: + +```json +{ + "result": "THROUGHPROXY" +} +``` + +### 5. Verify audit and usage in GoModel + +Open the dashboard audit view: [http://localhost:8080/admin/dashboard/audit](http://localhost:8080/admin/dashboard/audit) -This is the easiest place to confirm that Claude Code is reaching GoModel and -to inspect the full request and response trail. From the same dashboard, you -can keep following your GoModel traffic and usage. +For a successful proxied subscription request, you should see: + +- startup requests such as `/api/oauth/claude_cli/client_data` +- a model request row for `/v1/messages` +- `usage` data for that `/v1/messages` call + +The admin APIs are also available directly: + +- `GET /admin/api/v1/audit/log` +- `GET /admin/api/v1/usage/log` ## References - Anthropic Claude Code gateway docs: [LLM gateway](https://code.claude.com/docs/en/llm-gateway) - Anthropic Claude Code settings: [Settings](https://code.claude.com/docs/en/settings) +- Anthropic Claude Code enterprise network config: [Network configuration](https://code.claude.com/docs/en/network-config) - Claude Help: [Managing API key environment variables in Claude Code](https://support.claude.com/en/articles/12304248-managing-api-key-environment-variables-in-claude-code) - Claude Help: [Paid Claude plans vs API Console billing](https://support.claude.com/en/articles/9876003-i-have-a-paid-claude-subscription-pro-max-team-or-enterprise-plans-why-do-i-have-to-pay-separately-to-use-the-claude-api-and-console) -- Claude API docs: [API overview (keys from Console account settings)](https://docs.anthropic.com/en/api/getting-started) -## Validated on March 10, 2026 +## Validation summary -This guide was validated against: +Validated on March 23, 2026 against: -- a local GoModel instance on `http://localhost:8080` -- a GoModel branch exposing `/p/{provider}/v1/...` -- Claude Code `2.1.72` +- GoModel running locally on `http://localhost:8080` +- Claude Code `2.1.81` +- Anthropic gateway mode through `/p/anthropic` +- subscription proxy mode through `HTTP_PROXY` / `HTTPS_PROXY` Local validation confirmed: -- `GET /p/anthropic/v1/models` returned `200 OK` -- `POST /p/anthropic/v1/messages` returned `200 OK` -- `claude` returned `ok` when pointed at GoModel +- gateway mode returned successful Anthropic passthrough responses +- subscription proxy mode returned `THROUGHPROXY` +- GoModel stored `/v1/messages` in `audit_logs` +- GoModel stored token and cost data for `/v1/messages` in `usage` diff --git a/internal/app/app.go b/internal/app/app.go index 775e11a5..77bb33e3 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -219,6 +219,12 @@ func New(ctx context.Context, cfg Config) (*App, error) { EnabledPassthroughProviders: appCfg.Server.EnabledPassthroughProviders, AllowPassthroughV1Alias: &allowPassthroughV1Alias, SwaggerEnabled: appCfg.Server.SwaggerEnabled, + ExperimentalForwardProxy: &server.ForwardProxyConfig{ + Enabled: appCfg.Server.ExperimentalForwardProxyEnabled, + MITMHosts: appCfg.Server.ExperimentalForwardProxyMITMHosts, + CACertFile: appCfg.Server.ExperimentalForwardProxyCACertFile, + CAKeyFile: appCfg.Server.ExperimentalForwardProxyCAKeyFile, + }, } // Initialize admin API and dashboard (behind separate feature flags) @@ -256,6 +262,9 @@ func New(ctx context.Context, cfg Config) (*App, error) { } else { slog.Info("provider passthrough disabled") } + if appCfg.Server.ExperimentalForwardProxyEnabled { + slog.Info("experimental forward proxy enabled", "mitm_hosts", appCfg.Server.ExperimentalForwardProxyMITMHosts) + } rcm, err := responsecache.NewResponseCacheMiddleware(appCfg.Cache.Response, cfg.AppConfig.RawProviders) if err != nil { diff --git a/internal/auditlog/auditlog_test.go b/internal/auditlog/auditlog_test.go index b60cc522..9de23c43 100644 --- a/internal/auditlog/auditlog_test.go +++ b/internal/auditlog/auditlog_test.go @@ -850,6 +850,94 @@ func TestCreateStreamEntry(t *testing.T) { } } +func TestStreamLogObserverAnthropicMessages(t *testing.T) { + store := &mockStore{} + cfg := Config{ + Enabled: true, + LogBodies: true, + BufferSize: 10, + FlushInterval: 100 * time.Millisecond, + } + logger := NewLogger(store, cfg) + + entry := &LogEntry{ + ID: "anthropic-stream-entry", + Timestamp: time.Now(), + Model: "claude-sonnet-4-5", + Data: &LogData{}, + } + + streamData := `event: message_start +data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[]}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn"}} + +data: [DONE] + +` + stream := streaming.NewObservedSSEStream( + io.NopCloser(strings.NewReader(streamData)), + NewStreamLogObserver(logger, entry, "/v1/messages"), + ) + + if _, err := io.Copy(io.Discard, stream); err != nil { + t.Fatalf("failed to read stream: %v", err) + } + if err := stream.Close(); err != nil { + t.Fatalf("failed to close stream: %v", err) + } + if err := logger.Close(); err != nil { + t.Fatalf("failed to close logger: %v", err) + } + + entries := store.getEntries() + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + logged := entries[0] + if logged.Data == nil { + t.Fatal("Data = nil") + } + body, ok := logged.Data.ResponseBody.(map[string]any) + if !ok { + t.Fatalf("ResponseBody = %#v, want map", logged.Data.ResponseBody) + } + if body["id"] != "msg_123" { + t.Fatalf("id = %#v, want msg_123", body["id"]) + } + if body["model"] != "claude-sonnet-4-5" { + t.Fatalf("model = %#v, want claude-sonnet-4-5", body["model"]) + } + if body["role"] != "assistant" { + t.Fatalf("role = %#v, want assistant", body["role"]) + } + if body["stop_reason"] != "end_turn" { + t.Fatalf("stop_reason = %#v, want end_turn", body["stop_reason"]) + } + content, ok := body["content"].([]map[string]any) + if ok { + if len(content) != 1 || content[0]["text"] != "Hello world" { + t.Fatalf("content = %#v, want Hello world", content) + } + return + } + contentAny, ok := body["content"].([]any) + if !ok || len(contentAny) != 1 { + t.Fatalf("content = %#v, want single text block", body["content"]) + } + first, ok := contentAny[0].(map[string]any) + if !ok || first["text"] != "Hello world" { + t.Fatalf("content[0] = %#v, want Hello world", contentAny[0]) + } +} + func TestHashAPIKey(t *testing.T) { tests := []struct { name string diff --git a/internal/auditlog/stream_observer.go b/internal/auditlog/stream_observer.go index 1a93a19e..3be74b7a 100644 --- a/internal/auditlog/stream_observer.go +++ b/internal/auditlog/stream_observer.go @@ -24,8 +24,10 @@ func NewStreamLogObserver(logger LoggerInterface, entry *LogEntry, path string) logBodies := logger.Config().LogBodies var builder *streamResponseBuilder if logBodies { + isResponsesAPI := strings.HasPrefix(path, "/v1/responses") builder = &streamResponseBuilder{ - IsResponsesAPI: strings.HasPrefix(path, "/v1/responses"), + IsResponsesAPI: isResponsesAPI, + IsAnthropicMessages: !isResponsesAPI && strings.HasPrefix(path, "/v1/messages"), } } @@ -46,6 +48,10 @@ func (o *StreamLogObserver) OnJSONEvent(event map[string]any) { o.parseResponsesAPIEvent(event) return } + if o.builder.IsAnthropicMessages { + o.parseAnthropicMessagesEvent(event) + return + } o.parseChatCompletionEvent(event) } @@ -62,6 +68,8 @@ func (o *StreamLogObserver) OnStreamClose() { if o.logBodies && o.builder != nil && o.entry != nil && o.entry.Data != nil { if o.builder.IsResponsesAPI { o.entry.Data.ResponseBody = o.builder.buildResponsesAPIResponse() + } else if o.builder.IsAnthropicMessages { + o.entry.Data.ResponseBody = o.builder.buildAnthropicMessageResponse() } else { o.entry.Data.ResponseBody = o.builder.buildChatCompletionResponse() } @@ -137,6 +145,42 @@ func (o *StreamLogObserver) parseResponsesAPIEvent(event map[string]any) { } } +func (o *StreamLogObserver) parseAnthropicMessagesEvent(event map[string]any) { + if o.builder == nil { + return + } + + eventType, _ := event["type"].(string) + switch eventType { + case "message_start": + if message, ok := event["message"].(map[string]any); ok { + if id, ok := message["id"].(string); ok { + o.builder.ID = id + } + if model, ok := message["model"].(string); ok { + o.builder.Model = model + } + if role, ok := message["role"].(string); ok { + o.builder.Role = role + } + } + case "content_block_delta": + if delta, ok := event["delta"].(map[string]any); ok { + if deltaType, _ := delta["type"].(string); deltaType == "text_delta" { + if text, ok := delta["text"].(string); ok && text != "" { + o.appendContent(text) + } + } + } + case "message_delta": + if delta, ok := event["delta"].(map[string]any); ok { + if stopReason, ok := delta["stop_reason"].(string); ok && stopReason != "" { + o.builder.FinishReason = stopReason + } + } + } +} + func (o *StreamLogObserver) appendContent(content string) { if o.builder == nil || o.builder.truncated || o.builder.contentLen >= MaxContentCapture { return diff --git a/internal/auditlog/stream_wrapper.go b/internal/auditlog/stream_wrapper.go index 9b8b07a0..6606dc0f 100644 --- a/internal/auditlog/stream_wrapper.go +++ b/internal/auditlog/stream_wrapper.go @@ -23,6 +23,9 @@ type streamResponseBuilder struct { CreatedAt int64 Status string + // Anthropic Messages fields + IsAnthropicMessages bool + // Tracking contentLen int // track content length to enforce limit truncated bool @@ -71,6 +74,22 @@ func (b *streamResponseBuilder) buildResponsesAPIResponse() map[string]any { } } +func (b *streamResponseBuilder) buildAnthropicMessageResponse() map[string]any { + return map[string]any{ + "id": b.ID, + "type": "message", + "role": b.Role, + "model": b.Model, + "content": []map[string]any{ + { + "type": "text", + "text": b.Content.String(), + }, + }, + "stop_reason": b.FinishReason, + } +} + // CreateStreamEntry creates a new log entry for a streaming request. // This should be called before starting the stream. func CreateStreamEntry(baseEntry *LogEntry) *LogEntry { diff --git a/internal/server/forward_proxy.go b/internal/server/forward_proxy.go new file mode 100644 index 00000000..38830456 --- /dev/null +++ b/internal/server/forward_proxy.go @@ -0,0 +1,859 @@ +package server + +import ( + "bufio" + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "log/slog" + "math/big" + "net" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/google/uuid" + + "gomodel/internal/auditlog" + "gomodel/internal/core" + "gomodel/internal/streaming" + "gomodel/internal/usage" +) + +type ForwardProxyConfig struct { + Enabled bool + MITMHosts []string + CACertFile string + CAKeyFile string + AuditLogger auditlog.LoggerInterface + UsageLogger usage.LoggerInterface + PricingResolver usage.PricingResolver + Transport *http.Transport +} + +type forwardProxyHandler struct { + api http.Handler + auditLogger auditlog.LoggerInterface + usageLogger usage.LoggerInterface + pricingResolver usage.PricingResolver + transport *http.Transport + authority *mitmCertificateAuthority + mitmHosts map[string]struct{} +} + +type mitmCertificateAuthority struct { + cert *x509.Certificate + key crypto.Signer + cache sync.Map +} + +var errForwardProxyCloseConnection = errors.New("forward proxy close connection") + +func NewForwardProxyHandler(api http.Handler, cfg *ForwardProxyConfig) (http.Handler, error) { + if api == nil { + return nil, fmt.Errorf("forward proxy requires an API handler") + } + if cfg == nil || !cfg.Enabled { + return api, nil + } + + handler := &forwardProxyHandler{ + api: api, + auditLogger: cfg.AuditLogger, + usageLogger: cfg.UsageLogger, + pricingResolver: cfg.PricingResolver, + transport: cloneProxyTransport(cfg.Transport), + mitmHosts: normalizeProxyHosts(cfg.MITMHosts), + } + + if len(handler.mitmHosts) > 0 { + authority, err := loadMITMCertificateAuthority(cfg.CACertFile, cfg.CAKeyFile) + if err != nil { + return nil, err + } + handler.authority = authority + } + + return handler, nil +} + +func cloneProxyTransport(transport *http.Transport) *http.Transport { + if transport == nil { + base := http.DefaultTransport.(*http.Transport).Clone() + base.Proxy = nil + return base + } + cloned := transport.Clone() + cloned.Proxy = nil + return cloned +} + +func normalizeProxyHosts(hosts []string) map[string]struct{} { + if len(hosts) == 0 { + return nil + } + result := make(map[string]struct{}, len(hosts)) + for _, host := range hosts { + canonical := canonicalProxyHost(host) + if canonical == "" { + continue + } + result[canonical] = struct{}{} + } + return result +} + +func (h *forwardProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !isForwardProxyRequest(r) { + h.api.ServeHTTP(w, r) + return + } + + switch r.Method { + case http.MethodConnect: + h.handleConnect(w, r) + default: + h.handleHTTPProxy(w, r) + } +} + +func isForwardProxyRequest(r *http.Request) bool { + if r == nil { + return false + } + return r.Method == http.MethodConnect || (r.URL != nil && r.URL.IsAbs()) +} + +func (h *forwardProxyHandler) handleConnect(w http.ResponseWriter, r *http.Request) { + targetAddr := targetAddress(r.Host, "443") + targetHost := canonicalProxyHost(targetAddr) + if _, ok := h.mitmHosts[targetHost]; ok && h.authority != nil { + slog.Info("forward proxy CONNECT", "target", targetAddr, "mode", "mitm") + h.handleMITMConnect(w, r, targetAddr, targetHost) + return + } + slog.Info("forward proxy CONNECT", "target", targetAddr, "mode", "tunnel") + h.handleTunnelConnect(w, targetAddr) +} + +func (h *forwardProxyHandler) handleTunnelConnect(w http.ResponseWriter, targetAddr string) { + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "proxy hijacking is not supported", http.StatusInternalServerError) + return + } + + clientConn, _, err := hijacker.Hijack() + if err != nil { + http.Error(w, "failed to hijack proxy connection", http.StatusInternalServerError) + return + } + + upstreamConn, err := net.DialTimeout("tcp", targetAddr, 30*time.Second) + if err != nil { + _, _ = io.WriteString(clientConn, "HTTP/1.1 502 Bad Gateway\r\n\r\n") + _ = clientConn.Close() + return + } + + _, _ = io.WriteString(clientConn, "HTTP/1.1 200 Connection Established\r\n\r\n") + + go func() { + _, _ = io.Copy(upstreamConn, clientConn) + _ = upstreamConn.Close() + }() + go func() { + _, _ = io.Copy(clientConn, upstreamConn) + _ = clientConn.Close() + }() +} + +func (h *forwardProxyHandler) handleMITMConnect(w http.ResponseWriter, r *http.Request, targetAddr, targetHost string) { + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "proxy hijacking is not supported", http.StatusInternalServerError) + return + } + + clientConn, _, err := hijacker.Hijack() + if err != nil { + http.Error(w, "failed to hijack proxy connection", http.StatusInternalServerError) + return + } + defer func() { _ = clientConn.Close() }() + + _, _ = io.WriteString(clientConn, "HTTP/1.1 200 Connection Established\r\n\r\n") + + cert, err := h.authority.certificateForHost(targetHost) + if err != nil { + slog.Error("failed to mint MITM certificate", "host", targetHost, "error", err) + return + } + + tlsConn := tls.Server(clientConn, &tls.Config{ + Certificates: []tls.Certificate{*cert}, + MinVersion: tls.VersionTLS12, + }) + defer func() { _ = tlsConn.Close() }() + + if err := tlsConn.Handshake(); err != nil { + slog.Debug("forward proxy TLS handshake failed", "host", targetHost, "error", err) + return + } + slog.Info( + "forward proxy TLS handshake", + "host", targetHost, + "alpn", tlsConn.ConnectionState().NegotiatedProtocol, + "version", tls.VersionName(tlsConn.ConnectionState().Version), + ) + + reader := bufio.NewReader(tlsConn) + requestCount := 0 + for { + req, err := http.ReadRequest(reader) + if err != nil { + if !errors.Is(err, io.EOF) { + buffered := reader.Buffered() + prefix := "" + if buffered > 0 { + peek, peekErr := reader.Peek(min(buffered, 32)) + if peekErr == nil { + prefix = sanitizeProxyPreview(peek) + } + } + slog.Info( + "forward proxy request read failed", + "host", targetHost, + "error", err, + "alpn", tlsConn.ConnectionState().NegotiatedProtocol, + "requests_served", requestCount, + "buffered", buffered, + "prefix", prefix, + ) + } + return + } + requestCount++ + if err := h.serveMITMRequest(tlsConn, req, targetAddr); err != nil { + if errors.Is(err, errForwardProxyCloseConnection) { + return + } + slog.Debug("forward proxy MITM request failed", "host", targetHost, "error", err) + return + } + } +} + +func (h *forwardProxyHandler) serveMITMRequest(clientConn net.Conn, req *http.Request, targetAddr string) error { + defer func() { + if req.Body != nil { + _ = req.Body.Close() + } + }() + + requestID := ensureProxyRequestID(req.Header) + start := time.Now().UTC() + slog.Info( + "forward proxy request", + "target", targetAddr, + "method", req.Method, + "path", req.URL.Path, + "content_length", req.ContentLength, + "expect", strings.TrimSpace(req.Header.Get("Expect")), + "user_agent", req.UserAgent(), + "request_id", requestID, + ) + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return err + } + model, streamRequested := extractProxyRequestInfo(bodyBytes) + if req.URL.Path == "/api/event_logging/v2/batch" { + return h.serveSyntheticEventLoggingSuccess(clientConn, req, start, requestID, model, bodyBytes) + } + + upstreamReq := req.Clone(context.Background()) + upstreamReq.URL = &url.URL{ + Scheme: "https", + Host: targetAddr, + Path: req.URL.Path, + RawPath: req.URL.RawPath, + RawQuery: req.URL.RawQuery, + } + upstreamReq.RequestURI = "" + upstreamReq.Host = targetAddr + upstreamReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + upstreamReq.ContentLength = int64(len(bodyBytes)) + upstreamReq.Header = cloneHeader(req.Header) + stripProxyHeaders(upstreamReq.Header) + upstreamReq.Header.Del("Accept-Encoding") + + resp, err := h.transport.RoundTrip(upstreamReq) + if err != nil { + entry := h.newProxyAuditEntry(start, requestID, req, req.URL.Path, model) + entry.StatusCode = http.StatusBadGateway + entry.ErrorType = "proxy_error" + entry.Data.ErrorMessage = err.Error() + h.writeAuditEntry(entry) + _, _ = io.WriteString(clientConn, "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n") + return err + } + defer func() { _ = resp.Body.Close() }() + slog.Info( + "forward proxy upstream response", + "target", targetAddr, + "method", req.Method, + "path", req.URL.Path, + "status", resp.StatusCode, + "request_id", requestID, + ) + + entry := h.newProxyAuditEntry(start, requestID, req, req.URL.Path, model) + entry.StatusCode = resp.StatusCode + entry.Data.RequestBody, entry.Data.RequestBodyTooBigToHandle = proxyBodyForAudit(bodyBytes) + entry.Data.ResponseHeaders = proxyHeaderMap(resp.Header) + + if isEventStreamHeader(resp.Header) || streamRequested { + return h.serveMITMStreamResponse(clientConn, req, resp, entry) + } + return h.serveMITMBufferedResponse(clientConn, req, resp, entry, bodyBytes) +} + +func (h *forwardProxyHandler) serveSyntheticEventLoggingSuccess(clientConn net.Conn, req *http.Request, start time.Time, requestID, model string, requestBody []byte) error { + entry := h.newProxyAuditEntry(start, requestID, req, req.URL.Path, model) + entry.StatusCode = http.StatusNoContent + entry.DurationNs = time.Since(start).Nanoseconds() + entry.Data.RequestBody, entry.Data.RequestBodyTooBigToHandle = proxyBodyForAudit(requestBody) + entry.Data.ResponseHeaders = map[string]string{ + "Connection": "close", + "Content-Length": "0", + } + h.writeAuditEntry(entry) + + if _, err := io.WriteString(clientConn, "HTTP/1.1 204 No Content\r\nConnection: close\r\nContent-Length: 0\r\n\r\n"); err != nil { + return err + } + return errForwardProxyCloseConnection +} + +func (h *forwardProxyHandler) serveMITMBufferedResponse(clientConn net.Conn, req *http.Request, resp *http.Response, entry *auditlog.LogEntry, requestBody []byte) error { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + entry.DurationNs = time.Since(entry.Timestamp).Nanoseconds() + entry.Data.RequestBody, entry.Data.RequestBodyTooBigToHandle = proxyBodyForAudit(requestBody) + entry.Data.ResponseBody, entry.Data.ResponseBodyTooBigToHandle = proxyBodyForAudit(respBody) + + if usageEntry := h.extractAnthropicUsageEntry(respBody, entry.RequestID, entry.Model, req.URL.Path); usageEntry != nil { + h.writeUsageEntry(usageEntry) + if entry.Model == "" && usageEntry.Model != "" { + entry.Model = usageEntry.Model + } + } + h.writeAuditEntry(entry) + + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + resp.Close = true + resp.Header.Set("Connection", "close") + normalizeResponseForProxyWrite(resp) + if err := resp.Write(clientConn); err != nil { + return err + } + return errForwardProxyCloseConnection +} + +func (h *forwardProxyHandler) serveMITMStreamResponse(clientConn net.Conn, req *http.Request, resp *http.Response, entry *auditlog.LogEntry) error { + observers := make([]streaming.Observer, 0, 2) + if h.auditLogger != nil && h.auditLogger.Config().Enabled { + if observer := auditlog.NewStreamLogObserver(h.auditLogger, entry, req.URL.Path); observer != nil { + observers = append(observers, observer) + } + } + if h.usageLogger != nil && h.usageLogger.Config().Enabled { + if observer := usage.NewStreamUsageObserver(h.usageLogger, entry.Model, entry.Provider, entry.RequestID, req.URL.Path, h.pricingResolver); observer != nil { + observers = append(observers, observer) + } + } + + wrappedStream := streaming.NewObservedSSEStream(resp.Body, observers...) + defer func() { _ = wrappedStream.Close() }() + + resp.Body = wrappedStream + resp.Close = true + resp.Header.Set("Connection", "close") + normalizeResponseForProxyWrite(resp) + if err := resp.Write(clientConn); err != nil { + return err + } + return errForwardProxyCloseConnection +} + +func (h *forwardProxyHandler) handleHTTPProxy(w http.ResponseWriter, r *http.Request) { + start := time.Now().UTC() + requestID := ensureProxyRequestID(r.Header) + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read proxy request body", http.StatusBadRequest) + return + } + model, _ := extractProxyRequestInfo(bodyBytes) + + upstreamReq := r.Clone(context.Background()) + upstreamReq.RequestURI = "" + upstreamReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + upstreamReq.ContentLength = int64(len(bodyBytes)) + upstreamReq.Header = cloneHeader(r.Header) + stripProxyHeaders(upstreamReq.Header) + upstreamReq.Header.Del("Accept-Encoding") + + resp, err := h.transport.RoundTrip(upstreamReq) + if err != nil { + entry := h.newProxyAuditEntry(start, requestID, r, r.URL.Path, model) + entry.StatusCode = http.StatusBadGateway + entry.ErrorType = "proxy_error" + entry.Data.ErrorMessage = err.Error() + h.writeAuditEntry(entry) + http.Error(w, "proxy upstream request failed", http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, "failed to read proxy upstream response", http.StatusBadGateway) + return + } + + entry := h.newProxyAuditEntry(start, requestID, r, r.URL.Path, model) + entry.StatusCode = resp.StatusCode + entry.DurationNs = time.Since(start).Nanoseconds() + entry.Data.RequestBody, entry.Data.RequestBodyTooBigToHandle = proxyBodyForAudit(bodyBytes) + entry.Data.ResponseBody, entry.Data.ResponseBodyTooBigToHandle = proxyBodyForAudit(respBody) + entry.Data.ResponseHeaders = proxyHeaderMap(resp.Header) + h.writeAuditEntry(entry) + + if usageEntry := h.extractAnthropicUsageEntry(respBody, requestID, model, r.URL.Path); usageEntry != nil { + h.writeUsageEntry(usageEntry) + } + + copyProxyResponseHeaders(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + _, _ = w.Write(respBody) +} + +func (h *forwardProxyHandler) newProxyAuditEntry(start time.Time, requestID string, req *http.Request, path, model string) *auditlog.LogEntry { + entry := &auditlog.LogEntry{ + ID: uuid.NewString(), + Timestamp: start, + RequestID: requestID, + ClientIP: proxyClientIP(req.RemoteAddr), + Method: req.Method, + Path: path, + Model: model, + Provider: "anthropic", + StatusCode: http.StatusOK, + Data: &auditlog.LogData{ + UserAgent: req.UserAgent(), + APIKeyHash: proxyCredentialHash(req.Header), + RequestHeaders: proxyHeaderMap(req.Header), + }, + } + return entry +} + +func (h *forwardProxyHandler) writeAuditEntry(entry *auditlog.LogEntry) { + if h.auditLogger == nil || !h.auditLogger.Config().Enabled || entry == nil { + return + } + h.auditLogger.Write(entry) +} + +func (h *forwardProxyHandler) writeUsageEntry(entry *usage.UsageEntry) { + if h.usageLogger == nil || !h.usageLogger.Config().Enabled || entry == nil { + return + } + h.usageLogger.Write(entry) +} + +func (h *forwardProxyHandler) extractAnthropicUsageEntry(body []byte, requestID, model, endpoint string) *usage.UsageEntry { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil + } + + usageRaw, ok := payload["usage"].(map[string]any) + if !ok { + return nil + } + + providerID, _ := payload["id"].(string) + if responseModel, _ := payload["model"].(string); responseModel != "" { + model = responseModel + } + + inputTokens := int(floatFromMap(usageRaw, "input_tokens")) + outputTokens := int(floatFromMap(usageRaw, "output_tokens")) + totalTokens := int(floatFromMap(usageRaw, "total_tokens")) + if totalTokens == 0 && (inputTokens > 0 || outputTokens > 0) { + totalTokens = inputTokens + outputTokens + } + if inputTokens == 0 && outputTokens == 0 && totalTokens == 0 { + return nil + } + + rawData := make(map[string]any) + for _, key := range []string{"cache_creation_input_tokens", "cache_read_input_tokens"} { + if value := int(floatFromMap(usageRaw, key)); value > 0 { + rawData[key] = value + } + } + if len(rawData) == 0 { + rawData = nil + } + + var pricingArgs []*core.ModelPricing + if h.pricingResolver != nil { + if pricing := h.pricingResolver.ResolvePricing(model, "anthropic"); pricing != nil { + pricingArgs = append(pricingArgs, pricing) + } + } + + return usage.ExtractFromSSEUsage( + providerID, + inputTokens, outputTokens, totalTokens, + rawData, + requestID, model, "anthropic", endpoint, + pricingArgs..., + ) +} + +func extractProxyRequestInfo(body []byte) (model string, stream bool) { + if len(body) == 0 { + return "", false + } + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return "", false + } + if m, ok := payload["model"].(string); ok { + model = strings.TrimSpace(m) + } + if s, ok := payload["stream"].(bool); ok { + stream = s + } + return model, stream +} + +func ensureProxyRequestID(headers http.Header) string { + if headers == nil { + return uuid.NewString() + } + if requestID := strings.TrimSpace(headers.Get("X-Request-ID")); requestID != "" { + return requestID + } + requestID := uuid.NewString() + headers.Set("X-Request-ID", requestID) + return requestID +} + +func proxyBodyForAudit(body []byte) (any, bool) { + if len(body) == 0 { + return nil, false + } + if len(body) > auditlog.MaxBodyCapture { + return nil, true + } + + var parsed any + if err := json.Unmarshal(body, &parsed); err == nil { + return parsed, false + } + return string(body), false +} + +func proxyHeaderMap(headers http.Header) map[string]string { + if len(headers) == 0 { + return nil + } + result := make(map[string]string, len(headers)) + for key, values := range headers { + if len(values) == 0 { + continue + } + result[key] = values[0] + } + return auditlog.RedactHeaders(result) +} + +func proxyCredentialHash(headers http.Header) string { + if headers == nil { + return "" + } + token := strings.TrimSpace(headers.Get("Authorization")) + if token == "" { + token = strings.TrimSpace(headers.Get("Cookie")) + } + if token == "" { + return "" + } + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:])[:auditlog.APIKeyHashPrefixLength] +} + +func proxyClientIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr)) + if err != nil { + return strings.TrimSpace(remoteAddr) + } + return host +} + +func targetAddress(hostPort, defaultPort string) string { + hostPort = strings.TrimSpace(hostPort) + if hostPort == "" { + return "" + } + if _, _, err := net.SplitHostPort(hostPort); err == nil { + return hostPort + } + return net.JoinHostPort(hostPort, defaultPort) +} + +func normalizeResponseForProxyWrite(resp *http.Response) { + if resp == nil { + return + } + resp.Proto = "HTTP/1.1" + resp.ProtoMajor = 1 + resp.ProtoMinor = 1 +} + +func canonicalProxyHost(host string) string { + host = strings.TrimSpace(host) + if host == "" { + return "" + } + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + host = parsedHost + } + return strings.ToLower(strings.TrimSpace(host)) +} + +func cloneHeader(headers http.Header) http.Header { + if headers == nil { + return nil + } + cloned := make(http.Header, len(headers)) + for key, values := range headers { + dst := make([]string, len(values)) + copy(dst, values) + cloned[key] = dst + } + return cloned +} + +func stripProxyHeaders(headers http.Header) { + if headers == nil { + return + } + for _, key := range []string{ + "Proxy-Connection", + "Proxy-Authorization", + "Proxy-Authenticate", + "Connection", + } { + headers.Del(key) + } +} + +func copyProxyResponseHeaders(dst, src http.Header) { + if dst == nil || src == nil { + return + } + connectionHeaders := passthroughConnectionHeaders(src) + for key, values := range src { + canonicalKey := http.CanonicalHeaderKey(strings.TrimSpace(key)) + if len(values) == 0 { + continue + } + if _, hopByHop := connectionHeaders[canonicalKey]; hopByHop { + continue + } + switch canonicalKey { + case "Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "Te", "Trailer", "Transfer-Encoding", "Upgrade": + continue + } + dst.Del(canonicalKey) + for _, value := range values { + dst.Add(canonicalKey, value) + } + } +} + +func sanitizeProxyPreview(data []byte) string { + if len(data) == 0 { + return "" + } + const maxLen = 32 + if len(data) > maxLen { + data = data[:maxLen] + } + out := make([]byte, 0, len(data)) + for _, b := range data { + if b >= 32 && b <= 126 { + out = append(out, b) + continue + } + switch b { + case '\r': + out = append(out, '\\', 'r') + case '\n': + out = append(out, '\\', 'n') + case '\t': + out = append(out, '\\', 't') + default: + out = append(out, '.') + } + } + return string(out) +} + +func isEventStreamHeader(headers http.Header) bool { + for key, values := range headers { + if !strings.EqualFold(key, "Content-Type") { + continue + } + for _, value := range values { + if strings.Contains(strings.ToLower(value), "text/event-stream") { + return true + } + } + } + return false +} + +func floatFromMap(values map[string]any, key string) float64 { + value, ok := values[key] + if !ok { + return 0 + } + if number, ok := value.(float64); ok { + return number + } + return 0 +} + +func loadMITMCertificateAuthority(certPath, keyPath string) (*mitmCertificateAuthority, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("failed to read forward proxy CA certificate: %w", err) + } + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read forward proxy CA key: %w", err) + } + + pair, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, fmt.Errorf("failed to load forward proxy CA keypair: %w", err) + } + + leaf, err := x509.ParseCertificate(pair.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("failed to parse forward proxy CA certificate: %w", err) + } + + signer, ok := pair.PrivateKey.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("forward proxy CA key does not implement crypto.Signer") + } + + return &mitmCertificateAuthority{ + cert: leaf, + key: signer, + }, nil +} + +func (a *mitmCertificateAuthority) certificateForHost(host string) (*tls.Certificate, error) { + if cached, ok := a.cache.Load(host); ok { + if cert, ok := cached.(*tls.Certificate); ok { + return cert, nil + } + } + + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, err + } + + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: host, + }, + NotBefore: time.Now().Add(-5 * time.Minute), + NotAfter: minTime(a.cert.NotAfter, time.Now().Add(24*time.Hour)), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + if ip := net.ParseIP(host); ip != nil { + template.IPAddresses = []net.IP{ip} + } else { + template.DNSNames = []string{host} + } + + der, err := x509.CreateCertificate(rand.Reader, template, a.cert, privateKey.Public(), a.key) + if err != nil { + return nil, err + } + + leaf, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + + cert := &tls.Certificate{ + Certificate: [][]byte{der, a.cert.Raw}, + PrivateKey: privateKey, + Leaf: leaf, + } + a.cache.Store(host, cert) + return cert, nil +} + +func minTime(first, second time.Time) time.Time { + if first.Before(second) { + return first + } + return second +} + +func encodePEMBlock(blockType string, der []byte) []byte { + return pem.EncodeToMemory(&pem.Block{Type: blockType, Bytes: der}) +} diff --git a/internal/server/forward_proxy_test.go b/internal/server/forward_proxy_test.go new file mode 100644 index 00000000..e806603b --- /dev/null +++ b/internal/server/forward_proxy_test.go @@ -0,0 +1,470 @@ +package server + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "io" + "math/big" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "gomodel/internal/auditlog" + "gomodel/internal/usage" +) + +func TestForwardProxyMITMAnthropicJSONUsageAndAudit(t *testing.T) { + upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + t.Fatalf("path = %q, want /v1/messages", r.URL.Path) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read upstream request body: %v", err) + } + if !strings.Contains(string(body), `"model":"claude-sonnet-4-5"`) { + t.Fatalf("unexpected request body: %s", body) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_123","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":10,"output_tokens":2,"cache_read_input_tokens":6}}`)) + })) + defer upstream.Close() + + caCertPath, caKeyPath, caCert := writeTestCAFiles(t) + proxyURL, proxyAudit, proxyUsage := startTestForwardProxy(t, upstream, caCertPath, caKeyPath) + + client := newMITMHTTPClient(t, proxyURL, caCert) + req, err := http.NewRequest(http.MethodPost, upstream.URL+"/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"Reply with ok"}]}`)) + if err != nil { + t.Fatalf("NewRequest error: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "claude-code-test") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client.Do error: %v", err) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("Close error: %v", err) + } + if !strings.Contains(string(body), `"id":"msg_123"`) { + t.Fatalf("unexpected response body: %s", body) + } + + auditEntries := waitForAuditEntries(t, proxyAudit, 1) + auditEntry := auditEntries[0] + if auditEntry.Provider != "anthropic" { + t.Fatalf("Provider = %q, want anthropic", auditEntry.Provider) + } + if auditEntry.Path != "/v1/messages" { + t.Fatalf("Path = %q, want /v1/messages", auditEntry.Path) + } + if auditEntry.Model != "claude-sonnet-4-5" { + t.Fatalf("Model = %q, want claude-sonnet-4-5", auditEntry.Model) + } + if auditEntry.Data == nil || auditEntry.Data.RequestBody == nil || auditEntry.Data.ResponseBody == nil { + t.Fatal("expected request and response bodies to be captured") + } + + usageEntries := waitForUsageEntries(t, proxyUsage, 1) + usageEntry := usageEntries[0] + if usageEntry.Provider != "anthropic" { + t.Fatalf("Provider = %q, want anthropic", usageEntry.Provider) + } + if usageEntry.Endpoint != "/v1/messages" { + t.Fatalf("Endpoint = %q, want /v1/messages", usageEntry.Endpoint) + } + if usageEntry.InputTokens != 10 { + t.Fatalf("InputTokens = %d, want 10", usageEntry.InputTokens) + } + if usageEntry.OutputTokens != 2 { + t.Fatalf("OutputTokens = %d, want 2", usageEntry.OutputTokens) + } + if usageEntry.TotalTokens != 12 { + t.Fatalf("TotalTokens = %d, want 12", usageEntry.TotalTokens) + } + if usageEntry.RawData["cache_read_input_tokens"] != 6 { + t.Fatalf("RawData[cache_read_input_tokens] = %v, want 6", usageEntry.RawData["cache_read_input_tokens"]) + } +} + +func TestForwardProxyMITMAnthropicStreamingUsageAndAudit(t *testing.T) { + upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(`event: message_start +data: {"type":"message_start","message":{"id":"msg_stream_123","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[],"usage":{"input_tokens":10,"output_tokens":0}}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":2}} + +data: [DONE] + +`)) + })) + defer upstream.Close() + + caCertPath, caKeyPath, caCert := writeTestCAFiles(t) + proxyURL, proxyAudit, proxyUsage := startTestForwardProxy(t, upstream, caCertPath, caKeyPath) + + client := newMITMHTTPClient(t, proxyURL, caCert) + req, err := http.NewRequest(http.MethodPost, upstream.URL+"/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-5","stream":true,"messages":[{"role":"user","content":"hi"}]}`)) + if err != nil { + t.Fatalf("NewRequest error: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client.Do error: %v", err) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("Close error: %v", err) + } + if !strings.Contains(string(body), "message_start") { + t.Fatalf("unexpected streamed body: %s", body) + } + + auditEntries := waitForAuditEntries(t, proxyAudit, 1) + auditEntry := auditEntries[0] + if auditEntry.Data == nil { + t.Fatal("Data = nil") + } + responseBody, ok := auditEntry.Data.ResponseBody.(map[string]any) + if !ok { + t.Fatalf("ResponseBody = %#v, want map", auditEntry.Data.ResponseBody) + } + if responseBody["id"] != "msg_stream_123" { + t.Fatalf("id = %#v, want msg_stream_123", responseBody["id"]) + } + if content, ok := responseBody["content"].([]map[string]any); ok { + if len(content) != 1 || content[0]["text"] != "Hello world" { + t.Fatalf("content = %#v, want Hello world", content) + } + } else { + contentAny, ok := responseBody["content"].([]any) + if !ok || len(contentAny) != 1 { + t.Fatalf("content = %#v, want one text block", responseBody["content"]) + } + content, ok := contentAny[0].(map[string]any) + if !ok || content["text"] != "Hello world" { + t.Fatalf("content[0] = %#v, want Hello world", contentAny[0]) + } + } + + usageEntries := waitForUsageEntries(t, proxyUsage, 1) + usageEntry := usageEntries[0] + if usageEntry.InputTokens != 10 || usageEntry.OutputTokens != 2 || usageEntry.TotalTokens != 12 { + t.Fatalf("unexpected usage entry: %+v", usageEntry) + } +} + +func waitForAuditEntries(t *testing.T, logger *capturingAuditLogger, want int) []*auditlog.LogEntry { + t.Helper() + + deadline := time.Now().Add(2 * time.Second) + for { + entries := logger.Entries() + if len(entries) == want { + return entries + } + if time.Now().After(deadline) { + t.Fatalf("audit entries = %d, want %d", len(entries), want) + } + time.Sleep(10 * time.Millisecond) + } +} + +func waitForUsageEntries(t *testing.T, logger *collectingUsageLogger, want int) []*usage.UsageEntry { + t.Helper() + + deadline := time.Now().Add(2 * time.Second) + for { + entries := logger.Entries() + if len(entries) == want { + return entries + } + if time.Now().After(deadline) { + t.Fatalf("usage entries = %d, want %d", len(entries), want) + } + time.Sleep(10 * time.Millisecond) + } +} + +func startTestForwardProxy(t *testing.T, upstream *httptest.Server, caCertPath, caKeyPath string) (string, *capturingAuditLogger, *collectingUsageLogger) { + t.Helper() + + upstreamURL, err := url.Parse(upstream.URL) + if err != nil { + t.Fatalf("Parse upstream URL error: %v", err) + } + mitmHost := canonicalProxyHost(upstreamURL.Host) + + auditLogger := &capturingAuditLogger{ + config: auditlog.Config{ + Enabled: true, + LogBodies: true, + LogHeaders: true, + }, + } + usageLogger := &collectingUsageLogger{ + config: usage.Config{Enabled: true}, + } + handler, err := NewForwardProxyHandler(http.NotFoundHandler(), &ForwardProxyConfig{ + Enabled: true, + MITMHosts: []string{mitmHost}, + CACertFile: caCertPath, + CAKeyFile: caKeyPath, + AuditLogger: auditLogger, + UsageLogger: usageLogger, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + }) + if err != nil { + t.Fatalf("NewForwardProxyHandler error: %v", err) + } + + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + return server.URL, auditLogger, usageLogger +} + +func newMITMHTTPClient(t *testing.T, proxyAddr string, caCert *x509.Certificate) *http.Client { + t.Helper() + + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + t.Fatalf("Parse proxy URL error: %v", err) + } + pool := x509.NewCertPool() + pool.AddCert(caCert) + + return &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{ + RootCAs: pool, + }, + }, + } +} + +func writeTestCAFiles(t *testing.T) (string, string, *x509.Certificate) { + t.Helper() + + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey error: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "gomodel-test-ca", + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, privateKey.Public(), privateKey) + if err != nil { + t.Fatalf("CreateCertificate error: %v", err) + } + cert, err := x509.ParseCertificate(der) + if err != nil { + t.Fatalf("ParseCertificate error: %v", err) + } + + keyDER, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + t.Fatalf("MarshalPKCS8PrivateKey error: %v", err) + } + + dir := t.TempDir() + certPath := filepath.Join(dir, "ca-cert.pem") + keyPath := filepath.Join(dir, "ca-key.pem") + if err := os.WriteFile(certPath, encodePEMBlock("CERTIFICATE", der), 0o600); err != nil { + t.Fatalf("WriteFile cert error: %v", err) + } + if err := os.WriteFile(keyPath, encodePEMBlock("PRIVATE KEY", keyDER), 0o600); err != nil { + t.Fatalf("WriteFile key error: %v", err) + } + + return certPath, keyPath, cert +} + +func TestIsForwardProxyRequest(t *testing.T) { + absoluteURL, _ := url.Parse("https://api.anthropic.com/v1/messages") + tests := []struct { + name string + req *http.Request + want bool + }{ + { + name: "connect", + req: &http.Request{Method: http.MethodConnect, URL: &url.URL{}}, + want: true, + }, + { + name: "absolute URL", + req: &http.Request{Method: http.MethodPost, URL: absoluteURL}, + want: true, + }, + { + name: "normal API route", + req: &http.Request{Method: http.MethodPost, URL: &url.URL{Path: "/v1/chat/completions"}}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isForwardProxyRequest(tt.req); got != tt.want { + t.Fatalf("isForwardProxyRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProxyBodyForAuditParsesJSON(t *testing.T) { + value, tooBig := proxyBodyForAudit([]byte(`{"ok":true}`)) + if tooBig { + t.Fatal("tooBig = true, want false") + } + parsed, ok := value.(map[string]any) + if !ok { + t.Fatalf("value = %#v, want map", value) + } + if parsed["ok"] != true { + t.Fatalf("ok = %#v, want true", parsed["ok"]) + } +} + +func TestProxyBodyForAuditRejectsLargePayload(t *testing.T) { + value, tooBig := proxyBodyForAudit([]byte(strings.Repeat("x", auditlog.MaxBodyCapture+1))) + if value != nil { + t.Fatalf("value = %#v, want nil", value) + } + if !tooBig { + t.Fatal("tooBig = false, want true") + } +} + +func TestExtractAnthropicUsageEntryComputesTotal(t *testing.T) { + handler := &forwardProxyHandler{} + entry := handler.extractAnthropicUsageEntry( + []byte(`{"id":"msg_123","model":"claude-sonnet-4-5","usage":{"input_tokens":10,"output_tokens":2,"cache_read_input_tokens":6}}`), + "req-123", + "claude-sonnet-4-5", + "/v1/messages", + ) + if entry == nil { + t.Fatal("entry = nil") + } + if entry.TotalTokens != 12 { + t.Fatalf("TotalTokens = %d, want 12", entry.TotalTokens) + } + if entry.RawData["cache_read_input_tokens"] != 6 { + t.Fatalf("RawData[cache_read_input_tokens] = %v, want 6", entry.RawData["cache_read_input_tokens"]) + } +} + +func TestNormalizeResponseForProxyWriteForcesHTTP11(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Status: "401 Unauthorized", + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + Body: io.NopCloser(strings.NewReader(`{"error":"unauthorized"}`)), + ContentLength: int64(len(`{"error":"unauthorized"}`)), + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + } + + normalizeResponseForProxyWrite(resp) + + var buf bytes.Buffer + if err := resp.Write(&buf); err != nil { + t.Fatalf("resp.Write error: %v", err) + } + if !strings.HasPrefix(buf.String(), "HTTP/1.1 401 Unauthorized\r\n") { + t.Fatalf("unexpected wire response prefix: %q", buf.String()) + } +} + +func TestWriteTestCAFilesProducesParseablePEM(t *testing.T) { + certPath, keyPath, cert := writeTestCAFiles(t) + if cert == nil { + t.Fatal("cert = nil") + } + certPEM, err := os.ReadFile(certPath) + if err != nil { + t.Fatalf("ReadFile cert error: %v", err) + } + keyPEM, err := os.ReadFile(keyPath) + if err != nil { + t.Fatalf("ReadFile key error: %v", err) + } + pair, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("X509KeyPair error: %v", err) + } + leaf, err := x509.ParseCertificate(pair.Certificate[0]) + if err != nil { + t.Fatalf("ParseCertificate error: %v", err) + } + if leaf.Subject.CommonName != "gomodel-test-ca" { + t.Fatalf("CommonName = %q, want gomodel-test-ca", leaf.Subject.CommonName) + } +} + +func TestProxyHeaderMapRedactsSensitiveHeaders(t *testing.T) { + headers := proxyHeaderMap(http.Header{ + "Authorization": {"Bearer secret"}, + "Cookie": {"session=secret"}, + "X-Test": {"ok"}, + }) + data, err := json.Marshal(headers) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if strings.Contains(string(data), "secret") { + t.Fatalf("expected redaction, got %s", data) + } + if headers["X-Test"] != "ok" { + t.Fatalf("X-Test = %q, want ok", headers["X-Test"]) + } +} diff --git a/internal/server/handlers_test.go b/internal/server/handlers_test.go index 2839ebe6..d5163e1a 100644 --- a/internal/server/handlers_test.go +++ b/internal/server/handlers_test.go @@ -14,6 +14,7 @@ import ( "slices" "sort" "strings" + "sync" "testing" "time" @@ -229,14 +230,23 @@ func setPathParam(c *echo.Context, name, value string) { } type capturingAuditLogger struct { + mu sync.Mutex config auditlog.Config entries []*auditlog.LogEntry } func (l *capturingAuditLogger) Write(entry *auditlog.LogEntry) { + l.mu.Lock() + defer l.mu.Unlock() l.entries = append(l.entries, entry) } +func (l *capturingAuditLogger) Entries() []*auditlog.LogEntry { + l.mu.Lock() + defer l.mu.Unlock() + return slices.Clone(l.entries) +} + func (l *capturingAuditLogger) Config() auditlog.Config { return l.config } @@ -3973,17 +3983,26 @@ func (c *capturingUsageLogger) Config() usage.Config { return c.config func (c *capturingUsageLogger) Close() error { return nil } type collectingUsageLogger struct { + mu sync.Mutex config usage.Config entries []*usage.UsageEntry } func (c *collectingUsageLogger) Write(entry *usage.UsageEntry) { + c.mu.Lock() + defer c.mu.Unlock() if entry == nil { return } c.entries = append(c.entries, entry) } +func (c *collectingUsageLogger) Entries() []*usage.UsageEntry { + c.mu.Lock() + defer c.mu.Unlock() + return slices.Clone(c.entries) +} + func (c *collectingUsageLogger) Config() usage.Config { return c.config } func (c *collectingUsageLogger) Close() error { return nil } diff --git a/internal/server/http.go b/internal/server/http.go index 73bac7f7..2e710c35 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -29,6 +29,7 @@ import ( type Server struct { echo *echo.Echo handler *Handler + root http.Handler responseCacheMiddleware *responsecache.ResponseCacheMiddleware } @@ -58,6 +59,7 @@ type Config struct { DashboardHandler *dashboard.Handler // Dashboard UI handler (nil if disabled) SwaggerEnabled bool // Whether to expose the Swagger UI at /swagger/index.html ResponseCacheMiddleware *responsecache.ResponseCacheMiddleware // Optional: response cache middleware for cacheable endpoints + ExperimentalForwardProxy *ForwardProxyConfig // Optional: experimental forward proxy wrapper GuardrailsHash string // Optional: SHA-256 hash of active guardrail rules; stored in context post-patch for semantic cache } @@ -283,9 +285,23 @@ func New(provider core.RoutableProvider, cfg *Config) *Server { if cfg != nil { rcm = cfg.ResponseCacheMiddleware } + root := http.Handler(e) + if cfg != nil && cfg.ExperimentalForwardProxy != nil && cfg.ExperimentalForwardProxy.Enabled { + proxyCfg := *cfg.ExperimentalForwardProxy + proxyCfg.AuditLogger = auditLogger + proxyCfg.UsageLogger = usageLogger + proxyCfg.PricingResolver = pricingResolver + proxyHandler, err := NewForwardProxyHandler(root, &proxyCfg) + if err != nil { + slog.Error("failed to enable experimental forward proxy", "error", err) + } else { + root = proxyHandler + } + } return &Server{ echo: e, handler: handler, + root: root, responseCacheMiddleware: rcm, } } @@ -303,7 +319,7 @@ func (s *Server) Start(ctx context.Context, addr string) error { Address: addr, HideBanner: true, } - return sc.Start(ctx, s.echo) + return sc.Start(ctx, s.root) } // Shutdown releases server resources. The HTTP server itself is stopped by @@ -318,7 +334,7 @@ func (s *Server) Shutdown(_ context.Context) error { // ServeHTTP implements the http.Handler interface, allowing Server to be used with httptest func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.echo.ServeHTTP(w, r) + s.root.ServeHTTP(w, r) } func parseBodySizeLimitBytes(limit string) int64 { diff --git a/internal/usage/stream_observer.go b/internal/usage/stream_observer.go index 1b3c8d44..4818b215 100644 --- a/internal/usage/stream_observer.go +++ b/internal/usage/stream_observer.go @@ -31,7 +31,7 @@ func NewStreamUsageObserver(logger LoggerInterface, model, provider, requestID, func (o *StreamUsageObserver) OnJSONEvent(chunk map[string]any) { entry := o.extractUsageFromEvent(chunk) if entry != nil { - o.cachedEntry = entry + o.cachedEntry = mergeUsageEntries(o.cachedEntry, entry) } } @@ -54,6 +54,19 @@ func (o *StreamUsageObserver) extractUsageFromEvent(chunk map[string]any) *Usage } usageRaw, ok := chunk["usage"] + if !ok { + if eventType, _ := chunk["type"].(string); eventType == "message_start" { + if message, msgOK := chunk["message"].(map[string]any); msgOK { + usageRaw, ok = message["usage"] + if id, idOK := message["id"].(string); idOK && id != "" { + providerID = id + } + if m, modelOK := message["model"].(string); modelOK && m != "" { + model = m + } + } + } + } if !ok { if eventType, _ := chunk["type"].(string); eventType == "response.completed" || eventType == "response.done" { if response, respOK := chunk["response"].(map[string]any); respOK { @@ -138,3 +151,69 @@ func (o *StreamUsageObserver) extractUsageFromEvent(chunk map[string]any) *Usage pricingArgs..., ) } + +func mergeUsageEntries(prev, next *UsageEntry) *UsageEntry { + if prev == nil { + return next + } + if next == nil { + return prev + } + + merged := *prev + + if next.ProviderID != "" { + merged.ProviderID = next.ProviderID + } + if next.Model != "" { + merged.Model = next.Model + } + if next.Provider != "" { + merged.Provider = next.Provider + } + if next.Endpoint != "" { + merged.Endpoint = next.Endpoint + } + if next.RequestID != "" { + merged.RequestID = next.RequestID + } + if next.Timestamp.After(merged.Timestamp) { + merged.Timestamp = next.Timestamp + } + + if next.InputTokens > 0 { + merged.InputTokens = next.InputTokens + } + if next.OutputTokens > 0 { + merged.OutputTokens = next.OutputTokens + } + if next.TotalTokens > 0 { + merged.TotalTokens = next.TotalTokens + } else if merged.InputTokens > 0 || merged.OutputTokens > 0 { + merged.TotalTokens = merged.InputTokens + merged.OutputTokens + } + + switch { + case merged.RawData == nil && next.RawData != nil: + merged.RawData = cloneRawData(next.RawData) + case merged.RawData != nil && next.RawData != nil: + for key, value := range next.RawData { + merged.RawData[key] = value + } + } + + if next.InputCost != nil { + merged.InputCost = next.InputCost + } + if next.OutputCost != nil { + merged.OutputCost = next.OutputCost + } + if next.TotalCost != nil { + merged.TotalCost = next.TotalCost + } + if next.CostsCalculationCaveat != "" { + merged.CostsCalculationCaveat = next.CostsCalculationCaveat + } + + return &merged +} diff --git a/internal/usage/stream_observer_test.go b/internal/usage/stream_observer_test.go index 3bfeba06..8636ceca 100644 --- a/internal/usage/stream_observer_test.go +++ b/internal/usage/stream_observer_test.go @@ -330,3 +330,61 @@ data: [DONE] t.Errorf("TotalTokens = %d, want 8", entry.TotalTokens) } } + +func TestStreamUsageObserverAnthropicMessagesMergesUsageAcrossEvents(t *testing.T) { + streamData := `event: message_start +data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[],"usage":{"input_tokens":10,"output_tokens":0,"cache_read_input_tokens":6}}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":2}} + +data: [DONE] + +` + logger := &trackingLogger{enabled: true} + stream := streaming.NewObservedSSEStream( + io.NopCloser(strings.NewReader(streamData)), + NewStreamUsageObserver(logger, "claude-sonnet-4-5", "anthropic", "req-anthropic", "/v1/messages", nil), + ) + + data, err := io.ReadAll(stream) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if string(data) != streamData { + t.Fatalf("stream passthrough mismatch") + } + if err := stream.Close(); err != nil { + t.Fatalf("Close error: %v", err) + } + + entries := logger.getEntries() + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + entry := entries[0] + if entry.ProviderID != "msg_123" { + t.Fatalf("ProviderID = %q, want msg_123", entry.ProviderID) + } + if entry.Model != "claude-sonnet-4-5" { + t.Fatalf("Model = %q, want claude-sonnet-4-5", entry.Model) + } + if entry.InputTokens != 10 { + t.Fatalf("InputTokens = %d, want 10", entry.InputTokens) + } + if entry.OutputTokens != 2 { + t.Fatalf("OutputTokens = %d, want 2", entry.OutputTokens) + } + if entry.TotalTokens != 12 { + t.Fatalf("TotalTokens = %d, want 12", entry.TotalTokens) + } + if entry.RawData == nil { + t.Fatal("RawData = nil") + } + if entry.RawData["cache_read_input_tokens"] != 6 { + t.Fatalf("RawData[cache_read_input_tokens] = %v, want 6", entry.RawData["cache_read_input_tokens"]) + } +}