Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix test and lint
Browse files Browse the repository at this point in the history
Signed-off-by: Yuvraj <[email protected]>
  • Loading branch information
yindia committed Aug 30, 2021
1 parent 92d9b25 commit 45cea17
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 66 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ replace (

)

replace github.com/flyteorg/flyteidl => github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d
replace github.com/flyteorg/flyteidl => github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/ernesto-jimenez/gogen v0.0.0-20180125220232-d7d4131e6607/go.mod h1:Cg4fM0vhYWOZdgM7RIOSTRNIc8/VT7CXClC3Ni86lu4=
github.com/euank/go-kmsg-parser v2.0.0+incompatible/go.mod h1:MhmAMZ8V4CYH4ybgdRwPr2TU5ThnS43puaKEMpja1uw=
github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d h1:faDhAzZI2LBWSgVA5bLc5FsN6hngPtToNxjLMPPgx8w=
github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743 h1:c0TgBrZkbs9xzz1UnqHP8bGkAnInKOxQOTSxkexIfFk=
github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U=
github.com/evanphx/json-patch v0.0.0-20180908160633-36442dbdb585/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/evanphx/json-patch v0.0.0-20190203023257-5858425f7550/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
Expand Down
74 changes: 24 additions & 50 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ import (

const (
TensorflowTaskType = "tensorflow"
MPITaskType = "mpi"
MPITaskType = "mpi"
PytorchTaskType = "pytorch"
)

func ExtractMPIMPICurrentCondition(jobConditions []mpiOp.JobCondition) (mpiOp.JobCondition, error) {
func ExtractMPICurrentCondition(jobConditions []mpiOp.JobCondition) (mpiOp.JobCondition, error) {
if jobConditions != nil {
sort.Slice(jobConditions, func(i, j int) bool {
return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time)
Expand Down Expand Up @@ -131,56 +131,30 @@ func GetLogs(taskType string, name string, namespace string,
}
taskLogs = append(taskLogs, workerLog.TaskLogs...)
}
// get all parameter servers logs
for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ {
psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, psReplicaLog.TaskLogs...)
}
// get chief worker log, and the max number of chief worker is 1
if chiefReplicasCount != 0 {
chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...)
}

return taskLogs, nil
}

func GetMPILogs(name string, namespace string,
workersCount int32, launcherReplicasCount int32) ([]*core.TaskLog, error) {
taskLogs := make([]*core.TaskLog, 0, 10)

logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig())

if err != nil {
return nil, err
}

if logPlugin == nil {
return nil, nil
}

// get all workers log
for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ {
workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-worker-%d", workerIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
if taskType != MPITaskType && taskType != PytorchTaskType {
// get all parameter servers logs
for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ {
psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, psReplicaLog.TaskLogs...)
}
// get chief worker log, and the max number of chief worker is 1
if chiefReplicasCount != 0 {
chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...)
}
taskLogs = append(taskLogs, workerLog.TaskLogs...)
}

return taskLogs, nil
Expand Down
120 changes: 120 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,44 @@
package common

import (
"fmt"
"testing"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/logs"

mpiOp "github.com/kubeflow/common/pkg/apis/common/v1"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
)

func TestExtractMPICurrentCondition(t *testing.T) {
jobCreated := mpiOp.JobCondition{
Type: mpiOp.JobCreated,
Status: corev1.ConditionTrue,
}
jobRunningActive := mpiOp.JobCondition{
Type: mpiOp.JobRunning,
Status: corev1.ConditionFalse,
}
jobConditions := []mpiOp.JobCondition{
jobCreated,
jobRunningActive,
}
currentCondition, err := ExtractMPICurrentCondition(jobConditions)
assert.NoError(t, err)
assert.Equal(t, currentCondition, jobCreated)

jobConditions = nil
currentCondition, err = ExtractMPICurrentCondition(jobConditions)
assert.Error(t, err)
assert.Equal(t, currentCondition, mpiOp.JobCondition{})
assert.Equal(t, err, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions))
}

func TestExtractCurrentCondition(t *testing.T) {
jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
Expand All @@ -26,6 +55,12 @@ func TestExtractCurrentCondition(t *testing.T) {
currentCondition, err := ExtractCurrentCondition(jobConditions)
assert.NoError(t, err)
assert.Equal(t, currentCondition, jobCreated)

jobConditions = nil
currentCondition, err = ExtractCurrentCondition(jobConditions)
assert.Error(t, err)
assert.Equal(t, currentCondition, commonOp.JobCondition{})
assert.Equal(t, err, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions))
}

