Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions actions/k8s/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,18 @@ func (c *ActionsClient) setupInformer(ctx context.Context) error {
}
c.dispatchEvent(taskAction, watch.Modified)
},
DeleteFunc: func(obj interface{}) {
// The informer may deliver a DeletedFinalStateUnknown tombstone
// when a delete event was missed; unwrap it first.
if tombstone, ok := obj.(toolscache.DeletedFinalStateUnknown); ok {
obj = tombstone.Obj
}
taskAction, ok := obj.(*executorv1.TaskAction)
if !ok {
return
}
c.dispatchEvent(taskAction, watch.Deleted)
},
})
if err != nil {
return fmt.Errorf("failed to add TaskAction informer handler: %w", err)
Expand Down Expand Up @@ -494,6 +506,11 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e
shortName = extractShortNameFromTemplate(taskAction.Spec.TaskTemplate)
}

phase := GetPhaseFromConditions(taskAction)
if eventType == watch.Deleted {
phase = common.ActionPhase_ACTION_PHASE_ABORTED
}

return &ActionUpdate{
ActionID: &common.ActionIdentifier{
Run: &common.RunIdentifier{
Expand All @@ -505,7 +522,7 @@ func buildActionUpdate(ctx context.Context, taskAction *executorv1.TaskAction, e
},
ParentActionName: parentName,
StateJSON: taskAction.Status.StateJSON,
Phase: GetPhaseFromConditions(taskAction),
Phase: phase,
OutputUri: buildOutputUri(ctx, taskAction),
IsDeleted: eventType == watch.Deleted,
TaskType: taskAction.Spec.TaskType,
Expand Down Expand Up @@ -536,7 +553,7 @@ func (c *ActionsClient) notifySubscribers(ctx context.Context, update *ActionUpd

// notifyRunService forwards a watch event to the internal run service.
// On ADDED events it calls RecordAction to create the DB record.
// On all events it calls UpdateActionStatus (when phase is meaningful) and RecordActionEvents.
// On all events it calls UpdateActionStatus (when phase is meaningful) to update the actions table.
func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *executorv1.TaskAction, update *ActionUpdate, eventType watch.EventType) {
if c.runClient == nil {
return
Expand Down Expand Up @@ -618,7 +635,9 @@ func (c *ActionsClient) notifyRunService(ctx context.Context, taskAction *execut
}
if _, err := c.runClient.UpdateActionStatus(ctx, connect.NewRequest(statusReq)); err != nil {
logger.Warnf(ctx, "Failed to update action status in run service for %s: %v", update.ActionID.Name, err)
} else if isTerminalPhase(update.Phase) {
} else if isTerminalPhase(update.Phase) && !update.IsDeleted {
// Skip label patching for deleted CRs — the patch would always fail
// with "not found" since the object is already gone.
if err := c.markTerminalStatusRecorded(ctx, taskAction); err != nil {
logger.Warnf(ctx, "Failed to mark terminal status recorded for %s: %v", update.ActionID.Name, err)
}
Expand Down
4 changes: 2 additions & 2 deletions buf.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ deps:
commit: 62f35d8aed1149c291d606d958a7ce32
digest: b5:d66bf04adc77a0870bdc9328aaf887c7188a36fb02b83a480dc45ef9dc031b4d39fc6e9dc6435120ccf4fe5bfd5c6cb6592533c6c316595571f9a31420ab47fe
- name: buf.build/grpc-ecosystem/grpc-gateway
commit: 6467306b4f624747aaf6266762ee7a1c
digest: b5:c2caa61467d992749812c909f93c07e9a667da33c758a7c1973d63136c23b3cafcc079985b12cdf54a10049ed3297418f1eda42cdffdcf34113792dcc3a990af
commit: 4836b6d552304e1bbe47e66a523f0daa
digest: b5:c3fefd4d3dfa9b0478bbb1a4ad87d7b38146e3ce6eff4214e32f2c5834c2e4afc3be218316f0fbd53e925a001c3ed1e2fc99fb76b3121ede642989f0d0d7c71c
13 changes: 13 additions & 0 deletions executor/pkg/controller/taskaction_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,19 @@ func (r *TaskActionReconciler) handleAbortAndFinalize(ctx context.Context, taskA
}
}

abortTime := time.Now()
abortPhaseInfo := pluginsCore.PhaseInfoAborted(abortTime, pluginsCore.DefaultPhaseVersion, "aborted")
actionEvent := r.buildActionEvent(ctx, taskAction, abortPhaseInfo)
// buildActionEvent derives UpdatedTime from PhaseHistory, which doesn't include the
// abort transition. Override it so mergeEvents uses the actual abort time as end_time.
actionEvent.UpdatedTime = timestamppb.New(abortTime)
if _, err := r.eventsClient.Record(ctx, connect.NewRequest(&workflow.RecordRequest{
Events: []*workflow.ActionEvent{actionEvent},
})); err != nil {
logger.Error(err, "failed to emit abort event, will retry")
return ctrl.Result{RequeueAfter: TaskActionDefaultRequeueDuration}, nil
}

return r.removeFinalizer(ctx, taskAction)
}

Expand Down
94 changes: 94 additions & 0 deletions executor/pkg/controller/taskaction_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package controller

import (
"context"
"sync"
"time"

"connectrpc.com/connect"
Expand All @@ -34,6 +35,7 @@ import (
flyteorgv1 "github.com/flyteorg/flyte/v2/executor/api/v1"
pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
k8sPlugin "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
)
Expand All @@ -45,6 +47,27 @@ func (f *fakeEventsClient) Record(_ context.Context, _ *connect.Request[workflow
return connect.NewResponse(&workflow.RecordResponse{}), nil
}

// recordingEventsClient captures all recorded ActionEvents for assertion in tests.
type recordingEventsClient struct {
mu sync.Mutex
events []*workflow.ActionEvent
}

func (r *recordingEventsClient) Record(_ context.Context, req *connect.Request[workflow.RecordRequest]) (*connect.Response[workflow.RecordResponse], error) {
r.mu.Lock()
defer r.mu.Unlock()
r.events = append(r.events, req.Msg.GetEvents()...)
return connect.NewResponse(&workflow.RecordResponse{}), nil
}

func (r *recordingEventsClient) RecordedEvents() []*workflow.ActionEvent {
r.mu.Lock()
defer r.mu.Unlock()
out := make([]*workflow.ActionEvent, len(r.events))
copy(out, r.events)
return out
}

// buildTaskTemplateBytes creates a minimal protobuf-serialized TaskTemplate
// with a container spec that the pod plugin can use to build a Pod.
func buildTaskTemplateBytes(taskType, image string) []byte {
Expand Down Expand Up @@ -285,6 +308,77 @@ var _ = Describe("TaskAction Controller", func() {
})
})

Context("When a TaskAction is deleted (abort flow)", func() {
const abortResourceName = "abort-test-resource"

ctx := context.Background()

typeNamespacedName := types.NamespacedName{
Name: abortResourceName,
Namespace: "default",
}

BeforeEach(func() {
resource := &flyteorgv1.TaskAction{
ObjectMeta: metav1.ObjectMeta{
Name: abortResourceName,
Namespace: "default",
Finalizers: []string{taskActionFinalizer},
},
Spec: flyteorgv1.TaskActionSpec{
RunName: "abort-run",
Project: "abort-project",
Domain: "abort-domain",
ActionName: "abort-action",
InputURI: "/tmp/input",
RunOutputBase: "/tmp/output",
TaskType: "python-task",
TaskTemplate: buildTaskTemplateBytes("python-task", "python:3.11"),
},
}
Expect(k8sClient.Create(ctx, resource)).To(Succeed())
Expect(k8sClient.Delete(ctx, resource)).To(Succeed())
})

AfterEach(func() {
resource := &flyteorgv1.TaskAction{}
err := k8sClient.Get(ctx, typeNamespacedName, resource)
if err == nil {
resource.Finalizers = nil
Expect(k8sClient.Update(ctx, resource)).To(Succeed())
}
})

It("should emit an ACTION_PHASE_ABORTED event before removing the finalizer", func() {
recorder := &recordingEventsClient{}
reconciler := &TaskActionReconciler{
Client: k8sClient,
Scheme: k8sClient.Scheme(),
Recorder: record.NewFakeRecorder(10),
PluginRegistry: pluginRegistry,
DataStore: dataStore,
eventsClient: recorder,
}

result, err := reconciler.Reconcile(ctx, reconcile.Request{NamespacedName: typeNamespacedName})
Expect(err).NotTo(HaveOccurred())
Expect(result.RequeueAfter).To(BeZero())

// Finalizer should have been removed — object is gone.
deleted := &flyteorgv1.TaskAction{}
Expect(k8sClient.Get(ctx, typeNamespacedName, deleted)).NotTo(Succeed())

// An ABORTED event must have been emitted.
recorded := recorder.RecordedEvents()
Expect(recorded).NotTo(BeEmpty())
phases := make([]interface{}, len(recorded))
for i, e := range recorded {
phases[i] = e.GetPhase()
}
Expect(phases).To(ContainElement(common.ActionPhase_ACTION_PHASE_ABORTED))
})
})

Context("toClusterEvents", func() {
It("should include both phase reason and additional reasons", func() {
phaseOccurredAt := time.Date(2026, 4, 2, 10, 0, 0, 0, time.UTC)
Expand Down
6 changes: 6 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/core/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ func PhaseInfoSuccess(info *TaskInfo) PhaseInfo {
return phaseInfo(PhaseSuccess, DefaultPhaseVersion, nil, info, false)
}

func PhaseInfoAborted(t time.Time, version uint32, reason string) PhaseInfo {
pi := phaseInfo(PhaseAborted, version, nil, &TaskInfo{OccurredAt: &t}, false)
pi.reason = reason
return pi
}

func PhaseInfoSystemFailure(code, reason string, info *TaskInfo) PhaseInfo {
return phaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason, Kind: core.ExecutionError_SYSTEM}, info, false)
}
Expand Down
46 changes: 30 additions & 16 deletions runs/repository/impl/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,22 +136,32 @@ func (r *actionRepo) ListRuns(ctx context.Context, req *workflow.ListRunsRequest
return runs, nextToken, nil
}

// AbortRun aborts a run and all its actions
// AbortRun marks only the root action as ABORTED and sets abort_requested_at on it.
// K8s cascades CRD deletion to child actions via OwnerReferences; the action service
// informer handles marking them ABORTED in DB when their CRDs are deleted.
func (r *actionRepo) AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error {
now := time.Now()

_, err := r.db.ExecContext(ctx,
`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5
WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL`,
var rootName string
err := r.db.QueryRowxContext(ctx,
`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5,
ended_at = COALESCE(ended_at, GREATEST($2, created_at)),
duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000
WHERE project = $6 AND domain = $7 AND run_name = $8 AND parent_action_name IS NULL
RETURNING name`,
int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason,
runID.Project, runID.Domain, runID.Name)

runID.Project, runID.Domain, runID.Name,
).Scan(&rootName)
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("run not found: %w", sql.ErrNoRows)
}
if err != nil {
return fmt.Errorf("failed to abort run: %w", err)
}

// Notify run subscribers.
rootID := &common.ActionIdentifier{Run: runID, Name: rootName}
r.notifyRunUpdate(ctx, runID)
r.notifyActionUpdate(ctx, rootID)

logger.Infof(ctx, "Aborted run: %s/%s/%s", runID.Project, runID.Domain, runID.Name)
return nil
Expand Down Expand Up @@ -431,7 +441,6 @@ func (r *actionRepo) UpdateActionPhase(
return err
}
if rowsAffected > 0 {
// Notify subscribers of the action update
r.notifyActionUpdate(ctx, actionID)
}

Expand All @@ -444,24 +453,29 @@ func (r *actionRepo) UpdateActionPhase(
return nil
}

// AbortAction aborts a specific action
// AbortAction marks only the targeted action as ABORTED and sets abort_requested_at.
// K8s cascades CRD deletion to descendants via OwnerReferences; the action service
// informer handles marking them ABORTED in DB when their CRDs are deleted.
func (r *actionRepo) AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error {
now := time.Now()

_, err := r.db.ExecContext(ctx,
`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5
result, err := r.db.ExecContext(ctx,
`UPDATE actions SET phase = $1, updated_at = $2, abort_requested_at = $3, abort_attempt_count = $4, abort_reason = $5,
ended_at = COALESCE(ended_at, GREATEST($2, created_at)),
duration_ms = EXTRACT(EPOCH FROM (COALESCE(ended_at, GREATEST($2, created_at)) - created_at)) * 1000
WHERE project = $6 AND domain = $7 AND run_name = $8 AND name = $9`,
int32(common.ActionPhase_ACTION_PHASE_ABORTED), now, now, 0, reason,
actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name)

if err != nil {
return fmt.Errorf("failed to abort action: %w", err)
}
if n, _ := result.RowsAffected(); n == 0 {
return fmt.Errorf("action not found: %w", sql.ErrNoRows)
}

// Notify action subscribers.
r.notifyActionUpdate(ctx, actionID)

logger.Infof(ctx, "Aborted action: %s", actionID.Name)
logger.Infof(ctx, "AbortAction: aborted %s", actionID.Name)
return nil
}

Expand Down Expand Up @@ -495,7 +509,7 @@ func (r *actionRepo) MarkAbortAttempt(ctx context.Context, actionID *common.Acti
// ClearAbortRequest clears abort_requested_at (and resets counters) once the pod is confirmed terminated.
func (r *actionRepo) ClearAbortRequest(ctx context.Context, actionID *common.ActionIdentifier) error {
_, err := r.db.ExecContext(ctx,
`UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, abort_reason = NULL, updated_at = $1
`UPDATE actions SET abort_requested_at = NULL, abort_attempt_count = 0, updated_at = $1
WHERE project = $2 AND domain = $3 AND run_name = $4 AND name = $5`,
time.Now(), actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name)
if err != nil {
Expand Down Expand Up @@ -856,7 +870,7 @@ func (r *actionRepo) processNotifications() {
select {
case ch <- notif.Extra:
default:
logger.Warnf(context.Background(), "Action subscriber channel full, dropping notification")
logger.Warnf(context.Background(), "Action subscriber channel full, dropping notification payload=%s", notif.Extra)
}
}
r.mu.RUnlock()
Expand Down
36 changes: 36 additions & 0 deletions runs/repository/impl/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,39 @@ func TestInsertEvents_WithLogContext(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "my-pod", deserialized.GetLogContext().GetPrimaryPodName())
}

// TestUpdateActionPhase_AbortedDoesNotInsertEvent verifies that transitioning an
// action to ABORTED updates the phase column but does NOT insert a synthetic row
// into action_events. The abort event is now emitted by the controller via
// RecordActionEvents before the TaskAction finalizer is removed.
func TestUpdateActionPhase_AbortedDoesNotInsertEvent(t *testing.T) {
db := setupActionDB(t)
actionRepo, err := NewActionRepo(db, testDbConfig)
require.NoError(t, err)
ctx := context.Background()

actionID := &common.ActionIdentifier{
Run: &common.RunIdentifier{Project: "p", Domain: "d", Name: "run-abort"},
Name: "abort-action",
}
_, err = actionRepo.CreateAction(ctx, models.NewActionModel(actionID), false)
require.NoError(t, err)

endTime := time.Now()
err = actionRepo.UpdateActionPhase(ctx, actionID, common.ActionPhase_ACTION_PHASE_ABORTED, 1, core.CatalogCacheStatus_CACHE_DISABLED, &endTime)
require.NoError(t, err)

// Phase column must be updated.
action, err := actionRepo.GetAction(ctx, actionID)
require.NoError(t, err)
assert.Equal(t, int32(common.ActionPhase_ACTION_PHASE_ABORTED), action.Phase)

// No synthetic event row should have been inserted — the controller now emits the event.
var count int
err = db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM action_events WHERE project=$1 AND domain=$2 AND run_name=$3 AND name=$4`,
actionID.Run.Project, actionID.Run.Domain, actionID.Run.Name, actionID.Name,
).Scan(&count)
require.NoError(t, err)
assert.Equal(t, 0, count, "UpdateActionPhase(ABORTED) must not insert a synthetic action_events row")
}
5 changes: 5 additions & 0 deletions runs/repository/interfaces/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ type ActionRepo interface {
// Run operations
GetRun(ctx context.Context, runID *common.RunIdentifier) (*models.Run, error)
ListRuns(ctx context.Context, req *workflow.ListRunsRequest) ([]*models.Run, string, error)
// AbortRun marks only the root action as ABORTED and sets abort_requested_at on it.
// K8s cascades CRD deletion to child actions via OwnerReferences; the action service
// informer handles marking them ABORTED in DB when their CRDs are deleted.
AbortRun(ctx context.Context, runID *common.RunIdentifier, reason string, abortedBy *common.EnrichedIdentity) error

// Action operations
Expand All @@ -26,6 +29,8 @@ type ActionRepo interface {
GetAction(ctx context.Context, actionID *common.ActionIdentifier) (*models.Action, error)
ListActions(ctx context.Context, runID *common.RunIdentifier, limit int, token string) ([]*models.Action, string, error)
UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time) error
// AbortAction marks only the targeted action as ABORTED and sets abort_requested_at.
// K8s cascades CRD deletion to descendants via OwnerReferences.
AbortAction(ctx context.Context, actionID *common.ActionIdentifier, reason string, abortedBy *common.EnrichedIdentity) error

// Abort reconciliation — used by the background AbortReconciler.
Expand Down
Loading
Loading