From 3bc4d9222805bc0c7b1267e5719050287a73f227 Mon Sep 17 00:00:00 2001 From: "M. Adil Fayyaz" <62440954+AdilFayyaz@users.noreply.github.com> Date: Mon, 13 Apr 2026 18:25:02 -0700 Subject: [PATCH 1/2] wip Signed-off-by: M. Adil Fayyaz <62440954+AdilFayyaz@users.noreply.github.com> --- actions/k8s/client.go | 26 +++- charts/flyte-demo/Chart.lock | 7 +- runs/repository/impl/action.go | 218 ++++++++++++++++++++++++--- runs/repository/interfaces/action.go | 5 +- runs/repository/mocks/mocks.go | 29 ++-- runs/service/abort_reconciler.go | 22 ++- runs/service/run_service.go | 29 +++- runs/service/run_service_test.go | 6 +- tasks/lessons.md | 19 +++ 9 files changed, 311 insertions(+), 50 deletions(-) create mode 100644 tasks/lessons.md diff --git a/actions/k8s/client.go b/actions/k8s/client.go index fd3d5a10cd..ae35c8e15f 100644 --- a/actions/k8s/client.go +++ b/actions/k8s/client.go @@ -422,6 +422,18 @@ func (c *ActionsClient) setupInformer(ctx context.Context) error { } c.dispatchEvent(taskAction, watch.Modified) }, + DeleteFunc: func(obj interface{}) { + // The informer may deliver a DeletedFinalStateUnknown tombstone + // when a delete event was missed; unwrap it first. + if tombstone, ok := obj.(toolscache.DeletedFinalStateUnknown); ok { + obj = tombstone.Obj + } + taskAction, ok := obj.(*executorv1.TaskAction) + if !ok { + return + } + c.dispatchEvent(taskAction, watch.Deleted) + }, }) if err != nil { return fmt.Errorf("failed to add TaskAction informer handler: %w", err) @@ -494,6 +506,14 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e shortName = extractShortNameFromTemplate(taskAction.Spec.TaskTemplate) } + phase := GetPhaseFromConditions(taskAction) + // A TaskAction deleted without a terminal condition was cascade-deleted by K8s + // (e.g. its parent was aborted). Treat it as aborted so the DB and UI reflect + // the real outcome instead of leaving the phase as UNSPECIFIED. + if eventType == watch.Deleted && phase == common.ActionPhase_ACTION_PHASE_UNSPECIFIED { + phase = common.ActionPhase_ACTION_PHASE_ABORTED + } + return &ActionUpdate{ ActionID: &common.ActionIdentifier{ Run: &common.RunIdentifier{ @@ -505,7 +525,7 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e }, ParentActionName: parentName, StateJSON: taskAction.Status.StateJSON, - Phase: GetPhaseFromConditions(taskAction), + Phase: phase, OutputUri: buildOutputUri(ctx, taskAction), IsDeleted: eventType == watch.Deleted, TaskType: taskAction.Spec.TaskType, @@ -618,7 +638,9 @@ func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *execut } if _, err := c.runClient.UpdateActionStatus(ctx, connect.NewRequest(statusReq)); err != nil { logger.Warnf(ctx, "Failed to update action status in run service for %s: %v", update.ActionID.Name, err) - } else if isTerminalPhase(update.Phase) { + } else if isTerminalPhase(update.Phase) && !update.IsDeleted { + // Skip label patching for deleted CRs — the patch would always fail + // with "not found" since the object is already gone. if err := c.markTerminalStatusRecorded(ctx, taskAction); err != nil { logger.Warnf(ctx, "Failed to mark terminal status recorded for %s: %v", update.ActionID.Name, err) } diff --git a/charts/flyte-demo/Chart.lock b/charts/flyte-demo/Chart.lock index 144b182599..9003265968 100644 --- a/charts/flyte-demo/Chart.lock +++ b/charts/flyte-demo/Chart.lock @@ -8,8 +8,5 @@ dependencies: - name: minio repository: https://charts.bitnami.com/bitnami version: 12.6.7 -- name: postgresql - repository: https://charts.bitnami.com/bitnami - version: 12.8.1 -digest: sha256:0e1b539359cb8edbc8be650462d31a896712183935781cc46be3b5a3596f2fae -generated: "2026-04-01T15:01:31.417773-07:00" +digest: sha256:efdaa4b57bfdb91479c94b7a6281879df8cab34b523e8b3bf35f1ed8bb1ec5c8 +generated: "2026-04-13T15:58:58.248213-07:00" diff --git a/runs/repository/impl/action.go b/runs/repository/impl/action.go index 1ff0c7439f..4daaeef48f 100644 --- a/runs/repository/impl/action.go +++ b/runs/repository/impl/action.go @@ -15,6 +15,8 @@ import ( "github.com/jmoiron/sqlx" "github.com/lib/pq" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/flytestdlib/database" "github.com/flyteorg/flyte/v2/flytestdlib/logger" @@ -136,25 +138,100 @@ func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest return runs, nextToken, nil } -// AbortRun aborts a run and all its actions -func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error { +// AbortRun marks the root action as aborted and sets abort_requested_at on all +// non-terminal child actions so the reconciler can terminate their pods. +// Returns the ActionIdentifiers of the child actions that were marked. +func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) ([]*common.ActionIdentifier, error) { now := time.Now() + abortedPhase := int32(common.ActionPhase_ACTION_PHASE_ABORTED) + succeededPhase := int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED) - _, err := r.db.ExecContext(ctx, - `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5 - WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL`, - int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason, - runID.Project, runID.Domain, runID.Name) - + // Mark the root action as aborted. + var rootName string + var rootAttempts uint32 + err := r.db.QueryRowxContext(ctx, + `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5, + ended_at = COALESCE(ended_at, GREATEST($2, created_at)), + duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000 + WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL + RETURNING name, attempts`, + abortedPhase, now, now, 0, reason, + runID.Project, runID.Domain, runID.Name, + ).Scan(&rootName, &rootAttempts) + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("run not found: %w", sql.ErrNoRows) + } + if err != nil { + return nil, fmt.Errorf("failed to abort run: %w", err) + } + + // Cascade abort to non-terminal child actions so the reconciler terminates their pods. + // Skip SUCCEEDED and ABORTED rows — they are already done. + rows, err := r.db.QueryxContext(ctx, + `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = 0, abort_reason = $4, + ended_at = COALESCE(ended_at, GREATEST($2, created_at)), + duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000 + WHERE project = $5 AND domain = $6 AND run_name = $7 + AND parent_action_name IS NOT NULL + AND phase != $8 AND phase != $1 + RETURNING name, attempts`, + abortedPhase, now, now, reason, + runID.Project, runID.Domain, runID.Name, + succeededPhase, + ) if err != nil { - return fmt.Errorf("failed to abort run: %w", err) + return nil, fmt.Errorf("failed to abort child actions: %w", err) + } + defer rows.Close() + + type childRow struct { + name string + attempts uint32 + } + var children []childRow + var childIDs []*common.ActionIdentifier + for rows.Next() { + var cr childRow + if scanErr := rows.Scan(&cr.name, &cr.attempts); scanErr != nil { + return nil, fmt.Errorf("failed to scan child action name: %w", scanErr) + } + children = append(children, cr) + childIDs = append(childIDs, &common.ActionIdentifier{ + Run: runID, + Name: cr.name, + }) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to iterate child actions: %w", err) + } + + // Insert ABORTED phase-transition events so WatchActionDetails returns a + // phaseTransitions entry with ABORTED for each affected action. + rootID := &common.ActionIdentifier{Run: runID, Name: rootName} + if err := r.insertAbortEvent(ctx, rootID, rootAttempts, reason, now); err != nil { + logger.Warnf(ctx, "AbortRun: failed to insert abort event for root %s: %v", rootName, err) + } + for _, cr := range children { + childID := &common.ActionIdentifier{Run: runID, Name: cr.name} + if err := r.insertAbortEvent(ctx, childID, cr.attempts, reason, now); err != nil { + logger.Warnf(ctx, "AbortRun: failed to insert abort event for %s: %v", cr.name, err) + } } // Notify run subscribers. r.notifyRunUpdate(ctx, runID) - logger.Infof(ctx, "Aborted run: %s/%s/%s", runID.Project, runID.Domain, runID.Name) - return nil + // Notify action subscribers for the root and every aborted child so that + // WatchAllActionUpdates (used by the UI to show per-action status) reflects + // the phase change immediately. + r.notifyActionUpdate(ctx, rootID) + for _, id := range childIDs { + r.notifyActionUpdate(ctx, id) + } + + logger.Infof(ctx, "Aborted run: %s/%s/%s (%d child action(s) queued for termination)", + runID.Project, runID.Domain, runID.Name, len(childIDs)) + return childIDs, nil } // InsertEvents inserts a batch of action events, ignoring duplicates (same PK = idempotent). @@ -424,6 +501,14 @@ func (r *actionRepo) UpdateActionPhase( return err } if rowsAffected > 0 { + // Insert an abort event so WatchActionDetails phaseTransitions include ABORTED. + // This must happen before notifyActionUpdate so the event is visible when the + // subscriber re-fetches action events in response to the notification. + if phase == common.ActionPhase_ACTION_PHASE_ABORTED { + if err := r.insertAbortEvent(ctx, actionID, attempts, "", now); err != nil { + logger.Warnf(ctx, "UpdateActionPhase: failed to insert abort event for %s: %v", actionID.Name, err) + } + } // Notify subscribers of the action update r.notifyActionUpdate(ctx, actionID) } @@ -437,27 +522,120 @@ func (r *actionRepo) UpdateActionPhase( return nil } -// AbortAction aborts a specific action +// AbortAction aborts a specific action and all of its descendants. +// Children cannot run without their parent, so they are marked ABORTED immediately. +// K8s OwnerReferences handle pod/CR termination for task-action children; trace +// children have no pods and require only the DB update. func (r *actionRepo) AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error { now := time.Now() + abortedPhase := int32(common.ActionPhase_ACTION_PHASE_ABORTED) + succeededPhase := int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED) - _, err := r.db.ExecContext(ctx, - `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5 - WHERE project = $6 AND domain = $7 AND run_name = $8 AND name = $9`, - int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason, - actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name) - + var name string + var attempts uint32 + err := r.db.QueryRowxContext(ctx, + `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5, + ended_at = COALESCE(ended_at, GREATEST($2, created_at)), + duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000 + WHERE project = $6 AND domain = $7 AND run_name = $8 AND name = $9 + RETURNING name, attempts`, + abortedPhase, now, now, 0, reason, + actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name, + ).Scan(&name, &attempts) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("action not found: %w", sql.ErrNoRows) + } if err != nil { return fmt.Errorf("failed to abort action: %w", err) } + if err := r.insertAbortEvent(ctx, actionID, attempts, reason, now); err != nil { + logger.Warnf(ctx, "AbortAction: failed to insert abort event for %s: %v", actionID.Name, err) + } + // Notify action subscribers. r.notifyActionUpdate(ctx, actionID) - logger.Infof(ctx, "Aborted action: %s", actionID.Name) + // Cascade ABORTED to all descendants of this action (recursive: children, + // grandchildren, etc.) that are not yet in a terminal phase. + // We use a recursive CTE so all depths are handled in a single round-trip. + rows, err := r.db.QueryxContext(ctx, + `WITH RECURSIVE descendants AS ( + SELECT name, attempts + FROM actions + WHERE project = $1 AND domain = $2 AND run_name = $3 + AND parent_action_name = $4 + UNION ALL + SELECT a.name, a.attempts + FROM actions a + JOIN descendants d ON a.parent_action_name = d.name + WHERE a.project = $1 AND a.domain = $2 AND a.run_name = $3 + ) + UPDATE actions + SET phase = $5, updated_at = $6, abort_reason = $7, + ended_at = COALESCE(ended_at, GREATEST($6, created_at)), + duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($6, created_at)) - created_at)) * 1000 + FROM descendants + WHERE actions.project = $1 AND actions.domain = $2 AND actions.run_name = $3 + AND actions.name = descendants.name + AND actions.phase != $5 AND actions.phase != $8 + RETURNING actions.name, actions.attempts`, + actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name, + abortedPhase, now, reason, succeededPhase, + ) + if err != nil { + logger.Warnf(ctx, "AbortAction: failed to cascade abort to descendants of %s: %v", actionID.Name, err) + return nil + } + defer rows.Close() + + for rows.Next() { + var childName string + var childAttempts uint32 + if scanErr := rows.Scan(&childName, &childAttempts); scanErr != nil { + logger.Warnf(ctx, "AbortAction: failed to scan descendant row: %v", scanErr) + continue + } + childID := &common.ActionIdentifier{Run: actionID.Run, Name: childName} + if err := r.insertAbortEvent(ctx, childID, childAttempts, reason, now); err != nil { + logger.Warnf(ctx, "AbortAction: failed to insert abort event for descendant %s: %v", childName, err) + } + r.notifyActionUpdate(ctx, childID) + } + if err := rows.Err(); err != nil { + logger.Warnf(ctx, "AbortAction: error iterating descendant rows for %s: %v", actionID.Name, err) + } + + logger.Infof(ctx, "Aborted action %s and its descendants", actionID.Name) return nil } +// insertAbortEvent writes a single ABORTED phase-transition event into action_events so that +// WatchActionDetails returns a phaseTransitions entry with phase = ABORTED for the action. +// Failures are non-fatal — the caller should log and continue. +func (r *actionRepo) insertAbortEvent(ctx context.Context, actionID *common.ActionIdentifier, attempts uint32, reason string, now time.Time) error { + // Build the minimal ActionEvent proto that mergeEvents will pick up. + event := &workflow.ActionEvent{ + Id: actionID, + Phase: common.ActionPhase_ACTION_PHASE_ABORTED, + Attempt: attempts, + UpdatedTime: timestamppb.New(now), + } + info, err := proto.Marshal(event) + if err != nil { + return fmt.Errorf("marshal abort event: %w", err) + } + + _, err = r.db.ExecContext(ctx, + `INSERT INTO action_events (project, domain, run_name, name, attempt, phase, version, info, error_kind, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, 0, $7, NULL, $8, $8) + ON CONFLICT DO NOTHING`, + actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name, + attempts, int32(common.ActionPhase_ACTION_PHASE_ABORTED), info, now, + ) + return err +} + // ListPendingAborts returns all actions that have abort_requested_at set (i.e. awaiting pod termination). func (r *actionRepo) ListPendingAborts(ctx context.Context) ([]*models.Action, error) { var actions []*models.Action @@ -488,7 +666,7 @@ func (r *actionRepo) MarkAbortAttempt(ctx context.Context, actionID *common.Acti // ClearAbortRequest clears abort_requested_at (and resets counters) once the pod is confirmed terminated. func (r *actionRepo) ClearAbortRequest(ctx context.Context, actionID *common.ActionIdentifier) error { _, err := r.db.ExecContext(ctx, - `UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, abort_reason = NULL, updated_at = $1 + `UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, updated_at = $1 WHERE project = $2 AND domain = $3 AND run_name = $4 AND name = $5`, time.Now(), actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name) if err != nil { diff --git a/runs/repository/interfaces/action.go b/runs/repository/interfaces/action.go index c5fd776cb9..d0a3f16c67 100644 --- a/runs/repository/interfaces/action.go +++ b/runs/repository/interfaces/action.go @@ -15,7 +15,10 @@ type ActionRepo interface { // Run operations GetRun(ctx context.Context, runID *common.RunIdentifier) (*models.Run, error) ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error) - AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error + // AbortRun marks the root action as aborted and sets abort_requested_at on all + // non-terminal child actions. Returns the identifiers of those child actions so + // the caller can push them to the AbortReconciler for pod termination. + AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) ([]*common.ActionIdentifier, error) // Action operations CreateAction(ctx context.Context, action *models.Action, updateTriggeredAt bool) (*models.Action, error) diff --git a/runs/repository/mocks/mocks.go b/runs/repository/mocks/mocks.go index 06bc646609..26911b441b 100644 --- a/runs/repository/mocks/mocks.go +++ b/runs/repository/mocks/mocks.go @@ -113,20 +113,24 @@ func (_c *ActionRepo_AbortAction_Call) RunAndReturn(run func(ctx context.Context } // AbortRun provides a mock function for the type ActionRepo -func (_mock *ActionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error { +func (_mock *ActionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) ([]*common.ActionIdentifier, error) { ret := _mock.Called(ctx, runID, reason, abortedBy) if len(ret) == 0 { panic("no return value specified for AbortRun") } - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *common.RunIdentifier, string, *common.EnrichedIdentity) error); ok { - r0 = returnFunc(ctx, runID, reason, abortedBy) + var r0 []*common.ActionIdentifier + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *common.RunIdentifier, string, *common.EnrichedIdentity) ([]*common.ActionIdentifier, error)); ok { + r0, r1 = returnFunc(ctx, runID, reason, abortedBy) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*common.ActionIdentifier) + } + r1 = ret.Error(1) } - return r0 + return r0, r1 } // ActionRepo_AbortRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AbortRun' @@ -161,22 +165,17 @@ func (_c *ActionRepo_AbortRun_Call) Run(run func(ctx context.Context, runID *com if args[3] != nil { arg3 = args[3].(*common.EnrichedIdentity) } - run( - arg0, - arg1, - arg2, - arg3, - ) + run(arg0, arg1, arg2, arg3) }) return _c } -func (_c *ActionRepo_AbortRun_Call) Return(err error) *ActionRepo_AbortRun_Call { - _c.Call.Return(err) +func (_c *ActionRepo_AbortRun_Call) Return(childActions []*common.ActionIdentifier, err error) *ActionRepo_AbortRun_Call { + _c.Call.Return(childActions, err) return _c } -func (_c *ActionRepo_AbortRun_Call) RunAndReturn(run func(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error) *ActionRepo_AbortRun_Call { +func (_c *ActionRepo_AbortRun_Call) RunAndReturn(run func(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) ([]*common.ActionIdentifier, error)) *ActionRepo_AbortRun_Call { _c.Call.Return(run) return _c } diff --git a/runs/service/abort_reconciler.go b/runs/service/abort_reconciler.go index adbd32ddb8..27d91e863e 100644 --- a/runs/service/abort_reconciler.go +++ b/runs/service/abort_reconciler.go @@ -2,8 +2,11 @@ package service import ( "context" + "database/sql" + "errors" "fmt" "math/rand" + "strings" "sync" "time" @@ -217,6 +220,12 @@ func (r *AbortReconciler) runWorker(ctx context.Context) { func (r *AbortReconciler) processTask(ctx context.Context, task abortTask) { attemptCount, err := r.repo.ActionRepo().MarkAbortAttempt(ctx, task.actionID) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + // Action no longer exists in the DB — nothing to abort, drop it. + r.queue.remove(task.key) + logger.Warnf(ctx, "AbortReconciler: action %s not found in DB, dropping abort task", task.key) + return + } logger.Errorf(ctx, "AbortReconciler: failed to mark attempt for %s: %v", task.key, err) // Re-enqueue without counting — the DB row is authoritative; try again later. r.queue.scheduleRequeue(ctx, task, r.cfg.InitialDelay) @@ -264,6 +273,8 @@ func (r *AbortReconciler) processTask(ctx context.Context, task abortTask) { } // isAlreadyTerminated returns true for errors that indicate the action is already gone. +// The actions service may wrap a Kubernetes "not found" error as CodeInternal rather +// than CodeNotFound, so we also check for that case by inspecting the message. func isAlreadyTerminated(err error) bool { if err == nil { return false @@ -272,5 +283,14 @@ func isAlreadyTerminated(err error) bool { if !ok { return false } - return connectErr.Code() == connect.CodeNotFound + if connectErr.Code() == connect.CodeNotFound { + return true + } + // The actions service forwards Kubernetes API "not found" errors with CodeInternal. + // Treat those as "already gone" so the reconciler clears the DB entry instead of + // retrying indefinitely. + if connectErr.Code() == connect.CodeInternal && strings.Contains(connectErr.Message(), "not found") { + return true + } + return false } diff --git a/runs/service/run_service.go b/runs/service/run_service.go index 7ac39a4f3b..7bf344b48c 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "sort" "strings" @@ -387,13 +388,26 @@ func (s *RunService) AbortRun( } // Abort in database, then push to reconciler for background pod termination. - if err := s.repo.ActionRepo().AbortRun(ctx, req.Msg.RunId, reason, nil); err != nil { + // AbortRun marks the root action aborted and returns identifiers for any + // non-terminal child actions so we can push them to the reconciler too. + childActions, err := s.repo.ActionRepo().AbortRun(ctx, req.Msg.RunId, reason, nil) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("run not found: %s/%s/%s", req.Msg.RunId.Project, req.Msg.RunId.Domain, req.Msg.RunId.Name)) + } logger.Errorf(ctx, "Failed to abort run: %v", err) return nil, connect.NewError(connect.CodeInternal, err) } if s.abortReconciler != nil { - s.abortReconciler.Push(ctx, &common.ActionIdentifier{Run: req.Msg.RunId, Name: req.Msg.RunId.Name}, reason) + // Push child actions first — they own the TaskAction CRDs and pods. + for _, actionID := range childActions { + s.abortReconciler.Push(ctx, actionID, reason) + } + // Also push the root action ("a0"). Its TaskAction CRD may not exist (it is + // a workflow-level composite action), but the reconciler handles that case + // gracefully via isAlreadyTerminated. + s.abortReconciler.Push(ctx, &common.ActionIdentifier{Run: req.Msg.RunId, Name: "a0"}, reason) } return connect.NewResponse(&workflow.AbortRunResponse{}), nil @@ -552,7 +566,13 @@ func (s *RunService) buildActionDetails(ctx context.Context, model *models.Actio } } case common.ActionPhase_ACTION_PHASE_ABORTED: - // TODO: set AbortInfo + abortInfo := &workflow.AbortInfo{} + if model.AbortReason != nil { + abortInfo.Reason = *model.AbortReason + } + action.Result = &workflow.ActionDetails_AbortInfo{ + AbortInfo: abortInfo, + } } return action, nil @@ -923,6 +943,9 @@ func (s *RunService) AbortAction( // Abort in database, then push to reconciler for background pod termination. if err := s.repo.ActionRepo().AbortAction(ctx, req.Msg.ActionId, reason, nil); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action not found: %s", req.Msg.ActionId.Name)) + } logger.Errorf(ctx, "Failed to abort action: %v", err) return nil, connect.NewError(connect.CodeInternal, err) } diff --git a/runs/service/run_service_test.go b/runs/service/run_service_test.go index 728a094a4c..0f420260e5 100644 --- a/runs/service/run_service_test.go +++ b/runs/service/run_service_test.go @@ -653,7 +653,7 @@ func TestAbortRun(t *testing.T) { t.Run("success with default reason", func(t *testing.T) { actionRepo, actionsClient, svc := newTestService(t) - actionRepo.On("AbortRun", mock.Anything, runID, "User requested abort", (*common.EnrichedIdentity)(nil)).Return(nil) + actionRepo.On("AbortRun", mock.Anything, runID, "User requested abort", (*common.EnrichedIdentity)(nil)).Return(([]*common.ActionIdentifier)(nil), nil) _, err := svc.AbortRun(context.Background(), connect.NewRequest(&workflow.AbortRunRequest{RunId: runID})) assert.NoError(t, err) @@ -664,7 +664,7 @@ func TestAbortRun(t *testing.T) { actionRepo, actionsClient, svc := newTestService(t) reason := "timeout exceeded" - actionRepo.On("AbortRun", mock.Anything, runID, reason, (*common.EnrichedIdentity)(nil)).Return(nil) + actionRepo.On("AbortRun", mock.Anything, runID, reason, (*common.EnrichedIdentity)(nil)).Return(([]*common.ActionIdentifier)(nil), nil) _, err := svc.AbortRun(context.Background(), connect.NewRequest(&workflow.AbortRunRequest{RunId: runID, Reason: &reason})) assert.NoError(t, err) @@ -674,7 +674,7 @@ func TestAbortRun(t *testing.T) { t.Run("db error returns error", func(t *testing.T) { actionRepo, actionsClient, svc := newTestService(t) - actionRepo.On("AbortRun", mock.Anything, runID, mock.Anything, mock.Anything).Return(errors.New("db unavailable")) + actionRepo.On("AbortRun", mock.Anything, runID, mock.Anything, mock.Anything).Return(([]*common.ActionIdentifier)(nil), errors.New("db unavailable")) _, err := svc.AbortRun(context.Background(), connect.NewRequest(&workflow.AbortRunRequest{RunId: runID})) assert.Error(t, err) diff --git a/tasks/lessons.md b/tasks/lessons.md new file mode 100644 index 0000000000..e5ed3e44e2 --- /dev/null +++ b/tasks/lessons.md @@ -0,0 +1,19 @@ +# Claude Lessons Learned + +This file tracks patterns and corrections from past sessions. +Updated after every user correction. Reviewed at session start. + +## Code Quality + + +## Architecture Decisions + + +## Testing Patterns + + +## Communication + + +## Project-Specific Gotchas + From 7fd78d5e1f250d7f812149054875615df05d6c21 Mon Sep 17 00:00:00 2001 From: "M. Adil Fayyaz" <62440954+AdilFayyaz@users.noreply.github.com> Date: Wed, 15 Apr 2026 10:30:45 -0700 Subject: [PATCH 2/2] fix Signed-off-by: M. Adil Fayyaz <62440954+AdilFayyaz@users.noreply.github.com> --- runs/repository/impl/action.go | 5 ++++- runs/service/run_service.go | 11 ++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/runs/repository/impl/action.go b/runs/repository/impl/action.go index 4daaeef48f..2a0196137a 100644 --- a/runs/repository/impl/action.go +++ b/runs/repository/impl/action.go @@ -589,6 +589,7 @@ func (r *actionRepo) AbortAction(ctx context.Context, actionID *common.ActionIde } defer rows.Close() + var descendantCount int for rows.Next() { var childName string var childAttempts uint32 @@ -596,17 +597,19 @@ func (r *actionRepo) AbortAction(ctx context.Context, actionID *common.ActionIde logger.Warnf(ctx, "AbortAction: failed to scan descendant row: %v", scanErr) continue } + descendantCount++ childID := &common.ActionIdentifier{Run: actionID.Run, Name: childName} if err := r.insertAbortEvent(ctx, childID, childAttempts, reason, now); err != nil { logger.Warnf(ctx, "AbortAction: failed to insert abort event for descendant %s: %v", childName, err) } r.notifyActionUpdate(ctx, childID) + logger.Infof(ctx, "AbortAction: cascade-aborted descendant %s of %s", childName, actionID.Name) } if err := rows.Err(); err != nil { logger.Warnf(ctx, "AbortAction: error iterating descendant rows for %s: %v", actionID.Name, err) } - logger.Infof(ctx, "Aborted action %s and its descendants", actionID.Name) + logger.Infof(ctx, "AbortAction: aborted %s and %d descendant(s)", actionID.Name, descendantCount) return nil } diff --git a/runs/service/run_service.go b/runs/service/run_service.go index 7bf344b48c..d413adcef6 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -1019,6 +1019,12 @@ func (s *RunService) WatchActionDetails( actionID := req.Msg.ActionId logger.Infof(ctx, "Received WatchActionDetails request for: %s/%s", actionID.Run.Name, actionID.Name) + // Subscribe FIRST to avoid missing notifications that fire between the DB + // read and the subscription setup (same pattern as WatchActions). + updates := make(chan *models.Action, 50) + errs := make(chan error, 1) + go s.repo.ActionRepo().WatchActionUpdates(ctx, actionID, updates, errs) + // Step 1: Send initial state from DB details, err := s.getActionDetails(ctx, req.Msg.GetActionId()) if err != nil { @@ -1036,11 +1042,6 @@ func (s *RunService) WatchActionDetails( return nil } - // Step 2: Watch DB for updates - updates := make(chan *models.Action, 50) - errs := make(chan error, 1) - go s.repo.ActionRepo().WatchActionUpdates(ctx, actionID, updates, errs) - for { select { case <-ctx.Done():