Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ func main() {
rate.Limit(conf.Configuration.CachedRateLimitPerSecond),
conf.Configuration.CachedRateLimitBurstLimit,
)
limiter.StartCleanup(5*time.Minute, 10*time.Minute)

loggedRouter := middleware.LoggingMiddleware(router)
corsHandler := c.Handler(loggedRouter)
Expand Down
47 changes: 42 additions & 5 deletions middleware/ratelimit.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package middleware

import (
"golang.org/x/time/rate"
"math"
"sync"
"time"

"golang.org/x/time/rate"
)

// LimiterPair holds both normal and cached tier limiters for an IP
type LimiterPair struct {
Normal *rate.Limiter
Cached *rate.Limiter
Normal *rate.Limiter
Cached *rate.Limiter
lastSeen time.Time
}

// GetNormalTokens returns the number of tokens available in the normal tier
Expand Down Expand Up @@ -61,8 +64,9 @@ func (i *IPRateLimiter) AddIP(ip string) *LimiterPair {
defer i.mu.Unlock()

pair := &LimiterPair{
Normal: rate.NewLimiter(i.normalRate, i.normalBurst),
Cached: rate.NewLimiter(i.cachedRate, i.cachedBurst),
Normal: rate.NewLimiter(i.normalRate, i.normalBurst),
Cached: rate.NewLimiter(i.cachedRate, i.cachedBurst),
lastSeen: time.Now(),
}

i.ips[ip] = pair
Expand All @@ -79,7 +83,40 @@ func (i *IPRateLimiter) GetLimiter(ip string) *LimiterPair {
return i.AddIP(ip)
}

limiter.lastSeen = time.Now()
i.mu.Unlock()

return limiter
}

// StartCleanup launches a background goroutine that periodically removes
// IP entries that haven't been seen within the given idle timeout.
func (i *IPRateLimiter) StartCleanup(interval, idleTimeout time.Duration) {
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
i.cleanup(idleTimeout)
}
}()
}

// cleanup removes IP entries that haven't been accessed within the idle timeout.
func (i *IPRateLimiter) cleanup(idleTimeout time.Duration) {
i.mu.Lock()
defer i.mu.Unlock()

cutoff := time.Now().Add(-idleTimeout)
for ip, pair := range i.ips {
if pair.lastSeen.Before(cutoff) {
delete(i.ips, ip)
}
}
}

// Len returns the number of tracked IPs (for testing).
func (i *IPRateLimiter) Len() int {
i.mu.RLock()
defer i.mu.RUnlock()
return len(i.ips)
}
70 changes: 70 additions & 0 deletions middleware/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,76 @@ func TestLimiterPairTokens(t *testing.T) {
}
}

// TestCleanupEvictsStaleIPs tests that the cleanup method removes stale IPs.
func TestCleanupEvictsStaleIPs(t *testing.T) {
rl := NewIPRateLimiter(1, 5, 10, 20)

// Add two IPs
rl.GetLimiter("10.0.0.1")
rl.GetLimiter("10.0.0.2")

if rl.Len() != 2 {
t.Fatalf("Expected 2 IPs, got %d", rl.Len())
}

// Manually backdate one entry so it appears stale
rl.mu.Lock()
rl.ips["10.0.0.1"].lastSeen = time.Now().Add(-20 * time.Minute)
rl.mu.Unlock()

// Run cleanup with a 10-minute idle timeout
rl.cleanup(10 * time.Minute)

if rl.Len() != 1 {
t.Fatalf("Expected 1 IP after cleanup, got %d", rl.Len())
}

// The recent IP should still exist
rl.mu.RLock()
_, exists := rl.ips["10.0.0.2"]
rl.mu.RUnlock()
if !exists {
t.Error("Expected 10.0.0.2 to survive cleanup")
}
}

