Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow using WORKER:0 as chief #221

Merged
merged 19 commits into from
Dec 20, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions pkg/spec/tf_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ type ChiefSpec struct {

// Validate checks that the TfJobSpec is valid.
func (c *TfJobSpec) Validate() error {
if c.TerminationPolicy == nil || c.TerminationPolicy.Chief == nil {
return fmt.Errorf("invalid termination policy: %v", c.TerminationPolicy)
}
// Check that each replica has a TensorFlow container.
chiefExists := false
for _, r := range c.ReplicaSpecs {
found := false
if r.Template == nil && r.TfReplicaType != PS {
Expand All @@ -138,6 +142,10 @@ func (c *TfJobSpec) Validate() error {
return errors.New("The MASTER must have Replicas = 1")
}

if r.TfReplicaType == TfReplicaType(c.TerminationPolicy.Chief.ReplicaName) {
chiefExists = true
}

if r.TfPort == nil {
return errors.New("tfReplicaSpec.TfPort can't be nil.")
}
Expand Down Expand Up @@ -167,14 +175,11 @@ func (c *TfJobSpec) Validate() error {
return fmt.Errorf("Replica type %v is missing a container named %v", r.TfReplicaType, TENSORFLOW)
}
}
if c.TerminationPolicy != nil {
if c.TerminationPolicy.Chief == nil {
return errors.New("invalid termination policy, Chief cannot be nil")
}
if c.TerminationPolicy.Chief.ReplicaName != "MASTER" || c.TerminationPolicy.Chief.ReplicaIndex != 0 {
return errors.New("invalid termination policy, Chief should have replicaName=MASTER and index=0")
}

if !chiefExists {
return fmt.Errorf("Missing ReplicaSpec for chief: %v", c.TerminationPolicy.Chief.ReplicaName)
}

return nil
}

Expand Down
86 changes: 86 additions & 0 deletions pkg/spec/tf_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,89 @@ func TestSetDefaults(t *testing.T) {
})
}
}

func TestValidate(t *testing.T) {
type testCase struct {
in *TfJobSpec
expectingError bool
}

testCases := []testCase{
{
in: &TfJobSpec{
ReplicaSpecs: []*TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: MASTER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
},
expectingError: false,
},
{
in: &TfJobSpec{
ReplicaSpecs: []*TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: WORKER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
},
expectingError: true,
},
{
in: &TfJobSpec{
ReplicaSpecs: []*TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: WORKER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
TerminationPolicy: &TerminationPolicySpec{
Chief: &ChiefSpec{
ReplicaName: "WORKER",
ReplicaIndex: 0,
},
},
},
expectingError: false,
},
}

for _, c := range testCases {
c.in.SetDefaults("")
if err := c.in.Validate(); (err != nil) != c.expectingError {
t.Errorf("unexpected validation result: %v", err)
}
}
}
71 changes: 35 additions & 36 deletions pkg/trainer/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,40 @@ func replicaStatusFromPodList(l v1.PodList, name spec.ContainerName) spec.Replic
return spec.ReplicaStateUnknown
}

func (s *TFReplicaSet) GetSingleReplicaStatus(index int32) (spec.ReplicaState) {
j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.Metadata.Namespace).Get(s.jobName(index), meta_v1.GetOptions{})

if err != nil {
return spec.ReplicaStateUnknown
}

if j.Status.Succeeded >= 1 {
return spec.ReplicaStateSucceeded
}

labels := s.Labels()
labels["task_index"] = fmt.Sprintf("%v", index)
selector, err := labels.ToSelector()
if err != nil {
log.Errorf("labels.ToSelector() error; %v", err)
return spec.ReplicaStateFailed
}

// TODO(jlewi): Handle errors. We need to get the pod and looking at recent container exits.
l, err := s.ClientSet.CoreV1().Pods(s.Job.job.Metadata.Namespace).List(meta_v1.ListOptions{
// TODO(jlewi): Why isn't the label selector working?
LabelSelector: selector,
})

if err != nil {
// TODO(jlewi): Are there errors that should be treated as retryable errors?
return spec.ReplicaStateFailed
}

status := replicaStatusFromPodList(*l, spec.TENSORFLOW)
return status
}

