Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 40 additions & 34 deletions core/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type consumer struct {
mu sync.Mutex
stopped bool
controllers []Controller
subscriptions map[Topic]*activeSubscription // topic -> subscription
subscriptions map[TopicKey]*activeSubscription // topicKey -> subscription
}

// activeSubscription tracks the state of an active subscription.
Expand All @@ -52,12 +52,12 @@ func New(logger *zap.SugaredLogger, scope tally.Scope, registry TopicRegistry) C
logger: logger,
metricsScope: scope.SubScope("consumer"),
registry: registry,
subscriptions: make(map[Topic]*activeSubscription),
subscriptions: make(map[TopicKey]*activeSubscription),
}
}

// Register adds a controller to the consumer. Must be called before Start().
// Returns error if a controller for the same topic is already registered or if the consumer is stopped.
// Returns error if a controller for the same topic key is already registered or if the consumer is stopped.
func (m *consumer) Register(controller Controller) error {
m.mu.Lock()
defer m.mu.Unlock()
Expand All @@ -66,19 +66,19 @@ func (m *consumer) Register(controller Controller) error {
return fmt.Errorf("consumer is stopped")
}

// Check for duplicate topic registration.
// Check for duplicate topic key registration.
// O(n) scan is fine here — controller count is in the single digits.
for _, c := range m.controllers {
if c.Topic() == controller.Topic() {
return fmt.Errorf("controller for topic %s already registered", controller.Topic())
if c.TopicKey() == controller.TopicKey() {
return fmt.Errorf("controller for topic key %s already registered", controller.TopicKey())
}
}

m.controllers = append(m.controllers, controller)

m.logger.Infow("registered controller",
"controller", controller.Name(),
"topic", controller.Topic(),
"topic_key", controller.TopicKey(),
"consumer_group", controller.ConsumerGroup(),
)

Expand Down Expand Up @@ -124,23 +124,29 @@ func (m *consumer) Start(ctx context.Context) error {

// subscribe subscribes a controller to its topic and spawns a consumption goroutine.
func (m *consumer) subscribe(ctx context.Context, controller Controller) error {
topic := controller.Topic()
topicKey := controller.TopicKey()
consumerGroup := controller.ConsumerGroup()

// Get subscription config from registry
config, ok := m.registry.SubscriptionConfig(topic, consumerGroup)
config, ok := m.registry.SubscriptionConfig(topicKey, consumerGroup)
if !ok {
return fmt.Errorf("no subscription config for topic %s, consumer group %s", topic, consumerGroup)
return fmt.Errorf("no subscription config for topic key %s, consumer group %s", topicKey, consumerGroup)
}

// Get queue for this topic
q, ok := m.registry.Queue(topic)
// Get queue for this topic key
q, ok := m.registry.Queue(topicKey)
if !ok {
return fmt.Errorf("no queue registered for topic %s", topic)
return fmt.Errorf("no queue registered for topic key %s", topicKey)
}

// Resolve the actual topic name for subscribing
topicName, ok := m.registry.TopicName(topicKey)
if !ok {
return fmt.Errorf("no topic name registered for topic key %s", topicKey)
}

subscriber := q.Subscriber()
deliveryChan, err := subscriber.Subscribe(ctx, topic.String(), config)
deliveryChan, err := subscriber.Subscribe(ctx, topicName, config)
if err != nil {
return fmt.Errorf("subscribe failed: %w", err)
}
Expand All @@ -155,14 +161,14 @@ func (m *consumer) subscribe(ctx context.Context, controller Controller) error {
cancelFunc: cancel,
done: done,
}
m.subscriptions[topic] = sub
m.subscriptions[topicKey] = sub

// Spawn consumption goroutine
go m.consumeLoop(controllerCtx, controller, deliveryChan, done)

m.logger.Infow("controller started",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
"consumer_group", consumerGroup,
)

Expand All @@ -173,32 +179,32 @@ func (m *consumer) subscribe(ctx context.Context, controller Controller) error {
func (m *consumer) consumeLoop(ctx context.Context, controller Controller, deliveryChan <-chan queue.Delivery, done chan struct{}) {
defer close(done)

topic := controller.Topic()
topicKey := controller.TopicKey()

controllerScope := m.metricsScope.Tagged(map[string]string{
"controller": controller.Name(),
"topic": topic.String(),
"topic_key": topicKey.String(),
})

m.logger.Debugw("consume loop started",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
)

for {
select {
case <-ctx.Done():
m.logger.Infow("consume loop stopped",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
)
return

case delivery, ok := <-deliveryChan:
if !ok {
m.logger.Infow("delivery channel closed",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
)
return
}
Expand All @@ -214,11 +220,11 @@ func (m *consumer) processDelivery(ctx context.Context, controller Controller, d
controllerScope.Counter("messages_received").Inc(1)

msg := delivery.Message()
topic := controller.Topic()
topicKey := controller.TopicKey()

m.logger.Debugw("processing delivery",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
"message_id", msg.ID,
"partition_key", msg.PartitionKey,
"attempt", delivery.Attempt(),
Expand Down Expand Up @@ -248,7 +254,7 @@ func (m *consumer) processDelivery(ctx context.Context, controller Controller, d
if IsNonRetryable(err) {
m.logger.Errorw("non-retryable controller error, rejecting message",
"controller", controller.Name(),
"topic", controller.Topic(),
"topic_key", controller.TopicKey(),
"message_id", msg.ID,
"partition_key", msg.PartitionKey,
"attempt", delivery.Attempt(),
Expand All @@ -262,7 +268,7 @@ func (m *consumer) processDelivery(ctx context.Context, controller Controller, d
if rejectErr := delivery.Reject(ctx, err.Error()); rejectErr != nil {
m.logger.Errorw("failed to reject non-retryable message",
"controller", controller.Name(),
"topic", controller.Topic(),
"topic_key", controller.TopicKey(),
"message_id", msg.ID,
"error", rejectErr,
)
Expand All @@ -274,7 +280,7 @@ func (m *consumer) processDelivery(ctx context.Context, controller Controller, d
// Controller returned retryable error - nack message for retry
m.logger.Errorw("controller error, nacking message",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
"message_id", msg.ID,
"partition_key", msg.PartitionKey,
"attempt", delivery.Attempt(),
Expand All @@ -289,7 +295,7 @@ func (m *consumer) processDelivery(ctx context.Context, controller Controller, d
if nackErr := delivery.Nack(ctx, 0); nackErr != nil {
m.logger.Errorw("failed to nack message",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
"message_id", msg.ID,
"error", nackErr,
)
Expand All @@ -310,7 +316,7 @@ func (m *consumer) processDelivery(ctx context.Context, controller Controller, d
if ackErr := delivery.Ack(ctx); ackErr != nil {
m.logger.Errorw("failed to ack message",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
"message_id", msg.ID,
"error", ackErr,
)
Expand All @@ -334,7 +340,7 @@ func (m *consumer) processDelivery(ctx context.Context, controller Controller, d

m.logger.Debugw("message processed successfully",
"controller", controller.Name(),
"topic", topic,
"topic_key", topicKey,
"message_id", msg.ID,
"partition_key", msg.PartitionKey,
"attempt", delivery.Attempt(),
Expand Down Expand Up @@ -368,26 +374,26 @@ func (m *consumer) Stop(timeoutMs int64) error {
// Returns error on timeout, nil on success.
func (m *consumer) unsubscribeAll(timeoutMs int64) error {
// Cancel all subscription contexts
for topic, sub := range m.subscriptions {
for topicKey, sub := range m.subscriptions {
m.logger.Debugw("stopping controller",
"controller", sub.controller.Name(),
"topic", topic,
"topic_key", topicKey,
)
sub.cancelFunc()
}

// Wait for each subscription to finish, splitting the timeout budget across them
remaining := time.Duration(timeoutMs) * time.Millisecond
var timedOut bool
for topic, sub := range m.subscriptions {
for topicKey, sub := range m.subscriptions {
start := time.Now()
select {
case <-sub.done:
// Controller stopped gracefully
case <-time.After(remaining):
m.logger.Errorw("timeout waiting for controller to stop",
"controller", sub.controller.Name(),
"topic", topic,
"topic_key", topicKey,
)
timedOut = true
}
Expand All @@ -399,7 +405,7 @@ func (m *consumer) unsubscribeAll(timeoutMs int64) error {
}

// Clear subscriptions
m.subscriptions = make(map[Topic]*activeSubscription)
m.subscriptions = make(map[TopicKey]*activeSubscription)

if timedOut {
return fmt.Errorf("timeout waiting for controllers to stop")
Expand Down
Loading