// TestCleanupKeepsActiveIPs tests that active IPs are not evicted.
func TestCleanupKeepsActiveIPs(t *testing.T) {
rl := NewIPRateLimiter(1, 5, 10, 20)

rl.GetLimiter("10.0.0.1")
rl.GetLimiter("10.0.0.2")

// Cleanup with a 10-minute timeout should not evict anything
rl.cleanup(10 * time.Minute)

if rl.Len() != 2 {
t.Fatalf("Expected 2 IPs after cleanup (all active), got %d", rl.Len())
}
}

// TestGetLimiterRefreshesLastSeen tests that GetLimiter updates lastSeen.
func TestGetLimiterRefreshesLastSeen(t *testing.T) {
rl := NewIPRateLimiter(1, 5, 10, 20)

rl.GetLimiter("10.0.0.1")

// Backdate it
rl.mu.Lock()
rl.ips["10.0.0.1"].lastSeen = time.Now().Add(-20 * time.Minute)
rl.mu.Unlock()

// Access it again — should refresh lastSeen
rl.GetLimiter("10.0.0.1")

// Now cleanup should NOT evict it
rl.cleanup(10 * time.Minute)

if rl.Len() != 1 {
t.Fatalf("Expected 1 IP (refreshed by GetLimiter), got %d", rl.Len())
}
}

// TestGetLimits tests the limit getter methods.
func TestGetLimits(t *testing.T) {
rl := NewIPRateLimiter(rate.Limit(2), 5, rate.Limit(10), 20)
Expand Down
45 changes: 41 additions & 4 deletions stats/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type Stats struct {

// User agent tracking
userAgentUsage sync.Map // map[string]*atomic.Int64
uniqueUACount atomic.Int64
uaMu sync.Mutex
}

// Global stats instance
Expand Down Expand Up @@ -99,7 +101,9 @@ func (s *Stats) RecordRequest(endpoint string) {
return s.requestTimes[i].After(cutoff)
})
if idx > 0 {
s.requestTimes = s.requestTimes[idx:]
remaining := make([]time.Time, len(s.requestTimes)-idx)
copy(remaining, s.requestTimes[idx:])
s.requestTimes = remaining
}
s.requestTimesMu.Unlock()
}
Expand Down Expand Up @@ -145,13 +149,46 @@ func (s *Stats) AccountUsageSnapshot() map[string]int64 {
return result
}

// RecordUserAgent records a request from a specific user agent
// maxUniqueUserAgents is the cap on distinct user agent strings tracked.
const maxUniqueUserAgents = 1000

// RecordUserAgent records a request from a specific user agent.
// After maxUniqueUserAgents distinct agents, new ones are bucketed as "(other)".
func (s *Stats) RecordUserAgent(userAgent string) {
if userAgent == "" {
userAgent = "(empty)"
}
counter, _ := s.userAgentUsage.LoadOrStore(userAgent, &atomic.Int64{})
counter.(*atomic.Int64).Add(1)

// Fast path: UA already tracked
if counter, ok := s.userAgentUsage.Load(userAgent); ok {
counter.(*atomic.Int64).Add(1)
return
}

// Slow path: new UA — acquire lock for cap-safe insertion
s.uaMu.Lock()

// Re-check after lock (another goroutine may have added this UA)
if counter, ok := s.userAgentUsage.Load(userAgent); ok {
s.uaMu.Unlock()
counter.(*atomic.Int64).Add(1)
return
}

// Check cap under lock — no TOCTOU possible
if s.uniqueUACount.Load() >= maxUniqueUserAgents {
s.uaMu.Unlock()
counter, _ := s.userAgentUsage.LoadOrStore("(other)", &atomic.Int64{})
counter.(*atomic.Int64).Add(1)
return
}

// Store new UA and increment count atomically (under lock)
counter := &atomic.Int64{}
s.userAgentUsage.Store(userAgent, counter)
s.uniqueUACount.Add(1)
s.uaMu.Unlock()
counter.Add(1)
}

// UserAgentSnapshot returns a map of user agents to request counts
Expand Down
Loading
Loading