diff --git a/pkg/apis/tensorflow/v1/common.go b/pkg/apis/tensorflow/v1/common.go new file mode 100644 index 0000000000..4e508ebe00 --- /dev/null +++ b/pkg/apis/tensorflow/v1/common.go @@ -0,0 +1,23 @@ +// Copyright 2020 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v1 + +// SuccessPolicy is the success policy. +type SuccessPolicy string + +const ( + SuccessPolicyDefault SuccessPolicy = "" + SuccessPolicyAllWorkers SuccessPolicy = "AllWorkers" +) diff --git a/pkg/apis/tensorflow/v1/defaults.go b/pkg/apis/tensorflow/v1/defaults.go index 33dad4bff9..be63704a65 100644 --- a/pkg/apis/tensorflow/v1/defaults.go +++ b/pkg/apis/tensorflow/v1/defaults.go @@ -95,6 +95,11 @@ func SetDefaults_TFJob(tfjob *TFJob) { running := common.CleanPodPolicyRunning tfjob.Spec.CleanPodPolicy = &running } + // Set default success policy to "". + if tfjob.Spec.SuccessPolicy == nil { + defaultPolicy := SuccessPolicyDefault + tfjob.Spec.SuccessPolicy = &defaultPolicy + } // Update the key of TFReplicaSpecs to camel case. setTypeNamesToCamelCase(tfjob) diff --git a/pkg/apis/tensorflow/v1/defaults_test.go b/pkg/apis/tensorflow/v1/defaults_test.go index d1a3b51323..39a8c2c51d 100644 --- a/pkg/apis/tensorflow/v1/defaults_test.go +++ b/pkg/apis/tensorflow/v1/defaults_test.go @@ -51,8 +51,11 @@ func expectedTFJob(cleanPodPolicy common.CleanPodPolicy, restartPolicy common.Re ) } + defaultSuccessPolicy := SuccessPolicyDefault + return &TFJob{ Spec: TFJobSpec{ + SuccessPolicy: &defaultSuccessPolicy, CleanPodPolicy: &cleanPodPolicy, TFReplicaSpecs: map[TFReplicaType]*common.ReplicaSpec{ TFReplicaTypeWorker: &common.ReplicaSpec{ diff --git a/pkg/apis/tensorflow/v1/openapi_generated.go b/pkg/apis/tensorflow/v1/openapi_generated.go index c70f7fcfb0..bc0a939596 100644 --- a/pkg/apis/tensorflow/v1/openapi_generated.go +++ b/pkg/apis/tensorflow/v1/openapi_generated.go @@ -410,6 +410,13 @@ func schema_pkg_apis_tensorflow_v1_TFJobSpec(ref common.ReferenceCallback) commo Format: "int32", }, }, + "successPolicy": { + SchemaProps: spec.SchemaProps{ + Description: "SuccessPolicy defines the policy to mark the TFJob as succeeded. Default to \"\", using the default rules.", + Type: []string{"string"}, + Format: "", + }, + }, "cleanPodPolicy": { SchemaProps: spec.SchemaProps{ Description: "Defines the policy for cleaning up pods after the TFJob completes. Defaults to Running.", diff --git a/pkg/apis/tensorflow/v1/swagger.json b/pkg/apis/tensorflow/v1/swagger.json index 225d4a7216..3b3dfb3a84 100644 --- a/pkg/apis/tensorflow/v1/swagger.json +++ b/pkg/apis/tensorflow/v1/swagger.json @@ -181,6 +181,10 @@ "type": "integer", "format": "int32" }, + "successPolicy": { + "description": "SuccessPolicy defines the policy to mark the TFJob as succeeded. Default to \"\", using the default rules.", + "type": "string" + }, "cleanPodPolicy": { "description": "Defines the policy for cleaning up pods after the TFJob completes. Defaults to Running.", "type": "string" diff --git a/pkg/apis/tensorflow/v1/types.go b/pkg/apis/tensorflow/v1/types.go index e4217f7281..535813e44b 100644 --- a/pkg/apis/tensorflow/v1/types.go +++ b/pkg/apis/tensorflow/v1/types.go @@ -55,6 +55,11 @@ type TFJobSpec struct { // +optional BackoffLimit *int32 `json:"backoffLimit,omitempty"` + // SuccessPolicy defines the policy to mark the TFJob as succeeded. + // Default to "", using the default rules. + // +optional + SuccessPolicy *SuccessPolicy `json:"successPolicy,omitempty"` + // Defines the policy for cleaning up pods after the TFJob completes. // Defaults to Running. // +optional diff --git a/pkg/apis/tensorflow/v1/zz_generated.deepcopy.go b/pkg/apis/tensorflow/v1/zz_generated.deepcopy.go index 5becb2e654..5095de6bb6 100644 --- a/pkg/apis/tensorflow/v1/zz_generated.deepcopy.go +++ b/pkg/apis/tensorflow/v1/zz_generated.deepcopy.go @@ -97,6 +97,11 @@ func (in *TFJobSpec) DeepCopyInto(out *TFJobSpec) { *out = new(int32) **out = **in } + if in.SuccessPolicy != nil { + in, out := &in.SuccessPolicy, &out.SuccessPolicy + *out = new(SuccessPolicy) + **out = **in + } if in.CleanPodPolicy != nil { in, out := &in.CleanPodPolicy, &out.CleanPodPolicy *out = new(commonv1.CleanPodPolicy) diff --git a/pkg/common/util/v1/testutil/tfjob.go b/pkg/common/util/v1/testutil/tfjob.go index 958bb6cd77..ea27c3dc54 100644 --- a/pkg/common/util/v1/testutil/tfjob.go +++ b/pkg/common/util/v1/testutil/tfjob.go @@ -102,6 +102,12 @@ func NewTFJobWithEvaluator(worker, ps, evaluator int) *tfv1.TFJob { return tfJob } +func NewTFJobWithSuccessPolicy(worker, ps int, successPolicy tfv1.SuccessPolicy) *tfv1.TFJob { + tfJob := NewTFJob(worker, ps) + tfJob.Spec.SuccessPolicy = &successPolicy + return tfJob +} + func NewTFJob(worker, ps int) *tfv1.TFJob { tfJob := &tfv1.TFJob{ TypeMeta: metav1.TypeMeta{ diff --git a/pkg/controller.v1/tensorflow/status.go b/pkg/controller.v1/tensorflow/status.go index 114ca8d63f..89b75166ba 100644 --- a/pkg/controller.v1/tensorflow/status.go +++ b/pkg/controller.v1/tensorflow/status.go @@ -113,8 +113,10 @@ func (tc *TFController) updateStatusSingle(tfjob *tfv1.TFJob, rtype tfv1.TFRepli } } else { if rtype == tfv1.TFReplicaTypeWorker { - // All workers are succeeded or worker 0 completed, leave a succeeded condition. - if expected == 0 || worker0Completed { + // Leave a succeeded condition for the following two cases: + // 1. If default success policy is used and worker 0 has completed. + // 2. If `SuccessPolicyAllWorkers` success policy is used and all workers are succeeded. + if expected == 0 || (worker0Completed && *tfjob.Spec.SuccessPolicy != tfv1.SuccessPolicyAllWorkers) { msg := fmt.Sprintf("TFJob %s successfully completed.", tfjob.Name) tc.Recorder.Event(tfjob, v1.EventTypeNormal, tfJobSucceededReason, msg) if tfjob.Status.CompletionTime == nil { diff --git a/pkg/controller.v1/tensorflow/status_test.go b/pkg/controller.v1/tensorflow/status_test.go index 8c81b82dc8..7edc6f801f 100644 --- a/pkg/controller.v1/tensorflow/status_test.go +++ b/pkg/controller.v1/tensorflow/status_test.go @@ -270,6 +270,54 @@ func TestStatus(t *testing.T) { worker0Completed: true, expectedType: common.JobSucceeded, }, + testCase{ + description: "(No chief worker, successPolicy: AllWorkers) worker-0 are succeeded, 3 workers are active", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 3, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: common.JobRunning, + }, + testCase{ + description: "(No chief worker, successPolicy: AllWorkers) 4 workers are succeeded", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: common.JobSucceeded, + }, + testCase{ + description: "(No chief worker, successPolicy: AllWorkers) worker-0 is succeeded, 2 workers are running, 1 worker is failed", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 1, + expectedSucceededWorker: 1, + expectedActiveWorker: 2, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: common.JobFailed, + }, testCase{ description: "Chief is running, workers are failed", tfJob: testutil.NewTFJobWithChief(4, 2),