Skip to content

Commit

Permalink
Improve performance of sequence parallel gather, scatter, and reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
bclyang committed Aug 22, 2024
1 parent f26b886 commit 8e7400f
Showing 1 changed file with 43 additions and 13 deletions.
56 changes: 43 additions & 13 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,30 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim):
if get_fp32_allreduce():
input_ = input_.float()

assert input_.shape[seq_dim] % world_size == 0
tensor_list = list(
torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim)
)
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(output, tensor_list)
dim_size = list(input_.size())
assert (
isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0
), "seq_dim must be a valid tensor dim"
assert dim_size[seq_dim] % world_size == 0

if seq_dim == 0:
dim_size[seq_dim] = dim_size[seq_dim] // world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
)
else:
tensor_list = list(
torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim)
)
output = torch.empty_like(tensor_list[0])
torch.distributed.reduce_scatter(output, tensor_list)

# reconvert to original Bf16/Fp16 dtype
if get_fp32_allreduce():
input_ = input_.to(dt)
output = output.to(dt)

return output

Expand All @@ -123,12 +137,28 @@ def _gather_along_seq_dim(input_, seq_dim):
if world_size == 1:
return input_

input_ = input_.contiguous()
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
output = torch.cat(tensor_list, dim=seq_dim)
dim_size = list(input_.size())
assert (
isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0
), "seq_dim must be a valid tensor dim"
dim_size[seq_dim] = dim_size[seq_dim] * world_size

if seq_dim == 0:
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=get_model_parallel_group()
)
else:
input_ = input_.contiguous()
rank = get_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(
tensor_list, input_, group=get_model_parallel_group()
)
output = torch.cat(tensor_list, dim=seq_dim)

return output

Expand Down

0 comments on commit 8e7400f

Please sign in to comment.