Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions actions/k8s/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,18 @@ func (c *ActionsClient) setupInformer(ctx context.Context) error {
}
c.dispatchEvent(taskAction, watch.Modified)
},
DeleteFunc: func(obj interface{}) {
// The informer may deliver a DeletedFinalStateUnknown tombstone
// when a delete event was missed; unwrap it first.
if tombstone, ok := obj.(toolscache.DeletedFinalStateUnknown); ok {
obj = tombstone.Obj
}
taskAction, ok := obj.(*executorv1.TaskAction)
if !ok {
return
}
c.dispatchEvent(taskAction, watch.Deleted)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking maybe we don't need it. we have inserted event in the run service, right?

},
})
if err != nil {
return fmt.Errorf("failed to add TaskAction informer handler: %w", err)
Expand Down Expand Up @@ -494,6 +506,14 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e
shortName = extractShortNameFromTemplate(taskAction.Spec.TaskTemplate)
}

phase := GetPhaseFromConditions(taskAction)
// A TaskAction deleted without a terminal condition was cascade-deleted by K8s
// (e.g. its parent was aborted). Treat it as aborted so the DB and UI reflect
// the real outcome instead of leaving the phase as UNSPECIFIED.
if eventType == watch.Deleted && phase == common.ActionPhase_ACTION_PHASE_UNSPECIFIED {
phase = common.ActionPhase_ACTION_PHASE_ABORTED
}

Comment on lines +510 to +516
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// A TaskAction deleted without a terminal condition was cascade-deleted by K8s
// (e.g. its parent was aborted). Treat it as aborted so the DB and UI reflect
// the real outcome instead of leaving the phase as UNSPECIFIED.
if eventType == watch.Deleted && phase == common.ActionPhase_ACTION_PHASE_UNSPECIFIED {
phase = common.ActionPhase_ACTION_PHASE_ABORTED
}
if eventType == watch.Deleted {
phase = common.ActionPhase_ACTION_PHASE_ABORTED
}

return &ActionUpdate{
ActionID: &common.ActionIdentifier{
Run: &common.RunIdentifier{
Expand All @@ -505,7 +525,7 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e
},
ParentActionName: parentName,
StateJSON: taskAction.Status.StateJSON,
Phase: GetPhaseFromConditions(taskAction),
Phase: phase,
OutputUri: buildOutputUri(ctx, taskAction),
IsDeleted: eventType == watch.Deleted,
TaskType: taskAction.Spec.TaskType,
Expand Down Expand Up @@ -618,7 +638,9 @@ func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *execut
}
if _, err := c.runClient.UpdateActionStatus(ctx, connect.NewRequest(statusReq)); err != nil {
logger.Warnf(ctx, "Failed to update action status in run service for %s: %v", update.ActionID.Name, err)
} else if isTerminalPhase(update.Phase) {
} else if isTerminalPhase(update.Phase) && !update.IsDeleted {
// Skip label patching for deleted CRs — the patch would always fail
// with "not found" since the object is already gone.
if err := c.markTerminalStatusRecorded(ctx, taskAction); err != nil {
logger.Warnf(ctx, "Failed to mark terminal status recorded for %s: %v", update.ActionID.Name, err)
}
Expand Down
7 changes: 2 additions & 5 deletions charts/flyte-demo/Chart.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,5 @@ dependencies:
- name: minio
repository: https://charts.bitnami.com/bitnami
version: 12.6.7
- name: postgresql
repository: https://charts.bitnami.com/bitnami
version: 12.8.1
digest: sha256:0e1b539359cb8edbc8be650462d31a896712183935781cc46be3b5a3596f2fae
generated: "2026-04-01T15:01:31.417773-07:00"
digest: sha256:efdaa4b57bfdb91479c94b7a6281879df8cab34b523e8b3bf35f1ed8bb1ec5c8
generated: "2026-04-13T15:58:58.248213-07:00"
221 changes: 201 additions & 20 deletions runs/repository/impl/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (

"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/flyteorg/flyte/v2/flytestdlib/database"
"github.com/flyteorg/flyte/v2/flytestdlib/logger"
Expand Down Expand Up @@ -136,25 +138,100 @@ func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest
return runs, nextToken, nil
}

// AbortRun aborts a run and all its actions
func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error {
// AbortRun marks the root action as aborted and sets abort_requested_at on all
// non-terminal child actions so the reconciler can terminate their pods.
// Returns the ActionIdentifiers of the child actions that were marked.
func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) ([]*common.ActionIdentifier, error) {
now := time.Now()
abortedPhase := int32(common.ActionPhase_ACTION_PHASE_ABORTED)
succeededPhase := int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED)

_, err := r.db.ExecContext(ctx,
`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5
WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL`,
int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason,
runID.Project, runID.Domain, runID.Name)

// Mark the root action as aborted.
var rootName string
var rootAttempts uint32
err := r.db.QueryRowxContext(ctx,
`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5,
ended_at = COALESCE(ended_at, GREATEST($2, created_at)),
duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000
WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL
RETURNING name, attempts`,
abortedPhase, now, now, 0, reason,
runID.Project, runID.Domain, runID.Name,
).Scan(&rootName, &rootAttempts)
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("run not found: %w", sql.ErrNoRows)
}
if err != nil {
return nil, fmt.Errorf("failed to abort run: %w", err)
}

// Cascade abort to non-terminal child actions so the reconciler terminates their pods.
// Skip SUCCEEDED and ABORTED rows — they are already done.
rows, err := r.db.QueryxContext(ctx,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can combine above query together

`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = 0, abort_reason = $4,
ended_at = COALESCE(ended_at, GREATEST($2, created_at)),
duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000
WHERE project = $5 AND domain = $6 AND run_name = $7
AND parent_action_name IS NOT NULL
AND phase != $8 AND phase != $1
RETURNING name, attempts`,
abortedPhase, now, now, reason,
runID.Project, runID.Domain, runID.Name,
succeededPhase,
)
if err != nil {
return fmt.Errorf("failed to abort run: %w", err)
return nil, fmt.Errorf("failed to abort child actions: %w", err)
}
defer rows.Close()

type childRow struct {
name string
attempts uint32
}
var children []childRow
var childIDs []*common.ActionIdentifier
for rows.Next() {
var cr childRow
if scanErr := rows.Scan(&cr.name, &cr.attempts); scanErr != nil {
return nil, fmt.Errorf("failed to scan child action name: %w", scanErr)
}
children = append(children, cr)
childIDs = append(childIDs, &common.ActionIdentifier{
Run: runID,
Name: cr.name,
})
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("failed to iterate child actions: %w", err)
}

// Insert ABORTED phase-transition events so WatchActionDetails returns a
// phaseTransitions entry with ABORTED for each affected action.
rootID := &common.ActionIdentifier{Run: runID, Name: rootName}
if err := r.insertAbortEvent(ctx, rootID, rootAttempts, reason, now); err != nil {
logger.Warnf(ctx, "AbortRun: failed to insert abort event for root %s: %v", rootName, err)
}
for _, cr := range children {
childID := &common.ActionIdentifier{Run: runID, Name: cr.name}
if err := r.insertAbortEvent(ctx, childID, cr.attempts, reason, now); err != nil {
logger.Warnf(ctx, "AbortRun: failed to insert abort event for %s: %v", cr.name, err)
}
}

// Notify run subscribers.
r.notifyRunUpdate(ctx, runID)

logger.Infof(ctx, "Aborted run: %s/%s/%s", runID.Project, runID.Domain, runID.Name)
return nil
// Notify action subscribers for the root and every aborted child so that
// WatchAllActionUpdates (used by the UI to show per-action status) reflects
// the phase change immediately.
r.notifyActionUpdate(ctx, rootID)
for _, id := range childIDs {
r.notifyActionUpdate(ctx, id)
}

logger.Infof(ctx, "Aborted run: %s/%s/%s (%d child action(s) queued for termination)",
runID.Project, runID.Domain, runID.Name, len(childIDs))
return childIDs, nil
}

// InsertEvents inserts a batch of action events, ignoring duplicates (same PK = idempotent).
Expand Down Expand Up @@ -424,6 +501,14 @@ func (r *actionRepo) UpdateActionPhase(
return err
}
if rowsAffected > 0 {
// Insert an abort event so WatchActionDetails phaseTransitions include ABORTED.
// This must happen before notifyActionUpdate so the event is visible when the
// subscriber re-fetches action events in response to the notification.
if phase == common.ActionPhase_ACTION_PHASE_ABORTED {
if err := r.insertAbortEvent(ctx, actionID, attempts, "", now); err != nil {
logger.Warnf(ctx, "UpdateActionPhase: failed to insert abort event for %s: %v", actionID.Name, err)
}
}
// Notify subscribers of the action update
r.notifyActionUpdate(ctx, actionID)
}
Expand All @@ -437,27 +522,123 @@ func (r *actionRepo) UpdateActionPhase(
return nil
}

// AbortAction aborts a specific action
// AbortAction aborts a specific action and all of its descendants.
// Children cannot run without their parent, so they are marked ABORTED immediately.
// K8s OwnerReferences handle pod/CR termination for task-action children; trace
// children have no pods and require only the DB update.
func (r *actionRepo) AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error {
now := time.Now()
abortedPhase := int32(common.ActionPhase_ACTION_PHASE_ABORTED)
succeededPhase := int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED)

_, err := r.db.ExecContext(ctx,
`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5
WHERE project = $6 AND domain = $7 AND run_name = $8 AND name = $9`,
int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason,
actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name)

var name string
var attempts uint32
err := r.db.QueryRowxContext(ctx,
`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5,
ended_at = COALESCE(ended_at, GREATEST($2, created_at)),
duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000
WHERE project = $6 AND domain = $7 AND run_name = $8 AND name = $9
RETURNING name, attempts`,
abortedPhase, now, now, 0, reason,
actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name,
).Scan(&name, &attempts)
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("action not found: %w", sql.ErrNoRows)
}
if err != nil {
return fmt.Errorf("failed to abort action: %w", err)
}

if err := r.insertAbortEvent(ctx, actionID, attempts, reason, now); err != nil {
logger.Warnf(ctx, "AbortAction: failed to insert abort event for %s: %v", actionID.Name, err)
}

// Notify action subscribers.
r.notifyActionUpdate(ctx, actionID)

logger.Infof(ctx, "Aborted action: %s", actionID.Name)
// Cascade ABORTED to all descendants of this action (recursive: children,
// grandchildren, etc.) that are not yet in a terminal phase.
// We use a recursive CTE so all depths are handled in a single round-trip.
rows, err := r.db.QueryxContext(ctx,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge above query if possible

`WITH RECURSIVE descendants AS (
SELECT name, attempts
FROM actions
WHERE project = $1 AND domain = $2 AND run_name = $3
AND parent_action_name = $4
UNION ALL
SELECT a.name, a.attempts
FROM actions a
JOIN descendants d ON a.parent_action_name = d.name
WHERE a.project = $1 AND a.domain = $2 AND a.run_name = $3
)
UPDATE actions
SET phase = $5, updated_at = $6, abort_reason = $7,
ended_at = COALESCE(ended_at, GREATEST($6, created_at)),
duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($6, created_at)) - created_at)) * 1000
FROM descendants
WHERE actions.project = $1 AND actions.domain = $2 AND actions.run_name = $3
AND actions.name = descendants.name
AND actions.phase != $5 AND actions.phase != $8
RETURNING actions.name, actions.attempts`,
actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name,
abortedPhase, now, reason, succeededPhase,
)
if err != nil {
logger.Warnf(ctx, "AbortAction: failed to cascade abort to descendants of %s: %v", actionID.Name, err)
return nil
}
defer rows.Close()

