diff --git a/CLAUDE.md b/CLAUDE.md index d1c6140c..07064c15 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -142,6 +142,7 @@ Generated proto files are committed. When modifying `.proto` files: - **Directories**: singular (`mock/`, `entity/`, not `mocks/`, `entities/`) - **Files**: `{method}.go`, `{entity}.go`, `{file}_test.go`, `BUILD.bazel` - **Proto files**: `{service}.proto` +- **README files**: Do not duplicate interface or type definitions as code blocks in READMEs. Describe behavior in prose and let readers navigate to the source. Only include code samples when explicitly instructed. ### Makefile diff --git a/core/consumer/README.md b/core/consumer/README.md index ad8cea0c..a1d72329 100644 --- a/core/consumer/README.md +++ b/core/consumer/README.md @@ -2,6 +2,22 @@ The consumer package orchestrates queue message processing. It manages subscription lifecycle, message consumption, ack/nack, and graceful shutdown. +## Architecture + +``` +Consumer + ├── Controller A (topic: "request") + │ └── consumeLoop + │ ├── processPartition("part-1") ← serial per partition + │ ├── processPartition("part-2") + │ └── processPartition("part-3") + └── Controller B (topic: "build") + └── consumeLoop + └── processPartition("part-1") +``` + +The consumer spawns one `consumeLoop` goroutine per controller. Each `consumeLoop` dispatches deliveries to per-partition goroutines, preserving ordering within each partition while processing different partitions in parallel. + ## Interfaces ### Consumer @@ -9,7 +25,11 @@ The consumer package orchestrates queue message processing. It manages subscript The top-level orchestrator. Register controllers, start consuming, and stop gracefully. ```go -c := consumer.New(logger, scope, queue, "worker-hostname") +registry, _ := consumer.NewTopicRegistry([]consumer.TopicConfig{ + {Key: consumer.TopicKeyRequest, Name: "request", Queue: q, Subscription: subConfig}, +}) + +c := consumer.New(logger, scope, registry) c.Register(myController) c.Start(ctx) @@ -28,15 +48,37 @@ Business logic for processing queue messages. Implement this interface to handle type Controller interface { Process(ctx context.Context, delivery Delivery) error Name() string - Topic() string + TopicKey() TopicKey ConsumerGroup() string - SubscriptionConfig(subscriberName string) queue.SubscriptionConfig } ``` ### Delivery -A restricted view of a queue delivery exposed to controllers. Hides Ack/Nack (handled automatically by Consumer) while exposing message data and `ExtendVisibilityTimeout`. +A restricted view of a queue delivery exposed to controllers. Hides Ack/Nack/Reject (handled automatically by Consumer) while exposing message data and `ExtendVisibilityTimeout`. + +## TopicRegistry + +The `TopicRegistry` maps topic keys to queue backends, topic names, and subscription configs. This decouples controllers from infrastructure wiring. + +```go +registry, _ := consumer.NewTopicRegistry([]consumer.TopicConfig{ + { + Key: consumer.TopicKeyRequest, + Name: "request", + Queue: q, + Subscription: extqueue.DefaultSubscriptionConfig("worker-1", "orchestrator"), + }, + { + Key: consumer.TopicKeyBuild, + Name: "build", + Queue: q, + // No Subscription — publish-only topic + }, +}) +``` + +**Topic keys** are fixed identifiers for pipeline stages (e.g., `TopicKeyRequest`, `TopicKeyBuild`). The actual queue topic name is configured separately, so library consumers can use their own naming conventions. ## Error Handling @@ -46,10 +88,26 @@ Controllers signal processing outcome via the return value of `Process()`: - **`return errs.NewRetryableError(err)`** — retryable failure, message is nacked for retry. - **`return err`** — non-retryable error (e.g. poison pill), message is rejected and removed from the queue to prevent infinite retry loops. +```go +func (c *MyController) Process(ctx context.Context, delivery consumer.Delivery) error { + msg := delivery.Message() + + result, err := c.service.Process(ctx, msg.Payload) + if err != nil { + if isTransient(err) { + return errs.NewRetryableError(err) // nack → retry + } + return err // reject → DLQ + } + + return nil // ack → done +} +``` + ## Lifecycle 1. **Register** controllers before starting. -2. **Start** subscribes to all topics and spawns consume loops. -3. **Stop** cancels all subscriptions and waits for goroutines to finish (with timeout). +2. **Start** subscribes to all topics and spawns consume loops. Startup is atomic — if any subscription fails, all started subscriptions are cleaned up. +3. **Stop** cancels all subscriptions and waits for goroutines to finish (with timeout budget split across controllers). Once stopped, the consumer cannot be restarted — `Register()` and `Start()` return errors. diff --git a/core/consumer/consumer.go b/core/consumer/consumer.go index 630b67bf..c35ecf5b 100644 --- a/core/consumer/consumer.go +++ b/core/consumer/consumer.go @@ -12,6 +12,12 @@ import ( "go.uber.org/zap" ) +const ( + // startupCleanupTimeoutMs is the timeout for cleaning up subscriptions when + // a controller fails to start during Start(). + startupCleanupTimeoutMs = 30000 +) + // Consumer orchestrates multiple queue consumers. It handles subscription lifecycle, // message consumption, ack/nack, and graceful shutdown for the entire pipeline. type Consumer interface { @@ -109,9 +115,9 @@ func (m *consumer) Start(ctx context.Context) error { for _, controller := range m.controllers { if err := m.subscribe(ctx, controller); err != nil { - // Cleanup any started controllers with short timeout (30 seconds). - // Ignore error since we're returning the subscribe error. - _ = m.unsubscribeAll(30000) + // Cleanup any started controllers. Ignore error since we're returning + // the subscribe error. + _ = m.unsubscribeAll(startupCleanupTimeoutMs) return fmt.Errorf("failed to start controller %s: %w", controller.Name(), err) } } @@ -165,7 +171,7 @@ func (m *consumer) subscribe(ctx context.Context, controller Controller) error { m.subscriptions[topicKey] = sub // Spawn consumption goroutine - go m.consumeLoop(controllerCtx, controller, deliveryChan, done) + go m.consumeLoop(controllerCtx, controller, deliveryChan, done, config.BatchSize) m.logger.Infow("controller started", "controller", controller.Name(), @@ -176,8 +182,29 @@ func (m *consumer) subscribe(ctx context.Context, controller Controller) error { return nil } -// consumeLoop processes deliveries for a controller, calling ack/nack based on controller result. -func (m *consumer) consumeLoop(ctx context.Context, controller Controller, deliveryChan <-chan queue.Delivery, done chan struct{}) { +// consumeLoop dispatches deliveries to per-partition worker goroutines. +// Each partition gets its own goroutine, so a slow message on one partition +// does not block other partitions. Per-partition ordering is preserved. +// +// Goroutine model: +// +// consumeLoop (this goroutine) ← reads from deliveryChan +// ├── processPartition("part-1") ← spawned lazily on first message +// ├── processPartition("part-2") +// └── processPartition("part-N") +// +// Shutdown sequence: +// 1. ctx is cancelled (by Stop or parent context) +// 2. consumeLoop exits the select loop and runs the deferred cleanup +// 3. All partition channels are closed, causing processPartition goroutines to +// drain remaining buffered messages and return (range loop ends) +// 4. wg.Wait() blocks until all partition goroutines have exited +// 5. close(done) signals to unsubscribeAll that this controller is fully stopped +// +// Any messages buffered in partition channels but not processed before ctx +// cancellation are safe to drop — the queue's visibility timeout will make +// them visible again for redelivery (at-least-once semantics). +func (m *consumer) consumeLoop(ctx context.Context, controller Controller, deliveryChan <-chan queue.Delivery, done chan struct{}, batchSize int) { defer close(done) topicKey := controller.TopicKey() @@ -192,6 +219,12 @@ func (m *consumer) consumeLoop(ctx context.Context, controller Controller, deliv "topic_key", topicKey, ) + // partitionChs maps partition keys to per-partition delivery channels. + // Each channel is created lazily on the first message for that partition + // and is never removed — partitions are stable for the lifetime of a subscription. + partitionChs := make(map[string]chan queue.Delivery) + var wg sync.WaitGroup + for { select { case <-ctx.Done(): @@ -199,6 +232,7 @@ func (m *consumer) consumeLoop(ctx context.Context, controller Controller, deliv "controller", controller.Name(), "topic_key", topicKey, ) + m.shutdownPartitions(partitionChs, &wg) return case delivery, ok := <-deliveryChan: @@ -207,10 +241,66 @@ func (m *consumer) consumeLoop(ctx context.Context, controller Controller, deliv "controller", controller.Name(), "topic_key", topicKey, ) + m.shutdownPartitions(partitionChs, &wg) return } - m.processDelivery(ctx, controller, delivery, controllerScope) + // Route delivery to its partition's channel, creating the channel + // and spawning a processPartition goroutine if this is the first + // message for that partition. + partitionKey := delivery.Message().PartitionKey + ch, exists := partitionChs[partitionKey] + if !exists { + ch = make(chan queue.Delivery, batchSize) + partitionChs[partitionKey] = ch + wg.Add(1) + go func(pCh <-chan queue.Delivery) { + defer wg.Done() + m.processPartition(ctx, controller, pCh, controllerScope) + }(ch) + } + + // Send to the partition channel. If ctx is cancelled while the + // channel buffer is full, we exit — the undelivered message will + // be retried after visibility timeout. + select { + case ch <- delivery: + case <-ctx.Done(): + m.shutdownPartitions(partitionChs, &wg) + return + } + } + } +} + +// shutdownPartitions closes all partition channels to signal processPartition +// goroutines to exit, then waits for them to finish draining. +func (m *consumer) shutdownPartitions(partitionChs map[string]chan queue.Delivery, wg *sync.WaitGroup) { + for _, ch := range partitionChs { + close(ch) + } + wg.Wait() +} + +// processPartition drains a per-partition channel and processes deliveries serially. +// It runs in its own goroutine (one per partition key). Deliveries within a partition +// are processed in order — the next delivery is not started until the current one +// completes (ack/nack/reject). +// +// The loop exits when either: +// - deliveryCh is closed (consumeLoop cleanup) +// - ctx is cancelled (graceful shutdown) +// +// On context cancellation, the current delivery being read from the channel is +// dropped without processing. This is safe because the queue's visibility timeout +// ensures unprocessed messages are redelivered. +func (m *consumer) processPartition(ctx context.Context, controller Controller, deliveryCh <-chan queue.Delivery, scope tally.Scope) { + for delivery := range deliveryCh { + select { + case <-ctx.Done(): + return + default: + m.processDelivery(ctx, controller, delivery, scope) } } } @@ -370,7 +460,13 @@ func (m *consumer) Stop(timeoutMs int64) error { return err } -// unsubscribeAll stops all active controllers (must be called with lock held). +// unsubscribeAll cancels all subscription contexts and waits for their consumeLoop +// goroutines to exit. +// +// The timeout budget is shared across all subscriptions — each subscription gets +// the remaining time after the previous one finishes. This ensures Stop() returns +// within the caller's specified timeoutMs even if some controllers are slow to drain. +// // timeoutMs is the maximum time in milliseconds to wait for all controllers to stop. // Returns error on timeout, nil on success. func (m *consumer) unsubscribeAll(timeoutMs int64) error { diff --git a/core/consumer/consumer_test.go b/core/consumer/consumer_test.go index d8188165..b7e5634c 100644 --- a/core/consumer/consumer_test.go +++ b/core/consumer/consumer_test.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -586,3 +588,211 @@ func TestConsumer_ErrorMetrics(t *testing.T) { err = c.Stop(30000) require.NoError(t, err) } + +// TestConsumer_PerPartitionProcessing verifies that a slow message on partition A +// does not block partition B from being processed. +func TestConsumer_PerPartitionProcessing(t *testing.T) { + ctrl := gomock.NewController(t) + logger := zaptest.NewLogger(t).Sugar() + + deliveryChan := make(chan extqueue.Delivery, 10) + mockSub := queuemock.NewMockSubscriber(ctrl) + mockSub.EXPECT().Subscribe(gomock.Any(), gomock.Any(), gomock.Any()).Return(deliveryChan, nil) + + mockQ := queuemock.NewMockQueue(ctrl) + mockQ.EXPECT().Subscriber().Return(mockSub) + + reg := newRegistry(t, mockQ, consumer.TopicKeyRequest, "test-group") + + c := consumer.New(logger, tally.NoopScope, reg) + + // Track processing by partition + partBDone := make(chan struct{}) + partABlocking := make(chan struct{}) + var partBProcessed atomic.Bool + + handler := consumermock.NewMockController(ctrl) + setupController(handler, "test-handler", consumer.TopicKeyRequest, "test-group", + func(ctx context.Context, delivery consumer.Delivery) error { + pk := delivery.Message().PartitionKey + if pk == "partition-a" { + // Signal that partition A is blocking + close(partABlocking) + // Block until test is done + <-ctx.Done() + return nil + } + // Partition B processes immediately + partBProcessed.Store(true) + close(partBDone) + return nil + }, + ) + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + // Send message to partition A (will block in controller) + msgA := queue.NewMessage("msg-a", []byte("payload-a"), "partition-a", nil) + mockDelA := queuemock.NewMockDelivery(ctrl) + mockDelA.EXPECT().Message().Return(msgA).AnyTimes() + mockDelA.EXPECT().Attempt().Return(1).AnyTimes() + mockDelA.EXPECT().ReceivedAt().Return(time.Now().UnixMilli()).AnyTimes() + mockDelA.EXPECT().Metadata().Return(nil).AnyTimes() + mockDelA.EXPECT().DeliveryID().Return(msgA.ID).AnyTimes() + mockDelA.EXPECT().Ack(gomock.Any()).Return(nil).MaxTimes(1) + mockDelA.EXPECT().Nack(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(1) + + deliveryChan <- mockDelA + + // Wait for partition A to start blocking + <-partABlocking + + // Send message to partition B (should process despite A being blocked) + msgB := queue.NewMessage("msg-b", []byte("payload-b"), "partition-b", nil) + mockDelB := queuemock.NewMockDelivery(ctrl) + mockDelB.EXPECT().Message().Return(msgB).AnyTimes() + mockDelB.EXPECT().Attempt().Return(1).AnyTimes() + mockDelB.EXPECT().ReceivedAt().Return(time.Now().UnixMilli()).AnyTimes() + mockDelB.EXPECT().Metadata().Return(nil).AnyTimes() + mockDelB.EXPECT().DeliveryID().Return(msgB.ID).AnyTimes() + mockDelB.EXPECT().Ack(gomock.Any()).Return(nil).MaxTimes(1) + + deliveryChan <- mockDelB + + // Partition B should be processed (test timeout handles hangs) + <-partBDone + assert.True(t, partBProcessed.Load(), "partition B should have been processed") + + err = c.Stop(30000) + require.NoError(t, err) +} + +// TestConsumer_PartitionOrdering verifies that messages within a single partition +// are processed in order. +func TestConsumer_PartitionOrdering(t *testing.T) { + ctrl := gomock.NewController(t) + logger := zaptest.NewLogger(t).Sugar() + + deliveryChan := make(chan extqueue.Delivery, 10) + mockSub := queuemock.NewMockSubscriber(ctrl) + mockSub.EXPECT().Subscribe(gomock.Any(), gomock.Any(), gomock.Any()).Return(deliveryChan, nil) + + mockQ := queuemock.NewMockQueue(ctrl) + mockQ.EXPECT().Subscriber().Return(mockSub) + + reg := newRegistry(t, mockQ, consumer.TopicKeyRequest, "test-group") + + c := consumer.New(logger, tally.NoopScope, reg) + + // Mutex + shared slice captures processing order for assertion; + // a channel would only signal completion, not record the sequence. + var mu sync.Mutex + var order []string + allDone := make(chan struct{}) + + handler := consumermock.NewMockController(ctrl) + setupController(handler, "test-handler", consumer.TopicKeyRequest, "test-group", + func(ctx context.Context, delivery consumer.Delivery) error { + mu.Lock() + order = append(order, delivery.Message().ID) + if len(order) == 3 { + close(allDone) + } + mu.Unlock() + return nil + }, + ) + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + // Send 3 messages to the same partition + for i, id := range []string{"msg-1", "msg-2", "msg-3"} { + msg := queue.NewMessage(id, []byte("payload"), "same-partition", nil) + mockDel := queuemock.NewMockDelivery(ctrl) + mockDel.EXPECT().Message().Return(msg).AnyTimes() + mockDel.EXPECT().Attempt().Return(1).AnyTimes() + mockDel.EXPECT().ReceivedAt().Return(time.Now().UnixMilli()).AnyTimes() + mockDel.EXPECT().Metadata().Return(nil).AnyTimes() + mockDel.EXPECT().DeliveryID().Return(fmt.Sprintf("del-%d", i)).AnyTimes() + mockDel.EXPECT().Ack(gomock.Any()).Return(nil).MaxTimes(1) + + deliveryChan <- mockDel + } + + // Wait for all messages (test timeout handles hangs) + <-allDone + mu.Lock() + assert.Equal(t, []string{"msg-1", "msg-2", "msg-3"}, order, "messages should be processed in order within a partition") + mu.Unlock() + + err = c.Stop(30000) + require.NoError(t, err) +} + +// TestConsumer_PartitionWorkerCleanup verifies that all partition goroutines +// exit cleanly on Stop(). +func TestConsumer_PartitionWorkerCleanup(t *testing.T) { + ctrl := gomock.NewController(t) + logger := zaptest.NewLogger(t).Sugar() + + deliveryChan := make(chan extqueue.Delivery, 10) + mockSub := queuemock.NewMockSubscriber(ctrl) + mockSub.EXPECT().Subscribe(gomock.Any(), gomock.Any(), gomock.Any()).Return(deliveryChan, nil) + + mockQ := queuemock.NewMockQueue(ctrl) + mockQ.EXPECT().Subscriber().Return(mockSub) + + reg := newRegistry(t, mockQ, consumer.TopicKeyRequest, "test-group") + + c := consumer.New(logger, tally.NoopScope, reg) + + processedCount := int64(0) + + handler := consumermock.NewMockController(ctrl) + setupController(handler, "test-handler", consumer.TopicKeyRequest, "test-group", + func(ctx context.Context, delivery consumer.Delivery) error { + atomic.AddInt64(&processedCount, 1) + return nil + }, + ) + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + // Send messages to multiple partitions to spawn multiple goroutines + for i := 0; i < 5; i++ { + pk := fmt.Sprintf("partition-%d", i) + msg := queue.NewMessage(fmt.Sprintf("msg-%d", i), []byte("payload"), pk, nil) + mockDel := queuemock.NewMockDelivery(ctrl) + done := setupDelivery(mockDel, msg, nil, nil) + deliveryChan <- mockDel + <-done + } + + // All messages should have been processed + assert.Equal(t, int64(5), atomic.LoadInt64(&processedCount)) + + // Stop should complete cleanly (no goroutine leaks or deadlocks) + err = c.Stop(30000) + require.NoError(t, err) +} diff --git a/extension/queue/README.md b/extension/queue/README.md index 17ed31d6..26685670 100644 --- a/extension/queue/README.md +++ b/extension/queue/README.md @@ -18,11 +18,11 @@ type Publisher interface { ``` ### Subscriber -Consumes messages from topics. +Consumes messages from topics with per-subscription configuration. ```go type Subscriber interface { - Subscribe(ctx context.Context, topic string) (<-chan Delivery, error) + Subscribe(ctx context.Context, topic string, config SubscriptionConfig) (<-chan Delivery, error) Close() error } ``` @@ -35,6 +35,7 @@ type Delivery interface { Message() queue.Message Ack(ctx context.Context) error Nack(ctx context.Context, requeueAfterMillis int64) error + Reject(ctx context.Context, reason string) error ExtendVisibilityTimeout(ctx context.Context, durationMillis int64) error DeliveryID() string Attempt() int @@ -43,6 +44,26 @@ type Delivery interface { } ``` +- **Ack** — message processed successfully, remove from queue +- **Nack** — processing failed, requeue for retry after delay +- **Reject** — poison pill, move to DLQ (or ack if DLQ disabled) +- **ExtendVisibilityTimeout** — extend processing window for long-running work + +### SubscriptionConfig + +Per-subscription configuration for polling, batching, leasing, retries, and DLQ: + +```go +cfg := extqueue.DefaultSubscriptionConfig("worker-1", "consumer-group") +cfg.PollIntervalMs = 50 +cfg.BatchSize = 20 +cfg.VisibilityTimeoutMs = 60000 +cfg.Retry.MaxAttempts = 3 +cfg.DLQ.Enabled = true +``` + +See `subscription_config.go` for all fields and defaults. + ## Usage ```go @@ -51,14 +72,18 @@ defer q.Close() // Publish pub := q.Publisher() -msg := queue.NewMessage("id", []byte("payload")) +msg := queue.NewMessage("id", []byte("payload"), "partition-key", nil) pub.Publish(ctx, "topic", msg) // Subscribe sub := q.Subscriber() -deliveries, _ := sub.Subscribe(ctx, "topic") +cfg := extqueue.DefaultSubscriptionConfig("worker-1", "consumer-group") +deliveries, _ := sub.Subscribe(ctx, "topic", cfg) for delivery := range deliveries { - process(delivery.Message().Payload) + if err := process(delivery.Message().Payload); err != nil { + delivery.Nack(ctx, 0) // Retry + continue + } delivery.Ack(ctx) } ``` @@ -69,3 +94,4 @@ for delivery := range deliveries { 2. Implement `Queue`, `Publisher`, `Subscriber`, `Delivery` interfaces 3. Map `queue.Message` to backend format +See `extension/queue/mysql/` for the reference implementation. diff --git a/extension/queue/mysql/README.md b/extension/queue/mysql/README.md index 6b0af6e9..e5dd22f8 100644 --- a/extension/queue/mysql/README.md +++ b/extension/queue/mysql/README.md @@ -4,10 +4,11 @@ MySQL-based distributed queue with partition leasing, visibility timeout, and at ## Key Features -- **Partition leasing** - Workers coordinate via database leases with automatic failover -- **Visibility timeout** - Messages retry automatically if worker crashes -- **At-least-once delivery** - Offset tracking for crash recovery -- **Dead letter queue** - Failed messages moved to DLQ after max retries +- **Partition leasing** — workers coordinate via database leases with automatic failover +- **Per-partition workers** — each leased partition gets its own goroutine for isolation +- **Visibility timeout** — messages retry automatically if worker crashes +- **At-least-once delivery** — offset tracking for crash recovery +- **Dead letter queue** — failed messages moved to DLQ after max retries ## Quick Start @@ -85,9 +86,52 @@ deliveryCh, _ := q.Subscriber().Subscribe(ctx, "my-topic", subConfig) - `Retry.InitialBackoffMs`: Initial retry backoff delay (milliseconds) - `Retry.MaxBackoffMs`: Maximum retry backoff delay (milliseconds) - `Retry.BackoffMultiplier`: Multiplier for exponential backoff -- `DLQ.TopicSuffix`: Suffix appended to topic name for DLQ (e.g., "orders" → "orders_dlq") +- `DLQ.TopicSuffix`: Suffix appended to topic name for DLQ (e.g., "orders" -> "orders_dlq") + +## Architecture + +### Goroutine Model + +Each subscription has a **supervisor goroutine** (`managePartitions`) that: +1. Discovers partitions from the messages table +2. Acquires and renews partition leases +3. Reconciles **per-partition worker goroutines** based on current leases + +Each partition worker goroutine polls and delivers messages for its partition independently. This provides fault isolation — a slow or blocked partition does not affect other partitions. + +``` +Subscribe() + └── managePartitions (supervisor) + ├── partitionWorker("part-1") ← polls & delivers + ├── partitionWorker("part-2") ← polls & delivers + └── partitionWorker("part-3") ← polls & delivers ``` +### Shutdown Sequence + +Shutdown uses two `sync.WaitGroup`s to ensure correctness: +- `wg` tracks the supervisor goroutine (`managePartitions`) +- `workerWg` tracks all partition worker goroutines + +When `Close()` is called: +1. Subscription context is cancelled +2. `managePartitions` calls `stopAllWorkers` — cancels each worker and waits up to 5s per worker +3. Partition leases are released +4. `workerWg.Wait()` blocks until all workers have fully exited +5. `deliveryCh` is closed — safe because no workers can send after step 4 +6. `managePartitions` returns, `wg.Done()` fires +7. `Close()` returns + +The `workerWg.Wait()` step prevents a race where a slow worker (blocked on I/O past the 5s timeout) could send on a closed channel. + +### Worker Stop Behavior + +When a partition worker is stopped (lease lost or shutdown): +- The worker is immediately removed from the workers map and its context is cancelled +- The caller waits up to 5s for the worker to confirm exit (logging a warning on timeout) +- `workerWg` tracks the worker regardless, so `Close()` always waits for full exit +- If the worker times out, reconciliation is free to start a replacement — any brief overlap is harmless with at-least-once delivery semantics + ## How It Works **Partition Leasing:** diff --git a/extension/queue/mysql/subscriber.go b/extension/queue/mysql/subscriber.go index 67bbfe2a..7c502ea5 100644 --- a/extension/queue/mysql/subscriber.go +++ b/extension/queue/mysql/subscriber.go @@ -14,6 +14,20 @@ import ( extqueue "github.com/uber/submitqueue/extension/queue" ) +const ( + // workerStopTimeout is the maximum time to wait for a partition worker to + // exit after its context is cancelled. + workerStopTimeout = 30 * time.Second + + // leaseReleaseTimeout is the timeout for releasing partition leases during + // shutdown. Uses a fresh context since the subscription context is cancelled. + leaseReleaseTimeout = 30 * time.Second + + // subscriptionShutdownTimeout is the maximum time to wait for the + // managePartitions goroutine to finish during Close(). + subscriptionShutdownTimeout = 30 * time.Second +) + type subscriber struct { logger *zap.SugaredLogger metrics tally.Scope @@ -33,7 +47,39 @@ type subscription struct { config extqueue.SubscriptionConfig deliveryCh chan extqueue.Delivery cancelFunc context.CancelFunc - wg sync.WaitGroup + + // wg tracks the single managePartitions supervisor goroutine. + // Close() waits on this to know the entire subscription is shut down. + wg sync.WaitGroup + + // workerWg tracks all partition worker goroutines independently of wg. + // During shutdown, managePartitions waits on workerWg before closing + // deliveryCh to guarantee no worker can send on a closed channel. + workerWg sync.WaitGroup + + // workers maps partition keys to their active worker goroutines. + // Only accessed by the managePartitions goroutine for reads/reconciliation, + // but mutations are protected by workersMu since stopPartitionWorker may + // be called during shutdown. + workers map[string]*partitionWorker + workersMu sync.Mutex +} + +// partitionWorker handles polling and delivering messages for a single partition. +// Each worker runs in its own goroutine, polling the DB on a ticker and sending +// deliveries to the shared deliveryCh. +type partitionWorker struct { + partitionKey string + sub *subscription + subscriber *subscriber + // cancelFunc cancels this worker's context, causing run() to exit. + cancelFunc context.CancelFunc + // done is closed when run() returns, signaling the worker has fully stopped. + done chan struct{} + // offsetInitialized tracks whether the offset has been initialized for this + // partition. Set once on the first successful poll, avoiding repeated + // initialization calls on every tick. + offsetInitialized bool } // sqlDelivery implements extqueue.Delivery for SQL queue @@ -302,6 +348,7 @@ func (s *subscriber) Subscribe(ctx context.Context, topic string, config extqueu config: config, deliveryCh: make(chan extqueue.Delivery, config.BatchSize*2), cancelFunc: cancel, + workers: make(map[string]*partitionWorker), } s.subscriptions[subKey] = sub @@ -309,7 +356,9 @@ func (s *subscriber) Subscribe(ctx context.Context, topic string, config extqueu // Track active subscription s.metrics.Tagged(map[string]string{"topic": topic}).Gauge("active_subscriptions").Update(1) - // Start partition leasing and polling goroutine + // Start the supervisor goroutine. It will discover partitions, acquire + // leases, and spawn per-partition worker goroutines. The supervisor runs + // until the subscription context is cancelled (via Close or explicit cancel). sub.wg.Add(1) go s.managePartitions(subCtx, sub) @@ -317,13 +366,32 @@ func (s *subscriber) Subscribe(ctx context.Context, topic string, config extqueu return sub.deliveryCh, nil } -// managePartitions discovers partitions, acquires leases, and polls messages +// managePartitions is the supervisor goroutine. It discovers partitions, reconciles +// workers, and renews leases. Each partition gets its own worker goroutine. +// +// There is exactly one managePartitions goroutine per subscription, started by +// Subscribe(). It is the only goroutine that calls reconcilePartitionWorkers, +// so no concurrent reconciliation can occur. +// +// Goroutine hierarchy: +// +// managePartitions (this goroutine) ← supervisor, tracked by sub.wg +// ├── partitionWorker("part-1") ← tracked by sub.workerWg +// ├── partitionWorker("part-2") +// └── partitionWorker("part-N") +// +// Shutdown sequence (triggered by ctx cancellation): +// 1. stopAllWorkers: cancels each worker's context and removes from map +// 2. releaseAllLeases: releases DB partition leases (fresh context, not cancelled) +// 3. workerWg.Wait(): blocks until all workers have fully exited — this ensures +// no worker can send on deliveryCh after step 4 +// 4. close(deliveryCh): safe because step 3 guarantees no senders remain +// 5. managePartitions returns → wg.Done() fires → Close() unblocks func (s *subscriber) managePartitions(ctx context.Context, sub *subscription) { defer sub.wg.Done() - defer close(sub.deliveryCh) - pollTicker := time.NewTicker(time.Duration(sub.config.PollIntervalMs) * time.Millisecond) - defer pollTicker.Stop() + discoveryTicker := time.NewTicker(time.Duration(sub.config.PollIntervalMs) * time.Millisecond) + defer discoveryTicker.Stop() leaseTicker := time.NewTicker(time.Duration(sub.config.LeaseRenewalIntervalMs) * time.Millisecond) defer leaseTicker.Stop() @@ -331,120 +399,213 @@ func (s *subscriber) managePartitions(ctx context.Context, sub *subscription) { for { select { case <-ctx.Done(): + s.stopAllWorkers(sub) // Release all leases on shutdown with a fresh context - // The passed context is already cancelled, so we create a new one with timeout - // to allow graceful lease release operations to complete - cleanupCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + cleanupCtx, cancel := context.WithTimeout(context.Background(), leaseReleaseTimeout) defer cancel() s.releaseAllLeases(cleanupCtx, sub) + // Wait for all workers to fully exit, then close channel + sub.workerWg.Wait() + close(sub.deliveryCh) return case <-leaseTicker.C: - // Renew existing leases s.renewLeases(ctx, sub) - case <-pollTicker.C: - // Fetch and deliver messages from leased partitions - s.pollLeasedPartitions(ctx, sub) + case <-discoveryTicker.C: + s.discoverAndReconcileWorkers(ctx, sub) } } } -// renewLeases renews leases for all partitions owned by this worker -func (s *subscriber) renewLeases(ctx context.Context, sub *subscription) { +// discoverAndReconcileWorkers discovers new partitions and reconciles workers. +func (s *subscriber) discoverAndReconcileWorkers(ctx context.Context, sub *subscription) { cfg := sub.config + + // Discover and try to acquire leases for new partitions + acquiredCount, err := s.leaseStore.DiscoverAndAcquirePartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup, cfg.LeaseDurationMs) + if err == nil && acquiredCount > 0 { + s.metrics.Tagged(map[string]string{"topic": sub.topic}).Counter("leases_acquired").Inc(int64(acquiredCount)) + } + + // Get currently leased partitions leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) if err != nil { - s.logger.Errorw("failed to get leased partitions for renewal", - "topic", sub.topic, - "error", err, - ) - // Error suppressed: lease renewal is best-effort. If we can't get leases, - // they will eventually expire and be reacquired by this or another worker. - // Failing the entire renewal cycle would be worse than skipping one iteration. - s.metrics.Tagged(map[string]string{"topic": sub.topic}).Counter("lease_renewal.get_partitions_errors").Inc(1) + s.logger.Errorw("failed to get leased partitions", "topic", sub.topic, "error", err) return } - for _, partitionKey := range leasedPartitions { - if err := s.leaseStore.RenewLease(ctx, sub.topic, partitionKey, cfg.SubscriberName, cfg.ConsumerGroup, cfg.LeaseDurationMs); err != nil { - s.logger.Warnw("failed to renew lease", - "topic", sub.topic, - "partition_key", partitionKey, - "error", err, - ) - // Error suppressed: Continue trying to renew other leases even if one fails. - // The partition will eventually expire and be re-acquired by this or another worker. - // Failing fast would prevent other partitions from being renewed. - s.metrics.Tagged(map[string]string{ - "topic": sub.topic, - "partition_key": partitionKey, - }).Counter("lease_renewal.renew_errors").Inc(1) + s.reconcilePartitionWorkers(ctx, sub, leasedPartitions) +} + +// reconcilePartitionWorkers diffs the current set of workers against the current +// set of leases and starts/stops workers to match. This is the core of the +// supervisor's control loop. +// +// Thread safety: only called from the single managePartitions goroutine, so the +// snapshot of workers read under the lock does not change between unlock and the +// subsequent start/stop calls. The lock is held briefly to read state, then +// released before blocking operations (stop may wait up to workerStopTimeout). +func (s *subscriber) reconcilePartitionWorkers(ctx context.Context, sub *subscription, currentLeases []string) { + leaseSet := make(map[string]struct{}, len(currentLeases)) + for _, pk := range currentLeases { + leaseSet[pk] = struct{}{} + } + + sub.workersMu.Lock() + + // Find workers to stop (no longer leased) + var toStop []string + for pk := range sub.workers { + if _, ok := leaseSet[pk]; !ok { + toStop = append(toStop, pk) + } + } + + // Find partitions to start (newly leased) + var toStart []string + for _, pk := range currentLeases { + if _, ok := sub.workers[pk]; !ok { + toStart = append(toStart, pk) } } + + sub.workersMu.Unlock() + + // Stop workers for partitions no longer leased + for _, pk := range toStop { + s.stopPartitionWorker(sub, pk) + } + + // Start workers for newly leased partitions + for _, pk := range toStart { + s.startPartitionWorker(ctx, sub, pk) + } } -// releaseAllLeases releases all leases for a topic -func (s *subscriber) releaseAllLeases(ctx context.Context, sub *subscription) { - cfg := sub.config - leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) - if err != nil { - s.logger.Errorw("failed to get leased partitions for release", - "topic", sub.topic, - "error", err, - ) +// startPartitionWorker creates and starts a worker goroutine for a partition. +// The worker is tracked in sub.workers (for reconciliation) and sub.workerWg +// (for shutdown synchronization). +func (s *subscriber) startPartitionWorker(ctx context.Context, sub *subscription, partitionKey string) { + workerCtx, cancel := context.WithCancel(ctx) + + w := &partitionWorker{ + partitionKey: partitionKey, + sub: sub, + subscriber: s, + cancelFunc: cancel, + done: make(chan struct{}), + } + + sub.workersMu.Lock() + sub.workers[partitionKey] = w + sub.workersMu.Unlock() + + sub.workerWg.Add(1) + go w.run(workerCtx) + + s.logger.Debugw("started partition worker", + "topic", sub.topic, + "partition_key", partitionKey, + ) +} + +// stopPartitionWorker cancels a worker's context and removes it from the workers +// map. The worker is removed immediately (before confirming exit) so that +// reconciliation can start a replacement if the lease is re-acquired. The old +// worker's context is cancelled, so its DB calls will fail and it will exit +// imminently. workerWg still tracks the old goroutine, so Close() blocks until +// it fully exits — preventing sends on a closed deliveryCh. +// +// The select with workerStopTimeout is purely for observability: if the worker +// takes longer than expected to exit, a warning is logged but no action is needed +// since workerWg handles the hard guarantee. +func (s *subscriber) stopPartitionWorker(sub *subscription, partitionKey string) { + sub.workersMu.Lock() + w, ok := sub.workers[partitionKey] + if !ok { + sub.workersMu.Unlock() return } + sub.workersMu.Unlock() - for _, partitionKey := range leasedPartitions { - if err := s.leaseStore.ReleaseLease(ctx, sub.topic, partitionKey, cfg.SubscriberName, cfg.ConsumerGroup); err != nil { - s.logger.Warnw("failed to release lease", - "topic", sub.topic, - "partition_key", partitionKey, - "error", err, - ) - // Continue trying to release other leases even if one fails - } + w.cancelFunc() + + // Always remove from map so reconcile can start a replacement if needed. + // The old worker's context is cancelled so it will exit imminently. + // workerWg still tracks it for shutdown — Close() won't return until it exits. + sub.workersMu.Lock() + delete(sub.workers, partitionKey) + sub.workersMu.Unlock() + + select { + case <-w.done: + s.logger.Debugw("stopped partition worker", + "topic", sub.topic, + "partition_key", partitionKey, + ) + case <-time.After(workerStopTimeout): + s.logger.Warnw("partition worker stop timeout, worker will drain in background", + "topic", sub.topic, + "partition_key", partitionKey, + ) } } -// pollLeasedPartitions fetches and delivers messages from all leased partitions -func (s *subscriber) pollLeasedPartitions(ctx context.Context, sub *subscription) { - cfg := sub.config - // Discover and try to acquire leases for new partitions - acquiredCount, err := s.leaseStore.DiscoverAndAcquirePartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup, cfg.LeaseDurationMs) - if err == nil && acquiredCount > 0 { - s.metrics.Tagged(map[string]string{"topic": sub.topic}).Counter("leases_acquired").Inc(int64(acquiredCount)) +// stopAllWorkers stops all partition workers for a subscription. +func (s *subscriber) stopAllWorkers(sub *subscription) { + sub.workersMu.Lock() + keys := make([]string, 0, len(sub.workers)) + for pk := range sub.workers { + keys = append(keys, pk) } + sub.workersMu.Unlock() - // Get currently leased partitions - leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) - if err != nil { - s.logger.Errorw("failed to get leased partitions", "topic", sub.topic, "error", err) - return + for _, pk := range keys { + s.stopPartitionWorker(sub, pk) } +} - // Poll each leased partition - for _, partitionKey := range leasedPartitions { - // Check if context was cancelled before processing next partition +// run is the per-partition goroutine loop. It polls the DB on a ticker and +// sends fetched messages to the shared deliveryCh. Each partition worker runs +// independently — a slow or blocked partition does not affect other partitions. +// +// Lifecycle: +// - Started by startPartitionWorker, tracked by sub.workerWg +// - Stopped when ctx is cancelled (lease lost, shutdown, or explicit stop) +// - Closing w.done signals stopPartitionWorker that the goroutine has exited +func (w *partitionWorker) run(ctx context.Context) { + defer close(w.done) + defer w.sub.workerWg.Done() + + pollTicker := time.NewTicker(time.Duration(w.sub.config.PollIntervalMs) * time.Millisecond) + defer pollTicker.Stop() + + for { select { case <-ctx.Done(): return - default: - s.fetchAndDeliverPartition(ctx, sub, partitionKey) + case <-pollTicker.C: + w.pollAndDeliver(ctx) } } } -// fetchAndDeliverPartition fetches messages from a specific partition and delivers them -func (s *subscriber) fetchAndDeliverPartition(ctx context.Context, sub *subscription, partitionKey string) { +// pollAndDeliver fetches messages from this worker's partition and delivers them. +func (w *partitionWorker) pollAndDeliver(ctx context.Context) { start := time.Now() + s := w.subscriber + sub := w.sub cfg := sub.config + partitionKey := w.partitionKey - // Initialize offset for this partition if needed - if err := s.offsetStore.Initialize(ctx, sub.topic, partitionKey, cfg.ConsumerGroup); err != nil { - s.logger.Errorw("offset initialization failure", "topic", sub.topic, "partition_key", partitionKey, "error", err) - return + // Initialize offset for this partition once per worker lifetime + if !w.offsetInitialized { + if err := s.offsetStore.Initialize(ctx, sub.topic, partitionKey, cfg.ConsumerGroup); err != nil { + s.logger.Errorw("offset initialization failure", "topic", sub.topic, "partition_key", partitionKey, "error", err) + return + } + w.offsetInitialized = true } // Get current offset for this partition @@ -586,7 +747,65 @@ func (s *subscriber) fetchAndDeliverPartition(ctx context.Context, sub *subscrip } } -// Close gracefully shuts down the subscriber +// renewLeases renews leases for all partitions owned by this worker +func (s *subscriber) renewLeases(ctx context.Context, sub *subscription) { + cfg := sub.config + leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) + if err != nil { + s.logger.Errorw("failed to get leased partitions for renewal", + "topic", sub.topic, + "error", err, + ) + s.metrics.Tagged(map[string]string{"topic": sub.topic}).Counter("lease_renewal.get_partitions_errors").Inc(1) + return + } + + for _, partitionKey := range leasedPartitions { + if err := s.leaseStore.RenewLease(ctx, sub.topic, partitionKey, cfg.SubscriberName, cfg.ConsumerGroup, cfg.LeaseDurationMs); err != nil { + s.logger.Warnw("failed to renew lease", + "topic", sub.topic, + "partition_key", partitionKey, + "error", err, + ) + s.metrics.Tagged(map[string]string{ + "topic": sub.topic, + "partition_key": partitionKey, + }).Counter("lease_renewal.renew_errors").Inc(1) + } + } +} + +// releaseAllLeases releases all leases for a topic +func (s *subscriber) releaseAllLeases(ctx context.Context, sub *subscription) { + cfg := sub.config + leasedPartitions, err := s.leaseStore.GetLeasedPartitions(ctx, sub.topic, cfg.SubscriberName, cfg.ConsumerGroup) + if err != nil { + s.logger.Errorw("failed to get leased partitions for release", + "topic", sub.topic, + "error", err, + ) + return + } + + for _, partitionKey := range leasedPartitions { + if err := s.leaseStore.ReleaseLease(ctx, sub.topic, partitionKey, cfg.SubscriberName, cfg.ConsumerGroup); err != nil { + s.logger.Warnw("failed to release lease", + "topic", sub.topic, + "partition_key", partitionKey, + "error", err, + ) + } + } +} + +// Close gracefully shuts down the subscriber and all its subscriptions. +// +// For each subscription: +// 1. Cancels the subscription context, triggering managePartitions shutdown +// 2. Wraps sub.wg.Wait() in a goroutine with subscriptionShutdownTimeout so +// Close() does not block indefinitely if a subscription hangs +// 3. managePartitions internally handles stopping workers and closing deliveryCh +// (see managePartitions shutdown sequence) func (s *subscriber) Close() error { s.mu.Lock() defer s.mu.Unlock() @@ -605,7 +824,10 @@ func (s *subscriber) Close() error { s.logger.Debugw("closing subscription", "topic", topic) sub.cancelFunc() - // Wait for goroutine to finish with timeout + // Wait for the managePartitions goroutine to finish. We wrap the + // blocking Wait in a goroutine so we can enforce a timeout — if + // managePartitions is stuck, we log a warning and move on rather + // than blocking Close() indefinitely. done := make(chan struct{}) go func() { sub.wg.Wait() @@ -615,7 +837,7 @@ func (s *subscriber) Close() error { select { case <-done: // Graceful shutdown completed - case <-time.After(30 * time.Second): + case <-time.After(subscriptionShutdownTimeout): s.logger.Warnw("subscription shutdown timeout", "topic", topic) } diff --git a/extension/queue/mysql/subscriber_test.go b/extension/queue/mysql/subscriber_test.go index a7fa2ba4..3948bc8b 100644 --- a/extension/queue/mysql/subscriber_test.go +++ b/extension/queue/mysql/subscriber_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -254,3 +255,202 @@ func TestSubscriber_Close(t *testing.T) { }) } } + +// TestSubscriber_ReconcilePartitionWorkers tests that workers are started/stopped +// based on lease changes. +func TestSubscriber_ReconcilePartitionWorkers(t *testing.T) { + tests := []struct { + name string + initialLeases []string + updatedLeases []string + }{ + { + name: "start workers for new leases", + initialLeases: []string{}, + updatedLeases: []string{"part-1", "part-2"}, + }, + { + name: "stop workers for lost leases", + initialLeases: []string{"part-1", "part-2"}, + updatedLeases: []string{"part-1"}, + }, + { + name: "no changes when leases unchanged", + initialLeases: []string{"part-1"}, + updatedLeases: []string{"part-1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + + mockMessageStore := NewMockmessageStore(ctrl) + mockOffsetStore := NewMockoffsetStore(ctrl) + mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + + s := NewSubscriber( + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + mockMessageStore, + mockOffsetStore, + mockLeaseStore, + ) + + // Allow offset initialization and fetch calls from workers + mockOffsetStore.EXPECT().Initialize(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockOffsetStore.EXPECT().GetAckedOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil).AnyTimes() + mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sub := &subscription{ + topic: "test_topic", + config: testSubscriptionConfig(), + deliveryCh: make(chan extqueue.Delivery, 100), + workers: make(map[string]*partitionWorker), + } + + // Start initial workers + s.reconcilePartitionWorkers(ctx, sub, tt.initialLeases) + + sub.workersMu.Lock() + assert.Equal(t, len(tt.initialLeases), len(sub.workers)) + sub.workersMu.Unlock() + + // Reconcile with updated leases + s.reconcilePartitionWorkers(ctx, sub, tt.updatedLeases) + + sub.workersMu.Lock() + assert.Equal(t, len(tt.updatedLeases), len(sub.workers)) + for _, pk := range tt.updatedLeases { + assert.Contains(t, sub.workers, pk) + } + sub.workersMu.Unlock() + + // Cleanup: stop all workers + cancel() + s.stopAllWorkers(sub) + }) + } +} + +// TestSubscriber_PartitionWorkerPollAndDeliver verifies a partition worker delivers messages. +func TestSubscriber_PartitionWorkerPollAndDeliver(t *testing.T) { + ctrl := gomock.NewController(t) + + mockMessageStore := NewMockmessageStore(ctrl) + mockOffsetStore := NewMockoffsetStore(ctrl) + mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + + s := NewSubscriber( + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + mockMessageStore, + mockOffsetStore, + mockLeaseStore, + ) + + cfg := testSubscriptionConfig() + deliveryCh := make(chan extqueue.Delivery, 10) + sub := &subscription{ + topic: "test_topic", + config: cfg, + deliveryCh: deliveryCh, + workers: make(map[string]*partitionWorker), + } + + ctx := context.Background() + + mockOffsetStore.EXPECT().Initialize(gomock.Any(), "test_topic", "part-1", cfg.ConsumerGroup).Return(nil) + mockOffsetStore.EXPECT().GetAckedOffset(gomock.Any(), "test_topic", "part-1", cfg.ConsumerGroup).Return(int64(0), nil) + + row := messageRow{ + ID: "msg-1", + Offset: 1, + PartitionKey: "part-1", + Payload: []byte("payload"), + PublishedAt: time.Now().UnixMilli(), + RetryCount: 0, + } + mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), "test_topic", "part-1", int64(0), cfg.BatchSize, cfg.VisibilityTimeoutMs). + Return([]messageRow{row}, nil) + + w := &partitionWorker{ + partitionKey: "part-1", + sub: sub, + subscriber: s, + done: make(chan struct{}), + } + + w.pollAndDeliver(ctx) + + // Verify message was delivered + select { + case del := <-deliveryCh: + assert.Equal(t, "msg-1", del.Message().ID) + default: + t.Fatal("expected delivery but channel was empty") + } + + // Verify offset was initialized only once + assert.True(t, w.offsetInitialized) +} + +// TestSubscriber_StopAllWorkers tests that all workers are stopped gracefully. +func TestSubscriber_StopAllWorkers(t *testing.T) { + ctrl := gomock.NewController(t) + + mockMessageStore := NewMockmessageStore(ctrl) + mockOffsetStore := NewMockoffsetStore(ctrl) + mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + + s := NewSubscriber( + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + mockMessageStore, + mockOffsetStore, + mockLeaseStore, + ) + + // Allow worker polling + mockOffsetStore.EXPECT().Initialize(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockOffsetStore.EXPECT().GetAckedOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil).AnyTimes() + mockMessageStore.EXPECT().FetchByOffset(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sub := &subscription{ + topic: "test_topic", + config: testSubscriptionConfig(), + deliveryCh: make(chan extqueue.Delivery, 100), + workers: make(map[string]*partitionWorker), + } + + // Start 3 workers + s.startPartitionWorker(ctx, sub, "part-1") + s.startPartitionWorker(ctx, sub, "part-2") + s.startPartitionWorker(ctx, sub, "part-3") + + sub.workersMu.Lock() + assert.Equal(t, 3, len(sub.workers)) + sub.workersMu.Unlock() + + // Collect done channels before stopping + sub.workersMu.Lock() + var doneChans []chan struct{} + for _, w := range sub.workers { + doneChans = append(doneChans, w.done) + } + sub.workersMu.Unlock() + + // Stop all workers + s.stopAllWorkers(sub) + + // Verify all done channels are closed (test timeout handles hangs) + for _, doneCh := range doneChans { + <-doneCh + } +}