Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructing common utility functions #792

Merged
merged 2 commits into from
Aug 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/tf-operator.v2/app/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package options

import (
"flag"

"k8s.io/api/core/v1"
)

Expand Down
47 changes: 47 additions & 0 deletions pkg/controller.v2/jobcontroller/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package jobcontroller
import (
"fmt"
"reflect"
"strconv"

log "github.com/sirupsen/logrus"
"k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
Expand Down Expand Up @@ -190,3 +192,48 @@ func (jc *JobController) GetPodsForJob(job metav1.Object) ([]*v1.Pod, error) {
cm := controller.NewPodControllerRefManager(jc.PodControl, job, selector, jc.Controller.GetAPIGroupVersionKind(), canAdoptFunc)
return cm.ClaimPods(pods)
}

// FilterPodsForReplicaType returns pods belong to a replicaType.
func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) {
var result []*v1.Pod

replicaSelector := &metav1.LabelSelector{
MatchLabels: make(map[string]string),
}

replicaSelector.MatchLabels[jc.Controller.GetReplicaTypeLabelKey()] = replicaType

for _, pod := range pods {
selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
if err != nil {
return nil, err
}
if !selector.Matches(labels.Set(pod.Labels)) {
continue
}
result = append(result, pod)
}
return result, nil
}

// getPodSlices returns a slice, which element is the slice of pod.
func (jc *JobController) GetPodSlices(pods []*v1.Pod, replicas int, logger *log.Entry) [][]*v1.Pod {
podSlices := make([][]*v1.Pod, replicas)
for _, pod := range pods {
if _, ok := pod.Labels[jc.Controller.GetReplicaIndexLabelKey()]; !ok {
logger.Warning("The pod do not have the index label.")
continue
}
index, err := strconv.Atoi(pod.Labels[jc.Controller.GetReplicaIndexLabelKey()])
if err != nil {
logger.Warningf("Error when strconv.Atoi: %v", err)
continue
}
if index < 0 || index >= replicas {
logger.Warningf("The label index is not expected: %d", index)
} else {
podSlices[index] = append(podSlices[index], pod)
}
}
return podSlices
}
48 changes: 48 additions & 0 deletions pkg/controller.v2/jobcontroller/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jobcontroller

import (
"fmt"
"strconv"

log "github.com/sirupsen/logrus"
"k8s.io/api/core/v1"
Expand Down Expand Up @@ -98,3 +99,50 @@ func (jc *JobController) GetServicesForJob(job metav1.Object) ([]*v1.Service, er
cm := control.NewServiceControllerRefManager(jc.ServiceControl, job, selector, jc.Controller.GetAPIGroupVersionKind(), canAdoptFunc)
return cm.ClaimServices(services)
}

// FilterServicesForReplicaType returns service belong to a replicaType.
func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) {
var result []*v1.Service

replicaSelector := &metav1.LabelSelector{
MatchLabels: make(map[string]string),
}

replicaSelector.MatchLabels[jc.Controller.GetReplicaTypeLabelKey()] = replicaType

for _, service := range services {
selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
if err != nil {
return nil, err
}
if !selector.Matches(labels.Set(service.Labels)) {
continue
}
result = append(result, service)
}
return result, nil
}