var descendantCount int
for rows.Next() {
var childName string
var childAttempts uint32
if scanErr := rows.Scan(&childName, &childAttempts); scanErr != nil {
logger.Warnf(ctx, "AbortAction: failed to scan descendant row: %v", scanErr)
continue
}
descendantCount++
childID := &common.ActionIdentifier{Run: actionID.Run, Name: childName}
if err := r.insertAbortEvent(ctx, childID, childAttempts, reason, now); err != nil {
logger.Warnf(ctx, "AbortAction: failed to insert abort event for descendant %s: %v", childName, err)
}
r.notifyActionUpdate(ctx, childID)
logger.Infof(ctx, "AbortAction: cascade-aborted descendant %s of %s", childName, actionID.Name)
}
if err := rows.Err(); err != nil {
logger.Warnf(ctx, "AbortAction: error iterating descendant rows for %s: %v", actionID.Name, err)
}

logger.Infof(ctx, "AbortAction: aborted %s and %d descendant(s)", actionID.Name, descendantCount)
return nil
}

// insertAbortEvent writes a single ABORTED phase-transition event into action_events so that
// WatchActionDetails returns a phaseTransitions entry with phase = ABORTED for the action.
// Failures are non-fatal — the caller should log and continue.
func (r *actionRepo) insertAbortEvent(ctx context.Context, actionID *common.ActionIdentifier, attempts uint32, reason string, now time.Time) error {
// Build the minimal ActionEvent proto that mergeEvents will pick up.
event := &workflow.ActionEvent{
Id: actionID,
Phase: common.ActionPhase_ACTION_PHASE_ABORTED,
Attempt: attempts,
UpdatedTime: timestamppb.New(now),
}
info, err := proto.Marshal(event)
if err != nil {
return fmt.Errorf("marshal abort event: %w", err)
}

_, err = r.db.ExecContext(ctx,
`INSERT INTO action_events (project, domain, run_name, name, attempt, phase, version, info, error_kind, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, 0, $7, NULL, $8, $8)
ON CONFLICT DO NOTHING`,
actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name,
attempts, int32(common.ActionPhase_ACTION_PHASE_ABORTED), info, now,
)
return err
}

// ListPendingAborts returns all actions that have abort_requested_at set (i.e. awaiting pod termination).
func (r *actionRepo) ListPendingAborts(ctx context.Context) ([]*models.Action, error) {
var actions []*models.Action
Expand Down Expand Up @@ -488,7 +669,7 @@ func (r *actionRepo) MarkAbortAttempt(ctx context.Context, actionID *common.Acti
// ClearAbortRequest clears abort_requested_at (and resets counters) once the pod is confirmed terminated.
func (r *actionRepo) ClearAbortRequest(ctx context.Context, actionID *common.ActionIdentifier) error {
_, err := r.db.ExecContext(ctx,
`UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, abort_reason = NULL, updated_at = $1
`UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, updated_at = $1
WHERE project = $2 AND domain = $3 AND run_name = $4 AND name = $5`,
time.Now(), actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion runs/repository/interfaces/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ type ActionRepo interface {
// Run operations
GetRun(ctx context.Context, runID *common.RunIdentifier) (*models.Run, error)
ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error)
AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error
// AbortRun marks the root action as aborted and sets abort_requested_at on all
// non-terminal child actions. Returns the identifiers of those child actions so
// the caller can push them to the AbortReconciler for pod termination.
AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) ([]*common.ActionIdentifier, error)

// Action operations
CreateAction(ctx context.Context, action *models.Action, updateTriggeredAt bool) (*models.Action, error)
Expand Down
Loading