Skip to content

Commit

Permalink
fix volcano podgroup update issue
Browse files Browse the repository at this point in the history
Signed-off-by: Weiyu Yen <[email protected]>
  • Loading branch information
ckyuto committed Apr 26, 2024
1 parent e764830 commit ca04ae6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
14 changes: 8 additions & 6 deletions pkg/controller.v1/common/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,31 +279,33 @@ func (jc *JobController) ReconcileJobs(
var pgSpecFill FillPodGroupSpecFunc
switch jc.Config.GangScheduling {
case GangSchedulerVolcano:
pgSpecFill = func(pg metav1.Object) error {
pgSpecFill = func(pg metav1.Object) (metav1.Object, error) {
volcanoPodGroup, match := pg.(*volcanov1beta1.PodGroup)
volcanoPodGroup = volcanoPodGroup.DeepCopy()
if !match {
return fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg))
return nil, fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg))
}
volcanoPodGroup.Spec = volcanov1beta1.PodGroupSpec{
MinMember: minMember,
Queue: queue,
PriorityClassName: priorityClass,
MinResources: minResources,
}
return nil
return volcanoPodGroup, nil
}
default:
pgSpecFill = func(pg metav1.Object) error {
pgSpecFill = func(pg metav1.Object) (metav1.Object, error) {
schedulerPluginsPodGroup, match := pg.(*schedulerpluginsv1alpha1.PodGroup)
schedulerPluginsPodGroup = schedulerPluginsPodGroup.DeepCopy()
if !match {
return fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg))
return nil, fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg))
}
schedulerPluginsPodGroup.Spec = schedulerpluginsv1alpha1.PodGroupSpec{
MinMember: minMember,
MinResources: *minResources,
ScheduleTimeoutSeconds: schedulerTimeout,
}
return nil
return schedulerPluginsPodGroup, nil
}
}

Expand Down
27 changes: 17 additions & 10 deletions pkg/controller.v1/common/scheduling.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import (
"fmt"

apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1"

"github.com/google/go-cmp/cmp"
log "github.com/sirupsen/logrus"
policyapi "k8s.io/api/policy/v1beta1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
Expand All @@ -33,7 +33,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
)

type FillPodGroupSpecFunc func(object metav1.Object) error
type FillPodGroupSpecFunc func(object metav1.Object) (metav1.Object, error)

func (jc *JobController) SyncPodGroup(job metav1.Object, specFunc FillPodGroupSpecFunc) (metav1.Object, error) {
pgctl := jc.PodGroupControl
Expand All @@ -42,14 +42,20 @@ func (jc *JobController) SyncPodGroup(job metav1.Object, specFunc FillPodGroupSp
podGroup, err := pgctl.GetPodGroup(job.GetNamespace(), job.GetName())
if err == nil {
// update podGroup for gang scheduling
oldPodGroup := &podGroup
if err = specFunc(podGroup); err != nil {
updatedSpecPodGroup, err := specFunc(podGroup)
if err != nil {
return nil, fmt.Errorf("unable to fill the spec of PodGroup, '%v': %v", klog.KObj(podGroup), err)
}
if diff := cmp.Diff(oldPodGroup, podGroup); len(diff) != 0 {
return podGroup, pgctl.UpdatePodGroup(podGroup.(client.Object))

existVolcanoPodGroup := podGroup.(*volcanov1beta1.PodGroup)
updatedSpecVolcanoPodGroup := updatedSpecPodGroup.(*volcanov1beta1.PodGroup)
// The hpa-controller may update the num of replicas
// https://github.com/kubeflow/common/pull/207
if existVolcanoPodGroup.Spec.MinMember != updatedSpecVolcanoPodGroup.Spec.MinMember {
// The queue name should not be changed after the pg is created
updatedSpecVolcanoPodGroup.Spec.Queue = existVolcanoPodGroup.Spec.Queue
return updatedSpecPodGroup, pgctl.UpdatePodGroup(updatedSpecPodGroup.(client.Object))
}
return podGroup, nil
} else if client.IgnoreNotFound(err) != nil {
return nil, fmt.Errorf("unable to get a PodGroup: %v", err)
} else {
Expand All @@ -59,13 +65,14 @@ func (jc *JobController) SyncPodGroup(job metav1.Object, specFunc FillPodGroupSp
newPodGroup.SetNamespace(job.GetNamespace())
newPodGroup.SetAnnotations(job.GetAnnotations())
newPodGroup.SetOwnerReferences([]metav1.OwnerReference{*jc.GenOwnerReference(job)})
if err = specFunc(newPodGroup); err != nil {
updatedSpecPodGroup, err := specFunc(newPodGroup)
if err != nil {
return nil, fmt.Errorf("unable to fill the spec of PodGroup, '%v': %v", klog.KObj(newPodGroup), err)
}

err = pgctl.CreatePodGroup(newPodGroup)
err = pgctl.CreatePodGroup(updatedSpecPodGroup.(client.Object))
if err != nil {
return podGroup, fmt.Errorf("unable to create PodGroup: %v", err)
return updatedSpecPodGroup, fmt.Errorf("unable to create PodGroup: %v", err)
}
createdPodGroupsCount.Inc()
}
Expand Down

0 comments on commit ca04ae6

Please sign in to comment.