Skip to content

Commit

Permalink
Revert ReplicaType to string (kubeflow#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
zw0610 authored Nov 22, 2021
1 parent bbdf282 commit 34276e9
Show file tree
Hide file tree
Showing 21 changed files with 107 additions and 95 deletions.
4 changes: 2 additions & 2 deletions pkg/apis/common/v1/interface.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package v1

import (
"k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
)
Expand Down Expand Up @@ -44,7 +44,7 @@ type ControllerInterface interface {
UpdateJobStatusInApiServer(job interface{}, jobStatus *JobStatus) error

// SetClusterSpec sets the cluster spec for the pod
SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype ReplicaType, index string) error
SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype, index string) error

// Returns the default container name in pod
GetDefaultContainerName() string
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller.v1/common/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,9 @@ func (jc *JobController) ReconcileJobs(
// ResetExpectations reset the expectation for creates and deletes of pod/service to zero.
func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) {
for rtype := range replicas {
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, string(rtype))
jc.Expectations.SetExpectations(expectationPodsKey, 0, 0)
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, string(rtype))
jc.Expectations.SetExpectations(expectationServicesKey, 0, 0)
}
}
Expand Down
30 changes: 16 additions & 14 deletions pkg/controller.v1/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"

apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/kubeflow/common/pkg/controller.v1/control"
Expand Down Expand Up @@ -105,7 +106,7 @@ func (jc *JobController) AddPod(obj interface{}) {
return
}

expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rType)
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, string(rType))

jc.Expectations.CreationObserved(expectationPodsKey)
// TODO: we may need add backoff here
Expand Down Expand Up @@ -206,7 +207,7 @@ func (jc *JobController) DeletePod(obj interface{}) {
return
}

expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rType)
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, string(rType))

jc.Expectations.DeletionObserved(expectationPodsKey)
deletedPodsCount.Inc()
Expand Down Expand Up @@ -255,7 +256,7 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error)
}

// FilterPodsForReplicaType returns pods belong to a replicaType.
func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) {
func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) {
return core.FilterPodsForReplicaType(pods, replicaType)
}

Expand All @@ -271,10 +272,11 @@ func (jc *JobController) ReconcilePods(
job interface{},
jobStatus *apiv1.JobStatus,
pods []*v1.Pod,
rtype apiv1.ReplicaType,
rType apiv1.ReplicaType,
spec *apiv1.ReplicaSpec,
replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error {

rt := strings.ToLower(string(rType))
metaObject, ok := job.(metav1.Object)
if !ok {
return fmt.Errorf("job is not a metav1.Object type")
Expand All @@ -288,19 +290,19 @@ func (jc *JobController) ReconcilePods(
utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", job, err))
return err
}
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rt)

// Convert ReplicaType to lower string.
logger := commonutil.LoggerForReplica(metaObject, rtype)
logger := commonutil.LoggerForReplica(metaObject, rt)
// Get all pods for the type rt.
pods, err = jc.FilterPodsForReplicaType(pods, rtype)
pods, err = jc.FilterPodsForReplicaType(pods, rt)
if err != nil {
return err
}
numReplicas := int(*spec.Replicas)
var masterRole bool

initializeReplicaStatuses(jobStatus, rtype)
initializeReplicaStatuses(jobStatus, rType)

// GetPodSlices will return enough information here to make decision to add/remove/update resources.
//
Expand All @@ -311,13 +313,13 @@ func (jc *JobController) ReconcilePods(
podSlices := jc.GetPodSlices(pods, numReplicas, logger)
for index, podSlice := range podSlices {
if len(podSlice) > 1 {
logger.Warningf("We have too many pods for %s %d", rtype, index)
logger.Warningf("We have too many pods for %s %d", rt, index)
} else if len(podSlice) == 0 {
logger.Infof("Need to create new pod: %s-%d", rtype, index)
logger.Infof("Need to create new pod: %s-%d", rt, index)

// check if this replica is the master role
masterRole = jc.Controller.IsMasterRole(replicas, rtype, index)
err = jc.createNewPod(job, rtype, index, spec, masterRole, replicas)
masterRole = jc.Controller.IsMasterRole(replicas, rType, index)
err = jc.createNewPod(job, rt, index, spec, masterRole, replicas)
if err != nil {
return err
}
Expand Down Expand Up @@ -358,14 +360,14 @@ func (jc *JobController) ReconcilePods(
}
}

updateJobReplicaStatuses(jobStatus, rtype, pod)
updateJobReplicaStatuses(jobStatus, rType, pod)
}
}
return nil
}

// createNewPod creates a new pod for the given index and type.
func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, index int, spec *apiv1.ReplicaSpec, masterRole bool,
func (jc *JobController) createNewPod(job interface{}, rt string, index int, spec *apiv1.ReplicaSpec, masterRole bool,
replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error {

metaObject, ok := job.(metav1.Object)
Expand Down
22 changes: 13 additions & 9 deletions pkg/controller.v1/common/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package common
import (
"fmt"
"strconv"
"strings"

apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/kubeflow/common/pkg/controller.v1/control"
Expand Down Expand Up @@ -74,7 +75,7 @@ func (jc *JobController) AddService(obj interface{}) {
return
}

expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rType)
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, string(rType))

jc.Expectations.CreationObserved(expectationServicesKey)
// TODO: we may need add backoff here
Expand Down Expand Up @@ -139,7 +140,7 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service
}

// FilterServicesForReplicaType returns service belong to a replicaType.
func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) {
func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) {
return core.FilterServicesForReplicaType(services, replicaType)
}

