diff --git a/pkg/controller.v1/common/job.go b/pkg/controller.v1/common/job.go index b4e7df83cc..3ace17c768 100644 --- a/pkg/controller.v1/common/job.go +++ b/pkg/controller.v1/common/job.go @@ -279,10 +279,11 @@ func (jc *JobController) ReconcileJobs( var pgSpecFill FillPodGroupSpecFunc switch jc.Config.GangScheduling { case GangSchedulerVolcano: - pgSpecFill = func(pg metav1.Object) error { + pgSpecFill = func(pg metav1.Object) (metav1.Object, error) { volcanoPodGroup, match := pg.(*volcanov1beta1.PodGroup) + volcanoPodGroup = volcanoPodGroup.DeepCopy() if !match { - return fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg)) + return nil, fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg)) } volcanoPodGroup.Spec = volcanov1beta1.PodGroupSpec{ MinMember: minMember, @@ -290,20 +291,21 @@ func (jc *JobController) ReconcileJobs( PriorityClassName: priorityClass, MinResources: minResources, } - return nil + return volcanoPodGroup, nil } default: - pgSpecFill = func(pg metav1.Object) error { + pgSpecFill = func(pg metav1.Object) (metav1.Object, error) { schedulerPluginsPodGroup, match := pg.(*schedulerpluginsv1alpha1.PodGroup) + schedulerPluginsPodGroup = schedulerPluginsPodGroup.DeepCopy() if !match { - return fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg)) + return nil, fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg)) } schedulerPluginsPodGroup.Spec = schedulerpluginsv1alpha1.PodGroupSpec{ MinMember: minMember, MinResources: *minResources, ScheduleTimeoutSeconds: schedulerTimeout, } - return nil + return schedulerPluginsPodGroup, nil } } diff --git a/pkg/controller.v1/common/scheduling.go b/pkg/controller.v1/common/scheduling.go index 3c3d40b813..bed9f0a0b3 100644 --- a/pkg/controller.v1/common/scheduling.go +++ b/pkg/controller.v1/common/scheduling.go @@ -22,8 +22,8 @@ import ( "fmt" apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" - "github.com/google/go-cmp/cmp" log "github.com/sirupsen/logrus" policyapi "k8s.io/api/policy/v1beta1" k8serrors "k8s.io/apimachinery/pkg/api/errors" @@ -33,7 +33,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -type FillPodGroupSpecFunc func(object metav1.Object) error +type FillPodGroupSpecFunc func(object metav1.Object) (metav1.Object, error) func (jc *JobController) SyncPodGroup(job metav1.Object, specFunc FillPodGroupSpecFunc) (metav1.Object, error) { pgctl := jc.PodGroupControl @@ -42,14 +42,20 @@ func (jc *JobController) SyncPodGroup(job metav1.Object, specFunc FillPodGroupSp podGroup, err := pgctl.GetPodGroup(job.GetNamespace(), job.GetName()) if err == nil { // update podGroup for gang scheduling - oldPodGroup := &podGroup - if err = specFunc(podGroup); err != nil { + updatedSpecPodGroup, err := specFunc(podGroup) + if err != nil { return nil, fmt.Errorf("unable to fill the spec of PodGroup, '%v': %v", klog.KObj(podGroup), err) } - if diff := cmp.Diff(oldPodGroup, podGroup); len(diff) != 0 { - return podGroup, pgctl.UpdatePodGroup(podGroup.(client.Object)) + + existVolcanoPodGroup := podGroup.(*volcanov1beta1.PodGroup) + updatedSpecVolcanoPodGroup := updatedSpecPodGroup.(*volcanov1beta1.PodGroup) + // The hpa-controller may update the num of replicas + // https://github.com/kubeflow/common/pull/207 + if existVolcanoPodGroup.Spec.MinMember != updatedSpecVolcanoPodGroup.Spec.MinMember { + // The queue name should not be changed after the pg is created + updatedSpecVolcanoPodGroup.Spec.Queue = existVolcanoPodGroup.Spec.Queue + return updatedSpecPodGroup, pgctl.UpdatePodGroup(updatedSpecPodGroup.(client.Object)) } - return podGroup, nil } else if client.IgnoreNotFound(err) != nil { return nil, fmt.Errorf("unable to get a PodGroup: %v", err) } else { @@ -59,13 +65,14 @@ func (jc *JobController) SyncPodGroup(job metav1.Object, specFunc FillPodGroupSp newPodGroup.SetNamespace(job.GetNamespace()) newPodGroup.SetAnnotations(job.GetAnnotations()) newPodGroup.SetOwnerReferences([]metav1.OwnerReference{*jc.GenOwnerReference(job)}) - if err = specFunc(newPodGroup); err != nil { + updatedSpecPodGroup, err := specFunc(newPodGroup) + if err != nil { return nil, fmt.Errorf("unable to fill the spec of PodGroup, '%v': %v", klog.KObj(newPodGroup), err) } - err = pgctl.CreatePodGroup(newPodGroup) + err = pgctl.CreatePodGroup(updatedSpecPodGroup.(client.Object)) if err != nil { - return podGroup, fmt.Errorf("unable to create PodGroup: %v", err) + return updatedSpecPodGroup, fmt.Errorf("unable to create PodGroup: %v", err) } createdPodGroupsCount.Inc() }