Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix test and lint
Browse files Browse the repository at this point in the history
Signed-off-by: Yuvraj <[email protected]>
  • Loading branch information
yindia committed Aug 26, 2021
1 parent 92d9b25 commit a077493
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 52 deletions.
72 changes: 23 additions & 49 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

const (
TensorflowTaskType = "tensorflow"
MPITaskType = "mpi"
MPITaskType = "mpi"
PytorchTaskType = "pytorch"
)

Expand Down Expand Up @@ -131,56 +131,30 @@ func GetLogs(taskType string, name string, namespace string,
}
taskLogs = append(taskLogs, workerLog.TaskLogs...)
}
// get all parameter servers logs
for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ {
psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, psReplicaLog.TaskLogs...)
}
// get chief worker log, and the max number of chief worker is 1
if chiefReplicasCount != 0 {
chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...)
}

return taskLogs, nil
}

func GetMPILogs(name string, namespace string,
workersCount int32, launcherReplicasCount int32) ([]*core.TaskLog, error) {
taskLogs := make([]*core.TaskLog, 0, 10)

logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig())

if err != nil {
return nil, err
}

if logPlugin == nil {
return nil, nil
}

// get all workers log
for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ {
workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-worker-%d", workerIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
if taskType != MPITaskType {
// get all parameter servers logs
for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ {
psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, psReplicaLog.TaskLogs...)
}
// get chief worker log, and the max number of chief worker is 1
if chiefReplicasCount != 0 {
chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...)
}
taskLogs = append(taskLogs, workerLog.TaskLogs...)
}

return taskLogs, nil
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext
workersCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeWorker].Replicas
launcherReplicasCount := app.Spec.MPIReplicaSpecs[mpi.MPIReplicaTypeLauncher].Replicas

taskLogs, err := common.GetMPILogs(app.Name, app.Namespace,
*workersCount, *launcherReplicasCount)
taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace,
*workersCount, *launcherReplicasCount, 0)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func TestGetLogs(t *testing.T) {

mpiResourceHandler := mpiOperatorResourceHandler{}
mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, commonKf.JobRunning)
jobLogs, err := common.GetMPILogs(mpiJob.Name, mpiJob.Namespace, workers, launcher)
jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, workers, launcher, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri)
Expand Down

0 comments on commit a077493

Please sign in to comment.