Skip to content

Commit

Permalink
cmd changes
Browse files Browse the repository at this point in the history
Signed-off-by: Yuvraj <[email protected]>
  • Loading branch information
yindia committed Oct 26, 2021
1 parent 2e43df6 commit a06b88f
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from flytekit import PythonFunctionTask
from flytekit.extend import SerializationSettings, TaskPlugins
from flytekit.models import common as _common
from flytekit.models import task as _task_model


class MPIJobModel(_common.FlyteIdlEntity):
Expand Down Expand Up @@ -82,8 +81,6 @@ class MPIJob(object):
slots: int
num_launcher_replicas: int = 1
num_workers: int = 1
per_replica_requests: Optional[_task_model.Resources] = None
per_replica_limits: Optional[_task_model.Resources] = None


class MPIFunctionTask(PythonFunctionTask[MPIJob]):
Expand Down Expand Up @@ -119,18 +116,13 @@ def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs):
task_config=task_config,
task_function=task_function,
task_type=self._MPI_JOB_TASK_TYPE,
**{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits},
**kwargs,
)

def get_command(self, settings: SerializationSettings) -> List[str]:
cmd = super().get_command(settings)
num_procs = self.task_config.num_workers * self.task_config.slots
mpi_cmd = (
self._MPI_BASE_COMMAND
+ ["-np", f"{num_procs}"]
+ ["python", settings.entrypoint_settings.path]
+ cmd[1:]
)
mpi_cmd = self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] + ["python", settings.entrypoint_settings.path] + cmd
# the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile
return mpi_cmd

Expand Down

0 comments on commit a06b88f

Please sign in to comment.