From f5a1f23e614117814503551ef964cb8de906fd42 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 18:55:02 +0000 Subject: [PATCH 1/4] Initial plan From 99dea91ecab1fa88780f919d9174f8fe80a6aab8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 19:07:46 +0000 Subject: [PATCH 2/4] Add rate-limit circuit breaker for GitHub MCP backend tool calls - Add circuit_breaker.go with CLOSED/OPEN/HALF-OPEN state machine (Phase 2) - Add isRateLimitToolResult to detect GitHub MCP rate-limit errors by inspecting isError flag and text content patterns - Integrate circuit breaker into callBackendTool in unified.go: check before each backend call, record rate-limit hits and successes, return descriptive errors when OPEN - Add RateLimitThreshold and RateLimitCooldown fields to ServerConfig (TOML/JSON) - Add injectRetryAfterIfRateLimited in proxy/handler.go: detects HTTP 429 or X-RateLimit-Remaining=0, injects Retry-After header, logs ERROR (Phase 1 proxy) - Tests: circuit_breaker_test.go covers all state transitions; rate_limit_test.go covers proxy Retry-After injection; config rate_limit fields test" Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/29c289a0-52c0-4c27-98f3-8eef3b574559 Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com> --- internal/config/config_core.go | 10 + internal/config/rate_limit_config_test.go | 35 +++ internal/proxy/handler.go | 63 ++++ internal/proxy/rate_limit_test.go | 139 +++++++++ internal/server/circuit_breaker.go | 296 ++++++++++++++++++ internal/server/circuit_breaker_test.go | 364 ++++++++++++++++++++++ internal/server/unified.go | 54 +++- 7 files changed, 960 insertions(+), 1 deletion(-) create mode 100644 internal/config/rate_limit_config_test.go create mode 100644 internal/proxy/rate_limit_test.go create mode 100644 internal/server/circuit_breaker.go create mode 100644 internal/server/circuit_breaker_test.go diff --git a/internal/config/config_core.go b/internal/config/config_core.go index 9b1c6dfd..337a9ea2 100644 --- a/internal/config/config_core.go +++ b/internal/config/config_core.go @@ -215,6 +215,16 @@ type ServerConfig struct { // fallback uses the HTTP client's request timeout instead. Increase this for backends that are // slow to initialize. Only applies to HTTP server types. Default: 30 seconds. ConnectTimeout int `toml:"connect_timeout" json:"connect_timeout,omitempty"` + + // RateLimitThreshold is the number of consecutive rate-limit errors from this backend + // that will trip the circuit breaker (transition CLOSED → OPEN). When OPEN, requests + // are immediately rejected until the cooldown period elapses. Default: 3. + RateLimitThreshold int `toml:"rate_limit_threshold" json:"rate_limit_threshold,omitempty"` + + // RateLimitCooldown is the number of seconds the circuit breaker stays OPEN before + // allowing a single probe request (transition OPEN → HALF-OPEN). If the probe + // succeeds the circuit closes; if rate-limited again it re-opens. Default: 60. + RateLimitCooldown int `toml:"rate_limit_cooldown" json:"rate_limit_cooldown,omitempty"` } // GuardConfig represents a guard configuration for DIFC enforcement. diff --git a/internal/config/rate_limit_config_test.go b/internal/config/rate_limit_config_test.go new file mode 100644 index 00000000..1dbdfe5f --- /dev/null +++ b/internal/config/rate_limit_config_test.go @@ -0,0 +1,35 @@ +package config + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestServerConfig_RateLimitFields(t *testing.T) { + t.Parallel() + toml := ` +[servers.github] +command = "docker" +args = ["run", "--rm", "-i", "ghcr.io/github/github-mcp-server:latest"] +rate_limit_threshold = 5 +rate_limit_cooldown = 120 +` + path := writeTempTOML(t, toml) + cfg, err := LoadFromFile(path) + require.NoError(t, err) + srv := cfg.Servers["github"] + assert.Equal(t, 5, srv.RateLimitThreshold) + assert.Equal(t, 120, srv.RateLimitCooldown) +} + +func TestServerConfig_RateLimitFieldsDefaultToZero(t *testing.T) { + t.Parallel() + toml := validDockerServerTOML + path := writeTempTOML(t, toml) + cfg, err := LoadFromFile(path) + require.NoError(t, err) + srv := cfg.Servers["github"] + assert.Equal(t, 0, srv.RateLimitThreshold) + assert.Equal(t, 0, srv.RateLimitCooldown) +} diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 3cf420b6..136091a7 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -7,6 +7,8 @@ import ( "fmt" "io" "net/http" + "strconv" + "time" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -350,8 +352,11 @@ func (h *proxyHandler) passthrough(w http.ResponseWriter, r *http.Request, path } // writeResponse writes an upstream response to the client. +// When the upstream signals rate-limiting (HTTP 429 or X-RateLimit-Remaining == 0), +// it injects a Retry-After header and logs the event at ERROR level. func (h *proxyHandler) writeResponse(w http.ResponseWriter, resp *http.Response, body []byte) { copyResponseHeaders(w, resp) + injectRetryAfterIfRateLimited(w, resp) w.WriteHeader(resp.StatusCode) w.Write(body) } @@ -416,6 +421,64 @@ func copyResponseHeaders(w http.ResponseWriter, resp *http.Response) { } } +// injectRetryAfterIfRateLimited inspects the upstream response for rate-limit signals +// (HTTP 429 or X-RateLimit-Remaining == 0). When detected it: +// 1. Injects a Retry-After header so the client knows when to retry. +// 2. Logs the event at ERROR level so operators can monitor rate-limit incidents. +func injectRetryAfterIfRateLimited(w http.ResponseWriter, resp *http.Response) { + is429 := resp.StatusCode == http.StatusTooManyRequests + remaining := resp.Header.Get("X-RateLimit-Remaining") + resetHeader := resp.Header.Get("X-RateLimit-Reset") + + isRateLimited := is429 || remaining == "0" + if !isRateLimited { + return + } + + resetAt := parseRateLimitReset(resetHeader) + retryAfter := computeRetryAfter(resetAt) + + w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) + + logger.LogError("client", + "upstream rate limit hit: status=%d X-RateLimit-Remaining=%s X-RateLimit-Reset=%s retry-after=%ds", + resp.StatusCode, remaining, resetHeader, retryAfter) +} + +// parseRateLimitReset parses the X-RateLimit-Reset Unix-timestamp header. +// Returns zero time when absent or malformed. +func parseRateLimitReset(value string) time.Time { + if value == "" { + return time.Time{} + } + unix, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return time.Time{} + } + return time.Unix(unix, 0) +} + +// computeRetryAfter returns the number of seconds to wait before retrying. +// When resetAt is in the future the delay is clamped to [1, 3600] seconds. +// When resetAt is zero or in the past a default of 60 seconds is returned. +func computeRetryAfter(resetAt time.Time) int { + const ( + defaultDelay = 60 + maxDelay = 3600 + ) + if resetAt.IsZero() { + return defaultDelay + } + secs := int(time.Until(resetAt).Seconds()) + 1 // add 1s buffer + if secs < 1 { + return defaultDelay + } + if secs > maxDelay { + return maxDelay + } + return secs +} + // rewrapSearchResponse re-wraps filtered items into the original search response // envelope. GitHub search endpoints return {"total_count": N, "items": [...]}; // ToResult() returns a bare array, so we rebuild the wrapper. diff --git a/internal/proxy/rate_limit_test.go b/internal/proxy/rate_limit_test.go new file mode 100644 index 00000000..d020a13c --- /dev/null +++ b/internal/proxy/rate_limit_test.go @@ -0,0 +1,139 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestInjectRetryAfterIfRateLimited verifies Retry-After injection and logging for +// rate-limited upstream responses. +func TestInjectRetryAfterIfRateLimited(t *testing.T) { + t.Parallel() + + t.Run("HTTP 429 injects Retry-After", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + future := time.Now().Add(30 * time.Second) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + "X-Ratelimit-Reset": []string{strconv.FormatInt(future.Unix(), 10)}, + }, + } + injectRetryAfterIfRateLimited(w, resp) + retryAfter := w.Header().Get("Retry-After") + assert.NotEmpty(t, retryAfter, "Retry-After should be set on 429") + secs, err := strconv.Atoi(retryAfter) + assert.NoError(t, err) + assert.Greater(t, secs, 0, "Retry-After should be positive") + }) + + t.Run("X-RateLimit-Remaining 0 injects Retry-After", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + future := time.Now().Add(60 * time.Second) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "X-Ratelimit-Remaining": []string{"0"}, + "X-Ratelimit-Reset": []string{strconv.FormatInt(future.Unix(), 10)}, + }, + } + injectRetryAfterIfRateLimited(w, resp) + assert.NotEmpty(t, w.Header().Get("Retry-After"), "Retry-After should be set when remaining=0") + }) + + t.Run("non-zero remaining does not inject Retry-After", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "X-Ratelimit-Remaining": []string{"100"}, + }, + } + injectRetryAfterIfRateLimited(w, resp) + assert.Empty(t, w.Header().Get("Retry-After")) + }) + + t.Run("200 with no rate-limit headers does not inject Retry-After", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + } + injectRetryAfterIfRateLimited(w, resp) + assert.Empty(t, w.Header().Get("Retry-After")) + }) + + t.Run("429 without reset header uses default delay", func(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + injectRetryAfterIfRateLimited(w, resp) + retryAfter := w.Header().Get("Retry-After") + assert.Equal(t, "60", retryAfter, "default delay should be 60 seconds") + }) +} + +// TestParseRateLimitReset verifies the Unix-timestamp header parser. +func TestParseRateLimitReset(t *testing.T) { + t.Parallel() + + t.Run("empty string returns zero", func(t *testing.T) { + t.Parallel() + assert.True(t, parseRateLimitReset("").IsZero()) + }) + + t.Run("invalid string returns zero", func(t *testing.T) { + t.Parallel() + assert.True(t, parseRateLimitReset("not-a-number").IsZero()) + }) + + t.Run("valid unix timestamp parses correctly", func(t *testing.T) { + t.Parallel() + ts := time.Now().Add(60 * time.Second) + got := parseRateLimitReset(strconv.FormatInt(ts.Unix(), 10)) + assert.False(t, got.IsZero()) + assert.Equal(t, ts.Unix(), got.Unix()) + }) +} + +// TestComputeRetryAfter verifies the retry-after calculation. +func TestComputeRetryAfter(t *testing.T) { + t.Parallel() + + t.Run("zero time returns default", func(t *testing.T) { + t.Parallel() + assert.Equal(t, 60, computeRetryAfter(time.Time{})) + }) + + t.Run("past time returns default", func(t *testing.T) { + t.Parallel() + assert.Equal(t, 60, computeRetryAfter(time.Now().Add(-time.Minute))) + }) + + t.Run("future time returns seconds until reset", func(t *testing.T) { + t.Parallel() + future := time.Now().Add(30 * time.Second) + secs := computeRetryAfter(future) + // Allow ±2s for timing jitter. + assert.GreaterOrEqual(t, secs, 29) + assert.LessOrEqual(t, secs, 32) + }) + + t.Run("very far future is clamped to max", func(t *testing.T) { + t.Parallel() + farFuture := time.Now().Add(24 * time.Hour) + assert.Equal(t, 3600, computeRetryAfter(farFuture)) + }) +} diff --git a/internal/server/circuit_breaker.go b/internal/server/circuit_breaker.go new file mode 100644 index 00000000..2cf2f6ba --- /dev/null +++ b/internal/server/circuit_breaker.go @@ -0,0 +1,296 @@ +package server + +import ( + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/github/gh-aw-mcpg/internal/logger" +) + +// circuitBreakerState represents the state of a circuit breaker. +type circuitBreakerState int + +const ( + // circuitClosed is normal operation — requests pass through. + circuitClosed circuitBreakerState = iota + // circuitOpen means the circuit is tripped — requests are rejected immediately. + circuitOpen + // circuitHalfOpen means one probe request is allowed to test recovery. + circuitHalfOpen +) + +func (s circuitBreakerState) String() string { + switch s { + case circuitClosed: + return "CLOSED" + case circuitOpen: + return "OPEN" + case circuitHalfOpen: + return "HALF-OPEN" + default: + return "UNKNOWN" + } +} + +// DefaultRateLimitThreshold is the number of consecutive rate-limit errors +// before the circuit breaker opens. +const DefaultRateLimitThreshold = 3 + +// DefaultRateLimitCooldown is the number of seconds the circuit stays OPEN +// before transitioning to HALF-OPEN to probe one request. +const DefaultRateLimitCooldown = 60 * time.Second + +var logCircuitBreaker = logger.New("server:circuit_breaker") + +// circuitBreaker implements a per-backend rate-limit circuit breaker. +// +// State transitions: +// +// CLOSED → OPEN : after threshold consecutive rate-limit errors +// OPEN → HALF-OPEN : after cooldown period elapses +// HALF-OPEN → CLOSED : probe request succeeds +// HALF-OPEN → OPEN : probe request is rate-limited again +type circuitBreaker struct { + mu sync.Mutex + + state circuitBreakerState + consecutiveErrors int + openedAt time.Time + // resetAt is the time when the upstream rate limit resets, parsed from + // the X-RateLimit-Reset header or the tool response message. + resetAt time.Time + serverID string + + threshold int + cooldown time.Duration +} + +// newCircuitBreaker creates a circuit breaker for the given server ID. +// threshold is the number of consecutive rate-limit errors before opening; +// cooldown is how long to stay OPEN before probing. +func newCircuitBreaker(serverID string, threshold int, cooldown time.Duration) *circuitBreaker { + if threshold <= 0 { + threshold = DefaultRateLimitThreshold + } + if cooldown <= 0 { + cooldown = DefaultRateLimitCooldown + } + return &circuitBreaker{ + serverID: serverID, + state: circuitClosed, + threshold: threshold, + cooldown: cooldown, + } +} + +// ErrCircuitOpen is returned when the circuit breaker is OPEN and a request is rejected. +type ErrCircuitOpen struct { + ServerID string + ResetAt time.Time +} + +func (e *ErrCircuitOpen) Error() string { + if e.ResetAt.IsZero() { + return fmt.Sprintf("rate limit circuit breaker is OPEN for server %q — requests temporarily rejected", e.ServerID) + } + return fmt.Sprintf("rate limit circuit breaker is OPEN for server %q — rate limit resets at %s (retry after %s)", + e.ServerID, e.ResetAt.UTC().Format(time.RFC3339), time.Until(e.ResetAt).Round(time.Second)) +} + +// Allow reports whether a request should be allowed through. It also handles +// the OPEN → HALF-OPEN transition when the cooldown has elapsed. +// Returns an *ErrCircuitOpen error when the circuit is OPEN. +func (cb *circuitBreaker) Allow() error { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case circuitClosed: + return nil + + case circuitOpen: + // Check whether we should transition to HALF-OPEN. + // We use the upstream reset time when available, otherwise the cooldown. + now := time.Now() + var openUntil time.Time + if !cb.resetAt.IsZero() && cb.resetAt.After(cb.openedAt) { + openUntil = cb.resetAt + } else { + openUntil = cb.openedAt.Add(cb.cooldown) + } + if now.After(openUntil) { + logCircuitBreaker.Printf("server %q circuit breaker OPEN → HALF-OPEN after cooldown", cb.serverID) + logger.LogInfo("backend", "circuit breaker for server %q transitioning OPEN → HALF-OPEN", cb.serverID) + cb.state = circuitHalfOpen + return nil // allow the probe + } + return &ErrCircuitOpen{ServerID: cb.serverID, ResetAt: cb.resetAt} + + case circuitHalfOpen: + // One probe is allowed through; further requests are blocked. + return nil + } + + return nil +} + +// RecordSuccess records a successful (non-rate-limited) response. +// In HALF-OPEN state this closes the circuit. +func (cb *circuitBreaker) RecordSuccess() { + cb.mu.Lock() + defer cb.mu.Unlock() + + prev := cb.state + cb.consecutiveErrors = 0 + if cb.state == circuitHalfOpen { + cb.state = circuitClosed + cb.resetAt = time.Time{} + logCircuitBreaker.Printf("server %q circuit breaker HALF-OPEN → CLOSED (probe succeeded)", cb.serverID) + logger.LogInfo("backend", "circuit breaker for server %q recovered: HALF-OPEN → CLOSED", cb.serverID) + } else if prev != circuitClosed { + cb.state = circuitClosed + } +} + +// RecordRateLimit records a rate-limit error for the given server. +// resetAt is the time the upstream rate limit resets (may be zero if unknown). +// When the consecutive error count reaches threshold the circuit opens. +func (cb *circuitBreaker) RecordRateLimit(resetAt time.Time) { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.consecutiveErrors++ + if !resetAt.IsZero() { + cb.resetAt = resetAt + } + + switch cb.state { + case circuitClosed: + if cb.consecutiveErrors >= cb.threshold { + cb.state = circuitOpen + cb.openedAt = time.Now() + logger.LogError("backend", + "circuit breaker for server %q OPENED after %d consecutive rate-limit errors; resets at %s", + cb.serverID, cb.consecutiveErrors, formatResetAt(cb.resetAt)) + logCircuitBreaker.Printf("server %q circuit breaker CLOSED → OPEN (errors=%d)", cb.serverID, cb.consecutiveErrors) + } else { + logger.LogWarn("backend", + "rate-limit error for server %q (consecutive=%d/%d); resets at %s", + cb.serverID, cb.consecutiveErrors, cb.threshold, formatResetAt(cb.resetAt)) + } + + case circuitHalfOpen: + // Probe failed — re-open the circuit. + cb.state = circuitOpen + cb.openedAt = time.Now() + logger.LogError("backend", + "circuit breaker for server %q re-OPENED after probe was rate-limited; resets at %s", + cb.serverID, formatResetAt(cb.resetAt)) + logCircuitBreaker.Printf("server %q circuit breaker HALF-OPEN → OPEN (probe rate-limited)", cb.serverID) + + case circuitOpen: + // Already open — update reset time. + logger.LogWarn("backend", "server %q circuit breaker still OPEN; resets at %s", + cb.serverID, formatResetAt(cb.resetAt)) + } +} + +// State returns the current circuit breaker state (for observability). +func (cb *circuitBreaker) State() circuitBreakerState { + cb.mu.Lock() + defer cb.mu.Unlock() + return cb.state +} + +// formatResetAt returns a human-readable representation of a reset time. +func formatResetAt(t time.Time) string { + if t.IsZero() { + return "unknown" + } + return fmt.Sprintf("%s (in %s)", t.UTC().Format(time.RFC3339), time.Until(t).Round(time.Second)) +} + +// isRateLimitToolResult reports whether a raw tool call result indicates +// a rate-limit error from the GitHub MCP server. It inspects the `isError` +// flag and the text content for well-known rate-limit phrases. +// +// The GitHub MCP server returns rate-limit errors as: +// +// {"content":[{"type":"text","text":"... 403 API rate limit exceeded ..."}],"isError":true} +func isRateLimitToolResult(result interface{}) (bool, time.Time) { + m, ok := result.(map[string]interface{}) + if !ok { + return false, time.Time{} + } + + // Only inspect error results. + isErr, _ := m["isError"].(bool) + if !isErr { + return false, time.Time{} + } + + contents, _ := m["content"].([]interface{}) + for _, c := range contents { + cm, ok := c.(map[string]interface{}) + if !ok { + continue + } + text, _ := cm["text"].(string) + if isRateLimitText(text) { + resetAt := parseRateLimitResetFromText(text) + return true, resetAt + } + } + return false, time.Time{} +} + +// isRateLimitText returns true when the message indicates a GitHub rate-limit error. +func isRateLimitText(text string) bool { + lower := strings.ToLower(text) + return strings.Contains(lower, "rate limit exceeded") || + strings.Contains(lower, "rate limit") && strings.Contains(lower, "403") || + strings.Contains(lower, "api rate limit") || + strings.Contains(lower, "secondary rate limit") || + strings.Contains(lower, "too many requests") +} + +// parseRateLimitResetFromText attempts to extract a reset timestamp from the +// rate-limit error text. The GitHub MCP server includes messages like +// "API rate limit exceeded [rate reset in 42s]". +// Returns zero time when the value cannot be parsed or is 0 seconds. +func parseRateLimitResetFromText(text string) time.Time { + // Look for "[rate reset in Ns]" pattern. + lower := strings.ToLower(text) + idx := strings.Index(lower, "rate reset in ") + if idx < 0 { + return time.Time{} + } + rest := text[idx+len("rate reset in "):] + // Find the first non-digit character. + end := strings.IndexAny(rest, "s])") + if end < 0 { + return time.Time{} + } + secs, err := strconv.ParseInt(strings.TrimSpace(rest[:end]), 10, 64) + if err != nil || secs <= 0 { + return time.Time{} + } + return time.Now().Add(time.Duration(secs) * time.Second) +} + +// parseRateLimitResetHeader parses the Unix-timestamp value of the +// X-RateLimit-Reset HTTP header into a time.Time. +// Returns zero time when the header is absent or malformed. +func parseRateLimitResetHeader(value string) time.Time { + if value == "" { + return time.Time{} + } + unix, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64) + if err != nil { + return time.Time{} + } + return time.Unix(unix, 0) +} diff --git a/internal/server/circuit_breaker_test.go b/internal/server/circuit_breaker_test.go new file mode 100644 index 00000000..0292cd44 --- /dev/null +++ b/internal/server/circuit_breaker_test.go @@ -0,0 +1,364 @@ +package server + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestCircuitBreaker_InitialStateClosed verifies new circuit breakers start CLOSED. +func TestCircuitBreaker_InitialStateClosed(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 3, 60*time.Second) + assert.Equal(t, circuitClosed, cb.State()) + assert.NoError(t, cb.Allow()) +} + +// TestCircuitBreaker_OpensAfterThreshold verifies the circuit opens after N consecutive errors. +func TestCircuitBreaker_OpensAfterThreshold(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 3, 60*time.Second) + + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitClosed, cb.State(), "should remain CLOSED after 1 error") + assert.NoError(t, cb.Allow()) + + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitClosed, cb.State(), "should remain CLOSED after 2 errors") + assert.NoError(t, cb.Allow()) + + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitOpen, cb.State(), "should be OPEN after 3 errors (threshold)") + + err := cb.Allow() + require.Error(t, err, "OPEN circuit should reject requests") + var openErr *ErrCircuitOpen + require.ErrorAs(t, err, &openErr) + assert.Equal(t, "test", openErr.ServerID) +} + +// TestCircuitBreaker_SuccessResetsCounter verifies that a success resets the consecutive-error counter. +func TestCircuitBreaker_SuccessResetsCounter(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 3, 60*time.Second) + + cb.RecordRateLimit(time.Time{}) + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitClosed, cb.State(), "still CLOSED after 2 errors") + + cb.RecordSuccess() + assert.Equal(t, circuitClosed, cb.State(), "still CLOSED after success") + + // After a success the counter resets, so 2 more errors should NOT open the circuit. + cb.RecordRateLimit(time.Time{}) + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitClosed, cb.State(), "should be CLOSED (counter reset by success)") +} + +// TestCircuitBreaker_HalfOpenAfterCooldown verifies OPEN → HALF-OPEN transition. +func TestCircuitBreaker_HalfOpenAfterCooldown(t *testing.T) { + t.Parallel() + // Use a very short cooldown so the test doesn't sleep long. + cb := newCircuitBreaker("test", 1, 10*time.Millisecond) + + cb.RecordRateLimit(time.Time{}) + require.Equal(t, circuitOpen, cb.State(), "should be OPEN after 1 error") + + // Wait for cooldown. + time.Sleep(20 * time.Millisecond) + + err := cb.Allow() + assert.NoError(t, err, "should allow probe after cooldown") + assert.Equal(t, circuitHalfOpen, cb.State(), "should be HALF-OPEN after cooldown") +} + +// TestCircuitBreaker_HalfOpenClosesOnSuccess verifies HALF-OPEN → CLOSED on probe success. +func TestCircuitBreaker_HalfOpenClosesOnSuccess(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 1, 10*time.Millisecond) + + cb.RecordRateLimit(time.Time{}) + require.Equal(t, circuitOpen, cb.State()) + + time.Sleep(20 * time.Millisecond) + require.NoError(t, cb.Allow()) // probe allowed + + cb.RecordSuccess() + assert.Equal(t, circuitClosed, cb.State(), "should be CLOSED after probe success") + assert.NoError(t, cb.Allow(), "CLOSED circuit should allow requests") +} + +// TestCircuitBreaker_HalfOpenReOpensOnRateLimit verifies HALF-OPEN → OPEN on probe failure. +func TestCircuitBreaker_HalfOpenReOpensOnRateLimit(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 1, 10*time.Millisecond) + + cb.RecordRateLimit(time.Time{}) + require.Equal(t, circuitOpen, cb.State()) + + time.Sleep(20 * time.Millisecond) + require.NoError(t, cb.Allow()) // probe allowed + + cb.RecordRateLimit(time.Time{}) + assert.Equal(t, circuitOpen, cb.State(), "should be OPEN again after probe is rate-limited") + + err := cb.Allow() + require.Error(t, err) + var openErr *ErrCircuitOpen + require.ErrorAs(t, err, &openErr) +} + +// TestCircuitBreaker_ResetAtFromHeader verifies the reset time from upstream is used. +func TestCircuitBreaker_ResetAtFromHeader(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 1, 60*time.Second) + future := time.Now().Add(5 * time.Millisecond) + cb.RecordRateLimit(future) + require.Equal(t, circuitOpen, cb.State()) + + // Before the reset time: still OPEN. + require.Error(t, cb.Allow()) + + // After the reset time: transitions to HALF-OPEN. + time.Sleep(10 * time.Millisecond) + err := cb.Allow() + assert.NoError(t, err, "should allow probe after reset time") + assert.Equal(t, circuitHalfOpen, cb.State()) +} + +// TestCircuitBreaker_DefaultsApplied verifies zero-value config gets sensible defaults. +func TestCircuitBreaker_DefaultsApplied(t *testing.T) { + t.Parallel() + cb := newCircuitBreaker("test", 0, 0) + assert.Equal(t, DefaultRateLimitThreshold, cb.threshold) + assert.Equal(t, DefaultRateLimitCooldown, cb.cooldown) +} + +// TestCircuitBreaker_ErrOpenMessage verifies ErrCircuitOpen.Error() content. +func TestCircuitBreaker_ErrOpenMessage(t *testing.T) { + t.Parallel() + + t.Run("no reset time", func(t *testing.T) { + t.Parallel() + err := &ErrCircuitOpen{ServerID: "github"} + assert.Contains(t, err.Error(), "github") + assert.Contains(t, err.Error(), "OPEN") + }) + + t.Run("with reset time", func(t *testing.T) { + t.Parallel() + reset := time.Now().Add(30 * time.Second) + err := &ErrCircuitOpen{ServerID: "github", ResetAt: reset} + assert.Contains(t, err.Error(), "github") + assert.Contains(t, err.Error(), "retry after") + }) +} + +// TestIsRateLimitToolResult verifies rate-limit detection from GitHub MCP tool results. +func TestIsRateLimitToolResult(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + result interface{} + wantHit bool + wantReset bool // whether a non-zero reset time is expected + }{ + { + name: "standard rate limit exceeded message", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "failed to search repositories: 403 API rate limit exceeded [rate reset in 42s]", + }, + }, + }, + wantHit: true, + wantReset: true, + }, + { + name: "secondary rate limit", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "secondary rate limit triggered", + }, + }, + }, + wantHit: true, + }, + { + name: "too many requests", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "too many requests", + }, + }, + }, + wantHit: true, + }, + { + name: "non-rate-limit error", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "repository not found", + }, + }, + }, + wantHit: false, + }, + { + name: "successful result (isError false)", + result: map[string]interface{}{ + "isError": false, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "API rate limit exceeded but isError is false", + }, + }, + }, + wantHit: false, + }, + { + name: "nil result", + result: nil, + wantHit: false, + }, + { + name: "non-map result", + result: "some string", + wantHit: false, + }, + { + name: "rate reset in 0s", + result: map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "API rate limit exceeded [rate reset in 0s]", + }, + }, + }, + wantHit: true, + wantReset: false, // 0s means no future time + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + hit, resetAt := isRateLimitToolResult(tt.result) + assert.Equal(t, tt.wantHit, hit, "isRateLimitToolResult mismatch") + if tt.wantReset { + assert.False(t, resetAt.IsZero(), "expected non-zero resetAt") + } + }) + } +} + +// TestParseRateLimitResetFromText verifies reset time parsing from error messages. +func TestParseRateLimitResetFromText(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + text string + wantZero bool + }{ + { + name: "42 seconds", + text: "rate limit exceeded [rate reset in 42s]", + wantZero: false, + }, + { + name: "0 seconds gives zero time", + text: "rate limit exceeded [rate reset in 0s]", + wantZero: true, + }, + { + name: "no pattern", + text: "some other error", + wantZero: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := parseRateLimitResetFromText(tt.text) + if tt.wantZero { + assert.True(t, got.IsZero(), "expected zero time, got %v", got) + } else { + assert.False(t, got.IsZero(), "expected non-zero time") + assert.True(t, got.After(time.Now()), "expected future time") + } + }) + } +} + +// TestParseRateLimitResetHeader verifies the Unix-timestamp header parsing. +func TestParseRateLimitResetHeader(t *testing.T) { + t.Parallel() + + now := time.Now() + future := now.Add(60 * time.Second) + + tests := []struct { + name string + value string + wantZero bool + wantTime time.Time + }{ + { + name: "empty", + value: "", + wantZero: true, + }, + { + name: "invalid", + value: "not-a-number", + wantZero: true, + }, + { + name: "valid unix timestamp", + value: "1000000000", + wantZero: false, + wantTime: time.Unix(1000000000, 0), + }, + { + name: "future timestamp", + value: strconv.FormatInt(future.Unix(), 10), + wantZero: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := parseRateLimitResetHeader(tt.value) + if tt.wantZero { + assert.True(t, got.IsZero(), "expected zero time") + } else { + assert.False(t, got.IsZero(), "expected non-zero time") + if !tt.wantTime.IsZero() { + assert.Equal(t, tt.wantTime.Unix(), got.Unix()) + } + } + }) + } +} diff --git a/internal/server/unified.go b/internal/server/unified.go index ce433ecb..fbea042d 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -94,6 +94,9 @@ type UnifiedServer struct { // means all tools are permitted for that server. allowedToolSets map[string]map[string]bool + // circuitBreakers holds a per-backend rate-limit circuit breaker keyed by server ID. + circuitBreakers map[string]*circuitBreaker + // DIFC components guardRegistry *guard.Registry agentRegistry *difc.AgentRegistry @@ -149,6 +152,7 @@ func NewUnified(ctx context.Context, cfg *config.Config) (*UnifiedServer, error) payloadPathPrefix: payloadPathPrefix, payloadSizeThreshold: payloadSizeThreshold, allowedToolSets: buildAllowedToolSets(cfg), + circuitBreakers: buildCircuitBreakers(cfg), // Initialize DIFC components guardRegistry: guard.NewRegistry(), @@ -363,7 +367,33 @@ func newErrorCallToolResult(err error) (*sdk.CallToolResult, interface{}, error) }, nil, err } -// buildAllowedToolSets converts the per-server Tools lists from the config into pre-computed +// buildCircuitBreakers creates per-backend circuit breakers from the configuration. +func buildCircuitBreakers(cfg *config.Config) map[string]*circuitBreaker { + cbs := make(map[string]*circuitBreaker) + if cfg == nil { + return cbs + } + for serverID, serverCfg := range cfg.Servers { + threshold := serverCfg.RateLimitThreshold + cooldown := time.Duration(serverCfg.RateLimitCooldown) * time.Second + cbs[serverID] = newCircuitBreaker(serverID, threshold, cooldown) + logUnified.Printf("Created circuit breaker for server %s: threshold=%d, cooldown=%s", + serverID, threshold, cooldown) + } + return cbs +} + +// getCircuitBreaker returns the circuit breaker for serverID, creating one with +// defaults if none exists (e.g., when called from tests that bypass NewUnified). +func (us *UnifiedServer) getCircuitBreaker(serverID string) *circuitBreaker { + if cb, ok := us.circuitBreakers[serverID]; ok { + return cb + } + cb := newCircuitBreaker(serverID, DefaultRateLimitThreshold, DefaultRateLimitCooldown) + us.circuitBreakers[serverID] = cb + return cb +} + // map[string]bool sets for O(1) lookup. Servers with no Tools list are not added to the map, // which signals that all tools are permitted. If the Tools list contains a "*" entry anywhere, // the server is treated the same as having no list (all tools allowed). @@ -532,14 +562,36 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName oteltrace.WithSpanKind(oteltrace.SpanKindClient), ) defer execSpan.End() + + // Check the circuit breaker before calling the backend. + cb := us.getCircuitBreaker(serverID) + if err := cb.Allow(); err != nil { + execSpan.RecordError(err) + execSpan.SetStatus(codes.Error, "circuit breaker open") + httpStatusCode = 429 + return newErrorCallToolResult(err) + } + backendResult, err := executeBackendToolCall(execCtx, us.launcher, serverID, sessionID, toolName, args) if err != nil { + cb.RecordSuccess() // transport error ≠ rate limit; reset consecutive counter execSpan.RecordError(err) execSpan.SetStatus(codes.Error, err.Error()) httpStatusCode = 500 return newErrorCallToolResult(err) } + // Inspect the tool result for rate-limit indicators from the GitHub MCP server. + if rateLimited, resetAt := isRateLimitToolResult(backendResult); rateLimited { + cb.RecordRateLimit(resetAt) + execSpan.SetAttributes(attribute.Bool("rate_limit.hit", true)) + httpStatusCode = 429 + // Return the original error message so the agent can see it. + return newErrorCallToolResult(fmt.Errorf("backend server %q rate-limited: %w", + serverID, &ErrCircuitOpen{ServerID: serverID, ResetAt: resetAt})) + } + cb.RecordSuccess() + // **Phase 4: Guard labels the response data (for fine-grained filtering)** // Per spec: LabelResponse() is only called for read operations in all modes, // and for read-write operations in filter/propagate modes. From 4b0ba32295fbab3defb1ff836804df0b562869eb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Apr 2026 19:11:20 +0000 Subject: [PATCH 3/4] Address code review: fix transport error handling, header casing, import grouping Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/29c289a0-52c0-4c27-98f3-8eef3b574559 Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com> --- internal/config/rate_limit_config_test.go | 3 ++- internal/proxy/handler.go | 10 ++++++---- internal/proxy/rate_limit_test.go | 2 +- internal/server/unified.go | 4 +++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/internal/config/rate_limit_config_test.go b/internal/config/rate_limit_config_test.go index 1dbdfe5f..ac753016 100644 --- a/internal/config/rate_limit_config_test.go +++ b/internal/config/rate_limit_config_test.go @@ -1,9 +1,10 @@ package config import ( + "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" ) func TestServerConfig_RateLimitFields(t *testing.T) { diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 136091a7..e09d0424 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -422,13 +422,15 @@ func copyResponseHeaders(w http.ResponseWriter, resp *http.Response) { } // injectRetryAfterIfRateLimited inspects the upstream response for rate-limit signals -// (HTTP 429 or X-RateLimit-Remaining == 0). When detected it: +// (HTTP 429 or X-Ratelimit-Remaining == 0). When detected it: // 1. Injects a Retry-After header so the client knows when to retry. // 2. Logs the event at ERROR level so operators can monitor rate-limit incidents. func injectRetryAfterIfRateLimited(w http.ResponseWriter, resp *http.Response) { is429 := resp.StatusCode == http.StatusTooManyRequests - remaining := resp.Header.Get("X-RateLimit-Remaining") - resetHeader := resp.Header.Get("X-RateLimit-Reset") + // Use Go's canonical header key form (textproto.CanonicalMIMEHeaderKey produces + // "X-Ratelimit-Remaining", matching GitHub's actual response headers). + remaining := resp.Header.Get("X-Ratelimit-Remaining") + resetHeader := resp.Header.Get("X-Ratelimit-Reset") isRateLimited := is429 || remaining == "0" if !isRateLimited { @@ -441,7 +443,7 @@ func injectRetryAfterIfRateLimited(w http.ResponseWriter, resp *http.Response) { w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) logger.LogError("client", - "upstream rate limit hit: status=%d X-RateLimit-Remaining=%s X-RateLimit-Reset=%s retry-after=%ds", + "upstream rate limit hit: status=%d X-Ratelimit-Remaining=%s X-Ratelimit-Reset=%s retry-after=%ds", resp.StatusCode, remaining, resetHeader, retryAfter) } diff --git a/internal/proxy/rate_limit_test.go b/internal/proxy/rate_limit_test.go index d020a13c..b05a19d1 100644 --- a/internal/proxy/rate_limit_test.go +++ b/internal/proxy/rate_limit_test.go @@ -33,7 +33,7 @@ func TestInjectRetryAfterIfRateLimited(t *testing.T) { assert.Greater(t, secs, 0, "Retry-After should be positive") }) - t.Run("X-RateLimit-Remaining 0 injects Retry-After", func(t *testing.T) { + t.Run("X-Ratelimit-Remaining 0 injects Retry-After", func(t *testing.T) { t.Parallel() w := httptest.NewRecorder() future := time.Now().Add(60 * time.Second) diff --git a/internal/server/unified.go b/internal/server/unified.go index fbea042d..787bd1fc 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -574,7 +574,9 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName backendResult, err := executeBackendToolCall(execCtx, us.launcher, serverID, sessionID, toolName, args) if err != nil { - cb.RecordSuccess() // transport error ≠ rate limit; reset consecutive counter + // Transport errors (connection failure, JSON parse error, etc.) are not rate-limit + // events and must not affect the consecutive rate-limit counter. Leave the circuit + // breaker state unchanged so genuine rate-limit history is preserved. execSpan.RecordError(err) execSpan.SetStatus(codes.Error, err.Error()) httpStatusCode = 500 From fea1da42d3fb3e23d97c43a4f9016f5cc5a54dea Mon Sep 17 00:00:00 2001 From: Landon Cox Date: Tue, 14 Apr 2026 12:58:53 -0700 Subject: [PATCH 4/4] review: enforce single HALF-OPEN probe, injectable clock, preserve backend errors Address all 5 existing review comments plus 2 additional findings: 1. HALF-OPEN probe tracking: add probeInFlight flag so only one probe passes in HALF-OPEN state; concurrent requests get ErrCircuitOpen. Add TestCircuitBreaker_HalfOpenBlocksConcurrentProbes. 2. Nil map guard: getCircuitBreaker initializes circuitBreakers map when nil, preventing panic in test constructors that bypass NewUnified. 3. Preserve backend error text: rate-limit detection now returns the original upstream error message via extractRateLimitErrorText instead of wrapping in ErrCircuitOpen. ErrCircuitOpen is only returned when cb.Allow() rejects the call before contacting the backend. 4. TOML-only doc comments: clarify that rate_limit_threshold and rate_limit_cooldown are not wired for stdin JSON config. 5. Injectable clock: replace time.Now() in circuit breaker with nowFunc field (default time.Now). Tests use deterministic fake time instead of flaky time.Sleep-based assertions. 6. Operator precedence: add explicit parentheses in isRateLimitText for the compound rate limit + 403 condition. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- internal/config/config_core.go | 2 + internal/server/circuit_breaker.go | 51 +++++++++++-- internal/server/circuit_breaker_test.go | 97 +++++++++++++++++++++---- internal/server/unified.go | 14 +++- 4 files changed, 140 insertions(+), 24 deletions(-) diff --git a/internal/config/config_core.go b/internal/config/config_core.go index 337a9ea2..fc715def 100644 --- a/internal/config/config_core.go +++ b/internal/config/config_core.go @@ -219,11 +219,13 @@ type ServerConfig struct { // RateLimitThreshold is the number of consecutive rate-limit errors from this backend // that will trip the circuit breaker (transition CLOSED → OPEN). When OPEN, requests // are immediately rejected until the cooldown period elapses. Default: 3. + // Supported in file-based config (TOML/JSON); stdin JSON config does not currently accept this field. RateLimitThreshold int `toml:"rate_limit_threshold" json:"rate_limit_threshold,omitempty"` // RateLimitCooldown is the number of seconds the circuit breaker stays OPEN before // allowing a single probe request (transition OPEN → HALF-OPEN). If the probe // succeeds the circuit closes; if rate-limited again it re-opens. Default: 60. + // Supported in file-based config (TOML/JSON); stdin JSON config does not currently accept this field. RateLimitCooldown int `toml:"rate_limit_cooldown" json:"rate_limit_cooldown,omitempty"` } diff --git a/internal/server/circuit_breaker.go b/internal/server/circuit_breaker.go index 2cf2f6ba..4e67f8a3 100644 --- a/internal/server/circuit_breaker.go +++ b/internal/server/circuit_breaker.go @@ -61,11 +61,16 @@ type circuitBreaker struct { openedAt time.Time // resetAt is the time when the upstream rate limit resets, parsed from // the X-RateLimit-Reset header or the tool response message. - resetAt time.Time - serverID string + resetAt time.Time + probeInFlight bool + serverID string threshold int cooldown time.Duration + + // nowFunc returns the current time. Defaults to time.Now; overridden in tests + // to avoid flaky time.Sleep-based assertions. + nowFunc func() time.Time } // newCircuitBreaker creates a circuit breaker for the given server ID. @@ -83,6 +88,7 @@ func newCircuitBreaker(serverID string, threshold int, cooldown time.Duration) * state: circuitClosed, threshold: threshold, cooldown: cooldown, + nowFunc: time.Now, } } @@ -114,7 +120,7 @@ func (cb *circuitBreaker) Allow() error { case circuitOpen: // Check whether we should transition to HALF-OPEN. // We use the upstream reset time when available, otherwise the cooldown. - now := time.Now() + now := cb.nowFunc() var openUntil time.Time if !cb.resetAt.IsZero() && cb.resetAt.After(cb.openedAt) { openUntil = cb.resetAt @@ -125,12 +131,18 @@ func (cb *circuitBreaker) Allow() error { logCircuitBreaker.Printf("server %q circuit breaker OPEN → HALF-OPEN after cooldown", cb.serverID) logger.LogInfo("backend", "circuit breaker for server %q transitioning OPEN → HALF-OPEN", cb.serverID) cb.state = circuitHalfOpen - return nil // allow the probe + cb.probeInFlight = true + return nil // allow the single probe } return &ErrCircuitOpen{ServerID: cb.serverID, ResetAt: cb.resetAt} case circuitHalfOpen: - // One probe is allowed through; further requests are blocked. + // Only one probe is allowed; further requests are blocked until the probe resolves. + if cb.probeInFlight { + return &ErrCircuitOpen{ServerID: cb.serverID, ResetAt: cb.resetAt} + } + // This shouldn't normally happen (probe resolved but state wasn't updated), + // but allow through defensively. return nil } @@ -145,6 +157,7 @@ func (cb *circuitBreaker) RecordSuccess() { prev := cb.state cb.consecutiveErrors = 0 + cb.probeInFlight = false if cb.state == circuitHalfOpen { cb.state = circuitClosed cb.resetAt = time.Time{} @@ -163,6 +176,7 @@ func (cb *circuitBreaker) RecordRateLimit(resetAt time.Time) { defer cb.mu.Unlock() cb.consecutiveErrors++ + cb.probeInFlight = false if !resetAt.IsZero() { cb.resetAt = resetAt } @@ -171,7 +185,7 @@ func (cb *circuitBreaker) RecordRateLimit(resetAt time.Time) { case circuitClosed: if cb.consecutiveErrors >= cb.threshold { cb.state = circuitOpen - cb.openedAt = time.Now() + cb.openedAt = cb.nowFunc() logger.LogError("backend", "circuit breaker for server %q OPENED after %d consecutive rate-limit errors; resets at %s", cb.serverID, cb.consecutiveErrors, formatResetAt(cb.resetAt)) @@ -185,7 +199,7 @@ func (cb *circuitBreaker) RecordRateLimit(resetAt time.Time) { case circuitHalfOpen: // Probe failed — re-open the circuit. cb.state = circuitOpen - cb.openedAt = time.Now() + cb.openedAt = cb.nowFunc() logger.LogError("backend", "circuit breaker for server %q re-OPENED after probe was rate-limited; resets at %s", cb.serverID, formatResetAt(cb.resetAt)) @@ -213,6 +227,27 @@ func formatResetAt(t time.Time) string { return fmt.Sprintf("%s (in %s)", t.UTC().Format(time.RFC3339), time.Until(t).Round(time.Second)) } +// extractRateLimitErrorText extracts the text content from a raw tool result +// that has been identified as a rate-limit error. Returns the original backend +// message so agents see the actual upstream error rather than a synthetic one. +func extractRateLimitErrorText(result interface{}) string { + m, ok := result.(map[string]interface{}) + if !ok { + return "rate limit exceeded" + } + contents, _ := m["content"].([]interface{}) + for _, c := range contents { + cm, ok := c.(map[string]interface{}) + if !ok { + continue + } + if text, ok := cm["text"].(string); ok && text != "" { + return text + } + } + return "rate limit exceeded" +} + // isRateLimitToolResult reports whether a raw tool call result indicates // a rate-limit error from the GitHub MCP server. It inspects the `isError` // flag and the text content for well-known rate-limit phrases. @@ -251,7 +286,7 @@ func isRateLimitToolResult(result interface{}) (bool, time.Time) { func isRateLimitText(text string) bool { lower := strings.ToLower(text) return strings.Contains(lower, "rate limit exceeded") || - strings.Contains(lower, "rate limit") && strings.Contains(lower, "403") || + (strings.Contains(lower, "rate limit") && strings.Contains(lower, "403")) || strings.Contains(lower, "api rate limit") || strings.Contains(lower, "secondary rate limit") || strings.Contains(lower, "too many requests") diff --git a/internal/server/circuit_breaker_test.go b/internal/server/circuit_breaker_test.go index 0292cd44..f31c9270 100644 --- a/internal/server/circuit_breaker_test.go +++ b/internal/server/circuit_breaker_test.go @@ -61,15 +61,19 @@ func TestCircuitBreaker_SuccessResetsCounter(t *testing.T) { // TestCircuitBreaker_HalfOpenAfterCooldown verifies OPEN → HALF-OPEN transition. func TestCircuitBreaker_HalfOpenAfterCooldown(t *testing.T) { t.Parallel() - // Use a very short cooldown so the test doesn't sleep long. - cb := newCircuitBreaker("test", 1, 10*time.Millisecond) + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Minute) + cb.nowFunc = func() time.Time { return fakeNow } cb.RecordRateLimit(time.Time{}) require.Equal(t, circuitOpen, cb.State(), "should be OPEN after 1 error") - // Wait for cooldown. - time.Sleep(20 * time.Millisecond) + // Before cooldown: still OPEN. + fakeNow = fakeNow.Add(30 * time.Second) + require.Error(t, cb.Allow(), "should reject before cooldown elapses") + // After cooldown: transitions to HALF-OPEN. + fakeNow = fakeNow.Add(31 * time.Second) err := cb.Allow() assert.NoError(t, err, "should allow probe after cooldown") assert.Equal(t, circuitHalfOpen, cb.State(), "should be HALF-OPEN after cooldown") @@ -78,12 +82,14 @@ func TestCircuitBreaker_HalfOpenAfterCooldown(t *testing.T) { // TestCircuitBreaker_HalfOpenClosesOnSuccess verifies HALF-OPEN → CLOSED on probe success. func TestCircuitBreaker_HalfOpenClosesOnSuccess(t *testing.T) { t.Parallel() - cb := newCircuitBreaker("test", 1, 10*time.Millisecond) + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Minute) + cb.nowFunc = func() time.Time { return fakeNow } cb.RecordRateLimit(time.Time{}) require.Equal(t, circuitOpen, cb.State()) - time.Sleep(20 * time.Millisecond) + fakeNow = fakeNow.Add(2 * time.Minute) require.NoError(t, cb.Allow()) // probe allowed cb.RecordSuccess() @@ -94,12 +100,14 @@ func TestCircuitBreaker_HalfOpenClosesOnSuccess(t *testing.T) { // TestCircuitBreaker_HalfOpenReOpensOnRateLimit verifies HALF-OPEN → OPEN on probe failure. func TestCircuitBreaker_HalfOpenReOpensOnRateLimit(t *testing.T) { t.Parallel() - cb := newCircuitBreaker("test", 1, 10*time.Millisecond) + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Minute) + cb.nowFunc = func() time.Time { return fakeNow } cb.RecordRateLimit(time.Time{}) require.Equal(t, circuitOpen, cb.State()) - time.Sleep(20 * time.Millisecond) + fakeNow = fakeNow.Add(2 * time.Minute) require.NoError(t, cb.Allow()) // probe allowed cb.RecordRateLimit(time.Time{}) @@ -114,21 +122,54 @@ func TestCircuitBreaker_HalfOpenReOpensOnRateLimit(t *testing.T) { // TestCircuitBreaker_ResetAtFromHeader verifies the reset time from upstream is used. func TestCircuitBreaker_ResetAtFromHeader(t *testing.T) { t.Parallel() - cb := newCircuitBreaker("test", 1, 60*time.Second) - future := time.Now().Add(5 * time.Millisecond) - cb.RecordRateLimit(future) + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Hour) + cb.nowFunc = func() time.Time { return fakeNow } + + resetAt := fakeNow.Add(30 * time.Second) + cb.RecordRateLimit(resetAt) require.Equal(t, circuitOpen, cb.State()) // Before the reset time: still OPEN. + fakeNow = fakeNow.Add(15 * time.Second) require.Error(t, cb.Allow()) - // After the reset time: transitions to HALF-OPEN. - time.Sleep(10 * time.Millisecond) + // After the reset time: transitions to HALF-OPEN (before cooldown would elapse). + fakeNow = fakeNow.Add(20 * time.Second) err := cb.Allow() assert.NoError(t, err, "should allow probe after reset time") assert.Equal(t, circuitHalfOpen, cb.State()) } +// TestCircuitBreaker_HalfOpenBlocksConcurrentProbes verifies that only one probe is allowed in HALF-OPEN. +func TestCircuitBreaker_HalfOpenBlocksConcurrentProbes(t *testing.T) { + t.Parallel() + fakeNow := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + cb := newCircuitBreaker("test", 1, time.Minute) + cb.nowFunc = func() time.Time { return fakeNow } + + cb.RecordRateLimit(time.Time{}) + require.Equal(t, circuitOpen, cb.State()) + + // Advance past cooldown to trigger HALF-OPEN. + fakeNow = fakeNow.Add(2 * time.Minute) + + // First Allow() should succeed (the probe). + require.NoError(t, cb.Allow()) + assert.Equal(t, circuitHalfOpen, cb.State()) + + // Second Allow() should be rejected — probe is already in flight. + err := cb.Allow() + require.Error(t, err, "concurrent requests in HALF-OPEN should be rejected") + var openErr *ErrCircuitOpen + require.ErrorAs(t, err, &openErr) + + // After the probe succeeds, requests should be allowed again. + cb.RecordSuccess() + assert.Equal(t, circuitClosed, cb.State()) + assert.NoError(t, cb.Allow()) +} + // TestCircuitBreaker_DefaultsApplied verifies zero-value config gets sensible defaults. func TestCircuitBreaker_DefaultsApplied(t *testing.T) { t.Parallel() @@ -362,3 +403,33 @@ func TestParseRateLimitResetHeader(t *testing.T) { }) } } + +// TestExtractRateLimitErrorText verifies extraction of error text from backend results. +func TestExtractRateLimitErrorText(t *testing.T) { + t.Parallel() + + t.Run("extracts text from standard rate-limit result", func(t *testing.T) { + t.Parallel() + result := map[string]interface{}{ + "isError": true, + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "failed to search: 403 API rate limit exceeded [rate reset in 42s]", + }, + }, + } + assert.Equal(t, "failed to search: 403 API rate limit exceeded [rate reset in 42s]", extractRateLimitErrorText(result)) + }) + + t.Run("returns fallback for nil result", func(t *testing.T) { + t.Parallel() + assert.Equal(t, "rate limit exceeded", extractRateLimitErrorText(nil)) + }) + + t.Run("returns fallback for empty content", func(t *testing.T) { + t.Parallel() + result := map[string]interface{}{"isError": true, "content": []interface{}{}} + assert.Equal(t, "rate limit exceeded", extractRateLimitErrorText(result)) + }) +} diff --git a/internal/server/unified.go b/internal/server/unified.go index 787bd1fc..d0f86074 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -386,6 +386,9 @@ func buildCircuitBreakers(cfg *config.Config) map[string]*circuitBreaker { // getCircuitBreaker returns the circuit breaker for serverID, creating one with // defaults if none exists (e.g., when called from tests that bypass NewUnified). func (us *UnifiedServer) getCircuitBreaker(serverID string) *circuitBreaker { + if us.circuitBreakers == nil { + us.circuitBreakers = make(map[string]*circuitBreaker) + } if cb, ok := us.circuitBreakers[serverID]; ok { return cb } @@ -588,9 +591,14 @@ func (us *UnifiedServer) callBackendTool(ctx context.Context, serverID, toolName cb.RecordRateLimit(resetAt) execSpan.SetAttributes(attribute.Bool("rate_limit.hit", true)) httpStatusCode = 429 - // Return the original error message so the agent can see it. - return newErrorCallToolResult(fmt.Errorf("backend server %q rate-limited: %w", - serverID, &ErrCircuitOpen{ServerID: serverID, ResetAt: resetAt})) + // Preserve the original backend error text so the agent sees the actual upstream + // rate-limit details. ErrCircuitOpen is only returned when cb.Allow() rejects + // the call before contacting the backend. + errText := extractRateLimitErrorText(backendResult) + return &sdk.CallToolResult{ + Content: []sdk.Content{&sdk.TextContent{Text: errText}}, + IsError: true, + }, backendResult, nil } cb.RecordSuccess()