Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
more changes
Browse files Browse the repository at this point in the history
Signed-off-by: Yuvraj <[email protected]>
  • Loading branch information
yindia committed Aug 28, 2021
1 parent a077493 commit 0b9f5b1
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 15 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ replace (

)

replace github.com/flyteorg/flyteidl => github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d
replace github.com/flyteorg/flyteidl => github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/ernesto-jimenez/gogen v0.0.0-20180125220232-d7d4131e6607/go.mod h1:Cg4fM0vhYWOZdgM7RIOSTRNIc8/VT7CXClC3Ni86lu4=
github.com/euank/go-kmsg-parser v2.0.0+incompatible/go.mod h1:MhmAMZ8V4CYH4ybgdRwPr2TU5ThnS43puaKEMpja1uw=
github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d h1:faDhAzZI2LBWSgVA5bLc5FsN6hngPtToNxjLMPPgx8w=
github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743 h1:c0TgBrZkbs9xzz1UnqHP8bGkAnInKOxQOTSxkexIfFk=
github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/evanphx/json-patch v0.0.0-20180908160633-36442dbdb585/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/evanphx/json-patch v0.0.0-20190203023257-5858425f7550/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (
PytorchTaskType = "pytorch"
)

func ExtractMPIMPICurrentCondition(jobConditions []mpiOp.JobCondition) (mpiOp.JobCondition, error) {
func ExtractMPICurrentCondition(jobConditions []mpiOp.JobCondition) (mpiOp.JobCondition, error) {
if jobConditions != nil {
sort.Slice(jobConditions, func(i, j int) bool {
return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time)
Expand Down
21 changes: 13 additions & 8 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,18 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}
workers := mpiTaskExtraArgs.GetWorkers()
launcherReplicas := mpiTaskExtraArgs.GetLauncherReplicas()
workers := mpiTaskExtraArgs.GetNumWorkers()
launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas()
slots := mpiTaskExtraArgs.GetSlots()

workersPodSpec := podSpec.DeepCopy()
for k := range workersPodSpec.Containers {
workersPodSpec.Containers[k].Args = []string{}
workersPodSpec.Containers[k].Command = []string{}
}

jobSpec := mpi.MPIJobSpec{
SlotsPerWorker: &slots,
//MainContainer: "",
MPIReplicaSpecs: map[mpi.MPIReplicaType]*commonKf.ReplicaSpec{
mpi.MPIReplicaTypeLauncher: &commonKf.ReplicaSpec{
Replicas: &launcherReplicas,
Expand All @@ -79,7 +84,7 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
mpi.MPIReplicaTypeWorker: &commonKf.ReplicaSpec{
Replicas: &workers,
Template: v1.PodTemplateSpec{
Spec: *podSpec,
Spec: *workersPodSpec,
},
RestartPolicy: commonKf.RestartPolicyNever,
},
Expand All @@ -103,15 +108,15 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app := resource.(*mpi.MPIJob)

workersCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeWorker].Replicas
launcherReplicasCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeLauncher].Replicas
numWorkers := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeWorker].Replicas
numLauncherReplicas := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeLauncher].Replicas

taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace,
*workersCount, *launcherReplicasCount, 0)
*numWorkers, *numLauncherReplicas, 0)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
currentCondition, err := common.ExtractMPIMPICurrentCondition(app.Status.Conditions)
currentCondition, err := common.ExtractMPICurrentCondition(app.Status.Conditions)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down
6 changes: 3 additions & 3 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ var (

func dummyMPICustomObj(workers int32, launcher int32, slots int32) *plugins.DistributedMPITrainingTask {
return &plugins.DistributedMPITrainingTask{
Workers: workers,
LauncherReplicas: launcher,
Slots: slots,
NumWorkers: workers,
NumLauncherReplicas: launcher,
Slots: slots,
}
}

Expand Down

0 comments on commit 0b9f5b1

Please sign in to comment.