diff --git a/nexus-broker/internal/service/connection_health_test.go b/nexus-broker/internal/service/connection_health_test.go index f39e599..ddaafe6 100644 --- a/nexus-broker/internal/service/connection_health_test.go +++ b/nexus-broker/internal/service/connection_health_test.go @@ -17,6 +17,20 @@ import ( "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/service" ) +func runWorkerUntilSignal(t *testing.T, worker *service.ConnectionHealthWorker, done <-chan struct{}) { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go worker.Start(ctx) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for health worker signal") + } +} + // Add missing mock methods to MockConnectionRepository func (m *MockConnectionRepository) GetForHealthCheck(ctx context.Context, limit int) ([]*domain.ConnectionWithProvider, error) { args := m.Called(ctx, limit) @@ -126,15 +140,14 @@ func TestConnectionHealthWorker_OAuth2_Healthy(t *testing.T) { mockSvc.On("Refresh", mock.Anything, connID).Return(&service.RefreshResponse{}, nil).Once() // Should update health to healthy - mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "healthy").Return(nil).Once() + done := make(chan struct{}) + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "healthy"). + Run(func(args mock.Arguments) { close(done) }). + Return(nil).Once() worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) - - ctx, cancel := context.WithCancel(context.Background()) - go worker.Start(ctx) - - time.Sleep(50 * time.Millisecond) // Give it time to run at least once - cancel() + + runWorkerUntilSignal(t, worker, done) mockRepo.AssertExpectations(t) mockSvc.AssertExpectations(t) @@ -169,15 +182,14 @@ func TestConnectionHealthWorker_OAuth2_Expired(t *testing.T) { mockRepo.On("UpdateStatus", mock.Anything, connID, "expired").Return(nil).Once() // Should update health to expired - mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "expired").Return(nil).Once() + done := make(chan struct{}) + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "expired"). + Run(func(args mock.Arguments) { close(done) }). + Return(nil).Once() worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) - - ctx, cancel := context.WithCancel(context.Background()) - go worker.Start(ctx) - - time.Sleep(50 * time.Millisecond) - cancel() + + runWorkerUntilSignal(t, worker, done) mockRepo.AssertExpectations(t) mockSvc.AssertExpectations(t) @@ -213,15 +225,14 @@ func TestConnectionHealthWorker_OAuth2_ProviderDown_ShieldsExpiration(t *testing // Should NOT call UpdateStatus (no expiration) // Should update health to "unhealthy" instead of "expired" - mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "unhealthy").Return(nil).Once() + done := make(chan struct{}) + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "unhealthy"). + Run(func(args mock.Arguments) { close(done) }). + Return(nil).Once() worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) - ctx, cancel := context.WithCancel(context.Background()) - go worker.Start(ctx) - - time.Sleep(50 * time.Millisecond) - cancel() + runWorkerUntilSignal(t, worker, done) mockRepo.AssertExpectations(t) mockSvc.AssertExpectations(t) @@ -267,15 +278,14 @@ func TestConnectionHealthWorker_APIKey_Expired(t *testing.T) { mockRepo.On("UpdateStatus", mock.Anything, connID, "expired").Return(nil).Once() // Should update health to expired - mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "expired").Return(nil).Once() + done := make(chan struct{}) + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "expired"). + Run(func(args mock.Arguments) { close(done) }). + Return(nil).Once() worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) - - ctx, cancel := context.WithCancel(context.Background()) - go worker.Start(ctx) - - time.Sleep(50 * time.Millisecond) - cancel() + + runWorkerUntilSignal(t, worker, done) mockRepo.AssertExpectations(t) mockSvc.AssertExpectations(t) @@ -304,15 +314,14 @@ func TestConnectionHealthWorker_OAuth2_Upstream5xx_MarksUnhealthy(t *testing.T) // Should set health_status to "unhealthy", NOT "expired" // Should NOT call UpdateStatus — connection status stays "active" - mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "unhealthy").Return(nil).Once() + done := make(chan struct{}) + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "unhealthy"). + Run(func(args mock.Arguments) { close(done) }). + Return(nil).Once() worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) - ctx, cancel := context.WithCancel(context.Background()) - go worker.Start(ctx) - - time.Sleep(50 * time.Millisecond) - cancel() + runWorkerUntilSignal(t, worker, done) mockRepo.AssertExpectations(t) mockSvc.AssertExpectations(t) @@ -340,15 +349,14 @@ func TestConnectionHealthWorker_OAuth2_403_MarksDegraded(t *testing.T) { mockSvc.On("Refresh", mock.Anything, connID).Return(&service.RefreshResponse{StatusCode: 403}, errors.New("forbidden")).Once() // Should set health_status to "degraded", NOT "expired" - mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "degraded").Return(nil).Once() + done := make(chan struct{}) + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "degraded"). + Run(func(args mock.Arguments) { close(done) }). + Return(nil).Once() worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) - ctx, cancel := context.WithCancel(context.Background()) - go worker.Start(ctx) - - time.Sleep(50 * time.Millisecond) - cancel() + runWorkerUntilSignal(t, worker, done) mockRepo.AssertExpectations(t) mockSvc.AssertExpectations(t) @@ -376,15 +384,14 @@ func TestConnectionHealthWorker_OAuth2_NetworkError_MarksDegraded(t *testing.T) mockSvc.On("Refresh", mock.Anything, connID).Return((*service.RefreshResponse)(nil), errors.New("connection refused")).Once() // Should set health_status to "degraded" (we don't know if credential is valid) - mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "degraded").Return(nil).Once() + done := make(chan struct{}) + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "degraded"). + Run(func(args mock.Arguments) { close(done) }). + Return(nil).Once() worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) - ctx, cancel := context.WithCancel(context.Background()) - go worker.Start(ctx) - - time.Sleep(50 * time.Millisecond) - cancel() + runWorkerUntilSignal(t, worker, done) mockRepo.AssertExpectations(t) mockSvc.AssertExpectations(t)