Skip to content

Commit

Permalink
Enhance Distributed Adam (#9051)
Browse files Browse the repository at this point in the history
* Enhance Distributed Adam (#9037)

* Fix deprecated env.

Signed-off-by: Wil Kong <[email protected]>

* Use user desired value for distributed adam.

Signed-off-by: Wil Kong <[email protected]>

* Preserve memory format in parameter buffer of distributed adam.

Signed-off-by: Wil Kong <[email protected]>

* Fix the contiguous_param_buffer bug about bprop overlap and redundant copy after all-gather.

Signed-off-by: Wil Kong <[email protected]>

* Provide API to lock SHArP tree for distributed adam within nodes.

Signed-off-by: Wil Kong <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Wil Kong <[email protected]>

---------

Signed-off-by: Wil Kong <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: ericharper <[email protected]>

---------

Signed-off-by: Wil Kong <[email protected]>
Signed-off-by: ericharper <[email protected]>
Co-authored-by: Wil Kong <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: ericharper <[email protected]>
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
5 people authored and janekl committed Jun 12, 2024
1 parent 72dcde7 commit ab0dd15
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -779,18 +779,21 @@ def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any:
model_dtype = torch.float32
if self.megatron_amp_O2 and hasattr(self, 'autocast_dtype'):
model_dtype = self.autocast_dtype
optim_kwargs['param_sync_dtype'] = model_dtype
# Don't override user desired value
if 'param_sync_dtype' not in optim_config:
optim_kwargs['param_sync_dtype'] = model_dtype

# Determine whether to store master params in optimizer
if self.cfg.get('fp8_params', False):
optim_kwargs['store_params'] = True
elif optim_dtype == model_dtype:
optim_kwargs['store_params'] = False
elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
optim_kwargs['store_params'] = True
if 'store_params' not in optim_config:
if self.cfg.get('fp8_params', False):
optim_kwargs['store_params'] = True
elif optim_dtype == model_dtype:
optim_kwargs['store_params'] = False
elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
optim_kwargs['store_params'] = True

return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs)

Expand Down
94 changes: 71 additions & 23 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,62 @@
from megatron.core.dist_checkpointing.optimizer import get_param_id_to_sharded_param_map, optim_state_to_sharding_state
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8

from nemo.utils import str_to_dtype
from nemo.utils import logging, str_to_dtype
from nemo.utils.te_utils import is_float8tensor

_distribute_within_nodes_pgs = {}


def create_distribute_within_nodes_pgs():
"""Create process groups for distributing with nodes.
User can reuse this function to reorder communicators for SHArP.
"""
global _distribute_within_nodes_pgs
assert torch.distributed.is_initialized()
if _distribute_within_nodes_pgs:
return _distribute_within_nodes_pgs

world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
devices = torch.cuda.device_count()
nodes = world_size // devices

if nodes * devices != world_size:
logging.warning("Expected all nodes have the same amout of devices, disable distribute_within_nodes.")
return {}

node_id = rank // devices
device_id = rank % devices

distributed_pgs = []
for i in range(nodes):
ranks = [i * devices + j for j in range(devices)]
pg = torch.distributed.new_group(ranks=ranks)
distributed_pgs.append(pg)

redundant_pgs = []
for i in range(devices):
ranks = [i + j * devices for j in range(nodes)]
pg = torch.distributed.new_group(ranks=ranks)
redundant_pgs.append(pg)

# To re-order SHArP communicator right after distributed init,
# we have to expose redundant_process_group to user.
# User has too invoke allreduce through redundant_process_group
# before all other communicators to lock SHArP tree.
_distribute_within_nodes_pgs = {
'world_size': world_size,
'rank': rank,
'devices': devices,
'nodes': nodes,
'node_id': node_id,
'device_id': device_id,
'distributed_process_group': distributed_pgs[node_id],
'redundant_process_group': redundant_pgs[device_id],
}
return _distribute_within_nodes_pgs


class MegatronDistributedFusedAdam(DistributedFusedAdam):
"""Adam optimizer with ZeRO algorithm
Expand Down Expand Up @@ -78,27 +131,12 @@ def __init__(
kwargs['distributed_process_group'] = self_groups[rank]
kwargs['redundant_process_group'] = kwargs['process_group']
elif distribute_within_nodes:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
devices = torch.cuda.device_count()
nodes = world_size // devices
assert nodes * devices == world_size, "Expected all nodes have teh same amout of devices."
node_id = rank // devices
device_id = rank % devices

distributed_pgs = []
for i in range(nodes):
ranks = [i * devices + j for j in range(devices)]
pg = torch.distributed.new_group(ranks=ranks)
distributed_pgs.append(pg)
kwargs['distributed_process_group'] = distributed_pgs[node_id]

redundant_pgs = []
for i in range(devices):
ranks = [i + j * devices for j in range(nodes)]
pg = torch.distributed.new_group(ranks=ranks)
redundant_pgs.append(pg)
kwargs['redundant_process_group'] = redundant_pgs[device_id]
dist_pg_infos = create_distribute_within_nodes_pgs()
if dist_pg_infos:
kwargs['distributed_process_group'] = dist_pg_infos['distributed_process_group']
kwargs['redundant_process_group'] = dist_pg_infos['redundant_process_group']
global _distribute_within_nodes_pgs
_distribute_within_nodes_pgs = {}

# Make sure dtypes are in right type
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):
Expand Down Expand Up @@ -380,6 +418,8 @@ def init_param_buffer(self) -> None:
f"Attempted to change a parameter with dtype={param.dtype} "
f"into a buffer view with dtype={param_buffer_view.dtype}"
)
if param.is_contiguous(memory_format=torch.channels_last):
param = param.permute(0, 2, 3, 1)
param_flat_views.append(param.detach().view(-1))
param_buffer_views.append(param_buffer_view)

Expand All @@ -395,7 +435,15 @@ def init_param_buffer(self) -> None:
if is_float8tensor(param):
param._data = buffer_view.view(param.size())
else:
param.data = buffer_view.view(param.size())
# Preserve memory format for param here, i.e. NHWC tensors
# `param.data.set_()` failed to change storage.
# `param.set_()` invalidates bprop hook.
param.data = torch.as_strided(
buffer_view,
param.size(),
param.stride(),
storage_offset=buffer_view.storage_offset(),
)

def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Attempt to launch gradient synchronization"""
Expand Down
20 changes: 16 additions & 4 deletions nemo/utils/callbacks/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,13 @@ def update_metrics(self, key, value, batch_size):


def get_optimizer_step(state):
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None,) -> None:
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_closure=None,
) -> None:
# Not all optimizer supports set_to_none.
if not hasattr(optimizer, "support_set_to_none"):
optimizer.support_set_to_none = is_param_in_hook_signature(
Expand All @@ -175,7 +181,10 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None,) -
with torch.cuda.stream(state.stream):
optimizer.zero_grad(**zero_grad_kwargs)
self.__orig_optimizer_step__(
epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure,
epoch,
batch_idx,
optimizer,
optimizer_closure=optimizer_closure,
)
torch.cuda.current_stream().wait_stream(state.stream)

Expand All @@ -194,7 +203,10 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None,) -
# `zero_grad()` being not captured.
optimizer.zero_grad(**zero_grad_kwargs)
self.__orig_optimizer_step__(
epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure,
epoch,
batch_idx,
optimizer,
optimizer_closure=optimizer_closure,
)
torch.cuda.synchronize()

Expand Down Expand Up @@ -270,7 +282,7 @@ def __init__(self, capture_iteration=-1):
raise Exception("Warmup must run at least 11 DDP-enabled eager iterations before capture.")
if torch.distributed.is_initialized():
raise Exception("CUDAGraphCallback should be initialized before process group.")
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"

self.state = CUDAGraphState(capture_iteration=capture_iteration)

Expand Down

0 comments on commit ab0dd15

Please sign in to comment.