Skip to content

Commit

Permalink
Open-source fused sequence parallelism
Browse files Browse the repository at this point in the history
ghstack-source-id: b519cc30a7e9b407c2930a9da875eb7eb481ca53
Pull Request resolved: fairinternal/xformers#1003

__original_commit__ = fairinternal/xformers@804f630
  • Loading branch information
lw authored and xFormers Bot committed Jan 25, 2024
1 parent d9ccf34 commit 342de87
Show file tree
Hide file tree
Showing 16 changed files with 4,317 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.0.24] - TBD
### Added
- Added components for model/sequence parallelism, as near-drop-in replacements for FairScale/Megatron Column&RowParallelLinear modules. They support fusing communication and computation for sequence parallelism, thus making the communication effectively free.

## [0.0.23] - 2023-12-05
Pre-built binary wheels require PyTorch 2.1.1
Expand Down
100 changes: 100 additions & 0 deletions tests/multiprocessing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import concurrent.futures
import multiprocessing
import signal
import tempfile
from typing import List

import torch


class SafeMpContext:
def __init__(self) -> None:
self.mp_context = multiprocessing.get_context("spawn")
self.processes: List[multiprocessing.context.SpawnProcess] = []

def Process(self, *args, **kwargs) -> multiprocessing.context.SpawnProcess:
p = self.mp_context.Process(*args, **kwargs)
p.daemon = True
self.processes.append(p)
return p

def kill_all_processes(self):
for p in self.processes:
p.terminate()
p.join(1)
if p.exitcode is None:
p.kill()
p.join()

def log_bad_exit_codes(self):
for rank, p in enumerate(self.processes):
if p.exitcode == 0:
continue
if p.exitcode < 0:
try:
signal_desc = f" (signal {signal.Signals(-p.exitcode).name})"
except ValueError:
signal_desc = " (unrecognized signal)"
else:
signal_desc = ""
print(
f"Child process for rank #{rank} with PID {p.pid} exited with code {p.exitcode}{signal_desc}"
)

def __getattr__(self, name: str):
return getattr(self.mp_context, name)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.kill_all_processes()
self.log_bad_exit_codes()


def _launch_subprocesses_fn_wrapper(
init_method: str, rank: int, world_size: int, user_fn, args, kwargs
):
torch._C._set_print_stack_traces_on_fatal_signal(True)

if torch.cuda.device_count() >= world_size:
backend = "nccl"
torch.cuda.set_device(rank)
else:
# Use Gloo instead of NCCL so that we can run on a single GPU
backend = "gloo"
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method,
)
return user_fn(*args, **kwargs)


def launch_subprocesses(world_size: int, fn, *args, **kwargs):
with SafeMpContext() as mp_context, concurrent.futures.ProcessPoolExecutor(
max_workers=world_size, mp_context=mp_context
) as e, tempfile.NamedTemporaryFile(mode="w+b", buffering=-1, delete=True) as rdv:
futures = [
e.submit(
_launch_subprocesses_fn_wrapper,
init_method=f"file://{rdv.name}",
rank=rank,
world_size=world_size,
user_fn=fn,
args=args,
kwargs=kwargs,
)
for rank in range(world_size)
]
done, _ = concurrent.futures.wait(
futures, return_when=concurrent.futures.FIRST_EXCEPTION
)
for f in done:
f.result()
261 changes: 261 additions & 0 deletions tests/test_seqpar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import os
import random
from typing import Tuple

import pytest
import torch

from xformers.ops import (
sequence_parallel_leading_matmul,
sequence_parallel_trailing_matmul,
)

from .multiprocessing_utils import launch_subprocesses

compute_capability = (0, 0)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
cuda_sm70_only = pytest.mark.skipif(
compute_capability < (7, 0), reason="requires sm70+"
)
at_least_2_gpus = pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="needs at least 2 GPUs"
)


def reference_leading(input_, w1, w2):
hidden1 = torch.matmul(input_, w1.t())
hidden2 = torch.matmul(input_, w2.t())
return [hidden1, hidden2]


def reference_trailing(hidden, w):
output = torch.matmul(hidden, w.t())
return output


def xformers_leading(input_, w1, w2, *, fuse, group):
return sequence_parallel_leading_matmul(
input_, [w1.t(), w2.t()], fuse=fuse, process_group=group
)


def xformers_trailing(hidden, w, *, fuse, group):
return sequence_parallel_trailing_matmul(
hidden, w.t(), fuse=fuse, process_group=group
)


def inner_seqpar(
kind: str,
step: str,
dims: Tuple[int, ...],
dtype: torch.dtype,
seed: int,
):
my_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
subgroup = torch.distributed.new_group()

fused = True
if kind == "unfused":
fused = False
elif kind == "fallback":
os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1"

torch.random.manual_seed(seed)

batch_dims = dims[:-2]
outer_dim = dims[-2]
inner_dim = dims[-1]

# To check for correctness we want to compare the outputs but the accuracy
# of matmuls, apparently, is not that great. We thus try to produce inputs
# for which no rounding at all will occur. We do this by using zero or one
# inputs, so their product will also be zero or one, and keep the reduction
# dimension small enough so that they fit in the mantissa without overflow.
max_exact_value = 2 * (1 / torch.finfo(dtype).eps)
# 0.25 is the ratio of expected ones and we aim at 2/3 of the safe range
assert outer_dim * 0.25 <= max_exact_value * 0.66
assert inner_dim * world_size * 0.25 <= max_exact_value * 0.66

