diff --git a/pkg/apis/common/v1/interface.go b/pkg/apis/common/v1/interface.go index 661a96c6..255d7b6e 100644 --- a/pkg/apis/common/v1/interface.go +++ b/pkg/apis/common/v1/interface.go @@ -44,7 +44,7 @@ type ControllerInterface interface { UpdateJobStatusInApiServer(job interface{}, jobStatus *JobStatus) error // SetClusterSpec sets the cluster spec for the pod - SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype ReplicaType, index string) error + SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype, index string) error // Returns the default container name in pod GetDefaultContainerName() string diff --git a/pkg/controller.v1/common/job.go b/pkg/controller.v1/common/job.go index 9f6cf428..7082186e 100644 --- a/pkg/controller.v1/common/job.go +++ b/pkg/controller.v1/common/job.go @@ -3,6 +3,7 @@ package common import ( "fmt" "reflect" + "strings" "time" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" @@ -300,9 +301,10 @@ func (jc *JobController) ReconcileJobs( // ResetExpectations reset the expectation for creates and deletes of pod/service to zero. func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) { for rtype := range replicas { - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) + rt := strings.ToLower(string(rtype)) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rt) jc.Expectations.SetExpectations(expectationPodsKey, 0, 0) - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt) jc.Expectations.SetExpectations(expectationServicesKey, 0, 0) } } diff --git a/pkg/controller.v1/common/pod.go b/pkg/controller.v1/common/pod.go index b29b469f..f190eb59 100644 --- a/pkg/controller.v1/common/pod.go +++ b/pkg/controller.v1/common/pod.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" "strconv" + "strings" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" @@ -113,7 +114,6 @@ func (jc *JobController) AddPod(obj interface{}) { return } - } // When a pod is updated, figure out what job is managing it and wake it up. @@ -255,7 +255,7 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error) } // FilterPodsForReplicaType returns pods belong to a replicaType. -func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) { +func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) { return core.FilterPodsForReplicaType(pods, replicaType) } @@ -288,12 +288,14 @@ func (jc *JobController) ReconcilePods( utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", job, err)) return err } - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) // Convert ReplicaType to lower string. - logger := commonutil.LoggerForReplica(metaObject, rtype) + rt := strings.ToLower(string(rtype)) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rt) + + logger := commonutil.LoggerForReplica(metaObject, rt) // Get all pods for the type rt. - pods, err = jc.FilterPodsForReplicaType(pods, rtype) + pods, err = jc.FilterPodsForReplicaType(pods, rt) if err != nil { return err } @@ -311,13 +313,13 @@ func (jc *JobController) ReconcilePods( podSlices := jc.GetPodSlices(pods, numReplicas, logger) for index, podSlice := range podSlices { if len(podSlice) > 1 { - logger.Warningf("We have too many pods for %s %d", rtype, index) + logger.Warningf("We have too many pods for %s %d", rt, index) } else if len(podSlice) == 0 { - logger.Infof("Need to create new pod: %s-%d", rtype, index) + logger.Infof("Need to create new pod: %s-%d", rt, index) // check if this replica is the master role masterRole = jc.Controller.IsMasterRole(replicas, rtype, index) - err = jc.createNewPod(job, rtype, index, spec, masterRole, replicas) + err = jc.createNewPod(job, rt, index, spec, masterRole, replicas) if err != nil { return err } @@ -365,7 +367,7 @@ func (jc *JobController) ReconcilePods( } // createNewPod creates a new pod for the given index and type. -func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, index int, spec *apiv1.ReplicaSpec, masterRole bool, +func (jc *JobController) createNewPod(job interface{}, rt string, index int, spec *apiv1.ReplicaSpec, masterRole bool, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error { metaObject, ok := job.(metav1.Object) diff --git a/pkg/controller.v1/common/service.go b/pkg/controller.v1/common/service.go index 7eb99c41..0b2293b0 100644 --- a/pkg/controller.v1/common/service.go +++ b/pkg/controller.v1/common/service.go @@ -16,6 +16,7 @@ package common import ( "fmt" "strconv" + "strings" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" @@ -139,7 +140,7 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service } // FilterServicesForReplicaType returns service belong to a replicaType. -func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) { +func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) { return core.FilterServicesForReplicaType(services, replicaType) } @@ -158,9 +159,11 @@ func (jc *JobController) ReconcileServices( rtype apiv1.ReplicaType, spec *apiv1.ReplicaSpec) error { + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) replicas := int(*spec.Replicas) // Get all services for the type rt. - services, err := jc.FilterServicesForReplicaType(services, rtype) + services, err := jc.FilterServicesForReplicaType(services, rt) if err != nil { return err } @@ -171,13 +174,13 @@ func (jc *JobController) ReconcileServices( // If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created. // // If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted. - serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype)) + serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt)) for index, serviceSlice := range serviceSlices { if len(serviceSlice) > 1 { - commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index) + commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rt, index) } else if len(serviceSlice) == 0 { - commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index) + commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rt, index) err = jc.CreateNewService(job, rtype, spec, strconv.Itoa(index)) if err != nil { return err @@ -212,9 +215,12 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica return err } + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) + // Append ReplicaTypeLabelDeprecated and ReplicaIndexLabelDeprecated labels. labels := jc.GenLabels(job.GetName()) - utillabels.SetReplicaType(labels, rtype) + utillabels.SetReplicaType(labels, rt) utillabels.SetReplicaIndexStr(labels, index) ports, err := jc.GetPortsFromJob(spec) @@ -236,13 +242,13 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica service.Spec.Ports = append(service.Spec.Ports, svcPort) } - service.Name = GenGeneralName(job.GetName(), rtype, index) + service.Name = GenGeneralName(job.GetName(), rt, index) service.Labels = labels // Create OwnerReference. controllerRef := jc.GenOwnerReference(job) // Creation is expected when there is no error returned - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt) jc.Expectations.RaiseExpectations(expectationServicesKey, 1, 0) err = jc.ServiceControl.CreateServicesWithControllerRef(job.GetNamespace(), service, job.(runtime.Object), controllerRef) diff --git a/pkg/controller.v1/common/util.go b/pkg/controller.v1/common/util.go index 7bd217d2..6fb97c73 100644 --- a/pkg/controller.v1/common/util.go +++ b/pkg/controller.v1/common/util.go @@ -47,8 +47,8 @@ func (p ReplicasPriority) Swap(i, j int) { p[i], p[j] = p[j], p[i] } -func GenGeneralName(jobName string, rtype apiv1.ReplicaType, index string) string { - n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index +func GenGeneralName(jobName string, rtype, index string) string { + n := jobName + "-" + strings.ToLower(rtype) + "-" + index return strings.Replace(n, "/", "-", -1) } diff --git a/pkg/controller.v1/common/util_test.go b/pkg/controller.v1/common/util_test.go index 3b87373d..168b9244 100644 --- a/pkg/controller.v1/common/util_test.go +++ b/pkg/controller.v1/common/util_test.go @@ -18,15 +18,13 @@ import ( "testing" "github.com/stretchr/testify/assert" - - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" ) func TestGenGeneralName(t *testing.T) { tcs := []struct { index string key string - replicaType apiv1.ReplicaType + replicaType string expectedName string }{ { diff --git a/pkg/controller.v1/expectation/util.go b/pkg/controller.v1/expectation/util.go index 9061000d..b7e5bcb4 100644 --- a/pkg/controller.v1/expectation/util.go +++ b/pkg/controller.v1/expectation/util.go @@ -1,16 +1,15 @@ package expectation import ( - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "strings" ) // GenExpectationPodsKey generates an expectation key for pods of a job -func GenExpectationPodsKey(jobKey string, replicaType apiv1.ReplicaType) string { - return jobKey + "/" + strings.ToLower(string(replicaType)) + "/pods" +func GenExpectationPodsKey(jobKey string, replicaType string) string { + return jobKey + "/" + strings.ToLower(replicaType) + "/pods" } // GenExpectationPodsKey generates an expectation key for services of a job -func GenExpectationServicesKey(jobKey string, replicaType apiv1.ReplicaType) string { - return jobKey + "/" + strings.ToLower(string(replicaType)) + "/services" +func GenExpectationServicesKey(jobKey string, replicaType string) string { + return jobKey + "/" + strings.ToLower(replicaType) + "/services" } diff --git a/pkg/core/job.go b/pkg/core/job.go index f0b67cb4..ae3c4321 100644 --- a/pkg/core/job.go +++ b/pkg/core/job.go @@ -2,6 +2,7 @@ package core import ( "sort" + "strings" "time" log "github.com/sirupsen/logrus" @@ -77,7 +78,7 @@ func PastActiveDeadline(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus) b // this method applies only to pods with restartPolicy == OnFailure or Always func PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec, pods []*v1.Pod, - podFilterFunc func(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error)) (bool, error) { + podFilterFunc func(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error)) (bool, error) { if runPolicy.BackoffLimit == nil { return false, nil } @@ -88,7 +89,8 @@ func PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy, continue } // Convert ReplicaType to lower string. - pods, err := podFilterFunc(pods, rtype) + rt := strings.ToLower(string(rtype)) + pods, err := podFilterFunc(pods, rt) if err != nil { return false, err } diff --git a/pkg/core/pod.go b/pkg/core/pod.go index 30cdd1e8..0368f5ef 100644 --- a/pkg/core/pod.go +++ b/pkg/core/pod.go @@ -10,16 +10,16 @@ import ( ) // FilterPodsForReplicaType returns pods belong to a replicaType. -func FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) { +func FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) { var result []*v1.Pod selector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabel: string(replicaType), + apiv1.ReplicaTypeLabel: replicaType, }) // TODO(#149): Remove deprecated selector. deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabelDeprecated: string(replicaType), + apiv1.ReplicaTypeLabelDeprecated: replicaType, }) for _, pod := range pods { diff --git a/pkg/core/service.go b/pkg/core/service.go index 73739b9f..97d9c300 100644 --- a/pkg/core/service.go +++ b/pkg/core/service.go @@ -11,16 +11,16 @@ import ( ) // FilterServicesForReplicaType returns service belong to a replicaType. -func FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) { +func FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) { var result []*v1.Service selector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabel: string(replicaType), + apiv1.ReplicaTypeLabel: replicaType, }) // TODO(#149): Remove deprecated selector. deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{ - apiv1.ReplicaTypeLabelDeprecated: string(replicaType), + apiv1.ReplicaTypeLabelDeprecated: replicaType, }) for _, service := range services { diff --git a/pkg/core/utils.go b/pkg/core/utils.go index 0f12a3c3..6da7a134 100644 --- a/pkg/core/utils.go +++ b/pkg/core/utils.go @@ -2,8 +2,6 @@ package core import ( "strings" - - commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" ) func MaxInt(x, y int) int { @@ -13,7 +11,7 @@ func MaxInt(x, y int) int { return x } -func GenGeneralName(jobName string, rtype commonv1.ReplicaType, index string) string { - n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index +func GenGeneralName(jobName string, rtype string, index string) string { + n := jobName + "-" + strings.ToLower(rtype) + "-" + index return strings.Replace(n, "/", "-", -1) } diff --git a/pkg/reconciler.v1/common/gang_volcano.go b/pkg/reconciler.v1/common/gang_volcano.go index 637c5d2e..82276feb 100644 --- a/pkg/reconciler.v1/common/gang_volcano.go +++ b/pkg/reconciler.v1/common/gang_volcano.go @@ -165,12 +165,12 @@ func (r *VolcanoReconciler) ReconcilePodGroup( } // DecoratePodForGangScheduling decorates the podTemplate before it's used to generate a pod with information for gang-scheduling -func (r *VolcanoReconciler) DecoratePodForGangScheduling(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) { +func (r *VolcanoReconciler) DecoratePodForGangScheduling(rt string, podTemplate *corev1.PodTemplateSpec, job client.Object) { if podTemplate.Spec.SchedulerName == "" || podTemplate.Spec.SchedulerName == r.GetGangSchedulerName() { podTemplate.Spec.SchedulerName = r.GetGangSchedulerName() } else { warnMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten" - commonutil.LoggerForReplica(job, rtype).Warn(warnMsg) + commonutil.LoggerForReplica(job, rt).Warn(warnMsg) r.GetRecorder().Event(job, corev1.EventTypeWarning, "PodTemplateSchedulerNameAlreadySet", warnMsg) } diff --git a/pkg/reconciler.v1/common/interface.go b/pkg/reconciler.v1/common/interface.go index 3ef70337..0e8931ff 100644 --- a/pkg/reconciler.v1/common/interface.go +++ b/pkg/reconciler.v1/common/interface.go @@ -78,7 +78,7 @@ type GangSchedulingInterface interface { // DecoratePodForGangScheduling SHOULD be overridden if gang scheduler demands Pods associated with PodGroup to be // decorated with specific requests. - DecoratePodForGangScheduling(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) + DecoratePodForGangScheduling(rtype string, podTemplate *corev1.PodTemplateSpec, job client.Object) } // PodInterface defines the abstract interface for Pod related actions, such like get, create or delete Pod @@ -90,14 +90,14 @@ type PodInterface interface { GetDefaultContainerName() string // GenPodName CAN be overridden to customize Pod name. - GenPodName(jobName string, rtype commonv1.ReplicaType, index string) string + GenPodName(jobName string, rtype string, index string) string // GetPodsForJob CAN be overridden to customize how to list all pods with the job. GetPodsForJob(ctx context.Context, job client.Object) ([]*corev1.Pod, error) // FilterPodsForReplicaType CAN be overridden if the linking approach between pods and replicaType changes as this // function filters out pods for specific replica type from all pods associated with the job. - FilterPodsForReplicaType(pods []*corev1.Pod, replicaType commonv1.ReplicaType) ([]*corev1.Pod, error) + FilterPodsForReplicaType(pods []*corev1.Pod, replicaType string) ([]*corev1.Pod, error) // GetPodSlices SHOULD NOT be overridden as it generates pod slices for further pod processing. GetPodSlices(pods []*corev1.Pod, replicas int, logger *logrus.Entry) [][]*corev1.Pod @@ -121,7 +121,7 @@ type PodInterface interface { // DecoratePod CAN be overridden if customization to the pod is needed. The default implementation applies nothing // to the pod. - DecoratePod(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) + DecoratePod(rtype string, podTemplate *corev1.PodTemplateSpec, job client.Object) } // ServiceInterface defines the abstract interface for Pod related actions, such like get, create or delete Service @@ -137,7 +137,7 @@ type ServiceInterface interface { // FilterServicesForReplicaType CAN be overridden to customize how to filter out services for this Replica Type. FilterServicesForReplicaType(services []*corev1.Service, - replicaType commonv1.ReplicaType) ([]*corev1.Service, error) + replicaType string) ([]*corev1.Service, error) // GetServiceSlices CAN be overridden to customize how to generate service slices. GetServiceSlices(services []*corev1.Service, replicas int, logger *logrus.Entry) [][]*corev1.Service @@ -157,7 +157,7 @@ type ServiceInterface interface { DeleteService(ns string, name string, job client.Object) error // DecorateService CAN be overridden to customize this service right before being created - DecorateService(rtype commonv1.ReplicaType, svc *corev1.Service, job client.Object) + DecorateService(rtype string, svc *corev1.Service, job client.Object) } // JobInterface defines the abstract interface for Pod related actions, such like get, create or delete TFJob, diff --git a/pkg/reconciler.v1/common/pod.go b/pkg/reconciler.v1/common/pod.go index f6f69cab..3fb31418 100644 --- a/pkg/reconciler.v1/common/pod.go +++ b/pkg/reconciler.v1/common/pod.go @@ -17,6 +17,7 @@ package common import ( "context" "strconv" + "strings" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -78,7 +79,7 @@ func BareKubeflowPodReconciler(client client.Client) *KubeflowPodReconciler { } // GenPodName returns the name of the Pod based on jobName, replicaType and its index -func (r *KubeflowPodReconciler) GenPodName(jobName string, rtype commonv1.ReplicaType, index string) string { +func (r *KubeflowPodReconciler) GenPodName(jobName string, rtype string, index string) string { return core.GenGeneralName(jobName, rtype, index) } @@ -110,7 +111,7 @@ func (r *KubeflowPodReconciler) GetPodSlices(pods []*corev1.Pod, replicas int, l } // FilterPodsForReplicaType filters out Pods for this replicaType -func (r *KubeflowPodReconciler) FilterPodsForReplicaType(pods []*corev1.Pod, replicaType commonv1.ReplicaType) ([]*corev1.Pod, error) { +func (r *KubeflowPodReconciler) FilterPodsForReplicaType(pods []*corev1.Pod, replicaType string) ([]*corev1.Pod, error) { return core.FilterPodsForReplicaType(pods, replicaType) } @@ -125,9 +126,10 @@ func (r *KubeflowPodReconciler) ReconcilePods( replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error { // Convert ReplicaType to lower string. - logger := commonutil.LoggerForReplica(job, rtype) + rt := strings.ToLower(string(rtype)) + logger := commonutil.LoggerForReplica(job, rt) // Get all pods for the type rt. - pods, err := r.FilterPodsForReplicaType(pods, rtype) + pods, err := r.FilterPodsForReplicaType(pods, rt) if err != nil { return err } @@ -196,9 +198,11 @@ func (r *KubeflowPodReconciler) ReconcilePods( } // CreateNewPod generate Pods for this job and submits creation request to APIServer -func (r *KubeflowPodReconciler) CreateNewPod(job client.Object, rt commonv1.ReplicaType, index string, +func (r *KubeflowPodReconciler) CreateNewPod(job client.Object, rtype commonv1.ReplicaType, index string, spec *commonv1.ReplicaSpec, masterRole bool, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) error { + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) logger := commonutil.LoggerForReplica(job, rt) podLabels := r.GenLabels(job.GetName()) @@ -270,7 +274,7 @@ func (r *KubeflowPodReconciler) DeletePod(ctx context.Context, ns string, name s } // DecoratePod decorates podTemplate before a Pod is submitted to the APIServer -func (r *KubeflowPodReconciler) DecoratePod(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) { +func (r *KubeflowPodReconciler) DecoratePod(rtype string, podTemplate *corev1.PodTemplateSpec, job client.Object) { // Default implementation applies nothing to podTemplate return } diff --git a/pkg/reconciler.v1/common/pod_test.go b/pkg/reconciler.v1/common/pod_test.go index f2afeec3..6d651f27 100644 --- a/pkg/reconciler.v1/common/pod_test.go +++ b/pkg/reconciler.v1/common/pod_test.go @@ -15,6 +15,7 @@ package common_test import ( + "strings" "testing" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" @@ -30,7 +31,7 @@ import ( func TestGenPodName(t *testing.T) { type tc struct { testJob *testjobv1.TestJob - testRType commonv1.ReplicaType + testRType string testIndex string expectedName string } @@ -40,7 +41,7 @@ func TestGenPodName(t *testing.T) { tj.SetName("hello-world") return tc{ testJob: tj, - testRType: commonv1.ReplicaType(testjobv1.TestReplicaTypeWorker), + testRType: strings.ToLower(string(testjobv1.TestReplicaTypeWorker)), testIndex: "1", expectedName: "hello-world-worker-1", } @@ -70,7 +71,7 @@ func PodInSlice(pod *corev1.Pod, pods []*corev1.Pod) bool { func TestFilterPodsForReplicaType(t *testing.T) { type tc struct { testPods []*corev1.Pod - testRType commonv1.ReplicaType + testRType string expectedPods []*corev1.Pod } testCase := []tc{ @@ -83,7 +84,7 @@ func TestFilterPodsForReplicaType(t *testing.T) { Name: "pod0", Namespace: "default", Labels: map[string]string{ - commonv1.ReplicaTypeLabel: string(testjobv1.TestReplicaTypeMaster), + commonv1.ReplicaTypeLabel: strings.ToLower(string(testjobv1.TestReplicaTypeMaster)), }, }, Spec: corev1.PodSpec{}, @@ -95,7 +96,7 @@ func TestFilterPodsForReplicaType(t *testing.T) { Name: "pod1", Namespace: "default", Labels: map[string]string{ - commonv1.ReplicaTypeLabel: string(testjobv1.TestReplicaTypeWorker), + commonv1.ReplicaTypeLabel: strings.ToLower(string(testjobv1.TestReplicaTypeWorker)), }, }, Spec: corev1.PodSpec{}, @@ -107,7 +108,7 @@ func TestFilterPodsForReplicaType(t *testing.T) { Name: "pod2", Namespace: "default", Labels: map[string]string{ - commonv1.ReplicaTypeLabel: string(testjobv1.TestReplicaTypeWorker), + commonv1.ReplicaTypeLabel: strings.ToLower(string(testjobv1.TestReplicaTypeWorker)), }, }, Spec: corev1.PodSpec{}, @@ -119,7 +120,7 @@ func TestFilterPodsForReplicaType(t *testing.T) { return tc{ testPods: allPods, - testRType: commonv1.ReplicaType(testjobv1.TestReplicaTypeWorker), + testRType: strings.ToLower(string(testjobv1.TestReplicaTypeWorker)), expectedPods: filteredPods, } }(), diff --git a/pkg/reconciler.v1/common/service.go b/pkg/reconciler.v1/common/service.go index 63e7791e..370ac49b 100644 --- a/pkg/reconciler.v1/common/service.go +++ b/pkg/reconciler.v1/common/service.go @@ -17,6 +17,7 @@ package common import ( "context" "strconv" + "strings" commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/core" @@ -94,7 +95,7 @@ func (r *KubeflowServiceReconciler) GetServicesForJob(ctx context.Context, job c // FilterServicesForReplicaType returns service belong to a replicaType. func (r *KubeflowServiceReconciler) FilterServicesForReplicaType(services []*corev1.Service, - replicaType commonv1.ReplicaType) ([]*corev1.Service, error) { + replicaType string) ([]*corev1.Service, error) { return core.FilterServicesForReplicaType(services, replicaType) } @@ -110,9 +111,11 @@ func (r *KubeflowServiceReconciler) ReconcileServices( rtype commonv1.ReplicaType, spec *commonv1.ReplicaSpec) error { + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) replicas := int(*spec.Replicas) // Get all services for the type rt. - services, err := r.FilterServicesForReplicaType(services, rtype) + services, err := r.FilterServicesForReplicaType(services, rt) if err != nil { return err } @@ -123,13 +126,13 @@ func (r *KubeflowServiceReconciler) ReconcileServices( // If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created. // // If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted. - serviceSlices := r.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype)) + serviceSlices := r.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt)) for index, serviceSlice := range serviceSlices { if len(serviceSlice) > 1 { - commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index) + commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rt, index) } else if len(serviceSlice) == 0 { - commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index) + commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rt, index) err = r.CreateNewService(job, rtype, spec, strconv.Itoa(index)) if err != nil { return err @@ -179,7 +182,9 @@ func (r *KubeflowServiceReconciler) CreateNewService(job client.Object, rtype co service.Spec.Ports = append(service.Spec.Ports, svcPort) } - service.Name = core.GenGeneralName(job.GetName(), rtype, index) + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) + service.Name = core.GenGeneralName(job.GetName(), rt, index) service.Namespace = job.GetNamespace() service.Labels = labels // Create OwnerReference. @@ -188,7 +193,7 @@ func (r *KubeflowServiceReconciler) CreateNewService(job client.Object, rtype co return err } - r.DecorateService(rtype, service, job) + r.DecorateService(rt, service, job) err = r.Create(context.Background(), service) if err != nil && errors.IsTimeout(err) { @@ -215,7 +220,7 @@ func (r *KubeflowServiceReconciler) DeleteService(ns string, name string, job cl } // DecorateService decorates the Service before it's submitted to APIServer -func (r *KubeflowServiceReconciler) DecorateService(rtype commonv1.ReplicaType, svc *corev1.Service, job client.Object) { +func (r *KubeflowServiceReconciler) DecorateService(rtype string, svc *corev1.Service, job client.Object) { // Default implementation applies nothing to podTemplate return } diff --git a/pkg/util/labels/labels.go b/pkg/util/labels/labels.go index 255884a9..a79edd2a 100644 --- a/pkg/util/labels/labels.go +++ b/pkg/util/labels/labels.go @@ -16,9 +16,8 @@ package labels import ( "errors" - "strconv" - v1 "github.com/kubeflow/common/pkg/apis/common/v1" + "strconv" ) // TODO(#149): Remove deprecated labels. @@ -43,7 +42,7 @@ func SetReplicaIndexStr(labels map[string]string, idx string) { labels[v1.ReplicaIndexLabelDeprecated] = idx } -func ReplicaType(labels map[string]string) (v1.ReplicaType, error) { +func ReplicaType(labels map[string]string) (string, error) { v, ok := labels[v1.ReplicaTypeLabel] if !ok { v, ok = labels[v1.ReplicaTypeLabelDeprecated] @@ -51,12 +50,12 @@ func ReplicaType(labels map[string]string) (v1.ReplicaType, error) { return "", errors.New("replica type label not found") } } - return v1.ReplicaType(v), nil + return v, nil } -func SetReplicaType(labels map[string]string, rt v1.ReplicaType) { - labels[v1.ReplicaTypeLabel] = string(rt) - labels[v1.ReplicaTypeLabelDeprecated] = string(rt) +func SetReplicaType(labels map[string]string, rt string) { + labels[v1.ReplicaTypeLabel] = rt + labels[v1.ReplicaTypeLabelDeprecated] = rt } func HasKnownLabels(labels map[string]string, groupName string) bool { diff --git a/pkg/util/labels/labels_test.go b/pkg/util/labels/labels_test.go index 49b2fe91..2c2d86f1 100644 --- a/pkg/util/labels/labels_test.go +++ b/pkg/util/labels/labels_test.go @@ -52,7 +52,7 @@ func TestReplicaIndex(t *testing.T) { func TestReplicaType(t *testing.T) { cases := map[string]struct { labels map[string]string - want v1.ReplicaType + want string wantErr bool }{ "new": { diff --git a/pkg/util/logger.go b/pkg/util/logger.go index a9719fce..d3ce95e6 100644 --- a/pkg/util/logger.go +++ b/pkg/util/logger.go @@ -15,7 +15,6 @@ package util import ( - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "strings" log "github.com/sirupsen/logrus" @@ -24,7 +23,7 @@ import ( metav1unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) -func LoggerForReplica(job metav1.Object, rtype apiv1.ReplicaType) *log.Entry { +func LoggerForReplica(job metav1.Object, rtype string) *log.Entry { return log.WithFields(log.Fields{ // We use job to match the key used in controller.go // Its more common in K8s to use a period to indicate namespace.name. So that's what we use. diff --git a/test_job/controller.v1/test_job/test_job_controller.go b/test_job/controller.v1/test_job/test_job_controller.go index 41978ce8..e40af065 100644 --- a/test_job/controller.v1/test_job/test_job_controller.go +++ b/test_job/controller.v1/test_job/test_job_controller.go @@ -61,7 +61,7 @@ func (t *TestJobController) UpdateJobStatusInApiServer(job interface{}, jobStatu return nil } -func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype commonv1.ReplicaType, index string) error { +func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { return nil }