forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Open-source fused sequence parallelism
ghstack-source-id: b519cc30a7e9b407c2930a9da875eb7eb481ca53 Pull Request resolved: fairinternal/xformers#1003 __original_commit__ = fairinternal/xformers@804f630
- Loading branch information
Showing
16 changed files
with
4,317 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import concurrent.futures | ||
import multiprocessing | ||
import signal | ||
import tempfile | ||
from typing import List | ||
|
||
import torch | ||
|
||
|
||
class SafeMpContext: | ||
def __init__(self) -> None: | ||
self.mp_context = multiprocessing.get_context("spawn") | ||
self.processes: List[multiprocessing.context.SpawnProcess] = [] | ||
|
||
def Process(self, *args, **kwargs) -> multiprocessing.context.SpawnProcess: | ||
p = self.mp_context.Process(*args, **kwargs) | ||
p.daemon = True | ||
self.processes.append(p) | ||
return p | ||
|
||
def kill_all_processes(self): | ||
for p in self.processes: | ||
p.terminate() | ||
p.join(1) | ||
if p.exitcode is None: | ||
p.kill() | ||
p.join() | ||
|
||
def log_bad_exit_codes(self): | ||
for rank, p in enumerate(self.processes): | ||
if p.exitcode == 0: | ||
continue | ||
if p.exitcode < 0: | ||
try: | ||
signal_desc = f" (signal {signal.Signals(-p.exitcode).name})" | ||
except ValueError: | ||
signal_desc = " (unrecognized signal)" | ||
else: | ||
signal_desc = "" | ||
print( | ||
f"Child process for rank #{rank} with PID {p.pid} exited with code {p.exitcode}{signal_desc}" | ||
) | ||
|
||
def __getattr__(self, name: str): | ||
return getattr(self.mp_context, name) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self.kill_all_processes() | ||
self.log_bad_exit_codes() | ||
|
||
|
||
def _launch_subprocesses_fn_wrapper( | ||
init_method: str, rank: int, world_size: int, user_fn, args, kwargs | ||
): | ||
torch._C._set_print_stack_traces_on_fatal_signal(True) | ||
|
||
if torch.cuda.device_count() >= world_size: | ||
backend = "nccl" | ||
torch.cuda.set_device(rank) | ||
else: | ||
# Use Gloo instead of NCCL so that we can run on a single GPU | ||
backend = "gloo" | ||
torch.distributed.init_process_group( | ||
backend=backend, | ||
world_size=world_size, | ||
rank=rank, | ||
init_method=init_method, | ||
) | ||
return user_fn(*args, **kwargs) | ||
|
||
|
||
def launch_subprocesses(world_size: int, fn, *args, **kwargs): | ||
with SafeMpContext() as mp_context, concurrent.futures.ProcessPoolExecutor( | ||
max_workers=world_size, mp_context=mp_context | ||
) as e, tempfile.NamedTemporaryFile(mode="w+b", buffering=-1, delete=True) as rdv: | ||
futures = [ | ||
e.submit( | ||
_launch_subprocesses_fn_wrapper, | ||
init_method=f"file://{rdv.name}", | ||
rank=rank, | ||
world_size=world_size, | ||
user_fn=fn, | ||
args=args, | ||
kwargs=kwargs, | ||
) | ||
for rank in range(world_size) | ||
] | ||
done, _ = concurrent.futures.wait( | ||
futures, return_when=concurrent.futures.FIRST_EXCEPTION | ||
) | ||
for f in done: | ||
f.result() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
import random | ||
from typing import Tuple | ||
|
||
import pytest | ||
import torch | ||
|
||
from xformers.ops import ( | ||
sequence_parallel_leading_matmul, | ||
sequence_parallel_trailing_matmul, | ||
) | ||
|
||
from .multiprocessing_utils import launch_subprocesses | ||
|
||
compute_capability = (0, 0) | ||
if torch.cuda.is_available(): | ||
compute_capability = torch.cuda.get_device_capability("cuda") | ||
cuda_sm70_only = pytest.mark.skipif( | ||
compute_capability < (7, 0), reason="requires sm70+" | ||
) | ||
at_least_2_gpus = pytest.mark.skipif( | ||
torch.cuda.device_count() < 2, reason="needs at least 2 GPUs" | ||
) | ||
|
||
|
||
def reference_leading(input_, w1, w2): | ||
hidden1 = torch.matmul(input_, w1.t()) | ||
hidden2 = torch.matmul(input_, w2.t()) | ||
return [hidden1, hidden2] | ||
|
||
|
||
def reference_trailing(hidden, w): | ||
output = torch.matmul(hidden, w.t()) | ||
return output | ||
|
||
|
||
def xformers_leading(input_, w1, w2, *, fuse, group): | ||
return sequence_parallel_leading_matmul( | ||
input_, [w1.t(), w2.t()], fuse=fuse, process_group=group | ||
) | ||
|
||
|
||
def xformers_trailing(hidden, w, *, fuse, group): | ||
return sequence_parallel_trailing_matmul( | ||
hidden, w.t(), fuse=fuse, process_group=group | ||
) | ||
|
||
|
||
def inner_seqpar( | ||
kind: str, | ||
step: str, | ||
dims: Tuple[int, ...], | ||
dtype: torch.dtype, | ||
seed: int, | ||
): | ||
my_rank = torch.distributed.get_rank() | ||
world_size = torch.distributed.get_world_size() | ||
subgroup = torch.distributed.new_group() | ||
|
||
fused = True | ||
if kind == "unfused": | ||
fused = False | ||
elif kind == "fallback": | ||
os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1" | ||
|
||
torch.random.manual_seed(seed) | ||
|
||
batch_dims = dims[:-2] | ||
outer_dim = dims[-2] | ||
inner_dim = dims[-1] | ||
|
||
# To check for correctness we want to compare the outputs but the accuracy | ||
# of matmuls, apparently, is not that great. We thus try to produce inputs | ||
# for which no rounding at all will occur. We do this by using zero or one | ||
# inputs, so their product will also be zero or one, and keep the reduction | ||
# dimension small enough so that they fit in the mantissa without overflow. | ||
max_exact_value = 2 * (1 / torch.finfo(dtype).eps) | ||
# 0.25 is the ratio of expected ones and we aim at 2/3 of the safe range | ||
assert outer_dim * 0.25 <= max_exact_value * 0.66 | ||
assert inner_dim * world_size * 0.25 <= max_exact_value * 0.66 | ||
|
||
def my_chunk(t, *, dim): | ||
return t.tensor_split(world_size, dim=dim)[my_rank] | ||
|
||
if step == "leading": | ||
input_ = torch.testing.make_tensor( | ||
batch_dims + (outer_dim,), | ||
dtype=dtype, | ||
device="cuda", | ||
low=0, | ||
high=1, | ||
).round() | ||
weight1, weight2 = [ | ||
torch.testing.make_tensor( | ||
(inner_dim, outer_dim), | ||
dtype=dtype, | ||
device="cuda", | ||
low=0, | ||
high=1, | ||
).round() | ||
for _ in range(2) | ||
] | ||
gradient1, gradient2 = [ | ||
torch.testing.make_tensor( | ||
batch_dims + (inner_dim,), | ||
dtype=dtype, | ||
device="cuda", | ||
low=0, | ||
high=1, | ||
).round() | ||
for _ in range(2) | ||
] | ||
|
||
# Non-fused reference code | ||
input_ref = input_.detach().requires_grad_() | ||
weight1_ref = weight1.detach().requires_grad_() | ||
weight2_ref = weight2.detach().requires_grad_() | ||
|
||
output1_ref, output2_ref = reference_leading( | ||
input_ref, weight1_ref, weight2_ref | ||
) | ||
torch.autograd.backward([output1_ref, output2_ref], [gradient1, gradient2]) | ||
|
||
my_output1_ref = my_chunk(output1_ref, dim=-1) | ||
my_output2_ref = my_chunk(output2_ref, dim=-1) | ||
my_weight1_grad_ref = my_chunk(weight1_ref.grad, dim=0) | ||
my_weight2_grad_ref = my_chunk(weight2_ref.grad, dim=0) | ||
my_input_grad_ref = my_chunk(input_ref.grad, dim=0) | ||
|
||
# Faster fused mode | ||
my_input_xf = my_chunk(input_, dim=0).detach().requires_grad_() | ||
my_weight1_xf = my_chunk(weight1, dim=0).detach().requires_grad_() | ||
my_weight2_xf = my_chunk(weight2, dim=0).detach().requires_grad_() | ||
my_gradient1 = my_chunk(gradient1, dim=-1) | ||
my_gradient2 = my_chunk(gradient2, dim=-1) | ||
|
||
my_output1_xf, my_output2_xf = xformers_leading( | ||
my_input_xf, my_weight1_xf, my_weight2_xf, fuse=fused, group=subgroup | ||
) | ||
torch.autograd.backward( | ||
[my_output1_xf, my_output2_xf], [my_gradient1, my_gradient2] | ||
) | ||
|
||
my_weight1_grad_xf = my_weight1_xf.grad | ||
my_weight2_grad_xf = my_weight2_xf.grad | ||
my_input_grad_xf = my_input_xf.grad | ||
|
||
# Checks | ||
torch.testing.assert_close(my_output1_ref, my_output1_xf) | ||
torch.testing.assert_close(my_output2_ref, my_output2_xf) | ||
torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf) | ||
torch.testing.assert_close(my_weight1_grad_ref, my_weight1_grad_xf) | ||
torch.testing.assert_close(my_weight2_grad_ref, my_weight2_grad_xf) | ||
|
||
elif step == "trailing": | ||
input_ = torch.testing.make_tensor( | ||
batch_dims + (inner_dim,), | ||
dtype=dtype, | ||
device="cuda", | ||
low=0, | ||
high=1, | ||
).round() | ||
weight = torch.testing.make_tensor( | ||
(outer_dim, inner_dim), | ||
dtype=dtype, | ||
device="cuda", | ||
low=0, | ||
high=1, | ||
).round() | ||
gradient = torch.testing.make_tensor( | ||
batch_dims + (outer_dim,), | ||
dtype=dtype, | ||
device="cuda", | ||
low=0, | ||
high=1, | ||
).round() | ||
|
||
# Non-fused reference code | ||
input_ref = input_.detach().requires_grad_() | ||
weight_ref = weight.detach().requires_grad_() | ||
|
||
output_ref = reference_trailing(input_ref, weight_ref) | ||
torch.autograd.backward([output_ref], [gradient]) | ||
|
||
my_output_ref = my_chunk(output_ref, dim=0) | ||
my_weight_grad_ref = my_chunk(weight_ref.grad, dim=1) | ||
my_input_grad_ref = my_chunk(input_ref.grad, dim=-1) | ||
|
||
# Faster fused mode | ||
my_input_xf = my_chunk(input_, dim=-1).detach().clone().requires_grad_() | ||
my_weight_xf = my_chunk(weight, dim=1).detach().requires_grad_() | ||
my_gradient = my_chunk(gradient, dim=0) | ||
|
||
my_output_xf = xformers_trailing( | ||
my_input_xf, my_weight_xf, fuse=fused, group=subgroup | ||
) | ||
torch.autograd.backward([my_output_xf], [my_gradient]) | ||
|
||
my_weight_grad_xf = my_weight_xf.grad | ||
my_input_grad_xf = my_input_xf.grad | ||
|
||
# Checks | ||
torch.testing.assert_close(my_output_ref, my_output_xf) | ||
torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf) | ||
torch.testing.assert_close(my_weight_grad_ref, my_weight_grad_xf) | ||
|
||
|
||
@cuda_sm70_only | ||
@pytest.mark.parametrize( | ||
"kind", | ||
[ | ||
"singleton", | ||
pytest.param("unfused", marks=at_least_2_gpus), | ||
pytest.param("fallback", marks=at_least_2_gpus), | ||
"fused", | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"step", | ||
[ | ||
"leading", | ||
"trailing", | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"dims", | ||
[ | ||
pytest.param((2, 2, 512, 512, 256), id="nice-shapes"), | ||
pytest.param((2, 1023, 511, 257), id="ugly-shapes"), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
pytest.param(torch.bfloat16, id="bf16"), | ||
pytest.param(torch.float16, id="fp16"), | ||
pytest.param(torch.float32, id="fp32"), | ||
], | ||
) | ||
def test_seqpar( | ||
kind: str, | ||
step: str, | ||
dims: Tuple[int, ...], | ||
dtype: torch.dtype, | ||
): | ||
world_size = 1 if kind == "singleton" else 2 | ||
seed = random.getrandbits(32) | ||
launch_subprocesses( | ||
world_size=world_size, | ||
fn=inner_seqpar, | ||
kind=kind, | ||
step=step, | ||
dims=dims, | ||
dtype=dtype, | ||
seed=seed, | ||
) |
Oops, something went wrong.