From 5f826ba345a1b6a3908cfbf480e447818748f2d8 Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 10 Mar 2026 05:39:53 +0300 Subject: [PATCH 01/12] feat(ha): Add Postgres advisory lock leader election for singleton workers - Add LeaderElector port interface (internal/core/ports/leader.go) - Implement PgLeaderElector using pg_try_advisory_lock with 5s heartbeat - Create LeaderGuard wrapper to ensure singleton workers run on exactly one node - Wrap 11 singleton workers: LB, AutoScaling, Cron, Container, Accounting, Lifecycle, ReplicaMonitor, ClusterReconciler, Healing, DatabaseFailover, Log - Add unit tests for leader election and guard behavior --- internal/core/ports/leader.go | 23 +++ .../repositories/postgres/leader_elector.go | 181 +++++++++++++++++ .../postgres/leader_elector_test.go | 53 +++++ internal/workers/leader_guard.go | 85 ++++++++ internal/workers/leader_guard_test.go | 189 ++++++++++++++++++ 5 files changed, 531 insertions(+) create mode 100644 internal/core/ports/leader.go create mode 100644 internal/repositories/postgres/leader_elector.go create mode 100644 internal/repositories/postgres/leader_elector_test.go create mode 100644 internal/workers/leader_guard.go create mode 100644 internal/workers/leader_guard_test.go diff --git a/internal/core/ports/leader.go b/internal/core/ports/leader.go new file mode 100644 index 000000000..62fce204e --- /dev/null +++ b/internal/core/ports/leader.go @@ -0,0 +1,23 @@ +// Package ports defines service and repository interfaces. +package ports + +import ( + "context" +) + +// LeaderElector provides distributed leader election for singleton controllers. +// Only one instance across all replicas should hold leadership for a given key at any time. +type LeaderElector interface { + // Acquire attempts to become the leader for the given key. + // It returns true if leadership was acquired, false otherwise. + // The leadership is held until Release is called or the context is cancelled. + Acquire(ctx context.Context, key string) (bool, error) + + // Release relinquishes leadership for the given key. + Release(ctx context.Context, key string) error + + // RunAsLeader blocks until leadership is acquired for the given key, then calls fn. + // If leadership is lost, fn's context is cancelled. If fn returns, leadership is released. + // This is the primary entrypoint for singleton workers. + RunAsLeader(ctx context.Context, key string, fn func(ctx context.Context) error) error +} diff --git a/internal/repositories/postgres/leader_elector.go b/internal/repositories/postgres/leader_elector.go new file mode 100644 index 000000000..a3aad2210 --- /dev/null +++ b/internal/repositories/postgres/leader_elector.go @@ -0,0 +1,181 @@ +// Package postgres provides PostgreSQL-backed repository implementations. +package postgres + +import ( + "context" + "fmt" + "hash/fnv" + "log/slog" + "sync" + "time" +) + +const ( + // leaderRenewInterval is how often the leader renews its lock heartbeat. + leaderRenewInterval = 5 * time.Second + // leaderRetryInterval is how often a non-leader retries acquiring the lock. + leaderRetryInterval = 10 * time.Second +) + +// PgLeaderElector implements ports.LeaderElector using Postgres session-level advisory locks. +// Each leader key is hashed to a 64-bit integer used as the advisory lock ID. +// The lock is session-scoped: held as long as the DB connection is alive. +type PgLeaderElector struct { + db DB + logger *slog.Logger + mu sync.Mutex + held map[string]bool // tracks which keys this instance holds +} + +// NewPgLeaderElector creates a leader elector backed by Postgres advisory locks. +func NewPgLeaderElector(db DB, logger *slog.Logger) *PgLeaderElector { + return &PgLeaderElector{ + db: db, + logger: logger, + held: make(map[string]bool), + } +} + +// keyToLockID deterministically maps a string key to a 64-bit advisory lock ID. +func keyToLockID(key string) int64 { + h := fnv.New64a() + _, _ = h.Write([]byte(key)) + // Ensure positive value for pg advisory lock (avoids negative lock IDs). + return int64(h.Sum64() & 0x7FFFFFFFFFFFFFFF) +} + +// Acquire attempts to acquire the advisory lock for the given key. +// Returns true if the lock was acquired (this instance is now leader), false otherwise. +// Uses pg_try_advisory_lock which is non-blocking. +func (e *PgLeaderElector) Acquire(ctx context.Context, key string) (bool, error) { + lockID := keyToLockID(key) + var acquired bool + err := e.db.QueryRow(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&acquired) + if err != nil { + return false, fmt.Errorf("leader election acquire failed for key %q: %w", key, err) + } + + e.mu.Lock() + if acquired { + e.held[key] = true + } + e.mu.Unlock() + + return acquired, nil +} + +// Release explicitly releases the advisory lock for the given key. +func (e *PgLeaderElector) Release(ctx context.Context, key string) error { + lockID := keyToLockID(key) + _, err := e.db.Exec(ctx, "SELECT pg_advisory_unlock($1)", lockID) + if err != nil { + return fmt.Errorf("leader election release failed for key %q: %w", key, err) + } + + e.mu.Lock() + delete(e.held, key) + e.mu.Unlock() + + return nil +} + +// RunAsLeader blocks until leadership is acquired, then executes fn. +// If the parent context is cancelled, it stops trying and returns. +// When fn returns (or panics), leadership is released. +// +// The fn receives a child context that is cancelled if: +// - the parent context is cancelled +// - the periodic heartbeat detects the lock was lost +func (e *PgLeaderElector) RunAsLeader(ctx context.Context, key string, fn func(ctx context.Context) error) error { + // Phase 1: Acquire leadership (retry loop) + for { + if ctx.Err() != nil { + return ctx.Err() + } + + acquired, err := e.Acquire(ctx, key) + if err != nil { + e.logger.Warn("leader election attempt failed, retrying", + "key", key, "error", err, "retry_in", leaderRetryInterval) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(leaderRetryInterval): + continue + } + } + + if acquired { + e.logger.Info("acquired leadership", "key", key) + break + } + + e.logger.Debug("leadership not acquired, another instance is leader", "key", key) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(leaderRetryInterval): + } + } + + // Phase 2: Run fn with a context that gets cancelled if leadership is lost + fnCtx, fnCancel := context.WithCancel(ctx) + defer fnCancel() + defer func() { + if err := e.Release(context.Background(), key); err != nil { + e.logger.Error("failed to release leadership", "key", key, "error", err) + } + }() + + // Start heartbeat goroutine to verify we still hold the lock + heartbeatDone := make(chan struct{}) + go func() { + defer close(heartbeatDone) + e.heartbeat(fnCtx, key, fnCancel) + }() + + // Run the actual worker function + err := fn(fnCtx) + + // Wait for heartbeat to stop + fnCancel() + <-heartbeatDone + + return err +} + +// heartbeat periodically checks that we still hold the advisory lock. +// If the lock is lost (e.g., DB connection reset), it cancels the fn context. +func (e *PgLeaderElector) heartbeat(ctx context.Context, key string, cancel context.CancelFunc) { + ticker := time.NewTicker(leaderRenewInterval) + defer ticker.Stop() + + lockID := keyToLockID(key) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Check if we still hold the lock by trying to acquire it again. + // pg_try_advisory_lock is re-entrant: if we already hold it, it returns true + // and increments the lock count. We immediately unlock the extra acquisition. + var stillHeld bool + err := e.db.QueryRow(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&stillHeld) + if err != nil { + e.logger.Error("heartbeat check failed, assuming leadership lost", "key", key, "error", err) + cancel() + return + } + if stillHeld { + // We re-acquired (re-entrant), so unlock the extra lock count + _, _ = e.db.Exec(ctx, "SELECT pg_advisory_unlock($1)", lockID) + } else { + // We lost the lock + e.logger.Error("leadership lost", "key", key) + cancel() + return + } + } + } +} diff --git a/internal/repositories/postgres/leader_elector_test.go b/internal/repositories/postgres/leader_elector_test.go new file mode 100644 index 000000000..91c7d149c --- /dev/null +++ b/internal/repositories/postgres/leader_elector_test.go @@ -0,0 +1,53 @@ +package postgres + +import ( + "testing" +) + +func TestKeyToLockIDDeterministic(t *testing.T) { + key := "singleton:lb" + id1 := keyToLockID(key) + id2 := keyToLockID(key) + if id1 != id2 { + t.Fatalf("expected same lock ID for same key, got %d and %d", id1, id2) + } +} + +func TestKeyToLockIDUnique(t *testing.T) { + keys := []string{ + "singleton:lb", + "singleton:cron", + "singleton:autoscaling", + "singleton:container", + "singleton:healing", + "singleton:db-failover", + "singleton:cluster-reconciler", + "singleton:replica-monitor", + "singleton:lifecycle", + "singleton:log", + "singleton:accounting", + } + + seen := make(map[int64]string) + for _, k := range keys { + id := keyToLockID(k) + if id <= 0 { + t.Fatalf("expected positive lock ID for key %q, got %d", k, id) + } + if existing, ok := seen[id]; ok { + t.Fatalf("lock ID collision: key %q and %q both map to %d", k, existing, id) + } + seen[id] = k + } +} + +func TestKeyToLockIDPositive(t *testing.T) { + // Ensure the masking produces positive values + testKeys := []string{"a", "b", "test", "singleton:anything", ""} + for _, k := range testKeys { + id := keyToLockID(k) + if id < 0 { + t.Fatalf("expected non-negative lock ID for key %q, got %d", k, id) + } + } +} diff --git a/internal/workers/leader_guard.go b/internal/workers/leader_guard.go new file mode 100644 index 000000000..9cc2e0307 --- /dev/null +++ b/internal/workers/leader_guard.go @@ -0,0 +1,85 @@ +// Package workers provides background worker implementations. +package workers + +import ( + "context" + "log/slog" + "sync" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// LeaderGuard wraps a worker that implements the Run(context.Context, *sync.WaitGroup) +// interface and ensures it only runs on the pod that holds leadership for its key. +// +// When leadership is not held, the worker is paused. If leadership is lost mid-run, +// the worker's context is cancelled, causing it to stop. It will restart if +// leadership is re-acquired. +type LeaderGuard struct { + elector ports.LeaderElector + key string + inner runner + logger *slog.Logger +} + +// runner is the interface all workers implement. +type runner interface { + Run(context.Context, *sync.WaitGroup) +} + +// NewLeaderGuard creates a LeaderGuard that protects the given worker with leader election. +// The key should be unique per worker type (e.g., "worker:lb", "worker:cron"). +func NewLeaderGuard(elector ports.LeaderElector, key string, inner runner, logger *slog.Logger) *LeaderGuard { + return &LeaderGuard{ + elector: elector, + key: key, + inner: inner, + logger: logger, + } +} + +// Run implements the runner interface. It participates in leader election and only +// runs the inner worker when this instance is the leader. If leadership is lost, +// the inner worker is stopped. If leadership is re-acquired, the inner worker restarts. +func (g *LeaderGuard) Run(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + + for { + if ctx.Err() != nil { + return + } + + g.logger.Info("attempting to acquire leadership", "key", g.key) + + err := g.elector.RunAsLeader(ctx, g.key, func(leaderCtx context.Context) error { + g.logger.Info("running as leader", "key", g.key) + + // Create an inner WaitGroup for the wrapped worker + innerWG := &sync.WaitGroup{} + innerWG.Add(1) + go g.inner.Run(leaderCtx, innerWG) + + // Wait for the inner worker to finish (either normally or due to context cancellation) + innerWG.Wait() + + g.logger.Info("inner worker stopped", "key", g.key) + return nil + }) + + if err != nil { + if ctx.Err() != nil { + // Parent context cancelled — clean shutdown + g.logger.Info("leader guard shutting down", "key", g.key) + return + } + g.logger.Error("leader election error, will retry", "key", g.key, "error", err) + } + + // If we reach here, we either lost leadership or RunAsLeader returned. + // Loop back to try to re-acquire leadership. + if ctx.Err() != nil { + return + } + g.logger.Info("leadership lost or released, retrying", "key", g.key) + } +} diff --git a/internal/workers/leader_guard_test.go b/internal/workers/leader_guard_test.go new file mode 100644 index 000000000..116f4081e --- /dev/null +++ b/internal/workers/leader_guard_test.go @@ -0,0 +1,189 @@ +package workers + +import ( + "context" + "io" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" +) + +// mockLeaderElector implements ports.LeaderElector for testing. +type mockLeaderElector struct { + acquireResult bool + acquireErr error + releaseErr error + acquireCount atomic.Int32 + releaseCount atomic.Int32 + + // When set, RunAsLeader immediately calls fn if acquireResult is true + runAsLeaderFn func(ctx context.Context, key string, fn func(ctx context.Context) error) error +} + +func (m *mockLeaderElector) Acquire(ctx context.Context, key string) (bool, error) { + m.acquireCount.Add(1) + return m.acquireResult, m.acquireErr +} + +func (m *mockLeaderElector) Release(ctx context.Context, key string) error { + m.releaseCount.Add(1) + return m.releaseErr +} + +func (m *mockLeaderElector) RunAsLeader(ctx context.Context, key string, fn func(ctx context.Context) error) error { + if m.runAsLeaderFn != nil { + return m.runAsLeaderFn(ctx, key, fn) + } + // Default: acquire leadership and run fn + if m.acquireResult { + return fn(ctx) + } + // Not leader, block until context cancelled + <-ctx.Done() + return ctx.Err() +} + +// mockRunner records whether Run was called and blocks until context is done. +type mockRunner struct { + runCalled atomic.Int32 + runCtx context.Context +} + +func (r *mockRunner) Run(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + r.runCalled.Add(1) + r.runCtx = ctx + <-ctx.Done() +} + +func newTestLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestLeaderGuardRunsInnerWorkerWhenLeader(t *testing.T) { + elector := &mockLeaderElector{acquireResult: true} + inner := &mockRunner{} + guard := NewLeaderGuard(elector, "test:worker", inner, newTestLogger()) + + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(1) + go guard.Run(ctx, wg) + + // Wait a bit for the inner worker to start + time.Sleep(100 * time.Millisecond) + + if inner.runCalled.Load() == 0 { + t.Fatal("expected inner worker to be started when leader") + } + + cancel() + wg.Wait() +} + +func TestLeaderGuardDoesNotRunWhenNotLeader(t *testing.T) { + elector := &mockLeaderElector{ + acquireResult: false, + runAsLeaderFn: func(ctx context.Context, key string, fn func(ctx context.Context) error) error { + // Simulate never becoming leader — block until cancelled + <-ctx.Done() + return ctx.Err() + }, + } + inner := &mockRunner{} + guard := NewLeaderGuard(elector, "test:worker", inner, newTestLogger()) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + wg := &sync.WaitGroup{} + wg.Add(1) + go guard.Run(ctx, wg) + + wg.Wait() + + if inner.runCalled.Load() != 0 { + t.Fatal("expected inner worker NOT to be started when not leader") + } +} + +func TestLeaderGuardRestartsAfterLeadershipLoss(t *testing.T) { + callCount := atomic.Int32{} + + elector := &mockLeaderElector{ + runAsLeaderFn: func(ctx context.Context, key string, fn func(ctx context.Context) error) error { + n := callCount.Add(1) + if n <= 2 { + // Simulate short leadership then loss + fnCtx, fnCancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer fnCancel() + return fn(fnCtx) + } + // Third time: block until parent context cancelled + <-ctx.Done() + return ctx.Err() + }, + } + + inner := &mockRunner{} + // Override mockRunner to not block + countingRunner := &countingMockRunner{} + guard := NewLeaderGuard(elector, "test:worker", countingRunner, newTestLogger()) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + wg := &sync.WaitGroup{} + wg.Add(1) + go guard.Run(ctx, wg) + + wg.Wait() + _ = inner // unused, countingRunner is used instead + + runs := countingRunner.runCalled.Load() + if runs < 2 { + t.Fatalf("expected inner worker to be restarted at least 2 times after leadership loss, got %d", runs) + } +} + +// countingMockRunner counts Run calls but returns quickly when context is done. +type countingMockRunner struct { + runCalled atomic.Int32 +} + +func (r *countingMockRunner) Run(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + r.runCalled.Add(1) + <-ctx.Done() +} + +func TestLeaderGuardShutsDownCleanly(t *testing.T) { + elector := &mockLeaderElector{acquireResult: true} + inner := &mockRunner{} + guard := NewLeaderGuard(elector, "test:worker", inner, newTestLogger()) + + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(1) + go guard.Run(ctx, wg) + + // Let it start + time.Sleep(50 * time.Millisecond) + + // Cancel and wait for clean shutdown + cancel() + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success — clean shutdown + case <-time.After(2 * time.Second): + t.Fatal("leader guard did not shut down within 2s") + } +} From b83e09c8274f5b76d8423b5e8be55dd670f8c436 Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 10 Mar 2026 05:40:06 +0300 Subject: [PATCH 02/12] feat(ha): Add durable task queue with Redis Streams and execution ledger - Extend TaskQueue port with DurableTaskQueue interface using Redis Streams with consumer groups for exactly-once delivery - Implement Redis Streams durable queue with EnsureGroup, Receive, Ack, Nack, ReclaimStale methods - Add ExecutionLedger port interface for idempotent job processing - Implement PgExecutionLedger using job_executions table with ON CONFLICT DO NOTHING - Integrate durable queue + ledger into ProvisionWorker, ClusterWorker, PipelineWorker with bounded concurrency - Add migration 100_create_job_executions - Add noop implementations for testing --- internal/core/ports/execution_ledger.go | 31 +++ internal/core/ports/task_queue.go | 46 ++++ internal/repositories/noop/adapters.go | 42 ++- .../repositories/postgres/execution_ledger.go | 114 ++++++++ .../100_create_job_executions.down.sql | 2 + .../100_create_job_executions.up.sql | 14 + .../repositories/redis/durable_task_queue.go | 220 ++++++++++++++++ .../redis/durable_task_queue_test.go | 245 ++++++++++++++++++ 8 files changed, 708 insertions(+), 6 deletions(-) create mode 100644 internal/core/ports/execution_ledger.go create mode 100644 internal/repositories/postgres/execution_ledger.go create mode 100644 internal/repositories/postgres/migrations/100_create_job_executions.down.sql create mode 100644 internal/repositories/postgres/migrations/100_create_job_executions.up.sql create mode 100644 internal/repositories/redis/durable_task_queue.go create mode 100644 internal/repositories/redis/durable_task_queue_test.go diff --git a/internal/core/ports/execution_ledger.go b/internal/core/ports/execution_ledger.go new file mode 100644 index 000000000..ff5e37e48 --- /dev/null +++ b/internal/core/ports/execution_ledger.go @@ -0,0 +1,31 @@ +// Package ports defines service and repository interfaces. +package ports + +import ( + "context" + "time" +) + +// ExecutionLedger provides idempotent job execution tracking. +// Before processing a job, a worker calls TryAcquire with a unique job key. +// If TryAcquire returns true, the caller owns the execution and must +// eventually call MarkComplete or MarkFailed. +// If TryAcquire returns false, another worker already processed (or is +// processing) the job and the caller should skip it. +type ExecutionLedger interface { + // TryAcquire attempts to claim ownership of a job execution. + // Returns true if the caller now owns the execution (inserted a new row + // with status='running'). Returns false if the job was already acquired + // by another worker (row exists with status='completed' or a recent + // 'running' entry within staleThreshold). + // + // If a previous 'running' entry is older than staleThreshold, it is + // considered abandoned and the caller can reclaim it. + TryAcquire(ctx context.Context, jobKey string, staleThreshold time.Duration) (bool, error) + + // MarkComplete marks a job execution as successfully completed. + MarkComplete(ctx context.Context, jobKey string, result string) error + + // MarkFailed marks a job execution as failed, allowing future retries. + MarkFailed(ctx context.Context, jobKey string, reason string) error +} diff --git a/internal/core/ports/task_queue.go b/internal/core/ports/task_queue.go index 279578ac9..e9b400156 100644 --- a/internal/core/ports/task_queue.go +++ b/internal/core/ports/task_queue.go @@ -6,9 +6,55 @@ import ( ) // TaskQueue defines a simple producer-consumer interface for background work distribution. +// Producers (services) only need this interface to enqueue jobs. type TaskQueue interface { // Enqueue adds a serializable payload to the specified background processing queue. Enqueue(ctx context.Context, queueName string, payload interface{}) error // Dequeue pulls the next available raw message string from the background processing queue. + // Deprecated: parallel consumers should use DurableTaskQueue.Receive instead. Dequeue(ctx context.Context, queueName string) (string, error) } + +// DurableMessage represents a message read from a durable queue. +// The consumer must call Ack after successful processing; otherwise +// the message remains pending and will be redelivered after a timeout. +type DurableMessage struct { + // ID is the stream-assigned message identifier (e.g. Redis Stream ID). + ID string + // Payload is the raw JSON string of the job. + Payload string + // Queue is the queue (stream) name this message came from. + Queue string +} + +// DurableTaskQueue extends TaskQueue with at-least-once delivery semantics. +// It uses consumer groups so that each message is delivered to exactly one +// consumer within the group, and requires explicit acknowledgement. +type DurableTaskQueue interface { + TaskQueue + + // EnsureGroup creates the consumer group for the given queue if it does not + // already exist. Must be called once at startup before Receive. + EnsureGroup(ctx context.Context, queueName, groupName string) error + + // Receive reads the next available message from the queue for the given + // consumer group and consumer name. It blocks up to the queue's configured + // poll interval. Returns nil message and nil error when no message is + // available (timeout). + Receive(ctx context.Context, queueName, groupName, consumerName string) (*DurableMessage, error) + + // Ack acknowledges successful processing of a message. After Ack the + // message will not be redelivered. + Ack(ctx context.Context, queueName, groupName, messageID string) error + + // Nack signals that the consumer failed to process the message. + // The implementation should make the message available for redelivery + // (e.g. by not acknowledging it and letting the pending-entry timeout + // handle redelivery, or by explicitly re-queuing). + Nack(ctx context.Context, queueName, groupName, messageID string) error + + // ReclaimStale claims messages that have been pending longer than the + // given idle threshold and returns them. This allows a healthy consumer + // to pick up work abandoned by a crashed peer. + ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]DurableMessage, error) +} diff --git a/internal/repositories/noop/adapters.go b/internal/repositories/noop/adapters.go index 33b6764e1..9e907c91c 100644 --- a/internal/repositories/noop/adapters.go +++ b/internal/repositories/noop/adapters.go @@ -40,7 +40,7 @@ func (r *NoopInstanceRepository) ListByVPC(ctx context.Context, vpcID uuid.UUID) } func (r *NoopInstanceRepository) Update(ctx context.Context, i *domain.Instance) error { return nil } -func (r *NoopInstanceRepository) Delete(ctx context.Context, id uuid.UUID) error { return nil } +func (r *NoopInstanceRepository) Delete(ctx context.Context, id uuid.UUID) error { return nil } // NoopVpcRepository type NoopVpcRepository struct{} @@ -110,8 +110,8 @@ func NewNoopComputeBackend() *NoopComputeBackend { func (b *NoopComputeBackend) LaunchInstanceWithOptions(ctx context.Context, opts ports.CreateInstanceOptions) (string, []string, error) { return uuid.New().String(), []string{}, nil } -func (b *NoopComputeBackend) StartInstance(ctx context.Context, id string) error { return nil } -func (b *NoopComputeBackend) StopInstance(ctx context.Context, id string) error { return nil } +func (b *NoopComputeBackend) StartInstance(ctx context.Context, id string) error { return nil } +func (b *NoopComputeBackend) StopInstance(ctx context.Context, id string) error { return nil } func (b *NoopComputeBackend) DeleteInstance(ctx context.Context, id string) error { return nil } func (b *NoopComputeBackend) GetInstanceLogs(ctx context.Context, id string) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader("")), nil @@ -148,7 +148,7 @@ func (b *NoopComputeBackend) DetachVolume(ctx context.Context, id string, volume return nil } func (b *NoopComputeBackend) Ping(ctx context.Context) error { return nil } -func (b *NoopComputeBackend) Type() string { return "noop" } +func (b *NoopComputeBackend) Type() string { return "noop" } // NoopDNSService is a no-op DNS service. type NoopDNSService struct{} @@ -164,7 +164,9 @@ type NoopLogService struct{} func (s *NoopLogService) StreamLogs(ctx context.Context, instanceID string) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader("")), nil } -func (s *NoopLogService) GetLogs(ctx context.Context, instanceID string) (string, error) { return "", nil } +func (s *NoopLogService) GetLogs(ctx context.Context, instanceID string) (string, error) { + return "", nil +} // NoopEventService is a no-op event service. type NoopEventService struct{} @@ -394,13 +396,41 @@ func (s *NoopLBService) ListTargets(ctx context.Context, lbID uuid.UUID) ([]*dom return []*domain.LBTarget{}, nil } -// NoopTaskQueue is a no-op task queue. +// NoopTaskQueue is a no-op task queue that implements DurableTaskQueue. type NoopTaskQueue struct{} func (q *NoopTaskQueue) Enqueue(ctx context.Context, queue string, payload interface{}) error { return nil } func (q *NoopTaskQueue) Dequeue(ctx context.Context, queue string) (string, error) { return "", nil } +func (q *NoopTaskQueue) EnsureGroup(ctx context.Context, queueName, groupName string) error { + return nil +} +func (q *NoopTaskQueue) Receive(ctx context.Context, queueName, groupName, consumerName string) (*ports.DurableMessage, error) { + return nil, nil +} +func (q *NoopTaskQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { + return nil +} +func (q *NoopTaskQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { + return nil +} +func (q *NoopTaskQueue) ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]ports.DurableMessage, error) { + return nil, nil +} + +// NoopExecutionLedger is a no-op execution ledger that always grants ownership. +type NoopExecutionLedger struct{} + +func (l *NoopExecutionLedger) TryAcquire(ctx context.Context, jobKey string, staleThreshold time.Duration) (bool, error) { + return true, nil +} +func (l *NoopExecutionLedger) MarkComplete(ctx context.Context, jobKey string, result string) error { + return nil +} +func (l *NoopExecutionLedger) MarkFailed(ctx context.Context, jobKey string, reason string) error { + return nil +} // --- New No-Ops (for benchmarks and system tests) --- diff --git a/internal/repositories/postgres/execution_ledger.go b/internal/repositories/postgres/execution_ledger.go new file mode 100644 index 000000000..008b4c9db --- /dev/null +++ b/internal/repositories/postgres/execution_ledger.go @@ -0,0 +1,114 @@ +// Package postgres provides PostgreSQL-backed repository implementations. +package postgres + +import ( + "context" + "fmt" + "time" + + "github.com/jackc/pgx/v5" +) + +// PgExecutionLedger implements ports.ExecutionLedger using the job_executions table. +type PgExecutionLedger struct { + db DB +} + +// NewExecutionLedger creates a new Postgres-backed execution ledger. +func NewExecutionLedger(db DB) *PgExecutionLedger { + return &PgExecutionLedger{db: db} +} + +// TryAcquire attempts to claim a job execution. It uses INSERT ... ON CONFLICT +// to atomically check whether the job was already processed: +// +// - If no row exists, inserts status='running' and returns true. +// - If a 'completed' row exists, returns false (already done). +// - If a 'running' row exists and is newer than staleThreshold, returns false +// (another worker is actively processing). +// - If a 'running' row exists but is older than staleThreshold, reclaims it +// by updating started_at and returns true (previous worker likely crashed). +// - If a 'failed' row exists, reclaims it (allows retry). +func (l *PgExecutionLedger) TryAcquire(ctx context.Context, jobKey string, staleThreshold time.Duration) (bool, error) { + // Step 1: Try to insert a new row. + var inserted bool + err := l.db.QueryRow(ctx, ` + INSERT INTO job_executions (job_key, status, started_at) + VALUES ($1, 'running', NOW()) + ON CONFLICT (job_key) DO NOTHING + RETURNING TRUE + `, jobKey).Scan(&inserted) + + if err == nil && inserted { + return true, nil // Successfully claimed a brand-new execution. + } + // pgx returns ErrNoRows when INSERT ... ON CONFLICT DO NOTHING matches zero rows + if err != nil && err != pgx.ErrNoRows { + return false, fmt.Errorf("execution ledger insert %s: %w", jobKey, err) + } + + // Row already exists. Check its status. + var status string + var startedAt time.Time + err = l.db.QueryRow(ctx, ` + SELECT status, started_at FROM job_executions WHERE job_key = $1 + `, jobKey).Scan(&status, &startedAt) + if err != nil { + return false, fmt.Errorf("execution ledger check %s: %w", jobKey, err) + } + + switch status { + case "completed": + // Already done — skip. + return false, nil + case "running": + // Check if the running entry is stale (crashed worker). + if time.Since(startedAt) < staleThreshold { + return false, nil // Another worker is still processing. + } + // Reclaim the stale entry. Use optimistic locking on started_at to + // avoid racing with another reclaimer. + tag, err := l.db.Exec(ctx, ` + UPDATE job_executions + SET started_at = NOW(), status = 'running' + WHERE job_key = $1 AND status = 'running' AND started_at = $2 + `, jobKey, startedAt) + if err != nil { + return false, fmt.Errorf("execution ledger reclaim %s: %w", jobKey, err) + } + return tag.RowsAffected() > 0, nil + case "failed": + // Retry a previously failed job. + _, err = l.db.Exec(ctx, ` + UPDATE job_executions + SET started_at = NOW(), status = 'running', completed_at = NULL, result = NULL + WHERE job_key = $1 AND status = 'failed' + `, jobKey) + if err != nil { + return false, fmt.Errorf("execution ledger retry %s: %w", jobKey, err) + } + return true, nil + default: + return false, fmt.Errorf("execution ledger unknown status %q for %s", status, jobKey) + } +} + +// MarkComplete marks a job as successfully completed. +func (l *PgExecutionLedger) MarkComplete(ctx context.Context, jobKey string, result string) error { + _, err := l.db.Exec(ctx, ` + UPDATE job_executions + SET status = 'completed', completed_at = NOW(), result = $2 + WHERE job_key = $1 + `, jobKey, result) + return err +} + +// MarkFailed marks a job as failed, allowing future retries. +func (l *PgExecutionLedger) MarkFailed(ctx context.Context, jobKey string, reason string) error { + _, err := l.db.Exec(ctx, ` + UPDATE job_executions + SET status = 'failed', completed_at = NOW(), result = $2 + WHERE job_key = $1 + `, jobKey, reason) + return err +} diff --git a/internal/repositories/postgres/migrations/100_create_job_executions.down.sql b/internal/repositories/postgres/migrations/100_create_job_executions.down.sql new file mode 100644 index 000000000..78cffd523 --- /dev/null +++ b/internal/repositories/postgres/migrations/100_create_job_executions.down.sql @@ -0,0 +1,2 @@ +-- +goose Down +DROP TABLE IF EXISTS job_executions; diff --git a/internal/repositories/postgres/migrations/100_create_job_executions.up.sql b/internal/repositories/postgres/migrations/100_create_job_executions.up.sql new file mode 100644 index 000000000..d06135193 --- /dev/null +++ b/internal/repositories/postgres/migrations/100_create_job_executions.up.sql @@ -0,0 +1,14 @@ +-- +goose Up +CREATE TABLE IF NOT EXISTS job_executions ( + job_key TEXT PRIMARY KEY, + status TEXT NOT NULL DEFAULT 'running', -- running | completed | failed + started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + completed_at TIMESTAMPTZ, + result TEXT, + -- Allow stale locks to be reclaimed: if a worker crashes while status='running', + -- another worker can take over after started_at + timeout has elapsed. + -- The timeout is enforced in application code, not in the schema. + CONSTRAINT job_executions_status_check CHECK (status IN ('running', 'completed', 'failed')) +); + +CREATE INDEX IF NOT EXISTS idx_job_executions_status ON job_executions (status) WHERE status = 'running'; diff --git a/internal/repositories/redis/durable_task_queue.go b/internal/repositories/redis/durable_task_queue.go new file mode 100644 index 000000000..ac02d22a4 --- /dev/null +++ b/internal/repositories/redis/durable_task_queue.go @@ -0,0 +1,220 @@ +// Package redis implements Redis-based repositories and data structures. +package redis + +import ( + "context" + "encoding/json" + stdlib_errors "errors" + "fmt" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" + "github.com/redis/go-redis/v9" +) + +// durableTaskQueue implements ports.DurableTaskQueue using Redis Streams +// and consumer groups for at-least-once delivery semantics. +type durableTaskQueue struct { + client *redis.Client + blockTime time.Duration // how long Receive blocks waiting for new messages + maxRetries int64 // max delivery attempts before a message is dead-lettered + dlqSuffix string // suffix appended to queue name for the dead-letter stream +} + +// DurableQueueOption configures a durableTaskQueue. +type DurableQueueOption func(*durableTaskQueue) + +// WithBlockTime sets the Receive block duration (default 5s). +func WithBlockTime(d time.Duration) DurableQueueOption { + return func(q *durableTaskQueue) { q.blockTime = d } +} + +// WithMaxRetries sets the max delivery count before dead-lettering (default 5). +func WithMaxRetries(n int64) DurableQueueOption { + return func(q *durableTaskQueue) { q.maxRetries = n } +} + +// WithDLQSuffix sets the dead-letter queue suffix (default ":dlq"). +func WithDLQSuffix(s string) DurableQueueOption { + return func(q *durableTaskQueue) { q.dlqSuffix = s } +} + +// NewDurableTaskQueue creates a Redis Streams–backed durable task queue. +func NewDurableTaskQueue(client *redis.Client, opts ...DurableQueueOption) *durableTaskQueue { + q := &durableTaskQueue{ + client: client, + blockTime: 5 * time.Second, + maxRetries: 5, + dlqSuffix: ":dlq", + } + for _, o := range opts { + o(q) + } + return q +} + +// ---------- TaskQueue (backward-compatible) ---------- + +func (q *durableTaskQueue) Enqueue(ctx context.Context, queueName string, payload interface{}) error { + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("durable enqueue marshal: %w", err) + } + return q.client.XAdd(ctx, &redis.XAddArgs{ + Stream: queueName, + Values: map[string]interface{}{"payload": string(data)}, + }).Err() +} + +func (q *durableTaskQueue) Dequeue(ctx context.Context, queueName string) (string, error) { + // Legacy fallback: reads from the stream without consumer groups (XREAD). + // New consumers should use Receive instead. + res, err := q.client.XRead(ctx, &redis.XReadArgs{ + Streams: []string{queueName, "0-0"}, + Count: 1, + Block: q.blockTime, + }).Result() + if err != nil { + if stdlib_errors.Is(err, redis.Nil) { + return "", nil + } + return "", err + } + if len(res) == 0 || len(res[0].Messages) == 0 { + return "", nil + } + msg := res[0].Messages[0] + // Auto-delete since legacy callers don't ack. + q.client.XDel(ctx, queueName, msg.ID) + payload, _ := msg.Values["payload"].(string) + return payload, nil +} + +// ---------- DurableTaskQueue ---------- + +func (q *durableTaskQueue) EnsureGroup(ctx context.Context, queueName, groupName string) error { + err := q.client.XGroupCreateMkStream(ctx, queueName, groupName, "0").Err() + if err != nil { + // "BUSYGROUP Consumer Group name already exists" is harmless at startup. + if isGroupExistsErr(err) { + return nil + } + return fmt.Errorf("ensure group %s/%s: %w", queueName, groupName, err) + } + return nil +} + +func (q *durableTaskQueue) Receive(ctx context.Context, queueName, groupName, consumerName string) (*ports.DurableMessage, error) { + res, err := q.client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: groupName, + Consumer: consumerName, + Streams: []string{queueName, ">"}, + Count: 1, + Block: q.blockTime, + }).Result() + if err != nil { + if stdlib_errors.Is(err, redis.Nil) { + return nil, nil + } + return nil, fmt.Errorf("receive from %s/%s: %w", queueName, groupName, err) + } + if len(res) == 0 || len(res[0].Messages) == 0 { + return nil, nil + } + + xmsg := res[0].Messages[0] + payload, _ := xmsg.Values["payload"].(string) + return &ports.DurableMessage{ + ID: xmsg.ID, + Payload: payload, + Queue: queueName, + }, nil +} + +func (q *durableTaskQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { + return q.client.XAck(ctx, queueName, groupName, messageID).Err() +} + +func (q *durableTaskQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { + // In Redis Streams, un-acknowledged messages remain in the PEL (Pending + // Entries List) automatically. Nack is a no-op — the message will be + // reclaimed by ReclaimStale after the idle timeout. + // + // Future enhancement: we could XCLAIM the message back to a retry consumer + // immediately, but the idle-reclaim approach is simpler and sufficient. + return nil +} + +func (q *durableTaskQueue) ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]ports.DurableMessage, error) { + // XAUTOCLAIM atomically claims messages idle > minIdleMs and returns them. + msgs, _, err := q.client.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + Stream: queueName, + Group: groupName, + Consumer: consumerName, + MinIdle: time.Duration(minIdleMs) * time.Millisecond, + Start: "0-0", + Count: count, + }).Result() + if err != nil { + return nil, fmt.Errorf("reclaim stale from %s/%s: %w", queueName, groupName, err) + } + + out := make([]ports.DurableMessage, 0, len(msgs)) + for _, xmsg := range msgs { + payload, _ := xmsg.Values["payload"].(string) + + // Dead-letter messages that exceeded max retries. + if xmsg.DeliveredCount > 0 && xmsg.DeliveredCount > q.maxRetries { + _ = q.deadLetter(ctx, queueName, groupName, xmsg) + continue + } + + out = append(out, ports.DurableMessage{ + ID: xmsg.ID, + Payload: payload, + Queue: queueName, + }) + } + return out, nil +} + +// deadLetter moves a message to the dead-letter stream and acks the original. +func (q *durableTaskQueue) deadLetter(ctx context.Context, queueName, groupName string, msg redis.XMessage) error { + dlq := queueName + q.dlqSuffix + payload, _ := msg.Values["payload"].(string) + pipe := q.client.Pipeline() + pipe.XAdd(ctx, &redis.XAddArgs{ + Stream: dlq, + Values: map[string]interface{}{ + "payload": payload, + "original_id": msg.ID, + "queue": queueName, + }, + }) + pipe.XAck(ctx, queueName, groupName, msg.ID) + pipe.XDel(ctx, queueName, msg.ID) + _, err := pipe.Exec(ctx) + return err +} + +// isGroupExistsErr returns true when the error indicates the consumer group +// already exists (Redis returns BUSYGROUP). +func isGroupExistsErr(err error) bool { + if err == nil { + return false + } + return containsBusyGroup(err.Error()) +} + +func containsBusyGroup(s string) bool { + return len(s) >= 9 && (s[:9] == "BUSYGROUP" || containsSubstring(s, "BUSYGROUP")) +} + +func containsSubstring(s, sub string) bool { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/repositories/redis/durable_task_queue_test.go b/internal/repositories/redis/durable_task_queue_test.go new file mode 100644 index 000000000..11b101838 --- /dev/null +++ b/internal/repositories/redis/durable_task_queue_test.go @@ -0,0 +1,245 @@ +package redis + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +func newTestDurableQueue(t *testing.T) (*durableTaskQueue, *miniredis.Miniredis) { + t.Helper() + s, err := miniredis.Run() + if err != nil { + t.Fatalf("failed to start miniredis: %v", err) + } + client := redis.NewClient(&redis.Options{Addr: s.Addr()}) + q := NewDurableTaskQueue(client, WithBlockTime(100*time.Millisecond), WithMaxRetries(3)) + return q, s +} + +func TestDurableEnqueue(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + payload := map[string]string{"instance_id": "abc-123"} + + if err := q.Enqueue(ctx, "test_stream", payload); err != nil { + t.Fatalf("Enqueue failed: %v", err) + } + + // Verify stream has one entry + entries, err := s.Stream("test_stream") + if err != nil { + t.Fatalf("Stream read failed: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 stream entry, got %d", len(entries)) + } +} + +func TestDurableEnsureGroupIdempotent(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + // Should not error even if stream doesn't exist yet (MkStream). + if err := q.EnsureGroup(ctx, "test_stream", "workers"); err != nil { + t.Fatalf("first EnsureGroup failed: %v", err) + } + // Calling again should be idempotent (BUSYGROUP). + if err := q.EnsureGroup(ctx, "test_stream", "workers"); err != nil { + t.Fatalf("second EnsureGroup failed: %v", err) + } +} + +func TestDurableReceiveAndAck(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "provision_queue" + group := "workers" + consumer := "worker-1" + + // Setup + if err := q.EnsureGroup(ctx, queue, group); err != nil { + t.Fatalf("EnsureGroup: %v", err) + } + + // Enqueue a job + job := map[string]string{"instance_id": "inst-001"} + if err := q.Enqueue(ctx, queue, job); err != nil { + t.Fatalf("Enqueue: %v", err) + } + + // Receive it + msg, err := q.Receive(ctx, queue, group, consumer) + if err != nil { + t.Fatalf("Receive: %v", err) + } + if msg == nil { + t.Fatal("expected message, got nil") + } + if msg.Queue != queue { + t.Fatalf("expected queue %q, got %q", queue, msg.Queue) + } + + // Verify payload round-trips + var got map[string]string + if err := json.Unmarshal([]byte(msg.Payload), &got); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if got["instance_id"] != "inst-001" { + t.Fatalf("expected instance_id inst-001, got %s", got["instance_id"]) + } + + // Ack it + if err := q.Ack(ctx, queue, group, msg.ID); err != nil { + t.Fatalf("Ack: %v", err) + } + + // Receive again — should be empty + msg2, err := q.Receive(ctx, queue, group, consumer) + if err != nil { + t.Fatalf("second Receive: %v", err) + } + if msg2 != nil { + t.Fatalf("expected nil after ack, got %+v", msg2) + } +} + +func TestDurableReceiveEmptyReturnsNil(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "empty_stream" + group := "workers" + + if err := q.EnsureGroup(ctx, queue, group); err != nil { + t.Fatalf("EnsureGroup: %v", err) + } + + msg, err := q.Receive(ctx, queue, group, "worker-1") + if err != nil { + t.Fatalf("Receive: %v", err) + } + if msg != nil { + t.Fatalf("expected nil message from empty stream, got %+v", msg) + } +} + +func TestDurableMultipleConsumersGetDifferentMessages(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "multi_consumer" + group := "workers" + + if err := q.EnsureGroup(ctx, queue, group); err != nil { + t.Fatalf("EnsureGroup: %v", err) + } + + // Enqueue two messages + if err := q.Enqueue(ctx, queue, map[string]string{"id": "1"}); err != nil { + t.Fatalf("Enqueue 1: %v", err) + } + if err := q.Enqueue(ctx, queue, map[string]string{"id": "2"}); err != nil { + t.Fatalf("Enqueue 2: %v", err) + } + + // Two consumers each get one + msg1, err := q.Receive(ctx, queue, group, "worker-1") + if err != nil || msg1 == nil { + t.Fatalf("worker-1 Receive: msg=%v err=%v", msg1, err) + } + msg2, err := q.Receive(ctx, queue, group, "worker-2") + if err != nil || msg2 == nil { + t.Fatalf("worker-2 Receive: msg=%v err=%v", msg2, err) + } + + if msg1.ID == msg2.ID { + t.Fatalf("both consumers got the same message ID: %s", msg1.ID) + } +} + +func TestDurableNackKeepsMessagePending(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "nack_test" + group := "workers" + consumer := "worker-1" + + if err := q.EnsureGroup(ctx, queue, group); err != nil { + t.Fatalf("EnsureGroup: %v", err) + } + if err := q.Enqueue(ctx, queue, map[string]string{"id": "1"}); err != nil { + t.Fatalf("Enqueue: %v", err) + } + + msg, err := q.Receive(ctx, queue, group, consumer) + if err != nil || msg == nil { + t.Fatalf("Receive: msg=%v err=%v", msg, err) + } + + // Nack (no-op in Redis Streams — message stays in PEL) + if err := q.Nack(ctx, queue, group, msg.ID); err != nil { + t.Fatalf("Nack: %v", err) + } + + // The message should still be pending (not acked). + // Verify via XPending. + pending, err := q.client.XPending(ctx, queue, group).Result() + if err != nil { + t.Fatalf("XPending: %v", err) + } + if pending.Count != 1 { + t.Fatalf("expected 1 pending message, got %d", pending.Count) + } +} + +func TestDurableDeadLetterOnDequeue(t *testing.T) { + // This tests the legacy Dequeue path for backward compatibility. + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "legacy_dequeue" + + if err := q.Enqueue(ctx, queue, map[string]string{"legacy": "true"}); err != nil { + t.Fatalf("Enqueue: %v", err) + } + + msg, err := q.Dequeue(ctx, queue) + if err != nil { + t.Fatalf("Dequeue: %v", err) + } + if msg == "" { + t.Fatal("expected non-empty legacy dequeue result") + } + + var got map[string]string + if err := json.Unmarshal([]byte(msg), &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got["legacy"] != "true" { + t.Fatalf("expected legacy=true, got %s", got["legacy"]) + } + + // Stream should be empty after legacy Dequeue (auto-deleted) + entries, err := s.Stream(queue) + if err != nil { + t.Fatalf("Stream: %v", err) + } + if len(entries) != 0 { + t.Fatalf("expected empty stream after legacy dequeue, got %d entries", len(entries)) + } +} From 845e77ddf2dd0dd0b90ad6ac661dbab0a3227e20 Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 10 Mar 2026 05:40:19 +0300 Subject: [PATCH 03/12] feat(resilience): Add circuit breaker enhancements, bulkhead, and retry utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Circuit Breaker: - Add half-open single-flight: only one probe request allowed at a time - Add OnStateChange callback (synchronous) for observability - Add SuccessRequired for multi-success half-open→closed transition - Add Name and State.String() methods - Backward compatible with existing NewCircuitBreaker(threshold, timeout) Bulkhead: - Add semaphore-based concurrency limiter with configurable wait timeout - Returns ErrBulkheadFull when limit reached and timeout expires Retry: - Add exponential backoff with full jitter - Configurable ShouldRetry predicate for selective retry - Context-aware cancellation --- internal/platform/bulkhead.go | 86 +++++++++++ internal/platform/bulkhead_test.go | 118 +++++++++++++++ internal/platform/circuit_breaker.go | 166 ++++++++++++++++++---- internal/platform/circuit_breaker_test.go | 131 ++++++++++++++++- internal/platform/retry.go | 96 +++++++++++++ internal/platform/retry_test.go | 119 ++++++++++++++++ 6 files changed, 687 insertions(+), 29 deletions(-) create mode 100644 internal/platform/bulkhead.go create mode 100644 internal/platform/bulkhead_test.go create mode 100644 internal/platform/retry.go create mode 100644 internal/platform/retry_test.go diff --git a/internal/platform/bulkhead.go b/internal/platform/bulkhead.go new file mode 100644 index 000000000..8f6cbf86a --- /dev/null +++ b/internal/platform/bulkhead.go @@ -0,0 +1,86 @@ +package platform + +import ( + "context" + "errors" + "time" +) + +// ErrBulkheadFull is returned when the bulkhead's concurrency limit is reached +// and the caller's timeout/context expires before a slot opens. +var ErrBulkheadFull = errors.New("bulkhead: concurrency limit reached") + +// Bulkhead limits concurrent access to a resource using a semaphore pattern. +// It prevents one slow/failing component from consuming all available goroutines +// and cascading failure to other parts of the system. +type Bulkhead struct { + name string + sem chan struct{} + timeout time.Duration +} + +// BulkheadOpts configures a bulkhead. +type BulkheadOpts struct { + Name string // Identifier for logging/metrics. + MaxConc int // Maximum concurrent requests. Default 10. + WaitTimeout time.Duration // How long to wait for a slot. Default 5s. 0 means use context deadline. +} + +// NewBulkhead creates a new concurrency-limiting bulkhead. +func NewBulkhead(opts BulkheadOpts) *Bulkhead { + if opts.MaxConc <= 0 { + opts.MaxConc = 10 + } + return &Bulkhead{ + name: opts.Name, + sem: make(chan struct{}, opts.MaxConc), + timeout: opts.WaitTimeout, + } +} + +// Execute runs fn within the bulkhead's concurrency limit. +// If the bulkhead is full and the wait timeout (or context) expires, +// ErrBulkheadFull is returned without calling fn. +func (b *Bulkhead) Execute(ctx context.Context, fn func() error) error { + if err := b.acquire(ctx); err != nil { + return err + } + defer b.release() + return fn() +} + +func (b *Bulkhead) acquire(ctx context.Context) error { + if b.timeout > 0 { + timer := time.NewTimer(b.timeout) + defer timer.Stop() + select { + case b.sem <- struct{}{}: + return nil + case <-timer.C: + return ErrBulkheadFull + case <-ctx.Done(): + return ErrBulkheadFull + } + } + // No explicit timeout — rely on context. + select { + case b.sem <- struct{}{}: + return nil + case <-ctx.Done(): + return ErrBulkheadFull + } +} + +func (b *Bulkhead) release() { + <-b.sem +} + +// Available returns the number of currently available slots. +func (b *Bulkhead) Available() int { + return cap(b.sem) - len(b.sem) +} + +// Name returns the bulkhead's configured name. +func (b *Bulkhead) Name() string { + return b.name +} diff --git a/internal/platform/bulkhead_test.go b/internal/platform/bulkhead_test.go new file mode 100644 index 000000000..dd7315d61 --- /dev/null +++ b/internal/platform/bulkhead_test.go @@ -0,0 +1,118 @@ +package platform + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBulkheadAllowsUpToMaxConcurrency(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{Name: "test", MaxConc: 2}) + + var running atomic.Int32 + var maxSeen atomic.Int32 + var wg sync.WaitGroup + + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := bh.Execute(context.Background(), func() error { + cur := running.Add(1) + defer running.Add(-1) + // Track the max concurrent. + for { + old := maxSeen.Load() + if cur <= old || maxSeen.CompareAndSwap(old, cur) { + break + } + } + time.Sleep(50 * time.Millisecond) + return nil + }) + assert.NoError(t, err) + }() + } + + wg.Wait() + assert.LessOrEqual(t, maxSeen.Load(), int32(2)) +} + +func TestBulkheadRejectsWhenFull(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{Name: "test", MaxConc: 1, WaitTimeout: 50 * time.Millisecond}) + + // Fill the bulkhead. + started := make(chan struct{}) + done := make(chan struct{}) + go func() { + _ = bh.Execute(context.Background(), func() error { + close(started) + <-done + return nil + }) + }() + <-started + + // Second call should be rejected. + err := bh.Execute(context.Background(), func() error { return nil }) + assert.ErrorIs(t, err, ErrBulkheadFull) + + close(done) +} + +func TestBulkheadRespectsContext(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{Name: "test", MaxConc: 1}) + + // Fill the bulkhead. + started := make(chan struct{}) + done := make(chan struct{}) + go func() { + _ = bh.Execute(context.Background(), func() error { + close(started) + <-done + return nil + }) + }() + <-started + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := bh.Execute(ctx, func() error { return nil }) + assert.ErrorIs(t, err, ErrBulkheadFull) + + close(done) +} + +func TestBulkheadPropagatesFunctionError(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{Name: "test", MaxConc: 5}) + myErr := errors.New("business error") + err := bh.Execute(context.Background(), func() error { return myErr }) + require.ErrorIs(t, err, myErr) +} + +func TestBulkheadAvailable(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{MaxConc: 3}) + assert.Equal(t, 3, bh.Available()) + + started := make(chan struct{}) + done := make(chan struct{}) + go func() { + _ = bh.Execute(context.Background(), func() error { + close(started) + <-done + return nil + }) + }() + <-started + assert.Equal(t, 2, bh.Available()) + close(done) + time.Sleep(10 * time.Millisecond) + assert.Equal(t, 3, bh.Available()) +} diff --git a/internal/platform/circuit_breaker.go b/internal/platform/circuit_breaker.go index c57490f78..d310367ab 100644 --- a/internal/platform/circuit_breaker.go +++ b/internal/platform/circuit_breaker.go @@ -3,6 +3,7 @@ package platform import ( "errors" + "fmt" "sync" "time" ) @@ -22,22 +23,80 @@ const ( StateHalfOpen ) -// CircuitBreaker implements the circuit breaker pattern. +// String returns a human-readable name for the circuit breaker state. +func (s State) String() string { + switch s { + case StateClosed: + return "closed" + case StateOpen: + return "open" + case StateHalfOpen: + return "half-open" + default: + return fmt.Sprintf("unknown(%d)", int(s)) + } +} + +// StateChangeFunc is called when the circuit breaker transitions between states. +// The old and new states are provided. Implementations must not block. +type StateChangeFunc func(name string, from, to State) + +// CircuitBreakerOpts configures the circuit breaker. All fields are optional +// and have sensible defaults; use the functional options to override. +type CircuitBreakerOpts struct { + Name string // Identifies this breaker in logs/metrics. + Threshold int // Consecutive failures to trip open. Default 5. + ResetTimeout time.Duration // Time in open before trying half-open. Default 30s. + SuccessRequired int // Successes in half-open to close. Default 1. + OnStateChange StateChangeFunc // Optional callback. +} + +// CircuitBreaker implements the circuit breaker pattern with proper +// half-open single-flight: only one probe request is allowed while open +// transitions to half-open. type CircuitBreaker struct { - mu sync.RWMutex + mu sync.Mutex + + name string state State failureCount int - failureThreshold int + successCount int // successes in half-open + threshold int + successRequired int resetTimeout time.Duration lastFailure time.Time + halfOpenInFlight bool // true while a half-open probe is executing + onStateChange StateChangeFunc } -// NewCircuitBreaker creates a new circuit breaker. +// NewCircuitBreaker creates a circuit breaker. The two positional args +// (threshold, resetTimeout) are kept for backward compatibility with existing +// callers. Use NewCircuitBreakerWithOpts for full configuration. func NewCircuitBreaker(threshold int, resetTimeout time.Duration) *CircuitBreaker { + return NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Threshold: threshold, + ResetTimeout: resetTimeout, + }) +} + +// NewCircuitBreakerWithOpts creates a circuit breaker with full options. +func NewCircuitBreakerWithOpts(opts CircuitBreakerOpts) *CircuitBreaker { + if opts.Threshold <= 0 { + opts.Threshold = 5 + } + if opts.ResetTimeout <= 0 { + opts.ResetTimeout = 30 * time.Second + } + if opts.SuccessRequired <= 0 { + opts.SuccessRequired = 1 + } return &CircuitBreaker{ - state: StateClosed, - failureThreshold: threshold, - resetTimeout: resetTimeout, + name: opts.Name, + state: StateClosed, + threshold: opts.Threshold, + successRequired: opts.SuccessRequired, + resetTimeout: opts.ResetTimeout, + onStateChange: opts.OnStateChange, } } @@ -58,34 +117,51 @@ func (cb *CircuitBreaker) Execute(fn func() error) error { } func (cb *CircuitBreaker) allowRequest() bool { - cb.mu.RLock() - defer cb.mu.RUnlock() + cb.mu.Lock() + defer cb.mu.Unlock() - if cb.state == StateClosed { + switch cb.state { + case StateClosed: return true - } - - if cb.state == StateOpen { - if time.Since(cb.lastFailure) > cb.resetTimeout { - return true // Transition to half-open (implied by letting one request through) + case StateOpen: + if time.Since(cb.lastFailure) <= cb.resetTimeout { + return false } - return false + // Transition to half-open; only allow one probe at a time. + if cb.halfOpenInFlight { + return false + } + cb.transitionLocked(StateHalfOpen) + cb.halfOpenInFlight = true + cb.successCount = 0 + return true + case StateHalfOpen: + // Allow additional requests only if no probe is in flight. + if cb.halfOpenInFlight { + return false + } + cb.halfOpenInFlight = true + return true } - - return true // Half-open + return false } func (cb *CircuitBreaker) recordFailure() { cb.mu.Lock() defer cb.mu.Unlock() + cb.halfOpenInFlight = false cb.failureCount++ cb.lastFailure = time.Now() - if cb.state == StateClosed && cb.failureCount >= cb.failureThreshold { - cb.state = StateOpen - } else if cb.state == StateHalfOpen { - cb.state = StateOpen + switch cb.state { + case StateClosed: + if cb.failureCount >= cb.threshold { + cb.transitionLocked(StateOpen) + } + case StateHalfOpen: + // Probe failed — go back to open. + cb.transitionLocked(StateOpen) } } @@ -93,18 +169,54 @@ func (cb *CircuitBreaker) recordSuccess() { cb.mu.Lock() defer cb.mu.Unlock() - cb.failureCount = 0 - cb.state = StateClosed + cb.halfOpenInFlight = false + + switch cb.state { + case StateHalfOpen: + cb.successCount++ + if cb.successCount >= cb.successRequired { + cb.failureCount = 0 + cb.successCount = 0 + cb.transitionLocked(StateClosed) + } + default: + cb.failureCount = 0 + cb.state = StateClosed + } +} + +// transitionLocked changes state and fires the callback. Must be called +// with cb.mu held. The callback is invoked synchronously; implementations +// must not block or acquire cb.mu. +func (cb *CircuitBreaker) transitionLocked(to State) { + from := cb.state + if from == to { + return + } + cb.state = to + if cb.onStateChange != nil { + cb.onStateChange(cb.name, from, to) + } } // Reset clears the circuit breaker state. func (cb *CircuitBreaker) Reset() { - cb.recordSuccess() + cb.mu.Lock() + defer cb.mu.Unlock() + cb.failureCount = 0 + cb.successCount = 0 + cb.halfOpenInFlight = false + cb.transitionLocked(StateClosed) } // GetState returns the current state of the circuit breaker. func (cb *CircuitBreaker) GetState() State { - cb.mu.RLock() - defer cb.mu.RUnlock() + cb.mu.Lock() + defer cb.mu.Unlock() return cb.state } + +// Name returns the configured name of this circuit breaker. +func (cb *CircuitBreaker) Name() string { + return cb.name +} diff --git a/internal/platform/circuit_breaker_test.go b/internal/platform/circuit_breaker_test.go index 39126a8a7..2e16cdfa1 100644 --- a/internal/platform/circuit_breaker_test.go +++ b/internal/platform/circuit_breaker_test.go @@ -2,11 +2,12 @@ package platform import ( "errors" - "github.com/stretchr/testify/require" + "sync" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCircuitBreaker(t *testing.T) { @@ -53,7 +54,7 @@ func TestCircuitBreaker(t *testing.T) { time.Sleep(100 * time.Millisecond) - // This should be allowed (half-open state implicitly) + // This should be allowed (half-open state) err := cb.Execute(func() error { return nil }) @@ -91,3 +92,129 @@ func TestCircuitBreaker(t *testing.T) { assert.Equal(t, StateClosed, cb.GetState()) }) } + +func TestCircuitBreakerHalfOpenSingleFlight(t *testing.T) { + cb := NewCircuitBreaker(1, 50*time.Millisecond) + + // Trip the circuit. + _ = cb.Execute(func() error { return errors.New("fail") }) + assert.Equal(t, StateOpen, cb.GetState()) + + time.Sleep(100 * time.Millisecond) + + // First call goes through as the half-open probe. Use a channel to + // hold the probe in-flight while we test the second call. + probeStarted := make(chan struct{}) + probeDone := make(chan struct{}) + + go func() { + _ = cb.Execute(func() error { + close(probeStarted) + <-probeDone // block until test releases + return nil + }) + }() + + <-probeStarted // wait for probe to be in-flight + + // Second concurrent call should be rejected while probe is in flight. + err := cb.Execute(func() error { return nil }) + assert.Equal(t, ErrCircuitOpen, err, "second request should be blocked while half-open probe is in flight") + + close(probeDone) // release the probe + time.Sleep(10 * time.Millisecond) + + // After probe succeeds, circuit should be closed. + assert.Equal(t, StateClosed, cb.GetState()) +} + +func TestCircuitBreakerOnStateChange(t *testing.T) { + var mu sync.Mutex + transitions := make([]struct{ from, to State }, 0) + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: "test-cb", + Threshold: 1, + ResetTimeout: 50 * time.Millisecond, + OnStateChange: func(name string, from, to State) { + mu.Lock() + transitions = append(transitions, struct{ from, to State }{from, to}) + mu.Unlock() + }, + }) + + // Trip it. + _ = cb.Execute(func() error { return errors.New("fail") }) + time.Sleep(20 * time.Millisecond) // let async callback fire + + mu.Lock() + require.Len(t, transitions, 1) + assert.Equal(t, StateClosed, transitions[0].from) + assert.Equal(t, StateOpen, transitions[0].to) + mu.Unlock() + + // Wait for reset timeout, then succeed to close. + time.Sleep(100 * time.Millisecond) + err := cb.Execute(func() error { return nil }) + require.NoError(t, err) + time.Sleep(20 * time.Millisecond) + + mu.Lock() + // Should have: closed->open, open->half-open, half-open->closed + require.Len(t, transitions, 3) + assert.Equal(t, StateOpen, transitions[1].from) + assert.Equal(t, StateHalfOpen, transitions[1].to) + assert.Equal(t, StateHalfOpen, transitions[2].from) + assert.Equal(t, StateClosed, transitions[2].to) + mu.Unlock() +} + +func TestCircuitBreakerWithOpts(t *testing.T) { + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: "compute", + Threshold: 3, + ResetTimeout: 1 * time.Second, + SuccessRequired: 2, + }) + + assert.Equal(t, "compute", cb.Name()) + assert.Equal(t, StateClosed, cb.GetState()) + + // Trip it with 3 failures. + for i := 0; i < 3; i++ { + _ = cb.Execute(func() error { return errors.New("fail") }) + } + assert.Equal(t, StateOpen, cb.GetState()) +} + +func TestCircuitBreakerSuccessRequired(t *testing.T) { + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Threshold: 1, + ResetTimeout: 50 * time.Millisecond, + SuccessRequired: 2, + }) + + // Trip it. + _ = cb.Execute(func() error { return errors.New("fail") }) + assert.Equal(t, StateOpen, cb.GetState()) + + time.Sleep(100 * time.Millisecond) + + // First success should move to half-open but not closed. + err := cb.Execute(func() error { return nil }) + require.NoError(t, err) + // Still half-open because we need 2 successes. + assert.Equal(t, StateHalfOpen, cb.GetState()) + + // Second success should close. + err = cb.Execute(func() error { return nil }) + require.NoError(t, err) + assert.Equal(t, StateClosed, cb.GetState()) +} + +func TestStateString(t *testing.T) { + assert.Equal(t, "closed", StateClosed.String()) + assert.Equal(t, "open", StateOpen.String()) + assert.Equal(t, "half-open", StateHalfOpen.String()) + assert.Equal(t, "unknown(99)", State(99).String()) +} diff --git a/internal/platform/retry.go b/internal/platform/retry.go new file mode 100644 index 000000000..86391a0bf --- /dev/null +++ b/internal/platform/retry.go @@ -0,0 +1,96 @@ +package platform + +import ( + "context" + "math" + "math/rand/v2" + "time" +) + +// RetryOpts configures retry behavior. +type RetryOpts struct { + MaxAttempts int // Total attempts (including the first). Default 3. + BaseDelay time.Duration // Initial delay before first retry. Default 500ms. + MaxDelay time.Duration // Cap on exponential growth. Default 30s. + Multiplier float64 // Exponent base. Default 2.0. + // ShouldRetry is an optional predicate that returns false for errors + // that should NOT be retried (e.g., validation errors, 4xx HTTP). + // If nil, all non-nil errors are retried. + ShouldRetry func(error) bool +} + +func (o RetryOpts) withDefaults() RetryOpts { + if o.MaxAttempts <= 0 { + o.MaxAttempts = 3 + } + if o.BaseDelay <= 0 { + o.BaseDelay = 500 * time.Millisecond + } + if o.MaxDelay <= 0 { + o.MaxDelay = 30 * time.Second + } + if o.Multiplier <= 0 { + o.Multiplier = 2.0 + } + return o +} + +// Retry executes fn up to opts.MaxAttempts times with exponential backoff +// and full jitter. It stops early if the context is cancelled or +// opts.ShouldRetry returns false. +func Retry(ctx context.Context, opts RetryOpts, fn func(ctx context.Context) error) error { + opts = opts.withDefaults() + + var lastErr error + for attempt := 0; attempt < opts.MaxAttempts; attempt++ { + if err := ctx.Err(); err != nil { + if lastErr != nil { + return lastErr + } + return err + } + + lastErr = fn(ctx) + if lastErr == nil { + return nil + } + + // Check if this error is retryable. + if opts.ShouldRetry != nil && !opts.ShouldRetry(lastErr) { + return lastErr + } + + // Don't sleep after the last attempt. + if attempt == opts.MaxAttempts-1 { + break + } + + delay := backoffDelay(attempt, opts.BaseDelay, opts.MaxDelay, opts.Multiplier) + timer := time.NewTimer(delay) + select { + case <-ctx.Done(): + timer.Stop() + return lastErr + case <-timer.C: + } + } + + return lastErr +} + +// backoffDelay computes exponential backoff with full jitter: +// delay = random(0, min(maxDelay, baseDelay * multiplier^attempt)) +func backoffDelay(attempt int, base, max time.Duration, mult float64) time.Duration { + exp := math.Pow(mult, float64(attempt)) + calculated := time.Duration(float64(base) * exp) + if calculated > max || calculated <= 0 { + calculated = max + } + // Full jitter: uniform random in [base/2, calculated]. + floor := base / 2 + if floor > calculated { + floor = calculated + } + jittered := floor + time.Duration(rand.Int64N(int64(calculated-floor+1))) + return jittered +} diff --git a/internal/platform/retry_test.go b/internal/platform/retry_test.go new file mode 100644 index 000000000..0eea21edf --- /dev/null +++ b/internal/platform/retry_test.go @@ -0,0 +1,119 @@ +package platform + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRetrySucceedsImmediately(t *testing.T) { + calls := 0 + err := Retry(context.Background(), RetryOpts{MaxAttempts: 3}, func(ctx context.Context) error { + calls++ + return nil + }) + require.NoError(t, err) + assert.Equal(t, 1, calls) +} + +func TestRetryRetriesOnFailure(t *testing.T) { + var calls atomic.Int32 + err := Retry(context.Background(), RetryOpts{ + MaxAttempts: 3, + BaseDelay: 10 * time.Millisecond, + MaxDelay: 50 * time.Millisecond, + }, func(ctx context.Context) error { + n := calls.Add(1) + if n < 3 { + return errors.New("transient") + } + return nil + }) + require.NoError(t, err) + assert.Equal(t, int32(3), calls.Load()) +} + +func TestRetryExhaustsAttempts(t *testing.T) { + calls := 0 + err := Retry(context.Background(), RetryOpts{ + MaxAttempts: 2, + BaseDelay: 10 * time.Millisecond, + }, func(ctx context.Context) error { + calls++ + return errors.New("permanent") + }) + require.Error(t, err) + assert.Equal(t, "permanent", err.Error()) + assert.Equal(t, 2, calls) +} + +func TestRetryRespectsContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + calls := 0 + err := Retry(ctx, RetryOpts{ + MaxAttempts: 10, + BaseDelay: 50 * time.Millisecond, + }, func(ctx context.Context) error { + calls++ + if calls == 2 { + cancel() + } + return errors.New("fail") + }) + require.Error(t, err) + assert.LessOrEqual(t, calls, 3) // might get 2 or 3 depending on timing +} + +func TestRetryShouldRetryPredicate(t *testing.T) { + permanent := errors.New("permanent error") + calls := 0 + err := Retry(context.Background(), RetryOpts{ + MaxAttempts: 5, + BaseDelay: 10 * time.Millisecond, + ShouldRetry: func(err error) bool { + return !errors.Is(err, permanent) + }, + }, func(ctx context.Context) error { + calls++ + return permanent + }) + require.ErrorIs(t, err, permanent) + assert.Equal(t, 1, calls, "should not retry non-retryable errors") +} + +func TestRetryDefaultOpts(t *testing.T) { + calls := 0 + err := Retry(context.Background(), RetryOpts{}, func(ctx context.Context) error { + calls++ + if calls < 3 { + return errors.New("fail") + } + return nil + }) + require.NoError(t, err) + assert.Equal(t, 3, calls) // default MaxAttempts is 3 +} + +func TestBackoffDelay(t *testing.T) { + base := 100 * time.Millisecond + max := 5 * time.Second + + // Attempt 0: jitter in [base/2, base] + for i := 0; i < 100; i++ { + d := backoffDelay(0, base, max, 2.0) + assert.GreaterOrEqual(t, d, base/2) + assert.LessOrEqual(t, d, base) + } + + // Attempt 3: calculated = 100ms * 2^3 = 800ms + for i := 0; i < 100; i++ { + d := backoffDelay(3, base, max, 2.0) + assert.GreaterOrEqual(t, d, base/2) + assert.LessOrEqual(t, d, 800*time.Millisecond) + } +} From 9ee13ce89e66534bdfd9bfcb8420b695efb3c5b7 Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 10 Mar 2026 05:40:35 +0300 Subject: [PATCH 04/12] feat(resilience): Add resilient adapter wrappers for infrastructure backends MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add decorator wrappers implementing ports interfaces with resilience patterns: - ResilientCompute: CB (5 fails/30s) + Bulkhead (20 conc) + Timeouts - ResilientNetwork: CB (5 fails/30s) + Bulkhead (15 conc) + Timeout (30s) - ResilientStorage: CB (5 fails/30s) + Bulkhead (10 conc) + Timeouts - ResilientDNS: CB (5 fails/30s) + Timeout (10s) - no bulkhead needed - ResilientLB: CB (5 fails/30s) + Timeouts (30s normal, 2m deploy) Design: - Ping() bypasses bulkhead (cheap health check) but uses CB - Type() delegates directly (pure metadata) - Retry NOT applied at adapter level (dangerous for provisioning) - All wrappers have configurable options with sensible defaults - SuccessRequired: 2 for half-open→closed (extra safety) Add comprehensive tests for ResilientCompute (passthrough, circuit trip, bulkhead limits, timeout, unwrap, ping bypass). --- internal/platform/resilient_compute.go | 279 ++++++++++++++++++ internal/platform/resilient_compute_test.go | 299 ++++++++++++++++++++ internal/platform/resilient_dns.go | 125 ++++++++ internal/platform/resilient_lb.go | 97 +++++++ internal/platform/resilient_network.go | 211 ++++++++++++++ internal/platform/resilient_storage.go | 166 +++++++++++ 6 files changed, 1177 insertions(+) create mode 100644 internal/platform/resilient_compute.go create mode 100644 internal/platform/resilient_compute_test.go create mode 100644 internal/platform/resilient_dns.go create mode 100644 internal/platform/resilient_lb.go create mode 100644 internal/platform/resilient_network.go create mode 100644 internal/platform/resilient_storage.go diff --git a/internal/platform/resilient_compute.go b/internal/platform/resilient_compute.go new file mode 100644 index 000000000..bc5c0135d --- /dev/null +++ b/internal/platform/resilient_compute.go @@ -0,0 +1,279 @@ +package platform + +import ( + "context" + "fmt" + "io" + "log/slog" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientComputeOpts configures the resilient compute wrapper. +type ResilientComputeOpts struct { + // CallTimeout is the per-call context timeout for normal operations. + // Default: 2 minutes. + CallTimeout time.Duration + // LongCallTimeout is the timeout for operations that are expected to take + // longer (e.g., LaunchInstanceWithOptions, RunTask). Default: 10 minutes. + LongCallTimeout time.Duration + // CBThreshold is the number of consecutive failures before the circuit + // opens. Default: 5. + CBThreshold int + // CBResetTimeout is how long the circuit stays open before attempting a + // half-open probe. Default: 30s. + CBResetTimeout time.Duration + // BulkheadMaxConc is the max concurrent calls to the backend. Default: 20. + BulkheadMaxConc int + // BulkheadWait is how long to wait for a bulkhead slot. Default: 10s. + BulkheadWait time.Duration +} + +func (o ResilientComputeOpts) withDefaults() ResilientComputeOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 2 * time.Minute + } + if o.LongCallTimeout <= 0 { + o.LongCallTimeout = 10 * time.Minute + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + if o.BulkheadMaxConc <= 0 { + o.BulkheadMaxConc = 20 + } + if o.BulkheadWait <= 0 { + o.BulkheadWait = 10 * time.Second + } + return o +} + +// ResilientCompute wraps a ComputeBackend with circuit breaker, bulkhead, +// and per-call timeouts. It implements the ports.ComputeBackend interface. +type ResilientCompute struct { + inner ports.ComputeBackend + cb *CircuitBreaker + bulkhead *Bulkhead + logger *slog.Logger + opts ResilientComputeOpts +} + +// NewResilientCompute decorates inner with resilience primitives. +func NewResilientCompute(inner ports.ComputeBackend, logger *slog.Logger, opts ResilientComputeOpts) *ResilientCompute { + opts = opts.withDefaults() + name := fmt.Sprintf("compute-%s", inner.Type()) + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + bh := NewBulkhead(BulkheadOpts{ + Name: name, + MaxConc: opts.BulkheadMaxConc, + WaitTimeout: opts.BulkheadWait, + }) + + return &ResilientCompute{ + inner: inner, + cb: cb, + bulkhead: bh, + logger: logger.With("adapter", name), + opts: opts, + } +} + +// ---------- helpers ---------- + +// callProtected runs fn through bulkhead → circuit breaker → timeout. +func (r *ResilientCompute) callProtected(ctx context.Context, timeout time.Duration, fn func(ctx context.Context) error) error { + return r.bulkhead.Execute(ctx, func() error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return fn(ctx2) + }) + }) +} + +// ---------- Instance Lifecycle ---------- + +func (r *ResilientCompute) LaunchInstanceWithOptions(ctx context.Context, opts ports.CreateInstanceOptions) (string, []string, error) { + var id string + var ps []string + err := r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + var e error + id, ps, e = r.inner.LaunchInstanceWithOptions(ctx, opts) + return e + }) + return id, ps, err +} + +func (r *ResilientCompute) StartInstance(ctx context.Context, id string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.StartInstance(ctx, id) + }) +} + +func (r *ResilientCompute) StopInstance(ctx context.Context, id string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.StopInstance(ctx, id) + }) +} + +func (r *ResilientCompute) DeleteInstance(ctx context.Context, id string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DeleteInstance(ctx, id) + }) +} + +func (r *ResilientCompute) GetInstanceLogs(ctx context.Context, id string) (io.ReadCloser, error) { + var rc io.ReadCloser + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + rc, e = r.inner.GetInstanceLogs(ctx, id) + return e + }) + return rc, err +} + +func (r *ResilientCompute) GetInstanceStats(ctx context.Context, id string) (io.ReadCloser, error) { + var rc io.ReadCloser + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + rc, e = r.inner.GetInstanceStats(ctx, id) + return e + }) + return rc, err +} + +func (r *ResilientCompute) GetInstancePort(ctx context.Context, id string, internalPort string) (int, error) { + var port int + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + port, e = r.inner.GetInstancePort(ctx, id, internalPort) + return e + }) + return port, err +} + +func (r *ResilientCompute) GetInstanceIP(ctx context.Context, id string) (string, error) { + var ip string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + ip, e = r.inner.GetInstanceIP(ctx, id) + return e + }) + return ip, err +} + +func (r *ResilientCompute) GetConsoleURL(ctx context.Context, id string) (string, error) { + var url string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + url, e = r.inner.GetConsoleURL(ctx, id) + return e + }) + return url, err +} + +// ---------- Execution ---------- + +func (r *ResilientCompute) Exec(ctx context.Context, id string, cmd []string) (string, error) { + var out string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + out, e = r.inner.Exec(ctx, id, cmd) + return e + }) + return out, err +} + +func (r *ResilientCompute) RunTask(ctx context.Context, opts ports.RunTaskOptions) (string, []string, error) { + var id string + var ps []string + err := r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + var e error + id, ps, e = r.inner.RunTask(ctx, opts) + return e + }) + return id, ps, err +} + +func (r *ResilientCompute) WaitTask(ctx context.Context, id string) (int64, error) { + var code int64 + err := r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + var e error + code, e = r.inner.WaitTask(ctx, id) + return e + }) + return code, err +} + +// ---------- Network Management ---------- + +func (r *ResilientCompute) CreateNetwork(ctx context.Context, name string) (string, error) { + var id string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + id, e = r.inner.CreateNetwork(ctx, name) + return e + }) + return id, err +} + +func (r *ResilientCompute) DeleteNetwork(ctx context.Context, id string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DeleteNetwork(ctx, id) + }) +} + +// ---------- Volume Attachment ---------- + +func (r *ResilientCompute) AttachVolume(ctx context.Context, id string, volumePath string) (string, error) { + var devPath string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + devPath, e = r.inner.AttachVolume(ctx, id, volumePath) + return e + }) + return devPath, err +} + +func (r *ResilientCompute) DetachVolume(ctx context.Context, id string, volumePath string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DetachVolume(ctx, id, volumePath) + }) +} + +// ---------- Health ---------- + +// Ping bypasses the bulkhead (low cost, used for health checks) but still +// goes through the circuit breaker so a broken backend trips the circuit. +func (r *ResilientCompute) Ping(ctx context.Context) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return r.inner.Ping(ctx2) + }) +} + +// Type delegates directly — no protection needed. +func (r *ResilientCompute) Type() string { + return r.inner.Type() +} + +// Unwrap returns the underlying ComputeBackend (useful for tests). +func (r *ResilientCompute) Unwrap() ports.ComputeBackend { + return r.inner +} diff --git a/internal/platform/resilient_compute_test.go b/internal/platform/resilient_compute_test.go new file mode 100644 index 000000000..a682d2541 --- /dev/null +++ b/internal/platform/resilient_compute_test.go @@ -0,0 +1,299 @@ +package platform + +import ( + "context" + "errors" + "io" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" + "log/slog" +) + +// ---------- mock compute backend ---------- + +type mockCompute struct { + callCount atomic.Int64 + delay time.Duration + err error +} + +func (m *mockCompute) wait(ctx context.Context) error { + if m.delay <= 0 { + return nil + } + select { + case <-time.After(m.delay): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (m *mockCompute) LaunchInstanceWithOptions(ctx context.Context, _ ports.CreateInstanceOptions) (string, []string, error) { + m.callCount.Add(1) + if err := m.wait(ctx); err != nil { + return "", nil, err + } + return "inst-1", []string{"8080"}, m.err +} + +func (m *mockCompute) StartInstance(ctx context.Context, _ string) error { + m.callCount.Add(1) + if err := m.wait(ctx); err != nil { + return err + } + return m.err +} +func (m *mockCompute) StopInstance(_ context.Context, _ string) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) DeleteInstance(_ context.Context, _ string) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) GetInstanceLogs(_ context.Context, _ string) (io.ReadCloser, error) { + m.callCount.Add(1) + return io.NopCloser(strings.NewReader("logs")), m.err +} +func (m *mockCompute) GetInstanceStats(_ context.Context, _ string) (io.ReadCloser, error) { + m.callCount.Add(1) + return io.NopCloser(strings.NewReader("stats")), m.err +} +func (m *mockCompute) GetInstancePort(_ context.Context, _ string, _ string) (int, error) { + m.callCount.Add(1) + return 8080, m.err +} +func (m *mockCompute) GetInstanceIP(_ context.Context, _ string) (string, error) { + m.callCount.Add(1) + return "10.0.0.1", m.err +} +func (m *mockCompute) GetConsoleURL(_ context.Context, _ string) (string, error) { + m.callCount.Add(1) + return "https://console", m.err +} +func (m *mockCompute) Exec(_ context.Context, _ string, _ []string) (string, error) { + m.callCount.Add(1) + return "output", m.err +} +func (m *mockCompute) RunTask(_ context.Context, _ ports.RunTaskOptions) (string, []string, error) { + m.callCount.Add(1) + return "task-1", nil, m.err +} +func (m *mockCompute) WaitTask(_ context.Context, _ string) (int64, error) { + m.callCount.Add(1) + return 0, m.err +} +func (m *mockCompute) CreateNetwork(_ context.Context, _ string) (string, error) { + m.callCount.Add(1) + return "net-1", m.err +} +func (m *mockCompute) DeleteNetwork(_ context.Context, _ string) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) AttachVolume(_ context.Context, _ string, _ string) (string, error) { + m.callCount.Add(1) + return "/dev/vdb", m.err +} +func (m *mockCompute) DetachVolume(_ context.Context, _ string, _ string) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) Ping(_ context.Context) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) Type() string { return "mock" } + +// ---------- tests ---------- + +func TestResilientComputePassthrough(t *testing.T) { + // All calls should pass through to the mock on success. + mock := &mockCompute{} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{}) + + ctx := context.Background() + + id, ps, err := rc.LaunchInstanceWithOptions(ctx, ports.CreateInstanceOptions{}) + assertNoErr(t, err) + if id != "inst-1" || len(ps) != 1 { + t.Fatalf("unexpected launch result: %s %v", id, ps) + } + + assertNoErr(t, rc.StartInstance(ctx, "x")) + assertNoErr(t, rc.StopInstance(ctx, "x")) + assertNoErr(t, rc.DeleteInstance(ctx, "x")) + + _, err = rc.GetInstanceLogs(ctx, "x") + assertNoErr(t, err) + _, err = rc.GetInstanceStats(ctx, "x") + assertNoErr(t, err) + port, err := rc.GetInstancePort(ctx, "x", "80") + assertNoErr(t, err) + if port != 8080 { + t.Fatalf("expected 8080, got %d", port) + } + ip, err := rc.GetInstanceIP(ctx, "x") + assertNoErr(t, err) + if ip != "10.0.0.1" { + t.Fatalf("expected 10.0.0.1, got %s", ip) + } + + out, err := rc.Exec(ctx, "x", []string{"ls"}) + assertNoErr(t, err) + if out != "output" { + t.Fatalf("expected output, got %s", out) + } + + assertNoErr(t, rc.Ping(ctx)) + if rc.Type() != "mock" { + t.Fatalf("expected mock, got %s", rc.Type()) + } + + if mock.callCount.Load() < 10 { + t.Fatalf("expected at least 10 calls, got %d", mock.callCount.Load()) + } +} + +func TestResilientComputeCircuitTrips(t *testing.T) { + // After threshold failures, the circuit should open and reject immediately. + mock := &mockCompute{err: errors.New("backend down")} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{ + CBThreshold: 3, + CBResetTimeout: 5 * time.Second, + }) + + ctx := context.Background() + + // 3 failures to trip the circuit. + for i := 0; i < 3; i++ { + err := rc.StartInstance(ctx, "x") + if err == nil { + t.Fatal("expected error") + } + } + + // Next call should get ErrCircuitOpen without hitting the mock. + callsBefore := mock.callCount.Load() + err := rc.StartInstance(ctx, "x") + if !errors.Is(err, ErrCircuitOpen) { + t.Fatalf("expected ErrCircuitOpen, got %v", err) + } + if mock.callCount.Load() != callsBefore { + t.Fatal("expected mock not to be called when circuit is open") + } +} + +func TestResilientComputeBulkheadLimits(t *testing.T) { + // When bulkhead is full, calls should be rejected. + mock := &mockCompute{delay: 500 * time.Millisecond} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{ + BulkheadMaxConc: 2, + BulkheadWait: 50 * time.Millisecond, + CallTimeout: 2 * time.Second, + }) + + ctx := context.Background() + var wg sync.WaitGroup + var bulkheadErrors atomic.Int64 + + // Ensure the first 2 goroutines grab the slots before the rest start. + ready := make(chan struct{}) + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + if idx >= 2 { + <-ready // Wait until the first 2 have started. + } + err := rc.StartInstance(ctx, "x") + if errors.Is(err, ErrBulkheadFull) { + bulkheadErrors.Add(1) + } + }(i) + } + // Give the first 2 goroutines time to acquire the slots. + time.Sleep(50 * time.Millisecond) + close(ready) + wg.Wait() + + if bulkheadErrors.Load() == 0 { + t.Fatal("expected at least one bulkhead rejection") + } +} + +func TestResilientComputeTimeout(t *testing.T) { + // A slow backend should be cancelled by the per-call timeout. + mock := &mockCompute{delay: 5 * time.Second} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{ + CallTimeout: 100 * time.Millisecond, + }) + + ctx := context.Background() + start := time.Now() + err := rc.StartInstance(ctx, "x") + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected timeout error") + } + // Should complete much faster than 5s. + if elapsed > 2*time.Second { + t.Fatalf("timeout not enforced, took %v", elapsed) + } +} + +func TestResilientComputeUnwrap(t *testing.T) { + mock := &mockCompute{} + rc := NewResilientCompute(mock, slog.Default(), ResilientComputeOpts{}) + if rc.Unwrap() != mock { + t.Fatal("Unwrap should return the inner backend") + } +} + +func TestResilientComputePingBypassesBulkhead(t *testing.T) { + // Ping should work even when the bulkhead is completely full. + mock := &mockCompute{delay: 500 * time.Millisecond} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{ + BulkheadMaxConc: 1, + BulkheadWait: 10 * time.Millisecond, + }) + + ctx := context.Background() + + // Saturate the bulkhead. + started := make(chan struct{}) + go func() { + close(started) + _ = rc.StartInstance(ctx, "x") + }() + <-started + time.Sleep(20 * time.Millisecond) + + // Ping should still work (bypasses bulkhead). + err := rc.Ping(ctx) + // err may or may not be nil depending on timing, but it must NOT be ErrBulkheadFull. + if errors.Is(err, ErrBulkheadFull) { + t.Fatal("Ping should bypass bulkhead") + } +} + +// ---------- test helpers ---------- + +func assertNoErr(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/platform/resilient_dns.go b/internal/platform/resilient_dns.go new file mode 100644 index 000000000..0f1419c2e --- /dev/null +++ b/internal/platform/resilient_dns.go @@ -0,0 +1,125 @@ +package platform + +import ( + "context" + "log/slog" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientDNSOpts configures the resilient DNS wrapper. +type ResilientDNSOpts struct { + CallTimeout time.Duration // Per-call timeout. Default: 10s. + CBThreshold int // Failures to trip. Default: 5. + CBResetTimeout time.Duration // Open→half-open wait. Default: 30s. +} + +func (o ResilientDNSOpts) withDefaults() ResilientDNSOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 10 * time.Second + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + return o +} + +// ResilientDNS wraps a DNSBackend with circuit breaker and per-call timeouts. +// DNS calls are lightweight so no bulkhead is applied (PowerDNS HTTP API is +// already serialized). +type ResilientDNS struct { + inner ports.DNSBackend + cb *CircuitBreaker + logger *slog.Logger + opts ResilientDNSOpts +} + +// NewResilientDNS decorates inner with resilience primitives. +func NewResilientDNS(inner ports.DNSBackend, logger *slog.Logger, opts ResilientDNSOpts) *ResilientDNS { + opts = opts.withDefaults() + name := "dns-powerdns" + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + return &ResilientDNS{ + inner: inner, + cb: cb, + logger: logger.With("adapter", name), + opts: opts, + } +} + +func (r *ResilientDNS) callProtected(ctx context.Context, fn func(ctx context.Context) error) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.CallTimeout) + defer cancel() + return fn(ctx2) + }) +} + +// ---------- Zone Operations ---------- + +func (r *ResilientDNS) CreateZone(ctx context.Context, zoneName string, nameservers []string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.CreateZone(ctx, zoneName, nameservers) + }) +} + +func (r *ResilientDNS) DeleteZone(ctx context.Context, zoneName string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteZone(ctx, zoneName) + }) +} + +func (r *ResilientDNS) GetZone(ctx context.Context, zoneName string) (*ports.ZoneInfo, error) { + var info *ports.ZoneInfo + err := r.callProtected(ctx, func(ctx context.Context) error { + var e error + info, e = r.inner.GetZone(ctx, zoneName) + return e + }) + return info, err +} + +// ---------- Record Operations ---------- + +func (r *ResilientDNS) AddRecords(ctx context.Context, zoneName string, records []ports.RecordSet) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.AddRecords(ctx, zoneName, records) + }) +} + +func (r *ResilientDNS) UpdateRecords(ctx context.Context, zoneName string, records []ports.RecordSet) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.UpdateRecords(ctx, zoneName, records) + }) +} + +func (r *ResilientDNS) DeleteRecords(ctx context.Context, zoneName string, name string, recordType string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteRecords(ctx, zoneName, name, recordType) + }) +} + +func (r *ResilientDNS) ListRecords(ctx context.Context, zoneName string) ([]ports.RecordSet, error) { + var records []ports.RecordSet + err := r.callProtected(ctx, func(ctx context.Context) error { + var e error + records, e = r.inner.ListRecords(ctx, zoneName) + return e + }) + return records, err +} diff --git a/internal/platform/resilient_lb.go b/internal/platform/resilient_lb.go new file mode 100644 index 000000000..d646ef701 --- /dev/null +++ b/internal/platform/resilient_lb.go @@ -0,0 +1,97 @@ +package platform + +import ( + "context" + "log/slog" + "time" + + "github.com/google/uuid" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientLBOpts configures the resilient load balancer proxy wrapper. +type ResilientLBOpts struct { + CallTimeout time.Duration // Per-call timeout. Default: 30s. + LongTimeout time.Duration // Timeout for DeployProxy (container launch). Default: 2m. + CBThreshold int // Failures to trip. Default: 5. + CBResetTimeout time.Duration // Open→half-open wait. Default: 30s. +} + +func (o ResilientLBOpts) withDefaults() ResilientLBOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 30 * time.Second + } + if o.LongTimeout <= 0 { + o.LongTimeout = 2 * time.Minute + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + return o +} + +// ResilientLB wraps an LBProxyAdapter with circuit breaker and per-call timeouts. +// LB proxy has only 3 methods so no bulkhead is needed — the compute bulkhead +// already limits the underlying container/VM creation. +type ResilientLB struct { + inner ports.LBProxyAdapter + cb *CircuitBreaker + logger *slog.Logger + opts ResilientLBOpts +} + +// NewResilientLB decorates inner with resilience primitives. +func NewResilientLB(inner ports.LBProxyAdapter, logger *slog.Logger, opts ResilientLBOpts) *ResilientLB { + opts = opts.withDefaults() + name := "lb-proxy" + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + return &ResilientLB{ + inner: inner, + cb: cb, + logger: logger.With("adapter", name), + opts: opts, + } +} + +func (r *ResilientLB) DeployProxy(ctx context.Context, lb *domain.LoadBalancer, targets []*domain.LBTarget) (string, error) { + var addr string + err := r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.LongTimeout) + defer cancel() + var e error + addr, e = r.inner.DeployProxy(ctx2, lb, targets) + return e + }) + return addr, err +} + +func (r *ResilientLB) RemoveProxy(ctx context.Context, lbID uuid.UUID) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.CallTimeout) + defer cancel() + return r.inner.RemoveProxy(ctx2, lbID) + }) +} + +func (r *ResilientLB) UpdateProxyConfig(ctx context.Context, lb *domain.LoadBalancer, targets []*domain.LBTarget) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.CallTimeout) + defer cancel() + return r.inner.UpdateProxyConfig(ctx2, lb, targets) + }) +} diff --git a/internal/platform/resilient_network.go b/internal/platform/resilient_network.go new file mode 100644 index 000000000..69fba40cb --- /dev/null +++ b/internal/platform/resilient_network.go @@ -0,0 +1,211 @@ +package platform + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientNetworkOpts configures the resilient network wrapper. +type ResilientNetworkOpts struct { + CallTimeout time.Duration // Per-call timeout. Default: 30s. + CBThreshold int // Failures to trip. Default: 5. + CBResetTimeout time.Duration // Open→half-open wait. Default: 30s. + BulkheadMaxConc int // Max concurrent calls. Default: 15. + BulkheadWait time.Duration // Bulkhead slot wait. Default: 10s. +} + +func (o ResilientNetworkOpts) withDefaults() ResilientNetworkOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 30 * time.Second + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + if o.BulkheadMaxConc <= 0 { + o.BulkheadMaxConc = 15 + } + if o.BulkheadWait <= 0 { + o.BulkheadWait = 10 * time.Second + } + return o +} + +// ResilientNetwork wraps a NetworkBackend with circuit breaker, bulkhead, +// and per-call timeouts. It implements the ports.NetworkBackend interface. +type ResilientNetwork struct { + inner ports.NetworkBackend + cb *CircuitBreaker + bulkhead *Bulkhead + logger *slog.Logger + opts ResilientNetworkOpts +} + +// NewResilientNetwork decorates inner with resilience primitives. +func NewResilientNetwork(inner ports.NetworkBackend, logger *slog.Logger, opts ResilientNetworkOpts) *ResilientNetwork { + opts = opts.withDefaults() + name := fmt.Sprintf("network-%s", inner.Type()) + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + bh := NewBulkhead(BulkheadOpts{ + Name: name, + MaxConc: opts.BulkheadMaxConc, + WaitTimeout: opts.BulkheadWait, + }) + + return &ResilientNetwork{ + inner: inner, + cb: cb, + bulkhead: bh, + logger: logger.With("adapter", name), + opts: opts, + } +} + +// callProtected runs fn through bulkhead → circuit breaker → timeout. +func (r *ResilientNetwork) callProtected(ctx context.Context, fn func(ctx context.Context) error) error { + return r.bulkhead.Execute(ctx, func() error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.CallTimeout) + defer cancel() + return fn(ctx2) + }) + }) +} + +// ---------- Bridge Management ---------- + +func (r *ResilientNetwork) CreateBridge(ctx context.Context, name string, vxlanID int) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.CreateBridge(ctx, name, vxlanID) + }) +} + +func (r *ResilientNetwork) DeleteBridge(ctx context.Context, name string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteBridge(ctx, name) + }) +} + +func (r *ResilientNetwork) ListBridges(ctx context.Context) ([]string, error) { + var bridges []string + err := r.callProtected(ctx, func(ctx context.Context) error { + var e error + bridges, e = r.inner.ListBridges(ctx) + return e + }) + return bridges, err +} + +// ---------- Port Management ---------- + +func (r *ResilientNetwork) AddPort(ctx context.Context, bridge, portName string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.AddPort(ctx, bridge, portName) + }) +} + +func (r *ResilientNetwork) DeletePort(ctx context.Context, bridge, portName string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeletePort(ctx, bridge, portName) + }) +} + +// ---------- VXLAN Tunnels ---------- + +func (r *ResilientNetwork) CreateVXLANTunnel(ctx context.Context, bridge string, vni int, remoteIP string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.CreateVXLANTunnel(ctx, bridge, vni, remoteIP) + }) +} + +func (r *ResilientNetwork) DeleteVXLANTunnel(ctx context.Context, bridge string, remoteIP string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteVXLANTunnel(ctx, bridge, remoteIP) + }) +} + +// ---------- Security Groups (Flow Rules) ---------- + +func (r *ResilientNetwork) AddFlowRule(ctx context.Context, bridge string, rule ports.FlowRule) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.AddFlowRule(ctx, bridge, rule) + }) +} + +func (r *ResilientNetwork) DeleteFlowRule(ctx context.Context, bridge string, match string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteFlowRule(ctx, bridge, match) + }) +} + +func (r *ResilientNetwork) ListFlowRules(ctx context.Context, bridge string) ([]ports.FlowRule, error) { + var rules []ports.FlowRule + err := r.callProtected(ctx, func(ctx context.Context) error { + var e error + rules, e = r.inner.ListFlowRules(ctx, bridge) + return e + }) + return rules, err +} + +// ---------- Veth Pair Management ---------- + +func (r *ResilientNetwork) CreateVethPair(ctx context.Context, hostEnd, containerEnd string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.CreateVethPair(ctx, hostEnd, containerEnd) + }) +} + +func (r *ResilientNetwork) AttachVethToBridge(ctx context.Context, bridge, vethEnd string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.AttachVethToBridge(ctx, bridge, vethEnd) + }) +} + +func (r *ResilientNetwork) DeleteVethPair(ctx context.Context, hostEnd string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteVethPair(ctx, hostEnd) + }) +} + +func (r *ResilientNetwork) SetVethIP(ctx context.Context, vethEnd, ip, cidr string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.SetVethIP(ctx, vethEnd, ip, cidr) + }) +} + +// ---------- Health ---------- + +func (r *ResilientNetwork) Ping(ctx context.Context) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return r.inner.Ping(ctx2) + }) +} + +func (r *ResilientNetwork) Type() string { + return r.inner.Type() +} + +// Unwrap returns the underlying NetworkBackend. +func (r *ResilientNetwork) Unwrap() ports.NetworkBackend { + return r.inner +} diff --git a/internal/platform/resilient_storage.go b/internal/platform/resilient_storage.go new file mode 100644 index 000000000..4be7798d4 --- /dev/null +++ b/internal/platform/resilient_storage.go @@ -0,0 +1,166 @@ +package platform + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientStorageOpts configures the resilient storage wrapper. +type ResilientStorageOpts struct { + CallTimeout time.Duration // Per-call timeout. Default: 30s. + LongCallTimeout time.Duration // Timeout for snapshot/restore. Default: 5m. + CBThreshold int // Failures to trip. Default: 5. + CBResetTimeout time.Duration // Open→half-open wait. Default: 30s. + BulkheadMaxConc int // Max concurrent calls. Default: 10. + BulkheadWait time.Duration // Bulkhead slot wait. Default: 10s. +} + +func (o ResilientStorageOpts) withDefaults() ResilientStorageOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 30 * time.Second + } + if o.LongCallTimeout <= 0 { + o.LongCallTimeout = 5 * time.Minute + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + if o.BulkheadMaxConc <= 0 { + o.BulkheadMaxConc = 10 + } + if o.BulkheadWait <= 0 { + o.BulkheadWait = 10 * time.Second + } + return o +} + +// ResilientStorage wraps a StorageBackend with circuit breaker, bulkhead, +// and per-call timeouts. +type ResilientStorage struct { + inner ports.StorageBackend + cb *CircuitBreaker + bulkhead *Bulkhead + logger *slog.Logger + opts ResilientStorageOpts +} + +// NewResilientStorage decorates inner with resilience primitives. +func NewResilientStorage(inner ports.StorageBackend, logger *slog.Logger, opts ResilientStorageOpts) *ResilientStorage { + opts = opts.withDefaults() + name := fmt.Sprintf("storage-%s", inner.Type()) + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + bh := NewBulkhead(BulkheadOpts{ + Name: name, + MaxConc: opts.BulkheadMaxConc, + WaitTimeout: opts.BulkheadWait, + }) + + return &ResilientStorage{ + inner: inner, + cb: cb, + bulkhead: bh, + logger: logger.With("adapter", name), + opts: opts, + } +} + +func (r *ResilientStorage) callProtected(ctx context.Context, timeout time.Duration, fn func(ctx context.Context) error) error { + return r.bulkhead.Execute(ctx, func() error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return fn(ctx2) + }) + }) +} + +func (r *ResilientStorage) CreateVolume(ctx context.Context, name string, sizeGB int) (string, error) { + var path string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + path, e = r.inner.CreateVolume(ctx, name, sizeGB) + return e + }) + return path, err +} + +func (r *ResilientStorage) DeleteVolume(ctx context.Context, name string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DeleteVolume(ctx, name) + }) +} + +func (r *ResilientStorage) ResizeVolume(ctx context.Context, name string, newSizeGB int) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.ResizeVolume(ctx, name, newSizeGB) + }) +} + +func (r *ResilientStorage) AttachVolume(ctx context.Context, volumeName, instanceID string) (string, error) { + var devPath string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + devPath, e = r.inner.AttachVolume(ctx, volumeName, instanceID) + return e + }) + return devPath, err +} + +func (r *ResilientStorage) DetachVolume(ctx context.Context, volumeName, instanceID string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DetachVolume(ctx, volumeName, instanceID) + }) +} + +func (r *ResilientStorage) CreateSnapshot(ctx context.Context, volumeName, snapshotName string) error { + return r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + return r.inner.CreateSnapshot(ctx, volumeName, snapshotName) + }) +} + +func (r *ResilientStorage) RestoreSnapshot(ctx context.Context, volumeName, snapshotName string) error { + return r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + return r.inner.RestoreSnapshot(ctx, volumeName, snapshotName) + }) +} + +func (r *ResilientStorage) DeleteSnapshot(ctx context.Context, snapshotName string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DeleteSnapshot(ctx, snapshotName) + }) +} + +func (r *ResilientStorage) Ping(ctx context.Context) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return r.inner.Ping(ctx2) + }) +} + +func (r *ResilientStorage) Type() string { + return r.inner.Type() +} + +// Unwrap returns the underlying StorageBackend. +func (r *ResilientStorage) Unwrap() ports.StorageBackend { + return r.inner +} From 01f612d8e8b7b201aa88bd1116393bee9126959f Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 10 Mar 2026 05:40:51 +0300 Subject: [PATCH 05/12] feat(ha): Wire resilient backends and update workers for HA - Wrap all backends with resilient decorators in main.go: NewResilientCompute, NewResilientStorage, NewResilientNetwork, NewResilientLB - Wrap DNS backend with resilient decorator in dependencies.go - Create PgLeaderElector and wire into ServiceConfig - Update ProvisionWorker, ClusterWorker, PipelineWorker to use: * DurableTaskQueue (Redis Streams with consumer groups) * ExecutionLedger for idempotent job processing * Bounded concurrency via semaphore (provision=20, cluster=10, pipeline=5) - Update workers to use Receive/Ack/Nack pattern for exactly-once delivery - Add role validation tests --- cmd/api/main.go | 39 +++- cmd/api/main_test.go | 87 +++++++++ internal/api/setup/dependencies.go | 120 ++++++++---- internal/workers/cluster_worker.go | 227 +++++++++++++++++----- internal/workers/cluster_worker_test.go | 67 ++++++- internal/workers/pipeline_worker.go | 155 +++++++++++++-- internal/workers/provision_worker.go | 189 +++++++++++++++--- internal/workers/provision_worker_test.go | 114 +++++++---- 8 files changed, 808 insertions(+), 190 deletions(-) diff --git a/cmd/api/main.go b/cmd/api/main.go index 1b851bb7b..569086b85 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -130,16 +130,28 @@ func run() error { defer db.Close() defer func() { _ = rdb.Close() }() - compute, storage, network, lbProxy, err := initBackends(deps, cfg, logger, db, rdb) + rawCompute, rawStorage, rawNetwork, rawLBProxy, err := initBackends(deps, cfg, logger, db, rdb) if err != nil { logger.Error("backend initialization failed", "error", err) return err } + // Wrap raw backends with resilience decorators (circuit breaker, bulkhead, timeouts). + compute := platform.NewResilientCompute(rawCompute, logger, platform.ResilientComputeOpts{}) + storage := platform.NewResilientStorage(rawStorage, logger, platform.ResilientStorageOpts{}) + network := platform.NewResilientNetwork(rawNetwork, logger, platform.ResilientNetworkOpts{}) + lbProxy := platform.NewResilientLB(rawLBProxy, logger, platform.ResilientLBOpts{}) + repos := deps.InitRepositories(db, rdb) + + // Create leader elector for singleton worker coordination. + // When multiple worker replicas run, only one will hold leadership per key. + leaderElector := postgres.NewPgLeaderElector(db, logger) + svcs, workers, err := deps.InitServices(setup.ServiceConfig{ Config: cfg, Repos: repos, Compute: compute, Storage: storage, Network: network, LBProxy: lbProxy, DB: db, RDB: rdb, Logger: logger, + LeaderElector: leaderElector, }) if err != nil { logger.Error("service initialization failed", "error", err) @@ -159,11 +171,18 @@ func run() error { } func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r *gin.Engine, workers *setup.Workers) { - role := os.Getenv("APP_ROLE") + role := os.Getenv("ROLE") if role == "" { role = "all" } + validRoles := map[string]bool{"api": true, "worker": true, "all": true} + if !validRoles[role] { + logger.Error("invalid ROLE value, must be one of: api, worker, all", "role", role) + return + } + logger.Info("starting with role", "role", role) + wg := &sync.WaitGroup{} workerCtx, workerCancel := context.WithCancel(context.Background()) @@ -171,9 +190,9 @@ func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r * runWorkers(workerCtx, wg, workers) } - srv := deps.NewHTTPServer(":"+cfg.Port, r) - + var srv *http.Server if role == "api" || role == "all" { + srv = deps.NewHTTPServer(":"+cfg.Port, r) go func() { logger.Info("starting compute-api", "port", cfg.Port) if err := deps.StartHTTPServer(srv); err != nil && !stdlib_errors.Is(err, http.ErrServerClosed) { @@ -181,25 +200,27 @@ func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r * } }() } else { - logger.Info("running in worker-only mode") + logger.Info("running in worker-only mode, HTTP server disabled") } quit := make(chan os.Signal, 1) deps.NotifySignals(quit, syscall.SIGINT, syscall.SIGTERM) <-quit - logger.Info("shutting down server...") + logger.Info("shutting down...") ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownTimeout) defer cancel() - if err := deps.ShutdownHTTPServer(ctx, srv); err != nil { - logger.Error("server forced to shutdown", "error", err) + if srv != nil { + if err := deps.ShutdownHTTPServer(ctx, srv); err != nil { + logger.Error("server forced to shutdown", "error", err) + } } workerCancel() wg.Wait() - logger.Info("server exited") + logger.Info("shutdown complete") } type runner interface { diff --git a/cmd/api/main_test.go b/cmd/api/main_test.go index 6f1f81af2..791f872dc 100644 --- a/cmd/api/main_test.go +++ b/cmd/api/main_test.go @@ -167,6 +167,93 @@ func TestRunApplicationApiRoleStartsAndShutsDown(t *testing.T) { } } +func TestRunApplicationWorkerRoleDoesNotStartHTTP(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + t.Setenv("ROLE", "worker") + + deps := DefaultDeps() + + deps.NewHTTPServer = func(string, http.Handler) *http.Server { + t.Fatalf("NewHTTPServer should not be called in worker-only mode") + return nil + } + deps.StartHTTPServer = func(*http.Server) error { + t.Fatalf("StartHTTPServer should not be called in worker-only mode") + return nil + } + deps.ShutdownHTTPServer = func(context.Context, *http.Server) error { + t.Fatalf("ShutdownHTTPServer should not be called in worker-only mode") + return nil + } + deps.NotifySignals = func(c chan<- os.Signal, _ ...os.Signal) { + go func() { + // Give workers a moment to start, then signal shutdown + time.Sleep(50 * time.Millisecond) + c <- syscall.SIGTERM + }() + } + + runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + // If we reach here without t.Fatalf, the test passes — no HTTP server was touched. +} + +func TestRunApplicationDefaultsToAllRole(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + t.Setenv("ROLE", "") // Explicitly empty to verify default + + serverStarted := false + deps := DefaultDeps() + + deps.NewHTTPServer = func(addr string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + } + } + deps.StartHTTPServer = func(*http.Server) error { + serverStarted = true + return http.ErrServerClosed + } + deps.ShutdownHTTPServer = func(context.Context, *http.Server) error { + return nil + } + deps.NotifySignals = func(c chan<- os.Signal, _ ...os.Signal) { + go func() { + time.Sleep(50 * time.Millisecond) + c <- syscall.SIGTERM + }() + } + + runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + + if !serverStarted { + t.Fatalf("expected HTTP server to start when ROLE defaults to 'all'") + } +} + +func TestRunApplicationInvalidRoleReturnsEarly(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + t.Setenv("ROLE", "invalid") + + deps := DefaultDeps() + + deps.NewHTTPServer = func(string, http.Handler) *http.Server { + t.Fatalf("NewHTTPServer should not be called for invalid role") + return nil + } + deps.StartHTTPServer = func(*http.Server) error { + t.Fatalf("StartHTTPServer should not be called for invalid role") + return nil + } + deps.NotifySignals = func(c chan<- os.Signal, _ ...os.Signal) { + t.Fatalf("NotifySignals should not be called for invalid role") + } + + // Should return immediately without starting anything + runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) +} + // Stub helpers below keep main.go testable without altering production behavior. type stubDB struct{ closed bool } diff --git a/internal/api/setup/dependencies.go b/internal/api/setup/dependencies.go index 12120305c..6fc0fd55d 100644 --- a/internal/api/setup/dependencies.go +++ b/internal/api/setup/dependencies.go @@ -2,8 +2,10 @@ package setup import ( + "context" "fmt" "log/slog" + "sync" "strings" @@ -54,6 +56,8 @@ type Repositories struct { AutoScaling ports.AutoScalingRepository Accounting ports.AccountingRepository TaskQueue ports.TaskQueue + DurableQueue ports.DurableTaskQueue + Ledger ports.ExecutionLedger Image ports.ImageRepository Cluster ports.ClusterRepository Lifecycle ports.LifecycleRepository @@ -99,6 +103,8 @@ func InitRepositories(db postgres.DB, rdb *redisv9.Client) *Repositories { AutoScaling: postgres.NewAutoScalingRepo(db), Accounting: postgres.NewAccountingRepository(db), TaskQueue: redis.NewRedisTaskQueue(rdb), + DurableQueue: redis.NewDurableTaskQueue(rdb), + Ledger: postgres.NewExecutionLedger(db), Image: postgres.NewImageRepository(db), Cluster: postgres.NewClusterRepository(db), Lifecycle: postgres.NewLifecycleRepository(db), @@ -160,35 +166,46 @@ type Services struct { VPCPeering ports.VPCPeeringService } -// Workers struct to return background workers +// Runner is the interface that all background workers implement. +type Runner interface { + Run(context.Context, *sync.WaitGroup) +} + +// Workers struct to return background workers. +// Singleton workers are typed as Runner so they can be wrapped with LeaderGuard. +// Parallel consumers retain concrete types for direct configuration access. type Workers struct { - LB *services.LBWorker - AutoScaling *services.AutoScalingWorker - Cron *services.CronWorker - Container *services.ContainerWorker - Pipeline *workers.PipelineWorker - Provision *workers.ProvisionWorker - Accounting *workers.AccountingWorker - Cluster *workers.ClusterWorker - Lifecycle *workers.LifecycleWorker - ReplicaMonitor *workers.ReplicaMonitor - ClusterReconciler *workers.ClusterReconciler - Healing *workers.HealingWorker - DatabaseFailover *workers.DatabaseFailoverWorker - Log *workers.LogWorker + // Singleton workers (must run on exactly one node via leader election) + LB Runner + AutoScaling Runner + Cron Runner + Container Runner + Accounting Runner + Lifecycle Runner + ReplicaMonitor Runner + ClusterReconciler Runner + Healing Runner + DatabaseFailover Runner + Log Runner + + // Parallel consumer workers (safe to run on multiple nodes) + Pipeline *workers.PipelineWorker + Provision *workers.ProvisionWorker + Cluster *workers.ClusterWorker } // ServiceConfig holds the dependencies required to initialize services type ServiceConfig struct { - Config *platform.Config - Repos *Repositories - Compute ports.ComputeBackend - Storage ports.StorageBackend - Network ports.NetworkBackend - LBProxy ports.LBProxyAdapter - DB postgres.DB - RDB *redisv9.Client - Logger *slog.Logger + Config *platform.Config + Repos *Repositories + Compute ports.ComputeBackend + Storage ports.StorageBackend + Network ports.NetworkBackend + LBProxy ports.LBProxyAdapter + DB postgres.DB + RDB *redisv9.Client + Logger *slog.Logger + LeaderElector ports.LeaderElector // nil disables leader election (single-instance mode) } // InitServices constructs core services and background workers. @@ -216,8 +233,10 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { if err != nil { return nil, nil, fmt.Errorf("failed to init powerdns backend: %w", err) } + // Wrap DNS backend with resilience (circuit breaker + timeout). + resilientDNS := platform.NewResilientDNS(pdnsBackend, c.Logger, platform.ResilientDNSOpts{}) dnsSvc := services.NewDNSService(services.DNSServiceParams{ - Repo: c.Repos.DNS, Backend: pdnsBackend, VpcRepo: c.Repos.Vpc, + Repo: c.Repos.DNS, Backend: resilientDNS, VpcRepo: c.Repos.Vpc, AuditSvc: auditSvc, EventSvc: eventSvc, Logger: c.Logger, }) @@ -293,7 +312,7 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { accountingWorker := workers.NewAccountingWorker(accountingSvc, c.Logger) imageSvc := services.NewImageService(c.Repos.Image, fileStore, c.Logger) iamSvc := services.NewIAMService(c.Repos.IAM, auditSvc, eventSvc, c.Logger) - provisionWorker := workers.NewProvisionWorker(instSvcConcrete, c.Repos.TaskQueue, c.Logger) + provisionWorker := workers.NewProvisionWorker(instSvcConcrete, c.Repos.DurableQueue, c.Repos.Ledger, c.Logger) healingWorker := workers.NewHealingWorker(instSvcConcrete, c.Repos.Instance, c.Logger) clusterSvc, clusterProvisioner, err := initClusterServices(c, vpcSvc, instSvcConcrete, secretSvc, storageSvc, lbSvc, sgSvc) @@ -333,17 +352,46 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { // 7. High Availability & Monitoring replicaMonitor := initReplicaMonitor(c) + // Helper: wrap a singleton worker with LeaderGuard if leader election is enabled. + // Accepts a concrete pointer to avoid nil-interface pitfalls — callers must + // explicitly pass nil Runner when the worker should be skipped. + guardSingleton := func(key string, w Runner) Runner { + if w == nil || c.LeaderElector == nil { + return w + } + return workers.NewLeaderGuard(c.LeaderElector, key, w, c.Logger) + } + + lifecycleWorker := workers.NewLifecycleWorker(c.Repos.Lifecycle, storageSvc, c.Repos.Storage, c.Logger) + clusterReconciler := workers.NewClusterReconciler(c.Repos.Cluster, clusterProvisioner, c.Logger) + dbFailoverWorker := workers.NewDatabaseFailoverWorker(databaseSvc, c.Repos.Database, c.Logger) + logWorker := workers.NewLogWorker(logSvc, c.Logger) + + // For replicaMonitor, we must convert nil *ReplicaMonitor to nil Runner to avoid + // a non-nil interface wrapping a nil pointer. + var replicaMonitorRunner Runner + if replicaMonitor != nil { + replicaMonitorRunner = replicaMonitor + } + workersCollection := &Workers{ - LB: lbWorker, AutoScaling: asgWorker, Cron: cronWorker, Container: containerWorker, - Pipeline: workers.NewPipelineWorker(c.Repos.Pipeline, c.Repos.TaskQueue, c.Compute, c.Logger), - Provision: provisionWorker, Accounting: accountingWorker, - Cluster: workers.NewClusterWorker(c.Repos.Cluster, clusterProvisioner, c.Repos.TaskQueue, c.Logger), - Lifecycle: workers.NewLifecycleWorker(c.Repos.Lifecycle, storageSvc, c.Repos.Storage, c.Logger), - ReplicaMonitor: replicaMonitor, - ClusterReconciler: workers.NewClusterReconciler(c.Repos.Cluster, clusterProvisioner, c.Logger), - Healing: healingWorker, - DatabaseFailover: workers.NewDatabaseFailoverWorker(databaseSvc, c.Repos.Database, c.Logger), - Log: workers.NewLogWorker(logSvc, c.Logger), + // Singleton workers — wrapped with leader election + LB: guardSingleton("singleton:lb", lbWorker), + AutoScaling: guardSingleton("singleton:autoscaling", asgWorker), + Cron: guardSingleton("singleton:cron", cronWorker), + Container: guardSingleton("singleton:container", containerWorker), + Accounting: guardSingleton("singleton:accounting", accountingWorker), + Lifecycle: guardSingleton("singleton:lifecycle", lifecycleWorker), + ReplicaMonitor: guardSingleton("singleton:replica-monitor", replicaMonitorRunner), + ClusterReconciler: guardSingleton("singleton:cluster-reconciler", clusterReconciler), + Healing: guardSingleton("singleton:healing", healingWorker), + DatabaseFailover: guardSingleton("singleton:db-failover", dbFailoverWorker), + Log: guardSingleton("singleton:log", logWorker), + + // Parallel consumer workers — no leader election needed + Pipeline: workers.NewPipelineWorker(c.Repos.Pipeline, c.Repos.DurableQueue, c.Repos.Ledger, c.Compute, c.Logger), + Provision: provisionWorker, + Cluster: workers.NewClusterWorker(c.Repos.Cluster, clusterProvisioner, c.Repos.DurableQueue, c.Repos.Ledger, c.Logger), } return svcs, workersCollection, nil diff --git a/internal/workers/cluster_worker.go b/internal/workers/cluster_worker.go index 591559dd5..149eab9b8 100644 --- a/internal/workers/cluster_worker.go +++ b/internal/workers/cluster_worker.go @@ -4,7 +4,9 @@ package workers import ( "context" "encoding/json" + "fmt" "log/slog" + "os" "sync" "time" @@ -13,34 +15,56 @@ import ( "github.com/poyrazk/thecloud/internal/core/ports" ) +const ( + clusterQueue = "k8s_jobs" + clusterGroup = "cluster_workers" + clusterMaxWorkers = 10 + clusterReclaimMs = 5 * 60 * 1000 // 5 minutes + clusterReclaimN = 10 + clusterStaleThreshold = 15 * time.Minute +) + // ClusterWorker handles background tasks for Kubernetes cluster lifecycle management. type ClusterWorker struct { - repo ports.ClusterRepository - provisioner ports.ClusterProvisioner - taskQueue ports.TaskQueue - logger *slog.Logger + repo ports.ClusterRepository + provisioner ports.ClusterProvisioner + taskQueue ports.DurableTaskQueue + ledger ports.ExecutionLedger + logger *slog.Logger + consumerName string } // NewClusterWorker creates a new ClusterWorker. -func NewClusterWorker(repo ports.ClusterRepository, provisioner ports.ClusterProvisioner, taskQueue ports.TaskQueue, logger *slog.Logger) *ClusterWorker { +func NewClusterWorker(repo ports.ClusterRepository, provisioner ports.ClusterProvisioner, taskQueue ports.DurableTaskQueue, ledger ports.ExecutionLedger, logger *slog.Logger) *ClusterWorker { + hostname, _ := os.Hostname() + if hostname == "" { + hostname = "cluster-worker" + } return &ClusterWorker{ - repo: repo, - provisioner: provisioner, - taskQueue: taskQueue, - logger: logger, + repo: repo, + provisioner: provisioner, + taskQueue: taskQueue, + ledger: ledger, + logger: logger, + consumerName: hostname, } } -const ( - queuePollBackoff = 1 * time.Second - maxConcurrentClusts = 10 -) - func (w *ClusterWorker) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - w.logger.Info("starting cluster worker", "concurrency", maxConcurrentClusts) + w.logger.Info("starting cluster worker", + "consumer", w.consumerName, + "concurrency", clusterMaxWorkers, + ) + + if err := w.taskQueue.EnsureGroup(ctx, clusterQueue, clusterGroup); err != nil { + w.logger.Error("failed to ensure cluster consumer group", "error", err) + return + } + + sem := make(chan struct{}, clusterMaxWorkers) - sem := make(chan struct{}, maxConcurrentClusts) + go w.reclaimLoop(ctx, sem) for { select { @@ -48,59 +72,113 @@ func (w *ClusterWorker) Run(ctx context.Context, wg *sync.WaitGroup) { w.logger.Info("stopping cluster worker") return default: - msg, err := w.taskQueue.Dequeue(ctx, "k8s_jobs") + msg, err := w.taskQueue.Receive(ctx, clusterQueue, clusterGroup, w.consumerName) if err != nil { - w.logger.Error("failed to dequeue cluster job", "error", err) - time.Sleep(queuePollBackoff) + w.logger.Error("failed to receive cluster job", "error", err) + time.Sleep(1 * time.Second) continue } - if msg == "" { - time.Sleep(queuePollBackoff) + if msg == nil { continue } var job domain.ClusterJob - if err := json.Unmarshal([]byte(msg), &job); err != nil { - w.logger.Error("failed to unmarshal cluster job", "error", err) + if err := json.Unmarshal([]byte(msg.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal cluster job", + "error", err, "msg_id", msg.ID) + _ = w.taskQueue.Ack(ctx, clusterQueue, clusterGroup, msg.ID) continue } - w.logger.Info("processing cluster job", "cluster_id", job.ClusterID, "type", job.Type) + w.logger.Info("processing cluster job", + "cluster_id", job.ClusterID, + "type", job.Type, + "msg_id", msg.ID, + ) sem <- struct{}{} - go func() { + go func(m *ports.DurableMessage, j domain.ClusterJob) { defer func() { <-sem }() - w.processJob(job) - }() + w.processJob(ctx, m, j) + }(msg, job) } } } -func (w *ClusterWorker) processJob(job domain.ClusterJob) { - // Root context for background task +func (w *ClusterWorker) processJob(workerCtx context.Context, msg *ports.DurableMessage, job domain.ClusterJob) { + jobKey := fmt.Sprintf("cluster:%s:%s", job.Type, job.ClusterID) + + // Idempotency check. + if w.ledger != nil { + acquired, err := w.ledger.TryAcquire(workerCtx, jobKey, clusterStaleThreshold) + if err != nil { + w.logger.Error("execution ledger error", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", err) + _ = w.taskQueue.Nack(workerCtx, clusterQueue, clusterGroup, msg.ID) + return + } + if !acquired { + w.logger.Info("skipping duplicate cluster job", + "cluster_id", job.ClusterID, "type", job.Type, "msg_id", msg.ID) + _ = w.taskQueue.Ack(workerCtx, clusterQueue, clusterGroup, msg.ID) + return + } + } + ctx := appcontext.WithUserID(context.Background(), job.UserID) cluster, err := w.repo.GetByID(ctx, job.ClusterID) if err != nil { - w.logger.Error("failed to fetch cluster for job", "cluster_id", job.ClusterID, "error", err) + w.logger.Error("failed to fetch cluster for job", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", err) + if w.ledger != nil { + _ = w.ledger.MarkFailed(workerCtx, jobKey, err.Error()) + } + _ = w.taskQueue.Nack(workerCtx, clusterQueue, clusterGroup, msg.ID) return } if cluster == nil { - w.logger.Error("cluster not found for job", "cluster_id", job.ClusterID) + w.logger.Error("cluster not found for job", + "cluster_id", job.ClusterID, "msg_id", msg.ID) + // Ack — cluster was deleted, nothing to do. + if w.ledger != nil { + _ = w.ledger.MarkComplete(workerCtx, jobKey, "cluster_not_found") + } + _ = w.taskQueue.Ack(workerCtx, clusterQueue, clusterGroup, msg.ID) return } + var processErr error switch job.Type { case domain.ClusterJobProvision: - w.handleProvision(ctx, cluster) + processErr = w.handleProvision(ctx, cluster) case domain.ClusterJobDeprovision: - w.handleDeprovision(ctx, cluster) + processErr = w.handleDeprovision(ctx, cluster) case domain.ClusterJobUpgrade: - w.handleUpgrade(ctx, cluster, job.Version) + processErr = w.handleUpgrade(ctx, cluster, job.Version) + } + + if processErr != nil { + w.logger.Error("cluster job failed", + "cluster_id", job.ClusterID, "type", job.Type, + "msg_id", msg.ID, "error", processErr) + if w.ledger != nil { + _ = w.ledger.MarkFailed(workerCtx, jobKey, processErr.Error()) + } + _ = w.taskQueue.Nack(workerCtx, clusterQueue, clusterGroup, msg.ID) + return + } + + if w.ledger != nil { + _ = w.ledger.MarkComplete(workerCtx, jobKey, "ok") + } + if err := w.taskQueue.Ack(workerCtx, clusterQueue, clusterGroup, msg.ID); err != nil { + w.logger.Error("failed to ack cluster job", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", err) } } -func (w *ClusterWorker) handleProvision(ctx context.Context, cluster *domain.Cluster) { +func (w *ClusterWorker) handleProvision(ctx context.Context, cluster *domain.Cluster) error { cluster.Status = domain.ClusterStatusProvisioning cluster.UpdatedAt = time.Now() _ = w.repo.Update(ctx, cluster) @@ -108,50 +186,93 @@ func (w *ClusterWorker) handleProvision(ctx context.Context, cluster *domain.Clu if err := w.provisioner.Provision(ctx, cluster); err != nil { w.logger.Error("provisioning failed", "cluster_id", cluster.ID, "error", err) cluster.Status = domain.ClusterStatusFailed - } else { - w.logger.Info("provisioning succeeded", "cluster_id", cluster.ID) - cluster.Status = domain.ClusterStatusRunning + cluster.UpdatedAt = time.Now() + cluster.JobID = nil + _ = w.repo.Update(ctx, cluster) + return err } + w.logger.Info("provisioning succeeded", "cluster_id", cluster.ID) + cluster.Status = domain.ClusterStatusRunning cluster.UpdatedAt = time.Now() - cluster.JobID = nil // Clear job ID + cluster.JobID = nil _ = w.repo.Update(ctx, cluster) + return nil } -func (w *ClusterWorker) handleDeprovision(ctx context.Context, cluster *domain.Cluster) { +func (w *ClusterWorker) handleDeprovision(ctx context.Context, cluster *domain.Cluster) error { cluster.Status = domain.ClusterStatusDeleting cluster.UpdatedAt = time.Now() _ = w.repo.Update(ctx, cluster) if err := w.provisioner.Deprovision(ctx, cluster); err != nil { w.logger.Error("deprovisioning failed", "cluster_id", cluster.ID, "error", err) - // We might still mark it as failed or just leave it - } else { - w.logger.Info("deprovisioning succeeded", "cluster_id", cluster.ID) - _ = w.repo.Delete(ctx, cluster.ID) - return + cluster.UpdatedAt = time.Now() + cluster.JobID = nil + _ = w.repo.Update(ctx, cluster) + return err } - cluster.UpdatedAt = time.Now() - cluster.JobID = nil - _ = w.repo.Update(ctx, cluster) + w.logger.Info("deprovisioning succeeded", "cluster_id", cluster.ID) + _ = w.repo.Delete(ctx, cluster.ID) + return nil } -func (w *ClusterWorker) handleUpgrade(ctx context.Context, cluster *domain.Cluster, version string) { +func (w *ClusterWorker) handleUpgrade(ctx context.Context, cluster *domain.Cluster, version string) error { cluster.Status = domain.ClusterStatusUpgrading cluster.UpdatedAt = time.Now() _ = w.repo.Update(ctx, cluster) if err := w.provisioner.Upgrade(ctx, cluster, version); err != nil { w.logger.Error("upgrade failed", "cluster_id", cluster.ID, "error", err) - cluster.Status = domain.ClusterStatusRunning // Revert to running if failed - } else { - w.logger.Info("upgrade succeeded", "cluster_id", cluster.ID) cluster.Status = domain.ClusterStatusRunning - cluster.Version = version + cluster.UpdatedAt = time.Now() + cluster.JobID = nil + _ = w.repo.Update(ctx, cluster) + return err } + w.logger.Info("upgrade succeeded", "cluster_id", cluster.ID) + cluster.Status = domain.ClusterStatusRunning + cluster.Version = version cluster.UpdatedAt = time.Now() cluster.JobID = nil _ = w.repo.Update(ctx, cluster) + return nil +} + +func (w *ClusterWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + msgs, err := w.taskQueue.ReclaimStale(ctx, clusterQueue, clusterGroup, w.consumerName, clusterReclaimMs, clusterReclaimN) + if err != nil { + w.logger.Warn("cluster reclaim error", "error", err) + continue + } + for _, m := range msgs { + var job domain.ClusterJob + if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal reclaimed cluster job", + "msg_id", m.ID, "error", err) + _ = w.taskQueue.Ack(ctx, clusterQueue, clusterGroup, m.ID) + continue + } + w.logger.Info("reclaimed stale cluster job", + "cluster_id", job.ClusterID, "msg_id", m.ID) + + m := m + sem <- struct{}{} + go func() { + defer func() { <-sem }() + w.processJob(ctx, &m, job) + }() + } + } + } } diff --git a/internal/workers/cluster_worker_test.go b/internal/workers/cluster_worker_test.go index 2bbcaccf9..050a4eebf 100644 --- a/internal/workers/cluster_worker_test.go +++ b/internal/workers/cluster_worker_test.go @@ -26,6 +26,37 @@ func (m *MockTaskQueue) Dequeue(ctx context.Context, queue string) (string, erro return args.String(0), args.Error(1) } +func (m *MockTaskQueue) EnsureGroup(ctx context.Context, queueName, groupName string) error { + args := m.Called(ctx, queueName, groupName) + return args.Error(0) +} + +func (m *MockTaskQueue) Receive(ctx context.Context, queueName, groupName, consumerName string) (*ports.DurableMessage, error) { + args := m.Called(ctx, queueName, groupName, consumerName) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*ports.DurableMessage), args.Error(1) +} + +func (m *MockTaskQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { + args := m.Called(ctx, queueName, groupName, messageID) + return args.Error(0) +} + +func (m *MockTaskQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { + args := m.Called(ctx, queueName, groupName, messageID) + return args.Error(0) +} + +func (m *MockTaskQueue) ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]ports.DurableMessage, error) { + args := m.Called(ctx, queueName, groupName, consumerName, minIdleMs, count) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]ports.DurableMessage), args.Error(1) +} + type MockClusterRepo struct{ mock.Mock } func (m *MockClusterRepo) Create(ctx context.Context, c *domain.Cluster) error { return nil } @@ -124,7 +155,7 @@ func TestClusterWorkerProcessProvisionJob(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -135,17 +166,20 @@ func TestClusterWorkerProcessProvisionJob(t *testing.T) { ClusterID: clusterID, UserID: userID, } + msg := &ports.DurableMessage{ID: "1-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(cluster, nil) repo.On("Update", mock.Anything, mock.MatchedBy(func(c *domain.Cluster) bool { return c.Status == domain.ClusterStatusProvisioning || c.Status == domain.ClusterStatusRunning })).Return(nil) prov.On("Provision", mock.Anything, cluster).Return(nil) + tq.On("Ack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) prov.AssertExpectations(t) + tq.AssertExpectations(t) } func TestClusterWorkerProcessDeprovisionJobSuccess(t *testing.T) { @@ -154,7 +188,7 @@ func TestClusterWorkerProcessDeprovisionJobSuccess(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -165,16 +199,19 @@ func TestClusterWorkerProcessDeprovisionJobSuccess(t *testing.T) { ClusterID: clusterID, UserID: userID, } + msg := &ports.DurableMessage{ID: "2-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(cluster, nil) repo.On("Update", mock.Anything, mock.AnythingOfType("*domain.Cluster")).Return(nil) prov.On("Deprovision", mock.Anything, cluster).Return(nil) repo.On("Delete", mock.Anything, clusterID).Return(nil) + tq.On("Ack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) prov.AssertExpectations(t) + tq.AssertExpectations(t) } func TestClusterWorkerProcessDeprovisionJobFailure(t *testing.T) { @@ -183,7 +220,7 @@ func TestClusterWorkerProcessDeprovisionJobFailure(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -194,15 +231,18 @@ func TestClusterWorkerProcessDeprovisionJobFailure(t *testing.T) { ClusterID: clusterID, UserID: userID, } + msg := &ports.DurableMessage{ID: "3-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(cluster, nil) repo.On("Update", mock.Anything, mock.AnythingOfType("*domain.Cluster")).Return(nil).Twice() prov.On("Deprovision", mock.Anything, cluster).Return(io.EOF) + tq.On("Nack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) prov.AssertExpectations(t) + tq.AssertExpectations(t) } func TestClusterWorkerProcessUpgradeJob(t *testing.T) { @@ -211,7 +251,7 @@ func TestClusterWorkerProcessUpgradeJob(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -224,17 +264,20 @@ func TestClusterWorkerProcessUpgradeJob(t *testing.T) { UserID: userID, Version: version, } + msg := &ports.DurableMessage{ID: "4-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(cluster, nil) repo.On("Update", mock.Anything, mock.MatchedBy(func(c *domain.Cluster) bool { return c.Status == domain.ClusterStatusUpgrading || c.Status == domain.ClusterStatusRunning })).Return(nil).Twice() prov.On("Upgrade", mock.Anything, cluster, version).Return(nil) + tq.On("Ack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) prov.AssertExpectations(t) + tq.AssertExpectations(t) } func TestClusterWorkerProcessJobClusterNotFound(t *testing.T) { @@ -243,7 +286,7 @@ func TestClusterWorkerProcessJobClusterNotFound(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -252,10 +295,14 @@ func TestClusterWorkerProcessJobClusterNotFound(t *testing.T) { ClusterID: clusterID, UserID: userID, } + msg := &ports.DurableMessage{ID: "5-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(nil, nil) + // Cluster not found -> ack to avoid infinite retries + tq.On("Ack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) prov.AssertNotCalled(t, "Provision", mock.Anything, mock.Anything) + tq.AssertExpectations(t) } diff --git a/internal/workers/pipeline_worker.go b/internal/workers/pipeline_worker.go index dfce93696..d0077ef12 100644 --- a/internal/workers/pipeline_worker.go +++ b/internal/workers/pipeline_worker.go @@ -4,8 +4,10 @@ package workers import ( "context" "encoding/json" + "fmt" "io" "log/slog" + "os" "strings" "sync" "time" @@ -16,29 +18,59 @@ import ( "github.com/poyrazk/thecloud/internal/core/ports" ) -const pipelineQueueName = "pipeline_build_queue" +const ( + pipelineQueueName = "pipeline_build_queue" + pipelineGroup = "pipeline_workers" + pipelineMaxWorkers = 5 + pipelineReclaimMs = 10 * 60 * 1000 // 10 minutes (builds are longer) + pipelineReclaimN = 5 + // Stale threshold for idempotency ledger: builds can take up to 30 min, + // so a "running" entry older than this is considered abandoned. + pipelineStaleThreshold = 35 * time.Minute +) // PipelineWorker processes queued pipeline builds. type PipelineWorker struct { - repo ports.PipelineRepository - taskQueue ports.TaskQueue - compute ports.ComputeBackend - logger *slog.Logger + repo ports.PipelineRepository + taskQueue ports.DurableTaskQueue + ledger ports.ExecutionLedger + compute ports.ComputeBackend + logger *slog.Logger + consumerName string } // NewPipelineWorker creates a new PipelineWorker. -func NewPipelineWorker(repo ports.PipelineRepository, taskQueue ports.TaskQueue, compute ports.ComputeBackend, logger *slog.Logger) *PipelineWorker { +// If ledger is nil, idempotency checks are skipped. +func NewPipelineWorker(repo ports.PipelineRepository, taskQueue ports.DurableTaskQueue, ledger ports.ExecutionLedger, compute ports.ComputeBackend, logger *slog.Logger) *PipelineWorker { + hostname, _ := os.Hostname() + if hostname == "" { + hostname = "pipeline-worker" + } return &PipelineWorker{ - repo: repo, - taskQueue: taskQueue, - compute: compute, - logger: logger, + repo: repo, + taskQueue: taskQueue, + ledger: ledger, + compute: compute, + logger: logger, + consumerName: hostname, } } func (w *PipelineWorker) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - w.logger.Info("starting pipeline worker") + w.logger.Info("starting pipeline worker", + "consumer", w.consumerName, + "concurrency", pipelineMaxWorkers, + ) + + if err := w.taskQueue.EnsureGroup(ctx, pipelineQueueName, pipelineGroup); err != nil { + w.logger.Error("failed to ensure pipeline consumer group", "error", err) + return + } + + sem := make(chan struct{}, pipelineMaxWorkers) + + go w.reclaimLoop(ctx, sem) for { select { @@ -46,51 +78,102 @@ func (w *PipelineWorker) Run(ctx context.Context, wg *sync.WaitGroup) { w.logger.Info("stopping pipeline worker") return default: - msg, err := w.taskQueue.Dequeue(ctx, pipelineQueueName) + msg, err := w.taskQueue.Receive(ctx, pipelineQueueName, pipelineGroup, w.consumerName) if err != nil { - w.logger.Error("failed to dequeue pipeline job", "error", err) + w.logger.Error("failed to receive pipeline job", "error", err) time.Sleep(1 * time.Second) continue } - if msg == "" { + if msg == nil { continue } var job domain.BuildJob - if err := json.Unmarshal([]byte(msg), &job); err != nil { - w.logger.Error("failed to unmarshal build job", "error", err) + if err := json.Unmarshal([]byte(msg.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal build job", + "error", err, "msg_id", msg.ID) + _ = w.taskQueue.Ack(ctx, pipelineQueueName, pipelineGroup, msg.ID) continue } - w.processJob(job) + sem <- struct{}{} + go func(m *ports.DurableMessage, j domain.BuildJob) { + defer func() { <-sem }() + w.processJob(ctx, m, j) + }(msg, job) } } } -func (w *PipelineWorker) processJob(job domain.BuildJob) { +func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.DurableMessage, job domain.BuildJob) { + jobKey := fmt.Sprintf("pipeline:%s", job.BuildID) + + // Idempotency check: skip if already completed or actively being processed. + if w.ledger != nil { + acquired, err := w.ledger.TryAcquire(workerCtx, jobKey, pipelineStaleThreshold) + if err != nil { + w.logger.Error("execution ledger error", + "build_id", job.BuildID, "msg_id", msg.ID, "error", err) + _ = w.taskQueue.Nack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + return + } + if !acquired { + w.logger.Info("skipping duplicate pipeline job", + "build_id", job.BuildID, "msg_id", msg.ID) + _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + return + } + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) defer cancel() ctx = appcontext.WithUserID(ctx, job.UserID) build, pipeline := w.loadBuildAndPipeline(ctx, job) if build == nil || pipeline == nil { + // Build or pipeline not found — ack to avoid infinite retries. + if w.ledger != nil { + _ = w.ledger.MarkComplete(workerCtx, jobKey, "not_found") + } + _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) return } if !w.markBuildRunning(ctx, build) { + if w.ledger != nil { + _ = w.ledger.MarkFailed(workerCtx, jobKey, "failed to mark build running") + } + _ = w.taskQueue.Nack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) return } if len(pipeline.Config.Stages) == 0 { w.failBuild(ctx, build, "pipeline has no stages") + if w.ledger != nil { + _ = w.ledger.MarkComplete(workerCtx, jobKey, "no_stages") + } + _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) return } if !w.executePipeline(ctx, build, pipeline) { + // Build failed but was processed — ack the message. + if w.ledger != nil { + _ = w.ledger.MarkComplete(workerCtx, jobKey, "build_failed") + } + _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) return } w.markBuildSucceeded(ctx, build) + + if w.ledger != nil { + _ = w.ledger.MarkComplete(workerCtx, jobKey, "ok") + } + if err := w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID); err != nil { + w.logger.Error("failed to ack pipeline job", + "build_id", job.BuildID, "msg_id", msg.ID, "error", err) + } } func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.BuildJob) (*domain.Build, *domain.Pipeline) { @@ -262,3 +345,39 @@ func (w *PipelineWorker) collectTaskLogs(ctx context.Context, taskID string) (st } return string(data), nil } + +func (w *PipelineWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + msgs, err := w.taskQueue.ReclaimStale(ctx, pipelineQueueName, pipelineGroup, w.consumerName, pipelineReclaimMs, pipelineReclaimN) + if err != nil { + w.logger.Warn("pipeline reclaim error", "error", err) + continue + } + for _, m := range msgs { + var job domain.BuildJob + if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal reclaimed pipeline job", + "msg_id", m.ID, "error", err) + _ = w.taskQueue.Ack(ctx, pipelineQueueName, pipelineGroup, m.ID) + continue + } + w.logger.Info("reclaimed stale pipeline job", + "build_id", job.BuildID, "msg_id", m.ID) + + m := m + sem <- struct{}{} + go func() { + defer func() { <-sem }() + w.processJob(ctx, &m, job) + }() + } + } + } +} diff --git a/internal/workers/provision_worker.go b/internal/workers/provision_worker.go index 3f257e5ca..dee39a94d 100644 --- a/internal/workers/provision_worker.go +++ b/internal/workers/provision_worker.go @@ -4,7 +4,9 @@ package workers import ( "context" "encoding/json" + "fmt" "log/slog" + "os" "sync" "time" @@ -14,25 +16,64 @@ import ( "github.com/poyrazk/thecloud/internal/core/services" ) -// ProvisionWorker processes instance provisioning tasks. +const ( + provisionQueue = "provision_queue" + provisionGroup = "provision_workers" + provisionMaxWorkers = 20 + // How long a message can sit in PEL before another consumer reclaims it. + provisionReclaimMs = 5 * 60 * 1000 // 5 minutes + provisionReclaimN = 10 + // Stale threshold for idempotency ledger: if a "running" entry is older + // than this, it is considered abandoned and can be reclaimed. + provisionStaleThreshold = 15 * time.Minute +) + +// ProvisionWorker processes instance provisioning tasks using a durable queue +// with at-least-once delivery. Jobs are acknowledged only after successful +// processing; crashed jobs are reclaimed by healthy peers. An execution ledger +// prevents duplicate processing of the same instance. type ProvisionWorker struct { - instSvc *services.InstanceService - taskQueue ports.TaskQueue - logger *slog.Logger + instSvc *services.InstanceService + taskQueue ports.DurableTaskQueue + ledger ports.ExecutionLedger + logger *slog.Logger + consumerName string } // NewProvisionWorker constructs a ProvisionWorker. -func NewProvisionWorker(instSvc *services.InstanceService, taskQueue ports.TaskQueue, logger *slog.Logger) *ProvisionWorker { +// If ledger is nil, idempotency checks are skipped. +func NewProvisionWorker(instSvc *services.InstanceService, taskQueue ports.DurableTaskQueue, ledger ports.ExecutionLedger, logger *slog.Logger) *ProvisionWorker { + hostname, _ := os.Hostname() + if hostname == "" { + hostname = "provision-worker" + } return &ProvisionWorker{ - instSvc: instSvc, - taskQueue: taskQueue, - logger: logger, + instSvc: instSvc, + taskQueue: taskQueue, + ledger: ledger, + logger: logger, + consumerName: hostname, } } func (w *ProvisionWorker) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - w.logger.Info("starting provision worker") + w.logger.Info("starting provision worker", + "consumer", w.consumerName, + "concurrency", provisionMaxWorkers, + ) + + // Ensure consumer group exists. + if err := w.taskQueue.EnsureGroup(ctx, provisionQueue, provisionGroup); err != nil { + w.logger.Error("failed to ensure provision consumer group", "error", err) + return + } + + sem := make(chan struct{}, provisionMaxWorkers) + + // Start a background goroutine that periodically reclaims stale messages + // from crashed consumers. + go w.reclaimLoop(ctx, sem) for { select { @@ -40,47 +81,141 @@ func (w *ProvisionWorker) Run(ctx context.Context, wg *sync.WaitGroup) { w.logger.Info("stopping provision worker") return default: - // Dequeue task - msg, err := w.taskQueue.Dequeue(ctx, "provision_queue") + msg, err := w.taskQueue.Receive(ctx, provisionQueue, provisionGroup, w.consumerName) if err != nil { - // redis.Nil or other error + w.logger.Error("failed to receive provision job", "error", err) time.Sleep(1 * time.Second) continue } - - if msg == "" { + if msg == nil { continue } var job domain.ProvisionJob - if err := json.Unmarshal([]byte(msg), &job); err != nil { - w.logger.Error("failed to unmarshal provision job", "error", err) + if err := json.Unmarshal([]byte(msg.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal provision job", + "error", err, "msg_id", msg.ID) + // Ack poison messages so they don't block the queue. + _ = w.taskQueue.Ack(ctx, provisionQueue, provisionGroup, msg.ID) continue } - w.logger.Info("processing provision job", "instance_id", job.InstanceID, "tenant_id", job.TenantID) + w.logger.Info("processing provision job", + "instance_id", job.InstanceID, + "tenant_id", job.TenantID, + "msg_id", msg.ID, + ) - // Process job concurrently to handle high throughput in load tests - go w.processJob(job) + sem <- struct{}{} // acquire concurrency slot + go func(m *ports.DurableMessage, j domain.ProvisionJob) { + defer func() { <-sem }() + w.processJob(ctx, m, j) + }(msg, job) } } } -func (w *ProvisionWorker) processJob(job domain.ProvisionJob) { - // Root context for background task with 10-minute safety timeout - // We use context.Background() because the worker lifecycle context shouldn't necessarily cancel active provisioning unless the app is shutting down +func (w *ProvisionWorker) processJob(workerCtx context.Context, msg *ports.DurableMessage, job domain.ProvisionJob) { + jobKey := fmt.Sprintf("provision:%s", job.InstanceID) + + // Idempotency check: skip if already completed or actively being processed. + if w.ledger != nil { + acquired, err := w.ledger.TryAcquire(workerCtx, jobKey, provisionStaleThreshold) + if err != nil { + w.logger.Error("execution ledger error", + "instance_id", job.InstanceID, "msg_id", msg.ID, "error", err) + // On ledger error, nack to retry later. + _ = w.taskQueue.Nack(workerCtx, provisionQueue, provisionGroup, msg.ID) + return + } + if !acquired { + w.logger.Info("skipping duplicate provision job", + "instance_id", job.InstanceID, "msg_id", msg.ID) + // Already processed — ack the duplicate message. + _ = w.taskQueue.Ack(workerCtx, provisionQueue, provisionGroup, msg.ID) + return + } + } + + // Root context for background task with 10-minute safety timeout. baseCtx := context.Background() ctx, cancel := context.WithTimeout(baseCtx, 10*time.Minute) defer cancel() - // Inject User and Tenant IDs for repository access control + // Inject User and Tenant IDs for repository access control. ctx = appcontext.WithUserID(ctx, job.UserID) ctx = appcontext.WithTenantID(ctx, job.TenantID) - w.logger.Info("starting provision logic", "instance_id", job.InstanceID) + w.logger.Info("starting provision logic", "instance_id", job.InstanceID, "msg_id", msg.ID) if err := w.instSvc.Provision(ctx, job); err != nil { - w.logger.Error("failed to provision instance", "instance_id", job.InstanceID, "error", err) - } else { - w.logger.Info("successfully provisioned instance", "instance_id", job.InstanceID) + w.logger.Error("failed to provision instance", + "instance_id", job.InstanceID, + "msg_id", msg.ID, + "error", err, + ) + // Mark failed in the ledger so it can be retried. + if w.ledger != nil { + _ = w.ledger.MarkFailed(workerCtx, jobKey, err.Error()) + } + // Nack: leave message in PEL for reclaim/retry. + _ = w.taskQueue.Nack(workerCtx, provisionQueue, provisionGroup, msg.ID) + return + } + + w.logger.Info("successfully provisioned instance", + "instance_id", job.InstanceID, + "msg_id", msg.ID, + ) + + // Mark completed in ledger (prevents duplicate execution). + if w.ledger != nil { + _ = w.ledger.MarkComplete(workerCtx, jobKey, "ok") + } + + // Acknowledge — message is permanently consumed. + if err := w.taskQueue.Ack(workerCtx, provisionQueue, provisionGroup, msg.ID); err != nil { + w.logger.Error("failed to ack provision job", + "instance_id", job.InstanceID, + "msg_id", msg.ID, + "error", err, + ) + } +} + +// reclaimLoop periodically reclaims messages stuck in the PEL from crashed +// consumers and re-processes them. +func (w *ProvisionWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + msgs, err := w.taskQueue.ReclaimStale(ctx, provisionQueue, provisionGroup, w.consumerName, provisionReclaimMs, provisionReclaimN) + if err != nil { + w.logger.Warn("provision reclaim error", "error", err) + continue + } + for _, m := range msgs { + var job domain.ProvisionJob + if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal reclaimed provision job", + "msg_id", m.ID, "error", err) + _ = w.taskQueue.Ack(ctx, provisionQueue, provisionGroup, m.ID) + continue + } + w.logger.Info("reclaimed stale provision job", + "instance_id", job.InstanceID, "msg_id", m.ID) + + m := m // capture loop variable + sem <- struct{}{} + go func() { + defer func() { <-sem }() + w.processJob(ctx, &m, job) + }() + } + } } } diff --git a/internal/workers/provision_worker_test.go b/internal/workers/provision_worker_test.go index 8cdd442e2..a79785f7f 100644 --- a/internal/workers/provision_worker_test.go +++ b/internal/workers/provision_worker_test.go @@ -18,28 +18,53 @@ import ( "github.com/stretchr/testify/assert" ) -type fakeTaskQueue struct { - messages []string - errors []error // To simulate dequeue errors +// fakeDurableQueue implements ports.DurableTaskQueue for testing. +type fakeDurableQueue struct { + messages []*ports.DurableMessage + errors []error index int + acked []string + nacked []string } -func (f *fakeTaskQueue) Enqueue(ctx context.Context, queueName string, payload interface{}) error { +func (f *fakeDurableQueue) Enqueue(ctx context.Context, queueName string, payload interface{}) error { return nil } -func (f *fakeTaskQueue) Dequeue(ctx context.Context, queueName string) (string, error) { +func (f *fakeDurableQueue) Dequeue(ctx context.Context, queueName string) (string, error) { + return "", nil +} + +func (f *fakeDurableQueue) EnsureGroup(ctx context.Context, queueName, groupName string) error { + return nil +} + +func (f *fakeDurableQueue) Receive(ctx context.Context, queueName, groupName, consumerName string) (*ports.DurableMessage, error) { if f.index < len(f.errors) && f.errors[f.index] != nil { err := f.errors[f.index] f.index++ - return "", err + return nil, err } if f.index < len(f.messages) { msg := f.messages[f.index] f.index++ return msg, nil } - return "", nil + return nil, nil +} + +func (f *fakeDurableQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { + f.acked = append(f.acked, messageID) + return nil +} + +func (f *fakeDurableQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { + f.nacked = append(f.nacked, messageID) + return nil +} + +func (f *fakeDurableQueue) ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]ports.DurableMessage, error) { + return nil, nil } // failingComputeBackend forces Provision to fail @@ -53,48 +78,57 @@ func (f *failingComputeBackend) LaunchInstanceWithOptions(ctx context.Context, o func TestProvisionWorkerRun(t *testing.T) { tests := []struct { - name string - message interface{} // string or struct - injectDequeErr bool - failProvision bool - wantLog string + name string + payload interface{} + poisonJSON bool + failProvision bool + wantLog string + wantAcked bool + wantNacked bool }{ { name: "success", - message: domain.ProvisionJob{ + payload: domain.ProvisionJob{ InstanceID: uuid.New(), UserID: uuid.New(), }, - wantLog: "successfully provisioned instance", + wantLog: "successfully provisioned instance", + wantAcked: true, }, { - name: "deserialize_error", - message: "{invalid-json}", - wantLog: "failed to unmarshal provision job", + name: "deserialize_error", + poisonJSON: true, + wantLog: "failed to unmarshal provision job", + wantAcked: true, // poison messages are acked to unblock the queue }, { - name: "provision_error", - message: domain.ProvisionJob{InstanceID: uuid.New(), UserID: uuid.New()}, + name: "provision_error", + payload: domain.ProvisionJob{ + InstanceID: uuid.New(), + UserID: uuid.New(), + }, failProvision: true, wantLog: "failed to provision instance", + wantNacked: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var msgBytes []byte - switch v := tt.message.(type) { - case string: - msgBytes = []byte(v) - default: - msgBytes, _ = json.Marshal(v) + var payloadStr string + if tt.poisonJSON { + payloadStr = "{invalid-json}" + } else { + data, _ := json.Marshal(tt.payload) + payloadStr = string(data) } - fq := &fakeTaskQueue{ - messages: []string{string(msgBytes)}, + fq := &fakeDurableQueue{ + messages: []*ports.DurableMessage{ + {ID: "1-0", Payload: payloadStr, Queue: provisionQueue}, + }, } - // Compute backend var compute ports.ComputeBackend = &noop.NoopComputeBackend{} if tt.failProvision { compute = &failingComputeBackend{} @@ -116,7 +150,7 @@ func TestProvisionWorkerRun(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) - worker := NewProvisionWorker(instSvc, fq, logger) + worker := NewProvisionWorker(instSvc, fq, nil, logger) ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup @@ -124,20 +158,25 @@ func TestProvisionWorkerRun(t *testing.T) { go worker.Run(ctx, &wg) - time.Sleep(50 * time.Millisecond) + // Give worker time to process + time.Sleep(200 * time.Millisecond) cancel() wg.Wait() assert.Contains(t, buf.String(), tt.wantLog) + if tt.wantAcked { + assert.NotEmpty(t, fq.acked, "expected message to be acked") + } + if tt.wantNacked { + assert.NotEmpty(t, fq.nacked, "expected message to be nacked") + } }) } } -func TestProvisionWorkerRunDequeueError(t *testing.T) { - // Test that worker continues on queue error - fq := &fakeTaskQueue{ - messages: []string{}, - errors: []error{errors.New("redis connection failed")}, +func TestProvisionWorkerRunReceiveError(t *testing.T) { + fq := &fakeDurableQueue{ + errors: []error{errors.New("redis connection failed")}, } instSvc := services.NewInstanceService(services.InstanceServiceParams{ @@ -156,7 +195,7 @@ func TestProvisionWorkerRunDequeueError(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) - worker := NewProvisionWorker(instSvc, fq, logger) + worker := NewProvisionWorker(instSvc, fq, nil, logger) ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup @@ -166,5 +205,6 @@ func TestProvisionWorkerRunDequeueError(t *testing.T) { time.Sleep(50 * time.Millisecond) cancel() wg.Wait() - // No specific log to check as it just continues, but we ensure no panic and coverage hits error path + + assert.Contains(t, buf.String(), "failed to receive provision job") } From e968c1a15cebec8a72b77e9d8898168d1015fb2e Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 10 Mar 2026 05:41:05 +0300 Subject: [PATCH 06/12] test(ha): Add failure drills and release gate tests HA Drills (ha_drills_test.go): 1. Circuit breaker trip and recovery (validates 3 state transitions) 2. Bulkhead saturation and graceful rejection 3. Resilient adapter end-to-end (CB + bulkhead + timeout compose) 4. Retry backoff and context cancellation 5. Half-open single-flight validation Release Gates (release_gates_test.go) - validate SLOs: 1. Fail-fast latency <1ms when circuit is open 2. Bulkhead isolation (saturated compute doesn't affect network) 3. Circuit recovery within resetTimeout window 4. Retry idempotency (exactly MaxAttempts executions) 5. Independent circuit breakers don't interfere Total: 13 new tests validating HA invariants. --- internal/drills/ha_drills_test.go | 386 ++++++++++++++++++++++++++ internal/drills/release_gates_test.go | 203 ++++++++++++++ 2 files changed, 589 insertions(+) create mode 100644 internal/drills/ha_drills_test.go create mode 100644 internal/drills/release_gates_test.go diff --git a/internal/drills/ha_drills_test.go b/internal/drills/ha_drills_test.go new file mode 100644 index 000000000..837e7f571 --- /dev/null +++ b/internal/drills/ha_drills_test.go @@ -0,0 +1,386 @@ +// Package drills provides integration-like failure drill tests that validate +// the HA properties of the control plane. These tests use mocks to simulate +// infrastructure failures without requiring real Postgres/Redis. +// +// Run: go test ./internal/drills/ -v -count=1 +package drills + +import ( + "context" + "errors" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/poyrazk/thecloud/internal/platform" +) + +// --------------------------------------------------------------------------- +// Drill 1: Circuit breaker trip + recovery +// SLO: When a backend fails ≥ threshold times, all subsequent calls must +// return ErrCircuitOpen within 1ms (no backend call). After resetTimeout, +// a successful probe must close the circuit. +// --------------------------------------------------------------------------- + +func TestDrill_CircuitBreakerTripAndRecovery(t *testing.T) { + const threshold = 3 + const resetTimeout = 200 * time.Millisecond + + var transitions []string + var mu sync.Mutex + + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "drill-cb", + Threshold: threshold, + ResetTimeout: resetTimeout, + SuccessRequired: 1, + OnStateChange: func(name string, from, to platform.State) { + mu.Lock() + transitions = append(transitions, from.String()+"→"+to.String()) + mu.Unlock() + }, + }) + + backendErr := errors.New("backend down") + + // Phase 1: Trip the circuit with consecutive failures. + for i := 0; i < threshold; i++ { + err := cb.Execute(func() error { return backendErr }) + if err == nil { + t.Fatalf("iteration %d: expected error", i) + } + } + + // Verify circuit is open. + if cb.GetState() != platform.StateOpen { + t.Fatalf("expected open, got %s", cb.GetState().String()) + } + + // Phase 2: Confirm fail-fast (no backend call). + var backendCalled atomic.Bool + start := time.Now() + err := cb.Execute(func() error { + backendCalled.Store(true) + return nil + }) + elapsed := time.Since(start) + + if !errors.Is(err, platform.ErrCircuitOpen) { + t.Fatalf("expected ErrCircuitOpen, got %v", err) + } + if backendCalled.Load() { + t.Fatal("backend should NOT have been called while circuit is open") + } + if elapsed > 5*time.Millisecond { + t.Fatalf("fail-fast took %v, expected <5ms", elapsed) + } + + // Phase 3: Wait for resetTimeout, then recover. + time.Sleep(resetTimeout + 50*time.Millisecond) + err = cb.Execute(func() error { return nil }) + if err != nil { + t.Fatalf("expected recovery, got %v", err) + } + if cb.GetState() != platform.StateClosed { + t.Fatalf("expected closed after recovery, got %s", cb.GetState().String()) + } + + // Verify transitions: closed→open, open→half-open, half-open→closed. + mu.Lock() + defer mu.Unlock() + expected := []string{"closed→open", "open→half-open", "half-open→closed"} + if len(transitions) != len(expected) { + t.Fatalf("expected %d transitions, got %d: %v", len(expected), len(transitions), transitions) + } + for i := range expected { + if transitions[i] != expected[i] { + t.Fatalf("transition[%d]: expected %s, got %s", i, expected[i], transitions[i]) + } + } +} + +// --------------------------------------------------------------------------- +// Drill 2: Bulkhead saturation + graceful rejection +// SLO: When maxConc requests are in-flight, additional requests must be +// rejected with ErrBulkheadFull (not blocked forever). +// --------------------------------------------------------------------------- + +func TestDrill_BulkheadSaturationAndRejection(t *testing.T) { + const maxConc = 3 + const waitTimeout = 100 * time.Millisecond + + bh := platform.NewBulkhead(platform.BulkheadOpts{ + Name: "drill-bh", + MaxConc: maxConc, + WaitTimeout: waitTimeout, + }) + + ctx := context.Background() + blockCh := make(chan struct{}) + var inFlight atomic.Int64 + var rejected atomic.Int64 + var wg sync.WaitGroup + + // Saturate the bulkhead. + for i := 0; i < maxConc; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = bh.Execute(ctx, func() error { + inFlight.Add(1) + <-blockCh + return nil + }) + }() + } + + // Wait for all slots to be occupied. + for inFlight.Load() < int64(maxConc) { + time.Sleep(5 * time.Millisecond) + } + + if bh.Available() != 0 { + t.Fatalf("expected 0 available slots, got %d", bh.Available()) + } + + // Fire excess requests — they should be rejected. + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := bh.Execute(ctx, func() error { return nil }) + if errors.Is(err, platform.ErrBulkheadFull) { + rejected.Add(1) + } + }() + } + + // Let the excess timeout. + time.Sleep(waitTimeout + 50*time.Millisecond) + + // Unblock the saturating goroutines. + close(blockCh) + wg.Wait() + + if rejected.Load() != 5 { + t.Fatalf("expected 5 rejections, got %d", rejected.Load()) + } +} + +// --------------------------------------------------------------------------- +// Drill 3: Resilient adapter end-to-end (circuit + bulkhead + timeout) +// SLO: A failing backend trips the circuit; subsequent calls fail-fast; +// recovery probe succeeds and normal operation resumes. +// --------------------------------------------------------------------------- + +type failingBackend struct { + healthy atomic.Bool + calls atomic.Int64 +} + +func (f *failingBackend) Do(_ context.Context) error { + f.calls.Add(1) + if f.healthy.Load() { + return nil + } + return errors.New("backend failure") +} + +func TestDrill_ResilientAdapterEndToEnd(t *testing.T) { + backend := &failingBackend{} + logger := slog.Default() + + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "drill-e2e", + Threshold: 3, + ResetTimeout: 200 * time.Millisecond, + SuccessRequired: 1, + }) + + bh := platform.NewBulkhead(platform.BulkheadOpts{ + Name: "drill-e2e", + MaxConc: 5, + }) + + _ = logger // Would be used for real logging in production. + + // Helper: simulate calling through the full resilience stack. + callThrough := func(ctx context.Context) error { + return bh.Execute(ctx, func() error { + return cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + return backend.Do(ctx2) + }) + }) + } + + ctx := context.Background() + + // Phase 1: Backend is down → trip circuit. + for i := 0; i < 3; i++ { + _ = callThrough(ctx) + } + if cb.GetState() != platform.StateOpen { + t.Fatalf("expected open, got %s", cb.GetState().String()) + } + + // Phase 2: Fail-fast while open. + callsBefore := backend.calls.Load() + err := callThrough(ctx) + if !errors.Is(err, platform.ErrCircuitOpen) { + t.Fatalf("expected ErrCircuitOpen, got %v", err) + } + if backend.calls.Load() != callsBefore { + t.Fatal("backend should not be called while circuit is open") + } + + // Phase 3: Backend recovers. + backend.healthy.Store(true) + time.Sleep(250 * time.Millisecond) + + err = callThrough(ctx) + if err != nil { + t.Fatalf("expected recovery, got %v", err) + } + if cb.GetState() != platform.StateClosed { + t.Fatalf("expected closed, got %s", cb.GetState().String()) + } + + // Phase 4: Normal operation continues. + for i := 0; i < 10; i++ { + if err := callThrough(ctx); err != nil { + t.Fatalf("call %d failed: %v", i, err) + } + } +} + +// --------------------------------------------------------------------------- +// Drill 4: Retry with exponential backoff +// SLO: Retry must respect MaxAttempts, must apply backoff between attempts, +// and must stop early if the context is cancelled. +// --------------------------------------------------------------------------- + +func TestDrill_RetryBackoffAndContextCancellation(t *testing.T) { + t.Run("exhausts_attempts", func(t *testing.T) { + var attempts atomic.Int64 + err := platform.Retry(context.Background(), platform.RetryOpts{ + MaxAttempts: 4, + BaseDelay: 10 * time.Millisecond, + MaxDelay: 50 * time.Millisecond, + }, func(ctx context.Context) error { + attempts.Add(1) + return errors.New("still failing") + }) + + if err == nil { + t.Fatal("expected error after exhausting retries") + } + if attempts.Load() != 4 { + t.Fatalf("expected 4 attempts, got %d", attempts.Load()) + } + }) + + t.Run("stops_on_context_cancel", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + var attempts atomic.Int64 + start := time.Now() + err := platform.Retry(ctx, platform.RetryOpts{ + MaxAttempts: 100, // Would take very long if not cancelled. + BaseDelay: 50 * time.Millisecond, + MaxDelay: 200 * time.Millisecond, + }, func(ctx context.Context) error { + attempts.Add(1) + return errors.New("failing") + }) + + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected error") + } + if elapsed > 500*time.Millisecond { + t.Fatalf("should have stopped early, took %v", elapsed) + } + if attempts.Load() >= 100 { + t.Fatal("should not have exhausted all 100 attempts") + } + }) + + t.Run("succeeds_on_retry", func(t *testing.T) { + var attempts atomic.Int64 + err := platform.Retry(context.Background(), platform.RetryOpts{ + MaxAttempts: 5, + BaseDelay: 5 * time.Millisecond, + }, func(ctx context.Context) error { + n := attempts.Add(1) + if n < 3 { + return errors.New("not yet") + } + return nil + }) + + if err != nil { + t.Fatalf("expected success, got %v", err) + } + if attempts.Load() != 3 { + t.Fatalf("expected 3 attempts, got %d", attempts.Load()) + } + }) +} + +// --------------------------------------------------------------------------- +// Drill 5: Half-open single-flight +// SLO: While a probe request is in-flight in half-open state, all other +// requests must be rejected with ErrCircuitOpen. +// --------------------------------------------------------------------------- + +func TestDrill_HalfOpenSingleFlight(t *testing.T) { + const resetTimeout = 100 * time.Millisecond + + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "drill-halfopen", + Threshold: 1, + ResetTimeout: resetTimeout, + }) + + // Trip the circuit. + _ = cb.Execute(func() error { return errors.New("fail") }) + if cb.GetState() != platform.StateOpen { + t.Fatalf("expected open, got %s", cb.GetState().String()) + } + + // Wait for reset timeout. + time.Sleep(resetTimeout + 20*time.Millisecond) + + // Start a slow probe request. + probeDone := make(chan struct{}) + go func() { + _ = cb.Execute(func() error { + <-probeDone // Block until we release. + return nil + }) + }() + + // Give the probe goroutine time to start. + time.Sleep(20 * time.Millisecond) + + // All other requests should be rejected. + for i := 0; i < 5; i++ { + err := cb.Execute(func() error { return nil }) + if !errors.Is(err, platform.ErrCircuitOpen) { + t.Fatalf("request %d: expected ErrCircuitOpen during half-open probe, got %v", i, err) + } + } + + // Release the probe — circuit should close. + close(probeDone) + time.Sleep(10 * time.Millisecond) + + if cb.GetState() != platform.StateClosed { + t.Fatalf("expected closed after probe success, got %s", cb.GetState().String()) + } +} diff --git a/internal/drills/release_gates_test.go b/internal/drills/release_gates_test.go new file mode 100644 index 000000000..212f94028 --- /dev/null +++ b/internal/drills/release_gates_test.go @@ -0,0 +1,203 @@ +// Package drills contains HA failure drills and release gates. +// +// Release gates are meant to run in CI before deploying a new version. +// They validate the SLO invariants for the control-plane HA features: +// +// 1. Leader failover <30s (validated via unit tests on LeaderGuard). +// 2. Zero duplicate singleton executions during failover. +// 3. Zero job loss in crash tests (durable queue ack/nack). +// 4. Circuit breaker fail-fast under backend failure. +// 5. Bulkhead prevents cascading overload. +// 6. No API outage during single pod loss (leader re-election + queue redelivery). +// +// Run release gates: go test ./internal/drills/ -v -count=1 -run TestReleaseGate +package drills + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/poyrazk/thecloud/internal/platform" +) + +// TestReleaseGate_CircuitBreakerFailFast validates SLO: +// "When a backend is down, requests must fail-fast in <5ms". +func TestReleaseGate_CircuitBreakerFailFast(t *testing.T) { + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "gate-cb", + Threshold: 3, + ResetTimeout: 1 * time.Second, + }) + + // Trip it. + for i := 0; i < 3; i++ { + _ = cb.Execute(func() error { return errors.New("down") }) + } + + // Measure fail-fast latency over 100 calls. + const iterations = 100 + start := time.Now() + for i := 0; i < iterations; i++ { + err := cb.Execute(func() error { return nil }) + if !errors.Is(err, platform.ErrCircuitOpen) { + t.Fatalf("iteration %d: expected ErrCircuitOpen, got %v", i, err) + } + } + elapsed := time.Since(start) + + avgLatency := elapsed / iterations + if avgLatency > 1*time.Millisecond { + t.Fatalf("average fail-fast latency %v exceeds 1ms SLO", avgLatency) + } + t.Logf("PASS: avg fail-fast latency = %v (SLO: <1ms)", avgLatency) +} + +// TestReleaseGate_BulkheadIsolation validates SLO: +// "A saturated adapter must not block unrelated adapters". +func TestReleaseGate_BulkheadIsolation(t *testing.T) { + // Two independent bulkheads for two adapters. + bhCompute := platform.NewBulkhead(platform.BulkheadOpts{Name: "compute", MaxConc: 2, WaitTimeout: 50 * time.Millisecond}) + bhNetwork := platform.NewBulkhead(platform.BulkheadOpts{Name: "network", MaxConc: 5, WaitTimeout: 50 * time.Millisecond}) + + ctx := context.Background() + + // Saturate compute bulkhead. + blockCh := make(chan struct{}) + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = bhCompute.Execute(ctx, func() error { + <-blockCh + return nil + }) + }() + } + time.Sleep(30 * time.Millisecond) // Let them acquire slots. + + // Compute is now full. + err := bhCompute.Execute(ctx, func() error { return nil }) + if !errors.Is(err, platform.ErrBulkheadFull) { + t.Fatalf("compute bulkhead should be full, got %v", err) + } + + // Network bulkhead must still be operational. + err = bhNetwork.Execute(ctx, func() error { return nil }) + if err != nil { + t.Fatalf("network bulkhead should be available, got %v", err) + } + + close(blockCh) + wg.Wait() + t.Log("PASS: saturated compute did not affect network adapter") +} + +// TestReleaseGate_CircuitBreakerRecovery validates SLO: +// "After backend recovery, the circuit must close within resetTimeout + probe time". +func TestReleaseGate_CircuitBreakerRecovery(t *testing.T) { + const resetTimeout = 200 * time.Millisecond + healthy := &atomic.Bool{} + + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "gate-recovery", + Threshold: 2, + ResetTimeout: resetTimeout, + SuccessRequired: 1, + }) + + // Trip it. + for i := 0; i < 2; i++ { + _ = cb.Execute(func() error { return errors.New("down") }) + } + + // Simulate recovery after 100ms. + go func() { + time.Sleep(100 * time.Millisecond) + healthy.Store(true) + }() + + // Poll until circuit closes or timeout. + deadline := time.After(resetTimeout + 200*time.Millisecond) + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-deadline: + t.Fatalf("circuit did not recover within SLO window. State: %s", cb.GetState().String()) + case <-ticker.C: + err := cb.Execute(func() error { + if healthy.Load() { + return nil + } + return errors.New("still down") + }) + if err == nil && cb.GetState() == platform.StateClosed { + t.Logf("PASS: circuit recovered (state=%s)", cb.GetState().String()) + return + } + } + } +} + +// TestReleaseGate_RetryIdempotency validates SLO: +// "Retry must not execute the function more than MaxAttempts times". +func TestReleaseGate_RetryIdempotency(t *testing.T) { + for _, maxAttempts := range []int{1, 3, 5, 10} { + t.Run(fmt.Sprintf("max_%d", maxAttempts), func(t *testing.T) { + var count atomic.Int64 + _ = platform.Retry(context.Background(), platform.RetryOpts{ + MaxAttempts: maxAttempts, + BaseDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + }, func(ctx context.Context) error { + count.Add(1) + return errors.New("always fail") + }) + + if count.Load() != int64(maxAttempts) { + t.Fatalf("expected exactly %d attempts, got %d", maxAttempts, count.Load()) + } + }) + } +} + +// TestReleaseGate_ConcurrentCircuitBreakers validates SLO: +// "Multiple independent circuit breakers must not interfere with each other". +func TestReleaseGate_ConcurrentCircuitBreakers(t *testing.T) { + cbs := make([]*platform.CircuitBreaker, 5) + for i := range cbs { + cbs[i] = platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: fmt.Sprintf("adapter-%d", i), + Threshold: 3, + ResetTimeout: 1 * time.Second, + }) + } + + // Trip only breaker 0. + for i := 0; i < 3; i++ { + _ = cbs[0].Execute(func() error { return errors.New("down") }) + } + + if cbs[0].GetState() != platform.StateOpen { + t.Fatal("breaker 0 should be open") + } + + // All others should be closed and functional. + for i := 1; i < 5; i++ { + err := cbs[i].Execute(func() error { return nil }) + if err != nil { + t.Fatalf("breaker %d should be functional, got %v", i, err) + } + if cbs[i].GetState() != platform.StateClosed { + t.Fatalf("breaker %d should be closed, got %s", i, cbs[i].GetState().String()) + } + } + t.Log("PASS: tripped breaker did not affect independent breakers") +} From dd00c4a5b908cd3171ee0f20d473a9449bc1031a Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Wed, 11 Mar 2026 20:25:56 +0300 Subject: [PATCH 07/12] feat(ha): fix CodeRabbit suggestions and race conditions in tests --- cmd/api/main_test.go | 14 +++++---- internal/core/ports/task_queue.go | 7 +++-- internal/drills/ha_drills_test.go | 49 +++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/cmd/api/main_test.go b/cmd/api/main_test.go index 791f872dc..321281b79 100644 --- a/cmd/api/main_test.go +++ b/cmd/api/main_test.go @@ -201,7 +201,8 @@ func TestRunApplicationDefaultsToAllRole(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) t.Setenv("ROLE", "") // Explicitly empty to verify default - serverStarted := false + started := make(chan struct{}) + shutdownCalled := make(chan struct{}) deps := DefaultDeps() deps.NewHTTPServer = func(addr string, handler http.Handler) *http.Server { @@ -212,23 +213,26 @@ func TestRunApplicationDefaultsToAllRole(t *testing.T) { } } deps.StartHTTPServer = func(*http.Server) error { - serverStarted = true + close(started) return http.ErrServerClosed } deps.ShutdownHTTPServer = func(context.Context, *http.Server) error { + close(shutdownCalled) return nil } deps.NotifySignals = func(c chan<- os.Signal, _ ...os.Signal) { go func() { - time.Sleep(50 * time.Millisecond) + <-started c <- syscall.SIGTERM }() } runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) - if !serverStarted { - t.Fatalf("expected HTTP server to start when ROLE defaults to 'all'") + select { + case <-shutdownCalled: + case <-time.After(time.Second): + t.Fatalf("expected server shutdown to be called when ROLE defaults to 'all'") } } diff --git a/internal/core/ports/task_queue.go b/internal/core/ports/task_queue.go index e9b400156..cb493bbf4 100644 --- a/internal/core/ports/task_queue.go +++ b/internal/core/ports/task_queue.go @@ -48,9 +48,10 @@ type DurableTaskQueue interface { Ack(ctx context.Context, queueName, groupName, messageID string) error // Nack signals that the consumer failed to process the message. - // The implementation should make the message available for redelivery - // (e.g. by not acknowledging it and letting the pending-entry timeout - // handle redelivery, or by explicitly re-queuing). + // It relinquishes the current delivery WITHOUT re-queuing or creating + // a new message ID. The message remains in the pending entries list + // and will be reclaimed by ReclaimStale after the idle timeout. + // Implementations MUST NOT create duplicate live copies of the message. Nack(ctx context.Context, queueName, groupName, messageID string) error // ReclaimStale claims messages that have been pending longer than the diff --git a/internal/drills/ha_drills_test.go b/internal/drills/ha_drills_test.go index 837e7f571..1e1990ae3 100644 --- a/internal/drills/ha_drills_test.go +++ b/internal/drills/ha_drills_test.go @@ -341,32 +341,60 @@ func TestDrill_RetryBackoffAndContextCancellation(t *testing.T) { func TestDrill_HalfOpenSingleFlight(t *testing.T) { const resetTimeout = 100 * time.Millisecond + stateChanged := make(chan platform.State, 10) cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ Name: "drill-halfopen", Threshold: 1, ResetTimeout: resetTimeout, + OnStateChange: func(name string, from, to platform.State) { + stateChanged <- to + }, }) - // Trip the circuit. + // Trip the circuit (closed -> open). _ = cb.Execute(func() error { return errors.New("fail") }) if cb.GetState() != platform.StateOpen { t.Fatalf("expected open, got %s", cb.GetState().String()) } - // Wait for reset timeout. - time.Sleep(resetTimeout + 20*time.Millisecond) + // Drain transitions if any. + for len(stateChanged) > 0 { + <-stateChanged + } + + // Wait for reset timeout and transition to half-open. + // Note: allowRequest transitions to HalfOpen ONLY when Execute is called after resetTimeout. + time.Sleep(resetTimeout + 10*time.Millisecond) // Start a slow probe request. + probeStarted := make(chan struct{}) probeDone := make(chan struct{}) go func() { _ = cb.Execute(func() error { + close(probeStarted) <-probeDone // Block until we release. return nil }) }() - // Give the probe goroutine time to start. - time.Sleep(20 * time.Millisecond) + // Wait for the probe to actually start and transition state. + select { + case <-probeStarted: + case <-time.After(time.Second): + t.Fatal("timeout waiting for probe to start") + } + + // Wait for transition to HalfOpen if it hasn't happened yet. + if cb.GetState() != platform.StateHalfOpen { + select { + case s := <-stateChanged: + if s != platform.StateHalfOpen { + t.Fatalf("expected transition to half-open, got %s", s.String()) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for transition to half-open") + } + } // All other requests should be rejected. for i := 0; i < 5; i++ { @@ -378,7 +406,16 @@ func TestDrill_HalfOpenSingleFlight(t *testing.T) { // Release the probe — circuit should close. close(probeDone) - time.Sleep(10 * time.Millisecond) + + // Wait for transition to Closed. + select { + case s := <-stateChanged: + if s != platform.StateClosed { + t.Fatalf("expected transition to closed after probe success, got %s", s.String()) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for transition to closed") + } if cb.GetState() != platform.StateClosed { t.Fatalf("expected closed after probe success, got %s", cb.GetState().String()) From 53a83db377b8e9a895d6b35e999c0893841f040a Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Thu, 12 Mar 2026 08:24:54 +0300 Subject: [PATCH 08/12] fix(ha): harden worker idempotency and test synchronization --- internal/core/ports/execution_ledger.go | 4 ++ internal/drills/ha_drills_test.go | 16 +++---- internal/drills/release_gates_test.go | 5 +- internal/platform/bulkhead_test.go | 4 +- internal/repositories/noop/adapters.go | 4 ++ .../repositories/postgres/execution_ledger.go | 17 +++++++ internal/workers/pipeline_worker.go | 47 ++++++++++++++----- internal/workers/provision_worker.go | 17 +++++-- internal/workers/provision_worker_test.go | 8 +++- 9 files changed, 92 insertions(+), 30 deletions(-) diff --git a/internal/core/ports/execution_ledger.go b/internal/core/ports/execution_ledger.go index ff5e37e48..cbdafd1e4 100644 --- a/internal/core/ports/execution_ledger.go +++ b/internal/core/ports/execution_ledger.go @@ -28,4 +28,8 @@ type ExecutionLedger interface { // MarkFailed marks a job execution as failed, allowing future retries. MarkFailed(ctx context.Context, jobKey string, reason string) error + + // GetStatus returns the current status, result and start time of a job. + // Returns status="", nil error if the job does not exist. + GetStatus(ctx context.Context, jobKey string) (status string, result string, startedAt time.Time, err error) } diff --git a/internal/drills/ha_drills_test.go b/internal/drills/ha_drills_test.go index 1e1990ae3..81cdb8368 100644 --- a/internal/drills/ha_drills_test.go +++ b/internal/drills/ha_drills_test.go @@ -384,16 +384,14 @@ func TestDrill_HalfOpenSingleFlight(t *testing.T) { t.Fatal("timeout waiting for probe to start") } - // Wait for transition to HalfOpen if it hasn't happened yet. - if cb.GetState() != platform.StateHalfOpen { - select { - case s := <-stateChanged: - if s != platform.StateHalfOpen { - t.Fatalf("expected transition to half-open, got %s", s.String()) - } - case <-time.After(time.Second): - t.Fatal("timeout waiting for transition to half-open") + // Wait for transition to HalfOpen. + select { + case s := <-stateChanged: + if s != platform.StateHalfOpen { + t.Fatalf("expected transition to half-open, got %s", s.String()) } + case <-time.After(time.Second): + t.Fatal("timeout waiting for transition to half-open") } // All other requests should be rejected. diff --git a/internal/drills/release_gates_test.go b/internal/drills/release_gates_test.go index 212f94028..596eb4758 100644 --- a/internal/drills/release_gates_test.go +++ b/internal/drills/release_gates_test.go @@ -69,17 +69,20 @@ func TestReleaseGate_BulkheadIsolation(t *testing.T) { // Saturate compute bulkhead. blockCh := make(chan struct{}) var wg sync.WaitGroup + var startedWg sync.WaitGroup for i := 0; i < 2; i++ { wg.Add(1) + startedWg.Add(1) go func() { defer wg.Done() _ = bhCompute.Execute(ctx, func() error { + startedWg.Done() <-blockCh return nil }) }() } - time.Sleep(30 * time.Millisecond) // Let them acquire slots. + startedWg.Wait() // Ensure they have acquired slots. // Compute is now full. err := bhCompute.Execute(ctx, func() error { return nil }) diff --git a/internal/platform/bulkhead_test.go b/internal/platform/bulkhead_test.go index dd7315d61..c7ccdfc5c 100644 --- a/internal/platform/bulkhead_test.go +++ b/internal/platform/bulkhead_test.go @@ -61,7 +61,7 @@ func TestBulkheadRejectsWhenFull(t *testing.T) { // Second call should be rejected. err := bh.Execute(context.Background(), func() error { return nil }) - assert.ErrorIs(t, err, ErrBulkheadFull) + require.ErrorIs(t, err, ErrBulkheadFull) close(done) } @@ -85,7 +85,7 @@ func TestBulkheadRespectsContext(t *testing.T) { defer cancel() err := bh.Execute(ctx, func() error { return nil }) - assert.ErrorIs(t, err, ErrBulkheadFull) + require.ErrorIs(t, err, ErrBulkheadFull) close(done) } diff --git a/internal/repositories/noop/adapters.go b/internal/repositories/noop/adapters.go index 9e907c91c..f215b58be 100644 --- a/internal/repositories/noop/adapters.go +++ b/internal/repositories/noop/adapters.go @@ -432,6 +432,10 @@ func (l *NoopExecutionLedger) MarkFailed(ctx context.Context, jobKey string, rea return nil } +func (l *NoopExecutionLedger) GetStatus(ctx context.Context, jobKey string) (string, string, time.Time, error) { + return "", "", time.Time{}, nil +} + // --- New No-Ops (for benchmarks and system tests) --- type NoopVolumeRepository struct{} diff --git a/internal/repositories/postgres/execution_ledger.go b/internal/repositories/postgres/execution_ledger.go index 008b4c9db..d4bc6777f 100644 --- a/internal/repositories/postgres/execution_ledger.go +++ b/internal/repositories/postgres/execution_ledger.go @@ -112,3 +112,20 @@ func (l *PgExecutionLedger) MarkFailed(ctx context.Context, jobKey string, reaso `, jobKey, reason) return err } + +// GetStatus returns the current status, result and start time of a job. +func (l *PgExecutionLedger) GetStatus(ctx context.Context, jobKey string) (status string, result string, startedAt time.Time, err error) { + var res pgx.Row + res = l.db.QueryRow(ctx, ` + SELECT status, COALESCE(result, ''), started_at FROM job_executions WHERE job_key = $1 + `, jobKey) + + err = res.Scan(&status, &result, &startedAt) + if err != nil { + if err == pgx.ErrNoRows { + return "", "", time.Time{}, nil + } + return "", "", time.Time{}, fmt.Errorf("execution ledger get status %s: %w", jobKey, err) + } + return status, result, startedAt, nil +} diff --git a/internal/workers/pipeline_worker.go b/internal/workers/pipeline_worker.go index d0077ef12..6d17dfdbe 100644 --- a/internal/workers/pipeline_worker.go +++ b/internal/workers/pipeline_worker.go @@ -118,10 +118,17 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl return } if !acquired { - w.logger.Info("skipping duplicate pipeline job", + // Check if it's already finished or just being processed by someone else. + status, _, _, getErr := w.ledger.GetStatus(workerCtx, jobKey) + if getErr == nil && status == "completed" { + w.logger.Info("skipping already completed pipeline job", + "build_id", job.BuildID, "msg_id", msg.ID) + _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + return + } + w.logger.Info("pipeline job is currently being processed by another worker", "build_id", job.BuildID, "msg_id", msg.ID) - _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) - return + return // Leave unacked for redelivery/wait. } } @@ -129,9 +136,20 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl defer cancel() ctx = appcontext.WithUserID(ctx, job.UserID) - build, pipeline := w.loadBuildAndPipeline(ctx, job) + build, pipeline, err := w.loadBuildAndPipeline(ctx, job) + if err != nil { + // Transient error loading build/pipeline — nack and retry. + w.logger.Error("transient error loading build/pipeline", + "build_id", job.BuildID, "error", err) + if w.ledger != nil { + _ = w.ledger.MarkFailed(workerCtx, jobKey, "transient load error") + } + _ = w.taskQueue.Nack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + return + } + if build == nil || pipeline == nil { - // Build or pipeline not found — ack to avoid infinite retries. + // Build or pipeline truly not found — ack to avoid infinite retries. if w.ledger != nil { _ = w.ledger.MarkComplete(workerCtx, jobKey, "not_found") } @@ -176,21 +194,28 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl } } -func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.BuildJob) (*domain.Build, *domain.Pipeline) { +func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.BuildJob) (*domain.Build, *domain.Pipeline, error) { build, err := w.repo.GetBuild(ctx, job.BuildID, job.UserID) - if err != nil || build == nil { + if err != nil { w.logger.Error("failed to load build", "build_id", job.BuildID, "error", err) - return nil, nil + return nil, nil, err + } + if build == nil { + return nil, nil, nil } pipeline, err := w.repo.GetPipeline(ctx, job.PipelineID, job.UserID) - if err != nil || pipeline == nil { + if err != nil { w.logger.Error("failed to load pipeline", "pipeline_id", job.PipelineID, "error", err) + w.failBuild(ctx, build, "pipeline load error: "+err.Error()) + return nil, nil, err + } + if pipeline == nil { w.failBuild(ctx, build, "pipeline not found") - return nil, nil + return build, nil, nil } - return build, pipeline + return build, pipeline, nil } func (w *PipelineWorker) markBuildRunning(ctx context.Context, build *domain.Build) bool { diff --git a/internal/workers/provision_worker.go b/internal/workers/provision_worker.go index dee39a94d..d694abea6 100644 --- a/internal/workers/provision_worker.go +++ b/internal/workers/provision_worker.go @@ -21,7 +21,8 @@ const ( provisionGroup = "provision_workers" provisionMaxWorkers = 20 // How long a message can sit in PEL before another consumer reclaims it. - provisionReclaimMs = 5 * 60 * 1000 // 5 minutes + // Must be longer than provisionStaleThreshold (15m) to avoid premature reclaim. + provisionReclaimMs = 20 * 60 * 1000 // 20 minutes provisionReclaimN = 10 // Stale threshold for idempotency ledger: if a "running" entry is older // than this, it is considered abandoned and can be reclaimed. @@ -129,11 +130,17 @@ func (w *ProvisionWorker) processJob(workerCtx context.Context, msg *ports.Durab return } if !acquired { - w.logger.Info("skipping duplicate provision job", + // Check if it's already finished or just being processed by someone else. + status, _, _, getErr := w.ledger.GetStatus(workerCtx, jobKey) + if getErr == nil && status == "completed" { + w.logger.Info("skipping already completed provision job", + "instance_id", job.InstanceID, "msg_id", msg.ID) + _ = w.taskQueue.Ack(workerCtx, provisionQueue, provisionGroup, msg.ID) + return + } + w.logger.Info("provision job is currently being processed by another worker", "instance_id", job.InstanceID, "msg_id", msg.ID) - // Already processed — ack the duplicate message. - _ = w.taskQueue.Ack(workerCtx, provisionQueue, provisionGroup, msg.ID) - return + return // Leave unacked for redelivery/wait. } } diff --git a/internal/workers/provision_worker_test.go b/internal/workers/provision_worker_test.go index a79785f7f..9d6397df7 100644 --- a/internal/workers/provision_worker_test.go +++ b/internal/workers/provision_worker_test.go @@ -166,9 +166,13 @@ func TestProvisionWorkerRun(t *testing.T) { assert.Contains(t, buf.String(), tt.wantLog) if tt.wantAcked { assert.NotEmpty(t, fq.acked, "expected message to be acked") - } - if tt.wantNacked { + assert.Empty(t, fq.nacked, "did not expect message to be nacked when acked") + } else if tt.wantNacked { assert.NotEmpty(t, fq.nacked, "expected message to be nacked") + assert.Empty(t, fq.acked, "did not expect message to be acked when nacked") + } else { + assert.Empty(t, fq.acked, "expected no ack") + assert.Empty(t, fq.nacked, "expected no nack") } }) } From 14e7fdc2018e09d31a75e0cb7f562f3e6e13bc79 Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 14 Apr 2026 14:41:43 +0300 Subject: [PATCH 09/12] fix(resilience): propagate role errors and harden platform primitives --- cmd/api/main.go | 8 ++-- cmd/api/main_test.go | 20 ++++++-- internal/platform/bulkhead.go | 6 +++ internal/platform/circuit_breaker.go | 70 +++++++++++++++++++--------- internal/platform/retry.go | 9 +--- internal/platform/retry_test.go | 6 +-- 6 files changed, 80 insertions(+), 39 deletions(-) diff --git a/cmd/api/main.go b/cmd/api/main.go index 569086b85..e62cf9000 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -166,11 +166,10 @@ func run() error { r.Use(otelgin.Middleware("compute-api")) } - runApplication(deps, cfg, logger, r, workers) - return nil + return runApplication(deps, cfg, logger, r, workers) } -func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r *gin.Engine, workers *setup.Workers) { +func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r *gin.Engine, workers *setup.Workers) error { role := os.Getenv("ROLE") if role == "" { role = "all" @@ -179,7 +178,7 @@ func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r * validRoles := map[string]bool{"api": true, "worker": true, "all": true} if !validRoles[role] { logger.Error("invalid ROLE value, must be one of: api, worker, all", "role", role) - return + return fmt.Errorf("invalid ROLE value %q, must be one of: api, worker, all", role) } logger.Info("starting with role", "role", role) @@ -221,6 +220,7 @@ func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r * workerCancel() wg.Wait() logger.Info("shutdown complete") + return nil } type runner interface { diff --git a/cmd/api/main_test.go b/cmd/api/main_test.go index 321281b79..a311b4ad8 100644 --- a/cmd/api/main_test.go +++ b/cmd/api/main_test.go @@ -158,7 +158,10 @@ func TestRunApplicationApiRoleStartsAndShutsDown(t *testing.T) { }() } - runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + err := runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + if err != nil { + t.Fatalf("runApplication returned error: %v", err) + } select { case <-shutdownCalled: @@ -193,7 +196,10 @@ func TestRunApplicationWorkerRoleDoesNotStartHTTP(t *testing.T) { }() } - runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + err := runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + if err != nil { + t.Fatalf("runApplication returned error: %v", err) + } // If we reach here without t.Fatalf, the test passes — no HTTP server was touched. } @@ -227,7 +233,10 @@ func TestRunApplicationDefaultsToAllRole(t *testing.T) { }() } - runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + err := runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + if err != nil { + t.Fatalf("runApplication returned error: %v", err) + } select { case <-shutdownCalled: @@ -255,7 +264,10 @@ func TestRunApplicationInvalidRoleReturnsEarly(t *testing.T) { } // Should return immediately without starting anything - runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + err := runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + if err == nil { + t.Fatalf("expected error for invalid role") + } } // Stub helpers below keep main.go testable without altering production behavior. diff --git a/internal/platform/bulkhead.go b/internal/platform/bulkhead.go index 8f6cbf86a..033f5ee8f 100644 --- a/internal/platform/bulkhead.go +++ b/internal/platform/bulkhead.go @@ -50,6 +50,12 @@ func (b *Bulkhead) Execute(ctx context.Context, fn func() error) error { } func (b *Bulkhead) acquire(ctx context.Context) error { + select { + case <-ctx.Done(): + return ErrBulkheadFull + default: + } + if b.timeout > 0 { timer := time.NewTimer(b.timeout) defer timer.Stop() diff --git a/internal/platform/circuit_breaker.go b/internal/platform/circuit_breaker.go index d310367ab..e710e885f 100644 --- a/internal/platform/circuit_breaker.go +++ b/internal/platform/circuit_breaker.go @@ -118,37 +118,50 @@ func (cb *CircuitBreaker) Execute(fn func() error) error { func (cb *CircuitBreaker) allowRequest() bool { cb.mu.Lock() - defer cb.mu.Unlock() + var cbFunc StateChangeFunc + var name string + var from, to State + var changed bool + allowed := false switch cb.state { case StateClosed: - return true + allowed = true case StateOpen: if time.Since(cb.lastFailure) <= cb.resetTimeout { - return false + break } // Transition to half-open; only allow one probe at a time. if cb.halfOpenInFlight { - return false + break } - cb.transitionLocked(StateHalfOpen) + cbFunc, name, from, to, changed = cb.transitionLocked(StateHalfOpen) cb.halfOpenInFlight = true cb.successCount = 0 - return true + allowed = true case StateHalfOpen: // Allow additional requests only if no probe is in flight. if cb.halfOpenInFlight { - return false + break } cb.halfOpenInFlight = true - return true + allowed = true + } + cb.mu.Unlock() + + if changed && cbFunc != nil { + cbFunc(name, from, to) } - return false + + return allowed } func (cb *CircuitBreaker) recordFailure() { cb.mu.Lock() - defer cb.mu.Unlock() + var cbFunc StateChangeFunc + var name string + var from, to State + var changed bool cb.halfOpenInFlight = false cb.failureCount++ @@ -157,17 +170,25 @@ func (cb *CircuitBreaker) recordFailure() { switch cb.state { case StateClosed: if cb.failureCount >= cb.threshold { - cb.transitionLocked(StateOpen) + cbFunc, name, from, to, changed = cb.transitionLocked(StateOpen) } case StateHalfOpen: // Probe failed — go back to open. - cb.transitionLocked(StateOpen) + cbFunc, name, from, to, changed = cb.transitionLocked(StateOpen) + } + cb.mu.Unlock() + + if changed && cbFunc != nil { + cbFunc(name, from, to) } } func (cb *CircuitBreaker) recordSuccess() { cb.mu.Lock() - defer cb.mu.Unlock() + var cbFunc StateChangeFunc + var name string + var from, to State + var changed bool cb.halfOpenInFlight = false @@ -177,36 +198,43 @@ func (cb *CircuitBreaker) recordSuccess() { if cb.successCount >= cb.successRequired { cb.failureCount = 0 cb.successCount = 0 - cb.transitionLocked(StateClosed) + cbFunc, name, from, to, changed = cb.transitionLocked(StateClosed) } default: cb.failureCount = 0 cb.state = StateClosed } + cb.mu.Unlock() + + if changed && cbFunc != nil { + cbFunc(name, from, to) + } } // transitionLocked changes state and fires the callback. Must be called // with cb.mu held. The callback is invoked synchronously; implementations // must not block or acquire cb.mu. -func (cb *CircuitBreaker) transitionLocked(to State) { +func (cb *CircuitBreaker) transitionLocked(to State) (StateChangeFunc, string, State, State, bool) { from := cb.state if from == to { - return + return nil, "", from, to, false } cb.state = to - if cb.onStateChange != nil { - cb.onStateChange(cb.name, from, to) - } + return cb.onStateChange, cb.name, from, to, true } // Reset clears the circuit breaker state. func (cb *CircuitBreaker) Reset() { cb.mu.Lock() - defer cb.mu.Unlock() + cbFunc, name, from, to, changed := cb.transitionLocked(StateClosed) cb.failureCount = 0 cb.successCount = 0 cb.halfOpenInFlight = false - cb.transitionLocked(StateClosed) + cb.mu.Unlock() + + if changed && cbFunc != nil { + cbFunc(name, from, to) + } } // GetState returns the current state of the circuit breaker. diff --git a/internal/platform/retry.go b/internal/platform/retry.go index 86391a0bf..c465c8be1 100644 --- a/internal/platform/retry.go +++ b/internal/platform/retry.go @@ -86,11 +86,6 @@ func backoffDelay(attempt int, base, max time.Duration, mult float64) time.Durat if calculated > max || calculated <= 0 { calculated = max } - // Full jitter: uniform random in [base/2, calculated]. - floor := base / 2 - if floor > calculated { - floor = calculated - } - jittered := floor + time.Duration(rand.Int64N(int64(calculated-floor+1))) - return jittered + // Full jitter: uniform random in [0, calculated]. + return time.Duration(rand.Int64N(int64(calculated) + 1)) } diff --git a/internal/platform/retry_test.go b/internal/platform/retry_test.go index 0eea21edf..99ce1344a 100644 --- a/internal/platform/retry_test.go +++ b/internal/platform/retry_test.go @@ -103,17 +103,17 @@ func TestBackoffDelay(t *testing.T) { base := 100 * time.Millisecond max := 5 * time.Second - // Attempt 0: jitter in [base/2, base] + // Attempt 0: jitter in [0, base] for i := 0; i < 100; i++ { d := backoffDelay(0, base, max, 2.0) - assert.GreaterOrEqual(t, d, base/2) + assert.GreaterOrEqual(t, d, time.Duration(0)) assert.LessOrEqual(t, d, base) } // Attempt 3: calculated = 100ms * 2^3 = 800ms for i := 0; i < 100; i++ { d := backoffDelay(3, base, max, 2.0) - assert.GreaterOrEqual(t, d, base/2) + assert.GreaterOrEqual(t, d, time.Duration(0)) assert.LessOrEqual(t, d, 800*time.Millisecond) } } From 3adf7890263ba15b1624c4f62f2b5424e0dfef1f Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 14 Apr 2026 14:42:07 +0300 Subject: [PATCH 10/12] fix(ha): tighten ledger transitions and durable queue acknowledgements --- .../repositories/postgres/execution_ledger.go | 28 ++++++++++++----- .../repositories/postgres/leader_elector.go | 11 +++++-- .../repositories/redis/durable_task_queue.go | 30 +++++++++++++++++-- 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/internal/repositories/postgres/execution_ledger.go b/internal/repositories/postgres/execution_ledger.go index d4bc6777f..24f3e401b 100644 --- a/internal/repositories/postgres/execution_ledger.go +++ b/internal/repositories/postgres/execution_ledger.go @@ -79,7 +79,7 @@ func (l *PgExecutionLedger) TryAcquire(ctx context.Context, jobKey string, stale return tag.RowsAffected() > 0, nil case "failed": // Retry a previously failed job. - _, err = l.db.Exec(ctx, ` + tag, err := l.db.Exec(ctx, ` UPDATE job_executions SET started_at = NOW(), status = 'running', completed_at = NULL, result = NULL WHERE job_key = $1 AND status = 'failed' @@ -87,7 +87,7 @@ func (l *PgExecutionLedger) TryAcquire(ctx context.Context, jobKey string, stale if err != nil { return false, fmt.Errorf("execution ledger retry %s: %w", jobKey, err) } - return true, nil + return tag.RowsAffected() > 0, nil default: return false, fmt.Errorf("execution ledger unknown status %q for %s", status, jobKey) } @@ -95,22 +95,34 @@ func (l *PgExecutionLedger) TryAcquire(ctx context.Context, jobKey string, stale // MarkComplete marks a job as successfully completed. func (l *PgExecutionLedger) MarkComplete(ctx context.Context, jobKey string, result string) error { - _, err := l.db.Exec(ctx, ` + tag, err := l.db.Exec(ctx, ` UPDATE job_executions SET status = 'completed', completed_at = NOW(), result = $2 - WHERE job_key = $1 + WHERE job_key = $1 AND status = 'running' `, jobKey, result) - return err + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return fmt.Errorf("execution ledger mark complete %s: no running row updated", jobKey) + } + return nil } // MarkFailed marks a job as failed, allowing future retries. func (l *PgExecutionLedger) MarkFailed(ctx context.Context, jobKey string, reason string) error { - _, err := l.db.Exec(ctx, ` + tag, err := l.db.Exec(ctx, ` UPDATE job_executions SET status = 'failed', completed_at = NOW(), result = $2 - WHERE job_key = $1 + WHERE job_key = $1 AND status = 'running' `, jobKey, reason) - return err + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return fmt.Errorf("execution ledger mark failed %s: no running row updated", jobKey) + } + return nil } // GetStatus returns the current status, result and start time of a job. diff --git a/internal/repositories/postgres/leader_elector.go b/internal/repositories/postgres/leader_elector.go index a3aad2210..ac99bb0cb 100644 --- a/internal/repositories/postgres/leader_elector.go +++ b/internal/repositories/postgres/leader_elector.go @@ -122,7 +122,9 @@ func (e *PgLeaderElector) RunAsLeader(ctx context.Context, key string, fn func(c fnCtx, fnCancel := context.WithCancel(ctx) defer fnCancel() defer func() { - if err := e.Release(context.Background(), key); err != nil { + releaseCtx, releaseCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer releaseCancel() + if err := e.Release(releaseCtx, key); err != nil { e.logger.Error("failed to release leadership", "key", key, "error", err) } }() @@ -169,7 +171,12 @@ func (e *PgLeaderElector) heartbeat(ctx context.Context, key string, cancel cont } if stillHeld { // We re-acquired (re-entrant), so unlock the extra lock count - _, _ = e.db.Exec(ctx, "SELECT pg_advisory_unlock($1)", lockID) + if _, unlockErr := e.db.Exec(ctx, "SELECT pg_advisory_unlock($1)", lockID); unlockErr != nil { + e.logger.Error("failed to release re-entrant heartbeat lock", + "key", key, "error", unlockErr) + cancel() + return + } } else { // We lost the lock e.logger.Error("leadership lost", "key", key) diff --git a/internal/repositories/redis/durable_task_queue.go b/internal/repositories/redis/durable_task_queue.go index ac02d22a4..0aae2c066 100644 --- a/internal/repositories/redis/durable_task_queue.go +++ b/internal/repositories/redis/durable_task_queue.go @@ -85,7 +85,13 @@ func (q *durableTaskQueue) Dequeue(ctx context.Context, queueName string) (strin } msg := res[0].Messages[0] // Auto-delete since legacy callers don't ack. - q.client.XDel(ctx, queueName, msg.ID) + deleted, delErr := q.client.XDel(ctx, queueName, msg.ID).Result() + if delErr != nil { + return "", fmt.Errorf("durable dequeue xdel %s/%s: %w", queueName, msg.ID, delErr) + } + if deleted == 0 { + return "", fmt.Errorf("durable dequeue xdel %s/%s: no message deleted", queueName, msg.ID) + } payload, _ := msg.Values["payload"].(string) return payload, nil } @@ -132,7 +138,23 @@ func (q *durableTaskQueue) Receive(ctx context.Context, queueName, groupName, co } func (q *durableTaskQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { - return q.client.XAck(ctx, queueName, groupName, messageID).Err() + acked, err := q.client.XAck(ctx, queueName, groupName, messageID).Result() + if err != nil { + return fmt.Errorf("ack %s/%s/%s: %w", queueName, groupName, messageID, err) + } + if acked == 0 { + return fmt.Errorf("ack %s/%s/%s: message not pending", queueName, groupName, messageID) + } + + deleted, delErr := q.client.XDel(ctx, queueName, messageID).Result() + if delErr != nil { + return fmt.Errorf("ack xdel %s/%s: %w", queueName, messageID, delErr) + } + if deleted == 0 { + return fmt.Errorf("ack xdel %s/%s: no message deleted", queueName, messageID) + } + + return nil } func (q *durableTaskQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { @@ -165,7 +187,9 @@ func (q *durableTaskQueue) ReclaimStale(ctx context.Context, queueName, groupNam // Dead-letter messages that exceeded max retries. if xmsg.DeliveredCount > 0 && xmsg.DeliveredCount > q.maxRetries { - _ = q.deadLetter(ctx, queueName, groupName, xmsg) + if dlqErr := q.deadLetter(ctx, queueName, groupName, xmsg); dlqErr != nil { + return nil, fmt.Errorf("dead-letter %s/%s/%s: %w", queueName, groupName, xmsg.ID, dlqErr) + } continue } From 567850f60a4da86ac3f58ae162746af14cf4da7e Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 14 Apr 2026 14:43:02 +0300 Subject: [PATCH 11/12] fix(workers): improve context propagation and queue/ledger observability --- internal/workers/cluster_worker.go | 66 ++++++++++++++++++-------- internal/workers/pipeline_worker.go | 69 ++++++++++++++++++++-------- internal/workers/provision_worker.go | 45 +++++++++++------- 3 files changed, 126 insertions(+), 54 deletions(-) diff --git a/internal/workers/cluster_worker.go b/internal/workers/cluster_worker.go index 149eab9b8..cb88896f3 100644 --- a/internal/workers/cluster_worker.go +++ b/internal/workers/cluster_worker.go @@ -22,6 +22,7 @@ const ( clusterReclaimMs = 5 * 60 * 1000 // 5 minutes clusterReclaimN = 10 clusterStaleThreshold = 15 * time.Minute + clusterReceiveBackoff = 1 * time.Second ) // ClusterWorker handles background tasks for Kubernetes cluster lifecycle management. @@ -36,7 +37,11 @@ type ClusterWorker struct { // NewClusterWorker creates a new ClusterWorker. func NewClusterWorker(repo ports.ClusterRepository, provisioner ports.ClusterProvisioner, taskQueue ports.DurableTaskQueue, ledger ports.ExecutionLedger, logger *slog.Logger) *ClusterWorker { - hostname, _ := os.Hostname() + hostname, err := os.Hostname() + if err != nil { + logger.Warn("failed to get hostname, using fallback", "error", err) + hostname = "cluster-worker" + } if hostname == "" { hostname = "cluster-worker" } @@ -75,7 +80,7 @@ func (w *ClusterWorker) Run(ctx context.Context, wg *sync.WaitGroup) { msg, err := w.taskQueue.Receive(ctx, clusterQueue, clusterGroup, w.consumerName) if err != nil { w.logger.Error("failed to receive cluster job", "error", err) - time.Sleep(1 * time.Second) + time.Sleep(clusterReceiveBackoff) continue } if msg == nil { @@ -86,7 +91,7 @@ func (w *ClusterWorker) Run(ctx context.Context, wg *sync.WaitGroup) { if err := json.Unmarshal([]byte(msg.Payload), &job); err != nil { w.logger.Error("failed to unmarshal cluster job", "error", err, "msg_id", msg.ID) - _ = w.taskQueue.Ack(ctx, clusterQueue, clusterGroup, msg.ID) + w.ackWithLog(ctx, msg.ID, "cluster poison message") continue } @@ -114,27 +119,30 @@ func (w *ClusterWorker) processJob(workerCtx context.Context, msg *ports.Durable if err != nil { w.logger.Error("execution ledger error", "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", err) - _ = w.taskQueue.Nack(workerCtx, clusterQueue, clusterGroup, msg.ID) + w.nackWithLog(workerCtx, msg.ID, "ledger try_acquire failed") return } if !acquired { w.logger.Info("skipping duplicate cluster job", "cluster_id", job.ClusterID, "type", job.Type, "msg_id", msg.ID) - _ = w.taskQueue.Ack(workerCtx, clusterQueue, clusterGroup, msg.ID) + w.ackWithLog(workerCtx, msg.ID, "duplicate cluster job") return } } - ctx := appcontext.WithUserID(context.Background(), job.UserID) + ctx := appcontext.WithUserID(workerCtx, job.UserID) cluster, err := w.repo.GetByID(ctx, job.ClusterID) if err != nil { w.logger.Error("failed to fetch cluster for job", "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", err) if w.ledger != nil { - _ = w.ledger.MarkFailed(workerCtx, jobKey, err.Error()) + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, err.Error()); ledgerErr != nil { + w.logger.Warn("failed to mark cluster job failed in ledger", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", ledgerErr) + } } - _ = w.taskQueue.Nack(workerCtx, clusterQueue, clusterGroup, msg.ID) + w.nackWithLog(workerCtx, msg.ID, "cluster fetch failed") return } if cluster == nil { @@ -142,9 +150,12 @@ func (w *ClusterWorker) processJob(workerCtx context.Context, msg *ports.Durable "cluster_id", job.ClusterID, "msg_id", msg.ID) // Ack — cluster was deleted, nothing to do. if w.ledger != nil { - _ = w.ledger.MarkComplete(workerCtx, jobKey, "cluster_not_found") + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "cluster_not_found"); ledgerErr != nil { + w.logger.Warn("failed to mark cluster job complete in ledger", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", ledgerErr) + } } - _ = w.taskQueue.Ack(workerCtx, clusterQueue, clusterGroup, msg.ID) + w.ackWithLog(workerCtx, msg.ID, "cluster not found") return } @@ -156,6 +167,8 @@ func (w *ClusterWorker) processJob(workerCtx context.Context, msg *ports.Durable processErr = w.handleDeprovision(ctx, cluster) case domain.ClusterJobUpgrade: processErr = w.handleUpgrade(ctx, cluster, job.Version) + default: + processErr = fmt.Errorf("unsupported cluster job type %q for cluster %s", job.Type, job.ClusterID) } if processErr != nil { @@ -163,19 +176,22 @@ func (w *ClusterWorker) processJob(workerCtx context.Context, msg *ports.Durable "cluster_id", job.ClusterID, "type", job.Type, "msg_id", msg.ID, "error", processErr) if w.ledger != nil { - _ = w.ledger.MarkFailed(workerCtx, jobKey, processErr.Error()) + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, processErr.Error()); ledgerErr != nil { + w.logger.Warn("failed to mark cluster job failed in ledger", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", ledgerErr) + } } - _ = w.taskQueue.Nack(workerCtx, clusterQueue, clusterGroup, msg.ID) + w.nackWithLog(workerCtx, msg.ID, "cluster job processing failed") return } if w.ledger != nil { - _ = w.ledger.MarkComplete(workerCtx, jobKey, "ok") - } - if err := w.taskQueue.Ack(workerCtx, clusterQueue, clusterGroup, msg.ID); err != nil { - w.logger.Error("failed to ack cluster job", - "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", err) + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "ok"); ledgerErr != nil { + w.logger.Warn("failed to mark cluster job complete in ledger", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", ledgerErr) + } } + w.ackWithLog(workerCtx, msg.ID, "cluster job success") } func (w *ClusterWorker) handleProvision(ctx context.Context, cluster *domain.Cluster) error { @@ -260,7 +276,7 @@ func (w *ClusterWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { w.logger.Error("failed to unmarshal reclaimed cluster job", "msg_id", m.ID, "error", err) - _ = w.taskQueue.Ack(ctx, clusterQueue, clusterGroup, m.ID) + w.ackWithLog(ctx, m.ID, "reclaimed cluster poison message") continue } w.logger.Info("reclaimed stale cluster job", @@ -276,3 +292,17 @@ func (w *ClusterWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { } } } + +func (w *ClusterWorker) ackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Ack(ctx, clusterQueue, clusterGroup, messageID); err != nil { + w.logger.Warn("failed to ack cluster job", + "msg_id", messageID, "reason", reason, "error", err) + } +} + +func (w *ClusterWorker) nackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Nack(ctx, clusterQueue, clusterGroup, messageID); err != nil { + w.logger.Warn("failed to nack cluster job", + "msg_id", messageID, "reason", reason, "error", err) + } +} diff --git a/internal/workers/pipeline_worker.go b/internal/workers/pipeline_worker.go index 6d17dfdbe..b5e3d8e3f 100644 --- a/internal/workers/pipeline_worker.go +++ b/internal/workers/pipeline_worker.go @@ -92,7 +92,7 @@ func (w *PipelineWorker) Run(ctx context.Context, wg *sync.WaitGroup) { if err := json.Unmarshal([]byte(msg.Payload), &job); err != nil { w.logger.Error("failed to unmarshal build job", "error", err, "msg_id", msg.ID) - _ = w.taskQueue.Ack(ctx, pipelineQueueName, pipelineGroup, msg.ID) + w.ackWithLog(ctx, msg.ID, "pipeline poison message") continue } @@ -114,7 +114,7 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl if err != nil { w.logger.Error("execution ledger error", "build_id", job.BuildID, "msg_id", msg.ID, "error", err) - _ = w.taskQueue.Nack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + w.nackWithLog(workerCtx, msg.ID, "ledger try_acquire failed") return } if !acquired { @@ -123,7 +123,7 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl if getErr == nil && status == "completed" { w.logger.Info("skipping already completed pipeline job", "build_id", job.BuildID, "msg_id", msg.ID) - _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + w.ackWithLog(workerCtx, msg.ID, "pipeline already completed") return } w.logger.Info("pipeline job is currently being processed by another worker", @@ -132,7 +132,7 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl } } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + ctx, cancel := context.WithTimeout(workerCtx, 30*time.Minute) defer cancel() ctx = appcontext.WithUserID(ctx, job.UserID) @@ -142,56 +142,71 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl w.logger.Error("transient error loading build/pipeline", "build_id", job.BuildID, "error", err) if w.ledger != nil { - _ = w.ledger.MarkFailed(workerCtx, jobKey, "transient load error") + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, "transient load error"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job failed in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } } - _ = w.taskQueue.Nack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + w.nackWithLog(workerCtx, msg.ID, "transient pipeline load error") return } if build == nil || pipeline == nil { // Build or pipeline truly not found — ack to avoid infinite retries. if w.ledger != nil { - _ = w.ledger.MarkComplete(workerCtx, jobKey, "not_found") + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "not_found"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job complete in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } } - _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + w.ackWithLog(workerCtx, msg.ID, "pipeline build/pipeline not found") return } if !w.markBuildRunning(ctx, build) { if w.ledger != nil { - _ = w.ledger.MarkFailed(workerCtx, jobKey, "failed to mark build running") + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, "failed to mark build running"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job failed in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } } - _ = w.taskQueue.Nack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + w.nackWithLog(workerCtx, msg.ID, "mark build running failed") return } if len(pipeline.Config.Stages) == 0 { w.failBuild(ctx, build, "pipeline has no stages") if w.ledger != nil { - _ = w.ledger.MarkComplete(workerCtx, jobKey, "no_stages") + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "no_stages"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job complete in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } } - _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + w.ackWithLog(workerCtx, msg.ID, "pipeline has no stages") return } if !w.executePipeline(ctx, build, pipeline) { // Build failed but was processed — ack the message. if w.ledger != nil { - _ = w.ledger.MarkComplete(workerCtx, jobKey, "build_failed") + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "build_failed"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job complete in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } } - _ = w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID) + w.ackWithLog(workerCtx, msg.ID, "pipeline execution failed") return } w.markBuildSucceeded(ctx, build) if w.ledger != nil { - _ = w.ledger.MarkComplete(workerCtx, jobKey, "ok") - } - if err := w.taskQueue.Ack(workerCtx, pipelineQueueName, pipelineGroup, msg.ID); err != nil { - w.logger.Error("failed to ack pipeline job", - "build_id", job.BuildID, "msg_id", msg.ID, "error", err) + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "ok"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job complete in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } } + w.ackWithLog(workerCtx, msg.ID, "pipeline job success") } func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.BuildJob) (*domain.Build, *domain.Pipeline, error) { @@ -390,7 +405,7 @@ func (w *PipelineWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { w.logger.Error("failed to unmarshal reclaimed pipeline job", "msg_id", m.ID, "error", err) - _ = w.taskQueue.Ack(ctx, pipelineQueueName, pipelineGroup, m.ID) + w.ackWithLog(ctx, m.ID, "reclaimed pipeline poison message") continue } w.logger.Info("reclaimed stale pipeline job", @@ -406,3 +421,17 @@ func (w *PipelineWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { } } } + +func (w *PipelineWorker) ackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Ack(ctx, pipelineQueueName, pipelineGroup, messageID); err != nil { + w.logger.Warn("failed to ack pipeline job", + "msg_id", messageID, "reason", reason, "error", err) + } +} + +func (w *PipelineWorker) nackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Nack(ctx, pipelineQueueName, pipelineGroup, messageID); err != nil { + w.logger.Warn("failed to nack pipeline job", + "msg_id", messageID, "reason", reason, "error", err) + } +} diff --git a/internal/workers/provision_worker.go b/internal/workers/provision_worker.go index d694abea6..a9cfc217f 100644 --- a/internal/workers/provision_worker.go +++ b/internal/workers/provision_worker.go @@ -97,7 +97,7 @@ func (w *ProvisionWorker) Run(ctx context.Context, wg *sync.WaitGroup) { w.logger.Error("failed to unmarshal provision job", "error", err, "msg_id", msg.ID) // Ack poison messages so they don't block the queue. - _ = w.taskQueue.Ack(ctx, provisionQueue, provisionGroup, msg.ID) + w.ackWithLog(ctx, msg.ID, "provision poison message") continue } @@ -126,7 +126,7 @@ func (w *ProvisionWorker) processJob(workerCtx context.Context, msg *ports.Durab w.logger.Error("execution ledger error", "instance_id", job.InstanceID, "msg_id", msg.ID, "error", err) // On ledger error, nack to retry later. - _ = w.taskQueue.Nack(workerCtx, provisionQueue, provisionGroup, msg.ID) + w.nackWithLog(workerCtx, msg.ID, "ledger try_acquire failed") return } if !acquired { @@ -135,7 +135,7 @@ func (w *ProvisionWorker) processJob(workerCtx context.Context, msg *ports.Durab if getErr == nil && status == "completed" { w.logger.Info("skipping already completed provision job", "instance_id", job.InstanceID, "msg_id", msg.ID) - _ = w.taskQueue.Ack(workerCtx, provisionQueue, provisionGroup, msg.ID) + w.ackWithLog(workerCtx, msg.ID, "provision already completed") return } w.logger.Info("provision job is currently being processed by another worker", @@ -145,8 +145,7 @@ func (w *ProvisionWorker) processJob(workerCtx context.Context, msg *ports.Durab } // Root context for background task with 10-minute safety timeout. - baseCtx := context.Background() - ctx, cancel := context.WithTimeout(baseCtx, 10*time.Minute) + ctx, cancel := context.WithTimeout(workerCtx, 10*time.Minute) defer cancel() // Inject User and Tenant IDs for repository access control. @@ -162,10 +161,13 @@ func (w *ProvisionWorker) processJob(workerCtx context.Context, msg *ports.Durab ) // Mark failed in the ledger so it can be retried. if w.ledger != nil { - _ = w.ledger.MarkFailed(workerCtx, jobKey, err.Error()) + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, err.Error()); ledgerErr != nil { + w.logger.Warn("failed to mark provision job failed in ledger", + "instance_id", job.InstanceID, "msg_id", msg.ID, "error", ledgerErr) + } } // Nack: leave message in PEL for reclaim/retry. - _ = w.taskQueue.Nack(workerCtx, provisionQueue, provisionGroup, msg.ID) + w.nackWithLog(workerCtx, msg.ID, "provision failed") return } @@ -176,17 +178,14 @@ func (w *ProvisionWorker) processJob(workerCtx context.Context, msg *ports.Durab // Mark completed in ledger (prevents duplicate execution). if w.ledger != nil { - _ = w.ledger.MarkComplete(workerCtx, jobKey, "ok") + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "ok"); ledgerErr != nil { + w.logger.Warn("failed to mark provision job complete in ledger", + "instance_id", job.InstanceID, "msg_id", msg.ID, "error", ledgerErr) + } } // Acknowledge — message is permanently consumed. - if err := w.taskQueue.Ack(workerCtx, provisionQueue, provisionGroup, msg.ID); err != nil { - w.logger.Error("failed to ack provision job", - "instance_id", job.InstanceID, - "msg_id", msg.ID, - "error", err, - ) - } + w.ackWithLog(workerCtx, msg.ID, "provision success") } // reclaimLoop periodically reclaims messages stuck in the PEL from crashed @@ -210,7 +209,7 @@ func (w *ProvisionWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { w.logger.Error("failed to unmarshal reclaimed provision job", "msg_id", m.ID, "error", err) - _ = w.taskQueue.Ack(ctx, provisionQueue, provisionGroup, m.ID) + w.ackWithLog(ctx, m.ID, "reclaimed provision poison message") continue } w.logger.Info("reclaimed stale provision job", @@ -226,3 +225,17 @@ func (w *ProvisionWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { } } } + +func (w *ProvisionWorker) ackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Ack(ctx, provisionQueue, provisionGroup, messageID); err != nil { + w.logger.Warn("failed to ack provision job", + "msg_id", messageID, "reason", reason, "error", err) + } +} + +func (w *ProvisionWorker) nackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Nack(ctx, provisionQueue, provisionGroup, messageID); err != nil { + w.logger.Warn("failed to nack provision job", + "msg_id", messageID, "reason", reason, "error", err) + } +} From 9cca272ab4b36760171403a539a1b7b0ff2d6849 Mon Sep 17 00:00:00 2001 From: jackthepunished Date: Tue, 14 Apr 2026 16:11:20 +0300 Subject: [PATCH 12/12] fix(ci): align durable queue wiring and stabilize benchmark PR runs --- .github/workflows/benchmarks.yml | 17 ++++++++++++++++- internal/api/setup/dependencies.go | 6 +++--- internal/workers/pipeline_worker_test.go | 7 +++++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 6028dba78..54a7a4f61 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -27,7 +27,22 @@ jobs: | grep -E '^(goos:|goarch:|pkg:|cpu:|Benchmark|PASS$|ok\s)' \ | tee bench.txt - - name: Store Benchmark Result + - name: Store Benchmark Result (PR) + if: github.event_name == 'pull_request' + uses: benchmark-action/github-action-benchmark@v1 + with: + name: Go Benchmarks + tool: 'go' + output-file-path: bench.txt + # On PRs, publishing to gh-pages is not allowed in all permission models. + auto-push: false + # Fail if performance drops by more than 50% + alert-threshold: '200%' + comment-on-alert: false + fail-on-alert: false + + - name: Store Benchmark Result (main) + if: github.event_name == 'push' && github.ref == 'refs/heads/main' uses: benchmark-action/github-action-benchmark@v1 with: name: Go Benchmarks diff --git a/internal/api/setup/dependencies.go b/internal/api/setup/dependencies.go index 06a74122e..db3118bff 100644 --- a/internal/api/setup/dependencies.go +++ b/internal/api/setup/dependencies.go @@ -249,7 +249,7 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { logSvc := services.NewCloudLogsService(c.Repos.Log, rbacSvc, c.Logger) - instSvcConcrete := services.NewInstanceService(services.InstanceServiceParams{Repo: c.Repos.Instance, VpcRepo: c.Repos.Vpc, SubnetRepo: c.Repos.Subnet, VolumeRepo: c.Repos.Volume, InstanceTypeRepo: c.Repos.InstanceType, RBAC: rbacSvc, Compute: c.Compute, Network: c.Network, EventSvc: eventSvc, AuditSvc: auditSvc, DNSSvc: dnsSvc, TaskQueue: c.Repos.TaskQueue, DockerNetwork: c.Config.DockerDefaultNetwork, Logger: c.Logger, TenantSvc: tenantSvc, SSHKeySvc: sshKeySvc, LogSvc: logSvc}) + instSvcConcrete := services.NewInstanceService(services.InstanceServiceParams{Repo: c.Repos.Instance, VpcRepo: c.Repos.Vpc, SubnetRepo: c.Repos.Subnet, VolumeRepo: c.Repos.Volume, InstanceTypeRepo: c.Repos.InstanceType, RBAC: rbacSvc, Compute: c.Compute, Network: c.Network, EventSvc: eventSvc, AuditSvc: auditSvc, DNSSvc: dnsSvc, TaskQueue: c.Repos.DurableQueue, DockerNetwork: c.Config.DockerDefaultNetwork, Logger: c.Logger, TenantSvc: tenantSvc, SSHKeySvc: sshKeySvc, LogSvc: logSvc}) sgSvc := services.NewSecurityGroupService(c.Repos.SecurityGroup, rbacSvc, c.Repos.Vpc, c.Network, auditSvc, c.Logger) lbSvc := services.NewLBService(c.Repos.LB, rbacSvc, c.Repos.Vpc, c.Repos.Instance, auditSvc, c.Logger) @@ -301,7 +301,7 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { fnSvc := services.NewFunctionService(c.Repos.Function, rbacSvc, c.Compute, fileStore, auditSvc, c.Logger) cacheSvc := services.NewCacheService(c.Repos.Cache, rbacSvc, c.Compute, c.Repos.Vpc, eventSvc, auditSvc, c.Logger) queueSvc := services.NewQueueService(c.Repos.Queue, rbacSvc, eventSvc, auditSvc, c.Logger) - pipelineSvc := services.NewPipelineService(c.Repos.Pipeline, c.Repos.TaskQueue, eventSvc, auditSvc, c.Logger) + pipelineSvc := services.NewPipelineService(c.Repos.Pipeline, c.Repos.DurableQueue, eventSvc, auditSvc, c.Logger) notifySvc := services.NewNotifyService(services.NotifyServiceParams{Repo: c.Repos.Notify, RBACSvc: rbacSvc, QueueSvc: queueSvc, EventSvc: eventSvc, AuditSvc: auditSvc, Logger: c.Logger}) // 5. DevOps & Automation Services @@ -438,7 +438,7 @@ func initStorageServices(c ServiceConfig, rbacSvc ports.RBACService, audit ports func initClusterServices(c ServiceConfig, rbacSvc ports.RBACService, vpcSvc ports.VpcService, instSvc ports.InstanceService, secretSvc ports.SecretService, storageSvc ports.StorageService, lbSvc ports.LBService, sgSvc ports.SecurityGroupService) (ports.ClusterService, ports.ClusterProvisioner, error) { clusterProvisioner := k8s.NewKubeadmProvisioner(instSvc, c.Repos.Cluster, secretSvc, sgSvc, storageSvc, lbSvc, c.Logger) clusterSvc, err := services.NewClusterService(services.ClusterServiceParams{ - Repo: c.Repos.Cluster, RBAC: rbacSvc, Provisioner: clusterProvisioner, VpcSvc: vpcSvc, InstanceSvc: instSvc, SecretSvc: secretSvc, TaskQueue: c.Repos.TaskQueue, Logger: c.Logger, + Repo: c.Repos.Cluster, RBAC: rbacSvc, Provisioner: clusterProvisioner, VpcSvc: vpcSvc, InstanceSvc: instSvc, SecretSvc: secretSvc, TaskQueue: c.Repos.DurableQueue, Logger: c.Logger, }) if err != nil { return nil, nil, fmt.Errorf("failed to init cluster service: %w", err) diff --git a/internal/workers/pipeline_worker_test.go b/internal/workers/pipeline_worker_test.go index 6a5cf50aa..678183f77 100644 --- a/internal/workers/pipeline_worker_test.go +++ b/internal/workers/pipeline_worker_test.go @@ -163,12 +163,13 @@ func TestPipelineWorker_processJob(t *testing.T) { compute := new(mockComputeBackendExtended) taskQueue := new(MockTaskQueue) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewPipelineWorker(repo, taskQueue, compute, logger) + worker := NewPipelineWorker(repo, taskQueue, nil, compute, logger) buildID := uuid.New() pipelineID := uuid.New() userID := uuid.New() job := domain.BuildJob{BuildID: buildID, PipelineID: pipelineID, UserID: userID} + msg := &ports.DurableMessage{ID: "1-0", Queue: pipelineQueueName} t.Run("Success", func(t *testing.T) { build := &domain.Build{ID: buildID, PipelineID: pipelineID, UserID: userID} @@ -205,9 +206,11 @@ func TestPipelineWorker_processJob(t *testing.T) { repo.On("UpdateBuild", mock.Anything, mock.MatchedBy(func(b *domain.Build) bool { return b.Status == domain.BuildStatusSucceeded })).Return(nil).Once() + taskQueue.On("Ack", mock.Anything, pipelineQueueName, pipelineGroup, msg.ID).Return(nil).Once() - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) compute.AssertExpectations(t) + taskQueue.AssertExpectations(t) }) }