diff --git a/bucket.go b/bucket.go index ecde1a3..5be20f2 100644 --- a/bucket.go +++ b/bucket.go @@ -103,3 +103,10 @@ func (b *bucket) nextTokensTime(executionTime ntime.Time, limit Limit, n int64) func (b *bucket) retryAfter(executionTime ntime.Time, limit Limit, n int64) time.Duration { return max(0, b.nextTokensTime(executionTime, limit, n).Sub(executionTime)) } + +// isFull checks if the bucket is full, i.e. has all the tokens it can have. +// +// ⚠️ caller is responsible for locking appropriately +func (b *bucket) isFull(executionTime ntime.Time, limit Limit) bool { + return b.time.BeforeOrEqual(executionTime.Add(-limit.period)) +} diff --git a/bucket_test.go b/bucket_test.go index dd7f8d9..b398ec3 100644 --- a/bucket_test.go +++ b/bucket_test.go @@ -576,3 +576,26 @@ func TestBucket_RetryAfter(t *testing.T) { require.Equal(t, expected, retryAfter, "retryAfter should be durationPerToken when requesting 2 tokens after exactly one durationPerToken has passed") }) } + +func TestBucket_IsFull(t *testing.T) { + t.Parallel() + + now := ntime.Now() + limit := NewLimit(11, time.Second) + bucket := newBucket(now, limit) + + // Newly created buckets are full + require.True(t, bucket.isFull(now, limit), "bucket should be full initially") + + bucket.consumeTokens(now, limit, 5) + require.Equal(t, int64(6), bucket.remainingTokens(now, limit), "bucket should have 6 tokens remaining after consuming 5 tokens") + require.False(t, bucket.isFull(now, limit), "bucket should not be full after consuming 5 tokens") + + now = now.Add(limit.durationPerToken * 4) + require.Equal(t, int64(10), bucket.remainingTokens(now, limit), "bucket should have 10 tokens remaining after consuming 5 tokens and waiting for 4") + require.False(t, bucket.isFull(now, limit), "bucket should not be full after consuming 5 tokens and waiting for 4") + + now = now.Add(limit.durationPerToken) + require.Equal(t, int64(11), bucket.remainingTokens(now, limit), "bucket should have 11 tokens remaining") + require.True(t, bucket.isFull(now, limit), "bucket should be full after consuming 1 token and waiting for 4 durationPerTokens") +} diff --git a/limiter.go b/limiter.go index 27eb522..e835a31 100644 --- a/limiter.go +++ b/limiter.go @@ -1,5 +1,7 @@ package rate +import "github.com/clipperhouse/ntime" + // Limiter is a rate limiter that can be used to limit the rate of requests to a given key. type Limiter[TInput any, TKey comparable] struct { keyFunc KeyFunc[TInput, TKey] @@ -41,3 +43,29 @@ func (r *Limiter[TInput, TKey]) getLimits(input TInput) []Limit { } return limits } + +// GC deletes buckets that are full, i.e, buckets for which enough +// time has passed that they are no longer relevant. A full bucket +// and a non-existent bucket have the same semantics. +// +// Without GC, buckets (memory) will grow unbounded. +// +// This can be a moderately expensive operation, depending +// on the number of buckets. If you want a cheaper operation, +// see [Clear]. +func (r *Limiter[TInput, TKey]) GC() (deleted int64) { + return r.buckets.gc(ntime.Now) +} + +// Clear deletes all buckets. This is semantically +// equivalent to refilling all buckets. +// +// You would use this method for garbage collection +// purposes, as the limiter's memory will grow unbounded +// otherwise. +// +// See also the [GC] method, which is more selective, and +// only deletes buckets that are no longer meaningful. +func (r *Limiter[TInput, TKey]) Clear() { + r.buckets.m.Clear() +} diff --git a/limiter_allow_test.go b/limiter_allow_test.go index d2b4242..57cbeec 100644 --- a/limiter_allow_test.go +++ b/limiter_allow_test.go @@ -151,7 +151,7 @@ func TestLimiter_Allow(t *testing.T) { } expected := buckets * len(limiter.limits) - require.Equal(t, limiter.buckets.count(), expected, "buckets should have persisted after allow") + require.Equal(t, int64(expected), limiter.buckets.count(), "buckets should have persisted after allow") }) }) diff --git a/limiter_peek_test.go b/limiter_peek_test.go index 884b11f..eb49e57 100644 --- a/limiter_peek_test.go +++ b/limiter_peek_test.go @@ -29,7 +29,7 @@ func TestLimiter_Peek_NeverPersists(t *testing.T) { } // no buckets should have been stored - require.Equal(t, limiter.buckets.count(), 0, "buckets should not persist after peeking") + require.Equal(t, int64(0), limiter.buckets.count(), "buckets should not persist after peeking") } func TestLimiter_Peek_SingleBucket(t *testing.T) { diff --git a/limiters.go b/limiters.go index 74a7fd4..18f90bd 100644 --- a/limiters.go +++ b/limiters.go @@ -12,3 +12,34 @@ func Combine[TInput any, TKey comparable](limiters ...*Limiter[TInput, TKey]) *L limiters: limiters, } } + +// GC deletes buckets that are full, i.e, buckets for which enough +// time has passed that they are no longer relevant. A full bucket +// and a non-existent bucket have the same semantics. +// +// Without GC, buckets (memory) will grow unbounded. +// +// This can be a moderately expensive operation, depending +// on the number of buckets. If you want a cheaper operation, +// see [Clear]. +func (r *Limiters[TInput, TKey]) GC() (deleted int64) { + for _, limiter := range r.limiters { + deleted += limiter.GC() + } + return deleted +} + +// Clear deletes all buckets. This is semantically +// equivalent to refilling all buckets. +// +// You would use this method for garbage collection +// purposes, as the limiter's memory will grow unbounded +// otherwise. +// +// See also the [GC] method, which is more selective, and +// only deletes buckets that are no longer meaningful. +func (r *Limiters[TInput, TKey]) Clear() { + for _, limiter := range r.limiters { + limiter.Clear() + } +} diff --git a/syncmap.go b/syncmap.go index 9f9c240..99100f3 100644 --- a/syncmap.go +++ b/syncmap.go @@ -8,7 +8,8 @@ import ( // bucketMap is a specialized sync.Map for storing buckets type bucketMap[TKey comparable] struct { - m sync.Map + m sync.Map + mu sync.RWMutex } // bucketSpec is a key for the bucket map, which includes the limit and the user key. @@ -47,11 +48,34 @@ func (bm *bucketMap[TKey]) load(userKey TKey, limit Limit) (*bucket, bool) { return nil, false } -func (bm *bucketMap[TKey]) count() int { - count := 0 +func (bm *bucketMap[TKey]) count() int64 { + count := int64(0) bm.m.Range(func(_, _ any) bool { count++ return true }) return count } + +// gc deletes buckets that are full, which has the +// same semantics as the bucket not existing +func (bm *bucketMap[TKey]) gc(timeFunc func() ntime.Time) (deleted int64) { + bm.mu.Lock() + defer bm.mu.Unlock() + + bm.m.Range(func(key, value any) bool { + // I wonder if there is a possibility of a race here, not sure. + // Thinking about the timing between getting the bucket from the map + // and locking the bucket. Maybe it's not a problem. + spec := key.(bucketSpec[TKey]) + b := value.(*bucket) + b.mu.Lock() + if b.isFull(timeFunc(), spec.limit) { + bm.m.Delete(key) + deleted++ + } + b.mu.Unlock() + return true + }) + return deleted +} diff --git a/syncmap_test.go b/syncmap_test.go index 5441a44..a2ba955 100644 --- a/syncmap_test.go +++ b/syncmap_test.go @@ -3,6 +3,7 @@ package rate import ( "fmt" "sync" + "sync/atomic" "testing" "time" @@ -108,7 +109,7 @@ func TestBucketMap_Count(t *testing.T) { executionTime := ntime.Now() // Empty map should have count 0 - require.Equal(t, 0, bm.count(), "empty map should have count 0") + require.Equal(t, int64(0), bm.count(), "empty map should have count 0") // Add buckets with different keys and limits for i := range 50 { @@ -116,7 +117,7 @@ func TestBucketMap_Count(t *testing.T) { bm.loadOrStore(key, executionTime, limit1) } - require.Equal(t, 50, bm.count(), "should have 50 buckets after adding 50 different keys") + require.Equal(t, int64(50), bm.count(), "should have 50 buckets after adding 50 different keys") // Add buckets with same keys but different limits for i := range 50 { @@ -124,7 +125,7 @@ func TestBucketMap_Count(t *testing.T) { bm.loadOrStore(key, executionTime, limit2) } - require.Equal(t, 100, bm.count(), "should have 100 buckets after adding same keys with different limits") + require.Equal(t, int64(100), bm.count(), "should have 100 buckets after adding same keys with different limits") // Adding duplicate key-limit combinations should not increase count for i := range 25 { @@ -132,7 +133,7 @@ func TestBucketMap_Count(t *testing.T) { bm.loadOrStore(key, executionTime, limit1) } - require.Equal(t, 100, bm.count(), "should still have 100 buckets after duplicate additions") + require.Equal(t, int64(100), bm.count(), "should still have 100 buckets after duplicate additions") } func TestBucketMap_ConcurrentAccess(t *testing.T) { @@ -190,5 +191,108 @@ func TestBucketMap_DifferentKeyTypes(t *testing.T) { bucket3 := bm.loadOrStore(43, executionTime, limit) require.False(t, bucket1 == bucket3, "different int keys should have different buckets") - require.Equal(t, 2, bm.count(), "should have 2 buckets for 2 different int keys") + require.Equal(t, int64(2), bm.count(), "should have 2 buckets for 2 different int keys") +} + +func TestBucketMap_GC(t *testing.T) { + t.Parallel() + + t.Run("gc deletes full buckets", func(t *testing.T) { + var bm bucketMap[string] + const count int64 = 1000 + + limit := NewLimit(10, time.Second) + executionTime := ntime.Now() + + // Create a bunch of older buckets + for i := range count { + key := fmt.Sprintf("key%d", i) + b := bm.loadOrStore(key, executionTime, limit) + // 1/4 of the buckets will not be full + if i%4 == 0 { + b.consumeTokens(executionTime, limit, 1) + } + } + + require.Equal(t, count, bm.count(), "should have 1000 buckets after creating 1000 older buckets") + + // GC should delete 750 buckets full buckets, and not the 250 that are not full + { + deleted := bm.gc(func() ntime.Time { + return executionTime + }) + require.Equal(t, int64(750), deleted, "should have 750 deleted buckets, since 1/4 of the buckets are full") + require.Equal(t, int64(250), bm.count(), "should have 250 buckets remaining after GC") + } + + // Time passes + executionTime = executionTime.Add(time.Second) + + // All remaining buckets are full now + { + deleted := bm.gc(func() ntime.Time { + return executionTime + }) + require.Equal(t, int64(250), deleted, "should have deleted the remaining 250 buckets") + require.Equal(t, int64(0), bm.count(), "should have 0 buckets remaining after all deletions") + } + }) + + t.Run("gc concurrent with reads and writes", func(t *testing.T) { + var bm bucketMap[string] + const count int64 = 1000 + + limit := NewLimit(10, time.Second) + executionTime := ntime.Now() + + // Create a bunch of buckets concurrently + + // Trying to get the timing right for a good test, + // since a slow system like GitHub Actions seems + // to take a while to launch goroutines. + signal := make(chan struct{}) + launched := int64(0) + + var wg sync.WaitGroup + for i := range count { + wg.Add(1) + go func(i int64) { + defer wg.Done() + + time.Sleep(time.Millisecond) + key := fmt.Sprintf("key%d", i) + b := bm.loadOrStore(key, executionTime, limit) + b.mu.Lock() + // 1/5 of the buckets will not be full + if i%5 == 0 { + b.consumeTokens(executionTime, limit, 1) + } + b.mu.Unlock() + + // Signal when we reach 100 launched goroutines + if atomic.AddInt64(&launched, 1) == 100 { + close(signal) + } + }(i) + } + + // Wait for 100 goroutines to launch before running GC, + // try to induce some concurrency. + + <-signal + bm.gc(func() ntime.Time { + return executionTime + }) + wg.Wait() + + // Expect that some, but not all, deletions have happened, + // since there was concurrent creation of buckets. + require.Less(t, bm.count(), count, "some deletions should have happened") + + // Now delete the remaining buckets, without concurrency + bm.gc(func() ntime.Time { + return executionTime + }) + require.Equal(t, int64(200), bm.count(), "should have 200 buckets after deletion and GC") + }) }