Skip to content

Commit

Permalink
KEP-2170: Implement TrainJob Reconciler to manage objects
Browse files Browse the repository at this point in the history
Signed-off-by: Yuki Iwai <[email protected]>
  • Loading branch information
tenzen-y committed Oct 20, 2024
1 parent 0149eb0 commit b31fc96
Show file tree
Hide file tree
Showing 13 changed files with 379 additions and 71 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ bin/
/tf-operator
vendor/
testbin/*
dep-crds/
cover.out

# IDEs
Expand Down
17 changes: 16 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ test: envtest
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./... -coverprofile cover.out

.PHONY: test-integrationv2
test-integrationv2: envtest
test-integrationv2: envtest jobset-operator-crd scheduler-plugins-crd
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./test/... -coverprofile cover.out

.PHONY: testv2
Expand Down Expand Up @@ -127,3 +127,18 @@ controller-gen: ## Download controller-gen locally if necessary.
KUSTOMIZE = $(shell pwd)/bin/kustomize
kustomize: ## Download kustomize locally if necessary.
GOBIN=$(PROJECT_DIR)/bin go install sigs.k8s.io/kustomize/kustomize/[email protected]

## Download external CRDs for the integration testings.
EXTERNAL_CRDS_DIR ?= $(PROJECT_DIR)/dep-crds

JOBSET_ROOT = $(shell go list -m -mod=readonly -f "{{.Dir}}" sigs.k8s.io/jobset)
.PHONY: jobset-operator-crd
jobset-operator-crd: ## Copy the CRDs from the jobset-operator to the dep-crds directory.
mkdir -p $(EXTERNAL_CRDS_DIR)/jobset-operator/
cp -f $(JOBSET_ROOT)/config/components/crd/bases/* $(EXTERNAL_CRDS_DIR)/jobset-operator/

SCHEDULER_PLUGINS_ROOT = $(shell go list -m -f "{{.Dir}}" sigs.k8s.io/scheduler-plugins)
.PHONY: scheduler-plugins-crd
scheduler-plugins-crd:
mkdir -p $(EXTERNAL_CRDS_DIR)/scheduler-plugins/
cp -f $(SCHEDULER_PLUGINS_ROOT)/manifests/coscheduling/* $(PROJECT_DIR)/dep-crds/scheduler-plugins
3 changes: 2 additions & 1 deletion pkg/controller.v2/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ func SetupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (st
if err := NewTrainJobReconciler(
mgr.GetClient(),
mgr.GetEventRecorderFor("training-operator-trainjob-controller"),
).SetupWithManager(mgr, runtimes); err != nil {
runtimes,
).SetupWithManager(mgr); err != nil {
return "TrainJob", err
}
return "", nil
Expand Down
68 changes: 64 additions & 4 deletions pkg/controller.v2/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ import (
"context"

"github.com/go-logr/logr"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
Expand All @@ -33,13 +36,15 @@ type TrainJobReconciler struct {
log logr.Logger
client client.Client
recorder record.EventRecorder
runtimes map[string]runtime.Runtime
}

func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder) *TrainJobReconciler {
func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder, runs map[string]runtime.Runtime) *TrainJobReconciler {
return &TrainJobReconciler{
log: ctrl.Log.WithName("trainjob-controller"),
client: client,
recorder: recorder,
runtimes: runs,
}
}

Expand All @@ -49,15 +54,70 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
return ctrl.Result{}, client.IgnoreNotFound(err)
}
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
ctrl.LoggerInto(ctx, log)
ctx = ctrl.LoggerInto(ctx, log)
log.V(2).Info("Reconciling TrainJob")
if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil {
return ctrl.Result{}, err
}
// TODO (tenzen-y): Do update the status.
return ctrl.Result{}, nil
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) error {
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
log := ctrl.LoggerFrom(ctx)

// Controller assumes the runtime existence has already verified in the webhook on TrainJob creation.
run := r.runtimes[runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()]
objs, err := run.NewObjects(ctx, trainJob)
if err != nil {
return err
}
for _, obj := range objs {
var gvk schema.GroupVersionKind
if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil {
return err
}
logKeysAndValues := []any{
"groupVersionKind", gvk.String(),
"namespace", obj.GetNamespace(),
"name", obj.GetName(),
}
// TODO (tenzen-y): Ideally, we should use the SSA instead of checking existence.
// Non-empty resourceVersion indicates UPDATE operation.
var creationErr error
var created bool
if obj.GetResourceVersion() == "" {
creationErr = r.client.Create(ctx, obj)
created = creationErr == nil
}
switch {
case created:
log.V(5).Info("Succeeded to create object", logKeysAndValues)
continue
case client.IgnoreAlreadyExists(creationErr) != nil:
return creationErr
default:
// This indicates CREATE operation has not been performed or the object has already existed in the cluster.
if err = r.client.Update(ctx, obj); err != nil {
return err
}
log.V(5).Info("Succeeded to update object", logKeysAndValues)
}
}
return nil
}

func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
}
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
b := ctrl.NewControllerManagedBy(mgr).
For(&kubeflowv2.TrainJob{})
for _, run := range runtimes {
for _, run := range r.runtimes {
for _, registrar := range run.EventHandlerRegistrars() {
if registrar != nil {
b = registrar(b, mgr.GetClient())
Expand Down
4 changes: 3 additions & 1 deletion pkg/runtime.v2/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
}{
"succeeded to build JobSet and PodGroup": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
UID("uid").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
Expand All @@ -57,7 +58,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
clusterTrainingRuntime: baseRuntime.RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
ContainerImage("test:runtime").
PodGroupPolicySchedulingTimeout(120).
PodGroupPolicyCoschedulingSchedulingTimeout(120).
MLPolicyNumNodes(20).
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
Expand All @@ -69,6 +70,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
).Obj(),
wantObjs: []client.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ContainerImage(ptr.To("test:trainjob")).
JobCompletionMode(batchv1.IndexedCompletion).
Expand Down
4 changes: 3 additions & 1 deletion pkg/runtime.v2/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
}{
"succeeded to build JobSet and PodGroup": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
UID("uid").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
SpecLabel("conflictLabel", "override").
Expand All @@ -62,7 +63,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
ContainerImage("test:runtime").
PodGroupPolicySchedulingTimeout(120).
PodGroupPolicyCoschedulingSchedulingTimeout(120).
MLPolicyNumNodes(20).
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
Expand All @@ -74,6 +75,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
).Obj(),
wantObjs: []client.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
Label("conflictLabel", "override").
Annotation("conflictAnnotation", "override").
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
Expand Down
10 changes: 6 additions & 4 deletions pkg/runtime.v2/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
ResourceRequests(1, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
}).
Clone()
})
jobSetWithPropagatedTrainJobParams := jobSetBase.
Clone().
JobCompletionMode(batchv1.IndexedCompletion).
ContainerImage(ptr.To("foo:bar")).
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
Clone()
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid")

cases := map[string]struct {
runtimeInfo *runtime.Info
Expand All @@ -361,6 +360,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
Obj(),
runtimeInfo: &runtime.Info{
Obj: jobSetBase.
Clone().
Obj(),
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
Expand Down Expand Up @@ -403,10 +403,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
Obj(),
jobSetWithPropagatedTrainJobParams.
Clone().
Obj(),
},
wantRuntimeInfo: &runtime.Info{
Obj: jobSetWithPropagatedTrainJobParams.
Clone().
Obj(),
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
Expand Down
11 changes: 8 additions & 3 deletions pkg/runtime.v2/framework/plugins/jobset/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ import (
)

type Builder struct {
*jobsetv1alpha2.JobSet
jobsetv1alpha2.JobSet
}

func NewBuilder(objectKey client.ObjectKey, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec) *Builder {
return &Builder{
JobSet: &jobsetv1alpha2.JobSet{
JobSet: jobsetv1alpha2.JobSet{
TypeMeta: metav1.TypeMeta{
APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(),
Kind: "JobSet",
Expand Down Expand Up @@ -76,8 +76,13 @@ func (b *Builder) PodLabels(labels map[string]string) *Builder {
return b
}

func (b *Builder) Suspend(suspend *bool) *Builder {
b.Spec.Suspend = suspend
return b
}

// TODO: Need to support all TrainJob fields.

func (b *Builder) Build() *jobsetv1alpha2.JobSet {
return b.JobSet
return &b.JobSet
}
45 changes: 29 additions & 16 deletions pkg/runtime.v2/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,29 +74,37 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl
if !ok {
return nil, nil
}
jobSetBuilder := NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: info.Labels,
Annotations: info.Annotations,
},
Spec: raw.Spec,
})

var jobSetBuilder *Builder
oldJobSet := &jobsetv1alpha2.JobSet{}
if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldJobSet); err != nil {
if !apierrors.IsNotFound(err) {
return nil, err
}
jobSetBuilder = NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: info.Labels,
Annotations: info.Annotations,
},
Spec: raw.Spec,
})
oldJobSet = nil
} else {
jobSetBuilder = &Builder{
JobSet: *oldJobSet.DeepCopy(),
}
}

// TODO (tenzen-y): We should support all field propagation in builder.
jobSet := jobSetBuilder.
Suspend(trainJob.Spec.Suspend).
ContainerImage(trainJob.Spec.Trainer.Image).
JobCompletionMode(batchv1.IndexedCompletion).
PodLabels(info.PodLabels).
Build()
if err := ctrlutil.SetControllerReference(trainJob, jobSet, j.scheme); err != nil {
return nil, err
}
oldJobSet := &jobsetv1alpha2.JobSet{}
if err := j.client.Get(ctx, client.ObjectKeyFromObject(jobSet), oldJobSet); err != nil {
if !apierrors.IsNotFound(err) {
return nil, err
}
oldJobSet = nil
}
if err := info.Update(jobSet); err != nil {
return nil, err
}
Expand All @@ -106,9 +114,14 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl
return nil, nil
}

func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, suspended bool) bool {
func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, trainJobIsSuspended bool) bool {
return old == nil ||
suspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations))
(!trainJobIsSuspended && jobSetIsSuspended(old) && !jobSetIsSuspended(new)) ||
(trainJobIsSuspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)))
}

func jobSetIsSuspended(jobSet *jobsetv1alpha2.JobSet) bool {
return ptr.Deref(jobSet.Spec.Suspend, false)
}

func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {
Expand Down
Loading

0 comments on commit b31fc96

Please sign in to comment.