Skip to content

Commit

Permalink
Fix ImportErrors on Multinode if package not present (#15963)
Browse files Browse the repository at this point in the history
(cherry picked from commit cbd4dd6)
  • Loading branch information
justusschock authored and Borda committed Dec 8, 2022
1 parent 7deac02 commit f130c54
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed PythonServer generating noise on M1 ([#15949](https://github.com/Lightning-AI/lightning/pull/15949))
- Fixed multiprocessing breakpoint ([#15950](https://github.com/Lightning-AI/lightning/pull/15950))
- Fixed detection of a Lightning App running in debug mode ([#15951](https://github.com/Lightning-AI/lightning/pull/15951))
- Fixed `ImportError` on Multinode if package not present ([#15963](https://github.com/Lightning-AI/lightning/pull/15963))

## [1.8.3] - 2022-11-22

Expand Down
13 changes: 8 additions & 5 deletions src/lightning_app/components/multi_node/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ def run(
mps_accelerators = []

for pkg_name in ("lightning.lite", "lightning_" + "lite"):
pkg = importlib.import_module(pkg_name)
lites.append(pkg.LightningLite)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
try:
pkg = importlib.import_module(pkg_name)
lites.append(pkg.LightningLite)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
except (ImportError, ModuleNotFoundError):
continue

# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
Expand Down
13 changes: 8 additions & 5 deletions src/lightning_app/components/multi_node/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ def run(
mps_accelerators = []

for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"):
pkg = importlib.import_module(pkg_name)
trainers.append(pkg.Trainer)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
try:
pkg = importlib.import_module(pkg_name)
trainers.append(pkg.Trainer)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
except (ImportError, ModuleNotFoundError):
continue

# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
Expand Down

0 comments on commit f130c54

Please sign in to comment.