// Status returns the status of the replica set.
func (s *TFReplicaSet) GetStatus() (spec.TfReplicaStatus, error) {

Expand All @@ -425,42 +459,7 @@ func (s *TFReplicaSet) GetStatus() (spec.TfReplicaStatus, error) {
}

for index := int32(0); index < *s.Spec.Replicas; index++ {

j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.Metadata.Namespace).Get(s.jobName(index), meta_v1.GetOptions{})

if err != nil {
increment(spec.ReplicaStateUnknown)
continue
}

if j.Status.Succeeded >= 1 {
increment(spec.ReplicaStateSucceeded)
continue
}

labels := s.Labels()
labels["task_index"] = fmt.Sprintf("%v", index)
selector, err := labels.ToSelector()
if err != nil {
log.Errorf("labels.ToSelector() error; %v", err)
increment(spec.ReplicaStateFailed)
continue
}

// TODO(jlewi): Handle errors. We need to get the pod and looking at recent container exits.
l, err := s.ClientSet.CoreV1().Pods(s.Job.job.Metadata.Namespace).List(meta_v1.ListOptions{
// TODO(jlewi): Why isn't the label selector working?
LabelSelector: selector,
})

if err != nil {
// TODO(jlewi): Are there errors that should be treated as retryable errors?
increment(spec.ReplicaStateFailed)
continue
}

status := replicaStatusFromPodList(*l, spec.TENSORFLOW)
increment(status)
increment(s.GetSingleReplicaStatus(index))
}

// Determine the overall status for the replica set based on the status of the individual
Expand Down
27 changes: 8 additions & 19 deletions pkg/trainer/training.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ func (j *TrainingJob) deleteResources() error {
return nil
}

func (j *TrainingJob) GetStatus() (spec.State, []*spec.TfReplicaStatus, error) {
state := spec.StateUnknown
func (j *TrainingJob) GetStatus() (spec.ReplicaState, []*spec.TfReplicaStatus, error) {
chief := j.job.Spec.TerminationPolicy.Chief
chiefState := spec.ReplicaStateUnknown
replicaStatuses := make([]*spec.TfReplicaStatus, 0)

// The state for each replica.
Expand All @@ -178,24 +179,12 @@ func (j *TrainingJob) GetStatus() (spec.State, []*spec.TfReplicaStatus, error) {

replicaStatuses = append(replicaStatuses, &rStatus)

// If any replicas are failed mark job as failed.
if rStatus.State == spec.ReplicaStateFailed {
state = spec.StateFailed
if string(r.Spec.TfReplicaType) == string(chief.ReplicaName) {
chiefState = r.GetSingleReplicaStatus(int32(chief.ReplicaIndex))
}
}

if v, ok := replicaSetStates[spec.MASTER]; ok && v == spec.ReplicaStateSucceeded {
state = spec.StateSucceeded
return state, replicaStatuses, nil
}

if v, ok := replicaSetStates[spec.MASTER]; ok && v == spec.ReplicaStateFailed {
state = spec.StateFailed
return state, replicaStatuses, nil
}

state = spec.StateRunning
return state, replicaStatuses, nil
return chiefState, replicaStatuses, nil
}

// isRetryableTerminationState returns true if a container terminated in a state
Expand Down Expand Up @@ -373,11 +362,11 @@ func (j *TrainingJob) reconcile(config *spec.ControllerConfig) {
log.Errorf("GetStatus() for job %v returned error: %v", j.job.Metadata.Name, err)
}
// TODO(jlewi): We should update the Phase if we detect the job is done.
if state == spec.StateFailed {
if state == spec.ReplicaStateFailed {
log.Errorf("Master failed Job: %v.", j.job.Metadata.Name)
j.status.SetPhase(spec.TfJobPhaseDone)
j.status.SetState(spec.StateFailed)
} else if state == spec.StateSucceeded {
} else if state == spec.ReplicaStateSucceeded {
log.Infof("Master succeeded Job: %v.", j.job.Metadata.Name)
j.status.SetPhase(spec.TfJobPhaseDone)
j.status.SetState(spec.StateSucceeded)
Expand Down
52 changes: 52 additions & 0 deletions pkg/trainer/training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ func TestJobSetup(t *testing.T) {
},
TfReplicaType: spec.PS,
},
{
Replicas: proto.Int32(1),
TfPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: spec.MASTER,
},
},
},
},
Expand Down Expand Up @@ -232,6 +246,25 @@ func TestJobSetup(t *testing.T) {
},
TfReplicaType: spec.PS,
},
{
Replicas: proto.Int32(1),
TfPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
Resources: v1.ResourceRequirements{
Requests: map[v1.ResourceName]resource.Quantity{
"nvidia-gpu": resource.MustParse("1"),
},
},
},
},
},
},
TfReplicaType: spec.MASTER,
},
},
},
},
Expand Down Expand Up @@ -263,6 +296,25 @@ func TestJobSetup(t *testing.T) {
},
TfReplicaType: spec.PS,
},
{
Replicas: proto.Int32(1),
TfPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
Resources: v1.ResourceRequirements{
Requests: map[v1.ResourceName]resource.Quantity{
"nvidia-gpu": resource.MustParse("1"),
},
},
},
},
},
},
TfReplicaType: spec.MASTER,
},
},
TensorBoard: &spec.TensorBoardSpec{},
},
Expand Down