func TestGetPhaseInfo(t *testing.T) {
Expand Down Expand Up @@ -64,4 +99,89 @@ func TestGetPhaseInfo(t *testing.T) {
assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobRestarting = commonOp.JobCondition{
Type: commonOp.JobRunning,
}
taskPhase, err = GetPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)
}

func TestGetMPIPhaseInfo(t *testing.T) {
jobCreated := mpiOp.JobCondition{
Type: mpiOp.JobCreated,
}
taskPhase, err := GetMPIPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobSucceeded := mpiOp.JobCondition{
Type: mpiOp.JobSucceeded,
}
taskPhase, err = GetMPIPhaseInfo(jobSucceeded, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobFailed := mpiOp.JobCondition{
Type: mpiOp.JobFailed,
}
taskPhase, err = GetMPIPhaseInfo(jobFailed, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobRestarting := mpiOp.JobCondition{
Type: mpiOp.JobRestarting,
}
taskPhase, err = GetMPIPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobRestarting = mpiOp.JobCondition{
Type: mpiOp.JobRunning,
}
taskPhase, err = GetMPIPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)
}

func TestGetLogs(t *testing.T) {
assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{
IsKubernetesEnabled: true,
KubernetesURL: "k8s.com",
}))

workers := int32(1)
launcher := int32(1)

jobLogs, err := GetLogs(MPITaskType, "test", "mpi-namespace", workers, launcher, 0)
assert.NoError(t, err)
assert.Equal(t, 1, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].Uri)

jobLogs, err = GetLogs(PytorchTaskType, "test", "pytorch-namespace", workers, launcher, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[1].Uri)

jobLogs, err = GetLogs(TensorflowTaskType, "test", "tensorflow-namespace", workers, launcher, 1)
assert.NoError(t, err)
assert.Equal(t, 3, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-psReplica-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[1].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[2].Uri)

}
23 changes: 14 additions & 9 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,18 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}
workers := mpiTaskExtraArgs.GetWorkers()
launcherReplicas := mpiTaskExtraArgs.GetLauncherReplicas()
workers := mpiTaskExtraArgs.GetNumWorkers()
launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas()
slots := mpiTaskExtraArgs.GetSlots()

workersPodSpec := podSpec.DeepCopy()
for k := range workersPodSpec.Containers {
workersPodSpec.Containers[k].Args = []string{}
workersPodSpec.Containers[k].Command = []string{}
}

jobSpec := mpi.MPIJobSpec{
SlotsPerWorker: &slots,
//MainContainer: "",
MPIReplicaSpecs: map[mpi.MPIReplicaType]*commonKf.ReplicaSpec{
mpi.MPIReplicaTypeLauncher: &commonKf.ReplicaSpec{
Replicas: &launcherReplicas,
Expand All @@ -79,7 +84,7 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
mpi.MPIReplicaTypeWorker: &commonKf.ReplicaSpec{
Replicas: &workers,
Template: v1.PodTemplateSpec{
Spec: *podSpec,
Spec: *workersPodSpec,
},
RestartPolicy: commonKf.RestartPolicyNever,
},
Expand All @@ -103,15 +108,15 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app := resource.(*mpi.MPIJob)

workersCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeWorker].Replicas
launcherReplicasCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeLauncher].Replicas
numWorkers := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeWorker].Replicas
numLauncherReplicas := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeLauncher].Replicas

taskLogs, err := common.GetMPILogs(app.Name, app.Namespace,
*workersCount, *launcherReplicasCount)
taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace,
*numWorkers, *numLauncherReplicas, 0)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
currentCondition, err := common.ExtractMPIMPICurrentCondition(app.Status.Conditions)
currentCondition, err := common.ExtractMPICurrentCondition(app.Status.Conditions)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down
8 changes: 4 additions & 4 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ var (

func dummyMPICustomObj(workers int32, launcher int32, slots int32) *plugins.DistributedMPITrainingTask {
return &plugins.DistributedMPITrainingTask{
Workers: workers,
LauncherReplicas: launcher,
Slots: slots,
NumWorkers: workers,
NumLauncherReplicas: launcher,
Slots: slots,
}
}

Expand Down Expand Up @@ -347,7 +347,7 @@ func TestGetLogs(t *testing.T) {

mpiResourceHandler := mpiOperatorResourceHandler{}
mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, commonKf.JobRunning)
jobLogs, err := common.GetMPILogs(mpiJob.Name, mpiJob.Namespace, workers, launcher)
jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, workers, launcher, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri)
Expand Down

0 comments on commit 45cea17

Please sign in to comment.