Skip to content

Commit

Permalink
Destroy process group in atexit handler (#19931)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jun 5, 2024
1 parent b9f215d commit 1a6786d
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870), [#19872](https://github.com/Lightning-AI/pytorch-lightning/pull/19872))

- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))


### Changed

Expand Down
10 changes: 10 additions & 0 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import contextlib
import logging
import os
Expand Down Expand Up @@ -291,6 +292,10 @@ def _init_dist_connection(
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)

if torch_distributed_backend == "nccl":
# PyTorch >= 2.4 warns about undestroyed NCCL process group, so we need to do it at program exit
atexit.register(_destroy_dist_connection)

# On rank=0 let everyone know training is starting
rank_zero_info(
f"{'-' * 100}\n"
Expand All @@ -300,6 +305,11 @@ def _init_dist_connection(
)


def _destroy_dist_connection() -> None:
if _distributed_is_initialized():
torch.distributed.destroy_process_group()


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"

Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `ModelParallelStrategy` to support 2D parallelism ([#19878](https://github.com/Lightning-AI/pytorch-lightning/pull/19878), [#19888](https://github.com/Lightning-AI/pytorch-lightning/pull/19888))

- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))


### Changed
Expand Down
5 changes: 2 additions & 3 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.distributed
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
from lightning.fabric.utilities.distributed import _distributed_is_initialized
from lightning.fabric.utilities.distributed import _destroy_dist_connection

if sys.version_info >= (3, 9):
from concurrent.futures.process import _ExecutorManagerThread
Expand Down Expand Up @@ -78,8 +78,7 @@ def restore_env_variables():
def teardown_process_group():
"""Ensures that the distributed process group gets closed before the next test runs."""
yield
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
_destroy_dist_connection()


@pytest.fixture(autouse=True)
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import (
_destroy_dist_connection,
_gather_all_tensors,
_InfiniteBarrier,
_init_dist_connection,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
Expand Down Expand Up @@ -217,3 +219,13 @@ def test_infinite_barrier():
barrier.__exit__(None, None, None)
assert barrier.barrier.call_count == 2
dist_mock.destroy_process_group.assert_called_once()


@mock.patch("lightning.fabric.utilities.distributed.atexit")
@mock.patch("lightning.fabric.utilities.distributed.torch.distributed.init_process_group")
def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
_init_dist_connection(LightningEnvironment(), "nccl")
atexit_mock.register.assert_called_once_with(_destroy_dist_connection)
atexit_mock.reset_mock()
_init_dist_connection(LightningEnvironment(), "gloo")
atexit_mock.register.assert_not_called()
5 changes: 2 additions & 3 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch.distributed
from lightning.fabric.plugins.environments.lightning import find_free_network_port
from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver
from lightning.fabric.utilities.distributed import _distributed_is_initialized
from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.accelerators import XLAAccelerator
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
Expand Down Expand Up @@ -123,8 +123,7 @@ def restore_signal_handlers():
def teardown_process_group():
"""Ensures that the distributed process group gets closed before the next test runs."""
yield
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
_destroy_dist_connection()


@pytest.fixture(autouse=True)
Expand Down

0 comments on commit 1a6786d

Please sign in to comment.