// getServiceSlices returns a slice, which element is the slice of service.
// Assume the return object is serviceSlices, then serviceSlices[i] is an
// array of pointers to services corresponding to Services for replica i.
func (jc *JobController) GetServiceSlices(services []*v1.Service, replicas int, logger *log.Entry) [][]*v1.Service {
serviceSlices := make([][]*v1.Service, replicas)
for _, service := range services {
if _, ok := service.Labels[jc.Controller.GetReplicaIndexLabelKey()]; !ok {
logger.Warning("The service do not have the index label.")
continue
}
index, err := strconv.Atoi(service.Labels[jc.Controller.GetReplicaIndexLabelKey()])
if err != nil {
logger.Warningf("Error when strconv.Atoi: %v", err)
continue
}
if index < 0 || index >= replicas {
logger.Warningf("The label index is not expected: %d", index)
} else {
serviceSlices[index] = append(serviceSlices[index], service)
}
}
return serviceSlices
}
52 changes: 2 additions & 50 deletions pkg/controller.v2/tfcontroller/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ import (
"strconv"
"strings"

log "github.com/sirupsen/logrus"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"

tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2"
Expand Down Expand Up @@ -54,7 +51,7 @@ func (tc *TFController) reconcilePods(
rt := strings.ToLower(string(rtype))
logger := tflogger.LoggerForReplica(tfjob, rt)
// Get all pods for the type rt.
pods, err := filterPodsForTFReplicaType(pods, rt)
pods, err := tc.FilterPodsForReplicaType(pods, rt)
if err != nil {
return err
}
Expand All @@ -63,7 +60,7 @@ func (tc *TFController) reconcilePods(

initializeTFReplicaStatuses(tfjob, rtype)

podSlices := getPodSlices(pods, replicas, logger)
podSlices := tc.GetPodSlices(pods, replicas, logger)
for index, podSlice := range podSlices {
if len(podSlice) > 1 {
logger.Warningf("We have too many pods for %s %d", rt, index)
Expand Down Expand Up @@ -102,28 +99,6 @@ func (tc *TFController) reconcilePods(
return updateStatusSingle(tfjob, rtype, replicas, restart)
}

// getPodSlices returns a slice, which element is the slice of pod.
func getPodSlices(pods []*v1.Pod, replicas int, logger *log.Entry) [][]*v1.Pod {
podSlices := make([][]*v1.Pod, replicas)
for _, pod := range pods {
if _, ok := pod.Labels[tfReplicaIndexLabel]; !ok {
logger.Warning("The pod do not have the index label.")
continue
}
index, err := strconv.Atoi(pod.Labels[tfReplicaIndexLabel])
if err != nil {
logger.Warningf("Error when strconv.Atoi: %v", err)
continue
}
if index < 0 || index >= replicas {
logger.Warningf("The label index is not expected: %d", index)
} else {
podSlices[index] = append(podSlices[index], pod)
}
}
return podSlices
}

// createNewPod creates a new pod for the given index and type.
func (tc *TFController) createNewPod(tfjob *tfv1alpha2.TFJob, rt, index string, spec *tfv1alpha2.TFReplicaSpec) error {
tfjobKey, err := KeyFunc(tfjob)
Expand Down Expand Up @@ -217,26 +192,3 @@ func setRestartPolicy(podTemplateSpec *v1.PodTemplateSpec, spec *tfv1alpha2.TFRe
podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicy(spec.RestartPolicy)
}
}

// filterPodsForTFReplicaType returns pods belong to a TFReplicaType.
func filterPodsForTFReplicaType(pods []*v1.Pod, tfReplicaType string) ([]*v1.Pod, error) {
var result []*v1.Pod

tfReplicaSelector := &metav1.LabelSelector{
MatchLabels: make(map[string]string),
}

tfReplicaSelector.MatchLabels[tfReplicaTypeLabel] = tfReplicaType

for _, pod := range pods {
selector, err := metav1.LabelSelectorAsSelector(tfReplicaSelector)
if err != nil {
return nil, err
}
if !selector.Matches(labels.Set(pod.Labels)) {
continue
}
result = append(result, pod)
}
return result, nil
}
54 changes: 2 additions & 52 deletions pkg/controller.v2/tfcontroller/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ import (
"strconv"
"strings"

log "github.com/sirupsen/logrus"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"

tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2"
Expand All @@ -45,12 +42,12 @@ func (tc *TFController) reconcileServices(

replicas := int(*spec.Replicas)
// Get all services for the type rt.
services, err := filterServicesForTFReplicaType(services, rt)
services, err := tc.FilterServicesForReplicaType(services, rt)
if err != nil {
return err
}

serviceSlices := getServiceSlices(services, replicas, tflogger.LoggerForReplica(tfjob, rt))
serviceSlices := tc.GetServiceSlices(services, replicas, tflogger.LoggerForReplica(tfjob, rt))

for index, serviceSlice := range serviceSlices {
if len(serviceSlice) > 1 {
Expand All @@ -68,30 +65,6 @@ func (tc *TFController) reconcileServices(
return nil
}

// getServiceSlices returns a slice, which element is the slice of service.
// Assume the return object is serviceSlices, then serviceSlices[i] is an
// array of pointers to services corresponding to Services for replica i.
func getServiceSlices(services []*v1.Service, replicas int, logger *log.Entry) [][]*v1.Service {
serviceSlices := make([][]*v1.Service, replicas)
for _, service := range services {
if _, ok := service.Labels[tfReplicaIndexLabel]; !ok {
logger.Warning("The service do not have the index label.")
continue
}
index, err := strconv.Atoi(service.Labels[tfReplicaIndexLabel])
if err != nil {
logger.Warningf("Error when strconv.Atoi: %v", err)
continue
}
if index < 0 || index >= replicas {
logger.Warningf("The label index is not expected: %d", index)
} else {
serviceSlices[index] = append(serviceSlices[index], service)
}
}
return serviceSlices
}

// createNewService creates a new service for the given index and type.
func (tc *TFController) createNewService(tfjob *tfv1alpha2.TFJob, rtype tfv1alpha2.TFReplicaType, index string, spec *tfv1alpha2.TFReplicaSpec) error {
tfjobKey, err := KeyFunc(tfjob)
Expand Down Expand Up @@ -152,26 +125,3 @@ func (tc *TFController) createNewService(tfjob *tfv1alpha2.TFJob, rtype tfv1alph
}
return nil
}

// filterServicesForTFReplicaType returns service belong to a TFReplicaType.
func filterServicesForTFReplicaType(services []*v1.Service, tfReplicaType string) ([]*v1.Service, error) {
var result []*v1.Service

tfReplicaSelector := &metav1.LabelSelector{
MatchLabels: make(map[string]string),
}

tfReplicaSelector.MatchLabels[tfReplicaTypeLabel] = tfReplicaType

for _, service := range services {
selector, err := metav1.LabelSelectorAsSelector(tfReplicaSelector)
if err != nil {
return nil, err
}
if !selector.Matches(labels.Set(service.Labels)) {
continue
}
result = append(result, service)
}
return result, nil
}