Skip to content

Commit

Permalink
copy labels and anotations to pod from tfjob (kubeflow#542)
Browse files Browse the repository at this point in the history
  • Loading branch information
u2takey authored and Penghui Yan committed Jun 18, 2018
1 parent a424377 commit 8656a8a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
18 changes: 16 additions & 2 deletions pkg/trainer/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ func (s *TFReplicaSet) CreatePodWithIndex(index int32) (*v1.Pod, error) {

pod := &v1.Pod{
ObjectMeta: meta_v1.ObjectMeta{
Name: s.genPodName(index),
Labels: taskLabels,
Name: s.genPodName(index),
Labels: taskLabels,
Annotations: map[string]string{},
OwnerReferences: []meta_v1.OwnerReference{
helper.AsOwner(s.Job.job),
},
Expand All @@ -185,6 +186,19 @@ func (s *TFReplicaSet) CreatePodWithIndex(index int32) (*v1.Pod, error) {

pod.Spec.SchedulerName = s.Job.SchedulerName()

// copy labels and annotations to pod from tfjob
for k, v := range s.Spec.Template.Labels {
if _, ok := pod.Labels[k]; !ok {
pod.Labels[k] = v
}
}

for k, v := range s.Spec.Template.Annotations {
if _, ok := pod.Annotations[k]; !ok {
pod.Annotations[k] = v
}
}

// Configure the TFCONFIG environment variable.
tfConfig := TFConfig{
Cluster: s.Job.ClusterSpec(),
Expand Down
28 changes: 23 additions & 5 deletions pkg/trainer/replicas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ func TestTFReplicaSet(t *testing.T) {
},
}

jobSpec.Spec.ReplicaSpecs[0].Template.Labels = map[string]string{"some-label": "some-value"}
jobSpec.Spec.ReplicaSpecs[0].Template.Annotations = map[string]string{"some-anno": "some-value"}
recorder := record.NewFakeRecorder(100)
job, err := initJob(clientSet, &tfJobFake.Clientset{}, recorder, jobSpec)

Expand Down Expand Up @@ -107,13 +109,25 @@ func TestTFReplicaSet(t *testing.T) {

for index := 0; index < 2; index++ {
// Expected labels
expectedLabels := map[string]string{
expectedServiceLabels := map[string]string{
"kubeflow.org": "",
"task_index": fmt.Sprintf("%v", index),
"job_type": "PS",
"runtime_id": "some-runtime",
"tf_job_name": "some-job",
}
expectedPodLabels := map[string]string{
"kubeflow.org": "",
"task_index": fmt.Sprintf("%v", index),
"job_type": "PS",
"runtime_id": "some-runtime",
"tf_job_name": "some-job",
"some-label": "some-value",
}

expectedPodAnnotations := map[string]string{
"some-anno": "some-value",
}

// Check that a service was created.
sList, err := clientSet.CoreV1().Services(replica.Job.job.ObjectMeta.Namespace).List(meta_v1.ListOptions{})
Expand All @@ -127,8 +141,8 @@ func TestTFReplicaSet(t *testing.T) {

s := sList.Items[index]

if !reflect.DeepEqual(expectedLabels, s.ObjectMeta.Labels) {
t.Fatalf("Service Labels; Got %v Want: %v", s.ObjectMeta.Labels, expectedLabels)
if !reflect.DeepEqual(expectedServiceLabels, s.ObjectMeta.Labels) {
t.Fatalf("Service Labels; Got %v Want: %v", s.ObjectMeta.Labels, expectedServiceLabels)
}

name := fmt.Sprintf("some-job-ps-some-runtime-%v", index)
Expand Down Expand Up @@ -156,8 +170,12 @@ func TestTFReplicaSet(t *testing.T) {

p := l.Items[index]

if !reflect.DeepEqual(expectedLabels, p.ObjectMeta.Labels) {
t.Fatalf("Pod Labels; Got %v Want: %v", p.ObjectMeta.Labels, expectedLabels)
if !reflect.DeepEqual(expectedPodLabels, p.ObjectMeta.Labels) {
t.Fatalf("Pod Labels; Got %v Want: %v", p.ObjectMeta.Labels, expectedPodLabels)
}

if !reflect.DeepEqual(expectedPodAnnotations, p.ObjectMeta.Annotations) {
t.Fatalf("Pod Annotations; Got %v Want: %v", p.ObjectMeta.Annotations, expectedPodAnnotations)
}

if len(p.Spec.Containers) != 1 {
Expand Down

0 comments on commit 8656a8a

Please sign in to comment.