From a0774939e8aaaac567bbc1cd8da4215654b79674 Mon Sep 17 00:00:00 2001 From: Yuvraj Date: Thu, 26 Aug 2021 19:13:24 +0530 Subject: [PATCH] Fix test and lint Signed-off-by: Yuvraj --- .../k8s/kfoperators/common/common_operator.go | 72 ++++++------------- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 4 +- .../plugins/k8s/kfoperators/mpi/mpi_test.go | 2 +- 3 files changed, 26 insertions(+), 52 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index c94b32dff..8567ce91c 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -18,7 +18,7 @@ import ( const ( TensorflowTaskType = "tensorflow" - MPITaskType = "mpi" + MPITaskType = "mpi" PytorchTaskType = "pytorch" ) @@ -131,56 +131,30 @@ func GetLogs(taskType string, name string, namespace string, } taskLogs = append(taskLogs, workerLog.TaskLogs...) } - // get all parameter servers logs - for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { - psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), - Namespace: namespace, - }) - if err != nil { - return nil, err - } - taskLogs = append(taskLogs, psReplicaLog.TaskLogs...) - } - // get chief worker log, and the max number of chief worker is 1 - if chiefReplicasCount != 0 { - chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), - Namespace: namespace, - }) - if err != nil { - return nil, err - } - taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...) - } - - return taskLogs, nil -} - -func GetMPILogs(name string, namespace string, - workersCount int32, launcherReplicasCount int32) ([]*core.TaskLog, error) { - taskLogs := make([]*core.TaskLog, 0, 10) - - logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig()) - if err != nil { - return nil, err - } - - if logPlugin == nil { - return nil, nil - } - - // get all workers log - for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { - workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-worker-%d", workerIndex), - Namespace: namespace, - }) - if err != nil { - return nil, err + if taskType != MPITaskType { + // get all parameter servers logs + for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { + psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), + Namespace: namespace, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, psReplicaLog.TaskLogs...) + } + // get chief worker log, and the max number of chief worker is 1 + if chiefReplicasCount != 0 { + chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), + Namespace: namespace, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...) } - taskLogs = append(taskLogs, workerLog.TaskLogs...) } return taskLogs, nil diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index b42ca9656..0cc7e8e0c 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -106,8 +106,8 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext workersCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeWorker].Replicas launcherReplicasCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeLauncher].Replicas - taskLogs, err := common.GetMPILogs(app.Name, app.Namespace, - *workersCount, *launcherReplicasCount) + taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace, + *workersCount, *launcherReplicasCount, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 2186a5ebd..970c91a72 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -347,7 +347,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, commonKf.JobRunning) - jobLogs, err := common.GetMPILogs(mpiJob.Name, mpiJob.Namespace, workers, launcher) + jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, workers, launcher, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri)