diff --git a/pkg/apis/common/v1/interface.go b/pkg/apis/common/v1/interface.go index 661a96c6..39e95d31 100644 --- a/pkg/apis/common/v1/interface.go +++ b/pkg/apis/common/v1/interface.go @@ -1,7 +1,7 @@ package v1 import ( - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" ) @@ -44,7 +44,7 @@ type ControllerInterface interface { UpdateJobStatusInApiServer(job interface{}, jobStatus *JobStatus) error // SetClusterSpec sets the cluster spec for the pod - SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype ReplicaType, index string) error + SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype, index string) error // Returns the default container name in pod GetDefaultContainerName() string diff --git a/pkg/controller.v1/common/job.go b/pkg/controller.v1/common/job.go index 083f6eea..b3533f02 100644 --- a/pkg/controller.v1/common/job.go +++ b/pkg/controller.v1/common/job.go @@ -300,9 +300,9 @@ func (jc *JobController) ReconcileJobs( // ResetExpectations reset the expectation for creates and deletes of pod/service to zero. func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) { for rtype := range replicas { - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, string(rtype)) jc.Expectations.SetExpectations(expectationPodsKey, 0, 0) - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, string(rtype)) jc.Expectations.SetExpectations(expectationServicesKey, 0, 0) } } diff --git a/pkg/controller.v1/common/pod.go b/pkg/controller.v1/common/pod.go index b29b469f..c878c71b 100644 --- a/pkg/controller.v1/common/pod.go +++ b/pkg/controller.v1/common/pod.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" "strconv" + "strings" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" @@ -105,7 +106,7 @@ func (jc *JobController) AddPod(obj interface{}) { return } - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rType) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, string(rType)) jc.Expectations.CreationObserved(expectationPodsKey) // TODO: we may need add backoff here @@ -206,7 +207,7 @@ func (jc *JobController) DeletePod(obj interface{}) { return } - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rType) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, string(rType)) jc.Expectations.DeletionObserved(expectationPodsKey) deletedPodsCount.Inc() @@ -255,7 +256,7 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error) } // FilterPodsForReplicaType returns pods belong to a replicaType. -func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) { +func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) { return core.FilterPodsForReplicaType(pods, replicaType) } @@ -271,10 +272,11 @@ func (jc *JobController) ReconcilePods( job interface{}, jobStatus *apiv1.JobStatus, pods []*v1.Pod, - rtype apiv1.ReplicaType, + rType apiv1.ReplicaType, spec *apiv1.ReplicaSpec, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error { + rt := strings.ToLower(string(rType)) metaObject, ok := job.(metav1.Object) if !ok { return fmt.Errorf("job is not a metav1.Object type") @@ -288,19 +290,19 @@ func (jc *JobController) ReconcilePods( utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", job, err)) return err } - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rt) // Convert ReplicaType to lower string. - logger := commonutil.LoggerForReplica(metaObject, rtype) + logger := commonutil.LoggerForReplica(metaObject, rt) // Get all pods for the type rt. - pods, err = jc.FilterPodsForReplicaType(pods, rtype) + pods, err = jc.FilterPodsForReplicaType(pods, rt) if err != nil { return err } numReplicas := int(*spec.Replicas) var masterRole bool - initializeReplicaStatuses(jobStatus, rtype) + initializeReplicaStatuses(jobStatus, rType) // GetPodSlices will return enough information here to make decision to add/remove/update resources. // @@ -311,13 +313,13 @@ func (jc *JobController) ReconcilePods( podSlices := jc.GetPodSlices(pods, numReplicas, logger) for index, podSlice := range podSlices { if len(podSlice) > 1 { - logger.Warningf("We have too many pods for %s %d", rtype, index) + logger.Warningf("We have too many pods for %s %d", rt, index) } else if len(podSlice) == 0 { - logger.Infof("Need to create new pod: %s-%d", rtype, index) + logger.Infof("Need to create new pod: %s-%d", rt, index) // check if this replica is the master role - masterRole = jc.Controller.IsMasterRole(replicas, rtype, index) - err = jc.createNewPod(job, rtype, index, spec, masterRole, replicas) + masterRole = jc.Controller.IsMasterRole(replicas, rType, index) + err = jc.createNewPod(job, rt, index, spec, masterRole, replicas) if err != nil { return err } @@ -358,14 +360,14 @@ func (jc *JobController) ReconcilePods( } } - updateJobReplicaStatuses(jobStatus, rtype, pod) + updateJobReplicaStatuses(jobStatus, rType, pod) } } return nil } // createNewPod creates a new pod for the given index and type. -func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, index int, spec *apiv1.ReplicaSpec, masterRole bool, +func (jc *JobController) createNewPod(job interface{}, rt string, index int, spec *apiv1.ReplicaSpec, masterRole bool, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error { metaObject, ok := job.(metav1.Object) diff --git a/pkg/controller.v1/common/service.go b/pkg/controller.v1/common/service.go index 7eb99c41..812d0445 100644 --- a/pkg/controller.v1/common/service.go +++ b/pkg/controller.v1/common/service.go @@ -16,6 +16,7 @@ package common import ( "fmt" "strconv" + "strings" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" @@ -74,7 +75,7 @@ func (jc *JobController) AddService(obj interface{}) { return } - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rType) + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, string(rType)) jc.Expectations.CreationObserved(expectationServicesKey) // TODO: we may need add backoff here @@ -139,7 +140,7 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service } // FilterServicesForReplicaType returns service belong to a replicaType. -func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) { +func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) { return core.FilterServicesForReplicaType(services, replicaType) } @@ -158,9 +159,11 @@ func (jc *JobController) ReconcileServices( rtype apiv1.ReplicaType, spec *apiv1.ReplicaSpec) error { + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) replicas := int(*spec.Replicas) // Get all services for the type rt. - services, err := jc.FilterServicesForReplicaType(services, rtype) + services, err := jc.FilterServicesForReplicaType(services, rt) if err != nil { return err } @@ -171,13 +174,13 @@ func (jc *JobController) ReconcileServices( // If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created. // // If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted. - serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype)) + serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt)) for index, serviceSlice := range serviceSlices { if len(serviceSlice) > 1 { - commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index) + commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rtype, index) } else if len(serviceSlice) == 0 { - commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index) + commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rtype, index) err = jc.CreateNewService(job, rtype, spec, strconv.Itoa(index)) if err != nil { return err @@ -212,9 +215,10 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica return err } + rt := strings.ToLower(string(rtype)) // Append ReplicaTypeLabelDeprecated and ReplicaIndexLabelDeprecated labels. labels := jc.GenLabels(job.GetName()) - utillabels.SetReplicaType(labels, rtype) + utillabels.SetReplicaType(labels, rt) utillabels.SetReplicaIndexStr(labels, index) ports, err := jc.GetPortsFromJob(spec) @@ -236,13 +240,13 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica service.Spec.Ports = append(service.Spec.Ports, svcPort) } - service.Name = GenGeneralName(job.GetName(), rtype, index) + service.Name = GenGeneralName(job.GetName(), rt, index) service.Labels = labels // Create OwnerReference. controllerRef := jc.GenOwnerReference(job) // Creation is expected when there is no error returned - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt) jc.Expectations.RaiseExpectations(expectationServicesKey, 1, 0) err = jc.ServiceControl.CreateServicesWithControllerRef(job.GetNamespace(), service, job.(runtime.Object), controllerRef) diff --git a/pkg/controller.v1/common/util.go b/pkg/controller.v1/common/util.go index 7bd217d2..99229034 100644 --- a/pkg/controller.v1/common/util.go +++ b/pkg/controller.v1/common/util.go @@ -47,8 +47,8 @@ func (p ReplicasPriority) Swap(i, j int) { p[i], p[j] = p[j], p[i] } -func GenGeneralName(jobName string, rtype apiv1.ReplicaType, index string) string { - n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index +func GenGeneralName(jobName string, rtype string, index string) string { + n := jobName + "-" + strings.ToLower(rtype) + "-" + index return strings.Replace(n, "/", "-", -1) } diff --git a/pkg/controller.v1/common/util_test.go b/pkg/controller.v1/common/util_test.go index 3b87373d..ecaab7aa 100644 --- a/pkg/controller.v1/common/util_test.go +++ b/pkg/controller.v1/common/util_test.go @@ -44,7 +44,7 @@ func TestGenGeneralName(t *testing.T) { } for _, tc := range tcs { - actual := GenGeneralName(tc.key, tc.replicaType, tc.index) + actual := GenGeneralName(tc.key, string(tc.replicaType), tc.index) if actual != tc.expectedName { t.Errorf("Expected name %s, got %s", tc.expectedName, actual) } diff --git a/pkg/controller.v1/expectation/util.go b/pkg/controller.v1/expectation/util.go index 9061000d..b7e5bcb4 100644 --- a/pkg/controller.v1/expectation/util.go +++ b/pkg/controller.v1/expectation/util.go @@ -1,16 +1,15 @@ package expectation import ( - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "strings" ) // GenExpectationPodsKey generates an expectation key for pods of a job -func GenExpectationPodsKey(jobKey string, replicaType apiv1.ReplicaType) string { - return jobKey + "/" + strings.ToLower(string(replicaType)) + "/pods" +func GenExpectationPodsKey(jobKey string, replicaType string) string { + return jobKey + "/" + strings.ToLower(replicaType) + "/pods" } // GenExpectationPodsKey generates an expectation key for services of a job -func GenExpectationServicesKey(jobKey string, replicaType apiv1.ReplicaType) string { - return jobKey + "/" + strings.ToLower(string(replicaType)) + "/services" +func GenExpectationServicesKey(jobKey string, replicaType string) string { + return jobKey + "/" + strings.ToLower(replicaType) + "/services" } diff --git a/pkg/core/job.go b/pkg/core/job.go index 577c9b18..45d932a0 100644 --- a/pkg/core/job.go +++ b/pkg/core/job.go @@ -2,6 +2,7 @@ package core import ( "sort" + "strings" "time" log "github.com/sirupsen/logrus" @@ -77,7 +78,7 @@ func PastActiveDeadline(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus) b // this method applies only to pods when restartPolicy is one of OnFailure, Always or ExitCode func PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec, pods []*v1.Pod, - podFilterFunc func(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error)) (bool, error) { + podFilterFunc func(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error)) (bool, error) { if runPolicy.BackoffLimit == nil { return false, nil } @@ -88,7 +89,8 @@ func PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy, continue } // Convert ReplicaType to lower string. - pods, err := podFilterFunc(pods, rtype) + rt := strings.ToLower(string(rtype)) + pods, err := podFilterFunc(pods, rt) if err != nil { return false, err } diff --git a/pkg/core/pod.go b/pkg/core/pod.go index 30cdd1e8..0368f5ef 100644 --- a/pkg/core/pod.go +++ b/pkg/core/pod.go @@ -10,16 +10,16 @@ import ( ) // FilterPodsForReplicaType returns pods belong to a replicaType. -func FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) { +func FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) { var result []*v1.Pod selector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabel: string(replicaType), + apiv1.ReplicaTypeLabel: replicaType, }) // TODO(#149): Remove deprecated selector. deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabelDeprecated: string(replicaType), + apiv1.ReplicaTypeLabelDeprecated: replicaType, }) for _, pod := range pods { diff --git a/pkg/core/service.go b/pkg/core/service.go index 73739b9f..97d9c300 100644 --- a/pkg/core/service.go +++ b/pkg/core/service.go @@ -11,16 +11,16 @@ import ( ) // FilterServicesForReplicaType returns service belong to a replicaType. -func FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) { +func FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) { var result []*v1.Service selector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabel: string(replicaType), + apiv1.ReplicaTypeLabel: replicaType, }) // TODO(#149): Remove deprecated selector. deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabelDeprecated: string(replicaType), + apiv1.ReplicaTypeLabelDeprecated: replicaType, }) for _, service := range services { diff --git a/pkg/core/utils.go b/pkg/core/utils.go index 0f12a3c3..6da7a134 100644 --- a/pkg/core/utils.go +++ b/pkg/core/utils.go @@ -2,8 +2,6 @@ package core import ( "strings" - - commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" ) func MaxInt(x, y int) int { @@ -13,7 +11,7 @@ func MaxInt(x, y int) int { return x } -func GenGeneralName(jobName string, rtype commonv1.ReplicaType, index string) string { - n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index +func GenGeneralName(jobName string, rtype string, index string) string { + n := jobName + "-" + strings.ToLower(rtype) + "-" + index return strings.Replace(n, "/", "-", -1) } diff --git a/pkg/reconciler.v1/common/gang_volcano.go b/pkg/reconciler.v1/common/gang_volcano.go index 637c5d2e..f9cd6ab1 100644 --- a/pkg/reconciler.v1/common/gang_volcano.go +++ b/pkg/reconciler.v1/common/gang_volcano.go @@ -165,7 +165,7 @@ func (r *VolcanoReconciler) ReconcilePodGroup( } // DecoratePodForGangScheduling decorates the podTemplate before it's used to generate a pod with information for gang-scheduling -func (r *VolcanoReconciler) DecoratePodForGangScheduling(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) { +func (r *VolcanoReconciler) DecoratePodForGangScheduling(rtype string, podTemplate *corev1.PodTemplateSpec, job client.Object) { if podTemplate.Spec.SchedulerName == "" || podTemplate.Spec.SchedulerName == r.GetGangSchedulerName() { podTemplate.Spec.SchedulerName = r.GetGangSchedulerName() } else { diff --git a/pkg/reconciler.v1/common/interface.go b/pkg/reconciler.v1/common/interface.go index 3ef70337..6942d16b 100644 --- a/pkg/reconciler.v1/common/interface.go +++ b/pkg/reconciler.v1/common/interface.go @@ -78,7 +78,7 @@ type GangSchedulingInterface interface { // DecoratePodForGangScheduling SHOULD be overridden if gang scheduler demands Pods associated with PodGroup to be // decorated with specific requests. - DecoratePodForGangScheduling(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) + DecoratePodForGangScheduling(rtype string, podTemplate *corev1.PodTemplateSpec, job client.Object) } // PodInterface defines the abstract interface for Pod related actions, such like get, create or delete Pod @@ -90,14 +90,14 @@ type PodInterface interface { GetDefaultContainerName() string // GenPodName CAN be overridden to customize Pod name. - GenPodName(jobName string, rtype commonv1.ReplicaType, index string) string + GenPodName(jobName string, rtype string, index string) string // GetPodsForJob CAN be overridden to customize how to list all pods with the job. GetPodsForJob(ctx context.Context, job client.Object) ([]*corev1.Pod, error) // FilterPodsForReplicaType CAN be overridden if the linking approach between pods and replicaType changes as this // function filters out pods for specific replica type from all pods associated with the job. - FilterPodsForReplicaType(pods []*corev1.Pod, replicaType commonv1.ReplicaType) ([]*corev1.Pod, error) + FilterPodsForReplicaType(pods []*corev1.Pod, replicaType string) ([]*corev1.Pod, error) // GetPodSlices SHOULD NOT be overridden as it generates pod slices for further pod processing. GetPodSlices(pods []*corev1.Pod, replicas int, logger *logrus.Entry) [][]*corev1.Pod @@ -113,7 +113,7 @@ type PodInterface interface { replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error // CreateNewPod CAN be overridden to customize how to create a new pod. - CreateNewPod(job client.Object, rt commonv1.ReplicaType, index string, + CreateNewPod(job client.Object, rt string, index string, spec *commonv1.ReplicaSpec, masterRole bool, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error // DeletePod CAN be overridden to customize how to delete a pod of {name} in namespace {ns}. @@ -121,7 +121,7 @@ type PodInterface interface { // DecoratePod CAN be overridden if customization to the pod is needed. The default implementation applies nothing // to the pod. - DecoratePod(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) + DecoratePod(rtype string, podTemplate *corev1.PodTemplateSpec, job client.Object) } // ServiceInterface defines the abstract interface for Pod related actions, such like get, create or delete Service @@ -136,8 +136,7 @@ type ServiceInterface interface { GetServicesForJob(ctx context.Context, job client.Object) ([]*corev1.Service, error) // FilterServicesForReplicaType CAN be overridden to customize how to filter out services for this Replica Type. - FilterServicesForReplicaType(services []*corev1.Service, - replicaType commonv1.ReplicaType) ([]*corev1.Service, error) + FilterServicesForReplicaType(services []*corev1.Service, replicaType string) ([]*corev1.Service, error) // GetServiceSlices CAN be overridden to customize how to generate service slices. GetServiceSlices(services []*corev1.Service, replicas int, logger *logrus.Entry) [][]*corev1.Service @@ -157,7 +156,7 @@ type ServiceInterface interface { DeleteService(ns string, name string, job client.Object) error // DecorateService CAN be overridden to customize this service right before being created - DecorateService(rtype commonv1.ReplicaType, svc *corev1.Service, job client.Object) + DecorateService(rtype string, svc *corev1.Service, job client.Object) } // JobInterface defines the abstract interface for Pod related actions, such like get, create or delete TFJob, @@ -222,7 +221,7 @@ type JobInterface interface { // IsFlagReplicaTypeForJobStatus CAN be overridden to customize how to determine if this ReplicaType is the // flag ReplicaType for the status of this kind of job - IsFlagReplicaTypeForJobStatus(rtype commonv1.ReplicaType) bool + IsFlagReplicaTypeForJobStatus(rtype string) bool // IsJobSucceeded CAN be overridden to customize how to determine if this job is succeeded. IsJobSucceeded(status commonv1.JobStatus) bool diff --git a/pkg/reconciler.v1/common/job.go b/pkg/reconciler.v1/common/job.go index 54ae2d30..45b1bc43 100644 --- a/pkg/reconciler.v1/common/job.go +++ b/pkg/reconciler.v1/common/job.go @@ -18,10 +18,11 @@ import ( "context" "fmt" "reflect" - ctrl "sigs.k8s.io/controller-runtime" "strings" "time" + ctrl "sigs.k8s.io/controller-runtime" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/core" commonutil "github.com/kubeflow/common/pkg/util" @@ -306,7 +307,7 @@ func (r *KubeflowJobReconciler) UpdateJobStatus( logrus.Infof("%s=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d , failed=%d", jobKind, jobNamespacedName, rtype, expected, running, succeeded, failed) - if r.IsFlagReplicaTypeForJobStatus(rtype) { + if r.IsFlagReplicaTypeForJobStatus(string(rtype)) { if running > 0 { msg := fmt.Sprintf("%s %s is running.", jobKind, jobNamespacedName) err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, commonutil.JobRunningReason, msg) @@ -447,7 +448,7 @@ func (r *KubeflowJobReconciler) CleanupJob(runPolicy *commonv1.RunPolicy, status } // IsFlagReplicaTypeForJobStatus checks if this replicaType is the flag replicaType for the status of KubeflowJob -func (r *KubeflowJobReconciler) IsFlagReplicaTypeForJobStatus(rtype commonv1.ReplicaType) bool { +func (r *KubeflowJobReconciler) IsFlagReplicaTypeForJobStatus(rtype string) bool { logrus.Warnf(WarnDefaultImplementationTemplate, "IsFlagReplicaTypeForJobStatus") return true } diff --git a/pkg/reconciler.v1/common/pod.go b/pkg/reconciler.v1/common/pod.go index f6f69cab..65ed538d 100644 --- a/pkg/reconciler.v1/common/pod.go +++ b/pkg/reconciler.v1/common/pod.go @@ -17,6 +17,7 @@ package common import ( "context" "strconv" + "strings" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -78,7 +79,7 @@ func BareKubeflowPodReconciler(client client.Client) *KubeflowPodReconciler { } // GenPodName returns the name of the Pod based on jobName, replicaType and its index -func (r *KubeflowPodReconciler) GenPodName(jobName string, rtype commonv1.ReplicaType, index string) string { +func (r *KubeflowPodReconciler) GenPodName(jobName string, rtype string, index string) string { return core.GenGeneralName(jobName, rtype, index) } @@ -110,7 +111,7 @@ func (r *KubeflowPodReconciler) GetPodSlices(pods []*corev1.Pod, replicas int, l } // FilterPodsForReplicaType filters out Pods for this replicaType -func (r *KubeflowPodReconciler) FilterPodsForReplicaType(pods []*corev1.Pod, replicaType commonv1.ReplicaType) ([]*corev1.Pod, error) { +func (r *KubeflowPodReconciler) FilterPodsForReplicaType(pods []*corev1.Pod, replicaType string) ([]*corev1.Pod, error) { return core.FilterPodsForReplicaType(pods, replicaType) } @@ -120,21 +121,22 @@ func (r *KubeflowPodReconciler) ReconcilePods( job client.Object, jobStatus *commonv1.JobStatus, pods []*corev1.Pod, - rtype commonv1.ReplicaType, + rType commonv1.ReplicaType, spec *commonv1.ReplicaSpec, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error { + rt := strings.ToLower(string(rType)) // Convert ReplicaType to lower string. - logger := commonutil.LoggerForReplica(job, rtype) + logger := commonutil.LoggerForReplica(job, rt) // Get all pods for the type rt. - pods, err := r.FilterPodsForReplicaType(pods, rtype) + pods, err := r.FilterPodsForReplicaType(pods, rt) if err != nil { return err } numReplicas := int(*spec.Replicas) var masterRole bool - core.InitializeReplicaStatuses(jobStatus, rtype) + core.InitializeReplicaStatuses(jobStatus, rType) // GetPodSlices will return enough information here to make decision to add/remove/update resources. // @@ -145,13 +147,13 @@ func (r *KubeflowPodReconciler) ReconcilePods( podSlices := r.GetPodSlices(pods, numReplicas, logger) for index, podSlice := range podSlices { if len(podSlice) > 1 { - logger.Warningf("We have too many pods for %s %d", rtype, index) + logger.Warningf("We have too many pods for %s %d", rt, index) } else if len(podSlice) == 0 { - logger.Infof("Need to create new pod: %s-%d", rtype, index) + logger.Infof("Need to create new pod: %s-%d", rt, index) // check if this replica is the master role - masterRole = r.IsMasterRole(replicas, rtype, index) - err = r.CreateNewPod(job, rtype, strconv.Itoa(index), spec, masterRole, replicas) + masterRole = r.IsMasterRole(replicas, commonv1.ReplicaType(rt), index) + err = r.CreateNewPod(job, rt, strconv.Itoa(index), spec, masterRole, replicas) if err != nil { return err } @@ -188,7 +190,7 @@ func (r *KubeflowPodReconciler) ReconcilePods( } } - core.UpdateJobReplicaStatuses(jobStatus, rtype, pod) + core.UpdateJobReplicaStatuses(jobStatus, rType, pod) } } return nil @@ -196,13 +198,13 @@ func (r *KubeflowPodReconciler) ReconcilePods( } // CreateNewPod generate Pods for this job and submits creation request to APIServer -func (r *KubeflowPodReconciler) CreateNewPod(job client.Object, rt commonv1.ReplicaType, index string, +func (r *KubeflowPodReconciler) CreateNewPod(job client.Object, rt string, index string, spec *commonv1.ReplicaSpec, masterRole bool, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error { logger := commonutil.LoggerForReplica(job, rt) podLabels := r.GenLabels(job.GetName()) - podLabels[commonv1.ReplicaTypeLabel] = string(rt) + podLabels[commonv1.ReplicaTypeLabel] = rt podLabels[commonv1.ReplicaIndexLabel] = index if masterRole { podLabels[commonv1.JobRoleLabel] = "master" @@ -270,7 +272,7 @@ func (r *KubeflowPodReconciler) DeletePod(ctx context.Context, ns string, name s } // DecoratePod decorates podTemplate before a Pod is submitted to the APIServer -func (r *KubeflowPodReconciler) DecoratePod(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) { +func (r *KubeflowPodReconciler) DecoratePod(rtype string, podTemplate *corev1.PodTemplateSpec, job client.Object) { // Default implementation applies nothing to podTemplate return } diff --git a/pkg/reconciler.v1/common/pod_test.go b/pkg/reconciler.v1/common/pod_test.go index f2afeec3..bf985410 100644 --- a/pkg/reconciler.v1/common/pod_test.go +++ b/pkg/reconciler.v1/common/pod_test.go @@ -30,7 +30,7 @@ import ( func TestGenPodName(t *testing.T) { type tc struct { testJob *testjobv1.TestJob - testRType commonv1.ReplicaType + testRType string testIndex string expectedName string } @@ -40,7 +40,7 @@ func TestGenPodName(t *testing.T) { tj.SetName("hello-world") return tc{ testJob: tj, - testRType: commonv1.ReplicaType(testjobv1.TestReplicaTypeWorker), + testRType: string(testjobv1.TestReplicaTypeWorker), testIndex: "1", expectedName: "hello-world-worker-1", } @@ -70,7 +70,7 @@ func PodInSlice(pod *corev1.Pod, pods []*corev1.Pod) bool { func TestFilterPodsForReplicaType(t *testing.T) { type tc struct { testPods []*corev1.Pod - testRType commonv1.ReplicaType + testRType string expectedPods []*corev1.Pod } testCase := []tc{ @@ -119,7 +119,7 @@ func TestFilterPodsForReplicaType(t *testing.T) { return tc{ testPods: allPods, - testRType: commonv1.ReplicaType(testjobv1.TestReplicaTypeWorker), + testRType: string(testjobv1.TestReplicaTypeWorker), expectedPods: filteredPods, } }(), diff --git a/pkg/reconciler.v1/common/service.go b/pkg/reconciler.v1/common/service.go index 63e7791e..d92f71ce 100644 --- a/pkg/reconciler.v1/common/service.go +++ b/pkg/reconciler.v1/common/service.go @@ -17,6 +17,7 @@ package common import ( "context" "strconv" + "strings" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/core" @@ -94,7 +95,7 @@ func (r *KubeflowServiceReconciler) GetServicesForJob(ctx context.Context, job c // FilterServicesForReplicaType returns service belong to a replicaType. func (r *KubeflowServiceReconciler) FilterServicesForReplicaType(services []*corev1.Service, - replicaType commonv1.ReplicaType) ([]*corev1.Service, error) { + replicaType string) ([]*corev1.Service, error) { return core.FilterServicesForReplicaType(services, replicaType) } @@ -110,9 +111,11 @@ func (r *KubeflowServiceReconciler) ReconcileServices( rtype commonv1.ReplicaType, spec *commonv1.ReplicaSpec) error { + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) replicas := int(*spec.Replicas) // Get all services for the type rt. - services, err := r.FilterServicesForReplicaType(services, rtype) + services, err := r.FilterServicesForReplicaType(services, rt) if err != nil { return err } @@ -123,13 +126,13 @@ func (r *KubeflowServiceReconciler) ReconcileServices( // If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created. // // If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted. - serviceSlices := r.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype)) + serviceSlices := r.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt)) for index, serviceSlice := range serviceSlices { if len(serviceSlice) > 1 { - commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index) + commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rtype, index) } else if len(serviceSlice) == 0 { - commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index) + commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rtype, index) err = r.CreateNewService(job, rtype, spec, strconv.Itoa(index)) if err != nil { return err @@ -155,9 +158,11 @@ func (r *KubeflowServiceReconciler) ReconcileServices( func (r *KubeflowServiceReconciler) CreateNewService(job client.Object, rtype commonv1.ReplicaType, spec *commonv1.ReplicaSpec, index string) error { + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) // Append ReplicaTypeLabel and ReplicaIndexLabel labels. labels := r.GenLabels(job.GetName()) - labels[commonv1.ReplicaTypeLabel] = string(rtype) + labels[commonv1.ReplicaTypeLabel] = rt labels[commonv1.ReplicaIndexLabel] = index ports, err := r.GetPortsFromJob(spec) @@ -179,7 +184,7 @@ func (r *KubeflowServiceReconciler) CreateNewService(job client.Object, rtype co service.Spec.Ports = append(service.Spec.Ports, svcPort) } - service.Name = core.GenGeneralName(job.GetName(), rtype, index) + service.Name = core.GenGeneralName(job.GetName(), rt, index) service.Namespace = job.GetNamespace() service.Labels = labels // Create OwnerReference. @@ -188,7 +193,7 @@ func (r *KubeflowServiceReconciler) CreateNewService(job client.Object, rtype co return err } - r.DecorateService(rtype, service, job) + r.DecorateService(rt, service, job) err = r.Create(context.Background(), service) if err != nil && errors.IsTimeout(err) { @@ -215,7 +220,7 @@ func (r *KubeflowServiceReconciler) DeleteService(ns string, name string, job cl } // DecorateService decorates the Service before it's submitted to APIServer -func (r *KubeflowServiceReconciler) DecorateService(rtype commonv1.ReplicaType, svc *corev1.Service, job client.Object) { +func (r *KubeflowServiceReconciler) DecorateService(rtype string, svc *corev1.Service, job client.Object) { // Default implementation applies nothing to podTemplate return } diff --git a/pkg/reconciler.v1/common/service_test.go b/pkg/reconciler.v1/common/service_test.go index b6d743c9..0585233a 100644 --- a/pkg/reconciler.v1/common/service_test.go +++ b/pkg/reconciler.v1/common/service_test.go @@ -16,6 +16,7 @@ package common_test import ( "reflect" + "strings" "testing" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" @@ -60,7 +61,7 @@ func TestCreateNewService(t *testing.T) { commonv1.OperatorNameLabel: "Test Reconciler", commonv1.JobNameLabelDeprecated: jobName, commonv1.JobNameLabel: jobName, - commonv1.ReplicaTypeLabel: string(testjobv1.TestReplicaTypeWorker), + commonv1.ReplicaTypeLabel: strings.ToLower(string(testjobv1.TestReplicaTypeWorker)), commonv1.ReplicaIndexLabel: idx, }, }, diff --git a/pkg/util/labels/labels.go b/pkg/util/labels/labels.go index 255884a9..8520c0c3 100644 --- a/pkg/util/labels/labels.go +++ b/pkg/util/labels/labels.go @@ -54,9 +54,9 @@ func ReplicaType(labels map[string]string) (v1.ReplicaType, error) { return v1.ReplicaType(v), nil } -func SetReplicaType(labels map[string]string, rt v1.ReplicaType) { - labels[v1.ReplicaTypeLabel] = string(rt) - labels[v1.ReplicaTypeLabelDeprecated] = string(rt) +func SetReplicaType(labels map[string]string, rt string) { + labels[v1.ReplicaTypeLabel] = rt + labels[v1.ReplicaTypeLabelDeprecated] = rt } func HasKnownLabels(labels map[string]string, groupName string) bool { diff --git a/pkg/util/logger.go b/pkg/util/logger.go index a9719fce..8d523fe1 100644 --- a/pkg/util/logger.go +++ b/pkg/util/logger.go @@ -15,16 +15,15 @@ package util import ( - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "strings" log "github.com/sirupsen/logrus" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) -func LoggerForReplica(job metav1.Object, rtype apiv1.ReplicaType) *log.Entry { +func LoggerForReplica(job metav1.Object, rtype string) *log.Entry { return log.WithFields(log.Fields{ // We use job to match the key used in controller.go // Its more common in K8s to use a period to indicate namespace.name. So that's what we use. diff --git a/test_job/controller.v1/test_job/test_job_controller.go b/test_job/controller.v1/test_job/test_job_controller.go index 41978ce8..e40af065 100644 --- a/test_job/controller.v1/test_job/test_job_controller.go +++ b/test_job/controller.v1/test_job/test_job_controller.go @@ -61,7 +61,7 @@ func (t *TestJobController) UpdateJobStatusInApiServer(job interface{}, jobStatu return nil } -func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype commonv1.ReplicaType, index string) error { +func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { return nil }