From f41eb6ba6831e914097f52a4d9bd3eab6490b356 Mon Sep 17 00:00:00 2001 From: sergeyb Date: Fri, 20 Feb 2026 07:34:37 +0000 Subject: [PATCH] feat(counter): Implement global atomic counter for sequential ID generation --- entities/BUILD.bazel | 12 +- entities/request.go | 33 +---- entities/request_test.go | 112 ----------------- examples/server/gateway/BUILD.bazel | 2 + examples/server/gateway/main.go | 13 +- extensions/counter/BUILD.bazel | 8 ++ extensions/counter/README.md | 35 ++++++ extensions/counter/counter.go | 14 +++ extensions/counter/mysql/BUILD.bazel | 11 ++ extensions/counter/mysql/counter.go | 37 ++++++ extensions/counter/mysql/schema/BUILD.bazel | 5 + extensions/counter/mysql/schema/counter.sql | 5 + extensions/storage/mysql/request_store.go | 90 ++++--------- extensions/storage/mysql/schema/request.sql | 4 +- extensions/storage/request_store.go | 8 +- extensions/storage/storage.go | 3 + gateway/controller/BUILD.bazel | 1 + gateway/controller/land.go | 27 +++- gateway/controller/land_test.go | 132 +++++++++++++------- integration_tests/BUILD.bazel | 1 + integration_tests/mysql.go | 54 ++++---- 21 files changed, 311 insertions(+), 296 deletions(-) delete mode 100644 entities/request_test.go create mode 100644 extensions/counter/BUILD.bazel create mode 100644 extensions/counter/README.md create mode 100644 extensions/counter/counter.go create mode 100644 extensions/counter/mysql/BUILD.bazel create mode 100644 extensions/counter/mysql/counter.go create mode 100644 extensions/counter/mysql/schema/BUILD.bazel create mode 100644 extensions/counter/mysql/schema/counter.sql diff --git a/entities/BUILD.bazel b/entities/BUILD.bazel index 583cdd26..627f0fa5 100644 --- a/entities/BUILD.bazel +++ b/entities/BUILD.bazel @@ -1,4 +1,4 @@ -load("@rules_go//go:def.bzl", "go_library", "go_test") +load("@rules_go//go:def.bzl", "go_library") go_library( name = "entities", @@ -6,13 +6,3 @@ go_library( importpath = "github.com/uber/submitqueue/entities", visibility = ["//visibility:public"], ) - -go_test( - name = "entities_test", - srcs = ["request_test.go"], - embed = [":entities"], - deps = [ - "@com_github_stretchr_testify//assert", - "@com_github_stretchr_testify//require", - ], -) diff --git a/entities/request.go b/entities/request.go index 487c2799..02c1cf71 100644 --- a/entities/request.go +++ b/entities/request.go @@ -1,10 +1,5 @@ package entities -import ( - "fmt" - "strconv" - "strings" -) // RequestLandStrategy defines the possible source control integration methods. type RequestLandStrategy string @@ -52,10 +47,10 @@ type Request struct { // Immutable fields, fixed at request entity creation // **************** + // ID is the globally unique identifier for the land request. Format: "/". + ID string // Queue is the name of the queue processing the land request. Queue name is defined in the configuration and should be unique within the system. Queue string - // Seq is an autoincrementing integer identifier for the land request. It is unique within the queue. - Seq int64 // Change is a number of code changes (such as pull requests) to land into the target branch. Target branch is defined by the queue configuration. Change Change // LandStrategy is the source control integration strategy to use for this land operation. @@ -71,27 +66,3 @@ type Request struct { // Versioning starts at 1 and is incremented for each change to the object. Version int32 } - -// GetID returns the globally unique identifier for the land request. -func (r *Request) GetID() string { - return fmt.Sprintf("%s/%d", r.Queue, r.Seq) -} - -// ParseRequestID parses the globally unique identifier for the land request and returns the queue name and sequence number. -func ParseRequestID(id string) (queue string, seq int64, err error) { - parts := strings.Split(id, "/") - if len(parts) != 2 { - return "", 0, fmt.Errorf("invalid format of the request ID: %s; expected format: /", id) - } - - seq, err = strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return "", 0, fmt.Errorf("invalid sequence number in the request ID: %s; expected format: /; parsing error: %w", id, err) - } - - if seq <= 0 { - return "", 0, fmt.Errorf("invalid sequence number in the request ID: %s; expected format: /; sequence number must be greater than 0 but got %d", id, seq) - } - - return parts[0], seq, nil -} diff --git a/entities/request_test.go b/entities/request_test.go deleted file mode 100644 index 7627ea1d..00000000 --- a/entities/request_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package entities - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestRequest_GetID(t *testing.T) { - tests := []struct { - name string - request Request - expected string - }{ - { - name: "standard ID", - request: Request{Queue: "my-queue", Seq: 42}, - expected: "my-queue/42", - }, - { - name: "seq 1", - request: Request{Queue: "q", Seq: 1}, - expected: "q/1", - }, - { - name: "large seq", - request: Request{Queue: "prod", Seq: 9999999}, - expected: "prod/9999999", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, tt.request.GetID()) - }) - } -} - -func TestParseRequestID(t *testing.T) { - tests := []struct { - name string - id string - wantQueue string - wantSeq int64 - expectError bool - }{ - { - name: "valid ID", - id: "my-queue/42", - wantQueue: "my-queue", - wantSeq: 42, - }, - { - name: "seq 1", - id: "q/1", - wantQueue: "q", - wantSeq: 1, - }, - { - name: "missing separator", - id: "no-separator", - expectError: true, - }, - { - name: "too many separators", - id: "a/b/c", - expectError: true, - }, - { - name: "empty string", - id: "", - expectError: true, - }, - { - name: "non-numeric seq", - id: "queue/abc", - expectError: true, - }, - { - name: "zero seq", - id: "queue/0", - expectError: true, - }, - { - name: "negative seq", - id: "queue/-1", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - queue, seq, err := ParseRequestID(tt.id) - if tt.expectError { - require.Error(t, err) - return - } - require.NoError(t, err) - assert.Equal(t, tt.wantQueue, queue) - assert.Equal(t, tt.wantSeq, seq) - }) - } -} - -func TestGetID_ParseRequestID_Roundtrip(t *testing.T) { - req := &Request{Queue: "test-queue", Seq: 123} - queue, seq, err := ParseRequestID(req.GetID()) - require.NoError(t, err) - assert.Equal(t, req.Queue, queue) - assert.Equal(t, req.Seq, seq) -} diff --git a/examples/server/gateway/BUILD.bazel b/examples/server/gateway/BUILD.bazel index 46c4bf02..1708aae1 100644 --- a/examples/server/gateway/BUILD.bazel +++ b/examples/server/gateway/BUILD.bazel @@ -6,9 +6,11 @@ go_library( importpath = "github.com/uber/submitqueue/examples/server/gateway", visibility = ["//visibility:private"], deps = [ + "//extensions/counter/mysql", "//extensions/storage/mysql", "//gateway/controller", "//gateway/protopb", + "@com_github_go_sql_driver_mysql//:mysql", "@com_github_uber_go_tally_v4//:tally", "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//reflection", diff --git a/examples/server/gateway/main.go b/examples/server/gateway/main.go index a27a0030..fa6fe0b4 100644 --- a/examples/server/gateway/main.go +++ b/examples/server/gateway/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "fmt" "net" "os" @@ -10,7 +11,9 @@ import ( "syscall" "time" + _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" + mysqlcounter "github.com/uber/submitqueue/extensions/counter/mysql" "github.com/uber/submitqueue/extensions/storage/mysql" "github.com/uber/submitqueue/gateway/controller" pb "github.com/uber/submitqueue/gateway/protopb" @@ -98,12 +101,20 @@ func run() error { } defer storeFactory.Close() + // Initialize MySQL counter + counterDB, err := sql.Open("mysql", mysqlDSN) + if err != nil { + return fmt.Errorf("failed to open MySQL connection for counter: %w", err) + } + defer counterDB.Close() + cnt := mysqlcounter.NewCounter(counterDB) + // Create gRPC server grpcServer := grpc.NewServer() // Create controllers and wrap them for gRPC pingController := controller.NewPingController(logger, scope) - landController := controller.NewLandController(logger, scope, storeFactory) + landController := controller.NewLandController(logger, scope, storeFactory, cnt) gatewayServer := &GatewayServer{ pingController: pingController, landController: landController, diff --git a/extensions/counter/BUILD.bazel b/extensions/counter/BUILD.bazel new file mode 100644 index 00000000..cca4640d --- /dev/null +++ b/extensions/counter/BUILD.bazel @@ -0,0 +1,8 @@ +load("@rules_go//go:def.bzl", "go_library") + +go_library( + name = "counter", + srcs = ["counter.go"], + importpath = "github.com/uber/submitqueue/extensions/counter", + visibility = ["//visibility:public"], +) diff --git a/extensions/counter/README.md b/extensions/counter/README.md new file mode 100644 index 00000000..b7464b4e --- /dev/null +++ b/extensions/counter/README.md @@ -0,0 +1,35 @@ +# Counter + +Vendor-agnostic interface for atomic sequential number generation. + +## Interface + +### Counter + +Generates unique, sequential values scoped to a domain string. + +```go +type Counter interface { + Next(ctx context.Context, domain string) (int64, error) +} +``` + +- **domain**: A string key that scopes the counter (max 255 characters). Each domain maintains its own independent sequence. +- **Next**: Atomically increments and returns the next value. The first call for a new domain returns 1. Safe for concurrent use; values are unique but ordering is not guaranteed. + +## Usage + +```go +cnt := mysqlcounter.NewCounter(db) + +// Generate sequential IDs for different domains +val, err := cnt.Next(ctx, "request/my-queue") // returns 1 +val, err = cnt.Next(ctx, "request/my-queue") // returns 2 +val, err = cnt.Next(ctx, "request/other") // returns 1 +``` + +## Implementing a Backend + +1. Create `extensions/counter/{backend}/` directory +2. Implement the `Counter` interface +3. Add a schema file under `extensions/counter/{backend}/schema/` if the backend requires it diff --git a/extensions/counter/counter.go b/extensions/counter/counter.go new file mode 100644 index 00000000..000846eb --- /dev/null +++ b/extensions/counter/counter.go @@ -0,0 +1,14 @@ +package counter + +import "context" + +// Counter provides atomic sequential number generation for a given domain. +// Each call to Next returns the next value in the sequence for the specified domain. +// The value is guaranteed to be unique within the domain throughout the system and persisted accordingly. +type Counter interface { + // Next atomically increments the counter for the given domain and returns the new value. + // The first call for a new domain returns 1. + // The implementation should support at least 255 length domains. + // The function is safe to be called concurrently and will give unique results, but the order of the values is not guaranteed. + Next(ctx context.Context, domain string) (int64, error) +} diff --git a/extensions/counter/mysql/BUILD.bazel b/extensions/counter/mysql/BUILD.bazel new file mode 100644 index 00000000..5fff3464 --- /dev/null +++ b/extensions/counter/mysql/BUILD.bazel @@ -0,0 +1,11 @@ +load("@rules_go//go:def.bzl", "go_library") + +go_library( + name = "mysql", + srcs = ["counter.go"], + importpath = "github.com/uber/submitqueue/extensions/counter/mysql", + visibility = ["//visibility:public"], + deps = [ + "//extensions/counter", + ], +) diff --git a/extensions/counter/mysql/counter.go b/extensions/counter/mysql/counter.go new file mode 100644 index 00000000..73c471f4 --- /dev/null +++ b/extensions/counter/mysql/counter.go @@ -0,0 +1,37 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uber/submitqueue/extensions/counter" +) + +type mysqlCounter struct { + db *sql.DB +} + +// NewCounter creates a new MySQL-backed Counter. +func NewCounter(db *sql.DB) counter.Counter { + return &mysqlCounter{db: db} +} + +// Next atomically increments the counter for the given domain and returns the new value. +// Uses MySQL's LAST_INSERT_ID() to set the value atomically and read the incremented value. +func (c *mysqlCounter) Next(ctx context.Context, domain string) (int64, error) { + result, err := c.db.ExecContext(ctx, + "INSERT INTO counter (domain, value) VALUES (?, LAST_INSERT_ID(1)) ON DUPLICATE KEY UPDATE value = LAST_INSERT_ID(value + 1)", + domain, + ) + if err != nil { + return 0, fmt.Errorf("failed to increment counter for domain=%s: %w", domain, err) + } + + value, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("failed to get counter value for domain=%s: %w", domain, err) + } + + return value, nil +} diff --git a/extensions/counter/mysql/schema/BUILD.bazel b/extensions/counter/mysql/schema/BUILD.bazel new file mode 100644 index 00000000..3412d773 --- /dev/null +++ b/extensions/counter/mysql/schema/BUILD.bazel @@ -0,0 +1,5 @@ +filegroup( + name = "schema", + srcs = glob(["*.sql"]), + visibility = ["//visibility:public"], +) diff --git a/extensions/counter/mysql/schema/counter.sql b/extensions/counter/mysql/schema/counter.sql new file mode 100644 index 00000000..754d4d79 --- /dev/null +++ b/extensions/counter/mysql/schema/counter.sql @@ -0,0 +1,5 @@ +CREATE TABLE IF NOT EXISTS counter ( + domain VARCHAR(255) NOT NULL, + value BIGINT NOT NULL, + PRIMARY KEY (domain) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; diff --git a/extensions/storage/mysql/request_store.go b/extensions/storage/mysql/request_store.go index 3a67b8fa..a38afde2 100644 --- a/extensions/storage/mysql/request_store.go +++ b/extensions/storage/mysql/request_store.go @@ -13,8 +13,6 @@ import ( "github.com/uber/submitqueue/extensions/storage" ) -const maxCreateRetries = 1000 - type requestStore struct { db *sql.DB } @@ -26,18 +24,13 @@ func NewRequestStore(db *sql.DB) storage.RequestStore { // Get retrieves a land request by ID. Returns ErrNotFound if the request is not found. func (r *requestStore) Get(ctx context.Context, id string) (entities.Request, error) { - queue, seq, err := entities.ParseRequestID(id) - if err != nil { - return entities.Request{}, fmt.Errorf("failed to parse request ID %s: %w", id, err) - } - var req entities.Request var changeIDsJSON []byte - err = r.db.QueryRowContext(ctx, - "SELECT queue, seq, change_source, change_ids, land_strategy, state, version FROM request WHERE queue = ? AND seq = ?", - queue, seq, - ).Scan(&req.Queue, &req.Seq, &req.Change.Source, &changeIDsJSON, &req.LandStrategy, &req.State, &req.Version) + err := r.db.QueryRowContext(ctx, + "SELECT id, queue, change_source, change_ids, land_strategy, state, version FROM request WHERE id = ?", + id, + ).Scan(&req.ID, &req.Queue, &req.Change.Source, &changeIDsJSON, &req.LandStrategy, &req.State, &req.Version) if errors.Is(err, sql.ErrNoRows) { return entities.Request{}, storage.WrapNotFound(err) @@ -53,89 +46,56 @@ func (r *requestStore) Get(ctx context.Context, id string) (entities.Request, er return req, nil } -// Create creates a new land request. Returns the created request object with generated sequence number. -// It uses optimistic locking: obtains the current max sequence number, attempts to insert with seq+1, -// and retries with an incremented sequence number on primary key conflict. -func (r *requestStore) Create(ctx context.Context, queue string, change entities.Change, strategy entities.RequestLandStrategy, state entities.RequestState) (entities.Request, error) { - changeIDsJSON, err := json.Marshal(change.IDs) +// Create creates a new land request. The request must have a unique ID already assigned. Returns ErrAlreadyExists if the request ID already exists. +func (r *requestStore) Create(ctx context.Context, request entities.Request) error { + changeIDsJSON, err := json.Marshal(request.Change.IDs) if err != nil { - return entities.Request{}, fmt.Errorf("failed to marshal change IDs=%v queue=%s for Create request entity: %w", change.IDs, queue, err) + return fmt.Errorf("failed to marshal change IDs=%v id=%s for Create request entity: %w", request.Change.IDs, request.ID, err) } - var seq int64 - err = r.db.QueryRowContext(ctx, - "SELECT COALESCE(MAX(seq), 0) + 1 FROM request WHERE queue = ?", - queue, - ).Scan(&seq) + _, err = r.db.ExecContext(ctx, + "INSERT INTO request (id, queue, change_source, change_ids, land_strategy, state, version) VALUES (?, ?, ?, ?, ?, ?, ?)", + request.ID, request.Queue, request.Change.Source, changeIDsJSON, request.LandStrategy, request.State, request.Version, + ) if err != nil { - return entities.Request{}, fmt.Errorf("failed to get next sequence number for queue=%s: %w", queue, err) - } - - // Version always start from 1 as per protocol. - version := int32(1) - - // retry up to maxCreateRetries times to insert the request entity, incrementing the sequence number on primary key conflict - for attempt := 0; attempt < maxCreateRetries; attempt++ { - _, err = r.db.ExecContext(ctx, - "INSERT INTO request (queue, seq, change_source, change_ids, land_strategy, state, version) VALUES (?, ?, ?, ?, ?, ?, ?)", - queue, seq, change.Source, changeIDsJSON, strategy, state, version, - ) - if err == nil { - return entities.Request{ - Queue: queue, - Seq: seq, - Change: change, - LandStrategy: strategy, - State: state, - Version: version, - }, nil - } - - // if the error is a MySQL primary key conflict error, increment the sequence number and retry - // It relies on MySQL-specific error code 1062 for primary key conflict. Hopefully this will not change in the future. var mysqlErr *mysql.MySQLError if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 { - seq++ - continue + // MySQL error code 1062 is "Duplicate entry". Hopefully it will never change with new versions of MySQL. + // Also it requires to have a single unique index on the table. + return fmt.Errorf("request entity id=%s: %w", request.ID, storage.ErrAlreadyExists) } - - return entities.Request{}, fmt.Errorf("failed to insert request entity queue=%s seq=%d: %w", queue, seq, err) + return fmt.Errorf("failed to insert request entity id=%s: %w", request.ID, err) } - return entities.Request{}, fmt.Errorf("failed to insert request entity queue=%s change=%v: exceeded %d retry attempts due to primary key conflicts", queue, change, maxCreateRetries) + return nil } // UpdateState updates the state of a land request if the current version matches the expected version. If versions do not match, returns ErrVersionMismatch. // The implementation increments the version by 1 atomically with the state update. func (r *requestStore) UpdateState(ctx context.Context, id string, version int32, newState entities.RequestState) error { - queue, seq, err := entities.ParseRequestID(id) - if err != nil { - return fmt.Errorf("failed to parse request ID=%q: %w", id, err) - } - result, err := r.db.ExecContext(ctx, - "UPDATE request SET state = ?, version = version + 1 WHERE queue = ? AND seq = ? AND version = ?", - newState, queue, seq, version, + "UPDATE request SET state = ?, version = version + 1 WHERE id = ? AND version = ?", + newState, id, version, ) if err != nil { return fmt.Errorf( - "failed to update request state for queue=%q seq=%d version=%d newState=%v: %w", - queue, seq, version, newState, err, + "failed to update request state for id=%q version=%d newState=%v: %w", + id, version, newState, err, ) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf( - "failed to get rows affected from update for queue=%q seq=%d version=%d newState=%v: %w", - queue, seq, version, newState, err, + "failed to get rows affected from update for id=%q version=%d newState=%v: %w", + id, version, newState, err, ) } if rowsAffected != 1 { return fmt.Errorf( - "version mismatch for request update: queue=%q seq=%d expected_version=%d newState=%v: %w", - queue, seq, version, newState, storage.ErrVersionMismatch, + "version mismatch for request update: id=%q expected_version=%d newState=%v: %w", + id, version, newState, storage.ErrVersionMismatch, ) } diff --git a/extensions/storage/mysql/schema/request.sql b/extensions/storage/mysql/schema/request.sql index a02c4f43..5011f094 100644 --- a/extensions/storage/mysql/schema/request.sql +++ b/extensions/storage/mysql/schema/request.sql @@ -1,10 +1,10 @@ CREATE TABLE IF NOT EXISTS request ( + id VARCHAR(255) NOT NULL, queue VARCHAR(255) NOT NULL, - seq BIGINT NOT NULL, change_source VARCHAR(255) NOT NULL, change_ids JSON NOT NULL, land_strategy VARCHAR(64) NOT NULL, state VARCHAR(64) NOT NULL, version INT NOT NULL, - PRIMARY KEY (queue, seq) + PRIMARY KEY (id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; diff --git a/extensions/storage/request_store.go b/extensions/storage/request_store.go index 06e8dab8..84424ae2 100644 --- a/extensions/storage/request_store.go +++ b/extensions/storage/request_store.go @@ -8,12 +8,12 @@ import ( // RequestStore is an interface that defines methods for managing land requests in the database. type RequestStore interface { - // Get retrieves a land request by ID (queue/seq). Returns ErrNotFound if the request is not found. + // Get retrieves a land request by ID. Returns ErrNotFound if the request is not found. Get(ctx context.Context, id string) (entities.Request, error) - // Create creates a new land request. Returns the created request object with generated sequence number. - // The implementation must ensure that the sequence number is unique within the queue. - Create(ctx context.Context, queue string, change entities.Change, strategy entities.RequestLandStrategy, state entities.RequestState) (entities.Request, error) + // Create creates a new land request. The request must have a unique ID already assigned. + // Returns ErrAlreadyExists if a request with the same ID already exists. + Create(ctx context.Context, request entities.Request) error // UpdateState updates the state of a land request if the current version matches the expected version. If versions do not match, returns ErrVersionMismatch. // The implementation should increment the version by 1 atomically with the state update. diff --git a/extensions/storage/storage.go b/extensions/storage/storage.go index 253da634..d10a16e2 100644 --- a/extensions/storage/storage.go +++ b/extensions/storage/storage.go @@ -16,6 +16,9 @@ func WrapNotFound(err error) error { return fmt.Errorf("%w: %w", ErrNotFound, err) } +// ErrAlreadyExists is returned by storage implementations when attempting to create a record with an ID that already exists. +var ErrAlreadyExists = errors.New("record already exists") + // ErrVersionMismatch is returned by storage implementations when the expected entity version does not match the current version of the object. // This is used to implement an optimistic locking mechanism, allowing multiple clients to update the same entity concurrently // and either retry or implement idempotent operations. diff --git a/gateway/controller/BUILD.bazel b/gateway/controller/BUILD.bazel index 3b6f682e..916c3cd2 100644 --- a/gateway/controller/BUILD.bazel +++ b/gateway/controller/BUILD.bazel @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//entities", + "//extensions/counter", "//extensions/storage", "//gateway/protopb", "@com_github_uber_go_tally_v4//:tally", diff --git a/gateway/controller/land.go b/gateway/controller/land.go index d911d3ab..bff23d9a 100644 --- a/gateway/controller/land.go +++ b/gateway/controller/land.go @@ -7,6 +7,7 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/entities" + "github.com/uber/submitqueue/extensions/counter" "github.com/uber/submitqueue/extensions/storage" pb "github.com/uber/submitqueue/gateway/protopb" "go.uber.org/zap" @@ -17,14 +18,16 @@ type LandController struct { logger *zap.Logger metricsScope tally.Scope storeFactory storage.StoreFactory + counter counter.Counter } // NewLandController creates a new instance of the gateway land controller -func NewLandController(logger *zap.Logger, scope tally.Scope, storeFactory storage.StoreFactory) *LandController { +func NewLandController(logger *zap.Logger, scope tally.Scope, storeFactory storage.StoreFactory, counter counter.Counter) *LandController { return &LandController{ logger: logger, metricsScope: scope, storeFactory: storeFactory, + counter: counter, } } @@ -51,20 +54,32 @@ func (c *LandController) Land(ctx context.Context, req *pb.LandRequest) (*pb.Lan return nil, fmt.Errorf("LandController failed to map strategy for queue=%s: %w", req.Queue, err) } - request, err := c.storeFactory.GetRequestStore().Create(ctx, queue, change, strategy, entities.RequestStateNew) + // Generate a globally unique request ID for the land request. + seq, err := c.counter.Next(ctx, "request/"+queue) if err != nil { - return nil, fmt.Errorf("LandController failed to create request for queue=%s: %w", req.Queue, err) + return nil, fmt.Errorf("LandController failed to generate request ID for queue=%s: %w", queue, err) + } + + request := entities.Request{ + ID: fmt.Sprintf("%s/%d", queue, seq), + Queue: queue, + Change: change, + LandStrategy: strategy, + State: entities.RequestStateNew, + Version: 1, } - sqid := request.GetID() + if err := c.storeFactory.GetRequestStore().Create(ctx, request); err != nil { + return nil, fmt.Errorf("LandController failed to create request for queue=%s: %w", req.Queue, err) + } c.logger.Debug("land request received", zap.String("queue", req.Queue), - zap.String("sqid", sqid), + zap.String("sqid", request.ID), ) return &pb.LandResponse{ - Sqid: sqid, + Sqid: request.ID, }, nil } diff --git a/gateway/controller/land_test.go b/gateway/controller/land_test.go index d589ffed..bc2662d9 100644 --- a/gateway/controller/land_test.go +++ b/gateway/controller/land_test.go @@ -14,16 +14,24 @@ import ( "go.uber.org/zap" ) +type mockCounter struct { + nextFunc func(ctx context.Context, domain string) (int64, error) +} + +func (m *mockCounter) Next(ctx context.Context, domain string) (int64, error) { + return m.nextFunc(ctx, domain) +} + type mockRequestStore struct { - createFunc func(ctx context.Context, queue string, change entities.Change, strategy entities.RequestLandStrategy, state entities.RequestState) (entities.Request, error) + createFunc func(ctx context.Context, request entities.Request) error } func (m *mockRequestStore) Get(ctx context.Context, id string) (entities.Request, error) { return entities.Request{}, nil } -func (m *mockRequestStore) Create(ctx context.Context, queue string, change entities.Change, strategy entities.RequestLandStrategy, state entities.RequestState) (entities.Request, error) { - return m.createFunc(ctx, queue, change, strategy, state) +func (m *mockRequestStore) Create(ctx context.Context, request entities.Request) error { + return m.createFunc(ctx, request) } func (m *mockRequestStore) UpdateState(ctx context.Context, id string, version int32, newState entities.RequestState) error { @@ -44,28 +52,27 @@ func (m *mockStoreFactory) Close() error { func TestNewLandController(t *testing.T) { factory := &mockStoreFactory{requestStore: &mockRequestStore{ - createFunc: func(ctx context.Context, queue string, change entities.Change, strategy entities.RequestLandStrategy, state entities.RequestState) (entities.Request, error) { - return entities.Request{}, nil + createFunc: func(ctx context.Context, request entities.Request) error { + return nil }, }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, factory) + cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { + return 1, nil + }} + controller := NewLandController(zap.NewNop(), tally.NoopScope, factory, cnt) require.NotNil(t, controller) } func TestLand_ReturnsSqid(t *testing.T) { factory := &mockStoreFactory{requestStore: &mockRequestStore{ - createFunc: func(ctx context.Context, queue string, change entities.Change, strategy entities.RequestLandStrategy, state entities.RequestState) (entities.Request, error) { - return entities.Request{ - Queue: queue, - Seq: 1, - Change: change, - LandStrategy: strategy, - State: state, - Version: 1, - }, nil + createFunc: func(ctx context.Context, request entities.Request) error { + return nil }, }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, factory) + cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { + return 1, nil + }} + controller := NewLandController(zap.NewNop(), tally.NoopScope, factory, cnt) ctx := context.Background() req := &pb.LandRequest{ @@ -79,28 +86,18 @@ func TestLand_ReturnsSqid(t *testing.T) { } func TestLand_PassesCorrectParametersToStore(t *testing.T) { - var capturedQueue string - var capturedChange entities.Change - var capturedStrategy entities.RequestLandStrategy - var capturedState entities.RequestState + var capturedRequest entities.Request factory := &mockStoreFactory{requestStore: &mockRequestStore{ - createFunc: func(ctx context.Context, queue string, change entities.Change, strategy entities.RequestLandStrategy, state entities.RequestState) (entities.Request, error) { - capturedQueue = queue - capturedChange = change - capturedStrategy = strategy - capturedState = state - return entities.Request{ - Queue: queue, - Seq: 42, - Change: change, - LandStrategy: strategy, - State: state, - Version: 1, - }, nil + createFunc: func(ctx context.Context, request entities.Request) error { + capturedRequest = request + return nil }, }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, factory) + cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { + return 42, nil + }} + controller := NewLandController(zap.NewNop(), tally.NoopScope, factory, cnt) ctx := context.Background() req := &pb.LandRequest{ @@ -111,21 +108,47 @@ func TestLand_PassesCorrectParametersToStore(t *testing.T) { resp, err := controller.Land(ctx, req) require.NoError(t, err) - assert.Equal(t, "my-queue", capturedQueue) - assert.Equal(t, "github", capturedChange.Source) - assert.Equal(t, []string{"pr-1", "pr-2"}, capturedChange.IDs) - assert.Equal(t, entities.RequestLandStrategyRebase, capturedStrategy) - assert.Equal(t, entities.RequestStateNew, capturedState) + assert.Equal(t, "my-queue/42", capturedRequest.ID) + assert.Equal(t, "my-queue", capturedRequest.Queue) + assert.Equal(t, "github", capturedRequest.Change.Source) + assert.Equal(t, []string{"pr-1", "pr-2"}, capturedRequest.Change.IDs) + assert.Equal(t, entities.RequestLandStrategyRebase, capturedRequest.LandStrategy) + assert.Equal(t, entities.RequestStateNew, capturedRequest.State) + assert.Equal(t, int32(1), capturedRequest.Version) assert.Equal(t, "my-queue/42", resp.Sqid) } func TestLand_ReturnsErrorOnStorageFailure(t *testing.T) { factory := &mockStoreFactory{requestStore: &mockRequestStore{ - createFunc: func(ctx context.Context, queue string, change entities.Change, strategy entities.RequestLandStrategy, state entities.RequestState) (entities.Request, error) { - return entities.Request{}, fmt.Errorf("database connection failed") + createFunc: func(ctx context.Context, request entities.Request) error { + return fmt.Errorf("database connection failed") + }, + }} + cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { + return 1, nil + }} + controller := NewLandController(zap.NewNop(), tally.NoopScope, factory, cnt) + ctx := context.Background() + + req := &pb.LandRequest{ + Queue: "test-queue", + Change: &pb.Change{Source: "github", Ids: []string{"123"}}, + } + _, err := controller.Land(ctx, req) + + require.Error(t, err) +} + +func TestLand_ReturnsErrorOnCounterFailure(t *testing.T) { + factory := &mockStoreFactory{requestStore: &mockRequestStore{ + createFunc: func(ctx context.Context, request entities.Request) error { + return nil }, }} - controller := NewLandController(zap.NewNop(), tally.NoopScope, factory) + cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { + return 0, fmt.Errorf("counter unavailable") + }} + controller := NewLandController(zap.NewNop(), tally.NoopScope, factory, cnt) ctx := context.Background() req := &pb.LandRequest{ @@ -136,3 +159,28 @@ func TestLand_ReturnsErrorOnStorageFailure(t *testing.T) { require.Error(t, err) } + +func TestLand_CounterDomainIncludesQueue(t *testing.T) { + var capturedDomain string + + factory := &mockStoreFactory{requestStore: &mockRequestStore{ + createFunc: func(ctx context.Context, request entities.Request) error { + return nil + }, + }} + cnt := &mockCounter{nextFunc: func(ctx context.Context, domain string) (int64, error) { + capturedDomain = domain + return 1, nil + }} + controller := NewLandController(zap.NewNop(), tally.NoopScope, factory, cnt) + ctx := context.Background() + + req := &pb.LandRequest{ + Queue: "my-queue", + Change: &pb.Change{Source: "github", Ids: []string{"123"}}, + } + _, err := controller.Land(ctx, req) + + require.NoError(t, err) + assert.Equal(t, "request/my-queue", capturedDomain) +} diff --git a/integration_tests/BUILD.bazel b/integration_tests/BUILD.bazel index 046a4b7b..8e0a1e40 100644 --- a/integration_tests/BUILD.bazel +++ b/integration_tests/BUILD.bazel @@ -10,6 +10,7 @@ go_test( "suite_test.go", ], data = [ + "//extensions/counter/mysql/schema", "//extensions/storage/mysql/schema", "//examples/server/gateway", "//examples/server/orchestrator", diff --git a/integration_tests/mysql.go b/integration_tests/mysql.go index f64677f0..b1e2933b 100644 --- a/integration_tests/mysql.go +++ b/integration_tests/mysql.go @@ -40,40 +40,50 @@ func (l *testLogger) logf(format string, args ...any) { l.t.Logf("[%s%s] "+format, append([]any{now.Format(time.RFC3339Nano), delta}, args...)...) } -// schemaDir returns the path to the MySQL schema directory. +// schemaDirs returns the paths to all schema directories. // It checks for both Bazel runfiles and direct go test paths. -func schemaDir() string { - // Bazel runfiles path - if dir := os.Getenv("TEST_SRCDIR"); dir != "" { - return filepath.Join(dir, os.Getenv("TEST_WORKSPACE"), "extensions/storage/mysql/schema") +func schemaDirs() []string { + dirs := []string{ + "extensions/storage/mysql/schema", + "extensions/counter/mysql/schema", } - // Direct go test path (run from repo root) - return "extensions/storage/mysql/schema" + + if srcDir := os.Getenv("TEST_SRCDIR"); srcDir != "" { + workspace := os.Getenv("TEST_WORKSPACE") + result := make([]string, len(dirs)) + for i, d := range dirs { + result[i] = filepath.Join(srcDir, workspace, d) + } + return result + } + + return dirs } -// applySchema reads all .sql files from the schema directory and executes them on the database. +// applySchema reads all .sql files from the schema directories and executes them on the database. func applySchema(t *testing.T, log *testLogger, db *sql.DB) { t.Helper() - dir := schemaDir() - files, err := filepath.Glob(filepath.Join(dir, "*.sql")) - require.NoError(t, err, "failed to glob schema files") - require.NotEmpty(t, files, "no .sql schema files found in %s", dir) + for _, dir := range schemaDirs() { + files, err := filepath.Glob(filepath.Join(dir, "*.sql")) + require.NoError(t, err, "failed to glob schema files in %s", dir) + require.NotEmpty(t, files, "no .sql schema files found in %s", dir) - // Sort files to ensure deterministic schema application order. - sort.Strings(files) + // Sort files to ensure deterministic schema application order. + sort.Strings(files) - for _, f := range files { - name := filepath.Base(f) - log.logf("Applying schema: %s", name) + for _, f := range files { + name := filepath.Base(f) + log.logf("Applying schema: %s", name) - content, err := os.ReadFile(f) - require.NoError(t, err, "failed to read schema file %s", name) + content, err := os.ReadFile(f) + require.NoError(t, err, "failed to read schema file %s", name) - _, err = db.ExecContext(context.Background(), string(content)) - require.NoError(t, err, "failed to execute schema file %s", name) + _, err = db.ExecContext(context.Background(), string(content)) + require.NoError(t, err, "failed to execute schema file %s", name) - log.logf("Schema applied: %s", name) + log.logf("Schema applied: %s", name) + } } }