Skip to content

Commit

Permalink
Improving Hydra+DDP support (#11617)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: rohitgr7 <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
4 people committed Sep 24, 2022
1 parent 2d6f70f commit 5163df6
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 37 deletions.
89 changes: 52 additions & 37 deletions src/pytorch_lightning/strategies/launchers/subprocess_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import subprocess
import sys
from time import sleep
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Sequence

import __main__
import numpy as np
Expand All @@ -25,7 +25,7 @@
from lightning_lite.plugins import ClusterEnvironment
from lightning_lite.strategies.launchers.base import _Launcher

_HYDRA_AVAILABLE = RequirementCache("hydra")
_HYDRA_AVAILABLE = RequirementCache("hydra-core")


class _SubprocessScriptLauncher(_Launcher):
Expand Down Expand Up @@ -101,32 +101,6 @@ def _call_children_scripts(self) -> None:
# allow the user to pass the node rank
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())

# Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
# See https://docs.python.org/3/reference/import.html#main-spec
if __main__.__spec__ is None: # pragma: no-cover
# Script called as `python a/b/c.py`
if _HYDRA_AVAILABLE:
# when user is using hydra find the absolute path
from hydra.utils import to_absolute_path

to_abs_path = to_absolute_path
else:
to_abs_path = os.path.abspath

# pull out the commands used to run the script and resolve the absolute file path
command = sys.argv
try:
full_path = to_abs_path(command[0])
except Exception:
full_path = os.path.abspath(command[0])

command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command
else: # Script called as `python -m a.b.c`
command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]

os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"

for local_rank in range(1, self.num_processes):
Expand All @@ -137,18 +111,18 @@ def _call_children_scripts(self) -> None:
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
del env_copy["PL_GLOBAL_SEED"]

# start process
# if hydra is available and initialized, make sure to set the cwd correctly
cwd: Optional[str] = None
hydra_in_use = False
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd

if HydraConfig.initialized():
cwd = get_original_cwd()
os_cwd = f'"{os.getcwd()}"'
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
subprocess.Popen(command, env=env_copy, cwd=cwd)
hydra_in_use = HydraConfig.initialized()

if hydra_in_use:
command = _hydra_subprocess_cmd(local_rank)
else:
command = _basic_subprocess_cmd(local_rank)

subprocess.Popen(command, env=env_copy)

# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
Expand All @@ -162,3 +136,44 @@ def _check_can_spawn_children(self) -> None:
" Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user,"
" 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented."
)


def _basic_subprocess_cmd(local_rank: int) -> Sequence[str]:
if __main__.__spec__ is None: # pragma: no-cover
return [sys.executable, os.path.abspath(sys.argv[0])] + sys.argv[1:]
else:
return [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]


def _hydra_subprocess_cmd(local_rank: int) -> Sequence[str]:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import to_absolute_path

# when user is using hydra find the absolute path
if __main__.__spec__ is None: # pragma: no-cover
command = [sys.executable, to_absolute_path(sys.argv[0])]
else:
command = [sys.executable, "-m", __main__.__spec__.name]

# extract the hydra configu
hydra_cfg = HydraConfig.get()

# the location of the hydra configuration files saved for the current job
hydra_output = hydra_cfg.runtime.output_dir
if hydra_cfg.output_subdir is not None:
hydra_output = os.path.join(hydra_output, hydra_cfg.output_subdir)

# check if experimental re-run capability exists
# otherwise use existing config.yaml which may have issues
pickled_config = os.path.join(hydra_output, "config.pickle")
if os.path.exists(pickled_config):
command += ["--experimental-rerun", pickled_config]

else:
command += ["-cp", hydra_output, "-cn", "config.yaml"]
command += [
f"hydra.output_subdir=.pl_ddp_hydra_{local_rank}",
f"hydra.run.dir={hydra_cfg.runtime.output_dir}",
]

return command
161 changes: 161 additions & 0 deletions tests/tests_pytorch/strategies/launchers/test_subprocess_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import logging
import os
import sys
from pathlib import Path

import pytest
from lightning_utilities.core.imports import RequirementCache

from pytorch_lightning.strategies.launchers.subprocess_script import _HYDRA_AVAILABLE
from tests_pytorch.helpers.runif import RunIf

_HYDRA_WITH_RERUN = RequirementCache("hydra-core>=1.2")
_HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7")

if _HYDRA_AVAILABLE:
from omegaconf import OmegaConf
if _HYDRA_WITH_RUN_PROCESS:
from hydra.test_utils.test_utils import run_process


# fixture to run hydra jobs in a clean temporary directory
# Hydra creates its own output directories and logs
@pytest.fixture
def cleandir(tmp_path):
"""Run function in a temporary directory."""
old_dir = os.getcwd() # get current working directory (cwd)
os.chdir(tmp_path) # change cwd to the temp-directory
yield tmp_path # yields control to the test to be run
os.chdir(old_dir)
logging.shutdown()


# Script to run from command line
script = """
import hydra
import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
class BoringModelGPU(BoringModel):
def on_train_start(self) -> None:
# make sure that the model is on GPU when training
assert self.device == torch.device(f"cuda:{self.trainer.strategy.local_rank}")
@hydra.main(config_path=None, version_base="1.1")
def task_fn(cfg):
trainer = Trainer(accelerator="auto", devices=cfg.devices, strategy=cfg.strategy, fast_dev_run=True)
model = BoringModelGPU()
trainer.fit(model)
trainer.test(model)
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
os.environ.pop("LOCAL_RANK", None)
if __name__ == "__main__":
task_fn()
"""


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"])
def test_ddp_with_hydra_runjob(cleandir, subdir):
# Save script locally
with open("temp.py", "w") as fn:
fn.write(script)

# Run CLI
devices = 2
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"']
if subdir is not None:
cmd += [f"hydra.output_subdir={subdir}"]
run_process(cmd)

# Make sure config.yaml was created for additional
# processes.
logs = list(Path.cwd().glob("**/config.yaml"))
assert len(logs) == devices

# Make sure the parameter was set and used
cfg = OmegaConf.load(logs[0])
assert cfg.devices == devices

# Make sure PL spawned a job that is logged by Hydra
logs = list(Path.cwd().glob("**/*.log"))
assert len(logs) == 1


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
@pytest.mark.parametrize("num_jobs", [1, 2])
def test_ddp_with_hydra_multirunjob(cleandir, num_jobs):
# Save script locally
with open("temp.py", "w") as fn:
fn.write(script)

# create fake multirun params based on `num_jobs`
fake_param = "+foo=" + ",".join(str(i) for i in range(num_jobs))

# Run CLI
run_process([sys.executable, "temp.py", "+devices=2", '+strategy="ddp"', fake_param, "--multirun"])

# Make sure config.yaml was created for each job
configs = sorted(Path.cwd().glob("**/.pl_ddp_hydra_*/config.yaml"))
assert len(configs) == num_jobs

# Make sure the parameter was set and used for each job
for i, config in enumerate(configs):
cfg = OmegaConf.load(config)
local_rank = int(config.parent.parent.parts[-1])
assert cfg.devices == 2
assert cfg.foo == local_rank

logs = list(Path.cwd().glob("**/*.log"))
assert len(logs) == num_jobs


yaml_file = """
hydra:
callbacks:
save_job_info:
_target_: hydra.experimental.callbacks.PickleJobInfoCallback
"""


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.skipif(not _HYDRA_WITH_RERUN, reason=str(_HYDRA_WITH_RERUN))
@pytest.mark.parametrize("num_jobs", [1, 2])
def test_ddp_with_hydra_multirunjob_rerun(cleandir, num_jobs):
# Save script locally
with open("temp.py", "w") as fn:
fn.write(script)

with open("config.yaml", "w") as fn:
fn.write(yaml_file)

# create fake multirun params based on `num_jobs`
fake_param = "+foo=" + ",".join(str(i) for i in range(num_jobs))

# Run CLI
run_process(
[
sys.executable,
"temp.py",
"-cp",
".",
"-cn",
"config.yaml",
"+devices=2",
'+strategy="ddp"',
fake_param,
"--multirun",
]
)

pickles = sorted(Path.cwd().glob("**/.hydra/config.pickle"))
assert len(pickles) == num_jobs

0 comments on commit 5163df6

Please sign in to comment.