Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Distributed Adam #9037

Merged
merged 6 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -774,18 +774,21 @@ def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any:
model_dtype = torch.float32
if self.megatron_amp_O2 and hasattr(self, 'autocast_dtype'):
model_dtype = self.autocast_dtype
optim_kwargs['param_sync_dtype'] = model_dtype
# Don't override user desired value
if 'param_sync_dtype' not in optim_config:
optim_kwargs['param_sync_dtype'] = model_dtype

# Determine whether to store master params in optimizer
if self.cfg.get('fp8_params', False):
optim_kwargs['store_params'] = True
elif optim_dtype == model_dtype:
optim_kwargs['store_params'] = False
elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
optim_kwargs['store_params'] = True
if 'store_params' not in optim_config:
if self.cfg.get('fp8_params', False):
optim_kwargs['store_params'] = True
elif optim_dtype == model_dtype:
optim_kwargs['store_params'] = False
elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
optim_kwargs['store_params'] = True

return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs)

Expand Down
91 changes: 68 additions & 23 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,62 @@
from megatron.core.dist_checkpointing.optimizer import get_param_id_to_sharded_param_map, optim_state_to_sharding_state
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8

from nemo.utils import str_to_dtype
from nemo.utils import logging, str_to_dtype
from nemo.utils.te_utils import is_float8tensor

_distribute_within_nodes_pgs = {}


def create_distribute_within_nodes_pgs():
"""Create process groups for distributing with nodes.

User can reuse this function to reorder communicators for SHArP.
"""
global _distribute_within_nodes_pgs
assert torch.distributed.is_initialized()
if _distribute_within_nodes_pgs:
return _distribute_within_nodes_pgs

world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
devices = torch.cuda.device_count()
nodes = world_size // devices

if nodes * devices != world_size:
logging.warning("Expected all nodes have the same amout of devices, disable distribute_within_nodes.")
return {}

node_id = rank // devices
device_id = rank % devices

distributed_pgs = []
for i in range(nodes):
ranks = [i * devices + j for j in range(devices)]
pg = torch.distributed.new_group(ranks=ranks)
distributed_pgs.append(pg)

redundant_pgs = []
for i in range(devices):
ranks = [i + j * devices for j in range(nodes)]
pg = torch.distributed.new_group(ranks=ranks)
redundant_pgs.append(pg)

# To re-order SHArP communicator right after distributed init,
# we have to expose redundant_process_group to user.
# User has too invoke allreduce through redundant_process_group
# before all other communicators to lock SHArP tree.
_distribute_within_nodes_pgs = {
'world_size': world_size,
'rank': rank,
'devices': devices,
'nodes': nodes,
'node_id': node_id,
'device_id': device_id,
'distributed_process_group': distributed_pgs[node_id],
'redundant_process_group': redundant_pgs[device_id],
}
return _distribute_within_nodes_pgs


class MegatronDistributedFusedAdam(DistributedFusedAdam):
"""Adam optimizer with ZeRO algorithm
Expand Down Expand Up @@ -78,27 +131,12 @@ def __init__(
kwargs['distributed_process_group'] = self_groups[rank]
kwargs['redundant_process_group'] = kwargs['process_group']
elif distribute_within_nodes:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
devices = torch.cuda.device_count()
nodes = world_size // devices
assert nodes * devices == world_size, "Expected all nodes have teh same amout of devices."
node_id = rank // devices
device_id = rank % devices

distributed_pgs = []
for i in range(nodes):
ranks = [i * devices + j for j in range(devices)]
pg = torch.distributed.new_group(ranks=ranks)
distributed_pgs.append(pg)
kwargs['distributed_process_group'] = distributed_pgs[node_id]

redundant_pgs = []
for i in range(devices):
ranks = [i + j * devices for j in range(nodes)]
pg = torch.distributed.new_group(ranks=ranks)
redundant_pgs.append(pg)
kwargs['redundant_process_group'] = redundant_pgs[device_id]
dist_pg_infos = create_distribute_within_nodes_pgs()
if dist_pg_infos:
kwargs['distributed_process_group'] = dist_pg_infos['distributed_process_group']
kwargs['redundant_process_group'] = dist_pg_infos['redundant_process_group']
global _distribute_within_nodes_pgs
_distribute_within_nodes_pgs = {}

# Make sure dtypes are in right type
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):
Expand Down Expand Up @@ -355,6 +393,8 @@ def init_param_buffer(self) -> None:
f"Attempted to change a parameter with dtype={param.dtype} "
f"into a buffer view with dtype={param_buffer_view.dtype}"
)
if param.is_contiguous(memory_format=torch.channels_last):
param = param.permute(0, 2, 3, 1)
param_flat_views.append(param.detach().view(-1))
param_buffer_views.append(param_buffer_view)

Expand All @@ -368,7 +408,12 @@ def init_param_buffer(self) -> None:
if is_float8tensor(param):
param._data = buffer_view.view(param.size())
else:
param.data = buffer_view.view(param.size())
# Preserve memory format for param here, i.e. NHWC tensors
# `param.data.set_()` failed to change storage.
# `param.set_()` invalidates bprop hook.
param.data = torch.as_strided(
buffer_view, param.size(), param.stride(), storage_offset=buffer_view.storage_offset(),
)

def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Attempt to launch gradient synchronization"""
Expand Down
2 changes: 1 addition & 1 deletion nemo/utils/callbacks/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __init__(self, capture_iteration=-1):
raise Exception("Warmup must run at least 11 DDP-enabled eager iterations before capture.")
if torch.distributed.is_initialized():
raise Exception("CUDAGraphCallback should be initialized before process group.")
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"

self.state = CUDAGraphState(capture_iteration=capture_iteration)

Expand Down
Loading