Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

tensorflow plugin implementation #103

Merged
merged 8 commits into from
Feb 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions copilot/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ github.com/aws/aws-sdk-go v1.28.9/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN
github.com/aws/aws-sdk-go v1.29.23 h1:wtiGLOzxAP755OfuVTDIy/NbUIYEDxbIbBEDfNhUpeU=
github.com/aws/aws-sdk-go v1.29.23/go.mod h1:1KvfttTE3SPKMpo8g2c6jL3ZKfXtFvKscTgahTma5Xg=
github.com/aws/aws-sdk-go-v2 v0.20.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/jwIeX6yqzw=
github.com/aws/aws-sdk-go-v2 v1.0.0/go.mod h1:smfAbmpW+tcRVuNUjo3MOArSZmW72t62rkCzc2i0TWM=
github.com/aws/aws-sdk-go-v2/config v1.0.0/go.mod h1:WysE/OpUgE37tjtmtJd8GXgT8s1euilE5XtUkRNUQ1w=
github.com/aws/aws-sdk-go-v2/credentials v1.0.0/go.mod h1:/SvsiqBf509hG4Bddigr3NB12MIpfHhZapyBurJe8aY=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.0/go.mod h1:wpMHDCXvOXZxGCRSidyepa8uJHY4vaBGfY2/+oKU/Bc=
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0/go.mod h1:qY8QFbemf2ceqweXcS6hQqiiIe1z42WqTvHsK2Lb0rE=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0/go.mod h1:3jExOmpbjgPnz2FJaMOfbSk1heTkZ66aD3yNtVhnjvI=
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0/go.mod h1:8/T2od4WQj1qKPr2ppDgjCnMFR6hfYJM4hzjH1D+HWg=
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0/go.mod h1:5f+cELGATgill5Pu3/vK3Ebuigstc+qYEHW5MvGWZO4=
github.com/aws/smithy-go v1.0.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw=
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1/go.mod h1:jvdWlw8vowVGnZqSDC7yhPd7AifQeQbRDkZcQXV2nRg=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
Expand Down Expand Up @@ -233,6 +242,8 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
Expand Down Expand Up @@ -287,6 +298,8 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5i
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc=
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
Expand Down Expand Up @@ -320,6 +333,7 @@ github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0 h1:NGL46+1RYcCXb3sShp0nQq
github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0/go.mod h1:/L5qH+AD540e7Cetbui1tuJeXdmNhO8jM6VkXeDdDhQ=
github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f h1:PGuAMDzAen0AulUfaEhNQMYmUpa41pAVo3zHI+GJsCM=
github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnznGEAqC3DcNm6yEj472xaFVfLM7hnYofMb12tQ=
github.com/lyft/flyteidl v0.18.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.18.11 h1:24NaFYWxANhRbwKfvkgu8axGTWUcl1tgZBqNJutKNJ8=
github.com/lyft/flyteidl v0.18.11/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU=
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ github.com/aws/aws-sdk-go-v2 v0.20.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/
github.com/aws/aws-sdk-go-v2 v0.24.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/jwIeX6yqzw=
github.com/aws/aws-sdk-go-v2 v1.0.0 h1:ncEVPoHArsG+HjoDe/3ex/TG1CbLwMQ4eaWj0UGdyTo=
github.com/aws/aws-sdk-go-v2 v1.0.0/go.mod h1:smfAbmpW+tcRVuNUjo3MOArSZmW72t62rkCzc2i0TWM=
github.com/aws/aws-sdk-go-v2 v1.1.0 h1:sKP6QWxdN1oRYjl+k6S3bpgBI+XUx/0mqVOLIw4lR/Q=
github.com/aws/aws-sdk-go-v2/config v1.0.0 h1:x6vSFAwqAvhYPeSu60f0ZUlGHo3PKKmwDOTL8aMXtv4=
github.com/aws/aws-sdk-go-v2/config v1.0.0/go.mod h1:WysE/OpUgE37tjtmtJd8GXgT8s1euilE5XtUkRNUQ1w=
github.com/aws/aws-sdk-go-v2/credentials v1.0.0 h1:0M7netgZ8gCV4v7z1km+Fbl7j6KQYyZL7SS0/l5Jn/4=
Expand All @@ -168,6 +169,7 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0 h1:IAutMPSryn
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0/go.mod h1:3jExOmpbjgPnz2FJaMOfbSk1heTkZ66aD3yNtVhnjvI=
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0 h1:WAKXnA5HISN6P8sbXsJ9486ThbRPnoBAtMyDSG7+jNM=
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0/go.mod h1:8/T2od4WQj1qKPr2ppDgjCnMFR6hfYJM4hzjH1D+HWg=
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.1.0 h1:qsaGAmYqUzym7g4uaBzx5uOYoEJW0wIHhgObLqZc1mo=
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0 h1:6XCgxNfE4L/Fnq+InhVNd16DKc6Ue1f3dJl3IwwJRUQ=
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0/go.mod h1:5f+cELGATgill5Pu3/vK3Ebuigstc+qYEHW5MvGWZO4=
github.com/aws/smithy-go v1.0.0 h1:hkhcRKG9rJ4Fn+RbfXY7Tz7b3ITLDyolBnLLBhwbg/c=
Expand Down
138 changes: 138 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package common

