Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
change type of rtype to commonv1.ReplicaType (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinForReal authored Jun 1, 2021
1 parent 4559a3d commit 31efa75
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pkg/apis/common/v1/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type ControllerInterface interface {
UpdateJobStatusInApiServer(job interface{}, jobStatus *JobStatus) error

// SetClusterSpec sets the cluster spec for the pod
SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype, index string) error
SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype ReplicaType, index string) error

// Returns the default container name in pod
GetDefaultContainerName() string
Expand Down
4 changes: 1 addition & 3 deletions pkg/controller.v1/common/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"reflect"
"sort"
"strings"
"time"

apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
Expand Down Expand Up @@ -368,8 +367,7 @@ func (jc *JobController) PastBackoffLimit(jobName string, runPolicy *apiv1.RunPo
continue
}
// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))
pods, err := jc.FilterPodsForReplicaType(pods, rt)
pods, err := jc.FilterPodsForReplicaType(pods, rtype)
if err != nil {
return false, err
}
Expand Down
29 changes: 13 additions & 16 deletions pkg/controller.v1/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ package common

import (
"fmt"
"reflect"
"strconv"
"strings"

"github.com/kubeflow/common/pkg/controller.v1/control"
"github.com/kubeflow/common/pkg/controller.v1/expectation"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -32,6 +28,8 @@ import (
"k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/tools/cache"
"reflect"
"strconv"

apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
commonutil "github.com/kubeflow/common/pkg/util"
Expand Down Expand Up @@ -104,7 +102,7 @@ func (jc *JobController) AddPod(obj interface{}) {
}

rtype := pod.Labels[apiv1.ReplicaTypeLabel]
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype))

jc.Expectations.CreationObserved(expectationPodsKey)
// TODO: we may need add backoff here
Expand Down Expand Up @@ -205,7 +203,7 @@ func (jc *JobController) DeletePod(obj interface{}) {
}

rtype := pod.Labels[apiv1.ReplicaTypeLabel]
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype))

jc.Expectations.DeletionObserved(expectationPodsKey)
deletedPodsCount.Inc()
Expand Down Expand Up @@ -254,14 +252,14 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error)
}

// FilterPodsForReplicaType returns pods belong to a replicaType.
func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) {
func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) {
var result []*v1.Pod

replicaSelector := &metav1.LabelSelector{
MatchLabels: make(map[string]string),
}

replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType
replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType)

