diff --git a/actions/k8s/client.go b/actions/k8s/client.go index fd3d5a10cd..9db00c5ef7 100644 --- a/actions/k8s/client.go +++ b/actions/k8s/client.go @@ -422,6 +422,18 @@ func (c *ActionsClient) setupInformer(ctx context.Context) error { } c.dispatchEvent(taskAction, watch.Modified) }, + DeleteFunc: func(obj interface{}) { + // The informer may deliver a DeletedFinalStateUnknown tombstone + // when a delete event was missed; unwrap it first. + if tombstone, ok := obj.(toolscache.DeletedFinalStateUnknown); ok { + obj = tombstone.Obj + } + taskAction, ok := obj.(*executorv1.TaskAction) + if !ok { + return + } + c.dispatchEvent(taskAction, watch.Deleted) + }, }) if err != nil { return fmt.Errorf("failed to add TaskAction informer handler: %w", err) @@ -494,6 +506,11 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e shortName = extractShortNameFromTemplate(taskAction.Spec.TaskTemplate) } + phase := GetPhaseFromConditions(taskAction) + if eventType == watch.Deleted { + phase = common.ActionPhase_ACTION_PHASE_ABORTED + } + return &ActionUpdate{ ActionID: &common.ActionIdentifier{ Run: &common.RunIdentifier{ @@ -505,7 +522,7 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e }, ParentActionName: parentName, StateJSON: taskAction.Status.StateJSON, - Phase: GetPhaseFromConditions(taskAction), + Phase: phase, OutputUri: buildOutputUri(ctx, taskAction), IsDeleted: eventType == watch.Deleted, TaskType: taskAction.Spec.TaskType, @@ -536,7 +553,7 @@ func (c *ActionsClient) notifySubscribers(ctx context.Context, update *ActionUpd // notifyRunService forwards a watch event to the internal run service. // On ADDED events it calls RecordAction to create the DB record. -// On all events it calls UpdateActionStatus (when phase is meaningful) and RecordActionEvents. +// On all events it calls UpdateActionStatus (when phase is meaningful) to update the actions table. func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *executorv1.TaskAction, update *ActionUpdate, eventType watch.EventType) { if c.runClient == nil { return @@ -618,7 +635,9 @@ func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *execut } if _, err := c.runClient.UpdateActionStatus(ctx, connect.NewRequest(statusReq)); err != nil { logger.Warnf(ctx, "Failed to update action status in run service for %s: %v", update.ActionID.Name, err) - } else if isTerminalPhase(update.Phase) { + } else if isTerminalPhase(update.Phase) && !update.IsDeleted { + // Skip label patching for deleted CRs — the patch would always fail + // with "not found" since the object is already gone. if err := c.markTerminalStatusRecorded(ctx, taskAction); err != nil { logger.Warnf(ctx, "Failed to mark terminal status recorded for %s: %v", update.ActionID.Name, err) } diff --git a/buf.lock b/buf.lock index c6b7e8e3b3..5e8e6b51fb 100644 --- a/buf.lock +++ b/buf.lock @@ -8,5 +8,5 @@ deps: commit: 62f35d8aed1149c291d606d958a7ce32 digest: b5:d66bf04adc77a0870bdc9328aaf887c7188a36fb02b83a480dc45ef9dc031b4d39fc6e9dc6435120ccf4fe5bfd5c6cb6592533c6c316595571f9a31420ab47fe - name: buf.build/grpc-ecosystem/grpc-gateway - commit: 6467306b4f624747aaf6266762ee7a1c - digest: b5:c2caa61467d992749812c909f93c07e9a667da33c758a7c1973d63136c23b3cafcc079985b12cdf54a10049ed3297418f1eda42cdffdcf34113792dcc3a990af + commit: 4836b6d552304e1bbe47e66a523f0daa + digest: b5:c3fefd4d3dfa9b0478bbb1a4ad87d7b38146e3ce6eff4214e32f2c5834c2e4afc3be218316f0fbd53e925a001c3ed1e2fc99fb76b3121ede642989f0d0d7c71c diff --git a/executor/pkg/controller/taskaction_controller.go b/executor/pkg/controller/taskaction_controller.go index 8c59c3f4c9..d30cbd2f79 100644 --- a/executor/pkg/controller/taskaction_controller.go +++ b/executor/pkg/controller/taskaction_controller.go @@ -363,6 +363,19 @@ func (r *TaskActionReconciler) handleAbortAndFinalize(ctx context.Context, taskA } } + abortTime := time.Now() + abortPhaseInfo := pluginsCore.PhaseInfoAborted(abortTime, pluginsCore.DefaultPhaseVersion, "aborted") + actionEvent := r.buildActionEvent(ctx, taskAction, abortPhaseInfo) + // buildActionEvent derives UpdatedTime from PhaseHistory, which doesn't include the + // abort transition. Override it so mergeEvents uses the actual abort time as end_time. + actionEvent.UpdatedTime = timestamppb.New(abortTime) + if _, err := r.eventsClient.Record(ctx, connect.NewRequest(&workflow.RecordRequest{ + Events: []*workflow.ActionEvent{actionEvent}, + })); err != nil { + logger.Error(err, "failed to emit abort event, will retry") + return ctrl.Result{RequeueAfter: TaskActionDefaultRequeueDuration}, nil + } + return r.removeFinalizer(ctx, taskAction) } diff --git a/executor/pkg/controller/taskaction_controller_test.go b/executor/pkg/controller/taskaction_controller_test.go index 5365872179..f30026cbd9 100644 --- a/executor/pkg/controller/taskaction_controller_test.go +++ b/executor/pkg/controller/taskaction_controller_test.go @@ -18,6 +18,7 @@ package controller import ( "context" + "sync" "time" "connectrpc.com/connect" @@ -34,6 +35,7 @@ import ( flyteorgv1 "github.com/flyteorg/flyte/v2/executor/api/v1" pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core" k8sPlugin "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" ) @@ -45,6 +47,27 @@ func (f *fakeEventsClient) Record(_ context.Context, _ *connect.Request[workflow return connect.NewResponse(&workflow.RecordResponse{}), nil } +// recordingEventsClient captures all recorded ActionEvents for assertion in tests. +type recordingEventsClient struct { + mu sync.Mutex + events []*workflow.ActionEvent +} + +func (r *recordingEventsClient) Record(_ context.Context, req *connect.Request[workflow.RecordRequest]) (*connect.Response[workflow.RecordResponse], error) { + r.mu.Lock() + defer r.mu.Unlock() + r.events = append(r.events, req.Msg.GetEvents()...) + return connect.NewResponse(&workflow.RecordResponse{}), nil +} + +func (r *recordingEventsClient) RecordedEvents() []*workflow.ActionEvent { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]*workflow.ActionEvent, len(r.events)) + copy(out, r.events) + return out +} + // buildTaskTemplateBytes creates a minimal protobuf-serialized TaskTemplate // with a container spec that the pod plugin can use to build a Pod. func buildTaskTemplateBytes(taskType, image string) []byte { @@ -285,6 +308,77 @@ var _ = Describe("TaskAction Controller", func() { }) }) + Context("When a TaskAction is deleted (abort flow)", func() { + const abortResourceName = "abort-test-resource" + + ctx := context.Background() + + typeNamespacedName := types.NamespacedName{ + Name: abortResourceName, + Namespace: "default", + } + + BeforeEach(func() { + resource := &flyteorgv1.TaskAction{ + ObjectMeta: metav1.ObjectMeta{ + Name: abortResourceName, + Namespace: "default", + Finalizers: []string{taskActionFinalizer}, + }, + Spec: flyteorgv1.TaskActionSpec{ + RunName: "abort-run", + Project: "abort-project", + Domain: "abort-domain", + ActionName: "abort-action", + InputURI: "/tmp/input", + RunOutputBase: "/tmp/output", + TaskType: "python-task", + TaskTemplate: buildTaskTemplateBytes("python-task", "python:3.11"), + }, + } + Expect(k8sClient.Create(ctx, resource)).To(Succeed()) + Expect(k8sClient.Delete(ctx, resource)).To(Succeed()) + }) + + AfterEach(func() { + resource := &flyteorgv1.TaskAction{} + err := k8sClient.Get(ctx, typeNamespacedName, resource) + if err == nil { + resource.Finalizers = nil + Expect(k8sClient.Update(ctx, resource)).To(Succeed()) + } + }) + + It("should emit an ACTION_PHASE_ABORTED event before removing the finalizer", func() { + recorder := &recordingEventsClient{} + reconciler := &TaskActionReconciler{ + Client: k8sClient, + Scheme: k8sClient.Scheme(), + Recorder: record.NewFakeRecorder(10), + PluginRegistry: pluginRegistry, + DataStore: dataStore, + eventsClient: recorder, + } + + result, err := reconciler.Reconcile(ctx, reconcile.Request{NamespacedName: typeNamespacedName}) + Expect(err).NotTo(HaveOccurred()) + Expect(result.RequeueAfter).To(BeZero()) + + // Finalizer should have been removed — object is gone. + deleted := &flyteorgv1.TaskAction{} + Expect(k8sClient.Get(ctx, typeNamespacedName, deleted)).NotTo(Succeed()) + + // An ABORTED event must have been emitted. + recorded := recorder.RecordedEvents() + Expect(recorded).NotTo(BeEmpty()) + phases := make([]interface{}, len(recorded)) + for i, e := range recorded { + phases[i] = e.GetPhase() + } + Expect(phases).To(ContainElement(common.ActionPhase_ACTION_PHASE_ABORTED)) + }) + }) + Context("toClusterEvents", func() { It("should include both phase reason and additional reasons", func() { phaseOccurredAt := time.Date(2026, 4, 2, 10, 0, 0, 0, time.UTC) diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index 7ede39b4c2..de3f48cdbd 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -274,6 +274,12 @@ func PhaseInfoSuccess(info *TaskInfo) PhaseInfo { return phaseInfo(PhaseSuccess, DefaultPhaseVersion, nil, info, false) } +func PhaseInfoAborted(t time.Time, version uint32, reason string) PhaseInfo { + pi := phaseInfo(PhaseAborted, version, nil, &TaskInfo{OccurredAt: &t}, false) + pi.reason = reason + return pi +} + func PhaseInfoSystemFailure(code, reason string, info *TaskInfo) PhaseInfo { return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_SYSTEM}, info, false) } diff --git a/runs/repository/impl/action.go b/runs/repository/impl/action.go index e1a56ebf7f..869aaac779 100644 --- a/runs/repository/impl/action.go +++ b/runs/repository/impl/action.go @@ -136,22 +136,32 @@ func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest return runs, nextToken, nil } -// AbortRun aborts a run and all its actions +// AbortRun marks only the root action as ABORTED and sets abort_requested_at on it. +// K8s cascades CRD deletion to child actions via OwnerReferences; the action service +// informer handles marking them ABORTED in DB when their CRDs are deleted. func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error { now := time.Now() - _, err := r.db.ExecContext(ctx, - `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5 - WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL`, + var rootName string + err := r.db.QueryRowxContext(ctx, + `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5, + ended_at = COALESCE(ended_at, GREATEST($2, created_at)), + duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000 + WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL + RETURNING name`, int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason, - runID.Project, runID.Domain, runID.Name) - + runID.Project, runID.Domain, runID.Name, + ).Scan(&rootName) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("run not found: %w", sql.ErrNoRows) + } if err != nil { return fmt.Errorf("failed to abort run: %w", err) } - // Notify run subscribers. + rootID := &common.ActionIdentifier{Run: runID, Name: rootName} r.notifyRunUpdate(ctx, runID) + r.notifyActionUpdate(ctx, rootID) logger.Infof(ctx, "Aborted run: %s/%s/%s", runID.Project, runID.Domain, runID.Name) return nil @@ -431,7 +441,6 @@ func (r *actionRepo) UpdateActionPhase( return err } if rowsAffected > 0 { - // Notify subscribers of the action update r.notifyActionUpdate(ctx, actionID) } @@ -444,24 +453,29 @@ func (r *actionRepo) UpdateActionPhase( return nil } -// AbortAction aborts a specific action +// AbortAction marks only the targeted action as ABORTED and sets abort_requested_at. +// K8s cascades CRD deletion to descendants via OwnerReferences; the action service +// informer handles marking them ABORTED in DB when their CRDs are deleted. func (r *actionRepo) AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error { now := time.Now() - _, err := r.db.ExecContext(ctx, - `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5 + result, err := r.db.ExecContext(ctx, + `UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5, + ended_at = COALESCE(ended_at, GREATEST($2, created_at)), + duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000 WHERE project = $6 AND domain = $7 AND run_name = $8 AND name = $9`, int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason, actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name) - if err != nil { return fmt.Errorf("failed to abort action: %w", err) } + if n, _ := result.RowsAffected(); n == 0 { + return fmt.Errorf("action not found: %w", sql.ErrNoRows) + } - // Notify action subscribers. r.notifyActionUpdate(ctx, actionID) - logger.Infof(ctx, "Aborted action: %s", actionID.Name) + logger.Infof(ctx, "AbortAction: aborted %s", actionID.Name) return nil } @@ -495,7 +509,7 @@ func (r *actionRepo) MarkAbortAttempt(ctx context.Context, actionID *common.Acti // ClearAbortRequest clears abort_requested_at (and resets counters) once the pod is confirmed terminated. func (r *actionRepo) ClearAbortRequest(ctx context.Context, actionID *common.ActionIdentifier) error { _, err := r.db.ExecContext(ctx, - `UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, abort_reason = NULL, updated_at = $1 + `UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, updated_at = $1 WHERE project = $2 AND domain = $3 AND run_name = $4 AND name = $5`, time.Now(), actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name) if err != nil { @@ -856,7 +870,7 @@ func (r *actionRepo) processNotifications() { select { case ch <- notif.Extra: default: - logger.Warnf(context.Background(), "Action subscriber channel full, dropping notification") + logger.Warnf(context.Background(), "Action subscriber channel full, dropping notification payload=%s", notif.Extra) } } r.mu.RUnlock() diff --git a/runs/repository/impl/action_test.go b/runs/repository/impl/action_test.go index 5fc6a491bd..019ba8d8c0 100644 --- a/runs/repository/impl/action_test.go +++ b/runs/repository/impl/action_test.go @@ -663,3 +663,39 @@ func TestInsertEvents_WithLogContext(t *testing.T) { require.NoError(t, err) assert.Equal(t, "my-pod", deserialized.GetLogContext().GetPrimaryPodName()) } + +// TestUpdateActionPhase_AbortedDoesNotInsertEvent verifies that transitioning an +// action to ABORTED updates the phase column but does NOT insert a synthetic row +// into action_events. The abort event is now emitted by the controller via +// RecordActionEvents before the TaskAction finalizer is removed. +func TestUpdateActionPhase_AbortedDoesNotInsertEvent(t *testing.T) { + db := setupActionDB(t) + actionRepo, err := NewActionRepo(db, testDbConfig) + require.NoError(t, err) + ctx := context.Background() + + actionID := &common.ActionIdentifier{ + Run: &common.RunIdentifier{Project: "p", Domain: "d", Name: "run-abort"}, + Name: "abort-action", + } + _, err = actionRepo.CreateAction(ctx, models.NewActionModel(actionID), false) + require.NoError(t, err) + + endTime := time.Now() + err = actionRepo.UpdateActionPhase(ctx, actionID, common.ActionPhase_ACTION_PHASE_ABORTED, 1, core.CatalogCacheStatus_CACHE_DISABLED, &endTime) + require.NoError(t, err) + + // Phase column must be updated. + action, err := actionRepo.GetAction(ctx, actionID) + require.NoError(t, err) + assert.Equal(t, int32(common.ActionPhase_ACTION_PHASE_ABORTED), action.Phase) + + // No synthetic event row should have been inserted — the controller now emits the event. + var count int + err = db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM action_events WHERE project=$1 AND domain=$2 AND run_name=$3 AND name=$4`, + actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name, + ).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 0, count, "UpdateActionPhase(ABORTED) must not insert a synthetic action_events row") +} diff --git a/runs/repository/interfaces/action.go b/runs/repository/interfaces/action.go index c5fd776cb9..48f5b3d73c 100644 --- a/runs/repository/interfaces/action.go +++ b/runs/repository/interfaces/action.go @@ -15,6 +15,9 @@ type ActionRepo interface { // Run operations GetRun(ctx context.Context, runID *common.RunIdentifier) (*models.Run, error) ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error) + // AbortRun marks only the root action as ABORTED and sets abort_requested_at on it. + // K8s cascades CRD deletion to child actions via OwnerReferences; the action service + // informer handles marking them ABORTED in DB when their CRDs are deleted. AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error // Action operations @@ -26,6 +29,8 @@ type ActionRepo interface { GetAction(ctx context.Context, actionID *common.ActionIdentifier) (*models.Action, error) ListActions(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error) UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time) error + // AbortAction marks only the targeted action as ABORTED and sets abort_requested_at. + // K8s cascades CRD deletion to descendants via OwnerReferences. AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error // Abort reconciliation — used by the background AbortReconciler. diff --git a/runs/service/abort_reconciler.go b/runs/service/abort_reconciler.go index adbd32ddb8..27d91e863e 100644 --- a/runs/service/abort_reconciler.go +++ b/runs/service/abort_reconciler.go @@ -2,8 +2,11 @@ package service import ( "context" + "database/sql" + "errors" "fmt" "math/rand" + "strings" "sync" "time" @@ -217,6 +220,12 @@ func (r *AbortReconciler) runWorker(ctx context.Context) { func (r *AbortReconciler) processTask(ctx context.Context, task abortTask) { attemptCount, err := r.repo.ActionRepo().MarkAbortAttempt(ctx, task.actionID) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + // Action no longer exists in the DB — nothing to abort, drop it. + r.queue.remove(task.key) + logger.Warnf(ctx, "AbortReconciler: action %s not found in DB, dropping abort task", task.key) + return + } logger.Errorf(ctx, "AbortReconciler: failed to mark attempt for %s: %v", task.key, err) // Re-enqueue without counting — the DB row is authoritative; try again later. r.queue.scheduleRequeue(ctx, task, r.cfg.InitialDelay) @@ -264,6 +273,8 @@ func (r *AbortReconciler) processTask(ctx context.Context, task abortTask) { } // isAlreadyTerminated returns true for errors that indicate the action is already gone. +// The actions service may wrap a Kubernetes "not found" error as CodeInternal rather +// than CodeNotFound, so we also check for that case by inspecting the message. func isAlreadyTerminated(err error) bool { if err == nil { return false @@ -272,5 +283,14 @@ func isAlreadyTerminated(err error) bool { if !ok { return false } - return connectErr.Code() == connect.CodeNotFound + if connectErr.Code() == connect.CodeNotFound { + return true + } + // The actions service forwards Kubernetes API "not found" errors with CodeInternal. + // Treat those as "already gone" so the reconciler clears the DB entry instead of + // retrying indefinitely. + if connectErr.Code() == connect.CodeInternal && strings.Contains(connectErr.Message(), "not found") { + return true + } + return false } diff --git a/runs/service/run_service.go b/runs/service/run_service.go index abcbabeada..b9f8d7cdd2 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -387,14 +387,19 @@ func (s *RunService) AbortRun( reason = *req.Msg.Reason } - // Abort in database, then push to reconciler for background pod termination. + // Mark only the root action ABORTED in DB, then push it to the reconciler. + // The reconciler deletes "a0"'s CRD; K8s cascades deletion to all child CRDs + // via OwnerReferences, and the action service informer marks them ABORTED in DB. if err := s.repo.ActionRepo().AbortRun(ctx, req.Msg.RunId, reason, nil); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("run not found: %s/%s/%s", req.Msg.RunId.Project, req.Msg.RunId.Domain, req.Msg.RunId.Name)) + } logger.Errorf(ctx, "Failed to abort run: %v", err) return nil, connect.NewError(connect.CodeInternal, err) } if s.abortReconciler != nil { - s.abortReconciler.Push(ctx, &common.ActionIdentifier{Run: req.Msg.RunId, Name: req.Msg.RunId.Name}, reason) + s.abortReconciler.Push(ctx, &common.ActionIdentifier{Run: req.Msg.RunId, Name: "a0"}, reason) } return connect.NewResponse(&workflow.AbortRunResponse{}), nil @@ -553,7 +558,13 @@ func (s *RunService) buildActionDetails(ctx context.Context, model *models.Actio } } case common.ActionPhase_ACTION_PHASE_ABORTED: - // TODO: set AbortInfo + abortInfo := &workflow.AbortInfo{} + if model.AbortReason != nil { + abortInfo.Reason = *model.AbortReason + } + action.Result = &workflow.ActionDetails_AbortInfo{ + AbortInfo: abortInfo, + } } return action, nil @@ -711,6 +722,22 @@ func IsTerminalPhase(phase common.ActionPhase) bool { phase == common.ActionPhase_ACTION_PHASE_ABORTED } +// lastAttemptIsTerminal returns true when the highest-numbered attempt has reached a +// terminal phase. Used by WatchActionDetails to close the stream only after action_events +// reflects the terminal transition, not just the actions table. +func lastAttemptIsTerminal(attempts []*workflow.ActionAttempt) bool { + if len(attempts) == 0 { + return false + } + var last *workflow.ActionAttempt + for _, a := range attempts { + if last == nil || a.GetAttempt() > last.GetAttempt() { + last = a + } + } + return IsTerminalPhase(last.GetPhase()) +} + // GetActionData gets input and output data for an action by reading from storage. func (s *RunService) GetActionData( ctx context.Context, @@ -982,6 +1009,9 @@ func (s *RunService) AbortAction( // Abort in database, then push to reconciler for background pod termination. if err := s.repo.ActionRepo().AbortAction(ctx, req.Msg.ActionId, reason, nil); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action not found: %s", req.Msg.ActionId.Name)) + } logger.Errorf(ctx, "Failed to abort action: %v", err) return nil, connect.NewError(connect.CodeInternal, err) } @@ -1055,6 +1085,12 @@ func (s *RunService) WatchActionDetails( actionID := req.Msg.ActionId logger.Infof(ctx, "Received WatchActionDetails request for: %s/%s", actionID.Run.Name, actionID.Name) + // Subscribe FIRST to avoid missing notifications that fire between the DB + // read and the subscription setup (same pattern as WatchActions). + updates := make(chan *models.Action, 50) + errs := make(chan error, 1) + go s.repo.ActionRepo().WatchActionUpdates(ctx, actionID, updates, errs) + // Step 1: Send initial state from DB details, err := s.getActionDetails(ctx, req.Msg.GetActionId()) if err != nil { @@ -1067,16 +1103,11 @@ func (s *RunService) WatchActionDetails( return err } - // If the action is already in a terminal phase, no further updates are expected. - if IsTerminalPhase(details.GetStatus().GetPhase()) { + // Close only once action_events reflects the terminal phase, not just actions table. + if lastAttemptIsTerminal(details.GetAttempts()) { return nil } - // Step 2: Watch DB for updates - updates := make(chan *models.Action, 50) - errs := make(chan error, 1) - go s.repo.ActionRepo().WatchActionUpdates(ctx, actionID, updates, errs) - for { select { case <-ctx.Done(): @@ -1101,8 +1132,8 @@ func (s *RunService) WatchActionDetails( }); err != nil { return err } - // Close the stream once the action reaches a terminal phase. - if IsTerminalPhase(details.GetStatus().GetPhase()) { + // Close once action_events reflects the terminal phase. + if lastAttemptIsTerminal(details.GetAttempts()) { return nil } }