Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Do not set TF_CONFIG for local training #1080

Merged
merged 2 commits into from
Sep 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions pkg/controller.v1/tensorflow/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,12 @@ func (tc *TFController) createNewPod(tfjob *tfv1.TFJob, rt, index string, spec *
return nil
}

// setClusterSpec generates and sets TF_CONFIG for the given podTemplateSpec.
func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1.TFJob, rt, index string) error {
// Do not set TF_CONFIG for local training jobs.
if !isDistributed(tfjob) {
return nil
}
// Generate TF_CONFIG JSON string.
tfConfigStr, err := genTFConfigJSONStr(tfjob, rt, index)
if err != nil {
Expand All @@ -226,19 +231,47 @@ func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1.TFJob, rt,
if tfConfigStr == "" {
return nil
}
// Add TF_CONFIG environment variable.
// Add TF_CONFIG environment variable to tensorflow container in the pod.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch might be needed in Pytorch also.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that you said it does not have a side effect in PyTorch. Do we need it?

for i := range podTemplateSpec.Spec.Containers {
if len(podTemplateSpec.Spec.Containers[i].Env) == 0 {
podTemplateSpec.Spec.Containers[i].Env = make([]v1.EnvVar, 0)
if podTemplateSpec.Spec.Containers[i].Name == tfv1.DefaultContainerName {
if len(podTemplateSpec.Spec.Containers[i].Env) == 0 {
podTemplateSpec.Spec.Containers[i].Env = make([]v1.EnvVar, 0)
}
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, v1.EnvVar{
Name: tfConfig,
Value: tfConfigStr,
})
break
}
podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, v1.EnvVar{
Name: tfConfig,
Value: tfConfigStr,
})
}
return nil
}

// isDistributed returns if the TFJob is a distributed training job.
// Ref https://github.com/kubeflow/tf-operator/issues/1078.
func isDistributed(tfjob *tfv1.TFJob) bool {
replicas := tfjob.Spec.TFReplicaSpecs
distributionCount := 0
allTypes := []tfv1.TFReplicaType{
tfv1.TFReplicaTypeChief,
tfv1.TFReplicaTypeEval,
tfv1.TFReplicaTypeMaster,
tfv1.TFReplicaTypePS,
tfv1.TFReplicaTypeWorker,
}
// Check if there is only one replica.
for _, typ := range allTypes {
if replicas[typ] != nil {
if replicas[typ].Replicas == nil {
distributionCount++
} else {
distributionCount += int(*replicas[typ].Replicas)
}
}
}
return distributionCount != 1
}

func setRestartPolicy(podTemplateSpec *v1.PodTemplateSpec, spec *common.ReplicaSpec) {
if spec.RestartPolicy == common.RestartPolicyExitCode {
podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicyNever
Expand Down
60 changes: 53 additions & 7 deletions pkg/controller.v1/tensorflow/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,14 @@ func TestClusterSpec(t *testing.T) {
rt: "worker",
index: "0",
customClusterDomain: "",
expectedClusterSpec: `{"cluster":{"worker":["` + testutil.TestTFJobName +
`-worker-0.ns0.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
expectedClusterSpec: "",
},
tc{
tfJob: testutil.NewTFJobWithNamespace(1, 0, "ns1"),
rt: "worker",
index: "0",
customClusterDomain: "tf.training.com",
expectedClusterSpec: `{"cluster":{"worker":["` + testutil.TestTFJobName +
`-worker-0.ns1.svc.tf.training.com:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
expectedClusterSpec: "",
},
tc{
tfJob: testutil.NewTFJobWithNamespace(1, 1, "ns2"),
Expand All @@ -142,16 +140,64 @@ func TestClusterSpec(t *testing.T) {
`-ps-0.ns3.svc.tf.training.io:2222"],"worker":["` + testutil.TestTFJobName +
`-worker-0.ns3.svc.tf.training.io:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
},
tc{
tfJob: testutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"),
rt: "worker",
index: "0",
customClusterDomain: "",
expectedClusterSpec: `{"cluster":{"ps":["` + testutil.TestTFJobName +
`-ps-0.ns3.svc:2222"],"worker":["` + testutil.TestTFJobName +
`-worker-0.ns3.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
},
}
for _, c := range testCase {
os.Setenv(EnvCustomClusterDomain, c.customClusterDomain)
demoTemplateSpec := c.tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].Template
if err := setClusterSpec(&demoTemplateSpec, c.tfJob, c.rt, c.index); err != nil {
t.Errorf("Failed to set cluster spec: %v", err)
}
actual := demoTemplateSpec.Spec.Containers[0].Env[0].Value
if c.expectedClusterSpec != actual {
t.Errorf("Expected %s, got %s", c.expectedClusterSpec, actual)
// The expected cluster spec is nil, which means that we should not set TF_CONFIG.
if c.expectedClusterSpec == "" {
if len(demoTemplateSpec.Spec.Containers[0].Env) != 0 {
t.Errorf("Expected empty TF_CONFIG, got %s",
demoTemplateSpec.Spec.Containers[0].Env[0].Value)
}
} else {
actual := demoTemplateSpec.Spec.Containers[0].Env[0].Value
if c.expectedClusterSpec != actual {
t.Errorf("Expected %s, got %s", c.expectedClusterSpec, actual)
}
}
}
}

func TestIsDistributed(t *testing.T) {
type tc struct {
tfJob *tfv1.TFJob
expected bool
}
testCase := []tc{
{
tfJob: testutil.NewTFJob(1, 0),
expected: false,
},
{
tfJob: testutil.NewTFJob(1, 1),
expected: true,
},
{
tfJob: testutil.NewTFJob(0, 1),
expected: false,
},
{
tfJob: testutil.NewTFJobWithChief(1, 0),
expected: true,
},
}
for _, c := range testCase {
actual := isDistributed(c.tfJob)
if actual != c.expected {
t.Errorf("Expected %t, got %t", c.expected, actual)
}
}
}
Expand Down