Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KEP-2170: Implement TrainJob Reconciler to manage objects #2295

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ bin/
/tf-operator
vendor/
testbin/*
manifests/external-crds/
cover.out

# IDEs
Expand Down
22 changes: 19 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@ HAS_SETUP_ENVTEST := $(shell command -v setup-envtest;)
testall: manifests generate fmt vet golangci-lint test ## Run tests.

test: envtest
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./... -coverprofile cover.out
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" \
go test ./pkg/apis/kubeflow.org/v1/... ./pkg/cert/... ./pkg/common/... ./pkg/config/... ./pkg/controller.v1/... ./pkg/core/... ./pkg/util/... ./pkg/webhooks/... -coverprofile cover.out

.PHONY: test-integrationv2
test-integrationv2: envtest
test-integrationv2: envtest jobset-operator-crd scheduler-plugins-crd
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./test/... -coverprofile cover.out

.PHONY: testv2
testv2:
go test ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out
go test ./pkg/apis/kubeflow.org/v2alpha1/... ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out

envtest:
ifndef HAS_SETUP_ENVTEST
Expand Down Expand Up @@ -127,3 +128,18 @@ controller-gen: ## Download controller-gen locally if necessary.
KUSTOMIZE = $(shell pwd)/bin/kustomize
kustomize: ## Download kustomize locally if necessary.
GOBIN=$(PROJECT_DIR)/bin go install sigs.k8s.io/kustomize/kustomize/[email protected]

## Download external CRDs for the integration testings.
EXTERNAL_CRDS_DIR ?= $(PROJECT_DIR)/manifests/external-crds

JOBSET_ROOT = $(shell go list -m -mod=readonly -f "{{.Dir}}" sigs.k8s.io/jobset)
.PHONY: jobset-operator-crd
jobset-operator-crd: ## Copy the CRDs from the jobset-operator to the manifests/external-crds directory.
mkdir -p $(EXTERNAL_CRDS_DIR)/jobset-operator/
cp -f $(JOBSET_ROOT)/config/components/crd/bases/* $(EXTERNAL_CRDS_DIR)/jobset-operator/

SCHEDULER_PLUGINS_ROOT = $(shell go list -m -f "{{.Dir}}" sigs.k8s.io/scheduler-plugins)
.PHONY: scheduler-plugins-crd
scheduler-plugins-crd:
mkdir -p $(EXTERNAL_CRDS_DIR)/scheduler-plugins/
cp -f $(SCHEDULER_PLUGINS_ROOT)/manifests/coscheduling/* $(EXTERNAL_CRDS_DIR)/scheduler-plugins
3 changes: 2 additions & 1 deletion pkg/controller.v2/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ func SetupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (st
if err := NewTrainJobReconciler(
mgr.GetClient(),
mgr.GetEventRecorderFor("training-operator-trainjob-controller"),
).SetupWithManager(mgr, runtimes); err != nil {
runtimes,
).SetupWithManager(mgr); err != nil {
return "TrainJob", err
}
return "", nil
Expand Down
79 changes: 73 additions & 6 deletions pkg/controller.v2/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,37 @@ package controllerv2

import (
"context"
"errors"
"fmt"

"github.com/go-logr/logr"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
jobruntimes "github.com/kubeflow/training-operator/pkg/runtime.v2"
)

var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")

type TrainJobReconciler struct {
log logr.Logger
client client.Client
recorder record.EventRecorder
runtimes map[string]jobruntimes.Runtime
}

func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder) *TrainJobReconciler {
func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder, runtimes map[string]jobruntimes.Runtime) *TrainJobReconciler {
return &TrainJobReconciler{
log: ctrl.Log.WithName("trainjob-controller"),
client: client,
recorder: recorder,
runtimes: runtimes,
}
}

Expand All @@ -49,16 +58,74 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
return ctrl.Result{}, client.IgnoreNotFound(err)
}
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
ctrl.LoggerInto(ctx, log)
ctx = ctrl.LoggerInto(ctx, log)
log.V(2).Info("Reconciling TrainJob")
if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil {
return ctrl.Result{}, err
}
// TODO (tenzen-y): Do update the status.
return ctrl.Result{}, nil
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) error {
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
log := ctrl.LoggerFrom(ctx)

runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
runtime, ok := r.runtimes[runtimeRefGK]
if !ok {
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
}
objs, err := runtime.NewObjects(ctx, trainJob)
if err != nil {
return err
}
for _, obj := range objs {
var gvk schema.GroupVersionKind
if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil {
return err
}
logKeysAndValues := []any{
"groupVersionKind", gvk.String(),
"namespace", obj.GetNamespace(),
"name", obj.GetName(),
Comment on lines +89 to +90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we set name and namespace for the build objects ?
E.g. here I can see that we only set the TypeMeta and Spec:

info := runtime.NewInfo(&jobsetv1alpha2.JobSet{
TypeMeta: metav1.TypeMeta{
APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(),
Kind: "JobSet",
},
Spec: *jobSetTemplateSpec.Spec.DeepCopy(),
}, opts...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do. Actual resource names are set in each plugins in the following:

JobSet:

Namespace: objectKey.Namespace,
Name: objectKey.Name,

PodGroup:
Name: trainJob.Name,
Namespace: trainJob.Namespace,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why for JobSet plugin you want to have separate Builder, but for co-scheduling you create a PodSpec object directly under Build() API ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why for JobSet plugin you want to have separate Builder, but for co-scheduling you create a PodSpec object directly under Build() API ?

That is simply reason. There are so many fields and nested levels in the JobSet, but the PodGroup has a few field and a few nested level. For sure, we can provide the builder in the PodGroup as well. But, I guess that the PodGroup builder seems to be slightly overkill.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

}
// TODO (tenzen-y): Ideally, we should use the SSA instead of checking existence.
// Non-empty resourceVersion indicates UPDATE operation.
var creationErr error
var created bool
if obj.GetResourceVersion() == "" {
creationErr = r.client.Create(ctx, obj)
created = creationErr == nil
}
switch {
case created:
log.V(5).Info("Succeeded to create object", logKeysAndValues)
continue
case client.IgnoreAlreadyExists(creationErr) != nil:
return creationErr
default:
// This indicates CREATE operation has not been performed or the object has already existed in the cluster.
if err = r.client.Update(ctx, obj); err != nil {
return err
}
log.V(5).Info("Succeeded to update object", logKeysAndValues)
}
Comment on lines +100 to +112
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to make this condition simpler ? E.g. if GetResourceVersion != "" we know that object exists and we should perform UPDATE operation, otherwise we should just perform CREATE operation.
If CREATE or UPDATE action fails, we return the error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that we need to check if the creationErr is AlreadyExists error. Because we do not know if the runtime interface returns objects with resourceVersion for UPDATE operation, and can not guarantee to have the resourceVersion as an implementation level.

After we migrate to the SSA patch (#2297), we can avoid this complexity and just issue the SSA patch for CREATE and UPDATE.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense!

}
return nil
}

func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
return schema.GroupKind{
Group: ptr.Deref(runtimeRef.APIGroup, ""),
Kind: ptr.Deref(runtimeRef.Kind, ""),
Comment on lines +119 to +120
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use Deref here if APIGroup and Kind can't be empty at this stage ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that we can not restrict these field like non empty. Especially, these should be accepted an empty value so that they can specify the core API type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially, these should be accepted an empty value so that they can specify the core API type.

What is the use-case when users can reference the object from Kubernetes Core, like Service or Pods ?
I thought that users can only reference runtimes that being registered within our Training Operator manager.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that they want to register Pod to the runtimes, and I sometimes have heard from some users the use cases since they often use the plain Pod with many special annotations in the production cluster to mitigate CRD migration costs. For sure, I understand that plain Pod with special annotations is not ideal approach. But, no restrictions would be better in the runtime registration level.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, for the Kind, should it be always non-empty ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, for the Kind, should it be always non-empty ?

As runtime specifications, Kind should be always non-empty.
But, there are some ways to avoid the validations like failurePolicy. So, I would recommend leaving this deref.

}
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
b := ctrl.NewControllerManagedBy(mgr).
For(&kubeflowv2.TrainJob{})
for _, run := range runtimes {
for _, registrar := range run.EventHandlerRegistrars() {
for _, runtime := range r.runtimes {
for _, registrar := range runtime.EventHandlerRegistrars() {
if registrar != nil {
b = registrar(b, mgr.GetClient())
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/runtime.v2/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
}{
"succeeded to build JobSet and PodGroup": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
UID("uid").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
Expand All @@ -57,7 +58,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
clusterTrainingRuntime: baseRuntime.RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
ContainerImage("test:runtime").
PodGroupPolicySchedulingTimeout(120).
PodGroupPolicyCoschedulingSchedulingTimeout(120).
MLPolicyNumNodes(20).
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
Expand All @@ -69,6 +70,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
).Obj(),
wantObjs: []client.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ContainerImage(ptr.To("test:trainjob")).
JobCompletionMode(batchv1.IndexedCompletion).
Expand Down
4 changes: 3 additions & 1 deletion pkg/runtime.v2/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
}{
"succeeded to build JobSet and PodGroup": {
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
UID("uid").
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
SpecLabel("conflictLabel", "override").
Expand All @@ -62,7 +63,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
ContainerImage("test:runtime").
PodGroupPolicySchedulingTimeout(120).
PodGroupPolicyCoschedulingSchedulingTimeout(120).
MLPolicyNumNodes(20).
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
Expand All @@ -74,6 +75,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
).Obj(),
wantObjs: []client.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
Suspend(true).
Label("conflictLabel", "override").
Annotation("conflictAnnotation", "override").
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
Expand Down
10 changes: 6 additions & 4 deletions pkg/runtime.v2/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
ResourceRequests(1, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
corev1.ResourceMemory: resource.MustParse("2Gi"),
}).
Clone()
})
jobSetWithPropagatedTrainJobParams := jobSetBase.
Clone().
JobCompletionMode(batchv1.IndexedCompletion).
ContainerImage(ptr.To("foo:bar")).
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
Clone()
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid")

cases := map[string]struct {
runtimeInfo *runtime.Info
Expand All @@ -361,6 +360,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
Obj(),
runtimeInfo: &runtime.Info{
Obj: jobSetBase.
Clone().
Obj(),
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
Expand Down Expand Up @@ -403,10 +403,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
Obj(),
jobSetWithPropagatedTrainJobParams.
Clone().
Obj(),
},
wantRuntimeInfo: &runtime.Info{
Obj: jobSetWithPropagatedTrainJobParams.
Clone().
Obj(),
Policy: runtime.Policy{
MLPolicy: &kubeflowv2.MLPolicy{
Expand Down
11 changes: 8 additions & 3 deletions pkg/runtime.v2/framework/plugins/jobset/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ import (
)

type Builder struct {
*jobsetv1alpha2.JobSet
jobsetv1alpha2.JobSet
}

func NewBuilder(objectKey client.ObjectKey, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec) *Builder {
return &Builder{
JobSet: &jobsetv1alpha2.JobSet{
JobSet: jobsetv1alpha2.JobSet{
TypeMeta: metav1.TypeMeta{
APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(),
Kind: "JobSet",
Expand Down Expand Up @@ -76,8 +76,13 @@ func (b *Builder) PodLabels(labels map[string]string) *Builder {
return b
}

func (b *Builder) Suspend(suspend *bool) *Builder {
b.Spec.Suspend = suspend
return b
}

// TODO: Need to support all TrainJob fields.

func (b *Builder) Build() *jobsetv1alpha2.JobSet {
return b.JobSet
return &b.JobSet
}
45 changes: 29 additions & 16 deletions pkg/runtime.v2/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,29 +74,37 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl
if !ok {
return nil, nil
}
jobSetBuilder := NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: info.Labels,
Annotations: info.Annotations,
},
Spec: raw.Spec,
})

var jobSetBuilder *Builder
oldJobSet := &jobsetv1alpha2.JobSet{}
if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldJobSet); err != nil {
if !apierrors.IsNotFound(err) {
return nil, err
}
jobSetBuilder = NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: info.Labels,
Annotations: info.Annotations,
},
Spec: raw.Spec,
})
oldJobSet = nil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need this for the needsCreateOrUpdate function.

} else {
jobSetBuilder = &Builder{
JobSet: *oldJobSet.DeepCopy(),
}
}

// TODO (tenzen-y): We should support all field propagation in builder.
jobSet := jobSetBuilder.
Suspend(trainJob.Spec.Suspend).
ContainerImage(trainJob.Spec.Trainer.Image).
JobCompletionMode(batchv1.IndexedCompletion).
PodLabels(info.PodLabels).
Build()
if err := ctrlutil.SetControllerReference(trainJob, jobSet, j.scheme); err != nil {
return nil, err
}
oldJobSet := &jobsetv1alpha2.JobSet{}
if err := j.client.Get(ctx, client.ObjectKeyFromObject(jobSet), oldJobSet); err != nil {
if !apierrors.IsNotFound(err) {
return nil, err
}
oldJobSet = nil
}
if err := info.Update(jobSet); err != nil {
return nil, err
}
Expand All @@ -106,9 +114,14 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl
return nil, nil
}

func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, suspended bool) bool {
func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, trainJobIsSuspended bool) bool {
return old == nil ||
suspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations))
(!trainJobIsSuspended && jobSetIsSuspended(old) && !jobSetIsSuspended(new)) ||
(trainJobIsSuspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)))
}

func jobSetIsSuspended(jobSet *jobsetv1alpha2.JobSet) bool {
return ptr.Deref(jobSet.Spec.Suspend, false)
}

func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {
Expand Down
Loading
Loading