diff --git a/pkg/trainer/replicas.go b/pkg/trainer/replicas.go index f8c72b7c69..8a4ff6f932 100644 --- a/pkg/trainer/replicas.go +++ b/pkg/trainer/replicas.go @@ -20,9 +20,7 @@ import ( "fmt" "strings" - "github.com/golang/protobuf/proto" - log "github.com/sirupsen/logrus" - batch "k8s.io/api/batch/v1" + log "github.com/golang/glog" "k8s.io/api/core/v1" k8s_errors "k8s.io/apimachinery/pkg/api/errors" meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -119,36 +117,13 @@ func (s *TFReplicaSet) Labels() KubernetesLabels { func (s *TFReplicaSet) Create(config *tfv1alpha1.ControllerConfig) error { for index := int32(0); index < *s.Spec.Replicas; index++ { - taskLabels := s.Labels() - taskLabels["task_index"] = fmt.Sprintf("%v", index) - - // Create the service. - service := &v1.Service{ - ObjectMeta: meta_v1.ObjectMeta{ - Name: s.jobName(index), - Labels: taskLabels, - OwnerReferences: []meta_v1.OwnerReference{ - helper.AsOwner(s.Job.job), - }, - }, - Spec: v1.ServiceSpec{ - Selector: taskLabels, - Ports: []v1.ServicePort{ - { - Name: "tf-port", - Port: *s.Spec.TFPort, - }, - }, - }, - } + // Create the service + createdService, err := s.CreateServiceWithIndex(index) - log.Infof("Creating Service: %v", service.ObjectMeta.Name) - createdService, err := s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Create(service) - - // If the job already exists do nothing. + // If the service already exists do nothing. if err != nil { if k8s_errors.IsAlreadyExists(err) { - log.Infof("Service %v already exists.", s.jobName(index)) + log.Infof("Service: %v already exists.", s.genName(index)) } else { s.recorder.Eventf(s.Job.job, v1.EventTypeWarning, FailedCreateReason, "Error creating: %v", err) return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating service %v returned error.", createdService.ObjectMeta.Name), err}) @@ -157,84 +132,106 @@ func (s *TFReplicaSet) Create(config *tfv1alpha1.ControllerConfig) error { s.recorder.Eventf(s.Job.job, v1.EventTypeNormal, SuccessfulCreateReason, "Created service: %v", createdService.Name) } - // Configure the TFCONFIG environment variable. - tfConfig := TFConfig{ - Cluster: s.Job.ClusterSpec(), - Task: TaskSpec{ - Type: strings.ToLower(string(s.Spec.TFReplicaType)), - Index: int(index), - }, - // We need to set environment to cloud otherwise it will default to local which isn't what we want. - Environment: "cloud", - } + // Create the pod + createdPod, err := s.CreatePodWithIndex(index) - tfConfigJson, err := json.Marshal(tfConfig) + // If the pod already exists do nothing. if err != nil { - log.Errorf("Job: %v serializing tfConfig: %v return error; %v", s.Job.job.ObjectMeta.Name, util.Pformat(tfConfig), err) - return err + if k8s_errors.IsAlreadyExists(err) { + log.Infof("Pod: %v already exists.", s.genName(index)) + continue + } + s.recorder.Eventf(s.Job.job, v1.EventTypeWarning, FailedCreateReason, "Error creating: %v", err) + return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating Pod %v returned error.", createdPod.ObjectMeta.Name), err}) + } - // Make a copy of the template because we will modify it below. . - newPodSpecTemplate := s.Spec.Template.DeepCopy() + s.recorder.Eventf(s.Job.job, v1.EventTypeNormal, SuccessfulCreateReason, "Created Pod: %v", createdPod.Name) + } + return nil +} - newJ := &batch.Job{ - ObjectMeta: meta_v1.ObjectMeta{ - Name: s.jobName(index), - Labels: taskLabels, - OwnerReferences: []meta_v1.OwnerReference{ - helper.AsOwner(s.Job.job), - }, +// CreateServiceWithIndex will create a new service with specify index +func (s *TFReplicaSet) CreateServiceWithIndex(index int32) (*v1.Service, error) { + taskLabels := s.Labels() + taskLabels["task_index"] = fmt.Sprintf("%v", index) + + // Create the service. + service := &v1.Service{ + ObjectMeta: meta_v1.ObjectMeta{ + Name: s.genName(index), + Labels: taskLabels, + OwnerReferences: []meta_v1.OwnerReference{ + helper.AsOwner(s.Job.job), }, - Spec: batch.JobSpec{ - Completions: proto.Int32(1), - Parallelism: proto.Int32(1), - Template: *newPodSpecTemplate, + }, + Spec: v1.ServiceSpec{ + Selector: taskLabels, + Ports: []v1.ServicePort{ + { + Name: "tf-port", + Port: *s.Spec.TFPort, + }, }, - } - - if newJ.Spec.Template.ObjectMeta.Labels == nil { - newJ.Spec.Template.ObjectMeta.Labels = make(map[string]string) - } + }, + } - // Pods need to be tagged with the labels. - for k, v := range taskLabels { - newJ.Spec.Template.ObjectMeta.Labels[k] = v - } + log.Infof("Creating service: %v", service.ObjectMeta.Name) + return s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Create(service) +} - // Add TF_CONFIG environment variable. - for i, _ := range newJ.Spec.Template.Spec.Containers { - // We can't get c in the loop variable because that would be by value so our modifications - // wouldn't have any effect. - c := &newJ.Spec.Template.Spec.Containers[i] - if tfv1alpha1.ContainerName(c.Name) != tfv1alpha1.TENSORFLOW { - continue - } - if len(c.Env) == 0 { - c.Env = make([]v1.EnvVar, 0) - } - c.Env = append(c.Env, v1.EnvVar{ - Name: "TF_CONFIG", - Value: string(tfConfigJson), - }) - } +// CreatePodWithIndex will create a new pod with specify index +func (s *TFReplicaSet) CreatePodWithIndex(index int32) (*v1.Pod, error) { + taskLabels := s.Labels() + taskLabels["task_index"] = fmt.Sprintf("%v", index) + + pod := &v1.Pod{ + ObjectMeta: meta_v1.ObjectMeta{ + Name: s.genName(index), + Labels: taskLabels, + OwnerReferences: []meta_v1.OwnerReference{ + helper.AsOwner(s.Job.job), + }, + }, + Spec: *s.Spec.Template.Spec.DeepCopy(), + } - log.Infof("Creating Job: %v", newJ.ObjectMeta.Name) - createdJob, err := s.ClientSet.BatchV1().Jobs(s.Job.job.ObjectMeta.Namespace).Create(newJ) + // Configure the TFCONFIG environment variable. + tfConfig := TFConfig{ + Cluster: s.Job.ClusterSpec(), + Task: TaskSpec{ + Type: strings.ToLower(string(s.Spec.TFReplicaType)), + Index: int(index), + }, + // We need to set environment to cloud otherwise it will default to local which isn't what we want. + Environment: "cloud", + } - // If the job already exists do nothing. - if err != nil { - if k8s_errors.IsAlreadyExists(err) { - log.Infof("%v already exists.", s.jobName(index)) + tfConfigJson, err := json.Marshal(tfConfig) + if err != nil { + log.Errorf("Job: %v serializing tfConfig: %v return error; %v", s.Job.job.ObjectMeta.Name, util.Pformat(tfConfig), err) + return nil, err + } - } else { - s.recorder.Eventf(s.Job.job, v1.EventTypeWarning, FailedCreateReason, "Error creating: %v", err) - return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating Job %v returned error.", createdJob.ObjectMeta.Name), err}) - } - } else { - s.recorder.Eventf(s.Job.job, v1.EventTypeNormal, SuccessfulCreateReason, "Created job: %v", createdJob.Name) + // Add TF_CONFIG environment variable. + for i, _ := range pod.Spec.Containers { + // We can't get c in the loop variable because that would be by value so our modifications + // wouldn't have any effect. + c := &pod.Spec.Containers[i] + if tfv1alpha1.ContainerName(c.Name) != tfv1alpha1.TENSORFLOW { + continue + } + if len(c.Env) == 0 { + c.Env = make([]v1.EnvVar, 0) } + c.Env = append(c.Env, v1.EnvVar{ + Name: "TF_CONFIG", + Value: string(tfConfigJson), + }) } - return nil + + log.Infof("Creating pod: %v", pod.ObjectMeta.Name) + return s.ClientSet.CoreV1().Pods(s.Job.job.ObjectMeta.Namespace).Create(pod) } // Delete deletes the replicas @@ -250,8 +247,8 @@ func (s *TFReplicaSet) Delete() error { LabelSelector: selector, } - log.Infof("Deleting Jobs namespace=%v selector=%v", s.Job.job.ObjectMeta.Namespace, selector) - err = s.ClientSet.BatchV1().Jobs(s.Job.job.ObjectMeta.Namespace).DeleteCollection(&meta_v1.DeleteOptions{}, options) + log.V(1).Infof("Deleting Jobs namespace=%v selector=%v", s.Job.job.ObjectMeta.Namespace, selector) + err = s.ClientSet.CoreV1().Pods(s.Job.job.ObjectMeta.Namespace).DeleteCollection(&meta_v1.DeleteOptions{}, options) if err != nil { log.Errorf("There was a problem deleting the jobs; %v", err) @@ -270,11 +267,11 @@ func (s *TFReplicaSet) Delete() error { // Services doesn't support DeleteCollection so we delete them individually. // TODO(jlewi): We should check if this has changed with K8s 1.8 or other releases. for index := int32(0); index < *s.Spec.Replicas; index++ { - log.Infof("Deleting Service %v:%v", s.Job.job.ObjectMeta.Namespace, s.jobName((index))) - err = s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Delete(s.jobName(index), &meta_v1.DeleteOptions{}) + log.V(1).Infof("Deleting Service %v:%v", s.Job.job.ObjectMeta.Namespace, s.genName((index))) + err = s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Delete(s.genName(index), &meta_v1.DeleteOptions{}) if err != nil { - log.Errorf("Error deleting service %v; %v", s.jobName(index), err) + log.Errorf("Error deleting service %v; %v", s.genName(index), err) failures = true } } @@ -304,7 +301,6 @@ func (s *TFReplicaSet) Delete() error { // replicaStatusFromPodList returns a status from a list of pods for a job. func replicaStatusFromPodList(l v1.PodList, name tfv1alpha1.ContainerName) tfv1alpha1.ReplicaState { - log.Infof("Get replicaStatus from PodList: %v", util.Pformat(l)) var latest *v1.Pod for _, i := range l.Items { if latest == nil { @@ -359,13 +355,13 @@ func replicaStatusFromPodList(l v1.PodList, name tfv1alpha1.ContainerName) tfv1a } func (s *TFReplicaSet) GetSingleReplicaStatus(index int32) tfv1alpha1.ReplicaState { - j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.ObjectMeta.Namespace).Get(s.jobName(index), meta_v1.GetOptions{}) + p, err := s.ClientSet.CoreV1().Pods(s.Job.job.ObjectMeta.Namespace).Get(s.genName(index), meta_v1.GetOptions{}) if err != nil { return tfv1alpha1.ReplicaStateUnknown } - if j.Status.Succeeded >= 1 { + if v1.PodSucceeded == p.Status.Phase { return tfv1alpha1.ReplicaStateSucceeded } @@ -436,10 +432,83 @@ func (s *TFReplicaSet) GetStatus() (tfv1alpha1.TFReplicaStatus, error) { return status, nil } -func (s *TFReplicaSet) jobName(index int32) string { +// SyncPods will try to check current pods for this TFReplicaSet and try to make it as desired. +func (s *TFReplicaSet) SyncPods() error { + for index := int32(0); index < *s.Spec.Replicas; index++ { + p, err := s.ClientSet.CoreV1().Pods(s.Job.job.ObjectMeta.Namespace).Get(s.genName(index), meta_v1.GetOptions{}) + if err != nil && k8s_errors.IsNotFound(err) { + log.Infof("Pod: %v not found, create new one.", s.genName(index)) + // Create the pod + createdPod, err := s.CreatePodWithIndex(index) + + // If the pod already exists do nothing. + if err != nil { + if k8s_errors.IsAlreadyExists(err) { + log.Infof("Pod: %v already exists.", s.genName(index)) + continue + } + s.recorder.Eventf(s.Job.job, v1.EventTypeWarning, FailedCreateReason, "Error creating: %v", err) + return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating pod %v returned error.", createdPod.ObjectMeta.Name), err}) + } + + s.recorder.Eventf(s.Job.job, v1.EventTypeNormal, SuccessfulCreateReason, "Created pod: %v", createdPod.Name) + continue + } + + if err != nil { + // TODO: handing this error + continue + } + + if v1.PodFailed == p.Status.Phase && p.DeletionTimestamp == nil { + // TODO: check the exit code to check whether it is permanent error + // Pod is failed so we delete this pod and will recreate it in the next sync loop + err = s.ClientSet.CoreV1().Pods(p.ObjectMeta.Namespace).Delete(s.genName(index), &meta_v1.DeleteOptions{}) + if err != nil { + log.Errorf("Error deleting pod %v; %v", s.genName(index), err) + } + } + } + + return nil +} + +// SyncServices will try to check current services for this TFReplicaSet and try to make it as desired. +func (s *TFReplicaSet) SyncServices() error { + for index := int32(0); index < *s.Spec.Replicas; index++ { + _, err := s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Get(s.genName(index), meta_v1.GetOptions{}) + if err != nil && k8s_errors.IsNotFound(err) { + log.Infof("Service: %v not found, create new one.", s.genName(index)) + // Create the service + createdService, err := s.CreateServiceWithIndex(index) + + // If the service already exists do nothing. + if err != nil { + if k8s_errors.IsAlreadyExists(err) { + log.Infof("Service: %v already exists.", s.genName(index)) + continue + } + s.recorder.Eventf(s.Job.job, v1.EventTypeWarning, FailedCreateReason, "Error creating: %v", err) + return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating Service %v returned error.", createdService.ObjectMeta.Name), err}) + } + + s.recorder.Eventf(s.Job.job, v1.EventTypeNormal, SuccessfulCreateReason, "Created Service: %v", createdService.Name) + continue + } + + if err != nil { + // TODO: handing this error + continue + } + } + + return nil +} + +func (s *TFReplicaSet) genName(index int32) string { // Truncate tfjob name to 40 characters // The whole job name should be compliant with the DNS_LABEL spec, up to a max length of 63 characters - // Thus jobname(40 chars)-replicaType(6 chars)-runtimeId(4 chars)-index(4 chars), also leaving some spaces + // Thus genName(40 chars)-replicaType(6 chars)-runtimeId(4 chars)-index(4 chars), also leaving some spaces // See https://github.com/kubernetes/community/blob/master/contributors/design-proposals/architecture/identifiers.md return fmt.Sprintf("%v-%v-%v-%v", fmt.Sprintf("%.40s", s.Job.job.ObjectMeta.Name), strings.ToLower(string(s.Spec.TFReplicaType)), s.Job.job.Spec.RuntimeId, index) } diff --git a/pkg/trainer/replicas_test.go b/pkg/trainer/replicas_test.go index c3afb38ddd..0d43287b85 100644 --- a/pkg/trainer/replicas_test.go +++ b/pkg/trainer/replicas_test.go @@ -136,39 +136,39 @@ func TestTFReplicaSet(t *testing.T) { t.Fatalf("Service.Metadata.OwnerReferences; Got %v; want %v", util.Pformat(s.ObjectMeta.OwnerReferences[0]), util.Pformat(expectedOwnerReference)) } - // Check that a job was created. - l, err := clientSet.BatchV1().Jobs(replica.Job.job.ObjectMeta.Namespace).List(meta_v1.ListOptions{}) + // Check that a pod was created. + l, err := clientSet.CoreV1().Pods(replica.Job.job.ObjectMeta.Namespace).List(meta_v1.ListOptions{}) if err != nil { - t.Fatalf("List jobs error; %v", err) + t.Fatalf("List pods error; %v", err) } if len(l.Items) != 2 { - t.Fatalf("Expected 1 job got %v", len(l.Items)) + t.Fatalf("Expected 1 pod got %v", len(l.Items)) } - j := l.Items[index] + p := l.Items[index] - if !reflect.DeepEqual(expectedLabels, j.ObjectMeta.Labels) { - t.Fatalf("Job Labels; Got %v Want: %v", expectedLabels, j.ObjectMeta.Labels) + if !reflect.DeepEqual(expectedLabels, p.ObjectMeta.Labels) { + t.Fatalf("Pod Labels; Got %v Want: %v", expectedLabels, p.ObjectMeta.Labels) } - if j.ObjectMeta.Name != name { - t.Fatalf("Job.ObjectMeta.Name = %v; want %v", j.ObjectMeta.Name, name) + if p.ObjectMeta.Name != name { + t.Fatalf("Pod.ObjectMeta.Name = %v; want %v", p.ObjectMeta.Name, name) } - if len(j.Spec.Template.Spec.Containers) != 1 { - t.Fatalf("Expected 1 container got %v", len(j.Spec.Template.Spec.Containers)) + if len(p.Spec.Containers) != 1 { + t.Fatalf("Expected 1 container got %v", len(p.Spec.Containers)) } - if len(j.ObjectMeta.OwnerReferences) != 1 { - t.Fatalf("Expected 1 owner reference got %v", len(j.ObjectMeta.OwnerReferences)) + if len(p.ObjectMeta.OwnerReferences) != 1 { + t.Fatalf("Expected 1 owner reference got %v", len(p.ObjectMeta.OwnerReferences)) } - if !reflect.DeepEqual(j.ObjectMeta.OwnerReferences[0], expectedOwnerReference) { - t.Fatalf("Job.Metadata.OwnerReferences; Got %v; want %v", util.Pformat(j.ObjectMeta.OwnerReferences[0]), util.Pformat(expectedOwnerReference)) + if !reflect.DeepEqual(p.ObjectMeta.OwnerReferences[0], expectedOwnerReference) { + t.Fatalf("Pod.Metadata.OwnerReferences; Got %v; want %v", util.Pformat(p.ObjectMeta.OwnerReferences[0]), util.Pformat(expectedOwnerReference)) } - c := j.Spec.Template.Spec.Containers[0] + c := p.Spec.Containers[0] if len(c.Env) != 1 { t.Fatalf("Expected 1 environment variable got %v", len(c.Env)) } diff --git a/pkg/trainer/training.go b/pkg/trainer/training.go index b59231ff4a..67ef3b48ee 100644 --- a/pkg/trainer/training.go +++ b/pkg/trainer/training.go @@ -98,7 +98,7 @@ func (j *TrainingJob) ClusterSpec() ClusterSpec { replicaNames := make([]string, 0, *p.Spec.Replicas) for i := int32(0); i < *p.Spec.Replicas; i++ { - replicaNames = append(replicaNames, fmt.Sprintf("%v:%v", p.jobName(i), *p.Spec.TFPort)) + replicaNames = append(replicaNames, fmt.Sprintf("%v:%v", p.genName(i), *p.Spec.TFPort)) } clusterSpec[strings.ToLower(string(p.Spec.TFReplicaType))] = replicaNames @@ -368,6 +368,22 @@ func (j *TrainingJob) Reconcile(config *tfv1alpha1.ControllerConfig) error { } } + // sync pods + for _, rc := range j.Replicas { + err := rc.SyncPods() + if err != nil { + log.Errorf("SyncPods error: %v", err) + } + } + + // sync services + for _, rc := range j.Replicas { + err := rc.SyncServices() + if err != nil { + log.Errorf("SyncServices error: %v", err) + } + } + // If the phase changed we should update the CRD. if err := j.updateCRDStatus(); err != nil { log.Warningf("Job %v, failed to update CRD status error: %v", j.job.ObjectMeta.Name, err)