diff --git a/controller/Makefile b/controller/Makefile index 08c9aa00e..bbf2dc457 100644 --- a/controller/Makefile +++ b/controller/Makefile @@ -96,7 +96,7 @@ vet: ## Run go vet against code. .PHONY: test test: manifests generate fmt vet envtest ## Run tests. - KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -coverprofile cover.out + KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -race -coverprofile cover.out # Utilize Kind or modify the e2e tests to load the image locally, enabling compatibility with other vendors. .PHONY: test-e2e # Run the e2e tests against a Kind k8s instance that is spun up. diff --git a/controller/internal/service/controller_service.go b/controller/internal/service/controller_service.go index f0ec66f0a..ee437c4e6 100644 --- a/controller/internal/service/controller_service.go +++ b/controller/internal/service/controller_service.go @@ -26,6 +26,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "golang.org/x/exp/maps" @@ -80,6 +81,90 @@ type ControllerService struct { ServerOptions []grpc.ServerOption Router config.Router listenQueues sync.Map + leaseLocks sync.Map +} + +type listenQueue struct { + ch chan *pb.ListenResponse + done chan struct{} + closeOnce sync.Once +} + +func (q *listenQueue) closeDone() { + q.closeOnce.Do(func() { close(q.done) }) +} + +type leaseLock struct { + mu sync.Mutex + refs int32 +} + +func (s *ControllerService) acquireLeaseLock(leaseName string) *sync.Mutex { + for { + v, loaded := s.leaseLocks.LoadOrStore(leaseName, &leaseLock{refs: 1}) + ll := v.(*leaseLock) + if !loaded { + return &ll.mu + } + newRefs := atomic.AddInt32(&ll.refs, 1) + if newRefs <= 1 { + atomic.AddInt32(&ll.refs, -1) + continue + } + return &ll.mu + } +} + +func (s *ControllerService) releaseLeaseLock(leaseName string) { + v, ok := s.leaseLocks.Load(leaseName) + if !ok { + return + } + ll := v.(*leaseLock) + if atomic.AddInt32(&ll.refs, -1) == 0 { + s.leaseLocks.CompareAndDelete(leaseName, ll) + } +} + +// swapListenQueue atomically replaces the listen queue for a lease and signals +// the previous queue to stop. The per-lease lock serializes this with +// sendToListener so that Dial never sends a token to a superseded queue. +func (s *ControllerService) swapListenQueue(leaseName string, newQueue *listenQueue) { + mu := s.acquireLeaseLock(leaseName) + mu.Lock() + old, loaded := s.listenQueues.Swap(leaseName, newQueue) + if loaded { + old.(*listenQueue).closeDone() + } + mu.Unlock() + s.releaseLeaseLock(leaseName) +} + +// sendToListener delivers a response to the active listener for a lease. The +// per-lease lock guarantees that the queue loaded here cannot be superseded +// between the load and the send, eliminating the TOCTOU race between Dial and +// a reconnecting Listen. +func (s *ControllerService) sendToListener(_ context.Context, leaseName string, response *pb.ListenResponse) error { + mu := s.acquireLeaseLock(leaseName) + defer s.releaseLeaseLock(leaseName) + mu.Lock() + defer mu.Unlock() + v, ok := s.listenQueues.Load(leaseName) + if !ok { + return status.Errorf(codes.Unavailable, "exporter is not listening on lease %s", leaseName) + } + q := v.(*listenQueue) + select { + case <-q.done: + return status.Errorf(codes.Unavailable, "exporter is not listening on lease %s", leaseName) + default: + } + select { + case q.ch <- response: + return nil + default: + return status.Errorf(codes.ResourceExhausted, "listener buffer full on lease %s", leaseName) + } } type wrappedStream struct { @@ -439,12 +524,35 @@ func (s *ControllerService) Listen(req *pb.ListenRequest, stream pb.ControllerSe return err } - queue, _ := s.listenQueues.LoadOrStore(leaseName, make(chan *pb.ListenResponse, 8)) + wrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + listenMu := s.acquireLeaseLock(leaseName) + s.swapListenQueue(leaseName, wrapper) + defer func() { + listenMu.Lock() + wrapper.closeDone() + listenMu.Unlock() + s.listenQueues.CompareAndDelete(leaseName, wrapper) + s.releaseLeaseLock(leaseName) + }() for { select { case <-ctx.Done(): return nil - case msg := <-queue.(chan *pb.ListenResponse): + case <-wrapper.done: + for { + select { + case msg := <-wrapper.ch: + if err := stream.Send(msg); err != nil { + return err + } + default: + return nil + } + } + case msg := <-wrapper.ch: if err := stream.Send(msg); err != nil { return err } @@ -732,11 +840,8 @@ func (s *ControllerService) Dial(ctx context.Context, req *pb.DialRequest) (*pb. RouterToken: token, } - queue, _ := s.listenQueues.LoadOrStore(leaseName, make(chan *pb.ListenResponse, 8)) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case queue.(chan *pb.ListenResponse) <- response: + if err := s.sendToListener(ctx, leaseName, response); err != nil { + return nil, err } logger.Info("Client dial assigned stream", "stream", stream) diff --git a/controller/internal/service/controller_service_test.go b/controller/internal/service/controller_service_test.go index e4d21f717..9e506f543 100644 --- a/controller/internal/service/controller_service_test.go +++ b/controller/internal/service/controller_service_test.go @@ -17,7 +17,11 @@ limitations under the License. package service import ( + "context" + "strings" + "sync" "testing" + "time" jumpstarterdevv1alpha1 "github.com/jumpstarter-dev/jumpstarter-controller/api/v1alpha1" pb "github.com/jumpstarter-dev/jumpstarter-controller/internal/protocol/jumpstarter/v1" @@ -27,6 +31,20 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +const testRouterToken = "tok" + +func drainChannel(ch <-chan *pb.ListenResponse) int { + count := 0 + for { + select { + case <-ch: + count++ + default: + return count + } + } +} + func TestProtoStatusToString(t *testing.T) { tests := []struct { name string @@ -183,7 +201,7 @@ func TestCheckExporterStatusForDriverCalls(t *testing.T) { } if tt.expectedSubstr != "" { - if !contains(st.Message(), tt.expectedSubstr) { + if !strings.Contains(st.Message(), tt.expectedSubstr) { t.Errorf("error message = %q, want to contain %q", st.Message(), tt.expectedSubstr) } } @@ -299,17 +317,1464 @@ func TestSyncOnlineConditionWithStatus(t *testing.T) { } } -// contains checks if substr is contained in s -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(substr) == 0 || - (len(s) > 0 && len(substr) > 0 && searchSubstring(s, substr))) +func TestListenQueueCompareAndDeleteOnStreamError(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-stream-error" + + wrapper := &listenQueue{ch: make(chan *pb.ListenResponse, 8), done: make(chan struct{})} + svc.swapListenQueue(leaseName, wrapper) + + t.Run("queue is deleted when no reconnect replaced it", func(t *testing.T) { + svc.listenQueues.CompareAndDelete(leaseName, wrapper) + + if _, ok := svc.listenQueues.Load(leaseName); ok { + t.Fatal("queue should be deleted when it is still the same instance") + } + }) + + t.Run("queue survives when a reconnecting Listen replaced it", func(t *testing.T) { + newWrapper := &listenQueue{ch: make(chan *pb.ListenResponse, 8), done: make(chan struct{})} + svc.swapListenQueue(leaseName, newWrapper) + + svc.listenQueues.CompareAndDelete(leaseName, wrapper) + + got, ok := svc.listenQueues.Load(leaseName) + if !ok { + t.Fatal("queue was deleted even though a new Listen replaced it") + } + if got != newWrapper { + t.Fatal("queue was replaced with something unexpected") + } + }) +} + +func TestListenQueueCompareAndDeleteOnCleanShutdown(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-shutdown" + + wrapper := &listenQueue{ch: make(chan *pb.ListenResponse, 8), done: make(chan struct{})} + svc.swapListenQueue(leaseName, wrapper) + + svc.listenQueues.CompareAndDelete(leaseName, wrapper) + + if _, ok := svc.listenQueues.Load(leaseName); ok { + t.Fatal("queue should be removed on clean shutdown") + } +} + +func TestListenQueueReconnectCreatesNewChannel(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-reconnect" + + originalWrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, originalWrapper) + + newWrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, newWrapper) + + v, ok := svc.listenQueues.Load(leaseName) + if !ok { + t.Fatal("queue entry should still exist") + } + current := v.(*listenQueue) + if current.ch == originalWrapper.ch { + t.Fatal("reconnecting Listen must use a new channel, not the old one") + } + if current != newWrapper { + t.Fatal("queue entry should be the new wrapper") + } + + select { + case <-originalWrapper.done: + default: + t.Fatal("original wrapper done channel should be closed after swap") + } +} + +func TestListenQueueDialTokenDeliveredToNewListener(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-dial-token" + + g1 := &listenQueue{ch: make(chan *pb.ListenResponse, 8), done: make(chan struct{})} + svc.swapListenQueue(leaseName, g1) + + g2 := &listenQueue{ch: make(chan *pb.ListenResponse, 8), done: make(chan struct{})} + svc.swapListenQueue(leaseName, g2) + + response := &pb.ListenResponse{RouterEndpoint: "test-endpoint", RouterToken: "test-token"} + err := svc.sendToListener(context.Background(), leaseName, response) + if err != nil { + t.Fatalf("sendToListener should succeed for active queue: %v", err) + } + + select { + case got := <-g2.ch: + if got.RouterEndpoint != "test-endpoint" || got.RouterToken != "test-token" { + t.Fatal("dial token was corrupted") + } + default: + t.Fatal("dial token was not delivered to the new listener") + } + + select { + case <-g1.ch: + t.Fatal("dial token was delivered to the old listener") + default: + } +} + +func TestListenQueueReconnectPreventsStaleCleanup(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-stale-cleanup" + + originalWrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, originalWrapper) + + reconnectWrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, reconnectWrapper) + + // Original wrapper's deferred CompareAndDelete should be a no-op. + svc.listenQueues.CompareAndDelete(leaseName, originalWrapper) + + got, ok := svc.listenQueues.Load(leaseName) + if !ok { + t.Fatal("stale Listen cleanup deleted queue that reconnected Listen is using") + } + if got != reconnectWrapper { + t.Fatal("queue entry does not match the reconnected wrapper") + } + + token := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + reconnectWrapper.ch <- token + + select { + case msg := <-reconnectWrapper.ch: + if msg.RouterEndpoint != "ep" || msg.RouterToken != testRouterToken { + t.Fatal("token was corrupted after stale cleanup attempt") + } + default: + t.Fatal("token was lost after stale cleanup attempt") + } +} + +func TestListenQueueConcurrentSwapSupersedes(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-concurrent-swap" + + g1 := &listenQueue{ch: make(chan *pb.ListenResponse, 8), done: make(chan struct{})} + svc.swapListenQueue(leaseName, g1) + + g2 := &listenQueue{ch: make(chan *pb.ListenResponse, 8), done: make(chan struct{})} + svc.swapListenQueue(leaseName, g2) + + g3 := &listenQueue{ch: make(chan *pb.ListenResponse, 8), done: make(chan struct{})} + svc.swapListenQueue(leaseName, g3) + + // G1 and G2 should both have their done channels closed. + select { + case <-g1.done: + default: + t.Fatal("G1 done channel should be closed") + } + select { + case <-g2.done: + default: + t.Fatal("G2 done channel should be closed") + } + + // G3 should still be active. + select { + case <-g3.done: + t.Fatal("G3 done channel should not be closed") + default: + } + + // G1 and G2 deferred CompareAndDelete are no-ops. + svc.listenQueues.CompareAndDelete(leaseName, g1) + svc.listenQueues.CompareAndDelete(leaseName, g2) + + got, ok := svc.listenQueues.Load(leaseName) + if !ok { + t.Fatal("queue was deleted by stale CompareAndDelete") + } + if got != g3 { + t.Fatal("queue entry does not match G3") + } +} + +func TestListenQueueStaleReaderConsumesDialToken(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-stale-reader" + + g1Queue := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1Queue) + + g2Queue := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g2Queue) + + token := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + err := svc.sendToListener(context.Background(), leaseName, token) + if err != nil { + t.Fatalf("sendToListener should succeed for active queue: %v", err) + } + + select { + case <-g1Queue.done: + default: + t.Fatal("G1 done channel should be closed after swap") + } + + select { + case <-g1Queue.ch: + t.Fatal("stale reader G1 consumed the dial token") + default: + } + + select { + case got := <-g2Queue.ch: + if got.RouterEndpoint != "ep" || got.RouterToken != testRouterToken { + t.Fatal("token received by G2 was corrupted") + } + default: + t.Fatal("active reader G2 did not receive the dial token") + } +} + +func TestListenQueueStaleReaderAlwaysDetectsSupersession(t *testing.T) { + iterations := 100 + + for i := 0; i < iterations; i++ { + svc := &ControllerService{} + leaseName := "test-lease-concurrent" + + g1Queue := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1Queue) + + g2Queue := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g2Queue) + + err := svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + if err != nil { + t.Fatalf("iteration %d: sendToListener should succeed: %v", i, err) + } + + select { + case <-g1Queue.done: + default: + t.Fatalf("iteration %d: g1 done channel should be closed after supersession", i) + } + + select { + case <-g1Queue.ch: + t.Fatalf("iteration %d: stale reader g1 consumed a token after supersession", i) + default: + } + } +} + +func TestDialRejectsSupersededQueue(t *testing.T) { + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + close(q.done) + + response := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + + rejected := false + select { + case <-q.done: + rejected = true + default: + } + if !rejected { + select { + case <-q.done: + rejected = true + case q.ch <- response: + } + } + + if !rejected { + t.Fatal("dial must reject send to a queue whose done channel is closed") + } + + select { + case <-q.ch: + t.Fatal("token should not have been buffered in a superseded queue") + default: + } +} + +func TestDialWithPreSwapReferenceNeverSendsToStaleQueue(t *testing.T) { + iterations := 500 + + for i := 0; i < iterations; i++ { + svc := &ControllerService{} + leaseName := "test-lease-pre-swap-ref" + + g1 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1) + + g2 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g2) + + response := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + + err := svc.sendToListener(context.Background(), leaseName, response) + if err != nil { + t.Fatalf("iteration %d: sendToListener should succeed for active g2: %v", i, err) + } + + select { + case <-g1.ch: + t.Fatalf("iteration %d: dial sent to stale queue g1", i) + default: + } + + select { + case got := <-g2.ch: + if got.RouterEndpoint != "ep" || got.RouterToken != testRouterToken { + t.Fatalf("iteration %d: token corrupted on g2", i) + } + default: + t.Fatalf("iteration %d: token not delivered to active g2", i) + } + } +} + +func TestDialSendsTokenViaServiceMethod(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-dial-method" + + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, q) + + response := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + + err := svc.sendToListener(context.Background(), leaseName, response) + if err != nil { + t.Fatalf("sendToListener should succeed for active queue: %v", err) + } + + select { + case got := <-q.ch: + if got.RouterEndpoint != "ep" || got.RouterToken != testRouterToken { + t.Fatal("token was corrupted") + } + default: + t.Fatal("token was not delivered") + } +} + +func TestDialSendToListenerRejectsSupersededQueue(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-dial-method-superseded" + + g1 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1) + + g2 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g2) + + response := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + + err := svc.sendToListener(context.Background(), leaseName, response) + if err != nil { + t.Fatalf("sendToListener should succeed for the new active queue: %v", err) + } + + select { + case <-g1.ch: + t.Fatal("token was delivered to superseded queue g1") + default: + } + + select { + case got := <-g2.ch: + if got.RouterEndpoint != "ep" || got.RouterToken != testRouterToken { + t.Fatal("token was corrupted") + } + default: + t.Fatal("token was not delivered to active queue g2") + } } -func searchSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true +func TestDialSendToListenerRejectsNoListener(t *testing.T) { + svc := &ControllerService{} + + response := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + err := svc.sendToListener(context.Background(), "nonexistent-lease", response) + if err == nil { + t.Fatal("sendToListener should return error when no listener exists") + } +} + +func TestDialSendToListenerRejectsDoneQueue(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-done-queue" + + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, q) + q.closeDone() + + response := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + err := svc.sendToListener(context.Background(), leaseName, response) + if err == nil { + t.Fatal("sendToListener should return error for done queue") + } + + select { + case <-q.ch: + t.Fatal("token should not be buffered in a done queue") + default: + } +} + +func TestDialSendToListenerSerializesWithSwap(t *testing.T) { + // Verify that swapListenQueue followed by sendToListener always delivers + // to the new queue (or returns an error), never to the superseded queue. + // This tests the scenario where the swap completes before the send. + iterations := 500 + + for i := 0; i < iterations; i++ { + svc := &ControllerService{} + leaseName := "test-lease-serialized" + + g1 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1) + + g2 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + + svc.swapListenQueue(leaseName, g2) + + response := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + err := svc.sendToListener(context.Background(), leaseName, response) + if err != nil { + t.Fatalf("iteration %d: sendToListener should succeed for active g2: %v", i, err) + } + + select { + case <-g1.ch: + t.Fatalf("iteration %d: token delivered to superseded g1", i) + default: + } + + select { + case got := <-g2.ch: + if got.RouterEndpoint != "ep" || got.RouterToken != testRouterToken { + t.Fatalf("iteration %d: token corrupted on g2", i) + } + default: + t.Fatalf("iteration %d: token not delivered to active g2", i) + } + } +} + +func TestDialSendToListenerConcurrentWithSwapNeverLandsOnSuperseded(t *testing.T) { + // Race swapListenQueue against sendToListener using goroutines. + // The per-lease mutex guarantees that the Load+send in sendToListener + // is atomic with respect to the Swap+closeDone in swapListenQueue. + // When sendToListener acquires the lock first, it sends to g1 (which + // is still current -- a valid send). When swapListenQueue acquires + // first, sendToListener sees g2 as the current queue. + // + // The invariant: if sendToListener returns nil, the done channel of the + // queue it sent to was NOT closed at the time of the send (guaranteed by + // the lock preventing concurrent swap+closeDone). + iterations := 500 + sentToG1 := 0 + sentToG2 := 0 + rejected := 0 + + for i := 0; i < iterations; i++ { + svc := &ControllerService{} + leaseName := "test-lease-concurrent-serial" + + g1 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), } + svc.swapListenQueue(leaseName, g1) + + g2 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + + swapDone := make(chan struct{}) + sendResult := make(chan error, 1) + + go func() { + defer close(swapDone) + svc.swapListenQueue(leaseName, g2) + }() + go func() { + sendResult <- svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + }() + + <-swapDone + sendErr := <-sendResult + + if sendErr != nil { + rejected++ + continue + } + + onG1 := false + select { + case <-g1.ch: + onG1 = true + sentToG1++ + default: + } + onG2 := false + select { + case <-g2.ch: + onG2 = true + sentToG2++ + default: + } + + if !onG1 && !onG2 { + t.Fatalf("iteration %d: send succeeded but token is lost", i) + } + if onG1 && onG2 { + t.Fatalf("iteration %d: token duplicated across queues", i) + } + } + + if sentToG1+sentToG2+rejected != iterations { + t.Fatalf("accounting error: g1=%d g2=%d rejected=%d total=%d", + sentToG1, sentToG2, rejected, sentToG1+sentToG2+rejected) + } +} + +func TestListenQueueDoneClosedOnNormalExit(t *testing.T) { + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + closeOnce: sync.Once{}, + } + + q.closeDone() + + select { + case <-q.done: + default: + t.Fatal("done channel should be closed after closeDone is called") + } + + q.closeDone() + + select { + case <-q.done: + default: + t.Fatal("done channel should remain closed after duplicate closeDone call") + } +} + +func TestListenQueueSupersessionSignaling(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-supersession" + + g1Queue := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1Queue) + + g2Queue := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g2Queue) + + // Verify G1's done channel is closed. + select { + case <-g1Queue.done: + // expected + default: + t.Fatal("G1 done channel was not closed after supersession") + } + + // Verify G2's done channel is still open. + select { + case <-g2Queue.done: + t.Fatal("G2 done channel should not be closed") + default: + // expected + } + + // CompareAndDelete by G1 should be a no-op (G2 is current). + svc.listenQueues.CompareAndDelete(leaseName, g1Queue) + v, ok := svc.listenQueues.Load(leaseName) + if !ok { + t.Fatal("G1 cleanup deleted the queue that G2 owns") + } + if v != g2Queue { + t.Fatal("queue entry does not match G2's queue") + } +} + +func TestListenQueueDoneClosedBeforeMapDelete(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-defer-order" + + wrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, wrapper) + + v, ok := svc.listenQueues.Load(leaseName) + if !ok { + t.Fatal("queue entry should exist") + } + q := v.(*listenQueue) + + q.closeDone() + svc.listenQueues.CompareAndDelete(leaseName, wrapper) + + // The Dial that loaded q before cleanup must see done is closed. + select { + case <-q.done: + // correct: Dial detects the listener exited + default: + t.Fatal("Dial did not detect listener exit via done channel") + } + + // Map entry should be removed. + if _, ok := svc.listenQueues.Load(leaseName); ok { + t.Fatal("map entry should be removed after cleanup") + } +} + +func TestListenQueueDoneClosedBeforeMapDeleteWithConcurrentDial(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-defer-order-concurrent" + + wrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, wrapper) + + wrapper.closeDone() + + response := &pb.ListenResponse{RouterEndpoint: "ep", RouterToken: testRouterToken} + err := svc.sendToListener(context.Background(), leaseName, response) + if err == nil { + t.Fatal("sendToListener should return error when done is closed before map delete") + } + + svc.listenQueues.CompareAndDelete(leaseName, wrapper) + + select { + case <-wrapper.ch: + t.Fatal("token should not be buffered in a queue whose done was closed first") + default: + } +} + +func TestListenQueueDialReturnsUnavailableWhenNoListener(t *testing.T) { + svc := &ControllerService{} + leaseName := "nonexistent-lease" + + err := svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + if err == nil { + t.Fatal("sendToListener should return error for nonexistent lease") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.Unavailable { + t.Fatalf("expected codes.Unavailable, got %v", st.Code()) + } +} + +func TestListenQueueDialReturnsUnavailableWhenDoneClosed(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-done-closed" + + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, q) + q.closeDone() + + err := svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + if err == nil { + t.Fatal("sendToListener should return error for done queue") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.Unavailable { + t.Fatalf("expected codes.Unavailable, got %v", st.Code()) + } +} + +func TestListenQueueContextCancellationExitsListenLoop(t *testing.T) { + wrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + + ctx, cancel := context.WithCancel(context.Background()) + exited := make(chan struct{}) + + go func() { + defer close(exited) + for { + select { + case <-ctx.Done(): + return + case <-wrapper.done: + return + case <-wrapper.ch: + } + } + }() + + cancel() + + select { + case <-exited: + case <-time.After(time.Second): + t.Fatal("listen loop did not exit after context cancellation") + } +} + +func TestListenQueueConcurrentDialDuringReconnection(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-concurrent-dial" + + g1 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1) + + var deliveredCount int64 + var mu sync.Mutex + + g1ListenerDone := make(chan struct{}) + go func() { + defer close(g1ListenerDone) + for { + select { + case <-g1.done: + return + case <-g1.ch: + mu.Lock() + deliveredCount++ + mu.Unlock() + } + } + }() + + dialAttempts := 50 + var dialWg sync.WaitGroup + var rejectedCount int64 + var rejectedMu sync.Mutex + var sentCount int64 + var sentMu sync.Mutex + + var g2 *listenQueue + g2ListenerDone := make(chan struct{}) + + for i := 0; i < dialAttempts; i++ { + dialWg.Add(1) + go func() { + defer dialWg.Done() + ctx := context.Background() + err := svc.sendToListener(ctx, leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + if err != nil { + rejectedMu.Lock() + rejectedCount++ + rejectedMu.Unlock() + return + } + sentMu.Lock() + sentCount++ + sentMu.Unlock() + }() + + if i == 25 { + g2 = &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g2) + + localG2 := g2 + go func() { + defer close(g2ListenerDone) + for { + select { + case <-localG2.done: + return + case <-localG2.ch: + mu.Lock() + deliveredCount++ + mu.Unlock() + } + } + }() + } + } + + dialWg.Wait() + + <-g1ListenerDone + + drainCount := drainChannel(g1.ch) + + if g2 != nil { + g2.closeDone() + <-g2ListenerDone + drainCount += drainChannel(g2.ch) + } + + mu.Lock() + delivered := deliveredCount + mu.Unlock() + rejectedMu.Lock() + rejected := rejectedCount + rejectedMu.Unlock() + sentMu.Lock() + sent := sentCount + sentMu.Unlock() + + totalHandled := delivered + rejected + int64(drainCount) + if totalHandled != int64(dialAttempts) { + t.Fatalf("expected %d total outcomes, got %d delivered + %d rejected + %d drained = %d", + dialAttempts, delivered, rejected, drainCount, totalHandled) + } + + if sent != delivered+int64(drainCount) { + t.Fatalf("sent count %d does not match delivered %d + drained %d", + sent, delivered, drainCount) + } + + select { + case <-g1.done: + default: + t.Fatal("g1 done channel should be closed after reconnection") + } +} + +func TestListenQueueListenLoopDeliversTokensAndExitsOnDone(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-listen-loop" + + wrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, wrapper) + + delivered := make(chan *pb.ListenResponse, 8) + loopExited := make(chan struct{}) + + go func() { + defer close(loopExited) + defer svc.listenQueues.CompareAndDelete(leaseName, wrapper) + defer wrapper.closeDone() + for { + select { + case <-wrapper.done: + for { + select { + case msg := <-wrapper.ch: + delivered <- msg + default: + return + } + } + case msg := <-wrapper.ch: + delivered <- msg + } + } + }() + + wrapper.ch <- &pb.ListenResponse{RouterEndpoint: "ep1", RouterToken: "tok1"} + wrapper.ch <- &pb.ListenResponse{RouterEndpoint: "ep2", RouterToken: "tok2"} + + for i := 0; i < 2; i++ { + select { + case msg := <-delivered: + if msg.RouterEndpoint == "" || msg.RouterToken == "" { + t.Fatal("received empty token") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for token delivery") + } + } + + superseder := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, superseder) + + select { + case <-loopExited: + case <-time.After(time.Second): + t.Fatal("listen loop did not exit after supersession") + } + + v, ok := svc.listenQueues.Load(leaseName) + if !ok { + t.Fatal("queue entry should still exist for superseder") + } + if v != superseder { + t.Fatal("queue entry should be the superseder") + } +} + +func TestSendToListenerReturnsResourceExhaustedWithCancelledContextAndBufferFull(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-ctx-cancel-buffer-full" + + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, q) + + for i := 0; i < 8; i++ { + q.ch <- &pb.ListenResponse{RouterEndpoint: "fill", RouterToken: "fill"} + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.sendToListener(ctx, leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + if err == nil { + t.Fatal("sendToListener should return error when buffer is full") + } + + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.ResourceExhausted { + t.Fatalf("expected ResourceExhausted, got %v", st.Code()) + } +} + +func TestSendToListenerReturnsImmediatelyDuringBackpressure(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-backpressure-immediate" + + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, q) + + for i := 0; i < 8; i++ { + q.ch <- &pb.ListenResponse{RouterEndpoint: "fill", RouterToken: "fill"} + } + + err := svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + if err == nil { + t.Fatal("sendToListener should return error when buffer is full") + } + + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.ResourceExhausted { + t.Fatalf("expected ResourceExhausted, got %v", st.Code()) + } +} + +func TestListenQueueDialFlowSendsToActiveListener(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-dial-flow" + + wrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, wrapper) + + response := &pb.ListenResponse{RouterEndpoint: "dial-ep", RouterToken: "dial-tok"} + err := svc.sendToListener(context.Background(), leaseName, response) + if err != nil { + t.Fatalf("sendToListener should succeed for active listener: %v", err) + } + + select { + case got := <-wrapper.ch: + if got.RouterEndpoint != "dial-ep" || got.RouterToken != "dial-tok" { + t.Fatal("received corrupted token") + } + default: + t.Fatal("token was not delivered to the active listener") + } +} + +func TestLeaseLockRefCountSingleListener(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-refcount-single" + + svc.acquireLeaseLock(leaseName) + + if _, ok := svc.leaseLocks.Load(leaseName); !ok { + t.Fatal("lease lock should exist after acquire") + } + + svc.releaseLeaseLock(leaseName) + + if _, ok := svc.leaseLocks.Load(leaseName); ok { + t.Fatal("lease lock should be removed when last reference is released") + } +} + +func TestLeaseLockRefCountOverlappingListeners(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-refcount-overlap" + + svc.acquireLeaseLock(leaseName) + svc.acquireLeaseLock(leaseName) + + if _, ok := svc.leaseLocks.Load(leaseName); !ok { + t.Fatal("lease lock should exist with two references") + } + + svc.releaseLeaseLock(leaseName) + + if _, ok := svc.leaseLocks.Load(leaseName); !ok { + t.Fatal("lease lock should still exist with one remaining reference") + } + + svc.releaseLeaseLock(leaseName) + + if _, ok := svc.leaseLocks.Load(leaseName); ok { + t.Fatal("lease lock should be removed when all references are released") + } +} + +func TestLeaseLockRefCountConcurrentAcquireRelease(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-refcount-concurrent" + + var wg sync.WaitGroup + goroutines := 100 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + svc.acquireLeaseLock(leaseName) + svc.releaseLeaseLock(leaseName) + }() + } + + wg.Wait() + + if _, ok := svc.leaseLocks.Load(leaseName); ok { + t.Fatal("lease lock should be removed after all goroutines release") + } +} + +func TestLeaseLockRefCountConcurrentOverlappingListeners(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-refcount-concurrent-overlap" + goroutines := 50 + + var counter int + var wg sync.WaitGroup + allAcquired := sync.WaitGroup{} + allAcquired.Add(goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + mu := svc.acquireLeaseLock(leaseName) + defer svc.releaseLeaseLock(leaseName) + + allAcquired.Done() + allAcquired.Wait() + + mu.Lock() + counter++ + mu.Unlock() + }() + } + + wg.Wait() + + if counter != goroutines { + t.Fatalf("expected counter=%d, got %d", goroutines, counter) + } + + if _, ok := svc.leaseLocks.Load(leaseName); ok { + t.Fatal("lease lock should be removed after all goroutines release") + } +} + +func TestLeaseLockRefCountSameInstanceForOverlap(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-refcount-same-instance" + + lock1 := svc.acquireLeaseLock(leaseName) + lock2 := svc.acquireLeaseLock(leaseName) + + if lock1 != lock2 { + t.Fatal("overlapping acquires must return the same mutex") + } + + svc.releaseLeaseLock(leaseName) + svc.releaseLeaseLock(leaseName) +} + +func TestLeaseLockPreservedWhenNewListenerTakesOver(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-lock-preserved" + + g1Mu := svc.acquireLeaseLock(leaseName) + g1 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1) + + g2Mu := svc.acquireLeaseLock(leaseName) + g2 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g2) + + if g1Mu != g2Mu { + t.Fatal("overlapping listeners must share the same mutex") + } + + g1Mu.Lock() + g1.closeDone() + g1Mu.Unlock() + svc.listenQueues.CompareAndDelete(leaseName, g1) + svc.releaseLeaseLock(leaseName) + + if _, ok := svc.leaseLocks.Load(leaseName); !ok { + t.Fatal("lease lock should be preserved when a new listener still holds a reference") + } + + if _, ok := svc.listenQueues.Load(leaseName); !ok { + t.Fatal("queue should still exist for the new listener") + } + + g2Mu.Lock() + g2.closeDone() + g2Mu.Unlock() + svc.listenQueues.CompareAndDelete(leaseName, g2) + svc.releaseLeaseLock(leaseName) + + if _, ok := svc.leaseLocks.Load(leaseName); ok { + t.Fatal("lease lock should be cleaned up when last listener releases") + } +} + +func TestSendToListenerReturnsResourceExhaustedWhenBufferFull(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-buffer-full-nonblocking" + + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, q) + + for i := 0; i < 8; i++ { + q.ch <- &pb.ListenResponse{RouterEndpoint: "fill", RouterToken: "fill"} + } + + err := svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + if err == nil { + t.Fatal("sendToListener should return error when buffer is full") + } + + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.ResourceExhausted { + t.Fatalf("expected ResourceExhausted, got %v", st.Code()) + } +} + +func TestSendToListenerDoesNotBlockMutexWhenBufferFull(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-no-mutex-block" + + q := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, q) + + for i := 0; i < 8; i++ { + q.ch <- &pb.ListenResponse{RouterEndpoint: "fill", RouterToken: "fill"} + } + + sendDone := make(chan struct{}) + sendErr := make(chan error, 1) + go func() { + defer close(sendDone) + sendErr <- svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + }() + + select { + case <-sendDone: + if err := <-sendErr; err == nil { + t.Fatal("sendToListener should return error when buffer is full") + } + case <-time.After(time.Second): + t.Fatal("sendToListener blocked when buffer was full; mutex held too long") + } + + swapDone := make(chan struct{}) + go func() { + defer close(swapDone) + g2 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g2) + }() + + select { + case <-swapDone: + case <-time.After(time.Second): + t.Fatal("swapListenQueue blocked because sendToListener held the mutex on full buffer") + } +} + +func TestSwapNotBlockedWhenBufferFull(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-no-deadlock-chain" + + g1 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, g1) + + for i := 0; i < 8; i++ { + g1.ch <- &pb.ListenResponse{RouterEndpoint: "fill", RouterToken: "fill"} + } + + err := svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep", RouterToken: testRouterToken, + }) + if err == nil { + t.Fatal("sendToListener should return error when buffer is full") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.ResourceExhausted { + t.Fatalf("expected ResourceExhausted, got %v", st.Code()) + } + + g2 := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + swapDone := make(chan struct{}) + go func() { + defer close(swapDone) + svc.swapListenQueue(leaseName, g2) + }() + + select { + case <-swapDone: + case <-time.After(2 * time.Second): + t.Fatal("swapListenQueue should not be blocked when sendToListener returned immediately") + } + + select { + case <-g1.done: + default: + t.Fatal("g1 done channel should be closed after swap") + } + + v, loaded := svc.listenQueues.Load(leaseName) + if !loaded { + t.Fatal("queue should exist for g2") + } + if v != g2 { + t.Fatal("active queue should be g2") + } +} + +func TestListenQueueDrainsBufferedTokensOnSupersession(t *testing.T) { + svc := &ControllerService{} + leaseName := "test-lease-drain-on-supersession" + + wrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, wrapper) + + err := svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep1", RouterToken: "tok1", + }) + if err != nil { + t.Fatalf("first sendToListener failed: %v", err) + } + err = svc.sendToListener(context.Background(), leaseName, &pb.ListenResponse{ + RouterEndpoint: "ep2", RouterToken: "tok2", + }) + if err != nil { + t.Fatalf("second sendToListener failed: %v", err) + } + + superseder := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + svc.swapListenQueue(leaseName, superseder) + + select { + case <-wrapper.done: + default: + t.Fatal("wrapper.done should be closed after supersession") + } + + if len(wrapper.ch) != 2 { + t.Fatalf("expected 2 buffered tokens before drain, got %d", len(wrapper.ch)) + } + + drained := drainChannel(wrapper.ch) + if drained != 2 { + t.Fatalf("expected 2 tokens to drain from superseded queue, got %d", drained) + } +} + +func TestListenQueueListenLoopDrainsOnSupersession(t *testing.T) { + wrapper := &listenQueue{ + ch: make(chan *pb.ListenResponse, 8), + done: make(chan struct{}), + } + + wrapper.ch <- &pb.ListenResponse{RouterEndpoint: "ep1", RouterToken: "tok1"} + wrapper.ch <- &pb.ListenResponse{RouterEndpoint: "ep2", RouterToken: "tok2"} + + wrapper.closeDone() + + delivered := make(chan *pb.ListenResponse, 8) + loopExited := make(chan struct{}) + + go func() { + defer close(loopExited) + for { + select { + case <-wrapper.done: + for { + select { + case msg := <-wrapper.ch: + delivered <- msg + default: + return + } + } + case msg := <-wrapper.ch: + delivered <- msg + } + } + }() + + select { + case <-loopExited: + case <-time.After(time.Second): + t.Fatal("listen loop did not exit after done was closed") + } + + close(delivered) + var count int + for range delivered { + count++ + } + if count != 2 { + t.Fatalf("expected 2 drained tokens from listen loop, got %d", count) } - return false }