diff --git a/pkg/controller.v1/common/job.go b/pkg/controller.v1/common/job.go index b3533f02..2b52a9d8 100644 --- a/pkg/controller.v1/common/job.go +++ b/pkg/controller.v1/common/job.go @@ -3,6 +3,7 @@ package common import ( "fmt" "reflect" + "strings" "time" apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" @@ -316,7 +317,41 @@ func (jc *JobController) PastActiveDeadline(runPolicy *apiv1.RunPolicy, jobStatu // this method applies only to pods when restartPolicy is one of OnFailure, Always or ExitCode func (jc *JobController) PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec, pods []*v1.Pod) (bool, error) { - return core.PastBackoffLimit(jobName, runPolicy, replicas, pods, jc.FilterPodsForReplicaType) + if runPolicy.BackoffLimit == nil { + return false, nil + } + result := int32(0) + for rtype, spec := range replicas { + if spec.RestartPolicy != apiv1.RestartPolicyOnFailure && spec.RestartPolicy != apiv1.RestartPolicyAlways { + log.Warnf("The restart policy of replica %v of the job %v is not OnFailure or Always. Not counted in backoff limit.", rtype, jobName) + continue + } + // Convert ReplicaType to lower string. + rt := strings.ToLower(string(rtype)) + pods, err := jc.FilterPodsForReplicaType(pods, rt) + if err != nil { + return false, err + } + for i := range pods { + po := pods[i] + if po.Status.Phase != v1.PodRunning { + continue + } + for j := range po.Status.InitContainerStatuses { + stat := po.Status.InitContainerStatuses[j] + result += stat.RestartCount + } + for j := range po.Status.ContainerStatuses { + stat := po.Status.ContainerStatuses[j] + result += stat.RestartCount + } + } + } + + if *runPolicy.BackoffLimit == 0 { + return result > 0, nil + } + return result >= *runPolicy.BackoffLimit, nil } func (jc *JobController) CleanupJob(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus, job interface{}) error { diff --git a/pkg/controller.v1/common/util_test.go b/pkg/controller.v1/common/util_test.go index ecaab7aa..4d0ec17e 100644 --- a/pkg/controller.v1/common/util_test.go +++ b/pkg/controller.v1/common/util_test.go @@ -18,15 +18,13 @@ import ( "testing" "github.com/stretchr/testify/assert" - - apiv1 "github.com/kubeflow/common/pkg/apis/common/v1" ) func TestGenGeneralName(t *testing.T) { tcs := []struct { index string key string - replicaType apiv1.ReplicaType + replicaType string expectedName string }{ {