From 9beb4e8b47f768387a0a4cf1b3122918bc80b521 Mon Sep 17 00:00:00 2001 From: machichima Date: Thu, 9 Apr 2026 14:34:27 +0800 Subject: [PATCH 1/6] feat: support filtering in list runs Signed-off-by: machichima --- runs/repository/impl/action.go | 87 ++++++++++++++++++++++---------- runs/repository/impl/filters.go | 56 ++++++++++++++++++-- runs/repository/models/action.go | 21 ++++++++ 3 files changed, 132 insertions(+), 32 deletions(-) diff --git a/runs/repository/impl/action.go b/runs/repository/impl/action.go index 5253a8f4d3..bd76bfeded 100644 --- a/runs/repository/impl/action.go +++ b/runs/repository/impl/action.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "net" - "strconv" "strings" "sync" "time" @@ -216,45 +215,79 @@ func (r *actionRepo) GetRun(ctx context.Context, runID *common.RunIdentifier) (* return &run, nil } -// ListRuns lists runs with pagination +// ListRuns lists runs with pagination and optional filters. func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error) { + // Build scope filter: always restrict to root actions (runs). + scopeFilter := interfaces.Filter(NewIsRootActionFilter()) + + switch scope := req.ScopeBy.(type) { + case *workflow.ListRunsRequest_ProjectId: + scopeFilter = scopeFilter.And(NewProjectIdFilter(scope.ProjectId)) + case *workflow.ListRunsRequest_TaskName: + scopeFilter = scopeFilter.And(NewRunTaskNameFilter(scope.TaskName)) + case *workflow.ListRunsRequest_TaskId: + scopeFilter = scopeFilter.And(NewRunTaskIdFilter(scope.TaskId)) + } + + // Parse pagination, sort, and user-supplied filters from the common ListRequest. + listInput, err := NewListResourceInputFromProto(req.Request, models.ActionColumnsSet) + if err != nil { + return nil, "", fmt.Errorf("invalid list request: %w", err) + } + + // Merge: scope filter always takes precedence, user filters are ANDed on top. + if listInput.Filter == nil { + listInput.Filter = scopeFilter + } else { + listInput.Filter = scopeFilter.And(listInput.Filter) + } + + return r.listRunsByInput(ctx, listInput) +} + +// listRunsByInput executes the query described by listInput against the actions table. +func (r *actionRepo) listRunsByInput(ctx context.Context, listInput interfaces.ListResourceInput) ([]*models.Run, string, error) { var queryBuilder strings.Builder var args []interface{} - argIdx := 1 - queryBuilder.WriteString("SELECT * FROM actions WHERE parent_action_name IS NULL") + queryBuilder.WriteString("SELECT * FROM actions") - // Apply scope filters - switch scope := req.ScopeBy.(type) { - case *workflow.ListRunsRequest_ProjectId: - queryBuilder.WriteString(fmt.Sprintf(" AND project = $%d AND domain = $%d", argIdx, argIdx+1)) - args = append(args, scope.ProjectId.Name, scope.ProjectId.Domain) - argIdx += 2 + // Apply filters — QueryExpression emits ? placeholders; Rebind converts to $N for PostgreSQL. + if listInput.Filter != nil { + expr, err := listInput.Filter.QueryExpression("") + if err != nil { + return nil, "", fmt.Errorf("failed to build filter expression: %w", err) + } + queryBuilder.WriteString(" WHERE ") + queryBuilder.WriteString(expr.Query) + args = append(args, expr.Args...) } - // Apply pagination according to token and limit from requests. - limit := 50 - offset := 0 - if req.Request != nil { - if req.Request.Token != "" { - parsedOffset, err := strconv.Atoi(req.Request.Token) - if err != nil { - return nil, "", fmt.Errorf("invalid pagination token: %w", err) + // Apply sort parameters; default to created_at DESC. + if len(listInput.SortParameters) > 0 { + queryBuilder.WriteString(" ORDER BY ") + for i, sp := range listInput.SortParameters { + if i > 0 { + queryBuilder.WriteString(", ") } - offset = parsedOffset + queryBuilder.WriteString(sp.GetOrderExpr()) } + } else { + queryBuilder.WriteString(" ORDER BY created_at DESC") + } - if req.Request.Limit > 0 { - limit = int(req.Request.Limit) - } + // Append LIMIT/OFFSET as ? placeholders, then rebind the entire query at once. + limit := listInput.Limit + if limit <= 0 { + limit = 50 } + queryBuilder.WriteString(" LIMIT ? OFFSET ?") + args = append(args, limit+1, listInput.Offset) // fetch one extra to detect next page - queryBuilder.WriteString(fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)) - args = append(args, limit+1, offset) // Fetch one extra to determine if there are more - argIdx += 2 + query := sqlx.Rebind(sqlx.DOLLAR, queryBuilder.String()) var runs []*models.Run - if err := sqlx.SelectContext(ctx, r.db, &runs, queryBuilder.String(), args...); err != nil { + if err := sqlx.SelectContext(ctx, r.db, &runs, query, args...); err != nil { return nil, "", fmt.Errorf("failed to list runs: %w", err) } @@ -262,7 +295,7 @@ func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest var nextToken string if len(runs) > limit { runs = runs[:limit] - nextToken = fmt.Sprintf("%d", offset+limit) + nextToken = fmt.Sprintf("%d", listInput.Offset+limit) } return runs, nextToken, nil diff --git a/runs/repository/impl/filters.go b/runs/repository/impl/filters.go index 7ccd253a29..901548202e 100644 --- a/runs/repository/impl/filters.go +++ b/runs/repository/impl/filters.go @@ -11,6 +11,11 @@ import ( "github.com/flyteorg/flyte/v2/runs/repository/interfaces" ) +// NewIsRootActionFilter creates a filter for root actions (runs) only. +func NewIsRootActionFilter() interfaces.Filter { + return &nullFilter{field: "parent_action_name", isNull: true} +} + // basicFilter implements the Filter interface for simple field comparisons type basicFilter struct { field string @@ -79,6 +84,34 @@ func (f *basicFilter) Or(filter interfaces.Filter) interfaces.Filter { } } +// nullFilter implements the Filter interface for IS NULL / IS NOT NULL checks +type nullFilter struct { + field string + isNull bool +} + +func (f *nullFilter) QueryExpression(table string) (interfaces.QueryExpr, error) { + column := f.field + if table != "" { + column = table + "." + f.field + } + op := "IS NULL" + if !f.isNull { + op = "IS NOT NULL" + } + return interfaces.QueryExpr{ + Query: fmt.Sprintf("%s %s", column, op), + }, nil +} + +func (f *nullFilter) And(filter interfaces.Filter) interfaces.Filter { + return &compositeFilter{left: f, right: filter, operator: "AND"} +} + +func (f *nullFilter) Or(filter interfaces.Filter) interfaces.Filter { + return &compositeFilter{left: f, right: filter, operator: "OR"} +} + // compositeFilter implements the Filter interface for AND/OR operations type compositeFilter struct { left interfaces.Filter @@ -150,13 +183,26 @@ func NewProjectIdFilter(projectId *common.ProjectIdentifier) interfaces.Filter { return projectFilter.And(domainFilter) } -// NewTaskNameFilter creates a filter for task name (project, domain, name) +// NewTaskNameFilter creates a filter for task name on the tasks table (project, domain, name). func NewTaskNameFilter(taskName *task.TaskName) interfaces.Filter { - projectFilter := NewEqualFilter("project", taskName.GetProject()) - domainFilter := NewEqualFilter("domain", taskName.GetDomain()) - nameFilter := NewEqualFilter("name", taskName.GetName()) + return NewEqualFilter("project", taskName.GetProject()). + And(NewEqualFilter("domain", taskName.GetDomain())). + And(NewEqualFilter("name", taskName.GetName())) +} + +// NewRunTaskNameFilter creates a filter matching runs by task name columns on the actions table. +func NewRunTaskNameFilter(taskName *task.TaskName) interfaces.Filter { + return NewEqualFilter("task_project", taskName.GetProject()). + And(NewEqualFilter("task_domain", taskName.GetDomain())). + And(NewEqualFilter("task_name", taskName.GetName())) +} - return projectFilter.And(domainFilter).And(nameFilter) +// NewRunTaskIdFilter creates a filter matching runs by full task identifier on the actions table. +func NewRunTaskIdFilter(taskId *task.TaskIdentifier) interfaces.Filter { + return NewEqualFilter("task_project", taskId.GetProject()). + And(NewEqualFilter("task_domain", taskId.GetDomain())). + And(NewEqualFilter("task_name", taskId.GetName())). + And(NewEqualFilter("task_version", taskId.GetVersion())) } // NewDeployedByFilter creates a filter for deployed_by = value diff --git a/runs/repository/models/action.go b/runs/repository/models/action.go index de79a0d471..133d7e6f6c 100644 --- a/runs/repository/models/action.go +++ b/runs/repository/models/action.go @@ -4,10 +4,31 @@ import ( "database/sql" "time" + "k8s.io/apimachinery/pkg/util/sets" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" ) +// ActionColumnsSet is the allowlist of columns that can be used in filters/sort for the actions table. +// This prevents SQL injection via user-supplied field names. +var ActionColumnsSet = sets.New( + "project", + "domain", + "run_name", + "phase", + "run_source", + "task_project", + "task_domain", + "task_name", + "task_version", + "function_name", + "created_at", + "updated_at", + "ended_at", + "duration_ms", +) + // Action represents a workflow action in the database // Root actions (where ParentActionName is NULL) represent runs type Action struct { From 720f178a38147874b8c6dce7d650c362b7fefe2a Mon Sep 17 00:00:00 2001 From: machichima Date: Fri, 10 Apr 2026 16:27:20 +0800 Subject: [PATCH 2/6] fix: merge repo impl ListRuns and ListActions and pass ListResourceInput as args Signed-off-by: machichima --- runs/repository/impl/action.go | 144 ++++++--------------------- runs/repository/impl/action_test.go | 89 ++++++----------- runs/repository/impl/filters.go | 15 +++ runs/repository/interfaces/action.go | 4 +- runs/repository/mocks/mocks.go | 133 ++++--------------------- runs/repository/models/action.go | 4 + runs/service/run_service.go | 129 +++++++++++++----------- 7 files changed, 173 insertions(+), 345 deletions(-) diff --git a/runs/repository/impl/action.go b/runs/repository/impl/action.go index db1ee5f897..9e99fd43a1 100644 --- a/runs/repository/impl/action.go +++ b/runs/repository/impl/action.go @@ -19,7 +19,6 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/logger" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" "github.com/flyteorg/flyte/v2/runs/repository/interfaces" "github.com/flyteorg/flyte/v2/runs/repository/models" ) @@ -83,92 +82,6 @@ func (r *actionRepo) GetRun(ctx context.Context, runID *common.RunIdentifier) (* return &run, nil } -// ListRuns lists runs with pagination and optional filters. -func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error) { - // Build scope filter: always restrict to root actions (runs). - scopeFilter := interfaces.Filter(NewIsRootActionFilter()) - - switch scope := req.ScopeBy.(type) { - case *workflow.ListRunsRequest_ProjectId: - scopeFilter = scopeFilter.And(NewProjectIdFilter(scope.ProjectId)) - case *workflow.ListRunsRequest_TaskName: - scopeFilter = scopeFilter.And(NewRunTaskNameFilter(scope.TaskName)) - case *workflow.ListRunsRequest_TaskId: - scopeFilter = scopeFilter.And(NewRunTaskIdFilter(scope.TaskId)) - } - - // Parse pagination, sort, and user-supplied filters from the common ListRequest. - listInput, err := NewListResourceInputFromProto(req.Request, models.ActionColumnsSet) - if err != nil { - return nil, "", fmt.Errorf("invalid list request: %w", err) - } - - // Merge: scope filter always takes precedence, user filters are ANDed on top. - if listInput.Filter == nil { - listInput.Filter = scopeFilter - } else { - listInput.Filter = scopeFilter.And(listInput.Filter) - } - - return r.listRunsByInput(ctx, listInput) -} - -// listRunsByInput executes the query described by listInput against the actions table. -func (r *actionRepo) listRunsByInput(ctx context.Context, listInput interfaces.ListResourceInput) ([]*models.Run, string, error) { - var queryBuilder strings.Builder - var args []interface{} - - queryBuilder.WriteString("SELECT * FROM actions") - - // Apply filters — QueryExpression emits ? placeholders; Rebind converts to $N for PostgreSQL. - if listInput.Filter != nil { - expr, err := listInput.Filter.QueryExpression("") - if err != nil { - return nil, "", fmt.Errorf("failed to build filter expression: %w", err) - } - queryBuilder.WriteString(" WHERE ") - queryBuilder.WriteString(expr.Query) - args = append(args, expr.Args...) - } - - // Apply sort parameters; default to created_at DESC. - if len(listInput.SortParameters) > 0 { - queryBuilder.WriteString(" ORDER BY ") - for i, sp := range listInput.SortParameters { - if i > 0 { - queryBuilder.WriteString(", ") - } - queryBuilder.WriteString(sp.GetOrderExpr()) - } - } else { - queryBuilder.WriteString(" ORDER BY created_at DESC") - } - - // Append LIMIT/OFFSET as ? placeholders, then rebind the entire query at once. - limit := listInput.Limit - if limit <= 0 { - limit = 50 - } - queryBuilder.WriteString(" LIMIT ? OFFSET ?") - args = append(args, limit+1, listInput.Offset) // fetch one extra to detect next page - - query := sqlx.Rebind(sqlx.DOLLAR, queryBuilder.String()) - - var runs []*models.Run - if err := sqlx.SelectContext(ctx, r.db, &runs, query, args...); err != nil { - return nil, "", fmt.Errorf("failed to list runs: %w", err) - } - - // Determine next token - var nextToken string - if len(runs) > limit { - runs = runs[:limit] - nextToken = fmt.Sprintf("%d", listInput.Offset+limit) - } - - return runs, nextToken, nil -} - // AbortRun aborts a run and all its actions func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error { now := time.Now() @@ -368,45 +281,46 @@ func (r *actionRepo) GetAction(ctx context.Context, actionID *common.ActionIdent return &action, nil } -// ListActions lists actions for a run -func (r *actionRepo) ListActions(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error) { - if limit == 0 { - limit = 100 - } - +// ListActions lists actions matching the given input filter, sort, and pagination. +func (r *actionRepo) ListActions(ctx context.Context, input interfaces.ListResourceInput) ([]*models.Action, error) { var queryBuilder strings.Builder var args []interface{} - argIdx := 1 - queryBuilder.WriteString(fmt.Sprintf("SELECT * FROM actions WHERE project = $%d AND domain = $%d AND run_name = $%d", argIdx, argIdx+1, argIdx+2)) - args = append(args, runID.Project, runID.Domain, runID.Name) - argIdx += 3 + queryBuilder.WriteString("SELECT * FROM actions") - // Apply pagination token (encoded as RFC3339Nano created_at cursor) - if token != "" { - if t, err := time.Parse(time.RFC3339Nano, token); err == nil { - queryBuilder.WriteString(fmt.Sprintf(" AND created_at > $%d", argIdx)) - args = append(args, t) - argIdx++ + if input.Filter != nil { + expr, err := input.Filter.QueryExpression("") + if err != nil { + return nil, fmt.Errorf("failed to build filter expression: %w", err) } + queryBuilder.WriteString(" WHERE ") + queryBuilder.WriteString(expr.Query) + args = append(args, expr.Args...) } - queryBuilder.WriteString(fmt.Sprintf(" ORDER BY created_at ASC LIMIT $%d", argIdx)) - args = append(args, limit+1) - - var actions []*models.Action - if err := sqlx.SelectContext(ctx, r.db, &actions, queryBuilder.String(), args...); err != nil { - return nil, "", fmt.Errorf("failed to list actions: %w", err) + if len(input.SortParameters) > 0 { + queryBuilder.WriteString(" ORDER BY ") + for i, sp := range input.SortParameters { + if i > 0 { + queryBuilder.WriteString(", ") + } + queryBuilder.WriteString(sp.GetOrderExpr()) + } + } else { + queryBuilder.WriteString(" ORDER BY created_at ASC") } - // Determine next token - var nextToken string - if len(actions) > limit { - actions = actions[:limit] - nextToken = actions[len(actions)-1].CreatedAt.UTC().Format(time.RFC3339Nano) + queryBuilder.WriteString(" LIMIT ? OFFSET ?") + args = append(args, input.Limit, input.Offset) + + query := sqlx.Rebind(sqlx.DOLLAR, queryBuilder.String()) + + var actions []*models.Action + if err := sqlx.SelectContext(ctx, r.db, &actions, query, args...); err != nil { + return nil, fmt.Errorf("failed to list actions: %w", err) } - return actions, nextToken, nil + return actions, nil } // UpdateActionPhase updates the phase of an action. diff --git a/runs/repository/impl/action_test.go b/runs/repository/impl/action_test.go index 5fc6a491bd..a22bfae7ba 100644 --- a/runs/repository/impl/action_test.go +++ b/runs/repository/impl/action_test.go @@ -14,6 +14,7 @@ import ( "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" + "github.com/flyteorg/flyte/v2/runs/repository/interfaces" "github.com/flyteorg/flyte/v2/runs/repository/models" "github.com/jmoiron/sqlx" "github.com/lib/pq" @@ -311,67 +312,40 @@ func TestListRuns(t *testing.T) { require.NoError(t, err) } - // Table-driven list tests for ListRuns - type listTestCase struct { - name string - req *workflow.ListRunsRequest - expectLen int - expectTokenNil bool - verify func(t *testing.T, runs []*models.Run) - } - - listTests := []listTestCase{ - { - name: "List by org should return 3 runs", - req: &workflow.ListRunsRequest{ScopeBy: &workflow.ListRunsRequest_Org{Org: "org1"}}, - expectLen: 3, - expectTokenNil: true, - verify: func(t *testing.T, runs []*models.Run) { - runNames := map[string]bool{} - for _, r := range runs { - runNames[r.RunName] = true - } - assert.True(t, runNames["run-1"]) - assert.True(t, runNames["run-2"]) - assert.True(t, runNames["run-3"]) - }, - }, - } - - for _, tt := range listTests { - t.Run(tt.name, func(t *testing.T) { - runs, nextToken, err := actionRepo.ListRuns(ctx, tt.req) - require.NoError(t, err) - assert.Len(t, runs, tt.expectLen) - if tt.expectTokenNil { - assert.Empty(t, nextToken) - } else { - assert.NotEmpty(t, nextToken) - } - if tt.verify != nil { - tt.verify(t, runs) - } - }) + // List all runs (root actions only) + runs, err := actionRepo.ListActions(ctx, interfaces.ListResourceInput{ + Filter: NewIsRootActionFilter(), + Limit: 50, + }) + require.NoError(t, err) + assert.Len(t, runs, 3) + runNames := map[string]bool{} + for _, r := range runs { + runNames[r.RunName] = true } - - // Pagination with limit and token results - runsPage1, token1, err := actionRepo.ListRuns(ctx, &workflow.ListRunsRequest{ - Request: &common.ListRequest{Limit: 2}, - ScopeBy: &workflow.ListRunsRequest_Org{Org: "org1"}, + assert.True(t, runNames["run-1"]) + assert.True(t, runNames["run-2"]) + assert.True(t, runNames["run-3"]) + + // Pagination: page 1 + runsPage1, err := actionRepo.ListActions(ctx, interfaces.ListResourceInput{ + Filter: NewIsRootActionFilter(), + Limit: 2, + Offset: 0, }) require.NoError(t, err) assert.Len(t, runsPage1, 2) - require.NotEmpty(t, token1) - runsPage2, token2, err := actionRepo.ListRuns(ctx, &workflow.ListRunsRequest{ - Request: &common.ListRequest{Token: token1, Limit: 2}, - ScopeBy: &workflow.ListRunsRequest_Org{Org: "org1"}, + // Pagination: page 2 + runsPage2, err := actionRepo.ListActions(ctx, interfaces.ListResourceInput{ + Filter: NewIsRootActionFilter(), + Limit: 2, + Offset: 2, }) require.NoError(t, err) assert.Len(t, runsPage2, 1) - assert.Empty(t, token2) - // Test project scope filtering doesn't include other org/project/domain + // Test project scope filtering doesn't include other project _, err = actionRepo.CreateAction(ctx, &models.Run{ Project: "other-proj", Domain: "domain1", @@ -381,12 +355,11 @@ func TestListRuns(t *testing.T) { }, false) require.NoError(t, err) - runsFiltered, _, err := actionRepo.ListRuns(ctx, &workflow.ListRunsRequest{ - ScopeBy: &workflow.ListRunsRequest_ProjectId{ProjectId: &common.ProjectIdentifier{ - Organization: "org1", - Name: "proj1", - Domain: "domain1", - }}, + runsFiltered, err := actionRepo.ListActions(ctx, interfaces.ListResourceInput{ + Filter: NewIsRootActionFilter(). + And(NewEqualFilter("project", "proj1")). + And(NewEqualFilter("domain", "domain1")), + Limit: 50, }) require.NoError(t, err) assert.Len(t, runsFiltered, 3) diff --git a/runs/repository/impl/filters.go b/runs/repository/impl/filters.go index 901548202e..e4c996f4ca 100644 --- a/runs/repository/impl/filters.go +++ b/runs/repository/impl/filters.go @@ -16,6 +16,13 @@ func NewIsRootActionFilter() interfaces.Filter { return &nullFilter{field: "parent_action_name", isNull: true} } +// NewRunActionsFilter creates a filter for all actions belonging to a specific run. +func NewRunActionsFilter(runID *common.RunIdentifier) interfaces.Filter { + return NewEqualFilter("project", runID.GetProject()). + And(NewEqualFilter("domain", runID.GetDomain())). + And(NewEqualFilter("run_name", runID.GetName())) +} + // basicFilter implements the Filter interface for simple field comparisons type basicFilter struct { field string @@ -205,6 +212,14 @@ func NewRunTaskIdFilter(taskId *task.TaskIdentifier) interfaces.Filter { And(NewEqualFilter("task_version", taskId.GetVersion())) } +// NewTriggerNameFilter creates a filter matching runs by trigger_name on the actions table. +func NewTriggerNameFilter(triggerName *common.TriggerName) interfaces.Filter { + return NewEqualFilter("project", triggerName.GetProject()). + And(NewEqualFilter("domain", triggerName.GetDomain())). + And(NewEqualFilter("trigger_task_name", triggerName.GetTaskName())). + And(NewEqualFilter("trigger_name", triggerName.GetName())) +} + // NewDeployedByFilter creates a filter for deployed_by = value func NewDeployedByFilter(deployedBy string) interfaces.Filter { return NewEqualFilter("deployed_by", deployedBy) diff --git a/runs/repository/interfaces/action.go b/runs/repository/interfaces/action.go index c5fd776cb9..d8f9859d89 100644 --- a/runs/repository/interfaces/action.go +++ b/runs/repository/interfaces/action.go @@ -6,7 +6,6 @@ import ( "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" "github.com/flyteorg/flyte/v2/runs/repository/models" ) @@ -14,7 +13,6 @@ import ( type ActionRepo interface { // Run operations GetRun(ctx context.Context, runID *common.RunIdentifier) (*models.Run, error) - ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error // Action operations @@ -24,7 +22,7 @@ type ActionRepo interface { ListEventsSince(ctx context.Context, actionID *common.ActionIdentifier, attempt uint32, since time.Time, offset, limit int) ([]*models.ActionEvent, error) GetLatestEventByAttempt(ctx context.Context, actionID *common.ActionIdentifier, attempt uint32) (*models.ActionEvent, error) GetAction(ctx context.Context, actionID *common.ActionIdentifier) (*models.Action, error) - ListActions(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error) + ListActions(ctx context.Context, input ListResourceInput) ([]*models.Action, error) UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time) error AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error diff --git a/runs/repository/mocks/mocks.go b/runs/repository/mocks/mocks.go index 06bc646609..3b3b8d3af7 100644 --- a/runs/repository/mocks/mocks.go +++ b/runs/repository/mocks/mocks.go @@ -10,7 +10,6 @@ import ( "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" "github.com/flyteorg/flyte/v2/runs/repository/interfaces" "github.com/flyteorg/flyte/v2/runs/repository/models" mock "github.com/stretchr/testify/mock" @@ -646,37 +645,31 @@ func (_c *ActionRepo_InsertEvents_Call) RunAndReturn(run func(ctx context.Contex } // ListActions provides a mock function for the type ActionRepo -func (_mock *ActionRepo) ListActions(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error) { - ret := _mock.Called(ctx, runID, limit, token) +func (_mock *ActionRepo) ListActions(ctx context.Context, input interfaces.ListResourceInput) ([]*models.Action, error) { + ret := _mock.Called(ctx, input) if len(ret) == 0 { panic("no return value specified for ListActions") } var r0 []*models.Action - var r1 string - var r2 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *common.RunIdentifier, int, string) ([]*models.Action, string, error)); ok { - return returnFunc(ctx, runID, limit, token) + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, interfaces.ListResourceInput) ([]*models.Action, error)); ok { + return returnFunc(ctx, input) } - if returnFunc, ok := ret.Get(0).(func(context.Context, *common.RunIdentifier, int, string) []*models.Action); ok { - r0 = returnFunc(ctx, runID, limit, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, interfaces.ListResourceInput) []*models.Action); ok { + r0 = returnFunc(ctx, input) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Action) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, *common.RunIdentifier, int, string) string); ok { - r1 = returnFunc(ctx, runID, limit, token) - } else { - r1 = ret.Get(1).(string) - } - if returnFunc, ok := ret.Get(2).(func(context.Context, *common.RunIdentifier, int, string) error); ok { - r2 = returnFunc(ctx, runID, limit, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, interfaces.ListResourceInput) error); ok { + r1 = returnFunc(ctx, input) } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // ActionRepo_ListActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListActions' @@ -686,47 +679,35 @@ type ActionRepo_ListActions_Call struct { // ListActions is a helper method to define mock.On call // - ctx context.Context -// - runID *common.RunIdentifier -// - limit int -// - token string -func (_e *ActionRepo_Expecter) ListActions(ctx interface{}, runID interface{}, limit interface{}, token interface{}) *ActionRepo_ListActions_Call { - return &ActionRepo_ListActions_Call{Call: _e.mock.On("ListActions", ctx, runID, limit, token)} +// - input interfaces.ListResourceInput +func (_e *ActionRepo_Expecter) ListActions(ctx interface{}, input interface{}) *ActionRepo_ListActions_Call { + return &ActionRepo_ListActions_Call{Call: _e.mock.On("ListActions", ctx, input)} } -func (_c *ActionRepo_ListActions_Call) Run(run func(ctx context.Context, runID *common.RunIdentifier, limit int, token string)) *ActionRepo_ListActions_Call { +func (_c *ActionRepo_ListActions_Call) Run(run func(ctx context.Context, input interfaces.ListResourceInput)) *ActionRepo_ListActions_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 *common.RunIdentifier + var arg1 interfaces.ListResourceInput if args[1] != nil { - arg1 = args[1].(*common.RunIdentifier) - } - var arg2 int - if args[2] != nil { - arg2 = args[2].(int) - } - var arg3 string - if args[3] != nil { - arg3 = args[3].(string) + arg1 = args[1].(interfaces.ListResourceInput) } run( arg0, arg1, - arg2, - arg3, ) }) return _c } -func (_c *ActionRepo_ListActions_Call) Return(actions []*models.Action, s string, err error) *ActionRepo_ListActions_Call { - _c.Call.Return(actions, s, err) +func (_c *ActionRepo_ListActions_Call) Return(actions []*models.Action, err error) *ActionRepo_ListActions_Call { + _c.Call.Return(actions, err) return _c } -func (_c *ActionRepo_ListActions_Call) RunAndReturn(run func(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error)) *ActionRepo_ListActions_Call { +func (_c *ActionRepo_ListActions_Call) RunAndReturn(run func(ctx context.Context, input interfaces.ListResourceInput) ([]*models.Action, error)) *ActionRepo_ListActions_Call { _c.Call.Return(run) return _c } @@ -1051,80 +1032,6 @@ func (_c *ActionRepo_ListRootActions_Call) RunAndReturn(run func(ctx context.Con return _c } -// ListRuns provides a mock function for the type ActionRepo -func (_mock *ActionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error) { - ret := _mock.Called(ctx, req) - - if len(ret) == 0 { - panic("no return value specified for ListRuns") - } - - var r0 []*models.Run - var r1 string - var r2 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *workflow.ListRunsRequest) ([]*models.Run, string, error)); ok { - return returnFunc(ctx, req) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, *workflow.ListRunsRequest) []*models.Run); ok { - r0 = returnFunc(ctx, req) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*models.Run) - } - } - if returnFunc, ok := ret.Get(1).(func(context.Context, *workflow.ListRunsRequest) string); ok { - r1 = returnFunc(ctx, req) - } else { - r1 = ret.Get(1).(string) - } - if returnFunc, ok := ret.Get(2).(func(context.Context, *workflow.ListRunsRequest) error); ok { - r2 = returnFunc(ctx, req) - } else { - r2 = ret.Error(2) - } - return r0, r1, r2 -} - -// ActionRepo_ListRuns_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRuns' -type ActionRepo_ListRuns_Call struct { - *mock.Call -} - -// ListRuns is a helper method to define mock.On call -// - ctx context.Context -// - req *workflow.ListRunsRequest -func (_e *ActionRepo_Expecter) ListRuns(ctx interface{}, req interface{}) *ActionRepo_ListRuns_Call { - return &ActionRepo_ListRuns_Call{Call: _e.mock.On("ListRuns", ctx, req)} -} - -func (_c *ActionRepo_ListRuns_Call) Run(run func(ctx context.Context, req *workflow.ListRunsRequest)) *ActionRepo_ListRuns_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 *workflow.ListRunsRequest - if args[1] != nil { - arg1 = args[1].(*workflow.ListRunsRequest) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *ActionRepo_ListRuns_Call) Return(vs []*models.Run, s string, err error) *ActionRepo_ListRuns_Call { - _c.Call.Return(vs, s, err) - return _c -} - -func (_c *ActionRepo_ListRuns_Call) RunAndReturn(run func(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error)) *ActionRepo_ListRuns_Call { - _c.Call.Return(run) - return _c -} - // MarkAbortAttempt provides a mock function for the type ActionRepo func (_mock *ActionRepo) MarkAbortAttempt(ctx context.Context, actionID *common.ActionIdentifier) (int, error) { ret := _mock.Called(ctx, actionID) diff --git a/runs/repository/models/action.go b/runs/repository/models/action.go index a57282aedb..9fce2d117d 100644 --- a/runs/repository/models/action.go +++ b/runs/repository/models/action.go @@ -18,6 +18,10 @@ var ActionColumnsSet = sets.New( "run_name", "phase", "run_source", + "trigger_task_name", + "trigger_name", + "trigger_revision", + "trigger_type", "task_project", "task_domain", "task_name", diff --git a/runs/service/run_service.go b/runs/service/run_service.go index 7ac39a4f3b..df94cc17a1 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -27,6 +27,7 @@ import ( "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect" + "github.com/flyteorg/flyte/v2/runs/repository/impl" "github.com/flyteorg/flyte/v2/runs/repository/interfaces" "github.com/flyteorg/flyte/v2/runs/repository/models" "github.com/flyteorg/flyte/v2/runs/repository/transformers" @@ -842,26 +843,53 @@ func (s *RunService) ListRuns( return nil, connect.NewError(connect.CodeInvalidArgument, err) } - // List runs from database - runs, nextToken, err := s.repo.ActionRepo().ListRuns(ctx, req.Msg) + // Build scope filter: always restrict to root actions (runs). + scopeFilter := interfaces.Filter(impl.NewIsRootActionFilter()) + switch scope := req.Msg.ScopeBy.(type) { + case *workflow.ListRunsRequest_ProjectId: + scopeFilter = scopeFilter.And(impl.NewProjectIdFilter(scope.ProjectId)) + case *workflow.ListRunsRequest_TriggerName: + scopeFilter = scopeFilter.And(impl.NewTriggerNameFilter(scope.TriggerName)) + case *workflow.ListRunsRequest_TaskName: + scopeFilter = scopeFilter.And(impl.NewRunTaskNameFilter(scope.TaskName)) + case *workflow.ListRunsRequest_TaskId: + scopeFilter = scopeFilter.And(impl.NewRunTaskIdFilter(scope.TaskId)) + } + + // Parse pagination, sort, and user-supplied filters from the common ListRequest. + listInput, err := impl.NewListResourceInputFromProto(req.Msg.Request, models.ActionColumnsSet) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + // Merge: scope filter always takes precedence, user filters are ANDed on top. + if listInput.Filter == nil { + listInput.Filter = scopeFilter + } else { + listInput.Filter = scopeFilter.And(listInput.Filter) + } + + actions, err := s.repo.ActionRepo().ListActions(ctx, listInput) if err != nil { logger.Errorf(ctx, "Failed to list runs: %v", err) return nil, connect.NewError(connect.CodeInternal, err) } - // Convert to proto format - protoRuns := make([]*workflow.Run, len(runs)) - for i, run := range runs { + var nextToken string + if len(actions) > 0 && len(actions) == listInput.Limit { + nextToken = fmt.Sprintf("%d", listInput.Offset+listInput.Limit) + } + + protoRuns := make([]*workflow.Run, len(actions)) + for i, run := range actions { protoRuns[i] = s.convertRunToProto(run) } - resp := &workflow.ListRunsResponse{ + logger.Debugf(ctx, "Listed %d runs", len(actions)) + return connect.NewResponse(&workflow.ListRunsResponse{ Runs: protoRuns, Token: nextToken, - } - - logger.Debugf(ctx, "Listed %d runs", len(runs)) - return connect.NewResponse(resp), nil + }), nil } // ListActions lists actions for a run @@ -871,37 +899,38 @@ func (s *RunService) ListActions( ) (*connect.Response[workflow.ListActionsResponse], error) { logger.Infof(ctx, "Received ListActions request") - // Validate request if err := req.Msg.Validate(); err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, err) } - // List actions from database - limit := 50 - if req.Msg.Request != nil && req.Msg.Request.Limit > 0 { - limit = int(req.Msg.Request.Limit) + listInput, err := impl.NewListResourceInputFromProto(req.Msg.Request, models.ActionColumnsSet) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) } + listInput.Filter = impl.NewRunActionsFilter(req.Msg.RunId) - actions, nextToken, err := s.repo.ActionRepo().ListActions(ctx, req.Msg.RunId, limit, "") + actions, err := s.repo.ActionRepo().ListActions(ctx, listInput) if err != nil { logger.Errorf(ctx, "Failed to list actions: %v", err) return nil, connect.NewError(connect.CodeInternal, err) } - // Convert to proto format + var nextToken string + if len(actions) > 0 && len(actions) == listInput.Limit { + nextToken = fmt.Sprintf("%d", listInput.Offset+listInput.Limit) + } + protoActions := make([]*workflow.Action, len(actions)) for i, action := range actions { ea := s.convertActionToEnrichedProto(action) protoActions[i] = ea.Action } - resp := &workflow.ListActionsResponse{ + logger.Infof(ctx, "Listed %d actions", len(actions)) + return connect.NewResponse(&workflow.ListActionsResponse{ Actions: protoActions, Token: nextToken, - } - - logger.Infof(ctx, "Listed %d actions", len(actions)) - return connect.NewResponse(resp), nil + }), nil } // AbortAction aborts a specific action @@ -1065,9 +1094,9 @@ func (s *RunService) WatchRuns( go s.repo.ActionRepo().WatchAllRunUpdates(ctx, updatesCh, errsCh) // Step 2: Send existing runs that match filter - listReq := s.convertWatchRequestToListRequest(req.Msg) + listInput := s.convertWatchRequestToListInput(req.Msg) - runs, _, err := s.repo.ActionRepo().ListRuns(ctx, listReq) + runs, err := s.repo.ActionRepo().ListActions(ctx, listInput) if err != nil { logger.Errorf(ctx, "Failed to list runs: %v", err) // Continue even if list fails - still watch for new updates @@ -1159,27 +1188,20 @@ func (s *RunService) listAndSendAllActions( rsm *runStateManager, stream *connect.ServerStream[workflow.WatchActionsResponse], ) error { - token := "" - for { - batch, nextToken, err := s.repo.ActionRepo().ListActions(ctx, runID, 100, token) - if err != nil { - return err - } - - updates, err := rsm.upsertActions(ctx, batch) - if err != nil { - return err - } - - if err := s.sendChangedActions(runID, updates, stream); err != nil { - return err - } + batch, err := s.repo.ActionRepo().ListActions(ctx, interfaces.ListResourceInput{ + Filter: impl.NewRunActionsFilter(runID), + Limit: 10000, + }) + if err != nil { + return err + } - if nextToken == "" { - return nil - } - token = nextToken + updates, err := rsm.upsertActions(ctx, batch) + if err != nil { + return err } + + return s.sendChangedActions(runID, updates, stream) } func (s *RunService) sendChangedActions( @@ -1937,22 +1959,17 @@ func taskIdFromTaskSpec(taskSpec *task.TaskSpec) *task.TaskIdentifier { } } -// convertWatchRequestToListRequest converts a WatchRunsRequest to a ListRunsRequest -func (s *RunService) convertWatchRequestToListRequest(req *workflow.WatchRunsRequest) *workflow.ListRunsRequest { - listReq := &workflow.ListRunsRequest{ - Request: &common.ListRequest{ - Limit: 100, - }, - } - +// convertWatchRequestToListInput converts a WatchRunsRequest to a ListResourceInput for the initial run snapshot. +func (s *RunService) convertWatchRequestToListInput(req *workflow.WatchRunsRequest) interfaces.ListResourceInput { + scopeFilter := interfaces.Filter(impl.NewIsRootActionFilter()) switch target := req.Target.(type) { case *workflow.WatchRunsRequest_ProjectId: - listReq.ScopeBy = &workflow.ListRunsRequest_ProjectId{ - ProjectId: target.ProjectId, - } + scopeFilter = scopeFilter.And(impl.NewProjectIdFilter(target.ProjectId)) + } + return interfaces.ListResourceInput{ + Limit: 100, + Filter: scopeFilter, } - - return listReq } // runMatchesFilter checks if a run matches the WatchRunsRequest filter criteria From c6dae13ab338b972d9c83c47823126ee000400f3 Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 20 Apr 2026 17:46:01 +0800 Subject: [PATCH 3/6] feat: pagination for listAndSendAllActions Signed-off-by: machichima --- runs/service/run_service.go | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/runs/service/run_service.go b/runs/service/run_service.go index fd1483978a..7c760a1bf3 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -1278,20 +1278,32 @@ func (s *RunService) listAndSendAllActions( rsm *runStateManager, stream *connect.ServerStream[workflow.WatchActionsResponse], ) error { - batch, err := s.repo.ActionRepo().ListActions(ctx, interfaces.ListResourceInput{ - Filter: impl.NewRunActionsFilter(runID), - Limit: 10000, - }) - if err != nil { - return err - } + const pageSize = 100 + offset := 0 + for { + batch, err := s.repo.ActionRepo().ListActions(ctx, interfaces.ListResourceInput{ + Filter: impl.NewRunActionsFilter(runID), + Limit: pageSize, + Offset: offset, + }) + if err != nil { + return err + } - updates, err := rsm.upsertActions(ctx, batch) - if err != nil { - return err - } + updates, err := rsm.upsertActions(ctx, batch) + if err != nil { + return err + } + + if err := s.sendChangedActions(runID, updates, stream); err != nil { + return err + } - return s.sendChangedActions(runID, updates, stream) + if len(batch) < pageSize { + return nil + } + offset += pageSize + } } func (s *RunService) sendChangedActions( From 9c87857a97cbade52e0e1fac91e1eefe25f6a37a Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 20 Apr 2026 17:51:00 +0800 Subject: [PATCH 4/6] fix: merge filter instead of override Signed-off-by: machichima --- runs/service/run_service.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/runs/service/run_service.go b/runs/service/run_service.go index 7c760a1bf3..78e0c89b8e 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -935,7 +935,12 @@ func (s *RunService) ListActions( if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, err) } - listInput.Filter = impl.NewRunActionsFilter(req.Msg.RunId) + runFilter := interfaces.Filter(impl.NewRunActionsFilter(req.Msg.RunId)) + if listInput.Filter == nil { + listInput.Filter = runFilter + } else { + listInput.Filter = runFilter.And(listInput.Filter) + } actions, err := s.repo.ActionRepo().ListActions(ctx, listInput) if err != nil { From 7523c5026e6c4aa6792b61c03d840ac967315d3a Mon Sep 17 00:00:00 2001 From: machichima Date: Mon, 20 Apr 2026 18:10:38 +0800 Subject: [PATCH 5/6] test: fix failed tests Signed-off-by: machichima --- runs/service/run_service_test.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/runs/service/run_service_test.go b/runs/service/run_service_test.go index 624b8b15a9..cb36c34842 100644 --- a/runs/service/run_service_test.go +++ b/runs/service/run_service_test.go @@ -772,9 +772,8 @@ func TestListRuns(t *testing.T) { }) } type mockListRes struct { - runs []*models.Run - token string - err error + runs []*models.Run + err error } testCases := []struct { name string @@ -791,20 +790,20 @@ func TestListRuns(t *testing.T) { { "list with limit 2 and token", &common.ListRequest{Limit: 2, Token: "5"}, - mockListRes{runs: sqlRes[5:7], token: "7", err: nil}, + mockListRes{runs: sqlRes[5:7], err: nil}, &workflow.ListRunsResponse{Runs: runs[5:7], Token: "7"}, }, { "list with limit 3 and token", &common.ListRequest{Limit: 3, Token: "8"}, - mockListRes{runs: sqlRes[8:10], token: "", err: nil}, + mockListRes{runs: sqlRes[8:10], err: nil}, &workflow.ListRunsResponse{Runs: runs[8:10], Token: ""}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := connect.NewRequest(&workflow.ListRunsRequest{Request: tc.req}) - actionRepo.On("ListRuns", mock.Anything, req.Msg).Return(tc.mockRes.runs, tc.mockRes.token, tc.mockRes.err) + actionRepo.On("ListActions", mock.Anything, mock.Anything).Return(tc.mockRes.runs, tc.mockRes.err).Once() got, err := svc.ListRuns(context.Background(), req) assert.NoError(t, err) assert.Equal(t, len(tc.expect.Runs), len(got.Msg.Runs)) From e7b168c13938657fb38a711ad16fede48bc1d26c Mon Sep 17 00:00:00 2001 From: machichima Date: Wed, 22 Apr 2026 17:24:38 +0800 Subject: [PATCH 6/6] feat: add cursor based token Signed-off-by: machichima --- runs/repository/impl/action.go | 17 +++++++++++++++-- runs/repository/impl/requests.go | 14 +------------- runs/repository/interfaces/common.go | 9 ++++++++- runs/service/run_service.go | 18 ++++++++++++++---- 4 files changed, 38 insertions(+), 20 deletions(-) diff --git a/runs/repository/impl/action.go b/runs/repository/impl/action.go index 99711a8ab2..9824f1ff88 100644 --- a/runs/repository/impl/action.go +++ b/runs/repository/impl/action.go @@ -315,6 +315,19 @@ func (r *actionRepo) ListActions(ctx context.Context, input interfaces.ListResou args = append(args, expr.Args...) } + if input.CursorToken != "" { + if t, err := time.Parse(time.RFC3339Nano, input.CursorToken); err == nil { + // If a filter was already applied above, the WHERE clause is already open + // and we extend it with AND. Otherwise we open a new WHERE clause. + if input.Filter != nil { + queryBuilder.WriteString(" AND created_at > ?") + } else { + queryBuilder.WriteString(" WHERE created_at > ?") + } + args = append(args, t) + } + } + if len(input.SortParameters) > 0 { queryBuilder.WriteString(" ORDER BY ") for i, sp := range input.SortParameters { @@ -327,8 +340,8 @@ func (r *actionRepo) ListActions(ctx context.Context, input interfaces.ListResou queryBuilder.WriteString(" ORDER BY created_at ASC") } - queryBuilder.WriteString(" LIMIT ? OFFSET ?") - args = append(args, input.Limit, input.Offset) + queryBuilder.WriteString(" LIMIT ?") + args = append(args, input.Limit+1) query := sqlx.Rebind(sqlx.DOLLAR, queryBuilder.String()) diff --git a/runs/repository/impl/requests.go b/runs/repository/impl/requests.go index 04111d761e..c0c5aa1102 100644 --- a/runs/repository/impl/requests.go +++ b/runs/repository/impl/requests.go @@ -32,18 +32,6 @@ func NewListResourceInputFromProto(request *common.ListRequest, allowedColumns s return interfaces.ListResourceInput{}, fmt.Errorf("invalid limit: %d (exceeds maximum of %d)", limit, maxLimit) } - // Parse token as offset - var offset int - if request.Token != "" { - _, err := fmt.Sscanf(request.Token, "%d", &offset) - if err != nil { - return interfaces.ListResourceInput{}, fmt.Errorf("invalid token format: %s", request.Token) - } - if offset < 0 { - return interfaces.ListResourceInput{}, fmt.Errorf("invalid offset: %d (must be non-negative)", offset) - } - } - sortParameters, err := GetSortByFieldsV2(request, allowedColumns) if err != nil { return interfaces.ListResourceInput{}, err @@ -57,7 +45,7 @@ func NewListResourceInputFromProto(request *common.ListRequest, allowedColumns s return interfaces.ListResourceInput{ Limit: limit, Filter: combinedFilter, - Offset: offset, + CursorToken: request.Token, SortParameters: sortParameters, }, nil } diff --git a/runs/repository/interfaces/common.go b/runs/repository/interfaces/common.go index b2f6f2fc9e..ae1d146c89 100644 --- a/runs/repository/interfaces/common.go +++ b/runs/repository/interfaces/common.go @@ -2,7 +2,14 @@ package interfaces // ListResourceInput contains parameters for querying collections of resources. type ListResourceInput struct { - Limit int + Limit int + + // CursorToken is a keyset pagination cursor encoded as a RFC3339Nano timestamp. + // When set, the query returns rows with created_at strictly greater than the cursor value. + // Mutually exclusive with Offset. + CursorToken string + + // Offset is an integer offset for offset-based pagination. Offset int Filter Filter diff --git a/runs/service/run_service.go b/runs/service/run_service.go index 78e0c89b8e..cf4bf029f9 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -903,9 +903,14 @@ func (s *RunService) ListRuns( return nil, connect.NewError(connect.CodeInternal, err) } + // We fetch Limit+1 rows to detect whether a next page exists without a + // separate COUNT query. If we got more than Limit rows back, there is at + // least one more page: trim the slice and encode the last returned row's + // created_at as the keyset cursor for the next request. var nextToken string - if len(actions) > 0 && len(actions) == listInput.Limit { - nextToken = fmt.Sprintf("%d", listInput.Offset+listInput.Limit) + if len(actions) > listInput.Limit { + actions = actions[:listInput.Limit] + nextToken = actions[len(actions)-1].CreatedAt.UTC().Format(time.RFC3339Nano) } protoRuns := make([]*workflow.Run, len(actions)) @@ -948,9 +953,14 @@ func (s *RunService) ListActions( return nil, connect.NewError(connect.CodeInternal, err) } + // We fetch Limit+1 rows to detect whether a next page exists without a + // separate COUNT query. If we got more than Limit rows back, there is at + // least one more page: trim the slice and encode the last returned row's + // created_at as the keyset cursor for the next request. var nextToken string - if len(actions) > 0 && len(actions) == listInput.Limit { - nextToken = fmt.Sprintf("%d", listInput.Offset+listInput.Limit) + if len(actions) > listInput.Limit { + actions = actions[:listInput.Limit] + nextToken = actions[len(actions)-1].CreatedAt.UTC().Format(time.RFC3339Nano) } protoActions := make([]*workflow.Action, len(actions))