diff --git a/cmd/mpi-operator/main.go b/cmd/mpi-operator/main.go index 022dae0f..589adf0a 100644 --- a/cmd/mpi-operator/main.go +++ b/cmd/mpi-operator/main.go @@ -26,6 +26,7 @@ import ( clientset "github.com/kubeflow/mpi-operator/pkg/client/clientset/versioned" informers "github.com/kubeflow/mpi-operator/pkg/client/informers/externalversions" "github.com/kubeflow/mpi-operator/pkg/controllers" + policyinformers "k8s.io/client-go/informers/policy/v1beta1" ) var ( @@ -70,6 +71,10 @@ func main() { kubeflowInformerFactory = informers.NewSharedInformerFactoryWithOptions(kubeflowClient, 0, informers.WithNamespace(namespace), nil) } + var pdbInformer policyinformers.PodDisruptionBudgetInformer + if enableGangScheduling { + pdbInformer = kubeInformerFactory.Policy().V1beta1().PodDisruptionBudgets() + } controller := controllers.NewMPIJobController( kubeClient, kubeflowClient, @@ -79,7 +84,7 @@ func main() { kubeInformerFactory.Rbac().V1().RoleBindings(), kubeInformerFactory.Apps().V1().StatefulSets(), kubeInformerFactory.Batch().V1().Jobs(), - kubeInformerFactory.Policy().V1beta1().PodDisruptionBudgets(), + pdbInformer, kubeflowInformerFactory.Kubeflow().V1alpha1().MPIJobs(), gpusPerNode, processingUnitsPerNode, diff --git a/pkg/controllers/mpi_job_controller.go b/pkg/controllers/mpi_job_controller.go index 3a29f3a8..57540306 100644 --- a/pkg/controllers/mpi_job_controller.go +++ b/pkg/controllers/mpi_job_controller.go @@ -173,6 +173,13 @@ func NewMPIJobController( eventBroadcaster.StartLogging(glog.Infof) eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: kubeClient.CoreV1().Events("")}) recorder := eventBroadcaster.NewRecorder(scheme.Scheme, corev1.EventSource{Component: controllerAgentName}) + var pdbLister policylisters.PodDisruptionBudgetLister + var pdbSynced cache.InformerSynced + + if enableGangScheduling { + pdbLister = pdbInformer.Lister() + pdbSynced = pdbInformer.Informer().HasSynced + } controller := &MPIJobController{ kubeClient: kubeClient, @@ -189,8 +196,8 @@ func NewMPIJobController( statefulSetSynced: statefulSetInformer.Informer().HasSynced, jobLister: jobInformer.Lister(), jobSynced: jobInformer.Informer().HasSynced, - pdbLister: pdbInformer.Lister(), - pdbSynced: pdbInformer.Informer().HasSynced, + pdbLister: pdbLister, + pdbSynced: pdbSynced, mpiJobLister: mpiJobInformer.Lister(), mpiJobSynced: mpiJobInformer.Informer().HasSynced, queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "MPIJobs"), @@ -307,21 +314,26 @@ func NewMPIJobController( }, DeleteFunc: controller.handleObject, }) - pdbInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ - AddFunc: controller.handleObject, - UpdateFunc: func(old, new interface{}) { - newPolicy := new.(*policyv1beta1.PodDisruptionBudget) - oldPolicy := old.(*policyv1beta1.PodDisruptionBudget) - if newPolicy.ResourceVersion == oldPolicy.ResourceVersion { - // Periodic re-sync will send update events for all known PodDisruptionBudgets. - // Two different versions of the same Job will always have - // different RVs. - return - } - controller.handleObject(new) - }, - DeleteFunc: controller.handleObject, - }) + + // there are cases pdbInformer is nil, + // i.e. We should only create the pbdInformer when gang scheduling is enabled. + if pdbInformer != nil { + pdbInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: controller.handleObject, + UpdateFunc: func(old, new interface{}) { + newPolicy := new.(*policyv1beta1.PodDisruptionBudget) + oldPolicy := old.(*policyv1beta1.PodDisruptionBudget) + if newPolicy.ResourceVersion == oldPolicy.ResourceVersion { + // Periodic re-sync will send update events for all known PodDisruptionBudgets. + // Two different versions of the same Job will always have + // different RVs. + return + } + controller.handleObject(new) + }, + DeleteFunc: controller.handleObject, + }) + } return controller } diff --git a/pkg/controllers/mpi_job_controller_test.go b/pkg/controllers/mpi_job_controller_test.go index f18a1078..0e37f748 100644 --- a/pkg/controllers/mpi_job_controller_test.go +++ b/pkg/controllers/mpi_job_controller_test.go @@ -30,6 +30,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/diff" kubeinformers "k8s.io/client-go/informers" + policyinformers "k8s.io/client-go/informers/policy/v1beta1" k8sfake "k8s.io/client-go/kubernetes/fake" core "k8s.io/client-go/testing" "k8s.io/client-go/tools/cache" @@ -142,13 +143,16 @@ func newMPIJobWithCustomResources(name string, replicas *int32, pusPerReplica in return mpiJob } -func (f *fixture) newController(processingResourceType string) (*MPIJobController, informers.SharedInformerFactory, kubeinformers.SharedInformerFactory) { +func (f *fixture) newController(processingResourceType string, enableGangScheduling bool) (*MPIJobController, informers.SharedInformerFactory, kubeinformers.SharedInformerFactory) { f.client = fake.NewSimpleClientset(f.objects...) f.kubeClient = k8sfake.NewSimpleClientset(f.kubeObjects...) i := informers.NewSharedInformerFactory(f.client, noResyncPeriodFunc()) k8sI := kubeinformers.NewSharedInformerFactory(f.kubeClient, noResyncPeriodFunc()) - + var pdbInformer policyinformers.PodDisruptionBudgetInformer + if enableGangScheduling { + pdbInformer = k8sI.Policy().V1beta1().PodDisruptionBudgets() + } c := NewMPIJobController( f.kubeClient, f.client, @@ -158,13 +162,13 @@ func (f *fixture) newController(processingResourceType string) (*MPIJobControlle k8sI.Rbac().V1().RoleBindings(), k8sI.Apps().V1().StatefulSets(), k8sI.Batch().V1().Jobs(), - k8sI.Policy().V1beta1().PodDisruptionBudgets(), + pdbInformer, i.Kubeflow().V1alpha1().MPIJobs(), 8, 8, processingResourceType, "kubectl-delivery", - false, + enableGangScheduling, ) c.configMapSynced = alwaysReady @@ -213,15 +217,20 @@ func (f *fixture) newController(processingResourceType string) (*MPIJobControlle } func (f *fixture) run(mpiJobName string, processingResourceType string) { - f.runController(mpiJobName, true, false, processingResourceType) + f.runController(mpiJobName, true, false, processingResourceType, false) +} + +func (f *fixture) runWithGangScheduling(mpiJobName string, processingResourceType string) { + f.runController(mpiJobName, true, false, processingResourceType, true) } func (f *fixture) runExpectError(mpiJobName string, processingResourceType string) { - f.runController(mpiJobName, true, true, processingResourceType) + f.runController(mpiJobName, true, true, processingResourceType, false) } -func (f *fixture) runController(mpiJobName string, startInformers bool, expectError bool, processingResourceType string) { - c, i, k8sI := f.newController(processingResourceType) +func (f *fixture) runController( + mpiJobName string, startInformers bool, expectError bool, processingResourceType string, enableGangScheduling bool) { + c, i, k8sI := f.newController(processingResourceType, enableGangScheduling) if startInformers { stopCh := make(chan struct{}) defer close(stopCh) @@ -492,6 +501,26 @@ func TestLauncherNotControlledByUs(t *testing.T) { f.runExpectError(getKey(mpiJob, t), gpuResourceName) } +func TestLauncherSucceededWithGang(t *testing.T) { + f := newFixture(t) + + startTime := metav1.Now() + completionTime := metav1.Now() + mpiJob := newMPIJob("test", int32Ptr(64), &startTime, &completionTime) + f.setUpMPIJob(mpiJob) + + launcher := newLauncher(mpiJob, "kubectl-delivery") + launcher.Status.Succeeded = 1 + f.setUpLauncher(launcher) + + mpiJobCopy := mpiJob.DeepCopy() + mpiJobCopy.Status.LauncherStatus = kubeflow.LauncherSucceeded + setUpMPIJobTimestamp(mpiJobCopy, &startTime, &completionTime) + f.expectUpdateMPIJobStatusAction(mpiJobCopy) + + f.runWithGangScheduling(getKey(mpiJob, t), gpuResourceName) +} + func TestLauncherSucceeded(t *testing.T) { f := newFixture(t)