def my_chunk(t, *, dim):
return t.tensor_split(world_size, dim=dim)[my_rank]

if step == "leading":
input_ = torch.testing.make_tensor(
batch_dims + (outer_dim,),
dtype=dtype,
device="cuda",
low=0,
high=1,
).round()
weight1, weight2 = [
torch.testing.make_tensor(
(inner_dim, outer_dim),
dtype=dtype,
device="cuda",
low=0,
high=1,
).round()
for _ in range(2)
]
gradient1, gradient2 = [
torch.testing.make_tensor(
batch_dims + (inner_dim,),
dtype=dtype,
device="cuda",
low=0,
high=1,
).round()
for _ in range(2)
]

# Non-fused reference code
input_ref = input_.detach().requires_grad_()
weight1_ref = weight1.detach().requires_grad_()
weight2_ref = weight2.detach().requires_grad_()

output1_ref, output2_ref = reference_leading(
input_ref, weight1_ref, weight2_ref
)
torch.autograd.backward([output1_ref, output2_ref], [gradient1, gradient2])

my_output1_ref = my_chunk(output1_ref, dim=-1)
my_output2_ref = my_chunk(output2_ref, dim=-1)
my_weight1_grad_ref = my_chunk(weight1_ref.grad, dim=0)
my_weight2_grad_ref = my_chunk(weight2_ref.grad, dim=0)
my_input_grad_ref = my_chunk(input_ref.grad, dim=0)

# Faster fused mode
my_input_xf = my_chunk(input_, dim=0).detach().requires_grad_()
my_weight1_xf = my_chunk(weight1, dim=0).detach().requires_grad_()
my_weight2_xf = my_chunk(weight2, dim=0).detach().requires_grad_()
my_gradient1 = my_chunk(gradient1, dim=-1)
my_gradient2 = my_chunk(gradient2, dim=-1)

my_output1_xf, my_output2_xf = xformers_leading(
my_input_xf, my_weight1_xf, my_weight2_xf, fuse=fused, group=subgroup
)
torch.autograd.backward(
[my_output1_xf, my_output2_xf], [my_gradient1, my_gradient2]
)

my_weight1_grad_xf = my_weight1_xf.grad
my_weight2_grad_xf = my_weight2_xf.grad
my_input_grad_xf = my_input_xf.grad

# Checks
torch.testing.assert_close(my_output1_ref, my_output1_xf)
torch.testing.assert_close(my_output2_ref, my_output2_xf)
torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf)
torch.testing.assert_close(my_weight1_grad_ref, my_weight1_grad_xf)
torch.testing.assert_close(my_weight2_grad_ref, my_weight2_grad_xf)

elif step == "trailing":
input_ = torch.testing.make_tensor(
batch_dims + (inner_dim,),
dtype=dtype,
device="cuda",
low=0,
high=1,
).round()
weight = torch.testing.make_tensor(
(outer_dim, inner_dim),
dtype=dtype,
device="cuda",
low=0,
high=1,
).round()
gradient = torch.testing.make_tensor(
batch_dims + (outer_dim,),
dtype=dtype,
device="cuda",
low=0,
high=1,
).round()

# Non-fused reference code
input_ref = input_.detach().requires_grad_()
weight_ref = weight.detach().requires_grad_()

output_ref = reference_trailing(input_ref, weight_ref)
torch.autograd.backward([output_ref], [gradient])

my_output_ref = my_chunk(output_ref, dim=0)
my_weight_grad_ref = my_chunk(weight_ref.grad, dim=1)
my_input_grad_ref = my_chunk(input_ref.grad, dim=-1)

# Faster fused mode
my_input_xf = my_chunk(input_, dim=-1).detach().clone().requires_grad_()
my_weight_xf = my_chunk(weight, dim=1).detach().requires_grad_()
my_gradient = my_chunk(gradient, dim=0)

my_output_xf = xformers_trailing(
my_input_xf, my_weight_xf, fuse=fused, group=subgroup
)
torch.autograd.backward([my_output_xf], [my_gradient])

my_weight_grad_xf = my_weight_xf.grad
my_input_grad_xf = my_input_xf.grad

# Checks
torch.testing.assert_close(my_output_ref, my_output_xf)
torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf)
torch.testing.assert_close(my_weight_grad_ref, my_weight_grad_xf)


@cuda_sm70_only
@pytest.mark.parametrize(
"kind",
[
"singleton",
pytest.param("unfused", marks=at_least_2_gpus),
pytest.param("fallback", marks=at_least_2_gpus),
"fused",
],
)
@pytest.mark.parametrize(
"step",
[
"leading",
"trailing",
],
)
@pytest.mark.parametrize(
"dims",
[
pytest.param((2, 2, 512, 512, 256), id="nice-shapes"),
pytest.param((2, 1023, 511, 257), id="ugly-shapes"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(torch.bfloat16, id="bf16"),
pytest.param(torch.float16, id="fp16"),
pytest.param(torch.float32, id="fp32"),
],
)
def test_seqpar(
kind: str,
step: str,
dims: Tuple[int, ...],
dtype: torch.dtype,
):
world_size = 1 if kind == "singleton" else 2
seed = random.getrandbits(32)
launch_subprocesses(
world_size=world_size,
fn=inner_seqpar,
kind=kind,
step=step,
dims=dims,
dtype=dtype,
seed=seed,
)
Loading

0 comments on commit 342de87

Please sign in to comment.