diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 74021df2be..7e4b32e25e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -345,3 +345,11 @@ func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExe return replicaSpec, nil } + +func GetReplicaCount(specs map[commonOp.ReplicaType]*commonOp.ReplicaSpec, replicaType commonOp.ReplicaType) *int32 { + if spec, ok := specs[replicaType]; ok && spec.Replicas != nil { + return spec.Replicas + } + + return new(int32) // return 0 as default value +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 826e83b671..97199025a7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -162,8 +162,8 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type") } - numWorkers = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas - numLauncherReplicas = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas + numWorkers = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeWorker) + numLauncherReplicas = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeLauncher) taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, false, *numWorkers, *numLauncherReplicas, 0, 0) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 1cb6e9d826..d009c7c887 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -776,3 +776,17 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) { assert.NotContains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Tolerations, gpuToleration) assert.Contains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) } + +func TestGetReplicaCount(t *testing.T) { + mpiResourceHandler := mpiOperatorResourceHandler{} + tfObj := dummyMPICustomObj(1, 1, 0) + taskTemplate := dummyMPITaskTemplate("the job", tfObj) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) + MPIJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + + assert.NotNil(t, common.GetReplicaCount(MPIJob.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeWorker)) + assert.NotNil(t, common.GetReplicaCount(MPIJob.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeLauncher)) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 81c8e16cd5..6d0bad4ecd 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -171,7 +171,10 @@ func ParseElasticConfig(elasticConfig ElasticConfig) *kubeflowv1.ElasticPolicy { // any operations that might take a long time (limits are configured system-wide) should be offloaded to the // background. func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { - app := resource.(*kubeflowv1.PyTorchJob) + app, ok := resource.(*kubeflowv1.PyTorchJob) + if !ok { + return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type") + } // Elastic PytorchJobs don't use master replicas hasMaster := false @@ -179,7 +182,7 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont hasMaster = true } - workersCount := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas + workersCount := common.GetReplicaCount(app.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0, 0) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index f0e215f262..e0606b1020 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -959,3 +959,16 @@ func TestParseElasticConfig(t *testing.T) { assert.Equal(t, int32(4), *elasticPolicy.NProcPerNode) assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *elasticPolicy.RDZVBackend) } + +func TestGetReplicaCount(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + tfObj := dummyPytorchCustomObj(1) + taskTemplate := dummyPytorchTaskTemplate("the job", tfObj) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil)) + assert.NoError(t, err) + assert.NotNil(t, resource) + PytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + + assert.NotNil(t, common.GetReplicaCount(PytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker)) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index f44b550ede..db5fe6a83a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -139,24 +139,19 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return job, nil } -func getReplicaCount(specs map[commonOp.ReplicaType]*commonOp.ReplicaSpec, replicaType commonOp.ReplicaType) *int32 { - if spec, ok := specs[replicaType]; ok && spec.Replicas != nil { - return spec.Replicas - } - - return new(int32) // return 0 as default value -} - // Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast, // any operations that might take a long time (limits are configured system-wide) should be offloaded to the // background. func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { - app := resource.(*kubeflowv1.TFJob) + app, ok := resource.(*kubeflowv1.TFJob) + if !ok { + return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type") + } - workersCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker) - psReplicasCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS) - chiefCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief) - evaluatorReplicasCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval) + workersCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker) + psReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS) + chiefCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief) + evaluatorReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval) taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, false, *workersCount, *psReplicasCount, *chiefCount, *evaluatorReplicasCount) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 7b694ad9d9..c3252183fc 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -305,10 +305,10 @@ func TestGetReplicaCount(t *testing.T) { tensorflowJob, ok := resource.(*kubeflowv1.TFJob) assert.True(t, ok) - assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker)) - assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)) - assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)) - assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval)) + assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker)) + assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)) + assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)) + assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval)) } func TestBuildResourceTensorFlow(t *testing.T) {