From fe3c9add15f7372193a60b88c8db7fd8026d15d1 Mon Sep 17 00:00:00 2001 From: Spyros Trigazis Date: Mon, 20 Apr 2026 11:44:34 +0300 Subject: [PATCH] Fix type error for suspended kubeflow jobs When kubeflow jobs are suspended by an external system (eg kueue), the controller error's with the following [0]. https://github.com/flyteorg/flyte/pull/6295/changes relies only on app.Spec.RunPolicy.Suspend and does not take into account the phase. [0] RuntimeExecutionError: failed during plugin execution, caused by: Invalid transition for plugin [pytorch]: transition doesn't have task info nor an execution error filled [TransitionTypeEphemeral,Phase Reason:>] Signed-off-by: Spyros Trigazis --- .../k8s/kfoperators/common/common_operator.go | 4 ++++ .../kfoperators/common/common_operator_test.go | 18 ++++++++++++++++++ .../plugins/k8s/kfoperators/mpi/mpi_test.go | 2 +- .../k8s/kfoperators/pytorch/pytorch_test.go | 2 +- .../kfoperators/tensorflow/tensorflow_test.go | 2 +- 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 9da1484c077..b7a1099e287 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -63,6 +63,8 @@ func GetPhaseInfo(currentCondition kubeflowv1.JobCondition, occurredAt time.Time return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil case kubeflowv1.JobRestarting: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil + case kubeflowv1.JobSuspended: + return pluginsCore.PhaseInfoQueuedWithTaskInfo(pluginsCore.DefaultPhaseVersion, "Suspended", &taskPhaseInfo), nil } return pluginsCore.PhaseInfoUndefined, nil @@ -83,6 +85,8 @@ func GetMPIPhaseInfo(currentCondition kubeflowv1.JobCondition, occurredAt time.T return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil case kubeflowv1.JobRestarting: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil + case kubeflowv1.JobSuspended: + return pluginsCore.PhaseInfoQueuedWithTaskInfo(pluginsCore.DefaultPhaseVersion, "Suspended", &taskPhaseInfo), nil } return pluginsCore.PhaseInfoUndefined, nil diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index cdb5d35480d..9e0b28400cb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -114,6 +114,15 @@ func TestGetPhaseInfo(t *testing.T) { assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) assert.NotNil(t, taskPhase.Info()) assert.Nil(t, err) + + jobSuspended := kubeflowv1.JobCondition{ + Type: kubeflowv1.JobSuspended, + } + taskPhase, err = GetPhaseInfo(jobSuspended, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) } func TestGetMPIPhaseInfo(t *testing.T) { @@ -161,6 +170,15 @@ func TestGetMPIPhaseInfo(t *testing.T) { assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) assert.NotNil(t, taskPhase.Info()) assert.Nil(t, err) + + jobSuspended := kubeflowv1.JobCondition{ + Type: kubeflowv1.JobSuspended, + } + taskPhase, err = GetMPIPhaseInfo(jobSuspended, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) } func TestGetLogs(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 9eeaa8fa58e..bdc1187bc7d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -556,7 +556,7 @@ func TestGetTaskPhase(t *testing.T) { assert.Equal(t, pluginsCore.PhaseInfoUndefined, taskPhase) // Training operator did not modify the job because it is suspended - mpiJobSuspended := dummyMPIJobResourceCreator(kubeflowv1.JobCreated) + mpiJobSuspended := dummyMPIJobResourceCreator(kubeflowv1.JobSuspended) mpiJobSuspended.CreationTimestamp = v1.Time{Time: time.Now().Add(-time.Hour)} mpiJobSuspended.Status.StartTime = nil suspend := true diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 31eb7778bd6..52631ac481c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -677,7 +677,7 @@ func TestGetTaskPhase(t *testing.T) { assert.Equal(t, pluginsCore.PhaseInfoUndefined, taskPhase) // Training operator did not modify the job because it is suspended - pytorchJobSuspended := dummyPytorchJobResourceCreator(kubeflowv1.JobCreated) + pytorchJobSuspended := dummyPytorchJobResourceCreator(kubeflowv1.JobSuspended) pytorchJobSuspended.CreationTimestamp = v1.Time{Time: time.Now().Add(-time.Hour)} pytorchJobSuspended.Status.StartTime = nil suspend := true diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index d76bf799d5d..e0ddb38f9ac 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -602,7 +602,7 @@ func TestGetTaskPhase(t *testing.T) { assert.Equal(t, pluginsCore.PhaseInfoUndefined, taskPhase) // Training operator did not modify the job because it is suspended - tfJobSuspended := dummyTensorFlowJobResourceCreator(kubeflowv1.JobCreated) + tfJobSuspended := dummyTensorFlowJobResourceCreator(kubeflowv1.JobSuspended) tfJobSuspended.CreationTimestamp = v1.Time{Time: time.Now().Add(-time.Hour)} tfJobSuspended.Status.StartTime = nil suspend := true