diff --git a/pkg/controller.v1/tensorflow/pod.go b/pkg/controller.v1/tensorflow/pod.go index 5df7f6c02e..c60539f341 100644 --- a/pkg/controller.v1/tensorflow/pod.go +++ b/pkg/controller.v1/tensorflow/pod.go @@ -216,7 +216,12 @@ func (tc *TFController) createNewPod(tfjob *tfv1.TFJob, rt, index string, spec * return nil } +// setClusterSpec generates and sets TF_CONFIG for the given podTemplateSpec. func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1.TFJob, rt, index string) error { + // Do not set TF_CONFIG for local training jobs. + if !isDistributed(tfjob) { + return nil + } // Generate TF_CONFIG JSON string. tfConfigStr, err := genTFConfigJSONStr(tfjob, rt, index) if err != nil { @@ -226,19 +231,47 @@ func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1.TFJob, rt, if tfConfigStr == "" { return nil } - // Add TF_CONFIG environment variable. + // Add TF_CONFIG environment variable to tensorflow container in the pod. for i := range podTemplateSpec.Spec.Containers { - if len(podTemplateSpec.Spec.Containers[i].Env) == 0 { - podTemplateSpec.Spec.Containers[i].Env = make([]v1.EnvVar, 0) + if podTemplateSpec.Spec.Containers[i].Name == tfv1.DefaultContainerName { + if len(podTemplateSpec.Spec.Containers[i].Env) == 0 { + podTemplateSpec.Spec.Containers[i].Env = make([]v1.EnvVar, 0) + } + podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, v1.EnvVar{ + Name: tfConfig, + Value: tfConfigStr, + }) + break } - podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, v1.EnvVar{ - Name: tfConfig, - Value: tfConfigStr, - }) } return nil } +// isDistributed returns if the TFJob is a distributed training job. +// Ref https://github.com/kubeflow/tf-operator/issues/1078. +func isDistributed(tfjob *tfv1.TFJob) bool { + replicas := tfjob.Spec.TFReplicaSpecs + distributionCount := 0 + allTypes := []tfv1.TFReplicaType{ + tfv1.TFReplicaTypeChief, + tfv1.TFReplicaTypeEval, + tfv1.TFReplicaTypeMaster, + tfv1.TFReplicaTypePS, + tfv1.TFReplicaTypeWorker, + } + // Check if there is only one replica. + for _, typ := range allTypes { + if replicas[typ] != nil { + if replicas[typ].Replicas == nil { + distributionCount++ + } else { + distributionCount += int(*replicas[typ].Replicas) + } + } + } + return distributionCount != 1 +} + func setRestartPolicy(podTemplateSpec *v1.PodTemplateSpec, spec *common.ReplicaSpec) { if spec.RestartPolicy == common.RestartPolicyExitCode { podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicyNever diff --git a/pkg/controller.v1/tensorflow/pod_test.go b/pkg/controller.v1/tensorflow/pod_test.go index cd86403a77..bc9e4c43bd 100644 --- a/pkg/controller.v1/tensorflow/pod_test.go +++ b/pkg/controller.v1/tensorflow/pod_test.go @@ -113,16 +113,14 @@ func TestClusterSpec(t *testing.T) { rt: "worker", index: "0", customClusterDomain: "", - expectedClusterSpec: `{"cluster":{"worker":["` + testutil.TestTFJobName + - `-worker-0.ns0.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + expectedClusterSpec: "", }, tc{ tfJob: testutil.NewTFJobWithNamespace(1, 0, "ns1"), rt: "worker", index: "0", customClusterDomain: "tf.training.com", - expectedClusterSpec: `{"cluster":{"worker":["` + testutil.TestTFJobName + - `-worker-0.ns1.svc.tf.training.com:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + expectedClusterSpec: "", }, tc{ tfJob: testutil.NewTFJobWithNamespace(1, 1, "ns2"), @@ -142,6 +140,15 @@ func TestClusterSpec(t *testing.T) { `-ps-0.ns3.svc.tf.training.io:2222"],"worker":["` + testutil.TestTFJobName + `-worker-0.ns3.svc.tf.training.io:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, }, + tc{ + tfJob: testutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"), + rt: "worker", + index: "0", + customClusterDomain: "", + expectedClusterSpec: `{"cluster":{"ps":["` + testutil.TestTFJobName + + `-ps-0.ns3.svc:2222"],"worker":["` + testutil.TestTFJobName + + `-worker-0.ns3.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + }, } for _, c := range testCase { os.Setenv(EnvCustomClusterDomain, c.customClusterDomain) @@ -149,9 +156,48 @@ func TestClusterSpec(t *testing.T) { if err := setClusterSpec(&demoTemplateSpec, c.tfJob, c.rt, c.index); err != nil { t.Errorf("Failed to set cluster spec: %v", err) } - actual := demoTemplateSpec.Spec.Containers[0].Env[0].Value - if c.expectedClusterSpec != actual { - t.Errorf("Expected %s, got %s", c.expectedClusterSpec, actual) + // The expected cluster spec is nil, which means that we should not set TF_CONFIG. + if c.expectedClusterSpec == "" { + if len(demoTemplateSpec.Spec.Containers[0].Env) != 0 { + t.Errorf("Expected empty TF_CONFIG, got %s", + demoTemplateSpec.Spec.Containers[0].Env[0].Value) + } + } else { + actual := demoTemplateSpec.Spec.Containers[0].Env[0].Value + if c.expectedClusterSpec != actual { + t.Errorf("Expected %s, got %s", c.expectedClusterSpec, actual) + } + } + } +} + +func TestIsDistributed(t *testing.T) { + type tc struct { + tfJob *tfv1.TFJob + expected bool + } + testCase := []tc{ + { + tfJob: testutil.NewTFJob(1, 0), + expected: false, + }, + { + tfJob: testutil.NewTFJob(1, 1), + expected: true, + }, + { + tfJob: testutil.NewTFJob(0, 1), + expected: false, + }, + { + tfJob: testutil.NewTFJobWithChief(1, 0), + expected: true, + }, + } + for _, c := range testCase { + actual := isDistributed(c.tfJob) + if actual != c.expected { + t.Errorf("Expected %t, got %t", c.expected, actual) } } }