diff --git a/CHANGELOG.md b/CHANGELOG.md index 7873138d83..8843d3114f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [0.0.24] - TBD +### Added +- Added components for model/sequence parallelism, as near-drop-in replacements for FairScale/Megatron Column&RowParallelLinear modules. They support fusing communication and computation for sequence parallelism, thus making the communication effectively free. ## [0.0.23] - 2023-12-05 Pre-built binary wheels require PyTorch 2.1.1 diff --git a/tests/multiprocessing_utils.py b/tests/multiprocessing_utils.py new file mode 100644 index 0000000000..0036dd25e0 --- /dev/null +++ b/tests/multiprocessing_utils.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import concurrent.futures +import multiprocessing +import signal +import tempfile +from typing import List + +import torch + + +class SafeMpContext: + def __init__(self) -> None: + self.mp_context = multiprocessing.get_context("spawn") + self.processes: List[multiprocessing.context.SpawnProcess] = [] + + def Process(self, *args, **kwargs) -> multiprocessing.context.SpawnProcess: + p = self.mp_context.Process(*args, **kwargs) + p.daemon = True + self.processes.append(p) + return p + + def kill_all_processes(self): + for p in self.processes: + p.terminate() + p.join(1) + if p.exitcode is None: + p.kill() + p.join() + + def log_bad_exit_codes(self): + for rank, p in enumerate(self.processes): + if p.exitcode == 0: + continue + if p.exitcode < 0: + try: + signal_desc = f" (signal {signal.Signals(-p.exitcode).name})" + except ValueError: + signal_desc = " (unrecognized signal)" + else: + signal_desc = "" + print( + f"Child process for rank #{rank} with PID {p.pid} exited with code {p.exitcode}{signal_desc}" + ) + + def __getattr__(self, name: str): + return getattr(self.mp_context, name) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.kill_all_processes() + self.log_bad_exit_codes() + + +def _launch_subprocesses_fn_wrapper( + init_method: str, rank: int, world_size: int, user_fn, args, kwargs +): + torch._C._set_print_stack_traces_on_fatal_signal(True) + + if torch.cuda.device_count() >= world_size: + backend = "nccl" + torch.cuda.set_device(rank) + else: + # Use Gloo instead of NCCL so that we can run on a single GPU + backend = "gloo" + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + init_method=init_method, + ) + return user_fn(*args, **kwargs) + + +def launch_subprocesses(world_size: int, fn, *args, **kwargs): + with SafeMpContext() as mp_context, concurrent.futures.ProcessPoolExecutor( + max_workers=world_size, mp_context=mp_context + ) as e, tempfile.NamedTemporaryFile(mode="w+b", buffering=-1, delete=True) as rdv: + futures = [ + e.submit( + _launch_subprocesses_fn_wrapper, + init_method=f"file://{rdv.name}", + rank=rank, + world_size=world_size, + user_fn=fn, + args=args, + kwargs=kwargs, + ) + for rank in range(world_size) + ] + done, _ = concurrent.futures.wait( + futures, return_when=concurrent.futures.FIRST_EXCEPTION + ) + for f in done: + f.result() diff --git a/tests/test_seqpar.py b/tests/test_seqpar.py new file mode 100644 index 0000000000..8cc535b115 --- /dev/null +++ b/tests/test_seqpar.py @@ -0,0 +1,261 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +from typing import Tuple + +import pytest +import torch + +from xformers.ops import ( + sequence_parallel_leading_matmul, + sequence_parallel_trailing_matmul, +) + +from .multiprocessing_utils import launch_subprocesses + +compute_capability = (0, 0) +if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability("cuda") +cuda_sm70_only = pytest.mark.skipif( + compute_capability < (7, 0), reason="requires sm70+" +) +at_least_2_gpus = pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="needs at least 2 GPUs" +) + + +def reference_leading(input_, w1, w2): + hidden1 = torch.matmul(input_, w1.t()) + hidden2 = torch.matmul(input_, w2.t()) + return [hidden1, hidden2] + + +def reference_trailing(hidden, w): + output = torch.matmul(hidden, w.t()) + return output + + +def xformers_leading(input_, w1, w2, *, fuse, group): + return sequence_parallel_leading_matmul( + input_, [w1.t(), w2.t()], fuse=fuse, process_group=group + ) + + +def xformers_trailing(hidden, w, *, fuse, group): + return sequence_parallel_trailing_matmul( + hidden, w.t(), fuse=fuse, process_group=group + ) + + +def inner_seqpar( + kind: str, + step: str, + dims: Tuple[int, ...], + dtype: torch.dtype, + seed: int, +): + my_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + subgroup = torch.distributed.new_group() + + fused = True + if kind == "unfused": + fused = False + elif kind == "fallback": + os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1" + + torch.random.manual_seed(seed) + + batch_dims = dims[:-2] + outer_dim = dims[-2] + inner_dim = dims[-1] + + # To check for correctness we want to compare the outputs but the accuracy + # of matmuls, apparently, is not that great. We thus try to produce inputs + # for which no rounding at all will occur. We do this by using zero or one + # inputs, so their product will also be zero or one, and keep the reduction + # dimension small enough so that they fit in the mantissa without overflow. + max_exact_value = 2 * (1 / torch.finfo(dtype).eps) + # 0.25 is the ratio of expected ones and we aim at 2/3 of the safe range + assert outer_dim * 0.25 <= max_exact_value * 0.66 + assert inner_dim * world_size * 0.25 <= max_exact_value * 0.66 + + def my_chunk(t, *, dim): + return t.tensor_split(world_size, dim=dim)[my_rank] + + if step == "leading": + input_ = torch.testing.make_tensor( + batch_dims + (outer_dim,), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + weight1, weight2 = [ + torch.testing.make_tensor( + (inner_dim, outer_dim), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + for _ in range(2) + ] + gradient1, gradient2 = [ + torch.testing.make_tensor( + batch_dims + (inner_dim,), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + for _ in range(2) + ] + + # Non-fused reference code + input_ref = input_.detach().requires_grad_() + weight1_ref = weight1.detach().requires_grad_() + weight2_ref = weight2.detach().requires_grad_() + + output1_ref, output2_ref = reference_leading( + input_ref, weight1_ref, weight2_ref + ) + torch.autograd.backward([output1_ref, output2_ref], [gradient1, gradient2]) + + my_output1_ref = my_chunk(output1_ref, dim=-1) + my_output2_ref = my_chunk(output2_ref, dim=-1) + my_weight1_grad_ref = my_chunk(weight1_ref.grad, dim=0) + my_weight2_grad_ref = my_chunk(weight2_ref.grad, dim=0) + my_input_grad_ref = my_chunk(input_ref.grad, dim=0) + + # Faster fused mode + my_input_xf = my_chunk(input_, dim=0).detach().requires_grad_() + my_weight1_xf = my_chunk(weight1, dim=0).detach().requires_grad_() + my_weight2_xf = my_chunk(weight2, dim=0).detach().requires_grad_() + my_gradient1 = my_chunk(gradient1, dim=-1) + my_gradient2 = my_chunk(gradient2, dim=-1) + + my_output1_xf, my_output2_xf = xformers_leading( + my_input_xf, my_weight1_xf, my_weight2_xf, fuse=fused, group=subgroup + ) + torch.autograd.backward( + [my_output1_xf, my_output2_xf], [my_gradient1, my_gradient2] + ) + + my_weight1_grad_xf = my_weight1_xf.grad + my_weight2_grad_xf = my_weight2_xf.grad + my_input_grad_xf = my_input_xf.grad + + # Checks + torch.testing.assert_close(my_output1_ref, my_output1_xf) + torch.testing.assert_close(my_output2_ref, my_output2_xf) + torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf) + torch.testing.assert_close(my_weight1_grad_ref, my_weight1_grad_xf) + torch.testing.assert_close(my_weight2_grad_ref, my_weight2_grad_xf) + + elif step == "trailing": + input_ = torch.testing.make_tensor( + batch_dims + (inner_dim,), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + weight = torch.testing.make_tensor( + (outer_dim, inner_dim), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + gradient = torch.testing.make_tensor( + batch_dims + (outer_dim,), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + + # Non-fused reference code + input_ref = input_.detach().requires_grad_() + weight_ref = weight.detach().requires_grad_() + + output_ref = reference_trailing(input_ref, weight_ref) + torch.autograd.backward([output_ref], [gradient]) + + my_output_ref = my_chunk(output_ref, dim=0) + my_weight_grad_ref = my_chunk(weight_ref.grad, dim=1) + my_input_grad_ref = my_chunk(input_ref.grad, dim=-1) + + # Faster fused mode + my_input_xf = my_chunk(input_, dim=-1).detach().clone().requires_grad_() + my_weight_xf = my_chunk(weight, dim=1).detach().requires_grad_() + my_gradient = my_chunk(gradient, dim=0) + + my_output_xf = xformers_trailing( + my_input_xf, my_weight_xf, fuse=fused, group=subgroup + ) + torch.autograd.backward([my_output_xf], [my_gradient]) + + my_weight_grad_xf = my_weight_xf.grad + my_input_grad_xf = my_input_xf.grad + + # Checks + torch.testing.assert_close(my_output_ref, my_output_xf) + torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf) + torch.testing.assert_close(my_weight_grad_ref, my_weight_grad_xf) + + +@cuda_sm70_only +@pytest.mark.parametrize( + "kind", + [ + "singleton", + pytest.param("unfused", marks=at_least_2_gpus), + pytest.param("fallback", marks=at_least_2_gpus), + "fused", + ], +) +@pytest.mark.parametrize( + "step", + [ + "leading", + "trailing", + ], +) +@pytest.mark.parametrize( + "dims", + [ + pytest.param((2, 2, 512, 512, 256), id="nice-shapes"), + pytest.param((2, 1023, 511, 257), id="ugly-shapes"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(torch.bfloat16, id="bf16"), + pytest.param(torch.float16, id="fp16"), + pytest.param(torch.float32, id="fp32"), + ], +) +def test_seqpar( + kind: str, + step: str, + dims: Tuple[int, ...], + dtype: torch.dtype, +): + world_size = 1 if kind == "singleton" else 2 + seed = random.getrandbits(32) + launch_subprocesses( + world_size=world_size, + fn=inner_seqpar, + kind=kind, + step=step, + dims=dims, + dtype=dtype, + seed=seed, + ) diff --git a/tests/test_sequence_parallel_fused_ops.py b/tests/test_sequence_parallel_fused_ops.py new file mode 100644 index 0000000000..8e6be08314 --- /dev/null +++ b/tests/test_sequence_parallel_fused_ops.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import random +from typing import Tuple + +import pytest +import torch + +from xformers.ops import fused_allgather_and_linear, fused_linear_and_reducescatter + +from .multiprocessing_utils import launch_subprocesses + +compute_capability = (0, 0) +if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability("cuda") +cuda_sm70_only = pytest.mark.skipif( + compute_capability < (7, 0), reason="requires sm70+" +) +at_least_2_gpus = pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="needs at least 2 GPUs" +) + + +def inner_sequence_parallel_fused( + seed: int, + kind: str, + step: str, + dims: Tuple[int, ...], + dtype: torch.dtype, +): + my_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + subgroup = torch.distributed.new_group() + + triton = True + if kind == "fallback": + os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1" + elif kind == "pytorch": + triton = False + + torch.random.manual_seed(seed) + + batch_dims = dims[:-2] + subbatch_dims = (batch_dims[0] // world_size,) + batch_dims[1:] + outer_dim = dims[-2] + inner_dim = dims[-1] + + # To check for correctness we want to compare the outputs but the accuracy + # of matmuls, apparently, is not that great. We thus try to produce inputs + # for which no rounding at all will occur. We do this by using zero or one + # inputs, so their product will also be zero or one, and keep the reduction + # dimension small enough so that they fit in the mantissa without overflow. + max_exact_value = 2 * (1 / torch.finfo(dtype).eps) + # 0.25 is the ratio of expected ones and we aim at 2/3 of the safe range + assert outer_dim * 0.25 <= max_exact_value * 0.66 + assert inner_dim * world_size * 0.25 <= max_exact_value * 0.66 + + if step == "all-gather": + inputs = torch.testing.make_tensor( + (world_size,) + subbatch_dims + (outer_dim,), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + weight = torch.testing.make_tensor( + (inner_dim, outer_dim), dtype=dtype, device="cuda", low=0, high=1 + ).round() + + # Non-fused reference code + output_reference = torch.matmul(inputs, weight.t()).flatten(0, 1) + + # Faster fused mode + output_fused = fused_allgather_and_linear( + inputs[my_rank], weight, group=subgroup, _triton=triton + ) + + elif step == "reduce-scatter": + inputs = torch.testing.make_tensor( + (world_size,) + batch_dims + (inner_dim,), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + weights = torch.testing.make_tensor( + (world_size, outer_dim, inner_dim), + dtype=dtype, + device="cuda", + low=0, + high=1, + ).round() + + # Non-fused reference code + staging = torch.empty( + (world_size,) + subbatch_dims + (outer_dim,), dtype=dtype, device="cuda" + ) + for rank in range(world_size): + torch.matmul( + inputs[rank].tensor_split(world_size, dim=0)[my_rank], + weights[rank].t(), + out=staging[rank], + ) + output_reference = torch.sum(staging, dim=0, dtype=dtype) + + # Faster fused mode + output_fused = fused_linear_and_reducescatter( + inputs[my_rank], weights[my_rank], group=subgroup, _triton=triton + ) + + torch.testing.assert_close(output_reference, output_fused, atol=0, rtol=0) + + +@cuda_sm70_only +@pytest.mark.parametrize( + "kind", + ["singleton", pytest.param("fallback", marks=at_least_2_gpus), "pytorch", "triton"], +) +@pytest.mark.parametrize("step", ["all-gather", "reduce-scatter"]) +@pytest.mark.parametrize( + "dims", + [ + pytest.param((2, 2, 512, 512, 256), id="nice-shapes"), + pytest.param((2, 1023, 511, 257), id="ugly-shapes"), + ], +) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(torch.bfloat16, id="bf16"), + pytest.param(torch.float16, id="fp16"), + pytest.param(torch.float32, id="fp32"), + ], +) +def test_sequence_parallel_fused( + kind: str, + step: str, + dims: Tuple[int, ...], + dtype: torch.dtype, +): + world_size = 1 if kind == "singleton" else 2 + seed = random.getrandbits(32) + launch_subprocesses( + world_size, + inner_sequence_parallel_fused, + seed=seed, + kind=kind, + step=step, + dims=dims, + dtype=dtype, + ) diff --git a/tests/test_tiled_matmul.py b/tests/test_tiled_matmul.py new file mode 100644 index 0000000000..0313556747 --- /dev/null +++ b/tests/test_tiled_matmul.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import pytest +import torch + +from xformers import _is_triton_available +from xformers.ops.tiled_matmul import tiled_matmul + +cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +compute_capability = (0, 0) +if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability("cuda") +cuda_sm70_only = pytest.mark.skipif( + compute_capability < (7, 0), reason="requires sm70+" +) + +# We care about correctness, not performance, hence let's "disable" the +# expensive autotuning by removing all configs except one (the first one). +if _is_triton_available(): + from xformers.ops._triton.tiled_matmul_kernels import _xformers_tiled_matmul_kernel + + while len(_xformers_tiled_matmul_kernel.configs) > 1: + _xformers_tiled_matmul_kernel.configs.pop() + + +def generate_test_shapes(*repeats, num_shapes=5): + shapes = [] + r = random.Random(0) + for repeat in repeats: + m_num_tiles, n_num_tiles, k_num_tiles = repeat + for _ in range(num_shapes): + shapes.append( + ( + [r.randint(2, 1024 // m_num_tiles) for _ in range(m_num_tiles)], + [r.randint(2, 1024 // n_num_tiles) for _ in range(n_num_tiles)], + [r.randint(2, 1024 // k_num_tiles) for _ in range(k_num_tiles)], + ) + ) + return shapes + + +_test_shapes = generate_test_shapes((1, 1, 1), (3, 3, 3)) +_dtypes = [torch.float32, torch.bfloat16, torch.float16] + + +def ceil_of_ratio(n, k): + return (n + k - 1) // k + + +def make_operands(m, n, k, *, dtype): + """Produce lhs, rhs and reference output tensors + + To dodge numerical accuracy differences between our kernels and PyTorch's + ones, we avoid random values and construct matrices whose product is an + exact mathematical computation, specifically: the remainder! + + We do it by having the i-th row of lhs and the j-th column on rhs be like: + * lhs: i times "1", followed by "0" + * rhs: j-1 times "1", followed by "-(j-1)", then repeated + The running sum of their pointwise product will thus be: + 1, 2, 3, ..., j-1, 0, 1, 2, 3, ... and so on + And the final value will be remainder of i by j. + + If K is smaller than M and/or N, this function also takes care of repeating + some rows and/or columns in order to "fill" M and/or K. Similarly, if the + precision of the dtype is too low to store the result without losses, the + function will only use small-enough values, and repeat them as needed. + + Finally, the function permutes the rows and columns, in order to avoid a + predictable block structure. + + """ + max_value = min(k, int(1 / torch.finfo(dtype).eps) * 2) + m_perm = torch.randperm(m) + n_perm = torch.randperm(n) + + num_reps_m = ceil_of_ratio(m, max_value) + lhs = ( + torch.ones((min(m, max_value), k), dtype=dtype) + .tril() + .repeat([num_reps_m, 1])[m_perm, :] + ) + assert lhs.shape == (m, k) + + num_reps_n = ceil_of_ratio(n, max_value) + rhs = torch.ones((k, min(n, max_value)), dtype=dtype) + for i in range(2, min(n, max_value) + 2): + rhs[:, i - 2][i - 1 :: i] = -i + 1 + rhs = rhs.repeat([1, num_reps_n])[:, n_perm] + assert rhs.shape == (k, n) + + lhs_idxs = torch.arange(1, min(m, max_value) + 1).repeat([num_reps_m])[m_perm, None] + rhs_idxs = torch.arange(2, min(n, max_value) + 2).repeat([num_reps_n])[None, n_perm] + out = torch.remainder(lhs_idxs, rhs_idxs).to(dtype) + assert out.shape == (m, n) + + return lhs, rhs, out + + +@cuda_only +@cuda_sm70_only +@pytest.mark.parametrize("shape", _test_shapes, ids=[str(x) for x in _test_shapes]) +@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes]) +def test_forward_backward( + shape, + dtype, +): + m_tiles, n_tiles, k_tiles = shape + m, n, k = sum(m_tiles), sum(n_tiles), sum(k_tiles) + + torch.manual_seed(m * n * k) + + a, b, c_reference = make_operands(m, n, k, dtype=dtype) + a = a.cuda().requires_grad_() + b = b.cuda().requires_grad_() + c_reference = c_reference.cuda() + + # In one operand make each tile have its own strides, in the other use the + # same stride for all tiles. And make the two operands have the stride==1 + # in different dimensions. + a_tiled = [ + [y.t().clone().t() for y in x.split(k_tiles, dim=1)] + for x in a.split(m_tiles, dim=0) + ] + b_tiled = [[y for y in x.split(n_tiles, dim=1)] for x in b.split(k_tiles, dim=0)] + + c_test_tiled = tiled_matmul(a_tiled, b_tiled) + c_test = torch.cat([torch.cat(x, dim=1) for x in c_test_tiled], dim=0) + + torch.testing.assert_close(c_test, c_reference) + + # To avoid numerical issues in the backward, set things up so that we only + # multiply by a diagonal matrix whose entries are +/- 2^{-1/0/+1} (so that + # it only changes the sign bit and the exponent). + diag = torch.tensor(random.choices([-2, -1, -0.5, 0.5, 1, 2], k=min(m, n))) + grad_c = torch.zeros_like(c_test) + torch.diag(grad_c)[:] = diag + grad_a_reference = torch.matmul(grad_c, b.detach().t()) + grad_b_reference = torch.matmul(a.detach().t(), grad_c) + + torch.autograd.backward([c_test], [grad_c], inputs=[a, b]) + + torch.testing.assert_close(a.grad, grad_a_reference) + torch.testing.assert_close(b.grad, grad_b_reference) diff --git a/xformers/benchmarks/benchmark_sequence_parallel_fused.py b/xformers/benchmarks/benchmark_sequence_parallel_fused.py new file mode 100644 index 0000000000..c3889c038d --- /dev/null +++ b/xformers/benchmarks/benchmark_sequence_parallel_fused.py @@ -0,0 +1,479 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import contextlib +import dataclasses +import enum +import multiprocessing +import os +import random +from collections import deque +from statistics import mean, stdev +from typing import Callable + +import torch + +# torch._C._set_print_stack_traces_on_fatal_signal(True) + + +@dataclasses.dataclass +class Scenario: + # The number of tokens, i.e., the batch size times the sequence length + num_samples: int + # The per-sample features outside of the MHA/FFN block, and inside of it + outer_dim: int + inner_dim: int + # Simulate this many matmuls during the all-gather step + num_ag_matrices: int + + +class Step(enum.Enum): + AllGather = "ag" + ReduceScatter = "rs" + + def __str__(self): + return self.value + + +@dataclasses.dataclass +class Bench: + ag: Callable[[], None] + rs: Callable[[], None] + + def __getitem__(self, step: Step): + if step is Step.AllGather: + return self.ag + elif step is Step.ReduceScatter: + return self.rs + else: + raise KeyError(f"{step}") + + +LLAMA_07B_SLEN = 4096 +LLAMA_07B_D = 4096 + +LLAMA_70B_SLEN = 2048 +LLAMA_70B_D = 8192 + + +def round_up_to_nearest_multiple(n: int, m: int) -> int: + return m * ((n + m - 1) // m) + + +def llama_07B_MHA(world_size: int) -> Scenario: + batch_size = 8 + return Scenario( + num_samples=batch_size * LLAMA_07B_SLEN, + outer_dim=LLAMA_07B_D, + inner_dim=LLAMA_07B_D // world_size, + num_ag_matrices=3, + ) + + +def llama_07B_FFN(world_size: int) -> Scenario: + batch_size = 8 + return Scenario( + num_samples=batch_size * LLAMA_07B_SLEN, + outer_dim=LLAMA_07B_D, + inner_dim=round_up_to_nearest_multiple(2 * (4 * LLAMA_07B_D) // 3, 256) + // world_size, + num_ag_matrices=2, + ) + + +def llama_70B_MHA(world_size: int) -> Scenario: + batch_size = world_size + return Scenario( + num_samples=batch_size * LLAMA_70B_SLEN, + outer_dim=LLAMA_70B_D, + inner_dim=LLAMA_70B_D // world_size, + num_ag_matrices=3, + ) + + +def llama_70B_FFN(world_size: int) -> Scenario: + batch_size = world_size + return Scenario( + num_samples=batch_size * LLAMA_70B_SLEN, + outer_dim=LLAMA_70B_D, + inner_dim=round_up_to_nearest_multiple(2 * (4 * LLAMA_70B_D) // 3, 256) + // world_size, + num_ag_matrices=2, + ) + + +SCENARIOS = { + "llama_07B_MHA": llama_07B_MHA, + "llama_07B_FFN": llama_07B_FFN, + "llama_70B_MHA": llama_70B_MHA, + "llama_70B_FFN": llama_70B_FFN, +} + +DTYPES = { + "bfloat16": torch.bfloat16, +} + + +def run_one_rank( + my_rank, + world_size, + scenario_name, + step, + dtype_str, + num_rounds, + num_warmup_iters, + num_bench_iters, + profile, + conn_from_prev, + conn_to_next, +): + print(f"RANK {my_rank} started") + + torch.cuda.set_device(my_rank) + my_device = torch.device(f"cuda:{my_rank}") + + os.environ["RANK"] = f"{my_rank}" + os.environ["WORLD_SIZE"] = f"{world_size}" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + torch.distributed.init_process_group(backend="nccl", init_method="env://") + + subgroup = torch.distributed.new_group() + subgroup_nowait = torch.distributed.new_group() + subgroup_nowait_nomemcpy = torch.distributed.new_group() + + scenario = SCENARIOS[scenario_name](world_size) + if step is Step.AllGather: + M = scenario.num_samples + N = scenario.inner_dim + K = scenario.outer_dim + num_matrices = scenario.num_ag_matrices + elif step is Step.ReduceScatter: + M = scenario.num_samples + N = scenario.outer_dim + K = scenario.inner_dim + num_matrices = 1 + + dtype = DTYPES[dtype_str] + + scattered_input = torch.randn((M // world_size, K), dtype=dtype, device=my_device) + gathered_input = torch.randn((M, K), dtype=dtype, device=my_device) + weights = [ + torch.randn((K, N), dtype=dtype, device=my_device) for _ in range(num_matrices) + ] + gathered_outputs = [ + torch.randn((M, N), dtype=dtype, device=my_device) for _ in range(num_matrices) + ] + scattered_outputs = [ + torch.randn((M // world_size, N), dtype=dtype, device=my_device) + for _ in range(num_matrices) + ] + + gathered_outputs_nccl_reference = [ + torch.randn((M, N), dtype=dtype, device=my_device) for _ in range(num_matrices) + ] + gathered_outputs_fused = [ + torch.randn((M, N), dtype=dtype, device=my_device) for _ in range(num_matrices) + ] + scattered_outputs_nccl_reference = [ + torch.randn((M // world_size, N), dtype=dtype, device=my_device) + for _ in range(num_matrices) + ] + scattered_outputs_fused = [ + torch.randn((M // world_size, N), dtype=dtype, device=my_device) + for _ in range(num_matrices) + ] + + def run_compute_lower_bound_ag(): + for w, go in zip(weights, gathered_outputs): + torch.matmul(gathered_input, w, out=go) + + def run_compute_lower_bound_rs(): + for w, go, so in zip(weights, gathered_outputs, scattered_outputs): + torch.matmul(gathered_input, w, out=go) + torch.sum(go.view((world_size, M // world_size, N)), dim=0, out=so) + + def run_comms_lower_bound_ag(): + torch.distributed.all_gather_into_tensor(gathered_input, scattered_input) + + def run_comms_lower_bound_rs(): + for so, go in zip(scattered_outputs, gathered_outputs): + torch.distributed.reduce_scatter_tensor(so, go) + + def run_nccl_reference_ag(): + torch.distributed.all_gather_into_tensor(gathered_input, scattered_input) + for w, go in zip(weights, gathered_outputs_nccl_reference): + torch.matmul(gathered_input, w, out=go) + + def run_nccl_reference_rs(): + for w, go, so in zip( + weights, gathered_outputs, scattered_outputs_nccl_reference + ): + torch.matmul(gathered_input, w, out=go) + torch.distributed.reduce_scatter_tensor(so, go) + + def run_fused_ag(): + nonlocal gathered_outputs_fused + from xformers.ops import fused_allgather_and_linear + + gathered_outputs_fused = fused_allgather_and_linear( + scattered_input, + [w.t() for w in weights], + group=subgroup, + num_stripes=2, + timeout_s=10, + ) + + def run_fused_rs(): + nonlocal scattered_outputs_fused + from xformers.ops import fused_linear_and_reducescatter + + scattered_outputs_fused = fused_linear_and_reducescatter( + gathered_input, + [w.t() for w in weights], + group=subgroup, + num_stripes=2, + timeout_s=10, + ) + + def run_fused_nowait_ag(): + nonlocal gathered_outputs_fused + from xformers.ops import fused_allgather_and_linear + + gathered_outputs_fused = fused_allgather_and_linear( + scattered_input, + [w.t() for w in weights], + group=subgroup_nowait, + num_stripes=2, + _wait=False, + timeout_s=10, + ) + + def run_fused_nowait_rs(): + nonlocal scattered_outputs_fused + from xformers.ops import fused_linear_and_reducescatter + + scattered_outputs_fused = fused_linear_and_reducescatter( + gathered_input, + [w.t() for w in weights], + group=subgroup_nowait, + num_stripes=2, + _wait=False, + timeout_s=10, + ) + + def run_fused_nowait_nomemcpy_ag(): + nonlocal gathered_outputs_fused + from xformers.ops import fused_allgather_and_linear + + gathered_outputs_fused = fused_allgather_and_linear( + scattered_input, + [w.t() for w in weights], + group=subgroup_nowait_nomemcpy, + num_stripes=2, + _wait=False, + _memcpy=False, + timeout_s=10, + ) + + def run_fused_nowait_nomemcpy_rs(): + nonlocal scattered_outputs_fused + from xformers.ops import fused_linear_and_reducescatter + + scattered_outputs_fused = fused_linear_and_reducescatter( + gathered_input, + [w.t() for w in weights], + group=subgroup_nowait_nomemcpy, + num_stripes=2, + _wait=False, + _memcpy=False, + timeout_s=10, + ) + + print(f"Sizes: ({world_size}x{M // world_size})x({num_matrices}x{N})x{K}") + + if step is Step.AllGather: + run_nccl_reference_ag() + run_fused_ag() + if my_rank == 0: + print("fused:") + print( + "Are equal? " + + " ".join( + str(torch.equal(ref, fus)) + for ref, fus in zip( + gathered_outputs_nccl_reference, gathered_outputs_fused + ) + ) + ) + print( + "Are allclose? " + + " ".join( + str(torch.allclose(ref, fus)) + for ref, fus in zip( + gathered_outputs_nccl_reference, gathered_outputs_fused + ) + ) + ) + + elif step is Step.ReduceScatter: + run_nccl_reference_rs() + run_fused_rs() + if my_rank == 0: + print("fused:") + print( + "Are equal? " + + " ".join( + str(torch.equal(ref, fus)) + for ref, fus in zip( + scattered_outputs_nccl_reference, scattered_outputs_fused + ) + ) + ) + print( + "Are allclose? " + + " ".join( + str(torch.allclose(ref, fus)) + for ref, fus in zip( + scattered_outputs_nccl_reference, scattered_outputs_fused + ) + ) + ) + + # The above checks might still return False for, e.g., bfloat16 because they + # have too little tolerance for its lower precision. This method, OTOH, uses + # variable tolerances based on dtype. + # for ref, fus in zip(gathered_outputs_nccl_reference, gathered_outputs_fused): + # torch.testing.assert_close(ref, fus) + # for ref, fus in zip(scattered_outputs_nccl_reference, scattered_outputs_fused): + # torch.testing.assert_close(ref, fus) + + all_benchs = { + "compute_lower_bound": Bench( + ag=run_compute_lower_bound_ag, rs=run_compute_lower_bound_rs + ), + "comms_lower_bound": Bench( + ag=run_comms_lower_bound_ag, rs=run_comms_lower_bound_rs + ), + "nccl_reference": Bench(ag=run_nccl_reference_ag, rs=run_nccl_reference_rs), + "fused": Bench(ag=run_fused_ag, rs=run_fused_rs), + "fused_nowait": Bench(ag=run_fused_nowait_ag, rs=run_fused_nowait_rs), + "fused_nowait_nomemcpy": Bench( + ag=run_fused_nowait_nomemcpy_ag, rs=run_fused_nowait_nomemcpy_rs + ), + } + + unused_events = deque( + tuple(torch.cuda.Event(enable_timing=my_rank == 0) for _ in range(2)) + for f in range(len(all_benchs)) + ) + used_events = deque() + + timings = {} + + gen = random.Random(42) + + if profile: + profiler = torch.profiler.profile() + else: + profiler = contextlib.nullcontext() + + with profiler as p: + for method in gen.sample( + list(all_benchs), + k=num_rounds * len(all_benchs), + counts=[num_rounds] * len(all_benchs), + ): + fun = all_benchs[method][step] + + if unused_events: + start_ev, end_ev = unused_events.popleft() + else: + old_method, start_ev, end_ev = used_events.popleft() + end_ev.synchronize() + if my_rank == 0: + timings.setdefault(old_method, []).append( + start_ev.elapsed_time(end_ev) / num_bench_iters + ) + + for _ in range(num_warmup_iters): + fun() + start_ev.record() + for _ in range(num_bench_iters): + fun() + end_ev.record() + + used_events.append((method, start_ev, end_ev)) + + torch.cuda.synchronize() + + if profile: + p.export_chrome_trace(f"fusion_trace_{my_rank}.json") + + if my_rank == 0: + for method, start_ev, end_ev in used_events: + timings.setdefault(method, []).append( + start_ev.elapsed_time(end_ev) / num_bench_iters + ) + + for method in all_benchs: + print( + f"{method} = {mean(timings[method]):g}ms (+/- {stdev(timings[method]):g})" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("scenario", choices=SCENARIOS.keys()) + parser.add_argument("step", choices=list(Step), type=Step) + parser.add_argument("--world-size", type=int, default=8) + parser.add_argument("--dtype", choices=DTYPES.keys(), default="bfloat16") + parser.add_argument("--num-rounds", type=int, default=20) + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-bench-iters", type=int, default=50) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + + conns_from_prev = [None] * args.world_size + conns_to_next = [None] * args.world_size + for rank in range(args.world_size): + end1, end2 = multiprocessing.get_context("spawn").Pipe(duplex=True) + conns_to_next[rank] = end1 + conns_from_prev[(rank + 1) % args.world_size] = end2 + + processes = [] + for rank in range(args.world_size): + p = multiprocessing.get_context("spawn").Process( + target=run_one_rank, + args=( + rank, + args.world_size, + args.scenario, + args.step, + args.dtype, + args.num_rounds, + args.num_warmup_iters, + args.num_bench_iters, + args.profile, + conns_from_prev[rank], + conns_to_next[rank], + ), + daemon=True, + ) + p.start() + processes.append(p) + + print("LAUNCHED") + + for rank, p in enumerate(processes): + p.join() + print(f"Rank {rank} exited with {p.exitcode}") + + print("JOINED") + + +if __name__ == "__main__": + main() diff --git a/xformers/benchmarks/benchmark_tiled_matmul.py b/xformers/benchmarks/benchmark_tiled_matmul.py new file mode 100644 index 0000000000..f584d1bcbb --- /dev/null +++ b/xformers/benchmarks/benchmark_tiled_matmul.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools + +import torch +from torch.utils import benchmark +from triton.ops.matmul import matmul as triton_matmul + +from xformers.benchmarks.utils import DTYPE2STR, benchmark_main_helper +from xformers.ops.tiled_matmul import tiled_matmul + +min_run_time = 5 + + +SHAPES = { + "llama1_65b_mha_fwd": ([16384], [1024] * 3, [8192]), + "llama1_65b_mha_bwd_input": ([16384], [8192], [1024] * 3), + "llama1_65b_mha_bwd_weight": ([8192], [1024] * 3, [16384]), + "llama1_65b_ffn_fwd": ([16384], [2752] * 2, [8192]), + "llama1_65b_ffn_bwd_input": ([16384], [8192], [2752] * 2), + "llama1_65b_ffn_bwd_weight": ([8192], [2752] * 2, [16384]), + "llama2_150b_mha_fwd": ([16384], [1536, 128, 128], [12288]), + "llama2_150b_mha_bwd_input": ([16384], [12288], [1536, 128, 128]), + "llama2_150b_mha_bwd_weight": ([12288], [1536, 128, 128], [16384]), + "llama2_150b_ffn_fwd": ([16384], [4096] * 2, [12288]), + "llama2_150b_ffn_bwd_input": ([16384], [12288], [4096] * 2), + "llama2_150b_ffn_bwd_weight": ([12288], [4096] * 2, [16384]), +} + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + shape_name=SHAPES.keys(), + dtype=[ + # torch.float32, + torch.bfloat16, + # torch.float16, + ], + ) +) + + +def matmul_per_tile(a, b): + c = [] + for n in range(len(a)): + c.append([]) + for m in range(len(b[0])): + c[-1].append( + sum([torch.matmul(a[n][k], b[k][m]) for k in range(len(a[0]))]) + ) + return c + + +def benchmark_tiled_matmul(shape_name, dtype): + ms, ns, ks = SHAPES[shape_name] + m, n, k = sum(ms), sum(ns), sum(ks) + + a = torch.randn((m, k), device="cuda", dtype=dtype) + b = torch.randn((k, n), device="cuda", dtype=dtype) + + a_tiles = [[y.clone() for y in x.split(ks, dim=1)] for x in a.split(ms, dim=0)] + b_tiles = [[y.clone() for y in x.split(ns, dim=1)] for x in b.split(ks, dim=0)] + + dtype_str = DTYPE2STR.get(dtype, dtype) + sub_label = ( + f"{dtype_str} {shape_name} " + f"M={'+'.join(f'{m}' for m in ms)} " + f"N={'+'.join(f'{n}' for n in ns)} " + f"K={'+'.join(f'{k}' for k in ks)}" + ) + + # Warmup (maybe not needed?) + torch.mm(a, b) + matmul_per_tile(a_tiles, b_tiles) + triton_matmul(a, b) + tiled_matmul(a_tiles, b_tiles) + + yield benchmark.Timer( + stmt="fn(a, b)", + globals={ + "a": a, + "b": b, + "fn": torch.mm, + }, + label="tiled_matmul", + description="pytorch_fused", + sub_label=sub_label, + ) + yield benchmark.Timer( + stmt="fn(a, b)", + globals={ + "a": a_tiles, + "b": b_tiles, + "fn": matmul_per_tile, + }, + label="tiled_matmul", + description="pytorch_tiled", + sub_label=sub_label, + ) + yield benchmark.Timer( + stmt="fn(a, b)", + globals={ + "a": a, + "b": b, + "fn": triton_matmul, + }, + label="tiled_matmul", + description="triton_fused", + sub_label=sub_label, + ) + yield benchmark.Timer( + stmt="fn(a, b)", + globals={ + "a": a_tiles, + "b": b_tiles, + "fn": tiled_matmul, + }, + label="tiled_matmul", + description="xformers_tiled", + sub_label=sub_label, + ) + + +benchmark_main_helper(benchmark_tiled_matmul, CASES, min_run_time=min_run_time) diff --git a/xformers/ops/__init__.py b/xformers/ops/__init__.py index 9986aa8a6f..ef92ada940 100644 --- a/xformers/ops/__init__.py +++ b/xformers/ops/__init__.py @@ -23,8 +23,16 @@ memory_efficient_attention_forward_requires_grad, ) from .indexing import index_select_cat, scaled_index_add +from .modpar_layers import ColumnParallelLinear, RowParallelLinear from .rmsnorm import RMSNorm from .rope_padded import rope_padded +from .seqpar import sequence_parallel_leading_matmul, sequence_parallel_trailing_matmul +from .sequence_parallel_fused_ops import ( + fused_allgather_and_anything, + fused_allgather_and_linear, + fused_anything_and_reducescatter, + fused_linear_and_reducescatter, +) from .swiglu_op import ( SwiGLU, SwiGLUEagerOp, @@ -34,6 +42,7 @@ SwiGLUPackedFusedOp, swiglu, ) +from .tiled_matmul import tiled_matmul from .unbind import get_stack_strides, stack_or_none, unbind # BW compatibility @@ -63,7 +72,7 @@ def masked_matmul(a, b, mask=None): __all__ = [ - "memory_efficient_attention", + # fmha "AttentionBias", "AttentionMask", "AttentionOp", @@ -75,10 +84,30 @@ def masked_matmul(a, b, mask=None): "MemoryEfficientAttentionFlashAttentionOp", "MemoryEfficientAttentionOp", "MemoryEfficientAttentionTritonFwdFlashBwOp", + "TritonFlashAttentionOp", + "memory_efficient_attention", "memory_efficient_attention_backward", "memory_efficient_attention_forward", "memory_efficient_attention_forward_requires_grad", + # indexing + "index_select_cat", + "scaled_index_add", + # modpar_layers + "ColumnParallelLinear", + "RowParallelLinear", + # rmsnorm "RMSNorm", + # rope_padded + "rope_padded", + # seqpar + "sequence_parallel_leading_matmul", + "sequence_parallel_trailing_matmul", + # sequence_parallel_fused_ops + "fused_allgather_and_anything", + "fused_allgather_and_linear", + "fused_anything_and_reducescatter", + "fused_linear_and_reducescatter", + # swiglu_op "SwiGLU", "SwiGLUEagerOp", "SwiGLUFusedOp", @@ -86,13 +115,12 @@ def masked_matmul(a, b, mask=None): "SwiGLUOpDispatch", "SwiGLUPackedFusedOp", "swiglu", - "TritonFlashAttentionOp", - "unbind", - "stack_or_none", + # tiled_matmul + "tiled_matmul", + # unbind "get_stack_strides", + "stack_or_none", + "unbind", + # . "masked_matmul", - "scaled_index_add", - "index_select_cat", - "rope_padded", - "attn_bias", ] diff --git a/xformers/ops/_triton/__init__.py b/xformers/ops/_triton/__init__.py index 254820b126..0a8ab8e0f6 100644 --- a/xformers/ops/_triton/__init__.py +++ b/xformers/ops/_triton/__init__.py @@ -4,6 +4,9 @@ # LICENSE file in the root directory of this source tree. +# One reason this module is called `_triton` instead of just `triton` is this: +# https://github.com/openai/triton/commit/c6040bcbd8a046785462481b2830b3fff5fc4aab + from typing import TYPE_CHECKING import xformers diff --git a/xformers/ops/_triton/sequence_parallel_fused_kernels.py b/xformers/ops/_triton/sequence_parallel_fused_kernels.py new file mode 100644 index 0000000000..15c559b78b --- /dev/null +++ b/xformers/ops/_triton/sequence_parallel_fused_kernels.py @@ -0,0 +1,643 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from typing import List, Optional, Set, Tuple, cast + +import torch +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(*names): + def result(nargs): + if nargs["blocks_done_counters"].numel() > 0: + nargs["blocks_done_counters"].zero_() + for name in names: + nargs[name].zero_() + + return result + + +def gen_config( + block_m: int, + block_n: int, + block_k: int, + stages: int, + warps: int, + split_k: int = 1, + group_m: int = 8, +) -> triton.Config: + """A more compact way to define a triton.Config, so it fits on one line""" + + return triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + "GROUP_M": group_m, + }, + num_stages=stages, + num_warps=warps, + pre_hook=init_to_zero("C1", "C2", "C3") if split_k > 1 else init_to_zero(), + ) + + +BASIC_MATMUL_CONFIGS = [ + gen_config(block_m=128, block_n=256, block_k=32, stages=3, warps=8), + gen_config(block_m=256, block_n=128, block_k=32, stages=3, warps=8), + gen_config(block_m=256, block_n=64, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=256, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=128, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=64, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=128, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=32, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=32, block_k=32, stages=5, warps=2), +] + + +INT8_MATMUL_CONFIGS = [ + gen_config(block_m=128, block_n=256, block_k=128, stages=3, warps=8), + gen_config(block_m=256, block_n=128, block_k=128, stages=3, warps=8), + gen_config(block_m=256, block_n=64, block_k=128, stages=4, warps=4), + gen_config(block_m=64, block_n=256, block_k=128, stages=4, warps=4), + gen_config(block_m=128, block_n=128, block_k=128, stages=4, warps=4), + gen_config(block_m=128, block_n=64, block_k=64, stages=4, warps=4), + gen_config(block_m=64, block_n=128, block_k=64, stages=4, warps=4), + gen_config(block_m=128, block_n=32, block_k=64, stages=4, warps=4), + gen_config(block_m=64, block_n=32, block_k=64, stages=5, warps=2), +] + + +IO_BOUND_MATMUL_CONFIGS_STAGES = [2, 3, 4, 5, 6] +IO_BOUND_MATMUL_CONFIGS_BLOCK_M = [16, 32] +IO_BOUND_MATMUL_CONFIGS_BLOCK_K = [32, 64] +IO_BOUND_MATMUL_CONFIGS_BLOCK_N = [32, 64, 128, 256] +IO_BOUND_MATMUL_CONFIGS_SPLIT_K = [1, 2, 4, 8, 16] + + +IO_BOUND_MATMUL_CONFIGS = [ + gen_config( + block_m=block_m, + block_n=block_n, + block_k=block_k, + stages=stages, + warps=2 if block_n <= 64 else 4, + split_k=split_k, + ) + for stages, block_m, block_k, block_n, split_k in itertools.product( + IO_BOUND_MATMUL_CONFIGS_STAGES, + IO_BOUND_MATMUL_CONFIGS_BLOCK_M, + IO_BOUND_MATMUL_CONFIGS_BLOCK_K, + IO_BOUND_MATMUL_CONFIGS_BLOCK_N, + IO_BOUND_MATMUL_CONFIGS_SPLIT_K, + ) +] + + +TRITON_CONFIGS = BASIC_MATMUL_CONFIGS + INT8_MATMUL_CONFIGS + IO_BOUND_MATMUL_CONFIGS + + +BACKWARDS_WITH_ME_FIRST = 0 +FORWARDS_WITH_ME_LAST = 1 + +NUM_SPINS_BETWEEN_TIMEOUT_CHECKS = 1000 + + +@triton.jit +def determine_tile( + A, + B1, + B2, + B3, + C1, + C2, + C3, + A_my_shard, + C1_my_shard, + C2_my_shard, + C3_my_shard, + M, + N1, + N2, + N3, + my_rank, + world_size, + direction, + stride_am, + stride_cm, + BLOCK_M, + BLOCK_N, + GROUP_M, +): + # tl.device_assert(M % world_size == 0) + M_per_rank = M // world_size + # matrix multiplication + pid = tl.program_id(0) + grid_m_per_rank = tl.cdiv(M_per_rank, BLOCK_M) + grid_n1 = tl.cdiv(N1, BLOCK_N) + grid_n2 = tl.cdiv(N2, BLOCK_N) + grid_n3 = tl.cdiv(N3, BLOCK_N) + grid_n = grid_n1 + grid_n2 + grid_n3 + + # Blocks with lower pid will be executed first (this isn't a documented + # guarantee, but seems to happen in practice, and Triton already leverages + # it for its swizzling just below). We want the first blocks to operate on + # the local rank's shard, since it's immediately available, then once that's + # all done operate on the first remote contribution to arrive (the one from + # my_rank - 1), etc. Thus we change the pointers to A and C, and the value + # of pid, as needed to operate on the input in the order we want. + blocks_per_rank = grid_m_per_rank * grid_n + if direction == BACKWARDS_WITH_ME_FIRST: + other_rank = (my_rank - (pid // blocks_per_rank) + world_size) % world_size + else: # direction == FORWARDS_WITH_ME_LAST: + other_rank = (my_rank + (pid // blocks_per_rank + 1)) % world_size + pid = pid % blocks_per_rank + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m_per_rank - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + B = tl.where(pid_n < grid_n1, B1, tl.where(pid_n < grid_n1 + grid_n2, B2, B3)) + C = tl.where(pid_n < grid_n1, C1, tl.where(pid_n < grid_n1 + grid_n2, C2, C3)) + C_my_shard = tl.where( + pid_n < grid_n1, + C1_my_shard, + tl.where(pid_n < grid_n1 + grid_n2, C2_my_shard, C3_my_shard), + ) + N = tl.where(pid_n < grid_n1, N1, tl.where(pid_n < grid_n1 + grid_n2, N2, N3)) + pid_n = tl.where( + pid_n < grid_n1, + pid_n, + tl.where(pid_n < grid_n1 + grid_n2, pid_n - grid_n1, pid_n - grid_n1 - grid_n2), + ) + + A = tl.where( + other_rank == my_rank, A_my_shard, A + other_rank * M_per_rank * stride_am + ) + C = tl.where( + other_rank == my_rank, C_my_shard, C + other_rank * M_per_rank * stride_cm + ) + + return A, B, C, M_per_rank, N, pid_m, pid_n, other_rank, blocks_per_rank + + +@triton.jit +def wait_for_recv( + seq_num, + wait_counters, + other_rank, + my_rank, + stripe, + num_stripes, + _wait, + do_wait, + timeout_ns, +): + if (_wait and do_wait) and other_rank != my_rank: + wait_counter = wait_counters + other_rank * num_stripes + stripe + start_time_ns = tl.extra.cuda.globaltimer() + num_spins = 0 + # There's no atomic_load, hence we simulate it with a CAS. + while tl.atomic_cas(wait_counter, 0, 0) != seq_num: + num_spins += 1 + if num_spins == NUM_SPINS_BETWEEN_TIMEOUT_CHECKS: + if tl.extra.cuda.globaltimer() - start_time_ns > timeout_ns: + tl.device_assert( + False, + "xFormers's fused kernels for sequence parallelism " + "timed out waiting for a peer GPU. To prevent " + "downstream computations from operating on corrupted " + "data, we're bringing the CUDA context down with us.", + ) + num_spins = 0 + + +@triton.jit +def do_matmul( + A, + B, + C, + pid_m, + pid_n, + pid_z, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + SPLIT_K, + EVEN_K, +): + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +@triton.jit +def trigger_send( + seq_num, + blocks_done_counters, + write_counters, + other_rank, + my_rank, + num_stripes, + stripe, + num_blocks_3d, + _wait, + do_write, +): + if (_wait and do_write) and other_rank != my_rank: + num_blocks_done = ( + tl.atomic_add( + blocks_done_counters + other_rank + tl.arange(0, 1), + 1, + sem="acq_rel", + ) + + 1 + ) + tl.atomic_xchg( + write_counters + other_rank * num_stripes + stripe + tl.arange(0, 1), + seq_num, + mask=num_blocks_done == num_blocks_3d, + sem="release", + ) + + +def our_estimate_matmul_time(B1, C1, N1, N2, N3, **kwargs): + """Call into Triton's upstream cost model, with the right args + + The upstream function expects arguments to have certain names. Since we + renamed a few of them in our implementation, we rename them back. + + At the time of writing (July 2023) the arguments that Triton expects are: + M, N, K, A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages. + + """ + return estimate_matmul_time(N=N1 + N2 + N3, B=B1, C=C1, **kwargs) + + +@triton.autotune( + configs=TRITON_CONFIGS, + key=["M", "N1", "N2", "N3", "K"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": our_estimate_matmul_time, + "top_k": 10, + }, + reset_to_zero=["blocks_done_counters"], +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit( + do_not_specialize=[11, 12, 13, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + debug=True, # To avoid stripping device asserts +) +def _xformers_seqpar_matmul_kernel( + A_my_shard, + A, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + B1, + B2, + B3, + C1, + C2, + C3, + C1_my_shard, + C2_my_shard, + C3_my_shard, + wait_counters, + blocks_done_counters, + write_counters, + M, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + N1, + N2, + N3, + K, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + do_wait, + do_write, + direction, + stripe, + seq_num, + num_stripes, + _wait, + my_rank, + world_size, + timeout_ns, + BLOCK_M: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + BLOCK_N: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + BLOCK_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + A, B, C, M, N, pid_m, pid_n, other_rank, num_blocks_2d = determine_tile( + A, + B1, + B2, + B3, + C1, + C2, + C3, + A_my_shard, + C1_my_shard, + C2_my_shard, + C3_my_shard, + M, + N1, + N2, + N3, + my_rank, + world_size, + direction, + stride_am, + stride_cm, + BLOCK_M, + BLOCK_N, + GROUP_M, + ) + pid_z = tl.program_id(1) + + wait_for_recv( + seq_num, + wait_counters, + other_rank, + my_rank, + stripe, + num_stripes, + _wait, + do_wait, + timeout_ns, + ) + + do_matmul( + A, + B, + C, + pid_m, + pid_n, + pid_z, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + SPLIT_K, + EVEN_K, + ) + + trigger_send( + seq_num, + blocks_done_counters, + write_counters, + other_rank, + my_rank, + num_stripes, + stripe, + num_blocks_2d * SPLIT_K, + _wait, + do_write, + ) + + +AUTOTUNED_SIZES: Set[Tuple[int, Tuple[int, ...], int, torch.dtype]] = set() + + +def common_alignment(*args): + for div in [16, 8, 4, 2]: + if all(a % div == 0 for a in args): + return div + return 1 + + +def _launch_triton_matmul( + a_my_shard: Optional[torch.Tensor], + a: torch.Tensor, + bs: List[torch.Tensor], + cs: List[torch.Tensor], + cs_my_shard: Optional[List[torch.Tensor]], + my_rank: int, + world_size: int, + wait_counters: Optional[torch.Tensor], + write_counters: Optional[torch.Tensor], + direction: int, + stripe: int, + seq_num: int, + num_stripes: int, + timeout_s: int, + _wait: bool = True, +) -> None: + # checks constraints + assert 0 <= my_rank < world_size + assert 0 <= stripe < num_stripes and 0 <= seq_num < 2**8 + assert direction in (BACKWARDS_WITH_ME_FIRST, FORWARDS_WITH_ME_LAST) + + assert len(bs) == len(cs) + assert a.ndim == 2 + assert all(b.ndim == 2 for b in bs) + assert all(c.ndim == 2 for c in cs) + M, K = a.shape + Ns = [b.shape[1] for b in bs] + assert all(b.shape[0] == K for b in bs) + assert all(c.shape[0] == M for c in cs) + assert all(c.shape[1] == N for c, N in zip(cs, Ns)) + stride_am, stride_ak = cast(Tuple[int, int], a.stride()) + stride_bk, stride_bn = cast(Tuple[int, int], bs[0].stride()) + stride_cm, stride_cn = cast(Tuple[int, int], cs[0].stride()) + assert all(b.stride() == (stride_bk, stride_bn) for b in bs) + assert all(c.stride() == (stride_cm, stride_cn) for c in cs) + assert stride_am == 1 or stride_ak == 1 + assert stride_bk == 1 or stride_bn == 1 + assert stride_cm == 1 or stride_cn == 1 + + if a_my_shard is not None: + assert a_my_shard.ndim == 2 + assert a_my_shard.shape[0] * world_size == a.shape[0] + assert a_my_shard.shape[1] == a.shape[1] + assert a_my_shard.stride() == a.stride() + else: + assert a.shape[0] % world_size == 0 + a_my_shard = a.tensor_split(world_size)[my_rank] + + if cs_my_shard is not None: + assert len(cs_my_shard) == len(cs) + assert all(c_my_shard.ndim == 2 for c_my_shard in cs_my_shard) + assert all( + c_my_shard.shape[0] * world_size == c.shape[0] + for c, c_my_shard in zip(cs, cs_my_shard) + ) + assert all( + c_my_shard.shape[1] == c.shape[1] for c, c_my_shard in zip(cs, cs_my_shard) + ) + assert all( + c_my_shard.stride() == c.stride() for c, c_my_shard in zip(cs, cs_my_shard) + ) + else: + assert all(c.shape[0] % world_size == 0 for c in cs) + cs_my_shard = [c.tensor_split(world_size)[my_rank] for c in cs] + + if wait_counters is not None: + assert wait_counters.shape == (world_size, num_stripes) + assert wait_counters.dtype is torch.int + assert wait_counters.is_contiguous() + do_wait = True + else: + do_wait = False + wait_counters = torch.empty((0,), dtype=torch.int, device=a.device) + + if write_counters is not None: + assert write_counters.shape == (world_size, num_stripes) + assert write_counters.dtype is torch.int + assert write_counters.is_contiguous() + do_write = True + blocks_done_counters = torch.empty( + (world_size,), dtype=torch.int, device=a.device + ) + else: + do_write = False + write_counters = torch.empty((0,), dtype=torch.int, device=a.device) + blocks_done_counters = torch.empty((0,), dtype=torch.int, device=a.device) + + # accumulator types + assert all(c.dtype == cs[0].dtype for c in cs) + ACC_TYPE = ( + tl.float32 + if cs[0].dtype in [torch.float16, torch.bfloat16, torch.float32] + else tl.int32 + ) + + # launch kernel + def grid(META): + return ( + world_size + * triton.cdiv(M // world_size, META["BLOCK_M"]) + * sum(triton.cdiv(N, META["BLOCK_N"]) for N in Ns), + META["SPLIT_K"], + ) + + # Can be raised if needed. + assert len(bs) <= 3 + + # We auto-tune the kernel's tiling and other parameters for each set of + # sizes. However, auto-tuning performs a device sync (it has to retrieve + # timings), which can be problematic: the kernel may busy-wait for something + # that will only be scheduled later, and the sync would never return. Thus, + # for auto-tuning, we'd like to set _wait to False, and then set it to True + # for the real run. (We assume that the kernel is idempotent, and that it + # won't have a wildly different perf profile when it runs on garbage data + # compared to real data). + + # Define the args/kwargs corresponding to the default invocation. + # We can't just use kwargs because Triton expects some args as positional. + args = ( + a_my_shard, + a, + bs[0], + bs[min(1, len(bs) - 1)], + bs[min(2, len(bs) - 1)], + cs[0], + cs[min(1, len(cs) - 1)], + cs[min(2, len(cs) - 1)], + cs_my_shard[0], + cs_my_shard[min(1, len(cs_my_shard) - 1)], + cs_my_shard[min(2, len(cs_my_shard) - 1)], + wait_counters, + blocks_done_counters, + write_counters, + M, + Ns[0], + Ns[1] if len(Ns) >= 2 else 0, + Ns[2] if len(Ns) >= 3 else 0, + K, + ) + kwargs = dict( + stride_am=stride_am, + stride_ak=stride_ak, + stride_bk=stride_bk, + stride_bn=stride_bn, + stride_cm=stride_cm, + stride_cn=stride_cn, + do_wait=do_wait, + do_write=do_write, + direction=direction, + stripe=stripe, + seq_num=seq_num, + num_stripes=num_stripes, + _wait=_wait, + my_rank=my_rank, + world_size=world_size, + timeout_ns=timeout_s * 1_000_000_000, + ACC_TYPE=ACC_TYPE, + ) + + # Run without waiting to auto-tune this set of sizes, if needed + if (M, tuple(Ns), K, cs[0].dtype) not in AUTOTUNED_SIZES: + kwargs["_wait"] = False + _xformers_seqpar_matmul_kernel[grid](*args, **kwargs) + kwargs["_wait"] = _wait + AUTOTUNED_SIZES.add((M, tuple(Ns), K, cs[0].dtype)) + + # Run the actual kernel + _xformers_seqpar_matmul_kernel[grid](*args, **kwargs) diff --git a/xformers/ops/_triton/tiled_matmul_kernels.py b/xformers/ops/_triton/tiled_matmul_kernels.py new file mode 100644 index 0000000000..8f77a0c816 --- /dev/null +++ b/xformers/ops/_triton/tiled_matmul_kernels.py @@ -0,0 +1,430 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from typing import List, Tuple + +import torch +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(*names): + def result(nargs): + for name in names: + nargs[name].zero_() + + return result + + +def gen_config( + block_m: int, + block_n: int, + block_k: int, + stages: int, + warps: int, + split_k: int = 1, + group_m: int = 8, +) -> triton.Config: + """A more compact way to define a triton.Config, so it fits on one line""" + + return triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + "GROUP_M": group_m, + }, + num_stages=stages, + num_warps=warps, + pre_hook=init_to_zero(*[f"C{i+1}{j+1}" for i in range(3) for j in range(3)]) + if split_k > 1 + else init_to_zero(), + ) + + +BASIC_MATMUL_CONFIGS = [ + gen_config(block_m=128, block_n=256, block_k=32, stages=3, warps=8), + gen_config(block_m=256, block_n=128, block_k=32, stages=3, warps=8), + gen_config(block_m=256, block_n=64, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=256, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=128, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=64, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=128, block_k=32, stages=4, warps=4), + gen_config(block_m=128, block_n=32, block_k=32, stages=4, warps=4), + gen_config(block_m=64, block_n=32, block_k=32, stages=5, warps=2), +] + + +INT8_MATMUL_CONFIGS = [ + gen_config(block_m=128, block_n=256, block_k=128, stages=3, warps=8), + gen_config(block_m=256, block_n=128, block_k=128, stages=3, warps=8), + gen_config(block_m=256, block_n=64, block_k=128, stages=4, warps=4), + gen_config(block_m=64, block_n=256, block_k=128, stages=4, warps=4), + gen_config(block_m=128, block_n=128, block_k=128, stages=4, warps=4), + gen_config(block_m=128, block_n=64, block_k=64, stages=4, warps=4), + gen_config(block_m=64, block_n=128, block_k=64, stages=4, warps=4), + gen_config(block_m=128, block_n=32, block_k=64, stages=4, warps=4), + gen_config(block_m=64, block_n=32, block_k=64, stages=5, warps=2), +] + + +IO_BOUND_MATMUL_CONFIGS_STAGES = [2, 3, 4, 5, 6] +IO_BOUND_MATMUL_CONFIGS_BLOCK_M = [16, 32] +IO_BOUND_MATMUL_CONFIGS_BLOCK_K = [32, 64] +IO_BOUND_MATMUL_CONFIGS_BLOCK_N = [32, 64, 128, 256] +IO_BOUND_MATMUL_CONFIGS_SPLIT_K = [1, 2, 4, 8, 16] + + +IO_BOUND_MATMUL_CONFIGS = [ + gen_config( + block_m=block_m, + block_n=block_n, + block_k=block_k, + stages=stages, + warps=2 if block_n <= 64 else 4, + split_k=split_k, + ) + for stages, block_m, block_k, block_n, split_k in itertools.product( + IO_BOUND_MATMUL_CONFIGS_STAGES, + IO_BOUND_MATMUL_CONFIGS_BLOCK_M, + IO_BOUND_MATMUL_CONFIGS_BLOCK_K, + IO_BOUND_MATMUL_CONFIGS_BLOCK_N, + IO_BOUND_MATMUL_CONFIGS_SPLIT_K, + ) +] + + +TRITON_CONFIGS = BASIC_MATMUL_CONFIGS + INT8_MATMUL_CONFIGS + IO_BOUND_MATMUL_CONFIGS + + +def our_estimate_matmul_time( + A11, B11, C11, M1, M2, M3, N1, N2, N3, K1, K2, K3, **kwargs +): + """Call into Triton's upstream cost model, with the right args + + The upstream function expects arguments to have certain names. Since we + renamed a few of them in our implementation, we rename them back. + + At the time of writing (July 2023) the arguments that Triton expects are: + M, N, K, A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages. + + """ + return estimate_matmul_time( + M=M1 + M2 + M3, N=N1 + N2 + N3, K=K1 + K2 + K3, A=A11, B=B11, C=C11, **kwargs + ) + + +def our_early_config_prune(config, named_args): + new_named_args = named_args.copy() + new_named_args["M"] = named_args["M1"] + named_args["M2"] + named_args["M3"] + new_named_args["N"] = named_args["N1"] + named_args["N2"] + named_args["N3"] + new_named_args["K"] = named_args["K1"] + named_args["K2"] + named_args["K3"] + new_named_args["A"] = named_args["A11"] + new_named_args["B"] = named_args["B11"] + new_named_args["C"] = named_args["C11"] + return early_config_prune(config, new_named_args) + + +@triton.autotune( + configs=TRITON_CONFIGS, + key=["M1", "M2", "M3", "N1", "N2", "N3", "K1", "K2", "K3"], + prune_configs_by={ + "early_config_prune": our_early_config_prune, + "perf_model": our_estimate_matmul_time, + "top_k": 10, + }, +) +@triton.heuristics( + { + "EVEN_K": lambda args: all( + k % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 + for k in [args["K1"], args["K2"], args["K3"]] + ), + } +) +@triton.jit() +def _xformers_tiled_matmul_kernel( + A11, + A12, + A13, + A21, + A22, + A23, + A31, + A32, + A33, + B11, + B12, + B13, + B21, + B22, + B23, + B31, + B32, + B33, + C11, + C12, + C13, + C21, + C22, + C23, + C31, + C32, + C33, + M1, + M2, + M3, + N1, + N2, + N3, + K1, + K2, + K3, + stride_am1, + stride_am2, + stride_am3, + stride_ak1, + stride_ak2, + stride_ak3, + stride_bk1, + stride_bk2, + stride_bk3, + stride_bn1, + stride_bn2, + stride_bn3, + stride_cm1, + stride_cm2, + stride_cm3, + stride_cn1, + stride_cn2, + stride_cn3, + BLOCK_M: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + BLOCK_N: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + BLOCK_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, # DO NOT CHANGE NAME: MUST MATCH PERF MODEL + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + # matrix multiplication + pid = tl.program_id(0) + pid_k = tl.program_id(1) + grid_m1 = tl.cdiv(M1, BLOCK_M) + grid_m2 = tl.cdiv(M2, BLOCK_M) + grid_m3 = tl.cdiv(M3, BLOCK_M) + grid_n1 = tl.cdiv(N1, BLOCK_N) + grid_n2 = tl.cdiv(N2, BLOCK_N) + grid_n3 = tl.cdiv(N3, BLOCK_N) + grid_m = grid_m1 + grid_m2 + grid_m3 + grid_n = grid_n1 + grid_n2 + grid_n3 + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + # We use tl.where to circumvent a regression in alignment auto-detection: + # https://github.com/openai/triton/issues/1784 + + A1 = tl.where(pid_m < grid_m1, A11, tl.where(pid_m < grid_m1 + grid_m2, A21, A31)) + A2 = tl.where(pid_m < grid_m1, A12, tl.where(pid_m < grid_m1 + grid_m2, A22, A32)) + A3 = tl.where(pid_m < grid_m1, A13, tl.where(pid_m < grid_m1 + grid_m2, A23, A33)) + B1 = tl.where(pid_n < grid_n1, B11, tl.where(pid_n < grid_n1 + grid_n2, B12, B13)) + B2 = tl.where(pid_n < grid_n1, B21, tl.where(pid_n < grid_n1 + grid_n2, B22, B23)) + B3 = tl.where(pid_n < grid_n1, B31, tl.where(pid_n < grid_n1 + grid_n2, B32, B33)) + C = tl.where( + pid_m < grid_m1, + tl.where(pid_n < grid_n1, C11, tl.where(pid_n < grid_n1 + grid_n2, C12, C13)), + tl.where( + pid_m < grid_m1 + grid_m2, + tl.where( + pid_n < grid_n1, C21, tl.where(pid_n < grid_n1 + grid_n2, C22, C23) + ), + tl.where( + pid_n < grid_n1, C31, tl.where(pid_n < grid_n1 + grid_n2, C32, C33) + ), + ), + ) + M = tl.where(pid_m < grid_m1, M1, tl.where(pid_m < grid_m1 + grid_m2, M2, M3)) + N = tl.where(pid_n < grid_n1, N1, tl.where(pid_n < grid_n1 + grid_n2, N2, N3)) + stride_ak = tl.where( + pid_m < grid_m1, + stride_ak1, + tl.where(pid_m < grid_m1 + grid_m2, stride_ak2, stride_ak3), + ) + stride_bk = tl.where( + pid_n < grid_n1, + stride_bk1, + tl.where(pid_n < grid_n1 + grid_n2, stride_bk2, stride_bk3), + ) + stride_cn = tl.where( + pid_m < grid_m1, + stride_cn1, + tl.where(pid_m < grid_m1 + grid_m2, stride_cn2, stride_cn3), + ) + stride_cm = tl.where( + pid_n < grid_n1, + stride_cm1, + tl.where(pid_n < grid_n1 + grid_n2, stride_cm2, stride_cm3), + ) + pid_m = tl.where( + pid_m < grid_m1, + pid_m, + tl.where(pid_m < grid_m1 + grid_m2, pid_m - grid_m1, pid_m - grid_m1 - grid_m2), + ) + pid_n = tl.where( + pid_n < grid_n1, + pid_n, + tl.where(pid_n < grid_n1 + grid_n2, pid_n - grid_n1, pid_n - grid_n1 - grid_n2), + ) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + # pointers + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + grid_k1 = tl.cdiv(K1, BLOCK_K) + grid_k2 = tl.cdiv(K2, BLOCK_K) + grid_k3 = tl.cdiv(K3, BLOCK_K) + for tile in range(pid_k, grid_k1 + grid_k2 + grid_k3, SPLIT_K): + A = tl.where(tile < grid_k1, A1, tl.where(tile < grid_k1 + grid_k2, A2, A3)) + B = tl.where(tile < grid_k1, B1, tl.where(tile < grid_k1 + grid_k2, B2, B3)) + K = tl.where(tile < grid_k1, K1, tl.where(tile < grid_k1 + grid_k2, K2, K3)) + stride_am = tl.where( + tile < grid_k1, + stride_am1, + tl.where(tile < grid_k1 + grid_k2, stride_am2, stride_am3), + ) + stride_bn = tl.where( + tile < grid_k1, + stride_bn1, + tl.where(tile < grid_k1 + grid_k2, stride_bn2, stride_bn3), + ) + my_tile = tl.where( + tile < grid_k1, + tile, + tl.where( + tile < grid_k1 + grid_k2, tile - grid_k1, tile - grid_k1 - grid_k2 + ), + ) + rk = my_tile * BLOCK_K + tl.arange(0, BLOCK_K) + Ain = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + Bin = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + if EVEN_K: + a = tl.load(Ain) + b = tl.load(Bin) + else: + a = tl.load(Ain, mask=rk[None, :] < K, other=0.0) + b = tl.load(Bin, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b, allow_tf32=False) + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def _check_row_or_column(row_or_col_type, row_or_col_idx, tensor_name, dim_name, vals): + assert len(vals) > 0 + for pos, val in enumerate(vals[1:]): + assert val == vals[0], ( + f"the tensors on {row_or_col_type} {row_or_col_idx} of the {tensor_name} " + f"must all have the same stride along the {dim_name} dimension, got " + f"{vals[0]} at position 0 and {val} at position {pos + 1}" + ) + return vals[0] + + +def _get_strides( + ts: List[List[torch.Tensor]], tensor_name, dim_0_name, dim_1_name +) -> Tuple[List[int], List[int]]: + strides_0 = [ + _check_row_or_column( + "column", idx, tensor_name, dim_0_name, [y.stride(0) for y in x] + ) + for idx, x in enumerate(zip(*ts)) + ] + strides_1 = [ + _check_row_or_column( + "row", idx, tensor_name, dim_1_name, [y.stride(1) for y in x] + ) + for idx, x in enumerate(ts) + ] + assert all(s == 1 for s in strides_0) or all(s == 1 for s in strides_1) + while len(strides_0) < 3: + strides_0.append(1 if strides_0[0] == 1 else 0) + while len(strides_1) < 3: + strides_1.append(1 if strides_1[0] == 1 else 0) + return strides_0, strides_1 + + +def _launch_triton_matmul( + a: List[List[torch.Tensor]], + b: List[List[torch.Tensor]], + c: List[List[torch.Tensor]], + ms: List[int], + ns: List[int], + ks: List[int], +) -> None: + strides_am, strides_ak = _get_strides(a, "first operand", "m", "k") + strides_bk, strides_bn = _get_strides(b, "second operand", "k", "n") + strides_cm, strides_cn = _get_strides(c, "output", "m", "n") + + # accumulator types + ACC_TYPE = ( + tl.float32 + if c[0][0].dtype in [torch.float16, torch.bfloat16, torch.float32] + else tl.int32 + ) + + # launch kernel + def grid(META): + return ( + sum(triton.cdiv(m, META["BLOCK_M"]) for m in ms) + * sum(triton.cdiv(n, META["BLOCK_N"]) for n in ns), + META["SPLIT_K"], + ) + + _xformers_tiled_matmul_kernel[grid]( + *[ + a[min(i, len(a) - 1)][min(j, len(a[0]) - 1)] + for i in range(3) + for j in range(3) + ], + *[ + b[min(i, len(b) - 1)][min(j, len(b[0]) - 1)] + for i in range(3) + for j in range(3) + ], + *[ + c[min(i, len(c) - 1)][min(j, len(c[0]) - 1)] + for i in range(3) + for j in range(3) + ], + *[ms[i] if len(ms) > i else 0 for i in range(3)], + *[ns[i] if len(ns) > i else 0 for i in range(3)], + *[ks[i] if len(ks) > i else 0 for i in range(3)], + *strides_am, + *strides_ak, + *strides_bk, + *strides_bn, + *strides_cm, + *strides_cn, + ACC_TYPE=ACC_TYPE, + ) diff --git a/xformers/ops/differentiable_collectives.py b/xformers/ops/differentiable_collectives.py new file mode 100644 index 0000000000..b073e2da60 --- /dev/null +++ b/xformers/ops/differentiable_collectives.py @@ -0,0 +1,178 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Tuple + +import torch +import torch.distributed + + +def all_reduce( + x: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> None: + assert x.is_contiguous() + + mp_size = torch.distributed.get_world_size(process_group) + if mp_size == 1: + return + + torch.distributed.all_reduce( + tensor=x, op=torch.distributed.ReduceOp.SUM, group=process_group + ) + + +def gather_along_first_dim_async( + input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: + assert input_.is_contiguous() + mp_size = torch.distributed.get_world_size(process_group) + if mp_size == 1: + return input_, None + + output = input_.new_empty((input_.shape[0] * mp_size,) + input_.shape[1:]) + handle = torch.distributed.all_gather_into_tensor( + output_tensor=output, + input_tensor=input_, + group=process_group, + async_op=True, + ) + + return output, handle + + +def reduce_scatter_along_first_dim_async( + input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: + assert input_.is_contiguous() + mp_size = torch.distributed.get_world_size(process_group) + if mp_size == 1: + return input_, None + + assert input_.shape[0] % mp_size == 0 + output = input_.new_empty((input_.shape[0] // mp_size,) + input_.shape[1:]) + handle = torch.distributed.reduce_scatter_tensor( + output=output, + input=input_, + op=torch.distributed.ReduceOp.SUM, + group=process_group, + async_op=True, + ) + + return output, handle + + +def gather_along_first_dim( + input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + output, handle = gather_along_first_dim_async(input_, process_group=process_group) + if handle is not None: + handle.wait() + return output + + +def reduce_scatter_along_first_dim( + input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + output, handle = reduce_scatter_along_first_dim_async( + input_, process_group=process_group + ) + if handle is not None: + handle.wait() + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, input_: torch.Tensor, process_group: torch.distributed.ProcessGroup + ) -> torch.Tensor: + ctx.process_group = process_group + return input_ + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + all_reduce(grad_output, process_group=ctx.process_group) + return grad_output, None + + +def copy_to_model_parallel_region( + x: torch.Tensor, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + return _CopyToModelParallelRegion.apply(x, process_group) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, input_: torch.Tensor, process_group: torch.distributed.ProcessGroup + ) -> torch.Tensor: + all_reduce(input_, process_group=process_group) + ctx.mark_dirty(input_) + return input_ + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + return grad_output, None + + +def reduce_from_model_parallel_region( + x: torch.Tensor, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + return _ReduceFromModelParallelRegion.apply(x, process_group) + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, x: torch.Tensor, process_group: torch.distributed.ProcessGroup + ) -> torch.Tensor: + ctx.process_group = process_group + return gather_along_first_dim(x, process_group=process_group) + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + return ( + reduce_scatter_along_first_dim( + grad_output, process_group=ctx.process_group + ), + None, + ) + + +def gather_from_sequence_parallel_region( + x: torch.Tensor, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + return _GatherFromSequenceParallelRegion.apply(x, process_group) + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, x: torch.Tensor, process_group: torch.distributed.ProcessGroup + ) -> torch.Tensor: + ctx.process_group = process_group + return reduce_scatter_along_first_dim(x, process_group=process_group) + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None]: + return ( + gather_along_first_dim(grad_output, process_group=ctx.process_group), + None, + ) + + +def scatter_to_sequence_parallel_region( + x: torch.Tensor, process_group: torch.distributed.ProcessGroup +) -> torch.Tensor: + return _ScatterToSequenceParallelRegion.apply(x, process_group) diff --git a/xformers/ops/modpar_layers.py b/xformers/ops/modpar_layers.py new file mode 100644 index 0000000000..779f017ede --- /dev/null +++ b/xformers/ops/modpar_layers.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List + +import torch +import torch.distributed + +from .differentiable_collectives import ( + copy_to_model_parallel_region, + reduce_from_model_parallel_region, +) +from .seqpar import sequence_parallel_leading_matmul, sequence_parallel_trailing_matmul + + +def _init_2d_weight( + weight: torch.Tensor, + init_method: Callable[[torch.Tensor], torch.Tensor], + process_group: torch.distributed.ProcessGroup, + partition_dim: int, +) -> None: + # Mimick FairScale's _initialize_affine_weight, for backwards compatibility. + # The reason we initialize the full unpartitioned/gathered weight is so that + # different ranks get different initial values and thus "break the symmetry" + # and in order to achieve the same init for any value of model parallelism. + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + nrows, ncols = weight.shape + if partition_dim == 0: + full_weight = weight.new_empty(nrows * world_size, ncols) + my_weight_slice = full_weight[rank::world_size, :] + else: + full_weight = weight.new_empty(nrows, ncols * world_size) + my_weight_slice = full_weight[:, rank::world_size] + + init_method(full_weight) + + with torch.no_grad(): + weight.copy_(my_weight_slice) + + +class ColumnParallelLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: List[int], + *, + process_group: torch.distributed.ProcessGroup, + bias: bool = True, + gather_output: bool = True, + init_method: Callable[ + [torch.Tensor], torch.Tensor + ] = torch.nn.init.xavier_normal_, + sequence_parallel: bool = False, + fuse_sequence_parallel: bool = True, + ) -> None: + super(ColumnParallelLinear, self).__init__() + + if not isinstance(out_features, list): + raise TypeError( + "xFormers's implementation of ColumnParallelLinear requires out_features to be a list" + ) + if bias: + raise ValueError( + "xFormers's implementation of ColumnParallelLinear requires bias=False" + ) + if gather_output: + raise ValueError( + "xFormers's implementation of ColumnParallelLinear requires gather_output=False" + ) + + self.in_features = in_features + self.global_out_features = out_features + self.sequence_parallel = sequence_parallel + self.fuse_sequence_parallel = fuse_sequence_parallel + self.process_group = process_group + mp_size = torch.distributed.get_world_size(process_group) + assert all(dim % mp_size == 0 for dim in out_features) + self.my_out_features = [dim // mp_size for dim in out_features] + + self.weights = torch.nn.ParameterList( + [ + torch.nn.Parameter(torch.empty((dim, in_features))) + for dim in self.my_out_features + ] + ) + + for w in self.weights: + _init_2d_weight(w, init_method, process_group, partition_dim=0) + + def forward(self, input_: torch.Tensor) -> List[torch.Tensor]: + if self.sequence_parallel: + outputs = sequence_parallel_leading_matmul( + input_, + [w.t() for w in self.weights], + fuse=self.fuse_sequence_parallel, + process_group=self.process_group, + ) + else: + input_ = copy_to_model_parallel_region(input_, self.process_group) + outputs = [torch.matmul(input_, w.t()) for w in self.weights] + return outputs + + +class RowParallelLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + *, + process_group: torch.distributed.ProcessGroup, + bias: bool = True, + input_is_parallel: bool = False, + init_method: Callable[ + [torch.Tensor], torch.Tensor + ] = torch.nn.init.xavier_normal_, + sequence_parallel: bool = False, + fuse_sequence_parallel: bool = True, + ): + super(RowParallelLinear, self).__init__() + + if bias: + raise ValueError( + "xFormers's implementation of RowParallelLinear requires bias=False" + ) + if not input_is_parallel: + raise ValueError( + "xFormers's implementation of RowParallelLinear requires input_is_parallel=True" + ) + + self.global_in_features = in_features + self.out_features = out_features + self.sequence_parallel = sequence_parallel + self.fuse_sequence_parallel = fuse_sequence_parallel + self.process_group = process_group + mp_size = torch.distributed.get_world_size(process_group) + assert in_features % mp_size == 0 + self.my_in_features = in_features // mp_size + + self.weight = torch.nn.Parameter( + torch.empty((out_features, self.my_in_features)) + ) + + _init_2d_weight(self.weight, init_method, process_group, partition_dim=1) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + if self.sequence_parallel: + output = sequence_parallel_trailing_matmul( + input_, + self.weight.t(), + fuse=self.fuse_sequence_parallel, + process_group=self.process_group, + ) + else: + output = torch.matmul(input_, self.weight.t()) + output = reduce_from_model_parallel_region(output, self.process_group) + return output diff --git a/xformers/ops/seqpar.py b/xformers/ops/seqpar.py new file mode 100644 index 0000000000..8752faa64b --- /dev/null +++ b/xformers/ops/seqpar.py @@ -0,0 +1,286 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Callable, List, Optional, Tuple + +import torch + +from .common import make_pytorch_cuda_operator +from .differentiable_collectives import ( + gather_along_first_dim, + gather_along_first_dim_async, + reduce_scatter_along_first_dim, + reduce_scatter_along_first_dim_async, +) +from .sequence_parallel_fused_ops import ( + fused_allgather_and_anything, + fused_allgather_and_linear, + fused_anything_and_reducescatter, + fused_linear_and_reducescatter, +) +from .tiled_matmul import tiled_matmul_fwd + + +@make_pytorch_cuda_operator +def sequence_parallel_leading_matmul_fwd( + scattered_input: torch.Tensor, + weights: List[torch.Tensor], + fuse: bool, + process_group: torch.distributed.ProcessGroup, +) -> List[torch.Tensor]: + if fuse: + gathered_outputs = fused_allgather_and_linear( + scattered_input, [w.t() for w in weights], group=process_group + ) + else: + gathered_input = gather_along_first_dim( + scattered_input, process_group=process_group + ) + (gathered_outputs,) = tiled_matmul_fwd( + [[gathered_input]], + [[w for w in weights]], + ) + return gathered_outputs + + +@make_pytorch_cuda_operator +def sequence_parallel_leading_matmul_bwd( + scattered_input: torch.Tensor, + weights: List[torch.Tensor], + grad_gathered_outputs: List[torch.Tensor], + fuse: bool, + process_group: torch.distributed.ProcessGroup, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + mp_size = torch.distributed.get_world_size(process_group) + + if fuse: + grad_scattered_input = torch.empty_like(scattered_input) + grad_weights = [torch.zeros_like(w) for w in weights] + + grad_gathered_outputss = [ + grad_go.tensor_split(mp_size, dim=0) for grad_go in grad_gathered_outputs + ] + + def my_si_matmul( + grad_gathered_inputs: List[torch.Tensor], + dst_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + (grad_gi,) = grad_gathered_inputs + with torch.cuda.stream(stream_factory()): + tiled_matmul_fwd( + [[grad_gos[dst_rank] for grad_gos in grad_gathered_outputss]], + [[w.t()] for w in weights], + out=[[grad_gi]], + ) + + fused_anything_and_reducescatter( + my_si_matmul, + [grad_scattered_input], + group=process_group, + ) + + # Each pair of shards of input and grad_output accumulates into the same + # grad_weight. Thus we need to make sure that the in-place addmms are + # sequenced correctly for each of the grad_weights. + events = [torch.cuda.Event() for _ in weights] + + def my_w_matmul( + gathered_inputs_shard: List[torch.Tensor], + src_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + (gi_shard,) = gathered_inputs_shard + for grad_gos, grad_w, event in zip( + grad_gathered_outputss, grad_weights, events + ): + with torch.cuda.stream(stream_factory()): + event.wait() + grad_w.t().addmm_(grad_gos[src_rank].t(), gi_shard) + event.record() + + fused_allgather_and_anything( + [scattered_input], + my_w_matmul, + group=process_group, + ) + else: + gathered_input, handle = gather_along_first_dim_async( + scattered_input, process_group=process_group + ) + ((grad_gathered_input,),) = tiled_matmul_fwd( + [[grad_go for grad_go in grad_gathered_outputs]], + [[w.t()] for w in weights], + ) + if handle is not None: + handle.wait() + + grad_scattered_input, handle = reduce_scatter_along_first_dim_async( + grad_gathered_input, process_group=process_group + ) + + grad_weights_tuples = tiled_matmul_fwd( + [[grad_go.t()] for grad_go in grad_gathered_outputs], + [[gathered_input]], + ) + if handle is not None: + handle.wait() + + grad_weights = [grad_w.t() for (grad_w,) in grad_weights_tuples] + + return grad_scattered_input, grad_weights + + +class _SequenceParallelLeadingMatmul(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + fuse: bool, + process_group: torch.distributed.ProcessGroup, + scattered_input: torch.Tensor, + *weights: torch.Tensor, + ) -> Tuple[torch.Tensor, ...]: + ctx.save_for_backward(scattered_input, *weights) + ctx.fuse = fuse + ctx.process_group = process_group + gathered_output = sequence_parallel_leading_matmul_fwd( + scattered_input, list(weights), fuse, process_group + ) + return tuple(gathered_output) + + @staticmethod + def backward( # type: ignore[override] + ctx, *grad_gathered_outputs: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], ...]: + scattered_input, *weights = ctx.saved_tensors + (grad_scattered_input, grad_weights,) = sequence_parallel_leading_matmul_bwd( + scattered_input, + list(weights), + list(grad_gathered_outputs), + ctx.fuse, + ctx.process_group, + ) + return None, None, grad_scattered_input, *grad_weights + + +def sequence_parallel_leading_matmul( + x: torch.Tensor, + ws: List[torch.Tensor], + *, + fuse: bool, + process_group: torch.distributed.ProcessGroup, +) -> List[torch.Tensor]: + os = _SequenceParallelLeadingMatmul.apply( + fuse, process_group, x.flatten(0, -2), *ws + ) + return [o.view(-1, *x.shape[1:-1], w.shape[1]) for o, w in zip(os, ws)] + + +@make_pytorch_cuda_operator +def sequence_parallel_trailing_matmul_fwd( + gathered_input: torch.Tensor, + weight: torch.Tensor, + fuse: bool, + process_group: torch.distributed.ProcessGroup, +) -> torch.Tensor: + if fuse: + scattered_output = fused_linear_and_reducescatter( + gathered_input, weight.t(), group=process_group + ) + else: + gathered_output = torch.matmul(gathered_input, weight) + scattered_output = reduce_scatter_along_first_dim( + gathered_output, process_group=process_group + ) + return scattered_output + + +@make_pytorch_cuda_operator +def sequence_parallel_trailing_matmul_bwd( + gathered_input: torch.Tensor, + weight: torch.Tensor, + grad_scattered_output: torch.Tensor, + fuse: bool, + process_group: torch.distributed.ProcessGroup, +) -> Tuple[torch.Tensor, torch.Tensor]: + mp_size = torch.distributed.get_world_size(process_group) + + if fuse: + grad_gathered_input = torch.empty_like(gathered_input) + grad_weight = torch.zeros_like(weight) + + gathered_inputs = gathered_input.tensor_split(mp_size, dim=0) + grad_gathered_inputs = grad_gathered_input.tensor_split(mp_size, dim=0) + + def my_gi_and_w_matmul( + grad_gathered_outputs_shard: List[torch.Tensor], + src_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + (grad_go_shard,) = grad_gathered_outputs_shard + with torch.cuda.stream(stream_factory()): + torch.matmul( + grad_go_shard, weight.t(), out=grad_gathered_inputs[src_rank] + ) + with torch.cuda.stream(stream_factory()): + grad_weight.t().addmm_(grad_go_shard.t(), gathered_inputs[src_rank]) + + fused_allgather_and_anything( + [grad_scattered_output], + my_gi_and_w_matmul, + group=process_group, + ) + else: + grad_gathered_output = gather_along_first_dim( + grad_scattered_output, process_group=process_group + ) + grad_gathered_input = torch.matmul(grad_gathered_output, weight.t()) + grad_weight = torch.matmul(grad_gathered_output.t(), gathered_input).t() + + return grad_gathered_input, grad_weight + + +class _SequenceParallelTrailingMatmul(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + fuse: bool, + process_group: torch.distributed.ProcessGroup, + gathered_input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + ctx.save_for_backward(gathered_input, weight) + ctx.fuse = fuse + ctx.process_group = process_group + scattered_output = sequence_parallel_trailing_matmul_fwd( + gathered_input, weight, fuse, process_group + ) + return scattered_output + + @staticmethod + def backward( # type: ignore[override] + ctx, grad_scattered_output: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], ...]: + gathered_input, weight = ctx.saved_tensors + (grad_gathered_input, grad_weight,) = sequence_parallel_trailing_matmul_bwd( + gathered_input, + weight, + grad_scattered_output, + ctx.fuse, + ctx.process_group, + ) + return None, None, grad_gathered_input, grad_weight + + +def sequence_parallel_trailing_matmul( + x: torch.Tensor, + w: torch.Tensor, + *, + fuse: bool, + process_group: torch.distributed.ProcessGroup, +) -> torch.Tensor: + o = _SequenceParallelTrailingMatmul.apply(fuse, process_group, x.flatten(0, -2), w) + return o.view(-1, *x.shape[1:-1], w.shape[1]) diff --git a/xformers/ops/sequence_parallel_fused_ops.py b/xformers/ops/sequence_parallel_fused_ops.py new file mode 100644 index 0000000000..44566fd967 --- /dev/null +++ b/xformers/ops/sequence_parallel_fused_ops.py @@ -0,0 +1,1052 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import concurrent.futures +import multiprocessing.connection +import os +from typing import Any, Callable, Dict, List, Mapping, Optional, Union, overload + +import torch +import torch.distributed as dist +import torch.multiprocessing.reductions + +from .. import _is_triton_available +from .common import BaseOperator, get_xformers_operator, register_operator + +if _is_triton_available(): + from ._triton.sequence_parallel_fused_kernels import ( + BACKWARDS_WITH_ME_FIRST, + FORWARDS_WITH_ME_LAST, + _launch_triton_matmul, + ) + + TRITON_IS_AVAILABLE = True +else: + TRITON_IS_AVAILABLE = False + + +# The sequence numbers will be communicated as 32-bit integers, due to +# limitations in both CUDA (memset can only operate on 4 bytes at a time at +# most) and Triton (scalar arguments are int32 if they fit). 32 bits are not +# enough to be sure that we'll never see overflow. Moreover, different parts of +# the code use signed or unsigned ints. To be safe, let's simulate overflow +# ourselves, at a value low enough so that it fits both a signed and an unsigned +# 32-bit integer. And, in fact, let's make it so low that we're sure we'll hit +# it in our tests, to avoid bugs that only manifest in long-running training. +SEQ_NUM_WRAP_AROUND = 2**8 + + +@register_operator +class WriteValues(BaseOperator): + OPERATOR = get_xformers_operator("write_values") + OPERATOR_CATEGORY = "sequence_parallel_fused" + NAME = "write_values" + + +@register_operator +class WaitValues(BaseOperator): + OPERATOR = get_xformers_operator("wait_values") + OPERATOR_CATEGORY = "sequence_parallel_fused" + NAME = "wait_values" + + +@register_operator +class Memset32bAsync(BaseOperator): + OPERATOR = get_xformers_operator("cuda_memset_32b_async") + OPERATOR_CATEGORY = "sequence_parallel_fused" + NAME = "cuda_memset_32b_async" + + +# We could just send tensors directly on mp.Connections, since PyTorch installs +# the necessary reductions to make it work. However, in the receiving process, +# PyTorch "mounts" the tensor in the CUDA context for the GPU with the **SAME +# INDEX** as on the sender. This works if all processes use CUDA_VISIBLE_DEVICES +# to limit themselves to a single GPU (which thus has index 0 everywhere) but in +# all other cases it's a mess. Hence we use our own reductions (which wrap the +# ones from PyTorch) to use the right devices. + + +def _serialize_cuda_tensor(tensor, device): + assert tensor.device == device + assert device.type == "cuda" + func, args = torch.multiprocessing.reductions.reduce_tensor(tensor) + assert func is torch.multiprocessing.reductions.rebuild_cuda_tensor + assert args[6] == device.index + return args + + +def _deserialize_cuda_tensor(args, device): + return torch.multiprocessing.reductions.rebuild_cuda_tensor( + *args[:6], device.index, *args[7:] + ) + + +# We need all processes to exchange a few strings with their addresses (in order +# to be able to connect to each other). The solution for this kind of things in +# PyTorch is a Store (TCPStore or FileStore) but we cannot create one ourselves +# (we don't know which addr/port/file to use, since the default one is already +# being used by PyTorch's global store) nor can we extract one from the +# ProcessGroup (since there's no API to do so). We thus resort to using the PG +# itself to exchange data, which is overkill (we need to store the pickled data +# into tensors and send it to the GPU). On top of that, it introduces one more +# catch: it doesn't work in inference mode because of something about modifying +# tensors inplace. I couldn't find a way to temporarily disable inference mode +# (although it's supposed to be possible) however inference mode is thread-local +# so we can dodge it by offloading the collective call to another thread. I hate +# all this so much. + + +def _exchange_addresses( + listeners: List[multiprocessing.connection.Listener], + group: dist.ProcessGroup, + device: torch.device, +) -> List[List[str]]: + world_size = dist.get_world_size(group=group) + my_addresses: List[str] = [] + for listener in listeners: + addr = listener.address + # The address could be a tuple if the listener weren't a UNIX socket + if isinstance(addr, bytes): + # Shouldn't be bytes, according to docs and typeshed, but... + # https://github.com/python/typeshed/issues/10054 + addr = addr.decode("utf-8") + assert isinstance(addr, str) + my_addresses.append(addr) + all_addresses = [[""] * (world_size - 1)] * world_size + with concurrent.futures.ThreadPoolExecutor( + initializer=torch.cuda.set_device, initargs=(device,) + ) as e: + e.submit( + dist.all_gather_object, + object_list=all_addresses, + obj=my_addresses, + group=group, + ).result() + return all_addresses + + +class _FusedSequenceParallel: + """Set up a communication ring and perform fused ops on it + + Stores the persistent state needed to support a ring of connections between + processes, and the logic that can do fused comms + matmuls on it. + + We want to achieve overlap between: + - a computation which reads from the data we received from a remote GPU + - and the communication where we send some data to another GPU + And in order to do that we need some staging buffers and a way to + synchronize access to them across processes. + + To perform the communication over NVLink we make the processes exchange + their staging buffers using IPC (Inter-Process Communication) handles, which + "mounts"/"mmaps" an allocation on one GPU into the virtual address space of + another GPU: the memory remains backed by the original GPU but the other GPU + can access it as if it were local. We exchange these IPC handles using + multiprocessing Connections (and the "reductions" provided by PyTorch), + which we establish over UNIX domain sockets, whose addresses we exchange by + using a ProcessGroup. + + To synchronize accesses we use a set of counters/sequence numbers that are + also allocated in memory shared over IPC handles. Processes signal that they + completed an operation by launching a kernel that increases that value, and + they wait for anoher process to complete an operation by launching a kernel + that busy-waits for that value to increase. Currently we implement these + kernels manually, but on recent CUDA drivers (515.43.04+, corresponding to + CUDA 11.7) we could use standard stream memory operations (see + https://docs.nvidia.com/cuda/archive/11.7.0/cuda-driver-api/group__CUDA__MEMOP.html). + + We prefer to use these kernels (or the stream memory ops) over IPC events + because IPC events require signaling between processes at launch time to + ensure that the wait on one process occurs after the record on another + process. This signaling means that _launching_ our fused operation becomes a + synchronization barrier, which can increase the launch overhead. It would + also behave differently from NCCL, where launching is async and all the + synchronization happens on device in the kernels. A previous version of this + code which uses IPC events can be found here: + https://github.com/fairinternal/xformers/pull/504. + + """ + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + group: dist.ProcessGroup, + num_stripes: int, + ): + self.my_device = device + self.dtype = dtype + self.my_rank = dist.get_rank(group=group) + self.world_size = dist.get_world_size(group=group) + self.num_stripes = num_stripes + self.my_device_capability = torch.cuda.get_device_capability(self.my_device) + + # Open connections to all other processes. We exchange addresses via + # NCCL since we don't have access to a Store. + listeners = [ + multiprocessing.connection.Listener(family="AF_UNIX", address="", backlog=1) + for _ in range(self.world_size - 1) + ] + # If any process is late, all other ones will block here + all_addresses = _exchange_addresses(listeners, group, self.my_device) + self.outgoing_conns = [ + None + if r == self.my_rank + else multiprocessing.connection.Client( + family="AF_UNIX", + # Mypy wants it to be str, but it actually can also be bytes + # https://github.com/python/typeshed/issues/10054 + address=all_addresses[r][(r - self.my_rank) % self.world_size - 1], + ) + for r in range(self.world_size) + ] + self.incoming_conns = [ + None + if r == self.my_rank + else listeners[(self.my_rank - r) % self.world_size - 1].accept() + for r in range(self.world_size) + ] + + self.next_stripe = 0 + self.next_seq_nums = [1] * self.num_stripes + + # My staging buffers + self.staging = torch.empty((0,), device=self.my_device) + + # (Mmapped view of a handle to) buddies' staging buffers + self.buddys_staging = [ + torch.empty((0,), device=self.my_device) + ] * self.world_size + + # Allocate buffers for my inboxes + self.num_writes_into_my_staging = torch.zeros( + (self.world_size, self.num_stripes), dtype=torch.int, device=self.my_device + ) + self.num_reads_from_buddys_staging = torch.zeros( + (self.world_size, self.num_stripes), dtype=torch.int, device=self.my_device + ) + + # Send my handles to buddies + for rank, (in_conn, out_conn) in enumerate( + zip(self.incoming_conns, self.outgoing_conns) + ): + if in_conn is not None: + in_conn.send( + _serialize_cuda_tensor( + self.num_writes_into_my_staging[rank], self.my_device + ) + ) + if out_conn is not None: + out_conn.send( + _serialize_cuda_tensor( + self.num_reads_from_buddys_staging[rank], self.my_device + ) + ) + + # Open buddies' inboxes as my outboxes + self.num_writes_into_buddys_staging = [ + torch.empty((0,), device=self.my_device) + if out_conn is None + else _deserialize_cuda_tensor(out_conn.recv(), self.my_device) + for out_conn in self.outgoing_conns + ] + self.num_reads_from_my_staging = [ + torch.empty((0,), device=self.my_device) + if in_conn is None + else _deserialize_cuda_tensor(in_conn.recv(), self.my_device) + for in_conn in self.incoming_conns + ] + + self.second_stream = torch.cuda.Stream() + # CUDA can schedule the matmul and the memcpy at the same time, but it + # tends to run the matmul first and delay the memcpy, which causes a + # domino effect. We thus "encourage" it to prioritize the memcpy. + self.memcpy_stream = torch.cuda.Stream(priority=-1) + # Use dedicated streams to parallelize other operations. + self.wait_stream = torch.cuda.Stream(priority=-1) + self.write_stream = torch.cuda.Stream(priority=-1) + + self.next_stream_idx = 0 + + def _ensure_staging_is_large_enough(self, num_elements: int, random_init: bool): + # Lazily size up the staging area as needed. (If it's the first call, + # this will always trigger, since staging starts empty). Once at steady + # state, staging will be of the right (max) size and never grow again. + if self.staging.numel() < self.world_size * num_elements: + # When running with _memcpy=False (i.e., for benchmarks) we must + # ensure that the staging buffer doesn't contain all zeroes as that + # makes the matmuls go faster (better L2 compression or something). + self.staging = torch.empty( + (self.num_stripes, self.world_size, num_elements), + device=self.my_device, + dtype=self.dtype, + ) + if random_init: + self.staging.normal_() + for rank, in_conn in enumerate(self.incoming_conns): + if in_conn is not None: + in_conn.send( + _serialize_cuda_tensor(self.staging[:, rank], self.my_device) + ) + self.buddys_staging = [ + torch.empty((0,), device=self.my_device) + if out_conn is None + else _deserialize_cuda_tensor(out_conn.recv(), self.my_device) + for rank, out_conn in enumerate(self.outgoing_conns) + ] + + def _should_use_triton(self, _triton: bool): + if not int(os.getenv("XFORMERS_FUSED_SEQPAR_ENABLE_TRITON", "1")): + return False + if not TRITON_IS_AVAILABLE: + return False + # Triton seems to be having issues on P100 and V100 GPUs, such as + # https://github.com/openai/triton/issues/1609 + # https://github.com/openai/triton/issues/1610 + # https://github.com/openai/triton/issues/1257#issuecomment-1532616965 + # and, in recent Triton versions (Jan 2024), returning wrong values. + if self.my_device_capability < (8, 0): + return False + if not _triton: + return False + return True + + def make_stream_factory( + self, current_stream: torch.cuda.Stream + ) -> Callable[[], torch.cuda.Stream]: + def result(): + stream = [current_stream, self.second_stream][self.next_stream_idx] + self.next_stream_idx += 1 + self.next_stream_idx %= 2 + return stream + + return result + + def allgather_and_linear( + self, + scattered_inputs: List[torch.Tensor], + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + timeout_s: int, + _wait: bool = True, + _memcpy: bool = True, + _triton: bool = True, + _is_regular_matmul: bool = False, + _extra_triton_args: Mapping[str, Any] = {}, + ): + """Perform a fused all-gather followed by a linear layer""" + + assert all(si.device == self.my_device for si in scattered_inputs) + assert all(si.dtype == self.dtype for si in scattered_inputs) + + scattered_input_numels = [si.numel() for si in scattered_inputs] + total_scattered_input_numel = sum(scattered_input_numels) + self._ensure_staging_is_large_enough( + total_scattered_input_numel, random_init=_memcpy is False + ) + + stripe = self.next_stripe % self.num_stripes + self.next_stripe += 1 + + seq_num = self.next_seq_nums[stripe] % SEQ_NUM_WRAP_AROUND + prev_seq_num = (seq_num - 1) % SEQ_NUM_WRAP_AROUND + self.next_seq_nums[stripe] += 1 + + stagings = [ + s.view((self.world_size,) + si.shape) + for s, si in zip( + self.staging[stripe, :, :total_scattered_input_numel].split( + scattered_input_numels, dim=-1 + ), + scattered_inputs, + ) + ] + buddys_stagings = [ + [bs] * len(scattered_inputs) + if bs.numel() == 0 + else [ + s.view(si.shape) + for s, si in zip( + bs[stripe, :total_scattered_input_numel].split( + scattered_input_numels, dim=-1 + ), + scattered_inputs, + ) + ] + for bs in self.buddys_staging + ] + + current_stream = torch.cuda.current_stream() + + self.memcpy_stream.wait_stream(current_stream) + + # Wait for buddy to signal that it read from the data before we + # overwrite it (this wait matches up with write [B] below). + if _wait: + WaitValues.OPERATOR( + [ + self.num_reads_from_buddys_staging[ + (self.my_rank + iter_) % self.world_size, stripe + ] + for iter_ in range(1, self.world_size) + ], + prev_seq_num, + self.memcpy_stream, + timeout_s, + ) + + for iter_ in range(1, self.world_size): + dst_rank = (self.my_rank + iter_) % self.world_size + + if _memcpy: + with torch.cuda.stream(self.memcpy_stream): + for bs, si in zip(buddys_stagings[dst_rank], scattered_inputs): + bs.copy_(si) + + self.write_stream.wait_stream(self.memcpy_stream) + + # Signal to buddy that we have written into the data so it can + # read from it (this write matches up with the wait in Triton + # or with wait [A] below). + if _wait: + Memset32bAsync.OPERATOR( + self.num_writes_into_buddys_staging[dst_rank][stripe], + seq_num, + self.write_stream, + ) + + # If we're doing a regular matmul, we have a faster fused Triton kernel! + if _is_regular_matmul and self._should_use_triton(_triton): + # Wait for buddy to signal that it wrote into the data before we + # read from it (this wait matches up with write [A] above). + _launch_triton_matmul( + a_my_shard=scattered_inputs[0].flatten(0, -2), + a=stagings[0].flatten(0, -2), + my_rank=self.my_rank, + world_size=self.world_size, + wait_counters=self.num_writes_into_my_staging, + write_counters=None, + direction=BACKWARDS_WITH_ME_FIRST, + stripe=stripe, + seq_num=seq_num, + num_stripes=self.num_stripes, + timeout_s=timeout_s, + _wait=_wait, + **_extra_triton_args, + ) + + else: + # Not needed, but it prevents the waits from starting much earlier + # than the rest of the op, which is confusing when profiling. + self.wait_stream.wait_stream(current_stream) + + self.second_stream.wait_stream(current_stream) + stream_factory = self.make_stream_factory(current_stream) + + my_matmul(scattered_inputs, self.my_rank, stream_factory) + + for iter_ in range(1, self.world_size): + src_rank = (self.my_rank - iter_) % self.world_size + + # Wait for buddy to signal that it wrote into the data before we + # read from it (this wait matches up with write [A] above). + if _wait: + WaitValues.OPERATOR( + [self.num_writes_into_my_staging[src_rank, stripe]], + seq_num, + self.wait_stream, + timeout_s, + ) + current_stream.wait_stream(self.wait_stream) + self.second_stream.wait_stream(self.wait_stream) + + my_matmul([s[src_rank] for s in stagings], src_rank, stream_factory) + + current_stream.wait_stream(self.second_stream) + + self.write_stream.wait_stream(current_stream) + + # Signal to buddy that we have read from the data so it can + # overwrite it (this write matches up with wait [B] above). + if _wait: + WriteValues.OPERATOR( + [ + self.num_reads_from_my_staging[ + (self.my_rank - iter_) % self.world_size + ][stripe] + for iter_ in range(1, self.world_size) + ], + seq_num, + self.write_stream, + ) + + def linear_and_reducescatter( + self, + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + gathered_outputs: List[torch.Tensor], + scattered_outputs: List[torch.Tensor], + timeout_s: int, + _wait: bool = True, + _memcpy: bool = True, + _triton: bool = True, + _is_regular_matmul: bool = False, + _extra_triton_args: Mapping[str, Any] = {}, + ): + """Perform a fused linear layer followed by a reduce-scatter""" + + assert all(go.device == self.my_device for go in gathered_outputs) + assert all(go.dtype == self.dtype for go in gathered_outputs) + assert all(so.device == self.my_device for so in scattered_outputs) + assert all(so.dtype == self.dtype for so in scattered_outputs) + + scattered_output_numels = [so.numel() for so in scattered_outputs] + total_scattered_output_numel = sum(scattered_output_numels) + self._ensure_staging_is_large_enough( + total_scattered_output_numel, random_init=_memcpy is False + ) + + stripe = self.next_stripe % self.num_stripes + self.next_stripe += 1 + + seq_num = self.next_seq_nums[stripe] % SEQ_NUM_WRAP_AROUND + prev_seq_num = (seq_num - 1) % SEQ_NUM_WRAP_AROUND + self.next_seq_nums[stripe] += 1 + + stagings = [ + s.view((self.world_size,) + so.shape) + for s, so in zip( + self.staging[stripe, :, :total_scattered_output_numel].split( + scattered_output_numels, dim=-1 + ), + scattered_outputs, + ) + ] + buddys_stagings = [ + [bs] * len(scattered_outputs) + if bs.numel() == 0 + else [ + s.view(so.shape) + for s, so in zip( + bs[stripe, :total_scattered_output_numel].split( + scattered_output_numels, dim=-1 + ), + scattered_outputs, + ) + ] + for bs in self.buddys_staging + ] + + current_stream = torch.cuda.current_stream() + + self.wait_stream.wait_stream(current_stream) + + # Wait for buddy to signal that it read from the data before we + # overwrite it (this wait matches up with write [2] below). + if _wait: + WaitValues.OPERATOR( + [ + self.num_reads_from_my_staging[ + (self.my_rank + iter_) % self.world_size + ][stripe] + for iter_ in range(1, self.world_size) + ], + prev_seq_num, + current_stream, + timeout_s, + ) + + # If we're doing a regular matmul, we have a faster fused Triton kernel! + if _is_regular_matmul and self._should_use_triton(_triton): + # Signal to buddy that we have written into the data so it can + # read from it (this write matches up with wait [1] below). + _launch_triton_matmul( + cs=[s.flatten(0, -2) for s in stagings], + cs_my_shard=[ + go[self.my_rank].flatten(0, -2) for go in gathered_outputs + ], + my_rank=self.my_rank, + world_size=self.world_size, + wait_counters=None, + write_counters=self.num_writes_into_my_staging, + direction=FORWARDS_WITH_ME_LAST, + stripe=stripe, + seq_num=seq_num, + num_stripes=self.num_stripes, + timeout_s=timeout_s, + _wait=_wait, + **_extra_triton_args, + ) + + else: + self.second_stream.wait_stream(current_stream) + stream_factory = self.make_stream_factory(current_stream) + + for iter_ in range(1, self.world_size): + dst_rank = (self.my_rank + iter_) % self.world_size + + my_matmul([s[dst_rank] for s in stagings], dst_rank, stream_factory) + + # Signal to buddy that we have written into the data so it can + # read from it (this write matches up with wait [1] below). + if _wait: + self.write_stream.wait_stream(current_stream) + self.write_stream.wait_stream(self.second_stream) + WriteValues.OPERATOR( + [self.num_writes_into_my_staging[dst_rank, stripe]], + seq_num, + self.write_stream, + ) + + my_matmul( + [o[self.my_rank] for o in gathered_outputs], + self.my_rank, + stream_factory, + ) + + current_stream.wait_stream(self.second_stream) + + for iter_ in range(1, self.world_size): + src_rank = (self.my_rank - iter_) % self.world_size + + # Wait for buddy to signal that it wrote into the data before we + # read from it (this wait matches up with the write in Triton + # or with write [1] above). + if _wait: + WaitValues.OPERATOR( + [self.num_writes_into_buddys_staging[src_rank][stripe]], + seq_num, + self.wait_stream, + timeout_s, + ) + + self.memcpy_stream.wait_stream(self.wait_stream) + + if _memcpy: + with torch.cuda.stream(self.memcpy_stream): + for go, bs in zip(gathered_outputs, buddys_stagings[src_rank]): + go[src_rank].copy_(bs) + + current_stream.wait_stream(self.memcpy_stream) + + for go, so in zip(gathered_outputs, scattered_outputs): + torch.sum(go, dim=0, out=so) + + self.write_stream.wait_stream(current_stream) + + # Signal to buddy that we have read from the data so it can + # overwrite it (this write matches up with wait [2] above). + if _wait: + WriteValues.OPERATOR( + [ + self.num_reads_from_buddys_staging[ + (self.my_rank - iter_) % self.world_size, stripe + ] + for iter_ in range(1, self.world_size) + ], + seq_num, + self.write_stream, + ) + + +# We'd store this as an attribute on the PG object itself, but some PGs are +# pybind-bound classes and thus don't support it, so we simulate this as an +# external cache. +CACHE: Dict[int, Optional[_FusedSequenceParallel]] = {} + + +def _can_ranks_communicate_all_to_all_over_nvlink(group: dist.ProcessGroup) -> bool: + # FIXME This is currently overly simplistic, must be improved. The following + # should be enough: + # - ensure that all ranks are running on the same machine (by exchanging + # their /proc/sys/kernel/random/boot_id value) + # - ensure there's P2P between all pairs of ranks (can_device_access_peer + # could help here but it's unclear what happens if target devices aren't + # visible? maybe just trying to exchange IPC handles and catching errors + # would work? note that in any case some ranks might succeed while some + # might fail so we need a barrier to have them all make the same decision) + return dist.get_world_size(group=group) <= 8 + + +def _lazy_init( + device: torch.device, dtype: torch.dtype, group: dist.ProcessGroup, num_stripes: int +) -> Optional[_FusedSequenceParallel]: + world_size = dist.get_world_size(group=group) + try: + obj = CACHE[id(group)] + except KeyError: + if int(os.environ.get("DISABLE_FUSED_SEQUENCE_PARALLEL", "0")): + obj = None + elif world_size == 1: + obj = None + elif not _can_ranks_communicate_all_to_all_over_nvlink(group): + obj = None + else: + obj = _FusedSequenceParallel(device, dtype, group, num_stripes) + CACHE[id(group)] = obj + return obj + + +def _default_stream_factory() -> torch.cuda.Stream: + return torch.cuda.current_stream() + + +@overload +def fused_allgather_and_linear( + scattered_input: torch.Tensor, + weight: torch.Tensor, + *, + group: dist.ProcessGroup, + out: Optional[torch.Tensor] = None, + num_stripes: int = 1, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> torch.Tensor: + ... + + +@overload +def fused_allgather_and_linear( + scattered_input: torch.Tensor, + weight: List[torch.Tensor], + *, + group: dist.ProcessGroup, + out: Optional[List[torch.Tensor]] = None, + num_stripes: int = 1, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> List[torch.Tensor]: + ... + + +def fused_allgather_and_linear( + scattered_input: torch.Tensor, + weight: Union[torch.Tensor, List[torch.Tensor]], + *, + group: dist.ProcessGroup, + out: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + num_stripes: int = 1, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Performs a fused all-gather followed by a linear op + + It is equivalent to the following plain PyTorch code: + + # like scattered_input but with first dim multiplied by group's world size + gathered_input = scattered_input.new_empty(...) + dist.all_gather_into_tensor(gathered_input, scattered_input, group=group) + return torch.nn.functional.linear(gathered_input, weight) + + It achieves this by breaking down the matmul into smaller partial ops (as + many as the world size), each needing as input a different "contribution" + to the all-gather (by a different rank), and writing to a different chunk of + the output. Then, on one stream, it sends the local contribution to all + other ranks (first one rank over, then two, ...) while, on another stream, + it launches the sub-matmuls in the order in which the remote contributions + (which are the sub-matmuls' inputs) are supposed to arrive, so that ideally + none of the sub-matmuls will ever have to wait. + + The idea comes from this paper: https://arxiv.org/abs/2302.05442 + + This method uses a staging buffer, which persists across calls, of the same + size as the all-gathered input tensor (i.e., the input's size times the + world size). If multiple inputs of multiple sizes are used, the staging + buffer will be the maximum needed by any of them. Each call, when it starts, + must first wait for the previous call to finish using the staging buffer. In + normal conditions, where there's some other operation between two calls, + this isn't an issue. However, when doing back-to-back calls (like in + benchmarks) it can introduce artificial delays. To hide them, we allow using + more than one staging buffer, which will be cycled through, thus trading + memory for speed. This can be controlled using the num_stripes argument. + + """ + world_size = dist.get_world_size(group=group) + weights = weight if isinstance(weight, list) else [weight] + assert all(w.ndim == 2 for w in weights) + assert scattered_input.ndim >= 2 + assert all(scattered_input.shape[-1] == w.shape[-1] for w in weights) + assert scattered_input.is_contiguous() + gathered_input_shape = (world_size,) + scattered_input.shape + gathered_output_shapes = [gathered_input_shape[:-1] + w.shape[:-1] for w in weights] + if out is not None: + assert isinstance(out, list) == isinstance(weight, list) + gathered_outputs = out if isinstance(out, list) else [out] + assert len(gathered_outputs) == len(gathered_output_shapes) + assert all( + go.shape == gos for go, gos in zip(gathered_outputs, gathered_output_shapes) + ) + assert all(go.is_contiguous() for go in gathered_outputs) + else: + gathered_outputs = [ + scattered_input.new_empty(gos) for gos in gathered_output_shapes + ] + + def my_matmul( + inputs: List[torch.Tensor], + src_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + for w, go in zip(weights, gathered_outputs): + with torch.cuda.stream(stream_factory()): + torch.matmul(inputs[0], w.t(), out=go[src_rank]) + + fused_allgather_and_anything( + [scattered_input], + my_matmul, + group=group, + num_stripes=num_stripes, + timeout_s=timeout_s, + _is_regular_matmul=True, + _extra_triton_args=dict( + bs=[w.t() for w in weights], + cs=[go.flatten(0, -2) for go in gathered_outputs], + cs_my_shard=None, + ), + **private_args_DO_NOT_USE, + ) + + if isinstance(weight, list): + return [go.flatten(0, 1) for go in gathered_outputs] + else: + return gathered_outputs[0].flatten(0, 1) + + +def fused_allgather_and_anything( + scattered_inputs: List[torch.Tensor], + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + *, + group: dist.ProcessGroup, + num_stripes: int = 1, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> None: + world_size = dist.get_world_size(group=group) + + if len(scattered_inputs) == 0: + for src_rank in range(world_size): + my_matmul([], src_rank, _default_stream_factory) + return + + assert all(si.is_contiguous() for si in scattered_inputs) + assert all(si.device == scattered_inputs[0].device for si in scattered_inputs) + assert all(si.dtype == scattered_inputs[0].dtype for si in scattered_inputs) + + gathered_input_shapes = [(world_size,) + si.shape for si in scattered_inputs] + + obj = _lazy_init( + scattered_inputs[0].device, scattered_inputs[0].dtype, group, num_stripes + ) + + if world_size == 1: + my_matmul(scattered_inputs, 0, _default_stream_factory) + + # Fallback + elif obj is None: + gathered_inputs = [ + si.new_empty(gis) + for si, gis in zip(scattered_inputs, gathered_input_shapes) + ] + for si, gi in zip(scattered_inputs, gathered_inputs): + dist.all_gather_into_tensor(output_tensor=gi, input_tensor=si, group=group) + for src_rank in range(world_size): + my_matmul( + [gi[src_rank] for gi in gathered_inputs], + src_rank, + _default_stream_factory, + ) + + # Fast path + else: + assert scattered_inputs[0].device == obj.my_device + assert scattered_inputs[0].dtype == obj.dtype + assert obj.num_stripes == num_stripes + obj.allgather_and_linear( + scattered_inputs, + my_matmul, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), + _triton=private_args_DO_NOT_USE.get("_triton", True), + _is_regular_matmul=private_args_DO_NOT_USE.get("_is_regular_matmul", False), + _extra_triton_args=private_args_DO_NOT_USE.get("_extra_triton_args", {}), + ) + + +@overload +def fused_linear_and_reducescatter( + gathered_input: torch.Tensor, + weight: torch.Tensor, + *, + group: dist.ProcessGroup, + out: Optional[torch.Tensor] = None, + num_stripes: int = 1, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> torch.Tensor: + ... + + +@overload +def fused_linear_and_reducescatter( + gathered_input: torch.Tensor, + weight: List[torch.Tensor], + *, + group: dist.ProcessGroup, + out: Optional[List[torch.Tensor]] = None, + num_stripes: int = 1, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> List[torch.Tensor]: + ... + + +def fused_linear_and_reducescatter( + gathered_input: torch.Tensor, + weight: Union[torch.Tensor, List[torch.Tensor]], + *, + group: dist.ProcessGroup, + out: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + num_stripes: int = 1, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Performs a fused linear op followed by a reduce-scatter + + It is equivalent to the following plain PyTorch code: + + gathered_output = torch.nn.functional.linear(gathered_input, weight) + # like gathered_output but with first dim divided by group's world size + scattered_output = gathered_output.new_empty(...) + dist.reduce_scatter_tensor(scattered_output, gathered_output, group=group) + + """ + world_size = dist.get_world_size(group=group) + weights = weight if isinstance(weight, list) else [weight] + assert all(w.ndim == 2 for w in weights) + assert gathered_input.ndim >= 2 + assert all(gathered_input.shape[-1] == w.shape[-1] for w in weights) + assert gathered_input.is_contiguous() + assert gathered_input.shape[0] % world_size == 0 + gathered_input = gathered_input.view( + (world_size, gathered_input.shape[0] // world_size) + gathered_input.shape[1:] + ) + gathered_output_shapes = [gathered_input.shape[:-1] + w.shape[:-1] for w in weights] + scattered_output_shapes = [gos[1:] for gos in gathered_output_shapes] + if out is not None: + assert isinstance(out, list) == isinstance(weight, list) + scattered_outputs = out if isinstance(out, list) else [out] + assert len(scattered_outputs) == scattered_output_shapes + assert all(so.device == gathered_input.device for so in scattered_outputs) + assert all(so.dtype == gathered_input.dtype for so in scattered_outputs) + assert all( + so.shape == sos + for so, sos in zip(scattered_outputs, scattered_output_shapes) + ) + else: + scattered_outputs = [ + gathered_input.new_empty(sos) for sos in scattered_output_shapes + ] + + def my_matmul( + outputs: List[torch.Tensor], + dst_rank: int, + stream_factory: Callable[[], torch.cuda.Stream], + ) -> None: + for w, o in zip(weights, outputs): + with torch.cuda.stream(stream_factory()): + torch.matmul(gathered_input[dst_rank], w.t(), out=o) + + fused_anything_and_reducescatter( + my_matmul, + scattered_outputs, + group=group, + num_stripes=num_stripes, + timeout_s=timeout_s, + _is_regular_matmul=True, + _extra_triton_args=dict( + a_my_shard=None, + a=gathered_input.flatten(0, -2), + bs=[w.t() for w in weights], + ), + **private_args_DO_NOT_USE, + ) + + if isinstance(weight, list): + return scattered_outputs + else: + return scattered_outputs[0] + + +def fused_anything_and_reducescatter( + my_matmul: Callable[ + [List[torch.Tensor], int, Callable[[], torch.cuda.Stream]], None + ], + scattered_outputs: List[torch.Tensor], + *, + group: dist.ProcessGroup, + num_stripes: int = 1, + timeout_s: int = 60 * 60, + **private_args_DO_NOT_USE, +) -> None: + world_size = dist.get_world_size(group=group) + + if len(scattered_outputs) == 0: + for dst_rank in range(world_size): + my_matmul([], dst_rank, _default_stream_factory) + return + + assert all(so.is_contiguous() for so in scattered_outputs) + assert all(so.device == scattered_outputs[0].device for so in scattered_outputs) + assert all(so.dtype == scattered_outputs[0].dtype for so in scattered_outputs) + + gathered_output_shapes = [(world_size,) + so.shape for so in scattered_outputs] + + obj = _lazy_init( + scattered_outputs[0].device, scattered_outputs[0].dtype, group, num_stripes + ) + + if world_size == 1: + my_matmul(scattered_outputs, 0, _default_stream_factory) + + # Fallback + elif obj is None: + gathered_outputs = [ + so.new_empty(gos) + for so, gos in zip(scattered_outputs, gathered_output_shapes) + ] + for dst_rank in range(world_size): + my_matmul( + [go[dst_rank] for go in gathered_outputs], + dst_rank, + _default_stream_factory, + ) + for go, so in zip(gathered_outputs, scattered_outputs): + dist.reduce_scatter_tensor(output=so, input=go, group=group) + + # Fast path + else: + assert scattered_outputs[0].device == obj.my_device + assert scattered_outputs[0].dtype == obj.dtype + assert obj.num_stripes == num_stripes + gathered_outputs = [ + scattered_outputs[0].new_empty(gos) for gos in gathered_output_shapes + ] + obj.linear_and_reducescatter( + my_matmul, + gathered_outputs, + scattered_outputs, + timeout_s=timeout_s, + _wait=private_args_DO_NOT_USE.get("_wait", True), + _memcpy=private_args_DO_NOT_USE.get("_memcpy", True), + _triton=private_args_DO_NOT_USE.get("_triton", True), + _is_regular_matmul=private_args_DO_NOT_USE.get("_is_regular_matmul", False), + _extra_triton_args=private_args_DO_NOT_USE.get("_extra_triton_args", {}), + ) diff --git a/xformers/ops/tiled_matmul.py b/xformers/ops/tiled_matmul.py new file mode 100644 index 0000000000..99cf8a6081 --- /dev/null +++ b/xformers/ops/tiled_matmul.py @@ -0,0 +1,247 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from typing import List, Optional + +import torch +import torch.multiprocessing.reductions +from torch.utils._pytree import tree_flatten, tree_unflatten +from typing_extensions import Annotated + +from .. import _is_triton_available +from .common import Alias, make_pytorch_operator_for_dispatch_key + +if _is_triton_available(): + from ._triton.tiled_matmul_kernels import _launch_triton_matmul + + TRITON_IS_AVAILABLE = True +else: + TRITON_IS_AVAILABLE = False + + +# Copied over from the sequence parallel fused ops. +def _should_use_triton(device: torch.device, dtype: torch.dtype) -> bool: + if not int(os.getenv("XFORMERS_TILED_MATMUL_ENABLE_TRITON", "1")): + return False + if not TRITON_IS_AVAILABLE: + return False + device_capability = torch.cuda.get_device_capability(device) + # Triton seems to be having issues on P100 and V100 GPUs, such as + # https://github.com/openai/triton/issues/1609 + # https://github.com/openai/triton/issues/1610 + # https://github.com/openai/triton/issues/1257#issuecomment-1532616965 + # and, in recent Triton versions (Jan 2024), returning wrong values. + if device_capability < (8, 0): + return False + return True + + +# We can't use make_pytorch_cuda_operator because PyTorch isn't able to inspect +# Tensor[][] args to detect they contain CUDA args. Thus we need to register +# this as a fallback implementation, so it gets invoked regardless of the args. +# See: https://github.com/pytorch/pytorch/issues/113022 +@make_pytorch_operator_for_dispatch_key("") +def tiled_matmul_fwd( + a: List[List[torch.Tensor]], + b: List[List[torch.Tensor]], + out: Optional[List[List[Annotated[torch.Tensor, Alias("a", write=True)]]]] = None, +) -> List[List[Annotated[torch.Tensor, Alias("a", write=True)]]]: + assert len(a) >= 1 and len(a[0]) >= 1 and all(len(row) == len(a[0]) for row in a), ( + "the first operand must be a non-empty two-dimensional regular list of lists " + "of tenors" + ) + assert len(b) >= 1 and len(b[0]) >= 1 and all(len(row) == len(b[0]) for row in b), ( + "the second operand must be a non-empty two-dimensional regular list of lists " + "of tenors" + ) + + m_tiles = len(a) + k_tiles = len(a[0]) + assert len(b) == k_tiles, ( + "the first operand's inner dimension must match the second operand's outer " + f"dimension, got {k_tiles} and {len(b)}" + ) + n_tiles = len(b[0]) + + ms = [a[tile_m][0].shape[0] for tile_m in range(m_tiles)] + ns = [b[0][tile_n].shape[1] for tile_n in range(n_tiles)] + aks = [a[0][tile_k].shape[1] for tile_k in range(k_tiles)] + bks = [b[tile_k][0].shape[0] for tile_k in range(k_tiles)] + + for tile_m in range(m_tiles): + for tile_k in range(k_tiles): + assert a[tile_m][tile_k].shape[0] == ms[tile_m], ( + f"the tensors on row {tile_m} of the first operand must all have the " + f"same size along the m dimension, got {ms[tile_m]} at position 0 and " + f"{a[tile_m][tile_k].shape[0]} at position {tile_k}" + ) + assert a[tile_m][tile_k].shape[1] == aks[tile_k], ( + f"the tensors on column {tile_k} of the first operand must all have " + f"the same size along the k dimension, got {aks[tile_k]} at position 0 " + f"and {a[tile_m][tile_k].shape[1]} at position {tile_m}" + ) + + for tile_n in range(n_tiles): + for tile_k in range(k_tiles): + assert b[tile_k][tile_n].shape[0] == bks[tile_k], ( + f"the tensors on row {tile_k} of the second operand must all have the " + f"same size along the k dimension, got {bks[tile_k]} at position 0 and " + f"{b[tile_k][tile_n].shape[0]} at position {tile_n}" + ) + assert b[tile_k][tile_n].shape[1] == ns[tile_n], ( + f"the tensors on column {tile_n} of the second operand must all have " + f"the same size along the n dimension, got {ns[tile_n]} at position 0 " + f"and {b[tile_k][tile_n].shape[1]} at position {tile_k}" + ) + + for tile_k in range(k_tiles): + assert aks[tile_k] == bks[tile_k], ( + f"the tensors on column {tile_k} of the first operand and those on row " + f"{tile_k} of the second operand must have the same size along the k " + f"dimension, got {aks[tile_k]} and {bks[tile_k]}" + ) + ks = aks + + if out is not None: + assert ( + len(out) >= 1 + and len(out[0]) >= 1 + and all(len(row) == len(out[0]) for row in out) + ), "out must be a non-empty two-dimensional regular list of lists of tenors" + assert len(out) == m_tiles + assert len(out[0]) == n_tiles + cms = [out[tile_m][0].shape[0] for tile_m in range(m_tiles)] + cns = [out[0][tile_n].shape[1] for tile_n in range(n_tiles)] + for tile_m in range(m_tiles): + for tile_n in range(n_tiles): + assert out[tile_m][tile_n].shape[0] == cms[tile_m], ( + f"the tensors on row {tile_m} of out must all have the same size " + f"along the m dimension, got {cms[tile_m]} at position 0 and " + f"{out[tile_m][tile_n].shape[0]} at position {tile_n}" + ) + assert out[tile_m][tile_n].shape[1] == cns[tile_n], ( + f"the tensors on column {tile_n} of out must all have the same " + f"size along the k dimension, got {cns[tile_n]} at position 0 and " + f"{out[tile_m][tile_n].shape[1]} at position {tile_m}" + ) + for tile_m in range(m_tiles): + assert cms[tile_m] == ms[tile_m], ( + f"the tensors on row {tile_m} of out and those on row {tile_m} of the " + f"first operand must have the same size along the m dimension, got " + f"{cms[tile_m]} and {ms[tile_m]}" + ) + for tile_n in range(n_tiles): + assert cns[tile_n] == ns[tile_n], ( + f"the tensors on column {tile_n} of out and those on column {tile_n} " + f"of the second operand must have the same size along the n dimension, " + f"got {cns[tile_n]} and {ns[tile_n]}" + ) + c = out + else: + c = [[a[0][0].new_empty((m, n)) for n in ns] for m in ms] + + # TODO We can try merging tiles that come from contiguous memory, using + # stack_or_none, to further improve performance. + + # Because the Triton kernel is hardcoded for maximum three tiles. + # Because, in turn, we aimed this at the fusion of wq/wk/wv. + if ( + m_tiles <= 3 + and k_tiles <= 3 + and n_tiles <= 3 + and _should_use_triton(a[0][0].device, a[0][0].dtype) + ): + _launch_triton_matmul(a, b, c, ms, ns, ks) + else: + for tile_m in range(len(ms)): + for tile_n in range(len(ns)): + torch.mm(a[tile_m][0], b[0][tile_n], out=c[tile_m][tile_n]) + for tile_k in range(1, len(ks)): + c[tile_m][tile_n].addmm_(a[tile_m][tile_k], b[tile_k][tile_n]) + + return c + + +def _transpose(x: List[List[torch.Tensor]]) -> List[List[torch.Tensor]]: + return [[t.t() for t in y] for y in zip(*x)] + + +class _TiledMatmul(torch.autograd.Function): + @staticmethod + def forward(ctx, ab_tree_spec, *ab_tree_values): + ctx.ab_tree_spec = ab_tree_spec + ctx.save_for_backward(*ab_tree_values) + a, b = tree_unflatten(list(ab_tree_values), ab_tree_spec) + + c = tiled_matmul_fwd(a, b) + + c_tree_values, c_tree_spec = tree_flatten(c) + ctx.c_tree_spec = c_tree_spec + return (c_tree_spec,) + tuple(c_tree_values) + + @staticmethod + def backward(ctx, _none, *grad_c_tree_values): + a, b = tree_unflatten(list(ctx.saved_tensors), ctx.ab_tree_spec) + grad_c = tree_unflatten(list(grad_c_tree_values), ctx.c_tree_spec) + + grad_a = tiled_matmul_fwd(grad_c, _transpose(b)) + grad_b = tiled_matmul_fwd(_transpose(a), grad_c) + + grad_ab_tree_values, grad_ab_tree_spec = tree_flatten((grad_a, grad_b)) + return (None,) + tuple(grad_ab_tree_values) + + +def tiled_matmul( + a: List[List[torch.Tensor]], + b: List[List[torch.Tensor]], +) -> List[List[torch.Tensor]]: + """Multiply two matrices given as grids of tiles + + It performs the matmul between A and B, which are given as two-dimensional + grids of tiles (i.e., blocks), represented as lists of lists of tensors. + The output will itself be a matrix in such a form. Formally: + + out[m][n] = sum(a[m][k] @ b[k][n] for k in range(...)) + + with the obvious constraints needed to make it work, in terms of number of + tiles and sizes of each tile. + + The interest of this operator is to improve performance by avoding wave + quantization effects when doing independent matrix multiplications in + series. Sometimes, when these matmuls have one operand in common, this can + also be addressed by concatenating the other operands into a single matrix, + and issuing a single matmul. However this isn't always possible (e.g., might + break the checkpoint format) and it's an anti-pattern, as it obscures the + logic (e.g., changing the modelling code out of performance reasons). This + tiled matmul performs the same computation as if the matrices were merged, + without merging them, simply through a smarter memory addressing scheme. + + The tiled matmul is less generic than a grouped matmul, which can also help + with wave quantization, and doesn't need the matmuls to have the same lhs + or rhs operand. However, a grouped matmul will write the result of each + matmul to a separate output matrix, whereas the tiled matmul allows to add + them together into a single output. This is needed during the backward pass + of a linear layer, and it's the reason we wrote this instead of using a + grouped matmul. + + The tiled matmul is implemented using a custom Triton kernel, which puts + constraints on the strides of the tiles. All rows of A must have the same + K stride, all columns of A must have the same M stride, and so on. + + Currently the tiled matmul supports at most three tiles on each dimension, + although fewer can also be given. This is because we needed it to fuse the + query, key and value weights of an attention layer. This limit can be + increased if needed. + + This operator is differentiable. + + """ + ab_tree_values, ab_tree_spec = tree_flatten((a, b)) + c_tree_spec, *c_tree_values = _TiledMatmul.apply(ab_tree_spec, *ab_tree_values) + c = tree_unflatten(list(c_tree_values), c_tree_spec) + + return c