for _, pod := range pods {
selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
Expand Down Expand Up @@ -337,10 +335,9 @@ func (jc *JobController) ReconcilePods(
}

// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))
logger := commonutil.LoggerForReplica(metaObject, rt)
logger := commonutil.LoggerForReplica(metaObject, rtype)
// Get all pods for the type rt.
pods, err := jc.FilterPodsForReplicaType(pods, rt)
pods, err := jc.FilterPodsForReplicaType(pods, rtype)
if err != nil {
return err
}
Expand All @@ -358,13 +355,13 @@ func (jc *JobController) ReconcilePods(
podSlices := jc.GetPodSlices(pods, numReplicas, logger)
for index, podSlice := range podSlices {
if len(podSlice) > 1 {
logger.Warningf("We have too many pods for %s %d", rt, index)
logger.Warningf("We have too many pods for %s %d", rtype, index)
} else if len(podSlice) == 0 {
logger.Infof("Need to create new pod: %s-%d", rt, index)
logger.Infof("Need to create new pod: %s-%d", rtype, index)

// check if this replica is the master role
masterRole = jc.Controller.IsMasterRole(replicas, rtype, index)
err = jc.createNewPod(job, rt, strconv.Itoa(index), spec, masterRole, replicas)
err = jc.createNewPod(job, rtype, strconv.Itoa(index), spec, masterRole, replicas)
if err != nil {
return err
}
Expand Down Expand Up @@ -408,7 +405,7 @@ func (jc *JobController) ReconcilePods(
}

// createNewPod creates a new pod for the given index and type.
func (jc *JobController) createNewPod(job interface{}, rt, index string, spec *apiv1.ReplicaSpec, masterRole bool,
func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, index string, spec *apiv1.ReplicaSpec, masterRole bool,
replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error {

metaObject, ok := job.(metav1.Object)
Expand All @@ -433,7 +430,7 @@ func (jc *JobController) createNewPod(job interface{}, rt, index string, spec *a

// Set type and index for the worker.
labels := jc.GenLabels(metaObject.GetName())
labels[apiv1.ReplicaTypeLabel] = rt
labels[apiv1.ReplicaTypeLabel] = string(rt)
labels[apiv1.ReplicaIndexLabel] = index

if masterRole {
Expand Down
30 changes: 12 additions & 18 deletions pkg/controller.v1/common/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ package common

import (
"fmt"
"strconv"
"strings"

apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/kubeflow/common/pkg/controller.v1/control"
"github.com/kubeflow/common/pkg/controller.v1/expectation"
Expand All @@ -31,6 +28,7 @@ import (
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"strconv"
)

var (
Expand Down Expand Up @@ -71,8 +69,8 @@ func (jc *JobController) AddService(obj interface{}) {
return
}

rtype := service.Labels[apiv1.ReplicaTypeLabel]
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
rtypeValue := service.Labels[apiv1.ReplicaTypeLabel]
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, apiv1.ReplicaType(rtypeValue))

jc.Expectations.CreationObserved(expectationServicesKey)
// TODO: we may need add backoff here
Expand Down Expand Up @@ -137,14 +135,14 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service
}

// FilterServicesForReplicaType returns service belong to a replicaType.
func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) {
func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) {
var result []*v1.Service

replicaSelector := &metav1.LabelSelector{
MatchLabels: make(map[string]string),
}

replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType
replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType)

for _, service := range services {
selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
Expand Down Expand Up @@ -209,12 +207,9 @@ func (jc *JobController) ReconcileServices(
rtype apiv1.ReplicaType,
spec *apiv1.ReplicaSpec) error {

// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))

replicas := int(*spec.Replicas)
// Get all services for the type rt.
services, err := jc.FilterServicesForReplicaType(services, rt)
services, err := jc.FilterServicesForReplicaType(services, rtype)
if err != nil {
return err
}
Expand All @@ -225,13 +220,13 @@ func (jc *JobController) ReconcileServices(
// If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created.
//
// If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted.
serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt))
serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype))

for index, serviceSlice := range serviceSlices {
if len(serviceSlice) > 1 {
commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rt, index)
commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index)
} else if len(serviceSlice) == 0 {
commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rt, index)
commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index)
err = jc.CreateNewService(job, rtype, spec, strconv.Itoa(index))
if err != nil {
return err
Expand Down Expand Up @@ -283,16 +278,15 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica
}

// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt)
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
err = jc.Expectations.ExpectCreations(expectationServicesKey, 1)
if err != nil {
return err
}

// Append ReplicaTypeLabel and ReplicaIndexLabel labels.
labels := jc.GenLabels(job.GetName())
labels[apiv1.ReplicaTypeLabel] = rt
labels[apiv1.ReplicaTypeLabel] = string(rtype)
labels[apiv1.ReplicaIndexLabel] = index

ports, err := jc.GetPortsFromJob(spec)
Expand All @@ -314,7 +308,7 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica
service.Spec.Ports = append(service.Spec.Ports, svcPort)
}

service.Name = GenGeneralName(job.GetName(), rt, index)
service.Name = GenGeneralName(job.GetName(), rtype, index)
service.Labels = labels
// Create OwnerReference.
controllerRef := jc.GenOwnerReference(job)
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller.v1/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func (p ReplicasPriority) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}

func GenGeneralName(jobName, rtype, index string) string {
n := jobName + "-" + rtype + "-" + index
func GenGeneralName(jobName string, rtype apiv1.ReplicaType, index string) string {
n := jobName + "-" + string(rtype) + "-" + index
return strings.Replace(n, "/", "-", -1)
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/controller.v1/common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ package common

import (
"fmt"
apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/stretchr/testify/assert"
"testing"
)

func TestGenGeneralName(t *testing.T) {
testRType := "worker"
var testRType apiv1.ReplicaType = "worker"
testIndex := "1"
testKey := "1/2/3/4/5"
expectedName := fmt.Sprintf("1-2-3-4-5-%s-%s", testRType, testIndex)
Expand Down
13 changes: 8 additions & 5 deletions pkg/controller.v1/expectation/util.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package expectation

import "strings"
import (
apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"strings"
)

// GenExpectationPodsKey generates an expectation key for pods of a job
func GenExpectationPodsKey(jobKey, replicaType string) string {
return jobKey + "/" + strings.ToLower(replicaType) + "/pods"
func GenExpectationPodsKey(jobKey string, replicaType apiv1.ReplicaType) string {
return jobKey + "/" + strings.ToLower(string(replicaType)) + "/pods"
}

// GenExpectationPodsKey generates an expectation key for services of a job
func GenExpectationServicesKey(jobKey, replicaType string) string {
return jobKey + "/" + strings.ToLower(replicaType) + "/services"
func GenExpectationServicesKey(jobKey string, replicaType apiv1.ReplicaType) string {
return jobKey + "/" + strings.ToLower(string(replicaType)) + "/services"
}
3 changes: 2 additions & 1 deletion pkg/util/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package util

import (
apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"strings"

log "github.com/sirupsen/logrus"
Expand All @@ -23,7 +24,7 @@ import (
metav1unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
)

func LoggerForReplica(job metav1.Object, rtype string) *log.Entry {
func LoggerForReplica(job metav1.Object, rtype apiv1.ReplicaType) *log.Entry {
return log.WithFields(log.Fields{
// We use job to match the key used in controller.go
// Its more common in K8s to use a period to indicate namespace.name. So that's what we use.
Expand Down
2 changes: 1 addition & 1 deletion test_job/controller.v1/test_job/test_job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (t *TestJobController) UpdateJobStatusInApiServer(job interface{}, jobStatu
return nil
}

func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype commonv1.ReplicaType, index string) error {
return nil
}

Expand Down

0 comments on commit 31efa75

Please sign in to comment.