Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Init test
Browse files Browse the repository at this point in the history
  • Loading branch information
pingsutw committed Feb 3, 2021
1 parent 6cf435d commit ff2a042
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 118 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ require (
k8s.io/client-go v11.0.1-0.20190918222721-c0e3722d5cf0+incompatible
k8s.io/utils v0.0.0-20200124190032-861946025e34
sigs.k8s.io/controller-runtime v0.5.1
k8s.io/klog v1.0.0 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect
)

Expand Down
124 changes: 124 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package common

import (
"fmt"
commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
logUtils "github.com/lyft/flyteidl/clients/go/coreutils/logs"
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
flyteerr "github.com/lyft/flyteplugins/go/tasks/errors"
"github.com/lyft/flyteplugins/go/tasks/logs"
pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
v1 "k8s.io/api/core/v1"
"sort"
"time"
)

const (
TensorflowTaskType = "tensorflow"
PytorchTaskType = "pytorch"
)

func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
sort.Slice(jobConditions[:], func(i, j int) bool {
return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time)
})

for _, jc := range jobConditions {
if jc.Status == v1.ConditionTrue {
return jc, nil
}
}

return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions)
}

func GetPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Time,
taskPhaseInfo pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error){
switch currentCondition.Type {
case commonOp.JobCreated:
return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil
case commonOp.JobRunning:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil
case commonOp.JobSucceeded:
return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil
case commonOp.JobFailed:
details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message)
return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil
case commonOp.JobRestarting:
details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message)
return pluginsCore.PhaseInfoRetryableFailure(flyteerr.RuntimeFailure, details, &taskPhaseInfo), nil
}

return pluginsCore.PhaseInfoUndefined, nil
}

func GetLogs(taskType string, name string, namespace string,
workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) {
// If kubeClient was available, it would be better to use
// https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/logs/logging_utils.go#L12
makeTaskLog := func(appName, appNamespace, suffix, url string) (core.TaskLog, error) {
return logUtils.NewKubernetesLogPlugin(url).GetTaskLog(
appName+"-"+suffix,
appNamespace,
"",
"",
suffix+" logs (via Kubernetes)")
}

var taskLogs []*core.TaskLog

logConfig := logs.GetLogConfig()
if logConfig.IsKubernetesEnabled {

if taskType == PytorchTaskType {
masterTaskLog, masterErr := makeTaskLog(name, namespace, "master-0", logConfig.KubernetesURL)
if masterErr != nil {
return nil, masterErr
}
taskLogs = append(taskLogs, &masterTaskLog)
}

// get all workers log
for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ {
workerLog, err := makeTaskLog(name, namespace, fmt.Sprintf("worker-%d", workerIndex), logConfig.KubernetesURL)
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, &workerLog)
}
// get all parameter servers logs
for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ {
psReplicaLog, err := makeTaskLog(name, namespace, fmt.Sprintf("psReplica-%d", psReplicaIndex), logConfig.KubernetesURL)
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, &psReplicaLog)
}
if chiefReplicasCount != 0 {
// get chief worker log, and the max number of chief worker is 1
chiefReplicaLog, err := makeTaskLog(name, namespace, fmt.Sprintf("chiefReplica-%d", 0), logConfig.KubernetesURL)
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, &chiefReplicaLog)
}
}
return taskLogs, nil
}

func OverrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec,
defaultContainerName string) {
// Pytorch operator forces pod to have container named 'pytorch'
// https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62
// Tensorflow operator forces pod to have container named 'tensorflow'
// https://github.com/kubeflow/tf-operator/blob/984adc287e6fe82841e4ca282dc9a2cbb71e2d4a/pkg/apis/tensorflow/validation/validation.go#L55-L63
// hence we have to override the name set here
// https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116
flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
for idx, c := range podSpec.Containers {
if c.Name == flyteDefaultContainerName {
podSpec.Containers[idx].Name = defaultContainerName
return
}
}
}
66 changes: 66 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package common

import (
commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
"testing"
"time"
)

func TestExtractCurrentCondition(t *testing.T) {
jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
Status: corev1.ConditionTrue,
}
jobRunningActive := commonOp.JobCondition{
Type: commonOp.JobRunning,
Status: corev1.ConditionFalse,
}
jobConditions := []commonOp.JobCondition{
jobCreated,
jobRunningActive,
}
currentCondition, err := ExtractCurrentCondition(jobConditions)
assert.NoError(t, err)
assert.Equal(t, currentCondition, jobCreated)
}

func TestGetPhaseInfo(t *testing.T) {
jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
}
taskPhase, err := GetPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobSucceeded := commonOp.JobCondition{
Type: commonOp.JobSucceeded,
}
taskPhase, err = GetPhaseInfo(jobSucceeded, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobFailed := commonOp.JobCondition{
Type: commonOp.JobFailed,
}
taskPhase, err = GetPhaseInfo(jobFailed, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobRestarting := commonOp.JobCondition{
Type: commonOp.JobRestarting,
}
taskPhase, err = GetPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)
}
16 changes: 6 additions & 10 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package pytorch
import (
"context"
"fmt"

"sort"
"github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
"time"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/tasklog"
Expand All @@ -29,10 +29,6 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

const (
pytorchTaskType = "pytorch"
)

type pytorchOperatorResourceHandler struct {
}

Expand Down Expand Up @@ -71,7 +67,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

overrideDefaultContainerName(taskCtx, podSpec)
common.OverrideDefaultContainerName(taskCtx, podSpec, ptOp.DefaultContainerName)

workers := pytorchTaskExtraArgs.GetWorkers()

Expand Down Expand Up @@ -113,12 +109,12 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont

workersCount := app.Spec.PyTorchReplicaSpecs[ptOp.PyTorchReplicaTypeWorker].Replicas

taskLogs, err := getLogs(app, *workersCount)
taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, *workersCount, 0, 0)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

currentCondition, err := extractCurrentCondition(app.Status.Conditions)
currentCondition, err := common.ExtractCurrentCondition(app.Status.Conditions)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down Expand Up @@ -223,8 +219,8 @@ func init() {

pluginmachinery.PluginRegistry().RegisterK8sPlugin(
k8s.PluginEntry{
ID: pytorchTaskType,
RegisteredTaskTypes: []pluginsCore.TaskType{pytorchTaskType},
ID: common.PytorchTaskType,
RegisteredTaskTypes: []pluginsCore.TaskType{common.PytorchTaskType},
ResourceToWatch: &ptOp.PyTorchJob{},
Plugin: pytorchOperatorResourceHandler{},
IsDefault: false,
Expand Down
4 changes: 3 additions & 1 deletion go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pytorch
import (
"context"
"fmt"
"github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common"
"testing"
"time"

Expand Down Expand Up @@ -345,7 +346,8 @@ func TestGetLogs(t *testing.T) {
workers := int32(2)

pytorchResourceHandler := pytorchOperatorResourceHandler{}
jobLogs, err := getLogs(dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning), workers)
pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning)
jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, workers, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 3, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri)
Expand Down
Loading

0 comments on commit ff2a042

Please sign in to comment.