diff --git a/integration_tests/BUILD.bazel b/integration_tests/BUILD.bazel index 8e0a1e40..12ef2e5b 100644 --- a/integration_tests/BUILD.bazel +++ b/integration_tests/BUILD.bazel @@ -5,7 +5,6 @@ load("@rules_go//go:def.bzl", "go_test") go_test( name = "integration_test", srcs = [ - "mysql.go", "servers.go", "suite_test.go", ], @@ -18,6 +17,7 @@ go_test( ], tags = ["integration"], deps = [ + "//integration_tests/testutil", "//gateway/protopb", "//orchestrator/protopb", "//speculator/protopb", diff --git a/integration_tests/mysql.go b/integration_tests/mysql.go deleted file mode 100644 index b1e2933b..00000000 --- a/integration_tests/mysql.go +++ /dev/null @@ -1,123 +0,0 @@ -package integration_tests - -import ( - "context" - "database/sql" - "os" - "path/filepath" - "sort" - "testing" - "time" - - _ "github.com/go-sql-driver/mysql" - "github.com/stretchr/testify/require" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/mysql" - "github.com/testcontainers/testcontainers-go/network" -) - -// testLogger is a simple test-aware logger that records elapsed time between logs. -type testLogger struct { - t *testing.T // The testing object to report logs to. - last time.Time // Timestamp of the last log, for elapsed calculation. -} - -// newTestLogger creates a testLogger for the current test. -func newTestLogger(t *testing.T) *testLogger { - t.Helper() - return &testLogger{t: t} -} - -// logf prints a formatted log message with timestamp and elapsed time since last log. -func (l *testLogger) logf(format string, args ...any) { - l.t.Helper() - now := time.Now() - delta := "" - if !l.last.IsZero() { - delta = " +" + now.Sub(l.last).Truncate(time.Millisecond).String() - } - l.last = now - l.t.Logf("[%s%s] "+format, append([]any{now.Format(time.RFC3339Nano), delta}, args...)...) -} - -// schemaDirs returns the paths to all schema directories. -// It checks for both Bazel runfiles and direct go test paths. -func schemaDirs() []string { - dirs := []string{ - "extensions/storage/mysql/schema", - "extensions/counter/mysql/schema", - } - - if srcDir := os.Getenv("TEST_SRCDIR"); srcDir != "" { - workspace := os.Getenv("TEST_WORKSPACE") - result := make([]string, len(dirs)) - for i, d := range dirs { - result[i] = filepath.Join(srcDir, workspace, d) - } - return result - } - - return dirs -} - -// applySchema reads all .sql files from the schema directories and executes them on the database. -func applySchema(t *testing.T, log *testLogger, db *sql.DB) { - t.Helper() - - for _, dir := range schemaDirs() { - files, err := filepath.Glob(filepath.Join(dir, "*.sql")) - require.NoError(t, err, "failed to glob schema files in %s", dir) - require.NotEmpty(t, files, "no .sql schema files found in %s", dir) - - // Sort files to ensure deterministic schema application order. - sort.Strings(files) - - for _, f := range files { - name := filepath.Base(f) - log.logf("Applying schema: %s", name) - - content, err := os.ReadFile(f) - require.NoError(t, err, "failed to read schema file %s", name) - - _, err = db.ExecContext(context.Background(), string(content)) - require.NoError(t, err, "failed to execute schema file %s", name) - - log.logf("Schema applied: %s", name) - } - } -} - -// setupMySQL starts a MySQL container on the given Docker network, applies the schema, -// and registers cleanup. The container is reachable by other containers on the network at "mysql:3306". -func setupMySQL(t *testing.T, log *testLogger, nw *testcontainers.DockerNetwork) { - t.Helper() - - ctx := context.Background() - - log.logf("Starting MySQL container") - mysqlContainer, err := mysql.Run(ctx, "mysql:8.0", - mysql.WithDatabase("submitqueue"), - mysql.WithUsername("root"), - mysql.WithPassword("root"), - network.WithNetwork([]string{"mysql"}, nw), - ) - require.NoError(t, err, "failed to start MySQL container") - log.logf("MySQL container started") - t.Cleanup(func() { - log.logf("Terminating MySQL container") - require.NoError(t, mysqlContainer.Terminate(ctx), "failed to terminate MySQL container") - log.logf("MySQL container terminated") - }) - - dsn, err := mysqlContainer.ConnectionString(ctx, "parseTime=true") - require.NoError(t, err, "failed to get MySQL connection string") - log.logf("MySQL DSN obtained: %s", dsn) - - log.logf("Opening MySQL connection") - db, err := sql.Open("mysql", dsn) - require.NoError(t, err, "failed to open MySQL connection") - log.logf("MySQL connection opened") - defer db.Close() - - applySchema(t, log, db) -} diff --git a/integration_tests/queue/BUILD.bazel b/integration_tests/queue/BUILD.bazel new file mode 100644 index 00000000..af3da33f --- /dev/null +++ b/integration_tests/queue/BUILD.bazel @@ -0,0 +1,25 @@ +load("@rules_go//go:def.bzl", "go_test") + +go_test( + name = "queue_test", + srcs = ["queue_test.go"], + data = [ + "//extensions/queue/sql/schema", + ], + tags = ["integration"], + deps = [ + "//entities/queue", + "//extensions/queue", + "//extensions/queue/sql", + "//integration_tests/testutil", + "@com_github_go_sql_driver_mysql//:mysql", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@com_github_stretchr_testify//suite", + "@com_github_testcontainers_testcontainers_go//:testcontainers-go", + "@com_github_testcontainers_testcontainers_go//network", + "@com_github_testcontainers_testcontainers_go_modules_mysql//:mysql", + "@com_github_uber_go_tally_v4//:tally", + "@org_uber_go_zap//zaptest", + ], +) diff --git a/integration_tests/queue/queue_test.go b/integration_tests/queue/queue_test.go new file mode 100644 index 00000000..0c2db7f4 --- /dev/null +++ b/integration_tests/queue/queue_test.go @@ -0,0 +1,1196 @@ +package queue_test + +import ( + "context" + "database/sql" + "fmt" + "strconv" + "sync" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/mysql" + "github.com/uber-go/tally/v4" + "go.uber.org/zap/zaptest" + + "github.com/uber/submitqueue/entities/queue" + extqueue "github.com/uber/submitqueue/extensions/queue" + queueSQL "github.com/uber/submitqueue/extensions/queue/sql" + "github.com/uber/submitqueue/integration_tests/testutil" +) + +type QueueIntegrationSuite struct { + suite.Suite + ctx context.Context + db *sql.DB + container *mysql.MySQLContainer + network *testcontainers.DockerNetwork + dsn string + log *testutil.TestLogger +} + +func TestQueueIntegration(t *testing.T) { + suite.Run(t, new(QueueIntegrationSuite)) +} + +func (s *QueueIntegrationSuite) SetupSuite() { + t := s.T() + s.ctx = context.Background() + s.log = testutil.NewTestLogger(t) + + // Setup Docker environment and network + s.network, s.ctx = testutil.SetupDockerEnv(t, s.log, s.ctx) + + // Setup MySQL using shared helper + s.container, s.db, s.dsn = testutil.SetupMySQL(t, s.log, s.network, "extensions/queue/sql/schema") +} + +func (s *QueueIntegrationSuite) TearDownSuite() { + if s.db != nil { + s.db.Close() + } + if s.container != nil { + require.NoError(s.T(), s.container.Terminate(s.ctx)) + } +} + +// receiveWithTimeout receives a single delivery from the channel with a timeout. +// Returns the delivery or fails the test on timeout. +func receiveWithTimeout(t *testing.T, deliveryChan <-chan extqueue.Delivery, timeout time.Duration) extqueue.Delivery { + t.Helper() + select { + case delivery := <-deliveryChan: + return delivery + case <-time.After(timeout): + t.Fatalf("Timeout waiting for delivery after %v", timeout) + return nil + } +} + +// receiveNWithTimeout receives N deliveries from the channel with a timeout. +// Calls the provided handler for each delivery. +func receiveNWithTimeout( + t *testing.T, + deliveryChan <-chan extqueue.Delivery, + count int, + timeout time.Duration, + handler func(delivery extqueue.Delivery, index int), +) { + t.Helper() + deadline := time.After(timeout) + for i := 0; i < count; i++ { + select { + case delivery := <-deliveryChan: + handler(delivery, i) + case <-deadline: + t.Fatalf("Timeout waiting for message %d/%d after %v", i+1, count, timeout) + } + } +} + +func (s *QueueIntegrationSuite) TestPublishAndSubscribe() { + t := s.T() + + // Create queue factory + config := queueSQL.DefaultConfig("test-consumer", "test-worker-1") + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + topic := "test_topic" + + // Subscribe first + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish messages with various metadata scenarios + msg1 := queue.NewMessage("msg-1", []byte("hello"), "partition-1", map[string]string{ + "key1": "value1", + "key2": "value2", + "trace_id": "abc123", + }) + + msg2 := queue.NewMessage("msg-2", []byte("world"), "partition-1", nil) + + err = publisher.Publish(s.ctx, topic, msg1) + require.NoError(t, err) + + err = publisher.Publish(s.ctx, topic, msg2) + require.NoError(t, err) + + t.Logf("Published 2 messages") + + // Receive and ack messages + receiveNWithTimeout(t, deliveryChan, 2, 5*time.Second, func(delivery extqueue.Delivery, index int) { + msg := delivery.Message() + t.Logf("Received message: id=%s payload=%s", msg.ID, string(msg.Payload)) + + if index == 0 { + // Verify message content + assert.Equal(t, "msg-1", msg.ID) + assert.Equal(t, []byte("hello"), msg.Payload) + assert.Equal(t, "partition-1", msg.PartitionKey) + + // Verify metadata round-trip (published metadata preserved exactly) + assert.Equal(t, 3, len(msg.Metadata), "Should have 3 metadata keys") + assert.Equal(t, "value1", msg.Metadata["key1"]) + assert.Equal(t, "value2", msg.Metadata["key2"]) + assert.Equal(t, "abc123", msg.Metadata["trace_id"]) + } else { + // Verify message with nil metadata + assert.Equal(t, "msg-2", msg.ID) + assert.Equal(t, []byte("world"), msg.Payload) + assert.NotNil(t, msg.Metadata, "Metadata should be initialized (not nil)") + assert.Equal(t, 0, len(msg.Metadata), "Empty metadata should have 0 keys") + } + + // Ack the message + err := delivery.Ack(s.ctx) + require.NoError(t, err) + }) + + t.Logf("Successfully received and acked 2 messages with metadata verified") +} + +func (s *QueueIntegrationSuite) TestMultiplePartitions() { + t := s.T() + + config := queueSQL.DefaultConfig("multi-partition-consumer", "worker-1") + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + topic := "multi_partition_topic" + + // Subscribe + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish messages to different partitions + partitions := []string{"part-A", "part-B", "part-C"} + expectedCount := len(partitions) * 2 // 2 messages per partition + + for _, partition := range partitions { + msg1 := queue.NewMessage(partition+"-msg-1", []byte("data-1"), partition, nil) + msg2 := queue.NewMessage(partition+"-msg-2", []byte("data-2"), partition, nil) + + require.NoError(t, publisher.Publish(s.ctx, topic, msg1)) + require.NoError(t, publisher.Publish(s.ctx, topic, msg2)) + } + + t.Logf("Published %d messages across %d partitions", expectedCount, len(partitions)) + + // Receive all messages + receiveNWithTimeout(t, deliveryChan, expectedCount, 10*time.Second, func(delivery extqueue.Delivery, index int) { + msg := delivery.Message() + t.Logf("Received: partition=%s id=%s", msg.PartitionKey, msg.ID) + require.NoError(t, delivery.Ack(s.ctx)) + }) + + t.Logf("Successfully processed all %d messages", expectedCount) +} + +func (s *QueueIntegrationSuite) TestVisibilityTimeoutAndRetry() { + t := s.T() + + // Use short visibility timeout for faster test + config := queueSQL.DefaultConfig("retry-consumer", "worker-1") + config.VisibilityTimeout = 2 * time.Second + config.PollInterval = 100 * time.Millisecond + + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + topic := "retry_topic" + + // Subscribe + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish a message + msg := queue.NewMessage("retry-msg", []byte("test"), "retry-partition", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + + t.Logf("Published message, expecting visibility timeout retry") + + // Test 1: ExtendVisibilityTimeout allows longer processing time + t.Logf("Test 1: ExtendVisibilityTimeout") + firstDelivery := receiveWithTimeout(t, deliveryChan, 5*time.Second) + t.Logf("First delivery: attempt=%d", firstDelivery.Attempt()) + assert.Equal(t, 1, firstDelivery.Attempt()) + + // Extend visibility timeout by 3 seconds + extensionDuration := 3 * time.Second + t.Logf("Extending visibility timeout by %v", extensionDuration) + err = firstDelivery.ExtendVisibilityTimeout(s.ctx, extensionDuration.Milliseconds()) + require.NoError(t, err) + + // Wait for original visibility timeout to expire (but not the extended timeout) + t.Logf("Waiting for original visibility timeout (%v) - message should NOT reappear", config.VisibilityTimeout) + time.Sleep(config.VisibilityTimeout + 200*time.Millisecond) + + // Message should NOT be redelivered yet (visibility was extended) + select { + case <-deliveryChan: + t.Fatal("Message should not be redelivered yet - visibility was extended") + case <-time.After(500 * time.Millisecond): + t.Logf("✓ Confirmed: message not redelivered during extended visibility") + } + + // Now ack the message successfully + t.Logf("Acking message after extended processing time") + require.NoError(t, firstDelivery.Ack(s.ctx)) + + // Test 2: Visibility timeout retry when not acked + t.Logf("Test 2: Visibility timeout retry") + + // Publish another message + msg2 := queue.NewMessage("retry-msg-2", []byte("test2"), "retry-partition", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg2)) + + // Receive first time + secondDelivery := receiveWithTimeout(t, deliveryChan, 5*time.Second) + t.Logf("Second message delivery: attempt=%d", secondDelivery.Attempt()) + assert.Equal(t, 1, secondDelivery.Attempt()) + // Don't ack - let it become visible again + + // Wait for visibility timeout to expire + t.Logf("Waiting for visibility timeout to expire...") + time.Sleep(config.VisibilityTimeout + 500*time.Millisecond) + + // Receive second time (retry) + thirdDelivery := receiveWithTimeout(t, deliveryChan, 5*time.Second) + t.Logf("Retry delivery: attempt=%d", thirdDelivery.Attempt()) + assert.Greater(t, thirdDelivery.Attempt(), 1, "retry count should increase") + assert.Equal(t, "retry-msg-2", thirdDelivery.Message().ID) + // Ack this time + require.NoError(t, thirdDelivery.Ack(s.ctx)) + + t.Logf("Successfully tested ExtendVisibilityTimeout and visibility timeout retry") +} + +func (s *QueueIntegrationSuite) TestNackWithDelay() { + t := s.T() + + config := queueSQL.DefaultConfig("nack-consumer", "worker-1") + config.PollInterval = 100 * time.Millisecond + + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + topic := "nack_topic" + + // Subscribe + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish message + msg := queue.NewMessage("nack-msg", []byte("test"), "nack-partition", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + + // Receive and Nack with delay + nackDelay := 2 * time.Second + + delivery := receiveWithTimeout(t, deliveryChan, 5*time.Second) + t.Logf("Received message, nacking with %s delay", nackDelay) + nackErr := delivery.Nack(s.ctx, int64(nackDelay.Milliseconds())) + require.NoError(t, nackErr) + + // Should NOT receive immediately + select { + case <-deliveryChan: + t.Fatal("Message should not be visible immediately after Nack") + case <-time.After(500 * time.Millisecond): + t.Logf("Confirmed message is not visible immediately") + } + + // Wait for nack delay to expire + time.Sleep(nackDelay) + + // Should receive again now + delivery2 := receiveWithTimeout(t, deliveryChan, 5*time.Second) + t.Logf("Received message again after nack delay") + assert.Equal(t, "nack-msg", delivery2.Message().ID) + require.NoError(t, delivery2.Ack(s.ctx)) +} + +func (s *QueueIntegrationSuite) TestIdempotentPublish() { + t := s.T() + + config := queueSQL.DefaultConfig("idempotent-consumer", "worker-1") + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + topic := "idempotent_topic" + + // Subscribe + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish same message twice + msg := queue.NewMessage("same-id", []byte("duplicate"), "same-partition", nil) + + err1 := publisher.Publish(s.ctx, topic, msg) + require.NoError(t, err1) + + err2 := publisher.Publish(s.ctx, topic, msg) + // Second publish should fail with duplicate key error since message already exists + require.Error(t, err2, "duplicate publish should return error") + + t.Logf("Published same message twice - second attempt correctly rejected") + + // Should only receive once + delivery := receiveWithTimeout(t, deliveryChan, 5*time.Second) + t.Logf("Received message: %s", delivery.Message().ID) + require.NoError(t, delivery.Ack(s.ctx)) + + // Verify no second message arrives + select { + case <-deliveryChan: + t.Fatal("Received duplicate message - idempotency check failed") + case <-time.After(1 * time.Second): + t.Logf("Confirmed: only received message once (idempotency works)") + } +} + +func (s *QueueIntegrationSuite) TestConcurrentPublishers() { + t := s.T() + + config := queueSQL.DefaultConfig("concurrent-consumer", "worker-1") + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + topic := "concurrent_topic" + + // Subscribe + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish from multiple goroutines + numPublishers := 5 + messagesPerPublisher := 3 + totalMessages := numPublishers * messagesPerPublisher + + errChan := make(chan error, totalMessages) + for i := 0; i < numPublishers; i++ { + go func(publisherID int) { + for j := 0; j < messagesPerPublisher; j++ { + msg := queue.NewMessage( + t.Name()+"-"+string(rune(publisherID))+"-"+string(rune(j)), + []byte("concurrent"), + "concurrent-partition", + nil, + ) + errChan <- publisher.Publish(s.ctx, topic, msg) + } + }(i) + } + + // Check all publishes succeeded + for i := 0; i < totalMessages; i++ { + require.NoError(t, <-errChan) + } + + t.Logf("Published %d messages concurrently", totalMessages) + + // Receive all messages + receiveNWithTimeout(t, deliveryChan, totalMessages, 10*time.Second, func(delivery extqueue.Delivery, index int) { + require.NoError(t, delivery.Ack(s.ctx)) + }) + + t.Logf("Received all %d concurrent messages", totalMessages) +} + +func (s *QueueIntegrationSuite) TestCrashRecovery() { + t := s.T() + + // Use short timeouts for faster test + config := queueSQL.DefaultConfig("crash-consumer", "worker-1") + config.VisibilityTimeout = 2 * time.Second + config.PollInterval = 100 * time.Millisecond + config.LeaseDuration = 3 * time.Second // Short lease for testing crash recovery + config.LeaseRenewalInterval = 1 * time.Second // Must be less than LeaseDuration + + factory1, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + + publisher := factory1.Publisher() + subscriber1 := factory1.Subscriber() + + topic := "crash_topic" + + // Subscribe with first worker + deliveryChan1, err := subscriber1.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish message + msg := queue.NewMessage("crash-msg", []byte("test-crash"), "crash-partition", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + + // Worker 1 receives but doesn't ack (simulating crash) + delivery1 := receiveWithTimeout(t, deliveryChan1, 5*time.Second) + t.Logf("Worker 1 received message but crashing without ack") + assert.Equal(t, "crash-msg", delivery1.Message().ID) + + // Simulate crash by closing factory1 + factory1.Close() + t.Logf("Worker 1 crashed (factory closed)") + + // Wait for both visibility timeout AND partition lease to expire + waitTime := config.LeaseDuration + config.VisibilityTimeout + time.Second + t.Logf("Waiting %v for lease and visibility timeout to expire", waitTime) + time.Sleep(waitTime) + + // Start worker 2 with same consumer group + config2 := queueSQL.DefaultConfig("crash-consumer", "worker-2") + config2.VisibilityTimeout = 2 * time.Second + config2.PollInterval = 100 * time.Millisecond + config2.LeaseDuration = 3 * time.Second + config2.LeaseRenewalInterval = 1 * time.Second + + factory2, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config2, + }) + require.NoError(t, err) + defer factory2.Close() + + subscriber2 := factory2.Subscriber() + deliveryChan2, err := subscriber2.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Worker 2 should receive the same message (recovery) + delivery2 := receiveWithTimeout(t, deliveryChan2, 5*time.Second) + t.Logf("Worker 2 recovered message: attempt=%d", delivery2.Attempt()) + assert.Equal(t, "crash-msg", delivery2.Message().ID) + assert.Greater(t, delivery2.Attempt(), 1, "should be a retry after crash") + + // Worker 2 successfully acks + require.NoError(t, delivery2.Ack(s.ctx)) + t.Logf("Crash recovery successful: message processed by worker 2") +} + +func (s *QueueIntegrationSuite) TestMultipleConsumerGroups() { + t := s.T() + + topic := "multi_group_topic" + + // Create two different consumer groups + config1 := queueSQL.DefaultConfig("group-A", "worker-1") + factory1, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config1, + }) + require.NoError(t, err) + defer factory1.Close() + + config2 := queueSQL.DefaultConfig("group-B", "worker-1") + factory2, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config2, + }) + require.NoError(t, err) + defer factory2.Close() + + publisher := factory1.Publisher() + subscriber1 := factory1.Subscriber() + subscriber2 := factory2.Subscriber() + + // Subscribe both groups + deliveryChan1, err := subscriber1.Subscribe(s.ctx, topic) + require.NoError(t, err) + + deliveryChan2, err := subscriber2.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish messages + numMessages := 3 + messageIDs := make([]string, numMessages) + for i := 0; i < numMessages; i++ { + msgID := fmt.Sprintf("msg-%d", i) + messageIDs[i] = msgID + msg := queue.NewMessage(msgID, []byte(fmt.Sprintf("data-%d", i)), "partition-1", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + } + t.Logf("Published %d messages to topic", numMessages) + + // Both groups should receive all messages + group1Messages := make(map[string]bool) + group2Messages := make(map[string]bool) + + // Receive from group A + receiveNWithTimeout(t, deliveryChan1, numMessages, 10*time.Second, func(delivery extqueue.Delivery, index int) { + msgID := delivery.Message().ID + t.Logf("Group A received: %s", msgID) + group1Messages[msgID] = true + require.NoError(t, delivery.Ack(s.ctx)) + }) + + // Receive from group B + receiveNWithTimeout(t, deliveryChan2, numMessages, 10*time.Second, func(delivery extqueue.Delivery, index int) { + msgID := delivery.Message().ID + t.Logf("Group B received: %s", msgID) + group2Messages[msgID] = true + require.NoError(t, delivery.Ack(s.ctx)) + }) + + // Verify both groups got all messages + assert.Equal(t, numMessages, len(group1Messages), "Group A should receive all messages") + assert.Equal(t, numMessages, len(group2Messages), "Group B should receive all messages") + + for _, msgID := range messageIDs { + assert.True(t, group1Messages[msgID], "Group A missing message: %s", msgID) + assert.True(t, group2Messages[msgID], "Group B missing message: %s", msgID) + } + + t.Logf("Both consumer groups independently received all %d messages", numMessages) +} + +func (s *QueueIntegrationSuite) TestMultipleWorkersInConsumerGroup() { + t := s.T() + + topic := "multi_worker_topic" + consumerGroup := "shared-group" + + // Create two workers in same consumer group + config1 := queueSQL.DefaultConfig(consumerGroup, "worker-1") + factory1, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config1, + }) + require.NoError(t, err) + defer factory1.Close() + + config2 := queueSQL.DefaultConfig(consumerGroup, "worker-2") + factory2, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config2, + }) + require.NoError(t, err) + defer factory2.Close() + + publisher := factory1.Publisher() + subscriber1 := factory1.Subscriber() + subscriber2 := factory2.Subscriber() + + // Subscribe both workers + deliveryChan1, err := subscriber1.Subscribe(s.ctx, topic) + require.NoError(t, err) + + deliveryChan2, err := subscriber2.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish messages to different partitions so they can be distributed + numMessages := 10 + messageIDs := make([]string, numMessages) + for i := 0; i < numMessages; i++ { + msgID := fmt.Sprintf("msg-%d", i) + messageIDs[i] = msgID + // Use different partition keys to allow distribution + partitionKey := fmt.Sprintf("partition-%d", i%3) + msg := queue.NewMessage(msgID, []byte(fmt.Sprintf("data-%d", i)), partitionKey, nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + } + t.Logf("Published %d messages to topic across multiple partitions", numMessages) + + // Collect messages from both workers concurrently + allMessages := make(map[string]int) // msgID -> count (should be 1 for each) + var mu sync.Mutex + var wg sync.WaitGroup + + wg.Add(2) + + // Worker 1 receiver + go func() { + defer wg.Done() + for { + select { + case delivery := <-deliveryChan1: + msgID := delivery.Message().ID + mu.Lock() + allMessages[msgID]++ + mu.Unlock() + t.Logf("Worker 1 received: %s (total received: %d)", msgID, len(allMessages)) + require.NoError(t, delivery.Ack(s.ctx)) + + if len(allMessages) == numMessages { + return + } + case <-time.After(10 * time.Second): + return + } + } + }() + + // Worker 2 receiver + go func() { + defer wg.Done() + for { + select { + case delivery := <-deliveryChan2: + msgID := delivery.Message().ID + mu.Lock() + allMessages[msgID]++ + mu.Unlock() + t.Logf("Worker 2 received: %s (total received: %d)", msgID, len(allMessages)) + require.NoError(t, delivery.Ack(s.ctx)) + + if len(allMessages) == numMessages { + return + } + case <-time.After(10 * time.Second): + return + } + } + }() + + wg.Wait() + + // Verify all messages received exactly once + assert.Equal(t, numMessages, len(allMessages), "Should receive all messages") + + for _, msgID := range messageIDs { + count, exists := allMessages[msgID] + assert.True(t, exists, "Missing message: %s", msgID) + assert.Equal(t, 1, count, "Message %s received %d times (expected 1)", msgID, count) + } + + t.Logf("Load balanced: %d messages distributed across 2 workers with no duplicates", numMessages) +} + +func (s *QueueIntegrationSuite) TestConcurrentSubscribers() { + t := s.T() + + topic := "concurrent_subscribers_topic" + consumerGroup := "concurrent-group" + numSubscribers := 3 + messagesPerSubscriber := 5 + totalMessages := numSubscribers * messagesPerSubscriber + + // Create publisher + publisherConfig := queueSQL.DefaultConfig(consumerGroup, "publisher") + publisherFactory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: publisherConfig, + }) + require.NoError(t, err) + defer publisherFactory.Close() + + publisher := publisherFactory.Publisher() + + // Create multiple concurrent subscribers + var factories []extqueue.Queue + var deliveryChans []<-chan extqueue.Delivery + + for i := 0; i < numSubscribers; i++ { + config := queueSQL.DefaultConfig(consumerGroup, fmt.Sprintf("worker-%d", i)) + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + factories = append(factories, factory) + + subscriber := factory.Subscriber() + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + deliveryChans = append(deliveryChans, deliveryChan) + } + + // Cleanup all factories + defer func() { + for _, f := range factories { + f.Close() + } + }() + + t.Logf("Started %d concurrent subscribers", numSubscribers) + + // Publish messages to multiple partitions + for i := 0; i < totalMessages; i++ { + msgID := fmt.Sprintf("concurrent-msg-%d", i) + partitionKey := fmt.Sprintf("partition-%d", i%5) + msg := queue.NewMessage(msgID, []byte(fmt.Sprintf("data-%d", i)), partitionKey, nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + } + t.Logf("Published %d messages", totalMessages) + + // Collect messages from all subscribers concurrently + allMessages := make(map[string]int) // msgID -> count + var mu sync.Mutex + var wg sync.WaitGroup + + for i, deliveryChan := range deliveryChans { + wg.Add(1) + go func(workerID int, ch <-chan extqueue.Delivery) { + defer wg.Done() + workerMessages := 0 + for { + select { + case delivery := <-ch: + msgID := delivery.Message().ID + mu.Lock() + allMessages[msgID]++ + totalReceived := len(allMessages) + mu.Unlock() + + t.Logf("Worker %d received: %s (total unique: %d)", workerID, msgID, totalReceived) + require.NoError(t, delivery.Ack(s.ctx)) + workerMessages++ + + if totalReceived >= totalMessages { + t.Logf("Worker %d processed %d messages", workerID, workerMessages) + return + } + case <-time.After(10 * time.Second): + t.Logf("Worker %d timeout after processing %d messages", workerID, workerMessages) + return + } + } + }(i, deliveryChan) + } + + wg.Wait() + + // Verify all messages received exactly once + assert.Equal(t, totalMessages, len(allMessages), "Should receive all messages") + + duplicates := 0 + for msgID, count := range allMessages { + if count > 1 { + t.Errorf("Message %s received %d times (duplicate!)", msgID, count) + duplicates++ + } + } + + assert.Equal(t, 0, duplicates, "Should have no duplicate messages") + t.Logf("Concurrent subscribers test: %d messages processed by %d workers with no duplicates", totalMessages, numSubscribers) +} + +func (s *QueueIntegrationSuite) TestDeadLetterQueue() { + t := s.T() + + topic := "dlq_topic" + + // Configure with low max attempts and DLQ enabled + config := queueSQL.DefaultConfig("dlq-consumer", "worker-1") + config.PollInterval = 100 * time.Millisecond + config.VisibilityTimeout = 1 * time.Second + config.Retry.MaxAttempts = 2 // Only 2 attempts before DLQ + config.DLQ.Enabled = true + + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + // Subscribe to main topic + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish a message that will fail + msg := queue.NewMessage("poison-msg", []byte("poison"), "partition-1", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + + t.Logf("Published poison message, will nack repeatedly") + + // Receive and nack the message MaxAttempts times + for attempt := 1; attempt <= config.Retry.MaxAttempts; attempt++ { + delivery := receiveWithTimeout(t, deliveryChan, 10*time.Second) + t.Logf("Attempt %d: received message, nacking", delivery.Attempt()) + assert.Equal(t, attempt, delivery.Attempt()) + assert.Equal(t, "poison-msg", delivery.Message().ID) + + // Nack without delay to retry immediately + require.NoError(t, delivery.Nack(s.ctx, 0)) + + // Wait a bit for visibility timeout + time.Sleep(config.VisibilityTimeout + 200*time.Millisecond) + } + + // After MaxAttempts, message should be moved to DLQ topic + t.Logf("Message should be moved to DLQ after %d failed attempts", config.Retry.MaxAttempts) + + // Should NOT receive on main topic anymore (message moved to DLQ) + select { + case <-deliveryChan: + t.Fatal("Should not receive message on main topic after max retries (should be in DLQ)") + case <-time.After(3 * time.Second): + t.Logf("Confirmed: message removed from main topic") + } + + // Subscribe to DLQ topic to consume the failed message + dlqTopic := topic + config.DLQ.TopicSuffix + t.Logf("Subscribing to DLQ topic: %s", dlqTopic) + + dlqDeliveryChan, err := subscriber.Subscribe(s.ctx, dlqTopic) + require.NoError(t, err) + + // Receive the message from DLQ + dlqDelivery := receiveWithTimeout(t, dlqDeliveryChan, 10*time.Second) + assert.Equal(t, "poison-msg", dlqDelivery.Message().ID) + assert.Equal(t, []byte("poison"), dlqDelivery.Message().Payload) + assert.Equal(t, "partition-1", dlqDelivery.Message().PartitionKey) + + // Verify DLQ-specific metadata is exposed through delivery metadata + metadata := dlqDelivery.Metadata() + assert.Contains(t, metadata, "dlq.failed_at") + assert.Contains(t, metadata, "dlq.failure_count") + assert.Contains(t, metadata, "dlq.last_error") + assert.Contains(t, metadata, "dlq.original_topic") + + // Verify values + assert.Equal(t, topic, metadata["dlq.original_topic"]) + assert.Equal(t, fmt.Sprintf("%d", config.Retry.MaxAttempts), metadata["dlq.failure_count"]) + assert.Equal(t, "exceeded retry limit", metadata["dlq.last_error"]) + + failedAt := metadata["dlq.failed_at"] + failedAtInt, err := strconv.ParseInt(failedAt, 10, 64) + require.NoError(t, err) + assert.Greater(t, failedAtInt, int64(0), "dlq.failed_at should be a valid timestamp") + + // Acknowledge the DLQ message + require.NoError(t, dlqDelivery.Ack(s.ctx)) + + t.Logf("DLQ test successful: poison message consumed from DLQ topic '%s' with metadata: %+v", dlqTopic, metadata) +} + +func (s *QueueIntegrationSuite) TestMessageOrderingWithinPartition() { + t := s.T() + + topic := "ordering_topic" + partitionKey := "ordered-partition" + + config := queueSQL.DefaultConfig("ordering-consumer", "worker-1") + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + // Subscribe first + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish messages with same partition key (should be ordered) + numMessages := 10 + messageIDs := make([]string, numMessages) + for i := 0; i < numMessages; i++ { + msgID := fmt.Sprintf("msg-%03d", i) + messageIDs[i] = msgID + msg := queue.NewMessage(msgID, []byte(fmt.Sprintf("order-%d", i)), partitionKey, nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + } + t.Logf("Published %d messages to same partition: %s", numMessages, partitionKey) + + // Receive and verify ordering + receivedOrder := make([]string, 0, numMessages) + receiveNWithTimeout(t, deliveryChan, numMessages, 10*time.Second, func(delivery extqueue.Delivery, index int) { + msgID := delivery.Message().ID + receivedOrder = append(receivedOrder, msgID) + t.Logf("Received in order: %s", msgID) + require.NoError(t, delivery.Ack(s.ctx)) + }) + + // Verify messages received in exact publish order + for i := 0; i < numMessages; i++ { + assert.Equal(t, messageIDs[i], receivedOrder[i], + "Message at position %d out of order: expected %s, got %s", + i, messageIDs[i], receivedOrder[i]) + } + + t.Logf("FIFO ordering verified: all %d messages received in exact publish order", numMessages) +} + +func (s *QueueIntegrationSuite) TestLateSubscriber() { + t := s.T() + + topic := "late_subscriber_topic" + + config := queueSQL.DefaultConfig("late-consumer", "worker-1") + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + publisher := factory.Publisher() + + // Publish messages BEFORE subscribing + numMessages := 5 + messageIDs := make([]string, numMessages) + for i := 0; i < numMessages; i++ { + msgID := fmt.Sprintf("early-msg-%d", i) + messageIDs[i] = msgID + msg := queue.NewMessage(msgID, []byte(fmt.Sprintf("data-%d", i)), "partition-1", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + } + t.Logf("Published %d messages BEFORE subscribing", numMessages) + + // Now subscribe (late subscriber) + subscriber := factory.Subscriber() + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + t.Logf("Late subscriber joined after messages published") + + // Late subscriber should receive all messages + receivedMessages := make(map[string]bool) + receiveNWithTimeout(t, deliveryChan, numMessages, 10*time.Second, func(delivery extqueue.Delivery, index int) { + msgID := delivery.Message().ID + receivedMessages[msgID] = true + t.Logf("Late subscriber received: %s", msgID) + require.NoError(t, delivery.Ack(s.ctx)) + }) + + // Verify all messages received + assert.Equal(t, numMessages, len(receivedMessages), "Should receive all pre-published messages") + for _, msgID := range messageIDs { + assert.True(t, receivedMessages[msgID], "Missing message: %s", msgID) + } + + t.Logf("Late subscriber successfully received all %d pre-published messages", numMessages) +} + +func (s *QueueIntegrationSuite) TestEmptyTopicSubscribe() { + t := s.T() + + topic := "empty_topic" + + config := queueSQL.DefaultConfig("empty-consumer", "worker-1") + config.PollInterval = 100 * time.Millisecond + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory.Close() + + subscriber := factory.Subscriber() + + // Subscribe to empty topic (no messages published yet) + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + t.Logf("Subscribed to empty topic") + + // Should not receive anything immediately + select { + case <-deliveryChan: + t.Fatal("Should not receive any messages from empty topic") + case <-time.After(1 * time.Second): + t.Logf("Confirmed: no messages on empty topic") + } + + // Now publish a message + publisher := factory.Publisher() + msg := queue.NewMessage("late-msg", []byte("data"), "partition-1", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + t.Logf("Published message to previously-empty topic") + + // Should now receive the message + delivery := receiveWithTimeout(t, deliveryChan, 5*time.Second) + assert.Equal(t, "late-msg", delivery.Message().ID) + require.NoError(t, delivery.Ack(s.ctx)) + + t.Logf("Successfully received message published after subscription to empty topic") +} + +func (s *QueueIntegrationSuite) TestGracefulShutdownDuringProcessing() { + t := s.T() + + topic := "shutdown_topic" + + config := queueSQL.DefaultConfig("shutdown-consumer", "worker-1") + config.PollInterval = 100 * time.Millisecond + factory, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + + publisher := factory.Publisher() + subscriber := factory.Subscriber() + + // Subscribe + deliveryChan, err := subscriber.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Publish messages + numMessages := 5 + for i := 0; i < numMessages; i++ { + msg := queue.NewMessage(fmt.Sprintf("msg-%d", i), []byte("data"), "partition-1", nil) + require.NoError(t, publisher.Publish(s.ctx, topic, msg)) + } + t.Logf("Published %d messages", numMessages) + + // Receive one message but don't ack yet (in-flight) + delivery := receiveWithTimeout(t, deliveryChan, 5*time.Second) + inFlightMsgID := delivery.Message().ID + t.Logf("Received in-flight message: %s (not acked yet)", inFlightMsgID) + + // Close the factory while message is in-flight + t.Logf("Closing factory with in-flight message...") + err = factory.Close() + require.NoError(t, err) + t.Logf("Factory closed successfully") + + // Drain any buffered messages from the channel without acking them + // These messages were already fetched and marked invisible + drained := 0 +drainLoop: + for { + select { + case msg, ok := <-deliveryChan: + if !ok { + // Channel closed - this is expected + t.Logf("✓ Delivery channel closed after draining %d buffered messages (not acked)", drained) + break drainLoop + } + drained++ + // Don't ack - let them become visible again after timeout + t.Logf("Drained buffered message (not acked): %s", msg.Message().ID) + case <-time.After(1 * time.Second): + t.Logf("Delivery channel not closed after draining %d messages, may still be open", drained) + break drainLoop + } + } + + // Don't try to ack the in-flight message - we want it to be redelivered + // (Acking after close might succeed and delete the message) + + // Wait for visibility timeout to expire so messages become visible again + // All messages (in-flight + buffered) were fetched and marked invisible + t.Logf("Waiting for visibility timeout to expire (%v) so messages become visible again...", config.VisibilityTimeout) + time.Sleep(config.VisibilityTimeout + 500*time.Millisecond) + + // Start new subscriber to verify all messages are redelivered + t.Logf("Starting new subscriber to verify message recovery...") + factory2, err := queueSQL.NewQueue(queueSQL.Params{ + DB: s.db, + Logger: zaptest.NewLogger(t), + MetricsScope: tally.NoopScope, + Config: config, + }) + require.NoError(t, err) + defer factory2.Close() + + subscriber2 := factory2.Subscriber() + deliveryChan2, err := subscriber2.Subscribe(s.ctx, topic) + require.NoError(t, err) + + // Receive all unprocessed messages (all should be redelivered after visibility timeout) + receivedIDs := make(map[string]bool) + expectedMessages := 1 + drained // in-flight + drained buffered messages + if expectedMessages == 0 { + expectedMessages = numMessages // fallback if nothing was drained + } + + for i := 0; i < expectedMessages; i++ { + delivery := receiveWithTimeout(t, deliveryChan2, 10*time.Second) + msgID := delivery.Message().ID + receivedIDs[msgID] = true + t.Logf("Recovered message %d/%d: %s", i+1, expectedMessages, msgID) + require.NoError(t, delivery.Ack(s.ctx)) + } + + // Verify the in-flight message was redelivered + assert.True(t, receivedIDs[inFlightMsgID], "In-flight message should be redelivered") + assert.GreaterOrEqual(t, len(receivedIDs), 1, "Should receive at least the in-flight message") + + t.Logf("Graceful shutdown test successful: %d messages recovered (including in-flight)", len(receivedIDs)) +} diff --git a/integration_tests/servers.go b/integration_tests/servers.go index b10a5be1..0708c81a 100644 --- a/integration_tests/servers.go +++ b/integration_tests/servers.go @@ -12,6 +12,7 @@ import ( "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/network" "github.com/testcontainers/testcontainers-go/wait" + "github.com/uber/submitqueue/integration_tests/testutil" ) const serverPort = "8080" @@ -29,7 +30,7 @@ func serverBinaryPath(name string) string { func startServerContainer( ctx context.Context, t *testing.T, - log *testLogger, + log *testutil.TestLogger, name string, env map[string]string, nw *testcontainers.DockerNetwork, @@ -37,7 +38,7 @@ func startServerContainer( t.Helper() binaryPath := serverBinaryPath(name) - log.logf("Resolved %s binary: %s", name, binaryPath) + log.Logf("Resolved %s binary: %s", name, binaryPath) // Create temp build context with binary and Dockerfile. tmpDir := t.TempDir() @@ -48,7 +49,7 @@ func startServerContainer( env["PORT"] = ":" + serverPort - log.logf("Starting %s container", name) + log.Logf("Starting %s container", name) ctr, err := testcontainers.Run(ctx, "", testcontainers.WithDockerfile(testcontainers.FromDockerfile{ Context: tmpDir, @@ -56,16 +57,25 @@ func startServerContainer( }), testcontainers.WithExposedPorts(serverPort+"/tcp"), testcontainers.WithEnv(env), - testcontainers.WithWaitStrategy(wait.ForLog("server is running")), + testcontainers.WithWaitStrategy(wait.ForLog("gRPC server is running")), network.WithNetwork([]string{name}, nw), ) - require.NoError(t, err, "failed to start %s container", name) + if err != nil { + // Print container logs on failure if container was created + if ctr != nil { + if logs, logErr := ctr.Logs(ctx); logErr == nil { + logBytes, _ := io.ReadAll(logs) + log.Logf("%s container logs:\n%s", name, string(logBytes)) + } + } + require.NoError(t, err, "failed to start %s container", name) + } t.Cleanup(func() { - log.logf("Terminating %s container", name) + log.Logf("Terminating %s container", name) if err := ctr.Terminate(ctx); err != nil { t.Logf("failed to terminate %s container: %v", name, err) } - log.logf("%s container terminated", name) + log.Logf("%s container terminated", name) }) mappedPort, err := ctr.MappedPort(ctx, serverPort+"/tcp") @@ -73,12 +83,12 @@ func startServerContainer( host, err := ctr.Host(ctx) require.NoError(t, err, "failed to get host for %s", name) addr := fmt.Sprintf("%s:%s", host, mappedPort.Port()) - log.logf("%s container started on %s", name, addr) + log.Logf("%s container started on %s", name, addr) return ctr, addr } -func startGatewayContainer(ctx context.Context, t *testing.T, log *testLogger, nw *testcontainers.DockerNetwork) string { +func startGatewayContainer(ctx context.Context, t *testing.T, log *testutil.TestLogger, nw *testcontainers.DockerNetwork) string { t.Helper() _, addr := startServerContainer(ctx, t, log, "gateway", map[string]string{ "MYSQL_DSN": "root:root@tcp(mysql:3306)/submitqueue?parseTime=true", @@ -86,13 +96,13 @@ func startGatewayContainer(ctx context.Context, t *testing.T, log *testLogger, n return addr } -func startOrchestratorContainer(ctx context.Context, t *testing.T, log *testLogger, nw *testcontainers.DockerNetwork) string { +func startOrchestratorContainer(ctx context.Context, t *testing.T, log *testutil.TestLogger, nw *testcontainers.DockerNetwork) string { t.Helper() _, addr := startServerContainer(ctx, t, log, "orchestrator", map[string]string{}, nw) return addr } -func startSpeculatorContainer(ctx context.Context, t *testing.T, log *testLogger, nw *testcontainers.DockerNetwork) string { +func startSpeculatorContainer(ctx context.Context, t *testing.T, log *testutil.TestLogger, nw *testcontainers.DockerNetwork) string { t.Helper() _, addr := startServerContainer(ctx, t, log, "speculator", map[string]string{}, nw) return addr diff --git a/integration_tests/suite_test.go b/integration_tests/suite_test.go index c1ffd4d0..d3e370ad 100644 --- a/integration_tests/suite_test.go +++ b/integration_tests/suite_test.go @@ -2,15 +2,14 @@ package integration_tests import ( "context" - "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/network" gatewaypb "github.com/uber/submitqueue/gateway/protopb" + "github.com/uber/submitqueue/integration_tests/testutil" orchestratorpb "github.com/uber/submitqueue/orchestrator/protopb" speculatorpb "github.com/uber/submitqueue/speculator/protopb" "google.golang.org/grpc" @@ -19,7 +18,7 @@ import ( type IntegrationSuite struct { suite.Suite - log *testLogger + log *testutil.TestLogger nw *testcontainers.DockerNetwork @@ -37,28 +36,27 @@ func TestIntegration(t *testing.T) { func (s *IntegrationSuite) SetupSuite() { t := s.T() ctx := context.Background() - s.log = newTestLogger(t) - - // Disable Ryuk reaper container which may not work in Docker-in-Docker environments. - t.Setenv("TESTCONTAINERS_RYUK_DISABLED", "true") - - // Ensure HOME is set for Docker config resolution in Bazel sandbox. - if os.Getenv("HOME") == "" { - t.Setenv("HOME", t.TempDir()) - } - - // Create Docker network for inter-container communication. - nw, err := network.New(ctx) - require.NoError(t, err, "failed to create Docker network") - s.nw = nw - t.Cleanup(func() { - s.log.logf("Removing Docker network") - require.NoError(t, nw.Remove(ctx), "failed to remove Docker network") + s.log = testutil.NewTestLogger(t) + + // Setup Docker environment and network + s.nw, ctx = testutil.SetupDockerEnv(t, s.log, ctx) + + // Start MySQL container on the network and apply schemas. + mysqlContainer, db, _ := testutil.SetupMySQL(t, s.log, s.nw, "extensions/storage/mysql/schema") + testutil.ApplySchema(t, s.log, db, testutil.SchemaDir("extensions/counter/mysql/schema")) + + // Register MySQL cleanup + s.addCleanup(func() { + s.log.Logf("Closing MySQL connection") + if err := db.Close(); err != nil { + s.log.Logf("Failed to close MySQL connection: %v", err) + } + s.log.Logf("Terminating MySQL container") + if err := mysqlContainer.Terminate(context.Background()); err != nil { + s.log.Logf("Failed to terminate MySQL container: %v", err) + } + s.log.Logf("MySQL container terminated") }) - s.log.logf("Docker network created: %s", nw.Name) - - // Start MySQL container on the network and apply schema. - setupMySQL(t, s.log, s.nw) // Start all server containers. gatewayAddr := startGatewayContainer(ctx, t, s.log, s.nw) @@ -71,7 +69,7 @@ func (s *IntegrationSuite) SetupSuite() { s.orchestratorClient = orchestratorpb.NewSubmitQueueOrchestratorClient(s.dial(orchestratorAddr, opts)) s.speculatorClient = speculatorpb.NewSubmitQueueSpeculatorClient(s.dial(speculatorAddr, opts)) - s.log.logf("All containers started and clients connected") + s.log.Logf("All containers started and clients connected") } func (s *IntegrationSuite) TearDownSuite() { @@ -96,7 +94,7 @@ func (s *IntegrationSuite) TestPingGateway() { resp, err := s.gatewayClient.Ping(ctx, &gatewaypb.PingRequest{Message: "integration test"}) require.NoError(s.T(), err, "Gateway Ping failed") assert.Equal(s.T(), "gateway", resp.ServiceName) - s.log.logf("Gateway ping: %s", resp.Message) + s.log.Logf("Gateway ping: %s", resp.Message) } func (s *IntegrationSuite) TestPingOrchestrator() { @@ -104,7 +102,7 @@ func (s *IntegrationSuite) TestPingOrchestrator() { resp, err := s.orchestratorClient.Ping(ctx, &orchestratorpb.PingRequest{Message: "integration test"}) require.NoError(s.T(), err, "Orchestrator Ping failed") assert.Equal(s.T(), "orchestrator", resp.ServiceName) - s.log.logf("Orchestrator ping: %s", resp.Message) + s.log.Logf("Orchestrator ping: %s", resp.Message) } func (s *IntegrationSuite) TestPingSpeculator() { @@ -112,7 +110,7 @@ func (s *IntegrationSuite) TestPingSpeculator() { resp, err := s.speculatorClient.Ping(ctx, &speculatorpb.PingRequest{Message: "integration test"}) require.NoError(s.T(), err, "Speculator Ping failed") assert.Equal(s.T(), "speculator", resp.ServiceName) - s.log.logf("Speculator ping: %s", resp.Message) + s.log.Logf("Speculator ping: %s", resp.Message) } func (s *IntegrationSuite) TestLandRequest() { @@ -123,9 +121,9 @@ func (s *IntegrationSuite) TestLandRequest() { Strategy: gatewaypb.Strategy_REBASE, } - s.log.logf("Sending Land request for queue=%s", req.Queue) + s.log.Logf("Sending Land request for queue=%s", req.Queue) resp, err := s.gatewayClient.Land(ctx, req) require.NoError(s.T(), err, "Land request failed") require.NotEmpty(s.T(), resp.Sqid, "SQID should not be empty") - s.log.logf("Land request succeeded: sqid=%s", resp.Sqid) + s.log.Logf("Land request succeeded: sqid=%s", resp.Sqid) } diff --git a/integration_tests/testutil/BUILD.bazel b/integration_tests/testutil/BUILD.bazel new file mode 100644 index 00000000..3d909e9f --- /dev/null +++ b/integration_tests/testutil/BUILD.bazel @@ -0,0 +1,18 @@ +load("@rules_go//go:def.bzl", "go_library") + +go_library( + name = "testutil", + srcs = [ + "docker.go", + "mysql.go", + ], + importpath = "github.com/uber/submitqueue/integration_tests/testutil", + visibility = ["//visibility:public"], + deps = [ + "@com_github_go_sql_driver_mysql//:mysql", + "@com_github_stretchr_testify//require", + "@com_github_testcontainers_testcontainers_go//:testcontainers-go", + "@com_github_testcontainers_testcontainers_go//network", + "@com_github_testcontainers_testcontainers_go_modules_mysql//:mysql", + ], +) diff --git a/integration_tests/testutil/docker.go b/integration_tests/testutil/docker.go new file mode 100644 index 00000000..e5c4460e --- /dev/null +++ b/integration_tests/testutil/docker.go @@ -0,0 +1,41 @@ +package testutil + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/network" +) + +// SetupDockerEnv configures Docker environment for testcontainers and creates a network. +// Automatically registers cleanup to remove the network on test completion. +// Returns the Docker network and the context to use for container operations. +func SetupDockerEnv(t *testing.T, log *TestLogger, ctx context.Context) (*testcontainers.DockerNetwork, context.Context) { + t.Helper() + + // Disable Ryuk reaper for Docker-in-Docker environments + t.Setenv("TESTCONTAINERS_RYUK_DISABLED", "true") + + // Ensure HOME is set for Docker config + if os.Getenv("HOME") == "" { + t.Setenv("HOME", t.TempDir()) + } + + // Create Docker network + nw, err := network.New(ctx) + require.NoError(t, err, "failed to create Docker network") + + log.Logf("Docker network created: %s", nw.Name) + + // Register cleanup + t.Cleanup(func() { + log.Logf("Removing Docker network") + require.NoError(t, nw.Remove(ctx), "failed to remove Docker network") + log.Logf("Docker network removed") + }) + + return nw, ctx +} diff --git a/integration_tests/testutil/mysql.go b/integration_tests/testutil/mysql.go new file mode 100644 index 00000000..51c70d6b --- /dev/null +++ b/integration_tests/testutil/mysql.go @@ -0,0 +1,112 @@ +package testutil + +import ( + "context" + "database/sql" + "os" + "path/filepath" + "sort" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/mysql" + "github.com/testcontainers/testcontainers-go/network" +) + +// TestLogger is a simple test-aware logger that records elapsed time between logs. +type TestLogger struct { + t *testing.T // The testing object to report logs to. + last time.Time // Timestamp of the last log, for elapsed calculation. +} + +// NewTestLogger creates a TestLogger for the current test. +func NewTestLogger(t *testing.T) *TestLogger { + t.Helper() + return &TestLogger{t: t} +} + +// Logf prints a formatted log message with timestamp and elapsed time since last log. +func (l *TestLogger) Logf(format string, args ...any) { + l.t.Helper() + now := time.Now() + delta := "" + if !l.last.IsZero() { + delta = " +" + now.Sub(l.last).Truncate(time.Millisecond).String() + } + l.last = now + l.t.Logf("[%s%s] "+format, append([]any{now.Format(time.RFC3339Nano), delta}, args...)...) +} + +// SchemaDir returns the path to a schema directory. +// It checks for both Bazel runfiles and direct go test paths. +// relativePath should be like "extensions/storage/mysql/schema" or "extensions/queue/sql/schema" +func SchemaDir(relativePath string) string { + // Bazel runfiles path + if dir := os.Getenv("TEST_SRCDIR"); dir != "" { + return filepath.Join(dir, os.Getenv("TEST_WORKSPACE"), relativePath) + } + // Direct go test path (run from repo root) + return relativePath +} + +// ApplySchema reads all .sql files from the schema directory and executes them on the database. +func ApplySchema(t *testing.T, log *TestLogger, db *sql.DB, schemaDirectory string) { + t.Helper() + + files, err := filepath.Glob(filepath.Join(schemaDirectory, "*.sql")) + require.NoError(t, err, "failed to glob schema files") + require.NotEmpty(t, files, "no .sql schema files found in %s", schemaDirectory) + + // Sort files to ensure deterministic schema application order. + sort.Strings(files) + + for _, f := range files { + name := filepath.Base(f) + log.Logf("Applying schema: %s", name) + + content, err := os.ReadFile(f) + require.NoError(t, err, "failed to read schema file %s", name) + + _, err = db.ExecContext(context.Background(), string(content)) + require.NoError(t, err, "failed to execute schema file %s", name) + + log.Logf("Schema applied: %s", name) + } +} + +// SetupMySQL starts a MySQL container on the given Docker network, applies the schema, +// and returns the container, db connection, and DSN for use in tests. +// The caller is responsible for cleanup (closing db, terminating container). +// schemaPath is the relative path to the schema directory (e.g., "extensions/storage/mysql/schema"). +func SetupMySQL(t *testing.T, log *TestLogger, nw *testcontainers.DockerNetwork, schemaPath string) (*mysql.MySQLContainer, *sql.DB, string) { + t.Helper() + + ctx := context.Background() + + log.Logf("Starting MySQL container") + mysqlContainer, err := mysql.Run(ctx, "mysql:8.0", + mysql.WithDatabase("submitqueue"), + mysql.WithUsername("root"), + mysql.WithPassword("root"), + network.WithNetwork([]string{"mysql"}, nw), + ) + require.NoError(t, err, "failed to start MySQL container") + log.Logf("MySQL container started") + + dsn, err := mysqlContainer.ConnectionString(ctx, "parseTime=true") + require.NoError(t, err, "failed to get MySQL connection string") + log.Logf("MySQL DSN obtained: %s", dsn) + + log.Logf("Opening MySQL connection") + db, err := sql.Open("mysql", dsn) + require.NoError(t, err, "failed to open MySQL connection") + log.Logf("MySQL connection opened") + + dir := SchemaDir(schemaPath) + ApplySchema(t, log, db, dir) + + return mysqlContainer, db, dsn +}