From f9fa8ac0585701d336099d799fb2a6e554dfa9bf Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Sun, 15 Jan 2023 07:53:51 +0900 Subject: [PATCH] Add scheduler plugin for coscheduling Signed-off-by: Yuki Iwai Co-authored-by: Wang Zhang Signed-off-by: Yuki Iwai --- cmd/training-operator.v1/main.go | 21 ++++++-- go.mod | 11 +++-- go.sum | 22 +++++---- manifests/base/cluster-role.yaml | 6 +++ pkg/common/util/scheduler.go | 9 ---- pkg/controller.v1/mpi/mpijob.go | 9 +--- pkg/controller.v1/mpi/mpijob_controller.go | 49 +++++++++---------- pkg/controller.v1/mpi/suite_test.go | 7 +-- pkg/controller.v1/mxnet/mxjob_controller.go | 18 +++---- pkg/controller.v1/mxnet/mxnet.go | 4 +- pkg/controller.v1/mxnet/suite_test.go | 2 +- .../paddlepaddle/paddlepaddle_controller.go | 18 +++---- .../paddlepaddle_controller_suite_test.go | 6 ++- .../pytorch/pytorchjob_controller.go | 18 +++---- .../pytorchjob_controller_suite_test.go | 6 ++- pkg/controller.v1/register_controller.go | 31 ++++++------ pkg/controller.v1/tensorflow/suite_test.go | 10 ++-- .../tensorflow/tfjob_controller.go | 33 ++++++------- .../xgboost/xgboostjob_controller.go | 18 +++---- 19 files changed, 149 insertions(+), 149 deletions(-) diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index 9b338bf5ac..4b09586cdb 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -20,6 +20,7 @@ import ( "flag" "fmt" "os" + "strings" "go.uber.org/zap/zapcore" "k8s.io/apimachinery/pkg/runtime" @@ -29,8 +30,11 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/healthz" "sigs.k8s.io/controller-runtime/pkg/log/zap" + schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" + volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" + "github.com/kubeflow/common/pkg/controller.v1/common" commonutil "github.com/kubeflow/common/pkg/util" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/config" @@ -47,6 +51,7 @@ func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) utilruntime.Must(kubeflowv1.AddToScheme(scheme)) utilruntime.Must(v1beta1.AddToScheme(scheme)) + utilruntime.Must(schedulerpluginsv1alpha1.AddToScheme(scheme)) //+kubebuilder:scaffold:scheme } @@ -56,7 +61,6 @@ func main() { var leaderElectionID string var probeAddr string var enabledSchemes controllerv1.EnabledSchemes - var enableGangScheduling bool var gangSchedulerName string var namespace string var monitoringPort int @@ -69,8 +73,7 @@ func main() { flag.StringVar(&leaderElectionID, "leader-election-id", "1ca428e5.training-operator.kubeflow.org", "The ID for leader election.") flag.Var(&enabledSchemes, "enable-scheme", "Enable scheme(s) as --enable-scheme=tfjob --enable-scheme=pytorchjob, case insensitive."+ " Now supporting TFJob, PyTorchJob, MXNetJob, XGBoostJob, PaddleJob. By default, all supported schemes will be enabled.") - flag.BoolVar(&enableGangScheduling, "enable-gang-scheduling", false, "Set true to enable gang scheduling") - flag.StringVar(&gangSchedulerName, "gang-scheduler-name", "volcano", "The scheduler to gang-schedule kubeflow jobs, defaults to volcano") + flag.StringVar(&gangSchedulerName, "gang-scheduler-name", "none", "The scheduler to gang-schedule kubeflow jobs, defaults to none") flag.StringVar(&namespace, "namespace", os.Getenv(commonutil.EnvKubeflowNamespace), "The namespace to monitor kubeflow jobs. If unset, it monitors all namespaces cluster-wide."+ "If set, it only monitors kubeflow jobs in the given namespace.") flag.IntVar(&monitoringPort, "monitoring-port", 9443, "Endpoint port for displaying monitoring metrics. "+ @@ -110,6 +113,16 @@ func main() { os.Exit(1) } + // Prepare GangSchedulingSetupFunc + gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() + if strings.EqualFold(gangSchedulerName, string(common.GangSchedulerVolcano)) { + cfg := mgr.GetConfig() + volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg) + gangSchedulingSetupFunc = common.GenVolcanoSetupFunc(volcanoClientSet) + } else if strings.EqualFold(gangSchedulerName, string(common.GangSchedulerSchedulerPlugins)) { + gangSchedulingSetupFunc = common.GenSchedulerPluginsSetupFunc(mgr.GetClient()) + } + // TODO: We need a general manager. all rest reconciler addsToManager // Based on the user configuration, we start different controllers if enabledSchemes.Empty() { @@ -122,7 +135,7 @@ func main() { "scheme not supported", "scheme", s) os.Exit(1) } - if err = setupFunc(mgr, enableGangScheduling, controllerThreads); err != nil { + if err = setupFunc(mgr, gangSchedulingSetupFunc, controllerThreads); err != nil { setupLog.Error(err, "unable to create controller", "controller", s) os.Exit(1) } diff --git a/go.mod b/go.mod index d7373d8a6a..5a3f1ae76a 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( k8s.io/kube-openapi v0.0.0-20220803162953-67bda5d908f1 k8s.io/utils v0.0.0-20220728103510-ee6ede2d64ed sigs.k8s.io/controller-runtime v0.13.0 + sigs.k8s.io/scheduler-plugins v0.24.9 sigs.k8s.io/yaml v1.3.0 volcano.sh/apis v1.2.0-k8s1.19.6 ) @@ -68,11 +69,11 @@ require ( go.uber.org/multierr v1.8.0 // indirect golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect + golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10 // indirect golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect - golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect - golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect - golang.org/x/text v0.3.7 // indirect + golang.org/x/sys v0.3.0 // indirect + golang.org/x/term v0.3.0 // indirect + golang.org/x/text v0.5.0 // indirect golang.org/x/time v0.0.0-20220609170525-579cf78fd858 // indirect golang.org/x/tools v0.1.12 // indirect gomodules.xyz/jsonpatch/v2 v2.2.0 // indirect @@ -88,3 +89,5 @@ require ( sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect ) + +replace github.com/kubeflow/common v0.4.5 => github.com/tenzen-y/common v0.0.0-20230115192234-865b54409a0b diff --git a/go.sum b/go.sum index d060cbe281..2463792221 100644 --- a/go.sum +++ b/go.sum @@ -327,8 +327,6 @@ github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kubeflow/common v0.4.5 h1:W7p+s/4Za1UzIgKP2Z6ormEvsUVHykeaXaOuu8+UgpI= -github.com/kubeflow/common v0.4.5/go.mod h1:di43u2m7DyuwnRDb7Kwz1nmA/nhpjnQ+K+gWCV/SPZk= github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= @@ -427,6 +425,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/tenzen-y/common v0.0.0-20230115192234-865b54409a0b h1:Dc2HDPNTzmpSOuOo1d8LeIIC0Sr0InPhdYTqYEEp2CU= +github.com/tenzen-y/common v0.0.0-20230115192234-865b54409a0b/go.mod h1:kI2gL98Ts9uJLHiKZzJZn7chVd9gUD85G+arE4M+7Lo= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= @@ -555,8 +555,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10 h1:Frnccbp+ok2GkUS2tC84yAq/U9Vg+0sIO7aRL3T4Xnc= +golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -648,11 +648,11 @@ golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= +golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -661,8 +661,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -954,6 +954,8 @@ sigs.k8s.io/controller-runtime v0.13.0 h1:iqa5RNciy7ADWnIc8QxCbOX5FEKVR3uxVxKHRM sigs.k8s.io/controller-runtime v0.13.0/go.mod h1:Zbz+el8Yg31jubvAEyglRZGdLAjplZl+PgtYNI6WNTI= sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 h1:iXTIw73aPyC+oRdyqqvVJuloN1p0AC/kzH07hu3NE+k= sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= +sigs.k8s.io/scheduler-plugins v0.24.9 h1:9oGtwk6uh7mZMCX8+O+PipQzBiRq9d2+E3xq1cn7zbc= +sigs.k8s.io/scheduler-plugins v0.24.9/go.mod h1:0u2b/0SwY2ozDhOD/f1S3e5IbStoDFLUK8yP5dJTaQ8= sigs.k8s.io/structured-merge-diff/v4 v4.0.1/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw= sigs.k8s.io/structured-merge-diff/v4 v4.2.3 h1:PRbqxJClWWYMNV1dhaG4NsibJbArud9kFxnAMREiWFE= sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E= diff --git a/manifests/base/cluster-role.yaml b/manifests/base/cluster-role.yaml index ae3732dafa..eb6004ebb0 100644 --- a/manifests/base/cluster-role.yaml +++ b/manifests/base/cluster-role.yaml @@ -91,3 +91,9 @@ rules: - horizontalpodautoscalers verbs: - "*" + - apiGroups: + - scheduling.sigs.k8s.io + resources: + - podgroups + verbs: + - "*" diff --git a/pkg/common/util/scheduler.go b/pkg/common/util/scheduler.go index c8863fe714..754d18b810 100644 --- a/pkg/common/util/scheduler.go +++ b/pkg/common/util/scheduler.go @@ -16,21 +16,12 @@ package util import commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" -const ( - DefaultGangSchedulerName = "volcano" -) - func IsGangSchedulerSet(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, schedulerName string) bool { - if len(schedulerName) == 0 { - schedulerName = DefaultGangSchedulerName - } - for _, spec := range replicas { if spec.Template.Spec.SchedulerName != "" && spec.Template.Spec.SchedulerName == schedulerName { return true } } - return false } diff --git a/pkg/controller.v1/mpi/mpijob.go b/pkg/controller.v1/mpi/mpijob.go index 47fe863140..36dae7a1d8 100644 --- a/pkg/controller.v1/mpi/mpijob.go +++ b/pkg/controller.v1/mpi/mpijob.go @@ -18,11 +18,11 @@ import ( "strings" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" - - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" ) const ( @@ -69,14 +69,9 @@ const ( // policy is set in pod template. podTemplateRestartPolicyReason = "SettedPodTemplateRestartPolicy" - // gang scheduler name. - gangSchedulerName = "volcano" - // podTemplateSchedulerNameReason is the warning reason when other scheduler name is set // in pod templates with gang-scheduling enabled podTemplateSchedulerNameReason = "SettedPodTemplateSchedulerName" - // gangSchedulingPodGroupAnnotation is the annotation key used by batch schedulers - gangSchedulingPodGroupAnnotation = "scheduling.k8s.io/group-name" // volcanoTaskSpecKey task spec key used in pod annotation when EnableGangScheduling is true volcanoTaskSpecKey = "volcano.sh/task-spec" diff --git a/pkg/controller.v1/mpi/mpijob_controller.go b/pkg/controller.v1/mpi/mpijob_controller.go index be2a60acbe..6c30f1e4be 100644 --- a/pkg/controller.v1/mpi/mpijob_controller.go +++ b/pkg/controller.v1/mpi/mpijob_controller.go @@ -25,12 +25,6 @@ import ( "time" "github.com/go-logr/logr" - commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" - "github.com/kubeflow/common/pkg/controller.v1/common" - "github.com/kubeflow/common/pkg/controller.v1/control" - "github.com/kubeflow/common/pkg/controller.v1/expectation" - commonutil "github.com/kubeflow/common/pkg/util" - "github.com/kubeflow/training-operator/pkg/common/util" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" @@ -54,10 +48,15 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" - volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/controller.v1/common" + "github.com/kubeflow/common/pkg/controller.v1/control" + "github.com/kubeflow/common/pkg/controller.v1/expectation" + commonutil "github.com/kubeflow/common/pkg/util" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" + "github.com/kubeflow/training-operator/pkg/common/util" ctlrconfig "github.com/kubeflow/training-operator/pkg/config" ) @@ -69,7 +68,7 @@ const ( labelMPIJobName = "mpi-job-name" ) -func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *MPIJobReconciler { +func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *MPIJobReconciler { r := &MPIJobReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), @@ -80,23 +79,23 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *MPIJobReconc cfg := mgr.GetConfig() kubeClientSet := kubeclientset.NewForConfigOrDie(cfg) - volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg) sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0) priorityClassInformer := sharedInformers.Scheduling().V1beta1().PriorityClasses() r.JobController = common.JobController{ Controller: r, Expectations: expectation.NewControllerExpectations(), - Config: common.JobControllerConfiguration{EnableGangScheduling: enableGangScheduling}, WorkQueue: &util.FakeWorkQueue{}, Recorder: r.recorder, KubeClientSet: kubeClientSet, - VolcanoClientSet: volcanoClientSet, PriorityClassLister: priorityClassInformer.Lister(), PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced, PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder}, ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder}, } + + gangSchedulingSetupFunc(&r.JobController) + return r } @@ -1000,20 +999,18 @@ func (jc *MPIJobReconciler) newWorker(mpiJob *kubeflowv1.MPIJob, name string) *c // if gang-scheduling is enabled: // 1. if user has specified other scheduler, we report a warning without overriding any fields. // 2. if no SchedulerName is set for pods, then we set the SchedulerName to "volcano". - if jc.Config.EnableGangScheduling { - if util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, gangSchedulerName) { + if jc.Config.EnableGangScheduling() { + if !util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, jc.PodGroupControl.GetSchedulerName()) { errMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten" logger.Warning(errMsg) jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg) - } else { - podSpec.Spec.SchedulerName = gangSchedulerName } - if podSpec.Annotations == nil { - podSpec.Annotations = map[string]string{} + rtWorker := strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeWorker)) + jc.PodGroupControl.DecoratePodTemplateSpec(podSpec, mpiJob, rtWorker) + if jc.PodGroupControl.GetSchedulerName() == "volcano" { + podSpec.Annotations[volcanoTaskSpecKey] = rtWorker } - podSpec.Annotations[gangSchedulingPodGroupAnnotation] = mpiJob.GetName() - podSpec.Annotations[volcanoTaskSpecKey] = strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeWorker)) } return &corev1.Pod{ @@ -1054,20 +1051,18 @@ func (jc *MPIJobReconciler) newLauncher(mpiJob *kubeflowv1.MPIJob, kubectlDelive logger := commonutil.LoggerForReplica(mpiJob, strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeLauncher))) // add SchedulerName to podSpec - if jc.Config.EnableGangScheduling { - if util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, gangSchedulerName) { + if jc.Config.EnableGangScheduling() { + if !util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, jc.PodGroupControl.GetSchedulerName()) { errMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten" logger.Warning(errMsg) jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg) - } else { - podSpec.Spec.SchedulerName = gangSchedulerName } - if podSpec.Annotations == nil { - podSpec.Annotations = map[string]string{} + rt := strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeLauncher)) + jc.PodGroupControl.DecoratePodTemplateSpec(podSpec, mpiJob, rt) + if jc.PodGroupControl.GetSchedulerName() == "volcano" { + podSpec.Annotations[volcanoTaskSpecKey] = rt } - podSpec.Annotations[gangSchedulingPodGroupAnnotation] = mpiJob.GetName() - podSpec.Annotations[volcanoTaskSpecKey] = strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeLauncher)) } podSpec.Spec.ServiceAccountName = launcherName diff --git a/pkg/controller.v1/mpi/suite_test.go b/pkg/controller.v1/mpi/suite_test.go index 1f25f00af0..01cd941e43 100644 --- a/pkg/controller.v1/mpi/suite_test.go +++ b/pkg/controller.v1/mpi/suite_test.go @@ -19,19 +19,19 @@ import ( "path/filepath" "testing" + "github.com/kubeflow/common/pkg/controller.v1/common" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/config" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "k8s.io/client-go/kubernetes/scheme" - ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" - v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -86,7 +86,8 @@ var _ = BeforeSuite(func() { }) Expect(err).NotTo(HaveOccurred()) - reconciler = NewReconciler(mgr, false) + gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() + reconciler = NewReconciler(mgr, gangSchedulingSetupFunc) Expect(reconciler.SetupWithManager(mgr, 1)).NotTo(HaveOccurred()) go func() { diff --git a/pkg/controller.v1/mxnet/mxjob_controller.go b/pkg/controller.v1/mxnet/mxjob_controller.go index 84cb1c6e2f..eb0c89250e 100644 --- a/pkg/controller.v1/mxnet/mxjob_controller.go +++ b/pkg/controller.v1/mxnet/mxjob_controller.go @@ -20,12 +20,16 @@ import ( "reflect" "time" - "github.com/go-logr/logr" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/common" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" commonutil "github.com/kubeflow/common/pkg/util" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" + "github.com/kubeflow/training-operator/pkg/common/util" + + "github.com/go-logr/logr" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" @@ -48,11 +52,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" - volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" - - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" - "github.com/kubeflow/training-operator/pkg/common/util" ) const ( @@ -69,7 +68,7 @@ const ( ) // NewReconciler creates a MXJob Reconciler -func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *MXJobReconciler { +func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *MXJobReconciler { r := &MXJobReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), @@ -81,7 +80,6 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *MXJobReconci // Create clients. cfg := mgr.GetConfig() kubeClientSet := kubeclientset.NewForConfigOrDie(cfg) - volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg) sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0) priorityClassInformer := sharedInformers.Scheduling().V1beta1().PriorityClasses() @@ -89,17 +87,17 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *MXJobReconci r.JobController = common.JobController{ Controller: r, Expectations: expectation.NewControllerExpectations(), - Config: common.JobControllerConfiguration{EnableGangScheduling: enableGangScheduling}, WorkQueue: &util.FakeWorkQueue{}, Recorder: r.Recorder, KubeClientSet: kubeClientSet, - VolcanoClientSet: volcanoClientSet, PriorityClassLister: priorityClassInformer.Lister(), PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced, PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.Recorder}, ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.Recorder}, } + gangSchedulingSetupFunc(&r.JobController) + return r } diff --git a/pkg/controller.v1/mxnet/mxnet.go b/pkg/controller.v1/mxnet/mxnet.go index 91b26119c6..26e4240388 100644 --- a/pkg/controller.v1/mxnet/mxnet.go +++ b/pkg/controller.v1/mxnet/mxnet.go @@ -22,9 +22,9 @@ import ( commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/common" - corev1 "k8s.io/api/core/v1" - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + + corev1 "k8s.io/api/core/v1" ) const ( diff --git a/pkg/controller.v1/mxnet/suite_test.go b/pkg/controller.v1/mxnet/suite_test.go index e78479fe7d..df09b331c3 100644 --- a/pkg/controller.v1/mxnet/suite_test.go +++ b/pkg/controller.v1/mxnet/suite_test.go @@ -27,7 +27,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" - v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) diff --git a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go index 547911d01d..8b3d7e6e96 100644 --- a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go +++ b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller.go @@ -20,12 +20,16 @@ import ( "strings" "time" - "github.com/go-logr/logr" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/common" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" commonutil "github.com/kubeflow/common/pkg/util" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" + "github.com/kubeflow/training-operator/pkg/common/util" + + "github.com/go-logr/logr" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" @@ -49,11 +53,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" - volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" - - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" - "github.com/kubeflow/training-operator/pkg/common/util" ) const ( @@ -61,7 +60,7 @@ const ( ) // NewReconciler creates a PaddleJob Reconciler -func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *PaddleJobReconciler { +func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *PaddleJobReconciler { r := &PaddleJobReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), @@ -73,7 +72,6 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *PaddleJobRec // Create clients cfg := mgr.GetConfig() kubeClientSet := kubeclientset.NewForConfigOrDie(cfg) - volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg) sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0) priorityClassInformer := sharedInformers.Scheduling().V1beta1().PriorityClasses() @@ -81,17 +79,17 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *PaddleJobRec r.JobController = common.JobController{ Controller: r, Expectations: expectation.NewControllerExpectations(), - Config: common.JobControllerConfiguration{EnableGangScheduling: enableGangScheduling}, WorkQueue: &util.FakeWorkQueue{}, Recorder: r.recorder, KubeClientSet: kubeClientSet, - VolcanoClientSet: volcanoClientSet, PriorityClassLister: priorityClassInformer.Lister(), PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced, PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder}, ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder}, } + gangSchedulingSetupFunc(&r.JobController) + return r } diff --git a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_suite_test.go b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_suite_test.go index 6c21b028b7..aebb63e3c9 100644 --- a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_suite_test.go +++ b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_suite_test.go @@ -19,6 +19,7 @@ import ( "path/filepath" "testing" + "github.com/kubeflow/common/pkg/controller.v1/common" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" . "github.com/onsi/ginkgo/v2" @@ -30,7 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" - v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -81,7 +82,8 @@ var _ = BeforeSuite(func() { }) Expect(err).NotTo(gomega.HaveOccurred()) - r := NewReconciler(mgr, false) + gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() + r := NewReconciler(mgr, gangSchedulingSetupFunc) Expect(r.SetupWithManager(mgr, 1)).NotTo(gomega.HaveOccurred()) diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index a19de42aca..6212e36c19 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -20,12 +20,16 @@ import ( "strings" "time" - "github.com/go-logr/logr" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/common" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" commonutil "github.com/kubeflow/common/pkg/util" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" + "github.com/kubeflow/training-operator/pkg/common/util" + + "github.com/go-logr/logr" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" @@ -49,11 +53,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" - volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" - - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" - "github.com/kubeflow/training-operator/pkg/common/util" ) const ( @@ -61,7 +60,7 @@ const ( ) // NewReconciler creates a PyTorchJob Reconciler -func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *PyTorchJobReconciler { +func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *PyTorchJobReconciler { r := &PyTorchJobReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), @@ -73,7 +72,6 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *PyTorchJobRe // Create clients cfg := mgr.GetConfig() kubeClientSet := kubeclientset.NewForConfigOrDie(cfg) - volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg) sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0) priorityClassInformer := sharedInformers.Scheduling().V1beta1().PriorityClasses() @@ -81,17 +79,17 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *PyTorchJobRe r.JobController = common.JobController{ Controller: r, Expectations: expectation.NewControllerExpectations(), - Config: common.JobControllerConfiguration{EnableGangScheduling: enableGangScheduling}, WorkQueue: &util.FakeWorkQueue{}, Recorder: r.recorder, KubeClientSet: kubeClientSet, - VolcanoClientSet: volcanoClientSet, PriorityClassLister: priorityClassInformer.Lister(), PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced, PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder}, ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder}, } + gangSchedulingSetupFunc(&r.JobController) + return r } diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go index 6111d04e55..301f869cde 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go @@ -19,6 +19,7 @@ import ( "path/filepath" "testing" + "github.com/kubeflow/common/pkg/controller.v1/common" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/config" @@ -31,7 +32,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" - v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -86,7 +87,8 @@ var _ = BeforeSuite(func() { }) Expect(err).NotTo(gomega.HaveOccurred()) - r := NewReconciler(mgr, false) + gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() + r := NewReconciler(mgr, gangSchedulingSetupFunc) Expect(r.SetupWithManager(mgr, 1)).NotTo(gomega.HaveOccurred()) diff --git a/pkg/controller.v1/register_controller.go b/pkg/controller.v1/register_controller.go index bbdd0af108..2d8e446367 100644 --- a/pkg/controller.v1/register_controller.go +++ b/pkg/controller.v1/register_controller.go @@ -18,8 +18,7 @@ import ( "fmt" "strings" - "sigs.k8s.io/controller-runtime/pkg/manager" - + "github.com/kubeflow/common/pkg/controller.v1/common" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" mpicontroller "github.com/kubeflow/training-operator/pkg/controller.v1/mpi" mxnetcontroller "github.com/kubeflow/training-operator/pkg/controller.v1/mxnet" @@ -27,30 +26,32 @@ import ( pytorchcontroller "github.com/kubeflow/training-operator/pkg/controller.v1/pytorch" tensorflowcontroller "github.com/kubeflow/training-operator/pkg/controller.v1/tensorflow" xgboostcontroller "github.com/kubeflow/training-operator/pkg/controller.v1/xgboost" + + "sigs.k8s.io/controller-runtime/pkg/manager" ) const ErrTemplateSchemeNotSupported = "scheme %s is not supported yet" -type ReconcilerSetupFunc func(manager manager.Manager, enableGangScheduling bool, controllerThreads int) error +type ReconcilerSetupFunc func(manager manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error var SupportedSchemeReconciler = map[string]ReconcilerSetupFunc{ - kubeflowv1.TFJobKind: func(mgr manager.Manager, enableGangScheduling bool, controllerThreads int) error { - return tensorflowcontroller.NewReconciler(mgr, enableGangScheduling).SetupWithManager(mgr, controllerThreads) + kubeflowv1.TFJobKind: func(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error { + return tensorflowcontroller.NewReconciler(mgr, gangSchedulingSetupFunc).SetupWithManager(mgr, controllerThreads) }, - kubeflowv1.PytorchJobKind: func(mgr manager.Manager, enableGangScheduling bool, controllerThreads int) error { - return pytorchcontroller.NewReconciler(mgr, enableGangScheduling).SetupWithManager(mgr, controllerThreads) + kubeflowv1.PytorchJobKind: func(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error { + return pytorchcontroller.NewReconciler(mgr, gangSchedulingSetupFunc).SetupWithManager(mgr, controllerThreads) }, - kubeflowv1.MXJobKind: func(mgr manager.Manager, enableGangScheduling bool, controllerThreads int) error { - return mxnetcontroller.NewReconciler(mgr, enableGangScheduling).SetupWithManager(mgr, controllerThreads) + kubeflowv1.MXJobKind: func(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error { + return mxnetcontroller.NewReconciler(mgr, gangSchedulingSetupFunc).SetupWithManager(mgr, controllerThreads) }, - kubeflowv1.XGBoostJobKind: func(mgr manager.Manager, enableGangScheduling bool, controllerThreads int) error { - return xgboostcontroller.NewReconciler(mgr, enableGangScheduling).SetupWithManager(mgr, controllerThreads) + kubeflowv1.XGBoostJobKind: func(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error { + return xgboostcontroller.NewReconciler(mgr, gangSchedulingSetupFunc).SetupWithManager(mgr, controllerThreads) }, - kubeflowv1.MPIJobKind: func(mgr manager.Manager, enableGangScheduling bool, controllerThreads int) error { - return mpicontroller.NewReconciler(mgr, enableGangScheduling).SetupWithManager(mgr, controllerThreads) + kubeflowv1.MPIJobKind: func(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error { + return mpicontroller.NewReconciler(mgr, gangSchedulingSetupFunc).SetupWithManager(mgr, controllerThreads) }, - kubeflowv1.PaddleJobKind: func(mgr manager.Manager, enableGangScheduling bool, controllerThreads int) error { - return paddlecontroller.NewReconciler(mgr, enableGangScheduling).SetupWithManager(mgr, controllerThreads) + kubeflowv1.PaddleJobKind: func(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc, controllerThreads int) error { + return paddlecontroller.NewReconciler(mgr, gangSchedulingSetupFunc).SetupWithManager(mgr, controllerThreads) }, } diff --git a/pkg/controller.v1/tensorflow/suite_test.go b/pkg/controller.v1/tensorflow/suite_test.go index e5a7cc9c3e..e7b69e3446 100644 --- a/pkg/controller.v1/tensorflow/suite_test.go +++ b/pkg/controller.v1/tensorflow/suite_test.go @@ -21,6 +21,9 @@ import ( "testing" "time" + "github.com/kubeflow/common/pkg/controller.v1/common" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -30,9 +33,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" - v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" - - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -88,7 +89,8 @@ var _ = BeforeSuite(func() { }) Expect(err).NotTo(HaveOccurred()) - reconciler = NewReconciler(mgr, false) + gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc() + reconciler = NewReconciler(mgr, gangSchedulingSetupFunc) Expect(reconciler.SetupWithManager(mgr, 1)).NotTo(HaveOccurred()) go func() { diff --git a/pkg/controller.v1/tensorflow/tfjob_controller.go b/pkg/controller.v1/tensorflow/tfjob_controller.go index d985179c5a..bd1623138f 100644 --- a/pkg/controller.v1/tensorflow/tfjob_controller.go +++ b/pkg/controller.v1/tensorflow/tfjob_controller.go @@ -21,13 +21,17 @@ import ( "strings" "time" - "github.com/go-logr/logr" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/common" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" commonutil "github.com/kubeflow/common/pkg/util" train_util "github.com/kubeflow/common/pkg/util/train" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" + "github.com/kubeflow/training-operator/pkg/common/util" + + "github.com/go-logr/logr" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1" @@ -50,11 +54,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" - volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" - - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" - "github.com/kubeflow/training-operator/pkg/common/util" ) const ( @@ -75,8 +74,6 @@ const ( // volcanoTaskSpecKey task spec key used in pod annotation when EnableGangScheduling is true volcanoTaskSpecKey = "volcano.sh/task-spec" - // gang scheduler name. - gangSchedulerName = "volcano" // tfConfig is the environment variable name of TensorFlow cluster spec. tfConfig = "TF_CONFIG" // exitedWithCodeReason is the normal reason when the pod is exited because of the exit code. @@ -87,11 +84,9 @@ const ( // podTemplateSchedulerNameReason is the warning reason when other scheduler name is set // in pod templates with gang-scheduling enabled podTemplateSchedulerNameReason = "SettedPodTemplateSchedulerName" - // gangSchedulingPodGroupAnnotation is the annotation key used by batch schedulers - gangSchedulingPodGroupAnnotation = "scheduling.k8s.io/group-name" ) -func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *TFJobReconciler { +func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *TFJobReconciler { r := &TFJobReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), @@ -102,24 +97,23 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *TFJobReconci cfg := mgr.GetConfig() kubeClientSet := kubeclientset.NewForConfigOrDie(cfg) - volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg) sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0) priorityClassInformer := sharedInformers.Scheduling().V1beta1().PriorityClasses() r.JobController = common.JobController{ Controller: r, Expectations: expectation.NewControllerExpectations(), - Config: common.JobControllerConfiguration{EnableGangScheduling: enableGangScheduling}, WorkQueue: &util.FakeWorkQueue{}, Recorder: r.recorder, KubeClientSet: kubeClientSet, - VolcanoClientSet: volcanoClientSet, PriorityClassLister: priorityClassInformer.Lister(), PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced, PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder}, ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder}, } + gangSchedulingSetupFunc(&r.JobController) + return r } @@ -868,8 +862,10 @@ func (r *TFJobReconciler) createNewPod(tfjob *kubeflowv1.TFJob, rt, index string // if gang-scheduling is enabled: // 1. if user has specified other scheduler, we report a warning without overriding any fields. // 2. if no SchedulerName is set for pods, then we set the SchedulerName to "volcano". - if r.Config.EnableGangScheduling { + if r.Config.EnableGangScheduling() { podSchedulerName := util.GetSchedulerName(replicas) + gangSchedulerName := r.PodGroupControl.GetSchedulerName() + if len(podSchedulerName) == 0 { podTemplate.Spec.SchedulerName = gangSchedulerName } else if strings.Compare(podSchedulerName, gangSchedulerName) != 0 { @@ -878,11 +874,10 @@ func (r *TFJobReconciler) createNewPod(tfjob *kubeflowv1.TFJob, rt, index string r.Recorder.Event(tfjob, v1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg) } - if podTemplate.Annotations == nil { - podTemplate.Annotations = map[string]string{} + r.PodGroupControl.DecoratePodTemplateSpec(podTemplate, tfjob, rt) + if gangSchedulerName == "volcano" { + podTemplate.Annotations[volcanoTaskSpecKey] = rt } - podTemplate.Annotations[gangSchedulingPodGroupAnnotation] = tfjob.GetName() - podTemplate.Annotations[volcanoTaskSpecKey] = rt } err = r.PodControl.CreatePodsWithControllerRef(tfjob.Namespace, podTemplate, tfjob, controllerRef) diff --git a/pkg/controller.v1/xgboost/xgboostjob_controller.go b/pkg/controller.v1/xgboost/xgboostjob_controller.go index 26680e018d..52f7e47b65 100644 --- a/pkg/controller.v1/xgboost/xgboostjob_controller.go +++ b/pkg/controller.v1/xgboost/xgboostjob_controller.go @@ -20,13 +20,17 @@ import ( "reflect" "time" - "github.com/go-logr/logr" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/common" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" commonutil "github.com/kubeflow/common/pkg/util" logger "github.com/kubeflow/common/pkg/util" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" + "github.com/kubeflow/training-operator/pkg/common/util" + + "github.com/go-logr/logr" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" @@ -50,11 +54,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/controller-runtime/pkg/source" "volcano.sh/apis/pkg/apis/scheduling/v1beta1" - volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" - - kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common" - "github.com/kubeflow/training-operator/pkg/common/util" ) const ( @@ -76,7 +75,7 @@ const ( ) // NewReconciler creates a XGBoostJob Reconciler -func NewReconciler(mgr manager.Manager, scheduling bool) *XGBoostJobReconciler { +func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *XGBoostJobReconciler { r := &XGBoostJobReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), @@ -88,7 +87,6 @@ func NewReconciler(mgr manager.Manager, scheduling bool) *XGBoostJobReconciler { // Create clients cfg := mgr.GetConfig() kubeClientSet := kubeclientset.NewForConfigOrDie(cfg) - volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg) sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0) priorityClassInformer := sharedInformers.Scheduling().V1beta1().PriorityClasses() @@ -96,17 +94,17 @@ func NewReconciler(mgr manager.Manager, scheduling bool) *XGBoostJobReconciler { r.JobController = common.JobController{ Controller: r, Expectations: expectation.NewControllerExpectations(), - Config: common.JobControllerConfiguration{EnableGangScheduling: false}, WorkQueue: &util.FakeWorkQueue{}, Recorder: r.recorder, KubeClientSet: kubeClientSet, - VolcanoClientSet: volcanoClientSet, PriorityClassLister: priorityClassInformer.Lister(), PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced, PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder}, ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder}, } + gangSchedulingSetupFunc(&r.JobController) + return r }