diff --git a/pkg/trainer/replicas.go b/pkg/trainer/replicas.go index 29c3b305f2..c8cb6407d4 100644 --- a/pkg/trainer/replicas.go +++ b/pkg/trainer/replicas.go @@ -174,8 +174,9 @@ func (s *TFReplicaSet) CreatePodWithIndex(index int32) (*v1.Pod, error) { pod := &v1.Pod{ ObjectMeta: meta_v1.ObjectMeta{ - Name: s.genPodName(index), - Labels: taskLabels, + Name: s.genPodName(index), + Labels: taskLabels, + Annotations: map[string]string{}, OwnerReferences: []meta_v1.OwnerReference{ helper.AsOwner(s.Job.job), }, @@ -185,6 +186,19 @@ func (s *TFReplicaSet) CreatePodWithIndex(index int32) (*v1.Pod, error) { pod.Spec.SchedulerName = s.Job.SchedulerName() + // copy labels and annotations to pod from tfjob + for k, v := range s.Spec.Template.Labels { + if _, ok := pod.Labels[k]; !ok { + pod.Labels[k] = v + } + } + + for k, v := range s.Spec.Template.Annotations { + if _, ok := pod.Annotations[k]; !ok { + pod.Annotations[k] = v + } + } + // Configure the TFCONFIG environment variable. tfConfig := TFConfig{ Cluster: s.Job.ClusterSpec(), diff --git a/pkg/trainer/replicas_test.go b/pkg/trainer/replicas_test.go index 2fe0345070..502a32d6f9 100644 --- a/pkg/trainer/replicas_test.go +++ b/pkg/trainer/replicas_test.go @@ -74,6 +74,8 @@ func TestTFReplicaSet(t *testing.T) { }, } + jobSpec.Spec.ReplicaSpecs[0].Template.Labels = map[string]string{"some-label": "some-value"} + jobSpec.Spec.ReplicaSpecs[0].Template.Annotations = map[string]string{"some-anno": "some-value"} recorder := record.NewFakeRecorder(100) job, err := initJob(clientSet, &tfJobFake.Clientset{}, recorder, jobSpec) @@ -107,13 +109,25 @@ func TestTFReplicaSet(t *testing.T) { for index := 0; index < 2; index++ { // Expected labels - expectedLabels := map[string]string{ + expectedServiceLabels := map[string]string{ "kubeflow.org": "", "task_index": fmt.Sprintf("%v", index), "job_type": "PS", "runtime_id": "some-runtime", "tf_job_name": "some-job", } + expectedPodLabels := map[string]string{ + "kubeflow.org": "", + "task_index": fmt.Sprintf("%v", index), + "job_type": "PS", + "runtime_id": "some-runtime", + "tf_job_name": "some-job", + "some-label": "some-value", + } + + expectedPodAnnotations := map[string]string{ + "some-anno": "some-value", + } // Check that a service was created. sList, err := clientSet.CoreV1().Services(replica.Job.job.ObjectMeta.Namespace).List(meta_v1.ListOptions{}) @@ -127,8 +141,8 @@ func TestTFReplicaSet(t *testing.T) { s := sList.Items[index] - if !reflect.DeepEqual(expectedLabels, s.ObjectMeta.Labels) { - t.Fatalf("Service Labels; Got %v Want: %v", s.ObjectMeta.Labels, expectedLabels) + if !reflect.DeepEqual(expectedServiceLabels, s.ObjectMeta.Labels) { + t.Fatalf("Service Labels; Got %v Want: %v", s.ObjectMeta.Labels, expectedServiceLabels) } name := fmt.Sprintf("some-job-ps-some-runtime-%v", index) @@ -156,8 +170,12 @@ func TestTFReplicaSet(t *testing.T) { p := l.Items[index] - if !reflect.DeepEqual(expectedLabels, p.ObjectMeta.Labels) { - t.Fatalf("Pod Labels; Got %v Want: %v", p.ObjectMeta.Labels, expectedLabels) + if !reflect.DeepEqual(expectedPodLabels, p.ObjectMeta.Labels) { + t.Fatalf("Pod Labels; Got %v Want: %v", p.ObjectMeta.Labels, expectedPodLabels) + } + + if !reflect.DeepEqual(expectedPodAnnotations, p.ObjectMeta.Annotations) { + t.Fatalf("Pod Annotations; Got %v Want: %v", p.ObjectMeta.Annotations, expectedPodAnnotations) } if len(p.Spec.Containers) != 1 {