Skip to content

Commit

Permalink
Fast Multi-ahead Attention support on AMD ROCM (#978)
Browse files Browse the repository at this point in the history
* add option to build a standalone runner for splitk decoder; debugging numerics in reduction

* fix a few bugs

* fix an indexing bug

* stash changes

* Add benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py to benchmark mqa/gqa performance on ck-tiled fmha

* Synchronize with latest update in composable_kernel_tiled feature/fmha-pad-support branch

* Tiny fix in benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py

* Synchronize with latest update in composable_kernel_tiled and make all unit_tests passed

* Swith to new branch for composable_kernel_tiled submodule

* Add bfp16 instances for ck-tiled inference

* Update to test and benchmark scripts to include bfloat16

* Tiny update to ck_tiled kernel

* Change to benchmark_mem_eff_attn_mqa_gqa_ck_tiled benchmark cases

* stash changes

* Use Async pipeline for no M/N0K1 padding cases

* Add CF_FMHA_FWD_FAST_EXP2 to buiding

* Add Triton FA2 forward op

* Add Triton Flash Attention 2 to benchmarks

* Synchronize with latest third_party/composable_kernel and remove the inner_product bhalf_t overloading in ck_attention_forward_decoder.h

* stash split attention testing wip

* Synchronize with latest third_party/composable_kernel again

* Synchronize with latest third_party/composable_kernel_tiled

* Change to make ck decoder buildable with both ck tiled or non-tiled fmha kernel

* Change to make ck decoder buildable with both ck tiled or non-tiled fmha kernel

* fix gqa for split-k=1

* Skip backward tests, fix import

* fix the mask for decoding; row max and lse are computed correctly; debugging must go on

* make libtorch split-1 decoder implementation pass numerical correctness

* Disable CK kernel for large shapes, better catch OOMs

* Actually remove submodule composable_kernel_tiled from the branch

* Change the domain for the repo of composable_kernel submodule to ROCm

* Update to validate_inputs() in common.py to support 4d mqa/gqa

* synchronize test_mem_eff_attention_ck.py with test_mem_eff_attention.py

* Tiny update in benchmark_mem_eff_attn_decoder_ck.py

* Synchronize benchmark_mem_eff_attention_ck.py with benchmark_mem_eff_attention.py

* Remove benchmark_mem_eff_attn_decoder_ck_tiled.py

* Support for Generic Attention Mask Coordinate

* Add ck.FwOp and ck.BwOp to dispatched operations

* Add ck.FwOp and ck.BwOp to ALL_FW_OPS and ALL_BW_OPS

* Update in tests/readme_test_on_rocm.txt

* Add ckF and ck_decoder to benchmark_mem_eff_attn_decoder.py

* Synchronize with the latest ck-tiled commits

* Add is_ck_tiled_used() c++ extension interface for judging if ck-tiled is used

* Remove composable_kernel_tiled submodule

* inner_product removed from splitk kernel code

* remove some commented out debug code

* comment out debug code calling libtorch instead of hip implementation

* remove commented out old and incorrect code fragments

* add python version override to cmakelists

* add conversion from Argument struct to string; fix split1 test crash

 -- fyi device guard needs to be declared to avoid segfaults in the kernel

* add f32 support in the python op

* refactor out input generation in cpp standalone

* set loop unrolls to 1 in order to avoid index errors (will need to be fixed later for perf)

* fix output splits allocation

* fix bug in split attention: sumexp needs timestep bounds in each split

* clang-format-10

* Enable support of attn-bias types with LocalAttention

* Enable support of attn-bias types with LocalAttention

* Synchronize submodule composable_kernel to the latest commits

* Make the efficient_attention_forward_ck() C++ interface consistent with the updating of xformers/ops/fmha API

* Tiny fix in ck.py to make test_backward pass

* some refactorings for standalone tests

* cleanup testing

* Make the efficient_attention_forward_ck() C++ interface consistent with the updating of xformers/ops/fmha API

* Tiny fix in ck.py to make test_backward pass

* fix split1 attention csrc test

* Enable support of flexible head-dim size (but <= 128) for ck-tiled fmha forward

* Use Async pipeline when no any padding used

* implement general split-k split-attention in libtorch, use for testing

* fix split-max and split-sumexp shapes for split attention in libtorch

* implement generic reduce split attention with libtorch

* implement testing split reduce hip vs libtorch; tbd debug split-k=2 numerical mismatch in this test

* refactor repetitive testing code

* address code review: rearrange loops

* address code review: add comment about number of iterations per split

* address code review: remove comments

* address code review: possibly eliminate a bug by using correct timestep range for scaling sumexp in smem

* address code review: add todo

* address code review: shift LDS access by tt_low to avoid smem overbooking

* address code review: simplify reduction loops in split attention

* Tiny update in ck-tiled forward kernel

* address code review: merge for loops

* address code review: simplify coefficient pick

* fix runtime error message in testing code

* fix split reduce test

* address code review: fix smem offsets

* remove redundant comment

* address code review: initialize split attention workspace as empty

* address code review: rename local vars

* address code review: remove unused _rand_seqlens

* address code review: cleanup python tests

* remove redundant new_max local var

* address code review: rename seq_acc

* re-enable loop unroll; adjust tests to handle splits with size divisible by block size; handle empty splits correctly

* test a wider range of split-k in cpp tests; fix torch implementation one more time to handle empty splits

* Synchronize with ck-tiled update to support head-dim-256 and LSE storing

* Add definition of FMHA_FWD_HEADDIM_SWITCH

* Split the ck-tiled inference instances based on head-dim sizes to improve compiling

* Setting k0n1_need_padding according to pipeline kQLoadOnce implementation

* Add fmha forward c++ extension for ck-tiled

* Set SUPPORTED_MAX_K=256 in ck.py

* fix index in split-k attention

* fix index in softmax reduce and complete fixing wavefronts per block optimization

* clang-format-10

* Fix v_dram_transposed transpose transform in the kernel

* Skipe trition_splitk for test_forward in test_mem_eff_attention.py

* cleanup commented dead code

* enable ck split-k in benchmark_attn_decoding

* add rocm_ci workflow

* move scipy import from file level under function similar to _vec_binom_test

saves a few keystrokes when setting up environment

* Add including of math_v2.hpp to ck_attention_forward_decoder_splitk.h

* move forward_splitk to ck_splitk; make dispatch aware of ck_splitk and ck_decoder

* Synchronize to latest ck-tiled and update accordingly

* fix benchmark_attn_decoding

* Remove third_party/composable_kernel_tiled

* [Fix] use kK0BlockLength for HeadDim256 padding judging

* Tiny type change for custom_mask_type in param class

* Change to use ROCm repo for ck-tiled submodule

* Remove tests/test_forward_ck_tiled.py

* Update to test_mqa_forward_ck_tiled.py to use common create_attn_bias method

* Add ck-tiled checking in test_mqa_forward_ck_tiled.py

* rearrange smem access in softmax reduction

* Add test_decoder and test_splitk_decoder for ROCM into test_mem_eff_attention.py

* Add ref_attention_splitk and its test to tests/test_mem_eff_attention.py

* Rename test_mem_eff_attention_ck.py as discarded

* Add test_mqa_forward and ref_attention_mqa (for BMHK format mqa/gqa verification) into test_mem_eff_attention.py

* Rename test_mqa_forward_ck_tiled.py as discarded

* Remove CK specific script benchmark_mem_eff_attn_decoder_ck.py

* Refine benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py

* Rename benchmark_mem_eff_attn_mqa_gqa_ck_tiled.py to benchmark_mem_eff_attention_mqa.py

* Remove the runtime_error with using logsumexp in attention_forward_generic_ck_tiled.cpp

* Add ck-tiled checking in ck.py

* Remove CK-specific benchmark scripts

* Don't require is_cpu_tensor for seqstart_q/seqstart_k/seqlen_k in attention_forward_generic_ck_tiled

* Remove seqlen_cpu from  _PaddedSeqLenInfo in attn_bias.py

* Change the branch for composable_kernel_tiled submodule and update to latest

* Remove the using of seqlen_cpu in BwOp of ck.py

* Remove the using of seqlen_cpu in BwOp of ck.py

* Align .clang_format with main branch and re-format c++ files

* Synchronize to latest ck-tiled commit

* Add checking of IS_CK_TILED into some testing scripts

* Update to test_mem_eff_attention.py and ck.py

* Building xformers using ck-tiled as default

* ensure ck_decoder does not dispatch

* Add disable_on_rocm on some test scripts

* Update to test_mem_eff_attention.py

* apply isort

* apply black

* fix flake8 suggestions

* add license headers and reapply black

* Tiny update to rocm_ci.yml

* Add conditional compiling for cuda-depending codes in ROCM

* Update to benchmark scripts

* Rename the one script file

* Revert "Add conditional compiling for cuda-depending codes in ROCM"

This reverts commit 12fb41c.

* Update to scripts

* Change and add readme for tests and benchmarks

* Remove the stuffs for supporting old ck

* Remove old composable_kernel from submodule list

* Remove folder third_party/composable_kernel

* Rename the folder

* Remove unused script file

* apply black

* pacify mypy

* fix clang-format

* reapply black

* fix lints

* make test_splitk_reference run on cpu

* add ck modules to docs

* try fixing nvidia build by re-including sparse24 cpp folder into extension sources

* update cutlass to upstream commit

* update flash-attention to upstream commit

* simplify setup.py

* remove duplicate run_batched_infer_causalmask_attnbias_dispatched<f16, true, true, 128>

* add hip version and pytorch hip arch list to xformers build info

* fix build

* patch around the unhappy path in get_hip_version

* skip test_grad_checkpointing for triton_splitk since it doesn't have bwop

* re-enable test_mqa_forward since ck tiled is the current implementation

* make skip test_wrong_alignment more generic

* reapply black

* simplify test_decoder

* put python version check inside triton_splitk op

* fix logic

* cleanup python3.9 checks in tests

* cleanup test_attentions

* cleanup test_checkpoint as test running on cpu does not depend on gpu platform

* fix lints

* try fixing win build by conditional import of triton in triton op

* re-enable test_triton_layernorm as it passes

* re-enable test_triton_blocksparse as it passes

* cleanup test_sparse_tensors

* cleanup test_custom_ops

* reapply black

* cleanup test_core_attention

* benchmark ck ops on rocm only

* fix mypy

* fix lint: black

* fix lints: mypy

* Rename HDim/headdim to MaxK/maxk

* Move some headers files to ck examples for later reusing

* Replace using qs_ks_vs pipeline by qr_ks_vs pipeline while HeadDim is 256 for better performance

* rm test_ck_7

* fix lints

* unskip test_unsupported_alignment

* move out test_splitk_reference

* add license header to file created in prev commit

* roll back fmha/common.py

... so users are forced to provide rank-5 inputs for mqa/gqa

* fix lint

* remove unused ref_attention_mqa

* resolve error in triton_splitk on rocm

> Triton does not support if expressions (ternary operators) with dynamic conditions, use if statements instead

* disable partial attention tests on rocm

---------

Co-authored-by: Max Podkorytov <[email protected]>
Co-authored-by: Grigory Sizov <[email protected]>
  • Loading branch information
3 people authored Mar 4, 2024
1 parent fe0526b commit 44b0d07
Show file tree
Hide file tree
Showing 196 changed files with 9,652 additions and 224 deletions.
71 changes: 71 additions & 0 deletions .github/workflows/rocm_ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: ROCM_CI

on:
pull_request:
types: [labeled, synchronize, reopened]

jobs:
build:
if: contains(github.event.label.name, 'rocm')
runs-on: rocm

steps:
- uses: actions/checkout@v2
- name: Get CPU info on Ubuntu
if: contains(runner.os, 'linux')
run: |
cat /proc/cpuinfo
- name: Get env vars
run: |
echo GITHUB_WORKFLOW = $GITHUB_WORKFLOW
echo HOME = $HOME
echo PWD = $PWD
echo GITHUB_ACTION = $GITHUB_ACTION
echo GITHUB_ACTIONS = $GITHUB_ACTIONS
echo GITHUB_REPOSITORY = $GITHUB_REPOSITORY
echo GITHUB_EVENT_NAME = $GITHUB_EVENT_NAME
echo GITHUB_EVENT_PATH = $GITHUB_EVENT_PATH
echo GITHUB_WORKSPACE = $GITHUB_WORKSPACE
echo GITHUB_SHA = $GITHUB_SHA
echo GITHUB_REF = $GITHUB_REF
export GIT_BRANCH=${GITHUB_BASE_REF:-${GITHUB_REF#refs/heads/}}
echo GIT_BRANCH = $GIT_BRANCH
export ROCM_PATH=/opt/rocm
echo ROCM_PATH = $ROCM_PATH
export MAX_JOBS=64
echo MAX_JOBS = $MAX_JOBS
hipcc --version
rocm-smi
rocminfo | grep "gfx"
- name: Build XFormers
run: |
git clone --recursive -b $GIT_BRANCH $GITHUB_REPOSITORY
docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G -v $PWD/xformers:/xformers rocm/pytorch-nightly:latest
pip3 install --upgrade pip
pip3 uninstall -y xformers
MAX_JOBS=$MAX_JOBS pip3 install -e /xformers --verbose
pip3 install scipy==1.10
python3 -c "import torch; print(torch.__version__)"
python3 -m xformers.info
- name: Run python tests
run: |
pytest -rpfs /xformers/tests/test_mem_eff_attention.py | tee test_mem_eff_attention.log
- name: Archive logs
uses: actions/upload-artifact@v3
with:
name: test results
path: test_mem_eff_attention_ck.log

- name: Process test results
run: |
echo "Processing test results TBD"
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ outputs
xformers/_flash_attn
xformers/version.py
xformers/cpp_lib.json

## temporary files
xformers/csrc/attention/hip_fmha/*.cu
xformers/csrc/attention/hip_fmha/*.hip
xformers/csrc/attention/hip_fmha/*_hip.h
xformers/csrc/attention/hip_fmha/instances/*.cu
xformers/csrc/attention/hip_fmha/instances/*.hip
xformers/csrc/attention/hip_fmha/instances_tiled/*.cu
xformers/csrc/attention/hip_fmha/instances_tiled/*.hip

4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = https://github.com/Dao-AILab/flash-attention.git
[submodule "third_party/composable_kernel_tiled"]
path = third_party/composable_kernel_tiled
url = https://github.com/ROCm/composable_kernel.git
branch = ck_tile/dev
14 changes: 13 additions & 1 deletion docs/source/components/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,25 @@ Available implementations
:member-order: bysource

.. automodule:: xformers.ops.fmha.triton
:members: FwOp, BwOp
:members: FwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.small_k
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck_decoder
:members: FwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.ck_splitk
:members: FwOp
:member-order: bysource

Attention biases
~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ hydra-core >= 1.1

# Dependency for Mixture of Experts
fairscale >= 0.4.5
scipy
scipy >= 1.7

# Dependency for fused layers, optional
cmake
82 changes: 82 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ def get_cuda_version(cuda_dir) -> int:
return bare_metal_major * 100 + bare_metal_minor


def get_hip_version(rocm_dir) -> str:
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
try:
raw_output = subprocess.check_output(
[hipcc_bin, "--version"], universal_newlines=True
)
except Exception as e:
print(
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
)
return None
for line in raw_output.split("\n"):
if "HIP version" in line:
return line.split()[-1]
return None


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
# XXX: Not supported on windows for cuda<12
# https://github.com/Dao-AILab/flash-attention/issues/345
Expand Down Expand Up @@ -223,11 +240,27 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
]


def rename_cpp_cu(cpp_files):
for entry in cpp_files:
shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")


def get_extensions():
extensions_dir = os.path.join("xformers", "csrc")

sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True)
source_hip = glob.glob(
os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cpp"),
recursive=True,
)
source_hip_generated = glob.glob(
os.path.join(extensions_dir, "attention", "hip_fmha", "**", "*.cu"),
recursive=True,
)
# avoid the temporary .cu files generated under xformers/csrc/attention/hip_fmha
source_cuda = list(set(source_cuda) - set(source_hip_generated))
sources = list(set(sources) - set(source_hip))

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
Expand All @@ -253,6 +286,7 @@ def get_extensions():
include_dirs = [extensions_dir]
ext_modules = []
cuda_version = None
hip_version = None
flash_version = "0.0.0"

if (
Expand Down Expand Up @@ -294,6 +328,7 @@ def get_extensions():
flash_extensions = get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

if flash_extensions:
flash_version = get_flash_version()
ext_modules += flash_extensions
Expand All @@ -306,6 +341,51 @@ def get_extensions():
"--ptxas-options=-O2",
"--ptxas-options=-allow-expensive-optimizations=true",
]
elif torch.cuda.is_available() and torch.version.hip:
rename_cpp_cu(source_hip)
rocm_home = os.getenv("ROCM_PATH")
hip_version = get_hip_version(rocm_home)

source_hip_cu = []
for ff in source_hip:
source_hip_cu += [ff.replace(".cpp", ".cu")]

extension = CUDAExtension
sources += source_hip_cu
include_dirs += [
Path(this_dir) / "xformers" / "csrc" / "attention" / "hip_fmha"
]

include_dirs += [
Path(this_dir)
/ "third_party"
/ "composable_kernel_tiled"
/ "example"
/ "91_tile_program"
/ "xformers_fmha"
]

include_dirs += [
Path(this_dir) / "third_party" / "composable_kernel_tiled" / "include"
]

generator_flag = []

cc_flag = ["-DBUILD_PYTHON_PACKAGE"]
extra_compile_args = {
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": [
"-O3",
"-std=c++17",
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-DCK_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
]
+ generator_flag
+ cc_flag,
}

ext_modules.append(
extension(
Expand All @@ -320,6 +400,7 @@ def get_extensions():
return ext_modules, {
"version": {
"cuda": cuda_version,
"hip": hip_version,
"torch": torch.__version__,
"python": platform.python_version(),
"flash": flash_version,
Expand All @@ -328,6 +409,7 @@ def get_extensions():
k: os.environ.get(k)
for k in [
"TORCH_CUDA_ARCH_LIST",
"PYTORCH_ROCM_ARCH",
"XFORMERS_BUILD_TYPE",
"XFORMERS_ENABLE_DEBUG_ASSERTIONS",
"NVCC_FLAGS",
Expand Down
13 changes: 13 additions & 0 deletions tests/readme_test_on_rocm.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

1. #> pip install -e ./

2. verify testing for generic fmha inference on ROCM

#> pytest tests/test_mem_eff_attention.py::test_forward

3. verify testing for decoder fmha inference on ROCM

#> pytest tests/test_mem_eff_attention.py::test_decoder
#> pytest tests/test_mem_eff_attention.py::test_splitk_decoder


7 changes: 7 additions & 0 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ def test_order_invariance(
causal: bool,
device: torch.device,
):
if (
torch.version.hip
and device == torch.device("cuda")
and attention_name == "local"
):
# Backend calls into Sputnik library which isn't built on ROCm
device = torch.device("cpu")

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
Expand Down
15 changes: 14 additions & 1 deletion tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def test_checkpoint_with_grad(policy_fn, input_requires_grad, grad_mode):
"op",
[
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassOp,
(
xformers.ops.MemoryEfficientAttentionCutlassOp
if torch.version.cuda
else xformers.ops.MemoryEfficientAttentionCkOp
),
],
)
def test_checkpoint_attention(policy_fn, input_requires_grad, device, autocast, op):
Expand All @@ -121,6 +125,15 @@ def test_checkpoint_attention(policy_fn, input_requires_grad, device, autocast,
):
pytest.skip("skipping operator not supported in this arch")

if (
op is xformers.ops.MemoryEfficientAttentionFlashAttentionOp
and torch.version.hip
):
pytest.skip("FlashAttentionOp is not supported on ROCM!")

if op is xformers.ops.MemoryEfficientAttentionCkOp:
pytest.skip("Gradience is currently not supported by ck-tiled!")

class Attn(nn.Module):
def forward(self, x):
out = xformers.ops.memory_efficient_attention(x, x, x, op=op)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_core_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def fn_and_catch_oor(*args, **kwargs):
return fn_and_catch_oor


_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
_devices = (
["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"]
)


def test_core_attention():
Expand Down Expand Up @@ -144,6 +146,7 @@ def test_amp_attention_sparsecs(device):
@pytest.mark.skipif(
not _is_blocksparse_available, reason="Blocksparse is not available"
)
@pytest.mark.skipif(not torch.version.cuda, reason="Sparse ops not supported on ROCm")
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("data_type", [torch.float16, torch.float32])
@catch_oor
Expand Down
9 changes: 7 additions & 2 deletions tests/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
_sparse_bmm,
)

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
cuda_only = pytest.mark.skipif(
not torch.cuda.is_available() or not torch.version.cuda, reason="requires CUDA"
)

_devices = (
["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"]
)


def _baseline_matmul_with_sparse_mask(
Expand Down
Loading

0 comments on commit 44b0d07

Please sign in to comment.