Skip to content

Commit

Permalink
Don't add master replica log link when doing elastic pytorch training (
Browse files Browse the repository at this point in the history
…flyteorg#356)

* Don't add master log link when doing elastic pytorch training

Signed-off-by: Fabio Graetz <[email protected]>

* Lint

Signed-off-by: Fabio Graetz <[email protected]>

---------

Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 authored Jun 7, 2023
1 parent 7a28eb2 commit 0a239c4
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func GetMPIPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Tim
}

// GetLogs will return the logs for kubeflow job
func GetLogs(taskType string, name string, namespace string,
func GetLogs(taskType string, name string, namespace string, hasMaster bool,
workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) {
taskLogs := make([]*core.TaskLog, 0, 10)

Expand All @@ -118,7 +118,7 @@ func GetLogs(taskType string, name string, namespace string,
return nil, nil
}

if taskType == PytorchTaskType {
if taskType == PytorchTaskType && hasMaster {
masterTaskLog, masterErr := logPlugin.GetTaskLogs(
tasklog.Input{
PodName: name + "-master-0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,18 @@ func TestGetLogs(t *testing.T) {
workers := int32(1)
launcher := int32(1)

jobLogs, err := GetLogs(MPITaskType, "test", "mpi-namespace", workers, launcher, 0)
jobLogs, err := GetLogs(MPITaskType, "test", "mpi-namespace", false, workers, launcher, 0)
assert.NoError(t, err)
assert.Equal(t, 1, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].Uri)

jobLogs, err = GetLogs(PytorchTaskType, "test", "pytorch-namespace", workers, launcher, 0)
jobLogs, err = GetLogs(PytorchTaskType, "test", "pytorch-namespace", true, workers, launcher, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[1].Uri)

jobLogs, err = GetLogs(TensorflowTaskType, "test", "tensorflow-namespace", workers, launcher, 1)
jobLogs, err = GetLogs(TensorflowTaskType, "test", "tensorflow-namespace", false, workers, launcher, 1)
assert.NoError(t, err)
assert.Equal(t, 3, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].Uri)
Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext
numWorkers = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas
numLauncherReplicas = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas

taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace,
taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace, false,
*numWorkers, *numLauncherReplicas, 0)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func TestGetLogs(t *testing.T) {

mpiResourceHandler := mpiOperatorResourceHandler{}
mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning)
jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, workers, launcher, 0)
jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, false, workers, launcher, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,15 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app := resource.(*kubeflowv1.PyTorchJob)

// Elastic PytorchJobs don't use master replicas
hasMaster := false
if _, ok := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok {
hasMaster = true
}

workersCount := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas

taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, *workersCount, 0, 0)
taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, hasMaster, *workersCount, 0, 0)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,18 +416,37 @@ func TestGetLogs(t *testing.T) {
KubernetesURL: "k8s.com",
}))

hasMaster := true
workers := int32(2)

pytorchResourceHandler := pytorchOperatorResourceHandler{}
pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning)
jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, workers, 0, 0)
jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, hasMaster, workers, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 3, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[1].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[2].Uri)
}

func TestGetLogsElastic(t *testing.T) {
assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{
IsKubernetesEnabled: true,
KubernetesURL: "k8s.com",
}))

hasMaster := false
workers := int32(2)

pytorchResourceHandler := pytorchOperatorResourceHandler{}
pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning)
jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, hasMaster, workers, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri)
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[1].Uri)
}

func TestGetProperties(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}
expected := k8s.PluginProperties{}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginC
psReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas
chiefCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas

taskLogs, err := common.GetLogs(common.TensorflowTaskType, app.Name, app.Namespace,
taskLogs, err := common.GetLogs(common.TensorflowTaskType, app.Name, app.Namespace, false,
*workersCount, *psReplicasCount, *chiefCount)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ func TestGetLogs(t *testing.T) {

tensorflowResourceHandler := tensorflowOperatorResourceHandler{}
tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, commonOp.JobRunning)
jobLogs, err := common.GetLogs(common.TensorflowTaskType, tensorFlowJob.Name, tensorFlowJob.Namespace,
jobLogs, err := common.GetLogs(common.TensorflowTaskType, tensorFlowJob.Name, tensorFlowJob.Namespace, false,
workers, psReplicas, chiefReplicas)
assert.NoError(t, err)
assert.Equal(t, 4, len(jobLogs))
Expand Down

0 comments on commit 0a239c4

Please sign in to comment.