Skip to content

Commit

Permalink
Support MPIJob managedBy feature for the MultiKueue (kubernetes-sigs#…
Browse files Browse the repository at this point in the history
…3289)

* Add managedBy field impl and unit tests

* Update MpiJob multikueue integration test

* Update e2e tests and start to use mpi-operator on managment cluster

* Implement webhook defaulting

* Update after code review
  • Loading branch information
mszadkow authored and PBundyra committed Nov 5, 2024
1 parent 02e31d9 commit 70df735
Show file tree
Hide file tree
Showing 10 changed files with 610 additions and 38 deletions.
1 change: 0 additions & 1 deletion hack/e2e-common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ export KUBEFLOW_IMAGE=kubeflow/training-operator:${KUBEFLOW_IMAGE_VERSION}

export KUBEFLOW_MPI_MANIFEST="https://raw.githubusercontent.com/kubeflow/mpi-operator/${KUBEFLOW_MPI_VERSION}/deploy/v2beta1/mpi-operator.yaml"
export KUBEFLOW_MPI_IMAGE=mpioperator/mpi-operator:${KUBEFLOW_MPI_VERSION/#v}
export KUBEFLOW_MPI_CRD=${ROOT_DIR}/dep-crds/mpi-operator/kubeflow.org_mpijobs.yaml

# sleep image to use for testing.
export E2E_TEST_IMAGE=gcr.io/k8s-staging-perf-tests/sleep:v0.1.0@sha256:8d91ddf9f145b66475efda1a1b52269be542292891b5de2a7fad944052bab6ea
Expand Down
8 changes: 5 additions & 3 deletions hack/multikueue-e2e-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,16 @@ function kind_load {
# have Kubeflow Jobs admitted without execution in the manager cluster.
kubectl config use-context "kind-${MANAGER_KIND_CLUSTER_NAME}"
kubectl apply -k "${KUBEFLOW_MANIFEST_MANAGER}"
## MPI
kubectl apply --server-side -f "${KUBEFLOW_MPI_CRD}"

# WORKERS
docker pull "${KUBEFLOW_IMAGE}"
docker pull "${KUBEFLOW_MPI_IMAGE}"

install_kubeflow "$WORKER1_KIND_CLUSTER_NAME"
install_kubeflow "$WORKER2_KIND_CLUSTER_NAME"

## MPI
docker pull "${KUBEFLOW_MPI_IMAGE}"
install_mpi "$MANAGER_KIND_CLUSTER_NAME"
install_mpi "$WORKER1_KIND_CLUSTER_NAME"
install_mpi "$WORKER2_KIND_CLUSTER_NAME"

Expand Down
7 changes: 1 addition & 6 deletions pkg/controller/jobs/mpijob/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,8 @@ var (
gvk = kfmpi.SchemeGroupVersionKind

FrameworkName = "kubeflow.org/mpijob"

SetupMPIJobWebhook = jobframework.BaseWebhookFactory(NewJob(), fromObject)
)

// +kubebuilder:webhook:path=/mutate-kubeflow-org-v2beta1-mpijob,mutating=true,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=mpijobs,verbs=create,versions=v2beta1,name=mmpijob.kb.io,admissionReviewVersions=v1
// +kubebuilder:webhook:path=/validate-kubeflow-org-v2beta1-mpijob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=mpijobs,verbs=create;update,versions=v2beta1,name=vmpijob.kb.io,admissionReviewVersions=v1

func init() {
utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{
SetupIndexes: SetupIndexes,
Expand Down Expand Up @@ -90,7 +85,7 @@ func (j *MPIJob) Object() client.Object {
return (*kfmpi.MPIJob)(j)
}

func fromObject(o runtime.Object) jobframework.GenericJob {
func fromObject(o runtime.Object) *MPIJob {
return (*MPIJob)(o.(*kfmpi.MPIJob))
}

Expand Down
15 changes: 14 additions & 1 deletion pkg/controller/jobs/mpijob/mpijob_multikueue_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
Expand Down Expand Up @@ -72,6 +73,9 @@ func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Clie
remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName
remoteJob.Labels[kueue.MultiKueueOriginLabel] = origin

// clear the managedBy enables the controller to take over
remoteJob.Spec.RunPolicy.ManagedBy = nil

return remoteClient.Create(ctx, &remoteJob)
}

Expand All @@ -86,7 +90,16 @@ func (b *multikueueAdapter) KeepAdmissionCheckPending() bool {
return false
}

func (b *multikueueAdapter) IsJobManagedByKueue(context.Context, client.Client, types.NamespacedName) (bool, string, error) {
func (b *multikueueAdapter) IsJobManagedByKueue(ctx context.Context, c client.Client, key types.NamespacedName) (bool, string, error) {
job := kfmpi.MPIJob{}
err := c.Get(ctx, key, &job)
if err != nil {
return false, "", err
}
jobControllerName := ptr.Deref(job.Spec.RunPolicy.ManagedBy, "")
if jobControllerName != kueue.MultiKueueControllerName {
return false, fmt.Sprintf("Expecting spec.runPolicy.managedBy to be %q not %q", kueue.MultiKueueControllerName, jobControllerName), nil
}
return true, "", nil
}

Expand Down
89 changes: 67 additions & 22 deletions pkg/controller/jobs/mpijob/mpijob_multikueue_adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package mpijob

import (
"context"
"errors"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -49,39 +50,38 @@ func TestMultikueueAdapter(t *testing.T) {
mpiJobBuilder := utiltestingmpijob.MakeMPIJob("mpijob1", TestNamespace)

cases := map[string]struct {
managersJobSets []kfmpi.MPIJob
workerJobSets []kfmpi.MPIJob
managersMpiJobs []kfmpi.MPIJob
workerMpiJobs []kfmpi.MPIJob

operation func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error

wantError error
wantManagersJobSets []kfmpi.MPIJob
wantWorkerJobSets []kfmpi.MPIJob
wantManagersMpiJobs []kfmpi.MPIJob
wantWorkerMpiJobs []kfmpi.MPIJob
}{

"sync creates missing remote mpijob": {
managersJobSets: []kfmpi.MPIJob{
managersMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.DeepCopy(),
},
operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error {
return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "mpijob1", Namespace: TestNamespace}, "wl1", "origin1")
},

wantManagersJobSets: []kfmpi.MPIJob{
wantManagersMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.DeepCopy(),
},
wantWorkerJobSets: []kfmpi.MPIJob{
wantWorkerMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
Label(constants.PrebuiltWorkloadLabel, "wl1").
Label(kueue.MultiKueueOriginLabel, "origin1").
Obj(),
},
},
"sync status from remote mpijob": {
managersJobSets: []kfmpi.MPIJob{
managersMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.DeepCopy(),
},
workerJobSets: []kfmpi.MPIJob{
workerMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
Label(constants.PrebuiltWorkloadLabel, "wl1").
Label(kueue.MultiKueueOriginLabel, "origin1").
Expand All @@ -92,12 +92,12 @@ func TestMultikueueAdapter(t *testing.T) {
return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "mpijob1", Namespace: TestNamespace}, "wl1", "origin1")
},

wantManagersJobSets: []kfmpi.MPIJob{
wantManagersMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
StatusConditions(kfmpi.JobCondition{Type: kfmpi.JobSucceeded, Status: corev1.ConditionTrue}).
Obj(),
},
wantWorkerJobSets: []kfmpi.MPIJob{
wantWorkerMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
Label(constants.PrebuiltWorkloadLabel, "wl1").
Label(kueue.MultiKueueOriginLabel, "origin1").
Expand All @@ -106,7 +106,7 @@ func TestMultikueueAdapter(t *testing.T) {
},
},
"remote mpijob is deleted": {
workerJobSets: []kfmpi.MPIJob{
workerMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
Label(constants.PrebuiltWorkloadLabel, "wl1").
Label(kueue.MultiKueueOriginLabel, "origin1").
Expand All @@ -116,16 +116,61 @@ func TestMultikueueAdapter(t *testing.T) {
return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "mpijob1", Namespace: TestNamespace})
},
},
"job with wrong managedBy is not considered managed": {
managersMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
ManagedBy("some-other-controller").
Obj(),
},
operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "mpijob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
wantManagersMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
ManagedBy("some-other-controller").
Obj(),
},
},

"job managedBy multikueue": {
managersMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
ManagedBy(kueue.MultiKueueControllerName).
Obj(),
},
operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "mpijob1", Namespace: TestNamespace}); !isManged {
return errors.New("expecting true")
}
return nil
},
wantManagersMpiJobs: []kfmpi.MPIJob{
*mpiJobBuilder.Clone().
ManagedBy(kueue.MultiKueueControllerName).
Obj(),
},
},
"missing job is not considered managed": {
operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error {
if isManged, _, _ := adapter.IsJobManagedByKueue(ctx, managerClient, types.NamespacedName{Name: "mpijob1", Namespace: TestNamespace}); isManged {
return errors.New("expecting false")
}
return nil
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
managerBuilder := utiltesting.NewClientBuilder(kfmpi.AddToScheme).WithInterceptorFuncs(interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge})
managerBuilder = managerBuilder.WithLists(&kfmpi.MPIJobList{Items: tc.managersJobSets})
managerBuilder = managerBuilder.WithStatusSubresource(slices.Map(tc.managersJobSets, func(w *kfmpi.MPIJob) client.Object { return w })...)
managerBuilder = managerBuilder.WithLists(&kfmpi.MPIJobList{Items: tc.managersMpiJobs})
managerBuilder = managerBuilder.WithStatusSubresource(slices.Map(tc.managersMpiJobs, func(w *kfmpi.MPIJob) client.Object { return w })...)
managerClient := managerBuilder.Build()

workerBuilder := utiltesting.NewClientBuilder(kfmpi.AddToScheme).WithInterceptorFuncs(interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge})
workerBuilder = workerBuilder.WithLists(&kfmpi.MPIJobList{Items: tc.workerJobSets})
workerBuilder = workerBuilder.WithLists(&kfmpi.MPIJobList{Items: tc.workerMpiJobs})
workerClient := workerBuilder.Build()

