diff --git a/README.md b/README.md index 0c57d23..abcdadc 100644 --- a/README.md +++ b/README.md @@ -486,6 +486,8 @@ The constructor performs validation on all parameters and returns descriptive er - `CheckTokens(id []byte, n uint8) bool`: Checks if n tokens would be available without consuming them - `TakeToken(id []byte) bool`: Attempts to take a single token, returns true if successful - `TakeTokens(id []byte, n uint8) bool`: Attempts to take n tokens atomically, returns true if all n tokens were taken +- `SetRefillRate(refillRate float64) error`: Updates the refill rate in-place while preserving existing bucket state +- `RefillRate() float64`: Returns the current refill rate - `RotationInterval() time.Duration`: Returns the automatically calculated rotation interval #### Collision-Resistant Algorithm Explained diff --git a/bucket.go b/bucket.go index c5c958a..97b5ac0 100644 --- a/bucket.go +++ b/bucket.go @@ -65,14 +65,8 @@ func NewTokenBucketLimiter( refillRate float64, refillRateUnit time.Duration, ) (*TokenBucketLimiter, error) { - if math.IsNaN(refillRate) || math.IsInf(refillRate, 0) || refillRate <= 0 { - return nil, fmt.Errorf("refillRate must be a positive, finite number") - } - - if rate := float64(refillRateUnit.Nanoseconds()); rate <= 0 { - return nil, fmt.Errorf("refillRateUnit must represent a positive duration") - } else if rate > math.MaxFloat64/refillRate { - return nil, fmt.Errorf("refillRate per duration is too large") + if err := validateRefillRate(refillRate, refillRateUnit); err != nil { + return nil, err } n := ceilPow2(uint64(numBuckets)) @@ -92,6 +86,20 @@ func NewTokenBucketLimiter( }, nil } +func validateRefillRate(refillRate float64, refillRateUnit time.Duration) error { + if math.IsNaN(refillRate) || math.IsInf(refillRate, 0) || refillRate <= 0 { + return fmt.Errorf("refillRate must be a positive, finite number") + } + + if rate := float64(refillRateUnit.Nanoseconds()); rate <= 0 { + return fmt.Errorf("refillRateUnit must represent a positive duration") + } else if rate > math.MaxFloat64/refillRate { + return fmt.Errorf("refillRate per duration is too large") + } + + return nil +} + // CheckToken returns whether a token would be available for the given // ID without actually taking it. This is useful for preemptively // checking if an operation would be rate limited before attempting diff --git a/rotating.go b/rotating.go index d58ded1..124df2f 100644 --- a/rotating.go +++ b/rotating.go @@ -20,6 +20,12 @@ type rotatingPair struct { rotated time56.Time } +type refillState struct { + refillRate float64 + nanosPerToken int64 + nanosPerRotation int64 +} + // RotatingTokenBucketLimiter implements a collision-resistant token // bucket rate limiter. It maintains two TokenBucketLimiters with // different hash seeds and rotates between them periodically. This @@ -40,8 +46,10 @@ type rotatingPair struct { // last for the duration of the rotation period, providing better // fairness and accuracy compared to a single TokenBucketLimiter. type RotatingTokenBucketLimiter struct { - pair atomic.Pointer[rotatingPair] // Current limiter pair - nanosPerRotation int64 // Rotation interval in nanoseconds + pair atomic.Pointer[rotatingPair] // Current limiter pair + burstCapacity uint8 + refillRateUnit time.Duration + state atomic.Pointer[refillState] } // Compile-time assertion that RotatingTokenBucketLimiter implements Limiter @@ -97,18 +105,18 @@ func NewRotatingTokenBucketLimiter( refillRate float64, refillRateUnit time.Duration, ) (*RotatingTokenBucketLimiter, error) { - checked, err := NewTokenBucketLimiter( + if err := validateRefillRate(refillRate, refillRateUnit); err != nil { + return nil, err + } + + // Validation passed above, and NewTokenBucketLimiter currently has no + // additional error cases beyond parameter validation. + checked, _ := NewTokenBucketLimiter( numBuckets, burstCapacity, refillRate, refillRateUnit, ) - if err != nil { - return nil, err - } - - // validation passed for exact params above, continue w/o checking - // error for 100% coverage. ignored, _ := NewTokenBucketLimiter( numBuckets, @@ -121,13 +129,11 @@ func NewRotatingTokenBucketLimiter( // convergence of all token buckets to steady state before rotation // occurs. This guarantees correctness by eliminating state // inconsistency issues when hash mappings change during rotation. - refillTime := time.Duration(float64(burstCapacity) / refillRate * float64(refillRateUnit)) - safetyFactor := 5.0 - rotationRate := time.Duration(float64(refillTime) * safetyFactor) - limiter := &RotatingTokenBucketLimiter{ - nanosPerRotation: rotationRate.Nanoseconds(), + burstCapacity: burstCapacity, + refillRateUnit: refillRateUnit, } + limiter.setRefillRateState(refillRate) pair := &rotatingPair{ checked: checked, @@ -157,13 +163,13 @@ func NewRotatingTokenBucketLimiter( // This approach ensures that hash collisions are resolved // periodically without affecting the thread-safety or performance of // the limiter. -func (r *RotatingTokenBucketLimiter) load(nowNS int64) *rotatingPair { +func (r *RotatingTokenBucketLimiter) load(nowNS int64, state *refillState) *rotatingPair { now := time56.Unix(nowNS) for { pair := r.pair.Load() - if now.Since(pair.rotated) < r.nanosPerRotation { + if now.Since(pair.rotated) < state.nanosPerRotation { return pair } @@ -220,9 +226,11 @@ func (r *RotatingTokenBucketLimiter) CheckToken(id []byte) bool { // multiple goroutines. func (r *RotatingTokenBucketLimiter) CheckTokens(id []byte, n uint8) bool { now := nowfn() - pair := r.load(now) - pair.ignored.checkTokensWithNow(id, n, now) - return pair.checked.checkTokensWithNow(id, n, now) + state := r.loadRefillState() + pair := r.load(now, state) + rate := state.nanosPerToken + pair.ignored.checkInner(pair.ignored.index(id), rate, now, n) + return pair.checked.checkInner(pair.checked.index(id), rate, now, n) } // TakeToken attempts to take a token for the given ID. It returns @@ -273,9 +281,29 @@ func (r *RotatingTokenBucketLimiter) TakeToken(id []byte) bool { // multiple goroutines. func (r *RotatingTokenBucketLimiter) TakeTokens(id []byte, n uint8) bool { now := nowfn() - pair := r.load(now) - pair.ignored.takeTokensWithNow(id, n, now) - return pair.checked.takeTokensWithNow(id, n, now) + state := r.loadRefillState() + pair := r.load(now, state) + rate := state.nanosPerToken + pair.ignored.takeTokenInner(pair.ignored.index(id), rate, now, n) + return pair.checked.takeTokenInner(pair.checked.index(id), rate, now, n) +} + +// SetRefillRate updates the refill rate used by the rotating limiter +// without rebuilding bucket state. Existing tokens are preserved, and +// subsequent checks, takes, and rotation timing use the new rate. +func (r *RotatingTokenBucketLimiter) SetRefillRate(refillRate float64) error { + if err := validateRefillRate(refillRate, r.refillRateUnit); err != nil { + return err + } + + r.setRefillRateState(refillRate) + return nil +} + +// RefillRate returns the current refill rate in tokens per +// refillRateUnit. +func (r *RotatingTokenBucketLimiter) RefillRate() float64 { + return r.loadRefillState().refillRate } // RotationInterval returns the automatically calculated rotation @@ -290,5 +318,23 @@ func (r *RotatingTokenBucketLimiter) TakeTokens(id []byte, n uint8) bool { // This method is thread-safe and can be called concurrently from // multiple goroutines. func (r *RotatingTokenBucketLimiter) RotationInterval() time.Duration { - return time.Duration(r.nanosPerRotation) + return time.Duration(r.loadRefillState().nanosPerRotation) +} + +func (r *RotatingTokenBucketLimiter) setRefillRateState(refillRate float64) { + r.state.Store(&refillState{ + refillRate: refillRate, + nanosPerToken: nanoRate(r.refillRateUnit, refillRate), + nanosPerRotation: calculateRotationInterval(r.burstCapacity, refillRate, r.refillRateUnit).Nanoseconds(), + }) +} + +func (r *RotatingTokenBucketLimiter) loadRefillState() *refillState { + return r.state.Load() +} + +func calculateRotationInterval(burstCapacity uint8, refillRate float64, refillRateUnit time.Duration) time.Duration { + refillTime := time.Duration(float64(burstCapacity) / refillRate * float64(refillRateUnit)) + safetyFactor := 5.0 + return time.Duration(float64(refillTime) * safetyFactor) } diff --git a/rotating_test.go b/rotating_test.go index db40aa0..86af4ba 100644 --- a/rotating_test.go +++ b/rotating_test.go @@ -94,6 +94,21 @@ func TestRotatingTokenBucketLimiterCreation(t *testing.T) { } } +func TestRotatingTokenBucketLimiterSetRefillRateValidation(t *testing.T) { + limiter, err := DefaultRotatingLimiter() + if err != nil { + t.Fatalf("Failed to create limiter: %v", err) + } + + if err := limiter.SetRefillRate(0); err == nil { + t.Fatal("Expected error for zero refill rate") + } + + if got := limiter.RefillRate(); got != rotatingRatePerSecond { + t.Fatalf("Expected refill rate to remain %v after failed update, got %v", rotatingRatePerSecond, got) + } +} + // TestRotatingTokenBucketLimiterBasicFunctionality tests basic token operations func TestRotatingTokenBucketLimiterBasicFunctionality(t *testing.T) { limiter, err := DefaultRotatingLimiter() @@ -128,6 +143,69 @@ func TestRotatingTokenBucketLimiterBasicFunctionality(t *testing.T) { } } +func TestRotatingTokenBucketLimiterSetRefillRatePreservesState(t *testing.T) { + limiter, err := NewRotatingTokenBucketLimiter( + rotatingNumBuckets, + 1, + 10.0, + time.Second, + ) + if err != nil { + t.Fatalf("Failed to create limiter: %v", err) + } + + id := []byte("adjustable-rate-state") + if !limiter.TakeToken(id) { + t.Fatal("Expected first token to succeed") + } + if limiter.TakeToken(id) { + t.Fatal("Expected bucket to be empty after first token") + } + + if err := limiter.SetRefillRate(1.0); err != nil { + t.Fatalf("SetRefillRate failed: %v", err) + } + if limiter.TakeToken(id) { + t.Fatal("Expected state to be preserved after refill rate change") + } + + tick(100 * time.Millisecond) + if limiter.TakeToken(id) { + t.Fatal("Expected slower refill rate to delay token availability") + } + + tick(900 * time.Millisecond) + if !limiter.TakeToken(id) { + t.Fatal("Expected token after one second at the new refill rate") + } +} + +func TestRotatingTokenBucketLimiterSetRefillRateSpeedsRefill(t *testing.T) { + limiter, err := NewRotatingTokenBucketLimiter( + rotatingNumBuckets, + 1, + 1.0, + time.Second, + ) + if err != nil { + t.Fatalf("Failed to create limiter: %v", err) + } + + id := []byte("adjustable-rate-fast") + if !limiter.TakeToken(id) { + t.Fatal("Expected first token to succeed") + } + + if err := limiter.SetRefillRate(10.0); err != nil { + t.Fatalf("SetRefillRate failed: %v", err) + } + + tick(100 * time.Millisecond) + if !limiter.TakeToken(id) { + t.Fatal("Expected faster refill rate to make token available sooner") + } +} + // TestRotatingTokenBucketLimiterImplementsInterface verifies interface compliance func TestRotatingTokenBucketLimiterImplementsInterface(t *testing.T) { limiter, err := DefaultRotatingLimiter() @@ -152,13 +230,13 @@ func TestRotatingTokenBucketLimiterRotation(t *testing.T) { } // Load initial pair - pair1 := limiter.load(nowfn()) + pair1 := limiter.load(nowfn(), limiter.loadRefillState()) if pair1 == nil { t.Fatal("Expected non-nil pair") } // Before rotation period, should get same pair - pair2 := limiter.load(nowfn()) + pair2 := limiter.load(nowfn(), limiter.loadRefillState()) if pair1 != pair2 { t.Error("Should get same pair before rotation period") } @@ -166,7 +244,7 @@ func TestRotatingTokenBucketLimiterRotation(t *testing.T) { // After rotation period, should get new pair // With burstCapacity=10, refillRate=1.0/sec: rotation = 10/1.0 * 5 = 50 seconds tick(50*time.Second + 1*time.Millisecond) - pair3 := limiter.load(nowfn()) + pair3 := limiter.load(nowfn(), limiter.loadRefillState()) if pair1 == pair3 { t.Error("Should get different pair after rotation period") } @@ -199,7 +277,7 @@ func TestRotatingTokenBucketLimiterCollisionAvoidance(t *testing.T) { id1 := []byte("collision-test-1") id2 := []byte("collision-test-2") - pair := limiter.load(nowfn()) + pair := limiter.load(nowfn(), limiter.loadRefillState()) index1 := pair.checked.index(id1) index2 := pair.checked.index(id2) @@ -343,13 +421,13 @@ func TestRotatingTokenBucketLimiterLoadLogic(t *testing.T) { now := nowfn() // First load should return initial pair - pair1 := limiter.load(now) + pair1 := limiter.load(now, limiter.loadRefillState()) if pair1 == nil { t.Fatal("Expected non-nil pair") } // Load with same timestamp should return same pair - pair2 := limiter.load(now) + pair2 := limiter.load(now, limiter.loadRefillState()) if pair1 != pair2 { t.Error("Same timestamp should return same pair") } @@ -358,14 +436,14 @@ func TestRotatingTokenBucketLimiterLoadLogic(t *testing.T) { // With burstCapacity=10, refillRate=1.0/sec: rotation = 10/1.0 * 5 = 50 seconds rotationInterval := 50 * time.Second beforeRotation := now + rotationInterval.Nanoseconds() - 1 - pair3 := limiter.load(beforeRotation) + pair3 := limiter.load(beforeRotation, limiter.loadRefillState()) if pair1 != pair3 { t.Error("Before rotation timestamp should return same pair") } // Load with timestamp at rotation should trigger rotation atRotation := now + rotationInterval.Nanoseconds() - pair4 := limiter.load(atRotation) + pair4 := limiter.load(atRotation, limiter.loadRefillState()) if pair1 == pair4 { t.Error("At rotation timestamp should return new pair") } @@ -471,6 +549,71 @@ func TestRotatingTokenBucketLimiterRotationInterval(t *testing.T) { } } +func TestRotatingTokenBucketLimiterSetRefillRateUpdatesRotationInterval(t *testing.T) { + limiter, err := NewRotatingTokenBucketLimiter( + rotatingNumBuckets, + rotatingBurstCapacity, + 1.0, + time.Second, + ) + if err != nil { + t.Fatalf("Failed to create limiter: %v", err) + } + + pair1 := limiter.load(nowfn(), limiter.loadRefillState()) + if got := limiter.RotationInterval(); got != 50*time.Second { + t.Fatalf("Expected initial rotation interval of 50s, got %v", got) + } + + if err := limiter.SetRefillRate(100.0); err != nil { + t.Fatalf("SetRefillRate failed: %v", err) + } + + if got := limiter.RefillRate(); got != 100.0 { + t.Fatalf("Expected refill rate of 100, got %v", got) + } + if got := limiter.RotationInterval(); got != 500*time.Millisecond { + t.Fatalf("Expected updated rotation interval of 500ms, got %v", got) + } + + tick(500*time.Millisecond + time.Millisecond) + pair2 := limiter.load(nowfn(), limiter.loadRefillState()) + if pair1 == pair2 { + t.Fatal("Expected limiter to rotate using the updated interval") + } +} + +func TestRotatingTokenBucketLimiterLoadUsesStateSnapshot(t *testing.T) { + limiter, err := NewRotatingTokenBucketLimiter( + rotatingNumBuckets, + rotatingBurstCapacity, + 1.0, + time.Second, + ) + if err != nil { + t.Fatalf("Failed to create limiter: %v", err) + } + + now := nowfn() + initialState := limiter.loadRefillState() + initialPair := limiter.load(now, initialState) + + if err := limiter.SetRefillRate(100.0); err != nil { + t.Fatalf("SetRefillRate failed: %v", err) + } + + updatedState := limiter.loadRefillState() + atUpdatedRotation := now + updatedState.nanosPerRotation + + if pair := limiter.load(atUpdatedRotation, initialState); pair != initialPair { + t.Fatal("Expected old state snapshot to keep using the old rotation interval") + } + + if pair := limiter.load(atUpdatedRotation, updatedState); pair == initialPair { + t.Fatal("Expected updated state snapshot to rotate using the new interval") + } +} + // TestRotatingTokenBucketLimiterDifferentIDs tests behavior with different IDs func TestRotatingTokenBucketLimiterDifferentIDs(t *testing.T) { limiter, err := DefaultRotatingLimiter() @@ -482,7 +625,7 @@ func TestRotatingTokenBucketLimiterDifferentIDs(t *testing.T) { id2 := []byte("different-id-2") // Check if these IDs hash to the same bucket (hash collision) - pair := limiter.load(nowfn()) + pair := limiter.load(nowfn(), limiter.loadRefillState()) index1 := pair.checked.index(id1) index2 := pair.checked.index(id2)