diff --git a/counter.go b/counter.go index 79a1bc8..6edda54 100644 --- a/counter.go +++ b/counter.go @@ -1,45 +1,72 @@ /* Package counter implements an advanced, fast and thread-safe counter. -It collects statstics, like current rate, min / max rate, etc. +It optionally collects statistics, like current rate, min / max rate, etc. */ package counter import ( "sync" + "sync/atomic" "time" ) // Counter is a fast, thread-safe counter. -// It collects statstics, like current rate, min / max rate, etc. +// It collects statistics, like current rate, min / max rate, etc. // The Counter can go up to `18446744073709551615` (2^64 - 1), as it uses uint64 internally. +// +// Basic usage: +// +// c := counter.NewCounter().Start() +// c.Increment() +// fmt.Println(c.Count()) // prints 1 +// c.Stop() +// rate := c.CalculateAverageRate(time.Second) // events per second type Counter struct { - mutex sync.Mutex - count uint64 - started bool - startedAt time.Time - stoppedAt time.Time - triggers []time.Time + // count is the current count, accessed atomically + count uint64 + + // mutex protects all fields except count + mutex sync.RWMutex + + started bool + startedAt time.Time + stoppedAt time.Time + + // Advanced statistics fields enableStats bool + triggers []time.Time + minDiff time.Duration // tracks minimum time between increments + maxDiff time.Duration // tracks maximum time between increments + lastTrigger time.Time // last time Increment was called } // NewCounter returns a new Counter. +// +// The counter starts in a stopped state. Call Start() to begin counting. func NewCounter() *Counter { return &Counter{ startedAt: time.Time{}, stoppedAt: time.Time{}, + minDiff: -1, // sentinel value indicating not set + maxDiff: 0, } } // WithAdvancedStats enables the calculation of advanced statistics like CalculateMinimumRate and CalculateMaximumRate. // CalculateAverageRate and CalculateCurrentRate are always enabled. +// +// Note: Enabling advanced stats will increase memory usage proportional to the number of increments. func (c *Counter) WithAdvancedStats() *Counter { cNew := NewCounter() cNew.enableStats = true + return cNew } // Start starts the counter. // It returns the counter itself, so you can chain it. +// +// If the counter is already started, this is a no-op. func (c *Counter) Start() *Counter { c.mutex.Lock() defer c.mutex.Unlock() @@ -51,10 +78,17 @@ func (c *Counter) Start() *Counter { c.started = true c.startedAt = time.Now() + if c.enableStats { + c.lastTrigger = c.startedAt + } + return c } // Stop stops the counter. +// +// This freezes the counter for rate calculations but does not reset the count. +// If the counter is already stopped, this is a no-op. func (c *Counter) Stop() { c.mutex.Lock() defer c.mutex.Unlock() @@ -68,43 +102,79 @@ func (c *Counter) Stop() { } // Increment increments the counter by 1. +// +// This method is thread-safe and can be called concurrently from multiple goroutines. func (c *Counter) Increment() { - c.mutex.Lock() - defer c.mutex.Unlock() + // Atomically increment the counter without locking + atomic.AddUint64(&c.count, 1) - c.count++ + // Only lock if advanced stats are enabled if c.enableStats { + c.mutex.Lock() + defer c.mutex.Unlock() + + if !c.started { + return + } + now := time.Now() c.triggers = append(c.triggers, now) + + // Update min/max time difference + if !c.lastTrigger.IsZero() { + diff := now.Sub(c.lastTrigger) + + // Update min diff (initialize if this is the first valid diff) + if c.minDiff == -1 || diff < c.minDiff { + c.minDiff = diff + } + + // Update max diff + if diff > c.maxDiff { + c.maxDiff = diff + } + } + + c.lastTrigger = now } } // Count returns the current count. +// +// This method is thread-safe and can be called concurrently from multiple goroutines. func (c *Counter) Count() uint64 { - c.mutex.Lock() - defer c.mutex.Unlock() - - return c.count + return atomic.LoadUint64(&c.count) } // Reset stops and resets the counter. +// +// This resets the count to 0 and clears all statistics. func (c *Counter) Reset() { c.mutex.Lock() defer c.mutex.Unlock() - c.count = 0 + atomic.StoreUint64(&c.count, 0) c.startedAt = time.Time{} c.stoppedAt = time.Now() c.started = false + c.triggers = nil + c.minDiff = -1 + c.maxDiff = 0 + c.lastTrigger = time.Time{} } // CalculateAverageRate calculates the average rate of the counter. // It returns the rate in `count / interval`. +// +// For example, to get events per second: +// +// rate := counter.CalculateAverageRate(time.Second) func (c *Counter) CalculateAverageRate(interval time.Duration) float64 { - c.mutex.Lock() - defer c.mutex.Unlock() + c.mutex.RLock() + defer c.mutex.RUnlock() - if c.count == 0 { + count := atomic.LoadUint64(&c.count) + if count == 0 { return 0 } @@ -113,59 +183,52 @@ func (c *Counter) CalculateAverageRate(interval time.Duration) float64 { untilTime = time.Now() } - return float64(c.count) / float64(untilTime.Sub(c.startedAt)) * float64(interval) + elapsed := untilTime.Sub(c.startedAt) + if elapsed <= 0 { + return 0 + } + + return float64(count) / float64(elapsed) * float64(interval) } // CalculateMaximumRate calculates the maximum rate of the counter. // It returns the rate in `count / interval`. -// It returns 0 if the counter has not been started yet. +// It returns 0 if the counter has not been started yet or has no increments. // Needs to be enabled via WithAdvancedStats. +// +// The maximum rate represents the fastest pace at which events occurred. func (c *Counter) CalculateMaximumRate(interval time.Duration) float64 { - c.mutex.Lock() - defer c.mutex.Unlock() + c.mutex.RLock() + defer c.mutex.RUnlock() if !c.enableStats { return 0 } - if len(c.triggers) == 0 { + if len(c.triggers) <= 1 || c.minDiff <= 0 { return 0 } - min := time.Duration(-1) - for i := 1; i < len(c.triggers); i++ { - diff := c.triggers[i].Sub(c.triggers[i-1]) - if diff < min || min == -1 { - min = diff - } - } - - return float64(interval) / float64(min) + return float64(interval) / float64(c.minDiff) } // CalculateMinimumRate calculates the minimum rate of the counter. // It returns the rate in `count / interval`. -// It returns 0 if the counter has not been started yet. +// It returns 0 if the counter has not been started yet or has no increments. // Needs to be enabled via WithAdvancedStats. +// +// The minimum rate represents the slowest pace at which events occurred. func (c *Counter) CalculateMinimumRate(interval time.Duration) float64 { - c.mutex.Lock() - defer c.mutex.Unlock() + c.mutex.RLock() + defer c.mutex.RUnlock() if !c.enableStats { return 0 } - if len(c.triggers) == 0 { + if len(c.triggers) <= 1 || c.maxDiff <= 0 { return 0 } - max := time.Duration(0) - for i := 1; i < len(c.triggers); i++ { - diff := c.triggers[i].Sub(c.triggers[i-1]) - if diff > max { - max = diff - } - } - - return float64(interval) / float64(max) + return float64(interval) / float64(c.maxDiff) } diff --git a/counter_benchmark_test.go b/counter_benchmark_test.go new file mode 100644 index 0000000..0f377a1 --- /dev/null +++ b/counter_benchmark_test.go @@ -0,0 +1,57 @@ +package counter + +import ( + "sync" + "testing" +) + +// basicCounter is a basic implementation of a counter. +// It's used to compare the performance to our version. +type basicCounter struct { + mutex sync.Mutex + count uint64 +} + +func (c *basicCounter) Increment() { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.count++ +} + +func (c *basicCounter) Count() uint64 { + c.mutex.Lock() + defer c.mutex.Unlock() + + return c.count +} + +func BenchmarkBasicCounterImplementation(b *testing.B) { + counter := basicCounter{} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + counter.Increment() + } +} + +func BenchmarkIncrement(b *testing.B) { + counter := NewCounter().Start() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + counter.Increment() + } +} + +func BenchmarkIncrementWithAdvancedStats(b *testing.B) { + counter := NewCounter().WithAdvancedStats().Start() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + counter.Increment() + } +} diff --git a/counter_test.go b/counter_test.go index 454e3c7..496567e 100644 --- a/counter_test.go +++ b/counter_test.go @@ -2,7 +2,9 @@ package counter import ( "sync" + "sync/atomic" "testing" + "time" "github.com/MarvinJWendt/testza" ) @@ -18,6 +20,7 @@ func TestCounter(t *testing.T) { for i := 0; i < 10; i++ { c.Increment() } + testza.AssertEqual(t, uint64(10), c.Count()) }) @@ -25,6 +28,7 @@ func TestCounter(t *testing.T) { for i := 0; i < 10; i++ { c.Increment() } + testza.AssertEqual(t, uint64(20), c.Count()) }) @@ -50,45 +54,97 @@ func TestCounter(t *testing.T) { }) } -// basicCounter is a basic implementation of a counter. -// It's used to compare the performance to our version. -type basicCounter struct { - mutex sync.Mutex - count uint64 -} +// TestAtomicOperations verifies that our counter works correctly +// with atomic operations, especially under concurrent access +func TestAtomicOperations(t *testing.T) { + c := NewCounter().Start() -func (c *basicCounter) Increment() { - c.mutex.Lock() - defer c.mutex.Unlock() - c.count++ -} + const numGoroutines = 10 + const incrementsPerGoroutine = 1000 -func (c *basicCounter) Count() uint64 { - c.mutex.Lock() - defer c.mutex.Unlock() - return c.count -} + var wg sync.WaitGroup + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + + go func() { + defer wg.Done() -func BenchmarkBasicCounterImplementation(b *testing.B) { - counter := basicCounter{} - b.ResetTimer() - for i := 0; i < b.N; i++ { - counter.Increment() + for j := 0; j < incrementsPerGoroutine; j++ { + c.Increment() + } + }() } + + wg.Wait() + c.Stop() + + expected := uint64(numGoroutines * incrementsPerGoroutine) + testza.AssertEqual(t, expected, c.Count(), "Count should be correct after concurrent increments") } -func BenchmarkIncrement(b *testing.B) { - counter := NewCounter().Start() - b.ResetTimer() - for i := 0; i < b.N; i++ { - counter.Increment() +// TestResetCleanup verifies that our Reset function properly cleans up all data +func TestResetCleanup(t *testing.T) { + c := NewCounter().WithAdvancedStats().Start() + + // Add some increments + for i := 0; i < 10; i++ { + c.Increment() + time.Sleep(1 * time.Millisecond) } + + // Reset the counter + c.Reset() + + // Verify count is reset to 0 + testza.AssertEqual(t, uint64(0), c.Count(), "Count should be 0 after reset") + + // Verify min/max rates are reset + testza.AssertEqual(t, 0.0, c.CalculateMinimumRate(time.Second), "Min rate should be 0 after reset") + testza.AssertEqual(t, 0.0, c.CalculateMaximumRate(time.Second), "Max rate should be 0 after reset") + + // Start and increment again to verify we can use the counter after reset + c.Start() + c.Increment() + testza.AssertEqual(t, uint64(1), c.Count(), "Count should be 1 after reset and increment") } -func BenchmarkIncrementWithAdvancedStats(b *testing.B) { - counter := NewCounter().WithAdvancedStats().Start() - b.ResetTimer() - for i := 0; i < b.N; i++ { - counter.Increment() +// TestReadWriteMutex verifies that our read-write mutex optimizations work correctly +func TestReadWriteMutex(t *testing.T) { + c := NewCounter().Start() + + // Start multiple readers and one writer + const numReaders = 100 + const readsPerGoroutine = 1000 + + // Counter for total reads completed + readsDone := int32(0) + + // Start readers + var wg sync.WaitGroup + for i := 0; i < numReaders; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for j := 0; j < readsPerGoroutine; j++ { + c.Count() + atomic.AddInt32(&readsDone, 1) + } + }() } + + // While readers are going, increment the counter periodically + go func() { + for atomic.LoadInt32(&readsDone) < int32(numReaders*readsPerGoroutine) { + c.Increment() + time.Sleep(1 * time.Millisecond) + } + }() + + // Wait for all readers to finish + wg.Wait() + + // If we get here without deadlock, the test passes + testza.AssertTrue(t, c.Count() > 0, "Counter should have been incremented") } diff --git a/examples_test.go b/examples_test.go index 75a4803..917a57e 100644 --- a/examples_test.go +++ b/examples_test.go @@ -12,6 +12,7 @@ func ExampleCounter_Increment() { for i := 0; i < 10; i++ { c.Increment() } + c.Stop() fmt.Println(c.Count()) @@ -20,10 +21,12 @@ func ExampleCounter_Increment() { func ExampleCounter_CalculateAverageRate() { c := counter.NewCounter().Start() + for i := 0; i < 10; i++ { time.Sleep(100 * time.Millisecond) c.Increment() } + c.Stop() fmt.Println(c.CalculateAverageRate(time.Second)) @@ -32,10 +35,12 @@ func ExampleCounter_CalculateAverageRate() { func ExampleCounter_CalculateMinimumRate() { c := counter.NewCounter().WithAdvancedStats().Start() + for i := 0; i < 10; i++ { time.Sleep(100 * time.Millisecond) c.Increment() } + c.Stop() fmt.Println(c.CalculateMinimumRate(time.Second)) @@ -44,10 +49,12 @@ func ExampleCounter_CalculateMinimumRate() { func ExampleCounter_CalculateMaximumRate() { c := counter.NewCounter().WithAdvancedStats().Start() + for i := 0; i < 10; i++ { time.Sleep(100 * time.Millisecond) c.Increment() } + c.Stop() fmt.Println(c.CalculateMaximumRate(time.Second)) @@ -59,6 +66,7 @@ func ExampleCounter_Reset() { for i := 0; i < 10; i++ { c.Increment() } + c.Reset() fmt.Println(c.Count())