diff --git a/CLAUDE.md b/CLAUDE.md index 998a042b..24b737d3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -47,6 +47,11 @@ Three services, each following the same layout: ``` / ├── controller/ # Business logic (pure, transport-agnostic) +│ ├── {method}.go # RPC controllers (e.g., land.go, ping.go) +│ ├── {method}_test.go +│ └── {step}/ # Queue message controllers (e.g., request/) +│ ├── {step}.go # Step in workflow +│ └── {step}_test.go ├── proto/ # Proto definitions (.proto files) ├── protopb/ # Generated proto code (committed to repo) └── integration_test/ @@ -54,7 +59,24 @@ Three services, each following the same layout: ### Controllers -Controllers contain pure business logic, independent of the transport layer (gRPC/YARPC). They live in `{service}/controller/` and are wired up in `example/server/{service}/main.go`. +Controllers contain pure business logic, independent of infrastructure. There are two types: + +**RPC Controllers** - Handle synchronous API requests in `{service}/controller/`. Accept protobuf types, independent of gRPC/YARPC transport. + +```go +func (c *LandController) Land(ctx context.Context, req *pb.LandRequest) (*pb.LandResponse, error) +``` + +**Queue Message Controllers** - Process async queue messages in `{service}/controller/{step}/`. Implement `consumer.Controller` interface. + +```go +// Receives consumer.Delivery (NOT extension/queue.Delivery) +func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) error { + // Return nil to ack, error to nack. Consumer handles ack/nack automatically. +} +``` + +Controllers receive `consumer.Delivery` (subset interface without Ack/Nack methods) to enforce separation: controllers do business logic, consumer framework handles infrastructure. ### Entities @@ -70,7 +92,10 @@ entity/ **Entity guidelines:** 1. Keep entities pure and framework-agnostic — no external dependencies 2. Use value types, not references -3. Prefer `int64` Unix epoch milliseconds over `time.Time` +3. Prefer `int64` milliseconds over `time.Time` and `time.Duration`: + - Timestamps: Unix epoch milliseconds (e.g., `CreatedAt int64`) — use `time.UnixMilli()` method + - Durations/timeouts: milliseconds (e.g., `TimeoutMs int64`, `DelayMs int64`) + - Use `time.Duration(ms) * time.Millisecond` to convert to `time.Duration` when needed 4. Every field must have a comment explaining its meaning 5. Reference other entities by ID (string or int), not directly 6. Use string enums with clear names; assign sentinel values (`""` for strings, `0` for ints) to unreachable/unknown enum variants @@ -104,7 +129,9 @@ extension/ ### Import Paths -- Controllers: `github.com/uber/submitqueue/{service}/controller` +- RPC Controllers: `github.com/uber/submitqueue/{service}/controller` +- Queue Controllers: `github.com/uber/submitqueue/{service}/controller/{step}` +- Consumer: `github.com/uber/submitqueue/consumer` - Proto (generated): `github.com/uber/submitqueue/{service}/protopb` - Extensions: `github.com/uber/submitqueue/extension/{extension}` - Extension impl: `github.com/uber/submitqueue/extension/{extension}/{impl}` @@ -164,6 +191,11 @@ All generated proto files are **committed to the repository**. When modifying `. - Tests: `{file}_test.go` - BUILD files: Always `BUILD.bazel` +### Directory Naming + +- Use **singular** names for directories (e.g., `mock/` not `mocks/`, `entity/` not `entities/`) +- This applies to all folders including test mocks, extensions, entities, and service directories + ### Common Make Targets ```bash @@ -189,6 +221,10 @@ make clean-proto # Remove generated proto files 3. Add controller in `{service}/controller/` 4. Wire up in `example/server/{service}/main.go` +**Add new queue message controller:** +1. Create `{service}/controller/{step}/` with controller implementing `consumer.Controller` +2. Wire up in `example/server/{service}/main.go`: register → start → stop on shutdown + **Add new extension implementation:** 1. Create `extension/{extension}/{impl}/` directory 2. Implement factory and core interfaces @@ -204,3 +240,20 @@ make clean-proto # Remove generated proto files 1. **Avoid asserting on error messages** — assert on error type if it is part of the contract, or assert generic error otherwise. 2. **Avoid blocking operations for synchronization** — do not use `time.Sleep`. Design the tested routine to signal back (channels, callbacks, condition variables). 3. **Use testify assertions** — use `stretchr/assert` or `require` instead of `t.Fatal()`. + +### Code Style Guidelines + +1. **Use SugaredLogger for structured logging** — always use `zap.SugaredLogger` with structured logging methods: + - `logger.Debugw(msg, key1, val1, key2, val2, ...)` for debug logs + - `logger.Infow(msg, key1, val1, key2, val2, ...)` for info logs + - `logger.Errorw(msg, key1, val1, key2, val2, ...)` for error logs + - Never use unstructured methods like `Debug()`, `Info()`, `Error()`, or `Printf()` + - Example: `logger.Infow("starting consumer", "subscriber_name", subscriberName, "controller_count", len(controllers))` + +2. **Use interfaces for contracts** — define interfaces for public APIs and dependencies: + - Public components should return/accept interfaces, not concrete structs + - Unexported structs implement the interfaces + - Makes testing easier through mocking + - Example: `func New(...) Consumer` returns interface, not `*consumer` + - Implementation struct is unexported: `type consumer struct { ... }` + diff --git a/core/consumer/BUILD.bazel b/core/consumer/BUILD.bazel new file mode 100644 index 00000000..57403ba5 --- /dev/null +++ b/core/consumer/BUILD.bazel @@ -0,0 +1,35 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "consumer", + srcs = [ + "consumer.go", + "controller.go", + "error.go", + ], + importpath = "github.com/uber/submitqueue/core/consumer", + visibility = ["//visibility:public"], + deps = [ + "//entity/queue", + "//extension/queue", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "consumer_test", + srcs = [ + "consumer_test.go", + "error_test.go", + ], + embed = [":consumer"], + deps = [ + "//entity/queue", + "//extension/queue", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//zaptest", + ], +) diff --git a/core/consumer/README.md b/core/consumer/README.md new file mode 100644 index 00000000..63e77528 --- /dev/null +++ b/core/consumer/README.md @@ -0,0 +1,55 @@ +# Consumer + +The consumer package orchestrates queue message processing. It manages subscription lifecycle, message consumption, ack/nack, and graceful shutdown. + +## Interfaces + +### Consumer + +The top-level orchestrator. Register controllers, start consuming, and stop gracefully. + +```go +c := consumer.New(logger, scope, queue, "worker-hostname") + +c.Register(myController) +c.Start(ctx) + +// On shutdown: +if err := c.Stop(30000); err != nil { + logger.Errorw("consumer stop error", "error", err) +} +``` + +### Controller + +Business logic for processing queue messages. Implement this interface to handle deliveries for a specific topic. + +```go +type Controller interface { + Process(ctx context.Context, delivery Delivery) error + Name() string + Topic() string + ConsumerGroup() string + SubscriptionConfig(subscriberName string) queue.SubscriptionConfig +} +``` + +### Delivery + +A restricted view of a queue delivery exposed to controllers. Hides Ack/Nack (handled automatically by Consumer) while exposing message data and `ExtendVisibilityTimeout`. + +## Error Handling + +Controllers signal processing outcome via the return value of `Process()`: + +- **`return nil`** — success, message is acked. +- **`return err`** — retryable failure, message is nacked for retry. +- **`return consumer.NewNonRetryableError(err)`** — poison pill, message is acked and removed from the queue to prevent infinite retry loops. + +## Lifecycle + +1. **Register** controllers before starting. +2. **Start** subscribes to all topics and spawns consume loops. +3. **Stop** cancels all subscriptions and waits for goroutines to finish (with timeout). + +Once stopped, the consumer cannot be restarted — `Register()` and `Start()` return errors. diff --git a/core/consumer/consumer.go b/core/consumer/consumer.go new file mode 100644 index 00000000..8c75204e --- /dev/null +++ b/core/consumer/consumer.go @@ -0,0 +1,399 @@ +package consumer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/extension/queue" + "go.uber.org/zap" +) + +// Consumer orchestrates multiple queue consumers. It handles subscription lifecycle, +// message consumption, ack/nack, and graceful shutdown for the entire pipeline. +type Consumer interface { + // Register adds a controller to the consumer. Must be called before Start(). + Register(controller Controller) error + + // Start subscribes to all registered controllers' topics and begins consuming messages. + Start(ctx context.Context) error + + // Stop gracefully shuts down all controllers with the specified timeout. + // timeoutMs is the maximum time in milliseconds to wait for graceful shutdown. + // Returns error if shutdown times out. + Stop(timeoutMs int64) error +} + +// consumer implements the Consumer interface. +type consumer struct { + logger *zap.SugaredLogger + metricsScope tally.Scope + queue queue.Queue + subscriberName string // Unique worker ID (hostname, pod name) + + mu sync.Mutex + stopped bool + controllers []Controller + subscriptions map[string]*activeSubscription // topic -> subscription +} + +// activeSubscription tracks the state of an active subscription. +type activeSubscription struct { + controller Controller + cancelFunc context.CancelFunc + done chan struct{} // Closed when consumeLoop exits +} + +// New creates a new consumer. +// subscriberName is the unique worker identifier used for partition leasing (e.g., hostname, pod name). +func New(logger *zap.SugaredLogger, scope tally.Scope, q queue.Queue, subscriberName string) Consumer { + return &consumer{ + logger: logger, + metricsScope: scope.SubScope("consumer"), + queue: q, + subscriberName: subscriberName, + subscriptions: make(map[string]*activeSubscription), + } +} + +// Register adds a controller to the consumer. Must be called before Start(). +// Returns error if a controller for the same topic is already registered or if the consumer is stopped. +func (m *consumer) Register(controller Controller) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.stopped { + return fmt.Errorf("consumer is stopped") + } + + // Check for duplicate topic registration. + // O(n) scan is fine here — controller count is in the single digits. + for _, c := range m.controllers { + if c.Topic() == controller.Topic() { + return fmt.Errorf("controller for topic %s already registered", controller.Topic()) + } + } + + m.controllers = append(m.controllers, controller) + + m.logger.Infow("registered controller", + "controller", controller.Name(), + "topic", controller.Topic(), + "consumer_group", controller.ConsumerGroup(), + ) + + return nil +} + +// Start subscribes to all registered controllers' topics and begins consuming messages. +// Spawns a goroutine per controller that processes deliveries and acks/nacks automatically. +func (m *consumer) Start(ctx context.Context) error { + // Hold the lock for the entire subscribe loop so that startup is atomic: + // either all controllers subscribe successfully or none remain active. + // This also ensures Stop() cannot interleave with a partially-started state. + m.mu.Lock() + defer m.mu.Unlock() + + if m.stopped { + return fmt.Errorf("consumer is stopped") + } + + if len(m.controllers) == 0 { + return fmt.Errorf("no controllers registered") + } + + m.logger.Infow("starting consumer", + "subscriber_name", m.subscriberName, + "controller_count", len(m.controllers), + ) + + for _, controller := range m.controllers { + if err := m.subscribe(ctx, controller); err != nil { + // Cleanup any started controllers with short timeout (30 seconds). + // Ignore error since we're returning the subscribe error. + _ = m.unsubscribeAll(30000) + return fmt.Errorf("failed to start controller %s: %w", controller.Name(), err) + } + } + + m.logger.Infow("consumer started", + "active_subscriptions", len(m.subscriptions), + ) + + return nil +} + +// subscribe subscribes a controller to its topic and spawns a consumption goroutine. +func (m *consumer) subscribe(ctx context.Context, controller Controller) error { + // Get controller's subscription config + config := controller.SubscriptionConfig(m.subscriberName) + + // Subscribe to topic + subscriber := m.queue.Subscriber() + deliveryChan, err := subscriber.Subscribe(ctx, controller.Topic(), config) + if err != nil { + return fmt.Errorf("subscribe failed: %w", err) + } + + // Create cancellable context for this controller + controllerCtx, cancel := context.WithCancel(ctx) + + // Track active subscription + done := make(chan struct{}) + sub := &activeSubscription{ + controller: controller, + cancelFunc: cancel, + done: done, + } + m.subscriptions[controller.Topic()] = sub + + // Spawn consumption goroutine + go m.consumeLoop(controllerCtx, controller, deliveryChan, done) + + m.logger.Infow("controller started", + "controller", controller.Name(), + "topic", controller.Topic(), + "consumer_group", controller.ConsumerGroup(), + ) + + return nil +} + +// consumeLoop processes deliveries for a controller, calling ack/nack based on controller result. +func (m *consumer) consumeLoop(ctx context.Context, controller Controller, deliveryChan <-chan queue.Delivery, done chan struct{}) { + defer close(done) + + controllerScope := m.metricsScope.Tagged(map[string]string{ + "controller": controller.Name(), + "topic": controller.Topic(), + }) + + m.logger.Debugw("consume loop started", + "controller", controller.Name(), + "topic", controller.Topic(), + ) + + for { + select { + case <-ctx.Done(): + m.logger.Infow("consume loop stopped", + "controller", controller.Name(), + "topic", controller.Topic(), + ) + return + + case delivery, ok := <-deliveryChan: + if !ok { + m.logger.Infow("delivery channel closed", + "controller", controller.Name(), + "topic", controller.Topic(), + ) + return + } + + m.processDelivery(ctx, controller, delivery, controllerScope) + } + } +} + +// processDelivery calls the controller and performs ack/nack based on the result. +func (m *consumer) processDelivery(ctx context.Context, controller Controller, delivery queue.Delivery, controllerScope tally.Scope) { + start := time.Now() + controllerScope.Counter("messages_received").Inc(1) + + msg := delivery.Message() + + m.logger.Debugw("processing delivery", + "controller", controller.Name(), + "topic", controller.Topic(), + "message_id", msg.ID, + "partition_key", msg.PartitionKey, + "attempt", delivery.Attempt(), + ) + + // Wrap delivery to hide Ack/Nack from controller + wrapped := &deliveryWrapper{delivery: delivery} + + // Call controller with wrapped delivery + err := controller.Process(ctx, wrapped) + + elapsed := time.Since(start) + + // Track latency with success/failure tags + successTag := "true" + if err != nil { + successTag = "false" + } + + latencyScope := controllerScope.Tagged(map[string]string{ + "success": successTag, + }) + latencyScope.Timer("controller_latency").Record(elapsed) + + if err != nil { + // Check if the error is non-retryable (poison pill message) + if IsNonRetryable(err) { + m.logger.Errorw("non-retryable controller error, rejecting message", + "controller", controller.Name(), + "topic", controller.Topic(), + "message_id", msg.ID, + "partition_key", msg.PartitionKey, + "attempt", delivery.Attempt(), + "error", err, + "elapsed_ms", elapsed.Milliseconds(), + ) + + controllerScope.Counter("non_retryable_errors").Inc(1) + + // Reject moves to DLQ (or acks if DLQ disabled) + if rejectErr := delivery.Reject(ctx, err.Error()); rejectErr != nil { + m.logger.Errorw("failed to reject non-retryable message", + "controller", controller.Name(), + "topic", controller.Topic(), + "message_id", msg.ID, + "error", rejectErr, + ) + controllerScope.Counter("reject_errors").Inc(1) + } + return + } + + // Controller returned retryable error - nack message for retry + m.logger.Errorw("controller error, nacking message", + "controller", controller.Name(), + "topic", controller.Topic(), + "message_id", msg.ID, + "partition_key", msg.PartitionKey, + "attempt", delivery.Attempt(), + "error", err, + "elapsed_ms", elapsed.Milliseconds(), + ) + + controllerScope.Counter("controller_errors").Inc(1) + + // Nack with no delay - let visibility timeout handle retry delay + nackStart := time.Now() + if nackErr := delivery.Nack(ctx, 0); nackErr != nil { + m.logger.Errorw("failed to nack message", + "controller", controller.Name(), + "topic", controller.Topic(), + "message_id", msg.ID, + "error", nackErr, + ) + controllerScope.Counter("nack_errors").Inc(1) + } else { + controllerScope.Counter("nack_count").Inc(1) + nackScope := controllerScope.Tagged(map[string]string{ + "operation": "nack", + "success": "true", + }) + nackScope.Timer("ack_nack_latency").Record(time.Since(nackStart)) + } + return + } + + // Controller succeeded - ack message + ackStart := time.Now() + if ackErr := delivery.Ack(ctx); ackErr != nil { + m.logger.Errorw("failed to ack message", + "controller", controller.Name(), + "topic", controller.Topic(), + "message_id", msg.ID, + "error", ackErr, + ) + controllerScope.Counter("ack_errors").Inc(1) + ackScope := controllerScope.Tagged(map[string]string{ + "operation": "ack", + "success": "false", + }) + ackScope.Timer("ack_nack_latency").Record(time.Since(ackStart)) + return + } + + controllerScope.Counter("messages_processed").Inc(1) + controllerScope.Counter("ack_count").Inc(1) + + ackScope := controllerScope.Tagged(map[string]string{ + "operation": "ack", + "success": "true", + }) + ackScope.Timer("ack_nack_latency").Record(time.Since(ackStart)) + + m.logger.Debugw("message processed successfully", + "controller", controller.Name(), + "topic", controller.Topic(), + "message_id", msg.ID, + "partition_key", msg.PartitionKey, + "attempt", delivery.Attempt(), + "elapsed_ms", elapsed.Milliseconds(), + ) +} + +// Stop gracefully shuts down all handlers with the specified timeout. +// Cancels all subscription contexts and waits for consumption goroutines to finish. +// timeoutMs is the maximum time in milliseconds to wait for graceful shutdown. +// Returns error if shutdown times out. +func (m *consumer) Stop(timeoutMs int64) error { + m.mu.Lock() + m.stopped = true + m.mu.Unlock() + + m.logger.Infow("stopping consumer", + "active_subscriptions", len(m.subscriptions), + "timeout_ms", timeoutMs, + ) + + err := m.unsubscribeAll(timeoutMs) + + m.logger.Infow("consumer stopped") + + return err +} + +// unsubscribeAll stops all active controllers (must be called with lock held). +// timeoutMs is the maximum time in milliseconds to wait for all controllers to stop. +// Returns error on timeout, nil on success. +func (m *consumer) unsubscribeAll(timeoutMs int64) error { + // Cancel all subscription contexts + for topic, sub := range m.subscriptions { + m.logger.Debugw("stopping controller", + "controller", sub.controller.Name(), + "topic", topic, + ) + sub.cancelFunc() + } + + // Wait for each subscription to finish, splitting the timeout budget across them + remaining := time.Duration(timeoutMs) * time.Millisecond + var timedOut bool + for topic, sub := range m.subscriptions { + start := time.Now() + select { + case <-sub.done: + // Controller stopped gracefully + case <-time.After(remaining): + m.logger.Errorw("timeout waiting for controller to stop", + "controller", sub.controller.Name(), + "topic", topic, + ) + timedOut = true + } + elapsed := time.Since(start) + remaining -= elapsed + if remaining < 0 { + remaining = 0 + } + } + + // Clear subscriptions + m.subscriptions = make(map[string]*activeSubscription) + + if timedOut { + return fmt.Errorf("timeout waiting for controllers to stop") + } + + m.logger.Debugw("all controllers stopped gracefully") + return nil +} diff --git a/core/consumer/consumer_test.go b/core/consumer/consumer_test.go new file mode 100644 index 00000000..7cc766ba --- /dev/null +++ b/core/consumer/consumer_test.go @@ -0,0 +1,729 @@ +package consumer + +import ( + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/entity/queue" + extqueue "github.com/uber/submitqueue/extension/queue" + "go.uber.org/zap/zaptest" +) + +// Mock Controller +type mockController struct { + name string + topic string + consumerGroup string + processFunc func(ctx context.Context, delivery Delivery) error +} + +func (m *mockController) Process(ctx context.Context, delivery Delivery) error { + return m.processFunc(ctx, delivery) +} + +func (m *mockController) Name() string { + return m.name +} + +func (m *mockController) Topic() string { + return m.topic +} + +func (m *mockController) ConsumerGroup() string { + return m.consumerGroup +} + +func (m *mockController) SubscriptionConfig(subscriberName string) extqueue.SubscriptionConfig { + return extqueue.DefaultSubscriptionConfig(subscriberName, m.consumerGroup) +} + +// Mock Delivery +type mockDelivery struct { + msg queue.Message + attempt int + ackFunc func(ctx context.Context) error + nackFunc func(ctx context.Context, requeueAfterMillis int64) error + rejectFunc func(ctx context.Context, reason string) error + acked bool + nacked bool + rejected bool + rejectReason string + nackDelayMs int64 + done chan struct{} // Signals when ack/nack/reject is called + mu sync.Mutex +} + +func (m *mockDelivery) Message() queue.Message { + return m.msg +} + +func (m *mockDelivery) DeliveryID() string { + return m.msg.ID +} + +func (m *mockDelivery) Attempt() int { + return m.attempt +} + +func (m *mockDelivery) ReceivedAt() int64 { + return time.Now().UnixMilli() +} + +func (m *mockDelivery) Metadata() map[string]string { + return nil +} + +func (m *mockDelivery) Ack(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + m.acked = true + if m.done != nil { + close(m.done) + } + if m.ackFunc != nil { + return m.ackFunc(ctx) + } + return nil +} + +func (m *mockDelivery) Nack(ctx context.Context, requeueAfterMillis int64) error { + m.mu.Lock() + defer m.mu.Unlock() + m.nacked = true + m.nackDelayMs = requeueAfterMillis + if m.done != nil { + close(m.done) + } + if m.nackFunc != nil { + return m.nackFunc(ctx, requeueAfterMillis) + } + return nil +} + +func (m *mockDelivery) Reject(ctx context.Context, reason string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.rejected = true + m.rejectReason = reason + if m.done != nil { + close(m.done) + } + if m.rejectFunc != nil { + return m.rejectFunc(ctx, reason) + } + return nil +} + +func (m *mockDelivery) ExtendVisibilityTimeout(ctx context.Context, durationMillis int64) error { + return nil +} + +func (m *mockDelivery) WasAcked() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.acked +} + +func (m *mockDelivery) WasNacked() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.nacked +} + +func (m *mockDelivery) WasRejected() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.rejected +} + +// Mock Subscriber +type mockSubscriber struct { + subscribeFunc func(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) +} + +func (m *mockSubscriber) Subscribe(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { + return m.subscribeFunc(ctx, topic, config) +} + +func (m *mockSubscriber) Close() error { + return nil +} + +// Mock Queue +type mockQueue struct { + subscriber extqueue.Subscriber +} + +func (m *mockQueue) Publisher() extqueue.Publisher { + return nil +} + +func (m *mockQueue) Subscriber() extqueue.Subscriber { + return m.subscriber +} + +func (m *mockQueue) Close() error { + return nil +} + +func TestNew(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + q := &mockQueue{} + + c := New(logger, scope, q, "test-worker") + + require.NotNil(t, c) + + // Type assert to access internal fields + impl := c.(*consumer) + assert.Equal(t, "test-worker", impl.subscriberName) + assert.Empty(t, impl.controllers) + assert.Empty(t, impl.subscriptions) +} + +func TestConsumer_Register(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + q := &mockQueue{} + c := New(logger, scope, q, "test-worker") + + handler1 := &mockController{ + name: "handler1", + topic: "topic1", + consumerGroup: "group1", + } + handler2 := &mockController{ + name: "handler2", + topic: "topic2", + consumerGroup: "group2", + } + + // Register first handler + err := c.Register(handler1) + require.NoError(t, err) + assert.Len(t, c.(*consumer).controllers, 1) + + // Register second handler + err = c.Register(handler2) + require.NoError(t, err) + assert.Len(t, c.(*consumer).controllers, 2) +} + +func TestConsumer_Register_DuplicateTopic(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + q := &mockQueue{} + c := New(logger, scope, q, "test-worker") + + handler1 := &mockController{ + name: "handler1", + topic: "topic1", + consumerGroup: "group1", + } + handler2 := &mockController{ + name: "handler2", + topic: "topic1", // Same topic + consumerGroup: "group2", + } + + err := c.Register(handler1) + require.NoError(t, err) + + err = c.Register(handler2) + assert.Error(t, err) +} + +func TestConsumer_Register_AfterStop(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + q := &mockQueue{} + c := New(logger, scope, q, "test-worker") + + err := c.Stop(1000) + require.NoError(t, err) + + handler := &mockController{ + name: "handler1", + topic: "topic1", + consumerGroup: "group1", + } + + err = c.Register(handler) + assert.Error(t, err) +} + +func TestConsumer_Start_NoHandlers(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + q := &mockQueue{} + c := New(logger, scope, q, "test-worker") + + ctx := context.Background() + err := c.Start(ctx) + assert.Error(t, err) +} + +func TestConsumer_Start_AfterStop(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + q := &mockQueue{} + c := New(logger, scope, q, "test-worker") + + handler := &mockController{ + name: "handler1", + topic: "topic1", + consumerGroup: "group1", + } + + err := c.Register(handler) + require.NoError(t, err) + + err = c.Stop(1000) + require.NoError(t, err) + + ctx := context.Background() + err = c.Start(ctx) + assert.Error(t, err) +} + +func TestConsumer_ProcessDelivery_Success(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + + deliveryChan := make(chan extqueue.Delivery, 1) + subscriber := &mockSubscriber{ + subscribeFunc: func(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { + return deliveryChan, nil + }, + } + q := &mockQueue{subscriber: subscriber} + c := New(logger, scope, q, "test-worker") + + handledMsg := "" + handler := &mockController{ + name: "test-handler", + topic: "test-topic", + consumerGroup: "test-group", + processFunc: func(ctx context.Context, delivery Delivery) error { + handledMsg = delivery.Message().ID + return nil // Success + }, + } + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + // Send a test message + msg := queue.NewMessage("test-msg-1", []byte("payload"), "partition1", nil) + delivery := &mockDelivery{ + msg: msg, + attempt: 1, + done: make(chan struct{}), + } + + deliveryChan <- delivery + + // Wait for processing to complete + <-delivery.done + + assert.Equal(t, "test-msg-1", handledMsg) + assert.True(t, delivery.WasAcked(), "Message should be acked on success") + assert.False(t, delivery.WasNacked(), "Message should not be nacked on success") + + err = c.Stop(30000) + require.NoError(t, err) +} + +func TestConsumer_ProcessDelivery_Error(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + + deliveryChan := make(chan extqueue.Delivery, 1) + subscriber := &mockSubscriber{ + subscribeFunc: func(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { + return deliveryChan, nil + }, + } + q := &mockQueue{subscriber: subscriber} + c := New(logger, scope, q, "test-worker") + + handler := &mockController{ + name: "test-handler", + topic: "test-topic", + consumerGroup: "test-group", + processFunc: func(ctx context.Context, delivery Delivery) error { + return fmt.Errorf("processing failed") + }, + } + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + // Send a test message + msg := queue.NewMessage("test-msg-2", []byte("payload"), "partition1", nil) + delivery := &mockDelivery{ + msg: msg, + attempt: 2, + done: make(chan struct{}), + } + + deliveryChan <- delivery + + // Wait for processing to complete + <-delivery.done + + assert.False(t, delivery.WasAcked(), "Message should not be acked on error") + assert.True(t, delivery.WasNacked(), "Message should be nacked on error") + + err = c.Stop(30000) + require.NoError(t, err) +} + +func TestConsumer_ProcessDelivery_NonRetryableError(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + + deliveryChan := make(chan extqueue.Delivery, 1) + subscriber := &mockSubscriber{ + subscribeFunc: func(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { + return deliveryChan, nil + }, + } + q := &mockQueue{subscriber: subscriber} + c := New(logger, scope, q, "test-worker") + + handler := &mockController{ + name: "test-handler", + topic: "test-topic", + consumerGroup: "test-group", + processFunc: func(ctx context.Context, delivery Delivery) error { + return NewNonRetryableError(fmt.Errorf("bad payload")) + }, + } + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + // Send a test message with non-retryable payload + msg := queue.NewMessage("poison-msg", []byte("bad"), "partition1", nil) + delivery := &mockDelivery{ + msg: msg, + attempt: 1, + done: make(chan struct{}), + } + + deliveryChan <- delivery + + // Wait for processing to complete + <-delivery.done + + assert.True(t, delivery.WasRejected(), "Non-retryable message should be rejected") + assert.False(t, delivery.WasAcked(), "Non-retryable message should not be acked directly") + assert.False(t, delivery.WasNacked(), "Non-retryable message should not be nacked") + + err = c.Stop(30000) + require.NoError(t, err) +} + +func TestConsumer_Stop(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + + deliveryChan := make(chan extqueue.Delivery) + subscriber := &mockSubscriber{ + subscribeFunc: func(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { + return deliveryChan, nil + }, + } + q := &mockQueue{subscriber: subscriber} + c := New(logger, scope, q, "test-worker") + + handler := &mockController{ + name: "test-handler", + topic: "test-topic", + consumerGroup: "test-group", + processFunc: func(ctx context.Context, delivery Delivery) error { + return nil + }, + } + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + assert.Len(t, c.(*consumer).subscriptions, 1) + + // Stop the c + err = c.Stop(30000) + require.NoError(t, err) + + assert.Empty(t, c.(*consumer).subscriptions, "Subscriptions should be cleared after stop") +} + +func TestConsumer_ObservabilityTags(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + + deliveryChan := make(chan extqueue.Delivery, 10) + subscriber := &mockSubscriber{ + subscribeFunc: func(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { + return deliveryChan, nil + }, + } + q := &mockQueue{subscriber: subscriber} + + tests := []struct { + name string + handlerError error + ackError error + nackError error + expectSuccess bool + expectAckCount bool + }{ + { + name: "success with ack", + handlerError: nil, + ackError: nil, + expectSuccess: true, + expectAckCount: true, + }, + { + name: "failure with nack", + handlerError: fmt.Errorf("handler failed"), + nackError: nil, + expectSuccess: false, + expectAckCount: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a fresh scope for each test + testScope := tally.NewTestScope("consumer", nil) + testC := New(logger, testScope, q, "test-worker") + + handler := &mockController{ + name: "test-handler", + topic: "test-topic", + consumerGroup: "test-group", + processFunc: func(ctx context.Context, delivery Delivery) error { + return tt.handlerError + }, + } + + err := testC.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = testC.Start(ctx) + require.NoError(t, err) + + // Send a test message + msg := queue.NewMessage("msg-1", []byte("payload"), "partition1", nil) + delivery := &mockDelivery{ + msg: msg, + attempt: 1, + done: make(chan struct{}), + ackFunc: func(ctx context.Context) error { + return tt.ackError + }, + nackFunc: func(ctx context.Context, requeueAfterMillis int64) error { + return tt.nackError + }, + } + + deliveryChan <- delivery + + // Wait for processing to complete + <-delivery.done + + // Verify metrics exist + snapshot := testScope.Snapshot() + + // Check handler latency with success tag exists + timers := snapshot.Timers() + assert.NotEmpty(t, timers, "Should have timer metrics") + + // Check for handler latency metric + var foundLatency bool + for _, timer := range timers { + if strings.Contains(timer.Name(), "controller_latency") { + foundLatency = true + // Verify success tag + tags := timer.Tags() + if tt.expectSuccess { + assert.Equal(t, "true", tags["success"], "Should have success=true tag") + } else { + assert.Equal(t, "false", tags["success"], "Should have success=false tag") + } + } + } + assert.True(t, foundLatency, "Should have controller_latency metric") + + // Check counters + counters := snapshot.Counters() + if tt.expectAckCount { + var foundAck bool + for _, counter := range counters { + if strings.Contains(counter.Name(), "ack_count") { + foundAck = true + assert.Greater(t, counter.Value(), int64(0), "ack_count should be > 0") + } + } + assert.True(t, foundAck, "Should have ack_count metric") + } + + _ = testC.Stop(30000) + }) + } +} + +func TestConsumer_AckNackLatencyTracking(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NewTestScope("consumer", nil) + + deliveryChan := make(chan extqueue.Delivery, 1) + subscriber := &mockSubscriber{ + subscribeFunc: func(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { + return deliveryChan, nil + }, + } + q := &mockQueue{subscriber: subscriber} + c := New(logger, scope, q, "test-worker") + + handler := &mockController{ + name: "test-handler", + topic: "test-topic", + consumerGroup: "test-group", + processFunc: func(ctx context.Context, delivery Delivery) error { + return nil // Success + }, + } + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + // Send a test message + msg := queue.NewMessage("msg-1", []byte("payload"), "partition1", nil) + delivery := &mockDelivery{ + msg: msg, + attempt: 1, + done: make(chan struct{}), + } + + deliveryChan <- delivery + + // Wait for processing to complete + <-delivery.done + + // Verify we have some timer metrics (latency tracking is working) + snapshot := scope.Snapshot() + assert.NotEmpty(t, snapshot.Timers(), "Should have timer metrics for latency tracking") + assert.NotEmpty(t, snapshot.Counters(), "Should have counter metrics") + + err = c.Stop(30000) + require.NoError(t, err) +} + +func TestConsumer_ErrorMetrics(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NewTestScope("consumer", nil) + + deliveryChan := make(chan extqueue.Delivery, 1) + subscriber := &mockSubscriber{ + subscribeFunc: func(ctx context.Context, topic string, config extqueue.SubscriptionConfig) (<-chan extqueue.Delivery, error) { + return deliveryChan, nil + }, + } + q := &mockQueue{subscriber: subscriber} + c := New(logger, scope, q, "test-worker") + + handler := &mockController{ + name: "test-handler", + topic: "test-topic", + consumerGroup: "test-group", + processFunc: func(ctx context.Context, delivery Delivery) error { + return fmt.Errorf("processing failed") + }, + } + + err := c.Register(handler) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = c.Start(ctx) + require.NoError(t, err) + + // Send a test message + msg := queue.NewMessage("msg-1", []byte("payload"), "partition1", nil) + delivery := &mockDelivery{ + msg: msg, + attempt: 1, + done: make(chan struct{}), + nackFunc: func(ctx context.Context, requeueAfterMillis int64) error { + return fmt.Errorf("nack failed") + }, + } + + deliveryChan <- delivery + + // Wait for processing to complete + <-delivery.done + + // Verify error metrics are tracked + snapshot := scope.Snapshot() + counters := snapshot.Counters() + + // Should have handler_errors and nack_errors + var hasErrorMetrics bool + for _, counter := range counters { + if strings.Contains(counter.Name(), "errors") { + hasErrorMetrics = true + break + } + } + assert.True(t, hasErrorMetrics, "Should track error metrics") + + err = c.Stop(30000) + require.NoError(t, err) +} diff --git a/core/consumer/controller.go b/core/consumer/controller.go new file mode 100644 index 00000000..997ce33a --- /dev/null +++ b/core/consumer/controller.go @@ -0,0 +1,98 @@ +package consumer + +import ( + "context" + + "github.com/uber/submitqueue/entity/queue" + extqueue "github.com/uber/submitqueue/extension/queue" +) + +// Delivery is the consumer package's view of a queue delivery. +// It exists to hide Ack/Nack from controllers — the Consumer framework handles those +// automatically based on the error returned from Process(). Controllers only see +// message data, metadata, and ExtendVisibilityTimeout (a business-level concern for +// long-running processing). +// +// To signal outcome from Process(): +// - Return nil to ack the message (success). +// - Return an error to nack the message for retry. +// - Return NonRetryableError to ack a poison pill message (removes it from the queue). +type Delivery interface { + // Message returns the delivered message. + Message() queue.Message + + // ExtendVisibilityTimeout extends the time before this message becomes + // visible to other consumers. Use when processing takes longer than expected. + ExtendVisibilityTimeout(ctx context.Context, durationMillis int64) error + + // DeliveryID returns a backend-specific identifier for this delivery. + DeliveryID() string + + // Attempt returns how many times this message has been delivered. + // Starts at 1 for first delivery. + Attempt() int + + // ReceivedAt returns when this delivery was received (Unix milliseconds). + ReceivedAt() int64 + + // Metadata returns backend-specific delivery metadata. + Metadata() map[string]string +} + +// deliveryWrapper wraps extension/queue.Delivery and exposes only the safe subset of methods. +// Hides Ack/Nack from controllers - Consumer handles those automatically. +type deliveryWrapper struct { + delivery extqueue.Delivery +} + +func (d *deliveryWrapper) Message() queue.Message { + return d.delivery.Message() +} + +func (d *deliveryWrapper) ExtendVisibilityTimeout(ctx context.Context, durationMillis int64) error { + return d.delivery.ExtendVisibilityTimeout(ctx, durationMillis) +} + +func (d *deliveryWrapper) DeliveryID() string { + return d.delivery.DeliveryID() +} + +func (d *deliveryWrapper) Attempt() int { + return d.delivery.Attempt() +} + +func (d *deliveryWrapper) ReceivedAt() int64 { + return d.delivery.ReceivedAt() +} + +func (d *deliveryWrapper) Metadata() map[string]string { + return d.delivery.Metadata() +} + +// Controller processes queue deliveries. Controllers contain business logic and are registered with the Consumer. +// The Controller interface enables clean separation of concerns: +// - Controller focuses on business logic (deserialize, process, return error status) +// - Consumer handles infrastructure (subscription, ack/nack, metrics, lifecycle) +type Controller interface { + // Process processes a delivery. Controller receives consumer.Delivery (not extension/queue.Delivery) + // which prevents direct Ack/Nack calls - Consumer handles those automatically. + // Return nil to ack the message (success), error to nack and retry, + // or NonRetryableError to ack a poison pill message. + Process(ctx context.Context, delivery Delivery) error + + // Name returns the controller name for logging and metrics. + Name() string + + // Topic returns the topic this controller subscribes to. + Topic() string + + // ConsumerGroup returns the consumer group for offset tracking. + // Multiple controllers can share a consumer group to load-balance across workers. + // Different consumer groups consume independently. + ConsumerGroup() string + + // SubscriptionConfig returns the subscription config for this controller. + // Allows each controller to customize poll interval, batch size, timeouts, retry, DLQ. + // The subscriberName parameter is the unique worker identifier (hostname, pod name). + SubscriptionConfig(subscriberName string) extqueue.SubscriptionConfig +} diff --git a/core/consumer/error.go b/core/consumer/error.go new file mode 100644 index 00000000..81c9f983 --- /dev/null +++ b/core/consumer/error.go @@ -0,0 +1,36 @@ +package consumer + +import ( + "errors" + "fmt" +) + +// NonRetryableError indicates a poison pill message that should not be retried. +// When a controller returns this error, the consumer will ack the message (removing it +// from the queue) instead of nacking it for retry. Use this for permanently malformed +// messages that will never succeed regardless of retry count. +type NonRetryableError struct { + // Cause is the underlying error that caused the message to be non-retryable. + Cause error +} + +// NewNonRetryableError creates a new NonRetryableError wrapping the given cause. +func NewNonRetryableError(cause error) *NonRetryableError { + return &NonRetryableError{Cause: cause} +} + +// Error returns the error message. +func (e *NonRetryableError) Error() string { + return fmt.Sprintf("non-retryable: %v", e.Cause) +} + +// Unwrap returns the underlying cause for errors.Is/As compatibility. +func (e *NonRetryableError) Unwrap() error { + return e.Cause +} + +// IsNonRetryable checks if an error is or wraps a NonRetryableError. +func IsNonRetryable(err error) bool { + var target *NonRetryableError + return errors.As(err, &target) +} diff --git a/core/consumer/error_test.go b/core/consumer/error_test.go new file mode 100644 index 00000000..de59eeab --- /dev/null +++ b/core/consumer/error_test.go @@ -0,0 +1,45 @@ +package consumer + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNonRetryableError(t *testing.T) { + cause := fmt.Errorf("bad payload") + err := NewNonRetryableError(cause) + + assert.Equal(t, "non-retryable: bad payload", err.Error()) + assert.True(t, IsNonRetryable(err)) +} + +func TestNonRetryableError_Unwrap(t *testing.T) { + cause := fmt.Errorf("deserialization failed") + err := NewNonRetryableError(cause) + + unwrapped := errors.Unwrap(err) + require.NotNil(t, unwrapped) + assert.Equal(t, cause, unwrapped) +} + +func TestIsNonRetryable_Wrapped(t *testing.T) { + cause := fmt.Errorf("bad json") + nonRetryable := NewNonRetryableError(cause) + wrapped := fmt.Errorf("controller error: %w", nonRetryable) + + assert.True(t, IsNonRetryable(wrapped)) +} + +func TestIsNonRetryable_RegularError(t *testing.T) { + err := fmt.Errorf("temporary failure") + + assert.False(t, IsNonRetryable(err)) +} + +func TestIsNonRetryable_Nil(t *testing.T) { + assert.False(t, IsNonRetryable(nil)) +} diff --git a/e2e_test/suite_test.go b/e2e_test/suite_test.go index 0ea54b33..863222c7 100644 --- a/e2e_test/suite_test.go +++ b/e2e_test/suite_test.go @@ -113,17 +113,3 @@ func (s *IntegrationSuite) TestPingSpeculator() { s.log.Logf("Speculator ping: %s", resp.Message) } -func (s *IntegrationSuite) TestLandRequest() { - ctx := context.Background() - req := &gatewaypb.LandRequest{ - Queue: "integration-test-queue", - Change: &gatewaypb.Change{Source: "github", Ids: []string{"pr-100", "pr-101"}}, - Strategy: gatewaypb.Strategy_REBASE, - } - - s.log.Logf("Sending Land request for queue=%s", req.Queue) - resp, err := s.gatewayClient.Land(ctx, req) - require.NoError(s.T(), err, "Land request failed") - require.NotEmpty(s.T(), resp.Sqid, "SQID should not be empty") - s.log.Logf("Land request succeeded: sqid=%s", resp.Sqid) -} diff --git a/entity/BUILD.bazel b/entity/BUILD.bazel index b9002664..9c98900b 100644 --- a/entity/BUILD.bazel +++ b/entity/BUILD.bazel @@ -1,4 +1,4 @@ -load("@rules_go//go:def.bzl", "go_library") +load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "entity", @@ -6,3 +6,13 @@ go_library( importpath = "github.com/uber/submitqueue/entity", visibility = ["//visibility:public"], ) + +go_test( + name = "entity_test", + srcs = ["request_test.go"], + embed = [":entity"], + deps = [ + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/entity/request.go b/entity/request.go index b0c7fe3c..0dd33f4d 100644 --- a/entity/request.go +++ b/entity/request.go @@ -1,5 +1,7 @@ package entity +import "encoding/json" + // RequestLandStrategy defines the possible source control integration methods. type RequestLandStrategy string @@ -35,9 +37,9 @@ const ( // The object is immutable after creation. type Change struct { // Source is the code change provider (e.g., "github", "gerrit", "phabricator"). - Source string + Source string `json:"source"` // IDs is a list of change IDs, in a format specific to the code change provider, that should be landed together. - IDs []string + IDs []string `json:"ids"` } // Request defines a request to land (merge into target branch of the source control repository) a set of code changes. @@ -48,21 +50,33 @@ type Request struct { // **************** // ID is the globally unique identifier for the land request. Format: "/". - ID string + ID string `json:"id"` // Queue is the name of the queue processing the land request. Queue name is defined in the configuration and should be unique within the system. - Queue string + Queue string `json:"queue"` // Change is a number of code changes (such as pull requests) to land into the target branch. Target branch is defined by the queue configuration. - Change Change + Change Change `json:"change"` // LandStrategy is the source control integration strategy to use for this land operation. - LandStrategy RequestLandStrategy + LandStrategy RequestLandStrategy `json:"land_strategy"` // **************** // Following fields could be changed throughout the lifecycle of the request // **************** // State is the current state of the land request. - State RequestState + State RequestState `json:"state"` // Version is the version of the object. It is used for optimistic locking. // Versioning starts at 1 and is incremented for each change to the object. - Version int32 + Version int32 `json:"version"` +} + +// ToBytes serializes the Request to JSON bytes for queue message payload. +func (r Request) ToBytes() ([]byte, error) { + return json.Marshal(r) +} + +// RequestFromBytes deserializes a Request from JSON bytes. +func RequestFromBytes(data []byte) (Request, error) { + var req Request + err := json.Unmarshal(data, &req) + return req, err } diff --git a/entity/request_test.go b/entity/request_test.go new file mode 100644 index 00000000..74bf2a00 --- /dev/null +++ b/entity/request_test.go @@ -0,0 +1,136 @@ +package entity + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequest_ToBytes(t *testing.T) { + req := Request{ + ID: "test-queue/123", + Queue: "test-queue", + Change: Change{Source: "github", IDs: []string{"PR-456", "PR-789"}}, + LandStrategy: RequestLandStrategyRebase, + State: RequestStateNew, + Version: 1, + } + + data, err := req.ToBytes() + require.NoError(t, err) + assert.NotEmpty(t, data) + + // Verify JSON contains expected fields + jsonStr := string(data) + assert.Contains(t, jsonStr, "test-queue/123") + assert.Contains(t, jsonStr, "github") + assert.Contains(t, jsonStr, "PR-456") + assert.Contains(t, jsonStr, "rebase") + assert.Contains(t, jsonStr, "new") +} + +func TestRequestFromBytes(t *testing.T) { + original := Request{ + ID: "my-queue/999", + Queue: "my-queue", + Change: Change{Source: "gerrit", IDs: []string{"CL-111"}}, + LandStrategy: RequestLandStrategyMerge, + State: RequestStateProcessing, + Version: 3, + } + + // Serialize + data, err := original.ToBytes() + require.NoError(t, err) + + // Deserialize + deserialized, err := RequestFromBytes(data) + require.NoError(t, err) + + // Verify all fields match + assert.Equal(t, original.ID, deserialized.ID) + assert.Equal(t, original.Queue, deserialized.Queue) + assert.Equal(t, original.Change.Source, deserialized.Change.Source) + assert.Equal(t, original.Change.IDs, deserialized.Change.IDs) + assert.Equal(t, original.LandStrategy, deserialized.LandStrategy) + assert.Equal(t, original.State, deserialized.State) + assert.Equal(t, original.Version, deserialized.Version) +} + +func TestRequestFromBytes_InvalidJSON(t *testing.T) { + invalidJSON := []byte(`{"invalid": json"}`) + + _, err := RequestFromBytes(invalidJSON) + assert.Error(t, err) +} + +func TestRequestFromBytes_EmptyData(t *testing.T) { + emptyJSON := []byte(`{}`) + + req, err := RequestFromBytes(emptyJSON) + require.NoError(t, err) + + // Empty JSON should deserialize with zero values + assert.Empty(t, req.ID) + assert.Empty(t, req.Queue) + assert.Equal(t, RequestStateUnknown, req.State) + assert.Equal(t, RequestLandStrategyUnknown, req.LandStrategy) + assert.Equal(t, int32(0), req.Version) +} + +func TestRequest_SerializationRoundTrip(t *testing.T) { + tests := []struct { + name string + req Request + }{ + { + name: "full request", + req: Request{ + ID: "queue1/100", + Queue: "queue1", + Change: Change{Source: "github", IDs: []string{"PR-1", "PR-2", "PR-3"}}, + LandStrategy: RequestLandStrategySquashRebase, + State: RequestStateLanded, + Version: 5, + }, + }, + { + name: "minimal request", + req: Request{ + ID: "queue2/200", + Queue: "queue2", + Change: Change{Source: "phabricator", IDs: []string{"D123"}}, + LandStrategy: RequestLandStrategyRebase, + State: RequestStateNew, + Version: 1, + }, + }, + { + name: "error state request", + req: Request{ + ID: "queue3/300", + Queue: "queue3", + Change: Change{Source: "github", IDs: []string{"PR-999"}}, + LandStrategy: RequestLandStrategyMerge, + State: RequestStateError, + Version: 10, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Serialize + data, err := tt.req.ToBytes() + require.NoError(t, err) + + // Deserialize + deserialized, err := RequestFromBytes(data) + require.NoError(t, err) + + // Verify complete equality + assert.Equal(t, tt.req, deserialized) + }) + } +} diff --git a/example/server/gateway/BUILD.bazel b/example/server/gateway/BUILD.bazel index 94d2768a..7c5d03ff 100644 --- a/example/server/gateway/BUILD.bazel +++ b/example/server/gateway/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:private"], deps = [ "//extension/counter/mysql", + "//extension/queue/sql", "//extension/storage/mysql", "//gateway/controller", "//gateway/protopb", diff --git a/example/server/gateway/main.go b/example/server/gateway/main.go index fa2d87af..0d21bd88 100644 --- a/example/server/gateway/main.go +++ b/example/server/gateway/main.go @@ -14,6 +14,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql" + queueSQL "github.com/uber/submitqueue/extension/queue/sql" "github.com/uber/submitqueue/extension/storage/mysql" "github.com/uber/submitqueue/gateway/controller" pb "github.com/uber/submitqueue/gateway/protopb" @@ -109,16 +110,43 @@ func run() error { defer counterDB.Close() cnt := mysqlcounter.NewCounter(counterDB) + // Initialize queue + queueDSN := os.Getenv("QUEUE_MYSQL_DSN") + if queueDSN == "" { + return fmt.Errorf("QUEUE_MYSQL_DSN environment variable is required") + } + // Create gRPC server grpcServer := grpc.NewServer() // Create controllers and wrap them for gRPC pingController := controller.NewPingController(logger, scope) - landController := controller.NewLandController(logger, scope, store, cnt) + + queueDB, err := sql.Open("mysql", queueDSN) + if err != nil { + return fmt.Errorf("failed to open MySQL connection for queue: %w", err) + } + defer queueDB.Close() + + q, err := queueSQL.NewQueue(queueSQL.Params{ + DB: queueDB, + Logger: logger, + MetricsScope: scope.SubScope("queue"), + }) + if err != nil { + return fmt.Errorf("failed to create queue: %w", err) + } + defer q.Close() + + logger.Info("queue initialized", zap.String("dsn", queueDSN)) + + // Land controller requires queue publisher + landController := controller.NewLandController(logger.Sugar(), scope, store, cnt, q.Publisher()) gatewayServer := &GatewayServer{ pingController: pingController, landController: landController, } + pb.RegisterSubmitQueueGatewayServer(grpcServer, gatewayServer) // Register reflection service for debugging with grpcurl diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 95f833e8..5c7bd78c 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -6,8 +6,12 @@ go_library( importpath = "github.com/uber/submitqueue/example/server/orchestrator", visibility = ["//visibility:private"], deps = [ + "//core/consumer", + "//extension/queue/sql", "//orchestrator/controller", + "//orchestrator/controller/request", "//orchestrator/protopb", + "@com_github_go_sql_driver_mysql//:mysql", "@com_github_uber_go_tally_v4//:tally", "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//reflection", diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index 588b9d60..2a1232a8 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "fmt" "net" "os" @@ -10,8 +11,12 @@ import ( "syscall" "time" + _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/core/consumer" + queueSQL "github.com/uber/submitqueue/extension/queue/sql" "github.com/uber/submitqueue/orchestrator/controller" + "github.com/uber/submitqueue/orchestrator/controller/request" pb "github.com/uber/submitqueue/orchestrator/protopb" "go.uber.org/zap" "google.golang.org/grpc" @@ -75,6 +80,57 @@ func run() error { metricsWgDone.Wait() }() + // Initialize queue (optional - only if QUEUE_MYSQL_DSN is provided) + // This allows the server to start without queue infrastructure for basic testing + queueDSN := os.Getenv("QUEUE_MYSQL_DSN") + var c consumer.Consumer + if queueDSN != "" { + queueDB, err := sql.Open("mysql", queueDSN) + if err != nil { + return fmt.Errorf("failed to open MySQL connection for queue: %w", err) + } + defer queueDB.Close() + + q, err := queueSQL.NewQueue(queueSQL.Params{ + DB: queueDB, + Logger: logger, + MetricsScope: scope.SubScope("queue"), + }) + if err != nil { + return fmt.Errorf("failed to create queue: %w", err) + } + defer q.Close() + + logger.Info("queue initialized", zap.String("dsn", queueDSN)) + + // Create consumer + subscriberName := os.Getenv("HOSTNAME") + if subscriberName == "" { + subscriberName = fmt.Sprintf("orchestrator-%d", time.Now().Unix()) + } + + c = consumer.New(logger.Sugar(), scope.SubScope("consumer"), q, subscriberName) + + // Register handlers for the pipeline + requestHandler := request.NewController(logger.Sugar(), scope) + if err := c.Register(requestHandler); err != nil { + return fmt.Errorf("failed to register request handler: %w", err) + } + + logger.Info("handlers registered", zap.Int("count", 1)) + + // Start consumers + ctx := context.Background() + + if err := c.Start(ctx); err != nil { + return fmt.Errorf("failed to start consumers: %w", err) + } + + logger.Info("consumer started") + } else { + logger.Warn("queue infrastructure disabled (QUEUE_MYSQL_DSN not set)") + } + // Create gRPC server grpcServer := grpc.NewServer() @@ -114,12 +170,22 @@ func run() error { select { case <-sigCh: fmt.Println("\nShutting down orchestrator server...") + if c != nil { + if err := c.Stop(30000); err != nil { + logger.Error("consumer stop error", zap.Error(err)) + } + } grpcServer.GracefulStop() _ = <-serverErrCh // Wait for the server to exit and ignore the error case errCh := <-serverErrCh: if errCh != nil { err = fmt.Errorf("\nServer exited with error: %w\n", errCh) } + if c != nil { + if stopErr := c.Stop(30000); stopErr != nil { + logger.Error("consumer stop error", zap.Error(stopErr)) + } + } } return err diff --git a/extension/queue/delivery.go b/extension/queue/delivery.go index 05faca8c..26d19679 100644 --- a/extension/queue/delivery.go +++ b/extension/queue/delivery.go @@ -1,5 +1,7 @@ package queue +//go:generate mockgen -source=delivery.go -destination=mock/delivery.go -package=mock + import ( "context" @@ -24,6 +26,12 @@ type Delivery interface { // If requeueAfterMillis is 0, the message is requeued immediately. Nack(ctx context.Context, requeueAfterMillis int64) error + // Reject moves the message to the dead letter queue. + // Use for poison pill messages that should never be retried. + // reason is stored as last_error in the DLQ for debugging. + // If DLQ is not configured, the message is acked (removed from queue). + Reject(ctx context.Context, reason string) error + // ExtendVisibilityTimeout extends the time before this message becomes // visible to other consumers. Use when processing takes longer than expected. ExtendVisibilityTimeout(ctx context.Context, durationMillis int64) error diff --git a/extension/queue/mock/BUILD.bazel b/extension/queue/mock/BUILD.bazel new file mode 100644 index 00000000..787b3047 --- /dev/null +++ b/extension/queue/mock/BUILD.bazel @@ -0,0 +1,12 @@ +load("@rules_go//go:def.bzl", "go_library") + +go_library( + name = "mock", + srcs = ["delivery.go"], + importpath = "github.com/uber/submitqueue/extension/queue/mock", + visibility = ["//visibility:public"], + deps = [ + "//entity/queue", + "@org_uber_go_mock//gomock", + ], +) diff --git a/extension/queue/mock/delivery.go b/extension/queue/mock/delivery.go new file mode 100644 index 00000000..16da8d26 --- /dev/null +++ b/extension/queue/mock/delivery.go @@ -0,0 +1,168 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: extension/queue/delivery.go +// +// Generated by this command: +// +// mockgen -source=extension/queue/delivery.go -destination=extension/queue/mock/delivery.go -package=mock +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + queue "github.com/uber/submitqueue/entity/queue" + gomock "go.uber.org/mock/gomock" +) + +// MockDelivery is a mock of Delivery interface. +type MockDelivery struct { + ctrl *gomock.Controller + recorder *MockDeliveryMockRecorder + isgomock struct{} +} + +// MockDeliveryMockRecorder is the mock recorder for MockDelivery. +type MockDeliveryMockRecorder struct { + mock *MockDelivery +} + +// NewMockDelivery creates a new mock instance. +func NewMockDelivery(ctrl *gomock.Controller) *MockDelivery { + mock := &MockDelivery{ctrl: ctrl} + mock.recorder = &MockDeliveryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDelivery) EXPECT() *MockDeliveryMockRecorder { + return m.recorder +} + +// Ack mocks base method. +func (m *MockDelivery) Ack(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ack", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Ack indicates an expected call of Ack. +func (mr *MockDeliveryMockRecorder) Ack(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ack", reflect.TypeOf((*MockDelivery)(nil).Ack), ctx) +} + +// Attempt mocks base method. +func (m *MockDelivery) Attempt() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Attempt") + ret0, _ := ret[0].(int) + return ret0 +} + +// Attempt indicates an expected call of Attempt. +func (mr *MockDeliveryMockRecorder) Attempt() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Attempt", reflect.TypeOf((*MockDelivery)(nil).Attempt)) +} + +// DeliveryID mocks base method. +func (m *MockDelivery) DeliveryID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeliveryID") + ret0, _ := ret[0].(string) + return ret0 +} + +// DeliveryID indicates an expected call of DeliveryID. +func (mr *MockDeliveryMockRecorder) DeliveryID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeliveryID", reflect.TypeOf((*MockDelivery)(nil).DeliveryID)) +} + +// ExtendVisibilityTimeout mocks base method. +func (m *MockDelivery) ExtendVisibilityTimeout(ctx context.Context, durationMillis int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExtendVisibilityTimeout", ctx, durationMillis) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExtendVisibilityTimeout indicates an expected call of ExtendVisibilityTimeout. +func (mr *MockDeliveryMockRecorder) ExtendVisibilityTimeout(ctx, durationMillis any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendVisibilityTimeout", reflect.TypeOf((*MockDelivery)(nil).ExtendVisibilityTimeout), ctx, durationMillis) +} + +// Message mocks base method. +func (m *MockDelivery) Message() queue.Message { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Message") + ret0, _ := ret[0].(queue.Message) + return ret0 +} + +// Message indicates an expected call of Message. +func (mr *MockDeliveryMockRecorder) Message() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Message", reflect.TypeOf((*MockDelivery)(nil).Message)) +} + +// Metadata mocks base method. +func (m *MockDelivery) Metadata() map[string]string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Metadata") + ret0, _ := ret[0].(map[string]string) + return ret0 +} + +// Metadata indicates an expected call of Metadata. +func (mr *MockDeliveryMockRecorder) Metadata() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Metadata", reflect.TypeOf((*MockDelivery)(nil).Metadata)) +} + +// Nack mocks base method. +func (m *MockDelivery) Nack(ctx context.Context, requeueAfterMillis int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Nack", ctx, requeueAfterMillis) + ret0, _ := ret[0].(error) + return ret0 +} + +// Nack indicates an expected call of Nack. +func (mr *MockDeliveryMockRecorder) Nack(ctx, requeueAfterMillis any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Nack", reflect.TypeOf((*MockDelivery)(nil).Nack), ctx, requeueAfterMillis) +} + +// Reject mocks base method. +func (m *MockDelivery) Reject(ctx context.Context, reason string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Reject", ctx, reason) + ret0, _ := ret[0].(error) + return ret0 +} + +// Reject indicates an expected call of Reject. +func (mr *MockDeliveryMockRecorder) Reject(ctx, reason any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reject", reflect.TypeOf((*MockDelivery)(nil).Reject), ctx, reason) +} + +// ReceivedAt mocks base method. +func (m *MockDelivery) ReceivedAt() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceivedAt") + ret0, _ := ret[0].(int64) + return ret0 +} + +// ReceivedAt indicates an expected call of ReceivedAt. +func (mr *MockDeliveryMockRecorder) ReceivedAt() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAt", reflect.TypeOf((*MockDelivery)(nil).ReceivedAt)) +} diff --git a/extension/queue/sql/subscriber.go b/extension/queue/sql/subscriber.go index 7ab82221..c9f4b807 100644 --- a/extension/queue/sql/subscriber.go +++ b/extension/queue/sql/subscriber.go @@ -52,6 +52,9 @@ type sqlDelivery struct { messageID string consumerGroup string + // DLQ configuration for Reject + dlqConfig extqueue.DLQConfig + // Track acknowledgment state mu sync.Mutex acknowledged bool @@ -68,6 +71,7 @@ func newSQLDelivery( offset int64, messageID string, consumerGroup string, + dlqConfig extqueue.DLQConfig, ) *sqlDelivery { return &sqlDelivery{ msg: msg, @@ -81,6 +85,7 @@ func newSQLDelivery( offset: offset, messageID: messageID, consumerGroup: consumerGroup, + dlqConfig: dlqConfig, acknowledged: false, } } @@ -171,6 +176,57 @@ func (d *sqlDelivery) Nack(ctx context.Context, requeueAfterMillis int64) error return nil } +// Reject implements extqueue.Delivery.Reject +func (d *sqlDelivery) Reject(ctx context.Context, reason string) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.acknowledged { + return &ErrAlreadyAcknowledged{DeliveryID: d.deliveryID} + } + + if d.dlqConfig.Enabled { + // Move to DLQ + if err := d.subscriber.messageStore.MoveToDLQ( + ctx, d.topic, d.messageID, d.attempt, reason, d.dlqConfig.TopicSuffix, + ); err != nil { + return fmt.Errorf("failed to move message to DLQ: %w", err) + } + + // Update offset tracking + if err := d.subscriber.offsetStore.UpdateAckedOffset( + ctx, d.topic, d.partitionKey, d.offset, d.consumerGroup, + ); err != nil { + // Log but don't fail — message is already in DLQ + d.subscriber.logger.Errorw("failed to update offset after DLQ move", + "topic", d.topic, + "message_id", d.messageID, + "error", err, + ) + } + + d.subscriber.metrics.Tagged(map[string]string{ + "topic": d.topic, + "partition_key": d.partitionKey, + }).Counter("messages_rejected_to_dlq").Inc(1) + } else { + // DLQ disabled — fall back to ack (remove from queue) + if err := d.subscriber.offsetStore.AckMessage( + ctx, d.topic, d.partitionKey, d.messageID, d.offset, d.consumerGroup, d.subscriber.messageStore, + ); err != nil { + return err + } + + d.subscriber.metrics.Tagged(map[string]string{ + "topic": d.topic, + "partition_key": d.partitionKey, + }).Counter("messages_rejected_no_dlq").Inc(1) + } + + d.acknowledged = true + return nil +} + // ExtendVisibilityTimeout implements extqueue.Delivery.ExtendVisibilityTimeout func (d *sqlDelivery) ExtendVisibilityTimeout(ctx context.Context, durationMillis int64) error { d.mu.Lock() @@ -505,6 +561,7 @@ func (s *subscriber) fetchAndDeliverPartition(ctx context.Context, sub *subscrip row.Offset, row.ID, cfg.ConsumerGroup, + cfg.DLQ, ) // Deliver message diff --git a/extension/queue/sql/subscriber_test.go b/extension/queue/sql/subscriber_test.go index d2c4eff6..5d81107e 100644 --- a/extension/queue/sql/subscriber_test.go +++ b/extension/queue/sql/subscriber_test.go @@ -2,6 +2,7 @@ package sql import ( "context" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -10,6 +11,7 @@ import ( "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" + "github.com/uber/submitqueue/entity/queue" extqueue "github.com/uber/submitqueue/extension/queue" ) @@ -75,6 +77,112 @@ func TestSubscriber_Subscribe(t *testing.T) { } } +func TestSQLDelivery_Reject(t *testing.T) { + tests := []struct { + name string + dlqEnabled bool + alreadyAcked bool + moveToDLQErr error + ackMessageErr error + expectErr bool + expectMoveDLQ bool + expectAck bool + }{ + { + name: "DLQ enabled moves message to DLQ", + dlqEnabled: true, + expectMoveDLQ: true, + }, + { + name: "DLQ disabled falls back to ack", + expectAck: true, + }, + { + name: "already acknowledged returns error", + dlqEnabled: true, + alreadyAcked: true, + expectErr: true, + }, + { + name: "DLQ enabled but MoveToDLQ fails", + dlqEnabled: true, + moveToDLQErr: fmt.Errorf("db error"), + expectErr: true, + expectMoveDLQ: true, + }, + { + name: "DLQ disabled but AckMessage fails", + ackMessageErr: fmt.Errorf("db error"), + expectErr: true, + expectAck: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMsgStore := NewMockmessageStore(ctrl) + mockOffStore := NewMockoffsetStore(ctrl) + mockLeaseStore := NewMockpartitionLeaseStore(ctrl) + + sub := NewSubscriber( + zaptest.NewLogger(t).Sugar(), + tally.NoopScope, + mockMsgStore, + mockOffStore, + mockLeaseStore, + ) + + msg := queue.NewMessage("msg-1", []byte("payload"), "part-1", nil) + dlqConfig := extqueue.DLQConfig{ + Enabled: tt.dlqEnabled, + TopicSuffix: "_dlq", + } + + d := newSQLDelivery( + msg, "1", 1, nil, + sub, "test_topic", "part-1", 100, "msg-1", "test-group", + dlqConfig, + ) + + if tt.alreadyAcked { + d.acknowledged = true + } + + if tt.expectMoveDLQ { + mockMsgStore.EXPECT().MoveToDLQ( + gomock.Any(), "test_topic", "msg-1", 1, "bad payload", "_dlq", + ).Return(tt.moveToDLQErr) + + if tt.moveToDLQErr == nil { + mockOffStore.EXPECT().UpdateAckedOffset( + gomock.Any(), "test_topic", "part-1", int64(100), "test-group", + ).Return(nil) + } + } + + if tt.expectAck { + mockOffStore.EXPECT().AckMessage( + gomock.Any(), "test_topic", "part-1", "msg-1", int64(100), "test-group", mockMsgStore, + ).Return(tt.ackMessageErr) + } + + err := d.Reject(context.Background(), "bad payload") + + if tt.expectErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.True(t, d.acknowledged) + }) + } +} + +// TestSubscriber_Close tests subscriber close behavior func TestSubscriber_Close(t *testing.T) { tests := []struct { name string diff --git a/gateway/controller/BUILD.bazel b/gateway/controller/BUILD.bazel index 6493d9cd..808b81b3 100644 --- a/gateway/controller/BUILD.bazel +++ b/gateway/controller/BUILD.bazel @@ -10,7 +10,9 @@ go_library( visibility = ["//visibility:public"], deps = [ "//entity", + "//entity/queue", "//extension/counter", + "//extension/queue", "//extension/storage", "//gateway/protopb", "@com_github_uber_go_tally_v4//:tally", @@ -27,6 +29,7 @@ go_test( embed = [":controller"], deps = [ "//entity", + "//entity/queue", "//extension/storage", "//gateway/protopb", "@com_github_stretchr_testify//assert", diff --git a/gateway/controller/land.go b/gateway/controller/land.go index cf30fa91..d2b59718 100644 --- a/gateway/controller/land.go +++ b/gateway/controller/land.go @@ -8,7 +8,9 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/entity/queue" "github.com/uber/submitqueue/extension/counter" + extqueue "github.com/uber/submitqueue/extension/queue" "github.com/uber/submitqueue/extension/storage" pb "github.com/uber/submitqueue/gateway/protopb" "go.uber.org/zap" @@ -25,19 +27,21 @@ func IsInvalidRequest(err error) bool { // LandController handles land business logic for the gateway type LandController struct { - logger *zap.Logger + logger *zap.SugaredLogger metricsScope tally.Scope store storage.Storage counter counter.Counter + publisher extqueue.Publisher } -// NewLandController creates a new instance of the gateway land controller -func NewLandController(logger *zap.Logger, scope tally.Scope, store storage.Storage, counter counter.Counter) *LandController { +// NewLandController creates a new instance of the gateway land controller. +func NewLandController(logger *zap.SugaredLogger, scope tally.Scope, store storage.Storage, counter counter.Counter, publisher extqueue.Publisher) *LandController { return &LandController{ logger: logger, metricsScope: scope, store: store, counter: counter, + publisher: publisher, } } @@ -94,16 +98,58 @@ func (c *LandController) Land(ctx context.Context, req *pb.LandRequest) (*pb.Lan return nil, fmt.Errorf("LandController failed to create request for queue=%s: %w", req.Queue, err) } - c.logger.Debug("land request received", - zap.String("queue", req.Queue), - zap.String("sqid", request.ID), + c.logger.Debugw("land request created", + "queue", req.Queue, + "sqid", request.ID, + "change_source", change.Source, + "change_ids", change.IDs, + "strategy", string(strategy), ) + // Publish to queue for async processing + if err := c.publishToQueue(ctx, request); err != nil { + c.logger.Errorw("failed to publish request to queue", + "queue", req.Queue, + "sqid", request.ID, + "error", err, + ) + return nil, fmt.Errorf("LandController failed to publish request to queue: %w", err) + } + + c.logger.Infow("request published to queue", + "queue", req.Queue, + "sqid", request.ID, + "topic", "request", + ) + c.metricsScope.Counter("publish_success").Inc(1) + return &pb.LandResponse{ Sqid: request.ID, }, nil } +// publishToQueue publishes a request to the request queue for async processing. +func (c *LandController) publishToQueue(ctx context.Context, request entity.Request) error { + // Serialize request entity to JSON + payload, err := request.ToBytes() + if err != nil { + return fmt.Errorf("failed to serialize request: %w", err) + } + + // Create queue message + // - Message ID: request.ID for idempotency + // - Payload: serialized Request entity + // - Partition key: request.Queue (ensures ordering per queue) + msg := queue.NewMessage(request.ID, payload, request.Queue, nil) + + // Publish to request topic + if err := c.publisher.Publish(ctx, "request", msg); err != nil { + return fmt.Errorf("failed to publish message: %w", err) + } + + return nil +} + // protoStrategyToEntity maps a proto Strategy enum to the entity RequestLandStrategy. func resolveRequestLandStrategy(s pb.Strategy) (entity.RequestLandStrategy, error) { switch s { diff --git a/gateway/controller/land_test.go b/gateway/controller/land_test.go index 6a598c32..581bbaf1 100644 --- a/gateway/controller/land_test.go +++ b/gateway/controller/land_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/entity/queue" "github.com/uber/submitqueue/extension/storage" pb "github.com/uber/submitqueue/gateway/protopb" "go.uber.org/zap" @@ -50,6 +51,25 @@ func (m *mockStorage) Close() error { return nil } +type mockPublisher struct { + publishFunc func(ctx context.Context, topic string, msg queue.Message) error +} + +func (m *mockPublisher) Publish(ctx context.Context, topic string, msg queue.Message) error { + return m.publishFunc(ctx, topic, msg) +} + +func (m *mockPublisher) Close() error { + return nil +} + +// noopPublisher returns a mock publisher that succeeds silently. +func noopPublisher() *mockPublisher { + return &mockPublisher{publishFunc: func(ctx context.Context, topic string, msg queue.Message) error { + return nil + }} +} + func TestNewLandController(t *testing.T) { store := &mockStorage{requestStore: &mockRequestStore{ createFunc: func(ctx context.Context, request entity.Request) error { @@ -59,7 +79,7 @@ func TestNewLandController(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 1, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) require.NotNil(t, controller) } @@ -72,7 +92,7 @@ func TestLand_ReturnsSqid(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 1, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -97,7 +117,7 @@ func TestLand_PassesCorrectParametersToStore(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 42, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -127,7 +147,7 @@ func TestLand_ReturnsErrorOnStorageFailure(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 1, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -148,7 +168,7 @@ func TestLand_ReturnsErrorOnCounterFailure(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 0, fmt.Errorf("counter unavailable") }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -172,7 +192,7 @@ func TestLand_CounterDomainIncludesQueue(t *testing.T) { capturedDomain = domain return 1, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -194,7 +214,7 @@ func TestLand_ReturnsErrorOnEmptyQueue(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 1, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -216,7 +236,7 @@ func TestLand_ReturnsErrorOnEmptyChangeSource(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 1, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -238,7 +258,7 @@ func TestLand_ReturnsErrorOnNilChange(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 1, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -260,7 +280,7 @@ func TestLand_ReturnsErrorOnEmptyChangeIDs(t *testing.T) { cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { return 1, nil }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, store, cnt) + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, noopPublisher()) ctx := context.Background() req := &pb.LandRequest{ @@ -272,3 +292,77 @@ func TestLand_ReturnsErrorOnEmptyChangeIDs(t *testing.T) { require.Error(t, err) assert.True(t, IsInvalidRequest(err)) } + +func TestLand_PublishesToQueue(t *testing.T) { + var publishedTopic string + var publishedMessage queue.Message + + store := &mockStorage{requestStore: &mockRequestStore{ + createFunc: func(ctx context.Context, request entity.Request) error { + return nil + }, + }} + cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { + return 123, nil + }} + publisher := &mockPublisher{publishFunc: func(ctx context.Context, topic string, msg queue.Message) error { + publishedTopic = topic + publishedMessage = msg + return nil + }} + + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, publisher) + ctx := context.Background() + + req := &pb.LandRequest{ + Queue: "test-queue", + Change: &pb.Change{Source: "github", Ids: []string{"PR-456"}}, + Strategy: pb.Strategy_REBASE, + } + resp, err := controller.Land(ctx, req) + + require.NoError(t, err) + assert.Equal(t, "test-queue/123", resp.Sqid) + + // Verify message was published + assert.Equal(t, "request", publishedTopic) + assert.Equal(t, "test-queue/123", publishedMessage.ID) + assert.Equal(t, "test-queue", publishedMessage.PartitionKey) + + // Verify payload can be deserialized + deserializedReq, err := entity.RequestFromBytes(publishedMessage.Payload) + require.NoError(t, err) + assert.Equal(t, "test-queue/123", deserializedReq.ID) + assert.Equal(t, "test-queue", deserializedReq.Queue) + assert.Equal(t, "github", deserializedReq.Change.Source) + assert.Equal(t, []string{"PR-456"}, deserializedReq.Change.IDs) + assert.Equal(t, entity.RequestLandStrategyRebase, deserializedReq.LandStrategy) + assert.Equal(t, entity.RequestStateNew, deserializedReq.State) + assert.Equal(t, int32(1), deserializedReq.Version) +} + +func TestLand_ContinuesWhenPublishFails(t *testing.T) { + store := &mockStorage{requestStore: &mockRequestStore{ + createFunc: func(ctx context.Context, request entity.Request) error { + return nil + }, + }} + cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { + return 999, nil + }} + publisher := &mockPublisher{publishFunc: func(ctx context.Context, topic string, msg queue.Message) error { + return fmt.Errorf("queue unavailable") + }} + + controller := NewLandController(zap.NewNop().Sugar(), tally.NoopScope, store, cnt, publisher) + ctx := context.Background() + + req := &pb.LandRequest{ + Queue: "test-queue", + Change: &pb.Change{Source: "github", Ids: []string{"PR-1"}}, + } + _, err := controller.Land(ctx, req) + + // Should fail if publish fails + require.Error(t, err) +} diff --git a/orchestrator/controller/request/BUILD.bazel b/orchestrator/controller/request/BUILD.bazel new file mode 100644 index 00000000..7f04bc5e --- /dev/null +++ b/orchestrator/controller/request/BUILD.bazel @@ -0,0 +1,33 @@ +load("@rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "request", + srcs = ["request.go"], + importpath = "github.com/uber/submitqueue/orchestrator/controller/request", + visibility = ["//visibility:public"], + deps = [ + "//core/consumer", + "//entity", + "//extension/queue", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "request_test", + srcs = ["request_test.go"], + embed = [":request"], + deps = [ + "//core/consumer", + "//entity", + "//entity/queue", + "//extension/queue", + "//extension/queue/mock", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_mock//gomock", + "@org_uber_go_zap//zaptest", + ], +) diff --git a/orchestrator/controller/request/request.go b/orchestrator/controller/request/request.go new file mode 100644 index 00000000..339b8975 --- /dev/null +++ b/orchestrator/controller/request/request.go @@ -0,0 +1,111 @@ +package request + +import ( + "context" + "fmt" + + "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/entity" + extqueue "github.com/uber/submitqueue/extension/queue" + "go.uber.org/zap" +) + +// Controller handles request queue messages. +// Implements consumer.Controller interface for integration with the consumer. +type Controller struct { + logger *zap.SugaredLogger + metricsScope tally.Scope + topic string + consumerGroup string +} + +// Verify Controller implements consumer.Controller interface at compile time. +var _ consumer.Controller = (*Controller)(nil) + +// NewController creates a new request controller for the orchestrator. +func NewController(logger *zap.SugaredLogger, scope tally.Scope) *Controller { + return &Controller{ + logger: logger.Named("request_controller"), + metricsScope: scope.SubScope("request_controller"), + topic: "request", + consumerGroup: "orchestrator-request", + } +} + +// Process processes a request delivery from the queue. +// Deserializes the request, logs the event, and prepares for future state transitions. +// Returns nil to ack (success), or error to nack (retry). +func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) error { + c.metricsScope.Counter("received").Inc(1) + + msg := delivery.Message() + + // Deserialize request entity + request, err := entity.RequestFromBytes(msg.Payload) + if err != nil { + c.logger.Errorw("failed to deserialize request", + "message_id", msg.ID, + "partition_key", msg.PartitionKey, + "attempt", delivery.Attempt(), + "error", err, + ) + c.metricsScope.Counter("deserialize_errors").Inc(1) + // Non-retryable: malformed messages will never succeed regardless of retry count + return consumer.NewNonRetryableError(fmt.Errorf("failed to deserialize request: %w", err)) + } + + c.logger.Infow("received land request event", + "request_id", request.ID, + "queue", request.Queue, + "state", string(request.State), + "land_strategy", string(request.LandStrategy), + "change_source", request.Change.Source, + "change_ids", request.Change.IDs, + "version", request.Version, + "attempt", delivery.Attempt(), + "partition_key", msg.PartitionKey, + ) + + // TODO: Update request state to processing + // TODO: Perform validation checks + // TODO: Publish to next queue (requests_for_batching) + + c.metricsScope.Counter("processed").Inc(1) + + return nil // Success - message will be acked +} + +// Name returns the controller name for logging and metrics. +func (c *Controller) Name() string { + return "request" +} + +// Topic returns the topic this controller subscribes to. +func (c *Controller) Topic() string { + return c.topic +} + +// ConsumerGroup returns the consumer group for offset tracking. +func (c *Controller) ConsumerGroup() string { + return c.consumerGroup +} + +// SubscriptionConfig returns the subscription config for the request controller. +// Uses default settings which work well for request processing (100ms poll, 60s visibility timeout). +func (c *Controller) SubscriptionConfig(subscriberName string) extqueue.SubscriptionConfig { + config := extqueue.DefaultSubscriptionConfig(subscriberName, c.consumerGroup) + + // Request controller uses default settings: + // - PollInterval: 100ms (fast polling for immediate request processing) + // - BatchSize: 10 + // - VisibilityTimeout: 60s + // - Retry: 3 attempts with exponential backoff + // - DLQ: enabled + + // Can customize if needed: + // config.PollInterval = 50 * time.Millisecond // Even faster polling + // config.BatchSize = 20 // Process more requests at once + + return config +} diff --git a/orchestrator/controller/request/request_test.go b/orchestrator/controller/request/request_test.go new file mode 100644 index 00000000..d22d7ec0 --- /dev/null +++ b/orchestrator/controller/request/request_test.go @@ -0,0 +1,201 @@ +package request + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally/v4" + "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/entity/queue" + extqueue "github.com/uber/submitqueue/extension/queue" + "github.com/uber/submitqueue/extension/queue/mock" + "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" +) + +func TestNewController(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + + controller := NewController(logger, scope) + + require.NotNil(t, controller) + assert.Equal(t, "request", controller.Topic()) + assert.Equal(t, "orchestrator-request", controller.ConsumerGroup()) + assert.Equal(t, "request", controller.Name()) +} + +func TestController_Process_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + controller := NewController(logger, scope) + + // Create a valid request + request := entity.Request{ + ID: "test-queue/123", + Queue: "test-queue", + Change: entity.Change{Source: "github", IDs: []string{"PR-456"}}, + LandStrategy: entity.RequestLandStrategyRebase, + State: entity.RequestStateNew, + Version: 1, + } + + // Serialize to bytes + payload, err := request.ToBytes() + require.NoError(t, err) + + // Create delivery with mock + msg := queue.NewMessage("test-queue/123", payload, "test-queue", nil) + delivery := mock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + // Handle the delivery + ctx := context.Background() + err = controller.Process(ctx, delivery) + + // Should return nil (success) + require.NoError(t, err) +} + +func TestController_Process_InvalidJSON(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + controller := NewController(logger, scope) + + // Create delivery with invalid JSON + invalidPayload := []byte(`{"invalid": json"}`) + msg := queue.NewMessage("invalid-msg", invalidPayload, "partition1", nil) + delivery := mock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + // Process the delivery + ctx := context.Background() + err := controller.Process(ctx, delivery) + + // Should return NonRetryableError for malformed messages + require.Error(t, err) + assert.True(t, consumer.IsNonRetryable(err)) +} + +func TestController_Process_AllRequestStates(t *testing.T) { + tests := []struct { + name string + state entity.RequestState + strategy entity.RequestLandStrategy + }{ + {"new request", entity.RequestStateNew, entity.RequestLandStrategyRebase}, + {"processing request", entity.RequestStateProcessing, entity.RequestLandStrategySquashRebase}, + {"landed request", entity.RequestStateLanded, entity.RequestLandStrategyMerge}, + {"error request", entity.RequestStateError, entity.RequestLandStrategyRebase}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + controller := NewController(logger, scope) + + request := entity.Request{ + ID: fmt.Sprintf("queue/%s", tt.state), + Queue: "test-queue", + Change: entity.Change{Source: "github", IDs: []string{"PR-1"}}, + LandStrategy: tt.strategy, + State: tt.state, + Version: 1, + } + + payload, err := request.ToBytes() + require.NoError(t, err) + + msg := queue.NewMessage(request.ID, payload, request.Queue, nil) + delivery := mock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + ctx := context.Background() + err = controller.Process(ctx, delivery) + + require.NoError(t, err) + }) + } +} + +func TestController_Process_MultipleChanges(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + controller := NewController(logger, scope) + + request := entity.Request{ + ID: "queue/999", + Queue: "test-queue", + Change: entity.Change{ + Source: "github", + IDs: []string{"PR-1", "PR-2", "PR-3"}, // Multiple PRs + }, + LandStrategy: entity.RequestLandStrategySquashRebase, + State: entity.RequestStateNew, + Version: 1, + } + + payload, err := request.ToBytes() + require.NoError(t, err) + + msg := queue.NewMessage(request.ID, payload, request.Queue, nil) + delivery := mock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + ctx := context.Background() + err = controller.Process(ctx, delivery) + + require.NoError(t, err) +} + +func TestController_SubscriptionConfig(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + controller := NewController(logger, scope) + + config := controller.SubscriptionConfig("test-worker-123") + + assert.Equal(t, "test-worker-123", config.SubscriberName) + assert.Equal(t, "orchestrator-request", config.ConsumerGroup) + assert.Equal(t, int64(100), config.PollIntervalMs) // 100ms + assert.Equal(t, 10, config.BatchSize) + assert.Equal(t, int64(60000), config.VisibilityTimeoutMs) // 60s + assert.Equal(t, 3, config.Retry.MaxAttempts) + assert.True(t, config.DLQ.Enabled) +} + +func TestController_InterfaceImplementation(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + scope := tally.NoopScope + controller := NewController(logger, scope) + + // Verify implements consumer.Controller interface + var _ interface { + Process(ctx context.Context, delivery consumer.Delivery) error + Name() string + Topic() string + ConsumerGroup() string + SubscriptionConfig(subscriberName string) extqueue.SubscriptionConfig + } = controller +}