Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/net/credentials/alts/handshake/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func TestHandshakerConcurrentHandshakes(t *testing.T) {
}

// Verify concurrent handshake limits
require.LessOrEqual(t, stat.MaxConcurrentCalls, maxConcurrentHandshakes,
require.LessOrEqual(t, stat.GetMaxConcurrentCalls(), maxConcurrentHandshakes,
"concurrent handshakes exceeded limit")
})
}
Expand Down
48 changes: 28 additions & 20 deletions pkg/net/credentials/alts/testutil/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,49 @@ package testutil
import (
"encoding/binary"
"net"
"sync"
"sync/atomic"
"time"

. "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
)

// Stats is used to collect statistics about concurrent handshake calls.
type Stats struct {
mu sync.Mutex
calls int
MaxConcurrentCalls int
calls int32
maxConcurrentCalls int32
}

// Update updates the statistics by adding one call.
func (s *Stats) Update() func() {
s.mu.Lock()
s.calls++
if s.calls > s.MaxConcurrentCalls {
s.MaxConcurrentCalls = s.calls
// Atomically increment calls
newCalls := atomic.AddInt32(&s.calls, 1)

// Update MaxConcurrentCalls if needed
for {
currentMax := atomic.LoadInt32(&s.maxConcurrentCalls)
if newCalls <= currentMax {
break
}
if atomic.CompareAndSwapInt32(&s.maxConcurrentCalls, currentMax, newCalls) {
break
}
}
s.mu.Unlock()

return func() {
s.mu.Lock()
s.calls--
s.mu.Unlock()
// Atomically decrement calls
atomic.AddInt32(&s.calls, -1)
}
}

// Reset resets the statistics.
func (s *Stats) Reset() {
s.mu.Lock()
defer s.mu.Unlock()
s.calls = 0
s.MaxConcurrentCalls = 0
atomic.StoreInt32(&s.calls, 0)
atomic.StoreInt32(&s.maxConcurrentCalls, 0)
}

// GetMaxConcurrentCalls returns the maximum number of concurrent calls.
func (s *Stats) GetMaxConcurrentCalls() int {
return int(atomic.LoadInt32(&s.maxConcurrentCalls))
}

// testConn mimics a net.Conn to the peer.
Expand All @@ -50,7 +58,7 @@ type testLatencyConn struct {
// NewTestConnWithReadLatency wraps a net.Conn with artificial read latency
func NewTestConnWithReadLatency(conn net.Conn, readLatency time.Duration) net.Conn {
return &testLatencyConn{
Conn: conn,
Conn: conn,
readLatency: readLatency,
}
}
Expand Down Expand Up @@ -87,9 +95,9 @@ func NewUnresponsiveTestConn(delay time.Duration) net.Conn {

// Read reads from the in buffer.
func (c *unresponsiveTestConn) Read([]byte) (n int, err error) {
// Wait for delay to simulate network latency
time.Sleep(c.delay)
// Return empty data (success but zero bytes)
// Wait for delay to simulate network latency
time.Sleep(c.delay)
// Return empty data (success but zero bytes)
return 0, nil
}

Expand Down