-
Notifications
You must be signed in to change notification settings - Fork 806
fix: propagate abort phase to sub-actions and fix WatchActionDetails race #7200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -422,6 +422,18 @@ func (c *ActionsClient) setupInformer(ctx context.Context) error { | |||||||||||||||||||||
| } | ||||||||||||||||||||||
| c.dispatchEvent(taskAction, watch.Modified) | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| DeleteFunc: func(obj interface{}) { | ||||||||||||||||||||||
| // The informer may deliver a DeletedFinalStateUnknown tombstone | ||||||||||||||||||||||
| // when a delete event was missed; unwrap it first. | ||||||||||||||||||||||
| if tombstone, ok := obj.(toolscache.DeletedFinalStateUnknown); ok { | ||||||||||||||||||||||
| obj = tombstone.Obj | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| taskAction, ok := obj.(*executorv1.TaskAction) | ||||||||||||||||||||||
| if !ok { | ||||||||||||||||||||||
| return | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| c.dispatchEvent(taskAction, watch.Deleted) | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| }) | ||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||
| return fmt.Errorf("failed to add TaskAction informer handler: %w", err) | ||||||||||||||||||||||
|
|
@@ -494,6 +506,14 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e | |||||||||||||||||||||
| shortName = extractShortNameFromTemplate(taskAction.Spec.TaskTemplate) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| phase := GetPhaseFromConditions(taskAction) | ||||||||||||||||||||||
| // A TaskAction deleted without a terminal condition was cascade-deleted by K8s | ||||||||||||||||||||||
| // (e.g. its parent was aborted). Treat it as aborted so the DB and UI reflect | ||||||||||||||||||||||
| // the real outcome instead of leaving the phase as UNSPECIFIED. | ||||||||||||||||||||||
| if eventType == watch.Deleted && phase == common.ActionPhase_ACTION_PHASE_UNSPECIFIED { | ||||||||||||||||||||||
| phase = common.ActionPhase_ACTION_PHASE_ABORTED | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Comment on lines
+510
to
+516
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
| return &ActionUpdate{ | ||||||||||||||||||||||
| ActionID: &common.ActionIdentifier{ | ||||||||||||||||||||||
| Run: &common.RunIdentifier{ | ||||||||||||||||||||||
|
|
@@ -505,7 +525,7 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e | |||||||||||||||||||||
| }, | ||||||||||||||||||||||
| ParentActionName: parentName, | ||||||||||||||||||||||
| StateJSON: taskAction.Status.StateJSON, | ||||||||||||||||||||||
| Phase: GetPhaseFromConditions(taskAction), | ||||||||||||||||||||||
| Phase: phase, | ||||||||||||||||||||||
| OutputUri: buildOutputUri(ctx, taskAction), | ||||||||||||||||||||||
| IsDeleted: eventType == watch.Deleted, | ||||||||||||||||||||||
| TaskType: taskAction.Spec.TaskType, | ||||||||||||||||||||||
|
|
@@ -618,7 +638,9 @@ func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *execut | |||||||||||||||||||||
| } | ||||||||||||||||||||||
| if _, err := c.runClient.UpdateActionStatus(ctx, connect.NewRequest(statusReq)); err != nil { | ||||||||||||||||||||||
| logger.Warnf(ctx, "Failed to update action status in run service for %s: %v", update.ActionID.Name, err) | ||||||||||||||||||||||
| } else if isTerminalPhase(update.Phase) { | ||||||||||||||||||||||
| } else if isTerminalPhase(update.Phase) && !update.IsDeleted { | ||||||||||||||||||||||
| // Skip label patching for deleted CRs — the patch would always fail | ||||||||||||||||||||||
| // with "not found" since the object is already gone. | ||||||||||||||||||||||
| if err := c.markTerminalStatusRecorded(ctx, taskAction); err != nil { | ||||||||||||||||||||||
| logger.Warnf(ctx, "Failed to mark terminal status recorded for %s: %v", update.ActionID.Name, err) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,8 @@ import ( | |
|
|
||
| "github.com/jmoiron/sqlx" | ||
| "github.com/lib/pq" | ||
| "google.golang.org/protobuf/proto" | ||
| "google.golang.org/protobuf/types/known/timestamppb" | ||
|
|
||
| "github.com/flyteorg/flyte/v2/flytestdlib/database" | ||
| "github.com/flyteorg/flyte/v2/flytestdlib/logger" | ||
|
|
@@ -136,25 +138,100 @@ func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest | |
| return runs, nextToken, nil | ||
| } | ||
|
|
||
| // AbortRun aborts a run and all its actions | ||
| func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error { | ||
| // AbortRun marks the root action as aborted and sets abort_requested_at on all | ||
| // non-terminal child actions so the reconciler can terminate their pods. | ||
| // Returns the ActionIdentifiers of the child actions that were marked. | ||
| func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) ([]*common.ActionIdentifier, error) { | ||
| now := time.Now() | ||
| abortedPhase := int32(common.ActionPhase_ACTION_PHASE_ABORTED) | ||
| succeededPhase := int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED) | ||
|
|
||
| _, err := r.db.ExecContext(ctx, | ||
| `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5 | ||
| WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL`, | ||
| int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason, | ||
| runID.Project, runID.Domain, runID.Name) | ||
|
|
||
| // Mark the root action as aborted. | ||
| var rootName string | ||
| var rootAttempts uint32 | ||
| err := r.db.QueryRowxContext(ctx, | ||
| `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5, | ||
| ended_at = COALESCE(ended_at, GREATEST($2, created_at)), | ||
| duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000 | ||
| WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL | ||
| RETURNING name, attempts`, | ||
| abortedPhase, now, now, 0, reason, | ||
| runID.Project, runID.Domain, runID.Name, | ||
| ).Scan(&rootName, &rootAttempts) | ||
| if errors.Is(err, sql.ErrNoRows) { | ||
| return nil, fmt.Errorf("run not found: %w", sql.ErrNoRows) | ||
| } | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to abort run: %w", err) | ||
| } | ||
|
|
||
| // Cascade abort to non-terminal child actions so the reconciler terminates their pods. | ||
| // Skip SUCCEEDED and ABORTED rows — they are already done. | ||
| rows, err := r.db.QueryxContext(ctx, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can combine above query together |
||
| `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = 0, abort_reason = $4, | ||
| ended_at = COALESCE(ended_at, GREATEST($2, created_at)), | ||
| duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000 | ||
| WHERE project = $5 AND domain = $6 AND run_name = $7 | ||
| AND parent_action_name IS NOT NULL | ||
| AND phase != $8 AND phase != $1 | ||
| RETURNING name, attempts`, | ||
| abortedPhase, now, now, reason, | ||
| runID.Project, runID.Domain, runID.Name, | ||
| succeededPhase, | ||
| ) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to abort run: %w", err) | ||
| return nil, fmt.Errorf("failed to abort child actions: %w", err) | ||
| } | ||
| defer rows.Close() | ||
|
|
||
| type childRow struct { | ||
| name string | ||
| attempts uint32 | ||
| } | ||
| var children []childRow | ||
| var childIDs []*common.ActionIdentifier | ||
| for rows.Next() { | ||
| var cr childRow | ||
| if scanErr := rows.Scan(&cr.name, &cr.attempts); scanErr != nil { | ||
| return nil, fmt.Errorf("failed to scan child action name: %w", scanErr) | ||
| } | ||
| children = append(children, cr) | ||
| childIDs = append(childIDs, &common.ActionIdentifier{ | ||
| Run: runID, | ||
| Name: cr.name, | ||
| }) | ||
| } | ||
| if err := rows.Err(); err != nil { | ||
| return nil, fmt.Errorf("failed to iterate child actions: %w", err) | ||
| } | ||
|
|
||
| // Insert ABORTED phase-transition events so WatchActionDetails returns a | ||
| // phaseTransitions entry with ABORTED for each affected action. | ||
| rootID := &common.ActionIdentifier{Run: runID, Name: rootName} | ||
| if err := r.insertAbortEvent(ctx, rootID, rootAttempts, reason, now); err != nil { | ||
| logger.Warnf(ctx, "AbortRun: failed to insert abort event for root %s: %v", rootName, err) | ||
| } | ||
| for _, cr := range children { | ||
| childID := &common.ActionIdentifier{Run: runID, Name: cr.name} | ||
| if err := r.insertAbortEvent(ctx, childID, cr.attempts, reason, now); err != nil { | ||
| logger.Warnf(ctx, "AbortRun: failed to insert abort event for %s: %v", cr.name, err) | ||
| } | ||
| } | ||
|
|
||
| // Notify run subscribers. | ||
| r.notifyRunUpdate(ctx, runID) | ||
|
|
||
| logger.Infof(ctx, "Aborted run: %s/%s/%s", runID.Project, runID.Domain, runID.Name) | ||
| return nil | ||
| // Notify action subscribers for the root and every aborted child so that | ||
| // WatchAllActionUpdates (used by the UI to show per-action status) reflects | ||
| // the phase change immediately. | ||
| r.notifyActionUpdate(ctx, rootID) | ||
| for _, id := range childIDs { | ||
| r.notifyActionUpdate(ctx, id) | ||
| } | ||
|
|
||
| logger.Infof(ctx, "Aborted run: %s/%s/%s (%d child action(s) queued for termination)", | ||
| runID.Project, runID.Domain, runID.Name, len(childIDs)) | ||
| return childIDs, nil | ||
| } | ||
|
|
||
| // InsertEvents inserts a batch of action events, ignoring duplicates (same PK = idempotent). | ||
|
|
@@ -424,6 +501,14 @@ func (r *actionRepo) UpdateActionPhase( | |
| return err | ||
| } | ||
| if rowsAffected > 0 { | ||
| // Insert an abort event so WatchActionDetails phaseTransitions include ABORTED. | ||
| // This must happen before notifyActionUpdate so the event is visible when the | ||
| // subscriber re-fetches action events in response to the notification. | ||
| if phase == common.ActionPhase_ACTION_PHASE_ABORTED { | ||
| if err := r.insertAbortEvent(ctx, actionID, attempts, "", now); err != nil { | ||
| logger.Warnf(ctx, "UpdateActionPhase: failed to insert abort event for %s: %v", actionID.Name, err) | ||
| } | ||
| } | ||
| // Notify subscribers of the action update | ||
| r.notifyActionUpdate(ctx, actionID) | ||
| } | ||
|
|
@@ -437,27 +522,123 @@ func (r *actionRepo) UpdateActionPhase( | |
| return nil | ||
| } | ||
|
|
||
| // AbortAction aborts a specific action | ||
| // AbortAction aborts a specific action and all of its descendants. | ||
| // Children cannot run without their parent, so they are marked ABORTED immediately. | ||
| // K8s OwnerReferences handle pod/CR termination for task-action children; trace | ||
| // children have no pods and require only the DB update. | ||
| func (r *actionRepo) AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error { | ||
| now := time.Now() | ||
| abortedPhase := int32(common.ActionPhase_ACTION_PHASE_ABORTED) | ||
| succeededPhase := int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED) | ||
|
|
||
| _, err := r.db.ExecContext(ctx, | ||
| `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5 | ||
| WHERE project = $6 AND domain = $7 AND run_name = $8 AND name = $9`, | ||
| int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason, | ||
| actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name) | ||
|
|
||
| var name string | ||
| var attempts uint32 | ||
| err := r.db.QueryRowxContext(ctx, | ||
| `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5, | ||
| ended_at = COALESCE(ended_at, GREATEST($2, created_at)), | ||
| duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000 | ||
| WHERE project = $6 AND domain = $7 AND run_name = $8 AND name = $9 | ||
| RETURNING name, attempts`, | ||
| abortedPhase, now, now, 0, reason, | ||
| actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name, | ||
| ).Scan(&name, &attempts) | ||
| if errors.Is(err, sql.ErrNoRows) { | ||
| return fmt.Errorf("action not found: %w", sql.ErrNoRows) | ||
| } | ||
| if err != nil { | ||
| return fmt.Errorf("failed to abort action: %w", err) | ||
| } | ||
|
|
||
| if err := r.insertAbortEvent(ctx, actionID, attempts, reason, now); err != nil { | ||
| logger.Warnf(ctx, "AbortAction: failed to insert abort event for %s: %v", actionID.Name, err) | ||
| } | ||
|
|
||
| // Notify action subscribers. | ||
| r.notifyActionUpdate(ctx, actionID) | ||
|
|
||
| logger.Infof(ctx, "Aborted action: %s", actionID.Name) | ||
| // Cascade ABORTED to all descendants of this action (recursive: children, | ||
| // grandchildren, etc.) that are not yet in a terminal phase. | ||
| // We use a recursive CTE so all depths are handled in a single round-trip. | ||
| rows, err := r.db.QueryxContext(ctx, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. merge above query if possible |
||
| `WITH RECURSIVE descendants AS ( | ||
| SELECT name, attempts | ||
| FROM actions | ||
| WHERE project = $1 AND domain = $2 AND run_name = $3 | ||
| AND parent_action_name = $4 | ||
| UNION ALL | ||
| SELECT a.name, a.attempts | ||
| FROM actions a | ||
| JOIN descendants d ON a.parent_action_name = d.name | ||
| WHERE a.project = $1 AND a.domain = $2 AND a.run_name = $3 | ||
| ) | ||
| UPDATE actions | ||
| SET phase = $5, updated_at = $6, abort_reason = $7, | ||
| ended_at = COALESCE(ended_at, GREATEST($6, created_at)), | ||
| duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($6, created_at)) - created_at)) * 1000 | ||
| FROM descendants | ||
| WHERE actions.project = $1 AND actions.domain = $2 AND actions.run_name = $3 | ||
| AND actions.name = descendants.name | ||
| AND actions.phase != $5 AND actions.phase != $8 | ||
| RETURNING actions.name, actions.attempts`, | ||
| actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name, | ||
| abortedPhase, now, reason, succeededPhase, | ||
| ) | ||
| if err != nil { | ||
| logger.Warnf(ctx, "AbortAction: failed to cascade abort to descendants of %s: %v", actionID.Name, err) | ||
| return nil | ||
| } | ||
| defer rows.Close() | ||
|
|
||
| var descendantCount int | ||
| for rows.Next() { | ||
| var childName string | ||
| var childAttempts uint32 | ||
| if scanErr := rows.Scan(&childName, &childAttempts); scanErr != nil { | ||
| logger.Warnf(ctx, "AbortAction: failed to scan descendant row: %v", scanErr) | ||
| continue | ||
| } | ||
| descendantCount++ | ||
| childID := &common.ActionIdentifier{Run: actionID.Run, Name: childName} | ||
| if err := r.insertAbortEvent(ctx, childID, childAttempts, reason, now); err != nil { | ||
| logger.Warnf(ctx, "AbortAction: failed to insert abort event for descendant %s: %v", childName, err) | ||
| } | ||
| r.notifyActionUpdate(ctx, childID) | ||
| logger.Infof(ctx, "AbortAction: cascade-aborted descendant %s of %s", childName, actionID.Name) | ||
| } | ||
| if err := rows.Err(); err != nil { | ||
| logger.Warnf(ctx, "AbortAction: error iterating descendant rows for %s: %v", actionID.Name, err) | ||
| } | ||
|
|
||
| logger.Infof(ctx, "AbortAction: aborted %s and %d descendant(s)", actionID.Name, descendantCount) | ||
| return nil | ||
| } | ||
|
|
||
| // insertAbortEvent writes a single ABORTED phase-transition event into action_events so that | ||
| // WatchActionDetails returns a phaseTransitions entry with phase = ABORTED for the action. | ||
| // Failures are non-fatal — the caller should log and continue. | ||
| func (r *actionRepo) insertAbortEvent(ctx context.Context, actionID *common.ActionIdentifier, attempts uint32, reason string, now time.Time) error { | ||
| // Build the minimal ActionEvent proto that mergeEvents will pick up. | ||
| event := &workflow.ActionEvent{ | ||
| Id: actionID, | ||
| Phase: common.ActionPhase_ACTION_PHASE_ABORTED, | ||
| Attempt: attempts, | ||
| UpdatedTime: timestamppb.New(now), | ||
| } | ||
| info, err := proto.Marshal(event) | ||
| if err != nil { | ||
| return fmt.Errorf("marshal abort event: %w", err) | ||
| } | ||
|
|
||
| _, err = r.db.ExecContext(ctx, | ||
| `INSERT INTO action_events (project, domain, run_name, name, attempt, phase, version, info, error_kind, created_at, updated_at) | ||
| VALUES ($1, $2, $3, $4, $5, $6, 0, $7, NULL, $8, $8) | ||
| ON CONFLICT DO NOTHING`, | ||
| actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name, | ||
| attempts, int32(common.ActionPhase_ACTION_PHASE_ABORTED), info, now, | ||
| ) | ||
| return err | ||
| } | ||
|
|
||
| // ListPendingAborts returns all actions that have abort_requested_at set (i.e. awaiting pod termination). | ||
| func (r *actionRepo) ListPendingAborts(ctx context.Context) ([]*models.Action, error) { | ||
| var actions []*models.Action | ||
|
|
@@ -488,7 +669,7 @@ func (r *actionRepo) MarkAbortAttempt(ctx context.Context, actionID *common.Acti | |
| // ClearAbortRequest clears abort_requested_at (and resets counters) once the pod is confirmed terminated. | ||
| func (r *actionRepo) ClearAbortRequest(ctx context.Context, actionID *common.ActionIdentifier) error { | ||
| _, err := r.db.ExecContext(ctx, | ||
| `UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, abort_reason = NULL, updated_at = $1 | ||
| `UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, updated_at = $1 | ||
| WHERE project = $2 AND domain = $3 AND run_name = $4 AND name = $5`, | ||
| time.Now(), actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name) | ||
| if err != nil { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking maybe we don't need it. we have inserted event in the run service, right?