Expand All @@ -158,9 +159,11 @@ func (jc *JobController) ReconcileServices(
rtype apiv1.ReplicaType,
spec *apiv1.ReplicaSpec) error {

// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))
replicas := int(*spec.Replicas)
// Get all services for the type rt.
services, err := jc.FilterServicesForReplicaType(services, rtype)
services, err := jc.FilterServicesForReplicaType(services, rt)
if err != nil {
return err
}
Expand All @@ -171,13 +174,13 @@ func (jc *JobController) ReconcileServices(
// If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created.
//
// If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted.
serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype))
serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt))

for index, serviceSlice := range serviceSlices {
if len(serviceSlice) > 1 {
commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index)
commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rtype, index)
} else if len(serviceSlice) == 0 {
commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index)
commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rtype, index)
err = jc.CreateNewService(job, rtype, spec, strconv.Itoa(index))
if err != nil {
return err
Expand Down Expand Up @@ -212,9 +215,10 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica
return err
}

rt := strings.ToLower(string(rtype))
// Append ReplicaTypeLabelDeprecated and ReplicaIndexLabelDeprecated labels.
labels := jc.GenLabels(job.GetName())
utillabels.SetReplicaType(labels, rtype)
utillabels.SetReplicaType(labels, rt)
utillabels.SetReplicaIndexStr(labels, index)

ports, err := jc.GetPortsFromJob(spec)
Expand All @@ -236,13 +240,13 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica
service.Spec.Ports = append(service.Spec.Ports, svcPort)
}

service.Name = GenGeneralName(job.GetName(), rtype, index)
service.Name = GenGeneralName(job.GetName(), rt, index)
service.Labels = labels
// Create OwnerReference.
controllerRef := jc.GenOwnerReference(job)

// Creation is expected when there is no error returned
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt)
jc.Expectations.RaiseExpectations(expectationServicesKey, 1, 0)

err = jc.ServiceControl.CreateServicesWithControllerRef(job.GetNamespace(), service, job.(runtime.Object), controllerRef)
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller.v1/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func (p ReplicasPriority) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}

