diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 9da1484c077..b7a1099e287 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -63,6 +63,8 @@ func GetPhaseInfo(currentCondition kubeflowv1.JobCondition, occurredAt time.Time return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil case kubeflowv1.JobRestarting: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil + case kubeflowv1.JobSuspended: + return pluginsCore.PhaseInfoQueuedWithTaskInfo(pluginsCore.DefaultPhaseVersion, "Suspended", &taskPhaseInfo), nil } return pluginsCore.PhaseInfoUndefined, nil @@ -83,6 +85,8 @@ func GetMPIPhaseInfo(currentCondition kubeflowv1.JobCondition, occurredAt time.T return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil case kubeflowv1.JobRestarting: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil + case kubeflowv1.JobSuspended: + return pluginsCore.PhaseInfoQueuedWithTaskInfo(pluginsCore.DefaultPhaseVersion, "Suspended", &taskPhaseInfo), nil } return pluginsCore.PhaseInfoUndefined, nil diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index cdb5d35480d..9e0b28400cb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -114,6 +114,15 @@ func TestGetPhaseInfo(t *testing.T) { assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) assert.NotNil(t, taskPhase.Info()) assert.Nil(t, err) + + jobSuspended := kubeflowv1.JobCondition{ + Type: kubeflowv1.JobSuspended, + } + taskPhase, err = GetPhaseInfo(jobSuspended, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) } func TestGetMPIPhaseInfo(t *testing.T) { @@ -161,6 +170,15 @@ func TestGetMPIPhaseInfo(t *testing.T) { assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) assert.NotNil(t, taskPhase.Info()) assert.Nil(t, err) + + jobSuspended := kubeflowv1.JobCondition{ + Type: kubeflowv1.JobSuspended, + } + taskPhase, err = GetMPIPhaseInfo(jobSuspended, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) } func TestGetLogs(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 9eeaa8fa58e..bdc1187bc7d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -556,7 +556,7 @@ func TestGetTaskPhase(t *testing.T) { assert.Equal(t, pluginsCore.PhaseInfoUndefined, taskPhase) // Training operator did not modify the job because it is suspended - mpiJobSuspended := dummyMPIJobResourceCreator(kubeflowv1.JobCreated) + mpiJobSuspended := dummyMPIJobResourceCreator(kubeflowv1.JobSuspended) mpiJobSuspended.CreationTimestamp = v1.Time{Time: time.Now().Add(-time.Hour)} mpiJobSuspended.Status.StartTime = nil suspend := true diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 31eb7778bd6..52631ac481c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -677,7 +677,7 @@ func TestGetTaskPhase(t *testing.T) { assert.Equal(t, pluginsCore.PhaseInfoUndefined, taskPhase) // Training operator did not modify the job because it is suspended - pytorchJobSuspended := dummyPytorchJobResourceCreator(kubeflowv1.JobCreated) + pytorchJobSuspended := dummyPytorchJobResourceCreator(kubeflowv1.JobSuspended) pytorchJobSuspended.CreationTimestamp = v1.Time{Time: time.Now().Add(-time.Hour)} pytorchJobSuspended.Status.StartTime = nil suspend := true diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index d76bf799d5d..e0ddb38f9ac 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -602,7 +602,7 @@ func TestGetTaskPhase(t *testing.T) { assert.Equal(t, pluginsCore.PhaseInfoUndefined, taskPhase) // Training operator did not modify the job because it is suspended - tfJobSuspended := dummyTensorFlowJobResourceCreator(kubeflowv1.JobCreated) + tfJobSuspended := dummyTensorFlowJobResourceCreator(kubeflowv1.JobSuspended) tfJobSuspended.CreationTimestamp = v1.Time{Time: time.Now().Add(-time.Hour)} tfJobSuspended.Status.StartTime = nil suspend := true