Skip to content

Commit

Permalink
Use patch instead of update on jobframework. (kubernetes-sigs#2553)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi authored and kannon92 committed Nov 19, 2024
1 parent 394d4f4 commit c47d782
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 76 deletions.
2 changes: 2 additions & 0 deletions charts/kueue/templates/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ rules:
- jobs/status
verbs:
- get
- patch
- update
- apiGroups:
- flowcontrol.apiserver.k8s.io
Expand Down Expand Up @@ -179,6 +180,7 @@ rules:
- jobsets/status
verbs:
- get
- patch
- update
- apiGroups:
- kubeflow.org
Expand Down
2 changes: 2 additions & 0 deletions config/components/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ rules:
- jobs/status
verbs:
- get
- patch
- update
- apiGroups:
- flowcontrol.apiserver.k8s.io
Expand Down Expand Up @@ -178,6 +179,7 @@ rules:
- jobsets/status
verbs:
- get
- patch
- update
- apiGroups:
- kubeflow.org
Expand Down
28 changes: 15 additions & 13 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (
"sigs.k8s.io/kueue/pkg/podset"
"sigs.k8s.io/kueue/pkg/queue"
"sigs.k8s.io/kueue/pkg/util/admissioncheck"
clientutil "sigs.k8s.io/kueue/pkg/util/client"
"sigs.k8s.io/kueue/pkg/util/equality"
"sigs.k8s.io/kueue/pkg/util/kubeversion"
"sigs.k8s.io/kueue/pkg/util/maps"
Expand Down Expand Up @@ -296,9 +297,10 @@ func (r *JobReconciler) ReconcileGenericJob(ctx context.Context, req ctrl.Reques
log.Error(err, "couldn't get the parent job workload")
return ctrl.Result{}, err
} else if parentWorkload == nil || !workload.IsAdmitted(parentWorkload) {
// suspend it
job.Suspend()
if err := r.client.Update(ctx, object); err != nil {
if err := clientutil.Patch(ctx, r.client, object, true, func() (bool, error) {
job.Suspend()
return true, nil
}); err != nil {
log.Error(err, "suspending child job failed")
return ctrl.Result{}, err
}
Expand Down Expand Up @@ -801,11 +803,9 @@ func (r *JobReconciler) startJob(ctx context.Context, job GenericJob, object cli
return err
}
} else {
if runErr := job.RunWithPodSetsInfo(info); runErr != nil {
return runErr
}

if err := r.client.Update(ctx, object); err != nil {
if err := clientutil.Patch(ctx, r.client, object, true, func() (bool, error) {
return true, job.RunWithPodSetsInfo(info)
}); err != nil {
return err
}
r.record.Event(object, corev1.EventTypeNormal, ReasonStarted, msg)
Expand Down Expand Up @@ -847,11 +847,13 @@ func (r *JobReconciler) stopJob(ctx context.Context, job GenericJob, wl *kueue.W
return nil
}

job.Suspend()
if info != nil {
job.RestorePodSetsInfo(info)
}
if err := r.client.Update(ctx, object); err != nil {
if err := clientutil.Patch(ctx, r.client, object, true, func() (bool, error) {
job.Suspend()
if info != nil {
job.RestorePodSetsInfo(info)
}
return true, nil
}); err != nil {
return err
}

Expand Down
38 changes: 24 additions & 14 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"sigs.k8s.io/kueue/pkg/controller/core/indexer"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/podset"
clientutil "sigs.k8s.io/kueue/pkg/util/client"
)

var (
Expand Down Expand Up @@ -69,7 +70,7 @@ func init() {
// +kubebuilder:rbac:groups=scheduling.k8s.io,resources=priorityclasses,verbs=list;get;watch
// +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update;patch
// +kubebuilder:rbac:groups=batch,resources=jobs,verbs=get;list;watch;update;patch
// +kubebuilder:rbac:groups=batch,resources=jobs/status,verbs=get;update
// +kubebuilder:rbac:groups=batch,resources=jobs/status,verbs=get;update;patch
// +kubebuilder:rbac:groups=batch,resources=jobs/finalizers,verbs=get;update;patch
// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch
Expand Down Expand Up @@ -163,34 +164,43 @@ func (j *Job) Suspend() {
j.Spec.Suspend = ptr.To(true)
}

func (j *Job) Stop(ctx context.Context, c client.Client, podSetsInfo []podset.PodSetInfo, _ jobframework.StopReason, eventMsg string) (bool, error) {
func (j *Job) Stop(ctx context.Context, c client.Client, podSetsInfo []podset.PodSetInfo, _ jobframework.StopReason, _ string) (bool, error) {
object := j.Object()
stoppedNow := false

if !j.IsSuspended() {
j.Suspend()
if j.ObjectMeta.Annotations == nil {
j.ObjectMeta.Annotations = map[string]string{}
}
// We are using annotation to be sure that all updates finished successfully.
j.ObjectMeta.Annotations[StoppingAnnotation] = "true"
if err := c.Update(ctx, j.Object()); err != nil {
if err := clientutil.Patch(ctx, c, object, true, func() (bool, error) {
j.Suspend()
if j.ObjectMeta.Annotations == nil {
j.ObjectMeta.Annotations = map[string]string{}
}
// We are using annotation to be sure that all updates finished successfully.
j.ObjectMeta.Annotations[StoppingAnnotation] = "true"
return true, nil
}); err != nil {
return false, fmt.Errorf("suspend: %w", err)
}
stoppedNow = true
}

// Reset start time if necessary, so we can update the scheduling directives.
if j.Status.StartTime != nil {
j.Status.StartTime = nil
if err := c.Status().Update(ctx, j.Object()); err != nil {
if err := clientutil.PatchStatus(ctx, c, object, func() (bool, error) {
j.Status.StartTime = nil
return true, nil
}); err != nil {
return stoppedNow, fmt.Errorf("reset status: %w", err)
}
}

j.RestorePodSetsInfo(podSetsInfo)
delete(j.ObjectMeta.Annotations, StoppingAnnotation)
if err := c.Update(ctx, j.Object()); err != nil {
if err := clientutil.Patch(ctx, c, object, true, func() (bool, error) {
j.RestorePodSetsInfo(podSetsInfo)
delete(j.ObjectMeta.Annotations, StoppingAnnotation)
return true, nil
}); err != nil {
return false, fmt.Errorf("restore info: %w", err)
}

return stoppedNow, nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/jobs/jobset/jobset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func init() {
// +kubebuilder:rbac:groups=scheduling.k8s.io,resources=priorityclasses,verbs=list;get;watch
// +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update;patch
// +kubebuilder:rbac:groups=jobset.x-k8s.io,resources=jobsets,verbs=get;list;watch;update;patch
// +kubebuilder:rbac:groups=jobset.x-k8s.io,resources=jobsets/status,verbs=get;update
// +kubebuilder:rbac:groups=jobset.x-k8s.io,resources=jobsets/status,verbs=get;update;patch
// +kubebuilder:rbac:groups=jobset.x-k8s.io,resources=jobsets/finalizers,verbs=get;update
// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch
Expand Down
94 changes: 49 additions & 45 deletions pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,54 +253,58 @@ func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.Pod
return fmt.Errorf("%w: expecting 1 pod set got %d", podset.ErrInvalidPodsetInfo, len(podSetsInfo))
}

podOriginal := p.pod.DeepCopy()

if ungated := ungatePod(&p.pod); !ungated {
if gateIndex(&p.pod) == gateNotFound {
return nil
}

if err := podset.Merge(&p.pod.ObjectMeta, &p.pod.Spec, podSetsInfo[0]); err != nil {
if err := clientutil.Patch(ctx, c, &p.pod, true, func() (bool, error) {
ungatePod(&p.pod)
return true, podset.Merge(&p.pod.ObjectMeta, &p.pod.Spec, podSetsInfo[0])
}); err != nil {
return err
}

if err := clientutil.Patch(ctx, c, podOriginal, &p.pod); err != nil {
return err
}
if recorder != nil {
recorder.Event(&p.pod, corev1.EventTypeNormal, jobframework.ReasonStarted, msg)
}

return nil
}

return parallelize.Until(ctx, len(p.list.Items), func(i int) error {
pod := &p.list.Items[i]
podOriginal := pod.DeepCopy()

if ungated := ungatePod(pod); !ungated {
if gateIndex(pod) == gateNotFound {
return nil
}

roleHash, err := getRoleHash(*pod)
if err != nil {
return err
}
if err := clientutil.Patch(ctx, c, pod, true, func() (bool, error) {
ungatePod(pod)

podSetIndex := slices.IndexFunc(podSetsInfo, func(info podset.PodSetInfo) bool {
return info.Name == roleHash
})
if podSetIndex == -1 {
return fmt.Errorf("%w: podSetInfo with the name '%s' is not found", podset.ErrInvalidPodsetInfo, roleHash)
}
roleHash, err := getRoleHash(*pod)
if err != nil {
return false, err
}

err = podset.Merge(&pod.ObjectMeta, &pod.Spec, podSetsInfo[podSetIndex])
if err != nil {
return err
}
podSetIndex := slices.IndexFunc(podSetsInfo, func(info podset.PodSetInfo) bool {
return info.Name == roleHash
})
if podSetIndex == -1 {
return false, fmt.Errorf("%w: podSetInfo with the name '%s' is not found", podset.ErrInvalidPodsetInfo, roleHash)
}

err = podset.Merge(&pod.ObjectMeta, &pod.Spec, podSetsInfo[podSetIndex])
if err != nil {
return false, err
}

log.V(3).Info("Starting pod in group", "podInGroup", klog.KObj(pod))

log.V(3).Info("Starting pod in group", "podInGroup", klog.KObj(pod))
if err := clientutil.Patch(ctx, c, podOriginal, pod); err != nil {
return true, nil
}); err != nil {
return err
}

if recorder != nil {
recorder.Event(pod, corev1.EventTypeNormal, jobframework.ReasonStarted, msg)
}
Expand Down Expand Up @@ -530,11 +534,9 @@ func (p *Pod) Finalize(ctx context.Context, c client.Client) error {

return parallelize.Until(ctx, len(podsInGroup.Items), func(i int) error {
pod := &podsInGroup.Items[i]
podOriginal := pod.DeepCopy()
if controllerutil.RemoveFinalizer(pod, PodFinalizer) {
return clientutil.Patch(ctx, c, podOriginal, pod)
}
return nil
return clientutil.Patch(ctx, c, pod, false, func() (bool, error) {
return controllerutil.RemoveFinalizer(pod, PodFinalizer), nil
})
})
}

Expand Down Expand Up @@ -840,14 +842,14 @@ func (p *Pod) removeExcessPods(ctx context.Context, c client.Client, r record.Ev
// Finalize and delete the active pods created last
err := parallelize.Until(ctx, len(extraPods), func(i int) error {
pod := extraPods[i]
podOriginal := pod.DeepCopy()
if controllerutil.RemoveFinalizer(&pod, PodFinalizer) {
if err := clientutil.Patch(ctx, c, &pod, false, func() (bool, error) {
removed := controllerutil.RemoveFinalizer(&pod, PodFinalizer)
log.V(3).Info("Finalizing excess pod in group", "excessPod", klog.KObj(&pod))
if err := clientutil.Patch(ctx, c, podOriginal, &pod); err != nil {
// We won't observe this cleanup in the event handler.
p.excessPodExpectations.ObservedUID(log, p.key, pod.UID)
return err
}
return removed, nil
}); err != nil {
// We won't observe this cleanup in the event handler.
p.excessPodExpectations.ObservedUID(log, p.key, pod.UID)
return err
}
if pod.DeletionTimestamp.IsZero() {
log.V(3).Info("Deleting excess pod in group", "excessPod", klog.KObj(&pod))
Expand Down Expand Up @@ -879,15 +881,17 @@ func (p *Pod) finalizePods(ctx context.Context, c client.Client, extraPods []cor

err := parallelize.Until(ctx, len(extraPods), func(i int) error {
pod := extraPods[i]
podOriginal := pod.DeepCopy()
if controllerutil.RemoveFinalizer(&pod, PodFinalizer) {
var removed bool
if err := clientutil.Patch(ctx, c, &pod, false, func() (bool, error) {
removed = controllerutil.RemoveFinalizer(&pod, PodFinalizer)
log.V(3).Info("Finalizing pod in group", "Pod", klog.KObj(&pod))
if err := clientutil.Patch(ctx, c, podOriginal, &pod); err != nil {
// We won't observe this cleanup in the event handler.
p.excessPodExpectations.ObservedUID(log, p.key, pod.UID)
return err
}
} else {
return removed, nil
}); err != nil {
// We won't observe this cleanup in the event handler.
p.excessPodExpectations.ObservedUID(log, p.key, pod.UID)
return err
}
if !removed {
// We don't expect an event in this case.
p.excessPodExpectations.ObservedUID(log, p.key, pod.UID)
}
Expand Down
39 changes: 36 additions & 3 deletions pkg/util/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,45 @@ func CreatePatch(before, after client.Object) (client.Patch, error) {
return client.RawPatch(patchBase.Type(), patchBytes), nil
}

func Patch(ctx context.Context, c client.Client, before, after client.Object) error {
patch, err := CreatePatch(before, after)
// Patch applies the merge patch of client.Object.
// If strict is true, the resourceVersion will be part of the patch, make this call fail if
// client.Object was changed.
func Patch(ctx context.Context, c client.Client, obj client.Object, strict bool, update func() (bool, error)) error {
objOriginal := obj.DeepCopyObject().(client.Object)
if strict {
// Clearing ResourceVersion from the original object to make sure it is included in the generated patch.
objOriginal.SetResourceVersion("")
}
updated, err := update()
if err != nil || !updated {
return err
}
patch, err := CreatePatch(objOriginal, obj)
if err != nil {
return err
}
if err = c.Patch(ctx, obj, patch); err != nil {
return err
}
return nil
}

// PatchStatus applies the merge patch of client.Object status.
// The resourceVersion will be part of the patch, make this call fail if
// client.Object was changed.
func PatchStatus(ctx context.Context, c client.Client, obj client.Object, update func() (bool, error)) error {
objOriginal := obj.DeepCopyObject().(client.Object)
// Clearing ResourceVersion from the original object to make sure it is included in the generated patch.
objOriginal.SetResourceVersion("")
updated, err := update()
if err != nil || !updated {
return err
}
patch, err := CreatePatch(objOriginal, obj)
if err != nil {
return err
}
if err = c.Patch(ctx, before, patch); err != nil {
if err = c.Status().Patch(ctx, obj, patch); err != nil {
return err
}
return nil
Expand Down

0 comments on commit c47d782

Please sign in to comment.