diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index ec42029b4..1b7624888 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -13,7 +13,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" - commonKf "github.com/kubeflow/common/pkg/apis/common/v1" + commonOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -92,32 +92,35 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu } if workers == 0 { - return nil, fmt.Errorf("number of worker should be more then 1 ") + return nil, fmt.Errorf("number of worker should be more then 0") } if launcherReplicas == 0 { - return nil, fmt.Errorf("number of launch worker should be more then 1") + return nil, fmt.Errorf("number of launch worker should be more then 0") } jobSpec := kubeflowv1.MPIJobSpec{ - SlotsPerWorker: &slots, - MPIReplicaSpecs: map[commonKf.ReplicaType]*commonKf.ReplicaSpec{ - kubeflowv1.MPIJobReplicaTypeLauncher: { - Replicas: &launcherReplicas, - Template: v1.PodTemplateSpec{ - ObjectMeta: objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonKf.RestartPolicyNever, - }, - kubeflowv1.MPIJobReplicaTypeWorker: { - Replicas: &workers, + SlotsPerWorker: &slots, + MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{}, + } + + for _, t := range []struct { + podSpec v1.PodSpec + replicaNum *int32 + replicaType commonOp.ReplicaType + }{ + {*podSpec, &launcherReplicas, kubeflowv1.MPIJobReplicaTypeLauncher}, + {*workersPodSpec, &workers, kubeflowv1.MPIJobReplicaTypeWorker}, + } { + if *t.replicaNum > 0 { + jobSpec.MPIReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ + Replicas: t.replicaNum, Template: v1.PodTemplateSpec{ ObjectMeta: objectMeta, - Spec: *workersPodSpec, + Spec: t.podSpec, }, - RestartPolicy: commonKf.RestartPolicyNever, - }, - }, + RestartPolicy: commonOp.RestartPolicyNever, + } + } } job := &kubeflowv1.MPIJob{ diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index b9effbe99..29fefd9ca 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -6,29 +6,28 @@ import ( "testing" "time" - "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/logs" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" + mpiOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" - - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - - pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" - "github.com/stretchr/testify/assert" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -385,3 +384,46 @@ func TestGetProperties(t *testing.T) { expected := k8s.PluginProperties{} assert.Equal(t, expected, mpiResourceHandler.GetProperties()) } + +func TestReplicaCounts(t *testing.T) { + for _, test := range []struct { + name string + launcherReplicaCount int32 + workerReplicaCount int32 + expectError bool + contains []mpiOp.ReplicaType + notContains []mpiOp.ReplicaType + }{ + {"NoWorkers", 0, 1, true, nil, nil}, + {"NoLaunchers", 1, 0, true, nil, nil}, + {"Works", 1, 1, false, []mpiOp.ReplicaType{kubeflowv1.MPIJobReplicaTypeLauncher, kubeflowv1.MPIJobReplicaTypeWorker}, []mpiOp.ReplicaType{}}, + } { + t.Run(test.name, func(t *testing.T) { + mpiResourceHandler := mpiOperatorResourceHandler{} + + mpiObj := dummyMPICustomObj(test.workerReplicaCount, test.launcherReplicaCount, 1) + taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) + + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate)) + if test.expectError { + assert.Error(t, err) + assert.Nil(t, resource) + return + } + + assert.NoError(t, err) + assert.NotNil(t, resource) + + job, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + + assert.Len(t, job.Spec.MPIReplicaSpecs, len(test.contains)) + for _, replicaType := range test.contains { + assert.Contains(t, job.Spec.MPIReplicaSpecs, replicaType) + } + for _, replicaType := range test.notContains { + assert.NotContains(t, job.Spec.MPIReplicaSpecs, replicaType) + } + }) + } +} diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index ce36fd2c8..01421f1a9 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -2,26 +2,27 @@ package pytorch import ( "context" + "fmt" "time" - "sigs.k8s.io/controller-runtime/pkg/client" - - "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - v1 "k8s.io/api/core/v1" - "k8s.io/client-go/kubernetes/scheme" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" commonOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + + "sigs.k8s.io/controller-runtime/pkg/client" ) type pytorchOperatorResourceHandler struct { @@ -82,6 +83,9 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx } workers := pytorchTaskExtraArgs.GetWorkers() + if workers == 0 { + return nil, fmt.Errorf("number of worker should be more then 0") + } jobSpec := kubeflowv1.PyTorchJobSpec{ PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 0c4dcb3b1..f1979dbf2 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -374,3 +374,44 @@ func TestGetProperties(t *testing.T) { expected := k8s.PluginProperties{} assert.Equal(t, expected, pytorchResourceHandler.GetProperties()) } + +func TestReplicaCounts(t *testing.T) { + for _, test := range []struct { + name string + workerReplicaCount int32 + expectError bool + contains []commonOp.ReplicaType + notContains []commonOp.ReplicaType + }{ + {"NoWorkers", 0, true, nil, nil}, + {"Works", 1, false, []commonOp.ReplicaType{kubeflowv1.PyTorchJobReplicaTypeMaster, kubeflowv1.PyTorchJobReplicaTypeWorker}, []commonOp.ReplicaType{}}, + } { + t.Run(test.name, func(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + ptObj := dummyPytorchCustomObj(test.workerReplicaCount) + taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) + + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + if test.expectError { + assert.Error(t, err) + assert.Nil(t, resource) + return + } + + assert.NoError(t, err) + assert.NotNil(t, resource) + + job, ok := resource.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + + assert.Len(t, job.Spec.PyTorchReplicaSpecs, len(test.contains)) + for _, replicaType := range test.contains { + assert.Contains(t, job.Spec.PyTorchReplicaSpecs, replicaType) + } + for _, replicaType := range test.notContains { + assert.NotContains(t, job.Spec.PyTorchReplicaSpecs, replicaType) + } + }) + } +} diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index d2370cf94..a5c813ca4 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -2,26 +2,27 @@ package tensorflow import ( "context" + "fmt" "time" - "sigs.k8s.io/controller-runtime/pkg/client" - - "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - v1 "k8s.io/api/core/v1" - "k8s.io/client-go/kubernetes/scheme" - pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" commonOp "github.com/kubeflow/common/pkg/apis/common/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + + "sigs.k8s.io/controller-runtime/pkg/client" ) type tensorflowOperatorResourceHandler struct { @@ -85,32 +86,36 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task psReplicas := tensorflowTaskExtraArgs.GetPsReplicas() chiefReplicas := tensorflowTaskExtraArgs.GetChiefReplicas() + if workers == 0 { + return nil, fmt.Errorf("number of worker should be more then 0") + } + if psReplicas == 0 && chiefReplicas == 0 { + return nil, fmt.Errorf("either number of chief or parameter servers needs to be be more then 0") + } + jobSpec := kubeflowv1.TFJobSpec{ - TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.TFJobReplicaTypePS: { - Replicas: &psReplicas, - Template: v1.PodTemplateSpec{ - ObjectMeta: objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, - }, - kubeflowv1.TFJobReplicaTypeChief: { - Replicas: &chiefReplicas, + TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{}, + } + + for _, t := range []struct { + podSpec v1.PodSpec + replicaNum *int32 + replicaType commonOp.ReplicaType + }{ + {*podSpec, &workers, kubeflowv1.TFJobReplicaTypeWorker}, + {*podSpec, &psReplicas, kubeflowv1.TFJobReplicaTypePS}, + {*podSpec, &chiefReplicas, kubeflowv1.TFJobReplicaTypeChief}, + } { + if *t.replicaNum > 0 { + jobSpec.TFReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ + Replicas: t.replicaNum, Template: v1.PodTemplateSpec{ ObjectMeta: objectMeta, - Spec: *podSpec, + Spec: t.podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, - }, - kubeflowv1.TFJobReplicaTypeWorker: { - Replicas: &workers, - Template: v1.PodTemplateSpec{ - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, - }, - }, + } + } } job := &kubeflowv1.TFJob{ diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 2785459f4..2145f839c 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -70,7 +70,7 @@ func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int } } -func dummySparkTaskTemplate(id string, tensorflowCustomObj *plugins.DistributedTensorflowTrainingTask) *core.TaskTemplate { +func dummyTensorFlowTaskTemplate(id string, tensorflowCustomObj *plugins.DistributedTensorflowTrainingTask) *core.TaskTemplate { tfObjJSON, err := utils.MarshalToString(tensorflowCustomObj) if err != nil { @@ -251,7 +251,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso } tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas) - taskTemplate := dummySparkTaskTemplate("the job", tfObj) + taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) if err != nil { panic(err) @@ -277,7 +277,7 @@ func TestBuildResourceTensorFlow(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tfObj := dummyTensorFlowCustomObj(100, 50, 1) - taskTemplate := dummySparkTaskTemplate("the job", tfObj) + taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) assert.NoError(t, err) @@ -371,3 +371,52 @@ func TestGetProperties(t *testing.T) { expected := k8s.PluginProperties{} assert.Equal(t, expected, tensorflowResourceHandler.GetProperties()) } + +func TestReplicaCounts(t *testing.T) { + for _, test := range []struct { + name string + chiefReplicaCount int32 + psReplicaCount int32 + workerReplicaCount int32 + expectError bool + contains []commonOp.ReplicaType + notContains []commonOp.ReplicaType + }{ + {"NoWorkers", 1, 1, 0, true, nil, nil}, + {"NoChiefOrPS", 0, 0, 1, true, nil, nil}, + {"SingleChief", 1, 0, 1, false, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypeChief, kubeflowv1.TFJobReplicaTypeWorker}, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS}}, + {"SinglePS", 0, 1, 1, false, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS, kubeflowv1.TFJobReplicaTypeWorker}, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypeChief}}, + } { + t.Run(test.name, func(t *testing.T) { + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + + tfObj := dummyTensorFlowCustomObj(test.workerReplicaCount, test.psReplicaCount, test.chiefReplicaCount) + taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) + + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) + if test.expectError { + assert.Error(t, err) + assert.Nil(t, resource) + return + } + + assert.NoError(t, err) + assert.NotNil(t, resource) + + job, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + + assert.Len(t, job.Spec.TFReplicaSpecs, len(test.contains)) + for _, replicaType := range test.contains { + assert.Contains(t, job.Spec.TFReplicaSpecs, replicaType) + } + for _, replicaType := range test.notContains { + assert.NotContains(t, job.Spec.TFReplicaSpecs, replicaType) + } + }) + } +}