import (
"fmt"
"sort"
"time"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/tasklog"

commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
flyteerr "github.com/lyft/flyteplugins/go/tasks/errors"
"github.com/lyft/flyteplugins/go/tasks/logs"
pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
v1 "k8s.io/api/core/v1"
)

const (
TensorflowTaskType = "tensorflow"
PytorchTaskType = "pytorch"
)

func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
if jobConditions != nil {
sort.Slice(jobConditions, func(i, j int) bool {
return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time)
})

for _, jc := range jobConditions {
if jc.Status == v1.ConditionTrue {
return jc, nil
}
}
}

return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions)
}

func GetPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Time,
taskPhaseInfo pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error) {
switch currentCondition.Type {
case commonOp.JobCreated:
return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil
case commonOp.JobRunning:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil
case commonOp.JobSucceeded:
return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil
case commonOp.JobFailed:
details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message)
return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil
case commonOp.JobRestarting:
EngHabu marked this conversation as resolved.
Show resolved Hide resolved
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil
}

return pluginsCore.PhaseInfoUndefined, nil
}

func GetLogs(taskType string, name string, namespace string,
workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) {
taskLogs := make([]*core.TaskLog, 0, 10)

logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig())

if err != nil {
return nil, err
}

if logPlugin == nil {
return nil, nil
}

if taskType == PytorchTaskType {
masterTaskLog, masterErr := logPlugin.GetTaskLogs(
tasklog.Input{
PodName: name + "-master-0",
Namespace: namespace,
LogName: "master",
},
)
if masterErr != nil {
return nil, masterErr
}
taskLogs = append(taskLogs, masterTaskLog.TaskLogs...)
}

// get all workers log
for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ {
workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-worker-%d", workerIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, workerLog.TaskLogs...)
}
// get all parameter servers logs
for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ {
psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, psReplicaLog.TaskLogs...)
}
// get chief worker log, and the max number of chief worker is 1
if chiefReplicasCount != 0 {
chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...)
}

return taskLogs, nil
}

func OverrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec,
defaultContainerName string) {
// Pytorch operator forces pod to have container named 'pytorch'
// https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62
// Tensorflow operator forces pod to have container named 'tensorflow'
// https://github.com/kubeflow/tf-operator/blob/984adc287e6fe82841e4ca282dc9a2cbb71e2d4a/pkg/apis/tensorflow/validation/validation.go#L55-L63
// hence we have to override the name set here
// https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116
flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
for idx, c := range podSpec.Containers {
if c.Name == flyteDefaultContainerName {
podSpec.Containers[idx].Name = defaultContainerName
return
}
}
}
67 changes: 67 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package common

import (
"testing"
"time"

commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
)

func TestExtractCurrentCondition(t *testing.T) {
jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
Status: corev1.ConditionTrue,
}
jobRunningActive := commonOp.JobCondition{
Type: commonOp.JobRunning,
Status: corev1.ConditionFalse,
}
jobConditions := []commonOp.JobCondition{
jobCreated,
jobRunningActive,
}
currentCondition, err := ExtractCurrentCondition(jobConditions)
assert.NoError(t, err)
assert.Equal(t, currentCondition, jobCreated)
}

func TestGetPhaseInfo(t *testing.T) {
jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
}
taskPhase, err := GetPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobSucceeded := commonOp.JobCondition{
Type: commonOp.JobSucceeded,
}
taskPhase, err = GetPhaseInfo(jobSucceeded, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobFailed := commonOp.JobCondition{
Type: commonOp.JobFailed,
}
taskPhase, err = GetPhaseInfo(jobFailed, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobRestarting := commonOp.JobCondition{
Type: commonOp.JobRestarting,
}
taskPhase, err = GetPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)
}
Loading