Skip to content

Commit

Permalink
support specifying the tfjob role sequence when scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
happy2048 committed Feb 1, 2021
1 parent d746bde commit 61b5919
Showing 1 changed file with 48 additions and 8 deletions.
56 changes: 48 additions & 8 deletions pkg/controller.v1/tensorflow/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ package tensorflow

import (
"fmt"
"os"
"reflect"
"strings"
"time"

kubebatchclient "github.com/kubernetes-sigs/kube-batch/pkg/client/clientset/versioned"
log "github.com/sirupsen/logrus"
"k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
Expand Down Expand Up @@ -52,11 +53,13 @@ const (
controllerName = "tf-operator"

// labels for pods and servers.
tfReplicaTypeLabel = "tf-replica-type"
tfReplicaIndexLabel = "tf-replica-index"
labelGroupName = "group-name"
labelTFJobName = "tf-job-name"
labelTFJobRole = "tf-job-role"
tfReplicaTypeLabel = "tf-replica-type"
tfReplicaIndexLabel = "tf-replica-index"
labelGroupName = "group-name"
labelTFJobName = "tf-job-name"
labelTFJobRole = "tf-job-role"
roleSequenceEnvKey = "ROLE_SEQUENCE"
roleSequenceAnnotationKey = "job-role-sequence"
)

var (
Expand Down Expand Up @@ -445,9 +448,16 @@ func (tc *TFController) reconcileTFJobs(tfjob *tfv1.TFJob) error {

// Save the current state of the replicas
replicasStatus := make(map[string]v1.PodPhase)

// get the custom role sequence.
roles := sortTFJobRoles(tfjob.Spec.TFReplicaSpecs, pods)
logger.Infof("the Role Sequence of job %v is: %v", tfjob.Name, roles)
// Diff current active pods/services with replicas.
for rtype, spec := range tfjob.Spec.TFReplicaSpecs {
for _, rtype := range roles {
spec := tfjob.Spec.TFReplicaSpecs[rtype]
if spec == nil {
logger.Infof("this job has no role: %v,skip to handle it", rtype)
continue
}
err = tc.reconcilePods(tfjob, pods, rtype, spec, replicasStatus)
if err != nil {
logger.Warnf("reconcilePods error %v", err)
Expand Down Expand Up @@ -588,3 +598,33 @@ func (tc *TFController) GetJobRoleKey() string {
func (tc *TFController) ControllerName() string {
return controllerName
}

func sortTFJobRoles(roleSpecs map[tfv1.TFReplicaType]*common.ReplicaSpec, pods []*v1.Pod) []tfv1.TFReplicaType {
var sortRolesFromStrs = func(customRoleSeq []string) []tfv1.TFReplicaType {
roles := []tfv1.TFReplicaType{}
exists := map[tfv1.TFReplicaType]bool{}
for _, r := range customRoleSeq {
roles = append(roles, tfv1.TFReplicaType(r))
exists[tfv1.TFReplicaType(r)] = true
}
for rtype := range roleSpecs {
if _, ok := exists[rtype]; !ok {
roles = append(roles, rtype)
exists[rtype] = true
}
}
return roles
}
// if the pod annotation has defined the role sequence,use it firstly
for _, pod := range pods {
if pod.Annotations[roleSequenceAnnotationKey] != "" {
return sortRolesFromStrs(strings.Split(pod.Annotations[roleSequenceAnnotationKey], ","))
}
}
// if the tfjob operator has defined the role sequence, use it
if os.Getenv(roleSequenceEnvKey) != "" {
return sortRolesFromStrs(strings.Split(os.Getenv(roleSequenceEnvKey), ","))
}
// use the random role sequence
return sortRolesFromStrs([]string{})
}

0 comments on commit 61b5919

Please sign in to comment.