diff --git a/pkg/apis/common/v1/interface.go b/pkg/apis/common/v1/interface.go index 661a96c6..255d7b6e 100644 --- a/pkg/apis/common/v1/interface.go +++ b/pkg/apis/common/v1/interface.go @@ -44,7 +44,7 @@ type ControllerInterface interface { UpdateJobStatusInApiServer(job interface{}, jobStatus *JobStatus) error // SetClusterSpec sets the cluster spec for the pod - SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype ReplicaType, index string) error + SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype, index string) error // Returns the default container name in pod GetDefaultContainerName() string diff --git a/pkg/controller.v1/common/job.go b/pkg/controller.v1/common/job.go index 606f0dc6..ad03d690 100644 --- a/pkg/controller.v1/common/job.go +++ b/pkg/controller.v1/common/job.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "sort" + "strings" "time" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" @@ -326,9 +327,10 @@ func (jc *JobController) ReconcileJobs( // ResetExpectations reset the expectation for creates and deletes of pod/service to zero. func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) { for rtype := range replicas { - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) + rt := strings.ToLower(string(rtype)) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rt) jc.Expectations.SetExpectations(expectationPodsKey, 0, 0) - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt) jc.Expectations.SetExpectations(expectationServicesKey, 0, 0) } } @@ -359,7 +361,8 @@ func (jc *JobController) PastBackoffLimit(jobName string, runPolicy *apiv1.RunPo continue } // Convert ReplicaType to lower string. - pods, err := jc.FilterPodsForReplicaType(pods, rtype) + rt := strings.ToLower(string(rtype)) + pods, err := jc.FilterPodsForReplicaType(pods, rt) if err != nil { return false, err } diff --git a/pkg/controller.v1/common/pod.go b/pkg/controller.v1/common/pod.go index 507b33f7..f5243a82 100644 --- a/pkg/controller.v1/common/pod.go +++ b/pkg/controller.v1/common/pod.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" "strconv" + "strings" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" @@ -103,7 +104,7 @@ func (jc *JobController) AddPod(obj interface{}) { } rtype := pod.Labels[apiv1.ReplicaTypeLabel] - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype)) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) jc.Expectations.CreationObserved(expectationPodsKey) // TODO: we may need add backoff here @@ -204,7 +205,7 @@ func (jc *JobController) DeletePod(obj interface{}) { } rtype := pod.Labels[apiv1.ReplicaTypeLabel] - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype)) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) jc.Expectations.DeletionObserved(expectationPodsKey) deletedPodsCount.Inc() @@ -253,14 +254,14 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error) } // FilterPodsForReplicaType returns pods belong to a replicaType. -func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) { +func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) { var result []*v1.Pod replicaSelector := &metav1.LabelSelector{ MatchLabels: make(map[string]string), } - replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType) + replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType for _, pod := range pods { selector, err := metav1.LabelSelectorAsSelector(replicaSelector) @@ -339,12 +340,13 @@ func (jc *JobController) ReconcilePods( utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", job, err)) return err } - expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype) - // Convert ReplicaType to lower string. - logger := commonutil.LoggerForReplica(metaObject, rtype) + rt := strings.ToLower(string(rtype)) + expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rt) + + logger := commonutil.LoggerForReplica(metaObject, rt) // Get all pods for the type rt. - pods, err = jc.FilterPodsForReplicaType(pods, rtype) + pods, err = jc.FilterPodsForReplicaType(pods, rt) if err != nil { return err } @@ -362,13 +364,13 @@ func (jc *JobController) ReconcilePods( podSlices := jc.GetPodSlices(pods, numReplicas, logger) for index, podSlice := range podSlices { if len(podSlice) > 1 { - logger.Warningf("We have too many pods for %s %d", rtype, index) + logger.Warningf("We have too many pods for %s %d", rt, index) } else if len(podSlice) == 0 { - logger.Infof("Need to create new pod: %s-%d", rtype, index) + logger.Infof("Need to create new pod: %s-%d", rt, index) // check if this replica is the master role masterRole = jc.Controller.IsMasterRole(replicas, rtype, index) - err = jc.createNewPod(job, rtype, strconv.Itoa(index), spec, masterRole, replicas) + err = jc.createNewPod(job, rt, strconv.Itoa(index), spec, masterRole, replicas) if err != nil { return err } @@ -416,7 +418,7 @@ func (jc *JobController) ReconcilePods( } // createNewPod creates a new pod for the given index and type. -func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, index string, spec *apiv1.ReplicaSpec, masterRole bool, +func (jc *JobController) createNewPod(job interface{}, rt, index string, spec *apiv1.ReplicaSpec, masterRole bool, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error { metaObject, ok := job.(metav1.Object) @@ -436,7 +438,7 @@ func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, ind // Set type and index for the worker. labels := jc.GenLabels(metaObject.GetName()) - labels[apiv1.ReplicaTypeLabel] = string(rt) + labels[apiv1.ReplicaTypeLabel] = rt labels[apiv1.ReplicaIndexLabel] = index if masterRole { diff --git a/pkg/controller.v1/common/service.go b/pkg/controller.v1/common/service.go index 9d993f34..ccf5aca5 100644 --- a/pkg/controller.v1/common/service.go +++ b/pkg/controller.v1/common/service.go @@ -16,6 +16,7 @@ package common import ( "fmt" "strconv" + "strings" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/kubeflow/common/pkg/controller.v1/control" @@ -71,8 +72,8 @@ func (jc *JobController) AddService(obj interface{}) { return } - rtypeValue := service.Labels[apiv1.ReplicaTypeLabel] - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, apiv1.ReplicaType(rtypeValue)) + rtype := service.Labels[apiv1.ReplicaTypeLabel] + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) jc.Expectations.CreationObserved(expectationServicesKey) // TODO: we may need add backoff here @@ -137,14 +138,14 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service } // FilterServicesForReplicaType returns service belong to a replicaType. -func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) { +func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) { var result []*v1.Service replicaSelector := &metav1.LabelSelector{ MatchLabels: make(map[string]string), } - replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType) + replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType for _, service := range services { selector, err := metav1.LabelSelectorAsSelector(replicaSelector) @@ -209,9 +210,11 @@ func (jc *JobController) ReconcileServices( rtype apiv1.ReplicaType, spec *apiv1.ReplicaSpec) error { + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) replicas := int(*spec.Replicas) // Get all services for the type rt. - services, err := jc.FilterServicesForReplicaType(services, rtype) + services, err := jc.FilterServicesForReplicaType(services, rt) if err != nil { return err } @@ -222,13 +225,13 @@ func (jc *JobController) ReconcileServices( // If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created. // // If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted. - serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype)) + serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt)) for index, serviceSlice := range serviceSlices { if len(serviceSlice) > 1 { - commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index) + commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rt, index) } else if len(serviceSlice) == 0 { - commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index) + commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rt, index) err = jc.CreateNewService(job, rtype, spec, strconv.Itoa(index)) if err != nil { return err @@ -279,9 +282,12 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica return err } + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) + // Append ReplicaTypeLabel and ReplicaIndexLabel labels. labels := jc.GenLabels(job.GetName()) - labels[apiv1.ReplicaTypeLabel] = string(rtype) + labels[apiv1.ReplicaTypeLabel] = rt labels[apiv1.ReplicaIndexLabel] = index ports, err := jc.GetPortsFromJob(spec) @@ -303,13 +309,13 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica service.Spec.Ports = append(service.Spec.Ports, svcPort) } - service.Name = GenGeneralName(job.GetName(), rtype, index) + service.Name = GenGeneralName(job.GetName(), rt, index) service.Labels = labels // Create OwnerReference. controllerRef := jc.GenOwnerReference(job) // Creation is expected when there is no error returned - expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype) + expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt) jc.Expectations.RaiseExpectations(expectationServicesKey, 1, 0) err = jc.ServiceControl.CreateServicesWithControllerRef(job.GetNamespace(), service, job.(runtime.Object), controllerRef) diff --git a/pkg/controller.v1/common/util.go b/pkg/controller.v1/common/util.go index f1800210..55a03739 100644 --- a/pkg/controller.v1/common/util.go +++ b/pkg/controller.v1/common/util.go @@ -44,8 +44,8 @@ func (p ReplicasPriority) Swap(i, j int) { p[i], p[j] = p[j], p[i] } -func GenGeneralName(jobName string, rtype apiv1.ReplicaType, index string) string { - n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index +func GenGeneralName(jobName string, rtype, index string) string { + n := jobName + "-" + strings.ToLower(rtype) + "-" + index return strings.Replace(n, "/", "-", -1) } diff --git a/pkg/controller.v1/common/util_test.go b/pkg/controller.v1/common/util_test.go index 3b87373d..168b9244 100644 --- a/pkg/controller.v1/common/util_test.go +++ b/pkg/controller.v1/common/util_test.go @@ -18,15 +18,13 @@ import ( "testing" "github.com/stretchr/testify/assert" - - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" ) func TestGenGeneralName(t *testing.T) { tcs := []struct { index string key string - replicaType apiv1.ReplicaType + replicaType string expectedName string }{ { diff --git a/pkg/controller.v1/expectation/util.go b/pkg/controller.v1/expectation/util.go index 9061000d..4c57c40f 100644 --- a/pkg/controller.v1/expectation/util.go +++ b/pkg/controller.v1/expectation/util.go @@ -1,16 +1,15 @@ package expectation import ( - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "strings" ) // GenExpectationPodsKey generates an expectation key for pods of a job -func GenExpectationPodsKey(jobKey string, replicaType apiv1.ReplicaType) string { - return jobKey + "/" + strings.ToLower(string(replicaType)) + "/pods" +func GenExpectationPodsKey(jobKey, replicaType string) string { + return jobKey + "/" + strings.ToLower(replicaType) + "/pods" } // GenExpectationPodsKey generates an expectation key for services of a job -func GenExpectationServicesKey(jobKey string, replicaType apiv1.ReplicaType) string { - return jobKey + "/" + strings.ToLower(string(replicaType)) + "/services" +func GenExpectationServicesKey(jobKey, replicaType string) string { + return jobKey + "/" + strings.ToLower(replicaType) + "/services" } diff --git a/pkg/util/logger.go b/pkg/util/logger.go index a9719fce..d3ce95e6 100644 --- a/pkg/util/logger.go +++ b/pkg/util/logger.go @@ -15,7 +15,6 @@ package util import ( - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" "strings" log "github.com/sirupsen/logrus" @@ -24,7 +23,7 @@ import ( metav1unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) -func LoggerForReplica(job metav1.Object, rtype apiv1.ReplicaType) *log.Entry { +func LoggerForReplica(job metav1.Object, rtype string) *log.Entry { return log.WithFields(log.Fields{ // We use job to match the key used in controller.go // Its more common in K8s to use a period to indicate namespace.name. So that's what we use. diff --git a/test_job/controller.v1/test_job/test_job_controller.go b/test_job/controller.v1/test_job/test_job_controller.go index 80d37cd9..2e6dd210 100644 --- a/test_job/controller.v1/test_job/test_job_controller.go +++ b/test_job/controller.v1/test_job/test_job_controller.go @@ -65,7 +65,7 @@ func (t *TestJobController) UpdateJobStatusInApiServer(job interface{}, jobStatu return nil } -func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype commonv1.ReplicaType, index string) error { +func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { return nil }