Skip to content

Commit

Permalink
Add scheduler plugin for coscheduling
Browse files Browse the repository at this point in the history
Signed-off-by: Yuki Iwai <[email protected]>
Co-authored-by: Wang Zhang <[email protected]>
Signed-off-by: Yuki Iwai <[email protected]>
  • Loading branch information
tenzen-y and zw0610 committed Jan 15, 2023
1 parent ddf372c commit f9fa8ac
Show file tree
Hide file tree
Showing 19 changed files with 149 additions and 149 deletions.
21 changes: 17 additions & 4 deletions cmd/training-operator.v1/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"flag"
"fmt"
"os"
"strings"

"go.uber.org/zap/zapcore"
"k8s.io/apimachinery/pkg/runtime"
Expand All @@ -29,8 +30,11 @@ import (
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/healthz"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
"volcano.sh/apis/pkg/apis/scheduling/v1beta1"
volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned"

"github.com/kubeflow/common/pkg/controller.v1/common"
commonutil "github.com/kubeflow/common/pkg/util"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"github.com/kubeflow/training-operator/pkg/config"
Expand All @@ -47,6 +51,7 @@ func init() {
utilruntime.Must(clientgoscheme.AddToScheme(scheme))
utilruntime.Must(kubeflowv1.AddToScheme(scheme))
utilruntime.Must(v1beta1.AddToScheme(scheme))
utilruntime.Must(schedulerpluginsv1alpha1.AddToScheme(scheme))
//+kubebuilder:scaffold:scheme
}

Expand All @@ -56,7 +61,6 @@ func main() {
var leaderElectionID string
var probeAddr string
var enabledSchemes controllerv1.EnabledSchemes
var enableGangScheduling bool
var gangSchedulerName string
var namespace string
var monitoringPort int
Expand All @@ -69,8 +73,7 @@ func main() {
flag.StringVar(&leaderElectionID, "leader-election-id", "1ca428e5.training-operator.kubeflow.org", "The ID for leader election.")
flag.Var(&enabledSchemes, "enable-scheme", "Enable scheme(s) as --enable-scheme=tfjob --enable-scheme=pytorchjob, case insensitive."+
" Now supporting TFJob, PyTorchJob, MXNetJob, XGBoostJob, PaddleJob. By default, all supported schemes will be enabled.")
flag.BoolVar(&enableGangScheduling, "enable-gang-scheduling", false, "Set true to enable gang scheduling")
flag.StringVar(&gangSchedulerName, "gang-scheduler-name", "volcano", "The scheduler to gang-schedule kubeflow jobs, defaults to volcano")
flag.StringVar(&gangSchedulerName, "gang-scheduler-name", "none", "The scheduler to gang-schedule kubeflow jobs, defaults to none")
flag.StringVar(&namespace, "namespace", os.Getenv(commonutil.EnvKubeflowNamespace), "The namespace to monitor kubeflow jobs. If unset, it monitors all namespaces cluster-wide."+
"If set, it only monitors kubeflow jobs in the given namespace.")
flag.IntVar(&monitoringPort, "monitoring-port", 9443, "Endpoint port for displaying monitoring metrics. "+
Expand Down Expand Up @@ -110,6 +113,16 @@ func main() {
os.Exit(1)
}

// Prepare GangSchedulingSetupFunc
gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc()
if strings.EqualFold(gangSchedulerName, string(common.GangSchedulerVolcano)) {
cfg := mgr.GetConfig()
volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg)
gangSchedulingSetupFunc = common.GenVolcanoSetupFunc(volcanoClientSet)
} else if strings.EqualFold(gangSchedulerName, string(common.GangSchedulerSchedulerPlugins)) {
gangSchedulingSetupFunc = common.GenSchedulerPluginsSetupFunc(mgr.GetClient())
}

// TODO: We need a general manager. all rest reconciler addsToManager
// Based on the user configuration, we start different controllers
if enabledSchemes.Empty() {
Expand All @@ -122,7 +135,7 @@ func main() {
"scheme not supported", "scheme", s)
os.Exit(1)
}
if err = setupFunc(mgr, enableGangScheduling, controllerThreads); err != nil {
if err = setupFunc(mgr, gangSchedulingSetupFunc, controllerThreads); err != nil {
setupLog.Error(err, "unable to create controller", "controller", s)
os.Exit(1)
}
Expand Down
11 changes: 7 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
k8s.io/kube-openapi v0.0.0-20220803162953-67bda5d908f1
k8s.io/utils v0.0.0-20220728103510-ee6ede2d64ed
sigs.k8s.io/controller-runtime v0.13.0
sigs.k8s.io/scheduler-plugins v0.24.9
sigs.k8s.io/yaml v1.3.0
volcano.sh/apis v1.2.0-k8s1.19.6
)
Expand Down Expand Up @@ -68,11 +69,11 @@ require (
go.uber.org/multierr v1.8.0 // indirect
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect
golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10 // indirect
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/sys v0.3.0 // indirect
golang.org/x/term v0.3.0 // indirect
golang.org/x/text v0.5.0 // indirect
golang.org/x/time v0.0.0-20220609170525-579cf78fd858 // indirect
golang.org/x/tools v0.1.12 // indirect
gomodules.xyz/jsonpatch/v2 v2.2.0 // indirect
Expand All @@ -88,3 +89,5 @@ require (
sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect
)

replace github.com/kubeflow/common v0.4.5 => github.com/tenzen-y/common v0.0.0-20230115192234-865b54409a0b
22 changes: 12 additions & 10 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,6 @@ github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kubeflow/common v0.4.5 h1:W7p+s/4Za1UzIgKP2Z6ormEvsUVHykeaXaOuu8+UgpI=
github.com/kubeflow/common v0.4.5/go.mod h1:di43u2m7DyuwnRDb7Kwz1nmA/nhpjnQ+K+gWCV/SPZk=
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
Expand Down Expand Up @@ -427,6 +425,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/tenzen-y/common v0.0.0-20230115192234-865b54409a0b h1:Dc2HDPNTzmpSOuOo1d8LeIIC0Sr0InPhdYTqYEEp2CU=
github.com/tenzen-y/common v0.0.0-20230115192234-865b54409a0b/go.mod h1:kI2gL98Ts9uJLHiKZzJZn7chVd9gUD85G+arE4M+7Lo=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
Expand Down Expand Up @@ -555,8 +555,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10 h1:Frnccbp+ok2GkUS2tC84yAq/U9Vg+0sIO7aRL3T4Xnc=
golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
Expand Down Expand Up @@ -648,11 +648,11 @@ golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand All @@ -661,8 +661,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM=
golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
Expand Down Expand Up @@ -954,6 +954,8 @@ sigs.k8s.io/controller-runtime v0.13.0 h1:iqa5RNciy7ADWnIc8QxCbOX5FEKVR3uxVxKHRM
sigs.k8s.io/controller-runtime v0.13.0/go.mod h1:Zbz+el8Yg31jubvAEyglRZGdLAjplZl+PgtYNI6WNTI=
sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2 h1:iXTIw73aPyC+oRdyqqvVJuloN1p0AC/kzH07hu3NE+k=
sigs.k8s.io/json v0.0.0-20220713155537-f223a00ba0e2/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/scheduler-plugins v0.24.9 h1:9oGtwk6uh7mZMCX8+O+PipQzBiRq9d2+E3xq1cn7zbc=
sigs.k8s.io/scheduler-plugins v0.24.9/go.mod h1:0u2b/0SwY2ozDhOD/f1S3e5IbStoDFLUK8yP5dJTaQ8=
sigs.k8s.io/structured-merge-diff/v4 v4.0.1/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw=
sigs.k8s.io/structured-merge-diff/v4 v4.2.3 h1:PRbqxJClWWYMNV1dhaG4NsibJbArud9kFxnAMREiWFE=
sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E=
Expand Down
6 changes: 6 additions & 0 deletions manifests/base/cluster-role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,9 @@ rules:
- horizontalpodautoscalers
verbs:
- "*"
- apiGroups:
- scheduling.sigs.k8s.io
resources:
- podgroups
verbs:
- "*"
9 changes: 0 additions & 9 deletions pkg/common/util/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,12 @@ package util

import commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"

const (
DefaultGangSchedulerName = "volcano"
)

func IsGangSchedulerSet(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, schedulerName string) bool {
if len(schedulerName) == 0 {
schedulerName = DefaultGangSchedulerName
}

for _, spec := range replicas {
if spec.Template.Spec.SchedulerName != "" && spec.Template.Spec.SchedulerName == schedulerName {
return true
}
}

return false
}

Expand Down
9 changes: 2 additions & 7 deletions pkg/controller.v1/mpi/mpijob.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ import (
"strings"

commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"

kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
)

const (
Expand Down Expand Up @@ -69,14 +69,9 @@ const (
// policy is set in pod template.
podTemplateRestartPolicyReason = "SettedPodTemplateRestartPolicy"

// gang scheduler name.
gangSchedulerName = "volcano"

// podTemplateSchedulerNameReason is the warning reason when other scheduler name is set
// in pod templates with gang-scheduling enabled
podTemplateSchedulerNameReason = "SettedPodTemplateSchedulerName"
// gangSchedulingPodGroupAnnotation is the annotation key used by batch schedulers
gangSchedulingPodGroupAnnotation = "scheduling.k8s.io/group-name"

// volcanoTaskSpecKey task spec key used in pod annotation when EnableGangScheduling is true
volcanoTaskSpecKey = "volcano.sh/task-spec"
Expand Down
49 changes: 22 additions & 27 deletions pkg/controller.v1/mpi/mpijob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ import (
"time"

"github.com/go-logr/logr"
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/kubeflow/common/pkg/controller.v1/common"
"github.com/kubeflow/common/pkg/controller.v1/control"
"github.com/kubeflow/common/pkg/controller.v1/expectation"
commonutil "github.com/kubeflow/common/pkg/util"
"github.com/kubeflow/training-operator/pkg/common/util"
"github.com/sirupsen/logrus"
corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
Expand All @@ -54,10 +48,15 @@ import (
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/source"
"volcano.sh/apis/pkg/apis/scheduling/v1beta1"
volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned"

commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/kubeflow/common/pkg/controller.v1/common"
"github.com/kubeflow/common/pkg/controller.v1/control"
"github.com/kubeflow/common/pkg/controller.v1/expectation"
commonutil "github.com/kubeflow/common/pkg/util"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common"
"github.com/kubeflow/training-operator/pkg/common/util"
ctlrconfig "github.com/kubeflow/training-operator/pkg/config"
)

Expand All @@ -69,7 +68,7 @@ const (
labelMPIJobName = "mpi-job-name"
)

func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *MPIJobReconciler {
func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *MPIJobReconciler {
r := &MPIJobReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Expand All @@ -80,23 +79,23 @@ func NewReconciler(mgr manager.Manager, enableGangScheduling bool) *MPIJobReconc

cfg := mgr.GetConfig()
kubeClientSet := kubeclientset.NewForConfigOrDie(cfg)
volcanoClientSet := volcanoclient.NewForConfigOrDie(cfg)
sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0)
priorityClassInformer := sharedInformers.Scheduling().V1beta1().PriorityClasses()

r.JobController = common.JobController{
Controller: r,
Expectations: expectation.NewControllerExpectations(),
Config: common.JobControllerConfiguration{EnableGangScheduling: enableGangScheduling},
WorkQueue: &util.FakeWorkQueue{},
Recorder: r.recorder,
KubeClientSet: kubeClientSet,
VolcanoClientSet: volcanoClientSet,
PriorityClassLister: priorityClassInformer.Lister(),
PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced,
PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder},
ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder},
}

gangSchedulingSetupFunc(&r.JobController)

return r
}

Expand Down Expand Up @@ -1000,20 +999,18 @@ func (jc *MPIJobReconciler) newWorker(mpiJob *kubeflowv1.MPIJob, name string) *c
// if gang-scheduling is enabled:
// 1. if user has specified other scheduler, we report a warning without overriding any fields.
// 2. if no SchedulerName is set for pods, then we set the SchedulerName to "volcano".
if jc.Config.EnableGangScheduling {
if util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, gangSchedulerName) {
if jc.Config.EnableGangScheduling() {
if !util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, jc.PodGroupControl.GetSchedulerName()) {
errMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten"
logger.Warning(errMsg)
jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg)
} else {
podSpec.Spec.SchedulerName = gangSchedulerName
}

if podSpec.Annotations == nil {
podSpec.Annotations = map[string]string{}
rtWorker := strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeWorker))
jc.PodGroupControl.DecoratePodTemplateSpec(podSpec, mpiJob, rtWorker)
if jc.PodGroupControl.GetSchedulerName() == "volcano" {
podSpec.Annotations[volcanoTaskSpecKey] = rtWorker
}
podSpec.Annotations[gangSchedulingPodGroupAnnotation] = mpiJob.GetName()
podSpec.Annotations[volcanoTaskSpecKey] = strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeWorker))
}

return &corev1.Pod{
Expand Down Expand Up @@ -1054,20 +1051,18 @@ func (jc *MPIJobReconciler) newLauncher(mpiJob *kubeflowv1.MPIJob, kubectlDelive

logger := commonutil.LoggerForReplica(mpiJob, strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeLauncher)))
// add SchedulerName to podSpec
if jc.Config.EnableGangScheduling {
if util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, gangSchedulerName) {
if jc.Config.EnableGangScheduling() {
if !util.IsGangSchedulerSet(mpiJob.Spec.MPIReplicaSpecs, jc.PodGroupControl.GetSchedulerName()) {
errMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten"
logger.Warning(errMsg)
jc.Recorder.Event(mpiJob, corev1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg)
} else {
podSpec.Spec.SchedulerName = gangSchedulerName
}

if podSpec.Annotations == nil {
podSpec.Annotations = map[string]string{}
rt := strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeLauncher))
jc.PodGroupControl.DecoratePodTemplateSpec(podSpec, mpiJob, rt)
if jc.PodGroupControl.GetSchedulerName() == "volcano" {
podSpec.Annotations[volcanoTaskSpecKey] = rt
}
podSpec.Annotations[gangSchedulingPodGroupAnnotation] = mpiJob.GetName()
podSpec.Annotations[volcanoTaskSpecKey] = strings.ToLower(string(kubeflowv1.MPIJobReplicaTypeLauncher))
}

podSpec.Spec.ServiceAccountName = launcherName
Expand Down
7 changes: 4 additions & 3 deletions pkg/controller.v1/mpi/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ import (
"path/filepath"
"testing"

"github.com/kubeflow/common/pkg/controller.v1/common"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"github.com/kubeflow/training-operator/pkg/config"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"k8s.io/client-go/kubernetes/scheme"

ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/envtest"
logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1"
"volcano.sh/apis/pkg/apis/scheduling/v1beta1"
//+kubebuilder:scaffold:imports
)

Expand Down Expand Up @@ -86,7 +86,8 @@ var _ = BeforeSuite(func() {
})
Expect(err).NotTo(HaveOccurred())

reconciler = NewReconciler(mgr, false)
gangSchedulingSetupFunc := common.GenNonGangSchedulerSetupFunc()
reconciler = NewReconciler(mgr, gangSchedulingSetupFunc)
Expect(reconciler.SetupWithManager(mgr, 1)).NotTo(HaveOccurred())

go func() {
Expand Down
Loading

0 comments on commit f9fa8ac

Please sign in to comment.