diff --git a/internal/config/config_core.go b/internal/config/config_core.go index 9b1c6dfd..fc715def 100644 --- a/internal/config/config_core.go +++ b/internal/config/config_core.go @@ -215,6 +215,18 @@ type ServerConfig struct { // fallback uses the HTTP client's request timeout instead. Increase this for backends that are // slow to initialize. Only applies to HTTP server types. Default: 30 seconds. ConnectTimeout int `toml:"connect_timeout" json:"connect_timeout,omitempty"` + + // RateLimitThreshold is the number of consecutive rate-limit errors from this backend + // that will trip the circuit breaker (transition CLOSED → OPEN). When OPEN, requests + // are immediately rejected until the cooldown period elapses. Default: 3. + // Supported in file-based config (TOML/JSON); stdin JSON config does not currently accept this field. + RateLimitThreshold int `toml:"rate_limit_threshold" json:"rate_limit_threshold,omitempty"` + + // RateLimitCooldown is the number of seconds the circuit breaker stays OPEN before + // allowing a single probe request (transition OPEN → HALF-OPEN). If the probe + // succeeds the circuit closes; if rate-limited again it re-opens. Default: 60. + // Supported in file-based config (TOML/JSON); stdin JSON config does not currently accept this field. + RateLimitCooldown int `toml:"rate_limit_cooldown" json:"rate_limit_cooldown,omitempty"` } // GuardConfig represents a guard configuration for DIFC enforcement. diff --git a/internal/config/rate_limit_config_test.go b/internal/config/rate_limit_config_test.go new file mode 100644 index 00000000..ac753016 --- /dev/null +++ b/internal/config/rate_limit_config_test.go @@ -0,0 +1,36 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServerConfig_RateLimitFields(t *testing.T) { + t.Parallel() + toml := ` +[servers.github] +command = "docker" +args = ["run", "--rm", "-i", "ghcr.io/github/github-mcp-server:latest"] +rate_limit_threshold = 5 +rate_limit_cooldown = 120 +` + path := writeTempTOML(t, toml) + cfg, err := LoadFromFile(path) + require.NoError(t, err) + srv := cfg.Servers["github"] + assert.Equal(t, 5, srv.RateLimitThreshold) + assert.Equal(t, 120, srv.RateLimitCooldown) +} + +func TestServerConfig_RateLimitFieldsDefaultToZero(t *testing.T) { + t.Parallel() + toml := validDockerServerTOML + path := writeTempTOML(t, toml) + cfg, err := LoadFromFile(path) + require.NoError(t, err) + srv := cfg.Servers["github"] + assert.Equal(t, 0, srv.RateLimitThreshold) + assert.Equal(t, 0, srv.RateLimitCooldown) +} diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 3cf420b6..e09d0424 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -7,6 +7,8 @@ import ( "fmt" "io" "net/http" + "strconv" + "time" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -350,8 +352,11 @@ func (h *proxyHandler) passthrough(w http.ResponseWriter, r *http.Request, path } // writeResponse writes an upstream response to the client. +// When the upstream signals rate-limiting (HTTP 429 or X-RateLimit-Remaining == 0), +// it injects a Retry-After header and logs the event at ERROR level. func (h *proxyHandler) writeResponse(w http.ResponseWriter, resp *http.Response, body []byte) { copyResponseHeaders(w, resp) + injectRetryAfterIfRateLimited(w, resp) w.WriteHeader(resp.StatusCode) w.Write(body) } @@ -416,6 +421,66 @@ func copyResponseHeaders(w http.ResponseWriter, resp *http.Response) { } } +// injectRetryAfterIfRateLimited inspects the upstream response for rate-limit signals +// (HTTP 429 or X-Ratelimit-Remaining == 0). When detected it: +// 1. Injects a Retry-After header so the client knows when to retry. +// 2. Logs the event at ERROR level so operators can monitor rate-limit incidents. +func injectRetryAfterIfRateLimited(w http.ResponseWriter, resp *http.Response) { + is429 := resp.StatusCode == http.StatusTooManyRequests + // Use Go's canonical header key form (textproto.CanonicalMIMEHeaderKey produces + // "X-Ratelimit-Remaining", matching GitHub's actual response headers). + remaining := resp.Header.Get("X-Ratelimit-Remaining") + resetHeader := resp.Header.Get("X-Ratelimit-Reset") + + isRateLimited := is429 || remaining == "0" + if !isRateLimited { + return + } + + resetAt := parseRateLimitReset(resetHeader) + retryAfter := computeRetryAfter(resetAt) + + w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) + + logger.LogError("client", + "upstream rate limit hit: status=%d X-Ratelimit-Remaining=%s X-Ratelimit-Reset=%s retry-after=%ds", + resp.StatusCode, remaining, resetHeader, retryAfter) +} + +// parseRateLimitReset parses the X-RateLimit-Reset Unix-timestamp header. +// Returns zero time when absent or malformed. +func parseRateLimitReset(value string) time.Time { + if value == "" { + return time.Time{} + } + unix, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return time.Time{} + } + return time.Unix(unix, 0) +} + +// computeRetryAfter returns the number of seconds to wait before retrying. +// When resetAt is in the future the delay is clamped to [1, 3600] seconds. +// When resetAt is zero or in the past a default of 60 seconds is returned. +func computeRetryAfter(resetAt time.Time) int { + const ( + defaultDelay = 60 + maxDelay = 3600 + ) + if resetAt.IsZero() { + return defaultDelay + } + secs := int(time.Until(resetAt).Seconds()) + 1 // add 1s buffer + if secs < 1 { + return defaultDelay + } + if secs > maxDelay { + return maxDelay + } + return secs +} + // rewrapSearchResponse re-wraps filtered items into the original search response // envelope. GitHub search endpoints return {"total_count": N, "items": [...]}; // ToResult() returns a bare array, so we rebuild the wrapper. diff --git a/internal/proxy/rate_limit_test.go b/internal/proxy/rate_limit_test.go new file mode 100644 index 00000000..b05a19d1 --- /dev/null +++ b/internal/proxy/rate_limit_test.go @@ -0,0 +1,139 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestInjectRetryAfterIfRateLimited verifies Retry-After injection and logging for +// rate-limited upstream responses. +func TestInjectRetryAfterIfRateLimited(t *testing.T) { + t.Parallel() + + t.Run("HTTP 429 injects Retry-After", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + future := time.Now().Add(30 * time.Second) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + "X-Ratelimit-Reset": []string{strconv.FormatInt(future.Unix(), 10)}, + }, + } + injectRetryAfterIfRateLimited(w, resp) + retryAfter := w.Header().Get("Retry-After") + assert.NotEmpty(t, retryAfter, "Retry-After should be set on 429") + secs, err := strconv.Atoi(retryAfter) + assert.NoError(t, err) + assert.Greater(t, secs, 0, "Retry-After should be positive") + }) + + t.Run("X-Ratelimit-Remaining 0 injects Retry-After", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + future := time.Now().Add(60 * time.Second) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "X-Ratelimit-Remaining": []string{"0"}, + "X-Ratelimit-Reset": []string{strconv.FormatInt(future.Unix(), 10)}, + }, + } + injectRetryAfterIfRateLimited(w, resp) + assert.NotEmpty(t, w.Header().Get("Retry-After"), "Retry-After should be set when remaining=0") + }) + + t.Run("non-zero remaining does not inject Retry-After", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "X-Ratelimit-Remaining": []string{"100"}, + }, + } + injectRetryAfterIfRateLimited(w, resp) + assert.Empty(t, w.Header().Get("Retry-After")) + }) + + t.Run("200 with no rate-limit headers does not inject Retry-After", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + } + injectRetryAfterIfRateLimited(w, resp) + assert.Empty(t, w.Header().Get("Retry-After")) + }) + + t.Run("429 without reset header uses default delay", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + injectRetryAfterIfRateLimited(w, resp) + retryAfter := w.Header().Get("Retry-After") + assert.Equal(t, "60", retryAfter, "default delay should be 60 seconds") + }) +} + +// TestParseRateLimitReset verifies the Unix-timestamp header parser. +func TestParseRateLimitReset(t *testing.T) { + t.Parallel() + + t.Run("empty string returns zero", func(t *testing.T) { + t.Parallel() + assert.True(t, parseRateLimitReset("").IsZero()) + }) + + t.Run("invalid string returns zero", func(t *testing.T) { + t.Parallel() + assert.True(t, parseRateLimitReset("not-a-number").IsZero()) + }) + + t.Run("valid unix timestamp parses correctly", func(t *testing.T) { + t.Parallel() + ts := time.Now().Add(60 * time.Second) + got := parseRateLimitReset(strconv.FormatInt(ts.Unix(), 10)) + assert.False(t, got.IsZero()) + assert.Equal(t, ts.Unix(), got.Unix()) + }) +} + +// TestComputeRetryAfter verifies the retry-after calculation. +func TestComputeRetryAfter(t *testing.T) { + t.Parallel() + + t.Run("zero time returns default", func(t *testing.T) { + t.Parallel() + assert.Equal(t, 60, computeRetryAfter(time.Time{})) + }) + + t.Run("past time returns default", func(t *testing.T) { + t.Parallel() + assert.Equal(t, 60, computeRetryAfter(time.Now().Add(-time.Minute))) + }) + + t.Run("future time returns seconds until reset", func(t *testing.T) { + t.Parallel() + future := time.Now().Add(30 * time.Second) + secs := computeRetryAfter(future) + // Allow ±2s for timing jitter. + assert.GreaterOrEqual(t, secs, 29) + assert.LessOrEqual(t, secs, 32) + }) + + t.Run("very far future is clamped to max", func(t *testing.T) { + t.Parallel() + farFuture := time.Now().Add(24 * time.Hour) + assert.Equal(t, 3600, computeRetryAfter(farFuture)) + }) +} diff --git a/internal/server/circuit_breaker.go b/internal/server/circuit_breaker.go new file mode 100644 index 00000000..4e67f8a3 --- /dev/null +++ b/internal/server/circuit_breaker.go @@ -0,0 +1,331 @@ +package server + +import ( + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/github/gh-aw-mcpg/internal/logger" +) + +// circuitBreakerState represents the state of a circuit breaker. +type circuitBreakerState int + +const ( + // circuitClosed is normal operation — requests pass through. + circuitClosed circuitBreakerState = iota + // circuitOpen means the circuit is tripped — requests are rejected immediately. + circuitOpen + // circuitHalfOpen means one probe request is allowed to test recovery. + circuitHalfOpen +) + +func (s circuitBreakerState) String() string { + switch s { + case circuitClosed: + return "CLOSED" + case circuitOpen: + return "OPEN" + case circuitHalfOpen: + return "HALF-OPEN" + default: + return "UNKNOWN" + } +} + +// DefaultRateLimitThreshold is the number of consecutive rate-limit errors +// before the circuit breaker opens. +const DefaultRateLimitThreshold = 3 + +// DefaultRateLimitCooldown is the number of seconds the circuit stays OPEN +// before transitioning to HALF-OPEN to probe one request. +const DefaultRateLimitCooldown = 60 * time.Second + +var logCircuitBreaker = logger.New("server:circuit_breaker") + +// circuitBreaker implements a per-backend rate-limit circuit breaker. +// +// State transitions: +// +// CLOSED → OPEN : after threshold consecutive rate-limit errors +// OPEN → HALF-OPEN : after cooldown period elapses +// HALF-OPEN → CLOSED : probe request succeeds +// HALF-OPEN → OPEN : probe request is rate-limited again +type circuitBreaker struct { + mu sync.Mutex + + state circuitBreakerState + consecutiveErrors int + openedAt time.Time + // resetAt is the time when the upstream rate limit resets, parsed from + // the X-RateLimit-Reset header or the tool response message. + resetAt time.Time + probeInFlight bool + serverID string + + threshold int + cooldown time.Duration + + // nowFunc returns the current time. Defaults to time.Now; overridden in tests + // to avoid flaky time.Sleep-based assertions. + nowFunc func() time.Time +} + +// newCircuitBreaker creates a circuit breaker for the given server ID. +// threshold is the number of consecutive rate-limit errors before opening; +// cooldown is how long to stay OPEN before probing. +func newCircuitBreaker(serverID string, threshold int, cooldown time.Duration) *circuitBreaker { + if threshold <= 0 { + threshold = DefaultRateLimitThreshold + } + if cooldown <= 0 { + cooldown = DefaultRateLimitCooldown + } + return &circuitBreaker{ + serverID: serverID, + state: circuitClosed, + threshold: threshold, + cooldown: cooldown, + nowFunc: time.Now, + } +} + +// ErrCircuitOpen is returned when the circuit breaker is OPEN and a request is rejected. +type ErrCircuitOpen struct { + ServerID string + ResetAt time.Time +} + +func (e *ErrCircuitOpen) Error() string { + if e.ResetAt.IsZero() { + return fmt.Sprintf("rate limit circuit breaker is OPEN for server %q — requests temporarily rejected", e.ServerID) + } + return fmt.Sprintf("rate limit circuit breaker is OPEN for server %q — rate limit resets at %s (retry after %s)", + e.ServerID, e.ResetAt.UTC().Format(time.RFC3339), time.Until(e.ResetAt).Round(time.Second)) +} + +// Allow reports whether a request should be allowed through. It also handles +// the OPEN → HALF-OPEN transition when the cooldown has elapsed. +// Returns an *ErrCircuitOpen error when the circuit is OPEN. +func (cb *circuitBreaker) Allow() error { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case circuitClosed: + return nil + + case circuitOpen: + // Check whether we should transition to HALF-OPEN. + // We use the upstream reset time when available, otherwise the cooldown. + now := cb.nowFunc() + var openUntil time.Time + if !cb.resetAt.IsZero() && cb.resetAt.After(cb.openedAt) { + openUntil = cb.resetAt + } else { + openUntil = cb.openedAt.Add(cb.cooldown) + } + if now.After(openUntil) { + logCircuitBreaker.Printf("server %q circuit breaker OPEN → HALF-OPEN after cooldown", cb.serverID) + logger.LogInfo("backend", "circuit breaker for server %q transitioning OPEN → HALF-OPEN", cb.serverID) + cb.state = circuitHalfOpen + cb.probeInFlight = true + return nil // allow the single probe + } + return &ErrCircuitOpen{ServerID: cb.serverID, ResetAt: cb.resetAt} + + case circuitHalfOpen: + // Only one probe is allowed; further requests are blocked until the probe resolves. + if cb.probeInFlight { + return &ErrCircuitOpen{ServerID: cb.serverID, ResetAt: cb.resetAt} + } + // This shouldn't normally happen (probe resolved but state wasn't updated), + // but allow through defensively. + return nil + } + + return nil +} + +// RecordSuccess records a successful (non-rate-limited) response. +// In HALF-OPEN state this closes the circuit. +func (cb *circuitBreaker) RecordSuccess() { + cb.mu.Lock() + defer cb.mu.Unlock() + + prev := cb.state + cb.consecutiveErrors = 0 + cb.probeInFlight = false + if cb.state == circuitHalfOpen { + cb.state = circuitClosed + cb.resetAt = time.Time{} + logCircuitBreaker.Printf("server %q circuit breaker HALF-OPEN → CLOSED (probe succeeded)", cb.serverID) + logger.LogInfo("backend", "circuit breaker for server %q recovered: HALF-OPEN → CLOSED", cb.serverID) + } else if prev != circuitClosed { + cb.state = circuitClosed + } +} + +// RecordRateLimit records a rate-limit error for the given server. +// resetAt is the time the upstream rate limit resets (may be zero if unknown). +// When the consecutive error count reaches threshold the circuit opens. +func (cb *circuitBreaker) RecordRateLimit(resetAt time.Time) { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.consecutiveErrors++ + cb.probeInFlight = false + if !resetAt.IsZero() { + cb.resetAt = resetAt + } + + switch cb.state { + case circuitClosed: + if cb.consecutiveErrors >= cb.threshold { + cb.state = circuitOpen + cb.openedAt = cb.nowFunc() + logger.LogError("backend", + "circuit breaker for server %q OPENED after %d consecutive rate-limit errors; resets at %s", + cb.serverID, cb.consecutiveErrors, formatResetAt(cb.resetAt)) + logCircuitBreaker.Printf("server %q circuit breaker CLOSED → OPEN (errors=%d)", cb.serverID, cb.consecutiveErrors) + } else { + logger.LogWarn("backend", + "rate-limit error for server %q (consecutive=%d/%d); resets at %s", + cb.serverID, cb.consecutiveErrors, cb.threshold, formatResetAt(cb.resetAt)) + } + + case circuitHalfOpen: + // Probe failed — re-open the circuit. + cb.state = circuitOpen + cb.openedAt = cb.nowFunc() + logger.LogError("backend", + "circuit breaker for server %q re-OPENED after probe was rate-limited; resets at %s", + cb.serverID, formatResetAt(cb.resetAt)) + logCircuitBreaker.Printf("server %q circuit breaker HALF-OPEN → OPEN (probe rate-limited)", cb.serverID) + + case circuitOpen: + // Already open — update reset time. + logger.LogWarn("backend", "server %q circuit breaker still OPEN; resets at %s", + cb.serverID, formatResetAt(cb.resetAt)) + } +} + +// State returns the current circuit breaker state (for observability). +func (cb *circuitBreaker) State() circuitBreakerState { + cb.mu.Lock() + defer cb.mu.Unlock() + return cb.state +} + +// formatResetAt returns a human-readable representation of a reset time. +func formatResetAt(t time.Time) string { + if t.IsZero() { + return "unknown" + } + return fmt.Sprintf("%s (in %s)", t.UTC().Format(time.RFC3339), time.Until(t).Round(time.Second)) +} + +// extractRateLimitErrorText extracts the text content from a raw tool result +// that has been identified as a rate-limit error. Returns the original backend +// message so agents see the actual upstream error rather than a synthetic one. +func extractRateLimitErrorText(result interface{}) string { + m, ok := result.(map[string]interface{}) + if !ok { + return "rate limit exceeded" + } + contents, _ := m["content"].([]interface{}) + for _, c := range contents { + cm, ok := c.(map[string]interface{}) + if !ok { + continue + } + if text, ok := cm["text"].(string); ok && text != "" { + return text + } + } + return "rate limit exceeded" +} + +// isRateLimitToolResult reports whether a raw tool call result indicates +// a rate-limit error from the GitHub MCP server. It inspects the `isError` +// flag and the text content for well-known rate-limit phrases. +// +// The GitHub MCP server returns rate-limit errors as: +// +// {"content":[{"type":"text","text":"... 403 API rate limit exceeded ..."}],"isError":true} +func isRateLimitToolResult(result interface{}) (bool, time.Time) { + m, ok := result.(map[string]interface{}) + if !ok { + return false, time.Time{} + } + + // Only inspect error results. + isErr, _ := m["isError"].(bool) + if !isErr { + return false, time.Time{} + } + + contents, _ := m["content"].([]interface{}) + for _, c := range contents { + cm, ok := c.(map[string]interface{}) + if !ok { + continue + } + text, _ := cm["text"].(string) + if isRateLimitText(text) { + resetAt := parseRateLimitResetFromText(text) + return true, resetAt + } + } + return false, time.Time{} +} + +// isRateLimitText returns true when the message indicates a GitHub rate-limit error. +func isRateLimitText(text string) bool { + lower := strings.ToLower(text) + return strings.Contains(lower, "rate limit exceeded") || + (strings.Contains(lower, "rate limit") && strings.Contains(lower, "403")) || + strings.Contains(lower, "api rate limit") || + strings.Contains(lower, "secondary rate limit") || + strings.Contains(lower, "too many requests") +} + +// parseRateLimitResetFromText attempts to extract a reset timestamp from the +// rate-limit error text. The GitHub MCP server includes messages like +// "API rate limit exceeded [rate reset in 42s]". +// Returns zero time when the value cannot be parsed or is 0 seconds. +func parseRateLimitResetFromText(text string) time.Time { + // Look for "[rate reset in Ns]" pattern. + lower := strings.ToLower(text) + idx := strings.Index(lower, "rate reset in ") + if idx < 0 { + return time.Time{} + } + rest := text[idx+len("rate reset in "):] + // Find the first non-digit character. + end := strings.IndexAny(rest, "s])") + if end < 0 { + return time.Time{} + } + secs, err := strconv.ParseInt(strings.TrimSpace(rest[:end]), 10, 64) + if err != nil || secs <= 0 { + return time.Time{} + } + return time.Now().Add(time.Duration(secs) * time.Second) +} + +// parseRateLimitResetHeader parses the Unix-timestamp value of the +// X-RateLimit-Reset HTTP header into a time.Time. +// Returns zero time when the header is absent or malformed. +func parseRateLimitResetHeader(value string) time.Time { + if value == "" { + return time.Time{} + } + unix, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64) + if err != nil { + return time.Time{} + } + return time.Unix(unix, 0) +} diff --git a/internal/server/circuit_breaker_test.go b/internal/server/circuit_breaker_test.go new file mode 100644 index 00000000..f31c9270 --- /dev/null +++ b/internal/server/circuit_breaker_test.go @@ -0,0 +1,435 @@ +package server + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCircuitBreaker_InitialStateClosed verifies new circuit breakers start CLOSED. +func TestCircuitBreaker_InitialStateClosed(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 3, 60*time.Second) + assert.Equal(t, circuitClosed, cb.State()) + assert.NoError(t, cb.Allow()) +} + +// TestCircuitBreaker_OpensAfterThreshold verifies the circuit opens after N consecutive errors. +func TestCircuitBreaker_OpensAfterThreshold(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 3, 60*time.Second) + + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitClosed, cb.State(), "should remain CLOSED after 1 error") + assert.NoError(t, cb.Allow()) + + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitClosed, cb.State(), "should remain CLOSED after 2 errors") + assert.NoError(t, cb.Allow()) + + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitOpen, cb.State(), "should be OPEN after 3 errors (threshold)") + + err := cb.Allow() + require.Error(t, err, "OPEN circuit should reject requests") + var openErr *ErrCircuitOpen + require.ErrorAs(t, err, &openErr) + assert.Equal(t, "test", openErr.ServerID) +} + +// TestCircuitBreaker_SuccessResetsCounter verifies that a success resets the consecutive-error counter. +func TestCircuitBreaker_SuccessResetsCounter(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 3, 60*time.Second) + + cb.RecordRateLimit(time.Time{}) + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitClosed, cb.State(), "still CLOSED after 2 errors") + + cb.RecordSuccess() + assert.Equal(t, circuitClosed, cb.State(), "still CLOSED after success") + + // After a success the counter resets, so 2 more errors should NOT open the circuit. + cb.RecordRateLimit(time.Time{}) + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitClosed, cb.State(), "should be CLOSED (counter reset by success)") +} + +// TestCircuitBreaker_HalfOpenAfterCooldown verifies OPEN → HALF-OPEN transition. +func TestCircuitBreaker_HalfOpenAfterCooldown(t *testing.T) { + t.Parallel() + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Minute) + cb.nowFunc = func() time.Time { return fakeNow } + + cb.RecordRateLimit(time.Time{}) + require.Equal(t, circuitOpen, cb.State(), "should be OPEN after 1 error") + + // Before cooldown: still OPEN. + fakeNow = fakeNow.Add(30 * time.Second) + require.Error(t, cb.Allow(), "should reject before cooldown elapses") + + // After cooldown: transitions to HALF-OPEN. + fakeNow = fakeNow.Add(31 * time.Second) + err := cb.Allow() + assert.NoError(t, err, "should allow probe after cooldown") + assert.Equal(t, circuitHalfOpen, cb.State(), "should be HALF-OPEN after cooldown") +} + +// TestCircuitBreaker_HalfOpenClosesOnSuccess verifies HALF-OPEN → CLOSED on probe success. +func TestCircuitBreaker_HalfOpenClosesOnSuccess(t *testing.T) { + t.Parallel() + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Minute) + cb.nowFunc = func() time.Time { return fakeNow } + + cb.RecordRateLimit(time.Time{}) + require.Equal(t, circuitOpen, cb.State()) + + fakeNow = fakeNow.Add(2 * time.Minute) + require.NoError(t, cb.Allow()) // probe allowed + + cb.RecordSuccess() + assert.Equal(t, circuitClosed, cb.State(), "should be CLOSED after probe success") + assert.NoError(t, cb.Allow(), "CLOSED circuit should allow requests") +} + +// TestCircuitBreaker_HalfOpenReOpensOnRateLimit verifies HALF-OPEN → OPEN on probe failure. +func TestCircuitBreaker_HalfOpenReOpensOnRateLimit(t *testing.T) { + t.Parallel() + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Minute) + cb.nowFunc = func() time.Time { return fakeNow } + + cb.RecordRateLimit(time.Time{}) + require.Equal(t, circuitOpen, cb.State()) + + fakeNow = fakeNow.Add(2 * time.Minute) + require.NoError(t, cb.Allow()) // probe allowed + + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitOpen, cb.State(), "should be OPEN again after probe is rate-limited") + + err := cb.Allow() + require.Error(t, err) + var openErr *ErrCircuitOpen + require.ErrorAs(t, err, &openErr) +} + +// TestCircuitBreaker_ResetAtFromHeader verifies the reset time from upstream is used. +func TestCircuitBreaker_ResetAtFromHeader(t *testing.T) { + t.Parallel() + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Hour) + cb.nowFunc = func() time.Time { return fakeNow } + + resetAt := fakeNow.Add(30 * time.Second) + cb.RecordRateLimit(resetAt) + require.Equal(t, circuitOpen, cb.State()) + + // Before the reset time: still OPEN. + fakeNow = fakeNow.Add(15 * time.Second) + require.Error(t, cb.Allow()) + + // After the reset time: transitions to HALF-OPEN (before cooldown would elapse). + fakeNow = fakeNow.Add(20 * time.Second) + err := cb.Allow() + assert.NoError(t, err, "should allow probe after reset time") + assert.Equal(t, circuitHalfOpen, cb.State()) +} + +// TestCircuitBreaker_HalfOpenBlocksConcurrentProbes verifies that only one probe is allowed in HALF-OPEN. +func TestCircuitBreaker_HalfOpenBlocksConcurrentProbes(t *testing.T) { + t.Parallel() + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Minute) + cb.nowFunc = func() time.Time { return fakeNow } + + cb.RecordRateLimit(time.Time{}) + require.Equal(t, circuitOpen, cb.State()) + + // Advance past cooldown to trigger HALF-OPEN. + fakeNow = fakeNow.Add(2 * time.Minute) + + // First Allow() should succeed (the probe). + require.NoError(t, cb.Allow()) + assert.Equal(t, circuitHalfOpen, cb.State()) + + // Second Allow() should be rejected — probe is already in flight. + err := cb.Allow() + require.Error(t, err, "concurrent requests in HALF-OPEN should be rejected") + var openErr *ErrCircuitOpen + require.ErrorAs(t, err, &openErr) + + // After the probe succeeds, requests should be allowed again. + cb.RecordSuccess() + assert.Equal(t, circuitClosed, cb.State()) + assert.NoError(t, cb.Allow()) +} + +// TestCircuitBreaker_DefaultsApplied verifies zero-value config gets sensible defaults. +func TestCircuitBreaker_DefaultsApplied(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 0, 0) + assert.Equal(t, DefaultRateLimitThreshold, cb.threshold) + assert.Equal(t, DefaultRateLimitCooldown, cb.cooldown) +} + +// TestCircuitBreaker_ErrOpenMessage verifies ErrCircuitOpen.Error() content. +func TestCircuitBreaker_ErrOpenMessage(t *testing.T) { + t.Parallel() + + t.Run("no reset time", func(t *testing.T) { + t.Parallel() + err := &ErrCircuitOpen{ServerID: "github"} + assert.Contains(t, err.Error(), "github") + assert.Contains(t, err.Error(), "OPEN") + }) + + t.Run("with reset time", func(t *testing.T) { + t.Parallel() + reset := time.Now().Add(30 * time.Second) + err := &ErrCircuitOpen{ServerID: "github", ResetAt: reset} + assert.Contains(t, err.Error(), "github") + assert.Contains(t, err.Error(), "retry after") + }) +} + +// TestIsRateLimitToolResult verifies rate-limit detection from GitHub MCP tool results. +func TestIsRateLimitToolResult(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + result interface{} + wantHit bool + wantReset bool // whether a non-zero reset time is expected + }{ + { + name: "standard rate limit exceeded message", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "failed to search repositories: 403 API rate limit exceeded [rate reset in 42s]", + }, + }, + }, + wantHit: true, + wantReset: true, + }, + { + name: "secondary rate limit", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "secondary rate limit triggered", + }, + }, + }, + wantHit: true, + }, + { + name: "too many requests", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "too many requests", + }, + }, + }, + wantHit: true, + }, + { + name: "non-rate-limit error", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "repository not found", + }, + }, + }, + wantHit: false, + }, + { + name: "successful result (isError false)", + result: map[string]interface{}{ + "isError": false, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "API rate limit exceeded but isError is false", + }, + }, + }, + wantHit: false, + }, + { + name: "nil result", + result: nil, + wantHit: false, + }, + { + name: "non-map result", + result: "some string", + wantHit: false, + }, + { + name: "rate reset in 0s", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "API rate limit exceeded [rate reset in 0s]", + }, + }, + }, + wantHit: true, + wantReset: false, // 0s means no future time + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + hit, resetAt := isRateLimitToolResult(tt.result) + assert.Equal(t, tt.wantHit, hit, "isRateLimitToolResult mismatch") + if tt.wantReset { + assert.False(t, resetAt.IsZero(), "expected non-zero resetAt") + } + }) + } +} + +// TestParseRateLimitResetFromText verifies reset time parsing from error messages. +func TestParseRateLimitResetFromText(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + text string + wantZero bool + }{ + { + name: "42 seconds", + text: "rate limit exceeded [rate reset in 42s]", + wantZero: false, + }, + { + name: "0 seconds gives zero time", + text: "rate limit exceeded [rate reset in 0s]", + wantZero: true, + }, + { + name: "no pattern", + text: "some other error", + wantZero: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := parseRateLimitResetFromText(tt.text) + if tt.wantZero { + assert.True(t, got.IsZero(), "expected zero time, got %v", got) + } else { + assert.False(t, got.IsZero(), "expected non-zero time") + assert.True(t, got.After(time.Now()), "expected future time") + } + }) + } +} + +// TestParseRateLimitResetHeader verifies the Unix-timestamp header parsing. +func TestParseRateLimitResetHeader(t *testing.T) { + t.Parallel() + + now := time.Now() + future := now.Add(60 * time.Second) + + tests := []struct { + name string + value string + wantZero bool + wantTime time.Time + }{ + { + name: "empty", + value: "", + wantZero: true, + }, + { + name: "invalid", + value: "not-a-number", + wantZero: true, + }, + { + name: "valid unix timestamp", + value: "1000000000", + wantZero: false, + wantTime: time.Unix(1000000000, 0), + }, + { + name: "future timestamp", + value: strconv.FormatInt(future.Unix(), 10), + wantZero: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := parseRateLimitResetHeader(tt.value) + if tt.wantZero { + assert.True(t, got.IsZero(), "expected zero time") + } else { + assert.False(t, got.IsZero(), "expected non-zero time") + if !tt.wantTime.IsZero() { + assert.Equal(t, tt.wantTime.Unix(), got.Unix()) + } + } + }) + } +} + +// TestExtractRateLimitErrorText verifies extraction of error text from backend results. +func TestExtractRateLimitErrorText(t *testing.T) { + t.Parallel() + + t.Run("extracts text from standard rate-limit result", func(t *testing.T) { + t.Parallel() + result := map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "failed to search: 403 API rate limit exceeded [rate reset in 42s]", + }, + }, + } + assert.Equal(t, "failed to search: 403 API rate limit exceeded [rate reset in 42s]", extractRateLimitErrorText(result)) + }) + + t.Run("returns fallback for nil result", func(t *testing.T) { + t.Parallel() + assert.Equal(t, "rate limit exceeded", extractRateLimitErrorText(nil)) + }) + + t.Run("returns fallback for empty content", func(t *testing.T) { + t.Parallel() + result := map[string]interface{}{"isError": true, "content": []interface{}{}} + assert.Equal(t, "rate limit exceeded", extractRateLimitErrorText(result)) + }) +} diff --git a/internal/server/unified.go b/internal/server/unified.go index ce433ecb..d0f86074 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -94,6 +94,9 @@ type UnifiedServer struct { // means all tools are permitted for that server. allowedToolSets map[string]map[string]bool + // circuitBreakers holds a per-backend rate-limit circuit breaker keyed by server ID. + circuitBreakers map[string]*circuitBreaker + // DIFC components guardRegistry *guard.Registry agentRegistry *difc.AgentRegistry @@ -149,6 +152,7 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) payloadPathPrefix: payloadPathPrefix, payloadSizeThreshold: payloadSizeThreshold, allowedToolSets: buildAllowedToolSets(cfg), + circuitBreakers: buildCircuitBreakers(cfg), // Initialize DIFC components guardRegistry: guard.NewRegistry(), @@ -363,7 +367,36 @@ func newErrorCallToolResult(err error) (*sdk.CallToolResult, interface{}, error) }, nil, err } -// buildAllowedToolSets converts the per-server Tools lists from the config into pre-computed +// buildCircuitBreakers creates per-backend circuit breakers from the configuration. +func buildCircuitBreakers(cfg *config.Config) map[string]*circuitBreaker { + cbs := make(map[string]*circuitBreaker) + if cfg == nil { + return cbs + } + for serverID, serverCfg := range cfg.Servers { + threshold := serverCfg.RateLimitThreshold + cooldown := time.Duration(serverCfg.RateLimitCooldown) * time.Second + cbs[serverID] = newCircuitBreaker(serverID, threshold, cooldown) + logUnified.Printf("Created circuit breaker for server %s: threshold=%d, cooldown=%s", + serverID, threshold, cooldown) + } + return cbs +} + +// getCircuitBreaker returns the circuit breaker for serverID, creating one with +// defaults if none exists (e.g., when called from tests that bypass NewUnified). +func (us *UnifiedServer) getCircuitBreaker(serverID string) *circuitBreaker { + if us.circuitBreakers == nil { + us.circuitBreakers = make(map[string]*circuitBreaker) + } + if cb, ok := us.circuitBreakers[serverID]; ok { + return cb + } + cb := newCircuitBreaker(serverID, DefaultRateLimitThreshold, DefaultRateLimitCooldown) + us.circuitBreakers[serverID] = cb + return cb +} + // map[string]bool sets for O(1) lookup. Servers with no Tools list are not added to the map, // which signals that all tools are permitted. If the Tools list contains a "*" entry anywhere, // the server is treated the same as having no list (all tools allowed). @@ -532,14 +565,43 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName oteltrace.WithSpanKind(oteltrace.SpanKindClient), ) defer execSpan.End() + + // Check the circuit breaker before calling the backend. + cb := us.getCircuitBreaker(serverID) + if err := cb.Allow(); err != nil { + execSpan.RecordError(err) + execSpan.SetStatus(codes.Error, "circuit breaker open") + httpStatusCode = 429 + return newErrorCallToolResult(err) + } + backendResult, err := executeBackendToolCall(execCtx, us.launcher, serverID, sessionID, toolName, args) if err != nil { + // Transport errors (connection failure, JSON parse error, etc.) are not rate-limit + // events and must not affect the consecutive rate-limit counter. Leave the circuit + // breaker state unchanged so genuine rate-limit history is preserved. execSpan.RecordError(err) execSpan.SetStatus(codes.Error, err.Error()) httpStatusCode = 500 return newErrorCallToolResult(err) } + // Inspect the tool result for rate-limit indicators from the GitHub MCP server. + if rateLimited, resetAt := isRateLimitToolResult(backendResult); rateLimited { + cb.RecordRateLimit(resetAt) + execSpan.SetAttributes(attribute.Bool("rate_limit.hit", true)) + httpStatusCode = 429 + // Preserve the original backend error text so the agent sees the actual upstream + // rate-limit details. ErrCircuitOpen is only returned when cb.Allow() rejects + // the call before contacting the backend. + errText := extractRateLimitErrorText(backendResult) + return &sdk.CallToolResult{ + Content: []sdk.Content{&sdk.TextContent{Text: errText}}, + IsError: true, + }, backendResult, nil + } + cb.RecordSuccess() + // **Phase 4: Guard labels the response data (for fine-grained filtering)** // Per spec: LabelResponse() is only called for read operations in all modes, // and for read-write operations in filter/propagate modes.