From 6c77fca1a300a42f575967bae12452ddfec150c0 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Wed, 13 Jun 2018 11:59:32 +0800 Subject: [PATCH] pod: Add test for restart policy Signed-off-by: Ce Gao --- pkg/controller.v2/controller_pod.go | 20 ++++--- pkg/controller.v2/controller_pod_test.go | 68 ++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 8 deletions(-) diff --git a/pkg/controller.v2/controller_pod.go b/pkg/controller.v2/controller_pod.go index 0656bce4ef..a8b3e76292 100644 --- a/pkg/controller.v2/controller_pod.go +++ b/pkg/controller.v2/controller_pod.go @@ -164,14 +164,7 @@ func (tc *TFJobController) createNewPod(tfjob *tfv1alpha2.TFJob, rt, index strin loggerForReplica(tfjob, rt).Warning(errMsg) tc.recorder.Event(tfjob, v1.EventTypeWarning, podTemplateRestartPolicyReason, errMsg) } - if spec.RestartPolicy == tfv1alpha2.RestartPolicyExitCode { - podTemplate.Spec.RestartPolicy = v1.RestartPolicyNever - } else if spec.RestartPolicy == tfv1alpha2.RestartPolicy("") { - // Set default to Never. - podTemplate.Spec.RestartPolicy = v1.RestartPolicyNever - } else { - podTemplate.Spec.RestartPolicy = v1.RestartPolicy(spec.RestartPolicy) - } + setRestartPolicy(podTemplate, spec) err = tc.podControl.CreatePodsWithControllerRef(tfjob.Namespace, podTemplate, tfjob, controllerRef) if err != nil && errors.IsTimeout(err) { @@ -212,6 +205,17 @@ func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1alpha2.TFJob return nil } +func setRestartPolicy(podTemplateSpec *v1.PodTemplateSpec, spec *tfv1alpha2.TFReplicaSpec) { + if spec.RestartPolicy == tfv1alpha2.RestartPolicyExitCode { + podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicyNever + } else if spec.RestartPolicy == tfv1alpha2.RestartPolicy("") { + // Set default to Never. + podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicyNever + } else { + podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicy(spec.RestartPolicy) + } +} + // getPodsForTFJob returns the set of pods that this tfjob should manage. // It also reconciles ControllerRef by adopting/orphaning. // Note that the returned Pods are pointers into the cache. diff --git a/pkg/controller.v2/controller_pod_test.go b/pkg/controller.v2/controller_pod_test.go index 8955d00c24..7affcb513b 100644 --- a/pkg/controller.v2/controller_pod_test.go +++ b/pkg/controller.v2/controller_pod_test.go @@ -168,3 +168,71 @@ func TestClusterSpec(t *testing.T) { } } } + +func TestRestartPolicy(t *testing.T) { + type tc struct { + tfJob *tfv1alpha2.TFJob + expectedRestartPolicy v1.RestartPolicy + expectedType tfv1alpha2.TFReplicaType + } + testCase := []tc{ + func() tc { + tfJob := newTFJob(1, 0) + specRestartPolicy := tfv1alpha2.RestartPolicyExitCode + tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: v1.RestartPolicyNever, + expectedType: tfv1alpha2.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := newTFJob(1, 0) + specRestartPolicy := tfv1alpha2.RestartPolicyNever + tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: v1.RestartPolicyNever, + expectedType: tfv1alpha2.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := newTFJob(1, 0) + specRestartPolicy := tfv1alpha2.RestartPolicyAlways + tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: v1.RestartPolicyAlways, + expectedType: tfv1alpha2.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := newTFJob(1, 0) + specRestartPolicy := tfv1alpha2.RestartPolicyOnFailure + tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: v1.RestartPolicyOnFailure, + expectedType: tfv1alpha2.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := newTFJob(1, 0) + specRestartPolicy := tfv1alpha2.RestartPolicy("") + tfJob.Spec.TFReplicaSpecs[tfv1alpha2.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: v1.RestartPolicyNever, + expectedType: tfv1alpha2.TFReplicaTypeWorker, + } + }(), + } + for _, c := range testCase { + spec := c.tfJob.Spec.TFReplicaSpecs[c.expectedType] + podTemplate := spec.Template + setRestartPolicy(&podTemplate, spec) + if podTemplate.Spec.RestartPolicy != c.expectedRestartPolicy { + t.Errorf("Expected %s, got %s", c.expectedRestartPolicy, podTemplate.Spec.RestartPolicy) + } + } +}