Skip to content

Commit

Permalink
Add a field SchedulerName to TFJob for specifying a scheduler (#408)
Browse files Browse the repository at this point in the history
This commit adds a new field SchedulerName to the definition of TFJob.
The purpose of the field is specifying the scheduler name of the pods
created by tf-operator and let the scheduler (which wouldn't be the
default scheduler) handle them. It would be convenient for letting
kube-batchd (a component of kube-arbitrator) handle the pods.
  • Loading branch information
mitake authored and gaocegege committed Mar 7, 2018
1 parent 997c583 commit a4b8031
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pkg/apis/tensorflow/v1alpha1/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ type TFJobSpec struct {

// TerminationPolicy specifies the condition that the tfjob should be considered finished.
TerminationPolicy *TerminationPolicySpec `json:"terminationPolicy,omitempty"`

// SchedulerName specifies the name of scheduler which should handle the TFJob
SchedulerName string `json:"schedulerName,omitempty"`
}

type TerminationPolicySpec struct {
Expand Down
2 changes: 2 additions & 0 deletions pkg/trainer/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ func (s *TFReplicaSet) CreatePodWithIndex(index int32) (*v1.Pod, error) {
Spec: *s.Spec.Template.Spec.DeepCopy(),
}

pod.Spec.SchedulerName = s.Job.SchedulerName()

// Configure the TFCONFIG environment variable.
tfConfig := TFConfig{
Cluster: s.Job.ClusterSpec(),
Expand Down
8 changes: 8 additions & 0 deletions pkg/trainer/replicas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"encoding/json"
"fmt"
"reflect"
"strings"
"testing"
"time"

Expand All @@ -44,6 +45,8 @@ var (
func TestTFReplicaSet(t *testing.T) {
clientSet := fake.NewSimpleClientset()

testSchedulerName := "test-scheduler"

jobSpec := &tfv1alpha1.TFJob{
ObjectMeta: meta_v1.ObjectMeta{
Name: "some-job",
Expand All @@ -67,6 +70,7 @@ func TestTFReplicaSet(t *testing.T) {
TFReplicaType: tfv1alpha1.PS,
},
},
SchedulerName: testSchedulerName,
},
}

Expand Down Expand Up @@ -169,6 +173,10 @@ func TestTFReplicaSet(t *testing.T) {
t.Fatalf("Expected 1 environment variable got %v", len(c.Env))
}

if strings.Compare(p.Spec.SchedulerName, testSchedulerName) != 0 {
t.Fatalf("p.Spec.Template.Spec.SchedulerName; Got %v; want %v", p.Spec.SchedulerName, testSchedulerName)
}

actualTFConfig := &TFConfig{}
if err := json.Unmarshal([]byte(c.Env[0].Value), actualTFConfig); err != nil {
t.Fatalf("Could not unmarshal TFConfig %v", err)
Expand Down
4 changes: 4 additions & 0 deletions pkg/trainer/training.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,7 @@ func (j *TrainingJob) name() string {
func (j *TrainingJob) fullname() string {
return j.job.ObjectMeta.GetNamespace() + ":" + j.job.ObjectMeta.GetName()
}

func (j *TrainingJob) SchedulerName() string {
return j.job.Spec.SchedulerName
}

0 comments on commit a4b8031

Please sign in to comment.