diff --git a/pkg/spec/tf_job.go b/pkg/spec/tf_job.go index a4fd7a415a..66bf160f24 100644 --- a/pkg/spec/tf_job.go +++ b/pkg/spec/tf_job.go @@ -127,7 +127,11 @@ type ChiefSpec struct { // Validate checks that the TfJobSpec is valid. func (c *TfJobSpec) Validate() error { + if c.TerminationPolicy == nil || c.TerminationPolicy.Chief == nil { + return fmt.Errorf("invalid termination policy: %v", c.TerminationPolicy) + } // Check that each replica has a TensorFlow container. + chiefExists := false for _, r := range c.ReplicaSpecs { found := false if r.Template == nil && r.TfReplicaType != PS { @@ -138,6 +142,10 @@ func (c *TfJobSpec) Validate() error { return errors.New("The MASTER must have Replicas = 1") } + if r.TfReplicaType == TfReplicaType(c.TerminationPolicy.Chief.ReplicaName) { + chiefExists = true + } + if r.TfPort == nil { return errors.New("tfReplicaSpec.TfPort can't be nil.") } @@ -167,14 +175,11 @@ func (c *TfJobSpec) Validate() error { return fmt.Errorf("Replica type %v is missing a container named %v", r.TfReplicaType, TENSORFLOW) } } - if c.TerminationPolicy != nil { - if c.TerminationPolicy.Chief == nil { - return errors.New("invalid termination policy, Chief cannot be nil") - } - if c.TerminationPolicy.Chief.ReplicaName != "MASTER" || c.TerminationPolicy.Chief.ReplicaIndex != 0 { - return errors.New("invalid termination policy, Chief should have replicaName=MASTER and index=0") - } + + if !chiefExists { + return fmt.Errorf("Missing ReplicaSpec for chief: %v", c.TerminationPolicy.Chief.ReplicaName) } + return nil } diff --git a/pkg/spec/tf_job_test.go b/pkg/spec/tf_job_test.go index 4601646986..18ef0c3786 100644 --- a/pkg/spec/tf_job_test.go +++ b/pkg/spec/tf_job_test.go @@ -383,3 +383,89 @@ func TestSetDefaults(t *testing.T) { }) } } + +func TestValidate(t *testing.T) { + type testCase struct { + in *TfJobSpec + expectingError bool + } + + testCases := []testCase{ + { + in: &TfJobSpec{ + ReplicaSpecs: []*TfReplicaSpec{ + { + Template: &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "tensorflow", + }, + }, + }, + }, + TfReplicaType: MASTER, + Replicas: proto.Int32(1), + }, + }, + TfImage: "tensorflow/tensorflow:1.3.0", + }, + expectingError: false, + }, + { + in: &TfJobSpec{ + ReplicaSpecs: []*TfReplicaSpec{ + { + Template: &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "tensorflow", + }, + }, + }, + }, + TfReplicaType: WORKER, + Replicas: proto.Int32(1), + }, + }, + TfImage: "tensorflow/tensorflow:1.3.0", + }, + expectingError: true, + }, + { + in: &TfJobSpec{ + ReplicaSpecs: []*TfReplicaSpec{ + { + Template: &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "tensorflow", + }, + }, + }, + }, + TfReplicaType: WORKER, + Replicas: proto.Int32(1), + }, + }, + TfImage: "tensorflow/tensorflow:1.3.0", + TerminationPolicy: &TerminationPolicySpec{ + Chief: &ChiefSpec{ + ReplicaName: "WORKER", + ReplicaIndex: 0, + }, + }, + }, + expectingError: false, + }, + } + + for _, c := range testCases { + c.in.SetDefaults("") + if err := c.in.Validate(); (err != nil) != c.expectingError { + t.Errorf("unexpected validation result: %v", err) + } + } +} diff --git a/pkg/trainer/replicas.go b/pkg/trainer/replicas.go index 5e552a321b..7b46748ad2 100644 --- a/pkg/trainer/replicas.go +++ b/pkg/trainer/replicas.go @@ -406,6 +406,40 @@ func replicaStatusFromPodList(l v1.PodList, name spec.ContainerName) spec.Replic return spec.ReplicaStateUnknown } +func (s *TFReplicaSet) GetSingleReplicaStatus(index int32) (spec.ReplicaState) { + j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.Metadata.Namespace).Get(s.jobName(index), meta_v1.GetOptions{}) + + if err != nil { + return spec.ReplicaStateUnknown + } + + if j.Status.Succeeded >= 1 { + return spec.ReplicaStateSucceeded + } + + labels := s.Labels() + labels["task_index"] = fmt.Sprintf("%v", index) + selector, err := labels.ToSelector() + if err != nil { + log.Errorf("labels.ToSelector() error; %v", err) + return spec.ReplicaStateFailed + } + + // TODO(jlewi): Handle errors. We need to get the pod and looking at recent container exits. + l, err := s.ClientSet.CoreV1().Pods(s.Job.job.Metadata.Namespace).List(meta_v1.ListOptions{ + // TODO(jlewi): Why isn't the label selector working? + LabelSelector: selector, + }) + + if err != nil { + // TODO(jlewi): Are there errors that should be treated as retryable errors? + return spec.ReplicaStateFailed + } + + status := replicaStatusFromPodList(*l, spec.TENSORFLOW) + return status +} + // Status returns the status of the replica set. func (s *TFReplicaSet) GetStatus() (spec.TfReplicaStatus, error) { @@ -425,42 +459,7 @@ func (s *TFReplicaSet) GetStatus() (spec.TfReplicaStatus, error) { } for index := int32(0); index < *s.Spec.Replicas; index++ { - - j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.Metadata.Namespace).Get(s.jobName(index), meta_v1.GetOptions{}) - - if err != nil { - increment(spec.ReplicaStateUnknown) - continue - } - - if j.Status.Succeeded >= 1 { - increment(spec.ReplicaStateSucceeded) - continue - } - - labels := s.Labels() - labels["task_index"] = fmt.Sprintf("%v", index) - selector, err := labels.ToSelector() - if err != nil { - log.Errorf("labels.ToSelector() error; %v", err) - increment(spec.ReplicaStateFailed) - continue - } - - // TODO(jlewi): Handle errors. We need to get the pod and looking at recent container exits. - l, err := s.ClientSet.CoreV1().Pods(s.Job.job.Metadata.Namespace).List(meta_v1.ListOptions{ - // TODO(jlewi): Why isn't the label selector working? - LabelSelector: selector, - }) - - if err != nil { - // TODO(jlewi): Are there errors that should be treated as retryable errors? - increment(spec.ReplicaStateFailed) - continue - } - - status := replicaStatusFromPodList(*l, spec.TENSORFLOW) - increment(status) + increment(s.GetSingleReplicaStatus(index)) } // Determine the overall status for the replica set based on the status of the individual diff --git a/pkg/trainer/training.go b/pkg/trainer/training.go index 40a1cd5bfe..c5fcdfb39b 100644 --- a/pkg/trainer/training.go +++ b/pkg/trainer/training.go @@ -160,8 +160,9 @@ func (j *TrainingJob) deleteResources() error { return nil } -func (j *TrainingJob) GetStatus() (spec.State, []*spec.TfReplicaStatus, error) { - state := spec.StateUnknown +func (j *TrainingJob) GetStatus() (spec.ReplicaState, []*spec.TfReplicaStatus, error) { + chief := j.job.Spec.TerminationPolicy.Chief + chiefState := spec.ReplicaStateUnknown replicaStatuses := make([]*spec.TfReplicaStatus, 0) // The state for each replica. @@ -178,24 +179,12 @@ func (j *TrainingJob) GetStatus() (spec.State, []*spec.TfReplicaStatus, error) { replicaStatuses = append(replicaStatuses, &rStatus) - // If any replicas are failed mark job as failed. - if rStatus.State == spec.ReplicaStateFailed { - state = spec.StateFailed + if string(r.Spec.TfReplicaType) == string(chief.ReplicaName) { + chiefState = r.GetSingleReplicaStatus(int32(chief.ReplicaIndex)) } } - if v, ok := replicaSetStates[spec.MASTER]; ok && v == spec.ReplicaStateSucceeded { - state = spec.StateSucceeded - return state, replicaStatuses, nil - } - - if v, ok := replicaSetStates[spec.MASTER]; ok && v == spec.ReplicaStateFailed { - state = spec.StateFailed - return state, replicaStatuses, nil - } - - state = spec.StateRunning - return state, replicaStatuses, nil + return chiefState, replicaStatuses, nil } // isRetryableTerminationState returns true if a container terminated in a state @@ -373,11 +362,11 @@ func (j *TrainingJob) reconcile(config *spec.ControllerConfig) { log.Errorf("GetStatus() for job %v returned error: %v", j.job.Metadata.Name, err) } // TODO(jlewi): We should update the Phase if we detect the job is done. - if state == spec.StateFailed { + if state == spec.ReplicaStateFailed { log.Errorf("Master failed Job: %v.", j.job.Metadata.Name) j.status.SetPhase(spec.TfJobPhaseDone) j.status.SetState(spec.StateFailed) - } else if state == spec.StateSucceeded { + } else if state == spec.ReplicaStateSucceeded { log.Infof("Master succeeded Job: %v.", j.job.Metadata.Name) j.status.SetPhase(spec.TfJobPhaseDone) j.status.SetState(spec.StateSucceeded) diff --git a/pkg/trainer/training_test.go b/pkg/trainer/training_test.go index f98b717c71..b4cf74ffab 100644 --- a/pkg/trainer/training_test.go +++ b/pkg/trainer/training_test.go @@ -202,6 +202,20 @@ func TestJobSetup(t *testing.T) { }, TfReplicaType: spec.PS, }, + { + Replicas: proto.Int32(1), + TfPort: proto.Int32(10), + Template: &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "tensorflow", + }, + }, + }, + }, + TfReplicaType: spec.MASTER, + }, }, }, }, @@ -232,6 +246,25 @@ func TestJobSetup(t *testing.T) { }, TfReplicaType: spec.PS, }, + { + Replicas: proto.Int32(1), + TfPort: proto.Int32(10), + Template: &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "tensorflow", + Resources: v1.ResourceRequirements{ + Requests: map[v1.ResourceName]resource.Quantity{ + "nvidia-gpu": resource.MustParse("1"), + }, + }, + }, + }, + }, + }, + TfReplicaType: spec.MASTER, + }, }, }, }, @@ -263,6 +296,25 @@ func TestJobSetup(t *testing.T) { }, TfReplicaType: spec.PS, }, + { + Replicas: proto.Int32(1), + TfPort: proto.Int32(10), + Template: &v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "tensorflow", + Resources: v1.ResourceRequirements{ + Requests: map[v1.ResourceName]resource.Quantity{ + "nvidia-gpu": resource.MustParse("1"), + }, + }, + }, + }, + }, + }, + TfReplicaType: spec.MASTER, + }, }, TensorBoard: &spec.TensorBoardSpec{}, },