From f2bbf41f59e31f1fd5b9504931581001ca52a6a6 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Fri, 15 Jun 2018 12:42:28 +0800 Subject: [PATCH] *: Move test util to separate package (#666) * *: Refactor Signed-off-by: Ce Gao * travis: Remove util from coverage test Signed-off-by: Ce Gao * test: Add copyright holder Signed-off-by: Ce Gao * *: Fix errors Signed-off-by: Ce Gao * generator: Fix test Signed-off-by: Ce Gao * pods: Add error handler Signed-off-by: Ce Gao * testutil: Fix linting errors Signed-off-by: Ce Gao --- .travis.yml | 2 +- pkg/controller.v2/controller_pod.go | 9 +- pkg/controller.v2/controller_pod_test.go | 111 +++++------------ pkg/controller.v2/controller_service.go | 11 +- pkg/controller.v2/controller_service_test.go | 57 ++------- pkg/controller.v2/controller_status.go | 3 +- pkg/controller.v2/controller_status_test.go | 33 ++--- pkg/controller.v2/controller_tensorflow.go | 5 +- pkg/controller.v2/controller_test.go | 46 +++---- pkg/controller.v2/controller_tfjob.go | 4 + pkg/controller.v2/controller_tfjob_test.go | 113 +++--------------- pkg/controller.v2/pod_control_test.go | 11 +- pkg/controller.v2/service_control_test.go | 18 +-- pkg/controller.v2/service_ref_manager_test.go | 65 +++++----- .../generator.go} | 28 ++--- .../generator_test.go} | 14 +-- pkg/util/testutil/const.go | 33 +++++ pkg/util/testutil/pod.go | 93 ++++++++++++++ pkg/util/testutil/service.go | 63 ++++++++++ pkg/util/testutil/tfjob.go | 84 +++++++++++++ pkg/util/testutil/util.go | 49 ++++++++ 21 files changed, 509 insertions(+), 343 deletions(-) rename pkg/{controller.v2/controller_helper.go => generator/generator.go} (75%) rename pkg/{controller.v2/controller_helper_test.go => generator/generator_test.go} (86%) create mode 100644 pkg/util/testutil/const.go create mode 100644 pkg/util/testutil/pod.go create mode 100644 pkg/util/testutil/service.go create mode 100644 pkg/util/testutil/tfjob.go create mode 100644 pkg/util/testutil/util.go diff --git a/.travis.yml b/.travis.yml index 7566657428..4261c23e68 100644 --- a/.travis.yml +++ b/.travis.yml @@ -30,7 +30,7 @@ script: # For now though we just run all tests in pkg. # And we can not use ** because goveralls uses filepath.Match # to match ignore files and it does not support it. - - goveralls -service=travis-ci -v -package ./pkg/... -ignore "pkg/client/*/*.go,pkg/client/*/*/*.go,pkg/client/*/*/*/*.go,pkg/client/*/*/*/*/*.go,pkg/client/*/*/*/*/*/*.go,pkg/client/*/*/*/*/*/*/*.go,pkg/apis/tensorflow/*/zz_generated.*.go,pkg/apis/tensorflow/*/*_generated.go" + - goveralls -service=travis-ci -v -package ./pkg/... -ignore "pkg/client/*/*.go,pkg/client/*/*/*.go,pkg/client/*/*/*/*.go,pkg/client/*/*/*/*/*.go,pkg/client/*/*/*/*/*/*.go,pkg/client/*/*/*/*/*/*/*.go,pkg/util/testutil/*.go,pkg/apis/tensorflow/*/zz_generated.*.go,pkg/apis/tensorflow/*/*_generated.go" notifications: webhooks: https://www.travisbuddy.com/ diff --git a/pkg/controller.v2/controller_pod.go b/pkg/controller.v2/controller_pod.go index 3ed2eb47e0..0fc874ccac 100644 --- a/pkg/controller.v2/controller_pod.go +++ b/pkg/controller.v2/controller_pod.go @@ -31,6 +31,7 @@ import ( "k8s.io/kubernetes/pkg/controller" tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" + "github.com/kubeflow/tf-operator/pkg/generator" train_util "github.com/kubeflow/tf-operator/pkg/util/train" ) @@ -133,17 +134,17 @@ func (tc *TFJobController) createNewPod(tfjob *tfv1alpha2.TFJob, rt, index strin } // Create OwnerReference. - controllerRef := genOwnerReference(tfjob) + controllerRef := generator.GenOwnerReference(tfjob) // Set type and index for the worker. - labels := genLabels(tfjobKey) + labels := generator.GenLabels(tfjobKey) labels[tfReplicaTypeLabel] = rt labels[tfReplicaIndexLabel] = index podTemplate := spec.Template.DeepCopy() // Set name for the template. - podTemplate.Name = genGeneralName(tfjob.Name, rt, index) + podTemplate.Name = generator.GenGeneralName(tfjob.Name, rt, index) if podTemplate.Labels == nil { podTemplate.Labels = make(map[string]string) @@ -228,7 +229,7 @@ func (tc *TFJobController) getPodsForTFJob(tfjob *tfv1alpha2.TFJob) ([]*v1.Pod, // Create selector. selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: genLabels(tfjobKey), + MatchLabels: generator.GenLabels(tfjobKey), }) if err != nil { diff --git a/pkg/controller.v2/controller_pod_test.go b/pkg/controller.v2/controller_pod_test.go index 5ab47e09c3..b3e48bc17d 100644 --- a/pkg/controller.v2/controller_pod_test.go +++ b/pkg/controller.v2/controller_pod_test.go @@ -16,68 +16,19 @@ package controller import ( - "fmt" "testing" "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" kubeclientset "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/cache" "k8s.io/kubernetes/pkg/controller" tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" tfjobclientset "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned" + "github.com/kubeflow/tf-operator/pkg/generator" + "github.com/kubeflow/tf-operator/pkg/util/testutil" ) -func newBasePod(name string, tfJob *tfv1alpha2.TFJob, t *testing.T) *v1.Pod { - return &v1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: name, - Labels: genLabels(getKey(tfJob, t)), - Namespace: tfJob.Namespace, - OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)}, - }, - } -} - -func newPod(tfJob *tfv1alpha2.TFJob, typ string, index int, t *testing.T) *v1.Pod { - pod := newBasePod(fmt.Sprintf("%s-%d", typ, index), tfJob, t) - pod.Labels[tfReplicaTypeLabel] = typ - pod.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index) - return pod -} - -// create count pods with the given phase for the given tfJob -func newPodList(count int32, status v1.PodPhase, tfJob *tfv1alpha2.TFJob, typ string, start int32, t *testing.T) []*v1.Pod { - pods := []*v1.Pod{} - for i := int32(0); i < count; i++ { - newPod := newPod(tfJob, typ, int(start+i), t) - newPod.Status = v1.PodStatus{Phase: status} - pods = append(pods, newPod) - } - return pods -} - -func setPodsStatuses(podIndexer cache.Indexer, tfJob *tfv1alpha2.TFJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, t *testing.T) { - var index int32 - for _, pod := range newPodList(pendingPods, v1.PodPending, tfJob, typ, index, t) { - podIndexer.Add(pod) - } - index += pendingPods - for _, pod := range newPodList(activePods, v1.PodRunning, tfJob, typ, index, t) { - podIndexer.Add(pod) - } - index += activePods - for _, pod := range newPodList(succeededPods, v1.PodSucceeded, tfJob, typ, index, t) { - podIndexer.Add(pod) - } - index += succeededPods - for _, pod := range newPodList(failedPods, v1.PodFailed, tfJob, typ, index, t) { - podIndexer.Add(pod) - } -} - func TestAddPod(t *testing.T) { // Prepare the clientset and controller for the test. kubeClientSet := kubeclientset.NewForConfigOrDie(&rest.Config{ @@ -95,14 +46,14 @@ func TestAddPod(t *testing.T) { } tfJobClientSet := tfjobclientset.NewForConfigOrDie(config) ctr, _, _ := newTFJobController(config, kubeClientSet, tfJobClientSet, controller.NoResyncPeriodFunc) - ctr.tfJobInformerSynced = alwaysReady - ctr.podInformerSynced = alwaysReady - ctr.serviceInformerSynced = alwaysReady + ctr.tfJobInformerSynced = testutil.AlwaysReady + ctr.podInformerSynced = testutil.AlwaysReady + ctr.serviceInformerSynced = testutil.AlwaysReady tfJobIndexer := ctr.tfJobInformer.GetIndexer() stopCh := make(chan struct{}) run := func(<-chan struct{}) { - ctr.Run(threadCount, stopCh) + ctr.Run(testutil.ThreadCount, stopCh) } go run(stopCh) @@ -114,8 +65,8 @@ func TestAddPod(t *testing.T) { return true, nil } - tfJob := newTFJob(1, 0) - unstructured, err := convertTFJobToUnstructured(tfJob) + tfJob := testutil.NewTFJob(1, 0) + unstructured, err := generator.ConvertTFJobToUnstructured(tfJob) if err != nil { t.Errorf("Failed to convert the TFJob to Unstructured: %v", err) } @@ -123,12 +74,12 @@ func TestAddPod(t *testing.T) { if err := tfJobIndexer.Add(unstructured); err != nil { t.Errorf("Failed to add tfjob to tfJobIndexer: %v", err) } - pod := newPod(tfJob, labelWorker, 0, t) + pod := testutil.NewPod(tfJob, testutil.LabelWorker, 0, t) ctr.addPod(pod) syncChan <- "sync" - if key != getKey(tfJob, t) { - t.Errorf("Failed to enqueue the TFJob %s: expected %s, got %s", tfJob.Name, getKey(tfJob, t), key) + if key != testutil.GetKey(tfJob, t) { + t.Errorf("Failed to enqueue the TFJob %s: expected %s, got %s", tfJob.Name, testutil.GetKey(tfJob, t), key) } close(stopCh) } @@ -142,18 +93,18 @@ func TestClusterSpec(t *testing.T) { } testCase := []tc{ tc{ - tfJob: newTFJob(1, 0), + tfJob: testutil.NewTFJob(1, 0), rt: "worker", index: "0", - expectedClusterSpec: `{"cluster":{"worker":["` + testTFJobName + + expectedClusterSpec: `{"cluster":{"worker":["` + testutil.TestTFJobName + `-worker-0.default.svc.cluster.local:2222"]},"task":{"type":"worker","index":0}}`, }, tc{ - tfJob: newTFJob(1, 1), + tfJob: testutil.NewTFJob(1, 1), rt: "worker", index: "0", - expectedClusterSpec: `{"cluster":{"ps":["` + testTFJobName + - `-ps-0.default.svc.cluster.local:2222"],"worker":["` + testTFJobName + + expectedClusterSpec: `{"cluster":{"ps":["` + testutil.TestTFJobName + + `-ps-0.default.svc.cluster.local:2222"],"worker":["` + testutil.TestTFJobName + `-worker-0.default.svc.cluster.local:2222"]},"task":{"type":"worker","index":0}}`, }, } @@ -177,7 +128,7 @@ func TestRestartPolicy(t *testing.T) { } testCase := []tc{ func() tc { - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) specRestartPolicy := tfv1alpha2.RestartPolicyExitCode tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy return tc{ @@ -187,7 +138,7 @@ func TestRestartPolicy(t *testing.T) { } }(), func() tc { - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) specRestartPolicy := tfv1alpha2.RestartPolicyNever tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy return tc{ @@ -197,7 +148,7 @@ func TestRestartPolicy(t *testing.T) { } }(), func() tc { - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) specRestartPolicy := tfv1alpha2.RestartPolicyAlways tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy return tc{ @@ -207,7 +158,7 @@ func TestRestartPolicy(t *testing.T) { } }(), func() tc { - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) specRestartPolicy := tfv1alpha2.RestartPolicyOnFailure tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy return tc{ @@ -217,7 +168,7 @@ func TestRestartPolicy(t *testing.T) { } }(), func() tc { - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) specRestartPolicy := tfv1alpha2.RestartPolicy("") tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy return tc{ @@ -256,15 +207,15 @@ func TestExitCode(t *testing.T) { ctr, kubeInformerFactory, _ := newTFJobController(config, kubeClientSet, tfJobClientSet, controller.NoResyncPeriodFunc) fakePodControl := &controller.FakePodControl{} ctr.podControl = fakePodControl - ctr.tfJobInformerSynced = alwaysReady - ctr.podInformerSynced = alwaysReady - ctr.serviceInformerSynced = alwaysReady + ctr.tfJobInformerSynced = testutil.AlwaysReady + ctr.podInformerSynced = testutil.AlwaysReady + ctr.serviceInformerSynced = testutil.AlwaysReady tfJobIndexer := ctr.tfJobInformer.GetIndexer() podIndexer := kubeInformerFactory.Core().V1().Pods().Informer().GetIndexer() stopCh := make(chan struct{}) run := func(<-chan struct{}) { - ctr.Run(threadCount, stopCh) + ctr.Run(testutil.ThreadCount, stopCh) } go run(stopCh) @@ -272,9 +223,9 @@ func TestExitCode(t *testing.T) { return nil } - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = tfv1alpha2.RestartPolicyExitCode - unstructured, err := convertTFJobToUnstructured(tfJob) + unstructured, err := generator.ConvertTFJobToUnstructured(tfJob) if err != nil { t.Errorf("Failed to convert the TFJob to Unstructured: %v", err) } @@ -282,7 +233,7 @@ func TestExitCode(t *testing.T) { if err := tfJobIndexer.Add(unstructured); err != nil { t.Errorf("Failed to add tfjob to tfJobIndexer: %v", err) } - pod := newPod(tfJob, labelWorker, 0, t) + pod := testutil.NewPod(tfJob, testutil.LabelWorker, 0, t) pod.Status.Phase = v1.PodFailed pod.Spec.Containers = append(pod.Spec.Containers, v1.Container{}) pod.Status.ContainerStatuses = append(pod.Status.ContainerStatuses, v1.ContainerStatus{ @@ -294,8 +245,10 @@ func TestExitCode(t *testing.T) { }, }) - podIndexer.Add(pod) - _, err = ctr.syncTFJob(getKey(tfJob, t)) + if err := podIndexer.Add(pod); err != nil { + t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) + } + _, err = ctr.syncTFJob(testutil.GetKey(tfJob, t)) if err != nil { t.Errorf("%s: unexpected error when syncing jobs %v", tfJob.Name, err) } diff --git a/pkg/controller.v2/controller_service.go b/pkg/controller.v2/controller_service.go index 9ca61d6668..7f4cac456b 100644 --- a/pkg/controller.v2/controller_service.go +++ b/pkg/controller.v2/controller_service.go @@ -28,6 +28,7 @@ import ( utilruntime "k8s.io/apimachinery/pkg/util/runtime" tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" + "github.com/kubeflow/tf-operator/pkg/generator" ) // reconcileServices checks and updates services for each given TFReplicaSpec. @@ -104,14 +105,14 @@ func (tc *TFJobController) createNewService(tfjob *tfv1alpha2.TFJob, rtype tfv1a } // Create OwnerReference. - controllerRef := genOwnerReference(tfjob) + controllerRef := generator.GenOwnerReference(tfjob) // Append tfReplicaTypeLabel and tfReplicaIndexLabel labels. - labels := genLabels(tfjobKey) + labels := generator.GenLabels(tfjobKey) labels[tfReplicaTypeLabel] = rt labels[tfReplicaIndexLabel] = index - port, err := getPortFromTFJob(tfjob, rtype) + port, err := generator.GetPortFromTFJob(tfjob, rtype) if err != nil { return err } @@ -129,7 +130,7 @@ func (tc *TFJobController) createNewService(tfjob *tfv1alpha2.TFJob, rtype tfv1a }, } - service.Name = genGeneralName(tfjob.Name, rt, index) + service.Name = generator.GenGeneralName(tfjob.Name, rt, index) service.Labels = labels err = tc.serviceControl.CreateServicesWithControllerRef(tfjob.Namespace, service, tfjob, controllerRef) @@ -160,7 +161,7 @@ func (tc *TFJobController) getServicesForTFJob(tfjob *tfv1alpha2.TFJob) ([]*v1.S // Create selector selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: genLabels(tfjobKey), + MatchLabels: generator.GenLabels(tfjobKey), }) if err != nil { diff --git a/pkg/controller.v2/controller_service_test.go b/pkg/controller.v2/controller_service_test.go index ae05750839..7cd371b6dc 100644 --- a/pkg/controller.v2/controller_service_test.go +++ b/pkg/controller.v2/controller_service_test.go @@ -16,54 +16,19 @@ package controller import ( - "fmt" "testing" "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" kubeclientset "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/cache" "k8s.io/kubernetes/pkg/controller" tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" tfjobclientset "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned" + "github.com/kubeflow/tf-operator/pkg/generator" + "github.com/kubeflow/tf-operator/pkg/util/testutil" ) -func newBaseService(name string, tfJob *tfv1alpha2.TFJob, t *testing.T) *v1.Service { - return &v1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: name, - Labels: genLabels(getKey(tfJob, t)), - Namespace: tfJob.Namespace, - OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)}, - }, - } -} - -func newService(tfJob *tfv1alpha2.TFJob, typ string, index int, t *testing.T) *v1.Service { - service := newBaseService(fmt.Sprintf("%s-%d", typ, index), tfJob, t) - service.Labels[tfReplicaTypeLabel] = typ - service.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index) - return service -} - -// create count pods with the given phase for the given tfJob -func newServiceList(count int32, tfJob *tfv1alpha2.TFJob, typ string, t *testing.T) []*v1.Service { - services := []*v1.Service{} - for i := int32(0); i < count; i++ { - newService := newService(tfJob, typ, int(i), t) - services = append(services, newService) - } - return services -} - -func setServices(serviceIndexer cache.Indexer, tfJob *tfv1alpha2.TFJob, typ string, activeWorkerServices int32, t *testing.T) { - for _, service := range newServiceList(activeWorkerServices, tfJob, typ, t) { - serviceIndexer.Add(service) - } -} - func TestAddService(t *testing.T) { // Prepare the clientset and controller for the test. kubeClientSet := kubeclientset.NewForConfigOrDie(&rest.Config{ @@ -81,14 +46,14 @@ func TestAddService(t *testing.T) { } tfJobClientSet := tfjobclientset.NewForConfigOrDie(config) ctr, _, _ := newTFJobController(config, kubeClientSet, tfJobClientSet, controller.NoResyncPeriodFunc) - ctr.tfJobInformerSynced = alwaysReady - ctr.podInformerSynced = alwaysReady - ctr.serviceInformerSynced = alwaysReady + ctr.tfJobInformerSynced = testutil.AlwaysReady + ctr.podInformerSynced = testutil.AlwaysReady + ctr.serviceInformerSynced = testutil.AlwaysReady tfJobIndexer := ctr.tfJobInformer.GetIndexer() stopCh := make(chan struct{}) run := func(<-chan struct{}) { - ctr.Run(threadCount, stopCh) + ctr.Run(testutil.ThreadCount, stopCh) } go run(stopCh) @@ -100,8 +65,8 @@ func TestAddService(t *testing.T) { return true, nil } - tfJob := newTFJob(1, 0) - unstructured, err := convertTFJobToUnstructured(tfJob) + tfJob := testutil.NewTFJob(1, 0) + unstructured, err := generator.ConvertTFJobToUnstructured(tfJob) if err != nil { t.Errorf("Failed to convert the TFJob to Unstructured: %v", err) } @@ -109,12 +74,12 @@ func TestAddService(t *testing.T) { if err := tfJobIndexer.Add(unstructured); err != nil { t.Errorf("Failed to add tfjob to tfJobIndexer: %v", err) } - service := newService(tfJob, labelWorker, 0, t) + service := testutil.NewService(tfJob, testutil.LabelWorker, 0, t) ctr.addService(service) syncChan <- "sync" - if key != getKey(tfJob, t) { - t.Errorf("Failed to enqueue the TFJob %s: expected %s, got %s", tfJob.Name, getKey(tfJob, t), key) + if key != testutil.GetKey(tfJob, t) { + t.Errorf("Failed to enqueue the TFJob %s: expected %s, got %s", tfJob.Name, testutil.GetKey(tfJob, t), key) } close(stopCh) } diff --git a/pkg/controller.v2/controller_status.go b/pkg/controller.v2/controller_status.go index 14505c83aa..5c3ec573d6 100644 --- a/pkg/controller.v2/controller_status.go +++ b/pkg/controller.v2/controller_status.go @@ -22,6 +22,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" + "github.com/kubeflow/tf-operator/pkg/generator" ) const ( @@ -48,7 +49,7 @@ func updateStatus(tfjob *tfv1alpha2.TFJob, rtype tfv1alpha2.TFReplicaType, repli tfjob.Status.StartTime = &now } - if containChiefSpec(tfjob) { + if generator.ContainChiefSpec(tfjob) { if rtype == tfv1alpha2.TFReplicaTypeChief { if running > 0 { msg := fmt.Sprintf("TFJob %s is running.", tfjob.Name) diff --git a/pkg/controller.v2/controller_status_test.go b/pkg/controller.v2/controller_status_test.go index 6330083e63..24c10becd3 100644 --- a/pkg/controller.v2/controller_status_test.go +++ b/pkg/controller.v2/controller_status_test.go @@ -21,12 +21,13 @@ import ( "k8s.io/api/core/v1" tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" + "github.com/kubeflow/tf-operator/pkg/util/testutil" ) func TestFailed(t *testing.T) { - tfJob := newTFJob(3, 0) + tfJob := testutil.NewTFJob(3, 0) initializeTFReplicaStatuses(tfJob, tfv1alpha2.TFReplicaTypeWorker) - pod := newBasePod("pod", tfJob, t) + pod := testutil.NewBasePod("pod", tfJob, t) pod.Status.Phase = v1.PodFailed updateTFJobReplicaStatuses(tfJob, tfv1alpha2.TFReplicaTypeWorker, pod) if tfJob.Status.TFReplicaStatuses[tfv1alpha2.TFReplicaTypeWorker].Failed != 1 { @@ -70,7 +71,7 @@ func TestStatus(t *testing.T) { testCases := []testCase{ testCase{ description: "Chief worker is succeeded", - tfJob: newTFJobWithChief(1, 0), + tfJob: testutil.NewTFJobWithChief(1, 0), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 0, @@ -84,7 +85,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "Chief worker is running", - tfJob: newTFJobWithChief(1, 0), + tfJob: testutil.NewTFJobWithChief(1, 0), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 0, @@ -98,7 +99,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "Chief worker is failed", - tfJob: newTFJobWithChief(1, 0), + tfJob: testutil.NewTFJobWithChief(1, 0), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 0, @@ -112,7 +113,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "(No chief worker) Worker is failed", - tfJob: newTFJob(1, 0), + tfJob: testutil.NewTFJob(1, 0), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 0, @@ -126,7 +127,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "(No chief worker) Worker is succeeded", - tfJob: newTFJob(1, 0), + tfJob: testutil.NewTFJob(1, 0), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 0, @@ -140,7 +141,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "(No chief worker) Worker is running", - tfJob: newTFJob(1, 0), + tfJob: testutil.NewTFJob(1, 0), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 0, @@ -154,7 +155,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "(No chief worker) 2 workers are succeeded, 2 workers are active", - tfJob: newTFJob(4, 2), + tfJob: testutil.NewTFJob(4, 2), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 2, @@ -168,7 +169,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "(No chief worker) 2 workers are running, 2 workers are failed", - tfJob: newTFJob(4, 2), + tfJob: testutil.NewTFJob(4, 2), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 2, @@ -182,7 +183,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "(No chief worker) 2 workers are succeeded, 2 workers are failed", - tfJob: newTFJob(4, 2), + tfJob: testutil.NewTFJob(4, 2), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 2, @@ -196,7 +197,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "Chief is running, workers are failed", - tfJob: newTFJobWithChief(4, 2), + tfJob: testutil.NewTFJobWithChief(4, 2), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 2, @@ -210,7 +211,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "Chief is running, workers are succeeded", - tfJob: newTFJobWithChief(4, 2), + tfJob: testutil.NewTFJobWithChief(4, 2), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 2, @@ -224,7 +225,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "Chief is failed, workers are succeeded", - tfJob: newTFJobWithChief(4, 2), + tfJob: testutil.NewTFJobWithChief(4, 2), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 2, @@ -238,7 +239,7 @@ func TestStatus(t *testing.T) { }, testCase{ description: "Chief is succeeded, workers are failed", - tfJob: newTFJobWithChief(4, 2), + tfJob: testutil.NewTFJobWithChief(4, 2), expectedFailedPS: 0, expectedSucceededPS: 0, expectedActivePS: 2, @@ -286,7 +287,7 @@ func TestStatus(t *testing.T) { } func setStatusForTest(tfJob *tfv1alpha2.TFJob, typ tfv1alpha2.TFReplicaType, failed, succeeded, active int32, t *testing.T) { - pod := newBasePod("pod", tfJob, t) + pod := testutil.NewBasePod("pod", tfJob, t) var i int32 for i = 0; i < failed; i++ { pod.Status.Phase = v1.PodFailed diff --git a/pkg/controller.v2/controller_tensorflow.go b/pkg/controller.v2/controller_tensorflow.go index cf9379fa3e..1e46805b82 100644 --- a/pkg/controller.v2/controller_tensorflow.go +++ b/pkg/controller.v2/controller_tensorflow.go @@ -22,6 +22,7 @@ import ( "strings" tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" + "github.com/kubeflow/tf-operator/pkg/generator" ) // TFConfig is a struct representing the distributed TensorFlow config. @@ -96,12 +97,12 @@ func genClusterSpec(tfjob *tfv1alpha2.TFJob) (ClusterSpec, error) { rt := strings.ToLower(string(rtype)) replicaNames := make([]string, 0, *spec.Replicas) - port, err := getPortFromTFJob(tfjob, rtype) + port, err := generator.GetPortFromTFJob(tfjob, rtype) if err != nil { return nil, err } for i := int32(0); i < *spec.Replicas; i++ { - host := fmt.Sprintf("%s:%d", genDNSRecord(tfjob.Name, rt, fmt.Sprintf("%d", i), tfjob.ObjectMeta.Namespace), port) + host := fmt.Sprintf("%s:%d", generator.GenDNSRecord(tfjob.Name, rt, fmt.Sprintf("%d", i), tfjob.ObjectMeta.Namespace), port) replicaNames = append(replicaNames, host) } diff --git a/pkg/controller.v2/controller_test.go b/pkg/controller.v2/controller_test.go index a9b000a9fd..b787e172f4 100644 --- a/pkg/controller.v2/controller_test.go +++ b/pkg/controller.v2/controller_test.go @@ -28,21 +28,11 @@ import ( tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" tfjobclientset "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned" tfjobinformers "github.com/kubeflow/tf-operator/pkg/client/informers/externalversions" -) - -const ( - testImageName = "test-image-for-kubeflow-tf-operator:latest" - testTFJobName = "test-tfjob" - labelWorker = "worker" - labelPS = "ps" - - sleepInterval = 500 * time.Millisecond - threadCount = 1 + "github.com/kubeflow/tf-operator/pkg/generator" + "github.com/kubeflow/tf-operator/pkg/util/testutil" ) var ( - alwaysReady = func() bool { return true } - tfJobRunning = tfv1alpha2.TFJobRunning tfJobSucceeded = tfv1alpha2.TFJobSucceeded ) @@ -224,9 +214,9 @@ func TestNormalPath(t *testing.T) { } tfJobClientSet := tfjobclientset.NewForConfigOrDie(config) ctr, kubeInformerFactory, _ := newTFJobController(config, kubeClientSet, tfJobClientSet, controller.NoResyncPeriodFunc) - ctr.tfJobInformerSynced = alwaysReady - ctr.podInformerSynced = alwaysReady - ctr.serviceInformerSynced = alwaysReady + ctr.tfJobInformerSynced = testutil.AlwaysReady + ctr.podInformerSynced = testutil.AlwaysReady + ctr.serviceInformerSynced = testutil.AlwaysReady tfJobIndexer := ctr.tfJobInformer.GetIndexer() var actual *tfv1alpha2.TFJob @@ -236,8 +226,8 @@ func TestNormalPath(t *testing.T) { } // Run the test logic. - tfJob := newTFJob(tc.worker, tc.ps) - unstructured, err := convertTFJobToUnstructured(tfJob) + tfJob := testutil.NewTFJob(tc.worker, tc.ps) + unstructured, err := generator.ConvertTFJobToUnstructured(tfJob) if err != nil { t.Errorf("Failed to convert the TFJob to Unstructured: %v", err) } @@ -247,14 +237,14 @@ func TestNormalPath(t *testing.T) { } podIndexer := kubeInformerFactory.Core().V1().Pods().Informer().GetIndexer() - setPodsStatuses(podIndexer, tfJob, labelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, t) - setPodsStatuses(podIndexer, tfJob, labelPS, tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, t) + testutil.SetPodsStatuses(podIndexer, tfJob, testutil.LabelWorker, tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, t) + testutil.SetPodsStatuses(podIndexer, tfJob, testutil.LabelPS, tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, t) serviceIndexer := kubeInformerFactory.Core().V1().Services().Informer().GetIndexer() - setServices(serviceIndexer, tfJob, labelWorker, tc.activeWorkerServices, t) - setServices(serviceIndexer, tfJob, labelPS, tc.activePSServices, t) + testutil.SetServices(serviceIndexer, tfJob, testutil.LabelWorker, tc.activeWorkerServices, t) + testutil.SetServices(serviceIndexer, tfJob, testutil.LabelPS, tc.activePSServices, t) - forget, err := ctr.syncTFJob(getKey(tfJob, t)) + forget, err := ctr.syncTFJob(testutil.GetKey(tfJob, t)) // We need requeue syncJob task if podController error if tc.ControllerError != nil { if err == nil { @@ -331,7 +321,7 @@ func TestNormalPath(t *testing.T) { t.Errorf("%s: StartTime was not set", name) } // Validate conditions. - if tc.expectedCondition != nil && !checkCondition(actual, *tc.expectedCondition, tc.expectedConditionReason) { + if tc.expectedCondition != nil && !testutil.CheckCondition(actual, *tc.expectedCondition, tc.expectedConditionReason) { t.Errorf("%s: expected condition %#v, got %#v", name, *tc.expectedCondition, actual.Status.Conditions) } } @@ -354,19 +344,19 @@ func TestRun(t *testing.T) { } tfJobClientSet := tfjobclientset.NewForConfigOrDie(config) ctr, _, _ := newTFJobController(config, kubeClientSet, tfJobClientSet, controller.NoResyncPeriodFunc) - ctr.tfJobInformerSynced = alwaysReady - ctr.podInformerSynced = alwaysReady - ctr.serviceInformerSynced = alwaysReady + ctr.tfJobInformerSynced = testutil.AlwaysReady + ctr.podInformerSynced = testutil.AlwaysReady + ctr.serviceInformerSynced = testutil.AlwaysReady stopCh := make(chan struct{}) go func() { // It is a hack to let the controller stop to run without errors. // We can not just send a struct to stopCh because there are multiple // receivers in controller.Run. - time.Sleep(sleepInterval) + time.Sleep(testutil.SleepInterval) stopCh <- struct{}{} }() - err := ctr.Run(threadCount, stopCh) + err := ctr.Run(testutil.ThreadCount, stopCh) if err != nil { t.Errorf("Failed to run: %v", err) } diff --git a/pkg/controller.v2/controller_tfjob.go b/pkg/controller.v2/controller_tfjob.go index c4571770ed..d27aba327b 100644 --- a/pkg/controller.v2/controller_tfjob.go +++ b/pkg/controller.v2/controller_tfjob.go @@ -10,6 +10,10 @@ import ( tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" ) +const ( + failedMarshalTFJobReason = "FailedMarshalTFJob" +) + // When a pod is added, set the defaults and enqueue the current tfjob. func (tc *TFJobController) addTFJob(obj interface{}) { // Convert from unstructured object. diff --git a/pkg/controller.v2/controller_tfjob_test.go b/pkg/controller.v2/controller_tfjob_test.go index ecc6b160a2..186b2f9af2 100644 --- a/pkg/controller.v2/controller_tfjob_test.go +++ b/pkg/controller.v2/controller_tfjob_test.go @@ -18,13 +18,14 @@ import ( "testing" "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" kubeclientset "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/kubernetes/pkg/controller" tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" tfjobclientset "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned" + "github.com/kubeflow/tf-operator/pkg/generator" + "github.com/kubeflow/tf-operator/pkg/util/testutil" ) func TestAddTFJob(t *testing.T) { @@ -44,14 +45,14 @@ func TestAddTFJob(t *testing.T) { } tfJobClientSet := tfjobclientset.NewForConfigOrDie(config) ctr, _, _ := newTFJobController(config, kubeClientSet, tfJobClientSet, controller.NoResyncPeriodFunc) - ctr.tfJobInformerSynced = alwaysReady - ctr.podInformerSynced = alwaysReady - ctr.serviceInformerSynced = alwaysReady + ctr.tfJobInformerSynced = testutil.AlwaysReady + ctr.podInformerSynced = testutil.AlwaysReady + ctr.serviceInformerSynced = testutil.AlwaysReady tfJobIndexer := ctr.tfJobInformer.GetIndexer() stopCh := make(chan struct{}) run := func(<-chan struct{}) { - ctr.Run(threadCount, stopCh) + ctr.Run(testutil.ThreadCount, stopCh) } go run(stopCh) @@ -66,8 +67,8 @@ func TestAddTFJob(t *testing.T) { return nil } - tfJob := newTFJob(1, 0) - unstructured, err := convertTFJobToUnstructured(tfJob) + tfJob := testutil.NewTFJob(1, 0) + unstructured, err := generator.ConvertTFJobToUnstructured(tfJob) if err != nil { t.Errorf("Failed to convert the TFJob to Unstructured: %v", err) } @@ -77,8 +78,8 @@ func TestAddTFJob(t *testing.T) { ctr.addTFJob(unstructured) syncChan <- "sync" - if key != getKey(tfJob, t) { - t.Errorf("Failed to enqueue the TFJob %s: expected %s, got %s", tfJob.Name, getKey(tfJob, t), key) + if key != testutil.GetKey(tfJob, t) { + t.Errorf("Failed to enqueue the TFJob %s: expected %s, got %s", tfJob.Name, testutil.GetKey(tfJob, t), key) } close(stopCh) } @@ -102,14 +103,14 @@ func TestCopyLabelsAndAnnotation(t *testing.T) { ctr, _, _ := newTFJobController(config, kubeClientSet, tfJobClientSet, controller.NoResyncPeriodFunc) fakePodControl := &controller.FakePodControl{} ctr.podControl = fakePodControl - ctr.tfJobInformerSynced = alwaysReady - ctr.podInformerSynced = alwaysReady - ctr.serviceInformerSynced = alwaysReady + ctr.tfJobInformerSynced = testutil.AlwaysReady + ctr.podInformerSynced = testutil.AlwaysReady + ctr.serviceInformerSynced = testutil.AlwaysReady tfJobIndexer := ctr.tfJobInformer.GetIndexer() stopCh := make(chan struct{}) run := func(<-chan struct{}) { - ctr.Run(threadCount, stopCh) + ctr.Run(testutil.ThreadCount, stopCh) } go run(stopCh) @@ -117,7 +118,7 @@ func TestCopyLabelsAndAnnotation(t *testing.T) { return nil } - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) annotations := map[string]string{ "annotation1": "1", } @@ -126,7 +127,7 @@ func TestCopyLabelsAndAnnotation(t *testing.T) { } tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].Template.Labels = labels tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].Template.Annotations = annotations - unstructured, err := convertTFJobToUnstructured(tfJob) + unstructured, err := generator.ConvertTFJobToUnstructured(tfJob) if err != nil { t.Errorf("Failed to convert the TFJob to Unstructured: %v", err) } @@ -135,7 +136,7 @@ func TestCopyLabelsAndAnnotation(t *testing.T) { t.Errorf("Failed to add tfjob to tfJobIndexer: %v", err) } - _, err = ctr.syncTFJob(getKey(tfJob, t)) + _, err = ctr.syncTFJob(testutil.GetKey(tfJob, t)) if err != nil { t.Errorf("%s: unexpected error when syncing jobs %v", tfJob.Name, err) } @@ -162,83 +163,3 @@ func TestCopyLabelsAndAnnotation(t *testing.T) { close(stopCh) } - -func newTFJobWithChief(worker, ps int) *tfv1alpha2.TFJob { - tfJob := newTFJob(worker, ps) - tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeChief] = &tfv1alpha2.TFReplicaSpec{ - Template: newTFReplicaSpecTemplate(), - } - return tfJob -} - -func newTFJob(worker, ps int) *tfv1alpha2.TFJob { - tfJob := &tfv1alpha2.TFJob{ - TypeMeta: metav1.TypeMeta{ - Kind: tfv1alpha2.Kind, - }, - ObjectMeta: metav1.ObjectMeta{ - Name: testTFJobName, - Namespace: metav1.NamespaceDefault, - }, - Spec: tfv1alpha2.TFJobSpec{ - TFReplicaSpecs: make(map[tfv1alpha2.TFReplicaType]*tfv1alpha2.TFReplicaSpec), - }, - } - - if worker > 0 { - worker := int32(worker) - workerReplicaSpec := &tfv1alpha2.TFReplicaSpec{ - Replicas: &worker, - Template: newTFReplicaSpecTemplate(), - } - tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker] = workerReplicaSpec - } - - if ps > 0 { - ps := int32(ps) - psReplicaSpec := &tfv1alpha2.TFReplicaSpec{ - Replicas: &ps, - Template: newTFReplicaSpecTemplate(), - } - tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypePS] = psReplicaSpec - } - return tfJob -} - -func getKey(tfJob *tfv1alpha2.TFJob, t *testing.T) string { - key, err := KeyFunc(tfJob) - if err != nil { - t.Errorf("Unexpected error getting key for job %v: %v", tfJob.Name, err) - return "" - } - return key -} - -func checkCondition(tfJob *tfv1alpha2.TFJob, condition tfv1alpha2.TFJobConditionType, reason string) bool { - for _, v := range tfJob.Status.Conditions { - if v.Type == condition && v.Status == v1.ConditionTrue && v.Reason == reason { - return true - } - } - return false -} - -func newTFReplicaSpecTemplate() v1.PodTemplateSpec { - return v1.PodTemplateSpec{ - Spec: v1.PodSpec{ - Containers: []v1.Container{ - v1.Container{ - Name: tfv1alpha2.DefaultContainerName, - Image: testImageName, - Args: []string{"Fake", "Fake"}, - Ports: []v1.ContainerPort{ - v1.ContainerPort{ - Name: tfv1alpha2.DefaultPortName, - ContainerPort: tfv1alpha2.DefaultPort, - }, - }, - }, - }, - }, - } -} diff --git a/pkg/controller.v2/pod_control_test.go b/pkg/controller.v2/pod_control_test.go index 88df7d7398..9cb3333cd3 100644 --- a/pkg/controller.v2/pod_control_test.go +++ b/pkg/controller.v2/pod_control_test.go @@ -32,6 +32,9 @@ import ( utiltesting "k8s.io/client-go/util/testing" "k8s.io/kubernetes/pkg/api/legacyscheme" "k8s.io/kubernetes/pkg/api/testapi" + + "github.com/kubeflow/tf-operator/pkg/generator" + "github.com/kubeflow/tf-operator/pkg/util/testutil" ) func TestCreatePods(t *testing.T) { @@ -50,12 +53,12 @@ func TestCreatePods(t *testing.T) { Recorder: &record.FakeRecorder{}, } - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) testName := "pod-name" - podTemplate := newTFReplicaSpecTemplate() + podTemplate := testutil.NewTFReplicaSpecTemplate() podTemplate.Name = testName - podTemplate.Labels = genLabels(getKey(tfJob, t)) + podTemplate.Labels = generator.GenLabels(testutil.GetKey(tfJob, t)) podTemplate.SetOwnerReferences([]metav1.OwnerReference{}) // Make sure createReplica sends a POST to the apiserver with a pod from the controllers pod template @@ -64,7 +67,7 @@ func TestCreatePods(t *testing.T) { expectedPod := v1.Pod{ ObjectMeta: metav1.ObjectMeta{ - Labels: genLabels(getKey(tfJob, t)), + Labels: generator.GenLabels(testutil.GetKey(tfJob, t)), Name: testName, }, Spec: podTemplate.Spec, diff --git a/pkg/controller.v2/service_control_test.go b/pkg/controller.v2/service_control_test.go index f6ed600c4b..edaab49007 100644 --- a/pkg/controller.v2/service_control_test.go +++ b/pkg/controller.v2/service_control_test.go @@ -19,6 +19,7 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" "k8s.io/api/core/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -29,7 +30,8 @@ import ( utiltesting "k8s.io/client-go/util/testing" "k8s.io/kubernetes/pkg/api/testapi" - "github.com/stretchr/testify/assert" + "github.com/kubeflow/tf-operator/pkg/generator" + "github.com/kubeflow/tf-operator/pkg/util/testutil" ) func TestCreateService(t *testing.T) { @@ -53,10 +55,10 @@ func TestCreateService(t *testing.T) { Recorder: &record.FakeRecorder{}, } - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) testName := "service-name" - service := newBaseService(testName, tfJob, t) + service := testutil.NewBaseService(testName, tfJob, t) service.SetOwnerReferences([]metav1.OwnerReference{}) // Make sure createReplica sends a POST to the apiserver with a pod from the controllers pod template @@ -65,7 +67,7 @@ func TestCreateService(t *testing.T) { expectedService := v1.Service{ ObjectMeta: metav1.ObjectMeta{ - Labels: genLabels(getKey(tfJob, t)), + Labels: generator.GenLabels(testutil.GetKey(tfJob, t)), Name: testName, Namespace: ns, }, @@ -99,13 +101,13 @@ func TestCreateServicesWithControllerRef(t *testing.T) { Recorder: &record.FakeRecorder{}, } - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) testName := "service-name" - service := newBaseService(testName, tfJob, t) + service := testutil.NewBaseService(testName, tfJob, t) service.SetOwnerReferences([]metav1.OwnerReference{}) - ownerRef := genOwnerReference(tfJob) + ownerRef := generator.GenOwnerReference(tfJob) // Make sure createReplica sends a POST to the apiserver with a pod from the controllers pod template err := serviceControl.CreateServicesWithControllerRef(ns, service, tfJob, ownerRef) @@ -113,7 +115,7 @@ func TestCreateServicesWithControllerRef(t *testing.T) { expectedService := v1.Service{ ObjectMeta: metav1.ObjectMeta{ - Labels: genLabels(getKey(tfJob, t)), + Labels: generator.GenLabels(testutil.GetKey(tfJob, t)), Name: testName, Namespace: ns, OwnerReferences: []metav1.OwnerReference{*ownerRef}, diff --git a/pkg/controller.v2/service_ref_manager_test.go b/pkg/controller.v2/service_ref_manager_test.go index 0d0ffc6e70..a07894a43e 100644 --- a/pkg/controller.v2/service_ref_manager_test.go +++ b/pkg/controller.v2/service_ref_manager_test.go @@ -21,6 +21,9 @@ import ( "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + + "github.com/kubeflow/tf-operator/pkg/generator" + "github.com/kubeflow/tf-operator/pkg/util/testutil" ) func TestClaimServices(t *testing.T) { @@ -36,15 +39,15 @@ func TestClaimServices(t *testing.T) { } var tests = []test{ func() test { - tfJob := newTFJob(1, 0) + tfJob := testutil.NewTFJob(1, 0) tfJobLabelSelector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: genLabels(getKey(tfJob, t)), + MatchLabels: generator.GenLabels(testutil.GetKey(tfJob, t)), }) if err != nil { t.Errorf("Unexpected error: %v", err) } - testService := newBaseService("service2", tfJob, nil) - testService.Labels[labelGroupName] = "testing" + testService := testutil.NewBaseService("service2", tfJob, nil) + testService.Labels[generator.LabelGroupName] = "testing" return test{ name: "Claim services with correct label", @@ -53,14 +56,14 @@ func TestClaimServices(t *testing.T) { tfJobLabelSelector, controllerKind, func() error { return nil }), - services: []*v1.Service{newBaseService("service1", tfJob, t), testService}, - claimed: []*v1.Service{newBaseService("service1", tfJob, t)}, + services: []*v1.Service{testutil.NewBaseService("service1", tfJob, t), testService}, + claimed: []*v1.Service{testutil.NewBaseService("service1", tfJob, t)}, } }(), func() test { - controller := newTFJob(1, 0) + controller := testutil.NewTFJob(1, 0) controllerLabelSelector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: genLabels(getKey(controller, t)), + MatchLabels: generator.GenLabels(testutil.GetKey(controller, t)), }) if err != nil { t.Errorf("Unexpected error: %v", err) @@ -68,9 +71,9 @@ func TestClaimServices(t *testing.T) { controller.UID = types.UID(controllerUID) now := metav1.Now() controller.DeletionTimestamp = &now - testService1 := newBaseService("service1", controller, t) + testService1 := testutil.NewBaseService("service1", controller, t) testService1.SetOwnerReferences([]metav1.OwnerReference{}) - testService2 := newBaseService("service2", controller, t) + testService2 := testutil.NewBaseService("service2", controller, t) testService2.SetOwnerReferences([]metav1.OwnerReference{}) return test{ name: "Controller marked for deletion can not claim services", @@ -84,9 +87,9 @@ func TestClaimServices(t *testing.T) { } }(), func() test { - controller := newTFJob(1, 0) + controller := testutil.NewTFJob(1, 0) controllerLabelSelector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: genLabels(getKey(controller, t)), + MatchLabels: generator.GenLabels(testutil.GetKey(controller, t)), }) if err != nil { t.Errorf("Unexpected error: %v", err) @@ -94,7 +97,7 @@ func TestClaimServices(t *testing.T) { controller.UID = types.UID(controllerUID) now := metav1.Now() controller.DeletionTimestamp = &now - testService2 := newBaseService("service2", controller, t) + testService2 := testutil.NewBaseService("service2", controller, t) testService2.SetOwnerReferences([]metav1.OwnerReference{}) return test{ name: "Controller marked for deletion can not claim new services", @@ -103,19 +106,19 @@ func TestClaimServices(t *testing.T) { controllerLabelSelector, controllerKind, func() error { return nil }), - services: []*v1.Service{newBaseService("service1", controller, t), testService2}, - claimed: []*v1.Service{newBaseService("service1", controller, t)}, + services: []*v1.Service{testutil.NewBaseService("service1", controller, t), testService2}, + claimed: []*v1.Service{testutil.NewBaseService("service1", controller, t)}, } }(), func() test { - controller := newTFJob(1, 0) + controller := testutil.NewTFJob(1, 0) controllerLabelSelector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: genLabels(getKey(controller, t)), + MatchLabels: generator.GenLabels(testutil.GetKey(controller, t)), }) if err != nil { t.Errorf("Unexpected error: %v", err) } - controller2 := newTFJob(1, 0) + controller2 := testutil.NewTFJob(1, 0) controller.UID = types.UID(controllerUID) controller2.UID = types.UID("AAAAA") return test{ @@ -125,21 +128,21 @@ func TestClaimServices(t *testing.T) { controllerLabelSelector, controllerKind, func() error { return nil }), - services: []*v1.Service{newBaseService("service1", controller, t), newBaseService("service2", controller2, t)}, - claimed: []*v1.Service{newBaseService("service1", controller, t)}, + services: []*v1.Service{testutil.NewBaseService("service1", controller, t), testutil.NewBaseService("service2", controller2, t)}, + claimed: []*v1.Service{testutil.NewBaseService("service1", controller, t)}, } }(), func() test { - controller := newTFJob(1, 0) + controller := testutil.NewTFJob(1, 0) controllerLabelSelector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: genLabels(getKey(controller, t)), + MatchLabels: generator.GenLabels(testutil.GetKey(controller, t)), }) if err != nil { t.Errorf("Unexpected error: %v", err) } controller.UID = types.UID(controllerUID) - testService2 := newBaseService("service2", controller, t) - testService2.Labels[labelGroupName] = "testing" + testService2 := testutil.NewBaseService("service2", controller, t) + testService2.Labels[generator.LabelGroupName] = "testing" return test{ name: "Controller releases claimed services when selector doesn't match", manager: NewServiceControllerRefManager(&FakeServiceControl{}, @@ -147,22 +150,22 @@ func TestClaimServices(t *testing.T) { controllerLabelSelector, controllerKind, func() error { return nil }), - services: []*v1.Service{newBaseService("service1", controller, t), testService2}, - claimed: []*v1.Service{newBaseService("service1", controller, t)}, + services: []*v1.Service{testutil.NewBaseService("service1", controller, t), testService2}, + claimed: []*v1.Service{testutil.NewBaseService("service1", controller, t)}, } }(), func() test { - controller := newTFJob(1, 0) + controller := testutil.NewTFJob(1, 0) controllerLabelSelector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ - MatchLabels: genLabels(getKey(controller, t)), + MatchLabels: generator.GenLabels(testutil.GetKey(controller, t)), }) if err != nil { t.Errorf("Unexpected error: %v", err) } controller.UID = types.UID(controllerUID) - testService1 := newBaseService("service1", controller, t) - testService2 := newBaseService("service2", controller, t) - testService2.Labels[labelGroupName] = "testing" + testService1 := testutil.NewBaseService("service1", controller, t) + testService2 := testutil.NewBaseService("service2", controller, t) + testService2.Labels[generator.LabelGroupName] = "testing" now := metav1.Now() testService1.DeletionTimestamp = &now testService2.DeletionTimestamp = &now diff --git a/pkg/controller.v2/controller_helper.go b/pkg/generator/generator.go similarity index 75% rename from pkg/controller.v2/controller_helper.go rename to pkg/generator/generator.go index c828a897c1..6403625860 100644 --- a/pkg/controller.v2/controller_helper.go +++ b/pkg/generator/generator.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package controller +package generator import ( "encoding/json" @@ -26,17 +26,15 @@ import ( ) const ( - labelGroupName = "group_name" + LabelGroupName = "group_name" labelTFJobKey = "tf_job_key" - - failedMarshalTFJobReason = "FailedMarshalTFJob" ) var ( errPortNotFound = fmt.Errorf("Failed to found the port") ) -func genOwnerReference(tfjob *tfv1alpha2.TFJob) *metav1.OwnerReference { +func GenOwnerReference(tfjob *tfv1alpha2.TFJob) *metav1.OwnerReference { boolPtr := func(b bool) *bool { return &b } controllerRef := &metav1.OwnerReference{ APIVersion: tfv1alpha2.SchemeGroupVersion.String(), @@ -50,24 +48,24 @@ func genOwnerReference(tfjob *tfv1alpha2.TFJob) *metav1.OwnerReference { return controllerRef } -func genLabels(tfjobKey string) map[string]string { +func GenLabels(tfjobKey string) map[string]string { return map[string]string{ - labelGroupName: tfv1alpha2.GroupName, + LabelGroupName: tfv1alpha2.GroupName, labelTFJobKey: strings.Replace(tfjobKey, "/", "-", -1), } } -func genGeneralName(tfJobName, rtype, index string) string { +func GenGeneralName(tfJobName, rtype, index string) string { n := tfJobName + "-" + rtype + "-" + index return strings.Replace(n, "/", "-", -1) } -func genDNSRecord(tfJobName, rtype, index, namespace string) string { - return fmt.Sprintf("%s.%s.svc.cluster.local", genGeneralName(tfJobName, rtype, index), namespace) +func GenDNSRecord(tfJobName, rtype, index, namespace string) string { + return fmt.Sprintf("%s.%s.svc.cluster.local", GenGeneralName(tfJobName, rtype, index), namespace) } -// convertTFJobToUnstructured uses JSON to convert TFJob to Unstructured. -func convertTFJobToUnstructured(tfJob *tfv1alpha2.TFJob) (*unstructured.Unstructured, error) { +// ConvertTFJobToUnstructured uses JSON to convert TFJob to Unstructured. +func ConvertTFJobToUnstructured(tfJob *tfv1alpha2.TFJob) (*unstructured.Unstructured, error) { var unstructured unstructured.Unstructured b, err := json.Marshal(tfJob) if err != nil { @@ -80,8 +78,8 @@ func convertTFJobToUnstructured(tfJob *tfv1alpha2.TFJob) (*unstructured.Unstruct return &unstructured, nil } -// getPortFromTFJob gets the port of tensorflow container. -func getPortFromTFJob(tfJob *tfv1alpha2.TFJob, rtype tfv1alpha2.TFReplicaType) (int32, error) { +// GetPortFromTFJob gets the port of tensorflow container. +func GetPortFromTFJob(tfJob *tfv1alpha2.TFJob, rtype tfv1alpha2.TFReplicaType) (int32, error) { containers := tfJob.Spec.TFReplicaSpecs[rtype].Template.Spec.Containers for _, container := range containers { if container.Name == tfv1alpha2.DefaultContainerName { @@ -96,7 +94,7 @@ func getPortFromTFJob(tfJob *tfv1alpha2.TFJob, rtype tfv1alpha2.TFReplicaType) ( return -1, errPortNotFound } -func containChiefSpec(tfJob *tfv1alpha2.TFJob) bool { +func ContainChiefSpec(tfJob *tfv1alpha2.TFJob) bool { if _, ok := tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeChief]; ok { return true } diff --git a/pkg/controller.v2/controller_helper_test.go b/pkg/generator/generator_test.go similarity index 86% rename from pkg/controller.v2/controller_helper_test.go rename to pkg/generator/generator_test.go index a49d5575b0..e4684db85e 100644 --- a/pkg/controller.v2/controller_helper_test.go +++ b/pkg/generator/generator_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package controller +package generator import ( "fmt" @@ -34,7 +34,7 @@ func TestGenOwnerReference(t *testing.T) { }, } - ref := genOwnerReference(tfJob) + ref := GenOwnerReference(tfJob) if ref.UID != testUID { t.Errorf("Expected UID %s, got %s", testUID, ref.UID) } @@ -50,13 +50,13 @@ func TestGenLabels(t *testing.T) { testKey := "test/key" expctedKey := "test-key" - labels := genLabels(testKey) + labels := GenLabels(testKey) if labels[labelTFJobKey] != expctedKey { t.Errorf("Expected %s %s, got %s", labelTFJobKey, expctedKey, labels[labelTFJobKey]) } - if labels[labelGroupName] != tfv1alpha2.GroupName { - t.Errorf("Expected %s %s, got %s", labelGroupName, tfv1alpha2.GroupName, labels[labelGroupName]) + if labels[LabelGroupName] != tfv1alpha2.GroupName { + t.Errorf("Expected %s %s, got %s", LabelGroupName, tfv1alpha2.GroupName, labels[LabelGroupName]) } } @@ -66,7 +66,7 @@ func TestGenGeneralName(t *testing.T) { testKey := "1/2/3/4/5" expectedName := fmt.Sprintf("1-2-3-4-5-%s-%s", testRType, testIndex) - name := genGeneralName(testKey, testRType, testIndex) + name := GenGeneralName(testKey, testRType, testIndex) if name != expectedName { t.Errorf("Expected name %s, got %s", expectedName, name) } @@ -85,7 +85,7 @@ func TestConvertTFJobToUnstructured(t *testing.T) { }, } - _, err := convertTFJobToUnstructured(tfJob) + _, err := ConvertTFJobToUnstructured(tfJob) if err != nil { t.Errorf("Expected error to be nil while got %v", err) } diff --git a/pkg/util/testutil/const.go b/pkg/util/testutil/const.go new file mode 100644 index 0000000000..b485f28c77 --- /dev/null +++ b/pkg/util/testutil/const.go @@ -0,0 +1,33 @@ +// Copyright 2018 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "time" +) + +const ( + TestImageName = "test-image-for-kubeflow-tf-operator:latest" + TestTFJobName = "test-tfjob" + LabelWorker = "worker" + LabelPS = "ps" + + SleepInterval = 500 * time.Millisecond + ThreadCount = 1 +) + +var ( + AlwaysReady = func() bool { return true } +) diff --git a/pkg/util/testutil/pod.go b/pkg/util/testutil/pod.go new file mode 100644 index 0000000000..5d9805abc6 --- /dev/null +++ b/pkg/util/testutil/pod.go @@ -0,0 +1,93 @@ +// Copyright 2018 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "fmt" + "testing" + + "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/cache" + + tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" + "github.com/kubeflow/tf-operator/pkg/generator" +) + +const ( + // labels for pods and servers. + tfReplicaTypeLabel = "tf-replica-type" + tfReplicaIndexLabel = "tf-replica-index" +) + +var ( + controllerKind = tfv1alpha2.SchemeGroupVersionKind +) + +func NewBasePod(name string, tfJob *tfv1alpha2.TFJob, t *testing.T) *v1.Pod { + return &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: generator.GenLabels(GetKey(tfJob, t)), + Namespace: tfJob.Namespace, + OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)}, + }, + } +} + +func NewPod(tfJob *tfv1alpha2.TFJob, typ string, index int, t *testing.T) *v1.Pod { + pod := NewBasePod(fmt.Sprintf("%s-%d", typ, index), tfJob, t) + pod.Labels[tfReplicaTypeLabel] = typ + pod.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index) + return pod +} + +// create count pods with the given phase for the given tfJob +func NewPodList(count int32, status v1.PodPhase, tfJob *tfv1alpha2.TFJob, typ string, start int32, t *testing.T) []*v1.Pod { + pods := []*v1.Pod{} + for i := int32(0); i < count; i++ { + newPod := NewPod(tfJob, typ, int(start+i), t) + newPod.Status = v1.PodStatus{Phase: status} + pods = append(pods, newPod) + } + return pods +} + +func SetPodsStatuses(podIndexer cache.Indexer, tfJob *tfv1alpha2.TFJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, t *testing.T) { + var index int32 + for _, pod := range NewPodList(pendingPods, v1.PodPending, tfJob, typ, index, t) { + if err := podIndexer.Add(pod); err != nil { + t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) + } + } + index += pendingPods + for _, pod := range NewPodList(activePods, v1.PodRunning, tfJob, typ, index, t) { + if err := podIndexer.Add(pod); err != nil { + t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) + } + } + index += activePods + for _, pod := range NewPodList(succeededPods, v1.PodSucceeded, tfJob, typ, index, t) { + if err := podIndexer.Add(pod); err != nil { + t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) + } + } + index += succeededPods + for _, pod := range NewPodList(failedPods, v1.PodFailed, tfJob, typ, index, t) { + if err := podIndexer.Add(pod); err != nil { + t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) + } + } +} diff --git a/pkg/util/testutil/service.go b/pkg/util/testutil/service.go new file mode 100644 index 0000000000..6accd5ff1d --- /dev/null +++ b/pkg/util/testutil/service.go @@ -0,0 +1,63 @@ +// Copyright 2018 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "fmt" + "testing" + + "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/cache" + + tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" + "github.com/kubeflow/tf-operator/pkg/generator" +) + +func NewBaseService(name string, tfJob *tfv1alpha2.TFJob, t *testing.T) *v1.Service { + return &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: generator.GenLabels(GetKey(tfJob, t)), + Namespace: tfJob.Namespace, + OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)}, + }, + } +} + +func NewService(tfJob *tfv1alpha2.TFJob, typ string, index int, t *testing.T) *v1.Service { + service := NewBaseService(fmt.Sprintf("%s-%d", typ, index), tfJob, t) + service.Labels[tfReplicaTypeLabel] = typ + service.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index) + return service +} + +// NewServiceList creates count pods with the given phase for the given tfJob +func NewServiceList(count int32, tfJob *tfv1alpha2.TFJob, typ string, t *testing.T) []*v1.Service { + services := []*v1.Service{} + for i := int32(0); i < count; i++ { + newService := NewService(tfJob, typ, int(i), t) + services = append(services, newService) + } + return services +} + +func SetServices(serviceIndexer cache.Indexer, tfJob *tfv1alpha2.TFJob, typ string, activeWorkerServices int32, t *testing.T) { + for _, service := range NewServiceList(activeWorkerServices, tfJob, typ, t) { + if err := serviceIndexer.Add(service); err != nil { + t.Errorf("unexpected error when adding service %v", err) + } + } +} diff --git a/pkg/util/testutil/tfjob.go b/pkg/util/testutil/tfjob.go new file mode 100644 index 0000000000..9f068f618b --- /dev/null +++ b/pkg/util/testutil/tfjob.go @@ -0,0 +1,84 @@ +// Copyright 2018 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" +) + +func NewTFJobWithChief(worker, ps int) *tfv1alpha2.TFJob { + tfJob := NewTFJob(worker, ps) + tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeChief] = &tfv1alpha2.TFReplicaSpec{ + Template: NewTFReplicaSpecTemplate(), + } + return tfJob +} + +func NewTFJob(worker, ps int) *tfv1alpha2.TFJob { + tfJob := &tfv1alpha2.TFJob{ + TypeMeta: metav1.TypeMeta{ + Kind: tfv1alpha2.Kind, + }, + ObjectMeta: metav1.ObjectMeta{ + Name: TestTFJobName, + Namespace: metav1.NamespaceDefault, + }, + Spec: tfv1alpha2.TFJobSpec{ + TFReplicaSpecs: make(map[tfv1alpha2.TFReplicaType]*tfv1alpha2.TFReplicaSpec), + }, + } + + if worker > 0 { + worker := int32(worker) + workerReplicaSpec := &tfv1alpha2.TFReplicaSpec{ + Replicas: &worker, + Template: NewTFReplicaSpecTemplate(), + } + tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker] = workerReplicaSpec + } + + if ps > 0 { + ps := int32(ps) + psReplicaSpec := &tfv1alpha2.TFReplicaSpec{ + Replicas: &ps, + Template: NewTFReplicaSpecTemplate(), + } + tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypePS] = psReplicaSpec + } + return tfJob +} + +func NewTFReplicaSpecTemplate() v1.PodTemplateSpec { + return v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: tfv1alpha2.DefaultContainerName, + Image: TestImageName, + Args: []string{"Fake", "Fake"}, + Ports: []v1.ContainerPort{ + v1.ContainerPort{ + Name: tfv1alpha2.DefaultPortName, + ContainerPort: tfv1alpha2.DefaultPort, + }, + }, + }, + }, + }, + } +} diff --git a/pkg/util/testutil/util.go b/pkg/util/testutil/util.go new file mode 100644 index 0000000000..cdfc602849 --- /dev/null +++ b/pkg/util/testutil/util.go @@ -0,0 +1,49 @@ +// Copyright 2018 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "testing" + + "k8s.io/api/core/v1" + "k8s.io/client-go/tools/cache" + + tfv1alpha2 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha2" +) + +var ( + // KeyFunc is the short name to DeletionHandlingMetaNamespaceKeyFunc. + // IndexerInformer uses a delta queue, therefore for deletes we have to use this + // key function but it should be just fine for non delete events. + KeyFunc = cache.DeletionHandlingMetaNamespaceKeyFunc +) + +func GetKey(tfJob *tfv1alpha2.TFJob, t *testing.T) string { + key, err := KeyFunc(tfJob) + if err != nil { + t.Errorf("Unexpected error getting key for job %v: %v", tfJob.Name, err) + return "" + } + return key +} + +func CheckCondition(tfJob *tfv1alpha2.TFJob, condition tfv1alpha2.TFJobConditionType, reason string) bool { + for _, v := range tfJob.Status.Conditions { + if v.Type == condition && v.Status == v1.ConditionTrue && v.Reason == reason { + return true + } + } + return false +}