Skip to content

Commit

Permalink
[Feats]: Add smddp dist backend option (open-mmlab#579)
Browse files Browse the repository at this point in the history
* Add smddp dist backend option

* [Dev]: Upgrade pre commit hooks (open-mmlab#576)

* Upgrade the versions of pre-commit-hooks

* update zh-cn.yaml

* [Docs] Fix the docstring of model sub-package (open-mmlab#573)

* [Doc]: Update config.md (open-mmlab#562)

* Update config.md

* Update config.md

* [Doc] delete the error comment  in docs (open-mmlab#514)

Co-authored-by: Zaida Zhou <[email protected]>
Co-authored-by: Zhengfei-0311 <[email protected]>
Co-authored-by: vansin <[email protected]>
  • Loading branch information
4 people authored Oct 8, 2022
1 parent e73c4bf commit 8864bd8
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ def _init_dist_mpi(backend, **kwargs) -> None:
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
if backend == 'smddp':
try:
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
'Please use an Amazon SageMaker DLC to access smdistributed: '
'https://github.com/aws/deep-learning-containers/blob/master'
'/available_images.md#sagemaker-framework-containers'
'-sm-support-only') from e
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
Expand Down Expand Up @@ -433,6 +442,8 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
elif backend == 'cncl':
import torch_mlu # noqa: F401
return torch.device('mlu', torch.mlu.current_device())
elif backend == 'smddp':
return torch.device('cuda', torch.cuda.current_device())
else:
# GLOO and MPI backends use cpu device by default
return torch.device('cpu')
Expand Down

0 comments on commit 8864bd8

Please sign in to comment.