Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge 6b7b974 into 4634a81
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Feb 1, 2023
2 parents 4634a81 + 6b7b974 commit 88e2b40
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 72 deletions.
41 changes: 22 additions & 19 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
commonKf "github.com/kubeflow/common/pkg/apis/common/v1"
commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -92,32 +92,35 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
}

if workers == 0 {
return nil, fmt.Errorf("number of worker should be more then 1 ")
return nil, fmt.Errorf("number of worker should be more then 0")
}
if launcherReplicas == 0 {
return nil, fmt.Errorf("number of launch worker should be more then 1")
return nil, fmt.Errorf("number of launch worker should be more then 0")
}

jobSpec := kubeflowv1.MPIJobSpec{
SlotsPerWorker: &slots,
MPIReplicaSpecs: map[commonKf.ReplicaType]*commonKf.ReplicaSpec{
kubeflowv1.MPIJobReplicaTypeLauncher: {
Replicas: &launcherReplicas,
Template: v1.PodTemplateSpec{
ObjectMeta: objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonKf.RestartPolicyNever,
},
kubeflowv1.MPIJobReplicaTypeWorker: {
Replicas: &workers,
SlotsPerWorker: &slots,
MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{},
}

for _, t := range []struct {
podSpec v1.PodSpec
replicaNum *int32
replicaType commonOp.ReplicaType
}{
{*podSpec, &launcherReplicas, kubeflowv1.MPIJobReplicaTypeLauncher},
{*workersPodSpec, &workers, kubeflowv1.MPIJobReplicaTypeWorker},
} {
if *t.replicaNum > 0 {
jobSpec.MPIReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{
Replicas: t.replicaNum,
Template: v1.PodTemplateSpec{
ObjectMeta: objectMeta,
Spec: *workersPodSpec,
Spec: t.podSpec,
},
RestartPolicy: commonKf.RestartPolicyNever,
},
},
RestartPolicy: commonOp.RestartPolicyNever,
}
}
}

job := &kubeflowv1.MPIJob{
Expand Down
68 changes: 55 additions & 13 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,28 @@ import (
"testing"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"

"github.com/flyteorg/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"

mpiOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"

pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

Expand Down Expand Up @@ -385,3 +384,46 @@ func TestGetProperties(t *testing.T) {
expected := k8s.PluginProperties{}
assert.Equal(t, expected, mpiResourceHandler.GetProperties())
}

func TestReplicaCounts(t *testing.T) {
for _, test := range []struct {
name string
launcherReplicaCount int32
workerReplicaCount int32
expectError bool
contains []mpiOp.ReplicaType
notContains []mpiOp.ReplicaType
}{
{"NoWorkers", 0, 1, true, nil, nil},
{"NoLaunchers", 1, 0, true, nil, nil},
{"Works", 1, 1, false, []mpiOp.ReplicaType{kubeflowv1.MPIJobReplicaTypeLauncher, kubeflowv1.MPIJobReplicaTypeWorker}, []mpiOp.ReplicaType{}},
} {
t.Run(test.name, func(t *testing.T) {
mpiResourceHandler := mpiOperatorResourceHandler{}

mpiObj := dummyMPICustomObj(test.workerReplicaCount, test.launcherReplicaCount, 1)
taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj)

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate))
if test.expectError {
assert.Error(t, err)
assert.Nil(t, resource)
return
}

assert.NoError(t, err)
assert.NotNil(t, resource)

job, ok := resource.(*kubeflowv1.MPIJob)
assert.True(t, ok)

assert.Len(t, job.Spec.MPIReplicaSpecs, len(test.contains))
for _, replicaType := range test.contains {
assert.Contains(t, job.Spec.MPIReplicaSpecs, replicaType)
}
for _, replicaType := range test.notContains {
assert.NotContains(t, job.Spec.MPIReplicaSpecs, replicaType)
}
})
}
}
20 changes: 12 additions & 8 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@ package pytorch

import (
"context"
"fmt"
"time"

"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"

flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
v1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes/scheme"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"

commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/scheme"

"sigs.k8s.io/controller-runtime/pkg/client"
)

type pytorchOperatorResourceHandler struct {
Expand Down Expand Up @@ -82,6 +83,9 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
}

workers := pytorchTaskExtraArgs.GetWorkers()
if workers == 0 {
return nil, fmt.Errorf("number of worker should be more then 0")
}

jobSpec := kubeflowv1.PyTorchJobSpec{
PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
Expand Down
41 changes: 41 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,44 @@ func TestGetProperties(t *testing.T) {
expected := k8s.PluginProperties{}
assert.Equal(t, expected, pytorchResourceHandler.GetProperties())
}

func TestReplicaCounts(t *testing.T) {
for _, test := range []struct {
name string
workerReplicaCount int32
expectError bool
contains []commonOp.ReplicaType
notContains []commonOp.ReplicaType
}{
{"NoWorkers", 0, true, nil, nil},
{"Works", 1, false, []commonOp.ReplicaType{kubeflowv1.PyTorchJobReplicaTypeMaster, kubeflowv1.PyTorchJobReplicaTypeWorker}, []commonOp.ReplicaType{}},
} {
t.Run(test.name, func(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}

ptObj := dummyPytorchCustomObj(test.workerReplicaCount)
taskTemplate := dummyPytorchTaskTemplate("the job", ptObj)

resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
if test.expectError {
assert.Error(t, err)
assert.Nil(t, resource)
return
}

assert.NoError(t, err)
assert.NotNil(t, resource)

job, ok := resource.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)

assert.Len(t, job.Spec.PyTorchReplicaSpecs, len(test.contains))
for _, replicaType := range test.contains {
assert.Contains(t, job.Spec.PyTorchReplicaSpecs, replicaType)
}
for _, replicaType := range test.notContains {
assert.NotContains(t, job.Spec.PyTorchReplicaSpecs, replicaType)
}
})
}
}
63 changes: 34 additions & 29 deletions go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@ package tensorflow

import (
"context"
"fmt"
"time"

"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"

flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
v1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes/scheme"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"

commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/scheme"

"sigs.k8s.io/controller-runtime/pkg/client"
)

type tensorflowOperatorResourceHandler struct {
Expand Down Expand Up @@ -85,32 +86,36 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task
psReplicas := tensorflowTaskExtraArgs.GetPsReplicas()
chiefReplicas := tensorflowTaskExtraArgs.GetChiefReplicas()

if workers == 0 {
return nil, fmt.Errorf("number of worker should be more then 0")
}
if psReplicas == 0 && chiefReplicas == 0 {
return nil, fmt.Errorf("either number of chief or parameter servers needs to be be more then 0")
}

jobSpec := kubeflowv1.TFJobSpec{
TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
kubeflowv1.TFJobReplicaTypePS: {
Replicas: &psReplicas,
Template: v1.PodTemplateSpec{
ObjectMeta: objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
kubeflowv1.TFJobReplicaTypeChief: {
Replicas: &chiefReplicas,
TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{},
}

for _, t := range []struct {
podSpec v1.PodSpec
replicaNum *int32
replicaType commonOp.ReplicaType
}{
{*podSpec, &workers, kubeflowv1.TFJobReplicaTypeWorker},
{*podSpec, &psReplicas, kubeflowv1.TFJobReplicaTypePS},
{*podSpec, &chiefReplicas, kubeflowv1.TFJobReplicaTypeChief},
} {
if *t.replicaNum > 0 {
jobSpec.TFReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{
Replicas: t.replicaNum,
Template: v1.PodTemplateSpec{
ObjectMeta: objectMeta,
Spec: *podSpec,
Spec: t.podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
kubeflowv1.TFJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
},
}
}
}

job := &kubeflowv1.TFJob{
Expand Down
Loading

0 comments on commit 88e2b40

Please sign in to comment.