Skip to content

Commit

Permalink
[fused seqpar] Make tests faster by disabling autotuning
Browse files Browse the repository at this point in the history
ghstack-source-id: e49d0b35debd6ab5d34ba328340ac00f1f1fc0bf
Pull Request resolved: fairinternal/xformers#1021

__original_commit__ = fairinternal/xformers@24957df
  • Loading branch information
lw authored and xFormers Bot committed Jan 30, 2024
1 parent e6e6695 commit cf5e6b4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/test_seqpar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import torch

from xformers import _is_triton_available
from xformers.ops import (
sequence_parallel_leading_matmul,
sequence_parallel_trailing_matmul,
Expand All @@ -28,6 +29,22 @@
)


# We care about correctness, not performance, hence let's "disable" the
# expensive autotuning by removing all configs except one (the first one).
if _is_triton_available():
from xformers.ops._triton.sequence_parallel_fused_kernels import (
_xformers_seqpar_matmul_kernel,
)

while len(_xformers_seqpar_matmul_kernel.configs) > 1:
_xformers_seqpar_matmul_kernel.configs.pop()

from xformers.ops._triton.tiled_matmul_kernels import _xformers_tiled_matmul_kernel

while len(_xformers_tiled_matmul_kernel.configs) > 1:
_xformers_tiled_matmul_kernel.configs.pop()


def reference_leading(input_, w1, w2):
hidden1 = torch.matmul(input_, w1.t())
hidden2 = torch.matmul(input_, w2.t())
Expand Down
12 changes: 12 additions & 0 deletions tests/test_sequence_parallel_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
import torch

from xformers import _is_triton_available
from xformers.ops import fused_allgather_and_linear, fused_linear_and_reducescatter

from .multiprocessing_utils import launch_subprocesses
Expand All @@ -26,6 +27,17 @@
)


# We care about correctness, not performance, hence let's "disable" the
# expensive autotuning by removing all configs except one (the first one).
if _is_triton_available():
from xformers.ops._triton.sequence_parallel_fused_kernels import (
_xformers_seqpar_matmul_kernel,
)

while len(_xformers_seqpar_matmul_kernel.configs) > 1:
_xformers_seqpar_matmul_kernel.configs.pop()


def inner_sequence_parallel_fused(
seed: int,
kind: str,
Expand Down

0 comments on commit cf5e6b4

Please sign in to comment.