diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 6028dba78..54a7a4f61 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -27,7 +27,22 @@ jobs: | grep -E '^(goos:|goarch:|pkg:|cpu:|Benchmark|PASS$|ok\s)' \ | tee bench.txt - - name: Store Benchmark Result + - name: Store Benchmark Result (PR) + if: github.event_name == 'pull_request' + uses: benchmark-action/github-action-benchmark@v1 + with: + name: Go Benchmarks + tool: 'go' + output-file-path: bench.txt + # On PRs, publishing to gh-pages is not allowed in all permission models. + auto-push: false + # Fail if performance drops by more than 50% + alert-threshold: '200%' + comment-on-alert: false + fail-on-alert: false + + - name: Store Benchmark Result (main) + if: github.event_name == 'push' && github.ref == 'refs/heads/main' uses: benchmark-action/github-action-benchmark@v1 with: name: Go Benchmarks diff --git a/cmd/api/main.go b/cmd/api/main.go index 1b851bb7b..e62cf9000 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -130,16 +130,28 @@ func run() error { defer db.Close() defer func() { _ = rdb.Close() }() - compute, storage, network, lbProxy, err := initBackends(deps, cfg, logger, db, rdb) + rawCompute, rawStorage, rawNetwork, rawLBProxy, err := initBackends(deps, cfg, logger, db, rdb) if err != nil { logger.Error("backend initialization failed", "error", err) return err } + // Wrap raw backends with resilience decorators (circuit breaker, bulkhead, timeouts). + compute := platform.NewResilientCompute(rawCompute, logger, platform.ResilientComputeOpts{}) + storage := platform.NewResilientStorage(rawStorage, logger, platform.ResilientStorageOpts{}) + network := platform.NewResilientNetwork(rawNetwork, logger, platform.ResilientNetworkOpts{}) + lbProxy := platform.NewResilientLB(rawLBProxy, logger, platform.ResilientLBOpts{}) + repos := deps.InitRepositories(db, rdb) + + // Create leader elector for singleton worker coordination. + // When multiple worker replicas run, only one will hold leadership per key. + leaderElector := postgres.NewPgLeaderElector(db, logger) + svcs, workers, err := deps.InitServices(setup.ServiceConfig{ Config: cfg, Repos: repos, Compute: compute, Storage: storage, Network: network, LBProxy: lbProxy, DB: db, RDB: rdb, Logger: logger, + LeaderElector: leaderElector, }) if err != nil { logger.Error("service initialization failed", "error", err) @@ -154,16 +166,22 @@ func run() error { r.Use(otelgin.Middleware("compute-api")) } - runApplication(deps, cfg, logger, r, workers) - return nil + return runApplication(deps, cfg, logger, r, workers) } -func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r *gin.Engine, workers *setup.Workers) { - role := os.Getenv("APP_ROLE") +func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r *gin.Engine, workers *setup.Workers) error { + role := os.Getenv("ROLE") if role == "" { role = "all" } + validRoles := map[string]bool{"api": true, "worker": true, "all": true} + if !validRoles[role] { + logger.Error("invalid ROLE value, must be one of: api, worker, all", "role", role) + return fmt.Errorf("invalid ROLE value %q, must be one of: api, worker, all", role) + } + logger.Info("starting with role", "role", role) + wg := &sync.WaitGroup{} workerCtx, workerCancel := context.WithCancel(context.Background()) @@ -171,9 +189,9 @@ func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r * runWorkers(workerCtx, wg, workers) } - srv := deps.NewHTTPServer(":"+cfg.Port, r) - + var srv *http.Server if role == "api" || role == "all" { + srv = deps.NewHTTPServer(":"+cfg.Port, r) go func() { logger.Info("starting compute-api", "port", cfg.Port) if err := deps.StartHTTPServer(srv); err != nil && !stdlib_errors.Is(err, http.ErrServerClosed) { @@ -181,25 +199,28 @@ func runApplication(deps AppDeps, cfg *platform.Config, logger *slog.Logger, r * } }() } else { - logger.Info("running in worker-only mode") + logger.Info("running in worker-only mode, HTTP server disabled") } quit := make(chan os.Signal, 1) deps.NotifySignals(quit, syscall.SIGINT, syscall.SIGTERM) <-quit - logger.Info("shutting down server...") + logger.Info("shutting down...") ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownTimeout) defer cancel() - if err := deps.ShutdownHTTPServer(ctx, srv); err != nil { - logger.Error("server forced to shutdown", "error", err) + if srv != nil { + if err := deps.ShutdownHTTPServer(ctx, srv); err != nil { + logger.Error("server forced to shutdown", "error", err) + } } workerCancel() wg.Wait() - logger.Info("server exited") + logger.Info("shutdown complete") + return nil } type runner interface { diff --git a/cmd/api/main_test.go b/cmd/api/main_test.go index 6f1f81af2..a311b4ad8 100644 --- a/cmd/api/main_test.go +++ b/cmd/api/main_test.go @@ -158,7 +158,10 @@ func TestRunApplicationApiRoleStartsAndShutsDown(t *testing.T) { }() } - runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + err := runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + if err != nil { + t.Fatalf("runApplication returned error: %v", err) + } select { case <-shutdownCalled: @@ -167,6 +170,106 @@ func TestRunApplicationApiRoleStartsAndShutsDown(t *testing.T) { } } +func TestRunApplicationWorkerRoleDoesNotStartHTTP(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + t.Setenv("ROLE", "worker") + + deps := DefaultDeps() + + deps.NewHTTPServer = func(string, http.Handler) *http.Server { + t.Fatalf("NewHTTPServer should not be called in worker-only mode") + return nil + } + deps.StartHTTPServer = func(*http.Server) error { + t.Fatalf("StartHTTPServer should not be called in worker-only mode") + return nil + } + deps.ShutdownHTTPServer = func(context.Context, *http.Server) error { + t.Fatalf("ShutdownHTTPServer should not be called in worker-only mode") + return nil + } + deps.NotifySignals = func(c chan<- os.Signal, _ ...os.Signal) { + go func() { + // Give workers a moment to start, then signal shutdown + time.Sleep(50 * time.Millisecond) + c <- syscall.SIGTERM + }() + } + + err := runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + if err != nil { + t.Fatalf("runApplication returned error: %v", err) + } + // If we reach here without t.Fatalf, the test passes — no HTTP server was touched. +} + +func TestRunApplicationDefaultsToAllRole(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + t.Setenv("ROLE", "") // Explicitly empty to verify default + + started := make(chan struct{}) + shutdownCalled := make(chan struct{}) + deps := DefaultDeps() + + deps.NewHTTPServer = func(addr string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + } + } + deps.StartHTTPServer = func(*http.Server) error { + close(started) + return http.ErrServerClosed + } + deps.ShutdownHTTPServer = func(context.Context, *http.Server) error { + close(shutdownCalled) + return nil + } + deps.NotifySignals = func(c chan<- os.Signal, _ ...os.Signal) { + go func() { + <-started + c <- syscall.SIGTERM + }() + } + + err := runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + if err != nil { + t.Fatalf("runApplication returned error: %v", err) + } + + select { + case <-shutdownCalled: + case <-time.After(time.Second): + t.Fatalf("expected server shutdown to be called when ROLE defaults to 'all'") + } +} + +func TestRunApplicationInvalidRoleReturnsEarly(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + t.Setenv("ROLE", "invalid") + + deps := DefaultDeps() + + deps.NewHTTPServer = func(string, http.Handler) *http.Server { + t.Fatalf("NewHTTPServer should not be called for invalid role") + return nil + } + deps.StartHTTPServer = func(*http.Server) error { + t.Fatalf("StartHTTPServer should not be called for invalid role") + return nil + } + deps.NotifySignals = func(c chan<- os.Signal, _ ...os.Signal) { + t.Fatalf("NotifySignals should not be called for invalid role") + } + + // Should return immediately without starting anything + err := runApplication(deps, &platform.Config{Port: "0"}, logger, gin.New(), &setup.Workers{}) + if err == nil { + t.Fatalf("expected error for invalid role") + } +} + // Stub helpers below keep main.go testable without altering production behavior. type stubDB struct{ closed bool } diff --git a/internal/api/setup/dependencies.go b/internal/api/setup/dependencies.go index 4d0d8ebb4..db3118bff 100644 --- a/internal/api/setup/dependencies.go +++ b/internal/api/setup/dependencies.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "log/slog" + "sync" "time" "strings" @@ -57,6 +58,8 @@ type Repositories struct { AutoScaling ports.AutoScalingRepository Accounting ports.AccountingRepository TaskQueue ports.TaskQueue + DurableQueue ports.DurableTaskQueue + Ledger ports.ExecutionLedger Image ports.ImageRepository Cluster ports.ClusterRepository Lifecycle ports.LifecycleRepository @@ -102,6 +105,8 @@ func InitRepositories(db postgres.DB, rdb *redisv9.Client) *Repositories { AutoScaling: postgres.NewAutoScalingRepo(db), Accounting: postgres.NewAccountingRepository(db), TaskQueue: redis.NewRedisTaskQueue(rdb), + DurableQueue: redis.NewDurableTaskQueue(rdb), + Ledger: postgres.NewExecutionLedger(db), Image: postgres.NewImageRepository(db), Cluster: postgres.NewClusterRepository(db), Lifecycle: postgres.NewLifecycleRepository(db), @@ -163,35 +168,46 @@ type Services struct { VPCPeering ports.VPCPeeringService } -// Workers struct to return background workers +// Runner is the interface that all background workers implement. +type Runner interface { + Run(context.Context, *sync.WaitGroup) +} + +// Workers struct to return background workers. +// Singleton workers are typed as Runner so they can be wrapped with LeaderGuard. +// Parallel consumers retain concrete types for direct configuration access. type Workers struct { - LB *services.LBWorker - AutoScaling *services.AutoScalingWorker - Cron *services.CronWorker - Container *services.ContainerWorker - Pipeline *workers.PipelineWorker - Provision *workers.ProvisionWorker - Accounting *workers.AccountingWorker - Cluster *workers.ClusterWorker - Lifecycle *workers.LifecycleWorker - ReplicaMonitor *workers.ReplicaMonitor - ClusterReconciler *workers.ClusterReconciler - Healing *workers.HealingWorker - DatabaseFailover *workers.DatabaseFailoverWorker - Log *workers.LogWorker + // Singleton workers (must run on exactly one node via leader election) + LB Runner + AutoScaling Runner + Cron Runner + Container Runner + Accounting Runner + Lifecycle Runner + ReplicaMonitor Runner + ClusterReconciler Runner + Healing Runner + DatabaseFailover Runner + Log Runner + + // Parallel consumer workers (safe to run on multiple nodes) + Pipeline *workers.PipelineWorker + Provision *workers.ProvisionWorker + Cluster *workers.ClusterWorker } // ServiceConfig holds the dependencies required to initialize services type ServiceConfig struct { - Config *platform.Config - Repos *Repositories - Compute ports.ComputeBackend - Storage ports.StorageBackend - Network ports.NetworkBackend - LBProxy ports.LBProxyAdapter - DB postgres.DB - RDB *redisv9.Client - Logger *slog.Logger + Config *platform.Config + Repos *Repositories + Compute ports.ComputeBackend + Storage ports.StorageBackend + Network ports.NetworkBackend + LBProxy ports.LBProxyAdapter + DB postgres.DB + RDB *redisv9.Client + Logger *slog.Logger + LeaderElector ports.LeaderElector // nil disables leader election (single-instance mode) } // InitServices constructs core services and background workers. @@ -219,7 +235,12 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { if err != nil { return nil, nil, fmt.Errorf("failed to init powerdns backend: %w", err) } - dnsSvc := services.NewDNSService(services.DNSServiceParams{Repo: c.Repos.DNS, RBAC: rbacSvc, Backend: pdnsBackend, VpcRepo: c.Repos.Vpc, AuditSvc: auditSvc, EventSvc: eventSvc, Logger: c.Logger}) + // Wrap DNS backend with resilience (circuit breaker + timeout). + resilientDNS := platform.NewResilientDNS(pdnsBackend, c.Logger, platform.ResilientDNSOpts{}) + dnsSvc := services.NewDNSService(services.DNSServiceParams{ + Repo: c.Repos.DNS, RBAC: rbacSvc, Backend: resilientDNS, VpcRepo: c.Repos.Vpc, + AuditSvc: auditSvc, EventSvc: eventSvc, Logger: c.Logger, + }) sshKeySvc, err := services.NewSSHKeyService(services.SSHKeyServiceParams{Repo: c.Repos.SSHKey, Logger: c.Logger, RBACSvc: rbacSvc}) if err != nil { @@ -228,7 +249,7 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { logSvc := services.NewCloudLogsService(c.Repos.Log, rbacSvc, c.Logger) - instSvcConcrete := services.NewInstanceService(services.InstanceServiceParams{Repo: c.Repos.Instance, VpcRepo: c.Repos.Vpc, SubnetRepo: c.Repos.Subnet, VolumeRepo: c.Repos.Volume, InstanceTypeRepo: c.Repos.InstanceType, RBAC: rbacSvc, Compute: c.Compute, Network: c.Network, EventSvc: eventSvc, AuditSvc: auditSvc, DNSSvc: dnsSvc, TaskQueue: c.Repos.TaskQueue, DockerNetwork: c.Config.DockerDefaultNetwork, Logger: c.Logger, TenantSvc: tenantSvc, SSHKeySvc: sshKeySvc, LogSvc: logSvc}) + instSvcConcrete := services.NewInstanceService(services.InstanceServiceParams{Repo: c.Repos.Instance, VpcRepo: c.Repos.Vpc, SubnetRepo: c.Repos.Subnet, VolumeRepo: c.Repos.Volume, InstanceTypeRepo: c.Repos.InstanceType, RBAC: rbacSvc, Compute: c.Compute, Network: c.Network, EventSvc: eventSvc, AuditSvc: auditSvc, DNSSvc: dnsSvc, TaskQueue: c.Repos.DurableQueue, DockerNetwork: c.Config.DockerDefaultNetwork, Logger: c.Logger, TenantSvc: tenantSvc, SSHKeySvc: sshKeySvc, LogSvc: logSvc}) sgSvc := services.NewSecurityGroupService(c.Repos.SecurityGroup, rbacSvc, c.Repos.Vpc, c.Network, auditSvc, c.Logger) lbSvc := services.NewLBService(c.Repos.LB, rbacSvc, c.Repos.Vpc, c.Repos.Instance, auditSvc, c.Logger) @@ -280,7 +301,7 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { fnSvc := services.NewFunctionService(c.Repos.Function, rbacSvc, c.Compute, fileStore, auditSvc, c.Logger) cacheSvc := services.NewCacheService(c.Repos.Cache, rbacSvc, c.Compute, c.Repos.Vpc, eventSvc, auditSvc, c.Logger) queueSvc := services.NewQueueService(c.Repos.Queue, rbacSvc, eventSvc, auditSvc, c.Logger) - pipelineSvc := services.NewPipelineService(c.Repos.Pipeline, c.Repos.TaskQueue, eventSvc, auditSvc, c.Logger) + pipelineSvc := services.NewPipelineService(c.Repos.Pipeline, c.Repos.DurableQueue, eventSvc, auditSvc, c.Logger) notifySvc := services.NewNotifyService(services.NotifyServiceParams{Repo: c.Repos.Notify, RBACSvc: rbacSvc, QueueSvc: queueSvc, EventSvc: eventSvc, AuditSvc: auditSvc, Logger: c.Logger}) // 5. DevOps & Automation Services @@ -298,7 +319,7 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { accountingWorker := workers.NewAccountingWorker(accountingSvc, c.Logger) imageSvc := services.NewImageService(services.ImageServiceParams{Repo: c.Repos.Image, RBACSvc: rbacSvc, FileStore: fileStore, Logger: c.Logger}) iamSvc := services.NewIAMService(c.Repos.IAM, auditSvc, eventSvc, c.Logger) - provisionWorker := workers.NewProvisionWorker(instSvcConcrete, c.Repos.TaskQueue, c.Logger) + provisionWorker := workers.NewProvisionWorker(instSvcConcrete, c.Repos.DurableQueue, c.Repos.Ledger, c.Logger) healingWorker := workers.NewHealingWorker(instSvcConcrete, c.Repos.Instance, c.Logger) clusterSvc, clusterProvisioner, err := initClusterServices(c, rbacSvc, vpcSvc, instSvcConcrete, secretSvc, storageSvc, lbSvc, sgSvc) @@ -311,7 +332,47 @@ func InitServices(c ServiceConfig) (*Services, *Workers, error) { // 7. High Availability & Monitoring replicaMonitor := initReplicaMonitor(c) - workersCollection := &Workers{LB: lbWorker, AutoScaling: asgWorker, Cron: cronWorker, Container: containerWorker, Pipeline: workers.NewPipelineWorker(c.Repos.Pipeline, c.Repos.TaskQueue, c.Compute, c.Logger), Provision: provisionWorker, Accounting: accountingWorker, Cluster: workers.NewClusterWorker(c.Repos.Cluster, clusterProvisioner, c.Repos.TaskQueue, c.Logger), Lifecycle: workers.NewLifecycleWorker(c.Repos.Lifecycle, storageSvc, c.Repos.Storage, c.Logger), ReplicaMonitor: replicaMonitor, ClusterReconciler: workers.NewClusterReconciler(c.Repos.Cluster, clusterProvisioner, c.Logger), Healing: healingWorker, DatabaseFailover: workers.NewDatabaseFailoverWorker(databaseSvc, c.Repos.Database, c.Logger), Log: workers.NewLogWorker(logSvc, c.Logger)} + // Helper: wrap a singleton worker with LeaderGuard if leader election is enabled. + // Accepts a concrete pointer to avoid nil-interface pitfalls — callers must + // explicitly pass nil Runner when the worker should be skipped. + guardSingleton := func(key string, w Runner) Runner { + if w == nil || c.LeaderElector == nil { + return w + } + return workers.NewLeaderGuard(c.LeaderElector, key, w, c.Logger) + } + + lifecycleWorker := workers.NewLifecycleWorker(c.Repos.Lifecycle, storageSvc, c.Repos.Storage, c.Logger) + clusterReconciler := workers.NewClusterReconciler(c.Repos.Cluster, clusterProvisioner, c.Logger) + dbFailoverWorker := workers.NewDatabaseFailoverWorker(databaseSvc, c.Repos.Database, c.Logger) + logWorker := workers.NewLogWorker(logSvc, c.Logger) + + // For replicaMonitor, we must convert nil *ReplicaMonitor to nil Runner to avoid + // a non-nil interface wrapping a nil pointer. + var replicaMonitorRunner Runner + if replicaMonitor != nil { + replicaMonitorRunner = replicaMonitor + } + + workersCollection := &Workers{ + // Singleton workers — wrapped with leader election + LB: guardSingleton("singleton:lb", lbWorker), + AutoScaling: guardSingleton("singleton:autoscaling", asgWorker), + Cron: guardSingleton("singleton:cron", cronWorker), + Container: guardSingleton("singleton:container", containerWorker), + Accounting: guardSingleton("singleton:accounting", accountingWorker), + Lifecycle: guardSingleton("singleton:lifecycle", lifecycleWorker), + ReplicaMonitor: guardSingleton("singleton:replica-monitor", replicaMonitorRunner), + ClusterReconciler: guardSingleton("singleton:cluster-reconciler", clusterReconciler), + Healing: guardSingleton("singleton:healing", healingWorker), + DatabaseFailover: guardSingleton("singleton:db-failover", dbFailoverWorker), + Log: guardSingleton("singleton:log", logWorker), + + // Parallel consumer workers — no leader election needed + Pipeline: workers.NewPipelineWorker(c.Repos.Pipeline, c.Repos.DurableQueue, c.Repos.Ledger, c.Compute, c.Logger), + Provision: provisionWorker, + Cluster: workers.NewClusterWorker(c.Repos.Cluster, clusterProvisioner, c.Repos.DurableQueue, c.Repos.Ledger, c.Logger), + } return svcs, workersCollection, nil } @@ -377,7 +438,7 @@ func initStorageServices(c ServiceConfig, rbacSvc ports.RBACService, audit ports func initClusterServices(c ServiceConfig, rbacSvc ports.RBACService, vpcSvc ports.VpcService, instSvc ports.InstanceService, secretSvc ports.SecretService, storageSvc ports.StorageService, lbSvc ports.LBService, sgSvc ports.SecurityGroupService) (ports.ClusterService, ports.ClusterProvisioner, error) { clusterProvisioner := k8s.NewKubeadmProvisioner(instSvc, c.Repos.Cluster, secretSvc, sgSvc, storageSvc, lbSvc, c.Logger) clusterSvc, err := services.NewClusterService(services.ClusterServiceParams{ - Repo: c.Repos.Cluster, RBAC: rbacSvc, Provisioner: clusterProvisioner, VpcSvc: vpcSvc, InstanceSvc: instSvc, SecretSvc: secretSvc, TaskQueue: c.Repos.TaskQueue, Logger: c.Logger, + Repo: c.Repos.Cluster, RBAC: rbacSvc, Provisioner: clusterProvisioner, VpcSvc: vpcSvc, InstanceSvc: instSvc, SecretSvc: secretSvc, TaskQueue: c.Repos.DurableQueue, Logger: c.Logger, }) if err != nil { return nil, nil, fmt.Errorf("failed to init cluster service: %w", err) diff --git a/internal/core/ports/execution_ledger.go b/internal/core/ports/execution_ledger.go new file mode 100644 index 000000000..cbdafd1e4 --- /dev/null +++ b/internal/core/ports/execution_ledger.go @@ -0,0 +1,35 @@ +// Package ports defines service and repository interfaces. +package ports + +import ( + "context" + "time" +) + +// ExecutionLedger provides idempotent job execution tracking. +// Before processing a job, a worker calls TryAcquire with a unique job key. +// If TryAcquire returns true, the caller owns the execution and must +// eventually call MarkComplete or MarkFailed. +// If TryAcquire returns false, another worker already processed (or is +// processing) the job and the caller should skip it. +type ExecutionLedger interface { + // TryAcquire attempts to claim ownership of a job execution. + // Returns true if the caller now owns the execution (inserted a new row + // with status='running'). Returns false if the job was already acquired + // by another worker (row exists with status='completed' or a recent + // 'running' entry within staleThreshold). + // + // If a previous 'running' entry is older than staleThreshold, it is + // considered abandoned and the caller can reclaim it. + TryAcquire(ctx context.Context, jobKey string, staleThreshold time.Duration) (bool, error) + + // MarkComplete marks a job execution as successfully completed. + MarkComplete(ctx context.Context, jobKey string, result string) error + + // MarkFailed marks a job execution as failed, allowing future retries. + MarkFailed(ctx context.Context, jobKey string, reason string) error + + // GetStatus returns the current status, result and start time of a job. + // Returns status="", nil error if the job does not exist. + GetStatus(ctx context.Context, jobKey string) (status string, result string, startedAt time.Time, err error) +} diff --git a/internal/core/ports/leader.go b/internal/core/ports/leader.go new file mode 100644 index 000000000..62fce204e --- /dev/null +++ b/internal/core/ports/leader.go @@ -0,0 +1,23 @@ +// Package ports defines service and repository interfaces. +package ports + +import ( + "context" +) + +// LeaderElector provides distributed leader election for singleton controllers. +// Only one instance across all replicas should hold leadership for a given key at any time. +type LeaderElector interface { + // Acquire attempts to become the leader for the given key. + // It returns true if leadership was acquired, false otherwise. + // The leadership is held until Release is called or the context is cancelled. + Acquire(ctx context.Context, key string) (bool, error) + + // Release relinquishes leadership for the given key. + Release(ctx context.Context, key string) error + + // RunAsLeader blocks until leadership is acquired for the given key, then calls fn. + // If leadership is lost, fn's context is cancelled. If fn returns, leadership is released. + // This is the primary entrypoint for singleton workers. + RunAsLeader(ctx context.Context, key string, fn func(ctx context.Context) error) error +} diff --git a/internal/core/ports/task_queue.go b/internal/core/ports/task_queue.go index 279578ac9..cb493bbf4 100644 --- a/internal/core/ports/task_queue.go +++ b/internal/core/ports/task_queue.go @@ -6,9 +6,56 @@ import ( ) // TaskQueue defines a simple producer-consumer interface for background work distribution. +// Producers (services) only need this interface to enqueue jobs. type TaskQueue interface { // Enqueue adds a serializable payload to the specified background processing queue. Enqueue(ctx context.Context, queueName string, payload interface{}) error // Dequeue pulls the next available raw message string from the background processing queue. + // Deprecated: parallel consumers should use DurableTaskQueue.Receive instead. Dequeue(ctx context.Context, queueName string) (string, error) } + +// DurableMessage represents a message read from a durable queue. +// The consumer must call Ack after successful processing; otherwise +// the message remains pending and will be redelivered after a timeout. +type DurableMessage struct { + // ID is the stream-assigned message identifier (e.g. Redis Stream ID). + ID string + // Payload is the raw JSON string of the job. + Payload string + // Queue is the queue (stream) name this message came from. + Queue string +} + +// DurableTaskQueue extends TaskQueue with at-least-once delivery semantics. +// It uses consumer groups so that each message is delivered to exactly one +// consumer within the group, and requires explicit acknowledgement. +type DurableTaskQueue interface { + TaskQueue + + // EnsureGroup creates the consumer group for the given queue if it does not + // already exist. Must be called once at startup before Receive. + EnsureGroup(ctx context.Context, queueName, groupName string) error + + // Receive reads the next available message from the queue for the given + // consumer group and consumer name. It blocks up to the queue's configured + // poll interval. Returns nil message and nil error when no message is + // available (timeout). + Receive(ctx context.Context, queueName, groupName, consumerName string) (*DurableMessage, error) + + // Ack acknowledges successful processing of a message. After Ack the + // message will not be redelivered. + Ack(ctx context.Context, queueName, groupName, messageID string) error + + // Nack signals that the consumer failed to process the message. + // It relinquishes the current delivery WITHOUT re-queuing or creating + // a new message ID. The message remains in the pending entries list + // and will be reclaimed by ReclaimStale after the idle timeout. + // Implementations MUST NOT create duplicate live copies of the message. + Nack(ctx context.Context, queueName, groupName, messageID string) error + + // ReclaimStale claims messages that have been pending longer than the + // given idle threshold and returns them. This allows a healthy consumer + // to pick up work abandoned by a crashed peer. + ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]DurableMessage, error) +} diff --git a/internal/drills/ha_drills_test.go b/internal/drills/ha_drills_test.go new file mode 100644 index 000000000..81cdb8368 --- /dev/null +++ b/internal/drills/ha_drills_test.go @@ -0,0 +1,421 @@ +// Package drills provides integration-like failure drill tests that validate +// the HA properties of the control plane. These tests use mocks to simulate +// infrastructure failures without requiring real Postgres/Redis. +// +// Run: go test ./internal/drills/ -v -count=1 +package drills + +import ( + "context" + "errors" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/poyrazk/thecloud/internal/platform" +) + +// --------------------------------------------------------------------------- +// Drill 1: Circuit breaker trip + recovery +// SLO: When a backend fails ≥ threshold times, all subsequent calls must +// return ErrCircuitOpen within 1ms (no backend call). After resetTimeout, +// a successful probe must close the circuit. +// --------------------------------------------------------------------------- + +func TestDrill_CircuitBreakerTripAndRecovery(t *testing.T) { + const threshold = 3 + const resetTimeout = 200 * time.Millisecond + + var transitions []string + var mu sync.Mutex + + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "drill-cb", + Threshold: threshold, + ResetTimeout: resetTimeout, + SuccessRequired: 1, + OnStateChange: func(name string, from, to platform.State) { + mu.Lock() + transitions = append(transitions, from.String()+"→"+to.String()) + mu.Unlock() + }, + }) + + backendErr := errors.New("backend down") + + // Phase 1: Trip the circuit with consecutive failures. + for i := 0; i < threshold; i++ { + err := cb.Execute(func() error { return backendErr }) + if err == nil { + t.Fatalf("iteration %d: expected error", i) + } + } + + // Verify circuit is open. + if cb.GetState() != platform.StateOpen { + t.Fatalf("expected open, got %s", cb.GetState().String()) + } + + // Phase 2: Confirm fail-fast (no backend call). + var backendCalled atomic.Bool + start := time.Now() + err := cb.Execute(func() error { + backendCalled.Store(true) + return nil + }) + elapsed := time.Since(start) + + if !errors.Is(err, platform.ErrCircuitOpen) { + t.Fatalf("expected ErrCircuitOpen, got %v", err) + } + if backendCalled.Load() { + t.Fatal("backend should NOT have been called while circuit is open") + } + if elapsed > 5*time.Millisecond { + t.Fatalf("fail-fast took %v, expected <5ms", elapsed) + } + + // Phase 3: Wait for resetTimeout, then recover. + time.Sleep(resetTimeout + 50*time.Millisecond) + err = cb.Execute(func() error { return nil }) + if err != nil { + t.Fatalf("expected recovery, got %v", err) + } + if cb.GetState() != platform.StateClosed { + t.Fatalf("expected closed after recovery, got %s", cb.GetState().String()) + } + + // Verify transitions: closed→open, open→half-open, half-open→closed. + mu.Lock() + defer mu.Unlock() + expected := []string{"closed→open", "open→half-open", "half-open→closed"} + if len(transitions) != len(expected) { + t.Fatalf("expected %d transitions, got %d: %v", len(expected), len(transitions), transitions) + } + for i := range expected { + if transitions[i] != expected[i] { + t.Fatalf("transition[%d]: expected %s, got %s", i, expected[i], transitions[i]) + } + } +} + +// --------------------------------------------------------------------------- +// Drill 2: Bulkhead saturation + graceful rejection +// SLO: When maxConc requests are in-flight, additional requests must be +// rejected with ErrBulkheadFull (not blocked forever). +// --------------------------------------------------------------------------- + +func TestDrill_BulkheadSaturationAndRejection(t *testing.T) { + const maxConc = 3 + const waitTimeout = 100 * time.Millisecond + + bh := platform.NewBulkhead(platform.BulkheadOpts{ + Name: "drill-bh", + MaxConc: maxConc, + WaitTimeout: waitTimeout, + }) + + ctx := context.Background() + blockCh := make(chan struct{}) + var inFlight atomic.Int64 + var rejected atomic.Int64 + var wg sync.WaitGroup + + // Saturate the bulkhead. + for i := 0; i < maxConc; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = bh.Execute(ctx, func() error { + inFlight.Add(1) + <-blockCh + return nil + }) + }() + } + + // Wait for all slots to be occupied. + for inFlight.Load() < int64(maxConc) { + time.Sleep(5 * time.Millisecond) + } + + if bh.Available() != 0 { + t.Fatalf("expected 0 available slots, got %d", bh.Available()) + } + + // Fire excess requests — they should be rejected. + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := bh.Execute(ctx, func() error { return nil }) + if errors.Is(err, platform.ErrBulkheadFull) { + rejected.Add(1) + } + }() + } + + // Let the excess timeout. + time.Sleep(waitTimeout + 50*time.Millisecond) + + // Unblock the saturating goroutines. + close(blockCh) + wg.Wait() + + if rejected.Load() != 5 { + t.Fatalf("expected 5 rejections, got %d", rejected.Load()) + } +} + +// --------------------------------------------------------------------------- +// Drill 3: Resilient adapter end-to-end (circuit + bulkhead + timeout) +// SLO: A failing backend trips the circuit; subsequent calls fail-fast; +// recovery probe succeeds and normal operation resumes. +// --------------------------------------------------------------------------- + +type failingBackend struct { + healthy atomic.Bool + calls atomic.Int64 +} + +func (f *failingBackend) Do(_ context.Context) error { + f.calls.Add(1) + if f.healthy.Load() { + return nil + } + return errors.New("backend failure") +} + +func TestDrill_ResilientAdapterEndToEnd(t *testing.T) { + backend := &failingBackend{} + logger := slog.Default() + + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "drill-e2e", + Threshold: 3, + ResetTimeout: 200 * time.Millisecond, + SuccessRequired: 1, + }) + + bh := platform.NewBulkhead(platform.BulkheadOpts{ + Name: "drill-e2e", + MaxConc: 5, + }) + + _ = logger // Would be used for real logging in production. + + // Helper: simulate calling through the full resilience stack. + callThrough := func(ctx context.Context) error { + return bh.Execute(ctx, func() error { + return cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + return backend.Do(ctx2) + }) + }) + } + + ctx := context.Background() + + // Phase 1: Backend is down → trip circuit. + for i := 0; i < 3; i++ { + _ = callThrough(ctx) + } + if cb.GetState() != platform.StateOpen { + t.Fatalf("expected open, got %s", cb.GetState().String()) + } + + // Phase 2: Fail-fast while open. + callsBefore := backend.calls.Load() + err := callThrough(ctx) + if !errors.Is(err, platform.ErrCircuitOpen) { + t.Fatalf("expected ErrCircuitOpen, got %v", err) + } + if backend.calls.Load() != callsBefore { + t.Fatal("backend should not be called while circuit is open") + } + + // Phase 3: Backend recovers. + backend.healthy.Store(true) + time.Sleep(250 * time.Millisecond) + + err = callThrough(ctx) + if err != nil { + t.Fatalf("expected recovery, got %v", err) + } + if cb.GetState() != platform.StateClosed { + t.Fatalf("expected closed, got %s", cb.GetState().String()) + } + + // Phase 4: Normal operation continues. + for i := 0; i < 10; i++ { + if err := callThrough(ctx); err != nil { + t.Fatalf("call %d failed: %v", i, err) + } + } +} + +// --------------------------------------------------------------------------- +// Drill 4: Retry with exponential backoff +// SLO: Retry must respect MaxAttempts, must apply backoff between attempts, +// and must stop early if the context is cancelled. +// --------------------------------------------------------------------------- + +func TestDrill_RetryBackoffAndContextCancellation(t *testing.T) { + t.Run("exhausts_attempts", func(t *testing.T) { + var attempts atomic.Int64 + err := platform.Retry(context.Background(), platform.RetryOpts{ + MaxAttempts: 4, + BaseDelay: 10 * time.Millisecond, + MaxDelay: 50 * time.Millisecond, + }, func(ctx context.Context) error { + attempts.Add(1) + return errors.New("still failing") + }) + + if err == nil { + t.Fatal("expected error after exhausting retries") + } + if attempts.Load() != 4 { + t.Fatalf("expected 4 attempts, got %d", attempts.Load()) + } + }) + + t.Run("stops_on_context_cancel", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + var attempts atomic.Int64 + start := time.Now() + err := platform.Retry(ctx, platform.RetryOpts{ + MaxAttempts: 100, // Would take very long if not cancelled. + BaseDelay: 50 * time.Millisecond, + MaxDelay: 200 * time.Millisecond, + }, func(ctx context.Context) error { + attempts.Add(1) + return errors.New("failing") + }) + + elapsed := time.Since(start) + if err == nil { + t.Fatal("expected error") + } + if elapsed > 500*time.Millisecond { + t.Fatalf("should have stopped early, took %v", elapsed) + } + if attempts.Load() >= 100 { + t.Fatal("should not have exhausted all 100 attempts") + } + }) + + t.Run("succeeds_on_retry", func(t *testing.T) { + var attempts atomic.Int64 + err := platform.Retry(context.Background(), platform.RetryOpts{ + MaxAttempts: 5, + BaseDelay: 5 * time.Millisecond, + }, func(ctx context.Context) error { + n := attempts.Add(1) + if n < 3 { + return errors.New("not yet") + } + return nil + }) + + if err != nil { + t.Fatalf("expected success, got %v", err) + } + if attempts.Load() != 3 { + t.Fatalf("expected 3 attempts, got %d", attempts.Load()) + } + }) +} + +// --------------------------------------------------------------------------- +// Drill 5: Half-open single-flight +// SLO: While a probe request is in-flight in half-open state, all other +// requests must be rejected with ErrCircuitOpen. +// --------------------------------------------------------------------------- + +func TestDrill_HalfOpenSingleFlight(t *testing.T) { + const resetTimeout = 100 * time.Millisecond + + stateChanged := make(chan platform.State, 10) + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "drill-halfopen", + Threshold: 1, + ResetTimeout: resetTimeout, + OnStateChange: func(name string, from, to platform.State) { + stateChanged <- to + }, + }) + + // Trip the circuit (closed -> open). + _ = cb.Execute(func() error { return errors.New("fail") }) + if cb.GetState() != platform.StateOpen { + t.Fatalf("expected open, got %s", cb.GetState().String()) + } + + // Drain transitions if any. + for len(stateChanged) > 0 { + <-stateChanged + } + + // Wait for reset timeout and transition to half-open. + // Note: allowRequest transitions to HalfOpen ONLY when Execute is called after resetTimeout. + time.Sleep(resetTimeout + 10*time.Millisecond) + + // Start a slow probe request. + probeStarted := make(chan struct{}) + probeDone := make(chan struct{}) + go func() { + _ = cb.Execute(func() error { + close(probeStarted) + <-probeDone // Block until we release. + return nil + }) + }() + + // Wait for the probe to actually start and transition state. + select { + case <-probeStarted: + case <-time.After(time.Second): + t.Fatal("timeout waiting for probe to start") + } + + // Wait for transition to HalfOpen. + select { + case s := <-stateChanged: + if s != platform.StateHalfOpen { + t.Fatalf("expected transition to half-open, got %s", s.String()) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for transition to half-open") + } + + // All other requests should be rejected. + for i := 0; i < 5; i++ { + err := cb.Execute(func() error { return nil }) + if !errors.Is(err, platform.ErrCircuitOpen) { + t.Fatalf("request %d: expected ErrCircuitOpen during half-open probe, got %v", i, err) + } + } + + // Release the probe — circuit should close. + close(probeDone) + + // Wait for transition to Closed. + select { + case s := <-stateChanged: + if s != platform.StateClosed { + t.Fatalf("expected transition to closed after probe success, got %s", s.String()) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for transition to closed") + } + + if cb.GetState() != platform.StateClosed { + t.Fatalf("expected closed after probe success, got %s", cb.GetState().String()) + } +} diff --git a/internal/drills/release_gates_test.go b/internal/drills/release_gates_test.go new file mode 100644 index 000000000..596eb4758 --- /dev/null +++ b/internal/drills/release_gates_test.go @@ -0,0 +1,206 @@ +// Package drills contains HA failure drills and release gates. +// +// Release gates are meant to run in CI before deploying a new version. +// They validate the SLO invariants for the control-plane HA features: +// +// 1. Leader failover <30s (validated via unit tests on LeaderGuard). +// 2. Zero duplicate singleton executions during failover. +// 3. Zero job loss in crash tests (durable queue ack/nack). +// 4. Circuit breaker fail-fast under backend failure. +// 5. Bulkhead prevents cascading overload. +// 6. No API outage during single pod loss (leader re-election + queue redelivery). +// +// Run release gates: go test ./internal/drills/ -v -count=1 -run TestReleaseGate +package drills + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/poyrazk/thecloud/internal/platform" +) + +// TestReleaseGate_CircuitBreakerFailFast validates SLO: +// "When a backend is down, requests must fail-fast in <5ms". +func TestReleaseGate_CircuitBreakerFailFast(t *testing.T) { + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "gate-cb", + Threshold: 3, + ResetTimeout: 1 * time.Second, + }) + + // Trip it. + for i := 0; i < 3; i++ { + _ = cb.Execute(func() error { return errors.New("down") }) + } + + // Measure fail-fast latency over 100 calls. + const iterations = 100 + start := time.Now() + for i := 0; i < iterations; i++ { + err := cb.Execute(func() error { return nil }) + if !errors.Is(err, platform.ErrCircuitOpen) { + t.Fatalf("iteration %d: expected ErrCircuitOpen, got %v", i, err) + } + } + elapsed := time.Since(start) + + avgLatency := elapsed / iterations + if avgLatency > 1*time.Millisecond { + t.Fatalf("average fail-fast latency %v exceeds 1ms SLO", avgLatency) + } + t.Logf("PASS: avg fail-fast latency = %v (SLO: <1ms)", avgLatency) +} + +// TestReleaseGate_BulkheadIsolation validates SLO: +// "A saturated adapter must not block unrelated adapters". +func TestReleaseGate_BulkheadIsolation(t *testing.T) { + // Two independent bulkheads for two adapters. + bhCompute := platform.NewBulkhead(platform.BulkheadOpts{Name: "compute", MaxConc: 2, WaitTimeout: 50 * time.Millisecond}) + bhNetwork := platform.NewBulkhead(platform.BulkheadOpts{Name: "network", MaxConc: 5, WaitTimeout: 50 * time.Millisecond}) + + ctx := context.Background() + + // Saturate compute bulkhead. + blockCh := make(chan struct{}) + var wg sync.WaitGroup + var startedWg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + startedWg.Add(1) + go func() { + defer wg.Done() + _ = bhCompute.Execute(ctx, func() error { + startedWg.Done() + <-blockCh + return nil + }) + }() + } + startedWg.Wait() // Ensure they have acquired slots. + + // Compute is now full. + err := bhCompute.Execute(ctx, func() error { return nil }) + if !errors.Is(err, platform.ErrBulkheadFull) { + t.Fatalf("compute bulkhead should be full, got %v", err) + } + + // Network bulkhead must still be operational. + err = bhNetwork.Execute(ctx, func() error { return nil }) + if err != nil { + t.Fatalf("network bulkhead should be available, got %v", err) + } + + close(blockCh) + wg.Wait() + t.Log("PASS: saturated compute did not affect network adapter") +} + +// TestReleaseGate_CircuitBreakerRecovery validates SLO: +// "After backend recovery, the circuit must close within resetTimeout + probe time". +func TestReleaseGate_CircuitBreakerRecovery(t *testing.T) { + const resetTimeout = 200 * time.Millisecond + healthy := &atomic.Bool{} + + cb := platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: "gate-recovery", + Threshold: 2, + ResetTimeout: resetTimeout, + SuccessRequired: 1, + }) + + // Trip it. + for i := 0; i < 2; i++ { + _ = cb.Execute(func() error { return errors.New("down") }) + } + + // Simulate recovery after 100ms. + go func() { + time.Sleep(100 * time.Millisecond) + healthy.Store(true) + }() + + // Poll until circuit closes or timeout. + deadline := time.After(resetTimeout + 200*time.Millisecond) + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-deadline: + t.Fatalf("circuit did not recover within SLO window. State: %s", cb.GetState().String()) + case <-ticker.C: + err := cb.Execute(func() error { + if healthy.Load() { + return nil + } + return errors.New("still down") + }) + if err == nil && cb.GetState() == platform.StateClosed { + t.Logf("PASS: circuit recovered (state=%s)", cb.GetState().String()) + return + } + } + } +} + +// TestReleaseGate_RetryIdempotency validates SLO: +// "Retry must not execute the function more than MaxAttempts times". +func TestReleaseGate_RetryIdempotency(t *testing.T) { + for _, maxAttempts := range []int{1, 3, 5, 10} { + t.Run(fmt.Sprintf("max_%d", maxAttempts), func(t *testing.T) { + var count atomic.Int64 + _ = platform.Retry(context.Background(), platform.RetryOpts{ + MaxAttempts: maxAttempts, + BaseDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + }, func(ctx context.Context) error { + count.Add(1) + return errors.New("always fail") + }) + + if count.Load() != int64(maxAttempts) { + t.Fatalf("expected exactly %d attempts, got %d", maxAttempts, count.Load()) + } + }) + } +} + +// TestReleaseGate_ConcurrentCircuitBreakers validates SLO: +// "Multiple independent circuit breakers must not interfere with each other". +func TestReleaseGate_ConcurrentCircuitBreakers(t *testing.T) { + cbs := make([]*platform.CircuitBreaker, 5) + for i := range cbs { + cbs[i] = platform.NewCircuitBreakerWithOpts(platform.CircuitBreakerOpts{ + Name: fmt.Sprintf("adapter-%d", i), + Threshold: 3, + ResetTimeout: 1 * time.Second, + }) + } + + // Trip only breaker 0. + for i := 0; i < 3; i++ { + _ = cbs[0].Execute(func() error { return errors.New("down") }) + } + + if cbs[0].GetState() != platform.StateOpen { + t.Fatal("breaker 0 should be open") + } + + // All others should be closed and functional. + for i := 1; i < 5; i++ { + err := cbs[i].Execute(func() error { return nil }) + if err != nil { + t.Fatalf("breaker %d should be functional, got %v", i, err) + } + if cbs[i].GetState() != platform.StateClosed { + t.Fatalf("breaker %d should be closed, got %s", i, cbs[i].GetState().String()) + } + } + t.Log("PASS: tripped breaker did not affect independent breakers") +} diff --git a/internal/platform/bulkhead.go b/internal/platform/bulkhead.go new file mode 100644 index 000000000..033f5ee8f --- /dev/null +++ b/internal/platform/bulkhead.go @@ -0,0 +1,92 @@ +package platform + +import ( + "context" + "errors" + "time" +) + +// ErrBulkheadFull is returned when the bulkhead's concurrency limit is reached +// and the caller's timeout/context expires before a slot opens. +var ErrBulkheadFull = errors.New("bulkhead: concurrency limit reached") + +// Bulkhead limits concurrent access to a resource using a semaphore pattern. +// It prevents one slow/failing component from consuming all available goroutines +// and cascading failure to other parts of the system. +type Bulkhead struct { + name string + sem chan struct{} + timeout time.Duration +} + +// BulkheadOpts configures a bulkhead. +type BulkheadOpts struct { + Name string // Identifier for logging/metrics. + MaxConc int // Maximum concurrent requests. Default 10. + WaitTimeout time.Duration // How long to wait for a slot. Default 5s. 0 means use context deadline. +} + +// NewBulkhead creates a new concurrency-limiting bulkhead. +func NewBulkhead(opts BulkheadOpts) *Bulkhead { + if opts.MaxConc <= 0 { + opts.MaxConc = 10 + } + return &Bulkhead{ + name: opts.Name, + sem: make(chan struct{}, opts.MaxConc), + timeout: opts.WaitTimeout, + } +} + +// Execute runs fn within the bulkhead's concurrency limit. +// If the bulkhead is full and the wait timeout (or context) expires, +// ErrBulkheadFull is returned without calling fn. +func (b *Bulkhead) Execute(ctx context.Context, fn func() error) error { + if err := b.acquire(ctx); err != nil { + return err + } + defer b.release() + return fn() +} + +func (b *Bulkhead) acquire(ctx context.Context) error { + select { + case <-ctx.Done(): + return ErrBulkheadFull + default: + } + + if b.timeout > 0 { + timer := time.NewTimer(b.timeout) + defer timer.Stop() + select { + case b.sem <- struct{}{}: + return nil + case <-timer.C: + return ErrBulkheadFull + case <-ctx.Done(): + return ErrBulkheadFull + } + } + // No explicit timeout — rely on context. + select { + case b.sem <- struct{}{}: + return nil + case <-ctx.Done(): + return ErrBulkheadFull + } +} + +func (b *Bulkhead) release() { + <-b.sem +} + +// Available returns the number of currently available slots. +func (b *Bulkhead) Available() int { + return cap(b.sem) - len(b.sem) +} + +// Name returns the bulkhead's configured name. +func (b *Bulkhead) Name() string { + return b.name +} diff --git a/internal/platform/bulkhead_test.go b/internal/platform/bulkhead_test.go new file mode 100644 index 000000000..c7ccdfc5c --- /dev/null +++ b/internal/platform/bulkhead_test.go @@ -0,0 +1,118 @@ +package platform + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBulkheadAllowsUpToMaxConcurrency(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{Name: "test", MaxConc: 2}) + + var running atomic.Int32 + var maxSeen atomic.Int32 + var wg sync.WaitGroup + + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := bh.Execute(context.Background(), func() error { + cur := running.Add(1) + defer running.Add(-1) + // Track the max concurrent. + for { + old := maxSeen.Load() + if cur <= old || maxSeen.CompareAndSwap(old, cur) { + break + } + } + time.Sleep(50 * time.Millisecond) + return nil + }) + assert.NoError(t, err) + }() + } + + wg.Wait() + assert.LessOrEqual(t, maxSeen.Load(), int32(2)) +} + +func TestBulkheadRejectsWhenFull(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{Name: "test", MaxConc: 1, WaitTimeout: 50 * time.Millisecond}) + + // Fill the bulkhead. + started := make(chan struct{}) + done := make(chan struct{}) + go func() { + _ = bh.Execute(context.Background(), func() error { + close(started) + <-done + return nil + }) + }() + <-started + + // Second call should be rejected. + err := bh.Execute(context.Background(), func() error { return nil }) + require.ErrorIs(t, err, ErrBulkheadFull) + + close(done) +} + +func TestBulkheadRespectsContext(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{Name: "test", MaxConc: 1}) + + // Fill the bulkhead. + started := make(chan struct{}) + done := make(chan struct{}) + go func() { + _ = bh.Execute(context.Background(), func() error { + close(started) + <-done + return nil + }) + }() + <-started + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := bh.Execute(ctx, func() error { return nil }) + require.ErrorIs(t, err, ErrBulkheadFull) + + close(done) +} + +func TestBulkheadPropagatesFunctionError(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{Name: "test", MaxConc: 5}) + myErr := errors.New("business error") + err := bh.Execute(context.Background(), func() error { return myErr }) + require.ErrorIs(t, err, myErr) +} + +func TestBulkheadAvailable(t *testing.T) { + bh := NewBulkhead(BulkheadOpts{MaxConc: 3}) + assert.Equal(t, 3, bh.Available()) + + started := make(chan struct{}) + done := make(chan struct{}) + go func() { + _ = bh.Execute(context.Background(), func() error { + close(started) + <-done + return nil + }) + }() + <-started + assert.Equal(t, 2, bh.Available()) + close(done) + time.Sleep(10 * time.Millisecond) + assert.Equal(t, 3, bh.Available()) +} diff --git a/internal/platform/circuit_breaker.go b/internal/platform/circuit_breaker.go index c57490f78..e710e885f 100644 --- a/internal/platform/circuit_breaker.go +++ b/internal/platform/circuit_breaker.go @@ -3,6 +3,7 @@ package platform import ( "errors" + "fmt" "sync" "time" ) @@ -22,22 +23,80 @@ const ( StateHalfOpen ) -// CircuitBreaker implements the circuit breaker pattern. +// String returns a human-readable name for the circuit breaker state. +func (s State) String() string { + switch s { + case StateClosed: + return "closed" + case StateOpen: + return "open" + case StateHalfOpen: + return "half-open" + default: + return fmt.Sprintf("unknown(%d)", int(s)) + } +} + +// StateChangeFunc is called when the circuit breaker transitions between states. +// The old and new states are provided. Implementations must not block. +type StateChangeFunc func(name string, from, to State) + +// CircuitBreakerOpts configures the circuit breaker. All fields are optional +// and have sensible defaults; use the functional options to override. +type CircuitBreakerOpts struct { + Name string // Identifies this breaker in logs/metrics. + Threshold int // Consecutive failures to trip open. Default 5. + ResetTimeout time.Duration // Time in open before trying half-open. Default 30s. + SuccessRequired int // Successes in half-open to close. Default 1. + OnStateChange StateChangeFunc // Optional callback. +} + +// CircuitBreaker implements the circuit breaker pattern with proper +// half-open single-flight: only one probe request is allowed while open +// transitions to half-open. type CircuitBreaker struct { - mu sync.RWMutex + mu sync.Mutex + + name string state State failureCount int - failureThreshold int + successCount int // successes in half-open + threshold int + successRequired int resetTimeout time.Duration lastFailure time.Time + halfOpenInFlight bool // true while a half-open probe is executing + onStateChange StateChangeFunc } -// NewCircuitBreaker creates a new circuit breaker. +// NewCircuitBreaker creates a circuit breaker. The two positional args +// (threshold, resetTimeout) are kept for backward compatibility with existing +// callers. Use NewCircuitBreakerWithOpts for full configuration. func NewCircuitBreaker(threshold int, resetTimeout time.Duration) *CircuitBreaker { + return NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Threshold: threshold, + ResetTimeout: resetTimeout, + }) +} + +// NewCircuitBreakerWithOpts creates a circuit breaker with full options. +func NewCircuitBreakerWithOpts(opts CircuitBreakerOpts) *CircuitBreaker { + if opts.Threshold <= 0 { + opts.Threshold = 5 + } + if opts.ResetTimeout <= 0 { + opts.ResetTimeout = 30 * time.Second + } + if opts.SuccessRequired <= 0 { + opts.SuccessRequired = 1 + } return &CircuitBreaker{ - state: StateClosed, - failureThreshold: threshold, - resetTimeout: resetTimeout, + name: opts.Name, + state: StateClosed, + threshold: opts.Threshold, + successRequired: opts.SuccessRequired, + resetTimeout: opts.ResetTimeout, + onStateChange: opts.OnStateChange, } } @@ -58,53 +117,134 @@ func (cb *CircuitBreaker) Execute(fn func() error) error { } func (cb *CircuitBreaker) allowRequest() bool { - cb.mu.RLock() - defer cb.mu.RUnlock() - - if cb.state == StateClosed { - return true + cb.mu.Lock() + var cbFunc StateChangeFunc + var name string + var from, to State + var changed bool + allowed := false + + switch cb.state { + case StateClosed: + allowed = true + case StateOpen: + if time.Since(cb.lastFailure) <= cb.resetTimeout { + break + } + // Transition to half-open; only allow one probe at a time. + if cb.halfOpenInFlight { + break + } + cbFunc, name, from, to, changed = cb.transitionLocked(StateHalfOpen) + cb.halfOpenInFlight = true + cb.successCount = 0 + allowed = true + case StateHalfOpen: + // Allow additional requests only if no probe is in flight. + if cb.halfOpenInFlight { + break + } + cb.halfOpenInFlight = true + allowed = true } + cb.mu.Unlock() - if cb.state == StateOpen { - if time.Since(cb.lastFailure) > cb.resetTimeout { - return true // Transition to half-open (implied by letting one request through) - } - return false + if changed && cbFunc != nil { + cbFunc(name, from, to) } - return true // Half-open + return allowed } func (cb *CircuitBreaker) recordFailure() { cb.mu.Lock() - defer cb.mu.Unlock() + var cbFunc StateChangeFunc + var name string + var from, to State + var changed bool + cb.halfOpenInFlight = false cb.failureCount++ cb.lastFailure = time.Now() - if cb.state == StateClosed && cb.failureCount >= cb.failureThreshold { - cb.state = StateOpen - } else if cb.state == StateHalfOpen { - cb.state = StateOpen + switch cb.state { + case StateClosed: + if cb.failureCount >= cb.threshold { + cbFunc, name, from, to, changed = cb.transitionLocked(StateOpen) + } + case StateHalfOpen: + // Probe failed — go back to open. + cbFunc, name, from, to, changed = cb.transitionLocked(StateOpen) + } + cb.mu.Unlock() + + if changed && cbFunc != nil { + cbFunc(name, from, to) } } func (cb *CircuitBreaker) recordSuccess() { cb.mu.Lock() - defer cb.mu.Unlock() + var cbFunc StateChangeFunc + var name string + var from, to State + var changed bool + + cb.halfOpenInFlight = false + + switch cb.state { + case StateHalfOpen: + cb.successCount++ + if cb.successCount >= cb.successRequired { + cb.failureCount = 0 + cb.successCount = 0 + cbFunc, name, from, to, changed = cb.transitionLocked(StateClosed) + } + default: + cb.failureCount = 0 + cb.state = StateClosed + } + cb.mu.Unlock() - cb.failureCount = 0 - cb.state = StateClosed + if changed && cbFunc != nil { + cbFunc(name, from, to) + } +} + +// transitionLocked changes state and fires the callback. Must be called +// with cb.mu held. The callback is invoked synchronously; implementations +// must not block or acquire cb.mu. +func (cb *CircuitBreaker) transitionLocked(to State) (StateChangeFunc, string, State, State, bool) { + from := cb.state + if from == to { + return nil, "", from, to, false + } + cb.state = to + return cb.onStateChange, cb.name, from, to, true } // Reset clears the circuit breaker state. func (cb *CircuitBreaker) Reset() { - cb.recordSuccess() + cb.mu.Lock() + cbFunc, name, from, to, changed := cb.transitionLocked(StateClosed) + cb.failureCount = 0 + cb.successCount = 0 + cb.halfOpenInFlight = false + cb.mu.Unlock() + + if changed && cbFunc != nil { + cbFunc(name, from, to) + } } // GetState returns the current state of the circuit breaker. func (cb *CircuitBreaker) GetState() State { - cb.mu.RLock() - defer cb.mu.RUnlock() + cb.mu.Lock() + defer cb.mu.Unlock() return cb.state } + +// Name returns the configured name of this circuit breaker. +func (cb *CircuitBreaker) Name() string { + return cb.name +} diff --git a/internal/platform/circuit_breaker_test.go b/internal/platform/circuit_breaker_test.go index 39126a8a7..2e16cdfa1 100644 --- a/internal/platform/circuit_breaker_test.go +++ b/internal/platform/circuit_breaker_test.go @@ -2,11 +2,12 @@ package platform import ( "errors" - "github.com/stretchr/testify/require" + "sync" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCircuitBreaker(t *testing.T) { @@ -53,7 +54,7 @@ func TestCircuitBreaker(t *testing.T) { time.Sleep(100 * time.Millisecond) - // This should be allowed (half-open state implicitly) + // This should be allowed (half-open state) err := cb.Execute(func() error { return nil }) @@ -91,3 +92,129 @@ func TestCircuitBreaker(t *testing.T) { assert.Equal(t, StateClosed, cb.GetState()) }) } + +func TestCircuitBreakerHalfOpenSingleFlight(t *testing.T) { + cb := NewCircuitBreaker(1, 50*time.Millisecond) + + // Trip the circuit. + _ = cb.Execute(func() error { return errors.New("fail") }) + assert.Equal(t, StateOpen, cb.GetState()) + + time.Sleep(100 * time.Millisecond) + + // First call goes through as the half-open probe. Use a channel to + // hold the probe in-flight while we test the second call. + probeStarted := make(chan struct{}) + probeDone := make(chan struct{}) + + go func() { + _ = cb.Execute(func() error { + close(probeStarted) + <-probeDone // block until test releases + return nil + }) + }() + + <-probeStarted // wait for probe to be in-flight + + // Second concurrent call should be rejected while probe is in flight. + err := cb.Execute(func() error { return nil }) + assert.Equal(t, ErrCircuitOpen, err, "second request should be blocked while half-open probe is in flight") + + close(probeDone) // release the probe + time.Sleep(10 * time.Millisecond) + + // After probe succeeds, circuit should be closed. + assert.Equal(t, StateClosed, cb.GetState()) +} + +func TestCircuitBreakerOnStateChange(t *testing.T) { + var mu sync.Mutex + transitions := make([]struct{ from, to State }, 0) + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: "test-cb", + Threshold: 1, + ResetTimeout: 50 * time.Millisecond, + OnStateChange: func(name string, from, to State) { + mu.Lock() + transitions = append(transitions, struct{ from, to State }{from, to}) + mu.Unlock() + }, + }) + + // Trip it. + _ = cb.Execute(func() error { return errors.New("fail") }) + time.Sleep(20 * time.Millisecond) // let async callback fire + + mu.Lock() + require.Len(t, transitions, 1) + assert.Equal(t, StateClosed, transitions[0].from) + assert.Equal(t, StateOpen, transitions[0].to) + mu.Unlock() + + // Wait for reset timeout, then succeed to close. + time.Sleep(100 * time.Millisecond) + err := cb.Execute(func() error { return nil }) + require.NoError(t, err) + time.Sleep(20 * time.Millisecond) + + mu.Lock() + // Should have: closed->open, open->half-open, half-open->closed + require.Len(t, transitions, 3) + assert.Equal(t, StateOpen, transitions[1].from) + assert.Equal(t, StateHalfOpen, transitions[1].to) + assert.Equal(t, StateHalfOpen, transitions[2].from) + assert.Equal(t, StateClosed, transitions[2].to) + mu.Unlock() +} + +func TestCircuitBreakerWithOpts(t *testing.T) { + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: "compute", + Threshold: 3, + ResetTimeout: 1 * time.Second, + SuccessRequired: 2, + }) + + assert.Equal(t, "compute", cb.Name()) + assert.Equal(t, StateClosed, cb.GetState()) + + // Trip it with 3 failures. + for i := 0; i < 3; i++ { + _ = cb.Execute(func() error { return errors.New("fail") }) + } + assert.Equal(t, StateOpen, cb.GetState()) +} + +func TestCircuitBreakerSuccessRequired(t *testing.T) { + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Threshold: 1, + ResetTimeout: 50 * time.Millisecond, + SuccessRequired: 2, + }) + + // Trip it. + _ = cb.Execute(func() error { return errors.New("fail") }) + assert.Equal(t, StateOpen, cb.GetState()) + + time.Sleep(100 * time.Millisecond) + + // First success should move to half-open but not closed. + err := cb.Execute(func() error { return nil }) + require.NoError(t, err) + // Still half-open because we need 2 successes. + assert.Equal(t, StateHalfOpen, cb.GetState()) + + // Second success should close. + err = cb.Execute(func() error { return nil }) + require.NoError(t, err) + assert.Equal(t, StateClosed, cb.GetState()) +} + +func TestStateString(t *testing.T) { + assert.Equal(t, "closed", StateClosed.String()) + assert.Equal(t, "open", StateOpen.String()) + assert.Equal(t, "half-open", StateHalfOpen.String()) + assert.Equal(t, "unknown(99)", State(99).String()) +} diff --git a/internal/platform/resilient_compute.go b/internal/platform/resilient_compute.go new file mode 100644 index 000000000..bc5c0135d --- /dev/null +++ b/internal/platform/resilient_compute.go @@ -0,0 +1,279 @@ +package platform + +import ( + "context" + "fmt" + "io" + "log/slog" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientComputeOpts configures the resilient compute wrapper. +type ResilientComputeOpts struct { + // CallTimeout is the per-call context timeout for normal operations. + // Default: 2 minutes. + CallTimeout time.Duration + // LongCallTimeout is the timeout for operations that are expected to take + // longer (e.g., LaunchInstanceWithOptions, RunTask). Default: 10 minutes. + LongCallTimeout time.Duration + // CBThreshold is the number of consecutive failures before the circuit + // opens. Default: 5. + CBThreshold int + // CBResetTimeout is how long the circuit stays open before attempting a + // half-open probe. Default: 30s. + CBResetTimeout time.Duration + // BulkheadMaxConc is the max concurrent calls to the backend. Default: 20. + BulkheadMaxConc int + // BulkheadWait is how long to wait for a bulkhead slot. Default: 10s. + BulkheadWait time.Duration +} + +func (o ResilientComputeOpts) withDefaults() ResilientComputeOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 2 * time.Minute + } + if o.LongCallTimeout <= 0 { + o.LongCallTimeout = 10 * time.Minute + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + if o.BulkheadMaxConc <= 0 { + o.BulkheadMaxConc = 20 + } + if o.BulkheadWait <= 0 { + o.BulkheadWait = 10 * time.Second + } + return o +} + +// ResilientCompute wraps a ComputeBackend with circuit breaker, bulkhead, +// and per-call timeouts. It implements the ports.ComputeBackend interface. +type ResilientCompute struct { + inner ports.ComputeBackend + cb *CircuitBreaker + bulkhead *Bulkhead + logger *slog.Logger + opts ResilientComputeOpts +} + +// NewResilientCompute decorates inner with resilience primitives. +func NewResilientCompute(inner ports.ComputeBackend, logger *slog.Logger, opts ResilientComputeOpts) *ResilientCompute { + opts = opts.withDefaults() + name := fmt.Sprintf("compute-%s", inner.Type()) + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + bh := NewBulkhead(BulkheadOpts{ + Name: name, + MaxConc: opts.BulkheadMaxConc, + WaitTimeout: opts.BulkheadWait, + }) + + return &ResilientCompute{ + inner: inner, + cb: cb, + bulkhead: bh, + logger: logger.With("adapter", name), + opts: opts, + } +} + +// ---------- helpers ---------- + +// callProtected runs fn through bulkhead → circuit breaker → timeout. +func (r *ResilientCompute) callProtected(ctx context.Context, timeout time.Duration, fn func(ctx context.Context) error) error { + return r.bulkhead.Execute(ctx, func() error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return fn(ctx2) + }) + }) +} + +// ---------- Instance Lifecycle ---------- + +func (r *ResilientCompute) LaunchInstanceWithOptions(ctx context.Context, opts ports.CreateInstanceOptions) (string, []string, error) { + var id string + var ps []string + err := r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + var e error + id, ps, e = r.inner.LaunchInstanceWithOptions(ctx, opts) + return e + }) + return id, ps, err +} + +func (r *ResilientCompute) StartInstance(ctx context.Context, id string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.StartInstance(ctx, id) + }) +} + +func (r *ResilientCompute) StopInstance(ctx context.Context, id string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.StopInstance(ctx, id) + }) +} + +func (r *ResilientCompute) DeleteInstance(ctx context.Context, id string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DeleteInstance(ctx, id) + }) +} + +func (r *ResilientCompute) GetInstanceLogs(ctx context.Context, id string) (io.ReadCloser, error) { + var rc io.ReadCloser + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + rc, e = r.inner.GetInstanceLogs(ctx, id) + return e + }) + return rc, err +} + +func (r *ResilientCompute) GetInstanceStats(ctx context.Context, id string) (io.ReadCloser, error) { + var rc io.ReadCloser + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + rc, e = r.inner.GetInstanceStats(ctx, id) + return e + }) + return rc, err +} + +func (r *ResilientCompute) GetInstancePort(ctx context.Context, id string, internalPort string) (int, error) { + var port int + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + port, e = r.inner.GetInstancePort(ctx, id, internalPort) + return e + }) + return port, err +} + +func (r *ResilientCompute) GetInstanceIP(ctx context.Context, id string) (string, error) { + var ip string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + ip, e = r.inner.GetInstanceIP(ctx, id) + return e + }) + return ip, err +} + +func (r *ResilientCompute) GetConsoleURL(ctx context.Context, id string) (string, error) { + var url string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + url, e = r.inner.GetConsoleURL(ctx, id) + return e + }) + return url, err +} + +// ---------- Execution ---------- + +func (r *ResilientCompute) Exec(ctx context.Context, id string, cmd []string) (string, error) { + var out string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + out, e = r.inner.Exec(ctx, id, cmd) + return e + }) + return out, err +} + +func (r *ResilientCompute) RunTask(ctx context.Context, opts ports.RunTaskOptions) (string, []string, error) { + var id string + var ps []string + err := r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + var e error + id, ps, e = r.inner.RunTask(ctx, opts) + return e + }) + return id, ps, err +} + +func (r *ResilientCompute) WaitTask(ctx context.Context, id string) (int64, error) { + var code int64 + err := r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + var e error + code, e = r.inner.WaitTask(ctx, id) + return e + }) + return code, err +} + +// ---------- Network Management ---------- + +func (r *ResilientCompute) CreateNetwork(ctx context.Context, name string) (string, error) { + var id string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + id, e = r.inner.CreateNetwork(ctx, name) + return e + }) + return id, err +} + +func (r *ResilientCompute) DeleteNetwork(ctx context.Context, id string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DeleteNetwork(ctx, id) + }) +} + +// ---------- Volume Attachment ---------- + +func (r *ResilientCompute) AttachVolume(ctx context.Context, id string, volumePath string) (string, error) { + var devPath string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + devPath, e = r.inner.AttachVolume(ctx, id, volumePath) + return e + }) + return devPath, err +} + +func (r *ResilientCompute) DetachVolume(ctx context.Context, id string, volumePath string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DetachVolume(ctx, id, volumePath) + }) +} + +// ---------- Health ---------- + +// Ping bypasses the bulkhead (low cost, used for health checks) but still +// goes through the circuit breaker so a broken backend trips the circuit. +func (r *ResilientCompute) Ping(ctx context.Context) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return r.inner.Ping(ctx2) + }) +} + +// Type delegates directly — no protection needed. +func (r *ResilientCompute) Type() string { + return r.inner.Type() +} + +// Unwrap returns the underlying ComputeBackend (useful for tests). +func (r *ResilientCompute) Unwrap() ports.ComputeBackend { + return r.inner +} diff --git a/internal/platform/resilient_compute_test.go b/internal/platform/resilient_compute_test.go new file mode 100644 index 000000000..a682d2541 --- /dev/null +++ b/internal/platform/resilient_compute_test.go @@ -0,0 +1,299 @@ +package platform + +import ( + "context" + "errors" + "io" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" + "log/slog" +) + +// ---------- mock compute backend ---------- + +type mockCompute struct { + callCount atomic.Int64 + delay time.Duration + err error +} + +func (m *mockCompute) wait(ctx context.Context) error { + if m.delay <= 0 { + return nil + } + select { + case <-time.After(m.delay): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (m *mockCompute) LaunchInstanceWithOptions(ctx context.Context, _ ports.CreateInstanceOptions) (string, []string, error) { + m.callCount.Add(1) + if err := m.wait(ctx); err != nil { + return "", nil, err + } + return "inst-1", []string{"8080"}, m.err +} + +func (m *mockCompute) StartInstance(ctx context.Context, _ string) error { + m.callCount.Add(1) + if err := m.wait(ctx); err != nil { + return err + } + return m.err +} +func (m *mockCompute) StopInstance(_ context.Context, _ string) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) DeleteInstance(_ context.Context, _ string) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) GetInstanceLogs(_ context.Context, _ string) (io.ReadCloser, error) { + m.callCount.Add(1) + return io.NopCloser(strings.NewReader("logs")), m.err +} +func (m *mockCompute) GetInstanceStats(_ context.Context, _ string) (io.ReadCloser, error) { + m.callCount.Add(1) + return io.NopCloser(strings.NewReader("stats")), m.err +} +func (m *mockCompute) GetInstancePort(_ context.Context, _ string, _ string) (int, error) { + m.callCount.Add(1) + return 8080, m.err +} +func (m *mockCompute) GetInstanceIP(_ context.Context, _ string) (string, error) { + m.callCount.Add(1) + return "10.0.0.1", m.err +} +func (m *mockCompute) GetConsoleURL(_ context.Context, _ string) (string, error) { + m.callCount.Add(1) + return "https://console", m.err +} +func (m *mockCompute) Exec(_ context.Context, _ string, _ []string) (string, error) { + m.callCount.Add(1) + return "output", m.err +} +func (m *mockCompute) RunTask(_ context.Context, _ ports.RunTaskOptions) (string, []string, error) { + m.callCount.Add(1) + return "task-1", nil, m.err +} +func (m *mockCompute) WaitTask(_ context.Context, _ string) (int64, error) { + m.callCount.Add(1) + return 0, m.err +} +func (m *mockCompute) CreateNetwork(_ context.Context, _ string) (string, error) { + m.callCount.Add(1) + return "net-1", m.err +} +func (m *mockCompute) DeleteNetwork(_ context.Context, _ string) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) AttachVolume(_ context.Context, _ string, _ string) (string, error) { + m.callCount.Add(1) + return "/dev/vdb", m.err +} +func (m *mockCompute) DetachVolume(_ context.Context, _ string, _ string) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) Ping(_ context.Context) error { + m.callCount.Add(1) + return m.err +} +func (m *mockCompute) Type() string { return "mock" } + +// ---------- tests ---------- + +func TestResilientComputePassthrough(t *testing.T) { + // All calls should pass through to the mock on success. + mock := &mockCompute{} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{}) + + ctx := context.Background() + + id, ps, err := rc.LaunchInstanceWithOptions(ctx, ports.CreateInstanceOptions{}) + assertNoErr(t, err) + if id != "inst-1" || len(ps) != 1 { + t.Fatalf("unexpected launch result: %s %v", id, ps) + } + + assertNoErr(t, rc.StartInstance(ctx, "x")) + assertNoErr(t, rc.StopInstance(ctx, "x")) + assertNoErr(t, rc.DeleteInstance(ctx, "x")) + + _, err = rc.GetInstanceLogs(ctx, "x") + assertNoErr(t, err) + _, err = rc.GetInstanceStats(ctx, "x") + assertNoErr(t, err) + port, err := rc.GetInstancePort(ctx, "x", "80") + assertNoErr(t, err) + if port != 8080 { + t.Fatalf("expected 8080, got %d", port) + } + ip, err := rc.GetInstanceIP(ctx, "x") + assertNoErr(t, err) + if ip != "10.0.0.1" { + t.Fatalf("expected 10.0.0.1, got %s", ip) + } + + out, err := rc.Exec(ctx, "x", []string{"ls"}) + assertNoErr(t, err) + if out != "output" { + t.Fatalf("expected output, got %s", out) + } + + assertNoErr(t, rc.Ping(ctx)) + if rc.Type() != "mock" { + t.Fatalf("expected mock, got %s", rc.Type()) + } + + if mock.callCount.Load() < 10 { + t.Fatalf("expected at least 10 calls, got %d", mock.callCount.Load()) + } +} + +func TestResilientComputeCircuitTrips(t *testing.T) { + // After threshold failures, the circuit should open and reject immediately. + mock := &mockCompute{err: errors.New("backend down")} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{ + CBThreshold: 3, + CBResetTimeout: 5 * time.Second, + }) + + ctx := context.Background() + + // 3 failures to trip the circuit. + for i := 0; i < 3; i++ { + err := rc.StartInstance(ctx, "x") + if err == nil { + t.Fatal("expected error") + } + } + + // Next call should get ErrCircuitOpen without hitting the mock. + callsBefore := mock.callCount.Load() + err := rc.StartInstance(ctx, "x") + if !errors.Is(err, ErrCircuitOpen) { + t.Fatalf("expected ErrCircuitOpen, got %v", err) + } + if mock.callCount.Load() != callsBefore { + t.Fatal("expected mock not to be called when circuit is open") + } +} + +func TestResilientComputeBulkheadLimits(t *testing.T) { + // When bulkhead is full, calls should be rejected. + mock := &mockCompute{delay: 500 * time.Millisecond} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{ + BulkheadMaxConc: 2, + BulkheadWait: 50 * time.Millisecond, + CallTimeout: 2 * time.Second, + }) + + ctx := context.Background() + var wg sync.WaitGroup + var bulkheadErrors atomic.Int64 + + // Ensure the first 2 goroutines grab the slots before the rest start. + ready := make(chan struct{}) + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + if idx >= 2 { + <-ready // Wait until the first 2 have started. + } + err := rc.StartInstance(ctx, "x") + if errors.Is(err, ErrBulkheadFull) { + bulkheadErrors.Add(1) + } + }(i) + } + // Give the first 2 goroutines time to acquire the slots. + time.Sleep(50 * time.Millisecond) + close(ready) + wg.Wait() + + if bulkheadErrors.Load() == 0 { + t.Fatal("expected at least one bulkhead rejection") + } +} + +func TestResilientComputeTimeout(t *testing.T) { + // A slow backend should be cancelled by the per-call timeout. + mock := &mockCompute{delay: 5 * time.Second} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{ + CallTimeout: 100 * time.Millisecond, + }) + + ctx := context.Background() + start := time.Now() + err := rc.StartInstance(ctx, "x") + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected timeout error") + } + // Should complete much faster than 5s. + if elapsed > 2*time.Second { + t.Fatalf("timeout not enforced, took %v", elapsed) + } +} + +func TestResilientComputeUnwrap(t *testing.T) { + mock := &mockCompute{} + rc := NewResilientCompute(mock, slog.Default(), ResilientComputeOpts{}) + if rc.Unwrap() != mock { + t.Fatal("Unwrap should return the inner backend") + } +} + +func TestResilientComputePingBypassesBulkhead(t *testing.T) { + // Ping should work even when the bulkhead is completely full. + mock := &mockCompute{delay: 500 * time.Millisecond} + logger := slog.Default() + rc := NewResilientCompute(mock, logger, ResilientComputeOpts{ + BulkheadMaxConc: 1, + BulkheadWait: 10 * time.Millisecond, + }) + + ctx := context.Background() + + // Saturate the bulkhead. + started := make(chan struct{}) + go func() { + close(started) + _ = rc.StartInstance(ctx, "x") + }() + <-started + time.Sleep(20 * time.Millisecond) + + // Ping should still work (bypasses bulkhead). + err := rc.Ping(ctx) + // err may or may not be nil depending on timing, but it must NOT be ErrBulkheadFull. + if errors.Is(err, ErrBulkheadFull) { + t.Fatal("Ping should bypass bulkhead") + } +} + +// ---------- test helpers ---------- + +func assertNoErr(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/platform/resilient_dns.go b/internal/platform/resilient_dns.go new file mode 100644 index 000000000..0f1419c2e --- /dev/null +++ b/internal/platform/resilient_dns.go @@ -0,0 +1,125 @@ +package platform + +import ( + "context" + "log/slog" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientDNSOpts configures the resilient DNS wrapper. +type ResilientDNSOpts struct { + CallTimeout time.Duration // Per-call timeout. Default: 10s. + CBThreshold int // Failures to trip. Default: 5. + CBResetTimeout time.Duration // Open→half-open wait. Default: 30s. +} + +func (o ResilientDNSOpts) withDefaults() ResilientDNSOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 10 * time.Second + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + return o +} + +// ResilientDNS wraps a DNSBackend with circuit breaker and per-call timeouts. +// DNS calls are lightweight so no bulkhead is applied (PowerDNS HTTP API is +// already serialized). +type ResilientDNS struct { + inner ports.DNSBackend + cb *CircuitBreaker + logger *slog.Logger + opts ResilientDNSOpts +} + +// NewResilientDNS decorates inner with resilience primitives. +func NewResilientDNS(inner ports.DNSBackend, logger *slog.Logger, opts ResilientDNSOpts) *ResilientDNS { + opts = opts.withDefaults() + name := "dns-powerdns" + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + return &ResilientDNS{ + inner: inner, + cb: cb, + logger: logger.With("adapter", name), + opts: opts, + } +} + +func (r *ResilientDNS) callProtected(ctx context.Context, fn func(ctx context.Context) error) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.CallTimeout) + defer cancel() + return fn(ctx2) + }) +} + +// ---------- Zone Operations ---------- + +func (r *ResilientDNS) CreateZone(ctx context.Context, zoneName string, nameservers []string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.CreateZone(ctx, zoneName, nameservers) + }) +} + +func (r *ResilientDNS) DeleteZone(ctx context.Context, zoneName string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteZone(ctx, zoneName) + }) +} + +func (r *ResilientDNS) GetZone(ctx context.Context, zoneName string) (*ports.ZoneInfo, error) { + var info *ports.ZoneInfo + err := r.callProtected(ctx, func(ctx context.Context) error { + var e error + info, e = r.inner.GetZone(ctx, zoneName) + return e + }) + return info, err +} + +// ---------- Record Operations ---------- + +func (r *ResilientDNS) AddRecords(ctx context.Context, zoneName string, records []ports.RecordSet) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.AddRecords(ctx, zoneName, records) + }) +} + +func (r *ResilientDNS) UpdateRecords(ctx context.Context, zoneName string, records []ports.RecordSet) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.UpdateRecords(ctx, zoneName, records) + }) +} + +func (r *ResilientDNS) DeleteRecords(ctx context.Context, zoneName string, name string, recordType string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteRecords(ctx, zoneName, name, recordType) + }) +} + +func (r *ResilientDNS) ListRecords(ctx context.Context, zoneName string) ([]ports.RecordSet, error) { + var records []ports.RecordSet + err := r.callProtected(ctx, func(ctx context.Context) error { + var e error + records, e = r.inner.ListRecords(ctx, zoneName) + return e + }) + return records, err +} diff --git a/internal/platform/resilient_lb.go b/internal/platform/resilient_lb.go new file mode 100644 index 000000000..d646ef701 --- /dev/null +++ b/internal/platform/resilient_lb.go @@ -0,0 +1,97 @@ +package platform + +import ( + "context" + "log/slog" + "time" + + "github.com/google/uuid" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientLBOpts configures the resilient load balancer proxy wrapper. +type ResilientLBOpts struct { + CallTimeout time.Duration // Per-call timeout. Default: 30s. + LongTimeout time.Duration // Timeout for DeployProxy (container launch). Default: 2m. + CBThreshold int // Failures to trip. Default: 5. + CBResetTimeout time.Duration // Open→half-open wait. Default: 30s. +} + +func (o ResilientLBOpts) withDefaults() ResilientLBOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 30 * time.Second + } + if o.LongTimeout <= 0 { + o.LongTimeout = 2 * time.Minute + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + return o +} + +// ResilientLB wraps an LBProxyAdapter with circuit breaker and per-call timeouts. +// LB proxy has only 3 methods so no bulkhead is needed — the compute bulkhead +// already limits the underlying container/VM creation. +type ResilientLB struct { + inner ports.LBProxyAdapter + cb *CircuitBreaker + logger *slog.Logger + opts ResilientLBOpts +} + +// NewResilientLB decorates inner with resilience primitives. +func NewResilientLB(inner ports.LBProxyAdapter, logger *slog.Logger, opts ResilientLBOpts) *ResilientLB { + opts = opts.withDefaults() + name := "lb-proxy" + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + return &ResilientLB{ + inner: inner, + cb: cb, + logger: logger.With("adapter", name), + opts: opts, + } +} + +func (r *ResilientLB) DeployProxy(ctx context.Context, lb *domain.LoadBalancer, targets []*domain.LBTarget) (string, error) { + var addr string + err := r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.LongTimeout) + defer cancel() + var e error + addr, e = r.inner.DeployProxy(ctx2, lb, targets) + return e + }) + return addr, err +} + +func (r *ResilientLB) RemoveProxy(ctx context.Context, lbID uuid.UUID) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.CallTimeout) + defer cancel() + return r.inner.RemoveProxy(ctx2, lbID) + }) +} + +func (r *ResilientLB) UpdateProxyConfig(ctx context.Context, lb *domain.LoadBalancer, targets []*domain.LBTarget) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.CallTimeout) + defer cancel() + return r.inner.UpdateProxyConfig(ctx2, lb, targets) + }) +} diff --git a/internal/platform/resilient_network.go b/internal/platform/resilient_network.go new file mode 100644 index 000000000..69fba40cb --- /dev/null +++ b/internal/platform/resilient_network.go @@ -0,0 +1,211 @@ +package platform + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientNetworkOpts configures the resilient network wrapper. +type ResilientNetworkOpts struct { + CallTimeout time.Duration // Per-call timeout. Default: 30s. + CBThreshold int // Failures to trip. Default: 5. + CBResetTimeout time.Duration // Open→half-open wait. Default: 30s. + BulkheadMaxConc int // Max concurrent calls. Default: 15. + BulkheadWait time.Duration // Bulkhead slot wait. Default: 10s. +} + +func (o ResilientNetworkOpts) withDefaults() ResilientNetworkOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 30 * time.Second + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + if o.BulkheadMaxConc <= 0 { + o.BulkheadMaxConc = 15 + } + if o.BulkheadWait <= 0 { + o.BulkheadWait = 10 * time.Second + } + return o +} + +// ResilientNetwork wraps a NetworkBackend with circuit breaker, bulkhead, +// and per-call timeouts. It implements the ports.NetworkBackend interface. +type ResilientNetwork struct { + inner ports.NetworkBackend + cb *CircuitBreaker + bulkhead *Bulkhead + logger *slog.Logger + opts ResilientNetworkOpts +} + +// NewResilientNetwork decorates inner with resilience primitives. +func NewResilientNetwork(inner ports.NetworkBackend, logger *slog.Logger, opts ResilientNetworkOpts) *ResilientNetwork { + opts = opts.withDefaults() + name := fmt.Sprintf("network-%s", inner.Type()) + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + bh := NewBulkhead(BulkheadOpts{ + Name: name, + MaxConc: opts.BulkheadMaxConc, + WaitTimeout: opts.BulkheadWait, + }) + + return &ResilientNetwork{ + inner: inner, + cb: cb, + bulkhead: bh, + logger: logger.With("adapter", name), + opts: opts, + } +} + +// callProtected runs fn through bulkhead → circuit breaker → timeout. +func (r *ResilientNetwork) callProtected(ctx context.Context, fn func(ctx context.Context) error) error { + return r.bulkhead.Execute(ctx, func() error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, r.opts.CallTimeout) + defer cancel() + return fn(ctx2) + }) + }) +} + +// ---------- Bridge Management ---------- + +func (r *ResilientNetwork) CreateBridge(ctx context.Context, name string, vxlanID int) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.CreateBridge(ctx, name, vxlanID) + }) +} + +func (r *ResilientNetwork) DeleteBridge(ctx context.Context, name string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteBridge(ctx, name) + }) +} + +func (r *ResilientNetwork) ListBridges(ctx context.Context) ([]string, error) { + var bridges []string + err := r.callProtected(ctx, func(ctx context.Context) error { + var e error + bridges, e = r.inner.ListBridges(ctx) + return e + }) + return bridges, err +} + +// ---------- Port Management ---------- + +func (r *ResilientNetwork) AddPort(ctx context.Context, bridge, portName string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.AddPort(ctx, bridge, portName) + }) +} + +func (r *ResilientNetwork) DeletePort(ctx context.Context, bridge, portName string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeletePort(ctx, bridge, portName) + }) +} + +// ---------- VXLAN Tunnels ---------- + +func (r *ResilientNetwork) CreateVXLANTunnel(ctx context.Context, bridge string, vni int, remoteIP string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.CreateVXLANTunnel(ctx, bridge, vni, remoteIP) + }) +} + +func (r *ResilientNetwork) DeleteVXLANTunnel(ctx context.Context, bridge string, remoteIP string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteVXLANTunnel(ctx, bridge, remoteIP) + }) +} + +// ---------- Security Groups (Flow Rules) ---------- + +func (r *ResilientNetwork) AddFlowRule(ctx context.Context, bridge string, rule ports.FlowRule) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.AddFlowRule(ctx, bridge, rule) + }) +} + +func (r *ResilientNetwork) DeleteFlowRule(ctx context.Context, bridge string, match string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteFlowRule(ctx, bridge, match) + }) +} + +func (r *ResilientNetwork) ListFlowRules(ctx context.Context, bridge string) ([]ports.FlowRule, error) { + var rules []ports.FlowRule + err := r.callProtected(ctx, func(ctx context.Context) error { + var e error + rules, e = r.inner.ListFlowRules(ctx, bridge) + return e + }) + return rules, err +} + +// ---------- Veth Pair Management ---------- + +func (r *ResilientNetwork) CreateVethPair(ctx context.Context, hostEnd, containerEnd string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.CreateVethPair(ctx, hostEnd, containerEnd) + }) +} + +func (r *ResilientNetwork) AttachVethToBridge(ctx context.Context, bridge, vethEnd string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.AttachVethToBridge(ctx, bridge, vethEnd) + }) +} + +func (r *ResilientNetwork) DeleteVethPair(ctx context.Context, hostEnd string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.DeleteVethPair(ctx, hostEnd) + }) +} + +func (r *ResilientNetwork) SetVethIP(ctx context.Context, vethEnd, ip, cidr string) error { + return r.callProtected(ctx, func(ctx context.Context) error { + return r.inner.SetVethIP(ctx, vethEnd, ip, cidr) + }) +} + +// ---------- Health ---------- + +func (r *ResilientNetwork) Ping(ctx context.Context) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return r.inner.Ping(ctx2) + }) +} + +func (r *ResilientNetwork) Type() string { + return r.inner.Type() +} + +// Unwrap returns the underlying NetworkBackend. +func (r *ResilientNetwork) Unwrap() ports.NetworkBackend { + return r.inner +} diff --git a/internal/platform/resilient_storage.go b/internal/platform/resilient_storage.go new file mode 100644 index 000000000..4be7798d4 --- /dev/null +++ b/internal/platform/resilient_storage.go @@ -0,0 +1,166 @@ +package platform + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// ResilientStorageOpts configures the resilient storage wrapper. +type ResilientStorageOpts struct { + CallTimeout time.Duration // Per-call timeout. Default: 30s. + LongCallTimeout time.Duration // Timeout for snapshot/restore. Default: 5m. + CBThreshold int // Failures to trip. Default: 5. + CBResetTimeout time.Duration // Open→half-open wait. Default: 30s. + BulkheadMaxConc int // Max concurrent calls. Default: 10. + BulkheadWait time.Duration // Bulkhead slot wait. Default: 10s. +} + +func (o ResilientStorageOpts) withDefaults() ResilientStorageOpts { + if o.CallTimeout <= 0 { + o.CallTimeout = 30 * time.Second + } + if o.LongCallTimeout <= 0 { + o.LongCallTimeout = 5 * time.Minute + } + if o.CBThreshold <= 0 { + o.CBThreshold = 5 + } + if o.CBResetTimeout <= 0 { + o.CBResetTimeout = 30 * time.Second + } + if o.BulkheadMaxConc <= 0 { + o.BulkheadMaxConc = 10 + } + if o.BulkheadWait <= 0 { + o.BulkheadWait = 10 * time.Second + } + return o +} + +// ResilientStorage wraps a StorageBackend with circuit breaker, bulkhead, +// and per-call timeouts. +type ResilientStorage struct { + inner ports.StorageBackend + cb *CircuitBreaker + bulkhead *Bulkhead + logger *slog.Logger + opts ResilientStorageOpts +} + +// NewResilientStorage decorates inner with resilience primitives. +func NewResilientStorage(inner ports.StorageBackend, logger *slog.Logger, opts ResilientStorageOpts) *ResilientStorage { + opts = opts.withDefaults() + name := fmt.Sprintf("storage-%s", inner.Type()) + + cb := NewCircuitBreakerWithOpts(CircuitBreakerOpts{ + Name: name, + Threshold: opts.CBThreshold, + ResetTimeout: opts.CBResetTimeout, + SuccessRequired: 2, + OnStateChange: func(n string, from, to State) { + logger.Warn("circuit breaker state change", + "breaker", n, "from", from.String(), "to", to.String()) + }, + }) + + bh := NewBulkhead(BulkheadOpts{ + Name: name, + MaxConc: opts.BulkheadMaxConc, + WaitTimeout: opts.BulkheadWait, + }) + + return &ResilientStorage{ + inner: inner, + cb: cb, + bulkhead: bh, + logger: logger.With("adapter", name), + opts: opts, + } +} + +func (r *ResilientStorage) callProtected(ctx context.Context, timeout time.Duration, fn func(ctx context.Context) error) error { + return r.bulkhead.Execute(ctx, func() error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return fn(ctx2) + }) + }) +} + +func (r *ResilientStorage) CreateVolume(ctx context.Context, name string, sizeGB int) (string, error) { + var path string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + path, e = r.inner.CreateVolume(ctx, name, sizeGB) + return e + }) + return path, err +} + +func (r *ResilientStorage) DeleteVolume(ctx context.Context, name string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DeleteVolume(ctx, name) + }) +} + +func (r *ResilientStorage) ResizeVolume(ctx context.Context, name string, newSizeGB int) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.ResizeVolume(ctx, name, newSizeGB) + }) +} + +func (r *ResilientStorage) AttachVolume(ctx context.Context, volumeName, instanceID string) (string, error) { + var devPath string + err := r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + var e error + devPath, e = r.inner.AttachVolume(ctx, volumeName, instanceID) + return e + }) + return devPath, err +} + +func (r *ResilientStorage) DetachVolume(ctx context.Context, volumeName, instanceID string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DetachVolume(ctx, volumeName, instanceID) + }) +} + +func (r *ResilientStorage) CreateSnapshot(ctx context.Context, volumeName, snapshotName string) error { + return r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + return r.inner.CreateSnapshot(ctx, volumeName, snapshotName) + }) +} + +func (r *ResilientStorage) RestoreSnapshot(ctx context.Context, volumeName, snapshotName string) error { + return r.callProtected(ctx, r.opts.LongCallTimeout, func(ctx context.Context) error { + return r.inner.RestoreSnapshot(ctx, volumeName, snapshotName) + }) +} + +func (r *ResilientStorage) DeleteSnapshot(ctx context.Context, snapshotName string) error { + return r.callProtected(ctx, r.opts.CallTimeout, func(ctx context.Context) error { + return r.inner.DeleteSnapshot(ctx, snapshotName) + }) +} + +func (r *ResilientStorage) Ping(ctx context.Context) error { + return r.cb.Execute(func() error { + ctx2, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return r.inner.Ping(ctx2) + }) +} + +func (r *ResilientStorage) Type() string { + return r.inner.Type() +} + +// Unwrap returns the underlying StorageBackend. +func (r *ResilientStorage) Unwrap() ports.StorageBackend { + return r.inner +} diff --git a/internal/platform/retry.go b/internal/platform/retry.go new file mode 100644 index 000000000..c465c8be1 --- /dev/null +++ b/internal/platform/retry.go @@ -0,0 +1,91 @@ +package platform + +import ( + "context" + "math" + "math/rand/v2" + "time" +) + +// RetryOpts configures retry behavior. +type RetryOpts struct { + MaxAttempts int // Total attempts (including the first). Default 3. + BaseDelay time.Duration // Initial delay before first retry. Default 500ms. + MaxDelay time.Duration // Cap on exponential growth. Default 30s. + Multiplier float64 // Exponent base. Default 2.0. + // ShouldRetry is an optional predicate that returns false for errors + // that should NOT be retried (e.g., validation errors, 4xx HTTP). + // If nil, all non-nil errors are retried. + ShouldRetry func(error) bool +} + +func (o RetryOpts) withDefaults() RetryOpts { + if o.MaxAttempts <= 0 { + o.MaxAttempts = 3 + } + if o.BaseDelay <= 0 { + o.BaseDelay = 500 * time.Millisecond + } + if o.MaxDelay <= 0 { + o.MaxDelay = 30 * time.Second + } + if o.Multiplier <= 0 { + o.Multiplier = 2.0 + } + return o +} + +// Retry executes fn up to opts.MaxAttempts times with exponential backoff +// and full jitter. It stops early if the context is cancelled or +// opts.ShouldRetry returns false. +func Retry(ctx context.Context, opts RetryOpts, fn func(ctx context.Context) error) error { + opts = opts.withDefaults() + + var lastErr error + for attempt := 0; attempt < opts.MaxAttempts; attempt++ { + if err := ctx.Err(); err != nil { + if lastErr != nil { + return lastErr + } + return err + } + + lastErr = fn(ctx) + if lastErr == nil { + return nil + } + + // Check if this error is retryable. + if opts.ShouldRetry != nil && !opts.ShouldRetry(lastErr) { + return lastErr + } + + // Don't sleep after the last attempt. + if attempt == opts.MaxAttempts-1 { + break + } + + delay := backoffDelay(attempt, opts.BaseDelay, opts.MaxDelay, opts.Multiplier) + timer := time.NewTimer(delay) + select { + case <-ctx.Done(): + timer.Stop() + return lastErr + case <-timer.C: + } + } + + return lastErr +} + +// backoffDelay computes exponential backoff with full jitter: +// delay = random(0, min(maxDelay, baseDelay * multiplier^attempt)) +func backoffDelay(attempt int, base, max time.Duration, mult float64) time.Duration { + exp := math.Pow(mult, float64(attempt)) + calculated := time.Duration(float64(base) * exp) + if calculated > max || calculated <= 0 { + calculated = max + } + // Full jitter: uniform random in [0, calculated]. + return time.Duration(rand.Int64N(int64(calculated) + 1)) +} diff --git a/internal/platform/retry_test.go b/internal/platform/retry_test.go new file mode 100644 index 000000000..99ce1344a --- /dev/null +++ b/internal/platform/retry_test.go @@ -0,0 +1,119 @@ +package platform + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRetrySucceedsImmediately(t *testing.T) { + calls := 0 + err := Retry(context.Background(), RetryOpts{MaxAttempts: 3}, func(ctx context.Context) error { + calls++ + return nil + }) + require.NoError(t, err) + assert.Equal(t, 1, calls) +} + +func TestRetryRetriesOnFailure(t *testing.T) { + var calls atomic.Int32 + err := Retry(context.Background(), RetryOpts{ + MaxAttempts: 3, + BaseDelay: 10 * time.Millisecond, + MaxDelay: 50 * time.Millisecond, + }, func(ctx context.Context) error { + n := calls.Add(1) + if n < 3 { + return errors.New("transient") + } + return nil + }) + require.NoError(t, err) + assert.Equal(t, int32(3), calls.Load()) +} + +func TestRetryExhaustsAttempts(t *testing.T) { + calls := 0 + err := Retry(context.Background(), RetryOpts{ + MaxAttempts: 2, + BaseDelay: 10 * time.Millisecond, + }, func(ctx context.Context) error { + calls++ + return errors.New("permanent") + }) + require.Error(t, err) + assert.Equal(t, "permanent", err.Error()) + assert.Equal(t, 2, calls) +} + +func TestRetryRespectsContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + calls := 0 + err := Retry(ctx, RetryOpts{ + MaxAttempts: 10, + BaseDelay: 50 * time.Millisecond, + }, func(ctx context.Context) error { + calls++ + if calls == 2 { + cancel() + } + return errors.New("fail") + }) + require.Error(t, err) + assert.LessOrEqual(t, calls, 3) // might get 2 or 3 depending on timing +} + +func TestRetryShouldRetryPredicate(t *testing.T) { + permanent := errors.New("permanent error") + calls := 0 + err := Retry(context.Background(), RetryOpts{ + MaxAttempts: 5, + BaseDelay: 10 * time.Millisecond, + ShouldRetry: func(err error) bool { + return !errors.Is(err, permanent) + }, + }, func(ctx context.Context) error { + calls++ + return permanent + }) + require.ErrorIs(t, err, permanent) + assert.Equal(t, 1, calls, "should not retry non-retryable errors") +} + +func TestRetryDefaultOpts(t *testing.T) { + calls := 0 + err := Retry(context.Background(), RetryOpts{}, func(ctx context.Context) error { + calls++ + if calls < 3 { + return errors.New("fail") + } + return nil + }) + require.NoError(t, err) + assert.Equal(t, 3, calls) // default MaxAttempts is 3 +} + +func TestBackoffDelay(t *testing.T) { + base := 100 * time.Millisecond + max := 5 * time.Second + + // Attempt 0: jitter in [0, base] + for i := 0; i < 100; i++ { + d := backoffDelay(0, base, max, 2.0) + assert.GreaterOrEqual(t, d, time.Duration(0)) + assert.LessOrEqual(t, d, base) + } + + // Attempt 3: calculated = 100ms * 2^3 = 800ms + for i := 0; i < 100; i++ { + d := backoffDelay(3, base, max, 2.0) + assert.GreaterOrEqual(t, d, time.Duration(0)) + assert.LessOrEqual(t, d, 800*time.Millisecond) + } +} diff --git a/internal/repositories/noop/adapters.go b/internal/repositories/noop/adapters.go index a916014c8..479bb2829 100644 --- a/internal/repositories/noop/adapters.go +++ b/internal/repositories/noop/adapters.go @@ -40,7 +40,7 @@ func (r *NoopInstanceRepository) ListByVPC(ctx context.Context, vpcID uuid.UUID) } func (r *NoopInstanceRepository) Update(ctx context.Context, i *domain.Instance) error { return nil } -func (r *NoopInstanceRepository) Delete(ctx context.Context, id uuid.UUID) error { return nil } +func (r *NoopInstanceRepository) Delete(ctx context.Context, id uuid.UUID) error { return nil } // NoopVpcRepository type NoopVpcRepository struct{} @@ -110,8 +110,8 @@ func NewNoopComputeBackend() *NoopComputeBackend { func (b *NoopComputeBackend) LaunchInstanceWithOptions(ctx context.Context, opts ports.CreateInstanceOptions) (string, []string, error) { return uuid.New().String(), []string{}, nil } -func (b *NoopComputeBackend) StartInstance(ctx context.Context, id string) error { return nil } -func (b *NoopComputeBackend) StopInstance(ctx context.Context, id string) error { return nil } +func (b *NoopComputeBackend) StartInstance(ctx context.Context, id string) error { return nil } +func (b *NoopComputeBackend) StopInstance(ctx context.Context, id string) error { return nil } func (b *NoopComputeBackend) DeleteInstance(ctx context.Context, id string) error { return nil } func (b *NoopComputeBackend) GetInstanceLogs(ctx context.Context, id string) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader("")), nil @@ -148,7 +148,7 @@ func (b *NoopComputeBackend) DetachVolume(ctx context.Context, id string, volume return nil } func (b *NoopComputeBackend) Ping(ctx context.Context) error { return nil } -func (b *NoopComputeBackend) Type() string { return "noop" } +func (b *NoopComputeBackend) Type() string { return "noop" } // NoopDNSService is a no-op DNS service. type NoopDNSService struct{} @@ -164,7 +164,9 @@ type NoopLogService struct{} func (s *NoopLogService) StreamLogs(ctx context.Context, instanceID string) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader("")), nil } -func (s *NoopLogService) GetLogs(ctx context.Context, instanceID string) (string, error) { return "", nil } +func (s *NoopLogService) GetLogs(ctx context.Context, instanceID string) (string, error) { + return "", nil +} // NoopEventService is a no-op event service. type NoopEventService struct{} @@ -397,13 +399,45 @@ func (s *NoopLBService) ListTargets(ctx context.Context, lbID uuid.UUID) ([]*dom return []*domain.LBTarget{}, nil } -// NoopTaskQueue is a no-op task queue. +// NoopTaskQueue is a no-op task queue that implements DurableTaskQueue. type NoopTaskQueue struct{} func (q *NoopTaskQueue) Enqueue(ctx context.Context, queue string, payload interface{}) error { return nil } func (q *NoopTaskQueue) Dequeue(ctx context.Context, queue string) (string, error) { return "", nil } +func (q *NoopTaskQueue) EnsureGroup(ctx context.Context, queueName, groupName string) error { + return nil +} +func (q *NoopTaskQueue) Receive(ctx context.Context, queueName, groupName, consumerName string) (*ports.DurableMessage, error) { + return nil, nil +} +func (q *NoopTaskQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { + return nil +} +func (q *NoopTaskQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { + return nil +} +func (q *NoopTaskQueue) ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]ports.DurableMessage, error) { + return nil, nil +} + +// NoopExecutionLedger is a no-op execution ledger that always grants ownership. +type NoopExecutionLedger struct{} + +func (l *NoopExecutionLedger) TryAcquire(ctx context.Context, jobKey string, staleThreshold time.Duration) (bool, error) { + return true, nil +} +func (l *NoopExecutionLedger) MarkComplete(ctx context.Context, jobKey string, result string) error { + return nil +} +func (l *NoopExecutionLedger) MarkFailed(ctx context.Context, jobKey string, reason string) error { + return nil +} + +func (l *NoopExecutionLedger) GetStatus(ctx context.Context, jobKey string) (string, string, time.Time, error) { + return "", "", time.Time{}, nil +} // --- New No-Ops (for benchmarks and system tests) --- diff --git a/internal/repositories/postgres/execution_ledger.go b/internal/repositories/postgres/execution_ledger.go new file mode 100644 index 000000000..24f3e401b --- /dev/null +++ b/internal/repositories/postgres/execution_ledger.go @@ -0,0 +1,143 @@ +// Package postgres provides PostgreSQL-backed repository implementations. +package postgres + +import ( + "context" + "fmt" + "time" + + "github.com/jackc/pgx/v5" +) + +// PgExecutionLedger implements ports.ExecutionLedger using the job_executions table. +type PgExecutionLedger struct { + db DB +} + +// NewExecutionLedger creates a new Postgres-backed execution ledger. +func NewExecutionLedger(db DB) *PgExecutionLedger { + return &PgExecutionLedger{db: db} +} + +// TryAcquire attempts to claim a job execution. It uses INSERT ... ON CONFLICT +// to atomically check whether the job was already processed: +// +// - If no row exists, inserts status='running' and returns true. +// - If a 'completed' row exists, returns false (already done). +// - If a 'running' row exists and is newer than staleThreshold, returns false +// (another worker is actively processing). +// - If a 'running' row exists but is older than staleThreshold, reclaims it +// by updating started_at and returns true (previous worker likely crashed). +// - If a 'failed' row exists, reclaims it (allows retry). +func (l *PgExecutionLedger) TryAcquire(ctx context.Context, jobKey string, staleThreshold time.Duration) (bool, error) { + // Step 1: Try to insert a new row. + var inserted bool + err := l.db.QueryRow(ctx, ` + INSERT INTO job_executions (job_key, status, started_at) + VALUES ($1, 'running', NOW()) + ON CONFLICT (job_key) DO NOTHING + RETURNING TRUE + `, jobKey).Scan(&inserted) + + if err == nil && inserted { + return true, nil // Successfully claimed a brand-new execution. + } + // pgx returns ErrNoRows when INSERT ... ON CONFLICT DO NOTHING matches zero rows + if err != nil && err != pgx.ErrNoRows { + return false, fmt.Errorf("execution ledger insert %s: %w", jobKey, err) + } + + // Row already exists. Check its status. + var status string + var startedAt time.Time + err = l.db.QueryRow(ctx, ` + SELECT status, started_at FROM job_executions WHERE job_key = $1 + `, jobKey).Scan(&status, &startedAt) + if err != nil { + return false, fmt.Errorf("execution ledger check %s: %w", jobKey, err) + } + + switch status { + case "completed": + // Already done — skip. + return false, nil + case "running": + // Check if the running entry is stale (crashed worker). + if time.Since(startedAt) < staleThreshold { + return false, nil // Another worker is still processing. + } + // Reclaim the stale entry. Use optimistic locking on started_at to + // avoid racing with another reclaimer. + tag, err := l.db.Exec(ctx, ` + UPDATE job_executions + SET started_at = NOW(), status = 'running' + WHERE job_key = $1 AND status = 'running' AND started_at = $2 + `, jobKey, startedAt) + if err != nil { + return false, fmt.Errorf("execution ledger reclaim %s: %w", jobKey, err) + } + return tag.RowsAffected() > 0, nil + case "failed": + // Retry a previously failed job. + tag, err := l.db.Exec(ctx, ` + UPDATE job_executions + SET started_at = NOW(), status = 'running', completed_at = NULL, result = NULL + WHERE job_key = $1 AND status = 'failed' + `, jobKey) + if err != nil { + return false, fmt.Errorf("execution ledger retry %s: %w", jobKey, err) + } + return tag.RowsAffected() > 0, nil + default: + return false, fmt.Errorf("execution ledger unknown status %q for %s", status, jobKey) + } +} + +// MarkComplete marks a job as successfully completed. +func (l *PgExecutionLedger) MarkComplete(ctx context.Context, jobKey string, result string) error { + tag, err := l.db.Exec(ctx, ` + UPDATE job_executions + SET status = 'completed', completed_at = NOW(), result = $2 + WHERE job_key = $1 AND status = 'running' + `, jobKey, result) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return fmt.Errorf("execution ledger mark complete %s: no running row updated", jobKey) + } + return nil +} + +// MarkFailed marks a job as failed, allowing future retries. +func (l *PgExecutionLedger) MarkFailed(ctx context.Context, jobKey string, reason string) error { + tag, err := l.db.Exec(ctx, ` + UPDATE job_executions + SET status = 'failed', completed_at = NOW(), result = $2 + WHERE job_key = $1 AND status = 'running' + `, jobKey, reason) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return fmt.Errorf("execution ledger mark failed %s: no running row updated", jobKey) + } + return nil +} + +// GetStatus returns the current status, result and start time of a job. +func (l *PgExecutionLedger) GetStatus(ctx context.Context, jobKey string) (status string, result string, startedAt time.Time, err error) { + var res pgx.Row + res = l.db.QueryRow(ctx, ` + SELECT status, COALESCE(result, ''), started_at FROM job_executions WHERE job_key = $1 + `, jobKey) + + err = res.Scan(&status, &result, &startedAt) + if err != nil { + if err == pgx.ErrNoRows { + return "", "", time.Time{}, nil + } + return "", "", time.Time{}, fmt.Errorf("execution ledger get status %s: %w", jobKey, err) + } + return status, result, startedAt, nil +} diff --git a/internal/repositories/postgres/leader_elector.go b/internal/repositories/postgres/leader_elector.go new file mode 100644 index 000000000..ac99bb0cb --- /dev/null +++ b/internal/repositories/postgres/leader_elector.go @@ -0,0 +1,188 @@ +// Package postgres provides PostgreSQL-backed repository implementations. +package postgres + +import ( + "context" + "fmt" + "hash/fnv" + "log/slog" + "sync" + "time" +) + +const ( + // leaderRenewInterval is how often the leader renews its lock heartbeat. + leaderRenewInterval = 5 * time.Second + // leaderRetryInterval is how often a non-leader retries acquiring the lock. + leaderRetryInterval = 10 * time.Second +) + +// PgLeaderElector implements ports.LeaderElector using Postgres session-level advisory locks. +// Each leader key is hashed to a 64-bit integer used as the advisory lock ID. +// The lock is session-scoped: held as long as the DB connection is alive. +type PgLeaderElector struct { + db DB + logger *slog.Logger + mu sync.Mutex + held map[string]bool // tracks which keys this instance holds +} + +// NewPgLeaderElector creates a leader elector backed by Postgres advisory locks. +func NewPgLeaderElector(db DB, logger *slog.Logger) *PgLeaderElector { + return &PgLeaderElector{ + db: db, + logger: logger, + held: make(map[string]bool), + } +} + +// keyToLockID deterministically maps a string key to a 64-bit advisory lock ID. +func keyToLockID(key string) int64 { + h := fnv.New64a() + _, _ = h.Write([]byte(key)) + // Ensure positive value for pg advisory lock (avoids negative lock IDs). + return int64(h.Sum64() & 0x7FFFFFFFFFFFFFFF) +} + +// Acquire attempts to acquire the advisory lock for the given key. +// Returns true if the lock was acquired (this instance is now leader), false otherwise. +// Uses pg_try_advisory_lock which is non-blocking. +func (e *PgLeaderElector) Acquire(ctx context.Context, key string) (bool, error) { + lockID := keyToLockID(key) + var acquired bool + err := e.db.QueryRow(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&acquired) + if err != nil { + return false, fmt.Errorf("leader election acquire failed for key %q: %w", key, err) + } + + e.mu.Lock() + if acquired { + e.held[key] = true + } + e.mu.Unlock() + + return acquired, nil +} + +// Release explicitly releases the advisory lock for the given key. +func (e *PgLeaderElector) Release(ctx context.Context, key string) error { + lockID := keyToLockID(key) + _, err := e.db.Exec(ctx, "SELECT pg_advisory_unlock($1)", lockID) + if err != nil { + return fmt.Errorf("leader election release failed for key %q: %w", key, err) + } + + e.mu.Lock() + delete(e.held, key) + e.mu.Unlock() + + return nil +} + +// RunAsLeader blocks until leadership is acquired, then executes fn. +// If the parent context is cancelled, it stops trying and returns. +// When fn returns (or panics), leadership is released. +// +// The fn receives a child context that is cancelled if: +// - the parent context is cancelled +// - the periodic heartbeat detects the lock was lost +func (e *PgLeaderElector) RunAsLeader(ctx context.Context, key string, fn func(ctx context.Context) error) error { + // Phase 1: Acquire leadership (retry loop) + for { + if ctx.Err() != nil { + return ctx.Err() + } + + acquired, err := e.Acquire(ctx, key) + if err != nil { + e.logger.Warn("leader election attempt failed, retrying", + "key", key, "error", err, "retry_in", leaderRetryInterval) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(leaderRetryInterval): + continue + } + } + + if acquired { + e.logger.Info("acquired leadership", "key", key) + break + } + + e.logger.Debug("leadership not acquired, another instance is leader", "key", key) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(leaderRetryInterval): + } + } + + // Phase 2: Run fn with a context that gets cancelled if leadership is lost + fnCtx, fnCancel := context.WithCancel(ctx) + defer fnCancel() + defer func() { + releaseCtx, releaseCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer releaseCancel() + if err := e.Release(releaseCtx, key); err != nil { + e.logger.Error("failed to release leadership", "key", key, "error", err) + } + }() + + // Start heartbeat goroutine to verify we still hold the lock + heartbeatDone := make(chan struct{}) + go func() { + defer close(heartbeatDone) + e.heartbeat(fnCtx, key, fnCancel) + }() + + // Run the actual worker function + err := fn(fnCtx) + + // Wait for heartbeat to stop + fnCancel() + <-heartbeatDone + + return err +} + +// heartbeat periodically checks that we still hold the advisory lock. +// If the lock is lost (e.g., DB connection reset), it cancels the fn context. +func (e *PgLeaderElector) heartbeat(ctx context.Context, key string, cancel context.CancelFunc) { + ticker := time.NewTicker(leaderRenewInterval) + defer ticker.Stop() + + lockID := keyToLockID(key) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Check if we still hold the lock by trying to acquire it again. + // pg_try_advisory_lock is re-entrant: if we already hold it, it returns true + // and increments the lock count. We immediately unlock the extra acquisition. + var stillHeld bool + err := e.db.QueryRow(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&stillHeld) + if err != nil { + e.logger.Error("heartbeat check failed, assuming leadership lost", "key", key, "error", err) + cancel() + return + } + if stillHeld { + // We re-acquired (re-entrant), so unlock the extra lock count + if _, unlockErr := e.db.Exec(ctx, "SELECT pg_advisory_unlock($1)", lockID); unlockErr != nil { + e.logger.Error("failed to release re-entrant heartbeat lock", + "key", key, "error", unlockErr) + cancel() + return + } + } else { + // We lost the lock + e.logger.Error("leadership lost", "key", key) + cancel() + return + } + } + } +} diff --git a/internal/repositories/postgres/leader_elector_test.go b/internal/repositories/postgres/leader_elector_test.go new file mode 100644 index 000000000..91c7d149c --- /dev/null +++ b/internal/repositories/postgres/leader_elector_test.go @@ -0,0 +1,53 @@ +package postgres + +import ( + "testing" +) + +func TestKeyToLockIDDeterministic(t *testing.T) { + key := "singleton:lb" + id1 := keyToLockID(key) + id2 := keyToLockID(key) + if id1 != id2 { + t.Fatalf("expected same lock ID for same key, got %d and %d", id1, id2) + } +} + +func TestKeyToLockIDUnique(t *testing.T) { + keys := []string{ + "singleton:lb", + "singleton:cron", + "singleton:autoscaling", + "singleton:container", + "singleton:healing", + "singleton:db-failover", + "singleton:cluster-reconciler", + "singleton:replica-monitor", + "singleton:lifecycle", + "singleton:log", + "singleton:accounting", + } + + seen := make(map[int64]string) + for _, k := range keys { + id := keyToLockID(k) + if id <= 0 { + t.Fatalf("expected positive lock ID for key %q, got %d", k, id) + } + if existing, ok := seen[id]; ok { + t.Fatalf("lock ID collision: key %q and %q both map to %d", k, existing, id) + } + seen[id] = k + } +} + +func TestKeyToLockIDPositive(t *testing.T) { + // Ensure the masking produces positive values + testKeys := []string{"a", "b", "test", "singleton:anything", ""} + for _, k := range testKeys { + id := keyToLockID(k) + if id < 0 { + t.Fatalf("expected non-negative lock ID for key %q, got %d", k, id) + } + } +} diff --git a/internal/repositories/postgres/migrations/100_create_job_executions.down.sql b/internal/repositories/postgres/migrations/100_create_job_executions.down.sql new file mode 100644 index 000000000..78cffd523 --- /dev/null +++ b/internal/repositories/postgres/migrations/100_create_job_executions.down.sql @@ -0,0 +1,2 @@ +-- +goose Down +DROP TABLE IF EXISTS job_executions; diff --git a/internal/repositories/postgres/migrations/100_create_job_executions.up.sql b/internal/repositories/postgres/migrations/100_create_job_executions.up.sql new file mode 100644 index 000000000..d06135193 --- /dev/null +++ b/internal/repositories/postgres/migrations/100_create_job_executions.up.sql @@ -0,0 +1,14 @@ +-- +goose Up +CREATE TABLE IF NOT EXISTS job_executions ( + job_key TEXT PRIMARY KEY, + status TEXT NOT NULL DEFAULT 'running', -- running | completed | failed + started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + completed_at TIMESTAMPTZ, + result TEXT, + -- Allow stale locks to be reclaimed: if a worker crashes while status='running', + -- another worker can take over after started_at + timeout has elapsed. + -- The timeout is enforced in application code, not in the schema. + CONSTRAINT job_executions_status_check CHECK (status IN ('running', 'completed', 'failed')) +); + +CREATE INDEX IF NOT EXISTS idx_job_executions_status ON job_executions (status) WHERE status = 'running'; diff --git a/internal/repositories/redis/durable_task_queue.go b/internal/repositories/redis/durable_task_queue.go new file mode 100644 index 000000000..0aae2c066 --- /dev/null +++ b/internal/repositories/redis/durable_task_queue.go @@ -0,0 +1,244 @@ +// Package redis implements Redis-based repositories and data structures. +package redis + +import ( + "context" + "encoding/json" + stdlib_errors "errors" + "fmt" + "time" + + "github.com/poyrazk/thecloud/internal/core/ports" + "github.com/redis/go-redis/v9" +) + +// durableTaskQueue implements ports.DurableTaskQueue using Redis Streams +// and consumer groups for at-least-once delivery semantics. +type durableTaskQueue struct { + client *redis.Client + blockTime time.Duration // how long Receive blocks waiting for new messages + maxRetries int64 // max delivery attempts before a message is dead-lettered + dlqSuffix string // suffix appended to queue name for the dead-letter stream +} + +// DurableQueueOption configures a durableTaskQueue. +type DurableQueueOption func(*durableTaskQueue) + +// WithBlockTime sets the Receive block duration (default 5s). +func WithBlockTime(d time.Duration) DurableQueueOption { + return func(q *durableTaskQueue) { q.blockTime = d } +} + +// WithMaxRetries sets the max delivery count before dead-lettering (default 5). +func WithMaxRetries(n int64) DurableQueueOption { + return func(q *durableTaskQueue) { q.maxRetries = n } +} + +// WithDLQSuffix sets the dead-letter queue suffix (default ":dlq"). +func WithDLQSuffix(s string) DurableQueueOption { + return func(q *durableTaskQueue) { q.dlqSuffix = s } +} + +// NewDurableTaskQueue creates a Redis Streams–backed durable task queue. +func NewDurableTaskQueue(client *redis.Client, opts ...DurableQueueOption) *durableTaskQueue { + q := &durableTaskQueue{ + client: client, + blockTime: 5 * time.Second, + maxRetries: 5, + dlqSuffix: ":dlq", + } + for _, o := range opts { + o(q) + } + return q +} + +// ---------- TaskQueue (backward-compatible) ---------- + +func (q *durableTaskQueue) Enqueue(ctx context.Context, queueName string, payload interface{}) error { + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("durable enqueue marshal: %w", err) + } + return q.client.XAdd(ctx, &redis.XAddArgs{ + Stream: queueName, + Values: map[string]interface{}{"payload": string(data)}, + }).Err() +} + +func (q *durableTaskQueue) Dequeue(ctx context.Context, queueName string) (string, error) { + // Legacy fallback: reads from the stream without consumer groups (XREAD). + // New consumers should use Receive instead. + res, err := q.client.XRead(ctx, &redis.XReadArgs{ + Streams: []string{queueName, "0-0"}, + Count: 1, + Block: q.blockTime, + }).Result() + if err != nil { + if stdlib_errors.Is(err, redis.Nil) { + return "", nil + } + return "", err + } + if len(res) == 0 || len(res[0].Messages) == 0 { + return "", nil + } + msg := res[0].Messages[0] + // Auto-delete since legacy callers don't ack. + deleted, delErr := q.client.XDel(ctx, queueName, msg.ID).Result() + if delErr != nil { + return "", fmt.Errorf("durable dequeue xdel %s/%s: %w", queueName, msg.ID, delErr) + } + if deleted == 0 { + return "", fmt.Errorf("durable dequeue xdel %s/%s: no message deleted", queueName, msg.ID) + } + payload, _ := msg.Values["payload"].(string) + return payload, nil +} + +// ---------- DurableTaskQueue ---------- + +func (q *durableTaskQueue) EnsureGroup(ctx context.Context, queueName, groupName string) error { + err := q.client.XGroupCreateMkStream(ctx, queueName, groupName, "0").Err() + if err != nil { + // "BUSYGROUP Consumer Group name already exists" is harmless at startup. + if isGroupExistsErr(err) { + return nil + } + return fmt.Errorf("ensure group %s/%s: %w", queueName, groupName, err) + } + return nil +} + +func (q *durableTaskQueue) Receive(ctx context.Context, queueName, groupName, consumerName string) (*ports.DurableMessage, error) { + res, err := q.client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: groupName, + Consumer: consumerName, + Streams: []string{queueName, ">"}, + Count: 1, + Block: q.blockTime, + }).Result() + if err != nil { + if stdlib_errors.Is(err, redis.Nil) { + return nil, nil + } + return nil, fmt.Errorf("receive from %s/%s: %w", queueName, groupName, err) + } + if len(res) == 0 || len(res[0].Messages) == 0 { + return nil, nil + } + + xmsg := res[0].Messages[0] + payload, _ := xmsg.Values["payload"].(string) + return &ports.DurableMessage{ + ID: xmsg.ID, + Payload: payload, + Queue: queueName, + }, nil +} + +func (q *durableTaskQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { + acked, err := q.client.XAck(ctx, queueName, groupName, messageID).Result() + if err != nil { + return fmt.Errorf("ack %s/%s/%s: %w", queueName, groupName, messageID, err) + } + if acked == 0 { + return fmt.Errorf("ack %s/%s/%s: message not pending", queueName, groupName, messageID) + } + + deleted, delErr := q.client.XDel(ctx, queueName, messageID).Result() + if delErr != nil { + return fmt.Errorf("ack xdel %s/%s: %w", queueName, messageID, delErr) + } + if deleted == 0 { + return fmt.Errorf("ack xdel %s/%s: no message deleted", queueName, messageID) + } + + return nil +} + +func (q *durableTaskQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { + // In Redis Streams, un-acknowledged messages remain in the PEL (Pending + // Entries List) automatically. Nack is a no-op — the message will be + // reclaimed by ReclaimStale after the idle timeout. + // + // Future enhancement: we could XCLAIM the message back to a retry consumer + // immediately, but the idle-reclaim approach is simpler and sufficient. + return nil +} + +func (q *durableTaskQueue) ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]ports.DurableMessage, error) { + // XAUTOCLAIM atomically claims messages idle > minIdleMs and returns them. + msgs, _, err := q.client.XAutoClaim(ctx, &redis.XAutoClaimArgs{ + Stream: queueName, + Group: groupName, + Consumer: consumerName, + MinIdle: time.Duration(minIdleMs) * time.Millisecond, + Start: "0-0", + Count: count, + }).Result() + if err != nil { + return nil, fmt.Errorf("reclaim stale from %s/%s: %w", queueName, groupName, err) + } + + out := make([]ports.DurableMessage, 0, len(msgs)) + for _, xmsg := range msgs { + payload, _ := xmsg.Values["payload"].(string) + + // Dead-letter messages that exceeded max retries. + if xmsg.DeliveredCount > 0 && xmsg.DeliveredCount > q.maxRetries { + if dlqErr := q.deadLetter(ctx, queueName, groupName, xmsg); dlqErr != nil { + return nil, fmt.Errorf("dead-letter %s/%s/%s: %w", queueName, groupName, xmsg.ID, dlqErr) + } + continue + } + + out = append(out, ports.DurableMessage{ + ID: xmsg.ID, + Payload: payload, + Queue: queueName, + }) + } + return out, nil +} + +// deadLetter moves a message to the dead-letter stream and acks the original. +func (q *durableTaskQueue) deadLetter(ctx context.Context, queueName, groupName string, msg redis.XMessage) error { + dlq := queueName + q.dlqSuffix + payload, _ := msg.Values["payload"].(string) + pipe := q.client.Pipeline() + pipe.XAdd(ctx, &redis.XAddArgs{ + Stream: dlq, + Values: map[string]interface{}{ + "payload": payload, + "original_id": msg.ID, + "queue": queueName, + }, + }) + pipe.XAck(ctx, queueName, groupName, msg.ID) + pipe.XDel(ctx, queueName, msg.ID) + _, err := pipe.Exec(ctx) + return err +} + +// isGroupExistsErr returns true when the error indicates the consumer group +// already exists (Redis returns BUSYGROUP). +func isGroupExistsErr(err error) bool { + if err == nil { + return false + } + return containsBusyGroup(err.Error()) +} + +func containsBusyGroup(s string) bool { + return len(s) >= 9 && (s[:9] == "BUSYGROUP" || containsSubstring(s, "BUSYGROUP")) +} + +func containsSubstring(s, sub string) bool { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/repositories/redis/durable_task_queue_test.go b/internal/repositories/redis/durable_task_queue_test.go new file mode 100644 index 000000000..11b101838 --- /dev/null +++ b/internal/repositories/redis/durable_task_queue_test.go @@ -0,0 +1,245 @@ +package redis + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +func newTestDurableQueue(t *testing.T) (*durableTaskQueue, *miniredis.Miniredis) { + t.Helper() + s, err := miniredis.Run() + if err != nil { + t.Fatalf("failed to start miniredis: %v", err) + } + client := redis.NewClient(&redis.Options{Addr: s.Addr()}) + q := NewDurableTaskQueue(client, WithBlockTime(100*time.Millisecond), WithMaxRetries(3)) + return q, s +} + +func TestDurableEnqueue(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + payload := map[string]string{"instance_id": "abc-123"} + + if err := q.Enqueue(ctx, "test_stream", payload); err != nil { + t.Fatalf("Enqueue failed: %v", err) + } + + // Verify stream has one entry + entries, err := s.Stream("test_stream") + if err != nil { + t.Fatalf("Stream read failed: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 stream entry, got %d", len(entries)) + } +} + +func TestDurableEnsureGroupIdempotent(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + // Should not error even if stream doesn't exist yet (MkStream). + if err := q.EnsureGroup(ctx, "test_stream", "workers"); err != nil { + t.Fatalf("first EnsureGroup failed: %v", err) + } + // Calling again should be idempotent (BUSYGROUP). + if err := q.EnsureGroup(ctx, "test_stream", "workers"); err != nil { + t.Fatalf("second EnsureGroup failed: %v", err) + } +} + +func TestDurableReceiveAndAck(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "provision_queue" + group := "workers" + consumer := "worker-1" + + // Setup + if err := q.EnsureGroup(ctx, queue, group); err != nil { + t.Fatalf("EnsureGroup: %v", err) + } + + // Enqueue a job + job := map[string]string{"instance_id": "inst-001"} + if err := q.Enqueue(ctx, queue, job); err != nil { + t.Fatalf("Enqueue: %v", err) + } + + // Receive it + msg, err := q.Receive(ctx, queue, group, consumer) + if err != nil { + t.Fatalf("Receive: %v", err) + } + if msg == nil { + t.Fatal("expected message, got nil") + } + if msg.Queue != queue { + t.Fatalf("expected queue %q, got %q", queue, msg.Queue) + } + + // Verify payload round-trips + var got map[string]string + if err := json.Unmarshal([]byte(msg.Payload), &got); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if got["instance_id"] != "inst-001" { + t.Fatalf("expected instance_id inst-001, got %s", got["instance_id"]) + } + + // Ack it + if err := q.Ack(ctx, queue, group, msg.ID); err != nil { + t.Fatalf("Ack: %v", err) + } + + // Receive again — should be empty + msg2, err := q.Receive(ctx, queue, group, consumer) + if err != nil { + t.Fatalf("second Receive: %v", err) + } + if msg2 != nil { + t.Fatalf("expected nil after ack, got %+v", msg2) + } +} + +func TestDurableReceiveEmptyReturnsNil(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "empty_stream" + group := "workers" + + if err := q.EnsureGroup(ctx, queue, group); err != nil { + t.Fatalf("EnsureGroup: %v", err) + } + + msg, err := q.Receive(ctx, queue, group, "worker-1") + if err != nil { + t.Fatalf("Receive: %v", err) + } + if msg != nil { + t.Fatalf("expected nil message from empty stream, got %+v", msg) + } +} + +func TestDurableMultipleConsumersGetDifferentMessages(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "multi_consumer" + group := "workers" + + if err := q.EnsureGroup(ctx, queue, group); err != nil { + t.Fatalf("EnsureGroup: %v", err) + } + + // Enqueue two messages + if err := q.Enqueue(ctx, queue, map[string]string{"id": "1"}); err != nil { + t.Fatalf("Enqueue 1: %v", err) + } + if err := q.Enqueue(ctx, queue, map[string]string{"id": "2"}); err != nil { + t.Fatalf("Enqueue 2: %v", err) + } + + // Two consumers each get one + msg1, err := q.Receive(ctx, queue, group, "worker-1") + if err != nil || msg1 == nil { + t.Fatalf("worker-1 Receive: msg=%v err=%v", msg1, err) + } + msg2, err := q.Receive(ctx, queue, group, "worker-2") + if err != nil || msg2 == nil { + t.Fatalf("worker-2 Receive: msg=%v err=%v", msg2, err) + } + + if msg1.ID == msg2.ID { + t.Fatalf("both consumers got the same message ID: %s", msg1.ID) + } +} + +func TestDurableNackKeepsMessagePending(t *testing.T) { + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "nack_test" + group := "workers" + consumer := "worker-1" + + if err := q.EnsureGroup(ctx, queue, group); err != nil { + t.Fatalf("EnsureGroup: %v", err) + } + if err := q.Enqueue(ctx, queue, map[string]string{"id": "1"}); err != nil { + t.Fatalf("Enqueue: %v", err) + } + + msg, err := q.Receive(ctx, queue, group, consumer) + if err != nil || msg == nil { + t.Fatalf("Receive: msg=%v err=%v", msg, err) + } + + // Nack (no-op in Redis Streams — message stays in PEL) + if err := q.Nack(ctx, queue, group, msg.ID); err != nil { + t.Fatalf("Nack: %v", err) + } + + // The message should still be pending (not acked). + // Verify via XPending. + pending, err := q.client.XPending(ctx, queue, group).Result() + if err != nil { + t.Fatalf("XPending: %v", err) + } + if pending.Count != 1 { + t.Fatalf("expected 1 pending message, got %d", pending.Count) + } +} + +func TestDurableDeadLetterOnDequeue(t *testing.T) { + // This tests the legacy Dequeue path for backward compatibility. + q, s := newTestDurableQueue(t) + defer s.Close() + + ctx := context.Background() + queue := "legacy_dequeue" + + if err := q.Enqueue(ctx, queue, map[string]string{"legacy": "true"}); err != nil { + t.Fatalf("Enqueue: %v", err) + } + + msg, err := q.Dequeue(ctx, queue) + if err != nil { + t.Fatalf("Dequeue: %v", err) + } + if msg == "" { + t.Fatal("expected non-empty legacy dequeue result") + } + + var got map[string]string + if err := json.Unmarshal([]byte(msg), &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got["legacy"] != "true" { + t.Fatalf("expected legacy=true, got %s", got["legacy"]) + } + + // Stream should be empty after legacy Dequeue (auto-deleted) + entries, err := s.Stream(queue) + if err != nil { + t.Fatalf("Stream: %v", err) + } + if len(entries) != 0 { + t.Fatalf("expected empty stream after legacy dequeue, got %d entries", len(entries)) + } +} diff --git a/internal/workers/cluster_worker.go b/internal/workers/cluster_worker.go index 591559dd5..cb88896f3 100644 --- a/internal/workers/cluster_worker.go +++ b/internal/workers/cluster_worker.go @@ -4,7 +4,9 @@ package workers import ( "context" "encoding/json" + "fmt" "log/slog" + "os" "sync" "time" @@ -13,34 +15,61 @@ import ( "github.com/poyrazk/thecloud/internal/core/ports" ) +const ( + clusterQueue = "k8s_jobs" + clusterGroup = "cluster_workers" + clusterMaxWorkers = 10 + clusterReclaimMs = 5 * 60 * 1000 // 5 minutes + clusterReclaimN = 10 + clusterStaleThreshold = 15 * time.Minute + clusterReceiveBackoff = 1 * time.Second +) + // ClusterWorker handles background tasks for Kubernetes cluster lifecycle management. type ClusterWorker struct { - repo ports.ClusterRepository - provisioner ports.ClusterProvisioner - taskQueue ports.TaskQueue - logger *slog.Logger + repo ports.ClusterRepository + provisioner ports.ClusterProvisioner + taskQueue ports.DurableTaskQueue + ledger ports.ExecutionLedger + logger *slog.Logger + consumerName string } // NewClusterWorker creates a new ClusterWorker. -func NewClusterWorker(repo ports.ClusterRepository, provisioner ports.ClusterProvisioner, taskQueue ports.TaskQueue, logger *slog.Logger) *ClusterWorker { +func NewClusterWorker(repo ports.ClusterRepository, provisioner ports.ClusterProvisioner, taskQueue ports.DurableTaskQueue, ledger ports.ExecutionLedger, logger *slog.Logger) *ClusterWorker { + hostname, err := os.Hostname() + if err != nil { + logger.Warn("failed to get hostname, using fallback", "error", err) + hostname = "cluster-worker" + } + if hostname == "" { + hostname = "cluster-worker" + } return &ClusterWorker{ - repo: repo, - provisioner: provisioner, - taskQueue: taskQueue, - logger: logger, + repo: repo, + provisioner: provisioner, + taskQueue: taskQueue, + ledger: ledger, + logger: logger, + consumerName: hostname, } } -const ( - queuePollBackoff = 1 * time.Second - maxConcurrentClusts = 10 -) - func (w *ClusterWorker) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - w.logger.Info("starting cluster worker", "concurrency", maxConcurrentClusts) + w.logger.Info("starting cluster worker", + "consumer", w.consumerName, + "concurrency", clusterMaxWorkers, + ) + + if err := w.taskQueue.EnsureGroup(ctx, clusterQueue, clusterGroup); err != nil { + w.logger.Error("failed to ensure cluster consumer group", "error", err) + return + } - sem := make(chan struct{}, maxConcurrentClusts) + sem := make(chan struct{}, clusterMaxWorkers) + + go w.reclaimLoop(ctx, sem) for { select { @@ -48,59 +77,124 @@ func (w *ClusterWorker) Run(ctx context.Context, wg *sync.WaitGroup) { w.logger.Info("stopping cluster worker") return default: - msg, err := w.taskQueue.Dequeue(ctx, "k8s_jobs") + msg, err := w.taskQueue.Receive(ctx, clusterQueue, clusterGroup, w.consumerName) if err != nil { - w.logger.Error("failed to dequeue cluster job", "error", err) - time.Sleep(queuePollBackoff) + w.logger.Error("failed to receive cluster job", "error", err) + time.Sleep(clusterReceiveBackoff) continue } - if msg == "" { - time.Sleep(queuePollBackoff) + if msg == nil { continue } var job domain.ClusterJob - if err := json.Unmarshal([]byte(msg), &job); err != nil { - w.logger.Error("failed to unmarshal cluster job", "error", err) + if err := json.Unmarshal([]byte(msg.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal cluster job", + "error", err, "msg_id", msg.ID) + w.ackWithLog(ctx, msg.ID, "cluster poison message") continue } - w.logger.Info("processing cluster job", "cluster_id", job.ClusterID, "type", job.Type) + w.logger.Info("processing cluster job", + "cluster_id", job.ClusterID, + "type", job.Type, + "msg_id", msg.ID, + ) sem <- struct{}{} - go func() { + go func(m *ports.DurableMessage, j domain.ClusterJob) { defer func() { <-sem }() - w.processJob(job) - }() + w.processJob(ctx, m, j) + }(msg, job) } } } -func (w *ClusterWorker) processJob(job domain.ClusterJob) { - // Root context for background task - ctx := appcontext.WithUserID(context.Background(), job.UserID) +func (w *ClusterWorker) processJob(workerCtx context.Context, msg *ports.DurableMessage, job domain.ClusterJob) { + jobKey := fmt.Sprintf("cluster:%s:%s", job.Type, job.ClusterID) + + // Idempotency check. + if w.ledger != nil { + acquired, err := w.ledger.TryAcquire(workerCtx, jobKey, clusterStaleThreshold) + if err != nil { + w.logger.Error("execution ledger error", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", err) + w.nackWithLog(workerCtx, msg.ID, "ledger try_acquire failed") + return + } + if !acquired { + w.logger.Info("skipping duplicate cluster job", + "cluster_id", job.ClusterID, "type", job.Type, "msg_id", msg.ID) + w.ackWithLog(workerCtx, msg.ID, "duplicate cluster job") + return + } + } + + ctx := appcontext.WithUserID(workerCtx, job.UserID) cluster, err := w.repo.GetByID(ctx, job.ClusterID) if err != nil { - w.logger.Error("failed to fetch cluster for job", "cluster_id", job.ClusterID, "error", err) + w.logger.Error("failed to fetch cluster for job", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", err) + if w.ledger != nil { + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, err.Error()); ledgerErr != nil { + w.logger.Warn("failed to mark cluster job failed in ledger", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.nackWithLog(workerCtx, msg.ID, "cluster fetch failed") return } if cluster == nil { - w.logger.Error("cluster not found for job", "cluster_id", job.ClusterID) + w.logger.Error("cluster not found for job", + "cluster_id", job.ClusterID, "msg_id", msg.ID) + // Ack — cluster was deleted, nothing to do. + if w.ledger != nil { + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "cluster_not_found"); ledgerErr != nil { + w.logger.Warn("failed to mark cluster job complete in ledger", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.ackWithLog(workerCtx, msg.ID, "cluster not found") return } + var processErr error switch job.Type { case domain.ClusterJobProvision: - w.handleProvision(ctx, cluster) + processErr = w.handleProvision(ctx, cluster) case domain.ClusterJobDeprovision: - w.handleDeprovision(ctx, cluster) + processErr = w.handleDeprovision(ctx, cluster) case domain.ClusterJobUpgrade: - w.handleUpgrade(ctx, cluster, job.Version) + processErr = w.handleUpgrade(ctx, cluster, job.Version) + default: + processErr = fmt.Errorf("unsupported cluster job type %q for cluster %s", job.Type, job.ClusterID) + } + + if processErr != nil { + w.logger.Error("cluster job failed", + "cluster_id", job.ClusterID, "type", job.Type, + "msg_id", msg.ID, "error", processErr) + if w.ledger != nil { + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, processErr.Error()); ledgerErr != nil { + w.logger.Warn("failed to mark cluster job failed in ledger", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.nackWithLog(workerCtx, msg.ID, "cluster job processing failed") + return } + + if w.ledger != nil { + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "ok"); ledgerErr != nil { + w.logger.Warn("failed to mark cluster job complete in ledger", + "cluster_id", job.ClusterID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.ackWithLog(workerCtx, msg.ID, "cluster job success") } -func (w *ClusterWorker) handleProvision(ctx context.Context, cluster *domain.Cluster) { +func (w *ClusterWorker) handleProvision(ctx context.Context, cluster *domain.Cluster) error { cluster.Status = domain.ClusterStatusProvisioning cluster.UpdatedAt = time.Now() _ = w.repo.Update(ctx, cluster) @@ -108,50 +202,107 @@ func (w *ClusterWorker) handleProvision(ctx context.Context, cluster *domain.Clu if err := w.provisioner.Provision(ctx, cluster); err != nil { w.logger.Error("provisioning failed", "cluster_id", cluster.ID, "error", err) cluster.Status = domain.ClusterStatusFailed - } else { - w.logger.Info("provisioning succeeded", "cluster_id", cluster.ID) - cluster.Status = domain.ClusterStatusRunning + cluster.UpdatedAt = time.Now() + cluster.JobID = nil + _ = w.repo.Update(ctx, cluster) + return err } + w.logger.Info("provisioning succeeded", "cluster_id", cluster.ID) + cluster.Status = domain.ClusterStatusRunning cluster.UpdatedAt = time.Now() - cluster.JobID = nil // Clear job ID + cluster.JobID = nil _ = w.repo.Update(ctx, cluster) + return nil } -func (w *ClusterWorker) handleDeprovision(ctx context.Context, cluster *domain.Cluster) { +func (w *ClusterWorker) handleDeprovision(ctx context.Context, cluster *domain.Cluster) error { cluster.Status = domain.ClusterStatusDeleting cluster.UpdatedAt = time.Now() _ = w.repo.Update(ctx, cluster) if err := w.provisioner.Deprovision(ctx, cluster); err != nil { w.logger.Error("deprovisioning failed", "cluster_id", cluster.ID, "error", err) - // We might still mark it as failed or just leave it - } else { - w.logger.Info("deprovisioning succeeded", "cluster_id", cluster.ID) - _ = w.repo.Delete(ctx, cluster.ID) - return + cluster.UpdatedAt = time.Now() + cluster.JobID = nil + _ = w.repo.Update(ctx, cluster) + return err } - cluster.UpdatedAt = time.Now() - cluster.JobID = nil - _ = w.repo.Update(ctx, cluster) + w.logger.Info("deprovisioning succeeded", "cluster_id", cluster.ID) + _ = w.repo.Delete(ctx, cluster.ID) + return nil } -func (w *ClusterWorker) handleUpgrade(ctx context.Context, cluster *domain.Cluster, version string) { +func (w *ClusterWorker) handleUpgrade(ctx context.Context, cluster *domain.Cluster, version string) error { cluster.Status = domain.ClusterStatusUpgrading cluster.UpdatedAt = time.Now() _ = w.repo.Update(ctx, cluster) if err := w.provisioner.Upgrade(ctx, cluster, version); err != nil { w.logger.Error("upgrade failed", "cluster_id", cluster.ID, "error", err) - cluster.Status = domain.ClusterStatusRunning // Revert to running if failed - } else { - w.logger.Info("upgrade succeeded", "cluster_id", cluster.ID) cluster.Status = domain.ClusterStatusRunning - cluster.Version = version + cluster.UpdatedAt = time.Now() + cluster.JobID = nil + _ = w.repo.Update(ctx, cluster) + return err } + w.logger.Info("upgrade succeeded", "cluster_id", cluster.ID) + cluster.Status = domain.ClusterStatusRunning + cluster.Version = version cluster.UpdatedAt = time.Now() cluster.JobID = nil _ = w.repo.Update(ctx, cluster) + return nil +} + +func (w *ClusterWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + msgs, err := w.taskQueue.ReclaimStale(ctx, clusterQueue, clusterGroup, w.consumerName, clusterReclaimMs, clusterReclaimN) + if err != nil { + w.logger.Warn("cluster reclaim error", "error", err) + continue + } + for _, m := range msgs { + var job domain.ClusterJob + if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal reclaimed cluster job", + "msg_id", m.ID, "error", err) + w.ackWithLog(ctx, m.ID, "reclaimed cluster poison message") + continue + } + w.logger.Info("reclaimed stale cluster job", + "cluster_id", job.ClusterID, "msg_id", m.ID) + + m := m + sem <- struct{}{} + go func() { + defer func() { <-sem }() + w.processJob(ctx, &m, job) + }() + } + } + } +} + +func (w *ClusterWorker) ackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Ack(ctx, clusterQueue, clusterGroup, messageID); err != nil { + w.logger.Warn("failed to ack cluster job", + "msg_id", messageID, "reason", reason, "error", err) + } +} + +func (w *ClusterWorker) nackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Nack(ctx, clusterQueue, clusterGroup, messageID); err != nil { + w.logger.Warn("failed to nack cluster job", + "msg_id", messageID, "reason", reason, "error", err) + } } diff --git a/internal/workers/cluster_worker_test.go b/internal/workers/cluster_worker_test.go index 2bbcaccf9..050a4eebf 100644 --- a/internal/workers/cluster_worker_test.go +++ b/internal/workers/cluster_worker_test.go @@ -26,6 +26,37 @@ func (m *MockTaskQueue) Dequeue(ctx context.Context, queue string) (string, erro return args.String(0), args.Error(1) } +func (m *MockTaskQueue) EnsureGroup(ctx context.Context, queueName, groupName string) error { + args := m.Called(ctx, queueName, groupName) + return args.Error(0) +} + +func (m *MockTaskQueue) Receive(ctx context.Context, queueName, groupName, consumerName string) (*ports.DurableMessage, error) { + args := m.Called(ctx, queueName, groupName, consumerName) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*ports.DurableMessage), args.Error(1) +} + +func (m *MockTaskQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { + args := m.Called(ctx, queueName, groupName, messageID) + return args.Error(0) +} + +func (m *MockTaskQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { + args := m.Called(ctx, queueName, groupName, messageID) + return args.Error(0) +} + +func (m *MockTaskQueue) ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]ports.DurableMessage, error) { + args := m.Called(ctx, queueName, groupName, consumerName, minIdleMs, count) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]ports.DurableMessage), args.Error(1) +} + type MockClusterRepo struct{ mock.Mock } func (m *MockClusterRepo) Create(ctx context.Context, c *domain.Cluster) error { return nil } @@ -124,7 +155,7 @@ func TestClusterWorkerProcessProvisionJob(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -135,17 +166,20 @@ func TestClusterWorkerProcessProvisionJob(t *testing.T) { ClusterID: clusterID, UserID: userID, } + msg := &ports.DurableMessage{ID: "1-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(cluster, nil) repo.On("Update", mock.Anything, mock.MatchedBy(func(c *domain.Cluster) bool { return c.Status == domain.ClusterStatusProvisioning || c.Status == domain.ClusterStatusRunning })).Return(nil) prov.On("Provision", mock.Anything, cluster).Return(nil) + tq.On("Ack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) prov.AssertExpectations(t) + tq.AssertExpectations(t) } func TestClusterWorkerProcessDeprovisionJobSuccess(t *testing.T) { @@ -154,7 +188,7 @@ func TestClusterWorkerProcessDeprovisionJobSuccess(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -165,16 +199,19 @@ func TestClusterWorkerProcessDeprovisionJobSuccess(t *testing.T) { ClusterID: clusterID, UserID: userID, } + msg := &ports.DurableMessage{ID: "2-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(cluster, nil) repo.On("Update", mock.Anything, mock.AnythingOfType("*domain.Cluster")).Return(nil) prov.On("Deprovision", mock.Anything, cluster).Return(nil) repo.On("Delete", mock.Anything, clusterID).Return(nil) + tq.On("Ack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) prov.AssertExpectations(t) + tq.AssertExpectations(t) } func TestClusterWorkerProcessDeprovisionJobFailure(t *testing.T) { @@ -183,7 +220,7 @@ func TestClusterWorkerProcessDeprovisionJobFailure(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -194,15 +231,18 @@ func TestClusterWorkerProcessDeprovisionJobFailure(t *testing.T) { ClusterID: clusterID, UserID: userID, } + msg := &ports.DurableMessage{ID: "3-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(cluster, nil) repo.On("Update", mock.Anything, mock.AnythingOfType("*domain.Cluster")).Return(nil).Twice() prov.On("Deprovision", mock.Anything, cluster).Return(io.EOF) + tq.On("Nack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) prov.AssertExpectations(t) + tq.AssertExpectations(t) } func TestClusterWorkerProcessUpgradeJob(t *testing.T) { @@ -211,7 +251,7 @@ func TestClusterWorkerProcessUpgradeJob(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -224,17 +264,20 @@ func TestClusterWorkerProcessUpgradeJob(t *testing.T) { UserID: userID, Version: version, } + msg := &ports.DurableMessage{ID: "4-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(cluster, nil) repo.On("Update", mock.Anything, mock.MatchedBy(func(c *domain.Cluster) bool { return c.Status == domain.ClusterStatusUpgrading || c.Status == domain.ClusterStatusRunning })).Return(nil).Twice() prov.On("Upgrade", mock.Anything, cluster, version).Return(nil) + tq.On("Ack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) prov.AssertExpectations(t) + tq.AssertExpectations(t) } func TestClusterWorkerProcessJobClusterNotFound(t *testing.T) { @@ -243,7 +286,7 @@ func TestClusterWorkerProcessJobClusterNotFound(t *testing.T) { prov := new(MockProvisioner) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewClusterWorker(repo, prov, tq, logger) + worker := NewClusterWorker(repo, prov, tq, nil, logger) clusterID := uuid.New() userID := uuid.New() @@ -252,10 +295,14 @@ func TestClusterWorkerProcessJobClusterNotFound(t *testing.T) { ClusterID: clusterID, UserID: userID, } + msg := &ports.DurableMessage{ID: "5-0", Payload: "", Queue: clusterQueue} repo.On("GetByID", mock.Anything, clusterID).Return(nil, nil) + // Cluster not found -> ack to avoid infinite retries + tq.On("Ack", mock.Anything, clusterQueue, clusterGroup, msg.ID).Return(nil) - worker.processJob(job) + worker.processJob(context.Background(), msg, job) prov.AssertNotCalled(t, "Provision", mock.Anything, mock.Anything) + tq.AssertExpectations(t) } diff --git a/internal/workers/leader_guard.go b/internal/workers/leader_guard.go new file mode 100644 index 000000000..9cc2e0307 --- /dev/null +++ b/internal/workers/leader_guard.go @@ -0,0 +1,85 @@ +// Package workers provides background worker implementations. +package workers + +import ( + "context" + "log/slog" + "sync" + + "github.com/poyrazk/thecloud/internal/core/ports" +) + +// LeaderGuard wraps a worker that implements the Run(context.Context, *sync.WaitGroup) +// interface and ensures it only runs on the pod that holds leadership for its key. +// +// When leadership is not held, the worker is paused. If leadership is lost mid-run, +// the worker's context is cancelled, causing it to stop. It will restart if +// leadership is re-acquired. +type LeaderGuard struct { + elector ports.LeaderElector + key string + inner runner + logger *slog.Logger +} + +// runner is the interface all workers implement. +type runner interface { + Run(context.Context, *sync.WaitGroup) +} + +// NewLeaderGuard creates a LeaderGuard that protects the given worker with leader election. +// The key should be unique per worker type (e.g., "worker:lb", "worker:cron"). +func NewLeaderGuard(elector ports.LeaderElector, key string, inner runner, logger *slog.Logger) *LeaderGuard { + return &LeaderGuard{ + elector: elector, + key: key, + inner: inner, + logger: logger, + } +} + +// Run implements the runner interface. It participates in leader election and only +// runs the inner worker when this instance is the leader. If leadership is lost, +// the inner worker is stopped. If leadership is re-acquired, the inner worker restarts. +func (g *LeaderGuard) Run(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + + for { + if ctx.Err() != nil { + return + } + + g.logger.Info("attempting to acquire leadership", "key", g.key) + + err := g.elector.RunAsLeader(ctx, g.key, func(leaderCtx context.Context) error { + g.logger.Info("running as leader", "key", g.key) + + // Create an inner WaitGroup for the wrapped worker + innerWG := &sync.WaitGroup{} + innerWG.Add(1) + go g.inner.Run(leaderCtx, innerWG) + + // Wait for the inner worker to finish (either normally or due to context cancellation) + innerWG.Wait() + + g.logger.Info("inner worker stopped", "key", g.key) + return nil + }) + + if err != nil { + if ctx.Err() != nil { + // Parent context cancelled — clean shutdown + g.logger.Info("leader guard shutting down", "key", g.key) + return + } + g.logger.Error("leader election error, will retry", "key", g.key, "error", err) + } + + // If we reach here, we either lost leadership or RunAsLeader returned. + // Loop back to try to re-acquire leadership. + if ctx.Err() != nil { + return + } + g.logger.Info("leadership lost or released, retrying", "key", g.key) + } +} diff --git a/internal/workers/leader_guard_test.go b/internal/workers/leader_guard_test.go new file mode 100644 index 000000000..116f4081e --- /dev/null +++ b/internal/workers/leader_guard_test.go @@ -0,0 +1,189 @@ +package workers + +import ( + "context" + "io" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" +) + +// mockLeaderElector implements ports.LeaderElector for testing. +type mockLeaderElector struct { + acquireResult bool + acquireErr error + releaseErr error + acquireCount atomic.Int32 + releaseCount atomic.Int32 + + // When set, RunAsLeader immediately calls fn if acquireResult is true + runAsLeaderFn func(ctx context.Context, key string, fn func(ctx context.Context) error) error +} + +func (m *mockLeaderElector) Acquire(ctx context.Context, key string) (bool, error) { + m.acquireCount.Add(1) + return m.acquireResult, m.acquireErr +} + +func (m *mockLeaderElector) Release(ctx context.Context, key string) error { + m.releaseCount.Add(1) + return m.releaseErr +} + +func (m *mockLeaderElector) RunAsLeader(ctx context.Context, key string, fn func(ctx context.Context) error) error { + if m.runAsLeaderFn != nil { + return m.runAsLeaderFn(ctx, key, fn) + } + // Default: acquire leadership and run fn + if m.acquireResult { + return fn(ctx) + } + // Not leader, block until context cancelled + <-ctx.Done() + return ctx.Err() +} + +// mockRunner records whether Run was called and blocks until context is done. +type mockRunner struct { + runCalled atomic.Int32 + runCtx context.Context +} + +func (r *mockRunner) Run(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + r.runCalled.Add(1) + r.runCtx = ctx + <-ctx.Done() +} + +func newTestLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestLeaderGuardRunsInnerWorkerWhenLeader(t *testing.T) { + elector := &mockLeaderElector{acquireResult: true} + inner := &mockRunner{} + guard := NewLeaderGuard(elector, "test:worker", inner, newTestLogger()) + + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(1) + go guard.Run(ctx, wg) + + // Wait a bit for the inner worker to start + time.Sleep(100 * time.Millisecond) + + if inner.runCalled.Load() == 0 { + t.Fatal("expected inner worker to be started when leader") + } + + cancel() + wg.Wait() +} + +func TestLeaderGuardDoesNotRunWhenNotLeader(t *testing.T) { + elector := &mockLeaderElector{ + acquireResult: false, + runAsLeaderFn: func(ctx context.Context, key string, fn func(ctx context.Context) error) error { + // Simulate never becoming leader — block until cancelled + <-ctx.Done() + return ctx.Err() + }, + } + inner := &mockRunner{} + guard := NewLeaderGuard(elector, "test:worker", inner, newTestLogger()) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + wg := &sync.WaitGroup{} + wg.Add(1) + go guard.Run(ctx, wg) + + wg.Wait() + + if inner.runCalled.Load() != 0 { + t.Fatal("expected inner worker NOT to be started when not leader") + } +} + +func TestLeaderGuardRestartsAfterLeadershipLoss(t *testing.T) { + callCount := atomic.Int32{} + + elector := &mockLeaderElector{ + runAsLeaderFn: func(ctx context.Context, key string, fn func(ctx context.Context) error) error { + n := callCount.Add(1) + if n <= 2 { + // Simulate short leadership then loss + fnCtx, fnCancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer fnCancel() + return fn(fnCtx) + } + // Third time: block until parent context cancelled + <-ctx.Done() + return ctx.Err() + }, + } + + inner := &mockRunner{} + // Override mockRunner to not block + countingRunner := &countingMockRunner{} + guard := NewLeaderGuard(elector, "test:worker", countingRunner, newTestLogger()) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + wg := &sync.WaitGroup{} + wg.Add(1) + go guard.Run(ctx, wg) + + wg.Wait() + _ = inner // unused, countingRunner is used instead + + runs := countingRunner.runCalled.Load() + if runs < 2 { + t.Fatalf("expected inner worker to be restarted at least 2 times after leadership loss, got %d", runs) + } +} + +// countingMockRunner counts Run calls but returns quickly when context is done. +type countingMockRunner struct { + runCalled atomic.Int32 +} + +func (r *countingMockRunner) Run(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + r.runCalled.Add(1) + <-ctx.Done() +} + +func TestLeaderGuardShutsDownCleanly(t *testing.T) { + elector := &mockLeaderElector{acquireResult: true} + inner := &mockRunner{} + guard := NewLeaderGuard(elector, "test:worker", inner, newTestLogger()) + + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(1) + go guard.Run(ctx, wg) + + // Let it start + time.Sleep(50 * time.Millisecond) + + // Cancel and wait for clean shutdown + cancel() + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success — clean shutdown + case <-time.After(2 * time.Second): + t.Fatal("leader guard did not shut down within 2s") + } +} diff --git a/internal/workers/pipeline_worker.go b/internal/workers/pipeline_worker.go index dfce93696..b5e3d8e3f 100644 --- a/internal/workers/pipeline_worker.go +++ b/internal/workers/pipeline_worker.go @@ -4,8 +4,10 @@ package workers import ( "context" "encoding/json" + "fmt" "io" "log/slog" + "os" "strings" "sync" "time" @@ -16,29 +18,59 @@ import ( "github.com/poyrazk/thecloud/internal/core/ports" ) -const pipelineQueueName = "pipeline_build_queue" +const ( + pipelineQueueName = "pipeline_build_queue" + pipelineGroup = "pipeline_workers" + pipelineMaxWorkers = 5 + pipelineReclaimMs = 10 * 60 * 1000 // 10 minutes (builds are longer) + pipelineReclaimN = 5 + // Stale threshold for idempotency ledger: builds can take up to 30 min, + // so a "running" entry older than this is considered abandoned. + pipelineStaleThreshold = 35 * time.Minute +) // PipelineWorker processes queued pipeline builds. type PipelineWorker struct { - repo ports.PipelineRepository - taskQueue ports.TaskQueue - compute ports.ComputeBackend - logger *slog.Logger + repo ports.PipelineRepository + taskQueue ports.DurableTaskQueue + ledger ports.ExecutionLedger + compute ports.ComputeBackend + logger *slog.Logger + consumerName string } // NewPipelineWorker creates a new PipelineWorker. -func NewPipelineWorker(repo ports.PipelineRepository, taskQueue ports.TaskQueue, compute ports.ComputeBackend, logger *slog.Logger) *PipelineWorker { +// If ledger is nil, idempotency checks are skipped. +func NewPipelineWorker(repo ports.PipelineRepository, taskQueue ports.DurableTaskQueue, ledger ports.ExecutionLedger, compute ports.ComputeBackend, logger *slog.Logger) *PipelineWorker { + hostname, _ := os.Hostname() + if hostname == "" { + hostname = "pipeline-worker" + } return &PipelineWorker{ - repo: repo, - taskQueue: taskQueue, - compute: compute, - logger: logger, + repo: repo, + taskQueue: taskQueue, + ledger: ledger, + compute: compute, + logger: logger, + consumerName: hostname, } } func (w *PipelineWorker) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - w.logger.Info("starting pipeline worker") + w.logger.Info("starting pipeline worker", + "consumer", w.consumerName, + "concurrency", pipelineMaxWorkers, + ) + + if err := w.taskQueue.EnsureGroup(ctx, pipelineQueueName, pipelineGroup); err != nil { + w.logger.Error("failed to ensure pipeline consumer group", "error", err) + return + } + + sem := make(chan struct{}, pipelineMaxWorkers) + + go w.reclaimLoop(ctx, sem) for { select { @@ -46,68 +78,159 @@ func (w *PipelineWorker) Run(ctx context.Context, wg *sync.WaitGroup) { w.logger.Info("stopping pipeline worker") return default: - msg, err := w.taskQueue.Dequeue(ctx, pipelineQueueName) + msg, err := w.taskQueue.Receive(ctx, pipelineQueueName, pipelineGroup, w.consumerName) if err != nil { - w.logger.Error("failed to dequeue pipeline job", "error", err) + w.logger.Error("failed to receive pipeline job", "error", err) time.Sleep(1 * time.Second) continue } - if msg == "" { + if msg == nil { continue } var job domain.BuildJob - if err := json.Unmarshal([]byte(msg), &job); err != nil { - w.logger.Error("failed to unmarshal build job", "error", err) + if err := json.Unmarshal([]byte(msg.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal build job", + "error", err, "msg_id", msg.ID) + w.ackWithLog(ctx, msg.ID, "pipeline poison message") continue } - w.processJob(job) + sem <- struct{}{} + go func(m *ports.DurableMessage, j domain.BuildJob) { + defer func() { <-sem }() + w.processJob(ctx, m, j) + }(msg, job) } } } -func (w *PipelineWorker) processJob(job domain.BuildJob) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) +func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.DurableMessage, job domain.BuildJob) { + jobKey := fmt.Sprintf("pipeline:%s", job.BuildID) + + // Idempotency check: skip if already completed or actively being processed. + if w.ledger != nil { + acquired, err := w.ledger.TryAcquire(workerCtx, jobKey, pipelineStaleThreshold) + if err != nil { + w.logger.Error("execution ledger error", + "build_id", job.BuildID, "msg_id", msg.ID, "error", err) + w.nackWithLog(workerCtx, msg.ID, "ledger try_acquire failed") + return + } + if !acquired { + // Check if it's already finished or just being processed by someone else. + status, _, _, getErr := w.ledger.GetStatus(workerCtx, jobKey) + if getErr == nil && status == "completed" { + w.logger.Info("skipping already completed pipeline job", + "build_id", job.BuildID, "msg_id", msg.ID) + w.ackWithLog(workerCtx, msg.ID, "pipeline already completed") + return + } + w.logger.Info("pipeline job is currently being processed by another worker", + "build_id", job.BuildID, "msg_id", msg.ID) + return // Leave unacked for redelivery/wait. + } + } + + ctx, cancel := context.WithTimeout(workerCtx, 30*time.Minute) defer cancel() ctx = appcontext.WithUserID(ctx, job.UserID) - build, pipeline := w.loadBuildAndPipeline(ctx, job) + build, pipeline, err := w.loadBuildAndPipeline(ctx, job) + if err != nil { + // Transient error loading build/pipeline — nack and retry. + w.logger.Error("transient error loading build/pipeline", + "build_id", job.BuildID, "error", err) + if w.ledger != nil { + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, "transient load error"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job failed in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.nackWithLog(workerCtx, msg.ID, "transient pipeline load error") + return + } + if build == nil || pipeline == nil { + // Build or pipeline truly not found — ack to avoid infinite retries. + if w.ledger != nil { + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "not_found"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job complete in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.ackWithLog(workerCtx, msg.ID, "pipeline build/pipeline not found") return } if !w.markBuildRunning(ctx, build) { + if w.ledger != nil { + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, "failed to mark build running"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job failed in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.nackWithLog(workerCtx, msg.ID, "mark build running failed") return } if len(pipeline.Config.Stages) == 0 { w.failBuild(ctx, build, "pipeline has no stages") + if w.ledger != nil { + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "no_stages"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job complete in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.ackWithLog(workerCtx, msg.ID, "pipeline has no stages") return } if !w.executePipeline(ctx, build, pipeline) { + // Build failed but was processed — ack the message. + if w.ledger != nil { + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "build_failed"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job complete in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.ackWithLog(workerCtx, msg.ID, "pipeline execution failed") return } w.markBuildSucceeded(ctx, build) + + if w.ledger != nil { + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "ok"); ledgerErr != nil { + w.logger.Warn("failed to mark pipeline job complete in ledger", + "build_id", job.BuildID, "msg_id", msg.ID, "error", ledgerErr) + } + } + w.ackWithLog(workerCtx, msg.ID, "pipeline job success") } -func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.BuildJob) (*domain.Build, *domain.Pipeline) { +func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.BuildJob) (*domain.Build, *domain.Pipeline, error) { build, err := w.repo.GetBuild(ctx, job.BuildID, job.UserID) - if err != nil || build == nil { + if err != nil { w.logger.Error("failed to load build", "build_id", job.BuildID, "error", err) - return nil, nil + return nil, nil, err + } + if build == nil { + return nil, nil, nil } pipeline, err := w.repo.GetPipeline(ctx, job.PipelineID, job.UserID) - if err != nil || pipeline == nil { + if err != nil { w.logger.Error("failed to load pipeline", "pipeline_id", job.PipelineID, "error", err) + w.failBuild(ctx, build, "pipeline load error: "+err.Error()) + return nil, nil, err + } + if pipeline == nil { w.failBuild(ctx, build, "pipeline not found") - return nil, nil + return build, nil, nil } - return build, pipeline + return build, pipeline, nil } func (w *PipelineWorker) markBuildRunning(ctx context.Context, build *domain.Build) bool { @@ -262,3 +385,53 @@ func (w *PipelineWorker) collectTaskLogs(ctx context.Context, taskID string) (st } return string(data), nil } + +func (w *PipelineWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + msgs, err := w.taskQueue.ReclaimStale(ctx, pipelineQueueName, pipelineGroup, w.consumerName, pipelineReclaimMs, pipelineReclaimN) + if err != nil { + w.logger.Warn("pipeline reclaim error", "error", err) + continue + } + for _, m := range msgs { + var job domain.BuildJob + if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal reclaimed pipeline job", + "msg_id", m.ID, "error", err) + w.ackWithLog(ctx, m.ID, "reclaimed pipeline poison message") + continue + } + w.logger.Info("reclaimed stale pipeline job", + "build_id", job.BuildID, "msg_id", m.ID) + + m := m + sem <- struct{}{} + go func() { + defer func() { <-sem }() + w.processJob(ctx, &m, job) + }() + } + } + } +} + +func (w *PipelineWorker) ackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Ack(ctx, pipelineQueueName, pipelineGroup, messageID); err != nil { + w.logger.Warn("failed to ack pipeline job", + "msg_id", messageID, "reason", reason, "error", err) + } +} + +func (w *PipelineWorker) nackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Nack(ctx, pipelineQueueName, pipelineGroup, messageID); err != nil { + w.logger.Warn("failed to nack pipeline job", + "msg_id", messageID, "reason", reason, "error", err) + } +} diff --git a/internal/workers/pipeline_worker_test.go b/internal/workers/pipeline_worker_test.go index 6a5cf50aa..678183f77 100644 --- a/internal/workers/pipeline_worker_test.go +++ b/internal/workers/pipeline_worker_test.go @@ -163,12 +163,13 @@ func TestPipelineWorker_processJob(t *testing.T) { compute := new(mockComputeBackendExtended) taskQueue := new(MockTaskQueue) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - worker := NewPipelineWorker(repo, taskQueue, compute, logger) + worker := NewPipelineWorker(repo, taskQueue, nil, compute, logger) buildID := uuid.New() pipelineID := uuid.New() userID := uuid.New() job := domain.BuildJob{BuildID: buildID, PipelineID: pipelineID, UserID: userID} + msg := &ports.DurableMessage{ID: "1-0", Queue: pipelineQueueName} t.Run("Success", func(t *testing.T) { build := &domain.Build{ID: buildID, PipelineID: pipelineID, UserID: userID} @@ -205,9 +206,11 @@ func TestPipelineWorker_processJob(t *testing.T) { repo.On("UpdateBuild", mock.Anything, mock.MatchedBy(func(b *domain.Build) bool { return b.Status == domain.BuildStatusSucceeded })).Return(nil).Once() + taskQueue.On("Ack", mock.Anything, pipelineQueueName, pipelineGroup, msg.ID).Return(nil).Once() - worker.processJob(job) + worker.processJob(context.Background(), msg, job) repo.AssertExpectations(t) compute.AssertExpectations(t) + taskQueue.AssertExpectations(t) }) } diff --git a/internal/workers/provision_worker.go b/internal/workers/provision_worker.go index 3f257e5ca..a9cfc217f 100644 --- a/internal/workers/provision_worker.go +++ b/internal/workers/provision_worker.go @@ -4,7 +4,9 @@ package workers import ( "context" "encoding/json" + "fmt" "log/slog" + "os" "sync" "time" @@ -14,25 +16,65 @@ import ( "github.com/poyrazk/thecloud/internal/core/services" ) -// ProvisionWorker processes instance provisioning tasks. +const ( + provisionQueue = "provision_queue" + provisionGroup = "provision_workers" + provisionMaxWorkers = 20 + // How long a message can sit in PEL before another consumer reclaims it. + // Must be longer than provisionStaleThreshold (15m) to avoid premature reclaim. + provisionReclaimMs = 20 * 60 * 1000 // 20 minutes + provisionReclaimN = 10 + // Stale threshold for idempotency ledger: if a "running" entry is older + // than this, it is considered abandoned and can be reclaimed. + provisionStaleThreshold = 15 * time.Minute +) + +// ProvisionWorker processes instance provisioning tasks using a durable queue +// with at-least-once delivery. Jobs are acknowledged only after successful +// processing; crashed jobs are reclaimed by healthy peers. An execution ledger +// prevents duplicate processing of the same instance. type ProvisionWorker struct { - instSvc *services.InstanceService - taskQueue ports.TaskQueue - logger *slog.Logger + instSvc *services.InstanceService + taskQueue ports.DurableTaskQueue + ledger ports.ExecutionLedger + logger *slog.Logger + consumerName string } // NewProvisionWorker constructs a ProvisionWorker. -func NewProvisionWorker(instSvc *services.InstanceService, taskQueue ports.TaskQueue, logger *slog.Logger) *ProvisionWorker { +// If ledger is nil, idempotency checks are skipped. +func NewProvisionWorker(instSvc *services.InstanceService, taskQueue ports.DurableTaskQueue, ledger ports.ExecutionLedger, logger *slog.Logger) *ProvisionWorker { + hostname, _ := os.Hostname() + if hostname == "" { + hostname = "provision-worker" + } return &ProvisionWorker{ - instSvc: instSvc, - taskQueue: taskQueue, - logger: logger, + instSvc: instSvc, + taskQueue: taskQueue, + ledger: ledger, + logger: logger, + consumerName: hostname, } } func (w *ProvisionWorker) Run(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - w.logger.Info("starting provision worker") + w.logger.Info("starting provision worker", + "consumer", w.consumerName, + "concurrency", provisionMaxWorkers, + ) + + // Ensure consumer group exists. + if err := w.taskQueue.EnsureGroup(ctx, provisionQueue, provisionGroup); err != nil { + w.logger.Error("failed to ensure provision consumer group", "error", err) + return + } + + sem := make(chan struct{}, provisionMaxWorkers) + + // Start a background goroutine that periodically reclaims stale messages + // from crashed consumers. + go w.reclaimLoop(ctx, sem) for { select { @@ -40,47 +82,160 @@ func (w *ProvisionWorker) Run(ctx context.Context, wg *sync.WaitGroup) { w.logger.Info("stopping provision worker") return default: - // Dequeue task - msg, err := w.taskQueue.Dequeue(ctx, "provision_queue") + msg, err := w.taskQueue.Receive(ctx, provisionQueue, provisionGroup, w.consumerName) if err != nil { - // redis.Nil or other error + w.logger.Error("failed to receive provision job", "error", err) time.Sleep(1 * time.Second) continue } - - if msg == "" { + if msg == nil { continue } var job domain.ProvisionJob - if err := json.Unmarshal([]byte(msg), &job); err != nil { - w.logger.Error("failed to unmarshal provision job", "error", err) + if err := json.Unmarshal([]byte(msg.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal provision job", + "error", err, "msg_id", msg.ID) + // Ack poison messages so they don't block the queue. + w.ackWithLog(ctx, msg.ID, "provision poison message") continue } - w.logger.Info("processing provision job", "instance_id", job.InstanceID, "tenant_id", job.TenantID) + w.logger.Info("processing provision job", + "instance_id", job.InstanceID, + "tenant_id", job.TenantID, + "msg_id", msg.ID, + ) - // Process job concurrently to handle high throughput in load tests - go w.processJob(job) + sem <- struct{}{} // acquire concurrency slot + go func(m *ports.DurableMessage, j domain.ProvisionJob) { + defer func() { <-sem }() + w.processJob(ctx, m, j) + }(msg, job) } } } -func (w *ProvisionWorker) processJob(job domain.ProvisionJob) { - // Root context for background task with 10-minute safety timeout - // We use context.Background() because the worker lifecycle context shouldn't necessarily cancel active provisioning unless the app is shutting down - baseCtx := context.Background() - ctx, cancel := context.WithTimeout(baseCtx, 10*time.Minute) +func (w *ProvisionWorker) processJob(workerCtx context.Context, msg *ports.DurableMessage, job domain.ProvisionJob) { + jobKey := fmt.Sprintf("provision:%s", job.InstanceID) + + // Idempotency check: skip if already completed or actively being processed. + if w.ledger != nil { + acquired, err := w.ledger.TryAcquire(workerCtx, jobKey, provisionStaleThreshold) + if err != nil { + w.logger.Error("execution ledger error", + "instance_id", job.InstanceID, "msg_id", msg.ID, "error", err) + // On ledger error, nack to retry later. + w.nackWithLog(workerCtx, msg.ID, "ledger try_acquire failed") + return + } + if !acquired { + // Check if it's already finished or just being processed by someone else. + status, _, _, getErr := w.ledger.GetStatus(workerCtx, jobKey) + if getErr == nil && status == "completed" { + w.logger.Info("skipping already completed provision job", + "instance_id", job.InstanceID, "msg_id", msg.ID) + w.ackWithLog(workerCtx, msg.ID, "provision already completed") + return + } + w.logger.Info("provision job is currently being processed by another worker", + "instance_id", job.InstanceID, "msg_id", msg.ID) + return // Leave unacked for redelivery/wait. + } + } + + // Root context for background task with 10-minute safety timeout. + ctx, cancel := context.WithTimeout(workerCtx, 10*time.Minute) defer cancel() - // Inject User and Tenant IDs for repository access control + // Inject User and Tenant IDs for repository access control. ctx = appcontext.WithUserID(ctx, job.UserID) ctx = appcontext.WithTenantID(ctx, job.TenantID) - w.logger.Info("starting provision logic", "instance_id", job.InstanceID) + w.logger.Info("starting provision logic", "instance_id", job.InstanceID, "msg_id", msg.ID) if err := w.instSvc.Provision(ctx, job); err != nil { - w.logger.Error("failed to provision instance", "instance_id", job.InstanceID, "error", err) - } else { - w.logger.Info("successfully provisioned instance", "instance_id", job.InstanceID) + w.logger.Error("failed to provision instance", + "instance_id", job.InstanceID, + "msg_id", msg.ID, + "error", err, + ) + // Mark failed in the ledger so it can be retried. + if w.ledger != nil { + if ledgerErr := w.ledger.MarkFailed(workerCtx, jobKey, err.Error()); ledgerErr != nil { + w.logger.Warn("failed to mark provision job failed in ledger", + "instance_id", job.InstanceID, "msg_id", msg.ID, "error", ledgerErr) + } + } + // Nack: leave message in PEL for reclaim/retry. + w.nackWithLog(workerCtx, msg.ID, "provision failed") + return + } + + w.logger.Info("successfully provisioned instance", + "instance_id", job.InstanceID, + "msg_id", msg.ID, + ) + + // Mark completed in ledger (prevents duplicate execution). + if w.ledger != nil { + if ledgerErr := w.ledger.MarkComplete(workerCtx, jobKey, "ok"); ledgerErr != nil { + w.logger.Warn("failed to mark provision job complete in ledger", + "instance_id", job.InstanceID, "msg_id", msg.ID, "error", ledgerErr) + } + } + + // Acknowledge — message is permanently consumed. + w.ackWithLog(workerCtx, msg.ID, "provision success") +} + +// reclaimLoop periodically reclaims messages stuck in the PEL from crashed +// consumers and re-processes them. +func (w *ProvisionWorker) reclaimLoop(ctx context.Context, sem chan struct{}) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + msgs, err := w.taskQueue.ReclaimStale(ctx, provisionQueue, provisionGroup, w.consumerName, provisionReclaimMs, provisionReclaimN) + if err != nil { + w.logger.Warn("provision reclaim error", "error", err) + continue + } + for _, m := range msgs { + var job domain.ProvisionJob + if err := json.Unmarshal([]byte(m.Payload), &job); err != nil { + w.logger.Error("failed to unmarshal reclaimed provision job", + "msg_id", m.ID, "error", err) + w.ackWithLog(ctx, m.ID, "reclaimed provision poison message") + continue + } + w.logger.Info("reclaimed stale provision job", + "instance_id", job.InstanceID, "msg_id", m.ID) + + m := m // capture loop variable + sem <- struct{}{} + go func() { + defer func() { <-sem }() + w.processJob(ctx, &m, job) + }() + } + } + } +} + +func (w *ProvisionWorker) ackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Ack(ctx, provisionQueue, provisionGroup, messageID); err != nil { + w.logger.Warn("failed to ack provision job", + "msg_id", messageID, "reason", reason, "error", err) + } +} + +func (w *ProvisionWorker) nackWithLog(ctx context.Context, messageID string, reason string) { + if err := w.taskQueue.Nack(ctx, provisionQueue, provisionGroup, messageID); err != nil { + w.logger.Warn("failed to nack provision job", + "msg_id", messageID, "reason", reason, "error", err) } } diff --git a/internal/workers/provision_worker_test.go b/internal/workers/provision_worker_test.go index 8cdd442e2..9d6397df7 100644 --- a/internal/workers/provision_worker_test.go +++ b/internal/workers/provision_worker_test.go @@ -18,28 +18,53 @@ import ( "github.com/stretchr/testify/assert" ) -type fakeTaskQueue struct { - messages []string - errors []error // To simulate dequeue errors +// fakeDurableQueue implements ports.DurableTaskQueue for testing. +type fakeDurableQueue struct { + messages []*ports.DurableMessage + errors []error index int + acked []string + nacked []string } -func (f *fakeTaskQueue) Enqueue(ctx context.Context, queueName string, payload interface{}) error { +func (f *fakeDurableQueue) Enqueue(ctx context.Context, queueName string, payload interface{}) error { return nil } -func (f *fakeTaskQueue) Dequeue(ctx context.Context, queueName string) (string, error) { +func (f *fakeDurableQueue) Dequeue(ctx context.Context, queueName string) (string, error) { + return "", nil +} + +func (f *fakeDurableQueue) EnsureGroup(ctx context.Context, queueName, groupName string) error { + return nil +} + +func (f *fakeDurableQueue) Receive(ctx context.Context, queueName, groupName, consumerName string) (*ports.DurableMessage, error) { if f.index < len(f.errors) && f.errors[f.index] != nil { err := f.errors[f.index] f.index++ - return "", err + return nil, err } if f.index < len(f.messages) { msg := f.messages[f.index] f.index++ return msg, nil } - return "", nil + return nil, nil +} + +func (f *fakeDurableQueue) Ack(ctx context.Context, queueName, groupName, messageID string) error { + f.acked = append(f.acked, messageID) + return nil +} + +func (f *fakeDurableQueue) Nack(ctx context.Context, queueName, groupName, messageID string) error { + f.nacked = append(f.nacked, messageID) + return nil +} + +func (f *fakeDurableQueue) ReclaimStale(ctx context.Context, queueName, groupName, consumerName string, minIdleMs int64, count int64) ([]ports.DurableMessage, error) { + return nil, nil } // failingComputeBackend forces Provision to fail @@ -53,48 +78,57 @@ func (f *failingComputeBackend) LaunchInstanceWithOptions(ctx context.Context, o func TestProvisionWorkerRun(t *testing.T) { tests := []struct { - name string - message interface{} // string or struct - injectDequeErr bool - failProvision bool - wantLog string + name string + payload interface{} + poisonJSON bool + failProvision bool + wantLog string + wantAcked bool + wantNacked bool }{ { name: "success", - message: domain.ProvisionJob{ + payload: domain.ProvisionJob{ InstanceID: uuid.New(), UserID: uuid.New(), }, - wantLog: "successfully provisioned instance", + wantLog: "successfully provisioned instance", + wantAcked: true, }, { - name: "deserialize_error", - message: "{invalid-json}", - wantLog: "failed to unmarshal provision job", + name: "deserialize_error", + poisonJSON: true, + wantLog: "failed to unmarshal provision job", + wantAcked: true, // poison messages are acked to unblock the queue }, { - name: "provision_error", - message: domain.ProvisionJob{InstanceID: uuid.New(), UserID: uuid.New()}, + name: "provision_error", + payload: domain.ProvisionJob{ + InstanceID: uuid.New(), + UserID: uuid.New(), + }, failProvision: true, wantLog: "failed to provision instance", + wantNacked: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var msgBytes []byte - switch v := tt.message.(type) { - case string: - msgBytes = []byte(v) - default: - msgBytes, _ = json.Marshal(v) + var payloadStr string + if tt.poisonJSON { + payloadStr = "{invalid-json}" + } else { + data, _ := json.Marshal(tt.payload) + payloadStr = string(data) } - fq := &fakeTaskQueue{ - messages: []string{string(msgBytes)}, + fq := &fakeDurableQueue{ + messages: []*ports.DurableMessage{ + {ID: "1-0", Payload: payloadStr, Queue: provisionQueue}, + }, } - // Compute backend var compute ports.ComputeBackend = &noop.NoopComputeBackend{} if tt.failProvision { compute = &failingComputeBackend{} @@ -116,7 +150,7 @@ func TestProvisionWorkerRun(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) - worker := NewProvisionWorker(instSvc, fq, logger) + worker := NewProvisionWorker(instSvc, fq, nil, logger) ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup @@ -124,20 +158,29 @@ func TestProvisionWorkerRun(t *testing.T) { go worker.Run(ctx, &wg) - time.Sleep(50 * time.Millisecond) + // Give worker time to process + time.Sleep(200 * time.Millisecond) cancel() wg.Wait() assert.Contains(t, buf.String(), tt.wantLog) + if tt.wantAcked { + assert.NotEmpty(t, fq.acked, "expected message to be acked") + assert.Empty(t, fq.nacked, "did not expect message to be nacked when acked") + } else if tt.wantNacked { + assert.NotEmpty(t, fq.nacked, "expected message to be nacked") + assert.Empty(t, fq.acked, "did not expect message to be acked when nacked") + } else { + assert.Empty(t, fq.acked, "expected no ack") + assert.Empty(t, fq.nacked, "expected no nack") + } }) } } -func TestProvisionWorkerRunDequeueError(t *testing.T) { - // Test that worker continues on queue error - fq := &fakeTaskQueue{ - messages: []string{}, - errors: []error{errors.New("redis connection failed")}, +func TestProvisionWorkerRunReceiveError(t *testing.T) { + fq := &fakeDurableQueue{ + errors: []error{errors.New("redis connection failed")}, } instSvc := services.NewInstanceService(services.InstanceServiceParams{ @@ -156,7 +199,7 @@ func TestProvisionWorkerRunDequeueError(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) - worker := NewProvisionWorker(instSvc, fq, logger) + worker := NewProvisionWorker(instSvc, fq, nil, logger) ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup @@ -166,5 +209,6 @@ func TestProvisionWorkerRunDequeueError(t *testing.T) { time.Sleep(50 * time.Millisecond) cancel() wg.Wait() - // No specific log to check as it just continues, but we ensure no panic and coverage hits error path + + assert.Contains(t, buf.String(), "failed to receive provision job") }