Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 29 additions & 82 deletions runs/repository/impl/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
Expand All @@ -20,7 +19,6 @@ import (
"github.com/flyteorg/flyte/v2/flytestdlib/logger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
"github.com/flyteorg/flyte/v2/runs/repository/interfaces"
"github.com/flyteorg/flyte/v2/runs/repository/models"
)
Expand Down Expand Up @@ -84,58 +82,6 @@ func (r *actionRepo) GetRun(ctx context.Context, runID *common.RunIdentifier) (*
return &run, nil
}

// ListRuns lists runs with pagination
func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error) {
var queryBuilder strings.Builder
var args []interface{}
argIdx := 1

queryBuilder.WriteString("SELECT * FROM actions WHERE parent_action_name IS NULL")

// Apply scope filters
switch scope := req.ScopeBy.(type) {
case *workflow.ListRunsRequest_ProjectId:
queryBuilder.WriteString(fmt.Sprintf(" AND project = $%d AND domain = $%d", argIdx, argIdx+1))
args = append(args, scope.ProjectId.Name, scope.ProjectId.Domain)
argIdx += 2
}

// Apply pagination according to token and limit from requests.
limit := 50
offset := 0
if req.Request != nil {
if req.Request.Token != "" {
parsedOffset, err := strconv.Atoi(req.Request.Token)
if err != nil {
return nil, "", fmt.Errorf("invalid pagination token: %w", err)
}
offset = parsedOffset
}

if req.Request.Limit > 0 {
limit = int(req.Request.Limit)
}
}

queryBuilder.WriteString(fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1))
args = append(args, limit+1, offset) // Fetch one extra to determine if there are more
argIdx += 2

var runs []*models.Run
if err := sqlx.SelectContext(ctx, r.db, &runs, queryBuilder.String(), args...); err != nil {
return nil, "", fmt.Errorf("failed to list runs: %w", err)
}

// Determine next token
var nextToken string
if len(runs) > limit {
runs = runs[:limit]
nextToken = fmt.Sprintf("%d", offset+limit)
}

return runs, nextToken, nil
}

// AbortRun marks only the root action as ABORTED and sets abort_requested_at on it.
// K8s cascades CRD deletion to child actions via OwnerReferences; the action service
// informer handles marking them ABORTED in DB when their CRDs are deleted.
Expand Down Expand Up @@ -352,45 +298,46 @@ func (r *actionRepo) GetAction(ctx context.Context, actionID *common.ActionIdent
return &action, nil
}

// ListActions lists actions for a run
func (r *actionRepo) ListActions(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error) {
if limit == 0 {
limit = 100
}

// ListActions lists actions matching the given input filter, sort, and pagination.
func (r *actionRepo) ListActions(ctx context.Context, input interfaces.ListResourceInput) ([]*models.Action, error) {
var queryBuilder strings.Builder
var args []interface{}
argIdx := 1

queryBuilder.WriteString(fmt.Sprintf("SELECT * FROM actions WHERE project = $%d AND domain = $%d AND run_name = $%d", argIdx, argIdx+1, argIdx+2))
args = append(args, runID.Project, runID.Domain, runID.Name)
argIdx += 3
queryBuilder.WriteString("SELECT * FROM actions")

// Apply pagination token (encoded as RFC3339Nano created_at cursor)
if token != "" {
if t, err := time.Parse(time.RFC3339Nano, token); err == nil {
queryBuilder.WriteString(fmt.Sprintf(" AND created_at > $%d", argIdx))
args = append(args, t)
argIdx++
if input.Filter != nil {
expr, err := input.Filter.QueryExpression("")
if err != nil {
return nil, fmt.Errorf("failed to build filter expression: %w", err)
}
queryBuilder.WriteString(" WHERE ")
queryBuilder.WriteString(expr.Query)
args = append(args, expr.Args...)
}

queryBuilder.WriteString(fmt.Sprintf(" ORDER BY created_at ASC LIMIT $%d", argIdx))
args = append(args, limit+1)

var actions []*models.Action
if err := sqlx.SelectContext(ctx, r.db, &actions, queryBuilder.String(), args...); err != nil {
return nil, "", fmt.Errorf("failed to list actions: %w", err)
if len(input.SortParameters) > 0 {
queryBuilder.WriteString(" ORDER BY ")
for i, sp := range input.SortParameters {
if i > 0 {
queryBuilder.WriteString(", ")
}
queryBuilder.WriteString(sp.GetOrderExpr())
}
} else {
queryBuilder.WriteString(" ORDER BY created_at ASC")
}

// Determine next token
var nextToken string
if len(actions) > limit {
actions = actions[:limit]
nextToken = actions[len(actions)-1].CreatedAt.UTC().Format(time.RFC3339Nano)
queryBuilder.WriteString(" LIMIT ? OFFSET ?")
args = append(args, input.Limit, input.Offset)

query := sqlx.Rebind(sqlx.DOLLAR, queryBuilder.String())

var actions []*models.Action
if err := sqlx.SelectContext(ctx, r.db, &actions, query, args...); err != nil {
return nil, fmt.Errorf("failed to list actions: %w", err)
}

return actions, nextToken, nil
return actions, nil
}

// UpdateActionPhase updates the phase of an action.
Expand Down
89 changes: 31 additions & 58 deletions runs/repository/impl/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
"github.com/flyteorg/flyte/v2/runs/repository/interfaces"
"github.com/flyteorg/flyte/v2/runs/repository/models"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
Expand Down Expand Up @@ -329,67 +330,40 @@ func TestListRuns(t *testing.T) {
require.NoError(t, err)
}

// Table-driven list tests for ListRuns
type listTestCase struct {
name string
req *workflow.ListRunsRequest
expectLen int
expectTokenNil bool
verify func(t *testing.T, runs []*models.Run)
}

listTests := []listTestCase{
{
name: "List by org should return 3 runs",
req: &workflow.ListRunsRequest{ScopeBy: &workflow.ListRunsRequest_Org{Org: "org1"}},
expectLen: 3,
expectTokenNil: true,
verify: func(t *testing.T, runs []*models.Run) {
runNames := map[string]bool{}
for _, r := range runs {
runNames[r.RunName] = true
}
assert.True(t, runNames["run-1"])
assert.True(t, runNames["run-2"])
assert.True(t, runNames["run-3"])
},
},
}

for _, tt := range listTests {
t.Run(tt.name, func(t *testing.T) {
runs, nextToken, err := actionRepo.ListRuns(ctx, tt.req)
require.NoError(t, err)
assert.Len(t, runs, tt.expectLen)
if tt.expectTokenNil {
assert.Empty(t, nextToken)
} else {
assert.NotEmpty(t, nextToken)
}
if tt.verify != nil {
tt.verify(t, runs)
}
})
// List all runs (root actions only)
runs, err := actionRepo.ListActions(ctx, interfaces.ListResourceInput{
Filter: NewIsRootActionFilter(),
Limit: 50,
})
require.NoError(t, err)
assert.Len(t, runs, 3)
runNames := map[string]bool{}
for _, r := range runs {
runNames[r.RunName] = true
}

// Pagination with limit and token results
runsPage1, token1, err := actionRepo.ListRuns(ctx, &workflow.ListRunsRequest{
Request: &common.ListRequest{Limit: 2},
ScopeBy: &workflow.ListRunsRequest_Org{Org: "org1"},
assert.True(t, runNames["run-1"])
assert.True(t, runNames["run-2"])
assert.True(t, runNames["run-3"])

// Pagination: page 1
runsPage1, err := actionRepo.ListActions(ctx, interfaces.ListResourceInput{
Filter: NewIsRootActionFilter(),
Limit: 2,
Offset: 0,
})
require.NoError(t, err)
assert.Len(t, runsPage1, 2)
require.NotEmpty(t, token1)

runsPage2, token2, err := actionRepo.ListRuns(ctx, &workflow.ListRunsRequest{
Request: &common.ListRequest{Token: token1, Limit: 2},
ScopeBy: &workflow.ListRunsRequest_Org{Org: "org1"},
// Pagination: page 2
runsPage2, err := actionRepo.ListActions(ctx, interfaces.ListResourceInput{
Filter: NewIsRootActionFilter(),
Limit: 2,
Offset: 2,
})
require.NoError(t, err)
assert.Len(t, runsPage2, 1)
assert.Empty(t, token2)

// Test project scope filtering doesn't include other org/project/domain
// Test project scope filtering doesn't include other project
_, err = actionRepo.CreateAction(ctx, &models.Run{
Project: "other-proj",
Domain: "domain1",
Expand All @@ -399,12 +373,11 @@ func TestListRuns(t *testing.T) {
}, false)
require.NoError(t, err)

runsFiltered, _, err := actionRepo.ListRuns(ctx, &workflow.ListRunsRequest{
ScopeBy: &workflow.ListRunsRequest_ProjectId{ProjectId: &common.ProjectIdentifier{
Organization: "org1",
Name: "proj1",
Domain: "domain1",
}},
runsFiltered, err := actionRepo.ListActions(ctx, interfaces.ListResourceInput{
Filter: NewIsRootActionFilter().
And(NewEqualFilter("project", "proj1")).
And(NewEqualFilter("domain", "domain1")),
Limit: 50,
})
require.NoError(t, err)
assert.Len(t, runsFiltered, 3)
Expand Down
71 changes: 66 additions & 5 deletions runs/repository/impl/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ import (
"github.com/flyteorg/flyte/v2/runs/repository/interfaces"
)

// NewIsRootActionFilter creates a filter for root actions (runs) only.
func NewIsRootActionFilter() interfaces.Filter {
return &nullFilter{field: "parent_action_name", isNull: true}
}

// NewRunActionsFilter creates a filter for all actions belonging to a specific run.
func NewRunActionsFilter(runID *common.RunIdentifier) interfaces.Filter {
return NewEqualFilter("project", runID.GetProject()).
And(NewEqualFilter("domain", runID.GetDomain())).
And(NewEqualFilter("run_name", runID.GetName()))
}

// basicFilter implements the Filter interface for simple field comparisons
type basicFilter struct {
field string
Expand Down Expand Up @@ -79,6 +91,34 @@ func (f *basicFilter) Or(filter interfaces.Filter) interfaces.Filter {
}
}

// nullFilter implements the Filter interface for IS NULL / IS NOT NULL checks
type nullFilter struct {
field string
isNull bool
}

func (f *nullFilter) QueryExpression(table string) (interfaces.QueryExpr, error) {
column := f.field
if table != "" {
column = table + "." + f.field
}
op := "IS NULL"
if !f.isNull {
op = "IS NOT NULL"
}
return interfaces.QueryExpr{
Query: fmt.Sprintf("%s %s", column, op),
}, nil
}

func (f *nullFilter) And(filter interfaces.Filter) interfaces.Filter {
return &compositeFilter{left: f, right: filter, operator: "AND"}
}

func (f *nullFilter) Or(filter interfaces.Filter) interfaces.Filter {
return &compositeFilter{left: f, right: filter, operator: "OR"}
}

// compositeFilter implements the Filter interface for AND/OR operations
type compositeFilter struct {
left interfaces.Filter
Expand Down Expand Up @@ -150,13 +190,34 @@ func NewProjectIdFilter(projectId *common.ProjectIdentifier) interfaces.Filter {
return projectFilter.And(domainFilter)
}

// NewTaskNameFilter creates a filter for task name (project, domain, name)
// NewTaskNameFilter creates a filter for task name on the tasks table (project, domain, name).
func NewTaskNameFilter(taskName *task.TaskName) interfaces.Filter {
projectFilter := NewEqualFilter("project", taskName.GetProject())
domainFilter := NewEqualFilter("domain", taskName.GetDomain())
nameFilter := NewEqualFilter("name", taskName.GetName())
return NewEqualFilter("project", taskName.GetProject()).
And(NewEqualFilter("domain", taskName.GetDomain())).
And(NewEqualFilter("name", taskName.GetName()))
}

// NewRunTaskNameFilter creates a filter matching runs by task name columns on the actions table.
func NewRunTaskNameFilter(taskName *task.TaskName) interfaces.Filter {
return NewEqualFilter("task_project", taskName.GetProject()).
And(NewEqualFilter("task_domain", taskName.GetDomain())).
And(NewEqualFilter("task_name", taskName.GetName()))
}

// NewRunTaskIdFilter creates a filter matching runs by full task identifier on the actions table.
func NewRunTaskIdFilter(taskId *task.TaskIdentifier) interfaces.Filter {
return NewEqualFilter("task_project", taskId.GetProject()).
And(NewEqualFilter("task_domain", taskId.GetDomain())).
And(NewEqualFilter("task_name", taskId.GetName())).
And(NewEqualFilter("task_version", taskId.GetVersion()))
}

return projectFilter.And(domainFilter).And(nameFilter)
// NewTriggerNameFilter creates a filter matching runs by trigger_name on the actions table.
func NewTriggerNameFilter(triggerName *common.TriggerName) interfaces.Filter {
return NewEqualFilter("project", triggerName.GetProject()).
And(NewEqualFilter("domain", triggerName.GetDomain())).
And(NewEqualFilter("trigger_task_name", triggerName.GetTaskName())).
And(NewEqualFilter("trigger_name", triggerName.GetName()))
}

// NewDeployedByFilter creates a filter for deployed_by = value
Expand Down
4 changes: 1 addition & 3 deletions runs/repository/interfaces/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@ import (

"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
"github.com/flyteorg/flyte/v2/runs/repository/models"
)

// ActionRepo defines the interface for actions/runs data access
type ActionRepo interface {
// Run operations
GetRun(ctx context.Context, runID *common.RunIdentifier) (*models.Run, error)
ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error)
// AbortRun marks only the root action as ABORTED and sets abort_requested_at on it.
// K8s cascades CRD deletion to child actions via OwnerReferences; the action service
// informer handles marking them ABORTED in DB when their CRDs are deleted.
Expand All @@ -27,7 +25,7 @@ type ActionRepo interface {
ListEventsSince(ctx context.Context, actionID *common.ActionIdentifier, attempt uint32, since time.Time, offset, limit int) ([]*models.ActionEvent, error)
GetLatestEventByAttempt(ctx context.Context, actionID *common.ActionIdentifier, attempt uint32) (*models.ActionEvent, error)
GetAction(ctx context.Context, actionID *common.ActionIdentifier) (*models.Action, error)
ListActions(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error)
ListActions(ctx context.Context, input ListResourceInput) ([]*models.Action, error)
UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time) error
// AbortAction marks only the targeted action as ABORTED and sets abort_requested_at.
// K8s cascades CRD deletion to descendants via OwnerReferences.
Expand Down
Loading
Loading