Skip to content

Commit

Permalink
feat: Do not set TF_CONFIG for local training (#1080)
Browse files Browse the repository at this point in the history
* feat: Do not set TF_CONFIG for local training

Signed-off-by: Ce Gao <[email protected]>

* fix: Fix lint

Signed-off-by: Ce Gao <[email protected]>
  • Loading branch information
gaocegege authored and k8s-ci-robot committed Sep 12, 2019
1 parent 5c0a06b commit b96dcd7
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 14 deletions.
47 changes: 40 additions & 7 deletions pkg/controller.v1/tensorflow/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,12 @@ func (tc *TFController) createNewPod(tfjob *tfv1.TFJob, rt, index string, spec *
return nil
}

// setClusterSpec generates and sets TF_CONFIG for the given podTemplateSpec.
func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1.TFJob, rt, index string) error {
// Do not set TF_CONFIG for local training jobs.
if !isDistributed(tfjob) {
return nil
}
// Generate TF_CONFIG JSON string.
tfConfigStr, err := genTFConfigJSONStr(tfjob, rt, index)
if err != nil {
Expand All @@ -226,19 +231,47 @@ func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1.TFJob, rt,
if tfConfigStr == "" {
return nil
}
// Add TF_CONFIG environment variable.
// Add TF_CONFIG environment variable to tensorflow container in the pod.
for i := range podTemplateSpec.Spec.Containers {
if len(podTemplateSpec.Spec.Containers[i].Env) == 0 {
podTemplateSpec.Spec.Containers[i].Env = make([]v1.EnvVar, 0)
if podTemplateSpec.Spec.Containers[i].Name == tfv1.DefaultContainerName {
if len(podTemplateSpec.Spec.Containers[i].Env) == 0 {
podTemplateSpec.Spec.Containers[i].Env = make([]v1.EnvVar, 0)
}
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, v1.EnvVar{
Name: tfConfig,
Value: tfConfigStr,
})
break
}
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, v1.EnvVar{
Name: tfConfig,
Value: tfConfigStr,
})
}
return nil
}

// isDistributed returns if the TFJob is a distributed training job.
// Ref https://github.com/kubeflow/tf-operator/issues/1078.
func isDistributed(tfjob *tfv1.TFJob) bool {
replicas := tfjob.Spec.TFReplicaSpecs
distributionCount := 0
allTypes := []tfv1.TFReplicaType{
tfv1.TFReplicaTypeChief,
tfv1.TFReplicaTypeEval,
tfv1.TFReplicaTypeMaster,
tfv1.TFReplicaTypePS,
tfv1.TFReplicaTypeWorker,
}
// Check if there is only one replica.
for _, typ := range allTypes {
if replicas[typ] != nil {
if replicas[typ].Replicas == nil {
distributionCount++
} else {
distributionCount += int(*replicas[typ].Replicas)
}
}
}
return distributionCount != 1
}

func setRestartPolicy(podTemplateSpec *v1.PodTemplateSpec, spec *common.ReplicaSpec) {
if spec.RestartPolicy == common.RestartPolicyExitCode {
podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicyNever
Expand Down
60 changes: 53 additions & 7 deletions pkg/controller.v1/tensorflow/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,14 @@ func TestClusterSpec(t *testing.T) {
rt: "worker",
index: "0",
customClusterDomain: "",
expectedClusterSpec: `{"cluster":{"worker":["` + testutil.TestTFJobName +
`-worker-0.ns0.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
expectedClusterSpec: "",
},
tc{
tfJob: testutil.NewTFJobWithNamespace(1, 0, "ns1"),
rt: "worker",
index: "0",
customClusterDomain: "tf.training.com",
expectedClusterSpec: `{"cluster":{"worker":["` + testutil.TestTFJobName +
`-worker-0.ns1.svc.tf.training.com:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
expectedClusterSpec: "",
},
tc{
tfJob: testutil.NewTFJobWithNamespace(1, 1, "ns2"),
Expand All @@ -142,16 +140,64 @@ func TestClusterSpec(t *testing.T) {
`-ps-0.ns3.svc.tf.training.io:2222"],"worker":["` + testutil.TestTFJobName +
`-worker-0.ns3.svc.tf.training.io:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
},
tc{
tfJob: testutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"),
rt: "worker",
index: "0",
customClusterDomain: "",
expectedClusterSpec: `{"cluster":{"ps":["` + testutil.TestTFJobName +
`-ps-0.ns3.svc:2222"],"worker":["` + testutil.TestTFJobName +
`-worker-0.ns3.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
},
}
for _, c := range testCase {
os.Setenv(EnvCustomClusterDomain, c.customClusterDomain)
demoTemplateSpec := c.tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].Template
if err := setClusterSpec(&demoTemplateSpec, c.tfJob, c.rt, c.index); err != nil {
t.Errorf("Failed to set cluster spec: %v", err)
}
actual := demoTemplateSpec.Spec.Containers[0].Env[0].Value
if c.expectedClusterSpec != actual {
t.Errorf("Expected %s, got %s", c.expectedClusterSpec, actual)
// The expected cluster spec is nil, which means that we should not set TF_CONFIG.
if c.expectedClusterSpec == "" {
if len(demoTemplateSpec.Spec.Containers[0].Env) != 0 {
t.Errorf("Expected empty TF_CONFIG, got %s",
demoTemplateSpec.Spec.Containers[0].Env[0].Value)
}
} else {
actual := demoTemplateSpec.Spec.Containers[0].Env[0].Value
if c.expectedClusterSpec != actual {
t.Errorf("Expected %s, got %s", c.expectedClusterSpec, actual)
}
}
}
}

func TestIsDistributed(t *testing.T) {
type tc struct {
tfJob *tfv1.TFJob
expected bool
}
testCase := []tc{
{
tfJob: testutil.NewTFJob(1, 0),
expected: false,
},
{
tfJob: testutil.NewTFJob(1, 1),
expected: true,
},
{
tfJob: testutil.NewTFJob(0, 1),
expected: false,
},
{
tfJob: testutil.NewTFJobWithChief(1, 0),
expected: true,
},
}
for _, c := range testCase {
actual := isDistributed(c.tfJob)
if actual != c.expected {
t.Errorf("Expected %t, got %t", c.expected, actual)
}
}
}
Expand Down

0 comments on commit b96dcd7

Please sign in to comment.