ctx, _ := utiltesting.ContextWithLog(t)
Expand All @@ -138,20 +183,20 @@ func TestMultikueueAdapter(t *testing.T) {
t.Errorf("unexpected error (-want/+got):\n%s", diff)
}

gotManagersJobSets := &kfmpi.MPIJobList{}
if err := managerClient.List(ctx, gotManagersJobSets); err != nil {
gotManagersMpiJobs := &kfmpi.MPIJobList{}
if err := managerClient.List(ctx, gotManagersMpiJobs); err != nil {
t.Errorf("unexpected list manager's mpijobs error %s", err)
} else {
if diff := cmp.Diff(tc.wantManagersJobSets, gotManagersJobSets.Items, objCheckOpts...); diff != "" {
if diff := cmp.Diff(tc.wantManagersMpiJobs, gotManagersMpiJobs.Items, objCheckOpts...); diff != "" {
t.Errorf("unexpected manager's mpijobs (-want/+got):\n%s", diff)
}
}

gotWorkerJobSets := &kfmpi.MPIJobList{}
if err := workerClient.List(ctx, gotWorkerJobSets); err != nil {
gotWorkerMpiJobs := &kfmpi.MPIJobList{}
if err := workerClient.List(ctx, gotWorkerMpiJobs); err != nil {
t.Errorf("unexpected list worker's mpijobs error %s", err)
} else {
if diff := cmp.Diff(tc.wantWorkerJobSets, gotWorkerJobSets.Items, objCheckOpts...); diff != "" {
if diff := cmp.Diff(tc.wantWorkerMpiJobs, gotWorkerMpiJobs.Items, objCheckOpts...); diff != "" {
t.Errorf("unexpected worker's mpijobs (-want/+got):\n%s", diff)
}
}
Expand Down
128 changes: 128 additions & 0 deletions pkg/controller/jobs/mpijob/mpijob_webhook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mpijob

import (
"context"

"github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/cache"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/controller/jobframework/webhook"
"sigs.k8s.io/kueue/pkg/features"
"sigs.k8s.io/kueue/pkg/queue"
"sigs.k8s.io/kueue/pkg/util/kubeversion"
)

type MpiJobWebhook struct {
manageJobsWithoutQueueName bool
kubeServerVersion *kubeversion.ServerVersionFetcher
queues *queue.Manager
cache *cache.Cache
}

// SetupMPIJobWebhook configures the webhook for MPIJob.
func SetupMPIJobWebhook(mgr ctrl.Manager, opts ...jobframework.Option) error {
options := jobframework.ProcessOptions(opts...)
wh := &MpiJobWebhook{
manageJobsWithoutQueueName: options.ManageJobsWithoutQueueName,
kubeServerVersion: options.KubeServerVersion,
queues: options.Queues,
cache: options.Cache,
}
obj := &v2beta1.MPIJob{}
return webhook.WebhookManagedBy(mgr).
For(obj).
WithMutationHandler(webhook.WithLosslessDefaulter(mgr.GetScheme(), obj, wh)).
WithValidator(wh).
Complete()
}

// +kubebuilder:webhook:path=/mutate-kubeflow-org-v2beta1-mpijob,mutating=true,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=mpijobs,verbs=create,versions=v2beta1,name=mmpijob.kb.io,admissionReviewVersions=v1

var _ admission.CustomDefaulter = &MpiJobWebhook{}

// Default implements webhook.CustomDefaulter so a webhook will be registered for the type
func (w *MpiJobWebhook) Default(ctx context.Context, obj runtime.Object) error {
mpiJob := fromObject(obj)
log := ctrl.LoggerFrom(ctx).WithName("mpijob-webhook")
log.V(5).Info("Applying defaults", "mpijob", klog.KObj(mpiJob))

jobframework.ApplyDefaultForSuspend(mpiJob, w.manageJobsWithoutQueueName)

if canDefaultManagedBy(mpiJob.Spec.RunPolicy.ManagedBy) {
localQueueName, found := mpiJob.Labels[constants.QueueLabel]
if !found {
return nil
}
clusterQueueName, ok := w.queues.ClusterQueueFromLocalQueue(queue.QueueKey(mpiJob.ObjectMeta.Namespace, localQueueName))
if !ok {
log.V(5).Info("Cluster queue for local queue not found", "mpijob", klog.KObj(mpiJob), "localQueue", localQueueName)
return nil
}
for _, admissionCheck := range w.cache.AdmissionChecksForClusterQueue(clusterQueueName) {
if admissionCheck.Controller == kueue.MultiKueueControllerName {
log.V(5).Info("Defaulting ManagedBy", "mpijob", klog.KObj(mpiJob), "oldManagedBy", mpiJob.Spec.RunPolicy.ManagedBy, "managedBy", kueue.MultiKueueControllerName)
mpiJob.Spec.RunPolicy.ManagedBy = ptr.To(kueue.MultiKueueControllerName)
return nil
}
}
}

return nil
}

func canDefaultManagedBy(mpiJobSpecManagedBy *string) bool {
return features.Enabled(features.MultiKueue) &&
(mpiJobSpecManagedBy == nil || *mpiJobSpecManagedBy == v2beta1.KubeflowJobController)
}

// +kubebuilder:webhook:path=/validate-kubeflow-org-v2beta1-mpijob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=mpijobs,verbs=create;update,versions=v2beta1,name=vmpijob.kb.io,admissionReviewVersions=v1

var _ admission.CustomValidator = &MpiJobWebhook{}

// ValidateCreate implements webhook.CustomValidator so a webhook will be registered for the type
func (w *MpiJobWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
mpiJob := fromObject(obj)
log := ctrl.LoggerFrom(ctx).WithName("mpijob-webhook")
log.Info("Validating create", "mpijob", klog.KObj(mpiJob))
return nil, jobframework.ValidateJobOnCreate(mpiJob).ToAggregate()
}

// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
func (w *MpiJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
oldMpiJob := fromObject(oldObj)
newMpiJob := fromObject(newObj)
log := ctrl.LoggerFrom(ctx).WithName("mpijob-webhook")
log.Info("Validating update", "mpijob", klog.KObj(newMpiJob))
allErrs := jobframework.ValidateJobOnUpdate(oldMpiJob, newMpiJob)
allErrs = append(allErrs, jobframework.ValidateJobOnCreate(newMpiJob)...)
return nil, allErrs.ToAggregate()
}

// ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type
func (w *MpiJobWebhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) {
return nil, nil
}
Loading

0 comments on commit 70df735

Please sign in to comment.