diff --git a/pkg/apis/tensorflow/v1alpha2/defaults.go b/pkg/apis/tensorflow/v1alpha2/defaults.go index 75955dc95f..0f29c25c43 100644 --- a/pkg/apis/tensorflow/v1alpha2/defaults.go +++ b/pkg/apis/tensorflow/v1alpha2/defaults.go @@ -15,6 +15,8 @@ package v1alpha2 import ( + "strings" + "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" ) @@ -60,8 +62,30 @@ func setDefaultReplicas(spec *TFReplicaSpec) { } } +// setTypeNamesToCamelCase sets the name of all replica types from any case to correct case. +func setTypeNamesToCamelCase(tfJob *TFJob) { + setTypeNameToCamelCase(tfJob, TFReplicaTypePS) + setTypeNameToCamelCase(tfJob, TFReplicaTypeWorker) + setTypeNameToCamelCase(tfJob, TFReplicaTypeChief) + setTypeNameToCamelCase(tfJob, TFReplicaTypeEval) +} + +// setTypeNameToCamelCase sets the name of the replica type from any case to correct case. +// E.g. from ps to PS; from WORKER to Worker. +func setTypeNameToCamelCase(tfJob *TFJob, typ TFReplicaType) { + for t := range tfJob.Spec.TFReplicaSpecs { + if strings.ToLower(string(t)) == strings.ToLower(string(typ)) && t != typ { + spec := tfJob.Spec.TFReplicaSpecs[t] + delete(tfJob.Spec.TFReplicaSpecs, t) + tfJob.Spec.TFReplicaSpecs[typ] = spec + return + } + } +} + // SetDefaults_TFJob sets any unspecified values to defaults. func SetDefaults_TFJob(tfjob *TFJob) { + setTypeNamesToCamelCase(tfjob) for _, spec := range tfjob.Spec.TFReplicaSpecs { setDefaultReplicas(spec) setDefaultPort(&spec.Template.Spec) diff --git a/pkg/apis/tensorflow/v1alpha2/defaults_test.go b/pkg/apis/tensorflow/v1alpha2/defaults_test.go index 4f2e70e325..e8898c9941 100644 --- a/pkg/apis/tensorflow/v1alpha2/defaults_test.go +++ b/pkg/apis/tensorflow/v1alpha2/defaults_test.go @@ -56,6 +56,45 @@ func expectedTFJob() *TFJob { } } +func TestSetTypeNames(t *testing.T) { + spec := &TFReplicaSpec{ + RestartPolicy: RestartPolicyAlways, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: DefaultContainerName, + Image: testImage, + Ports: []v1.ContainerPort{ + v1.ContainerPort{ + Name: DefaultPortName, + ContainerPort: DefaultPort, + }, + }, + }, + }, + }, + }, + } + + workerUpperCase := TFReplicaType("WORKER") + original := &TFJob{ + Spec: TFJobSpec{ + TFReplicaSpecs: map[TFReplicaType]*TFReplicaSpec{ + workerUpperCase: spec, + }, + }, + } + + setTypeNamesToCamelCase(original) + if _, ok := original.Spec.TFReplicaSpecs[workerUpperCase]; ok { + t.Errorf("Failed to delete key %s", workerUpperCase) + } + if _, ok := original.Spec.TFReplicaSpecs[TFReplicaTypeWorker]; !ok { + t.Errorf("Failed to set key %s", TFReplicaTypeWorker) + } +} + func TestSetDefaultTFJob(t *testing.T) { testCases := map[string]struct { original *TFJob