Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kf operators use GetReplicaFunc (Error Handling) #4471

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -266,27 +266,27 @@
return nil
}

func ToReplicaSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string) (*commonOp.ReplicaSpec, error) {
podSpec, objectMeta, oldPrimaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

Check warning on line 273 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L269-L273

Added lines #L269 - L273 were not covered by tests

OverridePrimaryContainerName(podSpec, oldPrimaryContainerName, primaryContainerName)

cfg := config.GetK8sPluginConfig()
objectMeta.Annotations = utils.UnionMaps(cfg.DefaultAnnotations, objectMeta.Annotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
objectMeta.Labels = utils.UnionMaps(cfg.DefaultLabels, objectMeta.Labels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))

replicas := int32(0)
return &commonOp.ReplicaSpec{
Replicas: &replicas,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
}, nil

Check warning on line 289 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L275-L289

Added lines #L275 - L289 were not covered by tests
}

type kfDistributedReplicaSpec interface {
Expand All @@ -300,48 +300,56 @@
GetCommand() []string
}

func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, rs kfDistributedReplicaSpec, primaryContainerName string, isMaster bool) (*commonOp.ReplicaSpec, error) {
taskCtxOptions := []flytek8s.PluginTaskExecutionContextOption{}
if rs != nil && rs.GetResources() != nil {
resources, err := flytek8s.ToK8sResourceRequirements(rs.GetResources())
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error())
}
taskCtxOptions = append(taskCtxOptions, flytek8s.WithResources(resources))

Check warning on line 310 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L303-L310

Added lines #L303 - L310 were not covered by tests
}
newTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, taskCtxOptions...)
replicaSpec, err := ToReplicaSpec(ctx, newTaskCtx, primaryContainerName)
if err != nil {
return nil, err
}

Check warning on line 316 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L312-L316

Added lines #L312 - L316 were not covered by tests

// Master should have a single replica
if isMaster {
replicas := int32(1)
replicaSpec.Replicas = &replicas
}

Check warning on line 322 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L319-L322

Added lines #L319 - L322 were not covered by tests

if rs != nil {
var command []string
if v, ok := rs.(allowsCommandOverride); ok {
command = v.GetCommand()
}
if err := OverrideContainerSpec(
&replicaSpec.Template.Spec,
primaryContainerName,
rs.GetImage(),
command,
); err != nil {
return nil, err
}

Check warning on line 336 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L324-L336

Added lines #L324 - L336 were not covered by tests

replicaSpec.RestartPolicy = ParseRestartPolicy(rs.GetRestartPolicy())

if !isMaster {
replicas := rs.GetReplicas()
replicaSpec.Replicas = &replicas
}

Check warning on line 343 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L338-L343

Added lines #L338 - L343 were not covered by tests
}

return replicaSpec, nil

Check warning on line 346 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L346

Added line #L346 was not covered by tests
}

func GetReplicaCount(specs map[commonOp.ReplicaType]*commonOp.ReplicaSpec, replicaType commonOp.ReplicaType) *int32 {
if spec, ok := specs[replicaType]; ok && spec.Replicas != nil {
return spec.Replicas
}

Check warning on line 352 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L349-L352

Added lines #L349 - L352 were not covered by tests

return new(int32) // return 0 as default value

Check warning on line 354 in flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go#L354

Added line #L354 was not covered by tests
}
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@

replicaSpec, err := common.ToReplicaSpec(ctx, taskCtx, kubeflowv1.MPIJobDefaultContainerName)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error())
}

Check warning on line 73 in flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go#L72-L73

Added lines #L72 - L73 were not covered by tests
launcherReplicaSpec = replicaSpec.DeepCopy()
// TODO (jeev): Is this even a valid configuration. Can there be more than 1
// launcher? TaskTypeVersion 1 does not support overriding this value.
Expand Down Expand Up @@ -108,12 +108,12 @@

launcherReplicaSpec, err = common.ToReplicaSpecWithOverrides(ctx, taskCtx, kfMPITaskExtraArgs.GetLauncherReplicas(), kubeflowv1.MPIJobDefaultContainerName, true)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create launcher replica spec: [%v]", err.Error())

Check warning on line 111 in flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go#L111

Added line #L111 was not covered by tests
}

workerReplicaSpec, err = common.ToReplicaSpecWithOverrides(ctx, taskCtx, kfMPITaskExtraArgs.GetWorkerReplicas(), kubeflowv1.MPIJobDefaultContainerName, false)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create worker replica spec: [%v]", err.Error())

Check warning on line 116 in flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go#L116

Added line #L116 was not covered by tests
}

if kfMPITaskExtraArgs.GetRunPolicy() != nil {
Expand All @@ -129,7 +129,7 @@
return nil, fmt.Errorf("number of workers must be greater than 0")
}
if *launcherReplicaSpec.Replicas <= 0 {
return nil, fmt.Errorf("number of launchers must be greater than 0")

Check warning on line 132 in flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go#L132

Added line #L132 was not covered by tests
}

jobSpec := kubeflowv1.MPIJobSpec{
Expand Down Expand Up @@ -162,8 +162,8 @@
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
}

numWorkers = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas
numLauncherReplicas = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas
numWorkers = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeWorker)
numLauncherReplicas = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeLauncher)

taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, false,
*numWorkers, *numLauncherReplicas, 0, 0)
Expand Down
14 changes: 14 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,17 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) {
assert.NotContains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Tolerations, gpuToleration)
assert.Contains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration)
}

func TestGetReplicaCount(t *testing.T) {
mpiResourceHandler := mpiOperatorResourceHandler{}
tfObj := dummyMPICustomObj(1, 1, 0)
taskTemplate := dummyMPITaskTemplate("the job", tfObj)
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
assert.NoError(t, err)
assert.NotNil(t, resource)
MPIJob, ok := resource.(*kubeflowv1.MPIJob)
assert.True(t, ok)

assert.NotNil(t, common.GetReplicaCount(MPIJob.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeWorker))
assert.NotNil(t, common.GetReplicaCount(MPIJob.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeLauncher))
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@

replicaSpec, err := common.ToReplicaSpec(ctx, taskCtx, kubeflowv1.PytorchJobDefaultContainerName)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error())
}

Check warning on line 71 in flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go#L70-L71

Added lines #L70 - L71 were not covered by tests
masterReplicaSpec = replicaSpec.DeepCopy()
masterReplicas := int32(1)
masterReplicaSpec.Replicas = &masterReplicas
Expand All @@ -91,12 +91,12 @@

masterReplicaSpec, err = common.ToReplicaSpecWithOverrides(ctx, taskCtx, kfPytorchTaskExtraArgs.GetMasterReplicas(), kubeflowv1.PytorchJobDefaultContainerName, true)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create master replica spec: [%v]", err.Error())

Check warning on line 94 in flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go#L94

Added line #L94 was not covered by tests
}

workerReplicaSpec, err = common.ToReplicaSpecWithOverrides(ctx, taskCtx, kfPytorchTaskExtraArgs.GetWorkerReplicas(), kubeflowv1.PytorchJobDefaultContainerName, false)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create worker replica spec: [%v]", err.Error())

Check warning on line 99 in flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go#L99

Added line #L99 was not covered by tests
}

if kfPytorchTaskExtraArgs.GetRunPolicy() != nil {
Expand Down Expand Up @@ -171,15 +171,18 @@
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app := resource.(*kubeflowv1.PyTorchJob)
app, ok := resource.(*kubeflowv1.PyTorchJob)
if !ok {
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
}

Check warning on line 177 in flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go#L176-L177

Added lines #L176 - L177 were not covered by tests

// Elastic PytorchJobs don't use master replicas
hasMaster := false
if _, ok := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok {
hasMaster = true
}

workersCount := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas
workersCount := common.GetReplicaCount(app.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker)

taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0, 0)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,16 @@ func TestParseElasticConfig(t *testing.T) {
assert.Equal(t, int32(4), *elasticPolicy.NProcPerNode)
assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *elasticPolicy.RDZVBackend)
}

func TestGetReplicaCount(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}
tfObj := dummyPytorchCustomObj(1)
taskTemplate := dummyPytorchTaskTemplate("the job", tfObj)
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil))
assert.NoError(t, err)
assert.NotNil(t, resource)
PytorchJob, ok := resource.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)

assert.NotNil(t, common.GetReplicaCount(PytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker))
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@

replicaSpec, err := common.ToReplicaSpec(ctx, taskCtx, kubeflowv1.TFJobDefaultContainerName)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error())
}

Check warning on line 69 in flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go#L68-L69

Added lines #L68 - L69 were not covered by tests

replicaNumMap := map[commonOp.ReplicaType]int32{
kubeflowv1.TFJobReplicaTypeChief: tensorflowTaskExtraArgs.GetChiefReplicas(),
Expand Down Expand Up @@ -105,7 +105,7 @@
}
rs, err := common.ToReplicaSpecWithOverrides(ctx, taskCtx, cfg, kubeflowv1.TFJobDefaultContainerName, false)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create replica spec: [%v]", err.Error())

Check warning on line 108 in flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go#L108

Added line #L108 was not covered by tests
}
replicaSpecMap[t] = rs
}
Expand Down Expand Up @@ -139,24 +139,19 @@
return job, nil
}

func getReplicaCount(specs map[commonOp.ReplicaType]*commonOp.ReplicaSpec, replicaType commonOp.ReplicaType) *int32 {
if spec, ok := specs[replicaType]; ok && spec.Replicas != nil {
return spec.Replicas
}

return new(int32) // return 0 as default value
}

// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast,
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app := resource.(*kubeflowv1.TFJob)
app, ok := resource.(*kubeflowv1.TFJob)
if !ok {
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
}

Check warning on line 149 in flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go#L148-L149

Added lines #L148 - L149 were not covered by tests

workersCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker)
psReplicasCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)
chiefCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)
evaluatorReplicasCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval)
workersCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker)
psReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)
chiefCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)
evaluatorReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval)

taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, false,
*workersCount, *psReplicasCount, *chiefCount, *evaluatorReplicasCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ func TestGetReplicaCount(t *testing.T) {
tensorflowJob, ok := resource.(*kubeflowv1.TFJob)
assert.True(t, ok)

assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval))
assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker))
assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS))
assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief))
assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval))
}

func TestBuildResourceTensorFlow(t *testing.T) {
Expand Down
Loading