diff --git a/PROJECT b/PROJECT index 8f321f6b86..4aea9cdea0 100644 --- a/PROJECT +++ b/PROJECT @@ -27,4 +27,12 @@ resources: kind: TFJob path: github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1 version: v1 +- api: + crdVersion: v1 + namespaced: true + controller: true + group: kubeflow.org + kind: JAXJob + path: github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1 + version: v1 version: "3" diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index bec5cd6b55..e008c05d4e 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -87,7 +87,7 @@ func main() { "Enabling this will ensure there is only one active controller manager.") flag.StringVar(&leaderElectionID, "leader-election-id", "1ca428e5.training-operator.kubeflow.org", "The ID for leader election.") flag.Var(&enabledSchemes, "enable-scheme", "Enable scheme(s) as --enable-scheme=tfjob --enable-scheme=pytorchjob, case insensitive."+ - " Now supporting TFJob, PyTorchJob, XGBoostJob, PaddleJob. By default, all supported schemes will be enabled.") + " Now supporting TFJob, PyTorchJob, XGBoostJob, PaddleJob, JAXJob. By default, all supported schemes will be enabled.") flag.StringVar(&gangSchedulerName, "gang-scheduler-name", "", "Now Supporting volcano and scheduler-plugins."+ " Note: If you set another scheduler name, the training-operator assumes it's the scheduler-plugins.") flag.StringVar(&namespace, "namespace", os.Getenv(EnvKubeflowNamespace), "The namespace to monitor kubeflow jobs. If unset, it monitors all namespaces cluster-wide."+ diff --git a/manifests/base/webhook/manifests.yaml b/manifests/base/webhook/manifests.yaml index c8a69845b7..2c381d0cd1 100644 --- a/manifests/base/webhook/manifests.yaml +++ b/manifests/base/webhook/manifests.yaml @@ -4,6 +4,26 @@ kind: ValidatingWebhookConfiguration metadata: name: validating-webhook-configuration webhooks: +- admissionReviewVersions: + - v1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-kubeflow-org-v1-jaxjob + failurePolicy: Fail + name: validator.jaxjob.training-operator.kubeflow.org + rules: + - apiGroups: + - kubeflow.org + apiVersions: + - v1 + operations: + - CREATE + - UPDATE + resources: + - jaxjobs + sideEffects: None - admissionReviewVersions: - v1 clientConfig: diff --git a/pkg/controller.v1/jax/jaxjob_controller_suite_test.go b/pkg/controller.v1/jax/jaxjob_controller_suite_test.go new file mode 100644 index 0000000000..01ce56b28a --- /dev/null +++ b/pkg/controller.v1/jax/jaxjob_controller_suite_test.go @@ -0,0 +1,128 @@ +// Copyright 2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jax + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "path/filepath" + "testing" + "time" + + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/controller.v1/common" + jaxwebhook "github.com/kubeflow/training-operator/pkg/webhooks/jax" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/envtest" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" + //+kubebuilder:scaffold:imports +) + +// These tests use Ginkgo (BDD-style Go testing framework). Refer to +// http://onsi.github.io/ginkgo/ to learn more about Ginkgo. + +var ( + testK8sClient client.Client + testEnv *envtest.Environment + testCtx context.Context + testCancel context.CancelFunc +) + +func TestAPIs(t *testing.T) { + RegisterFailHandler(Fail) + + RunSpecs(t, "Controller Suite") +} + +var _ = BeforeSuite(func() { + logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + + testCtx, testCancel = context.WithCancel(context.TODO()) + + By("bootstrapping test environment") + testEnv = &envtest.Environment{ + CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, + ErrorIfCRDPathMissing: true, + WebhookInstallOptions: envtest.WebhookInstallOptions{ + Paths: []string{filepath.Join("..", "..", "..", "manifests", "base", "webhook", "manifests.yaml")}, + }, + } + + cfg, err := testEnv.Start() + Expect(err).NotTo(HaveOccurred()) + Expect(cfg).NotTo(BeNil()) + + err = v1beta1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + err = kubeflowv1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) + + //+kubebuilder:scaffold:scheme + + testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + Expect(err).NotTo(HaveOccurred()) + Expect(testK8sClient).NotTo(BeNil()) + + mgr, err := ctrl.NewManager(cfg, ctrl.Options{ + Metrics: metricsserver.Options{ + BindAddress: "0", + }, + WebhookServer: webhook.NewServer( + webhook.Options{ + Host: testEnv.WebhookInstallOptions.LocalServingHost, + Port: testEnv.WebhookInstallOptions.LocalServingPort, + CertDir: testEnv.WebhookInstallOptions.LocalServingCertDir, + }), + }) + Expect(err).NotTo(HaveOccurred()) + + gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() + r := NewReconciler(mgr, gangSchedulingSetupFunc) + + Expect(r.SetupWithManager(mgr, 1)).NotTo(HaveOccurred()) + Expect(jaxwebhook.SetupWebhook(mgr)).NotTo(HaveOccurred()) + + go func() { + defer GinkgoRecover() + err = mgr.Start(testCtx) + Expect(err).ToNot(HaveOccurred(), "failed to run manager") + }() + + dialer := &net.Dialer{Timeout: time.Second} + addrPort := fmt.Sprintf("%s:%d", testEnv.WebhookInstallOptions.LocalServingHost, testEnv.WebhookInstallOptions.LocalServingPort) + Eventually(func(g Gomega) { + conn, err := tls.DialWithDialer(dialer, "tcp", addrPort, &tls.Config{InsecureSkipVerify: true}) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(conn.Close()).NotTo(HaveOccurred()) + }).Should(Succeed()) +}) + +var _ = AfterSuite(func() { + By("tearing down the test environment") + testCancel() + err := testEnv.Stop() + Expect(err).NotTo(HaveOccurred()) +}) diff --git a/pkg/controller.v1/jax/jaxjob_controller_test.go b/pkg/controller.v1/jax/jaxjob_controller_test.go new file mode 100644 index 0000000000..289d5a3efe --- /dev/null +++ b/pkg/controller.v1/jax/jaxjob_controller_test.go @@ -0,0 +1,320 @@ +// Copyright 2024 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jax + +import ( + "context" + "fmt" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + commonutil "github.com/kubeflow/training-operator/pkg/util" + "github.com/kubeflow/training-operator/pkg/util/testutil" +) + +var _ = Describe("JAXJob controller", func() { + // Define utility constants for object names. + const ( + expectedPort = int32(6666) + ) + + Context("When creating the JAXJob", func() { + const name = "test-job" + var ( + ns *corev1.Namespace + job *kubeflowv1.JAXJob + jobKey types.NamespacedName + worker0Key types.NamespacedName + ctx = context.Background() + ) + BeforeEach(func() { + ns = &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + GenerateName: "jax-test-", + }, + } + Expect(testK8sClient.Create(ctx, ns)).Should(Succeed()) + + job = newJAXJobForTest(name, ns.Name) + jobKey = client.ObjectKeyFromObject(job) + + worker0Key = types.NamespacedName{ + Name: fmt.Sprintf("%s-worker-0", name), + Namespace: ns.Name, + } + job.Spec.JAXReplicaSpecs = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ + kubeflowv1.JAXJobReplicaTypeWorker: { + Replicas: ptr.To[int32](2), + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Image: "test-image", + Name: kubeflowv1.JAXJobDefaultContainerName, + Ports: []corev1.ContainerPort{ + { + Name: kubeflowv1.JAXJobDefaultPortName, + ContainerPort: expectedPort, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + }, + } + }) + AfterEach(func() { + Expect(testK8sClient.Delete(ctx, job)).Should(Succeed()) + Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed()) + }) + + It("Shouldn't create resources if JAXJob is suspended", func() { + By("By creating a new JAXJob with suspend=true") + job.Spec.RunPolicy.Suspend = ptr.To(true) + job.Spec.JAXReplicaSpecs[kubeflowv1.JAXJobReplicaTypeWorker].Replicas = ptr.To[int32](1) + Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) + + created := &kubeflowv1.JAXJob{} + workerPod := &corev1.Pod{} + workerSvc := &corev1.Service{} + + By("Checking created JAXJob") + Eventually(func() bool { + err := testK8sClient.Get(ctx, jobKey, created) + return err == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + By("Checking created JAXJob has a nil startTime") + Consistently(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.StartTime + }, testutil.ConsistentDuration, testutil.Interval).Should(BeNil()) + + By("Checking if the pods and services aren't created") + Consistently(func() bool { + errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod) + errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errWorkerPod) && + errors.IsNotFound(errWorkerSvc) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + + By("Checking if the JAXJob has suspended condition") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.ConsistentDuration, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("JAXJob %s is created.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("JAXJob %s is suspended.", name), + }, + }, testutil.IgnoreJobConditionsTimes)) + }) + + It("Should delete resources after JAXJob is suspended; Should resume JAXJob after JAXJob is unsuspended", func() { + By("By creating a new JAXJob") + job.Spec.JAXReplicaSpecs[kubeflowv1.JAXJobReplicaTypeWorker].Replicas = ptr.To[int32](1) + Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) + + created := &kubeflowv1.JAXJob{} + workerPod := &corev1.Pod{} + workerSvc := &corev1.Service{} + + // We'll need to retry getting this newly created JAXJob, given that creation may not immediately happen. + By("Checking created JAXJob") + Eventually(func() bool { + err := testK8sClient.Get(ctx, jobKey, created) + return err == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + + var startTimeBeforeSuspended *metav1.Time + Eventually(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + startTimeBeforeSuspended = created.Status.StartTime + return startTimeBeforeSuspended + }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) + + By("Checking the created pods and services") + Eventually(func() bool { + errWorker := testK8sClient.Get(ctx, worker0Key, workerPod) + return errWorker == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Eventually(func() bool { + errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errWorker == nil + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + + By("Updating the pod's phase with Running") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed()) + workerPod.Status.Phase = corev1.PodRunning + return testK8sClient.Status().Update(ctx, workerPod) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking the JAXJob's condition") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("JAXJob %s is created.", name), + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason), + Message: fmt.Sprintf("JAXJob %s is running.", name), + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Updating the JAXJob with suspend=true") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + created.Spec.RunPolicy.Suspend = ptr.To(true) + return testK8sClient.Update(ctx, created) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking if the pods and services are removed") + Eventually(func() bool { + errWorker := testK8sClient.Get(ctx, worker0Key, workerPod) + return errors.IsNotFound(errWorker) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Eventually(func() bool { + errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errWorker) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Consistently(func() bool { + errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod) + errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc) + return errors.IsNotFound(errWorkerPod) && + errors.IsNotFound(errWorkerSvc) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + + By("Checking if the JAXJob has a suspended condition") + Eventually(func() bool { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.ReplicaStatuses[kubeflowv1.JAXJobReplicaTypeWorker].Active == 0 && + created.Status.StartTime.Equal(startTimeBeforeSuspended) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) + Consistently(func() bool { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.ReplicaStatuses[kubeflowv1.JAXJobReplicaTypeWorker].Active == 0 && + created.Status.StartTime.Equal(startTimeBeforeSuspended) + }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) + Expect(created.Status.Conditions).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("JAXJob %s is created.", name), + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionFalse, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("JAXJob %s is suspended.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobSuspendedReason), + Message: fmt.Sprintf("JAXJob %s is suspended.", name), + Status: corev1.ConditionTrue, + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Unsuspending the JAXJob") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + created.Spec.RunPolicy.Suspend = ptr.To(false) + return testK8sClient.Update(ctx, created) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + Eventually(func() *metav1.Time { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.StartTime + }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) + + By("Check if the pods and services are created") + Eventually(func() error { + return testK8sClient.Get(ctx, worker0Key, workerPod) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) + Eventually(func() error { + return testK8sClient.Get(ctx, worker0Key, workerSvc) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) + + By("Updating Pod's condition with running") + Eventually(func() error { + Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed()) + workerPod.Status.Phase = corev1.PodRunning + return testK8sClient.Status().Update(ctx, workerPod) + }, testutil.Timeout, testutil.Interval).Should(Succeed()) + + By("Checking if the JAXJob has resumed conditions") + Eventually(func() []kubeflowv1.JobCondition { + Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) + return created.Status.Conditions + }, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ + { + Type: kubeflowv1.JobCreated, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), + Message: fmt.Sprintf("JAXJob %s is created.", name), + }, + { + Type: kubeflowv1.JobSuspended, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobResumedReason), + Message: fmt.Sprintf("JAXJob %s is resumed.", name), + Status: corev1.ConditionFalse, + }, + { + Type: kubeflowv1.JobRunning, + Status: corev1.ConditionTrue, + Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason), + Message: fmt.Sprintf("JAXJob %s is running.", name), + }, + }, testutil.IgnoreJobConditionsTimes)) + + By("Checking if the startTime is updated") + Expect(created.Status.StartTime).ShouldNot(Equal(startTimeBeforeSuspended)) + }) + }) +}) + +func newJAXJobForTest(name, namespace string) *kubeflowv1.JAXJob { + return &kubeflowv1.JAXJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + } +} diff --git a/pkg/webhooks/jax/jaxjob_webhook.go b/pkg/webhooks/jax/jaxjob_webhook.go new file mode 100644 index 0000000000..12888b3d3c --- /dev/null +++ b/pkg/webhooks/jax/jaxjob_webhook.go @@ -0,0 +1,124 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jax + +import ( + "context" + "fmt" + "slices" + "strings" + + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" +) + +var ( + specPath = field.NewPath("spec") + jaxReplicaSpecPath = specPath.Child("jaxReplicaSpecs") +) + +type Webhook struct{} + +func SetupWebhook(mgr ctrl.Manager) error { + return ctrl.NewWebhookManagedBy(mgr). + For(&trainingoperator.JAXJob{}). + WithValidator(&Webhook{}). + Complete() +} + +// +kubebuilder:webhook:path=/validate-kubeflow-org-v1-jaxjob,mutating=false,failurePolicy=fail,sideEffects=None,groups=kubeflow.org,resources=jaxjobs,verbs=create;update,versions=v1,name=validator.jaxjob.training-operator.kubeflow.org,admissionReviewVersions=v1 + +var _ webhook.CustomValidator = &Webhook{} + +func (w *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { + job := obj.(*trainingoperator.JAXJob) + log := ctrl.LoggerFrom(ctx).WithName("jaxjob-webhook") + log.V(5).Info("Validating create", "jaxJob", klog.KObj(job)) + return nil, validateJAXJob(job).ToAggregate() +} + +func (w *Webhook) ValidateUpdate(ctx context.Context, _ runtime.Object, newObj runtime.Object) (admission.Warnings, error) { + job := newObj.(*trainingoperator.JAXJob) + log := ctrl.LoggerFrom(ctx).WithName("jaxjob-webhook") + log.V(5).Info("Validating update", "jaxJob", klog.KObj(job)) + return nil, validateJAXJob(job).ToAggregate() +} + +func (w *Webhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { + return nil, nil +} + +func validateJAXJob(job *trainingoperator.JAXJob) field.ErrorList { + var allErrs field.ErrorList + if errors := apimachineryvalidation.NameIsDNS1035Label(job.ObjectMeta.Name, false); len(errors) != 0 { + allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), job.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) + } + + allErrs = append(allErrs, validateSpec(job.Spec)...) + return allErrs +} + +func validateSpec(spec trainingoperator.JAXJobSpec) field.ErrorList { + return validateJAXReplicaSpecs(spec.JAXReplicaSpecs) +} + +func validateJAXReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { + var allErrs field.ErrorList + + if rSpecs == nil { + allErrs = append(allErrs, field.Required(jaxReplicaSpecPath, "must be required")) + } + for rType, rSpec := range rSpecs { + rolePath := jaxReplicaSpecPath.Key(string(rType)) + containersPath := rolePath.Child("template").Child("spec").Child("containers") + + // Make sure the replica type is valid. + validRoleTypes := []trainingoperator.ReplicaType{ + trainingoperator.JAXJobReplicaTypeWorker, + } + if !slices.Contains(validRoleTypes, rType) { + allErrs = append(allErrs, field.NotSupported(rolePath, rType, validRoleTypes)) + } + + if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { + allErrs = append(allErrs, field.Required(containersPath, "must be specified")) + } + + // Make sure the image is defined in the container + defaultContainerPresent := false + for idx, container := range rSpec.Template.Spec.Containers { + if container.Image == "" { + allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) + } + if container.Name == trainingoperator.JAXJobDefaultContainerName { + defaultContainerPresent = true + } + } + // Make sure there has at least one container named "jax" + if !defaultContainerPresent { + allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.JAXJobDefaultContainerName))) + } + } + return allErrs +} diff --git a/pkg/webhooks/webhooks.go b/pkg/webhooks/webhooks.go index 29ad08e2fd..d1dd2b2f8e 100644 --- a/pkg/webhooks/webhooks.go +++ b/pkg/webhooks/webhooks.go @@ -20,6 +20,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/webhooks/jax" "github.com/kubeflow/training-operator/pkg/webhooks/paddlepaddle" "github.com/kubeflow/training-operator/pkg/webhooks/pytorch" "github.com/kubeflow/training-operator/pkg/webhooks/tensorflow" @@ -35,6 +36,7 @@ var ( trainingoperator.XGBoostJobKind: xgboost.SetupWebhook, trainingoperator.MPIJobKind: scaffold, trainingoperator.PaddleJobKind: paddlepaddle.SetupWebhook, + trainingoperator.JAXJobKind: jax.SetupWebhook, } ) diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py new file mode 100644 index 0000000000..483ac011f4 --- /dev/null +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -0,0 +1,162 @@ +# Copyright 2024 kubeflow.org. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import pytest +from typing import Optional + +from kubernetes.client import V1PodTemplateSpec +from kubernetes.client import V1ObjectMeta +from kubernetes.client import V1PodSpec +from kubernetes.client import V1Container +from kubernetes.client import V1ResourceRequirements + +from kubeflow.training import TrainingClient +from kubeflow.training import KubeflowOrgV1ReplicaSpec +from kubeflow.training import KubeflowOrgV1JAXJob +from kubeflow.training import KubeflowOrgV1JAXJobSpec +from kubeflow.training import KubeflowOrgV1RunPolicy +from kubeflow.training import KubeflowOrgV1SchedulingPolicy +from kubeflow.training.constants import constants + +import test.e2e.utils as utils +from test.e2e.constants import TEST_GANG_SCHEDULER_NAME_ENV_KEY +from test.e2e.constants import GANG_SCHEDULERS, NONE_GANG_SCHEDULERS + +logging.basicConfig(format="%(message)s") +logging.getLogger("kubeflow.training.api.training_client").setLevel(logging.DEBUG) + +TRAINING_CLIENT = TrainingClient(job_kind=constants.JAXJOB_KIND) +JOB_NAME = "jaxjob-cpu-ci-test" +CONTAINER_NAME = "jax" +GANG_SCHEDULER_NAME = os.getenv(TEST_GANG_SCHEDULER_NAME_ENV_KEY, "") + + +@pytest.mark.skipif( + GANG_SCHEDULER_NAME in NONE_GANG_SCHEDULERS, + reason="For gang-scheduling", +) +def test_sdk_e2e_with_gang_scheduling(job_namespace): + container = generate_container() + + worker = KubeflowOrgV1ReplicaSpec( + replicas=2, + restart_policy="OnFailure", + template=V1PodTemplateSpec( + metadata=V1ObjectMeta( + annotations={constants.ISTIO_SIDECAR_INJECTION: "false"} + ), + spec=V1PodSpec( + scheduler_name=utils.get_pod_spec_scheduler_name(GANG_SCHEDULER_NAME), + containers=[container], + ), + ), + ) + + unschedulable_jaxjob = generate_jaxjob( + job_namespace, worker, KubeflowOrgV1SchedulingPolicy(min_available=10) + ) + schedulable_jaxjob = generate_jaxjob( + job_namespace, worker, KubeflowOrgV1SchedulingPolicy(min_available=2) + ) + + TRAINING_CLIENT.create_job(job=unschedulable_jaxjob, namespace=job_namespace) + logging.info(f"List of created {TRAINING_CLIENT.job_kind}s") + logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) + + try: + utils.verify_unschedulable_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace) + except Exception as e: + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + raise Exception(f"JAXJob E2E fails. Exception: {e}") + + TRAINING_CLIENT.update_job(schedulable_jaxjob, JOB_NAME, job_namespace) + logging.info(f"List of updated {TRAINING_CLIENT.job_kind}s") + logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) + + try: + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=900) + except Exception as e: + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + raise Exception(f"JAXJob E2E fails. Exception: {e}") + + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + + +@pytest.mark.skipif( + GANG_SCHEDULER_NAME in GANG_SCHEDULERS, + reason="For plain scheduling", +) +def test_sdk_e2e(job_namespace): + container = generate_container() + + worker = KubeflowOrgV1ReplicaSpec( + replicas=2, + restart_policy="OnFailure", + template=V1PodTemplateSpec( + metadata=V1ObjectMeta( + annotations={constants.ISTIO_SIDECAR_INJECTION: "false"} + ), + spec=V1PodSpec(containers=[container]), + ), + ) + + jaxjob = generate_jaxjob(job_namespace, worker) + + TRAINING_CLIENT.create_job(job=jaxjob, namespace=job_namespace) + logging.info(f"List of created {TRAINING_CLIENT.job_kind}s") + logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) + + try: + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=900) + except Exception as e: + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + raise Exception(f"JAXJob E2E fails. Exception: {e}") + + utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) + TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) + + +def generate_jaxjob( + job_namespace: str, + worker: KubeflowOrgV1ReplicaSpec, + scheduling_policy: Optional[KubeflowOrgV1SchedulingPolicy] = None, +) -> KubeflowOrgV1JAXJob: + return KubeflowOrgV1JAXJob( + api_version=constants.API_VERSION, + kind=constants.JAXJOB_KIND, + metadata=V1ObjectMeta(name=JOB_NAME, namespace=job_namespace), + spec=KubeflowOrgV1JAXJobSpec( + run_policy=KubeflowOrgV1RunPolicy( + scheduling_policy=scheduling_policy, + clean_pod_policy="None", + ), + jax_replica_specs={"Worker": worker}, + ), + ) + + +# def generate_container() -> V1Container: +# return V1Container( +# name=CONTAINER_NAME, +# image="docker.io/sandipanify/jaxgloo", +# command=["python"], +# args=["-m", "", ""], +# resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), +# )