Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions entity/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ type Batch struct {
// - queueA/batch/3 will contain queueA/batch/1
//
Dependencies []map[string]interface{}
// Score is the probability of a successful land for the batch, between 0.0 and 1.0.
Score float32 `json:"score"`
// The state of the batch lifecycle this batch is in.
State BatchState
// Version is the version of the object. It is used for optimistic locking.
Expand All @@ -69,6 +71,16 @@ func (b Batch) ToBytes() ([]byte, error) {
return json.Marshal(b)
}

// WithScoreAndState returns a new Batch with the given score and state, incrementing the version.
// The version is only incremented after a successful write to the DB to reflect the update to the
// batch.
func (b Batch) WithScoreAndState(score float32, state BatchState) Batch {
b.Score = score
b.State = state
b.Version++
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need to increment the version? It should only be incremented after a successful atomic persistence (i.e. a write to the database).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see where it is used (after the update), the comment could reflect that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment.

Copy link
Copy Markdown
Collaborator

@behinddwalls behinddwalls Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am wondering if this should be in entity at all?

to me it seems like we probably need this to be part of contract at storage layer itself which accepts, entity, oldVersion, newVresion... I am not fan of implicit things as they hide away and lead to confusion/issues down the line..

Thoughts?

return b
}

// BatchFromBytes deserializes a Batch from JSON bytes.
func BatchFromBytes(data []byte) (Batch, error) {
var batch Batch
Expand Down
3 changes: 3 additions & 0 deletions example/server/orchestrator/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ go_library(
visibility = ["//visibility:private"],
deps = [
"//core/consumer",
"//entity",
"//extension/counter",
"//extension/counter/mysql",
"//extension/mergechecker",
"//extension/mergechecker/github",
"//extension/queue",
"//extension/queue/mysql",
"//extension/scorer",
"//extension/scorer/heuristic",
"//extension/storage",
"//extension/storage/mysql",
"//orchestrator/controller",
Expand Down
27 changes: 25 additions & 2 deletions example/server/orchestrator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ import (
_ "github.com/go-sql-driver/mysql"
"github.com/uber-go/tally/v4"
"github.com/uber/submitqueue/core/consumer"
"github.com/uber/submitqueue/entity"
"github.com/uber/submitqueue/extension/counter"
mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql"
"github.com/uber/submitqueue/extension/mergechecker"
githubchecker "github.com/uber/submitqueue/extension/mergechecker/github"
extqueue "github.com/uber/submitqueue/extension/queue"
queueMySQL "github.com/uber/submitqueue/extension/queue/mysql"
"github.com/uber/submitqueue/extension/scorer"
heuristicscorer "github.com/uber/submitqueue/extension/scorer/heuristic"
mysqlstorage "github.com/uber/submitqueue/extension/storage/mysql"
"github.com/uber/submitqueue/extension/storage"
"github.com/uber/submitqueue/orchestrator/controller"
Expand Down Expand Up @@ -161,8 +164,11 @@ func run() error {
// Create merge checker
mc := newMergeChecker(logger, scope)

// Create scorer
sc := newScorer(scope)

// Register controllers
if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cnt, store); err != nil {
if err := registerControllers(c, logger.Sugar(), scope, registry, mc, cnt, store, sc); err != nil {
return err
}

Expand Down Expand Up @@ -310,7 +316,7 @@ func newTopicRegistry(q extqueue.Queue, subscriberName string) (consumer.TopicRe
// │ │ │
// └────────┴────────────────────────┘

func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cnt counter.Counter, store storage.Storage) error {
func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope tally.Scope, registry consumer.TopicRegistry, mc mergechecker.MergeChecker, cnt counter.Counter, store storage.Storage, sc scorer.Scorer) error {
requestController := request.NewController(
logger,
scope,
Expand Down Expand Up @@ -354,6 +360,8 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t
registry,
consumer.TopicKeyScore,
"orchestrator-score",
sc,
store,
)
if err := c.Register(scoreController); err != nil {
return fmt.Errorf("failed to register score controller: %w", err)
Expand Down Expand Up @@ -442,6 +450,21 @@ func newMergeChecker(logger *zap.Logger, scope tally.Scope) mergechecker.MergeCh
})
}

// newScorer returns a temporary scorer that assigns a success probability of
// 1.0 to every change, regardless of its attributes.
// TODO: Replace with a better heuristic as the pipeline evolves.
func newScorer(scope tally.Scope) scorer.Scorer {
return heuristicscorer.New(
[]heuristicscorer.Bucket{
{Min: 0, Max: 100, Score: 1.0},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add "TODO" for the real one?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

},
func(_ context.Context, _ entity.Change) (int, error) {
return 0, nil
},
scope.SubScope("scorer"),
)
}

// bearerTransport is an http.RoundTripper that adds a Bearer token to requests.
type bearerTransport struct {
token string
Expand Down
5 changes: 5 additions & 0 deletions extension/storage/batch_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ type BatchStore interface {
// The implementation should increment the version by 1 atomically with the state update.
UpdateState(ctx context.Context, id string, version int32, newState entity.BatchState) error

// UpdateScoreAndState updates the score and state of a batch if the current version matches the expected version.
// If versions do not match, returns ErrVersionMismatch.
// The implementation should increment the version by 1 atomically with the score and state update.
UpdateScoreAndState(ctx context.Context, id string, version int32, score float32, newState entity.BatchState) error

// GetByQueueAndStates retrieves all batches that belong to the given queue and are in the given states.
GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) ([]entity.Batch, error)
}
48 changes: 42 additions & 6 deletions extension/storage/mysql/batch_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func (s *batchStore) Get(ctx context.Context, id string) (ret entity.Batch, retE
var dependenciesJSON []byte

err := s.db.QueryRowContext(ctx,
"SELECT id, queue, contains, dependencies, state, version FROM batch WHERE id = ?",
"SELECT id, queue, contains, dependencies, score, state, version FROM batch WHERE id = ?",
id,
).Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.State, &batch.Version)
).Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.Score, &batch.State, &batch.Version)

if errors.Is(err, sql.ErrNoRows) {
return entity.Batch{}, storage.WrapNotFound(err)
Expand Down Expand Up @@ -74,8 +74,8 @@ func (s *batchStore) Create(ctx context.Context, batch entity.Batch) (retErr err
}

_, err = s.db.ExecContext(ctx,
"INSERT INTO batch (id, queue, contains, dependencies, state, version) VALUES (?, ?, ?, ?, ?, ?)",
batch.ID, batch.Queue, containsJSON, dependenciesJSON, batch.State, batch.Version,
"INSERT INTO batch (id, queue, contains, dependencies, score, state, version) VALUES (?, ?, ?, ?, ?, ?, ?)",
batch.ID, batch.Queue, containsJSON, dependenciesJSON, batch.Score, batch.State, batch.Version,
)
if err != nil {
var mysqlErr *mysql.MySQLError
Expand Down Expand Up @@ -123,6 +123,42 @@ func (s *batchStore) UpdateState(ctx context.Context, id string, version int32,
return nil
}

// UpdateScoreAndState updates the score and state of a batch if the current version matches the expected version.
// If versions do not match, returns ErrVersionMismatch.
// The implementation increments the version by 1 atomically with the score and state update.
func (s *batchStore) UpdateScoreAndState(ctx context.Context, id string, version int32, score float32, newState entity.BatchState) (retErr error) {
op := metrics.Begin(s.scope, "update_score")
defer func() { op.Complete(retErr) }()

result, err := s.db.ExecContext(ctx,
"UPDATE batch SET score = ?, state = ?, version = version + 1 WHERE id = ? AND version = ?",
score, newState, id, version,
)
if err != nil {
return fmt.Errorf(
"failed to update batch score for id=%q version=%d score=%v newState=%v: %w",
id, version, score, newState, err,
)
}

rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf(
"failed to get rows affected from update for id=%q version=%d score=%v newState=%v: %w",
id, version, score, newState, err,
)
}

if rowsAffected != 1 {
return fmt.Errorf(
"version mismatch for batch update: id=%q expected_version=%d score=%v newState=%v: %w",
id, version, score, newState, storage.ErrVersionMismatch,
)
}

return nil
}

// GetByQueueAndStates retrieves all batches that belong to the given queue and are in the given states.
func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) (ret []entity.Batch, retErr error) {
op := metrics.Begin(s.scope, "get_by_queue_and_states")
Expand All @@ -132,7 +168,7 @@ func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, stat
return nil, nil
}

query := "SELECT id, queue, contains, dependencies, state, version FROM batch WHERE queue = ? AND state IN (?" + strings.Repeat(", ?", len(states)-1) + ")"
query := "SELECT id, queue, contains, dependencies, score, state, version FROM batch WHERE queue = ? AND state IN (?" + strings.Repeat(", ?", len(states)-1) + ")"

args := make([]any, 1+len(states))
args[0] = queue
Expand All @@ -152,7 +188,7 @@ func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, stat
var containsJSON []byte
var dependenciesJSON []byte

if err := rows.Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.State, &batch.Version); err != nil {
if err := rows.Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.Score, &batch.State, &batch.Version); err != nil {
return nil, fmt.Errorf("failed to scan batch entity by queue=%q states=%v from the database: %w", queue, states, err)
}

Expand Down
1 change: 1 addition & 0 deletions extension/storage/mysql/schema/batch.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS batch (
queue VARCHAR(255) NOT NULL,
contains JSON NOT NULL,
dependencies JSON NOT NULL,
score FLOAT NOT NULL,
state VARCHAR(255) NOT NUll,
version INT NOT NULL,
PRIMARY KEY (id),
Expand Down
6 changes: 6 additions & 0 deletions orchestrator/controller/score/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ go_library(
deps = [
"//core/consumer",
"//core/errs",
"//core/metrics",
"//entity",
"//entity/queue",
"//extension/scorer",
"//extension/storage",
"@com_github_uber_go_tally_v4//:tally",
"@org_uber_go_zap//:zap",
],
Expand All @@ -25,6 +28,9 @@ go_test(
"//entity",
"//entity/queue",
"//extension/queue/mock",
"//extension/scorer/mock",
"//extension/storage",
"//extension/storage/mock",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
"@com_github_uber_go_tally_v4//:tally",
Expand Down
Loading