From 9d08ea688b2790d0236fb46bb1b98ac660d914be Mon Sep 17 00:00:00 2001 From: manjari Date: Thu, 5 Mar 2026 01:24:29 +0000 Subject: [PATCH 1/3] feat(controller) Score controller scores batches using heuristic scorer --- entity/batch.go | 10 + example/server/orchestrator/BUILD.bazel | 3 + example/server/orchestrator/main.go | 24 +- extension/storage/batch_store.go | 5 + extension/storage/mysql/batch_store.go | 48 +++- extension/storage/mysql/schema/batch.sql | 1 + orchestrator/controller/score/BUILD.bazel | 5 + orchestrator/controller/score/score.go | 86 +++++- orchestrator/controller/score/score_test.go | 291 ++++++++++++++++---- 9 files changed, 415 insertions(+), 58 deletions(-) diff --git a/entity/batch.go b/entity/batch.go index 8d47c677..128e6b6d 100644 --- a/entity/batch.go +++ b/entity/batch.go @@ -57,6 +57,8 @@ type Batch struct { // - queueA/batch/3 will contain queueA/batch/1 // Dependencies []map[string]interface{} + // Score is the probability of a successful land for the batch, between 0.0 and 1.0. + Score float32 `json:"score"` // The state of the batch lifecycle this batch is in. State BatchState // Version is the version of the object. It is used for optimistic locking. @@ -69,6 +71,14 @@ func (b Batch) ToBytes() ([]byte, error) { return json.Marshal(b) } +// WithScoreAndState returns a new Batch with the given score and state, incrementing the version. +func (b Batch) WithScoreAndState(score float32, state BatchState) Batch { + b.Score = score + b.State = state + b.Version++ + return b +} + // BatchFromBytes deserializes a Batch from JSON bytes. func BatchFromBytes(data []byte) (Batch, error) { var batch Batch diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 26ad6f18..33609664 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -12,12 +12,15 @@ go_library( visibility = ["//visibility:private"], deps = [ "//core/consumer", + "//entity", "//extension/counter", "//extension/counter/mysql", "//extension/mergechecker", "//extension/mergechecker/github", "//extension/queue", "//extension/queue/mysql", + "//extension/scorer", + "//extension/scorer/heuristic", "//extension/storage", "//extension/storage/mysql", "//orchestrator/controller", diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index f59b9eb9..aa26f9b1 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -15,12 +15,15 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/extension/counter" mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql" "github.com/uber/submitqueue/extension/mergechecker" githubchecker "github.com/uber/submitqueue/extension/mergechecker/github" extqueue "github.com/uber/submitqueue/extension/queue" queueMySQL "github.com/uber/submitqueue/extension/queue/mysql" + "github.com/uber/submitqueue/extension/scorer" + heuristicscorer "github.com/uber/submitqueue/extension/scorer/heuristic" mysqlstorage "github.com/uber/submitqueue/extension/storage/mysql" "github.com/uber/submitqueue/extension/storage" "github.com/uber/submitqueue/orchestrator/controller" @@ -161,8 +164,11 @@ func run() error { // Create merge checker mc := newMergeChecker(logger, scope) + // Create scorer + sc := newScorer(scope) + // Register controllers - if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cnt, store); err != nil { + if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cnt, store, sc); err != nil { return err } @@ -310,7 +316,7 @@ func newTopicRegistry(q extqueue.Queue, subscriberName string) (consumer.TopicRe // │ │ │ // └────────┴────────────────────────┘ -func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cnt counter.Counter, store storage.Storage) error { +func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cnt counter.Counter, store storage.Storage, sc scorer.Scorer) error { requestController := request.NewController( logger, scope, @@ -354,6 +360,8 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t registry, consumer.TopicKeyScore, "orchestrator-score", + sc, + store, ) if err := c.Register(scoreController); err != nil { return fmt.Errorf("failed to register score controller: %w", err) @@ -442,6 +450,18 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh }) } +func newScorer(scope tally.Scope) scorer.Scorer { + return heuristicscorer.New( + []heuristicscorer.Bucket{ + {Min: 0, Max: 100, Score: 1.0}, + }, + func(_ context.Context, _ entity.Change) (int, error) { + return 0, nil + }, + scope.SubScope("scorer"), + ) +} + // bearerTransport is an http.RoundTripper that adds a Bearer token to requests. type bearerTransport struct { token string diff --git a/extension/storage/batch_store.go b/extension/storage/batch_store.go index 64e8be83..ba1ef429 100644 --- a/extension/storage/batch_store.go +++ b/extension/storage/batch_store.go @@ -21,6 +21,11 @@ type BatchStore interface { // The implementation should increment the version by 1 atomically with the state update. UpdateState(ctx context.Context, id string, version int32, newState entity.BatchState) error + // UpdateScoreAndState updates the score and state of a batch if the current version matches the expected version. + // If versions do not match, returns ErrVersionMismatch. + // The implementation should increment the version by 1 atomically with the score and state update. + UpdateScoreAndState(ctx context.Context, id string, version int32, score float32, newState entity.BatchState) error + // GetByQueueAndStates retrieves all batches that belong to the given queue and are in the given states. GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) ([]entity.Batch, error) } diff --git a/extension/storage/mysql/batch_store.go b/extension/storage/mysql/batch_store.go index 8ee3d9ea..68e638e4 100644 --- a/extension/storage/mysql/batch_store.go +++ b/extension/storage/mysql/batch_store.go @@ -36,9 +36,9 @@ func (s *batchStore) Get(ctx context.Context, id string) (ret entity.Batch, retE var dependenciesJSON []byte err := s.db.QueryRowContext(ctx, - "SELECT id, queue, contains, dependencies, state, version FROM batch WHERE id = ?", + "SELECT id, queue, contains, dependencies, score, state, version FROM batch WHERE id = ?", id, - ).Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.State, &batch.Version) + ).Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.Score, &batch.State, &batch.Version) if errors.Is(err, sql.ErrNoRows) { return entity.Batch{}, storage.WrapNotFound(err) @@ -74,8 +74,8 @@ func (s *batchStore) Create(ctx context.Context, batch entity.Batch) (retErr err } _, err = s.db.ExecContext(ctx, - "INSERT INTO batch (id, queue, contains, dependencies, state, version) VALUES (?, ?, ?, ?, ?, ?)", - batch.ID, batch.Queue, containsJSON, dependenciesJSON, batch.State, batch.Version, + "INSERT INTO batch (id, queue, contains, dependencies, score, state, version) VALUES (?, ?, ?, ?, ?, ?, ?)", + batch.ID, batch.Queue, containsJSON, dependenciesJSON, batch.Score, batch.State, batch.Version, ) if err != nil { var mysqlErr *mysql.MySQLError @@ -123,6 +123,42 @@ func (s *batchStore) UpdateState(ctx context.Context, id string, version int32, return nil } +// UpdateScoreAndState updates the score and state of a batch if the current version matches the expected version. +// If versions do not match, returns ErrVersionMismatch. +// The implementation increments the version by 1 atomically with the score and state update. +func (s *batchStore) UpdateScoreAndState(ctx context.Context, id string, version int32, score float32, newState entity.BatchState) (retErr error) { + op := metrics.Begin(s.scope, "update_score") + defer func() { op.Complete(retErr) }() + + result, err := s.db.ExecContext(ctx, + "UPDATE batch SET score = ?, state = ?, version = version + 1 WHERE id = ? AND version = ?", + score, newState, id, version, + ) + if err != nil { + return fmt.Errorf( + "failed to update batch score for id=%q version=%d score=%v newState=%v: %w", + id, version, score, newState, err, + ) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf( + "failed to get rows affected from update for id=%q version=%d score=%v newState=%v: %w", + id, version, score, newState, err, + ) + } + + if rowsAffected != 1 { + return fmt.Errorf( + "version mismatch for batch update: id=%q expected_version=%d score=%v newState=%v: %w", + id, version, score, newState, storage.ErrVersionMismatch, + ) + } + + return nil +} + // GetByQueueAndStates retrieves all batches that belong to the given queue and are in the given states. func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) (ret []entity.Batch, retErr error) { op := metrics.Begin(s.scope, "get_by_queue_and_states") @@ -132,7 +168,7 @@ func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, stat return nil, nil } - query := "SELECT id, queue, contains, dependencies, state, version FROM batch WHERE queue = ? AND state IN (?" + strings.Repeat(", ?", len(states)-1) + ")" + query := "SELECT id, queue, contains, dependencies, score, state, version FROM batch WHERE queue = ? AND state IN (?" + strings.Repeat(", ?", len(states)-1) + ")" args := make([]any, 1+len(states)) args[0] = queue @@ -152,7 +188,7 @@ func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, stat var containsJSON []byte var dependenciesJSON []byte - if err := rows.Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.State, &batch.Version); err != nil { + if err := rows.Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.Score, &batch.State, &batch.Version); err != nil { return nil, fmt.Errorf("failed to scan batch entity by queue=%q states=%v from the database: %w", queue, states, err) } diff --git a/extension/storage/mysql/schema/batch.sql b/extension/storage/mysql/schema/batch.sql index 398d17ca..c79c046e 100644 --- a/extension/storage/mysql/schema/batch.sql +++ b/extension/storage/mysql/schema/batch.sql @@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS batch ( queue VARCHAR(255) NOT NULL, contains JSON NOT NULL, dependencies JSON NOT NULL, + score FLOAT NOT NULL DEFAULT 0, state VARCHAR(255) NOT NUll, version INT NOT NULL, PRIMARY KEY (id), diff --git a/orchestrator/controller/score/BUILD.bazel b/orchestrator/controller/score/BUILD.bazel index 43829396..d8c15942 100644 --- a/orchestrator/controller/score/BUILD.bazel +++ b/orchestrator/controller/score/BUILD.bazel @@ -10,6 +10,8 @@ go_library( "//core/errs", "//entity", "//entity/queue", + "//extension/scorer", + "//extension/storage", "@com_github_uber_go_tally_v4//:tally", "@org_uber_go_zap//:zap", ], @@ -25,6 +27,9 @@ go_test( "//entity", "//entity/queue", "//extension/queue/mock", + "//extension/scorer/mock", + "//extension/storage", + "//extension/storage/mock", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", "@com_github_uber_go_tally_v4//:tally", diff --git a/orchestrator/controller/score/score.go b/orchestrator/controller/score/score.go index f6e07ed8..17b526ab 100644 --- a/orchestrator/controller/score/score.go +++ b/orchestrator/controller/score/score.go @@ -2,6 +2,7 @@ package score import ( "context" + "errors" "fmt" "github.com/uber-go/tally/v4" @@ -9,6 +10,8 @@ import ( "github.com/uber/submitqueue/core/errs" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/scorer" + "github.com/uber/submitqueue/extension/storage" "go.uber.org/zap" ) @@ -21,6 +24,8 @@ type Controller struct { registry consumer.TopicRegistry topicKey consumer.TopicKey consumerGroup string + scorer scorer.Scorer + store storage.Storage } // Verify Controller implements consumer.Controller interface at compile time. @@ -33,6 +38,8 @@ func NewController( registry consumer.TopicRegistry, topicKey consumer.TopicKey, consumerGroup string, + scorer scorer.Scorer, + store storage.Storage, ) *Controller { return &Controller{ logger: logger.Named("score_controller"), @@ -40,6 +47,8 @@ func NewController( registry: registry, topicKey: topicKey, consumerGroup: consumerGroup, + scorer: scorer, + store: store, } } @@ -74,12 +83,81 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "partition_key", msg.PartitionKey, ) - // TODO: Add scoring logic - // - Evaluate batch priority - // - Apply scoring heuristics + // Validate batch contains exactly one request + if len(batch.Contains) == 0 { + c.metricsScope.Counter("empty_batch_errors").Inc(1) + return fmt.Errorf("batch %s contains no requests", batch.ID) + } + if len(batch.Contains) > 1 { + // TODO: multi-request batches will be supported later + c.metricsScope.Counter("multi_request_batch_errors").Inc(1) + return fmt.Errorf("batch %s contains %d requests, only single-request batches are supported", batch.ID, len(batch.Contains)) + } + + // Look up the request to get its Change + request, err := c.store.GetRequestStore().Get(ctx, batch.Contains[0]) + if err != nil { + if storage.IsNotFound(err) { + c.logger.Errorw("request not found", + "batch_id", batch.ID, + "request_id", batch.Contains[0], + "error", err, + ) + c.metricsScope.Counter("request_not_found_errors").Inc(1) + return fmt.Errorf("request %s not found: %w", batch.Contains[0], err) + } + c.logger.Errorw("failed to get request", + "batch_id", batch.ID, + "request_id", batch.Contains[0], + "error", err, + ) + c.metricsScope.Counter("storage_errors").Inc(1) + return errs.NewRetryableError(fmt.Errorf("failed to get request %s: %w", batch.Contains[0], err)) + } + + // Score the change + score, err := c.scorer.Score(ctx, request.Change) + if err != nil { + c.logger.Errorw("failed to score change", + "batch_id", batch.ID, + "request_id", request.ID, + "error", err, + ) + c.metricsScope.Counter("scorer_errors").Inc(1) + return errs.NewRetryableError(fmt.Errorf("failed to score change: %w", err)) + } + + batchScore := float32(score) + + c.logger.Infow("scored batch", + "batch_id", batch.ID, + "score", batchScore, + ) + + // Update batch store with score and transition state to speculating + if err := c.store.GetBatchStore().UpdateScoreAndState(ctx, batch.ID, batch.Version, batchScore, entity.BatchStateSpeculating); err != nil { + if errors.Is(err, storage.ErrVersionMismatch) { + c.logger.Errorw("version mismatch updating batch score", + "batch_id", batch.ID, + "version", batch.Version, + "error", err, + ) + c.metricsScope.Counter("version_mismatch_errors").Inc(1) + return fmt.Errorf("version mismatch updating batch %s: %w", batch.ID, err) + } + c.logger.Errorw("failed to update batch score", + "batch_id", batch.ID, + "error", err, + ) + c.metricsScope.Counter("batch_store_errors").Inc(1) + return errs.NewRetryableError(fmt.Errorf("failed to update batch %s score: %w", batch.ID, err)) + } + + // Create new batch with updated state and version to reflect the store update + scored := batch.WithScoreAndState(batchScore, entity.BatchStateSpeculating) // Publish to speculate topic - if err := c.publish(ctx, consumer.TopicKeySpeculate, batch); err != nil { + if err := c.publish(ctx, consumer.TopicKeySpeculate, scored); err != nil { c.logger.Errorw("failed to publish output", "batch_id", batch.ID, "topic_key", consumer.TopicKeySpeculate, diff --git a/orchestrator/controller/score/score_test.go b/orchestrator/controller/score/score_test.go index e10d5006..41ba2bb7 100644 --- a/orchestrator/controller/score/score_test.go +++ b/orchestrator/controller/score/score_test.go @@ -13,12 +13,21 @@ import ( "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/entity/queue" queuemock "github.com/uber/submitqueue/extension/queue/mock" + scorermock "github.com/uber/submitqueue/extension/scorer/mock" + "github.com/uber/submitqueue/extension/storage" + storagemock "github.com/uber/submitqueue/extension/storage/mock" "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" ) // newTestController creates a controller with test dependencies. -func newTestController(t *testing.T, ctrl *gomock.Controller, publishErr error) *Controller { +func newTestController( + t *testing.T, + ctrl *gomock.Controller, + publishErr error, + mockScorer *scorermock.MockScorer, + mockStorage *storagemock.MockStorage, +) *Controller { logger := zaptest.NewLogger(t).Sugar() scope := tally.NoopScope @@ -37,12 +46,14 @@ func newTestController(t *testing.T, ctrl *gomock.Controller, publishErr error) ) require.NoError(t, err) - return NewController(logger, scope, registry, consumer.TopicKeyScore, "orchestrator-score") + return NewController(logger, scope, registry, consumer.TopicKeyScore, "orchestrator-score", mockScorer, mockStorage) } func TestNewController(t *testing.T) { ctrl := gomock.NewController(t) - controller := newTestController(t, ctrl, nil) + mockScorer := scorermock.NewMockScorer(ctrl) + mockStorage := storagemock.NewMockStorage(ctrl) + controller := newTestController(t, ctrl, nil, mockScorer, mockStorage) require.NotNil(t, controller) assert.Equal(t, consumer.TopicKeyScore, controller.TopicKey()) @@ -50,34 +61,244 @@ func TestNewController(t *testing.T) { assert.Equal(t, "score", controller.Name()) } -func TestController_Process_Success(t *testing.T) { - ctrl := gomock.NewController(t) +func TestController_Process(t *testing.T) { + tests := []struct { + name string + batch entity.Batch + setupMocks func(*scorermock.MockScorer, *storagemock.MockStorage, *storagemock.MockRequestStore, *storagemock.MockBatchStore) + publishErr error + wantErr bool + wantRetry bool + }{ + { + name: "Success", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1"}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + st.EXPECT().GetRequestStore().Return(rs) + rs.EXPECT().Get(gomock.Any(), "test-queue/1").Return(entity.Request{ + ID: "test-queue/1", + Queue: "test-queue", + Change: entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/1/abc123"}, + }, + }, nil) + s.EXPECT().Score(gomock.Any(), entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/1/abc123"}, + }).Return(0.85, nil) + st.EXPECT().GetBatchStore().Return(bs) + bs.EXPECT().UpdateScoreAndState(gomock.Any(), "test-queue/batch/1", int32(1), float32(0.85), entity.BatchStateSpeculating).Return(nil) + }, + wantErr: false, + }, + { + name: "EmptyBatch", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + }, + wantErr: true, + wantRetry: false, + }, + { + name: "MultiRequestBatch", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1", "test-queue/2"}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + }, + wantErr: true, + wantRetry: false, + }, + { + name: "RequestNotFound", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1"}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + st.EXPECT().GetRequestStore().Return(rs) + rs.EXPECT().Get(gomock.Any(), "test-queue/1").Return(entity.Request{}, storage.ErrNotFound) + }, + wantErr: true, + wantRetry: false, + }, + { + name: "StorageFailure", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1"}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + st.EXPECT().GetRequestStore().Return(rs) + rs.EXPECT().Get(gomock.Any(), "test-queue/1").Return(entity.Request{}, fmt.Errorf("connection refused")) + }, + wantErr: true, + wantRetry: true, + }, + { + name: "ScorerFailure", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1"}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + st.EXPECT().GetRequestStore().Return(rs) + rs.EXPECT().Get(gomock.Any(), "test-queue/1").Return(entity.Request{ + ID: "test-queue/1", + Queue: "test-queue", + Change: entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/1/abc123"}, + }, + }, nil) + s.EXPECT().Score(gomock.Any(), gomock.Any()).Return(0.0, fmt.Errorf("scorer unavailable")) + }, + wantErr: true, + wantRetry: true, + }, + { + name: "BatchStoreVersionMismatch", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1"}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + st.EXPECT().GetRequestStore().Return(rs) + rs.EXPECT().Get(gomock.Any(), "test-queue/1").Return(entity.Request{ + ID: "test-queue/1", + Queue: "test-queue", + Change: entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/1/abc123"}, + }, + }, nil) + s.EXPECT().Score(gomock.Any(), gomock.Any()).Return(0.9, nil) + st.EXPECT().GetBatchStore().Return(bs) + bs.EXPECT().UpdateScoreAndState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( + fmt.Errorf("version mismatch: %w", storage.ErrVersionMismatch), + ) + }, + wantErr: true, + wantRetry: false, + }, + { + name: "BatchStoreFailure", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1"}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + st.EXPECT().GetRequestStore().Return(rs) + rs.EXPECT().Get(gomock.Any(), "test-queue/1").Return(entity.Request{ + ID: "test-queue/1", + Queue: "test-queue", + Change: entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/1/abc123"}, + }, + }, nil) + s.EXPECT().Score(gomock.Any(), gomock.Any()).Return(0.9, nil) + st.EXPECT().GetBatchStore().Return(bs) + bs.EXPECT().UpdateScoreAndState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( + fmt.Errorf("database unavailable"), + ) + }, + wantErr: true, + wantRetry: true, + }, + { + name: "PublishFailure", + batch: entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1"}, + State: entity.BatchStateCreated, + Version: 1, + }, + setupMocks: func(s *scorermock.MockScorer, st *storagemock.MockStorage, rs *storagemock.MockRequestStore, bs *storagemock.MockBatchStore) { + st.EXPECT().GetRequestStore().Return(rs) + rs.EXPECT().Get(gomock.Any(), "test-queue/1").Return(entity.Request{ + ID: "test-queue/1", + Queue: "test-queue", + Change: entity.Change{ + URIs: []string{"github://uber/submitqueue/pull/1/abc123"}, + }, + }, nil) + s.EXPECT().Score(gomock.Any(), gomock.Any()).Return(0.9, nil) + st.EXPECT().GetBatchStore().Return(bs) + bs.EXPECT().UpdateScoreAndState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + }, + publishErr: fmt.Errorf("publish failed"), + wantErr: true, + wantRetry: true, + }, + } - controller := newTestController(t, ctrl, nil) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + mockScorer := scorermock.NewMockScorer(ctrl) + mockStorageFactory := storagemock.NewMockStorage(ctrl) + mockRequestStore := storagemock.NewMockRequestStore(ctrl) + mockBatchStore := storagemock.NewMockBatchStore(ctrl) - batch := entity.Batch{ - ID: "test-queue/batch/1", - Queue: "test-queue", - State: entity.BatchStateCreated, - Version: 1, - } + tt.setupMocks(mockScorer, mockStorageFactory, mockRequestStore, mockBatchStore) - payload, err := batch.ToBytes() - require.NoError(t, err) + controller := newTestController(t, ctrl, tt.publishErr, mockScorer, mockStorageFactory) - msg := queue.NewMessage("test-queue/batch/1", payload, "test-queue", nil) - delivery := queuemock.NewMockDelivery(ctrl) - delivery.EXPECT().Message().Return(msg).AnyTimes() - delivery.EXPECT().Attempt().Return(1).AnyTimes() + payload, err := tt.batch.ToBytes() + require.NoError(t, err) - err = controller.Process(context.Background(), delivery) - require.NoError(t, err) + msg := queue.NewMessage(tt.batch.ID, payload, tt.batch.Queue, nil) + delivery := queuemock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + err = controller.Process(context.Background(), delivery) + + if tt.wantErr { + require.Error(t, err) + assert.Equal(t, tt.wantRetry, errs.IsRetryable(err)) + } else { + require.NoError(t, err) + } + }) + } } func TestController_Process_InvalidJSON(t *testing.T) { ctrl := gomock.NewController(t) + mockScorer := scorermock.NewMockScorer(ctrl) + mockStorage := storagemock.NewMockStorage(ctrl) - controller := newTestController(t, ctrl, nil) + controller := newTestController(t, ctrl, nil, mockScorer, mockStorage) invalidPayload := []byte(`{"invalid": json"}`) msg := queue.NewMessage("invalid-msg", invalidPayload, "partition1", nil) @@ -91,33 +312,11 @@ func TestController_Process_InvalidJSON(t *testing.T) { assert.False(t, errs.IsRetryable(err)) } -func TestController_Process_PublishFailure(t *testing.T) { - ctrl := gomock.NewController(t) - - controller := newTestController(t, ctrl, fmt.Errorf("publish failed")) - - batch := entity.Batch{ - ID: "test-queue/batch/1", - Queue: "test-queue", - State: entity.BatchStateCreated, - Version: 1, - } - - payload, err := batch.ToBytes() - require.NoError(t, err) - - msg := queue.NewMessage(batch.ID, payload, batch.Queue, nil) - delivery := queuemock.NewMockDelivery(ctrl) - delivery.EXPECT().Message().Return(msg).AnyTimes() - delivery.EXPECT().Attempt().Return(1).AnyTimes() - - err = controller.Process(context.Background(), delivery) - assert.Error(t, err) -} - func TestController_InterfaceImplementation(t *testing.T) { ctrl := gomock.NewController(t) - controller := newTestController(t, ctrl, nil) + mockScorer := scorermock.NewMockScorer(ctrl) + mockStorage := storagemock.NewMockStorage(ctrl) + controller := newTestController(t, ctrl, nil, mockScorer, mockStorage) var _ consumer.Controller = controller } From f5a9fcbbcd8164601bb0da5621a9e064b88cb7f7 Mon Sep 17 00:00:00 2001 From: manjari Date: Thu, 5 Mar 2026 01:32:04 +0000 Subject: [PATCH 2/3] use core/metrics --- orchestrator/controller/score/BUILD.bazel | 1 + orchestrator/controller/score/score.go | 26 +++++++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/orchestrator/controller/score/BUILD.bazel b/orchestrator/controller/score/BUILD.bazel index d8c15942..8d7b7a11 100644 --- a/orchestrator/controller/score/BUILD.bazel +++ b/orchestrator/controller/score/BUILD.bazel @@ -8,6 +8,7 @@ go_library( deps = [ "//core/consumer", "//core/errs", + "//core/metrics", "//entity", "//entity/queue", "//extension/scorer", diff --git a/orchestrator/controller/score/score.go b/orchestrator/controller/score/score.go index 17b526ab..53def911 100644 --- a/orchestrator/controller/score/score.go +++ b/orchestrator/controller/score/score.go @@ -8,6 +8,7 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/core/errs" + "github.com/uber/submitqueue/core/metrics" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" "github.com/uber/submitqueue/extension/scorer" @@ -55,8 +56,9 @@ func NewController( // Process processes a score delivery from the queue. // Deserializes the batch, scores it, and publishes to the speculate topic. // Returns nil to ack (success), or error to nack (retry). -func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) error { - c.metricsScope.Counter("received").Inc(1) +func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (retErr error) { + op := metrics.Begin(c.metricsScope, "process") + defer func() { op.Complete(retErr) }() msg := delivery.Message() @@ -69,7 +71,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "attempt", delivery.Attempt(), "error", err, ) - c.metricsScope.Counter("deserialize_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "deserialize_errors", 1) // Non-retryable: malformed messages will never succeed regardless of retry count return fmt.Errorf("failed to deserialize batch: %w", err) } @@ -85,12 +87,12 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er // Validate batch contains exactly one request if len(batch.Contains) == 0 { - c.metricsScope.Counter("empty_batch_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "empty_batch_errors", 1) return fmt.Errorf("batch %s contains no requests", batch.ID) } if len(batch.Contains) > 1 { // TODO: multi-request batches will be supported later - c.metricsScope.Counter("multi_request_batch_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "multi_request_batch_errors", 1) return fmt.Errorf("batch %s contains %d requests, only single-request batches are supported", batch.ID, len(batch.Contains)) } @@ -103,7 +105,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "request_id", batch.Contains[0], "error", err, ) - c.metricsScope.Counter("request_not_found_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "request_not_found_errors", 1) return fmt.Errorf("request %s not found: %w", batch.Contains[0], err) } c.logger.Errorw("failed to get request", @@ -111,7 +113,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "request_id", batch.Contains[0], "error", err, ) - c.metricsScope.Counter("storage_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "storage_errors", 1) return errs.NewRetryableError(fmt.Errorf("failed to get request %s: %w", batch.Contains[0], err)) } @@ -123,7 +125,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "request_id", request.ID, "error", err, ) - c.metricsScope.Counter("scorer_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "scorer_errors", 1) return errs.NewRetryableError(fmt.Errorf("failed to score change: %w", err)) } @@ -142,14 +144,14 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "version", batch.Version, "error", err, ) - c.metricsScope.Counter("version_mismatch_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "version_mismatch_errors", 1) return fmt.Errorf("version mismatch updating batch %s: %w", batch.ID, err) } c.logger.Errorw("failed to update batch score", "batch_id", batch.ID, "error", err, ) - c.metricsScope.Counter("batch_store_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "batch_store_errors", 1) return errs.NewRetryableError(fmt.Errorf("failed to update batch %s score: %w", batch.ID, err)) } @@ -163,7 +165,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "topic_key", consumer.TopicKeySpeculate, "error", err, ) - c.metricsScope.Counter("publish_errors").Inc(1) + metrics.NamedCounter(c.metricsScope, "process", "publish_errors", 1) return errs.NewRetryableError(fmt.Errorf("failed to publish to speculate: %w", err)) } @@ -172,8 +174,6 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "topic_key", consumer.TopicKeySpeculate, ) - c.metricsScope.Counter("processed").Inc(1) - return nil // Success - message will be acked } From 33fcaab074abe6aa56a315bd0cb06be47345dd20 Mon Sep 17 00:00:00 2001 From: manjari Date: Fri, 6 Mar 2026 20:54:04 +0000 Subject: [PATCH 3/3] address comments --- entity/batch.go | 2 ++ example/server/orchestrator/main.go | 3 +++ extension/storage/mysql/schema/batch.sql | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/entity/batch.go b/entity/batch.go index 128e6b6d..c4ea20ae 100644 --- a/entity/batch.go +++ b/entity/batch.go @@ -72,6 +72,8 @@ func (b Batch) ToBytes() ([]byte, error) { } // WithScoreAndState returns a new Batch with the given score and state, incrementing the version. +// The version is only incremented after a successful write to the DB to reflect the update to the +// batch. func (b Batch) WithScoreAndState(score float32, state BatchState) Batch { b.Score = score b.State = state diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index aa26f9b1..014e7190 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -450,6 +450,9 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh }) } +// newScorer returns a temporary scorer that assigns a success probability of +// 1.0 to every change, regardless of its attributes. +// TODO: Replace with a better heuristic as the pipeline evolves. func newScorer(scope tally.Scope) scorer.Scorer { return heuristicscorer.New( []heuristicscorer.Bucket{ diff --git a/extension/storage/mysql/schema/batch.sql b/extension/storage/mysql/schema/batch.sql index c79c046e..90bbeb44 100644 --- a/extension/storage/mysql/schema/batch.sql +++ b/extension/storage/mysql/schema/batch.sql @@ -3,7 +3,7 @@ CREATE TABLE IF NOT EXISTS batch ( queue VARCHAR(255) NOT NULL, contains JSON NOT NULL, dependencies JSON NOT NULL, - score FLOAT NOT NULL DEFAULT 0, + score FLOAT NOT NULL, state VARCHAR(255) NOT NUll, version INT NOT NULL, PRIMARY KEY (id),