From 31efa7515e6b3d4afb4f6b8bfe411a0e60e04b89 Mon Sep 17 00:00:00 2001 From: MartinForReal Date: Tue, 1 Jun 2021 13:02:04 +0800 Subject: [PATCH] change type of rtype to commonv1.ReplicaType (#135) --- pkg/apis/common/v1/interface.go | 2 +- pkg/controller.v1/common/job.go | 4 +-- pkg/controller.v1/common/pod.go | 29 ++++++++---------- pkg/controller.v1/common/service.go | 30 ++++++++----------- pkg/controller.v1/common/util.go | 4 +-- pkg/controller.v1/common/util_test.go | 3 +- pkg/controller.v1/expectation/util.go | 13 ++++---- pkg/util/logger.go | 3 +- .../test_job/test_job_controller.go | 2 +- 9 files changed, 42 insertions(+), 48 deletions(-) diff --git a/pkg/apis/common/v1/interface.go b/pkg/apis/common/v1/interface.go index 255d7b6e..661a96c6 100644 --- a/pkg/apis/common/v1/interface.go +++ b/pkg/apis/common/v1/interface.go @@ -44,7 +44,7 @@ type ControllerInterface interface { UpdateJobStatusInApiServer(job interface{}, jobStatus *JobStatus) error // SetClusterSpec sets the cluster spec for the pod - SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype, index string) error + SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype ReplicaType, index string) error // Returns the default container name in pod GetDefaultContainerName() string diff --git a/pkg/controller.v1/common/job.go b/pkg/controller.v1/common/job.go index 00cfcae7..26f65b91 100644 --- a/pkg/controller.v1/common/job.go +++ b/pkg/controller.v1/common/job.go @@ -4,7 +4,6 @@ import ( "fmt" "reflect" "sort" - "strings" "time" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" @@ -368,8 +367,7 @@ func (jc *JobController) PastBackoffLimit(jobName string, runPolicy *apiv1.RunPo continue } // Convert ReplicaType to lower string. - rt := strings.ToLower(string(rtype)) - pods, err := jc.FilterPodsForReplicaType(pods, rt) + pods, err := jc.FilterPodsForReplicaType(pods, rtype) if err != nil { return false, err } diff --git a/pkg/controller.v1/common/pod.go b/pkg/controller.v1/common/pod.go index 409694dc..1d00700f 100644 --- a/pkg/controller.v1/common/pod.go +++ b/pkg/controller.v1/common/pod.go @@ -16,10 +16,6 @@ package common import ( "fmt" - "reflect" - "strconv" - "strings" - "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" "github.com/prometheus/client_golang/prometheus" @@ -32,6 +28,8 @@ import ( "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/client-go/tools/cache" + "reflect" + "strconv" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" commonutil "github.com/kubeflow/common/pkg/util" @@ -104,7 +102,7 @@ func (jc *JobController) AddPod(obj interface{}) { } rtype := pod.Labels[apiv1.ReplicaTypeLabel] - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype)) jc.Expectations.CreationObserved(expectationPodsKey) // TODO: we may need add backoff here @@ -205,7 +203,7 @@ func (jc *JobController) DeletePod(obj interface{}) { } rtype := pod.Labels[apiv1.ReplicaTypeLabel] - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype)) jc.Expectations.DeletionObserved(expectationPodsKey) deletedPodsCount.Inc() @@ -254,14 +252,14 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error) } // FilterPodsForReplicaType returns pods belong to a replicaType. -func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) { +func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) { var result []*v1.Pod replicaSelector := &metav1.LabelSelector{ MatchLabels: make(map[string]string), } - replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType + replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType) for _, pod := range pods { selector, err := metav1.LabelSelectorAsSelector(replicaSelector) @@ -337,10 +335,9 @@ func (jc *JobController) ReconcilePods( } // Convert ReplicaType to lower string. - rt := strings.ToLower(string(rtype)) - logger := commonutil.LoggerForReplica(metaObject, rt) + logger := commonutil.LoggerForReplica(metaObject, rtype) // Get all pods for the type rt. - pods, err := jc.FilterPodsForReplicaType(pods, rt) + pods, err := jc.FilterPodsForReplicaType(pods, rtype) if err != nil { return err } @@ -358,13 +355,13 @@ func (jc *JobController) ReconcilePods( podSlices := jc.GetPodSlices(pods, numReplicas, logger) for index, podSlice := range podSlices { if len(podSlice) > 1 { - logger.Warningf("We have too many pods for %s %d", rt, index) + logger.Warningf("We have too many pods for %s %d", rtype, index) } else if len(podSlice) == 0 { - logger.Infof("Need to create new pod: %s-%d", rt, index) + logger.Infof("Need to create new pod: %s-%d", rtype, index) // check if this replica is the master role masterRole = jc.Controller.IsMasterRole(replicas, rtype, index) - err = jc.createNewPod(job, rt, strconv.Itoa(index), spec, masterRole, replicas) + err = jc.createNewPod(job, rtype, strconv.Itoa(index), spec, masterRole, replicas) if err != nil { return err } @@ -408,7 +405,7 @@ func (jc *JobController) ReconcilePods( } // createNewPod creates a new pod for the given index and type. -func (jc *JobController) createNewPod(job interface{}, rt, index string, spec *apiv1.ReplicaSpec, masterRole bool, +func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, index string, spec *apiv1.ReplicaSpec, masterRole bool, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error { metaObject, ok := job.(metav1.Object) @@ -433,7 +430,7 @@ func (jc *JobController) createNewPod(job interface{}, rt, index string, spec *a // Set type and index for the worker. labels := jc.GenLabels(metaObject.GetName()) - labels[apiv1.ReplicaTypeLabel] = rt + labels[apiv1.ReplicaTypeLabel] = string(rt) labels[apiv1.ReplicaIndexLabel] = index if masterRole { diff --git a/pkg/controller.v1/common/service.go b/pkg/controller.v1/common/service.go index d7ce4545..07db5b0d 100644 --- a/pkg/controller.v1/common/service.go +++ b/pkg/controller.v1/common/service.go @@ -15,9 +15,6 @@ package common import ( "fmt" - "strconv" - "strings" - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" "github.com/kubeflow/common/pkg/controller.v1/expectation" @@ -31,6 +28,7 @@ import ( "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "strconv" ) var ( @@ -71,8 +69,8 @@ func (jc *JobController) AddService(obj interface{}) { return } - rtype := service.Labels[apiv1.ReplicaTypeLabel] - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) + rtypeValue := service.Labels[apiv1.ReplicaTypeLabel] + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, apiv1.ReplicaType(rtypeValue)) jc.Expectations.CreationObserved(expectationServicesKey) // TODO: we may need add backoff here @@ -137,14 +135,14 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service } // FilterServicesForReplicaType returns service belong to a replicaType. -func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) { +func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) { var result []*v1.Service replicaSelector := &metav1.LabelSelector{ MatchLabels: make(map[string]string), } - replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType + replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType) for _, service := range services { selector, err := metav1.LabelSelectorAsSelector(replicaSelector) @@ -209,12 +207,9 @@ func (jc *JobController) ReconcileServices( rtype apiv1.ReplicaType, spec *apiv1.ReplicaSpec) error { - // Convert ReplicaType to lower string. - rt := strings.ToLower(string(rtype)) - replicas := int(*spec.Replicas) // Get all services for the type rt. - services, err := jc.FilterServicesForReplicaType(services, rt) + services, err := jc.FilterServicesForReplicaType(services, rtype) if err != nil { return err } @@ -225,13 +220,13 @@ func (jc *JobController) ReconcileServices( // If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created. // // If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted. - serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt)) + serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype)) for index, serviceSlice := range serviceSlices { if len(serviceSlice) > 1 { - commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rt, index) + commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index) } else if len(serviceSlice) == 0 { - commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rt, index) + commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index) err = jc.CreateNewService(job, rtype, spec, strconv.Itoa(index)) if err != nil { return err @@ -283,8 +278,7 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica } // Convert ReplicaType to lower string. - rt := strings.ToLower(string(rtype)) - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt) + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) err = jc.Expectations.ExpectCreations(expectationServicesKey, 1) if err != nil { return err @@ -292,7 +286,7 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica // Append ReplicaTypeLabel and ReplicaIndexLabel labels. labels := jc.GenLabels(job.GetName()) - labels[apiv1.ReplicaTypeLabel] = rt + labels[apiv1.ReplicaTypeLabel] = string(rtype) labels[apiv1.ReplicaIndexLabel] = index ports, err := jc.GetPortsFromJob(spec) @@ -314,7 +308,7 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica service.Spec.Ports = append(service.Spec.Ports, svcPort) } - service.Name = GenGeneralName(job.GetName(), rt, index) + service.Name = GenGeneralName(job.GetName(), rtype, index) service.Labels = labels // Create OwnerReference. controllerRef := jc.GenOwnerReference(job) diff --git a/pkg/controller.v1/common/util.go b/pkg/controller.v1/common/util.go index 97535a87..fd758194 100644 --- a/pkg/controller.v1/common/util.go +++ b/pkg/controller.v1/common/util.go @@ -44,8 +44,8 @@ func (p ReplicasPriority) Swap(i, j int) { p[i], p[j] = p[j], p[i] } -func GenGeneralName(jobName, rtype, index string) string { - n := jobName + "-" + rtype + "-" + index +func GenGeneralName(jobName string, rtype apiv1.ReplicaType, index string) string { + n := jobName + "-" + string(rtype) + "-" + index return strings.Replace(n, "/", "-", -1) } diff --git a/pkg/controller.v1/common/util_test.go b/pkg/controller.v1/common/util_test.go index fb065205..b43779a0 100644 --- a/pkg/controller.v1/common/util_test.go +++ b/pkg/controller.v1/common/util_test.go @@ -16,12 +16,13 @@ package common import ( "fmt" + apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/stretchr/testify/assert" "testing" ) func TestGenGeneralName(t *testing.T) { - testRType := "worker" + var testRType apiv1.ReplicaType = "worker" testIndex := "1" testKey := "1/2/3/4/5" expectedName := fmt.Sprintf("1-2-3-4-5-%s-%s", testRType, testIndex) diff --git a/pkg/controller.v1/expectation/util.go b/pkg/controller.v1/expectation/util.go index 056847b2..9061000d 100644 --- a/pkg/controller.v1/expectation/util.go +++ b/pkg/controller.v1/expectation/util.go @@ -1,13 +1,16 @@ package expectation -import "strings" +import ( + apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "strings" +) // GenExpectationPodsKey generates an expectation key for pods of a job -func GenExpectationPodsKey(jobKey, replicaType string) string { - return jobKey + "/" + strings.ToLower(replicaType) + "/pods" +func GenExpectationPodsKey(jobKey string, replicaType apiv1.ReplicaType) string { + return jobKey + "/" + strings.ToLower(string(replicaType)) + "/pods" } // GenExpectationPodsKey generates an expectation key for services of a job -func GenExpectationServicesKey(jobKey, replicaType string) string { - return jobKey + "/" + strings.ToLower(replicaType) + "/services" +func GenExpectationServicesKey(jobKey string, replicaType apiv1.ReplicaType) string { + return jobKey + "/" + strings.ToLower(string(replicaType)) + "/services" } diff --git a/pkg/util/logger.go b/pkg/util/logger.go index d3ce95e6..a9719fce 100644 --- a/pkg/util/logger.go +++ b/pkg/util/logger.go @@ -15,6 +15,7 @@ package util import ( + apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "strings" log "github.com/sirupsen/logrus" @@ -23,7 +24,7 @@ import ( metav1unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) -func LoggerForReplica(job metav1.Object, rtype string) *log.Entry { +func LoggerForReplica(job metav1.Object, rtype apiv1.ReplicaType) *log.Entry { return log.WithFields(log.Fields{ // We use job to match the key used in controller.go // Its more common in K8s to use a period to indicate namespace.name. So that's what we use. diff --git a/test_job/controller.v1/test_job/test_job_controller.go b/test_job/controller.v1/test_job/test_job_controller.go index 2e6dd210..80d37cd9 100644 --- a/test_job/controller.v1/test_job/test_job_controller.go +++ b/test_job/controller.v1/test_job/test_job_controller.go @@ -65,7 +65,7 @@ func (t *TestJobController) UpdateJobStatusInApiServer(job interface{}, jobStatu return nil } -func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { +func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype commonv1.ReplicaType, index string) error { return nil }