func GenGeneralName(jobName string, rtype apiv1.ReplicaType, index string) string {
n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index
func GenGeneralName(jobName string, rtype string, index string) string {
n := jobName + "-" + strings.ToLower(rtype) + "-" + index
return strings.Replace(n, "/", "-", -1)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/controller.v1/common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestGenGeneralName(t *testing.T) {
}

for _, tc := range tcs {
actual := GenGeneralName(tc.key, tc.replicaType, tc.index)
actual := GenGeneralName(tc.key, string(tc.replicaType), tc.index)
if actual != tc.expectedName {
t.Errorf("Expected name %s, got %s", tc.expectedName, actual)
}
Expand Down
9 changes: 4 additions & 5 deletions pkg/controller.v1/expectation/util.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
package expectation

import (
apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"strings"
)

// GenExpectationPodsKey generates an expectation key for pods of a job
func GenExpectationPodsKey(jobKey string, replicaType apiv1.ReplicaType) string {
return jobKey + "/" + strings.ToLower(string(replicaType)) + "/pods"
func GenExpectationPodsKey(jobKey string, replicaType string) string {
return jobKey + "/" + strings.ToLower(replicaType) + "/pods"
}

// GenExpectationPodsKey generates an expectation key for services of a job
func GenExpectationServicesKey(jobKey string, replicaType apiv1.ReplicaType) string {
return jobKey + "/" + strings.ToLower(string(replicaType)) + "/services"
func GenExpectationServicesKey(jobKey string, replicaType string) string {
return jobKey + "/" + strings.ToLower(replicaType) + "/services"
}
6 changes: 4 additions & 2 deletions pkg/core/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package core

import (
"sort"
"strings"
"time"

log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -77,7 +78,7 @@ func PastActiveDeadline(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus) b
// this method applies only to pods when restartPolicy is one of OnFailure, Always or ExitCode
func PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy,
replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec, pods []*v1.Pod,
podFilterFunc func(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error)) (bool, error) {
podFilterFunc func(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error)) (bool, error) {
if runPolicy.BackoffLimit == nil {
return false, nil
}
Expand All @@ -88,7 +89,8 @@ func PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy,
continue
}
// Convert ReplicaType to lower string.
pods, err := podFilterFunc(pods, rtype)
rt := strings.ToLower(string(rtype))
pods, err := podFilterFunc(pods, rt)
if err != nil {
return false, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/core/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import (
)

// FilterPodsForReplicaType returns pods belong to a replicaType.
func FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) {
func FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) {
var result []*v1.Pod

selector := labels.SelectorFromValidatedSet(labels.Set{
apiv1.ReplicaTypeLabel: string(replicaType),
apiv1.ReplicaTypeLabel: replicaType,
})

// TODO(#149): Remove deprecated selector.
deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{
apiv1.ReplicaTypeLabelDeprecated: string(replicaType),
apiv1.ReplicaTypeLabelDeprecated: replicaType,
})

for _, pod := range pods {
Expand Down
6 changes: 3 additions & 3 deletions pkg/core/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ import (
)

// FilterServicesForReplicaType returns service belong to a replicaType.
func FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) {
func FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) {
var result []*v1.Service

selector := labels.SelectorFromValidatedSet(labels.Set{
apiv1.ReplicaTypeLabel: string(replicaType),
apiv1.ReplicaTypeLabel: replicaType,
})

// TODO(#149): Remove deprecated selector.
deprecatedSelector := labels.SelectorFromValidatedSet(labels.Set{
apiv1.ReplicaTypeLabelDeprecated: string(replicaType),
apiv1.ReplicaTypeLabelDeprecated: replicaType,
})

for _, service := range services {
Expand Down
6 changes: 2 additions & 4 deletions pkg/core/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package core

import (
"strings"

commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
)

func MaxInt(x, y int) int {
Expand All @@ -13,7 +11,7 @@ func MaxInt(x, y int) int {
return x
}

func GenGeneralName(jobName string, rtype commonv1.ReplicaType, index string) string {
n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index
func GenGeneralName(jobName string, rtype string, index string) string {
n := jobName + "-" + strings.ToLower(rtype) + "-" + index
return strings.Replace(n, "/", "-", -1)
}
2 changes: 1 addition & 1 deletion pkg/reconciler.v1/common/gang_volcano.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (r *VolcanoReconciler) ReconcilePodGroup(
}

// DecoratePodForGangScheduling decorates the podTemplate before it's used to generate a pod with information for gang-scheduling
func (r *VolcanoReconciler) DecoratePodForGangScheduling(rtype commonv1.ReplicaType, podTemplate *corev1.PodTemplateSpec, job client.Object) {
func (r *VolcanoReconciler) DecoratePodForGangScheduling(rtype string, podTemplate *corev1.PodTemplateSpec, job client.Object) {
if podTemplate.Spec.SchedulerName == "" || podTemplate.Spec.SchedulerName == r.GetGangSchedulerName() {
podTemplate.Spec.SchedulerName = r.GetGangSchedulerName()
} else {
Expand Down
Loading

0 comments on commit 34276e9

Please sign in to comment.