diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index c26d643799..0f849492b9 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -85,6 +85,10 @@ func main() { flag.StringVar(&config.Config.PyTorchInitContainerTemplateFile, "pytorch-init-container-template-file", config.PyTorchInitContainerTemplateFileDefault, "The template file for pytorch init container") + // MPI related flags + flag.StringVar(&config.Config.MPIKubectlDeliveryImage, "mpi-kubectl-delivery-image", + config.MPIKubectlDeliveryImageDefault, "The image for mpi launcher init container") + opts := zap.Options{ Development: true, StacktraceLevel: zapcore.DPanicLevel, diff --git a/pkg/config/config.go b/pkg/config/config.go index 5742729bd8..2a8ed5e813 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -18,6 +18,7 @@ package config var Config struct { PyTorchInitContainerTemplateFile string PyTorchInitContainerImage string + MPIKubectlDeliveryImage string } const ( @@ -27,4 +28,6 @@ const ( // PyTorchInitContainerTemplateFileDefault is the default template file for // the pytorch init container. PyTorchInitContainerTemplateFileDefault = "/etc/config/initContainer.yaml" + // MPIKubectlDeliveryImageDefault is the default image for launcher pod in MPIJob init container. + MPIKubectlDeliveryImageDefault = "mpioperator/kubectl-delivery:latest" ) diff --git a/pkg/controller.v1/mpi/mpijob.go b/pkg/controller.v1/mpi/mpijob.go index bc5d95ab99..f734be00bd 100644 --- a/pkg/controller.v1/mpi/mpijob.go +++ b/pkg/controller.v1/mpi/mpijob.go @@ -75,8 +75,6 @@ const ( // gang scheduler name. gangSchedulerName = "volcano" - kubectlDeliveryImage = "mpioperator/kubectl-delivery:latest" - // podTemplateSchedulerNameReason is the warning reason when other scheduler name is set // in pod templates with gang-scheduling enabled podTemplateSchedulerNameReason = "SettedPodTemplateSchedulerName" diff --git a/pkg/controller.v1/mpi/mpijob_controller.go b/pkg/controller.v1/mpi/mpijob_controller.go index 97aeb08b6c..4d95007847 100644 --- a/pkg/controller.v1/mpi/mpijob_controller.go +++ b/pkg/controller.v1/mpi/mpijob_controller.go @@ -58,6 +58,7 @@ import ( mpiv1 "github.com/kubeflow/training-operator/pkg/apis/mpi/v1" "github.com/kubeflow/training-operator/pkg/apis/mpi/validation" trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" + ctlrconfig "github.com/kubeflow/training-operator/pkg/config" ) const ( @@ -116,12 +117,12 @@ type MPIJobReconciler struct { // Reconcile is part of the main kubernetes reconciliation loop which aims to // move the current state of the cluster closer to the desired state. -func (r *MPIJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { +func (jc *MPIJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { _ = log.FromContext(ctx) - logger := r.Log.WithValues(mpiv1.Singular, req.NamespacedName) + logger := jc.Log.WithValues(mpiv1.Singular, req.NamespacedName) mpijob := &mpiv1.MPIJob{} - err := r.Get(ctx, req.NamespacedName, mpijob) + err := jc.Get(ctx, req.NamespacedName, mpijob) if err != nil { logger.Info(err.Error(), "unable to fetch MPIJob", req.NamespacedName.String()) return ctrl.Result{}, client.IgnoreNotFound(err) @@ -141,11 +142,11 @@ func (r *MPIJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr } // Set default priorities to mpijob - r.Scheme.Default(mpijob) + jc.Scheme.Default(mpijob) // Use common to reconcile the job related pod and service //mpijob not need service - err = r.ReconcileJobs(mpijob, mpijob.Spec.MPIReplicaSpecs, mpijob.Status, &mpijob.Spec.RunPolicy) + err = jc.ReconcileJobs(mpijob, mpijob.Spec.MPIReplicaSpecs, mpijob.Status, &mpijob.Spec.RunPolicy) if err != nil { logrus.Warnf("Reconcile MPIJob error %v", err) return ctrl.Result{}, err @@ -155,9 +156,9 @@ func (r *MPIJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr } // SetupWithManager sets up the controller with the Manager. -func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { - c, err := controller.New(r.ControllerName(), mgr, controller.Options{ - Reconciler: r, +func (jc *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { + c, err := controller.New(jc.ControllerName(), mgr, controller.Options{ + Reconciler: jc, }) if err != nil { @@ -166,7 +167,7 @@ func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { // using onOwnerCreateFunc is easier to set defaults if err = c.Watch(&source.Kind{Type: &mpiv1.MPIJob{}}, &handler.EnqueueRequestForObject{}, - predicate.Funcs{CreateFunc: r.onOwnerCreateFunc()}, + predicate.Funcs{CreateFunc: jc.onOwnerCreateFunc()}, ); err != nil { return err } @@ -176,9 +177,9 @@ func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { IsController: true, OwnerType: &mpiv1.MPIJob{}, }, predicate.Funcs{ - CreateFunc: util.OnDependentCreateFunc(r.Expectations), - UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), - DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), + CreateFunc: util.OnDependentCreateFunc(jc.Expectations), + UpdateFunc: util.OnDependentUpdateFunc(&jc.JobController), + DeleteFunc: util.OnDependentDeleteFunc(jc.Expectations), }); err != nil { return err } @@ -188,9 +189,9 @@ func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { IsController: true, OwnerType: &mpiv1.MPIJob{}, }, predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), + CreateFunc: util.OnDependentCreateFuncGeneric(jc.Expectations), + UpdateFunc: util.OnDependentUpdateFuncGeneric(&jc.JobController), + DeleteFunc: util.OnDependentDeleteFuncGeneric(jc.Expectations), }); err != nil { return err } @@ -200,9 +201,9 @@ func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { IsController: true, OwnerType: &mpiv1.MPIJob{}, }, predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), + CreateFunc: util.OnDependentCreateFuncGeneric(jc.Expectations), + UpdateFunc: util.OnDependentUpdateFuncGeneric(&jc.JobController), + DeleteFunc: util.OnDependentDeleteFuncGeneric(jc.Expectations), }); err != nil { return err } @@ -212,9 +213,9 @@ func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { IsController: true, OwnerType: &mpiv1.MPIJob{}, }, predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), + CreateFunc: util.OnDependentCreateFuncGeneric(jc.Expectations), + UpdateFunc: util.OnDependentUpdateFuncGeneric(&jc.JobController), + DeleteFunc: util.OnDependentDeleteFuncGeneric(jc.Expectations), }); err != nil { return err } @@ -224,9 +225,9 @@ func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { IsController: true, OwnerType: &mpiv1.MPIJob{}, }, predicate.Funcs{ - CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations), - UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController), - DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations), + CreateFunc: util.OnDependentCreateFuncGeneric(jc.Expectations), + UpdateFunc: util.OnDependentUpdateFuncGeneric(&jc.JobController), + DeleteFunc: util.OnDependentDeleteFuncGeneric(jc.Expectations), }); err != nil { return err } @@ -235,7 +236,7 @@ func (r *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { } // DeletePodsAndServices is overridden because mpi-reconciler.v1 needs not deleting services -func (r *MPIJobReconciler) DeletePodsAndServices(runPolicy *commonv1.RunPolicy, job interface{}, pods []*corev1.Pod) error { +func (jc *MPIJobReconciler) DeletePodsAndServices(runPolicy *commonv1.RunPolicy, job interface{}, pods []*corev1.Pod) error { if len(pods) == 0 { return nil } @@ -252,7 +253,7 @@ func (r *MPIJobReconciler) DeletePodsAndServices(runPolicy *commonv1.RunPolicy, if *runPolicy.CleanPodPolicy == commonv1.CleanPodPolicyRunning && pod.Status.Phase != corev1.PodRunning && pod.Status.Phase != corev1.PodPending { continue } - if err := r.PodControl.DeletePod(pod.Namespace, pod.Name, job.(runtime.Object)); err != nil { + if err := jc.PodControl.DeletePod(pod.Namespace, pod.Name, job.(runtime.Object)); err != nil { return err } } @@ -268,57 +269,57 @@ func (jc *MPIJobReconciler) ReconcileServices( return nil } -func (r *MPIJobReconciler) ControllerName() string { +func (jc *MPIJobReconciler) ControllerName() string { return controllerName } -func (r *MPIJobReconciler) GetAPIGroupVersionKind() schema.GroupVersionKind { +func (jc *MPIJobReconciler) GetAPIGroupVersionKind() schema.GroupVersionKind { return mpiv1.GroupVersion.WithKind(mpiv1.Kind) } -func (r *MPIJobReconciler) GetAPIGroupVersion() schema.GroupVersion { +func (jc *MPIJobReconciler) GetAPIGroupVersion() schema.GroupVersion { return mpiv1.GroupVersion } -func (r *MPIJobReconciler) GetGroupNameLabelValue() string { +func (jc *MPIJobReconciler) GetGroupNameLabelValue() string { return mpiv1.GroupVersion.Group } // SetClusterSpec is overridden because no cluster spec is needed for MPIJob -func (r *MPIJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { +func (jc *MPIJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { return nil } -func (r *MPIJobReconciler) GetDefaultContainerName() string { +func (jc *MPIJobReconciler) GetDefaultContainerName() string { return mpiv1.DefaultContainerName } -func (r *MPIJobReconciler) GetDefaultContainerPortName() string { +func (jc *MPIJobReconciler) GetDefaultContainerPortName() string { return mpiv1.DefaultPortName } -func (r *MPIJobReconciler) IsMasterRole(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, +func (jc *MPIJobReconciler) IsMasterRole(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, rtype commonv1.ReplicaType, index int) bool { return false } -func (r *MPIJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) { +func (jc *MPIJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) { mpijob := &mpiv1.MPIJob{} - err := r.Get(context.Background(), types.NamespacedName{ + err := jc.Get(context.Background(), types.NamespacedName{ Namespace: namespace, Name: name, }, mpijob) return mpijob, err } // onOwnerCreateFunc modify creation condition. -func (r *MPIJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool { +func (jc *MPIJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool { return func(e event.CreateEvent) bool { mpiJob, ok := e.Object.(*mpiv1.MPIJob) if !ok { return true } - r.Scheme.Default(mpiJob) + jc.Scheme.Default(mpiJob) msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, e.Object.GetName()) logrus.Info(msg) trainingoperatorcommon.CreatedJobsCounterInc(mpiJob.Namespace, mpiv1.FrameworkName) @@ -330,7 +331,7 @@ func (r *MPIJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool { } } -func (r *MPIJobReconciler) ReconcilePods( +func (jc *MPIJobReconciler) ReconcilePods( job interface{}, jobStatus *commonv1.JobStatus, pods []*corev1.Pod, @@ -353,7 +354,7 @@ func (r *MPIJobReconciler) ReconcilePods( initializeReplicaStatuses(jobStatus, rtype) // Get the launcher Job for this MPIJob. - launcher, err := r.getLauncherJob(mpiJob) + launcher, err := jc.getLauncherJob(mpiJob) if err != nil { return err } @@ -371,57 +372,57 @@ func (r *MPIJobReconciler) ReconcilePods( isGPULauncher := isGPULauncher(mpiJob) // Get the launcher ServiceAccount for this MPIJob. - if sa, err := r.getOrCreateLauncherServiceAccount(mpiJob); sa == nil || err != nil { + if sa, err := jc.getOrCreateLauncherServiceAccount(mpiJob); sa == nil || err != nil { return err } // Get the ConfigMap for this MPIJob. - if config, err := r.getOrCreateConfigMap(mpiJob, workerReplicas, isGPULauncher); config == nil || err != nil { + if config, err := jc.getOrCreateConfigMap(mpiJob, workerReplicas, isGPULauncher); config == nil || err != nil { return err } // Get the launcher Role for this MPIJob. - if r, err := r.getOrCreateLauncherRole(mpiJob, workerReplicas); r == nil || err != nil { + if r, err := jc.getOrCreateLauncherRole(mpiJob, workerReplicas); r == nil || err != nil { return err } // Get the launcher RoleBinding for this MPIJob. - if rb, err := r.getLauncherRoleBinding(mpiJob); rb == nil || err != nil { + if rb, err := jc.getLauncherRoleBinding(mpiJob); rb == nil || err != nil { return err } - worker, err = r.getOrCreateWorker(mpiJob) + worker, err = jc.getOrCreateWorker(mpiJob) if err != nil { return err } if launcher == nil { - launcher, err = r.KubeClientSet.CoreV1().Pods(mpiJob.Namespace).Create(context.Background(), r.newLauncher(mpiJob, kubectlDeliveryImage, isGPULauncher), metav1.CreateOptions{}) + launcher, err = jc.KubeClientSet.CoreV1().Pods(mpiJob.Namespace).Create(context.Background(), jc.newLauncher(mpiJob, ctlrconfig.Config.MPIKubectlDeliveryImage, isGPULauncher), metav1.CreateOptions{}) if err != nil { - r.Recorder.Eventf(mpiJob, corev1.EventTypeWarning, mpiJobFailedReason, "launcher pod created failed: %v", err) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeWarning, mpiJobFailedReason, "launcher pod created failed: %v", err) return err } else { - r.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, mpiJobRunningReason, "launcher pod created success: %v", launcher.Name) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, mpiJobRunningReason, "launcher pod created success: %v", launcher.Name) } } } // Finally, we update the status block of the MPIJob resource to reflect the // current state of the world. - err = r.updateMPIJobStatus(mpiJob, launcher, worker) + err = jc.updateMPIJobStatus(mpiJob, launcher, worker) if err != nil { return err } return nil } -func (r *MPIJobReconciler) updateMPIJobStatus(mpiJob *mpiv1.MPIJob, launcher *corev1.Pod, worker []*corev1.Pod) error { +func (jc *MPIJobReconciler) updateMPIJobStatus(mpiJob *mpiv1.MPIJob, launcher *corev1.Pod, worker []*corev1.Pod) error { if launcher != nil { initializeMPIJobStatuses(mpiJob, mpiv1.MPIReplicaTypeLauncher) if isPodSucceeded(launcher) { - mpiJob.Status.ReplicaStatuses[commonv1.ReplicaType(mpiv1.MPIReplicaTypeLauncher)].Succeeded = 1 + mpiJob.Status.ReplicaStatuses[mpiv1.MPIReplicaTypeLauncher].Succeeded = 1 msg := fmt.Sprintf("MPIJob %s/%s successfully completed.", mpiJob.Namespace, mpiJob.Name) - r.Recorder.Event(mpiJob, corev1.EventTypeNormal, mpiJobSucceededReason, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeNormal, mpiJobSucceededReason, msg) if mpiJob.Status.CompletionTime == nil { now := metav1.Now() mpiJob.Status.CompletionTime = &now @@ -431,13 +432,13 @@ func (r *MPIJobReconciler) updateMPIJobStatus(mpiJob *mpiv1.MPIJob, launcher *co return err } } else if isPodFailed(launcher) { - mpiJob.Status.ReplicaStatuses[commonv1.ReplicaType(mpiv1.MPIReplicaTypeLauncher)].Failed = 1 + mpiJob.Status.ReplicaStatuses[mpiv1.MPIReplicaTypeLauncher].Failed = 1 msg := fmt.Sprintf("MPIJob %s/%s has failed", mpiJob.Namespace, mpiJob.Name) reason := launcher.Status.Reason if reason == "" { reason = mpiJobFailedReason } - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, reason, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, reason, msg) if reason == "Evicted" { reason = mpiJobEvict } else if !isEvicted(mpiJob.Status) && mpiJob.Status.CompletionTime == nil { @@ -451,7 +452,7 @@ func (r *MPIJobReconciler) updateMPIJobStatus(mpiJob *mpiv1.MPIJob, launcher *co } } else if isPodRunning(launcher) { - mpiJob.Status.ReplicaStatuses[commonv1.ReplicaType(mpiv1.MPIReplicaTypeLauncher)].Active = 1 + mpiJob.Status.ReplicaStatuses[mpiv1.MPIReplicaTypeLauncher].Active = 1 } } @@ -464,15 +465,15 @@ func (r *MPIJobReconciler) updateMPIJobStatus(mpiJob *mpiv1.MPIJob, launcher *co for i := 0; i < len(worker); i++ { switch worker[i].Status.Phase { case corev1.PodFailed: - mpiJob.Status.ReplicaStatuses[commonv1.ReplicaType(mpiv1.MPIReplicaTypeWorker)].Failed += 1 + mpiJob.Status.ReplicaStatuses[mpiv1.MPIReplicaTypeWorker].Failed += 1 if worker[i].Status.Reason == "Evicted" { evict += 1 } case corev1.PodSucceeded: - mpiJob.Status.ReplicaStatuses[commonv1.ReplicaType(mpiv1.MPIReplicaTypeWorker)].Succeeded += 1 + mpiJob.Status.ReplicaStatuses[mpiv1.MPIReplicaTypeWorker].Succeeded += 1 case corev1.PodRunning: running += 1 - mpiJob.Status.ReplicaStatuses[commonv1.ReplicaType(mpiv1.MPIReplicaTypeWorker)].Active += 1 + mpiJob.Status.ReplicaStatuses[mpiv1.MPIReplicaTypeWorker].Active += 1 } } if evict > 0 { @@ -480,7 +481,7 @@ func (r *MPIJobReconciler) updateMPIJobStatus(mpiJob *mpiv1.MPIJob, launcher *co if err := updateMPIJobConditions(mpiJob, commonv1.JobFailed, mpiJobEvict, msg); err != nil { return err } - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, mpiJobEvict, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, mpiJobEvict, msg) } if launcher != nil && launcher.Status.Phase == corev1.PodRunning && running == len(worker) { @@ -489,15 +490,15 @@ func (r *MPIJobReconciler) updateMPIJobStatus(mpiJob *mpiv1.MPIJob, launcher *co if err != nil { return err } - r.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "MPIJobRunning", "MPIJob %s/%s is running", mpiJob.Namespace, mpiJob.Name) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "MPIJobRunning", "MPIJob %s/%s is running", mpiJob.Namespace, mpiJob.Name) } return nil } -func (r *MPIJobReconciler) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) { +func (jc *MPIJobReconciler) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) { job := &mpiv1.MPIJob{} - clientReader, err := util.GetDelegatingClientFromClient(r.Client) + clientReader, err := util.GetDelegatingClientFromClient(jc.Client) if err != nil { return nil, err } @@ -516,7 +517,7 @@ func (r *MPIJobReconciler) GetJobFromAPIClient(namespace, name string) (metav1.O // GetPodsForJob returns the set of pods that this job should manage. // It also reconciles ControllerRef by adopting/orphaning. // Note that the returned Pods are pointers into the cache. -func (r *MPIJobReconciler) GetPodsForJob(jobObject interface{}) ([]*corev1.Pod, error) { +func (jc *MPIJobReconciler) GetPodsForJob(jobObject interface{}) ([]*corev1.Pod, error) { job, ok := jobObject.(metav1.Object) if !ok { return nil, fmt.Errorf("job is not of type metav1.Object") @@ -524,7 +525,7 @@ func (r *MPIJobReconciler) GetPodsForJob(jobObject interface{}) ([]*corev1.Pod, // Create selector. selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: r.GenLabels(job.GetName()), + MatchLabels: jc.GenLabels(job.GetName()), }) if err != nil { @@ -533,7 +534,7 @@ func (r *MPIJobReconciler) GetPodsForJob(jobObject interface{}) ([]*corev1.Pod, // List all pods to include those that don't match the selector anymore // but have a ControllerRef pointing to this controller. podlist := &corev1.PodList{} - err = r.List(context.Background(), podlist, + err = jc.List(context.Background(), podlist, client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(job.GetNamespace())) if err != nil { return nil, err @@ -544,7 +545,7 @@ func (r *MPIJobReconciler) GetPodsForJob(jobObject interface{}) ([]*corev1.Pod, // If any adoptions are attempted, we should first recheck for deletion // with an uncached quorum read sometime after listing Pods (see #42639). canAdoptFunc := common.RecheckDeletionTimestamp(func() (metav1.Object, error) { - fresh, err := r.Controller.GetJobFromAPIClient(job.GetNamespace(), job.GetName()) + fresh, err := jc.Controller.GetJobFromAPIClient(job.GetNamespace(), job.GetName()) if err != nil { return nil, err } @@ -553,24 +554,24 @@ func (r *MPIJobReconciler) GetPodsForJob(jobObject interface{}) ([]*corev1.Pod, } return fresh, nil }) - cm := control.NewPodControllerRefManager(r.PodControl, job, selector, r.Controller.GetAPIGroupVersionKind(), canAdoptFunc) + cm := control.NewPodControllerRefManager(jc.PodControl, job, selector, jc.Controller.GetAPIGroupVersionKind(), canAdoptFunc) return cm.ClaimPods(pods) } -func (r *MPIJobReconciler) DeleteJob(job interface{}) error { +func (jc *MPIJobReconciler) DeleteJob(job interface{}) error { mpiJob, ok := job.(*mpiv1.MPIJob) if !ok { return fmt.Errorf("%v is not a type of TFJob", mpiJob) } log := commonutil.LoggerForJob(mpiJob) - if err := r.Delete(context.Background(), mpiJob); err != nil { - r.Recorder.Eventf(mpiJob, corev1.EventTypeWarning, FailedDeleteJobReason, "Error deleting: %v", err) + if err := jc.Delete(context.Background(), mpiJob); err != nil { + jc.Recorder.Eventf(mpiJob, corev1.EventTypeWarning, FailedDeleteJobReason, "Error deleting: %v", err) log.Errorf("failed to delete job %s/%s, %v", mpiJob.Namespace, mpiJob.Name, err) return err } - r.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, SuccessfulDeleteJobReason, "Deleted job: %v", mpiJob.Name) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, SuccessfulDeleteJobReason, "Deleted job: %v", mpiJob.Name) log.Infof("job %s/%s has been deleted", mpiJob.Namespace, mpiJob.Name) trainingoperatorcommon.DeletedJobsCounterInc(mpiJob.Namespace, mpiv1.FrameworkName) return nil @@ -579,11 +580,11 @@ func (r *MPIJobReconciler) DeleteJob(job interface{}) error { // GetServicesForJob returns the set of services that this job should manage. // It also reconciles ControllerRef by adopting/orphaning. // Note that the returned services are pointers into the cache. -func (r *MPIJobReconciler) GetServicesForJob(jobObject interface{}) ([]*corev1.Service, error) { +func (jc *MPIJobReconciler) GetServicesForJob(jobObject interface{}) ([]*corev1.Service, error) { return nil, nil } -func (r *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, jobStatus *commonv1.JobStatus) error { +func (jc *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, jobStatus *commonv1.JobStatus) error { mpiJob, ok := job.(*mpiv1.MPIJob) if !ok { return fmt.Errorf("%+v is not a type of MPIJob", job) @@ -600,7 +601,7 @@ func (r *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv logrus.Infof("MPIJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d , failed=%d", mpiJob.Name, rtype, expected, running, succeeded, failed) - if rtype == commonv1.ReplicaType(mpiv1.MPIReplicaTypeLauncher) { + if rtype == mpiv1.MPIReplicaTypeLauncher { if running > 0 { msg := fmt.Sprintf("MPIJob %s is running.", mpiJob.Name) err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, commonutil.JobRunningReason, msg) @@ -613,7 +614,7 @@ func (r *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv if expected == 0 { msg := fmt.Sprintf("MPIJob %s is successfully completed.", mpiJob.Name) logrus.Info(msg) - r.Recorder.Event(mpiJob, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeNormal, commonutil.JobSucceededReason, msg) if jobStatus.CompletionTime == nil { now := metav1.Now() jobStatus.CompletionTime = &now @@ -630,7 +631,7 @@ func (r *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv if failed > 0 { if spec.RestartPolicy == commonv1.RestartPolicyExitCode { msg := fmt.Sprintf("MPIJob %s is restarting because %d %s replica(s) failed.", mpiJob.Name, failed, rtype) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, commonutil.JobRestartingReason, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, commonutil.JobRestartingReason, msg) err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRestarting, commonutil.JobRestartingReason, msg) if err != nil { commonutil.LoggerForJob(mpiJob).Infof("Append job condition error: %v", err) @@ -639,7 +640,7 @@ func (r *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv trainingoperatorcommon.RestartedJobsCounterInc(mpiJob.Namespace, mpiv1.FrameworkName) } else { msg := fmt.Sprintf("MPIJob %s is failed because %d %s replica(s) failed.", mpiJob.Name, failed, rtype) - r.Recorder.Event(mpiJob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeNormal, commonutil.JobFailedReason, msg) if mpiJob.Status.CompletionTime == nil { now := metav1.Now() mpiJob.Status.CompletionTime = &now @@ -666,7 +667,7 @@ func (r *MPIJobReconciler) UpdateJobStatus(job interface{}, replicas map[commonv return nil } -func (r *MPIJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus *commonv1.JobStatus) error { +func (jc *MPIJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus *commonv1.JobStatus) error { mpiJob, ok := job.(*mpiv1.MPIJob) if !ok { return fmt.Errorf("%v is not a type of MpiJob", mpiJob) @@ -682,10 +683,10 @@ func (r *MPIJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus mpiJob = mpiJob.DeepCopy() mpiJob.Status = *jobStatus.DeepCopy() - result := r.Status().Update(context.Background(), mpiJob) + result := jc.Status().Update(context.Background(), mpiJob) if result != nil { - r.Log.WithValues("mpijob", types.NamespacedName{ + jc.Log.WithValues("mpijob", types.NamespacedName{ Namespace: mpiJob.GetNamespace(), Name: mpiJob.GetName(), }) @@ -696,10 +697,10 @@ func (r *MPIJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus } // getLauncherJob gets the launcher Job controlled by this MPIJob. -func (r *MPIJobReconciler) getLauncherJob(mpiJob *mpiv1.MPIJob) (*corev1.Pod, error) { +func (jc *MPIJobReconciler) getLauncherJob(mpiJob *mpiv1.MPIJob) (*corev1.Pod, error) { launcher := &corev1.Pod{} NamespacedName := types.NamespacedName{Namespace: mpiJob.Namespace, Name: mpiJob.Name + launcherSuffix} - err := r.Get(context.Background(), NamespacedName, launcher) + err := jc.Get(context.Background(), NamespacedName, launcher) if errors.IsNotFound(err) { return nil, nil } @@ -714,7 +715,7 @@ func (r *MPIJobReconciler) getLauncherJob(mpiJob *mpiv1.MPIJob) (*corev1.Pod, er // a warning to the event recorder and return. if !metav1.IsControlledBy(launcher, mpiJob) { msg := fmt.Sprintf(MessageResourceExists, launcher.Name, launcher.Kind) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) return launcher, fmt.Errorf(msg) } return launcher, nil @@ -722,9 +723,9 @@ func (r *MPIJobReconciler) getLauncherJob(mpiJob *mpiv1.MPIJob) (*corev1.Pod, er // getOrCreateConfigMap gets the ConfigMap controlled by this MPIJob, or creates // one if it doesn't exist. -func (r *MPIJobReconciler) getOrCreateConfigMap(mpiJob *mpiv1.MPIJob, workerReplicas int32, isGPULauncher bool) (*corev1.ConfigMap, error) { +func (jc *MPIJobReconciler) getOrCreateConfigMap(mpiJob *mpiv1.MPIJob, workerReplicas int32, isGPULauncher bool) (*corev1.ConfigMap, error) { newCM := newConfigMap(mpiJob, workerReplicas, isGPULauncher) - podList, err := r.getRunningWorkerPods(mpiJob) + podList, err := jc.getRunningWorkerPods(mpiJob) if err != nil { return nil, err } @@ -732,11 +733,11 @@ func (r *MPIJobReconciler) getOrCreateConfigMap(mpiJob *mpiv1.MPIJob, workerRepl cm := &corev1.ConfigMap{} NamespacedName := types.NamespacedName{Namespace: mpiJob.Namespace, Name: mpiJob.Name + configSuffix} - err = r.Get(context.Background(), NamespacedName, cm) + err = jc.Get(context.Background(), NamespacedName, cm) // If the ConfigMap doesn't exist, we'll create it. if errors.IsNotFound(err) { - cm, err = r.KubeClientSet.CoreV1().ConfigMaps(mpiJob.Namespace).Create(context.Background(), newCM, metav1.CreateOptions{}) + cm, err = jc.KubeClientSet.CoreV1().ConfigMaps(mpiJob.Namespace).Create(context.Background(), newCM, metav1.CreateOptions{}) } // If an error occurs during Get/Create, we'll requeue the item so we // can attempt processing again later. This could have been caused by a @@ -749,13 +750,13 @@ func (r *MPIJobReconciler) getOrCreateConfigMap(mpiJob *mpiv1.MPIJob, workerRepl // should log a warning to the event recorder and return. if !metav1.IsControlledBy(cm, mpiJob) { msg := fmt.Sprintf(MessageResourceExists, cm.Name, cm.Kind) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) return nil, fmt.Errorf(msg) } // If the ConfigMap is changed, update it if !reflect.DeepEqual(cm.Data, newCM.Data) { - cm, err = r.KubeClientSet.CoreV1().ConfigMaps(mpiJob.Namespace).Update(context.Background(), newCM, metav1.UpdateOptions{}) + cm, err = jc.KubeClientSet.CoreV1().ConfigMaps(mpiJob.Namespace).Update(context.Background(), newCM, metav1.UpdateOptions{}) if err != nil { return nil, err } @@ -766,18 +767,18 @@ func (r *MPIJobReconciler) getOrCreateConfigMap(mpiJob *mpiv1.MPIJob, workerRepl // getOrCreateLauncherServiceAccount gets the launcher ServiceAccount controlled // by this MPIJob, or creates one if it doesn't exist. -func (r *MPIJobReconciler) getOrCreateLauncherServiceAccount(mpiJob *mpiv1.MPIJob) (*corev1.ServiceAccount, error) { +func (jc *MPIJobReconciler) getOrCreateLauncherServiceAccount(mpiJob *mpiv1.MPIJob) (*corev1.ServiceAccount, error) { sa := &corev1.ServiceAccount{} NamespacedName := types.NamespacedName{Namespace: mpiJob.Namespace, Name: mpiJob.Name + launcherSuffix} - err := r.Get(context.Background(), NamespacedName, sa) + err := jc.Get(context.Background(), NamespacedName, sa) if err == nil { - r.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "ServiceAccount is exist", "ServiceAccount: %v", sa.Name) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "ServiceAccount is exist", "ServiceAccount: %v", sa.Name) } if errors.IsNotFound(err) { - sa, err = r.KubeClientSet.CoreV1().ServiceAccounts(mpiJob.Namespace).Create(context.Background(), newLauncherServiceAccount(mpiJob), metav1.CreateOptions{}) + sa, err = jc.KubeClientSet.CoreV1().ServiceAccounts(mpiJob.Namespace).Create(context.Background(), newLauncherServiceAccount(mpiJob), metav1.CreateOptions{}) } // If an error occurs during Get/Create, we'll requeue the item so we // can attempt processing again later. This could have been caused by a @@ -789,7 +790,7 @@ func (r *MPIJobReconciler) getOrCreateLauncherServiceAccount(mpiJob *mpiv1.MPIJo // should log a warning to the event recorder and return. if !metav1.IsControlledBy(sa, mpiJob) { msg := fmt.Sprintf(MessageResourceExists, sa.Name, sa.Kind) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) return nil, fmt.Errorf(msg) } @@ -797,19 +798,19 @@ func (r *MPIJobReconciler) getOrCreateLauncherServiceAccount(mpiJob *mpiv1.MPIJo } // getOrCreateLauncherRole gets the launcher Role controlled by this MPIJob. -func (r *MPIJobReconciler) getOrCreateLauncherRole(mpiJob *mpiv1.MPIJob, workerReplicas int32) (*rbacv1.Role, error) { +func (jc *MPIJobReconciler) getOrCreateLauncherRole(mpiJob *mpiv1.MPIJob, workerReplicas int32) (*rbacv1.Role, error) { role := &rbacv1.Role{} NamespacedName := types.NamespacedName{Namespace: mpiJob.Namespace, Name: mpiJob.Name + launcherSuffix} - err := r.Get(context.Background(), NamespacedName, role) + err := jc.Get(context.Background(), NamespacedName, role) if err == nil { - r.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "LauncherRole is exist", "LauncherRole: %v", role.Name) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "LauncherRole is exist", "LauncherRole: %v", role.Name) } launcherRole := newLauncherRole(mpiJob, workerReplicas) // If the Role doesn't exist, we'll create it. if errors.IsNotFound(err) { - role, err = r.KubeClientSet.RbacV1().Roles(mpiJob.Namespace).Create(context.Background(), launcherRole, metav1.CreateOptions{}) + role, err = jc.KubeClientSet.RbacV1().Roles(mpiJob.Namespace).Create(context.Background(), launcherRole, metav1.CreateOptions{}) } // If an error occurs during Get/Create, we'll requeue the item so we // can attempt processing again later. This could have been caused by a @@ -821,12 +822,12 @@ func (r *MPIJobReconciler) getOrCreateLauncherRole(mpiJob *mpiv1.MPIJob, workerR // should log a warning to the event recorder and return. if !metav1.IsControlledBy(role, mpiJob) { msg := fmt.Sprintf(MessageResourceExists, role.Name, role.Kind) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) return nil, fmt.Errorf(msg) } if !reflect.DeepEqual(role.Rules, launcherRole.Rules) { - role, err = r.KubeClientSet.RbacV1().Roles(mpiJob.Namespace).Update(context.Background(), launcherRole, metav1.UpdateOptions{}) + role, err = jc.KubeClientSet.RbacV1().Roles(mpiJob.Namespace).Update(context.Background(), launcherRole, metav1.UpdateOptions{}) if err != nil { return nil, err } @@ -837,18 +838,18 @@ func (r *MPIJobReconciler) getOrCreateLauncherRole(mpiJob *mpiv1.MPIJob, workerR // getLauncherRoleBinding gets the launcher RoleBinding controlled by this // MPIJob, or creates one if it doesn't exist. -func (r *MPIJobReconciler) getLauncherRoleBinding(mpiJob *mpiv1.MPIJob) (*rbacv1.RoleBinding, error) { +func (jc *MPIJobReconciler) getLauncherRoleBinding(mpiJob *mpiv1.MPIJob) (*rbacv1.RoleBinding, error) { rb := &rbacv1.RoleBinding{} NamespacedName := types.NamespacedName{Namespace: mpiJob.Namespace, Name: mpiJob.Name + launcherSuffix} - err := r.Get(context.Background(), NamespacedName, rb) + err := jc.Get(context.Background(), NamespacedName, rb) // If the RoleBinding doesn't exist, we'll create it. if err == nil { - r.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "RoleBinding is exist", "RoleBinding: %v", rb.Name) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "RoleBinding is exist", "RoleBinding: %v", rb.Name) } if errors.IsNotFound(err) { - rb, err = r.KubeClientSet.RbacV1().RoleBindings(mpiJob.Namespace).Create(context.Background(), newLauncherRoleBinding(mpiJob), metav1.CreateOptions{}) + rb, err = jc.KubeClientSet.RbacV1().RoleBindings(mpiJob.Namespace).Create(context.Background(), newLauncherRoleBinding(mpiJob), metav1.CreateOptions{}) } // If an error occurs during Get/Create, we'll requeue the item so we // can attempt processing again later. This could have been caused by a @@ -860,7 +861,7 @@ func (r *MPIJobReconciler) getLauncherRoleBinding(mpiJob *mpiv1.MPIJob) (*rbacv1 // should log a warning to the event recorder and return. if !metav1.IsControlledBy(rb, mpiJob) { msg := fmt.Sprintf(MessageResourceExists, rb.Name, rb.Kind) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) return nil, fmt.Errorf(msg) } @@ -869,7 +870,7 @@ func (r *MPIJobReconciler) getLauncherRoleBinding(mpiJob *mpiv1.MPIJob) (*rbacv1 // getOrCreateWorker gets the worker Pod controlled by this // MPIJob, or creates one if it doesn't exist. -func (r *MPIJobReconciler) getOrCreateWorker(mpiJob *mpiv1.MPIJob) ([]*corev1.Pod, error) { +func (jc *MPIJobReconciler) getOrCreateWorker(mpiJob *mpiv1.MPIJob) ([]*corev1.Pod, error) { var ( workerPrefix string = mpiJob.Name + workerSuffix workerPods []*corev1.Pod = []*corev1.Pod{} @@ -889,7 +890,7 @@ func (r *MPIJobReconciler) getOrCreateWorker(mpiJob *mpiv1.MPIJob) ([]*corev1.Po } podlist := &corev1.PodList{} - err = r.List(context.Background(), podlist, client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(mpiJob.GetNamespace())) + err = jc.List(context.Background(), podlist, client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(mpiJob.GetNamespace())) if err != nil { return nil, err @@ -903,7 +904,7 @@ func (r *MPIJobReconciler) getOrCreateWorker(mpiJob *mpiv1.MPIJob) ([]*corev1.Po index, err := strconv.Atoi(indexStr) if err == nil { if index >= int(*workerReplicas) { - err = r.KubeClientSet.CoreV1().Pods(pod.Namespace).Delete(context.Background(), pod.Name, metav1.DeleteOptions{}) + err = jc.KubeClientSet.CoreV1().Pods(pod.Namespace).Delete(context.Background(), pod.Name, metav1.DeleteOptions{}) if err != nil { return nil, err } @@ -917,24 +918,24 @@ func (r *MPIJobReconciler) getOrCreateWorker(mpiJob *mpiv1.MPIJob) ([]*corev1.Po pod := &corev1.Pod{} NamespacedName := types.NamespacedName{Namespace: mpiJob.Namespace, Name: name} - err := r.Get(context.Background(), NamespacedName, pod) + err := jc.Get(context.Background(), NamespacedName, pod) // If the worker Pod doesn't exist, we'll create it. if errors.IsNotFound(err) { - worker := r.newWorker(mpiJob, name) + worker := jc.newWorker(mpiJob, name) if worker == nil { msg := fmt.Sprintf(MessageResourceDoesNotExist, "Worker") - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceDoesNotExist, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceDoesNotExist, msg) err = fmt.Errorf(msg) return nil, err } // Insert ReplicaIndexLabel worker.Labels[commonv1.ReplicaIndexLabel] = strconv.Itoa(int(i)) - pod, err = r.KubeClientSet.CoreV1().Pods(mpiJob.Namespace).Create(context.Background(), worker, metav1.CreateOptions{}) + pod, err = jc.KubeClientSet.CoreV1().Pods(mpiJob.Namespace).Create(context.Background(), worker, metav1.CreateOptions{}) if err == nil { - r.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "SuccessfulCreatePod", "Created worker pod: %v", pod.Name) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeNormal, "SuccessfulCreatePod", "Created worker pod: %v", pod.Name) } else { - r.Recorder.Eventf(mpiJob, corev1.EventTypeWarning, "FailedCreatePod", "Created worker pod: %v", pod.Name) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeWarning, "FailedCreatePod", "Created worker pod: %v", pod.Name) } } @@ -942,14 +943,14 @@ func (r *MPIJobReconciler) getOrCreateWorker(mpiJob *mpiv1.MPIJob) ([]*corev1.Po // can attempt processing again later. This could have been caused by a // temporary network failure, or any other transient reason. if err != nil && !errors.IsNotFound(err) { - r.Recorder.Eventf(mpiJob, corev1.EventTypeWarning, mpiJobFailedReason, "worker pod created failed: %v", err) + jc.Recorder.Eventf(mpiJob, corev1.EventTypeWarning, mpiJobFailedReason, "worker pod created failed: %v", err) return nil, err } // If the worker is not controlled by this MPIJob resource, we should log // a warning to the event recorder and return. if pod != nil && !metav1.IsControlledBy(pod, mpiJob) { msg := fmt.Sprintf(MessageResourceExists, pod.Name, pod.Kind) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceExists, msg) return nil, fmt.Errorf(msg) } workerPods = append(workerPods, pod) @@ -961,7 +962,7 @@ func (r *MPIJobReconciler) getOrCreateWorker(mpiJob *mpiv1.MPIJob) ([]*corev1.Po // newWorker creates a new worker Pod for an MPIJob resource. It also // sets the appropriate OwnerReferences on the resource so handleObject can // discover the MPIJob resource that 'owns' it. -func (r *MPIJobReconciler) newWorker(mpiJob *mpiv1.MPIJob, name string) *corev1.Pod { +func (jc *MPIJobReconciler) newWorker(mpiJob *mpiv1.MPIJob, name string) *corev1.Pod { labels := defaultWorkerLabels(mpiJob.Name) podSpec := mpiJob.Spec.MPIReplicaSpecs[mpiv1.MPIReplicaTypeWorker].Template.DeepCopy() @@ -1016,11 +1017,11 @@ func (r *MPIJobReconciler) newWorker(mpiJob *mpiv1.MPIJob, name string) *corev1. // if gang-scheduling is enabled: // 1. if user has specified other scheduler, we report a warning without overriding any fields. // 2. if no SchedulerName is set for pods, then we set the SchedulerName to "volcano". - if r.Config.EnableGangScheduling { + if jc.Config.EnableGangScheduling { if !util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, gangSchedulerName) { errMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten" logger.Warning(errMsg) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg) } else { podSpec.Spec.SchedulerName = gangSchedulerName } @@ -1049,7 +1050,7 @@ func (r *MPIJobReconciler) newWorker(mpiJob *mpiv1.MPIJob, name string) *corev1. // newLauncher creates a new launcher Job for an MPIJob resource. It also sets // the appropriate OwnerReferences on the resource so handleObject can discover // the MPIJob resource that 'owns' it. -func (r *MPIJobReconciler) newLauncher(mpiJob *mpiv1.MPIJob, kubectlDeliveryImage string, isGPULauncher bool) *corev1.Pod { +func (jc *MPIJobReconciler) newLauncher(mpiJob *mpiv1.MPIJob, kubectlDeliveryImage string, isGPULauncher bool) *corev1.Pod { launcherName := mpiJob.Name + launcherSuffix labels := map[string]string{ labelGroupName: "kubeflow.org", @@ -1068,11 +1069,11 @@ func (r *MPIJobReconciler) newLauncher(mpiJob *mpiv1.MPIJob, kubectlDeliveryImag logger := commonutil.LoggerForReplica(mpiJob, strings.ToLower(string(mpiv1.MPIReplicaTypeLauncher))) // add SchedulerName to podSpec - if r.Config.EnableGangScheduling { + if jc.Config.EnableGangScheduling { if !util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, gangSchedulerName) { errMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten" logger.Warning(errMsg) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg) } else { podSpec.Spec.SchedulerName = gangSchedulerName } @@ -1125,7 +1126,7 @@ func (r *MPIJobReconciler) newLauncher(mpiJob *mpiv1.MPIJob, kubectlDeliveryImag if len(podSpec.Spec.Containers) == 0 { klog.Errorln("Launcher pod does not have any containers in its spec") msg := fmt.Sprintf(MessageResourceDoesNotExist, "Launcher") - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceDoesNotExist, msg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, ErrResourceDoesNotExist, msg) return nil } container := podSpec.Spec.Containers[0] @@ -1171,7 +1172,7 @@ func (r *MPIJobReconciler) newLauncher(mpiJob *mpiv1.MPIJob, kubectlDeliveryImag if podSpec.Spec.RestartPolicy != corev1.RestartPolicy("") { errMsg := "Restart policy in pod template will be overwritten by restart policy in replica spec" klog.Warning(errMsg) - r.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateRestartPolicyReason, errMsg) + jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateRestartPolicyReason, errMsg) } setRestartPolicy(podSpec, mpiJob.Spec.MPIReplicaSpecs[mpiv1.MPIReplicaTypeLauncher]) @@ -1226,14 +1227,14 @@ func (r *MPIJobReconciler) newLauncher(mpiJob *mpiv1.MPIJob, kubectlDeliveryImag } // getRunningWorkerPods get all worker Pods with Running phase controlled by this MPIJob. -func (r *MPIJobReconciler) getRunningWorkerPods(mpiJob *mpiv1.MPIJob) ([]*corev1.Pod, error) { +func (jc *MPIJobReconciler) getRunningWorkerPods(mpiJob *mpiv1.MPIJob) ([]*corev1.Pod, error) { selector, err := workerSelector(mpiJob.Name) if err != nil { return nil, err } podFullList := &corev1.PodList{} - err = r.List(context.Background(), podFullList, client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(mpiJob.GetNamespace())) + err = jc.List(context.Background(), podFullList, client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(mpiJob.GetNamespace())) //podFullList, err := r.PodLister.List(selector) if err != nil { return nil, err diff --git a/pkg/controller.v1/mpi/suite_test.go b/pkg/controller.v1/mpi/suite_test.go index 455d41dc1c..8b82fe5a82 100644 --- a/pkg/controller.v1/mpi/suite_test.go +++ b/pkg/controller.v1/mpi/suite_test.go @@ -19,6 +19,8 @@ import ( "path/filepath" "testing" + "github.com/kubeflow/training-operator/pkg/config" + mpiv1 "github.com/kubeflow/training-operator/pkg/apis/mpi/v1" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -69,6 +71,9 @@ var _ = BeforeSuite(func() { err = mpiv1.AddToScheme(scheme.Scheme) Expect(err).NotTo(HaveOccurred()) + // Set Default kubectl delivery image + config.Config.MPIKubectlDeliveryImage = config.MPIKubectlDeliveryImageDefault + //+kubebuilder:scaffold:scheme testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme})