From ff2a04207f2aecfe18921407195a8b631b1f7e3f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 19 Aug 2020 23:06:03 +0800 Subject: [PATCH] Init test --- go.mod | 1 - .../k8s/kfoperators/common/common_operator.go | 124 ++++++++++++++++++ .../common/common_operator_test.go | 66 ++++++++++ .../k8s/kfoperators/pytorch/pytorch.go | 16 +-- .../k8s/kfoperators/pytorch/pytorch_test.go | 4 +- .../k8s/kfoperators/tensorflow/tensorflow.go | 111 +--------------- .../kfoperators/tensorflow/tensorflow_test.go | 5 +- 7 files changed, 209 insertions(+), 118 deletions(-) create mode 100644 go/tasks/plugins/k8s/kfoperators/common/common_operator.go create mode 100644 go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go diff --git a/go.mod b/go.mod index c33d111b3..7d9a5163a 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,6 @@ require ( k8s.io/client-go v11.0.1-0.20190918222721-c0e3722d5cf0+incompatible k8s.io/utils v0.0.0-20200124190032-861946025e34 sigs.k8s.io/controller-runtime v0.5.1 - k8s.io/klog v1.0.0 // indirect sigs.k8s.io/yaml v1.2.0 // indirect ) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go new file mode 100644 index 000000000..488b0953c --- /dev/null +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -0,0 +1,124 @@ +package common + +import ( + "fmt" + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" + logUtils "github.com/lyft/flyteidl/clients/go/coreutils/logs" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + flyteerr "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/logs" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + v1 "k8s.io/api/core/v1" + "sort" + "time" +) + +const ( + TensorflowTaskType = "tensorflow" + PytorchTaskType = "pytorch" +) + +func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { + sort.Slice(jobConditions[:], func(i, j int) bool { + return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time) + }) + + for _, jc := range jobConditions { + if jc.Status == v1.ConditionTrue { + return jc, nil + } + } + + return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions) +} + +func GetPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Time, + taskPhaseInfo pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error){ + switch currentCondition.Type { + case commonOp.JobCreated: + return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil + case commonOp.JobRunning: + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil + case commonOp.JobSucceeded: + return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil + case commonOp.JobFailed: + details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) + return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil + case commonOp.JobRestarting: + details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) + return pluginsCore.PhaseInfoRetryableFailure(flyteerr.RuntimeFailure, details, &taskPhaseInfo), nil + } + + return pluginsCore.PhaseInfoUndefined, nil +} + +func GetLogs(taskType string, name string, namespace string, + workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) { + // If kubeClient was available, it would be better to use + // https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/logs/logging_utils.go#L12 + makeTaskLog := func(appName, appNamespace, suffix, url string) (core.TaskLog, error) { + return logUtils.NewKubernetesLogPlugin(url).GetTaskLog( + appName+"-"+suffix, + appNamespace, + "", + "", + suffix+" logs (via Kubernetes)") + } + + var taskLogs []*core.TaskLog + + logConfig := logs.GetLogConfig() + if logConfig.IsKubernetesEnabled { + + if taskType == PytorchTaskType { + masterTaskLog, masterErr := makeTaskLog(name, namespace, "master-0", logConfig.KubernetesURL) + if masterErr != nil { + return nil, masterErr + } + taskLogs = append(taskLogs, &masterTaskLog) + } + + // get all workers log + for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { + workerLog, err := makeTaskLog(name, namespace, fmt.Sprintf("worker-%d", workerIndex), logConfig.KubernetesURL) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, &workerLog) + } + // get all parameter servers logs + for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { + psReplicaLog, err := makeTaskLog(name, namespace, fmt.Sprintf("psReplica-%d", psReplicaIndex), logConfig.KubernetesURL) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, &psReplicaLog) + } + if chiefReplicasCount != 0 { + // get chief worker log, and the max number of chief worker is 1 + chiefReplicaLog, err := makeTaskLog(name, namespace, fmt.Sprintf("chiefReplica-%d", 0), logConfig.KubernetesURL) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, &chiefReplicaLog) + } + } + return taskLogs, nil +} + +func OverrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, + defaultContainerName string) { + // Pytorch operator forces pod to have container named 'pytorch' + // https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62 + // Tensorflow operator forces pod to have container named 'tensorflow' + // https://github.com/kubeflow/tf-operator/blob/984adc287e6fe82841e4ca282dc9a2cbb71e2d4a/pkg/apis/tensorflow/validation/validation.go#L55-L63 + // hence we have to override the name set here + // https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116 + flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + for idx, c := range podSpec.Containers { + if c.Name == flyteDefaultContainerName { + podSpec.Containers[idx].Name = defaultContainerName + return + } + } +} diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go new file mode 100644 index 000000000..30c68868f --- /dev/null +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -0,0 +1,66 @@ +package common + +import ( + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + "testing" + "time" +) + +func TestExtractCurrentCondition(t *testing.T) { + jobCreated := commonOp.JobCondition{ + Type: commonOp.JobCreated, + Status: corev1.ConditionTrue, + } + jobRunningActive := commonOp.JobCondition{ + Type: commonOp.JobRunning, + Status: corev1.ConditionFalse, + } + jobConditions := []commonOp.JobCondition{ + jobCreated, + jobRunningActive, + } + currentCondition, err := ExtractCurrentCondition(jobConditions) + assert.NoError(t, err) + assert.Equal(t, currentCondition, jobCreated) +} + +func TestGetPhaseInfo(t *testing.T) { + jobCreated := commonOp.JobCondition{ + Type: commonOp.JobCreated, + } + taskPhase, err := GetPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobSucceeded := commonOp.JobCondition{ + Type: commonOp.JobSucceeded, + } + taskPhase, err = GetPhaseInfo(jobSucceeded, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobFailed := commonOp.JobCondition{ + Type: commonOp.JobFailed, + } + taskPhase, err = GetPhaseInfo(jobFailed, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobRestarting := commonOp.JobCondition{ + Type: commonOp.JobRestarting, + } + taskPhase, err = GetPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) +} diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 898f87da6..51cdd2ee8 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -3,8 +3,8 @@ package pytorch import ( "context" "fmt" - "sort" + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" "time" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/tasklog" @@ -29,10 +29,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -const ( - pytorchTaskType = "pytorch" -) - type pytorchOperatorResourceHandler struct { } @@ -71,7 +67,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - overrideDefaultContainerName(taskCtx, podSpec) + common.OverrideDefaultContainerName(taskCtx, podSpec, ptOp.DefaultContainerName) workers := pytorchTaskExtraArgs.GetWorkers() @@ -113,12 +109,12 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont workersCount := app.Spec.PyTorchReplicaSpecs[ptOp.PyTorchReplicaTypeWorker].Replicas - taskLogs, err := getLogs(app, *workersCount) + taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, *workersCount, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } - currentCondition, err := extractCurrentCondition(app.Status.Conditions) + currentCondition, err := common.ExtractCurrentCondition(app.Status.Conditions) if err != nil { return pluginsCore.PhaseInfoUndefined, err } @@ -223,8 +219,8 @@ func init() { pluginmachinery.PluginRegistry().RegisterK8sPlugin( k8s.PluginEntry{ - ID: pytorchTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{pytorchTaskType}, + ID: common.PytorchTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{common.PytorchTaskType}, ResourceToWatch: &ptOp.PyTorchJob{}, Plugin: pytorchOperatorResourceHandler{}, IsDefault: false, diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 67d4fb4aa..facc785d7 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -3,6 +3,7 @@ package pytorch import ( "context" "fmt" + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" "testing" "time" @@ -345,7 +346,8 @@ func TestGetLogs(t *testing.T) { workers := int32(2) pytorchResourceHandler := pytorchOperatorResourceHandler{} - jobLogs, err := getLogs(dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning), workers) + pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) + jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, workers, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 563f008fc..8545801fb 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -2,8 +2,7 @@ package tensorflow import ( "context" - "fmt" - "sort" + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" "time" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" @@ -17,20 +16,12 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" - logUtils "github.com/lyft/flyteidl/clients/go/coreutils/logs" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flyteplugins/go/tasks/logs" - //commonOp "github.com/kubeflow/common/pkg/apis/common/v1" // switch to real 'common' once https://github.com/kubeflow/pytorch-operator/issues/263 resolved commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" tfOp "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -const ( - tensorflowTaskType = "tensorflow" -) - type tensorflowOperatorResourceHandler struct { } @@ -69,7 +60,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - overrideDefaultContainerName(taskCtx, podSpec) + common.OverrideDefaultContainerName(taskCtx, podSpec, tfOp.DefaultContainerName) workers := tensorflowTaskExtraArgs.GetWorkers() psReplicas := tensorflowTaskExtraArgs.GetPsReplicas() @@ -123,12 +114,12 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(ctx context.Context, plugi psReplicasCount := app.Spec.TFReplicaSpecs[tfOp.TFReplicaTypePS].Replicas chiefCount := app.Spec.TFReplicaSpecs[tfOp.TFReplicaTypeChief].Replicas - taskLogs, err := getLogs(app, *workersCount, *psReplicasCount, *chiefCount) + taskLogs, err := common.GetLogs(common.TensorflowTaskType, app.Name, app.Namespace, *workersCount, *psReplicasCount, *chiefCount) if err != nil { return pluginsCore.PhaseInfoUndefined, err } - currentCondition, err := extractCurrentCondition(app.Status.Conditions) + currentCondition, err := common.ExtractCurrentCondition(app.Status.Conditions) if err != nil { return pluginsCore.PhaseInfoUndefined, err } @@ -141,95 +132,7 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(ctx context.Context, plugi CustomInfo: statusDetails, } - switch currentCondition.Type { - case commonOp.JobCreated: - return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil - case commonOp.JobRunning: - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil - case commonOp.JobSucceeded: - return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil - case commonOp.JobFailed: - details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) - return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil - case commonOp.JobRestarting: - details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) - return pluginsCore.PhaseInfoRetryableFailure(flyteerr.RuntimeFailure, details, &taskPhaseInfo), nil - } - - return pluginsCore.PhaseInfoUndefined, nil -} - -func getLogs(app *tfOp.TFJob, workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) { - // If kubeClient was available, it would be better to use - // https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/logs/logging_utils.go#L12 - makeTaskLog := func(appName, appNamespace, suffix, url string) (core.TaskLog, error) { - return logUtils.NewKubernetesLogPlugin(url).GetTaskLog( - appName+"-"+suffix, - appNamespace, - "", - "", - suffix+" logs (via Kubernetes)") - } - - var taskLogs []*core.TaskLog - - logConfig := logs.GetLogConfig() - if logConfig.IsKubernetesEnabled { - - // get all workers log - for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { - workerLog, err := makeTaskLog(app.Name, app.Namespace, fmt.Sprintf("worker-%d", workerIndex), logConfig.KubernetesURL) - if err != nil { - return nil, err - } - taskLogs = append(taskLogs, &workerLog) - } - // get all parameter servers logs - for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { - psReplicaLog, err := makeTaskLog(app.Name, app.Namespace, fmt.Sprintf("psReplica-%d", psReplicaIndex), logConfig.KubernetesURL) - if err != nil { - return nil, err - } - taskLogs = append(taskLogs, &psReplicaLog) - } - if chiefReplicasCount != 0 { - // get chief worker log, and the max number of chief worker is 1 - chiefReplicaLog, err := makeTaskLog(app.Name, app.Namespace, fmt.Sprintf("chiefReplica-%d", 0), logConfig.KubernetesURL) - if err != nil { - return nil, err - } - taskLogs = append(taskLogs, &chiefReplicaLog) - } - } - return taskLogs, nil -} - -func extractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { - sort.Slice(jobConditions[:], func(i, j int) bool { - return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time) - }) - - for _, jc := range jobConditions { - if jc.Status == v1.ConditionTrue { - return jc, nil - } - } - - return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions) -} - -func overrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec) { - // Pytorch operator forces pod to have container named 'pytorch' - // https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62 - // hence we have to override the name set here - // https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116 - flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - for idx, c := range podSpec.Containers { - if c.Name == flyteDefaultContainerName { - podSpec.Containers[idx].Name = tfOp.DefaultContainerName - return - } - } + return common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) } func init() { @@ -239,8 +142,8 @@ func init() { pluginmachinery.PluginRegistry().RegisterK8sPlugin( k8s.PluginEntry{ - ID: tensorflowTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{tensorflowTaskType}, + ID: common.TensorflowTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{common.TensorflowTaskType}, ResourceToWatch: &tfOp.TFJob{}, Plugin: tensorflowOperatorResourceHandler{}, IsDefault: false, diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 9c667d3db..c75ef7871 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -9,6 +9,7 @@ import ( commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" "github.com/lyft/flyteplugins/go/tasks/logs" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -351,8 +352,8 @@ func TestGetLogs(t *testing.T) { chiefReplicas := int32(1) tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - jobLogs, err := getLogs(dummyTensorflowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, commonOp.JobRunning), - workers, psReplicas, chiefReplicas) + tensorflowJob := dummyTensorflowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, commonOp.JobRunning) + jobLogs, err := common.GetLogs(common.TensorflowTaskType, tensorflowJob.Name,tensorflowJob.Namespace, workers, psReplicas, chiefReplicas) assert.NoError(t, err) assert.Equal(t, 4, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[0].Uri)