Skip to content

Commit

Permalink
Fix a bunch of problems in TfJob CRD that crept in while tests were b…
Browse files Browse the repository at this point in the history
…roken (#308)

* In syncTfJob when checking whether a work queue item corresponds to a TrainingJob already in the map we need to check the UID. Otherwise we will not properly handle the case where a training job is deleted and then a new job is recreated with the same name.

* We need to make sure that the Replicas field in TrainingJob is always properly set;

* We were only initializing replicas in setup which was problematic in the case where the TfJob controller gets restarted because on restarted setup won't be invoked because the job is past that phase and as a result the replicas won't be reinitialized.

* test_runner needs to ignore case when checking whether the job succeeded otherwise we conclude
that successful jobs failed

* The controller should only forget about job after the job has been cleaned up; not when it is marked as succeeded or failed.

* Add back code to support termination policies use the worker and not the master as the chief
    *This was added in #221 and accidentally removed in the refactor in #234.
  • Loading branch information
jlewi authored Jan 16, 2018
1 parent 77e272a commit b97dfc7
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 102 deletions.
1 change: 0 additions & 1 deletion pkg/apis/tensorflow/v1alpha1/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ type ReplicaState string

const (
ReplicaStateUnknown ReplicaState = "Unknown"
ReplicaStateStarting ReplicaState = "Starting"
ReplicaStateRunning ReplicaState = "Running"
ReplicaStateFailed ReplicaState = "Failed"
ReplicaStateSucceeded ReplicaState = "Succeeded"
Expand Down
24 changes: 14 additions & 10 deletions pkg/apis/tensorflow/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@ import (

// ValidateTfJobSpec checks that the TfJobSpec is valid.
func ValidateTfJobSpec(c *tfv1.TfJobSpec) error {
// Check that each replica has a TensorFlow container.
if c.TerminationPolicy == nil || c.TerminationPolicy.Chief == nil {
return fmt.Errorf("invalid termination policy: %v", c.TerminationPolicy)
}

chiefExists := false

// Check that each replica has a TensorFlow container and a chief.
for _, r := range c.ReplicaSpecs {
found := false
if r.Template == nil && r.TfReplicaType != tfv1.PS {
return fmt.Errorf("Replica is missing Template; %v", util.Pformat(r))
}

if r.TfReplicaType == tfv1.MASTER && *r.Replicas != 1 {
return errors.New("The MASTER must have Replicas = 1")
if r.TfReplicaType == tfv1.TfReplicaType(c.TerminationPolicy.Chief.ReplicaName) {
chiefExists = true
}

if r.TfPort == nil {
Expand Down Expand Up @@ -51,14 +57,12 @@ func ValidateTfJobSpec(c *tfv1.TfJobSpec) error {
}
}

if c.TerminationPolicy != nil {
if c.TerminationPolicy.Chief == nil {
return errors.New("invalid termination policy, Chief cannot be nil")
}
if c.TerminationPolicy.Chief.ReplicaName != "MASTER" || c.TerminationPolicy.Chief.ReplicaIndex != 0 {
return errors.New("invalid termination policy, Chief should have replicaName=MASTER and index=0")
}
if !chiefExists {
return fmt.Errorf("Missing ReplicaSpec for chief: %v", c.TerminationPolicy.Chief.ReplicaName)
}

if c.TensorBoard != nil && c.TensorBoard.LogDir == "" {
return fmt.Errorf("tbReplicaSpec.LogDir must be specified")
}
return nil
}
99 changes: 99 additions & 0 deletions pkg/apis/tensorflow/validation/validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package validation

import (
"testing"

tfv1 "github.com/tensorflow/k8s/pkg/apis/tensorflow/v1alpha1"

"github.com/gogo/protobuf/proto"
"k8s.io/api/core/v1"
)

func TestValidate(t *testing.T) {
type testCase struct {
in *tfv1.TfJobSpec
expectingError bool
}

testCases := []testCase{
{
in: &tfv1.TfJobSpec{
ReplicaSpecs: []*tfv1.TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: tfv1.MASTER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
},
expectingError: false,
},
{
in: &tfv1.TfJobSpec{
ReplicaSpecs: []*tfv1.TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: tfv1.WORKER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
},
expectingError: true,
},
{
in: &tfv1.TfJobSpec{
ReplicaSpecs: []*tfv1.TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: tfv1.WORKER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
TerminationPolicy: &tfv1.TerminationPolicySpec{
Chief: &tfv1.ChiefSpec{
ReplicaName: "WORKER",
ReplicaIndex: 0,
},
},
},
expectingError: false,
},
}

for _, c := range testCases {
job := &tfv1.TfJob{
Spec: *c.in,
}
tfv1.SetObjectDefaults_TfJob(job)
if err := ValidateTfJobSpec(&job.Spec); (err != nil) != c.expectingError {
t.Errorf("unexpected validation result: %v", err)
}
}
}
20 changes: 13 additions & 7 deletions pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func New(kubeClient kubernetes.Interface, apiExtclient apiextensionsclient.Inter
}
},
Handler: cache.ResourceEventHandlerFuncs{
AddFunc: controller.enqueueController,
AddFunc: controller.enqueueController,
UpdateFunc: func(oldObj, newObj interface{}) {
controller.enqueueController(newObj)
},
Expand Down Expand Up @@ -125,8 +125,8 @@ func (c *Controller) Run(threadiness int, stopCh <-chan struct{}) error {
return fmt.Errorf("failed to wait for caches to sync")
}

glog.Info("Starting workers")
// Launch two workers to process Foo resources
glog.Info("Starting %v workers", threadiness)
// Launch workers to process TfJob resources
for i := 0; i < threadiness; i++ {
go wait.Until(c.runWorker, time.Second, stopCh)
}
Expand Down Expand Up @@ -169,9 +169,11 @@ func (c *Controller) processNextWorkItem() bool {
return true
}

// syncJob will sync the job with the given key if it has had its expectations fulfilled, meaning
// it did not expect to see any more of its pods created or deleted. This function is not meant to be invoked
// syncTFJob will sync the job with the given. This function is not meant to be invoked
// concurrently with the same key.
//
// When a job is completely processed it will return true indicating that its ok to forget about this job since
// no more processing will occur for it.
func (c *Controller) syncTFJob(key string) (bool, error) {
startTime := time.Now()
defer func() {
Expand All @@ -196,7 +198,9 @@ func (c *Controller) syncTFJob(key string) (bool, error) {
return false, err
}

if _, ok := c.jobs[tfJob.ObjectMeta.Namespace+"-"+tfJob.ObjectMeta.Name]; !ok {
// Create a new TrainingJob if there is no TrainingJob stored for it in the jobs map or if the UID's don't match.
// The UID's won't match in the event we deleted the job and then recreated the job with the samee name.
if cJob, ok := c.jobs[tfJob.ObjectMeta.Namespace+"-"+tfJob.ObjectMeta.Name]; !ok || cJob.UID() != tfJob.UID {
nc, err := trainer.NewJob(c.KubeClient, c.TfJobClient, tfJob, &c.config)

if err != nil {
Expand All @@ -217,7 +221,9 @@ func (c *Controller) syncTFJob(key string) (bool, error) {
return false, err
}

if tfJob.Status.State == tfv1alpha1.StateSucceeded {
// TODO(jlewi): This logic will need to change when/if we get rid of phases and move to conditions. At that
// case we should forget about a job when the appropriate condition is reached.
if tfJob.Status.Phase == tfv1alpha1.TfJobPhaseCleanUp {
return true, nil
} else {
return false, nil
Expand Down
106 changes: 59 additions & 47 deletions pkg/trainer/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,35 @@ func (s *TFReplicaSet) Create(config *tfv1alpha1.ControllerConfig) error {
}
_, err = s.ClientSet.CoreV1().ConfigMaps(s.Job.job.ObjectMeta.Namespace).Create(cm)
if err != nil {
log.Errorf("Error creating PS ConfigMap: %v, %v", cm.ObjectMeta.Name, err)
return err
if k8s_errors.IsAlreadyExists(err) {
log.Infof("%v already exists.", cm.Name)
} else {
log.Errorf("Error creating PS ConfigMap: %v, %v", cm.ObjectMeta.Name, err)
return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating PS ConfigMap %v returned error.", cm.Name), err})
}
}

// Update Volumes to include the ConfigMap containing grpc_tensorflow_server.py
s.Spec.Template.Spec.Volumes = append(s.Spec.Template.Spec.Volumes, v1.Volume{
Name: "ps-config-volume",
VolumeSource: v1.VolumeSource{
ConfigMap: &v1.ConfigMapVolumeSource{
LocalObjectReference: v1.LocalObjectReference{
Name: s.defaultPSConfigMapName(),
name := "ps-config-volume"
hasVolume := false
for _, v := range s.Spec.Template.Spec.Volumes {
if v.Name == name {
hasVolume = true
break
}
}
if !hasVolume {
s.Spec.Template.Spec.Volumes = append(s.Spec.Template.Spec.Volumes, v1.Volume{
Name: "ps-config-volume",
VolumeSource: v1.VolumeSource{
ConfigMap: &v1.ConfigMapVolumeSource{
LocalObjectReference: v1.LocalObjectReference{
Name: s.defaultPSConfigMapName(),
},
},
},
},
})
})
}
}

for index := int32(0); index < *s.Spec.Replicas; index++ {
Expand Down Expand Up @@ -410,9 +424,42 @@ func replicaStatusFromPodList(l v1.PodList, name tfv1alpha1.ContainerName) tfv1a
return tfv1alpha1.ReplicaStateUnknown
}

func (s *TFReplicaSet) GetSingleReplicaStatus(index int32) tfv1alpha1.ReplicaState {
j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.ObjectMeta.Namespace).Get(s.jobName(index), meta_v1.GetOptions{})

if err != nil {
return tfv1alpha1.ReplicaStateUnknown
}

if j.Status.Succeeded >= 1 {
return tfv1alpha1.ReplicaStateSucceeded
}

labels := s.Labels()
labels["task_index"] = fmt.Sprintf("%v", index)
selector, err := labels.ToSelector()
if err != nil {
log.Errorf("labels.ToSelector() error; %v", err)
return tfv1alpha1.ReplicaStateFailed
}

// TODO(jlewi): Handle errors. We need to get the pod and looking at recent container exits.
l, err := s.ClientSet.CoreV1().Pods(s.Job.job.ObjectMeta.Namespace).List(meta_v1.ListOptions{
// TODO(jlewi): Why isn't the label selector working?
LabelSelector: selector,
})

if err != nil {
// TODO(jlewi): Are there errors that should be treated as retryable errors?
return tfv1alpha1.ReplicaStateFailed
}

status := replicaStatusFromPodList(*l, tfv1alpha1.TENSORFLOW)
return status
}

// Status returns the status of the replica set.
func (s *TFReplicaSet) GetStatus() (tfv1alpha1.TfReplicaStatus, error) {

status := tfv1alpha1.TfReplicaStatus{
TfReplicaType: s.Spec.TfReplicaType,
State: tfv1alpha1.ReplicaStateUnknown,
Expand All @@ -429,42 +476,7 @@ func (s *TFReplicaSet) GetStatus() (tfv1alpha1.TfReplicaStatus, error) {
}

for index := int32(0); index < *s.Spec.Replicas; index++ {

j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.ObjectMeta.Namespace).Get(s.jobName(index), meta_v1.GetOptions{})

if err != nil {
increment(tfv1alpha1.ReplicaStateUnknown)
continue
}

if j.Status.Succeeded >= 1 {
increment(tfv1alpha1.ReplicaStateSucceeded)
continue
}

labels := s.Labels()
labels["task_index"] = fmt.Sprintf("%v", index)
selector, err := labels.ToSelector()
if err != nil {
log.Errorf("labels.ToSelector() error; %v", err)
increment(tfv1alpha1.ReplicaStateFailed)
continue
}

// TODO(jlewi): Handle errors. We need to get the pod and looking at recent container exits.
l, err := s.ClientSet.CoreV1().Pods(s.Job.job.ObjectMeta.Namespace).List(meta_v1.ListOptions{
// TODO(jlewi): Why isn't the label selector working?
LabelSelector: selector,
})

if err != nil {
// TODO(jlewi): Are there errors that should be treated as retryable errors?
increment(tfv1alpha1.ReplicaStateFailed)
continue
}

status := replicaStatusFromPodList(*l, tfv1alpha1.TENSORFLOW)
increment(status)
increment(s.GetSingleReplicaStatus(index))
}

// Determine the overall status for the replica set based on the status of the individual
Expand Down
Loading

0 comments on commit b97dfc7

Please sign in to comment.