diff --git a/pkg/net/credentials/alts/handshake/handshake_test.go b/pkg/net/credentials/alts/handshake/handshake_test.go index f2f06bc7..13b40175 100644 --- a/pkg/net/credentials/alts/handshake/handshake_test.go +++ b/pkg/net/credentials/alts/handshake/handshake_test.go @@ -435,7 +435,7 @@ func TestHandshakerConcurrentHandshakes(t *testing.T) { } // Verify concurrent handshake limits - require.LessOrEqual(t, stat.MaxConcurrentCalls, maxConcurrentHandshakes, + require.LessOrEqual(t, stat.GetMaxConcurrentCalls(), maxConcurrentHandshakes, "concurrent handshakes exceeded limit") }) } diff --git a/pkg/net/credentials/alts/testutil/testutil.go b/pkg/net/credentials/alts/testutil/testutil.go index bc221829..9781c236 100644 --- a/pkg/net/credentials/alts/testutil/testutil.go +++ b/pkg/net/credentials/alts/testutil/testutil.go @@ -4,7 +4,7 @@ package testutil import ( "encoding/binary" "net" - "sync" + "sync/atomic" "time" . "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common" @@ -12,33 +12,41 @@ import ( // Stats is used to collect statistics about concurrent handshake calls. type Stats struct { - mu sync.Mutex - calls int - MaxConcurrentCalls int + calls int32 + maxConcurrentCalls int32 } // Update updates the statistics by adding one call. func (s *Stats) Update() func() { - s.mu.Lock() - s.calls++ - if s.calls > s.MaxConcurrentCalls { - s.MaxConcurrentCalls = s.calls + // Atomically increment calls + newCalls := atomic.AddInt32(&s.calls, 1) + + // Update MaxConcurrentCalls if needed + for { + currentMax := atomic.LoadInt32(&s.maxConcurrentCalls) + if newCalls <= currentMax { + break + } + if atomic.CompareAndSwapInt32(&s.maxConcurrentCalls, currentMax, newCalls) { + break + } } - s.mu.Unlock() return func() { - s.mu.Lock() - s.calls-- - s.mu.Unlock() + // Atomically decrement calls + atomic.AddInt32(&s.calls, -1) } } // Reset resets the statistics. func (s *Stats) Reset() { - s.mu.Lock() - defer s.mu.Unlock() - s.calls = 0 - s.MaxConcurrentCalls = 0 + atomic.StoreInt32(&s.calls, 0) + atomic.StoreInt32(&s.maxConcurrentCalls, 0) +} + +// GetMaxConcurrentCalls returns the maximum number of concurrent calls. +func (s *Stats) GetMaxConcurrentCalls() int { + return int(atomic.LoadInt32(&s.maxConcurrentCalls)) } // testConn mimics a net.Conn to the peer. @@ -50,7 +58,7 @@ type testLatencyConn struct { // NewTestConnWithReadLatency wraps a net.Conn with artificial read latency func NewTestConnWithReadLatency(conn net.Conn, readLatency time.Duration) net.Conn { return &testLatencyConn{ - Conn: conn, + Conn: conn, readLatency: readLatency, } } @@ -87,9 +95,9 @@ func NewUnresponsiveTestConn(delay time.Duration) net.Conn { // Read reads from the in buffer. func (c *unresponsiveTestConn) Read([]byte) (n int, err error) { - // Wait for delay to simulate network latency - time.Sleep(c.delay) - // Return empty data (success but zero bytes) + // Wait for delay to simulate network latency + time.Sleep(c.delay) + // Return empty data (success but zero bytes) return 0, nil }