diff --git a/README.md b/README.md
index dffe3e98..76afb457 100644
--- a/README.md
+++ b/README.md
@@ -172,6 +172,8 @@ Key settings:
| `ENABLE_PASSTHROUGH_ROUTES` | `true` | Enable provider-native passthrough routes under `/p/{provider}/...` |
| `ALLOW_PASSTHROUGH_V1_ALIAS` | `true` | Allow `/p/{provider}/v1/...` aliases while keeping `/p/{provider}/...` canonical |
| `ENABLED_PASSTHROUGH_PROVIDERS` | `openai,anthropic` | Comma-separated list of enabled passthrough providers |
+| `EXPERIMENTAL_FORWARD_PROXY_ENABLED` | `false` | Enable the experimental HTTP forward proxy wrapper for client traffic inspection |
+| `EXPERIMENTAL_FORWARD_PROXY_MITM_HOSTS` | `api.anthropic.com` | Comma-separated HTTPS hosts to inspect; other CONNECT targets are tunneled blindly |
| `CACHE_TYPE` | `local` | Cache backend (`local` or `redis`) |
| `STORAGE_TYPE` | `sqlite` | Storage backend (`sqlite`, `postgresql`, `mongodb`) |
| `METRICS_ENABLED` | `false` | Enable Prometheus metrics |
@@ -180,6 +182,8 @@ Key settings:
**Quick Start - Authentication:** By default `GOMODEL_MASTER_KEY` is unset. Without this key, API endpoints are unprotected and anyone can call them. This is insecure for production. **Strongly recommend** setting a strong secret before exposing the service. Add `GOMODEL_MASTER_KEY` to your `.env` or environment for production deployments.
+**Experimental forward proxy:** When `EXPERIMENTAL_FORWARD_PROXY_ENABLED=true`, GoModel can act as a local HTTP proxy for client traffic. To inspect HTTPS bodies, provide `EXPERIMENTAL_FORWARD_PROXY_CA_CERT_FILE` and `EXPERIMENTAL_FORWARD_PROXY_CA_KEY_FILE`, trust that CA in the client environment, and point the client at GoModel with `HTTP_PROXY` or `HTTPS_PROXY`. This mode is intended for local experiments and is not hardened as a general-purpose enterprise proxy. For Claude Code setup examples, see [`docs/guides/claude-code.mdx`](docs/guides/claude-code.mdx).
+
---
## Response Caching
diff --git a/config/config.example.yaml b/config/config.example.yaml
index c068c498..adc72b61 100644
--- a/config/config.example.yaml
+++ b/config/config.example.yaml
@@ -11,6 +11,10 @@ server:
enable_passthrough_routes: true # expose /p/{provider}/{endpoint} passthrough routes
allow_passthrough_v1_alias: true # allow /p/{provider}/v1/... while keeping /p/{provider}/... canonical
enabled_passthrough_providers: ["openai", "anthropic"] # providers enabled on /p/{provider}/...
+ experimental_forward_proxy_enabled: false # HTTP forward proxy for client traffic inspection (experimental)
+ experimental_forward_proxy_mitm_hosts: ["api.anthropic.com"] # HTTPS hosts to inspect; all others tunnel blindly
+ experimental_forward_proxy_ca_cert_file: "" # PEM CA cert used for MITM leaf cert generation
+ experimental_forward_proxy_ca_key_file: "" # PEM CA private key used for MITM leaf cert generation
cache:
model:
diff --git a/config/config.go b/config/config.go
index 983822a3..f25e39ec 100644
--- a/config/config.go
+++ b/config/config.go
@@ -434,6 +434,18 @@ type ServerConfig struct {
// EnabledPassthroughProviders lists the provider types enabled on
// /p/{provider}/... passthrough routes. Default: ["openai", "anthropic"].
EnabledPassthroughProviders []string `yaml:"enabled_passthrough_providers" env:"ENABLED_PASSTHROUGH_PROVIDERS"`
+ // ExperimentalForwardProxyEnabled enables an HTTP forward proxy entrypoint that can
+ // optionally MITM selected HTTPS hosts for traffic inspection. Default: false.
+ ExperimentalForwardProxyEnabled bool `yaml:"experimental_forward_proxy_enabled" env:"EXPERIMENTAL_FORWARD_PROXY_ENABLED"`
+ // ExperimentalForwardProxyMITMHosts lists the hosts whose HTTPS CONNECT traffic
+ // should be terminated and inspected. Other hosts are tunneled blindly.
+ ExperimentalForwardProxyMITMHosts []string `yaml:"experimental_forward_proxy_mitm_hosts" env:"EXPERIMENTAL_FORWARD_PROXY_MITM_HOSTS"`
+ // ExperimentalForwardProxyCACertFile points at the PEM-encoded CA certificate used
+ // to mint leaf certificates for inspected HTTPS hosts.
+ ExperimentalForwardProxyCACertFile string `yaml:"experimental_forward_proxy_ca_cert_file" env:"EXPERIMENTAL_FORWARD_PROXY_CA_CERT_FILE"`
+ // ExperimentalForwardProxyCAKeyFile points at the PEM-encoded CA private key used
+ // to mint leaf certificates for inspected HTTPS hosts.
+ ExperimentalForwardProxyCAKeyFile string `yaml:"experimental_forward_proxy_ca_key_file" env:"EXPERIMENTAL_FORWARD_PROXY_CA_KEY_FILE"`
}
// MetricsConfig holds observability configuration for Prometheus metrics
@@ -504,6 +516,7 @@ func buildDefaultConfig() *Config {
"openai",
"anthropic",
},
+ ExperimentalForwardProxyMITMHosts: []string{"api.anthropic.com"},
},
Cache: CacheConfig{
Model: ModelCacheConfig{
diff --git a/docs/guides/claude-code.mdx b/docs/guides/claude-code.mdx
index d2b517bb..f2501074 100644
--- a/docs/guides/claude-code.mdx
+++ b/docs/guides/claude-code.mdx
@@ -1,45 +1,34 @@
---
title: "Using GoModel with Claude Code"
-description: "Step-by-step guide for routing Claude Code through GoModel with Anthropic passthrough."
+description: "Step-by-step guide for using Claude Code with GoModel in gateway mode or subscription proxy mode."
---
-GoModel can sit in front of Claude Code so every request goes through your own
-gateway first.
+GoModel supports two different Claude Code setups:
-Flow:
-
-`Claude Code -> GoModel -> Anthropic`
+| Goal | Mode | Upstream auth |
+| --- | --- | --- |
+| Use Claude Code through an Anthropic-compatible gateway | Gateway mode | GoModel's `ANTHROPIC_API_KEY` |
+| Keep Claude.ai subscription auth and still capture traffic in GoModel | Subscription proxy mode | Claude Code's own `claude.ai` or enterprise session |
-## Before you start
+If you want company-wide Claude Code observability without moving users onto API
+keys, use **subscription proxy mode**.
-- Install Claude Code on your machine.
-- Choose a GoModel master key, for example `change-me`.
-- Make sure GoModel has an Anthropic upstream credential.
+## Mode 1: Gateway mode
-
- Claude Code can be routed through GoModel whether or not you personally use a
- Claude Code subscription. For gateway mode, Claude Code talks to GoModel with
- `ANTHROPIC_BASE_URL` and `ANTHROPIC_AUTH_TOKEN`. GoModel still needs its own
- `ANTHROPIC_API_KEY` to reach Anthropic upstream.
-
-
-## How to get `ANTHROPIC_API_KEY`
+Flow:
-1. Open the Claude Console and sign in to your API account.
-2. Go to account settings in Console, then create an API key.
-3. Copy the key once and set it for GoModel as `ANTHROPIC_API_KEY`.
+`Claude Code -> GoModel /p/anthropic -> Anthropic`
-Anthropic's API docs state that API keys are created in Console account
-settings.
+This is the standard Anthropic-compatible gateway setup. Claude Code talks to
+GoModel with `ANTHROPIC_BASE_URL` and `ANTHROPIC_AUTH_TOKEN`, and GoModel calls
+Anthropic with its own `ANTHROPIC_API_KEY`.
- Claude paid plans (Pro, Max, Team, Enterprise) and Claude API Console billing
- are separate. API key usage is billed as API usage.
+ Gateway mode does not use a user's Claude subscription allowance. It uses
+ GoModel's upstream Anthropic API credentials.
-## 1. Run GoModel
-
-Start GoModel with a master key and an Anthropic provider key:
+### 1. Run GoModel
```bash
docker run --rm -p 8080:8080 \
@@ -48,17 +37,13 @@ docker run --rm -p 8080:8080 \
enterpilot/gomodel
```
-## 2. Confirm Anthropic passthrough with curl
-
-Check that the Anthropic passthrough models endpoint responds:
+### 2. Confirm Anthropic passthrough
```bash
curl -s http://localhost:8080/p/anthropic/v1/models \
-H "Authorization: Bearer change-me"
```
-Then send one small test message:
-
```bash
curl -s http://localhost:8080/p/anthropic/v1/messages \
-H "Authorization: Bearer change-me" \
@@ -76,11 +61,7 @@ curl -s http://localhost:8080/p/anthropic/v1/messages \
}'
```
-If the gateway is wired correctly, the response will contain `ok`.
-
-## 3. Configure Claude Code to use GoModel
-
-Point Claude Code at GoModel's Anthropic passthrough:
+### 3. Point Claude Code at GoModel
```bash
export ANTHROPIC_BASE_URL=http://localhost:8080/p/anthropic
@@ -88,55 +69,142 @@ export ANTHROPIC_AUTH_TOKEN=change-me
export CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1
```
-Short Claude Code doc summary: for gateway mode, set `ANTHROPIC_BASE_URL` to
-your gateway URL and `ANTHROPIC_AUTH_TOKEN` to your gateway token, then run
-Claude Code normally. See the official guide:
-[Claude Code LLM gateway docs](https://code.claude.com/docs/en/llm-gateway).
+```bash
+claude -p --output-format text --model claude-3-haiku-20240307 \
+ 'Reply with exactly ok and no punctuation.'
+```
+
+## Mode 2: Subscription proxy mode
+
+Flow:
-If GoModel is not running on your local machine, replace `localhost:8080` with
-the correct host and port.
+`Claude Code -> GoModel forward proxy -> Anthropic`
-## 4. Run a Claude Code test prompt
+This mode keeps Claude Code signed in with `claude.ai` or enterprise auth and
+places GoModel in the middle as an HTTP(S) proxy. It is the mode to use when
+you want GoModel `audit_logs` and `usage` for real Claude Code subscription
+traffic.
+
+
+ Validated on March 23, 2026 with Claude Code `2.1.81`, `authMethod:
+ claude.ai`, and GoModel's experimental forward proxy.
+
+
+
+ The forward proxy is experimental and intended for local or tightly controlled
+ internal use. Do not expose it directly to the internet in its current form.
+
+
+### 1. Generate a local MITM CA
```bash
-claude -p --output-format text --model claude-3-haiku-20240307 \
- 'Reply with exactly ok and no punctuation.'
+mkdir -p proxy-certs
+
+openssl req -x509 -newkey rsa:2048 -nodes \
+ -keyout proxy-certs/ca-key.pem \
+ -out proxy-certs/ca-cert.pem \
+ -days 30 \
+ -subj "/CN=GoModel Claude Proxy CA"
+```
+
+### 2. Run GoModel with the forward proxy enabled
+
+```bash
+docker run --rm -p 8080:8080 \
+ -v "$PWD/proxy-certs:/certs" \
+ -e OPENAI_API_KEY="dummy" \
+ -e LOGGING_ENABLED=true \
+ -e LOGGING_LOG_BODIES=true \
+ -e LOGGING_LOG_HEADERS=true \
+ -e USAGE_ENABLED=true \
+ -e EXPERIMENTAL_FORWARD_PROXY_ENABLED=true \
+ -e EXPERIMENTAL_FORWARD_PROXY_MITM_HOSTS="api.anthropic.com" \
+ -e EXPERIMENTAL_FORWARD_PROXY_CA_CERT_FILE="/certs/ca-cert.pem" \
+ -e EXPERIMENTAL_FORWARD_PROXY_CA_KEY_FILE="/certs/ca-key.pem" \
+ enterpilot/gomodel
+```
+
+
+ The forward proxy path itself does not use `OPENAI_API_KEY`. The example sets
+ `OPENAI_API_KEY=dummy` only because GoModel currently expects at least one
+ configured provider at startup. If you already run GoModel with a real
+ provider credential, keep using that instead.
+
+
+### 3. Point Claude Code at GoModel as an HTTP proxy
+
+Keep Claude Code logged in normally, then export:
+
+```bash
+export HTTPS_PROXY=http://localhost:8080
+export HTTP_PROXY=http://localhost:8080
+export NODE_EXTRA_CA_CERTS="$PWD/proxy-certs/ca-cert.pem"
+export NO_PROXY=127.0.0.1,localhost
```
-The validated result was:
+You can confirm the client is still using subscription auth with:
-```text
-ok
+```bash
+claude auth status
```
-## 5. Check the traffic in GoModel
+### 4. Run a real Claude Code prompt
-Open the GoModel dashboard audit logs:
+```bash
+claude -p --output-format json \
+ --setting-sources user \
+ --strict-mcp-config --mcp-config '{"mcpServers":{}}' \
+ --disable-slash-commands \
+ --no-chrome \
+ --no-session-persistence \
+ 'Respond with exactly the word THROUGHPROXY.'
+```
+
+A successful run returns JSON with:
+
+```json
+{
+ "result": "THROUGHPROXY"
+}
+```
+
+### 5. Verify audit and usage in GoModel
+
+Open the dashboard audit view:
[http://localhost:8080/admin/dashboard/audit](http://localhost:8080/admin/dashboard/audit)
-This is the easiest place to confirm that Claude Code is reaching GoModel and
-to inspect the full request and response trail. From the same dashboard, you
-can keep following your GoModel traffic and usage.
+For a successful proxied subscription request, you should see:
+
+- startup requests such as `/api/oauth/claude_cli/client_data`
+- a model request row for `/v1/messages`
+- `usage` data for that `/v1/messages` call
+
+The admin APIs are also available directly:
+
+- `GET /admin/api/v1/audit/log`
+- `GET /admin/api/v1/usage/log`
## References
- Anthropic Claude Code gateway docs: [LLM gateway](https://code.claude.com/docs/en/llm-gateway)
- Anthropic Claude Code settings: [Settings](https://code.claude.com/docs/en/settings)
+- Anthropic Claude Code enterprise network config: [Network configuration](https://code.claude.com/docs/en/network-config)
- Claude Help: [Managing API key environment variables in Claude Code](https://support.claude.com/en/articles/12304248-managing-api-key-environment-variables-in-claude-code)
- Claude Help: [Paid Claude plans vs API Console billing](https://support.claude.com/en/articles/9876003-i-have-a-paid-claude-subscription-pro-max-team-or-enterprise-plans-why-do-i-have-to-pay-separately-to-use-the-claude-api-and-console)
-- Claude API docs: [API overview (keys from Console account settings)](https://docs.anthropic.com/en/api/getting-started)
-## Validated on March 10, 2026
+## Validation summary
-This guide was validated against:
+Validated on March 23, 2026 against:
-- a local GoModel instance on `http://localhost:8080`
-- a GoModel branch exposing `/p/{provider}/v1/...`
-- Claude Code `2.1.72`
+- GoModel running locally on `http://localhost:8080`
+- Claude Code `2.1.81`
+- Anthropic gateway mode through `/p/anthropic`
+- subscription proxy mode through `HTTP_PROXY` / `HTTPS_PROXY`
Local validation confirmed:
-- `GET /p/anthropic/v1/models` returned `200 OK`
-- `POST /p/anthropic/v1/messages` returned `200 OK`
-- `claude` returned `ok` when pointed at GoModel
+- gateway mode returned successful Anthropic passthrough responses
+- subscription proxy mode returned `THROUGHPROXY`
+- GoModel stored `/v1/messages` in `audit_logs`
+- GoModel stored token and cost data for `/v1/messages` in `usage`
diff --git a/internal/app/app.go b/internal/app/app.go
index 775e11a5..77bb33e3 100644
--- a/internal/app/app.go
+++ b/internal/app/app.go
@@ -219,6 +219,12 @@ func New(ctx context.Context, cfg Config) (*App, error) {
EnabledPassthroughProviders: appCfg.Server.EnabledPassthroughProviders,
AllowPassthroughV1Alias: &allowPassthroughV1Alias,
SwaggerEnabled: appCfg.Server.SwaggerEnabled,
+ ExperimentalForwardProxy: &server.ForwardProxyConfig{
+ Enabled: appCfg.Server.ExperimentalForwardProxyEnabled,
+ MITMHosts: appCfg.Server.ExperimentalForwardProxyMITMHosts,
+ CACertFile: appCfg.Server.ExperimentalForwardProxyCACertFile,
+ CAKeyFile: appCfg.Server.ExperimentalForwardProxyCAKeyFile,
+ },
}
// Initialize admin API and dashboard (behind separate feature flags)
@@ -256,6 +262,9 @@ func New(ctx context.Context, cfg Config) (*App, error) {
} else {
slog.Info("provider passthrough disabled")
}
+ if appCfg.Server.ExperimentalForwardProxyEnabled {
+ slog.Info("experimental forward proxy enabled", "mitm_hosts", appCfg.Server.ExperimentalForwardProxyMITMHosts)
+ }
rcm, err := responsecache.NewResponseCacheMiddleware(appCfg.Cache.Response, cfg.AppConfig.RawProviders)
if err != nil {
diff --git a/internal/auditlog/auditlog_test.go b/internal/auditlog/auditlog_test.go
index b60cc522..9de23c43 100644
--- a/internal/auditlog/auditlog_test.go
+++ b/internal/auditlog/auditlog_test.go
@@ -850,6 +850,94 @@ func TestCreateStreamEntry(t *testing.T) {
}
}
+func TestStreamLogObserverAnthropicMessages(t *testing.T) {
+ store := &mockStore{}
+ cfg := Config{
+ Enabled: true,
+ LogBodies: true,
+ BufferSize: 10,
+ FlushInterval: 100 * time.Millisecond,
+ }
+ logger := NewLogger(store, cfg)
+
+ entry := &LogEntry{
+ ID: "anthropic-stream-entry",
+ Timestamp: time.Now(),
+ Model: "claude-sonnet-4-5",
+ Data: &LogData{},
+ }
+
+ streamData := `event: message_start
+data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[]}}
+
+event: content_block_delta
+data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
+
+event: content_block_delta
+data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}}
+
+event: message_delta
+data: {"type":"message_delta","delta":{"stop_reason":"end_turn"}}
+
+data: [DONE]
+
+`
+ stream := streaming.NewObservedSSEStream(
+ io.NopCloser(strings.NewReader(streamData)),
+ NewStreamLogObserver(logger, entry, "/v1/messages"),
+ )
+
+ if _, err := io.Copy(io.Discard, stream); err != nil {
+ t.Fatalf("failed to read stream: %v", err)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("failed to close stream: %v", err)
+ }
+ if err := logger.Close(); err != nil {
+ t.Fatalf("failed to close logger: %v", err)
+ }
+
+ entries := store.getEntries()
+ if len(entries) != 1 {
+ t.Fatalf("expected 1 entry, got %d", len(entries))
+ }
+ logged := entries[0]
+ if logged.Data == nil {
+ t.Fatal("Data = nil")
+ }
+ body, ok := logged.Data.ResponseBody.(map[string]any)
+ if !ok {
+ t.Fatalf("ResponseBody = %#v, want map", logged.Data.ResponseBody)
+ }
+ if body["id"] != "msg_123" {
+ t.Fatalf("id = %#v, want msg_123", body["id"])
+ }
+ if body["model"] != "claude-sonnet-4-5" {
+ t.Fatalf("model = %#v, want claude-sonnet-4-5", body["model"])
+ }
+ if body["role"] != "assistant" {
+ t.Fatalf("role = %#v, want assistant", body["role"])
+ }
+ if body["stop_reason"] != "end_turn" {
+ t.Fatalf("stop_reason = %#v, want end_turn", body["stop_reason"])
+ }
+ content, ok := body["content"].([]map[string]any)
+ if ok {
+ if len(content) != 1 || content[0]["text"] != "Hello world" {
+ t.Fatalf("content = %#v, want Hello world", content)
+ }
+ return
+ }
+ contentAny, ok := body["content"].([]any)
+ if !ok || len(contentAny) != 1 {
+ t.Fatalf("content = %#v, want single text block", body["content"])
+ }
+ first, ok := contentAny[0].(map[string]any)
+ if !ok || first["text"] != "Hello world" {
+ t.Fatalf("content[0] = %#v, want Hello world", contentAny[0])
+ }
+}
+
func TestHashAPIKey(t *testing.T) {
tests := []struct {
name string
diff --git a/internal/auditlog/stream_observer.go b/internal/auditlog/stream_observer.go
index 1a93a19e..3be74b7a 100644
--- a/internal/auditlog/stream_observer.go
+++ b/internal/auditlog/stream_observer.go
@@ -24,8 +24,10 @@ func NewStreamLogObserver(logger LoggerInterface, entry *LogEntry, path string)
logBodies := logger.Config().LogBodies
var builder *streamResponseBuilder
if logBodies {
+ isResponsesAPI := strings.HasPrefix(path, "/v1/responses")
builder = &streamResponseBuilder{
- IsResponsesAPI: strings.HasPrefix(path, "/v1/responses"),
+ IsResponsesAPI: isResponsesAPI,
+ IsAnthropicMessages: !isResponsesAPI && strings.HasPrefix(path, "/v1/messages"),
}
}
@@ -46,6 +48,10 @@ func (o *StreamLogObserver) OnJSONEvent(event map[string]any) {
o.parseResponsesAPIEvent(event)
return
}
+ if o.builder.IsAnthropicMessages {
+ o.parseAnthropicMessagesEvent(event)
+ return
+ }
o.parseChatCompletionEvent(event)
}
@@ -62,6 +68,8 @@ func (o *StreamLogObserver) OnStreamClose() {
if o.logBodies && o.builder != nil && o.entry != nil && o.entry.Data != nil {
if o.builder.IsResponsesAPI {
o.entry.Data.ResponseBody = o.builder.buildResponsesAPIResponse()
+ } else if o.builder.IsAnthropicMessages {
+ o.entry.Data.ResponseBody = o.builder.buildAnthropicMessageResponse()
} else {
o.entry.Data.ResponseBody = o.builder.buildChatCompletionResponse()
}
@@ -137,6 +145,42 @@ func (o *StreamLogObserver) parseResponsesAPIEvent(event map[string]any) {
}
}
+func (o *StreamLogObserver) parseAnthropicMessagesEvent(event map[string]any) {
+ if o.builder == nil {
+ return
+ }
+
+ eventType, _ := event["type"].(string)
+ switch eventType {
+ case "message_start":
+ if message, ok := event["message"].(map[string]any); ok {
+ if id, ok := message["id"].(string); ok {
+ o.builder.ID = id
+ }
+ if model, ok := message["model"].(string); ok {
+ o.builder.Model = model
+ }
+ if role, ok := message["role"].(string); ok {
+ o.builder.Role = role
+ }
+ }
+ case "content_block_delta":
+ if delta, ok := event["delta"].(map[string]any); ok {
+ if deltaType, _ := delta["type"].(string); deltaType == "text_delta" {
+ if text, ok := delta["text"].(string); ok && text != "" {
+ o.appendContent(text)
+ }
+ }
+ }
+ case "message_delta":
+ if delta, ok := event["delta"].(map[string]any); ok {
+ if stopReason, ok := delta["stop_reason"].(string); ok && stopReason != "" {
+ o.builder.FinishReason = stopReason
+ }
+ }
+ }
+}
+
func (o *StreamLogObserver) appendContent(content string) {
if o.builder == nil || o.builder.truncated || o.builder.contentLen >= MaxContentCapture {
return
diff --git a/internal/auditlog/stream_wrapper.go b/internal/auditlog/stream_wrapper.go
index 9b8b07a0..6606dc0f 100644
--- a/internal/auditlog/stream_wrapper.go
+++ b/internal/auditlog/stream_wrapper.go
@@ -23,6 +23,9 @@ type streamResponseBuilder struct {
CreatedAt int64
Status string
+ // Anthropic Messages fields
+ IsAnthropicMessages bool
+
// Tracking
contentLen int // track content length to enforce limit
truncated bool
@@ -71,6 +74,22 @@ func (b *streamResponseBuilder) buildResponsesAPIResponse() map[string]any {
}
}
+func (b *streamResponseBuilder) buildAnthropicMessageResponse() map[string]any {
+ return map[string]any{
+ "id": b.ID,
+ "type": "message",
+ "role": b.Role,
+ "model": b.Model,
+ "content": []map[string]any{
+ {
+ "type": "text",
+ "text": b.Content.String(),
+ },
+ },
+ "stop_reason": b.FinishReason,
+ }
+}
+
// CreateStreamEntry creates a new log entry for a streaming request.
// This should be called before starting the stream.
func CreateStreamEntry(baseEntry *LogEntry) *LogEntry {
diff --git a/internal/server/forward_proxy.go b/internal/server/forward_proxy.go
new file mode 100644
index 00000000..38830456
--- /dev/null
+++ b/internal/server/forward_proxy.go
@@ -0,0 +1,859 @@
+package server
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/sha256"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/hex"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "math/big"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+
+ "gomodel/internal/auditlog"
+ "gomodel/internal/core"
+ "gomodel/internal/streaming"
+ "gomodel/internal/usage"
+)
+
+type ForwardProxyConfig struct {
+ Enabled bool
+ MITMHosts []string
+ CACertFile string
+ CAKeyFile string
+ AuditLogger auditlog.LoggerInterface
+ UsageLogger usage.LoggerInterface
+ PricingResolver usage.PricingResolver
+ Transport *http.Transport
+}
+
+type forwardProxyHandler struct {
+ api http.Handler
+ auditLogger auditlog.LoggerInterface
+ usageLogger usage.LoggerInterface
+ pricingResolver usage.PricingResolver
+ transport *http.Transport
+ authority *mitmCertificateAuthority
+ mitmHosts map[string]struct{}
+}
+
+type mitmCertificateAuthority struct {
+ cert *x509.Certificate
+ key crypto.Signer
+ cache sync.Map
+}
+
+var errForwardProxyCloseConnection = errors.New("forward proxy close connection")
+
+func NewForwardProxyHandler(api http.Handler, cfg *ForwardProxyConfig) (http.Handler, error) {
+ if api == nil {
+ return nil, fmt.Errorf("forward proxy requires an API handler")
+ }
+ if cfg == nil || !cfg.Enabled {
+ return api, nil
+ }
+
+ handler := &forwardProxyHandler{
+ api: api,
+ auditLogger: cfg.AuditLogger,
+ usageLogger: cfg.UsageLogger,
+ pricingResolver: cfg.PricingResolver,
+ transport: cloneProxyTransport(cfg.Transport),
+ mitmHosts: normalizeProxyHosts(cfg.MITMHosts),
+ }
+
+ if len(handler.mitmHosts) > 0 {
+ authority, err := loadMITMCertificateAuthority(cfg.CACertFile, cfg.CAKeyFile)
+ if err != nil {
+ return nil, err
+ }
+ handler.authority = authority
+ }
+
+ return handler, nil
+}
+
+func cloneProxyTransport(transport *http.Transport) *http.Transport {
+ if transport == nil {
+ base := http.DefaultTransport.(*http.Transport).Clone()
+ base.Proxy = nil
+ return base
+ }
+ cloned := transport.Clone()
+ cloned.Proxy = nil
+ return cloned
+}
+
+func normalizeProxyHosts(hosts []string) map[string]struct{} {
+ if len(hosts) == 0 {
+ return nil
+ }
+ result := make(map[string]struct{}, len(hosts))
+ for _, host := range hosts {
+ canonical := canonicalProxyHost(host)
+ if canonical == "" {
+ continue
+ }
+ result[canonical] = struct{}{}
+ }
+ return result
+}
+
+func (h *forwardProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if !isForwardProxyRequest(r) {
+ h.api.ServeHTTP(w, r)
+ return
+ }
+
+ switch r.Method {
+ case http.MethodConnect:
+ h.handleConnect(w, r)
+ default:
+ h.handleHTTPProxy(w, r)
+ }
+}
+
+func isForwardProxyRequest(r *http.Request) bool {
+ if r == nil {
+ return false
+ }
+ return r.Method == http.MethodConnect || (r.URL != nil && r.URL.IsAbs())
+}
+
+func (h *forwardProxyHandler) handleConnect(w http.ResponseWriter, r *http.Request) {
+ targetAddr := targetAddress(r.Host, "443")
+ targetHost := canonicalProxyHost(targetAddr)
+ if _, ok := h.mitmHosts[targetHost]; ok && h.authority != nil {
+ slog.Info("forward proxy CONNECT", "target", targetAddr, "mode", "mitm")
+ h.handleMITMConnect(w, r, targetAddr, targetHost)
+ return
+ }
+ slog.Info("forward proxy CONNECT", "target", targetAddr, "mode", "tunnel")
+ h.handleTunnelConnect(w, targetAddr)
+}
+
+func (h *forwardProxyHandler) handleTunnelConnect(w http.ResponseWriter, targetAddr string) {
+ hijacker, ok := w.(http.Hijacker)
+ if !ok {
+ http.Error(w, "proxy hijacking is not supported", http.StatusInternalServerError)
+ return
+ }
+
+ clientConn, _, err := hijacker.Hijack()
+ if err != nil {
+ http.Error(w, "failed to hijack proxy connection", http.StatusInternalServerError)
+ return
+ }
+
+ upstreamConn, err := net.DialTimeout("tcp", targetAddr, 30*time.Second)
+ if err != nil {
+ _, _ = io.WriteString(clientConn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
+ _ = clientConn.Close()
+ return
+ }
+
+ _, _ = io.WriteString(clientConn, "HTTP/1.1 200 Connection Established\r\n\r\n")
+
+ go func() {
+ _, _ = io.Copy(upstreamConn, clientConn)
+ _ = upstreamConn.Close()
+ }()
+ go func() {
+ _, _ = io.Copy(clientConn, upstreamConn)
+ _ = clientConn.Close()
+ }()
+}
+
+func (h *forwardProxyHandler) handleMITMConnect(w http.ResponseWriter, r *http.Request, targetAddr, targetHost string) {
+ hijacker, ok := w.(http.Hijacker)
+ if !ok {
+ http.Error(w, "proxy hijacking is not supported", http.StatusInternalServerError)
+ return
+ }
+
+ clientConn, _, err := hijacker.Hijack()
+ if err != nil {
+ http.Error(w, "failed to hijack proxy connection", http.StatusInternalServerError)
+ return
+ }
+ defer func() { _ = clientConn.Close() }()
+
+ _, _ = io.WriteString(clientConn, "HTTP/1.1 200 Connection Established\r\n\r\n")
+
+ cert, err := h.authority.certificateForHost(targetHost)
+ if err != nil {
+ slog.Error("failed to mint MITM certificate", "host", targetHost, "error", err)
+ return
+ }
+
+ tlsConn := tls.Server(clientConn, &tls.Config{
+ Certificates: []tls.Certificate{*cert},
+ MinVersion: tls.VersionTLS12,
+ })
+ defer func() { _ = tlsConn.Close() }()
+
+ if err := tlsConn.Handshake(); err != nil {
+ slog.Debug("forward proxy TLS handshake failed", "host", targetHost, "error", err)
+ return
+ }
+ slog.Info(
+ "forward proxy TLS handshake",
+ "host", targetHost,
+ "alpn", tlsConn.ConnectionState().NegotiatedProtocol,
+ "version", tls.VersionName(tlsConn.ConnectionState().Version),
+ )
+
+ reader := bufio.NewReader(tlsConn)
+ requestCount := 0
+ for {
+ req, err := http.ReadRequest(reader)
+ if err != nil {
+ if !errors.Is(err, io.EOF) {
+ buffered := reader.Buffered()
+ prefix := ""
+ if buffered > 0 {
+ peek, peekErr := reader.Peek(min(buffered, 32))
+ if peekErr == nil {
+ prefix = sanitizeProxyPreview(peek)
+ }
+ }
+ slog.Info(
+ "forward proxy request read failed",
+ "host", targetHost,
+ "error", err,
+ "alpn", tlsConn.ConnectionState().NegotiatedProtocol,
+ "requests_served", requestCount,
+ "buffered", buffered,
+ "prefix", prefix,
+ )
+ }
+ return
+ }
+ requestCount++
+ if err := h.serveMITMRequest(tlsConn, req, targetAddr); err != nil {
+ if errors.Is(err, errForwardProxyCloseConnection) {
+ return
+ }
+ slog.Debug("forward proxy MITM request failed", "host", targetHost, "error", err)
+ return
+ }
+ }
+}
+
+func (h *forwardProxyHandler) serveMITMRequest(clientConn net.Conn, req *http.Request, targetAddr string) error {
+ defer func() {
+ if req.Body != nil {
+ _ = req.Body.Close()
+ }
+ }()
+
+ requestID := ensureProxyRequestID(req.Header)
+ start := time.Now().UTC()
+ slog.Info(
+ "forward proxy request",
+ "target", targetAddr,
+ "method", req.Method,
+ "path", req.URL.Path,
+ "content_length", req.ContentLength,
+ "expect", strings.TrimSpace(req.Header.Get("Expect")),
+ "user_agent", req.UserAgent(),
+ "request_id", requestID,
+ )
+ bodyBytes, err := io.ReadAll(req.Body)
+ if err != nil {
+ return err
+ }
+ model, streamRequested := extractProxyRequestInfo(bodyBytes)
+ if req.URL.Path == "/api/event_logging/v2/batch" {
+ return h.serveSyntheticEventLoggingSuccess(clientConn, req, start, requestID, model, bodyBytes)
+ }
+
+ upstreamReq := req.Clone(context.Background())
+ upstreamReq.URL = &url.URL{
+ Scheme: "https",
+ Host: targetAddr,
+ Path: req.URL.Path,
+ RawPath: req.URL.RawPath,
+ RawQuery: req.URL.RawQuery,
+ }
+ upstreamReq.RequestURI = ""
+ upstreamReq.Host = targetAddr
+ upstreamReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+ upstreamReq.ContentLength = int64(len(bodyBytes))
+ upstreamReq.Header = cloneHeader(req.Header)
+ stripProxyHeaders(upstreamReq.Header)
+ upstreamReq.Header.Del("Accept-Encoding")
+
+ resp, err := h.transport.RoundTrip(upstreamReq)
+ if err != nil {
+ entry := h.newProxyAuditEntry(start, requestID, req, req.URL.Path, model)
+ entry.StatusCode = http.StatusBadGateway
+ entry.ErrorType = "proxy_error"
+ entry.Data.ErrorMessage = err.Error()
+ h.writeAuditEntry(entry)
+ _, _ = io.WriteString(clientConn, "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n")
+ return err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ slog.Info(
+ "forward proxy upstream response",
+ "target", targetAddr,
+ "method", req.Method,
+ "path", req.URL.Path,
+ "status", resp.StatusCode,
+ "request_id", requestID,
+ )
+
+ entry := h.newProxyAuditEntry(start, requestID, req, req.URL.Path, model)
+ entry.StatusCode = resp.StatusCode
+ entry.Data.RequestBody, entry.Data.RequestBodyTooBigToHandle = proxyBodyForAudit(bodyBytes)
+ entry.Data.ResponseHeaders = proxyHeaderMap(resp.Header)
+
+ if isEventStreamHeader(resp.Header) || streamRequested {
+ return h.serveMITMStreamResponse(clientConn, req, resp, entry)
+ }
+ return h.serveMITMBufferedResponse(clientConn, req, resp, entry, bodyBytes)
+}
+
+func (h *forwardProxyHandler) serveSyntheticEventLoggingSuccess(clientConn net.Conn, req *http.Request, start time.Time, requestID, model string, requestBody []byte) error {
+ entry := h.newProxyAuditEntry(start, requestID, req, req.URL.Path, model)
+ entry.StatusCode = http.StatusNoContent
+ entry.DurationNs = time.Since(start).Nanoseconds()
+ entry.Data.RequestBody, entry.Data.RequestBodyTooBigToHandle = proxyBodyForAudit(requestBody)
+ entry.Data.ResponseHeaders = map[string]string{
+ "Connection": "close",
+ "Content-Length": "0",
+ }
+ h.writeAuditEntry(entry)
+
+ if _, err := io.WriteString(clientConn, "HTTP/1.1 204 No Content\r\nConnection: close\r\nContent-Length: 0\r\n\r\n"); err != nil {
+ return err
+ }
+ return errForwardProxyCloseConnection
+}
+
+func (h *forwardProxyHandler) serveMITMBufferedResponse(clientConn net.Conn, req *http.Request, resp *http.Response, entry *auditlog.LogEntry, requestBody []byte) error {
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err
+ }
+
+ entry.DurationNs = time.Since(entry.Timestamp).Nanoseconds()
+ entry.Data.RequestBody, entry.Data.RequestBodyTooBigToHandle = proxyBodyForAudit(requestBody)
+ entry.Data.ResponseBody, entry.Data.ResponseBodyTooBigToHandle = proxyBodyForAudit(respBody)
+
+ if usageEntry := h.extractAnthropicUsageEntry(respBody, entry.RequestID, entry.Model, req.URL.Path); usageEntry != nil {
+ h.writeUsageEntry(usageEntry)
+ if entry.Model == "" && usageEntry.Model != "" {
+ entry.Model = usageEntry.Model
+ }
+ }
+ h.writeAuditEntry(entry)
+
+ resp.Body = io.NopCloser(bytes.NewReader(respBody))
+ resp.Close = true
+ resp.Header.Set("Connection", "close")
+ normalizeResponseForProxyWrite(resp)
+ if err := resp.Write(clientConn); err != nil {
+ return err
+ }
+ return errForwardProxyCloseConnection
+}
+
+func (h *forwardProxyHandler) serveMITMStreamResponse(clientConn net.Conn, req *http.Request, resp *http.Response, entry *auditlog.LogEntry) error {
+ observers := make([]streaming.Observer, 0, 2)
+ if h.auditLogger != nil && h.auditLogger.Config().Enabled {
+ if observer := auditlog.NewStreamLogObserver(h.auditLogger, entry, req.URL.Path); observer != nil {
+ observers = append(observers, observer)
+ }
+ }
+ if h.usageLogger != nil && h.usageLogger.Config().Enabled {
+ if observer := usage.NewStreamUsageObserver(h.usageLogger, entry.Model, entry.Provider, entry.RequestID, req.URL.Path, h.pricingResolver); observer != nil {
+ observers = append(observers, observer)
+ }
+ }
+
+ wrappedStream := streaming.NewObservedSSEStream(resp.Body, observers...)
+ defer func() { _ = wrappedStream.Close() }()
+
+ resp.Body = wrappedStream
+ resp.Close = true
+ resp.Header.Set("Connection", "close")
+ normalizeResponseForProxyWrite(resp)
+ if err := resp.Write(clientConn); err != nil {
+ return err
+ }
+ return errForwardProxyCloseConnection
+}
+
+func (h *forwardProxyHandler) handleHTTPProxy(w http.ResponseWriter, r *http.Request) {
+ start := time.Now().UTC()
+ requestID := ensureProxyRequestID(r.Header)
+
+ bodyBytes, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, "failed to read proxy request body", http.StatusBadRequest)
+ return
+ }
+ model, _ := extractProxyRequestInfo(bodyBytes)
+
+ upstreamReq := r.Clone(context.Background())
+ upstreamReq.RequestURI = ""
+ upstreamReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+ upstreamReq.ContentLength = int64(len(bodyBytes))
+ upstreamReq.Header = cloneHeader(r.Header)
+ stripProxyHeaders(upstreamReq.Header)
+ upstreamReq.Header.Del("Accept-Encoding")
+
+ resp, err := h.transport.RoundTrip(upstreamReq)
+ if err != nil {
+ entry := h.newProxyAuditEntry(start, requestID, r, r.URL.Path, model)
+ entry.StatusCode = http.StatusBadGateway
+ entry.ErrorType = "proxy_error"
+ entry.Data.ErrorMessage = err.Error()
+ h.writeAuditEntry(entry)
+ http.Error(w, "proxy upstream request failed", http.StatusBadGateway)
+ return
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ http.Error(w, "failed to read proxy upstream response", http.StatusBadGateway)
+ return
+ }
+
+ entry := h.newProxyAuditEntry(start, requestID, r, r.URL.Path, model)
+ entry.StatusCode = resp.StatusCode
+ entry.DurationNs = time.Since(start).Nanoseconds()
+ entry.Data.RequestBody, entry.Data.RequestBodyTooBigToHandle = proxyBodyForAudit(bodyBytes)
+ entry.Data.ResponseBody, entry.Data.ResponseBodyTooBigToHandle = proxyBodyForAudit(respBody)
+ entry.Data.ResponseHeaders = proxyHeaderMap(resp.Header)
+ h.writeAuditEntry(entry)
+
+ if usageEntry := h.extractAnthropicUsageEntry(respBody, requestID, model, r.URL.Path); usageEntry != nil {
+ h.writeUsageEntry(usageEntry)
+ }
+
+ copyProxyResponseHeaders(w.Header(), resp.Header)
+ w.WriteHeader(resp.StatusCode)
+ _, _ = w.Write(respBody)
+}
+
+func (h *forwardProxyHandler) newProxyAuditEntry(start time.Time, requestID string, req *http.Request, path, model string) *auditlog.LogEntry {
+ entry := &auditlog.LogEntry{
+ ID: uuid.NewString(),
+ Timestamp: start,
+ RequestID: requestID,
+ ClientIP: proxyClientIP(req.RemoteAddr),
+ Method: req.Method,
+ Path: path,
+ Model: model,
+ Provider: "anthropic",
+ StatusCode: http.StatusOK,
+ Data: &auditlog.LogData{
+ UserAgent: req.UserAgent(),
+ APIKeyHash: proxyCredentialHash(req.Header),
+ RequestHeaders: proxyHeaderMap(req.Header),
+ },
+ }
+ return entry
+}
+
+func (h *forwardProxyHandler) writeAuditEntry(entry *auditlog.LogEntry) {
+ if h.auditLogger == nil || !h.auditLogger.Config().Enabled || entry == nil {
+ return
+ }
+ h.auditLogger.Write(entry)
+}
+
+func (h *forwardProxyHandler) writeUsageEntry(entry *usage.UsageEntry) {
+ if h.usageLogger == nil || !h.usageLogger.Config().Enabled || entry == nil {
+ return
+ }
+ h.usageLogger.Write(entry)
+}
+
+func (h *forwardProxyHandler) extractAnthropicUsageEntry(body []byte, requestID, model, endpoint string) *usage.UsageEntry {
+ var payload map[string]any
+ if err := json.Unmarshal(body, &payload); err != nil {
+ return nil
+ }
+
+ usageRaw, ok := payload["usage"].(map[string]any)
+ if !ok {
+ return nil
+ }
+
+ providerID, _ := payload["id"].(string)
+ if responseModel, _ := payload["model"].(string); responseModel != "" {
+ model = responseModel
+ }
+
+ inputTokens := int(floatFromMap(usageRaw, "input_tokens"))
+ outputTokens := int(floatFromMap(usageRaw, "output_tokens"))
+ totalTokens := int(floatFromMap(usageRaw, "total_tokens"))
+ if totalTokens == 0 && (inputTokens > 0 || outputTokens > 0) {
+ totalTokens = inputTokens + outputTokens
+ }
+ if inputTokens == 0 && outputTokens == 0 && totalTokens == 0 {
+ return nil
+ }
+
+ rawData := make(map[string]any)
+ for _, key := range []string{"cache_creation_input_tokens", "cache_read_input_tokens"} {
+ if value := int(floatFromMap(usageRaw, key)); value > 0 {
+ rawData[key] = value
+ }
+ }
+ if len(rawData) == 0 {
+ rawData = nil
+ }
+
+ var pricingArgs []*core.ModelPricing
+ if h.pricingResolver != nil {
+ if pricing := h.pricingResolver.ResolvePricing(model, "anthropic"); pricing != nil {
+ pricingArgs = append(pricingArgs, pricing)
+ }
+ }
+
+ return usage.ExtractFromSSEUsage(
+ providerID,
+ inputTokens, outputTokens, totalTokens,
+ rawData,
+ requestID, model, "anthropic", endpoint,
+ pricingArgs...,
+ )
+}
+
+func extractProxyRequestInfo(body []byte) (model string, stream bool) {
+ if len(body) == 0 {
+ return "", false
+ }
+
+ var payload map[string]any
+ if err := json.Unmarshal(body, &payload); err != nil {
+ return "", false
+ }
+ if m, ok := payload["model"].(string); ok {
+ model = strings.TrimSpace(m)
+ }
+ if s, ok := payload["stream"].(bool); ok {
+ stream = s
+ }
+ return model, stream
+}
+
+func ensureProxyRequestID(headers http.Header) string {
+ if headers == nil {
+ return uuid.NewString()
+ }
+ if requestID := strings.TrimSpace(headers.Get("X-Request-ID")); requestID != "" {
+ return requestID
+ }
+ requestID := uuid.NewString()
+ headers.Set("X-Request-ID", requestID)
+ return requestID
+}
+
+func proxyBodyForAudit(body []byte) (any, bool) {
+ if len(body) == 0 {
+ return nil, false
+ }
+ if len(body) > auditlog.MaxBodyCapture {
+ return nil, true
+ }
+
+ var parsed any
+ if err := json.Unmarshal(body, &parsed); err == nil {
+ return parsed, false
+ }
+ return string(body), false
+}
+
+func proxyHeaderMap(headers http.Header) map[string]string {
+ if len(headers) == 0 {
+ return nil
+ }
+ result := make(map[string]string, len(headers))
+ for key, values := range headers {
+ if len(values) == 0 {
+ continue
+ }
+ result[key] = values[0]
+ }
+ return auditlog.RedactHeaders(result)
+}
+
+func proxyCredentialHash(headers http.Header) string {
+ if headers == nil {
+ return ""
+ }
+ token := strings.TrimSpace(headers.Get("Authorization"))
+ if token == "" {
+ token = strings.TrimSpace(headers.Get("Cookie"))
+ }
+ if token == "" {
+ return ""
+ }
+ hash := sha256.Sum256([]byte(token))
+ return hex.EncodeToString(hash[:])[:auditlog.APIKeyHashPrefixLength]
+}
+
+func proxyClientIP(remoteAddr string) string {
+ host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr))
+ if err != nil {
+ return strings.TrimSpace(remoteAddr)
+ }
+ return host
+}
+
+func targetAddress(hostPort, defaultPort string) string {
+ hostPort = strings.TrimSpace(hostPort)
+ if hostPort == "" {
+ return ""
+ }
+ if _, _, err := net.SplitHostPort(hostPort); err == nil {
+ return hostPort
+ }
+ return net.JoinHostPort(hostPort, defaultPort)
+}
+
+func normalizeResponseForProxyWrite(resp *http.Response) {
+ if resp == nil {
+ return
+ }
+ resp.Proto = "HTTP/1.1"
+ resp.ProtoMajor = 1
+ resp.ProtoMinor = 1
+}
+
+func canonicalProxyHost(host string) string {
+ host = strings.TrimSpace(host)
+ if host == "" {
+ return ""
+ }
+ if parsedHost, _, err := net.SplitHostPort(host); err == nil {
+ host = parsedHost
+ }
+ return strings.ToLower(strings.TrimSpace(host))
+}
+
+func cloneHeader(headers http.Header) http.Header {
+ if headers == nil {
+ return nil
+ }
+ cloned := make(http.Header, len(headers))
+ for key, values := range headers {
+ dst := make([]string, len(values))
+ copy(dst, values)
+ cloned[key] = dst
+ }
+ return cloned
+}
+
+func stripProxyHeaders(headers http.Header) {
+ if headers == nil {
+ return
+ }
+ for _, key := range []string{
+ "Proxy-Connection",
+ "Proxy-Authorization",
+ "Proxy-Authenticate",
+ "Connection",
+ } {
+ headers.Del(key)
+ }
+}
+
+func copyProxyResponseHeaders(dst, src http.Header) {
+ if dst == nil || src == nil {
+ return
+ }
+ connectionHeaders := passthroughConnectionHeaders(src)
+ for key, values := range src {
+ canonicalKey := http.CanonicalHeaderKey(strings.TrimSpace(key))
+ if len(values) == 0 {
+ continue
+ }
+ if _, hopByHop := connectionHeaders[canonicalKey]; hopByHop {
+ continue
+ }
+ switch canonicalKey {
+ case "Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "Te", "Trailer", "Transfer-Encoding", "Upgrade":
+ continue
+ }
+ dst.Del(canonicalKey)
+ for _, value := range values {
+ dst.Add(canonicalKey, value)
+ }
+ }
+}
+
+func sanitizeProxyPreview(data []byte) string {
+ if len(data) == 0 {
+ return ""
+ }
+ const maxLen = 32
+ if len(data) > maxLen {
+ data = data[:maxLen]
+ }
+ out := make([]byte, 0, len(data))
+ for _, b := range data {
+ if b >= 32 && b <= 126 {
+ out = append(out, b)
+ continue
+ }
+ switch b {
+ case '\r':
+ out = append(out, '\\', 'r')
+ case '\n':
+ out = append(out, '\\', 'n')
+ case '\t':
+ out = append(out, '\\', 't')
+ default:
+ out = append(out, '.')
+ }
+ }
+ return string(out)
+}
+
+func isEventStreamHeader(headers http.Header) bool {
+ for key, values := range headers {
+ if !strings.EqualFold(key, "Content-Type") {
+ continue
+ }
+ for _, value := range values {
+ if strings.Contains(strings.ToLower(value), "text/event-stream") {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func floatFromMap(values map[string]any, key string) float64 {
+ value, ok := values[key]
+ if !ok {
+ return 0
+ }
+ if number, ok := value.(float64); ok {
+ return number
+ }
+ return 0
+}
+
+func loadMITMCertificateAuthority(certPath, keyPath string) (*mitmCertificateAuthority, error) {
+ certPEM, err := os.ReadFile(certPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read forward proxy CA certificate: %w", err)
+ }
+ keyPEM, err := os.ReadFile(keyPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read forward proxy CA key: %w", err)
+ }
+
+ pair, err := tls.X509KeyPair(certPEM, keyPEM)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load forward proxy CA keypair: %w", err)
+ }
+
+ leaf, err := x509.ParseCertificate(pair.Certificate[0])
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse forward proxy CA certificate: %w", err)
+ }
+
+ signer, ok := pair.PrivateKey.(crypto.Signer)
+ if !ok {
+ return nil, fmt.Errorf("forward proxy CA key does not implement crypto.Signer")
+ }
+
+ return &mitmCertificateAuthority{
+ cert: leaf,
+ key: signer,
+ }, nil
+}
+
+func (a *mitmCertificateAuthority) certificateForHost(host string) (*tls.Certificate, error) {
+ if cached, ok := a.cache.Load(host); ok {
+ if cert, ok := cached.(*tls.Certificate); ok {
+ return cert, nil
+ }
+ }
+
+ privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ return nil, err
+ }
+
+ serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
+ if err != nil {
+ return nil, err
+ }
+
+ template := &x509.Certificate{
+ SerialNumber: serial,
+ Subject: pkix.Name{
+ CommonName: host,
+ },
+ NotBefore: time.Now().Add(-5 * time.Minute),
+ NotAfter: minTime(a.cert.NotAfter, time.Now().Add(24*time.Hour)),
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ template.IPAddresses = []net.IP{ip}
+ } else {
+ template.DNSNames = []string{host}
+ }
+
+ der, err := x509.CreateCertificate(rand.Reader, template, a.cert, privateKey.Public(), a.key)
+ if err != nil {
+ return nil, err
+ }
+
+ leaf, err := x509.ParseCertificate(der)
+ if err != nil {
+ return nil, err
+ }
+
+ cert := &tls.Certificate{
+ Certificate: [][]byte{der, a.cert.Raw},
+ PrivateKey: privateKey,
+ Leaf: leaf,
+ }
+ a.cache.Store(host, cert)
+ return cert, nil
+}
+
+func minTime(first, second time.Time) time.Time {
+ if first.Before(second) {
+ return first
+ }
+ return second
+}
+
+func encodePEMBlock(blockType string, der []byte) []byte {
+ return pem.EncodeToMemory(&pem.Block{Type: blockType, Bytes: der})
+}
diff --git a/internal/server/forward_proxy_test.go b/internal/server/forward_proxy_test.go
new file mode 100644
index 00000000..e806603b
--- /dev/null
+++ b/internal/server/forward_proxy_test.go
@@ -0,0 +1,470 @@
+package server
+
+import (
+ "bytes"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/json"
+ "io"
+ "math/big"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "gomodel/internal/auditlog"
+ "gomodel/internal/usage"
+)
+
+func TestForwardProxyMITMAnthropicJSONUsageAndAudit(t *testing.T) {
+ upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/v1/messages" {
+ t.Fatalf("path = %q, want /v1/messages", r.URL.Path)
+ }
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("failed to read upstream request body: %v", err)
+ }
+ if !strings.Contains(string(body), `"model":"claude-sonnet-4-5"`) {
+ t.Fatalf("unexpected request body: %s", body)
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"msg_123","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":10,"output_tokens":2,"cache_read_input_tokens":6}}`))
+ }))
+ defer upstream.Close()
+
+ caCertPath, caKeyPath, caCert := writeTestCAFiles(t)
+ proxyURL, proxyAudit, proxyUsage := startTestForwardProxy(t, upstream, caCertPath, caKeyPath)
+
+ client := newMITMHTTPClient(t, proxyURL, caCert)
+ req, err := http.NewRequest(http.MethodPost, upstream.URL+"/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"Reply with ok"}]}`))
+ if err != nil {
+ t.Fatalf("NewRequest error: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "claude-code-test")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatalf("client.Do error: %v", err)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("ReadAll error: %v", err)
+ }
+ if err := resp.Body.Close(); err != nil {
+ t.Fatalf("Close error: %v", err)
+ }
+ if !strings.Contains(string(body), `"id":"msg_123"`) {
+ t.Fatalf("unexpected response body: %s", body)
+ }
+
+ auditEntries := waitForAuditEntries(t, proxyAudit, 1)
+ auditEntry := auditEntries[0]
+ if auditEntry.Provider != "anthropic" {
+ t.Fatalf("Provider = %q, want anthropic", auditEntry.Provider)
+ }
+ if auditEntry.Path != "/v1/messages" {
+ t.Fatalf("Path = %q, want /v1/messages", auditEntry.Path)
+ }
+ if auditEntry.Model != "claude-sonnet-4-5" {
+ t.Fatalf("Model = %q, want claude-sonnet-4-5", auditEntry.Model)
+ }
+ if auditEntry.Data == nil || auditEntry.Data.RequestBody == nil || auditEntry.Data.ResponseBody == nil {
+ t.Fatal("expected request and response bodies to be captured")
+ }
+
+ usageEntries := waitForUsageEntries(t, proxyUsage, 1)
+ usageEntry := usageEntries[0]
+ if usageEntry.Provider != "anthropic" {
+ t.Fatalf("Provider = %q, want anthropic", usageEntry.Provider)
+ }
+ if usageEntry.Endpoint != "/v1/messages" {
+ t.Fatalf("Endpoint = %q, want /v1/messages", usageEntry.Endpoint)
+ }
+ if usageEntry.InputTokens != 10 {
+ t.Fatalf("InputTokens = %d, want 10", usageEntry.InputTokens)
+ }
+ if usageEntry.OutputTokens != 2 {
+ t.Fatalf("OutputTokens = %d, want 2", usageEntry.OutputTokens)
+ }
+ if usageEntry.TotalTokens != 12 {
+ t.Fatalf("TotalTokens = %d, want 12", usageEntry.TotalTokens)
+ }
+ if usageEntry.RawData["cache_read_input_tokens"] != 6 {
+ t.Fatalf("RawData[cache_read_input_tokens] = %v, want 6", usageEntry.RawData["cache_read_input_tokens"])
+ }
+}
+
+func TestForwardProxyMITMAnthropicStreamingUsageAndAudit(t *testing.T) {
+ upstream := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ _, _ = w.Write([]byte(`event: message_start
+data: {"type":"message_start","message":{"id":"msg_stream_123","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[],"usage":{"input_tokens":10,"output_tokens":0}}}
+
+event: content_block_delta
+data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
+
+event: content_block_delta
+data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}}
+
+event: message_delta
+data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":2}}
+
+data: [DONE]
+
+`))
+ }))
+ defer upstream.Close()
+
+ caCertPath, caKeyPath, caCert := writeTestCAFiles(t)
+ proxyURL, proxyAudit, proxyUsage := startTestForwardProxy(t, upstream, caCertPath, caKeyPath)
+
+ client := newMITMHTTPClient(t, proxyURL, caCert)
+ req, err := http.NewRequest(http.MethodPost, upstream.URL+"/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-5","stream":true,"messages":[{"role":"user","content":"hi"}]}`))
+ if err != nil {
+ t.Fatalf("NewRequest error: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatalf("client.Do error: %v", err)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("ReadAll error: %v", err)
+ }
+ if err := resp.Body.Close(); err != nil {
+ t.Fatalf("Close error: %v", err)
+ }
+ if !strings.Contains(string(body), "message_start") {
+ t.Fatalf("unexpected streamed body: %s", body)
+ }
+
+ auditEntries := waitForAuditEntries(t, proxyAudit, 1)
+ auditEntry := auditEntries[0]
+ if auditEntry.Data == nil {
+ t.Fatal("Data = nil")
+ }
+ responseBody, ok := auditEntry.Data.ResponseBody.(map[string]any)
+ if !ok {
+ t.Fatalf("ResponseBody = %#v, want map", auditEntry.Data.ResponseBody)
+ }
+ if responseBody["id"] != "msg_stream_123" {
+ t.Fatalf("id = %#v, want msg_stream_123", responseBody["id"])
+ }
+ if content, ok := responseBody["content"].([]map[string]any); ok {
+ if len(content) != 1 || content[0]["text"] != "Hello world" {
+ t.Fatalf("content = %#v, want Hello world", content)
+ }
+ } else {
+ contentAny, ok := responseBody["content"].([]any)
+ if !ok || len(contentAny) != 1 {
+ t.Fatalf("content = %#v, want one text block", responseBody["content"])
+ }
+ content, ok := contentAny[0].(map[string]any)
+ if !ok || content["text"] != "Hello world" {
+ t.Fatalf("content[0] = %#v, want Hello world", contentAny[0])
+ }
+ }
+
+ usageEntries := waitForUsageEntries(t, proxyUsage, 1)
+ usageEntry := usageEntries[0]
+ if usageEntry.InputTokens != 10 || usageEntry.OutputTokens != 2 || usageEntry.TotalTokens != 12 {
+ t.Fatalf("unexpected usage entry: %+v", usageEntry)
+ }
+}
+
+func waitForAuditEntries(t *testing.T, logger *capturingAuditLogger, want int) []*auditlog.LogEntry {
+ t.Helper()
+
+ deadline := time.Now().Add(2 * time.Second)
+ for {
+ entries := logger.Entries()
+ if len(entries) == want {
+ return entries
+ }
+ if time.Now().After(deadline) {
+ t.Fatalf("audit entries = %d, want %d", len(entries), want)
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+}
+
+func waitForUsageEntries(t *testing.T, logger *collectingUsageLogger, want int) []*usage.UsageEntry {
+ t.Helper()
+
+ deadline := time.Now().Add(2 * time.Second)
+ for {
+ entries := logger.Entries()
+ if len(entries) == want {
+ return entries
+ }
+ if time.Now().After(deadline) {
+ t.Fatalf("usage entries = %d, want %d", len(entries), want)
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+}
+
+func startTestForwardProxy(t *testing.T, upstream *httptest.Server, caCertPath, caKeyPath string) (string, *capturingAuditLogger, *collectingUsageLogger) {
+ t.Helper()
+
+ upstreamURL, err := url.Parse(upstream.URL)
+ if err != nil {
+ t.Fatalf("Parse upstream URL error: %v", err)
+ }
+ mitmHost := canonicalProxyHost(upstreamURL.Host)
+
+ auditLogger := &capturingAuditLogger{
+ config: auditlog.Config{
+ Enabled: true,
+ LogBodies: true,
+ LogHeaders: true,
+ },
+ }
+ usageLogger := &collectingUsageLogger{
+ config: usage.Config{Enabled: true},
+ }
+ handler, err := NewForwardProxyHandler(http.NotFoundHandler(), &ForwardProxyConfig{
+ Enabled: true,
+ MITMHosts: []string{mitmHost},
+ CACertFile: caCertPath,
+ CAKeyFile: caKeyPath,
+ AuditLogger: auditLogger,
+ UsageLogger: usageLogger,
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
+ },
+ })
+ if err != nil {
+ t.Fatalf("NewForwardProxyHandler error: %v", err)
+ }
+
+ server := httptest.NewServer(handler)
+ t.Cleanup(server.Close)
+ return server.URL, auditLogger, usageLogger
+}
+
+func newMITMHTTPClient(t *testing.T, proxyAddr string, caCert *x509.Certificate) *http.Client {
+ t.Helper()
+
+ proxyURL, err := url.Parse(proxyAddr)
+ if err != nil {
+ t.Fatalf("Parse proxy URL error: %v", err)
+ }
+ pool := x509.NewCertPool()
+ pool.AddCert(caCert)
+
+ return &http.Client{
+ Transport: &http.Transport{
+ Proxy: http.ProxyURL(proxyURL),
+ TLSClientConfig: &tls.Config{
+ RootCAs: pool,
+ },
+ },
+ }
+}
+
+func writeTestCAFiles(t *testing.T) (string, string, *x509.Certificate) {
+ t.Helper()
+
+ privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatalf("GenerateKey error: %v", err)
+ }
+
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ CommonName: "gomodel-test-ca",
+ },
+ NotBefore: time.Now().Add(-time.Hour),
+ NotAfter: time.Now().Add(24 * time.Hour),
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
+ BasicConstraintsValid: true,
+ IsCA: true,
+ }
+
+ der, err := x509.CreateCertificate(rand.Reader, template, template, privateKey.Public(), privateKey)
+ if err != nil {
+ t.Fatalf("CreateCertificate error: %v", err)
+ }
+ cert, err := x509.ParseCertificate(der)
+ if err != nil {
+ t.Fatalf("ParseCertificate error: %v", err)
+ }
+
+ keyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
+ if err != nil {
+ t.Fatalf("MarshalPKCS8PrivateKey error: %v", err)
+ }
+
+ dir := t.TempDir()
+ certPath := filepath.Join(dir, "ca-cert.pem")
+ keyPath := filepath.Join(dir, "ca-key.pem")
+ if err := os.WriteFile(certPath, encodePEMBlock("CERTIFICATE", der), 0o600); err != nil {
+ t.Fatalf("WriteFile cert error: %v", err)
+ }
+ if err := os.WriteFile(keyPath, encodePEMBlock("PRIVATE KEY", keyDER), 0o600); err != nil {
+ t.Fatalf("WriteFile key error: %v", err)
+ }
+
+ return certPath, keyPath, cert
+}
+
+func TestIsForwardProxyRequest(t *testing.T) {
+ absoluteURL, _ := url.Parse("https://api.anthropic.com/v1/messages")
+ tests := []struct {
+ name string
+ req *http.Request
+ want bool
+ }{
+ {
+ name: "connect",
+ req: &http.Request{Method: http.MethodConnect, URL: &url.URL{}},
+ want: true,
+ },
+ {
+ name: "absolute URL",
+ req: &http.Request{Method: http.MethodPost, URL: absoluteURL},
+ want: true,
+ },
+ {
+ name: "normal API route",
+ req: &http.Request{Method: http.MethodPost, URL: &url.URL{Path: "/v1/chat/completions"}},
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := isForwardProxyRequest(tt.req); got != tt.want {
+ t.Fatalf("isForwardProxyRequest() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestProxyBodyForAuditParsesJSON(t *testing.T) {
+ value, tooBig := proxyBodyForAudit([]byte(`{"ok":true}`))
+ if tooBig {
+ t.Fatal("tooBig = true, want false")
+ }
+ parsed, ok := value.(map[string]any)
+ if !ok {
+ t.Fatalf("value = %#v, want map", value)
+ }
+ if parsed["ok"] != true {
+ t.Fatalf("ok = %#v, want true", parsed["ok"])
+ }
+}
+
+func TestProxyBodyForAuditRejectsLargePayload(t *testing.T) {
+ value, tooBig := proxyBodyForAudit([]byte(strings.Repeat("x", auditlog.MaxBodyCapture+1)))
+ if value != nil {
+ t.Fatalf("value = %#v, want nil", value)
+ }
+ if !tooBig {
+ t.Fatal("tooBig = false, want true")
+ }
+}
+
+func TestExtractAnthropicUsageEntryComputesTotal(t *testing.T) {
+ handler := &forwardProxyHandler{}
+ entry := handler.extractAnthropicUsageEntry(
+ []byte(`{"id":"msg_123","model":"claude-sonnet-4-5","usage":{"input_tokens":10,"output_tokens":2,"cache_read_input_tokens":6}}`),
+ "req-123",
+ "claude-sonnet-4-5",
+ "/v1/messages",
+ )
+ if entry == nil {
+ t.Fatal("entry = nil")
+ }
+ if entry.TotalTokens != 12 {
+ t.Fatalf("TotalTokens = %d, want 12", entry.TotalTokens)
+ }
+ if entry.RawData["cache_read_input_tokens"] != 6 {
+ t.Fatalf("RawData[cache_read_input_tokens] = %v, want 6", entry.RawData["cache_read_input_tokens"])
+ }
+}
+
+func TestNormalizeResponseForProxyWriteForcesHTTP11(t *testing.T) {
+ resp := &http.Response{
+ StatusCode: http.StatusUnauthorized,
+ Status: "401 Unauthorized",
+ Proto: "HTTP/2.0",
+ ProtoMajor: 2,
+ ProtoMinor: 0,
+ Body: io.NopCloser(strings.NewReader(`{"error":"unauthorized"}`)),
+ ContentLength: int64(len(`{"error":"unauthorized"}`)),
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ },
+ }
+
+ normalizeResponseForProxyWrite(resp)
+
+ var buf bytes.Buffer
+ if err := resp.Write(&buf); err != nil {
+ t.Fatalf("resp.Write error: %v", err)
+ }
+ if !strings.HasPrefix(buf.String(), "HTTP/1.1 401 Unauthorized\r\n") {
+ t.Fatalf("unexpected wire response prefix: %q", buf.String())
+ }
+}
+
+func TestWriteTestCAFilesProducesParseablePEM(t *testing.T) {
+ certPath, keyPath, cert := writeTestCAFiles(t)
+ if cert == nil {
+ t.Fatal("cert = nil")
+ }
+ certPEM, err := os.ReadFile(certPath)
+ if err != nil {
+ t.Fatalf("ReadFile cert error: %v", err)
+ }
+ keyPEM, err := os.ReadFile(keyPath)
+ if err != nil {
+ t.Fatalf("ReadFile key error: %v", err)
+ }
+ pair, err := tls.X509KeyPair(certPEM, keyPEM)
+ if err != nil {
+ t.Fatalf("X509KeyPair error: %v", err)
+ }
+ leaf, err := x509.ParseCertificate(pair.Certificate[0])
+ if err != nil {
+ t.Fatalf("ParseCertificate error: %v", err)
+ }
+ if leaf.Subject.CommonName != "gomodel-test-ca" {
+ t.Fatalf("CommonName = %q, want gomodel-test-ca", leaf.Subject.CommonName)
+ }
+}
+
+func TestProxyHeaderMapRedactsSensitiveHeaders(t *testing.T) {
+ headers := proxyHeaderMap(http.Header{
+ "Authorization": {"Bearer secret"},
+ "Cookie": {"session=secret"},
+ "X-Test": {"ok"},
+ })
+ data, err := json.Marshal(headers)
+ if err != nil {
+ t.Fatalf("Marshal error: %v", err)
+ }
+ if strings.Contains(string(data), "secret") {
+ t.Fatalf("expected redaction, got %s", data)
+ }
+ if headers["X-Test"] != "ok" {
+ t.Fatalf("X-Test = %q, want ok", headers["X-Test"])
+ }
+}
diff --git a/internal/server/handlers_test.go b/internal/server/handlers_test.go
index 2839ebe6..d5163e1a 100644
--- a/internal/server/handlers_test.go
+++ b/internal/server/handlers_test.go
@@ -14,6 +14,7 @@ import (
"slices"
"sort"
"strings"
+ "sync"
"testing"
"time"
@@ -229,14 +230,23 @@ func setPathParam(c *echo.Context, name, value string) {
}
type capturingAuditLogger struct {
+ mu sync.Mutex
config auditlog.Config
entries []*auditlog.LogEntry
}
func (l *capturingAuditLogger) Write(entry *auditlog.LogEntry) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
l.entries = append(l.entries, entry)
}
+func (l *capturingAuditLogger) Entries() []*auditlog.LogEntry {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ return slices.Clone(l.entries)
+}
+
func (l *capturingAuditLogger) Config() auditlog.Config {
return l.config
}
@@ -3973,17 +3983,26 @@ func (c *capturingUsageLogger) Config() usage.Config { return c.config
func (c *capturingUsageLogger) Close() error { return nil }
type collectingUsageLogger struct {
+ mu sync.Mutex
config usage.Config
entries []*usage.UsageEntry
}
func (c *collectingUsageLogger) Write(entry *usage.UsageEntry) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
if entry == nil {
return
}
c.entries = append(c.entries, entry)
}
+func (c *collectingUsageLogger) Entries() []*usage.UsageEntry {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return slices.Clone(c.entries)
+}
+
func (c *collectingUsageLogger) Config() usage.Config { return c.config }
func (c *collectingUsageLogger) Close() error { return nil }
diff --git a/internal/server/http.go b/internal/server/http.go
index 73bac7f7..2e710c35 100644
--- a/internal/server/http.go
+++ b/internal/server/http.go
@@ -29,6 +29,7 @@ import (
type Server struct {
echo *echo.Echo
handler *Handler
+ root http.Handler
responseCacheMiddleware *responsecache.ResponseCacheMiddleware
}
@@ -58,6 +59,7 @@ type Config struct {
DashboardHandler *dashboard.Handler // Dashboard UI handler (nil if disabled)
SwaggerEnabled bool // Whether to expose the Swagger UI at /swagger/index.html
ResponseCacheMiddleware *responsecache.ResponseCacheMiddleware // Optional: response cache middleware for cacheable endpoints
+ ExperimentalForwardProxy *ForwardProxyConfig // Optional: experimental forward proxy wrapper
GuardrailsHash string // Optional: SHA-256 hash of active guardrail rules; stored in context post-patch for semantic cache
}
@@ -283,9 +285,23 @@ func New(provider core.RoutableProvider, cfg *Config) *Server {
if cfg != nil {
rcm = cfg.ResponseCacheMiddleware
}
+ root := http.Handler(e)
+ if cfg != nil && cfg.ExperimentalForwardProxy != nil && cfg.ExperimentalForwardProxy.Enabled {
+ proxyCfg := *cfg.ExperimentalForwardProxy
+ proxyCfg.AuditLogger = auditLogger
+ proxyCfg.UsageLogger = usageLogger
+ proxyCfg.PricingResolver = pricingResolver
+ proxyHandler, err := NewForwardProxyHandler(root, &proxyCfg)
+ if err != nil {
+ slog.Error("failed to enable experimental forward proxy", "error", err)
+ } else {
+ root = proxyHandler
+ }
+ }
return &Server{
echo: e,
handler: handler,
+ root: root,
responseCacheMiddleware: rcm,
}
}
@@ -303,7 +319,7 @@ func (s *Server) Start(ctx context.Context, addr string) error {
Address: addr,
HideBanner: true,
}
- return sc.Start(ctx, s.echo)
+ return sc.Start(ctx, s.root)
}
// Shutdown releases server resources. The HTTP server itself is stopped by
@@ -318,7 +334,7 @@ func (s *Server) Shutdown(_ context.Context) error {
// ServeHTTP implements the http.Handler interface, allowing Server to be used with httptest
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- s.echo.ServeHTTP(w, r)
+ s.root.ServeHTTP(w, r)
}
func parseBodySizeLimitBytes(limit string) int64 {
diff --git a/internal/usage/stream_observer.go b/internal/usage/stream_observer.go
index 1b3c8d44..4818b215 100644
--- a/internal/usage/stream_observer.go
+++ b/internal/usage/stream_observer.go
@@ -31,7 +31,7 @@ func NewStreamUsageObserver(logger LoggerInterface, model, provider, requestID,
func (o *StreamUsageObserver) OnJSONEvent(chunk map[string]any) {
entry := o.extractUsageFromEvent(chunk)
if entry != nil {
- o.cachedEntry = entry
+ o.cachedEntry = mergeUsageEntries(o.cachedEntry, entry)
}
}
@@ -54,6 +54,19 @@ func (o *StreamUsageObserver) extractUsageFromEvent(chunk map[string]any) *Usage
}
usageRaw, ok := chunk["usage"]
+ if !ok {
+ if eventType, _ := chunk["type"].(string); eventType == "message_start" {
+ if message, msgOK := chunk["message"].(map[string]any); msgOK {
+ usageRaw, ok = message["usage"]
+ if id, idOK := message["id"].(string); idOK && id != "" {
+ providerID = id
+ }
+ if m, modelOK := message["model"].(string); modelOK && m != "" {
+ model = m
+ }
+ }
+ }
+ }
if !ok {
if eventType, _ := chunk["type"].(string); eventType == "response.completed" || eventType == "response.done" {
if response, respOK := chunk["response"].(map[string]any); respOK {
@@ -138,3 +151,69 @@ func (o *StreamUsageObserver) extractUsageFromEvent(chunk map[string]any) *Usage
pricingArgs...,
)
}
+
+func mergeUsageEntries(prev, next *UsageEntry) *UsageEntry {
+ if prev == nil {
+ return next
+ }
+ if next == nil {
+ return prev
+ }
+
+ merged := *prev
+
+ if next.ProviderID != "" {
+ merged.ProviderID = next.ProviderID
+ }
+ if next.Model != "" {
+ merged.Model = next.Model
+ }
+ if next.Provider != "" {
+ merged.Provider = next.Provider
+ }
+ if next.Endpoint != "" {
+ merged.Endpoint = next.Endpoint
+ }
+ if next.RequestID != "" {
+ merged.RequestID = next.RequestID
+ }
+ if next.Timestamp.After(merged.Timestamp) {
+ merged.Timestamp = next.Timestamp
+ }
+
+ if next.InputTokens > 0 {
+ merged.InputTokens = next.InputTokens
+ }
+ if next.OutputTokens > 0 {
+ merged.OutputTokens = next.OutputTokens
+ }
+ if next.TotalTokens > 0 {
+ merged.TotalTokens = next.TotalTokens
+ } else if merged.InputTokens > 0 || merged.OutputTokens > 0 {
+ merged.TotalTokens = merged.InputTokens + merged.OutputTokens
+ }
+
+ switch {
+ case merged.RawData == nil && next.RawData != nil:
+ merged.RawData = cloneRawData(next.RawData)
+ case merged.RawData != nil && next.RawData != nil:
+ for key, value := range next.RawData {
+ merged.RawData[key] = value
+ }
+ }
+
+ if next.InputCost != nil {
+ merged.InputCost = next.InputCost
+ }
+ if next.OutputCost != nil {
+ merged.OutputCost = next.OutputCost
+ }
+ if next.TotalCost != nil {
+ merged.TotalCost = next.TotalCost
+ }
+ if next.CostsCalculationCaveat != "" {
+ merged.CostsCalculationCaveat = next.CostsCalculationCaveat
+ }
+
+ return &merged
+}
diff --git a/internal/usage/stream_observer_test.go b/internal/usage/stream_observer_test.go
index 3bfeba06..8636ceca 100644
--- a/internal/usage/stream_observer_test.go
+++ b/internal/usage/stream_observer_test.go
@@ -330,3 +330,61 @@ data: [DONE]
t.Errorf("TotalTokens = %d, want 8", entry.TotalTokens)
}
}
+
+func TestStreamUsageObserverAnthropicMessagesMergesUsageAcrossEvents(t *testing.T) {
+ streamData := `event: message_start
+data: {"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[],"usage":{"input_tokens":10,"output_tokens":0,"cache_read_input_tokens":6}}}
+
+event: content_block_delta
+data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
+
+event: message_delta
+data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":2}}
+
+data: [DONE]
+
+`
+ logger := &trackingLogger{enabled: true}
+ stream := streaming.NewObservedSSEStream(
+ io.NopCloser(strings.NewReader(streamData)),
+ NewStreamUsageObserver(logger, "claude-sonnet-4-5", "anthropic", "req-anthropic", "/v1/messages", nil),
+ )
+
+ data, err := io.ReadAll(stream)
+ if err != nil {
+ t.Fatalf("ReadAll error: %v", err)
+ }
+ if string(data) != streamData {
+ t.Fatalf("stream passthrough mismatch")
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("Close error: %v", err)
+ }
+
+ entries := logger.getEntries()
+ if len(entries) != 1 {
+ t.Fatalf("expected 1 entry, got %d", len(entries))
+ }
+ entry := entries[0]
+ if entry.ProviderID != "msg_123" {
+ t.Fatalf("ProviderID = %q, want msg_123", entry.ProviderID)
+ }
+ if entry.Model != "claude-sonnet-4-5" {
+ t.Fatalf("Model = %q, want claude-sonnet-4-5", entry.Model)
+ }
+ if entry.InputTokens != 10 {
+ t.Fatalf("InputTokens = %d, want 10", entry.InputTokens)
+ }
+ if entry.OutputTokens != 2 {
+ t.Fatalf("OutputTokens = %d, want 2", entry.OutputTokens)
+ }
+ if entry.TotalTokens != 12 {
+ t.Fatalf("TotalTokens = %d, want 12", entry.TotalTokens)
+ }
+ if entry.RawData == nil {
+ t.Fatal("RawData = nil")
+ }
+ if entry.RawData["cache_read_input_tokens"] != 6 {
+ t.Fatalf("RawData[cache_read_input_tokens] = %v, want 6", entry.RawData["cache_read_input_tokens"])
+ }
+}