diff --git a/go.mod b/go.mod index 27d940939..7048e3944 100644 --- a/go.mod +++ b/go.mod @@ -80,4 +80,4 @@ replace ( ) -replace github.com/flyteorg/flyteidl => github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d +replace github.com/flyteorg/flyteidl => github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743 diff --git a/go.sum b/go.sum index aae1bca75..00e7fc7a7 100644 --- a/go.sum +++ b/go.sum @@ -285,8 +285,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/ernesto-jimenez/gogen v0.0.0-20180125220232-d7d4131e6607/go.mod h1:Cg4fM0vhYWOZdgM7RIOSTRNIc8/VT7CXClC3Ni86lu4= github.com/euank/go-kmsg-parser v2.0.0+incompatible/go.mod h1:MhmAMZ8V4CYH4ybgdRwPr2TU5ThnS43puaKEMpja1uw= -github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d h1:faDhAzZI2LBWSgVA5bLc5FsN6hngPtToNxjLMPPgx8w= -github.com/evalsocket/flyteidl v0.18.8-0.20210820144224-2cd6288a159d/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743 h1:c0TgBrZkbs9xzz1UnqHP8bGkAnInKOxQOTSxkexIfFk= +github.com/evalsocket/flyteidl v0.19.26-0.20210827192202-131a810ef743/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= github.com/evanphx/json-patch v0.0.0-20180908160633-36442dbdb585/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v0.0.0-20190203023257-5858425f7550/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index c94b32dff..b1a468189 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -18,11 +18,11 @@ import ( const ( TensorflowTaskType = "tensorflow" - MPITaskType = "mpi" + MPITaskType = "mpi" PytorchTaskType = "pytorch" ) -func ExtractMPIMPICurrentCondition(jobConditions []mpiOp.JobCondition) (mpiOp.JobCondition, error) { +func ExtractMPICurrentCondition(jobConditions []mpiOp.JobCondition) (mpiOp.JobCondition, error) { if jobConditions != nil { sort.Slice(jobConditions, func(i, j int) bool { return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time) @@ -131,56 +131,30 @@ func GetLogs(taskType string, name string, namespace string, } taskLogs = append(taskLogs, workerLog.TaskLogs...) } - // get all parameter servers logs - for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { - psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), - Namespace: namespace, - }) - if err != nil { - return nil, err - } - taskLogs = append(taskLogs, psReplicaLog.TaskLogs...) - } - // get chief worker log, and the max number of chief worker is 1 - if chiefReplicasCount != 0 { - chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), - Namespace: namespace, - }) - if err != nil { - return nil, err - } - taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...) - } - - return taskLogs, nil -} - -func GetMPILogs(name string, namespace string, - workersCount int32, launcherReplicasCount int32) ([]*core.TaskLog, error) { - taskLogs := make([]*core.TaskLog, 0, 10) - - logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig()) - if err != nil { - return nil, err - } - - if logPlugin == nil { - return nil, nil - } - - // get all workers log - for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { - workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-worker-%d", workerIndex), - Namespace: namespace, - }) - if err != nil { - return nil, err + if taskType != MPITaskType && taskType != PytorchTaskType { + // get all parameter servers logs + for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { + psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), + Namespace: namespace, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, psReplicaLog.TaskLogs...) + } + // get chief worker log, and the max number of chief worker is 1 + if chiefReplicasCount != 0 { + chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), + Namespace: namespace, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...) } - taskLogs = append(taskLogs, workerLog.TaskLogs...) } return taskLogs, nil diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index 58b985d9e..776c3f644 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -1,15 +1,44 @@ package common import ( + "fmt" "testing" "time" + "github.com/flyteorg/flyteplugins/go/tasks/logs" + + mpiOp "github.com/kubeflow/common/pkg/apis/common/v1" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" ) +func TestExtractMPICurrentCondition(t *testing.T) { + jobCreated := mpiOp.JobCondition{ + Type: mpiOp.JobCreated, + Status: corev1.ConditionTrue, + } + jobRunningActive := mpiOp.JobCondition{ + Type: mpiOp.JobRunning, + Status: corev1.ConditionFalse, + } + jobConditions := []mpiOp.JobCondition{ + jobCreated, + jobRunningActive, + } + currentCondition, err := ExtractMPICurrentCondition(jobConditions) + assert.NoError(t, err) + assert.Equal(t, currentCondition, jobCreated) + + jobConditions = nil + currentCondition, err = ExtractMPICurrentCondition(jobConditions) + assert.Error(t, err) + assert.Equal(t, currentCondition, mpiOp.JobCondition{}) + assert.Equal(t, err, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions)) +} + func TestExtractCurrentCondition(t *testing.T) { jobCreated := commonOp.JobCondition{ Type: commonOp.JobCreated, @@ -26,6 +55,12 @@ func TestExtractCurrentCondition(t *testing.T) { currentCondition, err := ExtractCurrentCondition(jobConditions) assert.NoError(t, err) assert.Equal(t, currentCondition, jobCreated) + + jobConditions = nil + currentCondition, err = ExtractCurrentCondition(jobConditions) + assert.Error(t, err) + assert.Equal(t, currentCondition, commonOp.JobCondition{}) + assert.Equal(t, err, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions)) } func TestGetPhaseInfo(t *testing.T) { @@ -64,4 +99,89 @@ func TestGetPhaseInfo(t *testing.T) { assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) assert.NotNil(t, taskPhase.Info()) assert.Nil(t, err) + + jobRestarting = commonOp.JobCondition{ + Type: commonOp.JobRunning, + } + taskPhase, err = GetPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) +} + +func TestGetMPIPhaseInfo(t *testing.T) { + jobCreated := mpiOp.JobCondition{ + Type: mpiOp.JobCreated, + } + taskPhase, err := GetMPIPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobSucceeded := mpiOp.JobCondition{ + Type: mpiOp.JobSucceeded, + } + taskPhase, err = GetMPIPhaseInfo(jobSucceeded, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobFailed := mpiOp.JobCondition{ + Type: mpiOp.JobFailed, + } + taskPhase, err = GetMPIPhaseInfo(jobFailed, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobRestarting := mpiOp.JobCondition{ + Type: mpiOp.JobRestarting, + } + taskPhase, err = GetMPIPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobRestarting = mpiOp.JobCondition{ + Type: mpiOp.JobRunning, + } + taskPhase, err = GetMPIPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) +} + +func TestGetLogs(t *testing.T) { + assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ + IsKubernetesEnabled: true, + KubernetesURL: "k8s.com", + })) + + workers := int32(1) + launcher := int32(1) + + jobLogs, err := GetLogs(MPITaskType, "test", "mpi-namespace", workers, launcher, 0) + assert.NoError(t, err) + assert.Equal(t, 1, len(jobLogs)) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].Uri) + + jobLogs, err = GetLogs(PytorchTaskType, "test", "pytorch-namespace", workers, launcher, 0) + assert.NoError(t, err) + assert.Equal(t, 2, len(jobLogs)) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[1].Uri) + + jobLogs, err = GetLogs(TensorflowTaskType, "test", "tensorflow-namespace", workers, launcher, 1) + assert.NoError(t, err) + assert.Equal(t, 3, len(jobLogs)) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-psReplica-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[1].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[2].Uri) + } diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index b42ca9656..553e4169e 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -61,13 +61,18 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - workers := mpiTaskExtraArgs.GetWorkers() - launcherReplicas := mpiTaskExtraArgs.GetLauncherReplicas() + workers := mpiTaskExtraArgs.GetNumWorkers() + launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas() slots := mpiTaskExtraArgs.GetSlots() + workersPodSpec := podSpec.DeepCopy() + for k := range workersPodSpec.Containers { + workersPodSpec.Containers[k].Args = []string{} + workersPodSpec.Containers[k].Command = []string{} + } + jobSpec := mpi.MPIJobSpec{ SlotsPerWorker: &slots, - //MainContainer: "", MPIReplicaSpecs: map[mpi.MPIReplicaType]*commonKf.ReplicaSpec{ mpi.MPIReplicaTypeLauncher: &commonKf.ReplicaSpec{ Replicas: &launcherReplicas, @@ -79,7 +84,7 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu mpi.MPIReplicaTypeWorker: &commonKf.ReplicaSpec{ Replicas: &workers, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + Spec: *workersPodSpec, }, RestartPolicy: commonKf.RestartPolicyNever, }, @@ -103,15 +108,15 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { app := resource.(*mpi.MPIJob) - workersCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeWorker].Replicas - launcherReplicasCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeLauncher].Replicas + numWorkers := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeWorker].Replicas + numLauncherReplicas := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeLauncher].Replicas - taskLogs, err := common.GetMPILogs(app.Name, app.Namespace, - *workersCount, *launcherReplicasCount) + taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace, + *numWorkers, *numLauncherReplicas, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } - currentCondition, err := common.ExtractMPIMPICurrentCondition(app.Status.Conditions) + currentCondition, err := common.ExtractMPICurrentCondition(app.Status.Conditions) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 2186a5ebd..405f6dc21 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -66,9 +66,9 @@ var ( func dummyMPICustomObj(workers int32, launcher int32, slots int32) *plugins.DistributedMPITrainingTask { return &plugins.DistributedMPITrainingTask{ - Workers: workers, - LauncherReplicas: launcher, - Slots: slots, + NumWorkers: workers, + NumLauncherReplicas: launcher, + Slots: slots, } } @@ -347,7 +347,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, commonKf.JobRunning) - jobLogs, err := common.GetMPILogs(mpiJob.Name, mpiJob.Namespace, workers, launcher) + jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, workers, launcher, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri)