From fda407ff0ac858c4ffa4f14a8a811aa07a0ab652 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 27 Feb 2024 17:52:48 +0800 Subject: [PATCH] [Initialization] Migration of Codebase from Azure DevOps. (#1) * base tuner * gpu schedule * matmul ops * initial commit * refactor fast dlight to bit blas * support i8 swizzle * int8xint2 gemm * update keep * update lop3 cpp test * all low int to float16 convert * int8 fast decoding * float16with scale * annotate tc layout propa * impl tir interleve test * impl interleave weight. * weight only propagation * support layout propagate recover schedule of dequantize. * refactor testing * enhance gemv schedule for dynamic * dequantize matmul initilization * [refactor] move comments to BitBLAS * evaluate pytorch integeration * evaluate correctness of weight only decode * annotate mit license * annotate apache/mit lisence * init logger * refactor ops test with pytest * ladder_permutate implementation * append tvm third party lisence * scaling ladder permutate impl * add storage dtype test * implement lop3 permutation ops and related test * support with propagate layout. * update tvm lisence * disable fmt in pytest * implement cpu arch for consistency * seperate gemv schedule and gemv_dequantize schedule. * fix typo * refactor quantization * init testing. * refactor matmul and operators * append dequantize and test items * reslove lisence related items * refactor implementation * init read me. * integration with faster transform imp * integerate bug fix. * update ignore * improve code structure. * update mit lisence * remove gitkeep file * provide simple tir benchmark result. * enhance build * auto layout deduce * fix default tensorize. * update ReadMe * update readme * update read me * update readme * simple fix * readme fix --- .gitignore | 66 ++ 3rdparty/.gitkeep | 0 README.md | 307 +++++- SECURITY.md | 6 +- THIRDPARTYNOTICES.txt | 208 ++++ VERSION | 1 + benchmark/dsl/convolution.py | 168 ++++ benchmark/dsl/matmul.py | 252 +++++ images/auto_tensorize.png | Bin 0 -> 32396 bytes .../kenrel_output/ladder_kernel.cu | 407 ++++++++ .../kenrel_output/ladder_kernel.h | 206 ++++ .../fastertransformer/kernel_generator.py | 224 +++++ .../kernel_template.int2.bitblas.cu.template | 16 + .../kernel_template.int4.bitblas.cu.template | 16 + integration/pytorch/quant_linear.py | 157 +++ integration/pytorch/test_quant_linear.py | 115 +++ maint/scripts/apply_mit_license.sh | 48 + maint/scripts/check_mit_license.sh | 25 + maint/scripts/installation.sh | 24 + maint/scripts/mit_liscense1.txt | 2 + maint/scripts/mit_liscense2.txt | 2 + python/bitblas/__init__.py | 55 ++ python/bitblas/base/__init__.py | 18 + python/bitblas/base/analysis.py | 342 +++++++ python/bitblas/base/common_schedules.py | 163 ++++ python/bitblas/base/roller/__init__.py | 6 + python/bitblas/base/roller/arch/__init__.py | 14 + python/bitblas/base/roller/arch/arch_base.py | 41 + python/bitblas/base/roller/arch/cpu.py | 19 + python/bitblas/base/roller/arch/cuda.py | 60 ++ python/bitblas/base/roller/bestfit.py | 66 ++ python/bitblas/base/roller/config.py | 237 +++++ python/bitblas/base/roller/node.py | 372 +++++++ python/bitblas/base/roller/policy/__init__.py | 5 + python/bitblas/base/roller/policy/common.py | 56 ++ python/bitblas/base/roller/policy/default.py | 770 +++++++++++++++ .../bitblas/base/roller/policy/tensorcore.py | 349 +++++++ python/bitblas/base/roller/rasterization.py | 85 ++ .../base/roller/shape_inference/__init__.py | 4 + .../base/roller/shape_inference/common.py | 66 ++ .../base/roller/shape_inference/tir.py | 399 ++++++++ python/bitblas/base/schedule_rule.py | 149 +++ python/bitblas/base/transform.py | 214 +++++ python/bitblas/base/utils.py | 489 ++++++++++ python/bitblas/generator.py | 17 + python/bitblas/gpu/__init__.py | 21 + python/bitblas/gpu/base.py | 44 + python/bitblas/gpu/element_wise.py | 97 ++ python/bitblas/gpu/fallback.py | 95 ++ python/bitblas/gpu/gemv.py | 860 +++++++++++++++++ python/bitblas/gpu/gemv_dequantize.py | 148 +++ python/bitblas/gpu/general_reduction.py | 465 +++++++++ python/bitblas/gpu/intrin/lop3.py | 708 ++++++++++++++ python/bitblas/gpu/matmul.py | 372 +++++++ python/bitblas/gpu/matmul_analysis.py | 763 +++++++++++++++ python/bitblas/gpu/matmul_mma.py | 684 +++++++++++++ python/bitblas/gpu/matmul_mma_dequantize.py | 637 ++++++++++++ python/bitblas/gpu/matmul_wmma.py | 909 ++++++++++++++++++ python/bitblas/gpu/reduction.py | 301 ++++++ python/bitblas/gpu/rmsnorm.py | 144 +++ python/bitblas/gpu/transpose.py | 133 +++ python/bitblas/gpu/utils.py | 86 ++ python/bitblas/ops/__init__.py | 6 + python/bitblas/ops/gemv_impl.py | 73 ++ python/bitblas/ops/impl/__init__.py | 3 + .../bitblas/ops/impl/ladder_permutate_impl.py | 89 ++ .../bitblas/ops/impl/lop3_permutate_impl.py | 171 ++++ .../ops/impl/matmul_dequantize_impl.py | 291 ++++++ python/bitblas/ops/impl/matmul_impl.py | 617 ++++++++++++ python/bitblas/ops/ladder_permutate.py | 96 ++ python/bitblas/ops/lop3_permutate.py | 75 ++ python/bitblas/ops/matmul.py | 204 ++++ python/bitblas/ops/matmul_dequantize.py | 251 +++++ python/bitblas/ops/operator.py | 214 +++++ python/bitblas/quantization/__init__.py | 8 + python/bitblas/quantization/quantization.py | 148 +++ python/bitblas/quantization/utils.py | 64 ++ python/bitblas/relax/op/interleave_weight.py | 23 + python/bitblas/relax/transform/__init__.py | 5 + .../relax/transform/annotate_decode_block.py | 131 +++ .../relax/transform/weight_only_propagate.py | 463 +++++++++ python/bitblas/testing/__init__.py | 10 + python/bitblas/utils/__init__.py | 10 + python/bitblas/utils/tensor_adapter.py | 16 + python/bitblas_cli.py | 2 + testing/cpp/.gitignore | 2 + testing/cpp/CMakeLists.txt | 15 + .../cpp/lop3_type_conversion/CMakeLists.txt | 12 + .../lop3_type_conversion/fast_decoding.hpp | 441 +++++++++ .../lowprecision_to_float16.cu | 680 +++++++++++++ .../lowprecision_to_int8.cu | 345 +++++++ .../dsl/test_auto_normalized_tensorcore.py | 157 +++ .../python/operators/test_int8xint8_gemm.py | 36 + .../operators/test_ladder_permutate_ops.py | 51 + .../operators/test_lop3_permutate_ops.py | 38 + .../operators/test_matmul_dequantize_ops.py | 276 ++++++ testing/python/operators/test_matmul_ops.py | 211 ++++ .../test_weight_dequantize_matmul_codegen.py | 73 ++ testing/python/test_fused_decode_matmul.py | 132 +++ testing/python/test_lop3_type_conversion.py | 78 ++ testing/python/test_matmul_codegen.py | 221 +++++ testing/python/test_weight_only_transform.py | 353 +++++++ .../python/tir_expr/float16xfloat16_gemm.py | 82 ++ testing/python/tir_expr/int8xint8_gemm.py | 368 +++++++ testing/python/tir_expr/test_tir.py | 74 ++ testing/python/tir_expr/test_tir_0.py | 189 ++++ testing/python/tir_expr/test_tir_1.py | 179 ++++ testing/python/tir_expr/test_tir_2.py | 95 ++ testing/python/tir_expr/test_tir_3.py | 87 ++ .../type_conversion/int4b_fp16_convert.py | 229 +++++ .../test_numpy_compress_convert.py | 2 + .../correctness/test_fp16xint4_correctness.py | 37 + .../python/weight_only/index_map_deduce.py | 24 + testing/python/weight_only/index_map_fuse.py | 81 ++ .../python/weight_only/inverse_index_map.py | 119 +++ 115 files changed, 20345 insertions(+), 25 deletions(-) create mode 100644 .gitignore create mode 100644 3rdparty/.gitkeep create mode 100644 THIRDPARTYNOTICES.txt create mode 100644 VERSION create mode 100644 benchmark/dsl/convolution.py create mode 100644 benchmark/dsl/matmul.py create mode 100644 images/auto_tensorize.png create mode 100644 integration/fastertransformer/kenrel_output/ladder_kernel.cu create mode 100644 integration/fastertransformer/kenrel_output/ladder_kernel.h create mode 100644 integration/fastertransformer/kernel_generator.py create mode 100644 integration/fastertransformer/template/kernel_template.int2.bitblas.cu.template create mode 100644 integration/fastertransformer/template/kernel_template.int4.bitblas.cu.template create mode 100644 integration/pytorch/quant_linear.py create mode 100644 integration/pytorch/test_quant_linear.py create mode 100755 maint/scripts/apply_mit_license.sh create mode 100755 maint/scripts/check_mit_license.sh create mode 100755 maint/scripts/installation.sh create mode 100644 maint/scripts/mit_liscense1.txt create mode 100644 maint/scripts/mit_liscense2.txt create mode 100644 python/bitblas/__init__.py create mode 100644 python/bitblas/base/__init__.py create mode 100644 python/bitblas/base/analysis.py create mode 100644 python/bitblas/base/common_schedules.py create mode 100644 python/bitblas/base/roller/__init__.py create mode 100644 python/bitblas/base/roller/arch/__init__.py create mode 100644 python/bitblas/base/roller/arch/arch_base.py create mode 100644 python/bitblas/base/roller/arch/cpu.py create mode 100644 python/bitblas/base/roller/arch/cuda.py create mode 100644 python/bitblas/base/roller/bestfit.py create mode 100644 python/bitblas/base/roller/config.py create mode 100644 python/bitblas/base/roller/node.py create mode 100644 python/bitblas/base/roller/policy/__init__.py create mode 100644 python/bitblas/base/roller/policy/common.py create mode 100644 python/bitblas/base/roller/policy/default.py create mode 100644 python/bitblas/base/roller/policy/tensorcore.py create mode 100644 python/bitblas/base/roller/rasterization.py create mode 100644 python/bitblas/base/roller/shape_inference/__init__.py create mode 100644 python/bitblas/base/roller/shape_inference/common.py create mode 100644 python/bitblas/base/roller/shape_inference/tir.py create mode 100644 python/bitblas/base/schedule_rule.py create mode 100644 python/bitblas/base/transform.py create mode 100644 python/bitblas/base/utils.py create mode 100644 python/bitblas/generator.py create mode 100644 python/bitblas/gpu/__init__.py create mode 100644 python/bitblas/gpu/base.py create mode 100644 python/bitblas/gpu/element_wise.py create mode 100644 python/bitblas/gpu/fallback.py create mode 100644 python/bitblas/gpu/gemv.py create mode 100644 python/bitblas/gpu/gemv_dequantize.py create mode 100644 python/bitblas/gpu/general_reduction.py create mode 100644 python/bitblas/gpu/intrin/lop3.py create mode 100644 python/bitblas/gpu/matmul.py create mode 100644 python/bitblas/gpu/matmul_analysis.py create mode 100644 python/bitblas/gpu/matmul_mma.py create mode 100644 python/bitblas/gpu/matmul_mma_dequantize.py create mode 100644 python/bitblas/gpu/matmul_wmma.py create mode 100644 python/bitblas/gpu/reduction.py create mode 100644 python/bitblas/gpu/rmsnorm.py create mode 100644 python/bitblas/gpu/transpose.py create mode 100644 python/bitblas/gpu/utils.py create mode 100644 python/bitblas/ops/__init__.py create mode 100644 python/bitblas/ops/gemv_impl.py create mode 100644 python/bitblas/ops/impl/__init__.py create mode 100644 python/bitblas/ops/impl/ladder_permutate_impl.py create mode 100644 python/bitblas/ops/impl/lop3_permutate_impl.py create mode 100644 python/bitblas/ops/impl/matmul_dequantize_impl.py create mode 100644 python/bitblas/ops/impl/matmul_impl.py create mode 100644 python/bitblas/ops/ladder_permutate.py create mode 100644 python/bitblas/ops/lop3_permutate.py create mode 100644 python/bitblas/ops/matmul.py create mode 100644 python/bitblas/ops/matmul_dequantize.py create mode 100644 python/bitblas/ops/operator.py create mode 100644 python/bitblas/quantization/__init__.py create mode 100644 python/bitblas/quantization/quantization.py create mode 100644 python/bitblas/quantization/utils.py create mode 100644 python/bitblas/relax/op/interleave_weight.py create mode 100644 python/bitblas/relax/transform/__init__.py create mode 100644 python/bitblas/relax/transform/annotate_decode_block.py create mode 100644 python/bitblas/relax/transform/weight_only_propagate.py create mode 100644 python/bitblas/testing/__init__.py create mode 100644 python/bitblas/utils/__init__.py create mode 100644 python/bitblas/utils/tensor_adapter.py create mode 100644 python/bitblas_cli.py create mode 100644 testing/cpp/.gitignore create mode 100644 testing/cpp/CMakeLists.txt create mode 100644 testing/cpp/lop3_type_conversion/CMakeLists.txt create mode 100644 testing/cpp/lop3_type_conversion/fast_decoding.hpp create mode 100644 testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu create mode 100644 testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu create mode 100644 testing/python/dsl/test_auto_normalized_tensorcore.py create mode 100644 testing/python/operators/test_int8xint8_gemm.py create mode 100644 testing/python/operators/test_ladder_permutate_ops.py create mode 100644 testing/python/operators/test_lop3_permutate_ops.py create mode 100644 testing/python/operators/test_matmul_dequantize_ops.py create mode 100644 testing/python/operators/test_matmul_ops.py create mode 100644 testing/python/operators/test_weight_dequantize_matmul_codegen.py create mode 100644 testing/python/test_fused_decode_matmul.py create mode 100644 testing/python/test_lop3_type_conversion.py create mode 100644 testing/python/test_matmul_codegen.py create mode 100644 testing/python/test_weight_only_transform.py create mode 100644 testing/python/tir_expr/float16xfloat16_gemm.py create mode 100644 testing/python/tir_expr/int8xint8_gemm.py create mode 100644 testing/python/tir_expr/test_tir.py create mode 100644 testing/python/tir_expr/test_tir_0.py create mode 100644 testing/python/tir_expr/test_tir_1.py create mode 100644 testing/python/tir_expr/test_tir_2.py create mode 100644 testing/python/tir_expr/test_tir_3.py create mode 100644 testing/python/type_conversion/int4b_fp16_convert.py create mode 100644 testing/python/type_conversion/test_numpy_compress_convert.py create mode 100644 testing/python/weight_only/correctness/test_fp16xint4_correctness.py create mode 100644 testing/python/weight_only/index_map_deduce.py create mode 100644 testing/python/weight_only/index_map_fuse.py create mode 100644 testing/python/weight_only/inverse_index_map.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000000..11b6ef5c8e58 --- /dev/null +++ b/.gitignore @@ -0,0 +1,66 @@ +# Compiled Object files +*.slo +*.lo +*.o +*.obj +*.pyc + +# Precompiled Headers +*.gch +*.pch + +# emacs +*~ + +# vim +*.swp +*.swo + +debug/ +build/ +dist/ +__pycache__ +nnfusion.tar.gz + +# makeenv and test intermediate files +tmp/ + +venv/ +.vscode/ +.vs/ + +# VisualGDB files +VisualGDB/ +toolchain.cmake + +# docbuild artifacts +doc/sphinx/build/* +doc/doxygen/*.xml +doc/doxygen/*.html +doc/doxygen/man/* +doc/doxygen/latex/* +doc/doxygen/xml/* +doc/doxygen/html/* + +# git merge +*.orig +\#* +\.#* + +# idea +.idea/* + +# python egg +*.egg-info + +# Macos +**/.DS_Store + +nnfusion_rt/ +models/frozenmodels/ + +# log +*.log + +# pkl +*.pkl_* diff --git a/3rdparty/.gitkeep b/3rdparty/.gitkeep new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/README.md b/README.md index 5cd7cecfc87a..227bb4fb0066 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,296 @@ -# Project +# BitBLAS -> This repo has been populated by an initial template to help get you started. Please -> make sure to update the content to build a great experience for community-building. +BitBLAS is a light weight framework to generate high performance CUDA/HIP code for BLAS operators with swizzling and layout propagation. BitBLAS can achieve comparable performance with cuBLAS and provide more flexibility with DSL (TIR Script). -As the maintainer of this project, please make a few updates: +## Feature -- Improving this README.MD file to provide a great experience -- Updating SUPPORT.MD with content about this project's support experience -- Understanding the security reporting process in SECURITY.MD -- Remove this section from the README +- Auto Tensorization. +- High Performance (FP16xFP16, FP16xINT4/2/1, INT8xINT8, INT8xINT4/2/1). +- Dynamic symbolic support, generate kernel with dynamic shape. + +## Requirements + +To manually install BitBLAS, please checkout `maint/scripts/installation.sh`. + +Also Make sure you already have the cuda toolkit (version >= 11) installed in the system. + +Finally, add ./python and tvm/python to PYTHONPATH. + +## Quick Start +We provide two primary ways to do the code generation: using a high-level DSL (TensorIR Script), or using packed Operators. + +You can find some example dsl implementation in `python/bitblas/ops/impl` and `benchmark/dsl`, see more examples and tutorials in [apache/tvm](https://github.com/apache/tvm) + +### Using BitBLAS from DSL +```python +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.base.utils import apply_and_build +@tvm.script.ir_module +class MatmulNT: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [M, K], dtype=in_dtype) + B = T.match_buffer(b, [N, K], dtype=in_dtype) + C = T.match_buffer(c, [M, N], dtype=out_dtype) + + for i, j, k in T.grid(M, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ + vj, vk + ].astype(out_dtype) + +ir_module = MatmulNT +func = ir_module["main"] +target = tvm.target.Target("nvidia/nvidia-a100") +arch = CUDA(target) +``` + +Get tuning policy and candidates: + +```python +# Tune with SIMT Cuda Core +policy = DefaultPolicy(func=func, arch=arch) +try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) +except: + tags = None +# Tune with Tensor Core if possible +if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + +configs = policy.emit_config(topk=20) +''' +[BitBLAS] Evaluation with config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.032 ms +[BitBLAS] Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.021 ms +[BitBLAS] Evaluation with config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.023 ms +[BitBLAS] Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.023 ms +[BitBLAS] Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.027 ms +[BitBLAS] Evaluation with config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.025 ms +[BitBLAS] Evaluation with config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.023 ms +[BitBLAS] Evaluation with config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.025 ms +[BitBLAS] Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.037 ms +[BitBLAS] Evaluation with config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.037 ms +[BitBLAS] Evaluation with config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.026 ms +[BitBLAS] Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.043 ms +[BitBLAS] Evaluation with config {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.042 ms +[BitBLAS] Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.025 ms +[BitBLAS] Evaluation with config {'block': [256, 32], 'warp': [128, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.029 ms +[BitBLAS] Evaluation with config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.028 ms +[BitBLAS] Evaluation with config {'block': [256, 64], 'warp': [128, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.027 ms +[BitBLAS] Evaluation with config {'block': [128, 256], 'warp': [64, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.044 ms +[BitBLAS] Evaluation with config {'block': [256, 128], 'warp': [128, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.040 ms +[BitBLAS] Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} +[BitBLAS] Time cost of this config: 0.047 ms +''' +``` + +Apply and build and get best code generation result: +```python +cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) +# get the best code generation result. +print(best.code) +''' +extern "C" __global__ void __launch_bounds__(128) default_function_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) { + ... +} +''' +``` + +we also provide something interesting with DSL. + +#### Auto Tensorization + +Say we currently have two policies, one is for SIMT Cuda Core, another is for TensorCore. The decision to utilize a TensorCore policy over a SIMT Cuda Core policy can be enhanced by the integration of an auto-tensorization strategy, it allows BitBLAS to automatically select if the DSL Expression can uitlize TensorCore. + +![Auto Tensorization](./images/auto_tensorize.png) + +```python +# Assume func is conv2d, after this api, the tensorized_func is the tensorized version of the conv2d, otherwise, the tags is None. +tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) +``` + +#### Tune with dynamic symbolic + +As in LLM Serving, the input shape is dynamic, we can use the dynamic symbolic to generate high performance kernel with dynamic shape. + +```python +@tvm.script.ir_module +class MatmulNT: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.int32() + A = T.match_buffer(a, [m, K], dtype=in_dtype) + B = T.match_buffer(b, [N, K], dtype=in_dtype) + C = T.match_buffer(c, [m, N], dtype=out_dtype) + + for i, j, k in T.grid(m, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ + vj, vk + ].astype(out_dtype) + +from bitblas import fast_tune_with_dynamic_range +# Tune with dynamic symbolic +optimized_mod = fast_tune_with_dynamic_range( + func, target, topk=topk, parallel_build=True, + dynamic_range={ + "M": [1, 1024] + } +) + +# fianlly, we will generate a dispatch func to dispatch the kernel with dynamic symbolic. +''' +@IRModule +class MatmulNT: + + def matmul_nt_opt_m_1(A: Tensor, T_reshape: Tensor, m: int): + ... + + def matmul_nt_opt_m_256(A: Tensor, T_reshape: Tensor, m: int): + ... + + def dispatcher(args): + if m <= 1: + matmul_nt_opt_m_1(A.data, T_reshape.data, m) + if m > 1 and m <= 256: + matmul_nt_opt_m_256(A.data, T_reshape.data, m) + if m > 256: + matmul_nt_m_256(A.data, T_reshape.data, m) +''' + +``` + + + +### Using BitBLAS from packed Operators + +We packed some operators in `bitblas/ops/impl` with configs, you can use them directly. Please see more examples in `testing/python/operators` + +```python +from bitblas.ops.matmul import Matmul, MatmulConfig +matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, +) +matmul = Matmul( + config=matmul_config, + target=target, +) +``` + +By default, we will apply a default schedule into the operator, you can also get code generation result by calling matmul.codegen(). + +```python +print(matmul.codegen()) +''' +extern "C" __global__ void __launch_bounds__(128) default_function_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) { + ... +} +''' +``` + +If you want to tune the operator to get better performance, you can use the api `hardware_aware_finetune`. + +```python +print(matmul.profile_latency()) +matmul.hardware_aware_finetune(topk=20) +print(matmul.profile_latency()) +``` + +The latency will be reduced after tuning. We re-implement OSDI'22 paper Roller to do fast tuning with hardware information. Typically, the 20 candidates is good enough. +#### Tune with Dynamic Symbolic + +```python +matmul_config = MatmulConfig( + M=[1, 1024], + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, +) +``` +#### Tune with FPA INTB Operators + +Generate High Performance Kernel for WeightOnly Quantization. + +```python +from bitblas.ops.matmul_dequantize import ( + MatmulWeightOnlyDequantize, + MatmulWeightOnlyDequantizeConfig, +) +matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, +) +matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, +) +``` ## Contributing -This project welcomes contributions and suggestions. Most contributions require you to agree to a -Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us -the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. +This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. -When you submit a pull request, a CLA bot will automatically determine whether you need to provide -a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions -provided by the bot. You will only need to do this once across all repos using our CLA. +When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. -This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). -For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or -contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. +This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments. ## Trademarks -This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft -trademarks or logos is subject to and must follow -[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). -Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. -Any use of third-party trademarks or logos are subject to those third-party's policies. +This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies. diff --git a/SECURITY.md b/SECURITY.md index b3c89efc852e..7b9e6e8bffa5 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,10 +1,10 @@ - + ## Security -Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). -If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. ## Reporting Security Issues diff --git a/THIRDPARTYNOTICES.txt b/THIRDPARTYNOTICES.txt new file mode 100644 index 000000000000..f377e67bba20 --- /dev/null +++ b/THIRDPARTYNOTICES.txt @@ -0,0 +1,208 @@ +BitBLAS uses third-party material as listed below. The attached notices are +provided for informational purposes only. + +Notice for apache/tvm +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ diff --git a/VERSION b/VERSION new file mode 100644 index 000000000000..b9f8bf2855b2 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.0.0.dev \ No newline at end of file diff --git a/benchmark/dsl/convolution.py b/benchmark/dsl/convolution.py new file mode 100644 index 000000000000..592544c3b302 --- /dev/null +++ b/benchmark/dsl/convolution.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import numpy as np +import tvm +from tvm.script import tir as T +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.gpu import Matmul, matmul_mma +from bitblas.base.utils import apply_and_build +import time +from tvm import te, tir + + +def conv2d_nhwc_hwio( + n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dtype="float16" +): + A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) + B = te.placeholder((kh, kw, c, f), name="weight", dtype=in_dtype) + + pad_shape = (n, h + 2 * p, w + 2 * p, c) + pad_value = tir.const(0.0, A.dtype) + pad = te.compute( + pad_shape, + lambda n, h, w, c: te.if_then_else( + tir.all( + h >= p, + w >= p, + h < pad_shape[1] - p, + w < pad_shape[2] - p, + ), + A[n, h - p, w - p, c], + pad_value, + ), + name="pad", + ) + kernel_h, kernel_w = kh, kw + stride_h, stride_w = s, s + dilation_h, dilation_w = d, d + out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + out_shape = (n, out_h, out_w, f) + kh = te.reduce_axis((0, kernel_h), name="kh") + kw = te.reduce_axis((0, kernel_w), name="kw") + c = te.reduce_axis((0, c), name="c") + C = te.compute( + out_shape, + lambda n, h, w, f: te.sum( + pad[ + n, + h * stride_h + kh * dilation_h, + w * stride_w + kw * dilation_w, + c, + ] + * B[kh, kw, c, f], + axis=[kh, kw, c], + ), + name="C", + ) + return tvm.ir.IRModule({"main": te.create_prim_func([A, B, C])}) + + +# fmt:off +benchmark_sets = [ + # (prim_func, input_args, BitBLAS_default_schedule), + (conv2d_nhwc_hwio, (128, 64, 224, 224, 64, 1, 1, 2, 1, 3, "float16", "float16"), Matmul), + # (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float32", "float32"), Matmul), + # (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float16", "float16"), Matmul), +] +# fmt:on +benchmark_results = {} +for get_prim_func, input_args, d_schedule in benchmark_sets: + ir_module = get_prim_func(*input_args) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + configs = policy.emit_config(20) + + tune_start = time.time() + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + fast_tune_time = time.time() - tune_start + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print( + "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) + ) + + # evaluate the performance of the default schedule + + rule = d_schedule() + default_tune_start = time.time() + sch_default = rule.apply(func, target, False) + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + mod_default = tvm.build(sch_default.mod["main"], target="cuda") + default_tune_time = time.time() - default_tune_start + + args = func.buffer_map.values() + + profile_tensors = [] + for arg in args: + profile_tensors.append( + tvm.nd.array( + np.random.uniform(0, 1, [int(i) for i in arg.shape]).astype(arg.dtype), + device=arch.device, + ) + ) + + timer_cuda_mod = mod_default.time_evaluator( + mod_default.entry_name, arch.device, number=5 + ) + t = timer_cuda_mod(*profile_tensors).mean + + print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) + + profile_config = { + f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { + "BitBLAS_top20_tune_time": fast_tune_time, + "BitBLAS_top1_latency": cpresults[0].latency * 1e3, + "BitBLAS_top20_latency": best.latency * 1e3, + "BitBLAS_default_tune_time": default_tune_time, + "BitBLAS_default_latency": t * 1e3, + } + } + benchmark_results.update(profile_config) + +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Tune Time", + "BitBLAS Top1 Latency", + "BitBLAS Top20 Latency", + "BitBLAS Default Tune Time", + "BitBLAS Default Latency", +] + +col_width = ( + max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + + 2 +) # padding + +print("".join(word.ljust(col_width) for word in headers)) + +print("-" * col_width * len(headers)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f" {str(values['BitBLAS_top20_tune_time'])} s", + f"{values['BitBLAS_top1_latency']:.3f} ms", + f"{values['BitBLAS_top20_latency']:.3f} ms", + str(values["BitBLAS_default_tune_time"]), + f"{values['BitBLAS_default_latency']:.3f} ms", + ] + print("".join(word.ljust(col_width) for word in row)) diff --git a/benchmark/dsl/matmul.py b/benchmark/dsl/matmul.py new file mode 100644 index 000000000000..f48ee8412789 --- /dev/null +++ b/benchmark/dsl/matmul.py @@ -0,0 +1,252 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import numpy as np +import tvm +from tvm.script import tir as T +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.gpu import Matmul +from bitblas.base.utils import apply_and_build +import time + + +def matmul_nt(M, N, K, in_dtype="float16", out_dtype="float16"): + @tvm.script.ir_module + class MatmulNT: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [M, K], dtype=in_dtype) + B = T.match_buffer(b, [N, K], dtype=in_dtype) + C = T.match_buffer(c, [M, N], dtype=out_dtype) + + for i, j, k in T.grid(M, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ + vj, vk + ].astype(out_dtype) + + return MatmulNT + + +def matmul_nn(M, N, K, in_dtype="float16", out_dtype="float16"): + @tvm.script.ir_module + class MatmulNN: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [M, K], dtype=in_dtype) + B = T.match_buffer(b, [K, N], dtype=in_dtype) + C = T.match_buffer(c, [M, N], dtype=out_dtype) + + for i, j, k in T.grid(M, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ + vk, vj + ].astype(out_dtype) + + return MatmulNN + + +def matmul_nt_propagate_b_f16_f16_mma(M, N, K, in_dtype="float16", out_dtype="float16"): + wm, wn, wk = 16, 16, 16 + if in_dtype == "int8": + wm, wn, wk = 16, 16, 32 + + @tvm.script.ir_module + class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr( + {"global_symbol": "main", "tir.noalias": True, "smooth_b": True} + ) + A = T.match_buffer(a, [M, K], dtype=in_dtype) + B = T.match_buffer(b, [N // wn, K // wk, wn, wk], dtype=in_dtype) + C = T.match_buffer(c, [M, N], dtype=out_dtype) + B_reindex = T.alloc_buffer([N, K], dtype=in_dtype) + + for j, k in T.grid(N, K): + with T.block("B_reindex"): + vj, vk = T.axis.remap("SS", [j, k]) + B_reindex[vj, vk] = B[ + vj // wn, + vk // wk, + vj % wn // 8 * 8 + vj % 4 * 2 + vk % wn // 8, + vj % 8 // 4 * 8 + vk % 8, + ] + + for i, j, k in T.grid(M, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B_reindex[ + vj, vk + ].astype(out_dtype) + + return MyModule + + +def matmul_nt_propagate_a_b(M, N, K, in_dtype="float16", out_dtype="float16"): + wm, wn, wk = 16, 16, 16 + if in_dtype == "int8": + wm, wn, wk = 16, 16, 32 + + @tvm.script.ir_module + class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr( + { + "global_symbol": "main", + "tir.noalias": True, + "smooth_a": True, + "smooth_b": True, + } + ) + A = T.match_buffer(a, [M // wm, K // wk, wm, wk], dtype=in_dtype) + B = T.match_buffer(b, [N // wn, K // wk, wn, wk], dtype=in_dtype) + C = T.match_buffer(c, [M, N], dtype=out_dtype) + A_reindex = T.alloc_buffer([M, K], dtype=in_dtype) + B_reindex = T.alloc_buffer([N, K], dtype=in_dtype) + + for i, k in T.grid(M, K): + with T.block("A_reindex"): + vj, vk = T.axis.remap("SS", [i, k]) + A_reindex[vj, vk] = A[vj // wm, vk // wk, vj % wm, vk % wk] + + for j, k in T.grid(N, K): + with T.block("B_reindex"): + vj, vk = T.axis.remap("SS", [j, k]) + B_reindex[vj, vk] = B[vj // wn, vk // wk, vj % wn, vk % wk] + + for i, j, k in T.grid(M, N, K): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A_reindex[vi, vk].astype( + out_dtype + ) * B_reindex[vj, vk].astype(out_dtype) + + return MyModule + + +# fmt:off +benchmark_sets = [ + # (prim_func, input_args, default_dlight_schedule), + (matmul_nt, (1024, 1024, 1024, "float16", "float16"), Matmul), + (matmul_nt, (16, 8192, 8192, "float16", "float16"), Matmul), + (matmul_nt, (32, 8192, 8192, "float16", "float16"), Matmul), + (matmul_nt, (16384, 16384, 16384, "float16", "float16"), Matmul), + (matmul_nt, (16384, 16384, 16384, "int8", "int32"), Matmul), + (matmul_nn, (1024, 1024, 1024, "float16", "float16"), Matmul), + (matmul_nn, (8192, 8192, 8192, "float16", "float16"), Matmul), + (matmul_nn, (16384, 16384, 16384, "float16", "float16"), Matmul), + (matmul_nt, (1024, 1024, 1024, "float32", "float32"), Matmul), + (matmul_nt_propagate_b_f16_f16_mma, (16384, 16384, 16384), Matmul), + (matmul_nt_propagate_a_b, (16384, 16384, 16384, "int8", "int32"), Matmul), + (matmul_nt_propagate_a_b, (16384, 16384, 16384, "float16", "float16"), Matmul), +] +# fmt:on + +benchmark_results = {} +for get_prim_func, input_args, d_schedule in benchmark_sets: + ir_module = get_prim_func(*input_args) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + tune_start = time.time() + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + fast_tune_time = time.time() - tune_start + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print( + "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) + ) + + # evaluate the performance of the default schedule + + rule = d_schedule() + default_tune_start = time.time() + sch_default = rule.apply(func, target, False) + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + mod_default = tvm.build(sch_default.mod["main"], target="cuda") + default_tune_time = time.time() - default_tune_start + + args = func.buffer_map.values() + + profile_tensors = best.profile_tensors + + timer_cuda_mod = mod_default.time_evaluator( + mod_default.entry_name, arch.device, number=5 + ) + t = timer_cuda_mod(*profile_tensors).mean + + print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) + + profile_config = { + f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { + "fast_dlight_top20_tune_time": fast_tune_time, + "fast_dlight_top1_latency": cpresults[0].latency * 1e3, + "fast_dlight_top20_latency": best.latency * 1e3, + "default_dlight_tune_time": default_tune_time, + "default_dlight_latency": t * 1e3, + } + } + + benchmark_results.update(profile_config) + +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Tune Time", + "BitBLAS Top1 Latency", + "BitBLAS Top20 Latency", + "DefaultDLight Tune Time", + "DefaultDLight Latency", +] + +col_width = ( + max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + + 2 +) # padding + +print("".join(word.ljust(col_width) for word in headers)) + +print("-" * col_width * len(headers)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f" {str(values['fast_dlight_top20_tune_time'])} s", + f"{values['fast_dlight_top1_latency']:.3f} ms", + f"{values['fast_dlight_top20_latency']:.3f} ms", + str(values["default_dlight_tune_time"]), + f"{values['default_dlight_latency']:.3f} ms", + ] + print("".join(word.ljust(col_width) for word in row)) diff --git a/images/auto_tensorize.png b/images/auto_tensorize.png new file mode 100644 index 0000000000000000000000000000000000000000..f51ed23b09ae953886fd2fd0bb6dd4822d023aa0 GIT binary patch literal 32396 zcmbrm1yq!8*FFje%pfp~q(}`QA&oRBFoY5k(kN+aFfA>G{_(%q;u2>u^_ zzVG|J-&tp^v(7oQmJ6AA=6UYe``XvO_Pz;`mwk)<_}OD5BqVG}2{8pEBs2>oBxGzb z3h>T_eEL4{1KC00tq4--2zkiBdbo2tRyjL8wv=0gEc-^Klz{R%8#j{!;b-9T*Rz| zhH;4};PDfAv?r@LB4d$9UdOjXiKkp9Bc6gBr&>fT(9vYRv`fCx!JUTfp`o4+z*ztD znnH#=%ro6I$!)FZgJ8f@F zRJbkj;@sR-$;u7$9jFwE6!q_QNV~{scZVi_woOWJH>c@@p{PV{vXBlr@0&?#;0jEX zJM~HL)rV`-aoXTWa9hd4P1x_F*1Njt6e#enH2ezmX@QDus$#3-w7=(-p zm&f~OregxDFdx{r=N-V4p;P{dhnt*1c)>6*ONcEL7%bNs?BTtC-I$6L6^aU%`Tu{r ztSW@@eHoJ5c<)hmrUt**+|U&+_;* ze6GRGz@T@g0H%!x5s8_^!o-=>J=KkBheA+zwI|jCfC<9#pWY!J=8L2~HV)^ls#hwQ zZ?w>L2-UO8RUj4(NdzvEjc*nJ<4`QrRGcmcEp?&&I5>;_u>MUL%7g(z2EMsVwnXZ$ zL4yIJ93b`=Cw@%AlZG*oXbWxItF!^Fg}f#q;UT@Lh<$0s(3;iNQU)T%T7&tl{*v@Y z|7#z2nb8Pna{!#>1LIf$Q^EB`23hg6va#=n&8e29n@=70CBI@Aq($obgwbNpau$lI zzNFDq_^H-Q(y8SY9_!N!ys$NA>?PM~vmqxOF63v_{{lAd!e1RYa(%YnmJ~B{2Wq#I zw5XN?=~;Zx2$BnvwTZp_y!#5+9WB8uZ|K7m(lHx&!7Y8y7ec2L&W!?})|2~{NAawl zTO2<|m>27<1603`t(#6*SvO zt=FR-8{7364TGW;KZkidGW+fg)k4b>N)X@$2V_T943L$p(M%ca)mlI;y=@olM+yr& z>cnmDD+9-j+O)A?-#p*^fj!`-Fp;Ajbid2I2HiQam@a>k;Kv(20TT`~)K(98<9ks1 z$70=n|9nVb6cJ&;B#8h={#dBwQFcC7jN*BF$X7vvJ++4^KYm#3ECg&|Q6vzdyJ8n5 zrONZnkPqx^7t2x3=X&XH#3+V^2_VEb&j|{YznIjzyAGRk!2^r?nL7>hN2lk|VxHt5 zz6;aI8@bXucnyva+wt3usuCwfy{!^0Gca!w=iZ4SihJJ7Ul63gK#1zqm!=3oanbF) zA6_L4Ou{gq^oEkARr7i{K}Co^3^=Jyza{>B9J_bz0o{xW`-(%G7ju4yaR_X*)BmhC z9muykJMX!;)k}$9hU-F(F4iT+nJ^s|P1qTeh5$H+oK{{)dA=HvGDI-c@4*SEqi0NQAbzYHUh`>duyXmxt&kY(R78|5l=qR5inhJA z9`qJgj_$WzSd0z8eAh08*u#GB`pH84`eaWOoY5W`Vfw-5gl-+7huasAtUtRa3{7%sI)O7csr z&?O%LVdtm;qxjAbUJJMp@xY_kP6WpxR+o9@6{Z>kk5Ft2P}WrZvJtQsvX_pbk(n{m z*oEpwfdDiQ(t3SggFOuKF0%*`RhOY{CEWMDu-teP#IA%?hvB1(IpLBg1cGzA{YdVJ z@?8Z1q3!j50F7VZAsXQ6w?af*sYPSN%R~6HE&cx#>v)Ml?OrEY-nEay;$L2z?FVgo(5v>Y566sP zGLcDGi0;Nm$!6Z%`);S$W~R=U?CB5S;+T>xZhgf9?%zFJ(!lYYFuw2?@| zspCqp!7%@E)rqupkFph^#PN3pJ^=2KcX71+=Zc%#T5u!9yc2O-UwWsh!@v6;WPX`( zi%{&kt=sJ#81b^X-IH;%|MQyllI-H~3{tz`xl^_ddQvcAk(>X6qIGwpa zqw|5WSLnD{*_*F;Dk+5r??N{?^PCzZlPZs-o-C20J1lIZp%s!rMDRM`P?l~uJ?H+c z8zWrnnDFclC9eHA_Ypf<<;C3Tc7Xe>u*Q7VwC&n-MA2Z|Bx#e`KLF8I9JL0cA0C_C zf@gkt{Can}CS#%*@r;F906hFuojLA_F6JJ`S=G?4TED@dhU|;eQ@xqOQd1VM!dwoF zh{a@?W%H@4uc@k`?xtH$NIJBq3%_I9L7QY&{T1E^PtdQ@V>1^GZOt>d;q9_y-4qW@ zKI_R*rKI2<^Swc&??qhx(%`gEa@yU^udcNBSbhb`<2iA26LmdwoPy8ss~F1v&QSbb zPhq|2a6gu|--vUFr5CCxCHq9#?qq74?>x~$-lPPAs{DkSN~=3c(U2^cdUfY>fg{6^ z3zRA)WPHfX_~Y<`ozpYdFP%JO>N{VUg#=@>rT)!Hmf@y(jD;AW4B}fw6@uhggsuwp zf#C-OO{L6*(69WCDr6sU%VOt4wkthE&iEDkKYw-LGk&pbS8i|+6Pw>N{VkQgKy-2u zk!~26GbBzK^F9?9iz!-KcJh#y%eCo?7lmqMC^uTEuJ zklP1fCd&L23+D-9JWKUqYey&d{z659{#^N3$L258>9Wr~34(mbwuHAnUN_ssm@ijKFUWn@^ zFZ;MWD^F{3o zXod^7Ei@VX$#|h(w=TQru)XK+ANHA2D|JDCh>Cd7{9k$9ETbY!x7^ud1u+m?3~&#^Hd`4wSnF)*8$bz}w;;@Bq;N{F>mmPuo1 zG;;|`$hIeE5IC0ED2CN&^c@of<$vhW|enE&avL9TFzj9@>p>d zG!gm}?1;|Dky4od+!k8=cx#OF?U{R80eQ_EpX9`Y4l+#k58?`TLOwT`xr+CyEDR)#x@MJmx1^ZT|&XfSMs=sI+1W<+-#Z_C98m6 zhBN*cgWV%M7FZG&W}~Y2FGjRr6U_ro%Lz(6>7ZG36&B8oa;H$r9;(UXVpJWF?{FF0N zIBGSda{a5jR$9TMl$pUt>87Vaw0`x?*^w@}fs0fDn4RKWupXb&BF1i{ke2Zw$JOh3n%Mn%u?5Urpe zK2raRJjK1JL$^nD4W@+sD)JF`iUwft;s2hS0H{44TS}V#brfZlS9dD27a4kveNu$^ zIVCBo__8EReze1QRGboW$a&PG=_z=yD{2csO2yLixt~?rLa(as6zdBR>1wY9rhmm! zpNz(&^>vdZfH7*4@h8SdrW~fWH4-7r(!`Y^tHq;*4}11giOgR(a6U7PMl8@rwPCQI zb|AY7kuJS8m>|)ZnO@0l&};vufy_mR6WF*K7xpeD^dnexOBg&ga91d-0)c?4V-Ye= z7o#!|PIobsd`KL(G@_DIAy`z7G4>$-@cPF)$gHD1V6Mto|1#ICW4JEEXM-1YT6+gC z=?Nr5*p;$H`M^n}rAZOpA{q)Q;kARlNxS5zy|lJozHcCYtr}7)V6zC(s8p?LXMGw? z$s1P_dtPIQy)D&l=Nu{iBGn4}6dqbMJwJ$u(g(guU#>UJvLs`0dO;J0#9(?{9I7(c zr%>4tUU2NA3Xf$r$QlBdFDGVK58c~91Tdz#*>YSEJeT+Qml&D9#=KUP1thO*;=O$y z>sKoyPxIA8fZc5JpS&?A{wr3nJ`(`5o0%KxL@+42-Pl2s0~vC1JJWwzQpRvp^TuvU z6JeG0YAAJQq5EgGS9b17JI-=}M?x*+!S`CAb{DEc^J|A%#A-?07k0{Fc<`{XW!zU8 zQwPZp^NFMpb6Q1Kn9S1ZSTjw}R3}UaX_a4zgWyF63gTD#Ss{qk-BS0hL!P6FlCjS} z`~4Q+Fw;TdH-^}hOQp%ohv;VUQ1lW7_4?N~o3LOQkGsw_+6wQ9gk!(FwJ3NwZ*C55 zYY}zhVGCWPjQ>SwHl6c0_4BdcbHZL{vaCd{Zn2ya?}lY|z3 zX&m;#R|oGYT`SZ%w%RbB%$iYt$JK9;dEVtbcAYe!xKTJPFZ4s{;1Q3^Cnlx8#LaQM z11=wvcJ(X=H@NT6CzZDNbIYPUeU4*62MUbWjX?^dzSw6$c_Ein+o`>8%={%08)En< zmdi7fVLA^XjQ_z$a$h^dd1nOeW2XrGRvhCRBC;m~EwEQECjzx6{lsNJ1WzETpk7tF zg(ORsS+lq?EG6!7sunl&%`Z#;EWmBp(FvsC19!4hJv*4o3}#bj%ba`7ZN)S7{)rKq zhUM$3BVJPn>d5X5b4?2Sl*w=Z!@73Aw$Q$N3GLz9>o`-7k3fuN>sAeQ9n3% zbCf-==~j)F)j?;eL*c)@qsIlBB!Omg>d>F9Qouxzbp+#NNk~bY<7~}+w%3O1K0&%D zY{tiG#TT~S!;~|>!5?|As}qHgm$*z~UPt3|P?3$SL7Rk#e03&!t+n{GLPVPf#I1p) zJ?BojWji)tQ9m&Q`mm$K^RHZ^xD8IvtvwqW?w>Eg3&zd~{_=Vq#>p_J^n-JFRLN@) zV=6PTQDdQEx~Y1HDT$;CziP)xR+)Jbo&Xkayl`7{oMPOEeIxvI4}Zwsd|1@}QSW&1 z0YM1{Hvfyu?w|+7F<$I{>41L|xhD@&XFBcwO8tPO4$%Mpt>OJmvi=`$$L@n@D_bG? z8SA9}s>X)rPN)#-4KhR*Xg>a!A@&NqC zK|t?N(_|WluR>fWBno24H1mpMJX06y{*2uzloTpbpwDv}ln{kH`wA*il=G}~s*+`W zB2|!yURG>;bBJ_EN!wAV*d2_T6<+khVIztoSuEzp7_sNzu^inODer~33VY`(bYGeH zGv!k7Fzu*K@p8GX=6L?GY)LSnUv|G*J!1-jc!rLL}Mr|un zsrsYwByRYZ-kB(6gNmyQO7nPrC{fy86{6hX-9(-_6*t*-cg9EF5PIzPCTqNcoknQ~ z_>ub40Wo=PKSa6v$CR1JDdUy}mS;UOQHLq1-0kw5#njR)rkG+SOf&0G5CG)bsI zs;c9E2u%TyKE{-eb4X}4Tnf&znd34I?C2YZ zugz`6B=Sa|6fLu&FU?J+T4VnS2eh7h<+h_ecNOQxBlEOV6)= z$^AmQU-g9$MN>C>6usg7SwGxG)9D>|?|EUC=Y)*F=7K%TocYY2aGV}CDJlVIT7QvR zBCSTP>{taay219zUGQ}=MN;N5hi_&{hq9(v-JsN0X#WZ)WN2W`(4Wetyy+$+!yN9XGfVrBCiXid4vKG&X_&;}tGSfPKMqtex z;IjoH($i~vNuPx)B{OqkNMmLRGmxSR7TK#X5|{SA*MAR!Q!cXyVOu{l2`p(m>y0Is zoQlQlK#dvgw#jfSlGD}J^d{I5er$3DTB@AJmr zt;K)Ol(qoNaGJQHecT8X6)G(q-q{=W&{>9(3y?WMYatxjbfMcXPio<-hUDR$C8D;H?rL?)}L#Ti7S&8 zO&_B{W zpli^Tt>be<+b^^dd=hAWXbV*t2}U|i{RDG=S76!?%XHr<Fh6UghJ zS>5N?Dqqj}=DVZmA&egXdh#am2w{`_x%Sil+XVX!?MN<9Nce794s9_#CWsOi?j;Xu zzkYG>dqkUR0QpxQO@uNDby%t@xg2jtj^vnaLYTi!#lcS(6NRVU@bUm(YD&MVh;oW$ zc>bXpU9MyLZbgj3`g@{oGIfG8HZPUFln~f9W_Df{9v-`7K$sd?1-d9^tX165vDFn# zG-f)B)DtmuMwzL|QaZ?|hd zAzPL!M9tjv$(jZQw!Td#s~ShLA^l_Hb1hH$K#q)*0*x{qV|Bdu%~&^k3{Fsph~^Ch zGo!b{DS=oAA+yPyO@vk|T4gNxfr``735eAzw9jG7A}!e;j#+j2ft=u_ zsbH0PfEf=>^?EWG$p z&&xZ+Znf;iGgIlVH&typjm45hJ?nq=>HB6C^hBsrvJvTugxfINu$r4FHPons5E11s zf(Nsb#j2?flM79tS4kU)P&*pyz&am?*@N62p!-45!3zDwsTP&`Cno4zDK;)`4K3ou z$$sXCcW2S)SfC7Muh*ol?xQy{(!K7lDQ4&PypUgTyFf#wXC4!v71iQayFeG0P(|Km z_P#U)Wy~d%Yk)HBO$qq~OYj9)ec}{aOD(4#_6Z+AX#}Cs%QGVg=Sx5>F~;BpOCvma z2y)p=@Br~7AXf`fSXuj9-rWF@!Sj1xt>qc271QmJ&s?NUIpHC{z5r8gKLy)9CJcw(^{MIOM|#Q!jUJF<80 zuEOHy2N%2s1KHCyGANLl47K}H2^DGPvn)#8=U?XE#mfwDYkbj117+|$eV^++$SR*E zp6g_G`ix`*C7y+KhD&}1Gr8e;E=L(KVsJpmPy86>(<_d4!Qv@~3|!3;MB593X5}M| z#Cj6D7qQ2BR(}p+O@m@Y$;tpDcIM*qB0Bew%!H#2iAw!`< zaP%5B9XCy(o*d%CnGCMYcl^lZ%nN2TnSz5+p9_G)kM*G`W+oQi82okJj~6aL9D7E< z+MA!~2diT(z)ih~5U;Suy$KCKAj;%#KBT^Z#?aoToDsQrnGs-8>&U>Q=Hi{MPk+5a zO;j<_`ka``BZ5sEU0*Q@W)vL=km!;Mq*`N|7^a&$CZ%yE@R+N1dn|9#%Op!xgqD!X>6;XUWZ&j^^` z;;!2kk>^Ge^*D`aJvOCln!W+3y#&60(hg@6g`BxPz337%9l#AONn|k!_`DZIbzEG6 zX9mOscf#N_pJ__9A}0|8fKERvEkdjw zO@mb7yx?eEEgTE%+$V0q!i*sJ=W|cr%BActmWMF>fe)R7cog>SED)56u^$1SR6S6& zLRgYo(=XNhdmt#E>nKg{Z3itXiV*T$^L|1Xn_~3#0Y6Ew8)69FuXon||Z$L!QNQ?+;>8ajnBI!=H z)uw{=<10mmjD;V+AoX9i`#FE|%6oGpQfO|b7Zc@yjM)P~090(sKfLi~>x8Rukw^$E zCyX8HxiW|1AC_vw4;v>nb&5GMXhN%qqfk&l87`!aM?JURr>YUh&G&cGWqa8<0G&fZ z;``hF08|yRKzip))-?HUC_*L{7cuvQ;tAbL#AM0`6N!9wUA1@_b2&_SRm5}ws1T78 z*OT?@ERV9*KRn)B8JeXYcL5vxPEUY)AqhkNHc$LBFragTj4u~qd%aXA@jmT)gC+&* zGfuVn5Mixo%11~@Pyg1Y{w65Ebw5rqO~l{ZwA#Vxn4is->bPyM>JY#5N*DD)Nx)fd zJzMy?BfATHukyc>o$Uk&%36P;H5H%CpZyK4w)&FBZZTfx@8R()Ueb9e4sF1f`WBf_ z+z81{jaYuQr=cXZ6j5l;fZEjl&i!xehl3%g-BL6ihu;kQ#|>sLMZ*^Vv-9!h*|O<2 zlW^5kcj1}+XqTn4J-S)PeOx+K3V`n4$@2(Njc{r?ijT^P& zHeM&%gb`OVX@81uRxl?kr@%`YI-He;!zB>Gql;Rax>DWm;3WENP#(%uhGK7{?{0eR;ueFLlZyiGnU!n=gjioAEq*9+Ub{%VV(3 zvi7+^uc&?;i&VCFnS$;_cFn2O>?&mcI z=I>uvOqFBB4{RdOH%K8-XfDx-+Gl!Lv34_#8<3r8pVw?y3?m;K&s=n2;s*7Nn1N7b zy@6ic@P|hCZcALZy`Q&^`wO}SEvas#xq>>RhU2n7Ao#rCB0c)iN-|U$DYE8)+2qE< zuWWIs(u>bHI?h?`^ZAn%zDiy56qbilrB-PryqY=l49$olVRV&jgO?+LkPlNnsD3c% z8S$WijBxMZdE^%k)A^kdRGsL}`~=!IKFz3nF~ISJDf|J7pJ0ndZ1{sSY`qk`6DVus zV`+BPID6_fCK=Y&OJPaTc=U#N!Tp)0^!a(*5$xUdJRr7e-I!~IC^)GtbLpb$tmCtf zg}xze_`ZOVd8P%Tq~fY+iVPmo*wWJe8mka%+GkRbv4D9(Uhi>a z{I2s>r7IzpC@1cBw?42=;kVLibl>_sOg6O^9b_{t8HO|*@F{HOIfiUX!$;@`8$CI` zSQmEa`Oiui2{f0_2pD@0DJmLs+@}0<6}v)t_2QOKzcDux86l!>+eh&78d62ozS(B4 zdJ@A6b>*-9uy?bQtJ~f|KWn9C`nKPs$W2oZ6^E<4nI~n?$IHhdam!VA=IFsJntxXv|xt6(dw!s}uw1^&YbH%(;njvm~ znr4ToR}IR5<0vls2&)l+d;?s9g-gWAP&MDiJ{xt$^Oy1=FnUxnHwOP9uKWI#{6lp% zh@d8xIN|&p5C$7w5%T@|Tq~vGm}{7M>RBjUiQCyE4?T4kERo*@`b!vSMcc#L{Tj6% zJZH55XwxF~uYUxRQ~?wLNQh- z@j2)7puzAjv~>i&1fHG6BM$y|`8^4$nogsd_N(&wQt!657-TqD@k5(ZOKR3Qd6*T1 z0G&v%?xKJ|OLsKRDYR~Q#df#XD3p17$?tJM;Tb`0K`+BPr;_L>e6=%N`-7m-OQqNLVgnT!t9T@*@1u(Rqm9UP%11k2R@zmWF_{B0^Gq_&0-~r)fNkWM7Dhi-Ks8 zJrq0_2y+F($eUi-Q~Pyb0Xz;Pw5e1HPj??X1pTh~ZvET#mMo=VIpNCah>#7^@r8Q? znk%FT7*FL{QXs$$HL41da#U6HCF9b$281-e_9nY00Xp6bDa?k?AusbzLoKPM7mrCd~GZy^HY^FWA{wsIbN$9;+zr1J8(O0v|j@ z!m4x(4r6AxOcQcBIpR0dhV#F@n82fR{k+m7q|Z*Y|Dx|@CY{Nx#;`9;wU+2KKP-nI zjJNXF-eMI1dvkJuI!fU{Q0ltHhd{5PW#1RAOZda5w<}aZz8w! zx1v6GM_YLvaCPSUdeZxMS05WuUHs2jPL~y?et1H+$`im2_XeKN;?r})03BxEV=d%k z&N1Bwga+ZOQu!h~F)N*FQpptyS{FO8A3r8Tr`Xgx>H_Pe_i&@FIr*=?Shqf&9FGnsr${H+s(lvvLt?UCwa?%uMq{%%i1Nww zLWM(M?5K?=S%mQZsW};{rD!>PJZfhRFVDeBi~RE2CsmsX=!l@JFqP-8A0?Pi$DwKy zs$W{fUM@-7Lvve^eSy$xqzLkr-(`{e$SA3j55Ip$-<{!A`wrtdDP7?unuhk5z<27$ zfy!gA57;!m4!4l{i&0PweHpG8De_aXaWh}s%P+2Fz#n!`ZT~E?Htqyn@E4hecFo*S z3FYwb2J`M72x;@P?3jl8t8D*{)nFdT%Mf*A9R2k zSNt;(8b%svscAm&&3Pz_oCgu%BVTiQG%EttRDzsd5teUYrUB+C?_vzhZ#h>=0hsdd zzjFEG#_m_gH4sp|8LKrcK9wb16wF%}{uG#bnHu@@aZqx5^#m<`|J;5R(D6ykR=P-g z@KA1?9bLMgs!$PE<>0Q)h4_gdhs@IRJP~aY0g~!3N!s-Y;e>9Q8g}@Cur+DAw`AwR z)VSSw*5EJ4z@5m&f@tF~3)Ql|m((q1S6=)C!E2NHCSxeuOT6H)dO33zZ)QN`FxF*p zG+2b24yAf$Q=n8L&E%SclluFz7p%p=U*%s>wtx4Df*iCar>h+(RK5JzuRby zNjmg=x>xFen+{9zgO_{mYiQPWdEe0k%+4XBc_}@;=&9H17mY}7R1&Ott{m=ZEX)@A zM`d=VpbxHOEw>Dp2={t+N}?HL5T%jghY)V#_J=+Y*_$@Llr` z(FW`2Ha(r_v`H1nQ1pE+c(<8kDU^7-!9lbQ8`?d2PU=5H+6ISKmnk|G+%16yDc5!9 zyy`RO*>vVOW`I7{1$Q{&o%ldzW1~jE-K!X-q}RobW;2 zz49DYyzVWfjI`k&WO|3#*-4Ah%DiCz!U{RVk9Spt@3l1>ba@oI@$8_VbOL&+;3CJ1 zGfTL{U)m|GTeIY_u*_AJE+;`3yV-8*K;jUB`*An!z%+bM3HIn$PiVc3=r`x{P){#_ z{SH}_5&lz|=-S0G<^@~Gehxp|b0;_LX$5LW_C!nc$wbftozz;#5+KbhQLVxIDf;&yHX zgR;m?ye_Xih(;%&&#hf$O%%26b}Pkfj=w`U8>^5&dO9RE|DA<+Oj69WYcAx`-pS``F0ziCF#iG`J zOG||F*9KiY$c>`6Smw1lR+Te!s5ZUcQW*9(L;G~E3$!Y-B)%@%#W>*w^Cck4p|Znp z7}LhJ6=FR$j~vPeZzx^)ek*jZt4W&t58tmsDbU)lN8Q<6)2xWj5g@d+`F!R*1M!_@j(AH_5ba?AI9JZ190Dmy!ZI?n=YNjIBC|@xP|Tf8ep#aP8laA2RkB*5SStcWmZ}y6#eV)?G%Tdfx?;j^3JQ^2I<&0>y zfyAjWdK&@cS04dRf*X|aL8&4&VaedPWiVoFLnVcseWg+2zmvb@T!b*&=xG+k7(WJp zX>z40K1g7}X1UBDlOk^-~*&s(;!ntdU`Myk!T;C)%%A zBcro@FK=ZZ!~Ew1*&5};F4^CL^jb)u;Qv$%wfW8%Z{W6^vh(nCis94d!+PAg$VaaC zeJP2BolPY6Is-q39q6*Z_gXhTA)n-pI@vg^Me7$OHQs>@xj#YyN?E{B4PeLEfJZck~_@k+*X$H z9p8IIoehy(f4%Z+}rYR|)HEv9ZQdvEIfgk8hph2`i3KC2mYgcENe;oazKq zA8oh%d+ph85I`RF)IiJ+pu0rOlMw%+D?|)GF*X!0?@LqmUoR zTe75k7cI*_C=fKY4)g2q$12Oj@K6R$w(?hQ zF#x0#kltSQi_~ujVYRl%z1Y=XjW+!R2Kuhi9n+C4gV97g_OBC-{{T*D#G8#?@l*xn zaU)Y9@zTXh0gxW@e|8&a!d-)%iN=-fKS5*~5ZfhM+4tb8uj6j2KT>B3g~o74?5>%H zb(DBDRiv@yf45i?I zH^(q9)n$J>OelPO=F6@ym3so9$mN-oEx-!w(@;J;#x(Ezxu$keSO>(HSwTj<$LXnq zcT8PwuPgQ3qTbpFK6s8AAL@MFOKOlFaq5fzg2THgf>mq*Nq)kHq-pIgG5Tu#g#1t* zB`+pk%$fp?MJG{uvMgrY`XF(BGg-}=8(hv?$xy8^Zldae8U@30{+HGt1qps)zGDC2 z`0lKh+%JX(-$cCs`=g!TEkEDyz$9=CkUxzlJ`o&RD5GRIt42zCNAZ+0zy_#MJlF?< ziuQ5n)f`!);psIHnWP5}SV`k86$9@gLI8{0Q_>axb~C=^_S(xq2hyRBrT7;C$0>$v zp*IgYUIp^(7mdDgo0s#R$kmC0W3=5igb=@L<=4IO z4RXla7w3R%_h(*uB|dx?z_~0rxu=}xfWyA$|8Oo0(lzYsWZJ33hLid~(6h+A@a}Ip zfBpYo7V+LEPK~B=G5jmqcvtQK2lRVag5#%JHOI2daR}}daaf_rXxz;#ry=7-FqSh* z30^eB&x`vd+C6<{;rC-XxHj`nq1#!Q@DH_|N@PjQ*?x&UO_YzKAO z*gls`K5$U%evr-t6>U}kzLbYj6bQgl*umyH?3IIUnP4g#efckJ?KU#a{~lgvuxCv- z8{-788nCNTT?ype6BVW6(TPtum&E<-)yW?32P#Yz78!+_l^*&80qsu_nkNvywy%@V zv(8N_hEblH-pmBzf2-#OSA7?F4fp(Z7g3-3=ZehRWu1ucVI0Q40SY;Jl8xHa3+Rky z=EimfoJR z0rDkcl6bsr8TyJJ;^)$bO%UjZl?|ZnX~?G2J@{cL(oT)0QbI$mbW?+XruR(*PoK zuz^$S#K6Wrv=!my=GJva<{!Y9{ojKU{am)@v?@>j_S{nddGFv&%0v(Rhrrl-KL(%^ zDT6NmOQx(crI*UH@3_~8;OtvMOaA=**mv+zh{#XA4W zr-`AzLh?T-Eh2<}U6`JeYq}9kyO|7V7IRSL_?j(CO6zVxnCk@l0#!U$NPG8bDQ=9B z+mFZSvFqp0pQi7ZFwc^K;FKAP=BIP-x& zHnn;7HY7zz{d~`a0+Hd*&}vhw<_AGPCHvGuY1vfFmfB1gMW~2a=-EHZap$ zB#>FQ(IMp3hmxAy=^VX0EHC_oghrbavZEz%d2c zsUGOvr%yB?`A$e)*k&`izm1Yx5amtSVATHx*EKevh&&vjRb957N#<`F&~kwn__R4G$uXVN8)I{?iOrrrxT;}S@D zo_EE`q8a<`c@u9nlSey23G0!4qqbnl3w2-b@f5yYe{w0LToJ6!@}S8&y=Ok>Fk1Qk z(Oofh(R*~TvGukY<^}s`bI>)be?C+9doArNFFfi1(+CKP-i@jt zo!=1;OK@8e!OjDpD{Fv7T!gj&k`1L9UAXLTUr%9z%|YJzx6$tw;cA=csK5cTwi`F# zpy+a-EfmP+yYg~5Bk@tDhd)yqGUG5Wc}>MwsfMbZ$#fDgyqcv<5Kpq?1&i8&fwM00 z-aqS=`uwJQNt5Z!ehy8+*;EBt0V|6+_~4LxOPwgFES$Y{Bx{@VTQb5sg84}cIW2U?|l#vj=};vZQRu2sW{Y; zasw9&M95J{k2GW1RHBSPK(Bg`b=jV3?2ZuGgLr@b zQ{mzVbm;C^<3E7%Okg9+&2sl?W(rXL5^VKoLxq13b{p&AtfLG@2v7AYb|?T$YM;#C z`DGcP7aMQDrbZjTY~9u^*u1CwyrR!u%}8SmoUo-rgLK#fopiuaVFw_UaU;7k)jZxo z3yqXGbI^1-$hHhFr~LAgRAf0iKe~kskm;Kg4UiLlkthJm%jEh9h4Khiu2H4ypV?Ty zq?J2^4!7f;N-IUcZig)VwX9N*kB$!97~b#oqdZ|BA|TvvkYCt3zEO%#y=(7z&C20j zd7jkJn7ibdFj>^o+RgECn1G61JrOgH6C zd3KVtbiDCAsB3{=yTLE+Muicxy)Pihi-pvb+wFs}F#WbXhx(D@$ z(ZBPhkYMbg9wROuaxBdJ0+dS_VnioND{N6W$B6Cy&6WE;RXaj?trHzIDz#sDckctG zn8D}6&mV{%{GkR$4B&bj$M4ro_Z@a2@8an;TLI1Nb3j)MOZ0=Tr0@_j9gp@2QXUUc z=wr#q<9rPY+21hJod@XdFqwa_KIgjPrsI0jD--#Fa2KAb@BfUy&|+#nF@z6=*AZ&@ zM01+zzYi#=So(?ls4_~QZpk#^=~F?ebB!YB2Mf*Le2$(h93u(vH%hyiZ964$9J5Dx zI)P!vO%l64g!UU_U776Ct1U4)m*jyfq(9V9nry8Pj?MPt2>e7S<>(JTe&AQ8kz1fk zedTHKc9``fHm0!L%1|$%_<3qBqWlmW3>2S5zR|SfPWv`oyl_UM9u$Pl%Ak<^R8&i6{(Dh4NylKIVY|7jCt01mK@U`9_fGx{$sj>q#-7W@Hf zjN*vb&b_4ox{T8$`7p5Kn3AqBAuxcrWySkOjQECTuy>4Ln^!tYQ@= z1m7e(ZBoq>Dgf4tH7`bgTVOs|lh~&IdHbP4^>+~u%dd*u!gY^tjxUj;x7W~s`pC-J za9DQyW(HOlVWbo4wx8L8TMy8KA<9GR(fc-1t(zPb*B5(A*D57;vvj5!mqb?h_(<#$ z#>?${fXQhO-A_ZL)2%~=Ucv$wmyp%japp>uBwfS`YgNJ?C%-a#@lS591hnZUlYO#I zTvc3r6X=#s_QE-3isIwPwcIpr}EzW$EP{#vf019T|fu^q6y za5Db{UFS4J)CO$K#6s2c%Mlr>eg$8L9P6HTQbZ-dBL43?_0xNdWZtq36~0wf!fySd zCAWcLr8k z3S$4C?Rq?Z0yP$ioI%ISMzmback#6AummBM!DWtVfF)X~>)6)fTs!0<(rqkexxS}q z_Z$=t$nE7|)y8EWYSHU*0Ad=M>!y1i^9_#WUsq!5}4C~m2u$Fwr-X(z?N6| z8%=jU$_jkuHPy=)UHY^!MNA2JwMYS1qV-6Z=jXQj+REK$t0bbp~W5*Ba&^U@t@GJet1Tg%P#Lb+kxzUs}q>#7UPGt zEo3C5a;WE##RuRIeGnqg_sapz8|Q8PoplrWl0%VyEV;j!Y%61qE_vKu+x)@1eis7B z=;Mh#=ber%A!j*JT=g0ty%IoorL`w|1oqQq&88SiNk#mSZas^?UoKq%lrt?P;E=RQ zIM-s}55Ur2ZYnn)`ecs-m8=-S-&HHWV{LwKmnFaQWIRN0^=7w(Ohf~X5C<3OgnzlC zHktpypPRDQk85N=;{cki8_;C}5Q$xr8)`*&>N{H?nK!C0zaDB=798~xxnsTm>CJ?) z*S>Od`Dgggb*3P4F}op22Y#wF6{AA{q@&?KwEb_lOT7fBeewSk_my!`b?@3Vh{Vt> zLrRxQBMj0dAPq`M2ndo&$Izu9tsp7gUDDkMN=itlq%`ju^#7daob!9%uPsoNJDA^YLPym;vMj>zwc8|deK$2*$qYTK64$!+^!?KX~;Yh|wSq4k}Frs(Sdi(M;*%^3BBxIn!RgQ5JtWIPxB6%_)1d zJ6KAXM776Uo;tDclE8nkqMJ#=ifbGOP1wFv)GXl^^ZL{s4Y(Yi-*=+-q=Dd%hXy?c z^6%}r*c~yzv?%l^GBTx)pUc8$w}*UlL`5-<#zSe+;`u4?b)PadRzuV$Ouu>IU-Af-HFR zQILX8$S0s!I9z}xZI6Aq4BCmu%lUR>{7zZzwYuqnk!E1jp@;3ztSvqpL$z&+PtM97 z-ZB)Yfimc)LUv#r%|B^LhXfb=y~ooZ{AdSXLSh^4N5v%QpjMvW21KSL10N=?(OzUL zL=o?KaW6I*Amt)1jrMIo<*biHLMHrOAEj>AbEVvt0;vcRL9DtXn)=r}6HW3*+Mz>$ z@L~}!mF~wn)=oQTL5n(O>On;$Vkh#e& zllojBSj8z5dK27lOv4J6@;;1aE@>4iLU)BF_2Y7r5vJ!$&&qkHNdsdk*C zY$JyGBLNEQ(l3AL6+X*J*3RHkuOw#Z*TJ|Gl$aEm`2|t`fFS011H!oKumlXrHr<|O z16M(&JpiHV+b}^dnLCKysvyN-m*W~UeNepvmaUZ!%WkM_W-Y}FV3~F84$&!%aR{hr zSBXw8YL9*tN2`YpPiLjxqs>!zAKKBuSh_KiFBy1~y*CS?ZHH3+HV+m0_q)(x`2&9nELNiA``Cyo7rrf{i+6fEB46NK zR@b+_hmt_Eq(Z9K_Ve-C37#MCxz>J3G^-IP`$y?m@W*1a6+fmdZRZ;rNzaNb)1)X> zPpH+4T+2(!;UUU<;!7B5$qz&2Y|ml?##?tZD)2|t$oD6)dYiGcu4FcN-&=iN9O>TD zk9DTkh;@q2O(tU-d6pl0I~LrJN&vb@ac7`n=>i}58>_l^G9&NduMQQdsOs^qORwng zK*zL9BJv;Cn#09)^oyjsjmnSSfHG_6jFPc9W+^wa3s}CWIkjiwZ|RPIJ@ScyZE7)l zUcAIA`_buv4fn8?;Xh{%0HR#Qi+925l9C}5q>3BxZcZ&Ia%N0nFY)$fQ5DzX5J_G* zIf=QXGg}v|##QlBziuVR`7#KaM5HXjLK~|G_)pB2%5h9;DY@g;rZtq0%?< z{s_;H`LROr)X<0CcSm@Vz-=NiPj-Pp>v30rLIoy^Vn?$sMs6kL)116nmHg66^)Y0n zs#aIZ#{(>xMZXuJ)YpqLuD4NdN@eAyAfV`ZH|Gu+tJ{(3O2v{6={32BI1n9zfjxzQ z_1%B2&w(1_x7=n2@Q{Ry&Vnao!bgI<6E+`~DDpd5?RQ*QCp1^aMCK}1^wzp%KA)k& zq*5-QiAKV$j46?#$+hwgD|L)`8}4Zo2hDN%2yz-}{}5G7CA30ZAt!SgdZpi-x()b= zQ!mp93oG>+RYSnFN4TzzPrP9L#tm=D{`?@<&{Gi0cb}UYZJ7hh#+y#&h#Vc$To#_! z;8Qu^o1-^*>dxJ+U&`*(7%qO*IJE8*z6&uvMetT?@$vTJULL$2jAOa&~Wb7 zb3fO+%Ltq+(-(#t;>mO9PK21IT&ornrrCWoVeJ>)8+vzIcp~9GxLAI^)M!zV^S-Wf z_QneLY=(3xx_Fd?(24I3oDt7o7JYlC&Z;S+WWcOB+XpC2S7mIw!zOnp&CBaTqDR!E z62_V}E|@s44%6eFZKXu zkzvNrNZe|n*4@sqXZp%#q=e9ylbfEV3)YN*)}jhbf6@<`R9}rmABD&Su-%sa93>N3 zpkAS&oA};X{sHahG=KAhL&DefF)^gda}9@qOTPZy6w;1yYV=MyvLh*~l3h~?D!;dm zzP?r3E>jIT)+;o~=~(Tp)QcRcn09?IRTy8OTZi6Ho_|1;On>=USyA24?G1QMFgq&oK%aCM4D<7TbuIJQw`}xoRrSC;zE~-73?(L3P2!e z+06hRK{wT`Hh%4EP@eVh?Lm9XcSAvy0fM))E(|Z+lZAOKg`ZmmBe(|@F21ow9uZ|s^-m6tlq%2zv5bn3KsdLijItmcPDiVRPElF&5 zt$|0h@CV>4;lS8sat?xqrGlSc2+)KbmROZb=Nt^g0h9Z0i3Y>x>%oztQiT|$L|Bo5)8qk zUVzPfwAW02)BRNUUO|Ej1{4lU1s+O%G3l2Jp>@`DNyYO??EjEUQqmP5+>B_JX^taA zVmb@V?)Gk5sndh&B0x6e*@jwL9Su&N955aODX*Xy09ClC82h9%B2OJoTp*K=8ZQw9 z$NM;2AFxsYaB5|@hX77jNTOG(c2Z_dxqYb(JdQ>9Iv7MP;m~3{nkhs`QXL?wDrg52 zZ{ZM@)8J6N->gz=yW7QXAAQY-ZTn%+xN#Z2gh=U_riNv@f)Y)!5MOX|^rAGlgdoVr zvuMwV<^q2KWR$-22B)}>gP1tq>OrH2M|0T&fVvTXEAh&fAe9xtPueP`u%Pmrf!@nO z#K@gLl!xP=P_$k{jr(bG8Wf>nzb=HhvfW^KIQhN3HrLy4j&wsWr4(AU)TpCK* z|9v2ShCbUEdGzKcATqhvDTT4XB1a0%5o=W2f#wP@`6&3@j@R4Wrc^MA+$2;WCXOuD=ac{<0sW`d#PG)1*3>=kVB)$gw~1uP1LZpN5JQjhFJ;+SZ!wNwWL99~prLyj^4bV26$$KCe=4>Tc!6n%wf`<`oqr-sT zI@aa+6XcY+UN#aXe_gq_F=1RKU0e<48jsO(rhIsw>V056^ZW=aZGq{tIV4P@*;#e> z_6xDM4({N$cox71%p3ET@vp~McU(^$yS!Jj-*%8iHr0(Zp#D(ej5!K3p3F=u2D*Rp zQ(fEBH-MA2G0yNoB+#9%6%74h9NFHo?itUlJuX_RYR4B5L*TeOz0VKhox7B0rJsQk z@DNjq^e#Vi;vdb(OfSr6TaBGB7U#nphL=~$B4oDc%}!el9?HE$WT0z)!c1^L;n?FcaJEbwnX8K9pAt2D190ek<{9zHwp zNEUr~2na==kjx|P?^vG!QbcYnn0s^+|62e6;zF%bPFEzga^5)q{TcE{ML@9C7K%hl zMTF(sO}01J@`kd&VUB?4ZaVdBLYT2hreHxKT;k2a%q@Z}qS`l>Ts;77%~#o=qR$G> z{A`}xwyeOn2kN}lb2X_?e1+NyPGNRxQFdp5sLc=6_$Ts!y;wY`-7l+CqY?c0B=Y8C z@v1w0Q_Fi;I(3l)Kl6+-=AvwZ2-!RoAOY?dQbn}A9sLVMU6kl5KWPr5cVF>Pu>>R5 zd_bcTX!Ga|7>TmsjV{jxF=1cI$6P&ei2rG39L$#C=`CfJ%8&IyGb3*}Mp**5j|ubI zpXDf0xnh45Qe1Z(-%tCRUt#GKMs>?dy7_(0owX7orc?7V*qg%-(ik0KDlfCS zY-bKO7HI#6ix%1A+#v(Fu(dNZCzL#q>_7@%m(doD{*=nTx!24Z{-ogtY}n8rk0A;3 zMY#12^umJ=Urj))_GIg|lX|h>`8;y&HDF6wT(bD~?S7F~SM78FnD)Z+yF7UPW*5F% z4_mhQN4vy~{Zoei%j0bUS`&`NW=C7ImcL(DWVtmo{*YW!z6Gy3)p4a0IQ2Ow9c2mZ5?a&zgRfx5=`=vrcbwr?Sv{wRdo5tAy3fccK+;o|JFFa)4n>sasyw8-GL`ty|@9p-PWmH`H^s!(IvN z)&}@;ug9O?P8fWOQ;IXC#lnx`AN_^CgBQid&7yE|5h|!DZCe&R^rl61Nip;Z<9CgH zA)RHnyC@7S33@nT%;=!mj4QtrMg)wUw=Dstovw=x#1Mgq*Fh67A0a@mw#G&ja^!s) z>a`F?fMUIKPJerj#1GBTmf+ zO1(83#*-}NHLe7?|1eIz)R)yUDI#kjspV*0{Lt^5bUH;5af|A5j{*TNYGGL_nwEYa zq(G)k>PvvCFmQlEdNB+m3<7ln?lyx=#WLobf_6Jyiri>DVy>(F_p`VR!SkcPEwrm( z$1vHMctGRNeTQV7e3$%pt)$=ovRo=nWTQ~qZ6Loenqp`CaH)N#r{{pm#wte#v`zq! zUr7|PM3d5-GNG?rVr8Z}snbFO80y?;xCXt!Fg$4mI|%Ng?wgFJmGl{UJzE}3*jpxMyq{IUpQrPg@g%pO;_!y~`c1BLVYpJJUvqTg z;Yr$jR?rIZ5cmGlT|MEx95QYldY7y%UH)ev91102E50QPQiE~%wx58mm^pI8(>`lP z%n+Uu*T>dMN~D=M(qjbjexo)UM9m&NGhu4s@ZrAJW(~tV>NO0@*&o)F|k{vc4JvNfN=?kXq$C)${$x<0A=qsp`ff94K zJ!go{oFO3IG5^Iisliz;HR$0sOW;080_iI8Z}p2n12+Y}sL8@Ons?p7Y|PD!N%>N`1ZLn7q@m+dPUev@N-ra%v6IPgnqt7k^CqG~rbep317^{Zx>3wh5ZYii6 zq0!qt8g4kr`EDW(m<|2-#)WA1R@gOM5^m-}g$DRte)U=m3+8`*Ma98`M+;+4VhwrUd z===_m<7Lo^gFrc6Q3udaehz<7@cF{yA7(-AANjZo@)Ms5;_!zTd&Bx^eG?;s;zu<* zkSkZ0=K0(G<%$()wJ8phC0T2`(gU-go;-`{-tV5g!@~>EH#{gSAVYIjFHSu(dwXg6 zdTXQT(M`T=$_P+)9~E?La`82A-1MEQZkm4twsQv`iESjhkrB}FU*`4` zi!}36Yb3ODghQ@Timagj%KhEjh_aV(E~ngRcll;Wqm6|Qqul-8dQO6@S_IcMef-?* zGppy5=g7cGc!`ene{>HpPnB(j<@m8Uo@vwm_6Jij)94zY@`En~S=6*BH-wu;4sLP; zD2j|)H6x~^!C&-ZSM~tXo9caJPt?6t?E;>iG}b>s`KL1A4_6HUtRD90w&}oE+Z(Z5 zLl22f7r{E2=F|3TklrrU3Fju=AR!VDSmsw=5Zu4Lt--=W8E#TtP1kLTHC1cJca(^S z^56`VCOn(L5b^on1q|SQgOYySm`&21we_)UKR|#r!F^xE>gqRF3fZ~ORo0L&Nw?3M z#I2_qO6q~DJB^JQO{j+`i3LHnwSykW%Tcu3DZl9sQXCCOyE+~b^&exJTFi4Lc>EcK z{uEk(T{*IP0zKoFXJ7O_r~{qVpwN45`N{Y{f0!E8E8`i8;EKD+q5ix?;WR>#yAFKq z-OZZx&=aQsPp0R?uG(wNwBrwoJljbTY@K~PDJZH4bLD@i>kH6Ky7M2~g?;p#R?fk4 zH0PnqpoA5Q)iWv1T~Crj=xNX_y0`~s<&`m1DTHcx$(gDh#OAcSk`%j1`I*$bCvsb{ zCJE4lr0`9P;61M!5QzVV>F?C@AA@>HO2&IqiV69t4(Aeua4<05 zkFNpLo2mS}xV65m_1_d)<=N88Ye=UEh?~LUE$U5l`2A{5@}mTyClXDmB?vLviW-O10S|d=81-^T9q#Hh;hGA`{*{WjJH7T0ybvr;*wx^Y^Z?VpJ zbc&1zV*Yk~3Y>WlF7-tO_+56~r-z7MZJM){oey7n)%5|t+?=&c+lutSE!(kQOIN@Z z=W=!zeI0erU0?Ru55fJ*E!*iR%M80vNgG2)qGwF( z<%Z+KHs8k6$SIuRk@iSUZ|=)!hj$)aWXO(NjJFTYx%;YLot$KN+l1nDU%=!CLt#wh z;Q}NdeQ4t7&EW8P-%ONSBnSSj08J`o@~^IiVymB|OSGdvOX-4SGCtFt$}9BHv68x? ziA9aY)PBZ3^b70w$@wC#V~=;AZ#CClqG1mXB5^z`abopx+oKYpYPR&}=@_kfW zHE#KZ!$OsgZuX80a(#?@-iW#qE=)(`N9xo@Fl!)*a4DPGb2)Ck5>{5A9sc%?Bmc`Q zQKbq^t$VsV-TW)G*FRSG!(N+~7acbjAQ4hI6=*TK?=4;QilpKj3UcA1Kc)0Pg-3ss zU;aeaGc`Z4RYIH#rAEV-p=$ty8+slalBf)0IdeJOObJBV>Om^Rn2O`1KwFbaDh=!D zTe?B#RiRbxKGdqW^(%Y(dnto$?J)x}JRdYrDLofnn*#}SwP$lomeZL;L7f^j@0&92$Ibsf^I`gTt`Tj2p0&ES*sTsq<=P}h1e8LOPt9u5^^)W^wbT{{{A<@tBzV>u%#?MmZl6{rJc{lt4< zg?+=s9;z8^ARN|i+nR9&i@;A;3j)gr*JlAkqfM2+Ar$WhvoH2v^356WqU8f{JRJTD zK6jl&e0{Rt6sz72s;gZqG1OxsSSeHMN=&6-wdv=zG|*qh=5qMTVKg@LwfQ?xKMu{y z1ERLFcTz#Jn`(utN+w#4y?GN&9hXD)5JG~l;Ff8D2>IC!rOtSdS^_D^AC6Ed%t%=P z-mVsQG-MQmKH26EGnr8xYu+(C5 zPu$$#y{oI=yEE+0y>~vbfV-XN^p3c?Y}xf5*a}I-chXu#jg}L^;&A35S8G@AZI)x@ zoDVBkyIk4OQS|uSLKYpb_3f7? z{ZE=R*EYJp`UOrg(AeQ0HP<~Xznc=yI<0aARyUm!T@{=`zNPXBMpcuq84HBpVQ^O$ z6u*51bf&u5LuIWexo2(;Wh7d1Rd6Gn1?ZAy_h#?IO-{w423>E7Fk(@_A79EaRAq1S zIe?1Ay(~(pKuge9O;gpHeNJWJ61v*Mzw>ex6It)+gSKiJ%1+s=;4fx^3mcq1wk?!y zLZ_#?s)UUVWguJDh+O!1@dIMaLU7VR&q!ilzpk zARwwaSSB&lMO9l4K<^2$j)d?$$iW|(97?gI|E{$>RgzDVEcJn6AlFw;TgVZVii;(m zsn`Z zvQ+PR+)0eGoX11i5NR}>UOb3|i5JfHM`~el4lmL2Pnj&0x{c$bvkLZsTPyKzgdhFF z=m?VTdcNZ{l*N%|#pS@DOGW)b#XG@kjZGeX0m^7(81Ho=e!aQxo8SVG3b4IwYf4kN z9JCk|{iXTj7f2zAzVn#qv^@mQDH&-cRU6=C6m2uXaCO zd=-%sNRmoYeu|-kz9zy#<%mNigk_%Ibo%xRbV9TDnfChJ;m0_<)`2?dGv?NA>(CWE zNhunN7i=c=3g^6y>=To%WodB8_r%y(!C9>P&zAB;-lJa#JdR_vp(nNy7wmiHLvUe3 zVOd$cHi=sn+st_fm=L!omVK)OBFp%*K1rlo2bDWS`ZpWwr1n&-?oi0E1TM@eUsi#kUVg$BImoT2=OT`Qt{+ zg+-xO$~aPedaKySJjjeIT_ycJu|d4K!4@(|y4Ce@Usqj7zdU*e%zOmorM=+qvtZ!z`O<-pXFX%t1JZL*&k7KhX7h z0{rP0NK z1SDb;?vUVp9^jSOEd#80nHR#BA`J{Hb9>z^O#Rq;MPr#rF*5e~D*4DfoL8ONCLn{8 z!BCtX-I6EN07r(fXBr-huT1gP|K_%vwwk|Gw2`=p*?{nJ~?7a!^Oh~|vIrH6!n!R!v zr+ce#kUzrm^V9W`ZR^ECp1;dEzJrKV=RvE0kDTyF!VKz19_wmH#S|H{2i?P4<5_Ds zY|GqZgVMH))^OAJ_e46%AIX9MUGib{uf_F!v$*4(^SJ$2WAUts1EG9r!gpat#02Xc z_2CG2g8waW3IbDGF8e}>s70Qj0|=y)i}QEtWM zfUF`(D)4lJ* zw07nJAMe;pXt;S;HLqks+F%J}8KHfX3B!MaP_D{;MbIzJY-Gc#&?@3NdAA?kN*CB; zMbx_rC-Pn2hE*;d(}|{CA90sqF#UH}NrtuvSwvhq^bYlhkCUt^G{S6Hm~^x&J;B%0 z-q+K@tbyKF)7QRgP$10x;EH@udc7jfn}s8PipYQvg#-%}8DE>nC+qL8`>Jglea#Q9 zy)LuQ&`Q*SR(XHu*#(j3Fj|&a{GM0DWAWO~w>B4j#mQG>!nZ6V)qX>HV(~aQQNL+; zieOpfWhUh@PKklZjA;ji8?^<#pD9tz7YShR7uL{`1D{{vu;8yn+LHR5%ch$ht1~<_ zl(=5P8qU1JU>X*W&k=FZHLKh0td_Aim%DSWM&SWouI?4_2rwWc*= z+jD9FWyQ`yY&TnND5y;*jUTB0hTlVGGDPy=xDc4KsD41OV0;F<?ZQ;{5RZ_ZI*b-wdtB_6Pn*ya75}H4-h;sNq+1`zjd)NTE)CUNSrwH` zYGa|g7IGH&DDe?5cCGkd6j^!YQ%LojHAwDAY)z}xazde{yl1-uV~fa%iXB(`_np2zOlSJ~HRb61;fJ#PcZl7i;f;B!#Jxgw$A7iB=@KK?EGNJYP$g^&*xHBiq zu`iDNg8e_>smBk<8GICRcc5|=eRS?oPWXlL+P-BQ3pI&ypEoqTV5%iLa?dvJL*N! zIH6Tl+HK$qDA|TiFLrMlWU%Mc7*&74m<%mtn$?7Ba z?7QO%!n42Fav}an|8J&laObfRUya43r8WnWz-*p!LQZ)Z6jR)v{UjqS%^Ztq_H>{~ zwT&vn1jw)yM?PB-;7QtL&I8XiD(RiLX^3P`S6)GPb0l0cSwX3nKYgyKud3{&8DLfB ztA+%=>^&rwNT+EFM@nMO93kB+V1CX01|6=jhe3GwM7m#VOGTTBm@Bl50=PqGt)blw zXS1*+Lktsu+lsVC252aAht35Z5_S5|_n1*QVmXx`=M_>AMSbF)<_^R@er7E8-Cg&1 z@Ww3RYc~6I>t-o^KDmP6UK+!a6yv?nL&{0gC@X%E?y5VM&JVuAQgJ#dgLTLo(cW2h z@i6#k5%G1z6^{s|xHWvUy7dG*F5?N1csCO zLHy(~`zvYFyu4UPRZQ40y^p+#U<$`z+XpORSEk9I5i*s1>X|m+i;?hJXOMv`p1cP( zZ_X1I@dvStx?%vRmP1ZK@bum(i-VNNA5$@h+#W#B!B19tes<#x| zcrHxQN2sOK-7_nDR>ckPvYV2iiWyZPT2zR_7bl$mb@bc)H_N`ve*iYQYbpc`cEepn zV~a{N)`(_M4X5VCq)1|9j`X?1Glj!f#X395-)0i;(vo zt4-FhyC05`#}~d4>yO&>w#0FDdW6Q|L$B=9pj7$}vi~q!=~Y9Na80voy2p879TMW? zFEJ<#M2IaPlfT^Q^-woWS4btiXAg0)NseyQ!jnwod*`5E3vZ_R^}K;dz~nLg4BpRN z&vQ+G&4IiZo@*<+L1NBdqDzqoXUSchA-yra-{^AWkioWJfA>(-NY`qUQm0P7yoN|lA){eJ2N?d{Z14{10i4b<+mPtn-;aSy>k7-Pr>G}r&MeO{j(fXZ( zw}bvg9Jng!xSBwM5EE^pP=-{pa9+3*^phY-Z#j23@9S{rr?HI(=$X7R@aES3W>Rwt ze%2N(RiG|TUSi_M&%=of?VBCoqIrd8+~{Dv4Lxa8SARm2hmHPus-4y3=RA2J%7kXq zX_}*RS%tN84<3VnUOlESdxd+eY4zD|Mb75)7}twI$5d)L^02(tU;Ow`-eo9SR7lx? zFq%k|Ir@Kf=`F)Hh?F;r&bnE;pEJF7n-nDoB*A+*B-zPmPMv~|vcXGV?}e`ScR@il zvs9hDeyTu^@gS|&CVya37z?{jNGN8II6IV}w$zH6`jJ3)sF)xYZGwk;P~@CUb%+vKrVAdG5WO%#VBuH@0Ti&x1UvZQ z%tK3R9JqhQ5!ye)L(wK|?a5gC-TS`MW#G%_ZF1^1Riye20dXgpyO_zwGOann>(+Kb zgVCvHuW6^Mo3=hsWHdQ=bQFo~D!+G32yQURKt$O$Ddt}OO~{s164ec18X?C+Oq0Xt zmjLESE@^x|e|GFcZ|8YkXa3kJVkeXkYX5uDbK7|)ipT~y z{%&R$C`Yi$bWCy;UHSThQC`%$(2#%K1P{`P54JtMjZ z*_ongNG#ZU7^hjPn_=BqbNT*}%TET8M9{luN5n~k z*PA(}ZseinE+aQjpY?Yb04wpq%?1O z9L%G=OA)kZs7o#g)gN|!QQANGFx+GI2ia0w1KAXNHcsxW>)hIcn^Wbp)$Q)v&Kdqw zNE_wPg4pTrg*aoejY|F3`~aH4hl{)ywifk;(1o0()9;CxXB){DC2BO^;`$g$Zh<pX)wL@6h2(6j%}} zYD0g&h%kb@k;DDjF4zbIF~k`0|Iw`c$w&Vm+i3uI-~X;CT7ZH({?Eq0CU-Fx*bPsy zV==*Y7dT&zD~rKq$cXs0`5hKSHm=jgihe$b)n%}hPkSlK3R%roR~0D&ObYZ z+H*A79qs=7OIesNlk`=lyelIz#Fqjc@?o8iWCwz$#H59@RZMcrXR*NUEt2}=Y*;&muR`q#*D(iBpbqZa@EuqI& z!1dc8`AP**gYmgNVuNoaqon_w{a37QPPX;Dr44$q%tKlU3oPt&%S9*}C&}rBq-}jq zxl1Qc_Isn(9wBMCJ z1+`FB4jD^Dd=@DIm@OlIOK_oRZdo4-gC+x2LqAh0m@#cFo}XdVdv2SKohSPRj-$C6 zEefi%$qI;%BPF2!J&qs6cc`p}xYy}fvNC*lz53sCAWKIr1A6Jnx7#C= z!#z~*C}8#m&M3UMNX9#hsJ~UGcH2ek=Bqvfh5+J^cet_)*2L*abixywv!wLPeT$!c zIExGWOFL;!1{{1(>?iZ`)JQn?*Ojy|@&gZbRXBxn9QQ2(lT#@xWEW#@!pbP`Z{wYI zN7wnjNg*0d-Zis~QZT3CL!3efk7Tzkecr9sPef f)cGW=>Es%DZ0$_t>#)HV;zH!5Riug}jr{%(#tl+v literal 0 HcmV?d00001 diff --git a/integration/fastertransformer/kenrel_output/ladder_kernel.cu b/integration/fastertransformer/kenrel_output/ladder_kernel.cu new file mode 100644 index 000000000000..0039bd64dc9d --- /dev/null +++ b/integration/fastertransformer/kenrel_output/ladder_kernel.cu @@ -0,0 +1,407 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include "ladder_kernel.h" +#include "mma.h" +// nvcc ladder_kernel.cu -gencode arch=compute_80,code=sm_80 + +__global__ void __launch_bounds__(128) bitblas_kernel_fp16_int2_fp16_m1n15360k5120_nt(half* __restrict__ A, half* __restrict__ QB, half* __restrict__ D) { + signed char* B = ((int8_t *)QB); + half* Scale = (half *)((int8_t *)QB + 19660800); + // const dim3 GridDim(7680, 1, 1); + // const dim3 BlockDim(64, 2, 1); + // bitblas_kernel_fp16_int2_fp16_m1n15360k5120_nt<<>>(input_0, input_1, output); + + half in_thread_C_local[1]; + int B_local[1]; + half B_decode_local[8]; + half A_local[8]; + __shared__ half red_result[2]; + in_thread_C_local[0] = __float2half_rn(0.000000e+00f); + for (int ax1_0 = 0; ax1_0 < 10; ++ax1_0) { + B_local[0] = *(int*)(B + ((((((int)blockIdx.x) * 5120) + (((int)threadIdx.y) * 2560)) + (ax1_0 * 256)) + (((int)threadIdx.x) * 4))); + decode_i4s_to_f16_scale(B_local, B_decode_local, (&(Scale[((((((int)blockIdx.x) * 80) + (((int)threadIdx.y) * 40)) + (ax1_0 * 4)) + (((int)threadIdx.x) >> 4))])), 8); + *(uint4*)(A_local + 0) = *(uint4*)(A + ((ax1_0 * 512) + (((int)threadIdx.x) * 8))); + for (int ax1_2 = 0; ax1_2 < 8; ++ax1_2) { + in_thread_C_local[0] = (in_thread_C_local[0] + (A_local[ax1_2] * B_decode_local[ax1_2])); + } + } + half red_buf0[1]; + uint mask[1]; + half t0[1]; + half red_buf0_1[1]; + uint mask_1[1]; + half t0_1[1]; + __shared__ half red_buf_staging[4]; + red_buf0_1[0] = in_thread_C_local[0]; + mask_1[0] = __activemask(); + t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 16, 32); + red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]); + t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 8, 32); + red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]); + t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 4, 32); + red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]); + t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 2, 32); + red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]); + t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 1, 32); + red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]); + if ((((int)threadIdx.x) % 32) == 0) { + red_buf_staging[((((int)threadIdx.y) * 2) + (((int)threadIdx.x) >> 5))] = red_buf0_1[0]; + } + __syncthreads(); + if (((int)threadIdx.x) < 2) { + red_buf0[0] = red_buf_staging[((((int)threadIdx.y) * 2) + ((int)threadIdx.x))]; + } + mask[0] = (__activemask() & ((uint)(3 << (((int)threadIdx.y) * 2)))); + t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32); + red_buf0[0] = (red_buf0[0] + t0[0]); + if (((int)threadIdx.x) == 0) { + ((volatile half*)red_result)[((int)threadIdx.y)] = red_buf0[0]; + } + __syncthreads(); + if (((int)threadIdx.x) == 0) { + D[((((int)blockIdx.x) * 2) + ((int)threadIdx.y))] = (half)(((volatile half*)red_result)[((int)threadIdx.y)]); + } +} + + + +__global__ void __launch_bounds__(128) bitblas_kernel_fp16_int2_fp16_m128n15360k5120_nt(half* __restrict__ A, half* __restrict__ QB, half* __restrict__ D) { + signed char* B = ((int8_t *)QB); + half* Scale = (half *)((int8_t *)QB + 19660800); + // const dim3 GridDim(160, 2, 1); + // const dim3 BlockDim(1, 2, 2); + // bitblas_kernel_fp16_int2_fp16_m128n15360k5120_nt<<>>(input_0, input_1, output); + + + const int MAX_BLOCK_N = 10; + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +MAX_BLOCK_N * gridDim.x - 1) / (MAX_BLOCK_N * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (MAX_BLOCK_N *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?MAX_BLOCK_N : (totalBlock - panelIdx * (MAX_BLOCK_N *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * MAX_BLOCK_N * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) % strideLd + panelIdx * MAX_BLOCK_N; + const auto bz = blockIdx.z; + const dim3 blockIdx(bx, by, bz); + half C_reindex_shared_warp[48]; + __shared__ half A_reindex_shared[4096]; + __shared__ signed char B_shared[4096]; + __shared__ half B_decode_reindex_shared[3072]; + int B_local[1]; + uint4 B_decode_reindex_local[1]; + half A_reindex_shared_warp[16]; + half B_decode_reindex_shared_warp[24]; + int B_local_1[1]; + uint4 B_decode_reindex_local_1[1]; + half A_reindex_shared_warp_1[16]; + half B_decode_reindex_shared_warp_1[24]; + for (int var = 0; var < 1; ++var) { + for (int ax1_0_3_init = 0; ax1_0_3_init < 2; ++ax1_0_3_init) { + for (int ax2_0_3_init = 0; ax2_0_3_init < 3; ++ax2_0_3_init) { + for (int i = 0; i < 8; ++i) { +C_reindex_shared_warp[((ax1_0_3_init * 24) + (ax2_0_3_init * 8)) + i] = 0.0;} +; + } + } + #pragma unroll + for (int ax0_ax1_ax2_fused_0 = 0; ax0_ax1_ax2_fused_0 < 2; ++ax0_ax1_ax2_fused_0) { + + { + unsigned int addr; +#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST + addr = static_cast(__cvta_generic_to_shared((void *)(A_reindex_shared + (((((ax0_ax1_ax2_fused_0 * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8))))); +#else + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_reindex_shared + (((((ax0_ax1_ax2_fused_0 * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8)))) + ); +#endif + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(A + ((((((((int)blockIdx.y) * 327680) + (ax0_ax1_ax2_fused_0 * 163840)) + (((int)threadIdx.y) * 81920)) + (((int)threadIdx.z) * 40960)) + ((((int)threadIdx.x) >> 2) * 5120)) + ((((int)threadIdx.x) & 3) * 8)))), "n"(16) + ); + } + } + #pragma unroll + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 1; ++ax0_ax1_fused_0) { + if (((((int)threadIdx.z) * 2) + ((int)threadIdx.y)) < 3) { + + { + unsigned int addr; +#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST + addr = static_cast(__cvta_generic_to_shared((void *)(B_shared + (((((int)threadIdx.z) * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16))))); +#else + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + (((((int)threadIdx.z) * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)))) + ); +#endif + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(B + ((((((int)blockIdx.x) * 245760) + (((int)threadIdx.z) * 163840)) + (((int)threadIdx.y) * 81920)) + (((int)threadIdx.x) * 2560)))), "n"(16) + ); + } + } + } +__asm__ __volatile__("cp.async.commit_group;"); + + for (int ax3_0_0 = 0; ax3_0_0 < 159; ++ax3_0_0) { + __syncthreads(); + #pragma unroll + for (int ax0_ax1_ax2_fused_0_1 = 0; ax0_ax1_ax2_fused_0_1 < 2; ++ax0_ax1_ax2_fused_0_1) { + + { + unsigned int addr; +#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST + addr = static_cast(__cvta_generic_to_shared((void *)(A_reindex_shared + ((((((((ax3_0_0 + 1) & 1) * 2048) + (ax0_ax1_ax2_fused_0_1 * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8))))); +#else + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_reindex_shared + ((((((((ax3_0_0 + 1) & 1) * 2048) + (ax0_ax1_ax2_fused_0_1 * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8)))) + ); +#endif + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(A + ((((((((((int)blockIdx.y) * 327680) + (ax0_ax1_ax2_fused_0_1 * 163840)) + (((int)threadIdx.y) * 81920)) + (((int)threadIdx.z) * 40960)) + ((((int)threadIdx.x) >> 2) * 5120)) + (ax3_0_0 * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 32))), "n"(16) + ); + } + } + #pragma unroll + for (int ax0_ax1_fused_0_1 = 0; ax0_ax1_fused_0_1 < 1; ++ax0_ax1_fused_0_1) { + if (((((int)threadIdx.z) * 2) + ((int)threadIdx.y)) < 3) { + + { + unsigned int addr; +#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST + addr = static_cast(__cvta_generic_to_shared((void *)(B_shared + ((((((ax3_0_0 + 1) & 1) * 2048) + (((int)threadIdx.z) * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16))))); +#else + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + ((((((ax3_0_0 + 1) & 1) * 2048) + (((int)threadIdx.z) * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)))) + ); +#endif + __asm__ __volatile__( + #if TVM_ENABLE_L2_PREFETCH + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.cg.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(B + ((((((((int)blockIdx.x) * 245760) + (((int)threadIdx.z) * 163840)) + (((int)threadIdx.y) * 81920)) + (((int)threadIdx.x) * 2560)) + (ax3_0_0 * 16)) + 16))), "n"(16) + ); + } + } + } +__asm__ __volatile__("cp.async.commit_group;"); + +__asm__ __volatile__("cp.async.wait_group 1;"); + + __syncthreads(); + for (int ax1_ax2_0_fused_0 = 0; ax1_ax2_0_fused_0 < 3; ++ax1_ax2_0_fused_0) { + B_local[0] = *(int*)(B_shared + ((((((ax3_0_0 & 1) * 2048) + (ax1_ax2_0_fused_0 * 512)) + (((int)threadIdx.y) * 256)) + (((int)threadIdx.z) * 128)) + (((int)threadIdx.x) * 4))); + decode_i4s_to_f16_scale(B_local, B_decode_reindex_local, (&(Scale[((((((((int)blockIdx.x) * 3840) + (ax1_ax2_0_fused_0 * 1280)) + (((int)threadIdx.y) * 640)) + (((int)threadIdx.z) * 320)) + ((((int)threadIdx.x) >> 2) * 40)) + (ax3_0_0 >> 2))])), 8); + *(uint4*)(B_decode_reindex_shared + (((((ax1_ax2_0_fused_0 * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8))) = B_decode_reindex_local[0]; + } + __syncthreads(); + for (int ax3_0_1 = 0; ax3_0_1 < 2; ++ax3_0_1) { + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) { + + { + unsigned int addr; +#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST + addr = static_cast(__cvta_generic_to_shared((void *)((&(A_reindex_shared[((((((ax3_0_0 & 1) * 2048) + (((int)threadIdx.y) * 1024)) + (ax1_0 * 512)) + ((((int)threadIdx.x) & 15) * 32)) + ((((ax3_0_1 * 2) + (((int)threadIdx.x) >> 4)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0))); +#else + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_reindex_shared[((((((ax3_0_0 & 1) * 2048) + (((int)threadIdx.y) * 1024)) + (ax1_0 * 512)) + ((((int)threadIdx.x) & 15) * 32)) + ((((ax3_0_1 * 2) + (((int)threadIdx.x) >> 4)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0)) + ); +#endif + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_reindex_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(A_reindex_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(A_reindex_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(A_reindex_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_1 = 0; ax1_0_1 < 3; ++ax1_0_1) { + + { + unsigned int addr; +#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST + addr = static_cast(__cvta_generic_to_shared((void *)((&(B_decode_reindex_shared[(((((((int)threadIdx.z) * 1536) + (ax1_0_1 * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((ax3_0_1 * 2) + ((((int)threadIdx.x) & 15) >> 3)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0))); +#else + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_decode_reindex_shared[(((((((int)threadIdx.z) * 1536) + (ax1_0_1 * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((ax3_0_1 * 2) + ((((int)threadIdx.x) & 15) >> 3)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0)) + ); +#endif + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_decode_reindex_shared_warp + (ax1_0_1 * 8)))[0]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + (ax1_0_1 * 8)))[1]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + (ax1_0_1 * 8)))[2]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + (ax1_0_1 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_3 = 0; ax1_0_3 < 2; ++ax1_0_3) { + for (int ax2_0_3 = 0; ax2_0_3 < 3; ++ax2_0_3) { + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" + : "=r"(((unsigned *)(C_reindex_shared_warp + ((ax1_0_3 * 24) + (ax2_0_3 * 8))))[0]), "=r"(((unsigned *)(C_reindex_shared_warp + ((ax1_0_3 * 24) + (ax2_0_3 * 8))))[1]) + : "r"(((unsigned *)(A_reindex_shared_warp + (ax1_0_3 * 8)))[0]), "r"(((unsigned *)(A_reindex_shared_warp + (ax1_0_3 * 8)))[1]), "r"(((unsigned *)(A_reindex_shared_warp + (ax1_0_3 * 8)))[2]), "r"(((unsigned *)(A_reindex_shared_warp + (ax1_0_3 * 8)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp + (ax2_0_3 * 8)))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp + (ax2_0_3 * 8)))[1]), "r"(((unsigned *)(C_reindex_shared_warp + ((ax1_0_3 * 24) + (ax2_0_3 * 8))))[0]), "r"(((unsigned *)(C_reindex_shared_warp + ((ax1_0_3 * 24) + (ax2_0_3 * 8))))[1])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" + : "=r"(((unsigned *)(C_reindex_shared_warp + (((ax1_0_3 * 24) + (ax2_0_3 * 8)) + 4)))[0]), "=r"(((unsigned *)(C_reindex_shared_warp + (((ax1_0_3 * 24) + (ax2_0_3 * 8)) + 4)))[1]) + : "r"(((unsigned *)(A_reindex_shared_warp + (ax1_0_3 * 8)))[0]), "r"(((unsigned *)(A_reindex_shared_warp + (ax1_0_3 * 8)))[1]), "r"(((unsigned *)(A_reindex_shared_warp + (ax1_0_3 * 8)))[2]), "r"(((unsigned *)(A_reindex_shared_warp + (ax1_0_3 * 8)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp + ((ax2_0_3 * 8) + 4)))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp + ((ax2_0_3 * 8) + 4)))[1]), "r"(((unsigned *)(C_reindex_shared_warp + (((ax1_0_3 * 24) + (ax2_0_3 * 8)) + 4)))[0]), "r"(((unsigned *)(C_reindex_shared_warp + (((ax1_0_3 * 24) + (ax2_0_3 * 8)) + 4)))[1])); + } + } + } + } + } +__asm__ __volatile__("cp.async.wait_group 0;"); + + __syncthreads(); + for (int ax1_ax2_0_fused_0_1 = 0; ax1_ax2_0_fused_0_1 < 3; ++ax1_ax2_0_fused_0_1) { + B_local_1[0] = *(int*)(B_shared + (((((ax1_ax2_0_fused_0_1 * 512) + (((int)threadIdx.y) * 256)) + (((int)threadIdx.z) * 128)) + (((int)threadIdx.x) * 4)) + 2048)); + decode_i4s_to_f16_scale(B_local_1, B_decode_reindex_local_1, (&(Scale[((((((((int)blockIdx.x) * 3840) + (ax1_ax2_0_fused_0_1 * 1280)) + (((int)threadIdx.y) * 640)) + (((int)threadIdx.z) * 320)) + ((((int)threadIdx.x) >> 2) * 40)) + 39)])), 8); + *(uint4*)(B_decode_reindex_shared + (((((ax1_ax2_0_fused_0_1 * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + ((((int)threadIdx.x) >> 2) * 32)) + (((((int)threadIdx.x) & 3) ^ (((int)threadIdx.x) >> 3)) * 8))) = B_decode_reindex_local_1[0]; + } + __syncthreads(); + for (int ax3_0_1_1 = 0; ax3_0_1_1 < 2; ++ax3_0_1_1) { + for (int ax1_0_2 = 0; ax1_0_2 < 2; ++ax1_0_2) { + + { + unsigned int addr; +#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST + addr = static_cast(__cvta_generic_to_shared((void *)((&(A_reindex_shared[(((((((int)threadIdx.y) * 1024) + (ax1_0_2 * 512)) + ((((int)threadIdx.x) & 15) * 32)) + ((((ax3_0_1_1 * 2) + (((int)threadIdx.x) >> 4)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 2048)])) + 0))); +#else + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_reindex_shared[(((((((int)threadIdx.y) * 1024) + (ax1_0_2 * 512)) + ((((int)threadIdx.x) & 15) * 32)) + ((((ax3_0_1_1 * 2) + (((int)threadIdx.x) >> 4)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8)) + 2048)])) + 0)) + ); +#endif + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_2 * 8)))[0]), "=r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_2 * 8)))[1]), "=r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_2 * 8)))[2]), "=r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_2 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_4 = 0; ax1_0_4 < 3; ++ax1_0_4) { + + { + unsigned int addr; +#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST + addr = static_cast(__cvta_generic_to_shared((void *)((&(B_decode_reindex_shared[(((((((int)threadIdx.z) * 1536) + (ax1_0_4 * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((ax3_0_1_1 * 2) + ((((int)threadIdx.x) & 15) >> 3)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0))); +#else + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_decode_reindex_shared[(((((((int)threadIdx.z) * 1536) + (ax1_0_4 * 512)) + ((((int)threadIdx.x) >> 4) * 256)) + ((((int)threadIdx.x) & 7) * 32)) + ((((ax3_0_1_1 * 2) + ((((int)threadIdx.x) & 15) >> 3)) ^ ((((int)threadIdx.x) & 7) >> 1)) * 8))])) + 0)) + ); +#endif + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + (ax1_0_4 * 8)))[0]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + (ax1_0_4 * 8)))[1]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + (ax1_0_4 * 8)))[2]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + (ax1_0_4 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int ax1_0_3_1 = 0; ax1_0_3_1 < 2; ++ax1_0_3_1) { + for (int ax2_0_3_1 = 0; ax2_0_3_1 < 3; ++ax2_0_3_1) { + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" + : "=r"(((unsigned *)(C_reindex_shared_warp + ((ax1_0_3_1 * 24) + (ax2_0_3_1 * 8))))[0]), "=r"(((unsigned *)(C_reindex_shared_warp + ((ax1_0_3_1 * 24) + (ax2_0_3_1 * 8))))[1]) + : "r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_3_1 * 8)))[0]), "r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_3_1 * 8)))[1]), "r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_3_1 * 8)))[2]), "r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_3_1 * 8)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + (ax2_0_3_1 * 8)))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + (ax2_0_3_1 * 8)))[1]), "r"(((unsigned *)(C_reindex_shared_warp + ((ax1_0_3_1 * 24) + (ax2_0_3_1 * 8))))[0]), "r"(((unsigned *)(C_reindex_shared_warp + ((ax1_0_3_1 * 24) + (ax2_0_3_1 * 8))))[1])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" + : "=r"(((unsigned *)(C_reindex_shared_warp + (((ax1_0_3_1 * 24) + (ax2_0_3_1 * 8)) + 4)))[0]), "=r"(((unsigned *)(C_reindex_shared_warp + (((ax1_0_3_1 * 24) + (ax2_0_3_1 * 8)) + 4)))[1]) + : "r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_3_1 * 8)))[0]), "r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_3_1 * 8)))[1]), "r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_3_1 * 8)))[2]), "r"(((unsigned *)(A_reindex_shared_warp_1 + (ax1_0_3_1 * 8)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + ((ax2_0_3_1 * 8) + 4)))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + ((ax2_0_3_1 * 8) + 4)))[1]), "r"(((unsigned *)(C_reindex_shared_warp + (((ax1_0_3_1 * 24) + (ax2_0_3_1 * 8)) + 4)))[0]), "r"(((unsigned *)(C_reindex_shared_warp + (((ax1_0_3_1 * 24) + (ax2_0_3_1 * 8)) + 4)))[1])); + } + } + } + } + for (int ax0 = 0; ax0 < 2; ++ax0) { + for (int ax1 = 0; ax1 < 3; ++ax1) { + __syncthreads(); + for (int local_id = 0; local_id < 8; local_id+=2) { +*((uint *)&(&(B_decode_reindex_shared[((((int)threadIdx.y) * 2048) + (((int)threadIdx.z) * 768))]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))]) = *((uint *)&C_reindex_shared_warp[((ax0 * 24) + (ax1 * 8)) + local_id]); +} +; + __syncthreads(); + #pragma unroll + for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 1; ++ax0_ax1_ax2_ax3_ax4_fused_0) { + *(uint4*)(D + ((((((((((int)blockIdx.y) * 983040) + (((int)threadIdx.y) * 491520)) + (ax0 * 245760)) + ((((int)threadIdx.x) >> 1) * 15360)) + (((int)blockIdx.x) * 96)) + (((int)threadIdx.z) * 48)) + (ax1 * 16)) + ((((int)threadIdx.x) & 1) * 8))) = *(uint4*)(B_decode_reindex_shared + (((((int)threadIdx.y) * 2048) + (((int)threadIdx.z) * 768)) + (((int)threadIdx.x) * 8))); + } + } + } + } +} + + + + + +int ladder_gemm_fp16xint2_fp16(half *input_0, half *input_1, half *output, const int M, const int N, const int K, const int trans_a, const int trans_b, half *workspace_ptr) +{ + assert(trans_a == 0 && trans_b == 1); + + if (M == 1 && N == 15360 && K == 5120){ + + const dim3 GridDim(7680, 1, 1); + const dim3 BlockDim(64, 2, 1); + bitblas_kernel_fp16_int2_fp16_m1n15360k5120_nt<<>>(input_0, input_1, output); + + return 0; + } + + + if (M == 128 && N == 15360 && K == 5120){ + + const dim3 GridDim(160, 2, 1); + const dim3 BlockDim(1, 2, 2); + bitblas_kernel_fp16_int2_fp16_m128n15360k5120_nt<<>>(input_0, input_1, output); + + return 0; + } + + + return -1; +} \ No newline at end of file diff --git a/integration/fastertransformer/kenrel_output/ladder_kernel.h b/integration/fastertransformer/kenrel_output/ladder_kernel.h new file mode 100644 index 000000000000..ab1c284c898d --- /dev/null +++ b/integration/fastertransformer/kenrel_output/ladder_kernel.h @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#ifndef __LADDER_KERNEL_H__ +#define __LADDER_KERNEL_H__ +#include +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) +#define TVM_ENABLE_L2_PREFETCH 1 +#else +#define TVM_ENABLE_L2_PREFETCH 0 +#endif + +#ifdef _WIN32 + using uint = unsigned int; + using uchar = unsigned char; + using ushort = unsigned short; + using int64_t = long long; + using uint64_t = unsigned long long; +#else + #define uint unsigned int + #define uchar unsigned char + #define ushort unsigned short + #define int64_t long long + #define uint64_t unsigned long long +#endif + + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) +#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1 +#else +#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0 +#endif + + +template +__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4u, B_local_decode, N); +} +, + + +template +__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + unsigned v0 = *((unsigned short *)scale); + unsigned v1 = *((unsigned short *)scale); + unsigned __packed_scale = (v1 << 16) | v0; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__packed_scale), "r"(0)); + } + +} + +template +__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4s, B_local_decode, scale, N); +} + +template +__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4u, B_local_decode, scale, N); +} +, + + +template +__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2u, B_local_decode, N); +} +, + + +template +__device__ void decode_i2b_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } +} + +template +__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2s, B_local_decode, scale, N); +} + +template +__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2u, B_local_decode, scale, N); +} +, + + +// Pack two half values. +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + + + +int ladder_gemm_fp16xint2_fp16(half *input_0, half *input_1, half *output, const int M, const int N, const int K, const int trans_a, const int trans_b, half *workspace_ptr); + +#endif + + \ No newline at end of file diff --git a/integration/fastertransformer/kernel_generator.py b/integration/fastertransformer/kernel_generator.py new file mode 100644 index 000000000000..c8590c415a68 --- /dev/null +++ b/integration/fastertransformer/kernel_generator.py @@ -0,0 +1,224 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from string import Template +import os +import tvm +from tvm import IRModule +from tvm.target import Target +from bitblas.utils import match_global_kernel +from bitblas.base.analysis import get_reduction_blocks +from bitblas.ops import Operator +from bitblas.ops.matmul_dequantize import ( + MatmulWeightOnlyDequantize, + MatmulWeightOnlyDequantizeConfig, +) +from bitblas.gpu.intrin.lop3 import ( + decode_i2_to_f16, + decode_i2_to_f16_scale, + decode_i4_to_f16, + decode_i4_to_f16_scale, +) +bit = 2 +mask = (1 << bit) - 1 +group_size = 128 + + +ft_shapes = [ + # [1, 5120, 5120], + [1, 15360, 5120], + [128, 15360, 5120], +] + + +target = tvm.target.Target("nvidia/nvidia-a100") + + +def get_template_path(): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + return os.path.join( + cur_dir, f"template/kernel_template.int{bit}.bitblas.cu.template" + ) + + +template_path = get_template_path() + + +def get_codegen_result(ops: Operator, target: Target): + code = ops.codegen(target=target) + return code + + +def get_thread_block_infomation(mod: IRModule): + sch = tvm.tir.Schedule(mod) + root_block = sch.get_block("root") + child_blocks = sch.get_child_blocks(root_block) + reduction_blocks = get_reduction_blocks(sch, child_blocks) + assert len(reduction_blocks) == 1 + (main_block,) = reduction_blocks + loops = sch.get_loops(main_block) + block_info = [1, 1, 1] + grid_info = [1, 1, 1] + for loop in loops: + stmt = sch.get(loop) + thread_binding = stmt.thread_binding + extent = int(stmt.extent) + if thread_binding is None: + continue + if thread_binding.thread_tag == "threadIdx.x": + block_info[0] = extent + elif thread_binding.thread_tag == "threadIdx.y": + block_info[1] = extent + elif thread_binding.thread_tag == "threadIdx.z": + block_info[2] = extent + elif thread_binding.thread_tag == "blockIdx.x": + grid_info[0] = extent + elif thread_binding.thread_tag == "blockIdx.y": + grid_info[1] = extent + elif thread_binding.thread_tag == "blockIdx.z": + grid_info[2] = extent + return block_info, grid_info + +kernel_body = "" +kernel_call = "" +for M, N, K in ft_shapes: + matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="int", + with_scaling=True, + group_size=group_size, + fast_decoding=True, + with_bias=False, + propagate_a=False, + propagate_b=False, + layout="nt", + ) + matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, + ) + matmul.hardware_aware_finetune(topk=10) + code = get_codegen_result(matmul, target) + index = match_global_kernel(code) + headers = code[:index] + headers.replace('extern "C" ', "") + declarations = code[index:].split(";")[0] + index = code.index("{", index) + + function_body = declarations + code[index:] + # get block infomation from mod + block_size, grid_size = get_thread_block_infomation(matmul.optimized_func) + + new_kernel_name = ( + f"bitblas_kernel_fp16_int{bit}_fp16_m{M}n{N}k{K}_nt" + ) + Qweight_bytes = N * K // 8 * bit + function_body = function_body.replace("main_kernel", new_kernel_name) + call = f""" + // const dim3 GridDim({grid_size[0]}, {grid_size[1]}, {grid_size[2]}); + // const dim3 BlockDim({block_size[0]}, {block_size[1]}, {block_size[2]}); + // {new_kernel_name}<<>>(input_0, input_1, output); + """ + function_body = function_body.replace( + "(half* __restrict__ A, signed char* __restrict__ B, half* __restrict__ D, half* __restrict__ Scale){", + f"(half* __restrict__ A, half* __restrict__ QB, half* __restrict__ D) {{\n\ + signed char* B = ((int8_t *)QB);\n\t half* Scale = (half *)((int8_t *)QB + {Qweight_bytes}); \ + {call}", + ) + kernel_body += function_body + kernel_body += "\n\n" + real_call = call.replace("//", "") + real_call = f""" + if (M == {M} && N == {N} && K == {K}){{ + {real_call} + return 0; + }} + + """ + kernel_call += real_call + + +# make output +cur_dir = os.path.dirname(os.path.abspath(__file__)) +ladder_path = os.path.join(cur_dir, f"kenrel_output") +if not os.path.exists(ladder_path): + os.makedirs(ladder_path) +ladder_kernel_path = os.path.join(ladder_path, f"ladder_kernel.cu") +ladder_header_path = os.path.join(ladder_path, f"ladder_kernel.h") + +with open(template_path, mode="r", encoding="utf-8") as r_f, open( + ladder_kernel_path, mode="w", encoding="utf8" +) as w_f: + template_content = r_f.read() + template = Template(template_content) + data = template.substitute(kernel_body=kernel_body, kernel_call=kernel_call) + w_f.write(data) + +pack_half2 = """ +// Pack two half values. +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} +""" +with open( + ladder_header_path, mode="w", encoding="utf8" +) as w_f: + headers = f"""// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#ifndef __LADDER_KERNEL_H__ +#define __LADDER_KERNEL_H__ +#include +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) +#define TVM_ENABLE_L2_PREFETCH 1 +#else +#define TVM_ENABLE_L2_PREFETCH 0 +#endif + +#ifdef _WIN32 + using uint = unsigned int; + using uchar = unsigned char; + using ushort = unsigned short; + using int64_t = long long; + using uint64_t = unsigned long long; +#else + #define uint unsigned int + #define uchar unsigned char + #define ushort unsigned short + #define int64_t long long + #define uint64_t unsigned long long +#endif + + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 800) +#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1 +#else +#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0 +#endif + +{decode_i4_to_f16}, + +{decode_i4_to_f16_scale}, + +{decode_i2_to_f16}, + +{decode_i2_to_f16_scale}, + +{pack_half2} + + +int ladder_gemm_fp16xint{bit}_fp16(half *input_0, half *input_1, half *output, const int M, const int N, const int K, const int trans_a, const int trans_b, half *workspace_ptr); + +#endif + + """ + w_f.write(headers) diff --git a/integration/fastertransformer/template/kernel_template.int2.bitblas.cu.template b/integration/fastertransformer/template/kernel_template.int2.bitblas.cu.template new file mode 100644 index 000000000000..2ca946133f9a --- /dev/null +++ b/integration/fastertransformer/template/kernel_template.int2.bitblas.cu.template @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include "ladder_kernel.h" +#include "mma.h" +// nvcc ladder_kernel.cu -gencode arch=compute_80,code=sm_80 + +${kernel_body} + +int ladder_gemm_fp16xint2_fp16(half *input_0, half *input_1, half *output, const int M, const int N, const int K, const int trans_a, const int trans_b, half *workspace_ptr) +{ + assert(trans_a == 0 && trans_b == 1); + ${kernel_call} + return -1; +} \ No newline at end of file diff --git a/integration/fastertransformer/template/kernel_template.int4.bitblas.cu.template b/integration/fastertransformer/template/kernel_template.int4.bitblas.cu.template new file mode 100644 index 000000000000..e8298d366fd0 --- /dev/null +++ b/integration/fastertransformer/template/kernel_template.int4.bitblas.cu.template @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include "ladder_kernel.h" +#include "mma.h" +// nvcc ladder_kernel.cu -gencode arch=compute_80,code=sm_80 + +${kernel_body} + +int ladder_gemm_fp16xint4_fp16(half *input_0, half *input_1, half *output, const int M, const int N, const int K, const int trans_a, const int trans_b, half *workspace_ptr) +{ + assert(trans_a == 0 && trans_b == 1); + ${kernel_call} + return -1; +} \ No newline at end of file diff --git a/integration/pytorch/quant_linear.py b/integration/pytorch/quant_linear.py new file mode 100644 index 000000000000..99fd101619fb --- /dev/null +++ b/integration/pytorch/quant_linear.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from logging import getLogger + +import numpy as np +import torch +import torch.nn as nn + + +logger = getLogger(__name__) + +try: + import bitblas +except ImportError as e: + bitblas_import_exception = e + + def error_raiser_bitblas(*args, **kwargs): + raise ValueError( + f"Trying to use the bitblas backend, but could not import dependencies with the following error: {bitblas_import_exception}" + ) + + autogptq_bitblas_cuda = bitblas_import_exception + +from bitblas.quantization.utils import general_compress, interleave_weight +from bitblas.ops.matmul import MatmulWeightOnlyDequantize + + +class QuantLinear(nn.Module): + QUANT_TYPE = "bitblas" + + def __init__( + self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs + ): + super().__init__() + if infeatures % 128 != 0 or outfeatures % 256 != 0: + raise ValueError( + "`infeatures` must be divisible by 128 and `outfeatures` by 256." + ) + if bits not in [1, 2, 4]: + raise NotImplementedError("Only 1/2/4 bits are supported.") + if infeatures % group_size != 0: + raise ValueError("`infeatures` must be divisible by `group_size`.") + if trainable: + raise NotImplementedError("Bitblas does not support train.") + + self.bits = bits + storage_nbit = 8 # assume int8 storage + n_float_per_elem = storage_nbit // bits + self.infeatures = infeatures + self.outfeatures = outfeatures + self.group_size = group_size if group_size != -1 else infeatures + self.register_buffer( + "qweight", + torch.empty( + (self.outfeatures, self.infeatures // storage_nbit * n_float_per_elem), + dtype=torch.uint8, + ), + ) + self.register_buffer( + "scales", + torch.empty( + (self.outfeatures, self.infeatures // self.group_size), dtype=torch.half + ), + ) + if bias: + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.half)) + else: + self.bias = None + + self.fast_type_conversion = False + self.weight_propagation = False + + # optimize target shapes for dynamic symbolic + OPTIMIZE_M_RANGE = [1, 16, 32] + self.bitblas_matmul = MatmulWeightOnlyDequantize( + M=1, + N=outfeatures, + K=infeatures, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + propagate_b=self.weight_propagation, + bit=self.bits, + storage_dtype="uint8", + source_format="int", + with_scaling=True, + group_size=self.group_size, + fast_decoding=self.fast_type_conversion, + with_bias=bias, + ) + # self.bitblas_matmul.optimize(topk=20) + + def post_init(self): + pass + + def pack(self, linear, scales): + """Pack a fake-quantized linear layer into this actual Bitblas representation. + @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`) + @scales: corresponding quantization scales of shape `(infeatures, groups)` + """ + if linear.weight.dtype != torch.half: + raise ValueError("Only `torch.half` weights are supported.") + + # do permutation with (n, k) layout + w = linear.weight.data + # scales shape should be (n, k) as well. + s = scales + # do permutation on weight + intweight = [] + for idx in range(self.infeatures): + g_idx = idx // self.group_size + intweight.append( + torch.round((w[:, idx]) / scales[:, g_idx]).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.contiguous() + intweight = intweight.cpu().numpy().astype(np.int8) + print("bitblas dequantize weight is ") + print(intweight) + intweight = intweight + 7 + print("bitblas dequantize weight +7 is ") + print(intweight) + # quantize to 4bit + qw_np = general_compress( + intweight, source_bits=self.bits, storage_dtype=np.uint8 + ) + # do interleave for fast type conversion + if self.fast_type_conversion: + qw_np = interleave_weight(qw_np, nbits=self.bits, target_dtype="float16") + if self.weight_propagation: + # do permutation on weight + pass + + q = torch.from_numpy(qw_np).to(w.device) + self.qweight = q.to(self.qweight.device).contiguous() + self.scales = s.to(self.scales.device).contiguous() + + if self.bias is not None: + self.bias[:] = linear.bias.data.to(self.bias.device).contiguous() + + def forward(self, A): + A = A.half() + C = torch.empty( + A.shape[:-1] + (self.qweight.shape[0],), dtype=A.dtype, device=A.device + ) + args = [A, self.qweight, self.scales] + if self.bias is not None: + args.append(self.bias) + args.append(C) + + self.bitblas_matmul(*args) + + return C + + +__all__ = ["QuantLinear"] diff --git a/integration/pytorch/test_quant_linear.py b/integration/pytorch/test_quant_linear.py new file mode 100644 index 000000000000..7b7699bae9be --- /dev/null +++ b/integration/pytorch/test_quant_linear.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from quant_linear import QuantLinear +import copy +import torch +import torch.nn as nn + +# !pip install auto-gptq +from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( + QuantLinear as CudaOldQuantLinear, +) + +torch.manual_seed(0) + + +def gen_quant4(k, n, groupsize=-1): + maxq = 2**4 - 1 + w = torch.randn((k, n), dtype=torch.half, device="cpu") + + original_w = w.clone() + print("original weight is: ") + print(original_w) + if group_size == -1: + groupsize = k + + if groupsize != -1: + w = w.reshape((-1, groupsize, n)) + w = w.permute(1, 0, 2) + w = w.reshape((groupsize, -1)) + + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / maxq + + # Quantize. + w = torch.round(w / s).int() + + # Unsigned storage. + w += (maxq) // 2 + w = torch.clamp(w, 0, maxq) + print("quantize weight is: ") + print((w - (maxq) // 2)) + # Dequantize. + ref = (w - (maxq) // 2).half() * s + + if groupsize != -1: + + def reshape(w): + w = w.reshape((groupsize, -1, n)) + w = w.permute(1, 0, 2) + w = w.reshape((k, n)).contiguous() + return w + + ref = reshape(ref) + w = reshape(w) + + s = s.reshape((-1, n)).contiguous() + linear = nn.Linear(k, n, bias=False) + linear.weight.data = ref.t() + + return original_w, linear, s, (w - (maxq) // 2) + + +bits = 4 +m = 1 +group_size = -1 +infeatures = 1024 # this is k of weight (n, k) +outfeatures = 4096 # this is n of weight (n, k) +bias = False + +original_w, linear, s, qw = gen_quant4(infeatures, outfeatures, group_size) + +cuda_old_linear = CudaOldQuantLinear( + bits=4, + group_size=group_size, + infeatures=infeatures, + outfeatures=outfeatures, + bias=False, +) + +if group_size == -1: + group_size = infeatures +zeros = torch.full((infeatures // group_size, outfeatures), 7, dtype=torch.int32) + +cuda_old_linear.pack(linear, s.T, zeros.T, g_idx=None) +linear_module = torch.nn.Linear( + in_features=infeatures, + out_features=outfeatures, + bias=False, + dtype=torch.float16, + device="cuda", +) +linear_module.weight.data.copy_( + linear.weight.data +) # Not using dequantized_weight to avoid approx + +scales = s.to("cuda") +bitblas_qlinear = QuantLinear(bits, group_size, infeatures, outfeatures, bias) + +bitblas_qlinear.pack( + linear_module.to("cuda"), + scales=scales.T.contiguous().to("cuda"), +) + +inp = torch.rand(m, infeatures, dtype=torch.float16, device="cuda") + +cuda_old_linear = cuda_old_linear.to("cuda") +bitblas_qlinear = bitblas_qlinear.to("cuda") +with torch.no_grad(): + res_original = linear_module(inp) + res_cuda_old = cuda_old_linear(inp) + res_bitblas = bitblas_qlinear(inp) + +print(res_original) +print(res_cuda_old) +print(res_bitblas) diff --git a/maint/scripts/apply_mit_license.sh b/maint/scripts/apply_mit_license.sh new file mode 100755 index 000000000000..e57717fdf395 --- /dev/null +++ b/maint/scripts/apply_mit_license.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +echo "Add MIT liscense boilerplate..." +PWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# TO source code root +pushd "${PWD}/../../" > /dev/null + +EXITCODE=0 + +for SRC_FILE in $(find . -path './thirdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name \ + '*apply_mit_liscense.sh' -not -name '*check_mit_liscense.sh' -and \( -name '*.cpp' -or -name '*.h*' -or -name '*.cu' -or -name '*.in' \) ); do + sed -i '/\/\/\s*Microsoft\s*(c)/Id' ${SRC_FILE} + if !(grep -q "Copyright (c) Microsoft Corporation." "${SRC_FILE}"); then + cat maint/scripts/mit_liscense1.txt ${SRC_FILE} > ${SRC_FILE}.new + mv ${SRC_FILE}.new ${SRC_FILE} + fi +done + +for SRC_FILE in $(find . -path './thirdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name \ + '*apply_mit_liscense.sh' -not -name '*check_mit_liscense.sh' -and \( -name 'CMakeLists.txt' -or -name '*.cmake' \ + -or -name '*.py' -or -name '*.dockerfile' -or -name '*.yaml' \) ); do + sed -i '/\#\s*Microsoft\s*(c)/Id' ${SRC_FILE} + if !(grep -q "Copyright (c) Microsoft Corporation" "${SRC_FILE}"); then + cat maint/scripts/mit_liscense2.txt ${SRC_FILE} > ${SRC_FILE}.new + mv ${SRC_FILE}.new ${SRC_FILE} + fi +done + +for SRC_FILE in $(find . -path './thirdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name \ + '*apply_mit_liscense.sh' -not -name '*check_mit_liscense.sh' -name '*.sh' ); do + sed -i '/\#\s*Microsoft\s*(c)/Id' ${SRC_FILE} + if !(grep -q "Copyright (c) Microsoft Corporation" "${SRC_FILE}"); then + line=$(head -n 1 ${SRC_FILE}) + if [[ $line == "#!/bin/bash"* ]]; then + (echo ${line}; echo ''; cat maint/scripts/mit_liscense2.txt; echo "$(tail -n +2 "${SRC_FILE}")" ) > ${SRC_FILE}.new + else + cat maint/scripts/mit_liscense2.txt ${SRC_FILE} > ${SRC_FILE}.new + fi + mv ${SRC_FILE}.new ${SRC_FILE} + fi +done + +echo "Done." +popd > /dev/null +exit $EXITCODE diff --git a/maint/scripts/check_mit_license.sh b/maint/scripts/check_mit_license.sh new file mode 100755 index 000000000000..d75e679cced1 --- /dev/null +++ b/maint/scripts/check_mit_license.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +echo "Check MIT Liscense boilerplate..." +PWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# TO source code root +pushd "${PWD}/../../" > /dev/null + +EXITCODE=0 + +for SRC_FILE in $(find . -path './thirdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name '*apply_mit_liscense.sh' \ + -not -name '*check_mit_liscense.sh' -and \( -name 'CMakeLists.txt' -or -name '*.cpp' -or -name '*.cu' -or -name '*.h' -or -name '*.hpp' \ + -or -name '*.in' -or -name '*.py' -or -name '*.sh' -or -name '*.dockerfile' -or -name '*.yaml' \) ); do + if !(grep -q "Copyright (c) Microsoft Corporation." "${SRC_FILE}") || !(grep -q "Licensed under the MIT License." "${SRC_FILE}") \ + || (grep -q -i -P "Microsoft( |)\(c\)" "${SRC_FILE}") || (grep -q "Apache License" "${SRC_FILE}"); then + echo "[ERROR] Require: MIT Liscense biolerplate" "${SRC_FILE}" + EXITCODE=1 + fi +done + +echo "Done." +popd > /dev/null +exit $EXITCODE diff --git a/maint/scripts/installation.sh b/maint/scripts/installation.sh new file mode 100755 index 000000000000..827121fcb010 --- /dev/null +++ b/maint/scripts/installation.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# install torch +pip install torch==2.1.0 + +# install llvm +apt-get install llvm-dev + +# clone and build tvm +git clone https://github.com/LeiWang1999/tvm --recursive -b dev/fast_dlight --depth 1 3rdparty/tvm + +cd 3rdparty/tvm +mkdir build +cp cmake/config.cmake build +cd build +echo "set(USE_LLVM ON)" >> config.cmake && echo "set(USE_CUDA ON)" >> config.cmake + +cmake .. && make -j && cd ../../.. + +echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc +echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd)/python" >> ~/.bashrc diff --git a/maint/scripts/mit_liscense1.txt b/maint/scripts/mit_liscense1.txt new file mode 100644 index 000000000000..fc36ab244fad --- /dev/null +++ b/maint/scripts/mit_liscense1.txt @@ -0,0 +1,2 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. diff --git a/maint/scripts/mit_liscense2.txt b/maint/scripts/mit_liscense2.txt new file mode 100644 index 000000000000..59e481eb93dd --- /dev/null +++ b/maint/scripts/mit_liscense2.txt @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/python/bitblas/__init__.py b/python/bitblas/__init__.py new file mode 100644 index 000000000000..176d0ad79a69 --- /dev/null +++ b/python/bitblas/__init__.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""DLight package provides efficient schedules out-of-box for deep learning workloads.""" +from . import gpu +from .base import ( + fast_tune, + ApplyDefaultSchedule, + ApplyFastTuning, + BlockInfo, + IterInfo, + ScheduleRule, + normalize_prim_func, + try_inline, + try_inline_contiguous_spatial, +) + +from .relax import transform +from . import testing + +import logging +from tqdm import tqdm + + +# target logger into tqdm.write +class TqdmLoggingHandler(logging.Handler): + def __init__(self, level=logging.NOTSET): + super().__init__(level) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + except Exception: + self.handleError(record) + + +def set_log_level(level): + logger = logging.getLogger(__name__) + logger.setLevel(level) + + +def _init_logger(): + logger = logging.getLogger(__name__) + handler = TqdmLoggingHandler() + formatter = logging.Formatter( + fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%F %T" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.propagate = False + set_log_level(logging.INFO) + + +_init_logger() diff --git a/python/bitblas/base/__init__.py b/python/bitblas/base/__init__.py new file mode 100644 index 000000000000..122c44cbd71f --- /dev/null +++ b/python/bitblas/base/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Base infra""" +from .analysis import ( + BlockInfo, + IterInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, + is_broadcast_epilogue, + normalize_prim_func, +) +from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial +from .schedule_rule import ScheduleRule +from .transform import ApplyDefaultSchedule, ApplyFastTuning +from .utils import fast_tune, fast_tune_with_dynamic_range +from .roller import * diff --git a/python/bitblas/base/analysis.py b/python/bitblas/base/analysis.py new file mode 100644 index 000000000000..a4147181bd4b --- /dev/null +++ b/python/bitblas/base/analysis.py @@ -0,0 +1,342 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Analysis on TIR blocks, loops and functions.""" +from typing import List, Optional, Set, Union, Tuple, Dict +from typing_extensions import Literal +from dataclasses import dataclass +from enum import Enum + +from tvm import ir, tir, DataType +from tvm.ir import Range +from tvm.tir.analysis import undefined_vars +from tvm._ffi import get_global_func +from tvm.target.target import Target +from tvm.tir import Schedule, IterVar, Var, PrimExpr +from tvm.tir.schedule import BlockRV + + +def get_reduction_blocks(sch, blocks) -> bool: + # Get the main computation block + def is_reduction(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + def is_spatial(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all([is_reduction(block) or is_spatial(block) for block in blocks]): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction(block)] + if len(reduction_blocks) != 1: + return None + + return reduction_blocks + + +class IterInfo: + """Information about a loop/iter var.""" + + kind: Literal["S", "R", "O"] + var: tir.Var + _dom: tir.PrimExpr + loop_rv: tir.schedule.LoopRV + + def __init__( + self, + kind: Literal["S", "R", "O"], + var: tir.Var, + dom: tir.PrimExpr, + loop_rv: tir.schedule.LoopRV, + ): + """Construct an IterInfo object.""" + self.kind = kind + self.var = var + self._dom = dom + self.loop_rv = loop_rv + + @property + def dom(self) -> Union[int, tir.PrimExpr]: + """The iteration domain of the loop.""" + return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom + + def __str__(self) -> str: + return f'Iter("{self.kind}", {self.dom})' + + def __repr__(self) -> str: + return str(self) + + +class BlockInfo: + """Information about a TIR block.""" + + name: str + iters: List[IterInfo] + block_rv: tir.schedule.BlockRV + _reduction_block: bool + + def __init__( + self, + name: str, + iters: List[IterInfo], + block_rv: tir.schedule.BlockRV, + reduction_block: bool = False, + ): + """Construct a BlockInfo object.""" + self.name = name + self.block_rv = block_rv + self.iters = iters + self._reduction_block = reduction_block + + def dom(self) -> List[Union[int, tir.PrimExpr]]: + """The iteration domain of the block.""" + return [i.dom for i in self.iters] + + def dom_kind(self) -> str: + """The iteration domain kind of the block, for example, SSSS, SSSR.""" + return "".join(i.kind for i in self.iters) + + def is_injective(self) -> bool: + """Whether the block is injective, i.e. all its iteration domains are injective.""" + return all(k == "S" for k in self.dom_kind()) + + def is_elementwise(self, sch: tir.Schedule) -> bool: + """Whether the block is elementwise, i.e. trivial mapping between read/write region""" + + def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool: + return dom.min.same_as(var) and dom.extent == 1 + + if not self.is_injective(): + return False + block = sch.get(self.block_rv) + if len(block.reads) != 1 or len(block.writes) != 1: + return False + r_region = block.reads[0].region + w_region = block.writes[0].region + if len(r_region) != len(w_region): + return False + for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region): + if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range( + var, w_dom + ): + return False + return True + + def is_reduction(self) -> bool: + """Whether the block is a reduction workload.""" + # TODO(@junrushao): distinguish GEMV and reduction + return self._reduction_block + + def is_gemv(self) -> bool: + """Whether the block is a GEMV workload.""" + raise NotImplementedError + + def is_gemm(self) -> bool: + """Whether the block is a GEMM workload.""" + raise NotImplementedError + + def __str__(self) -> str: + return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})' + + def __repr__(self) -> str: + return str(self) + + +_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc") + + +def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]: + """Normalize the primfunc to normal form""" + try: + result = _normalize_prim_func(sch) + if result is None: + return None + except Exception: # pylint: disable=broad-except + return None + + def _iter_kind(i: tir.IterVar) -> str: + return { + tir.IterVar.DataPar: "S", + tir.IterVar.CommReduce: "R", + }.get(i.iter_type, "O") + + blocks: List[BlockInfo] = [] + for block, loops, iters, is_reduction in zip(*result): + blocks.append( + BlockInfo( + name=sch.get(block).name_hint, + iters=[ + IterInfo( + kind=_iter_kind(iter), # type: ignore + var=iter.var, + dom=iter.dom, + loop_rv=loop, + ) + for loop, iter in zip(loops, iters) + ], + block_rv=block, + reduction_block=is_reduction, + ) + ) + return blocks + + +def find_var_from_func(func, var: str): + for buffer in func.buffer_map.values(): + for i in buffer.shape: + if isinstance(i, tir.Var) and i.name == var: + return i + return None + + +def check_func_with_dynamic(func): + for buffer in func.buffer_map.values(): + for i in buffer.shape: + if isinstance(i, tir.Var): + return True + return False + + +def _assert_gpu_target(target: Target): + if "gpu" not in target.keys: + raise ValueError(f"Expect a GPU target, but got {target}") + + +def get_max_threads_per_block(target: Target) -> int: + _assert_gpu_target(target) + max_threads_per_block = None + for name in ["max_threads_per_block", "max_num_threads"]: + if max_threads_per_block is None: + max_threads_per_block = target.attrs.get(name, None) + if max_threads_per_block is None: + max_threads_per_block = 64 + return int(max_threads_per_block) + + +def get_max_shared_memory_per_block(target: Target) -> int: + _assert_gpu_target(target) + max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) + if max_shared_memory_per_block is None: + raise ValueError( + f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually" + ) + return int(max_shared_memory_per_block) + + +def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: + try: + block = sch.mod[func_name].body.block + except: + raise ValueError( + f"The function body is expected to be the root block, but got:\n" + f"{sch.mod[func_name].body}" + ) + return sch.get_block(block.name_hint) + + +def collect_block_iter_vars_used_in_access_region( + block: tir.Block, region: List[ir.Range] +) -> Set[tir.Var]: + """Collect the block iter variables used in the access region of a buffer region.""" + tir_vars = set() + for expr in region: + assert expr.extent == 1 + tir_vars |= collect_vars_used_in_prim_expr(expr.min) + tir_vars &= set(iter_var.var for iter_var in block.iter_vars) + return tir_vars + + +def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]: + """Collect the variables used in the PrimExpr.""" + tir_vars = set() + + def _collect_tir_var(expr): + if isinstance(expr, tir.Var): + tir_vars.add(expr) + + tir.stmt_functor.post_order_visit(expr, _collect_tir_var) + return tir_vars + + +def detect_dominant_read(block: tir.Block) -> tir.PrimExpr: + """Detect the dominant read indices in the block.""" + dominant_read = None + num_read_iters = -1 + for buffer_region in block.reads: + tir_vars = collect_block_iter_vars_used_in_access_region( + block, buffer_region.region + ) + if num_read_iters < len(tir_vars): + num_read_iters = len(tir_vars) + dominant_read = buffer_region + assert dominant_read is not None + (result,) = dominant_read.buffer.offset_of([e.min for e in dominant_read.region]) + return result + + +def is_broadcast_epilogue( + sch: tir.Schedule, + block: tir.schedule.BlockRV, + epilogue: tir.schedule.BlockRV, +) -> bool: + """Check if the epilogue block is a broadcast pattern""" + write_buffers = {r.buffer for r in sch.get(block).writes} + epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom != 1} + for buffer_region in sch.get(epilogue).reads: + if buffer_region.buffer not in write_buffers: + continue + tir_vars = collect_block_iter_vars_used_in_access_region( + sch.get(epilogue), buffer_region.region + ) + if len(tir_vars) < len(epilogue_iters): + return True + return False + + +def get_reduction_blocks( + sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] +) -> List[tir.schedule.BlockRV]: + # Get the main computation block + def is_reduction(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + def is_spatial(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all([is_reduction(block) or is_spatial(block) for block in blocks]): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction(block)] + if len(reduction_blocks) == 0: + return None + return reduction_blocks + + +def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int: + # gpu memory prefer 128 bits coalesced access (e.g. four banks) + # 128 bits + block_stmt + buffers: List[tir.Buffer] = [] + for read in block_stmt.reads: + buffers.append(read.buffer) + for write in block_stmt.writes: + buffers.append(write.buffer) + # pick the dtype with the largest bits + max_dtype_bits: int = 0 + for buffer in buffers: + max_dtype_bits = max(max_dtype_bits, DataType(buffer.dtype).bits) + return target_bits // max_dtype_bits diff --git a/python/bitblas/base/common_schedules.py b/python/bitblas/base/common_schedules.py new file mode 100644 index 000000000000..e1852e81c44e --- /dev/null +++ b/python/bitblas/base/common_schedules.py @@ -0,0 +1,163 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm common_schedules.py in dlight. +"""Common schedule strategies for TIR.""" +from typing import Callable, List + +from tvm import tir + +from .analysis import BlockInfo + + +def get_block( + sch: tir.Schedule, + blocks: List[BlockInfo], + name: str, +): + """Get the target block from a schedule. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to get target block. + name : str + The name of the target block. + + Returns + ------- + target_block : BlockRV + The target block. + """ + + target_block: tir.BlockRV = None + for block_info in blocks: + block = block_info.block_rv + if sch.get(block).name_hint == name: + target_block = block + return target_block + + +def get_output_blocks( + sch: tir.Schedule, + blocks: List[BlockInfo], +): + """Get the output blocks of a schedule. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to get output blocks. + blocks : List[BlockInfo] + The blocks to be analyzed. + + Returns + ------- + output_blocks : List[BlockInfo] + The output blocks. + """ + + # collect arguments buffer + func = sch.mod["main"] + args = list(func.buffer_map.values()) + + output_blocks = [] + for block_info in blocks: + block = block_info.block_rv + for write in sch.get(block).writes: + if write.buffer in args: + output_blocks.append(block) + + return output_blocks + + +def try_inline( + sch: tir.Schedule, + blocks: List[BlockInfo], +) -> List[BlockInfo]: + """Try to inline as many blocks as possible, and return the remaining blocks. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to inline blocks. + blocks : List[BlockInfo] + The blocks to be inlined. + + Returns + ------- + remaining : List[BlockInfo] + The remaining blocks that cannot be inlined. + """ + + def _trial(func: Callable): + for i, block in enumerate(blocks): + try: + func(block.block_rv) + except: # pylint: disable=bare-except + continue + return i + return None + + while True: + i = _trial(sch.compute_inline) + if i is None: + i = _trial(sch.reverse_compute_inline) + if i is None: + break + blocks.pop(i) + return blocks + + +def try_inline_contiguous_spatial( + sch: tir.Schedule, + block_infos: List[BlockInfo], +) -> List[BlockInfo]: + """Try to inline contiguous spatial blocks in a schedule + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to inline blocks. + block_infos : List[BlockInfo] + The blocks to be try. + + Returns + ------- + remaining : List[BlockInfo] + The remaining blocks that cannot be inlined. + """ + + if block_infos is None: + return None + results = [] + spatial_blocks = [] + block: BlockInfo + for block in block_infos: + if block.is_injective(): + spatial_blocks.append(block) + elif spatial_blocks: + results.extend(try_inline(sch, spatial_blocks)) + results.append(block) + spatial_blocks = [] + else: + results.append(block) + if spatial_blocks: + results.extend(try_inline(sch, spatial_blocks)) + return results diff --git a/python/bitblas/base/roller/__init__.py b/python/bitblas/base/roller/__init__.py new file mode 100644 index 000000000000..39180bf1a57d --- /dev/null +++ b/python/bitblas/base/roller/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .node import PrimFuncNode +from .config import Config +from .policy import DefaultPolicy, TensorCorePolicy +from .arch import Arch, CUDA diff --git a/python/bitblas/base/roller/arch/__init__.py b/python/bitblas/base/roller/arch/__init__.py new file mode 100644 index 000000000000..80293a1336b6 --- /dev/null +++ b/python/bitblas/base/roller/arch/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .arch_base import Arch +from .cuda import * +from .cpu import * + + +def get_arch(target: tvm.target.Target) -> Arch: + if target.kind.name == "cuda": + return CUDA(target) + elif target.kind.name == "llvm": + return CPU(target) + else: + raise ValueError(f"Unsupported target: {target.kind.name}") diff --git a/python/bitblas/base/roller/arch/arch_base.py b/python/bitblas/base/roller/arch/arch_base.py new file mode 100644 index 000000000000..8628badce1ae --- /dev/null +++ b/python/bitblas/base/roller/arch/arch_base.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import List + + +class Arch: + """ + Represents the architecture of a computing device, capturing various hardware specifications. + """ + + def __init__(self) -> None: + self.reg_cap: int = 0 # Register capacity: The amount of register memory available + self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available + self.compute_max_core: int = 0 # The maximum number of computing cores + self.warp_size: int = ( + 0 # The size of a warp, a group of threads that execute instructions in lockstep + ) + self.sm_partition: int = 0 # The number of streaming multiprocessor partitions + self.transaction_size: List[int] = [ + 0, + 0, + ] # The size of memory transactions, typically in bytes + self.max_smem_usage: int = 0 # The maximum shared memory usage allowed + self.bandwidth: List[int] = [ + 0, + 0, + ] # Bandwidth specifications, possibly including peak and sustained rates + self.platform: str = "unknown" # The platform or manufacturer of the device + self.compute_capability: str = ( + "unknown" # The compute capability, indicating the feature set and performance level + ) + self.l2_cache_size_bytes: int = 0 + # the number of transaction size in bytes + self.transaction_size: List[int] = [0, 0] # in bytes + # bandwidth in MB/s, will be used for recommend basic tile size + self.bandwidth: List[int] = [0, 0] + + def get_avaliable_tensorintrin_shapes(self): + raise NotImplementedError() + \ No newline at end of file diff --git a/python/bitblas/base/roller/arch/cpu.py b/python/bitblas/base/roller/arch/cpu.py new file mode 100644 index 000000000000..a90015717dce --- /dev/null +++ b/python/bitblas/base/roller/arch/cpu.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tvm +from tvm.target import Target +from .arch_base import Arch +from typing import List, Dict + + +# For LLVM Backend, we do not provide the detailed information of the CPU +# As the LLVM backend do not required tuning, just maintain the consistency +class CPU(Arch): + def __init__(self, target: Target): + self.target = target + device = tvm.runtime.cpu(0) + if not device.exist: + raise RuntimeError("Cannot find cpu device 0.") + self.device: tvm.runtime.Device = device + self.platform: str = "CPU" diff --git a/python/bitblas/base/roller/arch/cuda.py b/python/bitblas/base/roller/arch/cuda.py new file mode 100644 index 000000000000..ebe1a2ee8607 --- /dev/null +++ b/python/bitblas/base/roller/arch/cuda.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tvm +from tvm.target import Target +from .arch_base import Arch +from typing import List, Dict + +def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + +class TensorInstruction(object): + def __init__( + self, + name: str, + intrin_group: Dict, + shape: List[int], + ): + self.name: str = name + self.intrin_group: Dict = intrin_group + # only mantain the shape of M and N + self.shape: List[int] = shape + +class CUDA(Arch): + def __init__(self, target: Target): + self.target = target + self.sm_version = check_sm_version(self.target.arch) + device = tvm.runtime.cuda(0) + if not device.exist: + raise RuntimeError("Cannot find cuda device 0.") + self.device: tvm.runtime.Device = device + self.platform: str = "CUDA" + self.smem_cap = device.max_shared_memory_per_block + self.compute_max_core = device.multi_processor_count + self.warp_size = device.warp_size + self.compute_capability = device.compute_version.replace(".", "") + self.reg_cap: int = 65536 + self.max_smem_usage: int = 2 * self.smem_cap + self.sm_partition: int = 4 + self.l2_cache_size_bytes: int = target.l2_cache_size_bytes + # the number of transaction size in bytes + self.transaction_size: List[int] = [32, 128] # in bytes + # bandwidth in MB/s, will be used for recommend basic tile size + # TODO(lei): find some way to get the real bandwidth + # However, the ratio of bandwidth between different devices can + # be similar. The bandwidth can work for another devices as well. + self.bandwidth: List[int] = [750, 12080] + # get the available tensor instructions during runtime to avoid + # the dependency of the tensor intrinsics registration + self.available_tensor_instructions: List[TensorInstruction] = None + + def get_avaliable_tensorintrin_shapes(self): + from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group, get_mma_intrin_group + + self.available_tensor_instructions = ( + TensorInstruction("mma", get_mma_intrin_group, [16, 16]), + TensorInstruction("wmma", get_wmma_intrin_group, [16, 16]), + ) + return [t.shape for t in self.available_tensor_instructions] \ No newline at end of file diff --git a/python/bitblas/base/roller/bestfit.py b/python/bitblas/base/roller/bestfit.py new file mode 100644 index 000000000000..b0e541c68b74 --- /dev/null +++ b/python/bitblas/base/roller/bestfit.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benifit For Dlight Schedule""" +class Block: + def __init__(self, start, end, is_free): + self.start = start + self.end = end + self.is_free = is_free + + def size(self) -> int: + return self.end - self.start + + def merge(self, other): + assert self.is_free == other.is_free + self.start = min(self.start, other.start) + self.end = max(self.end, other.end) + + def __repr__(self) -> str: + return "".format(self.start, self.size()) + + +class BestFit: + def __init__(self, align=32): + self.limit = 0 + self.list = [] + self.align = align + + def malloc(self, size) -> Block: + size = (size + self.align - 1) // self.align * self.align + found = None + for block in self.list: + if block.is_free and block.size() >= size: + if not found or found.size() > block.size(): + found = block + if found: + found.is_free = False + remain = found.size() - size + if remain != 0: + found.end -= remain + self.list.insert( + self.list.index(found) + 1, Block(found.end, found.end + remain, True) + ) + return found + elif len(self.list) > 0 and self.list[-1].is_free: + add = size - self.list[-1].size() + self.list[-1].end += add + self.limit = self.list[-1].end + self.list[-1].is_free = False + return self.list[-1] + else: + block = Block(self.limit, self.limit + size, False) + self.list.append(block) + self.limit += size + return block + + def free(self, block: Block) -> None: + assert not block.is_free + idx = self.list.index(block) + self.list[idx] = Block(block.start, block.end, True) + if idx + 1 < len(self.list) and self.list[idx + 1].is_free: + self.list[idx].merge(self.list[idx + 1]) + self.list.pop(idx + 1) + if idx - 1 >= 0 and self.list[idx - 1].is_free: + self.list[idx].merge(self.list[idx - 1]) + self.list.pop(idx - 1) diff --git a/python/bitblas/base/roller/config.py b/python/bitblas/base/roller/config.py new file mode 100644 index 000000000000..f3ce60847d54 --- /dev/null +++ b/python/bitblas/base/roller/config.py @@ -0,0 +1,237 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Config definition for schedule""" +from typing import Dict, List, Optional, Tuple +from ..roller import PrimFuncNode +import numpy as np + + +class TensorCoreExtraConfig: + """ + This class is used to store extra information for tensorcore + """ + + def __init__( + self, + AS_shape: Tuple[int], + BS_shape: Tuple[int], + AF_shape: Tuple[int], + BF_shape: Tuple[int], + tc_axis: Tuple[int], + ) -> None: + self.AS_shape: Tuple[int] = AS_shape + self.BS_shape: Tuple[int] = BS_shape + self.AF_shape: Tuple[int] = AF_shape + self.BF_shape: Tuple[int] = BF_shape + self.tc_axis: Tuple[int] = tc_axis + + +class Stride: + """ + Manages stride information for a given axis of a tensor. + """ + + def __init__(self, stride: int = 1, ax: int = -1) -> None: + # which axis to put stride on + self._ax: int = int(ax) + # the stride size of the axis + self._stride: int = int(stride) + + @property + def ax(self) -> int: + return self._ax + + @property + def stride(self) -> int: + return self._stride + + def compute_strides_from_shape(self, shape: List[int]) -> List[int]: + ndim = len(shape) + strides = [1 for _ in shape] + for i in range(ndim - 2, -1, -1): + if i == self.ax: + strides[i] = self.stride + else: + strides[i] = int(strides[i + 1] * shape[i + 1]) + return strides + + def compute_elements_from_shape(self, shape: List[int]) -> int: + original_shape = np.prod(shape) + if not self.is_valid(): + strided_elem = original_shape + else: + assert self.ax < len(shape) + strided_elem = np.prod(shape[0 : self.ax + 1]) * self.stride + assert strided_elem >= original_shape + return int(strided_elem) + + def is_valid(self) -> bool: + return self.ax >= 0 + + def __repr__(self) -> str: + return f"" + + +class TileDict: + """ + Manages tiling information and configurations for computational tasks. + """ + + def __init__(self, output_tile) -> None: + self.output_tile = output_tile + # schedule config + self.tile_map = {} + self.rstep_map = {} + self.cached_tensors_map = {} + self.output_strides_map = {} + self.tensor_strides_map = {} + + # analysis + self.traffic = -1 + self.smem_cost = -1 + self.block_per_SM = -1 + self.num_wave = -1 + self.grid_size = -1 + self.valid = True + + def get_tile(self, func) -> List[int]: + return self.tile_map[func] + + def get_rstep(self, func) -> Dict[str, int]: + return self.rstep_map + + def __hash__(self) -> int: + return hash(tuple(self.output_tile)) + + +class IntrinInfo: + """ + The information of tensorcore intrinsic related infomation + """ + + def __init__( + self, + in_dtype: str, + out_dtype: str, + trans_b: bool, + smooth_a: bool = False, + smooth_b: bool = False, + ) -> None: + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.trans_a = False + self.trans_b = trans_b + self.smooth_a = smooth_a + self.smooth_b = smooth_b + + def __repr__(self) -> str: + return f"" + + +class Config(object): + """ + Central configuration class for managing various parameters of computational tasks. + """ + + def __init__(self) -> None: + self.arch = None + self.use_tc = None # todo(lei): this should be renamed. + + # spacial axes tiling info + self.block = [] + self.thread = [] + # special axes for tensorCore + self.warp = [] + # reduce axes tiling info + self.rstep = [] + self.reduce_thread = [] + self.rasterization_plan = None + self.cached_tensors = [] + self.output_strides = {} + self.schedule_stages = None + + # Experimental + self._raxis_order = [] + self._step = [] + self.vectorize: Dict[str, int] = {} + self.pipeline_stage = 1 + self.use_async = False + self.opt_shapes: Dict[str, int] = {} + self.intrin_info = IntrinInfo("float16", "float16", True) + self.shared_scope: str = "shared" + self.pass_context: Dict = {} + + def to_dict(self) -> Dict: + dic = {} + dic["block"] = self.block + if self.use_tc: + dic["warp"] = self.warp + else: + dic["thread"] = self.thread + dic["rstep"] = self.rstep + if np.prod(self.reduce_thread) > 1: + dic["reduce_thread"] = self.reduce_thread + if self.use_tc: + dic["use_tc"] = self.use_tc + if self.output_strides: + dic["strides"] = {} + for k, stride in self.output_strides.items(): + if stride.is_valid(): + dic["strides"][k] = stride + if len(dic["strides"]) == 0: + del dic["strides"] + if np.prod(self._step) > 1: + dic["step"] = self._step + if self._raxis_order != []: + dic["raxis_order"] = self._raxis_order + if self.vectorize != {}: + dic["vectorize"] = self.vectorize + return dic + + def from_dict(self, dic: Dict) -> "Config": + self.__init__() + if "use_tc" in dic: + self.use_tc = dic["use_tc"] + self.block = dic["block"] + if self.use_tc: + self.warp = dic["warp"] + else: + self.thread = dic["thread"] + self.rstep = dic["rstep"] + if "reduce_thread" in dic: + self.reduce_thread = dic["reduce_thread"] + else: + self.reduce_thread = [1 for _ in self.rstep] + if "strides" in dic: + self.output_strides = dic["strides"] + if "step" in dic: + self._step = dic["step"] + if "raxis_order" in dic: + self._raxis_order = dic["raxis_order"] + if "vectorize" in dic: + self.vectorize = dic["vectorize"] + return self + + @property + def raxis_order(self) -> List[int]: + if self._raxis_order != []: + return self._raxis_order + return list(range(len(self.rstep))) + + @property + def step(self) -> List[int]: + if self._step != []: + return self._step + return [1 for _ in self.block] + + def __repr__(self) -> str: + return str(self.to_dict()) + + def complete_config(self, node:PrimFuncNode): + # analysis pass context, for int8 mma, we should merge static shared memory + merge_static_smem = False + if self.use_tc and self.intrin_info.in_dtype == "int8": + merge_static_smem = True + self.pass_context = {"tir.merge_static_smem": merge_static_smem} + return self diff --git a/python/bitblas/base/roller/node.py b/python/bitblas/base/roller/node.py new file mode 100644 index 000000000000..97f7be917a0a --- /dev/null +++ b/python/bitblas/base/roller/node.py @@ -0,0 +1,372 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""PrimFunc Warpper and Block Infomation Analaysis""" + +import tvm +from tvm import tir +from tvm.tir import IterVar, Var, PrimFunc +from typing import Any, Iterable, Dict, List, Tuple +import functools +import numpy as np +from tvm.tir.schedule.schedule import BlockRV +from ..analysis import BlockInfo, get_reduction_blocks +from .. import analysis +from .. import normalize_prim_func +from .shape_inference import get_analyzer_by_tir + + +def pre_order_traverse(block_analyzer, blocks, func): + visited = set() + + def _traverse(block): + if block in visited: + return + visited.add(block) + for dep_blocks in block_analyzer.get_consumer_blocks(block): + _traverse(dep_blocks) + func(block) + + for block in blocks: + _traverse(block) + + +class BlockAnalyzer(object): + def __init__(self, sch) -> None: + self.sch: tir.Schedule = sch + self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch) + + def get_block_name(self, block: BlockRV) -> str: + return self.sch.get(block).name_hint + + def get_block_info(self, block: BlockRV) -> BlockInfo: + for block_info in self.block_infos: + if self.get_block_name(block) == block_info.name: + return block_info + return None + + def get_spatial_axis(self, block: BlockRV) -> List[IterVar]: + block_info = self.get_block_info(block) + axis = [] + for iter in block_info.iters: + if iter.kind == "S": + axis.append(iter) + return axis + + def get_reduce_axis(self, block: BlockRV) -> List[IterVar]: + block_info = self.get_block_info(block) + raxis = [] + for iter in block_info.iters: + if iter.kind == "R": + raxis.append(iter) + return raxis + + def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]: + buffers = [] + for read in self.sch.get(block).reads: + buffers.append(read.buffer) + return buffers + + def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]: + buffers = [] + for write in self.sch.get(block).writes: + buffers.append(write.buffer) + return buffers + + def get_buffers(self, block: BlockRV) -> List[tir.Buffer]: + return self.get_input_buffers(block) + self.get_output_buffers(block) + + def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]: + return self.sch.get_producers(block) + + def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]: + return self.sch.get_consumers(block) + + +class Node(object): + def __init__(self, tags: Dict = {}) -> None: + self._dtypes = [] + self._tag: Dict = {} + for tag in tags: + self.add_tag(tag, tags[tag]) + + def set_tag(self, k: str, v: Any = True) -> None: + self.add_tag(k, v) + + def add_tag(self, k: str, v: Any = True) -> None: + self._tag[k] = v + + def get_tag(self, k: str) -> Any: + if k not in self._tag: + return None + return self._tag[k] + + +class PrimFuncNode(Node): + def __init__(self, prim_func: PrimFunc, tags: Dict = {}) -> None: + super().__init__(tags) + self.prim_func = self._specialize_func(prim_func) + self.sch: tir.Schedule = tir.Schedule(self.prim_func) + self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch) + self.schedule_stages: List[BlockRV] = [] + self.blocks: List[BlockRV] = [] + self.output_blocks: List[BlockRV] = None + self.reduction_block: BlockRV = None + self.raxis = [] + self.input_buffers = [] + self.output_buffers = [] + self.buffers = [] + self.args = [] + self._analysis_funcinfo() + self.ana = get_analyzer_by_tir(self.block_analyzer, self.blocks) + + def _specialize_func(self, func: PrimFunc): + # Specialize the function to make it more friendly for analysis. + # set attrs + for k, v in func.attrs.items(): + self.set_tag(k, v) + if self.get_tag("is_speclized"): + return func + opt_shapes = self.get_tag("opt_shapes") + if opt_shapes: + for name, shape in opt_shapes.items(): + var = analysis.find_var_from_func(func, name) + if var is not None: + func = func.specialize({var: shape.astype(var.dtype)}) + return func + + def _analysis_funcinfo(self): + root_block = analysis.get_root_block(self.sch) + blocks = self.sch.get_child_blocks(root_block) + self.blocks = blocks + + self.output_blocks = self.sch.get_output_blocks(root_block) + reduction_blocks = get_reduction_blocks(self.sch, blocks) + if reduction_blocks is None: + self.reduction_block = None + self.schedule_stages.append(*self.output_blocks) + else: + # analysis on the last reduction block + self.reduction_block = reduction_blocks[-1] + # set raxis + reduce_block_info = self.block_analyzer.get_block_info(self.reduction_block) + for iter in reduce_block_info.iters: + if iter.kind == "R": + self.raxis.append(iter) + self.schedule_stages.append(self.reduction_block) + + # collect output buffers + for output_block in self.output_blocks: + for write in self.sch.get(output_block).writes: + if write not in self.output_buffers: + self.output_buffers.append(write.buffer) + + for param in self.prim_func.params: + if param not in self.prim_func.buffer_map: + # in case of dynamic symbolic may in params + continue + buffer = self.prim_func.buffer_map[param] + if buffer not in self.output_buffers: + self.input_buffers.append(buffer) + + self.args = self.input_buffers + self.output_buffers + self.buffers = [buffer for buffer in self.prim_func.buffer_map.values()] + + # set dtype + self.set_dtype(tvm.DataType(self.output_buffers[0].dtype)) + + def get_opt_shape(self, name) -> int: + opt_shapes = self.get_tag("opt_shapes") + if opt_shapes is None: + return None + return opt_shapes[name] + + def extent_warpper(self, value) -> int: + if isinstance(value, tvm.tir.Var): + return self.get_opt_shape(value.name) + elif isinstance(value, tvm.tir.IntImm): + return int(value) + else: + return value + + @functools.lru_cache() + def get_space_dim(self) -> List[int]: + dim_size = [] + if self.reduction_block: + block_info = self.block_analyzer.get_block_info(self.reduction_block) + for iter in block_info.iters: + if iter.kind == "S": + if isinstance(iter.dom.extent, tvm.tir.IntImm): + dim_size.append(int(iter.dom.extent)) + else: + assert isinstance(iter.dom.extent, tvm.tir.Var) + dim_size.append(self.get_opt_shape(iter.dom.extent.name)) + else: + # assume outer stage has the same shape + loops = self.sch.get_loops(self.schedule_stages[0]) + for loop in loops: + dim_size.append(int(self.sch.get(loop).extent)) + return [int(x) for x in dim_size] + + def set_dtype(self, dtype: tvm.DataType, id=0) -> None: + assert isinstance(dtype, tvm.DataType), type(dtype) + if dtype == tvm.DataType("bool"): + dtype = tvm.DataType("int8") + if len(self._dtypes) <= id: + self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)]) + elif self._dtypes[id] is not None: + assert self._dtypes[id] == dtype, (self._dtypes, dtype) + self._dtypes[id] = dtype + + def get_dtype(self, id=0) -> tvm.DataType: + return self._dtypes[id] + + def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType: + return tvm.DataType(buffer.dtype) + + def propogate(self, tile, rstep={}, targets=None): + shape = { + self.block_analyzer.get_output_buffers(block)[0].name: [ + tvm.arith.ConstIntBound(0, val - 1) for val in tile + ] + for block in self.schedule_stages + } + return self.ana.infer(shape, rstep, targets) + + def propogate_inputs(self, tile, rstep={}) -> List[List[int]]: + read_idx_offset = len(self.input_buffers) + targets = [t.name for t in self.args[:read_idx_offset]] + shapes, intermediate_bind = self.propogate(tile, rstep, targets) + results = [] + for i, arg in enumerate(self.args[:read_idx_offset]): + if arg.name in intermediate_bind: + results.append(shapes[arg.name]) + continue + # should not exceed original shape + trimmed_shape = [ + self.extent_warpper(i) + for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) + ] + results.append(trimmed_shape) + return results + + def propogate_outputs(self, tile, rstep={}) -> List[List[int]]: + read_idx_offset = len(self.input_buffers) + targets = [t.name for t in self.args[read_idx_offset:]] + shapes, _ = self.propogate(tile, rstep, targets) + results = [] + for i, arg in enumerate(self.args[read_idx_offset:]): + # should not exceed original shape + trimmed_shape = list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) + results.append(trimmed_shape) + return results + + def propogate_reduction_inputs(self, shape, rstep={}) -> Dict[str, List[int]]: + if self.reduction_block is None: + return {} + targets = [b.name for b in self.block_analyzer.get_input_buffers(self.reduction_block)] + results, _ = self.propogate(shape, rstep, targets) + return results + + def get_reduce_inputs_dtype(self): + if self.reduction_block is None: + return {} + return { + b.name: tvm.DataType(b.dtype) + for b in self.block_analyzer.get_input_buffers(self.reduction_block) + } + + @functools.lru_cache() + def infer_tensorcore_axis(self) -> Tuple[int]: + # axis is fixed for one expression, so only inference and cached + assert self.get_tag("tensorcore_config") + + C_ax_m, C_ax_n = self.get_tag("tensorcore_config") + wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok + + def get_cl_shapes(c_ax_m, c_ax_n): + output_buffer_shape = ( + self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape + ) + valid_region = [] + for region in output_buffer_shape: + if region.value == 1: + continue + valid_region.append(region) + + num_nvalid_regions = len(output_buffer_shape) - len(valid_region) + + spatial_dim = self.get_space_dim() + assert len(valid_region) == len( + spatial_dim + ), f" {valid_region} mismatch with {spatial_dim}" + cl_shapes = [1] * len(spatial_dim) + cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m + cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n + self.set_tag("tensorcore_config", [s - num_nvalid_regions for s in [c_ax_m, c_ax_n]]) + return cl_shapes + + CL_shape = get_cl_shapes(C_ax_m, C_ax_n) + shapes = self.propogate_reduction_inputs(CL_shape, {x.var.name: 1 for x in self.raxis}) + A_deps, B_deps = shapes.values() + A_ax_m = A_deps.index(wmma_m) + B_ax_n = B_deps.index(wmma_n) + + CL_shape = [1] * len(self.get_space_dim()) + shapes = self.propogate_reduction_inputs(CL_shape, {x.var.name: wmma_k for x in self.raxis}) + A_deps, B_deps = shapes.values() + A_ax_k = len(A_deps) - 1 - A_deps[::-1].index(wmma_k) + B_ax_k = len(B_deps) - 1 - B_deps[::-1].index(wmma_k) + tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n) + return tc_axis + + def footprint(self, shape, rstep, stride_map={}) -> int: + result = 0 + shapes, _ = self.propogate(shape, rstep) + + def is_broadcast_pattern(buffer, output_buffer): + return ( + buffer in self.args + and len(shapes[output_buffer.name]) > len(shapes[buffer.name]) + and np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name]) + ) + + def is_after_reduce_stage(block): + if not self.reduction_block: + return False + reduce_dependent_blocks = getattr(self, "reduce_dependent_blocks", None) + if reduce_dependent_blocks is None: + reduce_dependent_blocks = set() + pre_order_traverse( + self.block_analyzer, + [self.reduction_block], + lambda block: reduce_dependent_blocks.add(block), + ) + self.reduce_dependent_blocks = reduce_dependent_blocks + return block not in reduce_dependent_blocks + + # compute cached stages + cached_tensor = [] + for block in self.blocks: + output_buffer = self.block_analyzer.get_output_buffers(block)[0] + for buffer in self.block_analyzer.get_input_buffers(block): + cache = buffer.name not in cached_tensor and ( + is_broadcast_pattern(buffer, output_buffer) + or self.block_analyzer.get_block_info(block).is_reduction + ) + if not cache: + continue + cached_tensor.append(buffer.name) + if is_after_reduce_stage(block): + continue # cache after reduce op can often reuse buffer in reduce stage + + if buffer.name in stride_map: + num_elem = stride_map[buffer.name].compute_elements_from_shape( + shapes[buffer.name] + ) + else: + num_elem = np.prod(shapes[buffer.name]) + buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8) + buffer_len = (buffer_len + 31) // 32 * 32 + result += buffer_len + return result, cached_tensor diff --git a/python/bitblas/base/roller/policy/__init__.py b/python/bitblas/base/roller/policy/__init__.py new file mode 100644 index 000000000000..09ed1d51b130 --- /dev/null +++ b/python/bitblas/base/roller/policy/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .default import DefaultPolicy +from .tensorcore import TensorCorePolicy diff --git a/python/bitblas/base/roller/policy/common.py b/python/bitblas/base/roller/policy/common.py new file mode 100644 index 000000000000..9141550c8003 --- /dev/null +++ b/python/bitblas/base/roller/policy/common.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import List +import numpy as np + + +def get_all_factors(n: int) -> List[int]: + # Calculate the square root of n and round it up to the nearest integer + n0 = int(np.ceil(np.sqrt(n))) + + # Find all divisors of n that are less than n0 + val = np.where(n % np.arange(1, n0) == 0)[0] + 1 + + # If n is a perfect square, add the square root to the list of factors + mid = np.array([], dtype=int) if n0 * n0 != n else [n0] + + # Combine the factors and their corresponding larger pair factors + return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])] + + +def factorize(n: int) -> List[int]: + i = 2 # Start with the smallest prime number + result = [] + + # Iterate through numbers to find factors + while n > 1: + if n % i == 0: # If i is a factor of n + n //= i # Divide n by i and keep the integer part + result.append(i) + else: + i += 1 # Try the next number + return result + + +def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int: + # If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension + if subtensor[-1] != tensor[-1] or len(subtensor) == 1: + return subtensor[-1] + else: + # Recursively calculate the coalesced factor for the remaining dimensions + return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1]) + + +def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int: + # Calculate the total number of elements in the subtensor + bytes = int(np.prod(subtensor)) + + if bytes == 0: + return 0 + + # Calculate the coalesced factor for the subtensor + factor = int(coalesced_factor(subtensor, tensor)) + + # Compute the shape of the coalesced tensor + return transaction_size * bytes / min(transaction_size, factor) diff --git a/python/bitblas/base/roller/policy/default.py b/python/bitblas/base/roller/policy/default.py new file mode 100644 index 000000000000..ac85921cea3d --- /dev/null +++ b/python/bitblas/base/roller/policy/default.py @@ -0,0 +1,770 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Policy for cuda core schedule""" +import functools +import math +from queue import PriorityQueue +from typing import Iterable, Dict, List + +import numpy as np +import tvm + + +from ..arch import Arch +from ..bestfit import BestFit +from ..config import Config, Stride, TileDict +from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors +from ..node import PrimFuncNode +from ..rasterization import * + + +class DefaultPolicy: + """ + Default Policy for fastdlight, a heuristic plan that tries to + minimize memory traffic and maximize parallelism.for Dlight Schedule. + """ + + def __init__(self, func: tvm.tir.PrimFunc, arch: Arch, tags: Dict = {}) -> None: + self.arch = arch + self.prim_func_node = PrimFuncNode(func, tags) + self.ordered_nodes = [self.prim_func_node] + self.output_nodes = [self.prim_func_node] + + def emit_config(self, topk: int) -> List[Config]: + base_tile = self.get_base_tile() + if base_tile is None: + return [] + + rstep_map = self._assign_reduce_step(self.prim_func_node) + smem_tile_condidates = self.dfs_smem_tile(base_tile, rstep_map) + results = [] + for td in smem_tile_condidates: + if not self.check_tile_shape_isvalid(td): + continue + + self._expand_reduce_axis(td) + for codegen_dicts in self.assign_block_size(td): + results.append(codegen_dicts) + if len(results) >= topk: + break + if len(results) >= topk: + break + return results + + def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]: + _steps = [get_all_factors(n) for n in self.prim_func_node.get_space_dim()] + steps = [step[step.index(t) :] for step, t in zip(_steps, init_tile)] + for i in range(len(steps)): + added = list( + filter( + lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i], + [2, 4, 8, 16, 32], + ) + ) + steps[i].extend(added) + steps[i] = sorted(steps[i]) + visited_tiles = {} + queue = PriorityQueue() + + def prio(td: TileDict): + return (td.traffic + 1) * td.num_wave + + def add_to_queue(tile): + if tuple(tile) in visited_tiles: + return + td = self.compute_tile_dict(tile, rstep_map) + visited_tiles[tuple(tile)] = td + if td.valid: + queue.put([prio(td), tile]) + + add_to_queue(init_tile) + while not (queue.empty() or len(visited_tiles) > 2000): + _, tile = queue.get() + dim_ids = [step.index(t) for step, t in zip(steps, tile)] + for i in reversed(range(len(dim_ids))): + if dim_ids[i] + 1 < len(steps[i]): + new_tile = tile.copy() + new_tile[i] = steps[i][dim_ids[i] + 1] + add_to_queue(new_tile) + + visited_tiles = filter(lambda td: td.valid, visited_tiles.values()) + sorted_tiles = sorted(visited_tiles, key=lambda td: prio(td)) + return sorted_tiles + + def get_base_tile(self): + """ + Gets the minimum tile configuration that satisfies no redundancy in computation. + + Returns + ------- + List[int] + The base tile configuration, which is a list of 1s equal in length to the space dimensions + of the primary function node. + """ + shape = self.prim_func_node.get_space_dim() + base_tile = [1 for _ in shape] + + return base_tile + + # handles multiple output cases + def _get_output_tile_map(self, tile): + """ + Handles multiple output cases by mapping output nodes to their respective tile configurations. + + Parameters + ---------- + tile : List[int] + The tile configuration. + + Returns + ------- + Dict + A dictionary mapping the primary function node to its corresponding tile configuration + based on the output nodes' space dimensions. + """ + tile_map = {} + tile_map[self.prim_func_node] = [ + tile[i] + * self.prim_func_node.get_space_dim()[i] + // self.output_nodes[0].get_space_dim()[i] + for i in range(len(tile)) + ] + return tile_map + + def score_block_size(self, n): + """ + Scores a block size based on its efficiency and fit relative to the architecture's warp size and SM partition. + + Parameters + ---------- + n : int + The block size to score. + + Returns + ------- + Tuple[float, float] + A tuple containing two scores representing efficiency and fit, respectively. + """ + num_wrap = (n + self.arch.warp_size - 1) // self.arch.warp_size + r1 = max(num_wrap / self.arch.sm_partition, self.arch.sm_partition / num_wrap) + r2 = (num_wrap * self.arch.warp_size - n) / n + return (r1, r2) + + def get_block_size(self, n): + """ + Determines the optimal block size for a given constraint, based on scoring various factors. + + Parameters + ---------- + n : int + The constraint size. + + Returns + ------- + int + The optimal block size chosen from the factors of n, constrained by a maximum of 1024 and + scored by the `score_block_size` method. + """ + factors = get_all_factors(n) + factors = list(filter(lambda x: x <= 1024, factors)) + factor_ordered = sorted(factors, key=self.score_block_size) + return factor_ordered[0] + + def get_node_reduce_step_candidates(self, node: PrimFuncNode): + """ + Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2. + + Parameters + ---------- + node : PrimFuncNode + The node for which to calculate reduction step candidates. It contains reduction axes (raxis) + with their domains (dom.extent). + + Returns + ------- + Dict[str, List[int]] + A dictionary mapping axis variable names to lists of step candidates. For each axis in the node, + this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2 + as step candidates; for others, it uses all factors of the domain. + """ + + results = {} + for k_iter in node.raxis: + all_factors = get_all_factors(int(k_iter.dom.extent)) + if len(all_factors) == 2 and int(k_iter.dom.extent) > 64: + all_factors = [1] + while all_factors[-1] * 2 < int(k_iter.dom.extent): + all_factors.append(all_factors[-1] * 2) + results[k_iter.var.name] = all_factors + return results + + def _assign_reduce_step(self, node: PrimFuncNode): + """ + Assigns an optimal reduction step for the given PrimFuncNode. + + Parameters + ---------- + node : PrimFuncNode + The node for which the reduction step is to be assigned. + + Returns + ------- + Dict + A dictionary mapping reduction axis variable names to their optimal reduction steps. + """ + if node.reduction_block is None: + return {} + + raxis = node.raxis + tile = [1] * len(node.get_space_dim()) + all_steps = self.get_node_reduce_step_candidates(node) + + def sim(a: int, b: int): + return (2 * a * b) / (a * a + b * b) + + def _score(rstep_id): + rstep = {k: all_steps[k][rstep_id[k]] for k in rstep_id} + score = 0 + shape = node.propogate_inputs(tile, rstep=rstep) + for i, input_buffer in enumerate(node.input_buffers): + read_transaction_elements = self.arch.transaction_size[1] // ( + (node.get_buffer_dtype(input_buffer).bits + 7) // 8 + ) + score += sim( + int(coalesced_factor(shape[i], input_buffer.shape)), + read_transaction_elements, + ) + return score + + def _enlarge(rstep_id): + candidates = [] + candidates.append((rstep_id, _score(rstep_id))) + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + best = max(candidates, key=lambda x: x[1]) + return best + + # enlarge rstep to ensure read is coaleased + cur_rstep_id = {ax.var.name: 0 for ax in raxis} + cur_score = _score(cur_rstep_id) + while True: + if cur_score == 0: + break + new_rstep, new_score = _enlarge(cur_rstep_id) + if new_score <= cur_score: + break + else: + cur_rstep_id, cur_score = new_rstep, new_score + rstep = {k: all_steps[k][cur_rstep_id[k]] for k in cur_rstep_id} + return rstep + + def _expand_reduce_axis(self, td: TileDict): + """ + Expands the reduction axis in the TileDict based on shared memory limits. + + Parameters + ---------- + td : TileDict + The TileDict object to be optimized. + + Returns + ------- + None + This function modifies the TileDict in place. + """ + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) + rstep_map = td.rstep_map.copy() + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + + def _score(rstep_id): + rstep = { + k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis + } + score = 0 + shape = node.propogate_inputs(td.get_tile(node), rstep=rstep) + for i, input_buffer in enumerate(node.input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = { + k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis + } + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map = { + k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis + } + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = self._compute_shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = { + k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis + } + return rstep + + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map) + rstep_map = rstep + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + + def _compute_memory_traffic(self, output_tile): + """ + Computes the memory traffic for a given output tile configuration. + + Parameters + ---------- + output_tile : List[int] + The output tile configuration. + + Returns + ------- + Tuple[int, Dict] + The total memory traffic and a map of operation tiles. + """ + op_tile_map = self._get_output_tile_map(output_tile) + traffic = 0 + for node in reversed(self.ordered_nodes): + tile = op_tile_map[node] + input_shapes = node.propogate_inputs(tile) + output_shapes = node.propogate_outputs(tile) + for i, buffer in enumerate(node.input_buffers): + nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8 + read_transaction_elements = self.arch.transaction_size[1] // nbytes + traffic += ( + coalesced_tensor_shape(input_shapes[i], buffer.shape, read_transaction_elements) + * nbytes + ) + for i, buffer in enumerate(node.output_buffers): + nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8 + write_transaction_elements = self.arch.transaction_size[0] // nbytes + traffic += ( + coalesced_tensor_shape( + output_shapes[i], buffer.shape, write_transaction_elements + ) + * nbytes + ) + return traffic, op_tile_map + + def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): + """ + Infers the shared memory usage of a node given a TileDict configuration. + + Parameters + ---------- + td : TileDict + The TileDict object containing the tile configuration. + node : PrimFuncNode + The node for which to infer the shared memory usage. + + Returns + ------- + int + The estimated amount of shared memory used by the node. + """ + return node.footprint(td.get_tile(node), td.get_rstep(node), td.tensor_strides_map[node]) + + def _compute_shared_memory_usage(self, td: TileDict): + """ + Computes the stride map for a given node and TileDict configuration. + + Parameters + ---------- + node : PrimFuncNode + The node for which to compute the stride map. + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + Tuple[Dict, Dict] + The output strides and tensor strides. + """ + self._compute_stride_map(td) + allocator = BestFit() + block_map = {} + cached_tensors_map = {} + + node_internal_bytes, cached_tensors_map[self.prim_func_node] = self.infer_node_smem_usage( + td, self.prim_func_node + ) + block = allocator.malloc(node_internal_bytes) + allocator.free(block) + assert len(block_map) == 0 + return allocator.limit, cached_tensors_map + + def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): + """ + Computes the stride map for a given node based on the TileDict configuration. + + Parameters + ---------- + node : PrimFuncNode + The node for which to compute the stride map. + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + Tuple[Dict, Dict] + A tuple of dictionaries containing the output strides and tensor strides. + """ + output_strides = { + int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) + } + tensor_strides = {} + return output_strides, tensor_strides + + def _compute_stride_map(self, td: TileDict): + """ + Computes the stride map for all nodes in a TileDict. + + Parameters + ---------- + td : TileDict + The TileDict object for which to compute the stride maps. + + Returns + ------- + None + This function updates the TileDict object in place with the computed stride maps. + """ + output_strides_map = {} + tensor_strides_map = {} + for node in self.ordered_nodes: + output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map( + node, td + ) + td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map + + def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict: + """ + Computes and returns a TileDict object for a given output tile configuration and reduction step map. + + Parameters + ---------- + output_tile : List[int] + The output tile configuration. + rstep_map : Dict + The reduction step map. + + Returns + ------- + TileDict + A TileDict object containing the computed tile configuration, memory traffic, shared memory cost, + grid size, and other related parameters. + """ + td = TileDict(output_tile) + td.rstep_map = rstep_map + td.traffic, td.tile_map = self._compute_memory_traffic(output_tile) + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + if td.smem_cost > self.arch.smem_cap: + td.valid = False + return td + output_shape = self.output_nodes[0].get_space_dim() + td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)])) + # estimated reg usage + reg_usage = int( + 2 + * max( + [ + np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 + for node in self.ordered_nodes + ] + ) + ) + if reg_usage > self.arch.reg_cap: + td.valid = False + return td + td.block_per_SM = min( + self.arch.max_smem_usage // max(td.smem_cost, 1), + self.arch.reg_cap // max(reg_usage, 1), + self.arch.sm_partition, + ) + td.num_wave = int(np.ceil(td.grid_size / int(td.block_per_SM * self.arch.compute_max_core))) + return td + + def check_tile_shape_isvalid(self, td: TileDict) -> bool: + """ + Checks if the tile shapes in the TileDict are valid for the nodes in this context. + + Parameters: + - td (TileDict): The TileDict object containing tile shapes and other configurations. + + Returns: + - bool: True if all tile shapes are valid, False otherwise. + """ + for node in self.ordered_nodes: + if np.prod(td.get_tile(node)) == 0: + return False + node_grid_size = np.prod( + [(y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim())] + ) + if node_grid_size != td.grid_size: + return False + if ( + hasattr(node, "reduce_op") + and node.reduce_op is not None + and len(node.reduce_op.axis) == len(td.output_tile) + ): + for i, tile_extent in enumerate(td.output_tile): + if node.reduce_op.axis[i].dom.extent % tile_extent: + return False + + return True + + def recommend_block_size(self, td: TileDict) -> List[int]: + """ + Recommends optimal block sizes based on the TileDict configuration. + + Parameters + ---------- + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + List[int] + A list of recommended block sizes sorted based on their score. + """ + node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes] + max_block_size = functools.reduce(math.gcd, node_space_sizes) + + if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min( + node_space_sizes + ): + node_reduce_sizes = [ + int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes + ] + total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)] + max_possible_size = functools.reduce(math.gcd, total_sizes) + possible_block_sizes = list( + filter( + lambda x: x % max_block_size == 0 and x <= 1024, + get_all_factors(max_possible_size), + ) + ) + possible_block_sizes = list( + filter( # either be a factor of space or cover fully cover the space + lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]), + possible_block_sizes, + ) + ) + factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) + return factor_ordered + else: + possible_block_sizes = get_all_factors(max_block_size) + possible_block_sizes = list(filter(lambda x: x <= 1024, possible_block_sizes)) + factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) + return factor_ordered + + def assign_block_size(self, td: TileDict, topk=1): + """ + Assigns block sizes to the TileDict based on the recommended block sizes. + + Parameters + ---------- + td : TileDict + The TileDict object to assign block sizes to. + topk : int, optional + The number of top block sizes to consider. + + Yields + ------- + Dict + The block size assignment for the primary function node. + """ + block_size_ordered = self.recommend_block_size(td) + for block_size in block_size_ordered: + result = {} + failed = False + result = self._assign_block_size(self.prim_func_node, td, block_size) + if result is None: + failed = True + break + if failed: + continue + else: + yield result + topk -= 1 + if topk == 0: + break + + def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): + """ + Assigns a block size to a given PrimFuncNode based on the TileDict configuration and the specified block size. + + Parameters + ---------- + node : PrimFuncNode + The node to assign the block size to. + td : TileDict + The TileDict object containing the tile configuration. + block_size : int + The block size to be assigned. + + Returns + ------- + Config + A Config object containing the assigned block size and other related settings. + """ + tile, rsteps = td.get_tile(node), td.get_rstep(node) + factors = factorize(block_size) + cur_threads = [1 for _ in tile] + reduce_thread = {k: 1 for k in rsteps} + ndim = len(tile) + + def _score(node, thread): # small is better + score = 0 + block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] + shape = node.propogate_inputs(block_tile) + for i, buffer in enumerate(node.input_buffers): + score += np.prod(shape[i]) / self.arch.bandwidth[1] + for buffer in node.output_buffers: + score += coalesced_tensor_shape(thread, buffer.shape, 8) / self.arch.bandwidth[0] + return score + + for factor in reversed(factors): + score_map = {} + for i in range(ndim): + if cur_threads[i] >= tile[i]: + continue + if (tile[i] % (cur_threads[i] * factor)) != 0: + continue + cur_threads[i] *= factor + score_map[i] = (_score(node, cur_threads), i) + cur_threads[i] //= factor + if len(score_map) > 0: + # assign to space axis + dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) + cur_threads[dim_order[0]] *= factor + else: + # assign to reduce axis + target_ax = None + for ax, ax_len in reversed(list(rsteps.items())): + if ax_len % (reduce_thread[ax] * factor) == 0: + target_ax = ax + break + assert target_ax + reduce_thread[target_ax] *= factor + + codegen_dict = Config() + codegen_dict.block = tile + codegen_dict.thread = cur_threads + codegen_dict.rstep = [rsteps[ax.var.name] for ax in node.raxis] + codegen_dict.reduce_thread = [reduce_thread[ax.var.name] for ax in node.raxis] + codegen_dict.cached_tensors = td.cached_tensors_map[node] + codegen_dict.rasterization_plan = self.plan_rasterization(td) + + if node.get_dtype().bits == 16: # set step=2 for 16bit case to ensure coalesced access + codegen_dict._step = [1 for _ in range(ndim)] + for i in reversed(range(ndim)): + if codegen_dict.block[i] // codegen_dict.thread[i] % 2 == 0: + codegen_dict._step[i] = 2 + break + elif node.get_dtype().bits == 8: # set step=4 for 8bit case to ensure coalesced access + codegen_dict._step = [1 for _ in range(ndim)] + for i in reversed(range(ndim)): + if codegen_dict.block[i] // codegen_dict.thread[i] % 4 == 0: + codegen_dict._step[i] = 4 + break + # Plan vectorize + codegen_dict.vectorize = self._plan_vectorize(node, td, block_size) + codegen_dict.arch = self.arch + codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes") + return codegen_dict + + def _plan_vectorize(self, node: PrimFuncNode, td: TileDict, block_size: int): + """ + Plans vectorization for a given PrimFuncNode based on the TileDict configuration and block size. + + Parameters + ---------- + node : PrimFuncNode + The node for which to plan vectorization. + td : TileDict + The TileDict object containing the tile configuration. + block_size : int + The block size used for vectorization planning. + + Returns + ------- + Dict + A dictionary mapping tensors to their vectorization size. + """ + + def is_cont(shape, vec): + if len(shape) == 0: + return vec == 1 + last = shape[-1] + if last == 1: + return is_cont(shape[0:-1], vec // last) + else: + return last % vec == 0 + + def is_shape_aligned(shape, factor): + return int(np.prod(shape)) % factor == 0 + + def is_type_allowed(dtype, vec): + return dtype.bits * vec <= 128 + + vectorize_sizes = [16, 8, 4, 2] + dtypes = node.get_reduce_inputs_dtype() + shapes = node.propogate_reduction_inputs(td.get_tile(node), td.get_rstep(node)) + vectorize_result = {} + for tensor, shape in shapes.items(): + for v in vectorize_sizes: + if ( + is_shape_aligned(shape, block_size * v) + and is_cont(shape, v) + and is_type_allowed(dtypes[tensor], v) + ): + vectorize_result[tensor] = v + break + return vectorize_result + + def plan_rasterization(self, td: TileDict): # pylint: disable=unused-argument + """ + Plans the rasterization for the given TileDict. This function is not implemented yet. + + Parameters + ---------- + td : TileDict + The TileDict object to plan rasterization for. + + Raises + ------- + RasterRationPlan + This function is not implemented yet. + """ + return NoRasterization() diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py new file mode 100644 index 000000000000..90f432ebc594 --- /dev/null +++ b/python/bitblas/base/roller/policy/tensorcore.py @@ -0,0 +1,349 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Policy for tensorcore schedule""" +import tvm +from typing import Dict, List, Tuple +import numpy as np + +from ..arch import Arch +from ..config import Config, Stride, TileDict, IntrinInfo +from ..node import PrimFuncNode +from .common import coalesced_factor, factorize, get_all_factors +from .default import DefaultPolicy +from ..rasterization import * + + +class TensorCorePolicy(DefaultPolicy): + def __init__(self, func: tvm.tir.PrimFunc, arch: Arch, tags: Dict = {}) -> None: + super().__init__(func, arch, tags) + # this is the trick for wmma. + # However, for int8 mma, the wmma_k should be 32. + self.wmma_k = 16 + self.pipeline_stage: int = 1 + self.use_async_copy: bool = False + self._legalize_info() + + def _legalize_info(self): + pipleline_stage = self.prim_func_node.get_tag("pipeline_stage") + if pipleline_stage: + self.pipeline_stage = pipleline_stage + else: + if self.arch.compute_capability == "sm_80": + self.pipeline_stage = 2 + else: + self.pipeline_stage = 1 + use_async_copy = self.prim_func_node.get_tag("use_async_copy") + if use_async_copy: + self.use_async_copy = use_async_copy + else: + if self.arch.compute_capability == "sm_80": + self.use_async_copy = 1 + else: + self.use_async_copy = 0 + + def _compute_tc_strides( + self, node: PrimFuncNode, tile: List[int], rstep: Dict[str, int] = {} + ) -> Tuple[Stride, Stride, Stride]: + # strides was used for shared memory padding. which is necessary for avoiding + # shared memory load bank conflict when we do not applying tensorcore layout. + shapes = node.propogate_reduction_inputs(tile, rstep) + AS_shape, BS_shape = shapes.values() + CS_shape = tile + A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n = node.infer_tensorcore_axis() + + # applying strides + # TODO(leiwang1999): offset should be dynamically set. we can use tag -> enable_offset to control this option.. + offset = 8 + A_high_ax = min(A_ax_m, A_ax_k) + B_high_ax = min(B_ax_n, B_ax_k) + C_high_ax = min(C_ax_m, C_ax_n) + A_stride = Stride( + stride=np.prod(AS_shape[A_high_ax + 1 :]) + offset, ax=A_high_ax + ) + B_stride = Stride( + stride=np.prod(BS_shape[B_high_ax + 1 :]) + offset, ax=B_high_ax + ) + C_stride = Stride( + stride=np.prod(CS_shape[C_high_ax + 1 :]) + offset, ax=C_high_ax + ) + return A_stride, B_stride, C_stride + + def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): + value, cached_tensors = super().infer_node_smem_usage(td, node) + value *= self.pipeline_stage + return value, cached_tensors + + def _assign_reduce_step(self, node): + if not node.get_tag("tensorcore_config"): + return super()._assign_reduce_step(node) + # get reduce input size + target_transaction = self.arch.transaction_size[0] * 2 + # 512 bytes // type bits + reduce_input_dtype = node.get_buffer_dtype( + node.block_analyzer.get_input_buffers(node.reduction_block)[0] + ) + basic = (target_transaction * 8) // reduce_input_dtype.bits + + result = {} + for iter_info in node.raxis: + iter_name = iter_info.var.name + iter_dom = iter_info.dom.extent + if iter_dom % 16 > 0: + result[iter_name] = ( + 16 if iter_dom < basic else basic + ) # for the case of padding + elif iter_dom % basic == 0: + result[iter_name] = basic + else: + return super()._assign_reduce_step(node) + return result + + def _expand_reduce_axis(self, td: TileDict): + # For tensorcore program, if we got a small tilesize, we should consider expand the reduce axis + # to improve compute efficiency. + def _check_small_tile(td: TileDict): + minimal_threadhold = 32 + for node in self.ordered_nodes: + tile = td.get_tile(node) + if any([t <= minimal_threadhold for t in tile]): + return True + return False + + if not _check_small_tile(td): + return None + + smem_limit = min( + self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap + ) + rstep_map = td.rstep_map.copy() + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + # todo(lei): optimzie the all_steps enlarge policy to be a multiple of the original all_steps[k] + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + if any([v == [] for v in all_steps.values()]): + return rstep + + def _shared_memory_usage(td: TileDict): + return node.footprint( + td.output_tile, new_rstep_map, td.tensor_strides_map[node] + ) + + def _score(rstep_id): + rstep = { + k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] + for k in node.raxis + } + score = 0 + shape = node.propogate_inputs(td.get_tile(node), rstep=rstep) + for i, input_buffer in enumerate(node.input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = { + k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) + for k in node.raxis + } + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map = { + k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] + for k in node.raxis + } + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = _shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = { + k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] + for k in node.raxis + } + return rstep + + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map) + rstep_map = rstep + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + return + + def get_node_reduce_step_candidates(self, node): + if not node.get_tag("tensorcore_config"): + return super().get_node_reduce_step_candidates(node) + else: + # must be a a multiple of wmma_k + return { + k.var.name: [ + x * self.wmma_k + for x in get_all_factors(int(k.dom.extent) // self.wmma_k) + ] + for k in node.raxis + } + + def check_tile_shape_isvalid(self, td: TileDict): + for node in self.ordered_nodes: + if node.get_tag("tensorcore_config"): + ax_m, ax_n = node.get_tag("tensorcore_config") + block_m, block_n = td.tile_map[node][ax_m], td.tile_map[node][ax_n] + # check the tile size is valid + wmma_invalid = [ + block_m < wmma_m or block_n < wmma_n + for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes() + ] + if all(wmma_invalid): + return False + if any( + [y % x for x, y in zip(td.tile_map[node], node.get_space_dim())] + ): + return False + return super().check_tile_shape_isvalid(td) + + def _can_implement_layout(self, node: PrimFuncNode, td: TileDict): + # Not implemented yet + # This function is used to check whether we can implement swizzling + # layout under this tile config + return False + + def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): + if not node.get_tag("tensorcore_config"): + return super().compute_node_stride_map(node, td) + use_layout = self._can_implement_layout(node, td) + + AS_stride, BS_stride, C_stride = self._compute_tc_strides( + node, td.get_tile(node), td.get_rstep(node) + ) + A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node)) + tensor_strides = {} + output_strides = { + int(i + len(node.input_buffers)): Stride() + for i, _ in enumerate(node.output_buffers) + } + tensor_strides = {} + # when connected to shared input, should use full stride without rstep + for i, (stride, stride_full) in enumerate( + zip([AS_stride, BS_stride], [A_stride, B_stride]) + ): + if use_layout: + continue + _ = node.block_analyzer.get_input_buffers(node.reduction_block)[i].name + # TODO(lei): should dig further for shared memory connection case. + + return output_strides, tensor_strides + + def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): + if not node.get_tag("tensorcore_config"): + return super()._assign_block_size(node, td, block_size) + ax_m, ax_n = node.get_tag("tensorcore_config") + if block_size % self.arch.warp_size != 0: + return None + tile, rsteps = td.get_tile(node), td.get_rstep(node) + warps = block_size // self.arch.warp_size + ndim = len(tile) + + wmma = self.arch.get_avaliable_tensorintrin_shapes()[-1] + wmma_tile = [1 for _ in range(ndim)] + wmma_tile[ax_m] = wmma[0] + wmma_tile[ax_n] = wmma[1] + + space = [tile[i] // wmma_tile[i] for i in range(ndim)] + if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]: + # allow pad, otherwise, we can not get a valid tile shape + return None + if np.prod(space) % warps != 0: + return None + factors = factorize(np.prod(space) // warps) + + def _score(node, thread): # small is better + score = 0 + block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] + shape = node.propogate_inputs(block_tile) + for i, _ in enumerate(node.input_buffers): + score += np.prod(shape[i]) / self.arch.bandwidth[1] + return score + + warp_tile = wmma_tile.copy() + for factor in reversed(factors): + score_map = {} + for i in range(ndim): + if tile[i] % (warp_tile[i] * factor) != 0: + continue + warp_tile[i] *= factor + score_map[i] = (_score(node, warp_tile), i) + warp_tile[i] //= factor + if len(score_map) == 0: + return None + dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) + warp_tile[dim_order[0]] *= factor + + codegen_dict = Config() + codegen_dict.block = tile + codegen_dict.warp = warp_tile + codegen_dict.use_tc = True + codegen_dict.pipeline_stage = self.pipeline_stage + codegen_dict.use_async = self.use_async_copy + codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis] + codegen_dict.cached_tensors = td.cached_tensors_map[node] + codegen_dict.rasterization_plan = self.plan_rasterization(td) + + intrin_info = node.get_tag("intrin_info") + if intrin_info: + codegen_dict.intrin_info = IntrinInfo(**intrin_info) + # smem capacity + if td.smem_cost > self.arch.smem_cap: + codegen_dict.shared_scope = "shared.dyn" + + codegen_dict.complete_config(node) + codegen_dict.vectorize = self._plan_vectorize( + self.prim_func_node, td, block_size + ) + codegen_dict.arch = self.arch + codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes") + return codegen_dict + + def plan_rasterization(self, td: TileDict): + conditions = [] + # only support single node for now + conditions.append(len(self.ordered_nodes) > 1) + # small op don't need imporve l2 cache + conditions.append(td.num_wave < 4) + # only on Ampere+ arch + conditions.append(self.arch.compute_capability < "80") + + def _check_memory_size(): + overall_gmem_size_in_bytes: int = 0 + for node in self.ordered_nodes: + for arg in node.args: + overall_gmem_size_in_bytes += ( + int(np.prod(arg.shape)) * tvm.DataType(arg.dtype).bits // 8 + ) + return overall_gmem_size_in_bytes < (self.arch.l2_cache_size_bytes * 4) + + conditions.append(_check_memory_size()) + if any(conditions): + return NoRasterization() + # otherwise, simply provide a block rasterization factor + raster_factor = int(self.arch.compute_max_core**0.5) + + return Rasterization2DColumn(raster_factor) diff --git a/python/bitblas/base/roller/rasterization.py b/python/bitblas/base/roller/rasterization.py new file mode 100644 index 000000000000..a15b0d8dc3ab --- /dev/null +++ b/python/bitblas/base/roller/rasterization.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Rasteration Plan For L2 Cache Locality""" + +from typing import List + + +class Rasterization: + def __init__(self) -> None: + pass + + def get_code(self) -> List[str]: + raise NotImplementedError() + + +class NoRasterization(Rasterization): + def __init__(self) -> None: + super().__init__() + + def __repr__(self) -> str: + return "" + + def get_code(self) -> List[str]: + return [] + + +class Rasterization2DRow(Rasterization): + """ + Rasterization by Row, each Row line width is panel_width + _________ + _________| + |_________ + __________| + """ + + def __init__(self, panel_width=4) -> None: + super().__init__() + self.panel_width_ = panel_width + + def __repr__(self) -> str: + return f"" + + def get_code(self) -> List[str]: + raise NotImplementedError() + + +class Rasterization2DColumn(Rasterization): + """ + Rasterization by Column, each column line width is panel_width + _ + | | | | + | | | | + |_| |_| + """ + + def __init__(self, panel_width=4) -> None: + super().__init__() + self.panel_width_ = panel_width + + def __repr__(self) -> str: + return f"" + + def get_device_function(self) -> str: + return """ +__device__ dim3 rasterization2DColumn(const int panel_width) { + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +panel_width * gridDim.x - 1) / (panel_width * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (panel_width *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?panel_width : (totalBlock - panelIdx * (panel_width *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width; + const auto bz = blockIdx.z; + + dim3 blockIdx(bx, by, bz); + return blockIdx; +} + """ + + def get_code(self) -> List[str]: + return [ + self.get_device_function(), + "const dim3 blockIdx(rasterization2DColumn({});".format(self.panel_width_), + ] diff --git a/python/bitblas/base/roller/shape_inference/__init__.py b/python/bitblas/base/roller/shape_inference/__init__.py new file mode 100644 index 000000000000..188aa0bb70a7 --- /dev/null +++ b/python/bitblas/base/roller/shape_inference/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .tir import get_analyzer_by_tir # pylint: disable=unused-import diff --git a/python/bitblas/base/roller/shape_inference/common.py b/python/bitblas/base/roller/shape_inference/common.py new file mode 100644 index 000000000000..730bbbeef4c8 --- /dev/null +++ b/python/bitblas/base/roller/shape_inference/common.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import OrderedDict +from typing import Dict, List + +from tvm import arith + + +class Statement(): + def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): + self.output = output + self.dependent_region = dependent_region + self.var_map = var_map + self.range_map = range_map + +def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): + return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) + +class InputShapeInference(): + def __init__(self, deps: List[Statement]): + self.deps = deps + + def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int]): + shape = shape.copy() + ana = arith.Analyzer() + for dep in reversed(self.deps): + for var, bound in zip(dep.var_map.values(), shape[dep.output]): + ana.update(var, bound) + for var, bound in dep.range_map.items(): + if var.name in rstep: + bound = arith.ConstIntBound(0, min(bound.max_value, rstep[var.name] - 1)) + ana.update(var, bound) + for name, regions in dep.dependent_region.items(): + for region in regions: + bounds = [ana.const_int_bound(index) for index in region] + if name in shape: # simply merge two bounds + bounds = [_merge_two_bounds(x, y) for x, y in zip(shape[name], bounds)] + shape[name] = bounds + + for name, bounds in shape.items(): + shape[name] = [c.max_value - c.min_value + 1 for c in bounds] + return shape + + def infer(self, shape, rstep: Dict[str, int] = {}): + if isinstance(shape, (list, tuple)): + shape = {"output0" : [arith.ConstIntBound(0, val - 1) for val in shape]} + shape = self._infer(shape, rstep) + return shape + + def get_input_exprs(self, output_exprs): + result = output_exprs.copy() + ana = arith.Analyzer() + for dep in reversed(self.deps): + for var, expr in zip(dep.var_map.values(), result[dep.output]): + ana.bind(var, expr) + for var in dep.range_map: + ana.bind(var, 0) + for name, regions in dep.dependent_region.items(): + if name in result: + continue + region = regions[0] + input_expr = [ana.simplify(index) for index in region] + result[name] = input_expr + return result + diff --git a/python/bitblas/base/roller/shape_inference/tir.py b/python/bitblas/base/roller/shape_inference/tir.py new file mode 100644 index 000000000000..35bf0b7d864f --- /dev/null +++ b/python/bitblas/base/roller/shape_inference/tir.py @@ -0,0 +1,399 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Dict, List, Tuple, Set, Mapping +from tvm.tir.schedule.schedule import BlockRV +from tvm.ir import structural_equal +from tvm import arith, tir + + +class Statement: + def __init__(self, block_analyzer, block: BlockRV): + self.block_analyzer = block_analyzer + self.block = block + # assume one tir block only has one output buffer + self.dep_name = block_analyzer.get_output_buffers(block)[0].name + self.dependent_region = _extract_dependent_region(block_analyzer, block) + + self.reverse_bound_inference = {} + + def make_reverse(self, input_name: str, input_iter: List[tir.PrimExpr]): + if len(self.block_analyzer.get_reduce_axis(self.block)) > 0: + return None + if len(self.dependent_region[input_name]) != 1: + return None + indices = self.dependent_region[input_name][0] + iter_map_range = { + _iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block) + } + iter_map_result = arith.detect_iter_map( + indices, + iter_map_range, + check_level=arith.iter_affine_map.IterMapLevel.Surjective, + simplify_trivial_iterators=False, + ) + if len(iter_map_result.errors) > 0: + return None + results = arith.iter_affine_map.inverse_affine_iter_map(iter_map_result.indices, input_iter) + output_indices = [] + for _iter in self.block_analyzer.get_spatial_axis(self.block): + if _iter.var in results: + output_indices.append(results[_iter.var]) + else: + # not Bijective mapping case + output_indices.append(tir.Var("undefined", dtype="int32") % int(_iter.dom.extent)) + return output_indices + + +def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): + return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) + + +class TensorDepNode(object): + """ + For tensor dependency analysis. + """ + + def __init__(self, name): + self.name = name + self._next = [] + self._prev = [] + + def add_next(self, node): + self._next.append(node) + self.deduplicate(self._next) + + def add_prev(self, node): + self._prev.append(node) + self.deduplicate(self._prev) + + def deduplicate(self, lst): + seen = set() + lst[:] = [n for n in lst if not (n in seen or seen.add(n))] + + def __str__(self): + return self.name + + def __repr__(self): + return self.name + + +class DependencyAnalysis(object): + def __init__(self, deps): + self.deps = deps + # issue: duplicate name when we have two same ops. + self.name2dep = self._construct_unique_name2dep(deps) + self.mapping = {} # name -> TensorDepNode + + def _construct_unique_name2dep(self, deps): + """ + This is a workaround for the issue that we have two same ops' fuse case. + See https://github.com/apache/tvm/issues/16433 + """ + _names:Set = set() + name2dep:Mapping = {} + for dep in deps: + output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0] + base_name = output_buffer.name + if base_name not in _names: + _names.add(base_name) + else: + i = 1 + while f"{base_name}_{i}" in _names: + i += 1 + base_name = f"{base_name}_{i}" + _names.add(base_name) + name2dep[base_name] = dep + return name2dep + + def get_or_create_node(self, name): + if name not in self.mapping: + self.mapping[name] = TensorDepNode(name) + return self.mapping[name] + + def traverse_dependencies(self, compute): + if isinstance(compute, Statement): + node = self.get_or_create_node( + compute.block_analyzer.get_output_buffers(compute.block)[0].name + ) + # Loop through input tensors + for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): + # Get the input node + input_node = self.traverse_dependencies(input_buffer) + input_node.add_next(node) + node.add_prev(input_node) + elif isinstance(compute, tir.Buffer): + node = self.get_or_create_node(compute.name) + return node + + def analyze(self): + # Starting point for traversal + for _, compute in self.name2dep.items(): + self.traverse_dependencies(compute) + + def print_dependencies(self): + for name, node in self.mapping.items(): + print(f"{name} depends on {', '.join([prev.name for prev in node._prev])}") + + def find_path_from_source(self, start_name, target_name): + """ + Finds the path (if it exists) from a starting node (source) to a target node. + Returns the path as a list of nodes. + """ + visited = set() + path = [] + if self._find_path_recursive(self.mapping[start_name], target_name, visited, path): + return path + return [] + + def _find_path_recursive(self, current_node, target_name, visited, path): + """ + Recursive helper function for find_path_from_source. + """ + if current_node.name == target_name: + path.append(current_node) + return True + + if current_node.name in visited: + return False + + visited.add(current_node.name) + path.append(current_node) + + for next_node in current_node._next: + if self._find_path_recursive(next_node, target_name, visited, path): + return True + + path.pop() + return False + + +class InputShapeInference: + def __init__(self, deps: List[Statement]): + self.deps = deps + self.target_mapping = {} + self.buffer_mapping = {} + self.reduce_axes = [] + for dep in self.deps: + for ax in dep.block_analyzer.get_reduce_axis(dep.block): + self.reduce_axes.append(ax) + self.dep_analysis = DependencyAnalysis(self.deps) + self.dep_analysis.analyze() + + def construct_dependency_target(self, targets: Tuple[str]): + if targets in self.target_mapping: + return self.target_mapping[targets] + # should be buffer name instead of block name + name2dep = { + dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps + } + mapping = {} + input_vars = [] + for target in targets: + vars = [ + iter.var + for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block) + ] + input_vars.append(vars) + mapping[target] = [vars] + ana = arith.Analyzer() + + for dep in self.deps: + for name in dep.dependent_region: + if name not in mapping: + continue + dep_name = dep.dep_name + indices = mapping[name][0] + output_indices = dep.make_reverse(name, indices) + if dep_name in targets: + continue + if dep_name not in mapping: + mapping[dep_name] = [output_indices] + elif not region_exist_in_list(output_indices, mapping[dep_name]): + mapping[dep_name].append(output_indices) + + for dep in reversed(self.deps): + indices_list = mapping[dep.dep_name] + ax_vars = [iter.var for iter in dep.block_analyzer.get_spatial_axis(dep.block)] + for input_name, regions in dep.dependent_region.items(): + if input_name in targets: + continue + if input_name not in mapping: + mapping[input_name] = [] + for indices in indices_list: + for region in regions: + vmap = { + k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) + for k, v in zip(ax_vars, indices) + } + region = [ + ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region + ] + if not region_exist_in_list(region, mapping[input_name]): + mapping[input_name].append(region) + buffers = [] + for dep in self.deps: + for buffer in dep.block_analyzer.get_buffers(dep.block): + buffers.append(buffer) + + for buffer in buffers: + self.buffer_mapping[buffer.name] = buffer + + self.target_mapping[targets] = input_vars, mapping + return input_vars, mapping + + def infer( + self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int] = {}, targets=None + ): + compute_targets = tuple(shape.keys()) + input_vars, mapping = self.construct_dependency_target(compute_targets) + ana = arith.Analyzer() + results = {} + intermediate_bind = {} + for vars, bounds in zip(input_vars, shape.values()): + for var, bound in zip(vars, bounds): + ana.update(var, bound, True) + for ax in self.reduce_axes: + # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. + if ax.var.name in rstep: + bound = arith.ConstIntBound( + int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1) + ) + else: + bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) + ana.update(ax.var, bound, True) + + for name, regions in mapping.items(): + if targets is not None and name not in targets: + continue + if compute_targets[0:1] == compute_targets: + (compute_target,) = compute_targets + path = self.dep_analysis.find_path_from_source(name, compute_target) + if len(path) > 2: + intermediate_nodes = path[1:-1] + for node in intermediate_nodes: + iters = mapping[node.name] + if len(iters) != len(regions) or len(iters) != 1: + continue + if len(*iters) != len(*regions): + break + regions = iters + intermediate_bind[name] = compute_target + + for region in regions: + bound = [ana.const_int_bound(indice) for indice in region] + if name in results: # simply merge two bounds + bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)] + results[name] = bound + else: + for region in regions: + bound = [ana.const_int_bound(indice) for indice in region] + if name in results: # simply merge two bounds + bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)] + results[name] = bound + + for name, bounds in results.items(): + results[name] = [c.max_value - c.min_value + 1 for c in bounds] + return results, intermediate_bind + + def get_input_exprs(self, output_exprs): + input_vars, mapping = self.construct_dependency_target(tuple(output_exprs.keys())) + ana = arith.Analyzer() + for ax in self.reduce_axes: + ana.bind(ax.var, 0) + vmap = {} + for vars, exprs in zip(input_vars, output_exprs.values()): + for var, expr in zip(vars, exprs): + if expr.dtype != var.dtype: + expr = tir.Cast(var.dtype, expr) + vmap[var] = expr + result = {} + + for name, regions in mapping.items(): + region = regions[0] + result[name] = [ + ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region + ] + return result + + +def region_exist_in_list(a, list) -> bool: + def expr_is_same(a, b) -> bool: + if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): + return a.value == b.value + return structural_equal(a, b) + + def region_is_same(a, b) -> bool: + for indice_a, indice_b in zip(a, b): + if not expr_is_same(indice_a, indice_b): + return False + return True + + return any([region_is_same(a, x) for x in list]) + + +def walk_indice(expr): + if isinstance(expr, tir.expr.BinaryOpExpr): + a = walk_indice(expr.a) + b = walk_indice(expr.b) + if a is not None and b is not None: + return expr + else: + return None + elif isinstance(expr, tir.expr.ConstExpr): + return expr + elif isinstance(expr, tir.Var): + return expr + elif isinstance(expr, tir.ProducerLoad): + return None + elif isinstance(expr, tir.Cast): + a = walk_indice(expr.value) + if a is not None: + return expr + return None + elif isinstance(expr, tir.Call): + return None + else: + raise Exception("Unhandled node type in walk_indice(): %s" % expr) + + +def _extract_dependent_region(block_analyzer, block: BlockRV) -> Dict[str, List[tir.PrimExpr]]: + input_buffers = block_analyzer.get_input_buffers(block) + dependent_region = {buffer.name: [] for buffer in input_buffers} + + def fvisit(x): + if not isinstance(x, tir.BufferLoad): + return + if x.buffer.name not in dependent_region: + return + index = [] + for indice, shape_limit in zip(x.indices, x.buffer.shape): + expr = walk_indice(indice) + if expr is None: + expr = tir.Var("undefined", dtype="int8") % shape_limit + if isinstance(expr, tir.IntImm) and expr.value == 0: + """for tensor ir zero dim smplification case. + for ax0, ax1, ax2 in T.grid(T.int64(1024), T.int64(1024), T.int64(1024)): + with T.block("T_dense"): + v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2]) + T.reads(A_reindex[T.int64(0), v0, v2], B_reindex[T.int64(0), v1, v2]) + T.writes(T_dense_reindex[T.int64(0), v0, v1]) + with T.init(): + T_dense_reindex[T.int64(0), v0, v1] = T.float16(0) + T_dense_reindex[T.int64(0), v0, v1] = T_dense_reindex[T.int64(0), v0, v1] + A_reindex[T.int64(0), v0, v2] * B_reindex[T.int64(0), v1, v2] + For exmaple, the T_dense_reindex has three dims, however there're only two spatial loops. + """ + continue + index.append(expr) + if not region_exist_in_list(index, dependent_region[x.buffer.name]): + dependent_region[x.buffer.name].append(index) + + stmt = block_analyzer.sch.get(block) + tir.stmt_functor.post_order_visit(stmt, fvisit=fvisit) + return dependent_region + + +def get_analyzer_by_tir(block_analyzer, args) -> InputShapeInference: + deps = [Statement(block_analyzer, block) for block in args] + + return InputShapeInference(deps) diff --git a/python/bitblas/base/schedule_rule.py b/python/bitblas/base/schedule_rule.py new file mode 100644 index 000000000000..53319b4fcfef --- /dev/null +++ b/python/bitblas/base/schedule_rule.py @@ -0,0 +1,149 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm schedule_rule.py in dlight. +"""A lightweight wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc.""" +from typing import Callable, List, Union + +from tvm import tir +from tvm.target import Target + + +class ScheduleRule: # pylint: disable=too-few-public-methods + """A thin wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc. + + Given a PrimFunc, a target, and a tunable flag, the apply method of a ScheduleRule + returns either a Schedule, a list of Schedules, or None, where None means that the rule + is not applicable to the given PrimFunc. If the tunable flag is True, the ScheduleRule is + allowed to return either a Schedule or a list of Schedules, and the Schedules are allowed to + contain tunable instructions. If the tunable flag is False, the ScheduleRule is only allowed to + return a Schedule, and the Schedule is not allowed to contain tunable instructions. + """ + + def apply( + self, + func: tir.PrimFunc, + target: Target, + tunable: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + """Apply the ScheduleRule to the given PrimFunc. + + Parameters + ---------- + func : tir.PrimFunc + The PrimFunc to apply the ScheduleRule to. + target : Target + The compilation target the schedule is supposed to be built for. + tunable : bool + Whether the schedule is allowed to contain tunable instructions. + + Returns + ------- + results : Union[None, tir.Schedule, List[tir.Schedule]] + Either a Schedule, a list of Schedules, or None, where None means that the rule + is not applicable to the given PrimFunc. + """ + raise NotImplementedError + + def apply_config( + self, + func: tir.PrimFunc, + config, + ): + """Apply the ScheduleRule to the given PrimFunc. + + Parameters + ---------- + func : tir.PrimFunc + The PrimFunc to apply the ScheduleRule to. + target : Target + The compilation target the schedule is supposed to be built for. + configs : + # todo: Discribe the configs + Returns + ------- + results : Union[None, tir.Schedule, List[tir.Schedule]] + Either a Schedule, a list of Schedules, or None, where None means that the rule + is not applicable to the given PrimFunc. + """ + raise NotImplementedError + + @staticmethod + def from_callable( + name, + ) -> Callable[ + [ + Callable[ + [tir.PrimFunc, Target, bool], + Union[None, tir.Schedule, List[tir.Schedule]], + ], + ], + "ScheduleRule", + ]: + """Create a ScheduleRule from a callable. + + Parameters + ---------- + name : str + + Returns + ------- + decorator : Callable + A decorator that takes a callable and returns a ScheduleRule. + + Examples + -------- + .. code-block:: python + + @ScheduleRule.from_callable("MyRule") + def my_rule(func: tir.PrimFunc, target: Target, tunable: bool) -> Union[None, Schedule] + # Do something with func and target + """ + + def decorator(f) -> "ScheduleRule": # pylint: disable=invalid-name + class _Rule(ScheduleRule): + def apply( + self, + func: tir.PrimFunc, + target: Target, + tunable: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + return f(func, target, tunable) + + _Rule.__name__ = name + return _Rule() + + return decorator + + def is_target_available( + self, target: Target + ) -> bool: # pylint: disable=unused-argument + """Check whether the rule is available for the given target. + + Parameters + ---------- + target : Target + The compilation target the schedule is supposed to be built for. + + Returns + ------- + available : bool + Whether the rule is available for the given target. + """ + return True diff --git a/python/bitblas/base/transform.py b/python/bitblas/base/transform.py new file mode 100644 index 000000000000..e1e59db880b4 --- /dev/null +++ b/python/bitblas/base/transform.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Apply ScheduleRules onto an IRModule to generate default schedules without tuning, +or a space for MetaSchedule tuning +""" +from typing import List, Optional, Dict +import os +import shutil +import tempfile +import os.path as osp +import tvm +from tvm import tir +from tvm import meta_schedule as ms +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm.target import Target +from .roller.policy import DefaultPolicy, TensorCorePolicy +from .roller.arch import CUDA +from .schedule_rule import ScheduleRule +from ..gpu.matmul_analysis import get_tensorized_func_and_tags +from ..base.analysis import check_func_with_dynamic +from .utils import apply_and_build, fast_tune, fast_tune_with_dynamic_range + + +def _is_scheduled(func: tir.PrimFunc) -> bool: + if not isinstance(func, tir.PrimFunc): + return False + if not func.attrs: + return False + if "tir.is_scheduled" not in func.attrs: + return False + return func.attrs["tir.is_scheduled"] == 1 + + +@module_pass(opt_level=0, name="ApplyDefaultSchedule") +class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods + """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" + + def __init__(self, *rules: ScheduleRule): + """Construct a new ApplyDefaultSchedule pass. + + Parameters + ---------- + *rules : ScheduleRule + The ScheduleRules to apply to all PrimFuncs in the module. + """ + self.rules = list(rules) + + def transform_module( # pylint: disable=missing-function-docstring + self, + mod: IRModule, + _: PassContext, + ) -> IRModule: + target = Target.current(allow_none=False) + + updated_functions = {} + for g_var, func in mod.functions_items(): + if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): + sch = _apply_rules(func, target, self.rules, tunable=False) + if sch is not None: + assert len(sch) == 1 + updated_functions[g_var] = sch[0].mod["main"].with_attr("tir.is_scheduled", 1) + for g_var, func in updated_functions.items(): + mod[g_var] = func + return mod + + +@module_pass(opt_level=0, name="ApplyFastTuning") +class ApplyFastTuning: # pylint: disable=too-few-public-methods + """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" + + def __init__( + self, + topk: int = 10, + target: Optional[Target] = None, + parallel_build: bool = True, + meta_database_dir: str = None, + whitelist: List[str] = [], + dynamic_range: Dict[str, List[int]] = {}, + ): + """Construct a new ApplyFastTuning pass. + + Parameters + ---------- + meta_database : str + The path of database. + dynamic_range : Dict[str, List[int]] + Use for generate kernel based on dynamic range. + """ + self.topk = topk + self.target = Target.current() if target is None else target + self.parallel_build = parallel_build + self.meta_database_dir = meta_database_dir + self.whitelist = whitelist + self.dynamic_range = dynamic_range + self.temp_dir = tempfile.TemporaryDirectory() + print(f"[BitBLAS] Using meta database dir {self.temp_dir}") + path_workload = osp.join(self.temp_dir.name, "database_workload.json") + path_tuning_record = osp.join(self.temp_dir.name, "database_tuning_record.json") + self.cache_meta_database = ms.database.JSONDatabase( + path_workload, path_tuning_record, module_equality="structural" + ) + + def _in_white_list(self, func_name: str) -> bool: + if len(self.whitelist) == 0: + return True + for name in self.whitelist: + if name in func_name: + return True + return False + + def transform_module( # pylint: disable=missing-function-docstring + self, + mod: IRModule, + _: PassContext, + ) -> IRModule: + target = self.target + updated_functions = {} + + for g_var, func in mod.functions_items(): + if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): + if not self._in_white_list(g_var.name_hint): + continue + print(f"[BitBLAS] Start to apply fast tuning for {g_var}") + normalize_mod_func_ = tvm._ffi.get_global_func("tvm.meta_schedule.normalize_mod") + _normalized_func_mod = normalize_mod_func_(func) + + if self.cache_meta_database.has_workload(_normalized_func_mod): + tuning_record = self.cache_meta_database.query_tuning_record( + _normalized_func_mod, + target, + g_var.name_hint, + ) + if tuning_record: + trace = tuning_record.trace + sch = tvm.tir.Schedule(func) + trace.apply_to_schedule(sch, remove_postproc=False) + print(f"[BitBLAS] Find Cache for {g_var}") + updated_functions[g_var] = sch.mod["main"].with_attr("tir.is_scheduled", 1) + continue + + if check_func_with_dynamic(func): + + dispatch_mod = fast_tune_with_dynamic_range( + func, + target=target, + topk=self.topk, + parallel_build=self.parallel_build, + global_symbol=g_var.name_hint, + dynamic_range=self.dynamic_range, + ) + + if dispatch_mod: + for g, f in dispatch_mod.functions_items(): + if g.name_hint == g_var.name_hint: + # avoid duplicated global symbol + updated_functions[g_var] = f.without_attr("global_symbol").with_attr("tir.is_scheduled", 1) + else: + updated_functions[g] = f.with_attr("tir.is_scheduled", 1) + # cannot reuse meta database as it canot be recorvered from the trace + workload = self.cache_meta_database.commit_workload(_normalized_func_mod) + else: + # otherwise is static shape analysis + _, best = fast_tune( + func, target=target, topk=self.topk, parallel_build=self.parallel_build + ) + + if best is not None: + updated_functions[g_var] = best.sch.mod["main"].with_attr("tir.is_scheduled", 1) + workload = self.cache_meta_database.commit_workload(_normalized_func_mod) + # only record the best schedule + self.cache_meta_database.commit_tuning_record( + ms.database.TuningRecord( + best.sch.trace, + workload, + [best.latency], + target, + ms.arg_info.ArgInfo.from_prim_func(func=best.sch.mod["main"]), + ) + ) + + for g_var, func in updated_functions.items(): + mod[g_var] = func + + # copy database + if self.meta_database_dir is not None: + if not osp.exists(self.meta_database_dir): + os.makedirs(self.meta_database_dir) + # TODO(lei): maybe another way to copy the database + shutil.copytree(self.temp_dir.name, self.meta_database_dir, dirs_exist_ok=True) + + return mod + + def __del__(self): + # clean up the temp cache + self.temp_dir.cleanup() + + +def _apply_rules( + func: tir.PrimFunc, + target: Target, + rules: List[ScheduleRule], + tunable: bool, +) -> Optional[List[tir.Schedule]]: + for rule in rules: + space = rule.apply(func, target, tunable) + if space is None: + continue + if isinstance(space, tir.Schedule): + space = [space] + return space + return None diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py new file mode 100644 index 000000000000..19721174c1db --- /dev/null +++ b/python/bitblas/base/utils.py @@ -0,0 +1,489 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tvm +import os +from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind, MapResult +from concurrent.futures import ThreadPoolExecutor, as_completed +import numpy as np +from typing import List, Tuple, Optional, Dict, Union +from tvm import tir, IRModule +from tvm.runtime import Module +from tvm.tir import Schedule +from tvm.relax.expr import Function +import bitblas +from .analysis import get_root_block, get_reduction_blocks, find_var_from_func +from bitblas.base.roller.arch import CUDA +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.base.roller.rasterization import NoRasterization +import tempfile +import itertools +from tvm.ir.supply import GlobalVarSupply +from bitblas.utils import match_global_kernel + + +def get_rasterization_code(pannel_width: int = 8) -> str: + return f""" + const int MAX_BLOCK_N = {pannel_width}; + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +MAX_BLOCK_N * gridDim.x - 1) / (MAX_BLOCK_N * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (MAX_BLOCK_N *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?MAX_BLOCK_N : (totalBlock - panelIdx * (MAX_BLOCK_N *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * MAX_BLOCK_N * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) % strideLd + panelIdx * MAX_BLOCK_N; + const auto bz = blockIdx.z; + const dim3 blockIdx(bx, by, bz); + """ + + +class CompileResult: + """ + Class to store the result of compilation + """ + + def __init__(self, config, sch, mod: Module): + self.config = config + self.sch = sch + self.mod = mod + self.code = mod.imported_modules[0].get_source() if mod else None + self.latency = 1e9 + self.profile_tensors = [] + self.time_evaluator = None + + def profile(self): + return self.time_evaluator(*self.profile_tensors).mean + + +def _apply_config( + func: tir.PrimFunc, + config=None, # todo(lei): update typing +) -> Optional[List[tir.Schedule]]: + """ + find rules: + case 1. if the main block has no reduce op, then use the Elementwise rule. + case 2. if the config enabled tensorcore, then use the TensorCore rule. + case 3. if any([t > 1 for t in config.reduce_thread]), we should use the InnerThread Reduction Rule. + case 4. else we should use general reduction rule. + """ + print("[BitBLAS] Apply config ", config) + + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_blocks = get_reduction_blocks(sch, blocks) + + if not reduction_blocks: + return bitblas.gpu.ElementWise().apply_config(func, config) + elif config.use_tc: + if config.arch.sm_version >= 80: + # For A100(sm_80) or more advanced gpu, use MMA tensorization. + return bitblas.gpu.MatmulTensorizationMMA().apply_config(func, config) + else: + # For other GPUs, use WMMA tensorization. + return bitblas.gpu.MatmulTensorizationWMMA().apply_config(func, config) + else: + _reduction_rules = [] + + _reduction_rules.append(bitblas.gpu.GEMV()) + if not any([t > 1 for t in config.reduce_thread]): + # Matrix multiplication template doesn't support inner thread reduction + _reduction_rules.append(bitblas.gpu.Matmul()) + _reduction_rules.append(bitblas.gpu.GeneralReduction()) + + for rule in _reduction_rules: + sch = rule.apply_config(func, config) + try: + sch = rule.apply_config(func, config) + except Exception as e_msg: + print("[BitBLAS] Apply config failed: ", e_msg) + continue + if sch is not None: + return sch + return None + + +def get_dummy_input_arrays( + func: Union[tir.PrimFunc, Function], device: tvm.runtime.Device +): + def var_wrapper(v): + if isinstance(v, tvm.tir.Var): + assert "opt_shapes" in func.attrs + assert v.name in func.attrs["opt_shapes"] + return func.attrs["opt_shapes"][v.name].value + elif isinstance(v, tvm.tir.IntImm): + return v.value + else: + raise RuntimeError("Not supported type: ", type(v)) + + profile_tensors = [] + for param in func.params: + if isinstance(func, tir.PrimFunc): + if param not in func.buffer_map: + # in case of dynamic symbolic may in params + continue + arg = func.buffer_map[param] + elif isinstance(func, Function): + arg = param.struct_info + else: + raise ValueError("Not supported type: ", type(func)) + + profile_tensors.append( + tvm.nd.array( + np.random.uniform(-1, 1, [var_wrapper(i) for i in arg.shape]).astype( + arg.dtype + ), + device=device, + ) + ) + return profile_tensors + + +def apply_and_build_parallel( + func, configs, arch, num_repeats=5, max_workers=10 +) -> CompileResult: + cpresults = [] + + profile_tensors = get_dummy_input_arrays(func, arch.device) + max_workers = min(len(configs), os.cpu_count(), max_workers) + + # apply config in thread parallel + _sched: List[Schedule] = [] + with ThreadPoolExecutor(max_workers=4) as schduler: + futures = { + schduler.submit(lambda f, c: _apply_config(f, c), func, config) + for config in configs + } + for future in as_completed(futures): + _sched.append(future.result()) + + builder = PopenPoolExecutor(max_workers=max_workers) + + # build in process parallel + def _build(context) -> str: + idx, mod, arch = context + if mod is None: + return idx, None, None + # TODO(lei): + # this is a trick to implement rasteration, will be removed in the future + config = configs[idx] + + @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) + def tvm_callback_cuda_postproc(code, _): + index = code.index("{", match_global_kernel(code)) + if not isinstance(config.rasterization_plan, NoRasterization): + factor = config.rasterization_plan.panel_width_ + rasterization_code = get_rasterization_code(factor) + code = code[: index + 2] + rasterization_code + code[index + 2 :] + return code + + with tvm.transform.PassContext( + config={"tir.use_async_copy": True, **config.pass_context} + ): + rt_mod = tvm.build(mod["main"], target=arch.target) + + from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel + + artifact_path = os.path.join( + tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format + ) + code = rt_mod.imported_modules[0].get_source() + rt_mod.export_library(artifact_path, fcompile=tar) + return idx, code, artifact_path + + _mods = [sch.mod if sch is not None else None for sch in _sched] + + for map_result in builder.map_with_error_catching( + _build, + [(i, mod, arch) for i, mod in enumerate(_mods)], + ): + if map_result.status == StatusKind.TIMEOUT: + print("[BitBLAS] LocalBuilder: Timeout") + elif map_result.status == StatusKind.EXCEPTION: + # TODO(lei): redirect the exception to file if needed + print("[BitBLAS] LocalBuilder: An exception occurred ", map_result.value) + continue + elif map_result.status == StatusKind.COMPLETE: + idx, code, artifact_path = map_result.value + if artifact_path is None: + print("[BitBLAS] Artifact path is None") + continue + sch = _sched[idx] + config = configs[idx] + rt_mod = tvm.runtime.load_module(artifact_path) + cpresult = CompileResult(config, sch, rt_mod) + timer_cuda_mod = rt_mod.time_evaluator( + rt_mod.entry_name, arch.device, number=num_repeats + ) + cpresult.profile_tensors = profile_tensors + cpresult.time_evaluator = timer_cuda_mod + cpresult.code = code + cpresults.append(cpresult) + else: + raise ValueError(f"Unreachable: unexpected result: {map_result}") + + del builder + + best = None + best_latency = 1e9 + for cpresult in cpresults: + config = cpresult.config + try: + latency = cpresult.profile() + except Exception as e_mesg: + print("[BitBLAS] Evaluation with config failed: ", e_mesg) + continue + print("[BitBLAS] Evaluation with config ", config) + print("[BitBLAS] Time cost of this config: {:.3f} ms".format(latency * 1e3)) + + cpresult.latency = latency + if latency < best_latency: + best_latency = latency + best = cpresult + + return cpresults, best + + +def apply_and_build( + func, + configs, + arch, + parallel_build=False, +) -> Tuple[List[CompileResult], CompileResult]: + max_workers = 10 if parallel_build else 1 + return apply_and_build_parallel(func, configs, arch, max_workers) + + +def fast_tune( + func: tir.PrimFunc, + target: tvm.target.Target, + topk: int = 10, + parallel_build: bool = True, +): + if target.kind.name != "cuda": + print("[BitBLAS] Only support CUDA target") + return None, None + + specilized_func = func + if func.attrs is not None and "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + # should be int value + if not all([isinstance(v.value, int) for v in opt_shapes.values()]): + print("[BitBLAS] The opt_shapes should be int value") + return None, None + # currently only support one dynmaic range + if len(opt_shapes) > 1: + print("[BitBLAS] Currently only support one dynamic range") + return None, None + + for buffer in func.buffer_map.values(): + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var): + if axis.name not in opt_shapes: + raise NotImplementedError( + "Currently do not support fast tune with none-dynamic range set" + ) + if opt_shapes: + for name, shape in opt_shapes.items(): + var = find_var_from_func(func, name) + specilized_func = func.specialize( + {var: shape.astype(var.dtype)} + ).with_attr("is_specialized") + + arch = CUDA(target) + + policy = DefaultPolicy(func=func, arch=arch) + try: + specilized_func, tags = get_tensorized_func_and_tags( + specilized_func, arch.target + ) + except Exception as e_msg: + print("[BitBLAS] Get tensorized func and tags failed: ", e_msg) + tags = None + if tags: + policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags) + + configs = policy.emit_config(topk) + cpresults, best = apply_and_build( + func, configs, arch, parallel_build=parallel_build + ) + + return cpresults, best + + +# always use the first function as the base +def collect_buffers_to_declare(func): + params = [] + # collect dynamic symbolic + dyn_symbolic: List[tvm.tir.Var] = [] + buffers_to_declare = [] + for param in func.params: + if param not in func.buffer_map: + continue + buffer = func.buffer_map[param] + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: + dyn_symbolic.append(axis) + buffers_to_declare.append(buffer) + params.append(buffer.data) + + # the args should be buffers + dynamic symbolic + params += list(dyn_symbolic) + + return params, buffers_to_declare + + +def refactor_specialized_func(g_var, func, params, buffers_to_declare): + body = func.body + attrs = func.attrs + global_symbol = g_var + if "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + + def serialize_name(opt_shapes: Dict): + return "_opt_" + "_".join([f"{k}_{v}" for k, v in opt_shapes.items()]) + + global_symbol += serialize_name(opt_shapes) + ret_type = func.ret_type + for buf in buffers_to_declare: + body = tvm.tir.DeclBuffer(buf, body=body) + + # devide func must be private + device_func = tvm.tir.PrimFunc(params, body, ret_type, attrs=attrs).without_attr( + "global_symbol" + ) + return global_symbol, device_func + + +def create_dispatch_func(g_var: str, func: tir.PrimFunc, refactored_funcs: List[str]): + global_symbol = g_var + attrs = func.attrs + buffer_map = func.buffer_map + params = func.params + ret_type = func.ret_type + + # collect dynamic symbolic + dyn_symbolic: List[tvm.tir.Var] = [] + _invoke_params = [] + for param in func.params: + if param not in func.buffer_map: + continue + buffer = func.buffer_map[param] + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: + dyn_symbolic.append(axis) + _invoke_params.append(buffer.data) + _invoke_params += list(dyn_symbolic) + + func_range: List[int] = [] + global_symbols = [] + for g_var, refactor_func in refactored_funcs: + opt_shapes = refactor_func.attrs["opt_shapes"] + func_range.append(list(opt_shapes.values())[0]) + global_symbols.append(g_var) + + # TODO(lei): general the dispatch function to support multiple dynamic symbolics + assert len(dyn_symbolic) == 1, "Only support one dyanmic symbolics currently" + + ib = tvm.tir.ir_builder.create() + syb = list(dyn_symbolic)[-1] + last_range = 0 + for i, (_range, g_var) in enumerate(zip(func_range, global_symbols)): + if i == 0: + with ib.if_scope(syb <= _range): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + else: + with ib.if_scope(tvm.tir.all(syb > last_range, syb <= _range)): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + last_range = _range + with ib.if_scope(syb > last_range): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + stmt = ib.get() + dispatch_func = tvm.tir.PrimFunc( + params, stmt, ret_type, buffer_map, attrs + ).with_attrs({"tir.is_global_func": True, "global_symbol": global_symbol}) + return dispatch_func + + +def create_dispatch_mod( + g_var: str, original_func: tir.PrimFunc, specialized_funcs: List[tir.PrimFunc] +) -> IRModule: + dispatch_mod: IRModule = tvm.IRModule() + g_var_supply = GlobalVarSupply(dispatch_mod) + refactored_funcs = [] + for func in specialized_funcs: + params, buffers_to_declare = collect_buffers_to_declare(func) + global_symbol, device_func = refactor_specialized_func( + g_var, func, params, buffers_to_declare + ) + global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False) + dispatch_mod[global_symbol] = device_func + refactored_funcs.append((global_symbol, device_func)) + dispatch_func = create_dispatch_func( + g_var, original_func, refactored_funcs=refactored_funcs + ) + dispatch_mod.update(tvm.IRModule.from_expr(dispatch_func)) + return dispatch_mod + + +def fast_tune_with_dynamic_range( + func: tir.PrimFunc, + target: tvm.target.Target, + topk: int = 10, + parallel_build: bool = True, + global_symbol: Optional[str] = None, + dynamic_range: Dict[str, List[int]] = {}, +) -> IRModule: + if target.kind.name != "cuda": + print("[BitBLAS] Only support CUDA target") + return None + if not global_symbol: + global_symbol = func.attrs["global_symbol"] + + # set opt_shapes for the primfunc with dynamc symbolic + opt_shapes: Dict[str, List[int]] = {} + for buffer in func.buffer_map.values(): + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var): + if axis.name in dynamic_range: + opt_shapes[axis.name] = dynamic_range[axis.name] + else: + raise ValueError( + f"[BitBLAS] The axis {axis.name} is not in dynamic_range" + ) + func = func.with_attr("opt_shapes", opt_shapes) + + if "opt_shapes" not in func.attrs: + print( + "[BitBLAS] The primfunc has no opt_shapes, please set opt_shapes for the primfunc" + ) + return None + else: + # should be list value + if not all( + [isinstance(v, tvm.ir.Array) for v in func.attrs["opt_shapes"].values()] + ): + print("[BitBLAS] The opt_shapes should be list value") + return None + + print("[BitBLAS] Start fast tuning with dynamic range") + opt_shapes = func.attrs["opt_shapes"] + + # Step 1.Calculate the Cartesian product using itertools.product + product_list = list(itertools.product(*(opt_shapes[key] for key in opt_shapes))) + + # Convert the Cartesian product to a list of dictionaries + specialize_items: List[Dict] = [ + dict(zip(opt_shapes.keys(), values)) for values in product_list + ] + + specilized_tuned_funcs: List[tir.PrimFunc] = [] + for item in specialize_items: + func = func.with_attr("opt_shapes", item) + _, best = fast_tune(func, target, topk, parallel_build) + if best is None: + return None + specilized_tuned_funcs.append(best.sch.mod["main"]) + + return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs) diff --git a/python/bitblas/generator.py b/python/bitblas/generator.py new file mode 100644 index 000000000000..4cbe697e2d80 --- /dev/null +++ b/python/bitblas/generator.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +class BitBLASGenerator: + def __init__(self, input_size, data_type='float', optimization_level=1): + self.input_size = input_size + self.data_type = data_type + self.optimization_level = optimization_level + # 其他初始化代码 + + def generate_cuda_code(self): + # 生成CUDA代码的逻辑 + pass + + def generate_header(self): + # 生成Header文件的逻辑 + pass diff --git a/python/bitblas/gpu/__init__.py b/python/bitblas/gpu/__init__.py new file mode 100644 index 000000000000..9fbe8ba93478 --- /dev/null +++ b/python/bitblas/gpu/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +GPU-generic schedule rules. +For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead +""" +from .fallback import Fallback +from .element_wise import ElementWise +from .gemv import GEMV +from .general_reduction import GeneralReduction +from .matmul import ( + Matmul, + MatmulTensorizationMMA, + MatmulTensorizationWMMA, + MatmulTensorizationLegacy, +) +from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo + +from .reduction import Reduction +from .transpose import Transpose diff --git a/python/bitblas/gpu/base.py b/python/bitblas/gpu/base.py new file mode 100644 index 000000000000..3bf927244936 --- /dev/null +++ b/python/bitblas/gpu/base.py @@ -0,0 +1,44 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# /* Modifications Copyright (c) Microsoft. */ +# The code below is mostly copied from apache/tvm base.py in dlight. +"""Base schedule rule for GPU operators.""" + +from tvm.target import Target + +from ..base import ScheduleRule + + +class GPUScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to GPU targets, will return None if the target is not GPU.""" + + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for gpu rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "gpu" in target.keys diff --git a/python/bitblas/gpu/element_wise.py b/python/bitblas/gpu/element_wise.py new file mode 100644 index 000000000000..07ea3a27e2d2 --- /dev/null +++ b/python/bitblas/gpu/element_wise.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring +"""A fallback schedule rule for GPU operators.""" +from typing import List + +from tvm import tir + +from ..base import ScheduleRule, normalize_prim_func, try_inline + + +class ElementWise(ScheduleRule): + """ + An elementwise schedule rule for GPU operators. + """ + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + block_factors = config.block + thread_factors = config.thread + step_factors = config.step + + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + block_infos = try_inline(sch, block_infos) + + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block) + ] + ) + or len(sch.get_loops(block)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + sch.reorder(*s_loops, *r_loops, *o_loops) + + block_loops = [] + vthread_loops = [] + thread_loops = [] + inner_loops = [] + for s_loop, block_factor, step_factor, thread_factor in zip( + s_loops, block_factors, step_factors, thread_factors + ): + block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) + vthread_loop, inner_loop = sch.split( + inner_loop, factors=[None, thread_factor * step_factor] + ) + thread_loop, inner_loop = sch.split( + inner_loop, factors=[None, step_factor] + ) + block_loops.append(block_loop) + vthread_loops.append(vthread_loop) + thread_loops.append(thread_loop) + inner_loops.append(inner_loop) + + # inner virtual thread first + vthread_loops = list(reversed(vthread_loops)) + sch.reorder( + *block_loops, + *vthread_loops, + *thread_loops, + *inner_loops, + *r_loops, + *o_loops + ) + sch.bind(sch.fuse(*block_loops), "blockIdx.x") + sch.bind(sch.fuse(*thread_loops), "threadIdx.x") + if len(vthread_loops) > 3: + vthread_loops = vthread_loops[0:2] + [sch.fuse(*vthread_loops[2:])] + + for i, ax in enumerate(vthread_loops): + sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) + + return sch diff --git a/python/bitblas/gpu/fallback.py b/python/bitblas/gpu/fallback.py new file mode 100644 index 000000000000..3711d3682c9e --- /dev/null +++ b/python/bitblas/gpu/fallback.py @@ -0,0 +1,95 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm fallback.py in dlight. +# pylint: disable=missing-docstring +"""A fallback schedule rule for GPU operators.""" +from typing import List, Tuple + +from tvm import tir +from tvm.target import Target + +from ..base import normalize_prim_func, try_inline +from . import utils +from .base import GPUScheduleRule + + +class Fallback(GPUScheduleRule): + """ + A fallback schedule rule for all GPU operators. It will try to inline all the blocks first, + and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> tir.Schedule: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + max_threads_per_block = utils.max_threads_per_block(target) + + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + block_infos = try_inline(sch, block_infos) + reduction_blocks: List[Tuple[tir.schedule.BlockRV, tir.schedule.LoopRV]] = [] + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block) + ] + ) + or len(sch.get_loops(block)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + sch.reorder(*s_loops, *r_loops, *o_loops) + bx, tx = sch.split( # pylint: disable=invalid-name + sch.fuse(*s_loops), + factors=[None, max_threads_per_block], + ) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + if len(r_loops) > 0: + reduction_blocks.append((block, r_loops[0])) + + for block, r_loop in reduction_blocks: + sch.decompose_reduction(block, r_loop) + + return sch + \ No newline at end of file diff --git a/python/bitblas/gpu/gemv.py b/python/bitblas/gpu/gemv.py new file mode 100644 index 000000000000..81a3d48af947 --- /dev/null +++ b/python/bitblas/gpu/gemv.py @@ -0,0 +1,860 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm gemv.py in dlight. +"""A rule for GEMV and DecodeGEMV.""" +import re +from functools import reduce +from typing import List, Optional, Union, Dict + +from tvm.tir.function import PrimFunc +from tvm import DataType, arith, ir, tir +from tvm.target import Target + +from ..base import ( + BlockInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, + is_broadcast_epilogue, + normalize_prim_func, + try_inline_contiguous_spatial, + get_output_blocks, +) +from .base import GPUScheduleRule +from .gemv_dequantize import GEMVWithDequantizeInfo +from ..base.analysis import ( + get_coalesced_veclen +) + +def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: + # Detect and return `Y` in `X[...] = X[...] + Y` + buffer_store = block.body + if not isinstance(buffer_store, tir.BufferStore): + return None + if not isinstance(buffer_store.value, tir.Add): + return None + if not ir.structural_equal( + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, + ): + return None + return buffer_store.value.b + + +def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): + loop: tir.For = sch.get(loop_rv) + return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent + + +def get_bytes(dtype: Union[DataType, str]) -> int: + num = re.findall(r"\d+", dtype) + if len(num) != 1: + raise ValueError(f"Cannot get bytes from {dtype}") + return int(num[0]) // 8 + + +def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: + """Check if the block is a GEMV. + + Parameters + ---------- + + sch : tir.Schedule + The schedule + + block_info : BlockInfo + The block info to be checked + + + Returns + ------- + ret : Optional[List[tir.Buffer]] + The vector buffers used in the GEMV if it is a GEMV, otherwise None. + """ + block = block_info.block_rv + block_stmt = sch.get(block) + conditions = [] + conditions.append(block_info.is_reduction()) + conditions.append(len(block_stmt.reads) >= 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append(_get_reduction_expr(block_stmt) is not None) + conditions.append( + len( + collect_block_iter_vars_used_in_access_region( + block_stmt, block_stmt.writes[0].region + ) + ) + > 0 + ) + if not all(conditions): + return None + + iter_num = len(block_stmt.iter_vars) + ret = [ + read.buffer + for read in block_stmt.reads + if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) + < iter_num + and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) + > 0 + ] + if len(ret) == len(block_stmt.reads): + func = sch.mod["main"] + opt_shapes: Dict = {} + if "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + # check with dynamic symbolic and at least one is unit + if not all([opt_shapes.get(buf.name, (1,))[0] == 1 for buf in ret]): + return None + elif len(ret) == 0: + return None + return ret + + +def normalize( + sch: tir.Schedule, + block_info: BlockInfo, +) -> Optional[bool]: + """Normalize the main block.""" + block_stmt: tir.Block = sch.get(block_info.block_rv) + access = arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ) + buffers_use_vars = [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.writes + ] + buffers_use_vars.extend( + [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.reads + ] + ) + if collect_vars_used_in_prim_expr(access.base) & set( + iter_var.var for iter_var in block_stmt.iter_vars + ): + return None + iter_to_info = {i.var: i for i in block_info.iters} + batch_loops, s_loops, r_loops, c_loops = [], [], [], [] + inner_axis = access.args[-1].source.source + is_inner_reduction = iter_to_info[inner_axis].kind == "R" + + for split_expr in access.args: + var = split_expr.source.source + info = iter_to_info.get(var) + loop = info.loop_rv + is_reduction = info.kind == "R" + if split_expr.lower_factor > 1: + if c_loops: + return None + loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + # we only support the reduction dim being grouped atm + if not is_reduction: + return None + c_loops.append(c_loop) + if is_reduction: + r_loops.append(loop) + elif all([var in buf_vars for buf_vars in buffers_use_vars]): + batch_loops.append(loop) + else: + s_loops.append(loop) + + assert s_loops + assert r_loops + if not c_loops: + c_loops = [sch.add_unit_loop(block_info.block_rv)] + if not batch_loops: + batch_loops = [sch.add_unit_loop(block_info.block_rv)] + sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) + sch.fuse(*batch_loops) + sch.fuse(*s_loops) + sch.fuse(*r_loops) + return is_inner_reduction + + +class GEMV(GPUScheduleRule): + """A rule for GEMV and DecodeGEMV.""" + + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + return None + block = block_info.block_rv + vector_input_buffers = is_gemv(sch, block_info) + if vector_input_buffers is None: + return None + + # Step 1. Normalize the block, merge spatial and reduction iters + is_inner_reduction = normalize(sch, block_info) + + # Step 2. Do the scheduling + if is_inner_reduction is None: + return None + elif is_inner_reduction: + self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) + return sch + else: + return self.sch_outer_reduction( + sch, target, block, vector_input_buffers, epilogue + ) + + def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + ): + """Schedule the inner reduction block.""" + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + + def apply( + sch: tir.Schedule, + gemv, + TAG_S, + TAG_R, + TS, + TR, + TILE_S, + TILE_R, + VEC_LOAD, + VEC_C, + LOAD_V_SHARED, + LOAD_V_VEC, + UNROLL, + ): + # rfactor: reduce to tx * vec_c + _, s, r, c = sch.get_loops(block=gemv) + s = sch.fuse(_, s) + r = sch.fuse(r, c) + bx, ts, tile_s = sch.split( + s, factors=[None, TS, TILE_S], preserve_unit_iters=True + ) + r, tr, tile_r_vec_n, vec_c = sch.split( + r, factors=[None, TR, TILE_R // VEC_C, VEC_C], preserve_unit_iters=True + ) + sch.reorder(r, tile_r_vec_n, tr, vec_c) + tr_vec_c = sch.fuse(tr, vec_c) + rf = sch.rfactor(tr_vec_c, 0) + + # rfactor: reduce to tx + bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv) + tr, vec_c = sch.split( + tr_vec_c, factors=[TR, None], preserve_unit_iters=True + ) + rf2 = sch.rfactor(tr, 0) + + # bind, vectorize compute + bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf) + tr, vec_c = sch.split( + tr_vec_c, factors=[TR, None], preserve_unit_iters=True + ) + sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c) + sch.bind(bx, "blockIdx.x") + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + sch.vectorize(vec_c) + + shared_mem_usage = 0 + for buf in vector_input_buffers: + buf_size = reduce( + lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1) + ) * get_bytes(buf.dtype) + shared_mem_usage += buf_size + try: + max_shared_memory_per_block = target.max_shared_memory_per_block + except: + max_shared_memory_per_block = 49152 + LOAD_V_SHARED = ( + LOAD_V_SHARED + and isinstance(shared_mem_usage, tir.IntImm) + and shared_mem_usage.value <= max_shared_memory_per_block + ) + + # vectorize load A + # (TODO) this is now actually problematic since the number of loops is dependent on the + # number of dimensions of A_q + Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") + sch.compute_at(Aq_local, r, preserve_unit_loops=True) + s_local, r_local = sch.get_loops(block=Aq_local)[-2:] + s_local, vec_load = sch.split( + s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True + ) + sch.reorder( + s_local, r_local, vec_load + ) # either s_local or r_local should be 1 + sch.vectorize(vec_load) + + # load vector into shared memory, shape should be the whole vector + if LOAD_V_SHARED: + V_shared = sch.cache_read( + rf, read_buffer_index=0, storage_scope="shared" + ) + sch.compute_at(V_shared, tr, preserve_unit_loops=True) + l = sch.get_loops(block=V_shared)[-1] + loop: tir.For = sch.get(l) + if isinstance(loop.extent, tir.IntImm): + # avoid introducing predicates when vector length is too large + vec_length = max( + min( + get_max_factor( + (int)(loop.extent), + [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS * TR * 8], + ) + // TS + // TR, + LOAD_V_VEC, + ), + 1, + ) + else: + vec_length = LOAD_V_VEC + if TAG_R == "threadIdx.x": + _, ty, tx, vec = sch.split( + l, factors=[None, TS, TR, vec_length], preserve_unit_iters=True + ) + else: + _, ty, tx, vec = sch.split( + l, factors=[None, TR, TS, vec_length], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + # reduce tile_s * tr * vec to tile_s * tr + sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) + tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:] + ts_tile_s = sch.fuse(*ts_tile_s) + ts, tile_s = sch.split( + ts_tile_s, factors=[TS, None], preserve_unit_iters=True + ) + tile_s, vec_s = sch.split( + tile_s, + factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], + preserve_unit_iters=True, + ) + sch.reorder(ts, tr, tile_s, vec_s, vec_c) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + sch.vectorize(vec_s) + + # reduce tile_s * tr to tile_s + sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) + tr, *ts_tile_s = sch.get_loops(block=gemv)[1:] + ts_tile_s = sch.fuse(*ts_tile_s) + ts, tile_s = sch.split( + ts_tile_s, factors=[TS, None], preserve_unit_iters=True + ) + sch.reorder(tile_s, ts, tr) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + + sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[3]) + sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1]) + + sch.set_scope(rf, buffer_index=0, storage_scope="local") + sch.set_scope(rf2, buffer_index=0, storage_scope="local") + + unroll_factor = UNROLL + + sch.annotate( + block_or_loop=sch.get_loops(rf)[3], + ann_key="pragma_auto_unroll_max_step", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf)[3], + ann_key="pragma_unroll_explicit", + ann_val=1, + ) + + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], + ann_key="pragma_auto_unroll_max_step", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], + ann_key="pragma_unroll_explicit", + ann_val=1, + ) + + if LOAD_V_SHARED: + sch.annotate( + block_or_loop=sch.get_loops(V_shared)[-4], + ann_key="pragma_unroll_explicit", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(V_shared)[-4], + ann_key="pragma_vectorize", + ann_val=1, + ) + + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + if is_broadcast_epilogue(sch, block, epilogue): + sch.reverse_compute_at(epilogue, bx) + sch.set_scope(block, 0, "shared") + _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) + sch.bind(tx, "threadIdx.x") + else: + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) + ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) + ts_tile_s = sch.get_loops(epilogue)[-1] + ts, tile_s = sch.split( + ts_tile_s, factors=[TS, None], preserve_unit_iters=True + ) + sch.bind(ts, TAG_S) + sch.set_scope(block, 0, "local") + # pylint: enable=invalid-name + return sch + + # Specify the `len_tx` and `len_ty` according to the loop extent + batch, s, r, c = sch.get_loops(block=block) + len_batch, len_s, len_r, len_c = ( + get_extent(sch, batch), + get_extent(sch, s), + get_extent(sch, r), + get_extent(sch, c), + ) + len_S = len_batch * len_s + len_R = len_r * len_c + + TAG_S, TAG_R = "threadIdx.y", "threadIdx.x" + if target.kind.name == "cuda": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 8 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 64 + else: + TS, TR = 16, 32 + elif target.kind.name == "metal": + # Note that the following tile size is tuned on M2 Ultra for 7B + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 1 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 16 + else: + TS, TR = 2, 64 + elif target.kind.name == "rocm": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 8 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 1, 128 + else: + TS, TR = 8, 64 + elif target.kind.name == "opencl" and "android" in str(target.host): + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 8 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 8 + TS, TR = 2, 32 + elif target.kind.name == "vulkan": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 4 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 32 + else: + TS, TR = 16, 32 + elif target.kind.name == "opencl" and "mali" in str(target.attrs): + VEC_C = 8 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 64 + TS, TR = 1, 64 + else: + VEC_C = 1 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 64 + TS, TR = 1, 64 + + if not isinstance(len_S, int): + TS, TR = 1, 64 + + while TS * TR > target.max_num_threads: + if TS > 1: + TS //= 2 + else: + TR //= 2 + + TILE_S, TILE_R = ( + 1, + ( + len_c + if len_c > 1 + else max( + get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1 + ) + ), + ) + VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) + VEC_LOAD = 1 + + return apply( + sch, + gemv=block, + TAG_S=TAG_S, + TAG_R=TAG_R, + TS=TS, + TR=TR, + TILE_S=TILE_S, + TILE_R=TILE_R, + VEC_LOAD=VEC_LOAD, + VEC_C=VEC_C, + LOAD_V_SHARED=LOAD_V_SHARED, + LOAD_V_VEC=LOAD_V_VEC, + UNROLL=UNROLL, + ) + + def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + ): + """Schedule the outer reduction block.""" + # NOTE: Only Android is supported so far + if not (target.kind.name == "opencl" and "android" in str(target.host)): + return None + batch, s, r, c = sch.get_loops(block) + len_s = get_extent(sch, s) + + # The config is designed for Adreno + tx_len = 64 + vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1 + inner_r = 4 + + bx, tx, vec = sch.split(s, factors=[None, tx_len, vec_len]) + r0, r1 = sch.split(r, factors=[None, inner_r]) + sch.bind(batch, "blockIdx.y") + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.reorder(bx, tx, r0, r1, c, vec) + + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=8) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + + cache_v = sch.cache_read(block, vector_input_buffers[0], "local") + sch.compute_at(cache_v, r1, preserve_unit_loops=True) + sch.vectorize(sch.get_loops(cache_v)[-1]) + + sch.vectorize(vec) + + # Schedule epilogue + if epilogue_info is not None: + sch.reverse_compute_at(epilogue_info.block_rv, tx) + + sch.set_scope(block, 0, "local") + + sch.decompose_reduction(block, r0) + + return sch + + def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + sch = tir.Schedule(func) + + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + reduction_block: tir.schedule.BlockRV = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block) + ] + ) + or len(sch.get_loops(block)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + if len(r_loops) > 0: + reduction_block = block + # skip analysis for following blocks + break + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + vec = 1 + if len(config.vectorize): + vec = list(config.vectorize.values())[-1] + + num_warps = int(prod(config.thread)) + warp_size = int(prod(config.reduce_thread)) + + block_b = reduction_block + output_blocks = get_output_blocks(sch, block_infos) + # compute inline + for block_info in reversed(block_infos): + block = block_info.block_rv + if block not in (reduction_block, *output_blocks): + sch.compute_inline(block) + try: + i, j, k = sch.get_loops(block_b) + except: + j, k = sch.get_loops(block_b) + block_local_A = sch.cache_read(block_b, 0, "local") + block_local_B = sch.cache_read(block_b, 1, "local") + block_local_C = sch.cache_write(block_b, 0, "local") + # reverse inline + if reduction_block != None and reduction_block != output_blocks[0]: + sch.reverse_compute_inline(output_blocks[0]) + + bx, j = sch.split(j, factors=[None, num_warps]) + k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) + sch.reorder(bx, j, k, tx) + + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.bind(j, "threadIdx.y") + + self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] + self.grid_size = [sch.get(bx).extent, 1, 1] + + sch.compute_at(block_local_A, tx, preserve_unit_loops=True) + sch.compute_at(block_local_B, tx, preserve_unit_loops=True) + sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) + + block_local_a_v = sch.get_loops(block_local_A)[-1] + sch.vectorize(block_local_a_v) + block_local_b_v = sch.get_loops(block_local_B)[-1] + sch.vectorize(block_local_b_v) + + return sch + + def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + reduction_block: tir.schedule.BlockRV = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block) + ] + ) + or len(sch.get_loops(block)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + if len(r_loops) > 0: + reduction_block = block + # skip analysis for following blocks + break + + C = reduction_block + CL = sch.cache_write(reduction_block, 0, "local") + + blck_axis = [] + vthd_axis = [] + thrd_axis = [] + tile_axis = [] + # for gemv, we should skip dynamic symbolic in s_loops + s_loops = [ + loop for loop in s_loops if isinstance(sch.get(loop).extent, tir.IntImm) + ] + assert len(s_loops) == len( + config.block + ), f"{len(s_loops)} != {len(config.block)}" + for i, loop in enumerate(s_loops): + if sch.get(loop).extent % config.block[i]: + raise NotImplementedError( + "Undivisible block in TIR schedule is still buggy." + ) + bx, _t = sch.split(loop, factors=[None, config.block[i]]) + blck_axis.append(bx) + if config.step[i] > 1: + _t, tn = sch.split(_t, factors=[None, config.step[i]]) + tile_axis.append(tn) + if config.block[i] <= config.thread[i] * config.step[i]: + tx = _t + else: + vx, tx = sch.split(_t, factors=[None, config.thread[i]]) + vthd_axis.append(vx) + thrd_axis.append(tx) + + reduce_outer_axis, reduce_inner_axis = [], [] + + for i in config.raxis_order: + loop = r_loops[i] + ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) + reduce_outer_axis.append(ro) + reduce_inner_axis.append(ri) + + vthd_axis = list(reversed(vthd_axis)) # inner virtual thread first + axis_order = ( + blck_axis + + vthd_axis + + thrd_axis + + reduce_outer_axis + + reduce_inner_axis + + tile_axis + ) + + sch.reorder(*axis_order) + blck_fused = sch.fuse(*blck_axis) + thrd_fused = sch.fuse(*thrd_axis) + sch.bind(blck_fused, "blockIdx.x") + sch.bind(thrd_fused, "threadIdx.x") + if len(vthd_axis) > 3: + vthd_axis = vthd_axis[0:2] + [sch.fuse(*vthd_axis[2:])] + for i, ax in enumerate(vthd_axis): + sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) + for ax in tile_axis: + sch.unroll(ax) + + sch.reverse_compute_at(CL, thrd_fused) + if len(tile_axis) > 0: + for ax in sch.get_loops(CL)[-len(tile_axis) :]: + sch.unroll(ax) + + sch.decompose_reduction(C, reduce_outer_axis[0]) + + try_inline_contiguous_spatial(sch, block_infos) + + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + if not isinstance(func, tir.PrimFunc): + return None + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + return None + + if is_gemv(sch, block_info) is None: + return None + + if "dequantize_info" in func.attrs: + dequantize_rule = GEMVWithDequantizeInfo() + return dequantize_rule.apply_config(func, config) + + if any([t > 1 for t in config.reduce_thread]): + return self.sch_inner_reduction_with_config(func, config) + + return self.sch_outer_reduction_with_config(func, config) diff --git a/python/bitblas/gpu/gemv_dequantize.py b/python/bitblas/gpu/gemv_dequantize.py new file mode 100644 index 000000000000..b3ac0578ded1 --- /dev/null +++ b/python/bitblas/gpu/gemv_dequantize.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""A rule for GEMV and DecodeGEMV.""" +import re +from functools import reduce +from typing import List, Optional, Union, Dict + +from tvm.tir.function import PrimFunc +from tvm import DataType, arith, ir, tir +from tvm.target import Target + +from ..base import ( + normalize_prim_func, + get_output_blocks, + get_block, +) +from .base import GPUScheduleRule + + +class GEMVWithDequantizeInfo(GPUScheduleRule): + """A rule for Dequantized GEMV.""" + + def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + sch = tir.Schedule(func) + from .intrin.lop3 import get_lop3_intrin_group + + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + conditions.append(len(dequantize_info) == 1) + # more conditions, e.g. check the format is in [fp, nf, int] + # check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (B_decode_info,) = list(dequantize_info.values()) + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + reduction_block: tir.schedule.BlockRV = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block) + ] + ) + or len(sch.get_loops(block)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + if len(r_loops) > 0: + reduction_block = block + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + def get_vectorize_factor(target_format): + # coalseced access requires the vectorize factor to be the same as the transaction size + return config.arch.transaction_size[-1] // DataType(target_format).bits + + vec = get_vectorize_factor(B_decode_info["target_format"]) + num_warps = int(prod(config.thread)) + warp_size = int(prod(config.reduce_thread)) + + block_b = reduction_block + output_blocks = get_output_blocks(sch, block_infos) + B_decode_block = get_block(sch, block_infos, B_decode_info["decode_block"]) + + # compute inline + for block_info in reversed(block_infos): + block = block_info.block_rv + if block not in (reduction_block, *output_blocks, B_decode_block): + sch.compute_inline(block) + + block_decode_B = sch.cache_read(block_b, 1, "local") + sch.compute_inline(B_decode_block) + + j, k = sch.get_loops(block_b)[-2:] + + block_shared_local_A = sch.cache_read(block_b, 0, "local") + block_shared_local_B = sch.cache_read(block_decode_B, 0, "local") + block_local_C = sch.cache_write(block_b, 0, "local") + # reverse inline + if reduction_block != None and reduction_block != output_blocks[0]: + sch.reverse_compute_inline(output_blocks[0]) + + bx, j = sch.split(j, factors=[None, num_warps]) + k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) + sch.reorder(bx, j, k, tx) + + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.bind(j, "threadIdx.y") + + self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] + self.grid_size = [sch.get(bx).extent, 1, 1] + + sch.compute_at(block_decode_B, tx, preserve_unit_loops=True) + sch.compute_at(block_shared_local_A, tx, preserve_unit_loops=True) + sch.compute_at(block_shared_local_B, tx, preserve_unit_loops=True) + sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) + + block_local_a_v = sch.get_loops(block_shared_local_A)[-1] + sch.vectorize(block_local_a_v) + block_local_b_v = sch.get_loops(block_shared_local_B)[-1] + sch.vectorize(block_local_b_v) + if "fast_decoding" in B_decode_info and B_decode_info["fast_decoding"]: + source_bit = B_decode_info["source_format"]["bits"] + out_dtype = B_decode_info["target_format"] + intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=B_decode_info["storage_dtype"], + source_format=B_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=B_decode_info["with_scaling"], + ) + sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"]) + sch.annotate( + block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"] + ) + return sch + + def apply_config(self, func: PrimFunc, config): + if any([t > 1 for t in config.reduce_thread]): + return self.sch_inner_reduction_with_config(func, config) + else: + return None diff --git a/python/bitblas/gpu/general_reduction.py b/python/bitblas/gpu/general_reduction.py new file mode 100644 index 000000000000..cc03acd993c4 --- /dev/null +++ b/python/bitblas/gpu/general_reduction.py @@ -0,0 +1,465 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=invalid-name +"""Reduction rule for operators including softmax, layer norm, RMS norm, etc""" +from typing import List, Union +from functools import reduce + +from tvm import tir +from tvm.target import Target + +from ..base import normalize_prim_func, try_inline_contiguous_spatial +from ..base.analysis import get_root_block, get_reduction_blocks, BlockInfo +from .base import GPUScheduleRule + + +class GeneralReduction(GPUScheduleRule): + """General Reduction rule for operators including softmax, layer norm, RMS norm, etc""" + + def apply( # pylint: disable=too-many-locals + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + + if target.kind.name == "cuda": + len_tx = 256 + unroll_depth = 256 + else: + len_tx = 64 + unroll_depth = 64 + + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None or len(block_infos) == 0: + return None + + dom_kind = block_infos[0].dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) + + # Align the number of block iters of the last block. + num_last_block_iter = len(block_infos[-1].dom_kind()) + if num_last_block_iter < len(dom_kind): + index_map = tir.IndexMap.from_func( + lambda *iters: ( + [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) + + list(iters) + ), + ndim=num_last_block_iter, + ) + sch.transform_block_layout(block_infos[-1].block_rv, index_map) + + try: + # TODO: fix num_leading_s = 0 case + assert num_trailing_r > 0 + for block in block_infos[1:-1]: + assert block.dom_kind() == dom_kind + assert block_infos[-1].is_injective() + assert len(block_infos[-1].dom_kind()) <= len(dom_kind) + except AssertionError: + return None + + loops = sch.get_loops(block_infos[-1].block_rv) + bx = sch.fuse(*loops[:num_leading_s]) + r_loop, tx = sch.split(loops[-1], [None, len_tx]) + sch.reorder(tx, r_loop) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + for block in reversed(block_infos[:-1]): + block = block.block_rv + for i, _ in enumerate(sch.get(block).writes): + sch.set_scope(block, buffer_index=i, storage_scope="shared") + sch.compute_at(block, bx, preserve_unit_loops=True) + r_loop = sch.fuse(*sch.get_loops(block)[-num_trailing_r:]) + r_loop, tx = sch.split(r_loop, [None, len_tx]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + # TODO: It's just a workaround to avoid unroll spatial loops, because of the bug of + # the pass lower-thread-allreduce. We should fix it in the future. + # sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + # sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1) + return sch + + def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + block_factors = config.block + thread_factors = config.thread + reduce_therad_factors = config.reduce_thread + + # For inter thread reduction case, one thread must only compute one element + assert thread_factors == block_factors + + # inline all the other blocks + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + schedule_block: tir.schedule.BlockRV = None + reduction_blocks: List[tir.schedule.BlockRV] = [] + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block_rv = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block_rv) + ] + ) + or len(sch.get_loops(block.block_rv)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block_rv)) + if len(r_loops) > 0: + # always use the last reduction block for scheduling + schedule_block = block + reduction_blocks.append(block_rv) + + # Align the number of block iters of the last block. + dom_kind = schedule_block.dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) + + schedule_block = schedule_block.block_rv + loops = sch.get_loops(schedule_block) + s_loops = loops[:num_leading_s] + r_loops = loops[-num_trailing_r:] + + block_axis = [] + thread_axis = [] + + for s_loop, block_factor in zip(s_loops, block_factors): + block_loop, thread_loop = sch.split(s_loop, factors=[None, block_factor]) + block_axis.append(block_loop) + thread_axis.append(thread_loop) + + axis_order = block_axis + thread_axis + + sch.reorder(*axis_order) + blck_fused = sch.fuse(*block_axis) + thrd_fused = sch.fuse(*thread_axis) + sch.bind(blck_fused, "blockIdx.x") + sch.bind(thrd_fused, "threadIdx.y") + + reduce_outer_axis, reduce_inner_axis, reduce_inter_threads = [], [], [] + for i in config.raxis_order: + loop = r_loops[i] + ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) + ri, thd = sch.split(ri, factors=[None, config.reduce_thread[i]]) + reduce_inter_threads.append(thd) + reduce_outer_axis.append(ro) + reduce_inner_axis.append(ri) + + axis_order = reduce_inter_threads + reduce_outer_axis + reduce_inner_axis + sch.reorder(*axis_order) + fused_reduce_inter_threads = sch.fuse(*reduce_inter_threads) + sch.bind(fused_reduce_inter_threads, "threadIdx.x") + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + reg_tile = sch.cache_write(schedule_block, 0, "local") + + # todo(lei): should add the shared_inputs/stride memory pad analysis at shared memory fusion stage. + for i, input_region in enumerate(sch.get(schedule_block).reads): + if input_region.buffer.name not in config.cached_tensors: + continue + + # otherwise cooperative fetch in shared memory. + cache_shared = sch.cache_read(schedule_block, i, "shared") + sch.compute_at(cache_shared, reduce_outer_axis[-1]) + + dim_offset = ( + len(reduce_inner_axis) + len(reduce_outer_axis) + 2 + ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis + if input_region.buffer.name in config.vectorize: + vectorize = config.vectorize[input_region.buffer.name] + else: + vectorize = 1 + + loops = sch.get_loops(cache_shared) + if len(loops) == dim_offset: + # handle fetching only one element + loops.append(sch.add_unit_loop(schedule_block)) + assert len(loops) > dim_offset + + _, ty, tx, tv = sch.split( + sch.fuse(*loops[dim_offset:]), + factors=[ + None, + int(prod(thread_factors)), + int(prod(reduce_therad_factors)), + vectorize, + ], + ) + sch.vectorize(tv) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + sch.reverse_compute_at(reg_tile, thrd_fused) + + # resolve compute_at + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None or len(block_infos) == 0: + return None + return sch + + def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + block_factors = config.block + thread_factors = config.thread + step_factors = config.step + + # inline all the other blocks + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + schedule_block: BlockInfo = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block_rv = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block_rv) + ] + ) + or len(sch.get_loops(block.block_rv)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block_rv)) + if len(r_loops) > 0: + # always use the last reduction block for scheduling + schedule_block = block + + # Align the number of block iters of the last block. + dom_kind = schedule_block.dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) + + num_last_block_iter = len(block_infos[-1].dom_kind()) + if num_last_block_iter < len(dom_kind): + index_map = tir.IndexMap.from_func( + lambda *iters: ( + [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) + + list(iters) + ), + ndim=num_last_block_iter, + ) + sch.transform_block_layout(block_infos[-1].block_rv, index_map) + + schedule_block = schedule_block.block_rv + loops = sch.get_loops(schedule_block) + s_loops = loops[:num_leading_s] + r_loops = loops[-num_trailing_r:] + + reg_tile = sch.cache_write(schedule_block, 0, "local") + + block_axis = [] + vthread_axis = [] + thread_axis = [] + inner_axis = [] + for s_loop, block_factor, step_factor, thread_factor in zip( + s_loops, block_factors, step_factors, thread_factors + ): + block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) + vthread_loop, inner_loop = sch.split( + inner_loop, factors=[None, thread_factor * step_factor] + ) + thread_loop, inner_loop = sch.split(inner_loop, factors=[None, step_factor]) + block_axis.append(block_loop) + vthread_axis.append(vthread_loop) + thread_axis.append(thread_loop) + inner_axis.append(inner_loop) + + reduce_outer_axis, reduce_inner_axis = [], [] + for i in config.raxis_order: + loop = r_loops[i] + ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) + reduce_outer_axis.append(ro) + reduce_inner_axis.append(ri) + + vthread_axis = list(reversed(vthread_axis)) # inner virtual thread first + axis_order = ( + block_axis + + vthread_axis + + thread_axis + + reduce_outer_axis + + reduce_inner_axis + + inner_axis + ) + + sch.reorder(*axis_order) + blck_fused = sch.fuse(*block_axis) + thrd_fused = sch.fuse(*thread_axis) + sch.bind(blck_fused, "blockIdx.x") + sch.bind(thrd_fused, "threadIdx.x") + if len(vthread_axis) > 3: + vthread_axis = vthread_axis[0:2] + [sch.fuse(*vthread_axis[2:])] + for i, ax in enumerate(vthread_axis): + sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) + + # todo(lei): should add the shared_inputs/stride memory pad analysis at shared memory fusion stage. + for i, input_region in enumerate(sch.get(schedule_block).reads): + if input_region.buffer.name not in config.cached_tensors: + continue + + # otherwise cooperative fetch in shared memory. + cache_shared = sch.cache_read(schedule_block, i, "shared") + sch.compute_at(cache_shared, reduce_outer_axis[-1]) + + dim_offset = ( + len(vthread_axis) + len(reduce_outer_axis) + 2 + ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis + if input_region.buffer.name in config.vectorize: + vectorize = config.vectorize[input_region.buffer.name] + else: + vectorize = 1 + + loops = sch.get_loops(cache_shared) + if len(loops) == dim_offset: + # handle fetching only one element + loops.append(sch.add_unit_loop(schedule_block)) + assert len(loops) > dim_offset + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + _, tx, tv = sch.split( + sch.fuse(*loops[dim_offset:]), factors=[None, int(prod(thread_factors)), vectorize] + ) + sch.vectorize(tv) + sch.bind(tx, "threadIdx.x") + + sch.reverse_compute_at(reg_tile, thrd_fused) + + sch.decompose_reduction(schedule_block, reduce_outer_axis[0]) + + # resolve compute_at + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None or len(block_infos) == 0: + return None + + return sch + + def sch_mutiple_reductions_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + block_factors = config.block + thread_factors = config.thread + reduce_therad_factors = config.reduce_thread + + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None or len(block_infos) == 0: + return None + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + len_tx = prod(thread_factors) * prod(reduce_therad_factors) + block_factor = prod(block_factors) + + dom_kind = block_infos[0].dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) + + # Align the number of block iters of the last block. + num_last_block_iter = len(block_infos[-1].dom_kind()) + if num_last_block_iter < len(dom_kind): + index_map = tir.IndexMap.from_func( + lambda *iters: ( + [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) + + list(iters) + ), + ndim=num_last_block_iter, + ) + sch.transform_block_layout(block_infos[-1].block_rv, index_map) + + try: + # TODO: fix num_leading_s = 0 case + assert num_trailing_r > 0 + for block in block_infos[1:-1]: + assert block.dom_kind() == dom_kind + assert block_infos[-1].is_injective() + assert len(block_infos[-1].dom_kind()) <= len(dom_kind) + except AssertionError: + return None + + loops = sch.get_loops(block_infos[-1].block_rv) + bx, _ = sch.split(sch.fuse(*loops[:num_leading_s]), factors=[None, block_factor]) + r_loop, tx = sch.split(loops[-1], [None, len_tx]) + sch.reorder(tx, r_loop) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + for block in reversed(block_infos[:-1]): + block = block.block_rv + for i, _ in enumerate(sch.get(block).writes): + sch.set_scope(block, buffer_index=i, storage_scope="shared") + sch.compute_at(block, bx, preserve_unit_loops=True) + r_loop = sch.fuse(*sch.get_loops(block)[-num_trailing_r:]) + r_loop, tx = sch.split(r_loop, [None, len_tx]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + # check the number of reduction blocks + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_blocks = get_reduction_blocks(sch, blocks) + if len(reduction_blocks) > 1: + # schedule for multiple reduction blocks (e.g. softmax) + return self.sch_mutiple_reductions_with_config(func, config) + + if any([t > 1 for t in config.reduce_thread]): + # todo(lei) should implement block reduction schedule + return self.sch_inner_reduction_with_config(func, config) + else: + return self.sch_outer_reduction_with_config(func, config) diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py new file mode 100644 index 000000000000..c32e5c8d7ccd --- /dev/null +++ b/python/bitblas/gpu/intrin/lop3.py @@ -0,0 +1,708 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.tir.function import TensorIntrin +from tvm.script import tir as T +from typing import Dict, Literal +from bitblas.quantization import ( + _tir_packed_to_signed_float, + _tir_packed_to_unsigned_float, +) + + +decode_i4_to_f16 = """ +template +__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4u, B_local_decode, N); +} +""" + +decode_i4_to_f16_scale = """ +template +__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + unsigned v0 = *((unsigned short *)scale); + unsigned v1 = *((unsigned short *)scale); + unsigned __packed_scale = (v1 << 16) | v0; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__packed_scale), "r"(0)); + } + +} + +template +__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4s, B_local_decode, scale, N); +} + +template +__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4u, B_local_decode, scale, N); +} +""" + +decode_i2_to_f16 = """ +template +__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2u, B_local_decode, N); +} +""" + +decode_i2_to_f16_scale = """ +template +__device__ void decode_i2b_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } +} + +template +__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2s, B_local_decode, scale, N); +} + +template +__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2u, B_local_decode, scale, N); +} +""" + +decode_i1s_to_i8s_l16 = """template +__device__ void decode_i1s_to_i8s_l16(T1 *_i1s, T2 *_i8s, const int N = 16) +{ + int *i8s = reinterpret_cast(_i8s); + int16_t i1s_i16 = *reinterpret_cast(_i1s); + // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} + // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} + int i1s = (i1s_i16 & 0x0f0f); + i1s |= ((i1s_i16 & 0xf0f0) << 12); + // i1s {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // First, we extract the i1s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; + + for (int i = 0; i < N / 4; i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i1s >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + +decode_i2s_to_i8s = """template +__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2s = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2s = *_i2s; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I4s_TO_INT8_TO_I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_INT8_TO_I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + +decode_i4s_to_i8s = """template +__device__ void decode_i4s_to_i8s(T1 *_i4s, T2 *_i8s, const int N = 16) +{ + uint *i8s = reinterpret_cast(_i8s); + uint *i4s = reinterpret_cast(_i4s); + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4 + static constexpr uint I4s_TO_INT8_TO_I8s_MAGIC_NUM = 0x00000000; // 1024 +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i4s[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_INT8_TO_I8s_MAGIC_NUM), "n"(immLut)); + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i + 2]) + : "r"(i4s[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_INT8_TO_I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + + +def get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + source_format="uint", + target_dtype="float16", + loops_extent=8, + with_scale=False, +): + """ + loops extent is the number of elements to be decoded in one stage + for memory friendly process, the loops_extent should be a multiple of (sizeof(int) // 8). + However, for the case of int1b, it is not possible to decode 8 elements in one stage, so we have to use 16. + """ + if target_dtype == "float16": + d4f = "f16" + elif target_dtype == "int8": + d4f = "i8s" + else: + raise ValueError("Unsupported target dtype: {}".format(target_dtype)) + source_symbol = "u" if source_format == "uint" else "s" + func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f) + if with_scale: + func_name += "_scale" + + assert storage_dtype in ["int8", "int32", "uint32"] + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + elem_per_unit = storage_nbit // source_bit + n_storage_elems = loops_extent // elem_per_unit + + if source_format == "int": + decode_func = _tir_packed_to_signed_float(storage_type, storage_nbit) + elif source_format == "uint": + decode_func = _tir_packed_to_unsigned_float(storage_type, storage_nbit) + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if with_scale is False: + + @T.prim_func + def fast_decode_desc(compressed: T.handle, decompressed: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + + with T.block("root"): + T.reads(Compressed[0:n_storage_elems]) + T.writes(Decompressed[0:loops_extent]) + for i in T.grid(loops_extent): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = decode_func( + source_bit, + Compressed[vi // elem_per_unit], + vi % elem_per_unit, + dtype=target_dtype, + ) + + @T.prim_func + def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + + with T.block("root"): + T.reads(Compressed[0:n_storage_elems]) + T.writes(Decompressed[0:loops_extent]) + T.call_extern( + "handle", + func_name, + Compressed.data, + Decompressed.data, + loops_extent, + ) + + else: + + @T.prim_func + def fast_decode_desc( + compressed: T.handle, decompressed: T.handle, scale: T.handle + ) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + ) + with T.block("root"): + T.reads(Compressed[0:n_storage_elems], Scale[0:1]) + T.writes(Decompressed[0:loops_extent]) + for i in T.grid(loops_extent): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = ( + decode_func( + source_bit, + Compressed[vi // elem_per_unit], + vi % elem_per_unit, + dtype=target_dtype, + ) + * Scale[0] + ) + + @T.prim_func + def fast_decode_impl( + compressed: T.handle, decompressed: T.handle, scale: T.handle + ) -> None: + s0 = T.int32() + + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + offset_factor=1, + strides=[s0], + ) + with T.block("root"): + T.reads(Compressed[0:n_storage_elems], Scale[0:1]) + T.writes(Decompressed[0:loops_extent]) + T.call_extern( + "handle", + func_name, + Compressed.data, + Decompressed.data, + Scale.access_ptr("r"), + loops_extent, + ) + + return fast_decode_desc, fast_decode_impl + + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_f16_l8_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int8", target_dtype="float16", loops_extent=8 + ), +) + + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=2, storage_dtype="int8", target_dtype="float16", loops_extent=8 + ), +) + +LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN = ( + "lop3_fast_decode_u4_to_int32_to_f16_l8_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int32", target_dtype="float16", loops_extent=8 + ), +) + + +LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u4_to_int32_to_f16_l8_scale_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int32", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN = ( + "lop3_fast_decode_u4_to_uint32_to_f16_l8_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="uint32", target_dtype="float16", loops_extent=8 + ), +) + + +LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u4_to_uint32_to_f16_l8_scale_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="uint32", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_i8_l8_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=8 + ), +) + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_i8_l16_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=16 + ), +) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_i8_l16_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16 + ), +) + +LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN = ( + "lop3_fast_decode_i2_to_int8_to_i8_l16_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16 + ), +) + +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ( + "LOP3_FAST_DECODE_UINT1_to_int8_to_i8_l16_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16 + ), +) + +LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ( + "lop3_fast_decode_i4_to_int8_to_f16_l8_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + ), +) + +LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_i4_to_int8_to_f16_l8_scale_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN = ( + "lop3_fast_decode_i2_to_int8_to_f16_l8_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + ), +) + +LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_i2_to_int8_to_f16_l8_scale_" +) +TensorIntrin.register( + LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +def get_lop3_intrin_group( + out_dtype: Literal["float16", "int8"], + source_format: Literal["int", "uint"] = "uint", + source_bit: int = 4, + storage_dtype: Literal["int32", "int8"] = "int8", + with_scaling: bool = False, +) -> Dict[str, str]: + """ + This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. + LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of + intrinsic operations that can be performed on these inputs. This function retrieves and returns this group. + + Parameters + ---------- + in_dtype : Literal["int8"] + The data type of the input. It should be "int8". + + out_dtype : Literal["float16", "int8"] + The data type of the output. It can be either "float16" or "int8". + + storage_nbit : int, optional + The number of bits used for storage. By default, it is 4. + + with_scale : bool, optional + A boolean parameter that indicates whether scaling should be applied. By default, it is False. + + Returns + ------- + Dict[str, str] + A dictionary mapping the names of the intrinsics to their corresponding implementations. + """ + assert out_dtype in ["float16", "int8"] + + dtype_mapping = {"float16": "f16", "int8": "i8", "int32": "i32"} + target_dtype = dtype_mapping[out_dtype] + target_bits = tvm.DataType(out_dtype).bits + loop_extent = 128 // target_bits + if source_format not in ["int", "uint"]: + raise ValueError("Invalid source_format. Expected 'int' or 'uint'.") + source_symbol = "i" if source_format == "int" else "u" + + _intrin = f"lop3_fast_decode_{source_symbol}{source_bit}_to_{storage_dtype}_to_{target_dtype}_l{loop_extent}_" + if with_scaling: + _intrin += "scale_" + + import_c_map = { + "i4_to_f16": decode_i4_to_f16, + "i2_to_f16": decode_i2_to_f16, + "i4_to_f16_scale": decode_i4_to_f16_scale, + "i2_to_f16_scale": decode_i2_to_f16_scale, + "i1_to_i8": decode_i1s_to_i8s_l16, + "i2_to_i8": decode_i2s_to_i8s, + "i4_to_i8": decode_i4s_to_i8s, + } + key = f"i{source_bit}_to_{target_dtype}" + if with_scaling: + key += "_scale" + + return { + "c_source": import_c_map[key], + "compute": _intrin, + } diff --git a/python/bitblas/gpu/matmul.py b/python/bitblas/gpu/matmul.py new file mode 100644 index 000000000000..4147e44d020c --- /dev/null +++ b/python/bitblas/gpu/matmul.py @@ -0,0 +1,372 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from dataclasses import dataclass +from typing import Optional + +from tvm import tir +from tvm.target import Target +from tvm.tir.stmt import ForKind + +from ..base import analysis +from .base import GPUScheduleRule +from . import utils +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_producers, + get_in_out_dtypes, + get_index_map, + normalize_to_matmul, + get_reduction_blocks, +) +from .matmul_mma import MatmulTensorizationMMA +from .matmul_wmma import MatmulInt8Tensorization, MatmulTensorizationWMMA, MatmulTensorizationLegacy +from functools import reduce + +class Matmul(GPUScheduleRule): + """The schedule rule for matmul-like computation""" + + @dataclass + class Config: + block_size_x: int = 8 + block_size_y: int = 8 + vthread_x: int = 1 + vthread_y: int = 1 + micro_size_x: int = 4 + micro_size_y: int = 4 + micro_size_k: int = 8 + vector_size: int = 1 + unroll: int = 256 # 0 means no unroll + use_shared: bool = True + storage_align: bool = False + inner_x: bool = False + + def get_configs(self, target: Target) -> Config: + """Get the schedule config for the target""" + if target.kind.name == "cuda" or target.kind.name == "rocm": + return Matmul.Config( + block_size_x=8, + block_size_y=16, + vthread_x=1, + vthread_y=1, + micro_size_x=4, + micro_size_y=4, + micro_size_k=16, + vector_size=2, + unroll=256, + use_shared=True, + storage_align=True, + inner_x=False, + ) + elif target.kind.name == "opencl" and "android" in str(target.host): + return Matmul.Config( + block_size_x=8, + block_size_y=8, + vthread_x=1, + vthread_y=1, + micro_size_x=8, + micro_size_y=2, + micro_size_k=16, + vector_size=8, + unroll=64, + use_shared=False, + storage_align=False, + inner_x=True, + ) + else: + return Matmul.Config() + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + sch = normalize_to_matmul(sch, main_block) + if sch is None: + return None + + # Step 1. Check Tensor Core support + # Tensorization config: + # If any value of I, J, K is fixed and less than this threshold, + # tensorization rule will not be applied. + minimal_tensorize_threshold = 64 + block_stmt = sch.get(main_block) + if target.kind.name == "cuda" and utils.get_sm_version(target) >= 70: + apply_tensorization: bool = True + # the batch dimension is not taken into consideration. + # Analyze read/write buffers and choose correct tensorizer: int8 or fp16. + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + if in_dtype not in ["int8", "float16"]: + apply_tensorization = False + for item_var in block_stmt.iter_vars[1:]: + extent = item_var.dom.extent + if isinstance(extent, tir.expr.IntImm): + if extent.value <= minimal_tensorize_threshold: + apply_tensorization = False + if apply_tensorization: + if in_dtype == "int8" and out_dtype == "int32": + tensorize_sch = MatmulInt8Tensorization().apply(func, target, _) + elif utils.get_sm_version(target) >= 80: + # For A100(sm_80) or more advanced gpu, use MMA tensorization. + tensorize_sch = MatmulTensorizationMMA().apply(func, target, _) + else: + # For other GPUs, use WMMA tensorization. + tensorize_sch = MatmulTensorizationWMMA().apply(func, target, _) + if tensorize_sch is not None: + return tensorize_sch + + # Step 2. Get schedule config. + config = self.get_configs(target) + + # Step 3. Schedule matmul + y_kernel_size = config.vthread_y * config.block_size_y * config.micro_size_y + x_kernel_size = config.vthread_x * config.block_size_x * config.micro_size_x + if config.inner_x: + sch.pad_einsum( + main_block, + [1, y_kernel_size, x_kernel_size, config.micro_size_k], + ) + batch, y, x, k = sch.get_loops(main_block) + else: + sch.pad_einsum( + main_block, + [1, x_kernel_size, y_kernel_size, config.micro_size_k], + ) + batch, x, y, k = sch.get_loops(main_block) + by, vy, ty, yi = sch.split( + y, [None, config.vthread_y, config.block_size_y, config.micro_size_y] + ) + bx, vx, tx, xi = sch.split( + x, [None, config.vthread_x, config.block_size_x, config.micro_size_x] + ) + ko, ki = sch.split(k, factors=[None, config.micro_size_k]) + sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) + by = sch.fuse(batch, by) + sch.bind(bx, "blockIdx.x") + sch.bind(by, "blockIdx.y") + sch.bind(vy, "vthread.y") + sch.bind(vx, "vthread.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + inner_loop = config.micro_size_x if config.inner_x else config.micro_size_y + if inner_loop % config.vector_size == 0: + _, v = sch.split(xi, [None, config.vector_size]) + sch.vectorize(v) + + if config.unroll > 0: + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=config.unroll) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + + l2g = sch.cache_write(main_block, 0, "local") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + if config.micro_size_x % config.vector_size == 0: + _, v = sch.split(sch.get_loops(l2g)[-1], [None, config.vector_size]) + sch.vectorize(v) + + if config.use_shared: + + def _cooperative_fetch(index, vec_len): + block = sch.cache_read(main_block, index, "shared") + num_loops = len(sch.get_loops(block)) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + ty, tx, _, vec = sch.split( + sch.fuse(*loops), + factors=[config.block_size_y, config.block_size_x, None, vec_len], + ) + sch.vectorize(vec) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + if config.storage_align: + sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len) + return block + + a_g2s = _cooperative_fetch(0, vec_len=config.vector_size) + b_g2s = _cooperative_fetch(1, vec_len=config.vector_size) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + else: + auto_inline_producers(sch, main_block) + + auto_inline_consumer_chain(sch, l2g) + sch.decompose_reduction(main_block, ko) + + # Step 4. Check if there are unbound blocks. Execute fallback scheduling to them. + def is_scheduled(block: tir.schedule.BlockRV) -> bool: + loops = sch.get_loops(block) + loop_kinds = {sch.get(loop).kind for loop in loops} + return loop_kinds != {ForKind.SERIAL} + + blocks = sch.get_child_blocks(root_block) + max_threads_per_block = utils.max_threads_per_block(target) + for block in blocks: + if is_scheduled(block): + continue + # no axis of the block is bound to thread or block + s_loops = sch.get_loops(block) + bx, tx = sch.split( + sch.fuse(*s_loops), + factors=[ + None, + 256, + ], + ) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + # in some case conv template will use this rule, but the tile config is not + # analyzed by matmul expr. + if len(config.block) != 2: + print(f"Warning: block config {config.block} is not valid for matmul, skip.") + return None + + main_block = reduction_blocks[0] + + block_stmt = sch.get(main_block) + + # cuda core prefer b is [k, j] layout without swizzling. + index_maps = get_index_map(block_stmt, ["n", "n", "n"]) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Get schedule config. + block_row_warps = config.block[0] // (config.thread[0] * config.step[0]) + block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) + thread_row_tiles = config.thread[1] // (config.step[0] * 2) + thread_col_tiles = config.thread[1] // (config.step[1] * 2) + vthread_row_tiles = config.step[0] * 2 # expand vtrhead to avoid load band conflict + vthread_col_tiles = config.step[1] * 2 # expand vtrhead to avoid load band conflict + chunk = config.rstep[0] + + # Step 3. Schedule matmul + BM = block_row_warps * vthread_row_tiles * thread_row_tiles + BN = block_col_warps * vthread_col_tiles * thread_col_tiles + BK = chunk + + sch.pad_einsum( + main_block, + [1, BM, BN, BK], + ) + batch, y, x, k = sch.get_loops(main_block) + by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles]) + bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles]) + ko, ki = sch.split(k, factors=[None, BK]) + sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) + by = sch.fuse(batch, by) + sch.bind(bx, "blockIdx.x") + sch.bind(by, "blockIdx.y") + sch.bind(vy, "vthread.y") + sch.bind(vx, "vthread.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + l2g = sch.cache_write(main_block, 0, "local") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + + def _cooperative_fetch(index, vec_len): + block = sch.cache_read(main_block, index, "shared") + num_loops = len(sch.get_loops(block)) + block_local = sch.cache_read(main_block, index, "local") + sch.compute_at(block_local, ki, preserve_unit_loops=True) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + _, ty, tx, vec = sch.split( + sch.fuse(*loops), + factors=[None, block_row_warps, block_col_warps, vec_len], + ) + + auto_inline_producers(sch, block) + + def is_trivial_load(block): + # avoid vectorize under global[v2, v1]] shared[v1, v2] case + reads = sch.get(block).reads + writes = sch.get(block).writes + if len(reads) != 1 or len(writes) != 1: + return False + return all( + read.region[-1] == write.region[-1] for read, write in zip(reads, writes) + ) + + if is_trivial_load(block): + sch.vectorize(vec) + + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + _, vec = sch.split( + sch.fuse(*sch.get_loops(block_local)[-2:]), + [None, vec_len // prod(config.step)], + ) + sch.vectorize(vec) + + return block + + for i, input_region in enumerate(sch.get(main_block).reads): + _buffer_name = input_region.buffer.name.replace("_reindex", "").replace("_pad", "") + if _buffer_name not in config.cached_tensors: + print( + f"Warning: {_buffer_name} is not in cached_tensors {config.cached_tensors}, skip." + ) + continue + + # otherwise cooperative fetch in shared memory. + if _buffer_name in config.vectorize: + vectorize = config.vectorize[_buffer_name] + else: + vectorize = 1 + + _cooperative_fetch(i, vec_len=vectorize) + + auto_inline_consumer_chain(sch, l2g) + + _, vec = sch.split( + sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)] + ) + sch.vectorize(vec) + + sch.decompose_reduction(main_block, ko) + return sch diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py new file mode 100644 index 000000000000..4003d825c03c --- /dev/null +++ b/python/bitblas/gpu/matmul_analysis.py @@ -0,0 +1,763 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Set, Union, Tuple, Dict +from tvm import tir, DataType +from tvm.ir import Range +from tvm.tir import IterVar, PrimExpr, Var +from tvm.tir.analysis import undefined_vars +from tvm.tir.schedule.schedule import BlockRV +from ..base.analysis import ( + collect_block_iter_vars_used_in_access_region, + get_root_block, + get_reduction_blocks, +) +from tvm.target.target import Target +from tvm.tir import IndexMap + + +def _is_one(x: PrimExpr) -> bool: + return isinstance(x, tir.IntImm) and x.value == 1 + + +def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for producer in sch.get_producers(block): + result.append(producer) + result.extend(_collect_producers(sch, producer)) + return result + + +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for consumer in sch.get_consumers(block): + result.append(consumer) + result.extend(_collect_consumers(sch, consumer)) + return result + + +def auto_inline_producers( + sch: tir.Schedule, + block: tir.schedule.BlockRV, + skip_blocks: Optional[List[tir.schedule.BlockRV]] = None, +): + skip_blocks = skip_blocks or [] + while True: + inlined_cnt = 0 + producers = _collect_producers(sch, block) + for producer in producers: + if any( + sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks + ): + continue + try: + sch.compute_inline(producer) + inlined_cnt += 1 + except: # pylint: disable=bare-except + continue + if inlined_cnt == 0: + return + + +def auto_inline_consumers( + sch: tir.Schedule, + block: tir.schedule.BlockRV, +): + while True: + inlined_cnt = 0 + consumers = _collect_consumers(sch, block) + for consumer in consumers: + try: + sch.compute_inline(consumer) + inlined_cnt += 1 + except: # pylint: disable=bare-except + continue + for consumer in consumers: + try: + sch.reverse_compute_inline(consumer) + inlined_cnt += 1 + except: # pylint: disable=bare-except + continue + if inlined_cnt == 0: + return + + +def auto_inline_consumer_chain( + sch: tir.Schedule, + block: tir.schedule.BlockRV, +): + auto_inline_consumers(sch, block) + remaining_consumers = sch.get_consumers(block) + + if len(remaining_consumers) != 0: + # Some blocks have failed to be inlined to the producer cache-write stage. + # This could be due to another producer block that has not been scheduled. + for c in remaining_consumers: + for p in sch.get_producers(c): + if sch.get(p) != sch.get(block): + sch.compute_inline(p) + + # Try inlining into the cache-write stage again, this time it should succeed. + auto_inline_consumers(sch, block) + + +# find the block that required to be reindex and scope. +def find_last_producer_from_buffer( + sch, main_block, buffer: tir.Buffer +) -> Optional[BlockRV]: + # block that most near to the arguments + block = main_block + buffer = buffer + while True: + last_buffer = buffer + producers = sch.get_producers(block) + + if len(producers) == 0: + # do not have any producer means it is the first block + break + + for producer in producers: + for write in sch.get(producer).writes: + if write.buffer == buffer: + block = producer + buffer = sch.get(producer).reads[0].buffer + if buffer == last_buffer: + break + return block + + +def find_arg_idx_from_buffer_chain( + sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer +) -> int: + """traverse to find the arg index from the buffer""" + producers = sch.get_producers(main_block) + + # a head buffer has no producer blocks + def find_args_index(sch: tir.Schedule, buffer: tir.Buffer): + for i, param in enumerate(sch.mod["main"].params): + if sch.mod["main"].buffer_map[param] == buffer: + return i + return None + + is_head_buffer = len(producers) == 0 + if is_head_buffer: + return find_args_index(sch, buffer) + for block in sch.get_producers(main_block): + if len(sch.get(block).reads) != 1 or len(sch.get(block).writes) != 1: + continue + for write in sch.get(block).writes: + if write.buffer == buffer: + return find_arg_idx_from_buffer_chain(sch, block, buffer) + + # if no buffer producer block found, it means the buffer is an input buffer + return find_args_index(sch, buffer) + + +class IterKind(Enum): + """Iter kinds for GEMM-liked programs. + We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], + where `I, J, K` are fundamental axes for gemm and `S` represents all + other spatial axes (e.g. batches) + kIter_S: spatial axes + kIter_I: I axes + kIter_J: J axes + kIter_K: K axes + kIter_T: trivial axes (i.e. with extent 1) + """ + + kIter_S = 0 + kIter_I = 1 + kIter_J = 2 + kIter_K = 3 + kIter_T = 4 + + +@dataclass +class IterTrait: + kind: IterKind + extent: PrimExpr + + +def make_iter_fusion_index_map( + traits: List[IterTrait], + kind_order: List[IterKind], +) -> tir.IndexMap: + fused_iters: Dict[IterKind, PrimExpr] = {} + input_iters: List[tir.Var] = [] + for i, trait in enumerate(traits): + v_i = tir.Var(f"i{i}", trait.extent.dtype) + input_iters.append(v_i) + if trait.kind == IterKind.kIter_T: + continue + if trait.kind not in kind_order: + raise ValueError(f"Unknown iter kind {trait.kind}") + if trait.kind in fused_iters: + fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i + else: + fused_iters[trait.kind] = v_i + + final_indices: List[tir.PrimExpr] = [ + fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) + for kind in kind_order + ] + + return tir.IndexMap(input_iters, final_indices, None) + + +def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: + """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] + + Parameters + ---------- + block : tir.Block + The block to be analyzed + + Returns + ------- + traits : Optional[Tuple[List[IterTrait]]] + The detected iter traits for axes in A, B and C. None if the block + does not match the pattern. + + """ + + if len(block.reads) != 2 or len(block.writes) != 1: + return None + + def get_access_axes(region: List[Range]) -> Set[Var]: + axes: Set[Var] = set() + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes = axes.union(set(undefined_vars(r.min))) + return axes + + try: + A_axes = get_access_axes(block.reads[0].region) + B_axes = get_access_axes(block.reads[1].region) + C_axes = get_access_axes(block.writes[0].region) + except ValueError: + return None + + traits: Dict[Var, IterTrait] = {} + for iter_var in block.iter_vars: + var = iter_var.var + kind: IterKind + if _is_one(iter_var.dom.extent): + if iter_var.iter_type == tir.IterVar.CommReduce: + # for simplified case (e.g. 1x1 conv kernel) + kind = IterKind.kIter_K + else: + kind = IterKind.kIter_T + elif iter_var.iter_type == iter_var.DataPar: + if var in A_axes and var in B_axes and var in C_axes: + kind = IterKind.kIter_S + elif var in A_axes and var in C_axes: + kind = IterKind.kIter_I + elif var in B_axes and var in C_axes: + kind = IterKind.kIter_J + else: + return None + elif iter_var.iter_type == tir.IterVar.CommReduce: + if var in A_axes and var in B_axes and var not in C_axes: + kind = IterKind.kIter_K + else: + return None + else: + return None + traits[var] = IterTrait(kind, iter_var.dom.extent) + + # A Gemm-kernel requires have I, J and K axes + gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K} + if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: + return None + + A_traits = [ + traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes + ] + B_traits = [ + traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes + ] + C_traits = [ + traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes + ] + block_traits = [traits[i.var] for i in block.iter_vars] + return A_traits, B_traits, C_traits, block_traits + + +def get_index_map( + block: tir.Block, layout: List[str] = ["n", "t", "n"] +) -> Optional[Tuple[tir.IndexMap, ...]]: + """Get index maps for the block + + Parameters + ---------- + block : tir.Block + The block to be analyzed + + layout : List[str] + the target layout index map to be used. + 'n' for [i, k] layout + 't' for [k, j] layout + 'a' for auto inference based on whether the last axis is reduction. + + Returns + ------- + index_maps : Optional[Tuple[tir.IndexMap]] + The index maps for the block, or None if the block is not a gemm-liked kernel + """ + traits = detect_iter_traits(block) + if traits is None: + return None + A_traits, B_traits, C_traits, block_traits = traits + + def get_ordered_axes(region: List[Range]) -> Set[Var]: + axes: List[Var] = [] + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes.append(r.min) + return axes + + def is_common_reduce(var: Var) -> bool: + for iter_var in block.iter_vars: + if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: + return True + return False + + def check_last_trait(region: List[Range]): + axes = get_ordered_axes(region) + return is_common_reduce(axes[-1]) + + def infer_layout(layout: str, region: List[Range], kind: str = "A"): + """ + Infer the layout based on the region and the kind of buffer + kind: "A", "B", "C" + """ + primary_iter, secondary_iter, reduction_iter = { + "A": (IterKind.kIter_I, IterKind.kIter_K, IterKind.kIter_K), + "B": (IterKind.kIter_K, IterKind.kIter_J, IterKind.kIter_K), + "C": (IterKind.kIter_I, IterKind.kIter_J, None), + }[kind] + + spatial_iter = { + "A": IterKind.kIter_I, + "B": IterKind.kIter_J, + "C": None, + }[kind] + + if layout == "n": + return [IterKind.kIter_S, primary_iter, secondary_iter] + elif layout == "t": + return [IterKind.kIter_S, secondary_iter, primary_iter] + elif layout == "a": + # auto inference layout + # for buffer with reduction axis, we put it as the last axis + # otherwise, we put it as the first axis + if kind == "C": + return [IterKind.kIter_S, primary_iter, secondary_iter] + else: + return ( + [IterKind.kIter_S, spatial_iter, reduction_iter] + if check_last_trait(region) + else [IterKind.kIter_S, reduction_iter, spatial_iter] + ) + else: + raise ValueError(f"Unknown layout {layout}") + + A_index_map = make_iter_fusion_index_map( + A_traits, infer_layout(layout[0], block.reads[0].region, kind="A") + ) + B_index_map = make_iter_fusion_index_map( + B_traits, infer_layout(layout[1], block.reads[1].region, kind="B") + ) + C_index_map = make_iter_fusion_index_map( + C_traits, infer_layout(layout[2], block.writes[0].region, kind="C") + ) + + matmul_index_map = make_iter_fusion_index_map( + block_traits, + [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K], + ) + + return ( + matmul_index_map, + A_index_map, + B_index_map, + C_index_map, + ) + + +def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: + """ + Detect In/Out data types for the given block based on the analysis if read/write buffers. + """ + assert len(block.reads) > 0 and len(block.writes) > 0 + in_dtype = block.reads[0].buffer.dtype + out_dtype = block.writes[0].buffer.dtype + return (in_dtype, out_dtype) + + +def get_dequantize_block(sch, blocks) -> Optional[BlockRV]: + # check at least two input and one output + # at lease one input has uint dtype, and the output dtype is float + def is_dequantize(block: BlockRV) -> bool: + block_stmt = sch.get(block) + if len(block_stmt.reads) < 2: + return False + has_uint_input = any( + "uint" in str(region.buffer.dtype) for region in block_stmt.reads + ) + if not has_uint_input: + return False + if len(block_stmt.writes) != 1 or "float" not in str( + block_stmt.writes[0].buffer.dtype + ): + return False + return True + + dequantize_blocks = [block for block in blocks if is_dequantize(block)] + return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None + + +def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + if iter_types != {IterVar.DataPar}: + return False, False + if not isinstance(block_stmt.body, tir.BufferStore): + return False, False + if not isinstance(block_stmt.body.value, tir.BufferLoad): + return False, False + + def get_access_vars(region: List[Range]) -> List[Var]: + axes: List[Var] = [] + for r in region: + if not _is_one(r.extent): + return None + axes.extend(undefined_vars(r.min)) + # remove trivial axis + trivial_vars = set( + iter_var.var + for iter_var in block_stmt.iter_vars + if _is_one(iter_var.dom.extent) + ) + axes = [axis for axis in axes if axis not in trivial_vars] + # remove duplicate axis + axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] + return axes + + lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] + rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] + is_identity = list(lhs_access_vars) == list(rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set( + lhs_access_vars + ) == set(rhs_access_vars) + return is_identity, is_transpose + + +def is_identity_block(block_stmt: tir.Block) -> bool: + return is_identity_or_transpose_block(block_stmt)[0] + + +def is_transpose_block(block_stmt: tir.Block) -> bool: + return is_identity_or_transpose_block(block_stmt)[1] + + +def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]): + result_blocks = [] + for block in blocks: + if not is_transpose_block(sch.get(block)): + result_blocks.append(block) + continue + try: + sch.compute_inline(block) + except: + try: + sch.reverse_compute_inline(block) + except: + result_blocks.append(block) + return result_blocks + + +def normalize_to_matmul( + sch: tir.Schedule, main_block: BlockRV, layout: List[str] = ["n", "t", "n"] +) -> Optional[tir.Schedule]: + block_stmt = sch.get(main_block) + + # let layout be 'a' to auto inference the layout + index_maps = get_index_map(block_stmt, layout=layout) + if index_maps is None: + print("[WARNING] Cannot find the appropriate index map for tensorcore") + return None + + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # `skip_simplify` to avoid the bug in the 1x1 conv + block = sch.reindex(main_block, ("read", 0), skip_simplify=True) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1), skip_simplify=True) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0), skip_simplify=True) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + sch.mod["main"] = sch.mod["main"].with_attr("dlight.tensorcore_prenormlized", True) + return sch + + +def get_tensorized_func_and_tags( + func: tir.PrimFunc, + target: Target, + layout: List[str] = ["a", "a", "a"], + skip_normalize: bool = False, + allow_gemv: bool = False, +) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group, + ) + + """ + transform function to matmul if necessary (e.g. transform conv2d with im2col) + """ + # step1. detect whether the function can utilize tensorcore + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_blocks = get_reduction_blocks(sch, blocks) + if not reduction_blocks or len(reduction_blocks) != 1: + return func, None + + def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + conditions = [] + conditions.append(len(block_stmt.reads) == 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append( + len( + collect_block_iter_vars_used_in_access_region( + block_stmt, block_stmt.writes[0].region + ) + ) + > 0 + ) + if not all(conditions): + return False + return True + + # step2. transform function to tensorcore matmul (e.g. conv2d with im2col) + def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + def analysis_tensorcore_tags( + sch: tir.Schedule, block: BlockRV, target: Target + ) -> bool: + tags: Dict[str, Union[List[int], int]] = {} + block_stmt = sch.get(block) + + # analysis tensorcore axis + # todo(lei): maybe we can remove this in the future + (write_buffer_region,) = block_stmt.writes + out_axis = len(write_buffer_region.buffer.shape) + tags["tensorcore_config"] = [out_axis - 2, out_axis - 1] + + # analysis pipeline stage + # todo(lei): maybe we can integrate this into policy in the future + tags["pipeline_stage"] = 1 + if target.kind.name == "cuda" and check_sm_version(target.arch) >= 80: + # enable pipleline stage only for sm_80 devices + tags["pipeline_stage"] = 2 + + # analysis async copy + # todo(lei): maybe we can integrate this into policy in the future + tags["use_async_copy"] = 0 + if tags["pipeline_stage"] == 2 and check_sm_version(target.arch) >= 80: + # async copy only works in software pipeline. + tags["use_async_copy"] = 1 + + # analysis intrin infomation + def get_ordered_axes(region: List[Range]) -> Set[Var]: + axes: List[Var] = [] + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes.append(r.min) + return axes + + def is_common_reduce(var: Var) -> bool: + for iter_var in block_stmt.iter_vars: + if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: + return True + return False + + def check_last_trait(region: List[Range]): + axes = get_ordered_axes(region) + return is_common_reduce(axes[-1]) + + intrin_info: dict = {} + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + intrin_info["in_dtype"] = in_dtype + intrin_info["out_dtype"] = out_dtype + # if the last dimension is reduce axis, the B is transposed + intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region) + if func.attrs is not None and "smooth_a" in func.attrs: + intrin_info["smooth_a"] = func.attrs["smooth_a"] + if func.attrs is not None and "smooth_b" in func.attrs: + intrin_info["smooth_b"] = func.attrs["smooth_b"] + tags["intrin_info"] = intrin_info + + return tags + + (main_block,) = reduction_blocks + if _can_be_tensorized(sch, main_block) is None: + return func, None + + block_stmt = sch.get(main_block) + if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + try: + _ = get_wmma_intrin_group( + in_dtype=in_dtype, + out_dtype=out_dtype, + ) + except: + print("[BitBLAS][WARNING] Cannot find the corresponding wmma intrin group") + return func, None + + # reindex and transform functions + # Normalize tensor functions to C[S, I, J] += A[S, I, K] * B[S, J, K] + # or C[S, I, J] += A[S, I, K] * B[S, K, J] + # skip normalize when we want to detect tags only. + if not skip_normalize: + sch = normalize_to_matmul(sch, main_block, layout) + if sch is None: + return func, None + + block_stmt = sch.get(main_block) + + minimal_tensorize_threshold = 16 + # the batch dimension is not taken into consideration. + extent = block_stmt.iter_vars[1].dom.extent + if isinstance(extent, tir.expr.IntImm): + if extent.value < (1 if allow_gemv else minimal_tensorize_threshold): + return func, None + for item_var in block_stmt.iter_vars[2:]: + extent = item_var.dom.extent + if isinstance(extent, tir.expr.IntImm): + if extent.value < minimal_tensorize_threshold: + return func, None + tags = analysis_tensorcore_tags(sch, main_block, target) + return sch.mod["main"], tags + + return func, None + + +def get_propagate_map( + trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32" +): + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + ldmatrix_32x8_to_shared_16x16_layout, + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, + ) + + assert dtype in ["float16", "int8"], "Only support float16 for now" + if dtype == "float16": + ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout + ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout + elif dtype == "int8": + # int8 mma only support 32x16 to 16x32 layout + if matrix_name == "A" and trans == False: + ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a + elif matrix_name == "B" and trans == True: + ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_b + else: + raise ValueError("Unknown matrix name ", matrix_name) + + # IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out + def ldmatrix_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_layout(thread_id, local_id) + + def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_layout_trans(thread_id, local_id) + + def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 16 + local_id = kernel_j % 16 + return ldmatrix_layout(thread_id, local_id) + + if dtype == "float16": + ldmatrix_index_map = ( + ldmatrix_trans_permutation_16x16_32x8_16x16 + if trans + else ldmatrix_permutation_16x16_32x8_16x16 + ) + else: + ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 + + ldmatrix_index_map = IndexMap.from_func(ldmatrix_index_map, index_dtype=index_dtype) + # TODO(lei): index_dtype should be analyzed from the schedule + row, col = [16, 16] if dtype == "float16" else [16, 32] + inversed_index_map = ldmatrix_index_map.inverse([row, col]) + return ldmatrix_index_map, inversed_index_map + + +def layout_propagate_chain( + sch: tir.Schedule, + start_block: BlockRV, + start_buffer: tir.Buffer, + end_block: BlockRV, + index_map: IndexMap, +): + # some layout transformation may only apply to the last n dimensions + # propagate the layout transformation to the chain of blocks + block = start_block + buffer = start_buffer + index_map = index_map + while True: + last_buffer = buffer + producers = sch.get_producers(block) + if len(producers) == 0: + break + for producer in producers: + if len(sch.get(producer).writes) != 1: + return index_map + if sch.get(producer) == sch.get(end_block): + return index_map + (write,) = sch.get(producer).writes + read = sch.get(producer).reads[0] + if write.buffer == buffer: + block = producer + buffer = sch.get(producer).reads[0].buffer + write_indices = [r.min for r in write.region] + read_indices = [r.min for r in read.region] + # reverse index map from [vi // x] -> [vi * x] to match the inconsistent layout + tmp_index_map = IndexMap(write_indices, read_indices, None) + tmp_index_map = tmp_index_map.non_surjective_inverse( + write.buffer.shape + )[0] + + # if dequantize like ops are used, the scaling factor should be considered + # to be applied to the final indices + scaling_factor = 1 + for i, j in zip(write.buffer.shape, read.buffer.shape): + scaling_factor *= i // j + final_indices = list( + index_map.map_indices(tmp_index_map.map_indices(write_indices)) + ) + final_indices[-1] = final_indices[-1] // scaling_factor + index_map = IndexMap( + write_indices, + final_indices, + None, + ) + if buffer == last_buffer: + break + return index_map diff --git a/python/bitblas/gpu/matmul_mma.py b/python/bitblas/gpu/matmul_mma.py new file mode 100644 index 000000000000..593a4293203e --- /dev/null +++ b/python/bitblas/gpu/matmul_mma.py @@ -0,0 +1,684 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from typing import Literal, Optional, List + +from tvm import tir +from tvm.target import Target + +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo +from ..base.analysis import get_coalesced_veclen +from .matmul_analysis import ( + auto_inline_consumer_chain, + is_transpose_block, + is_identity_block, + _collect_producers, + inline_transpose_block, + auto_inline_producers, + get_index_map, + get_reduction_blocks, + get_dequantize_block, + normalize_to_matmul, + get_propagate_map +) + + +def get_index_map_3d(index_map, l=16, r=16): + def index_map_3d(b, i, j): + return ( + b, + i // l, + j // r, + *index_map(i % l, j % r), + ) + + return index_map_3d + + +def get_index_map_5d(index_map): + """ + for layout transformed gemm, the index map should be 5d + """ + + def index_map_5d(b, i, j, ii, jj): + return ( + b, + i, + j, + *index_map(ii, jj), + ) + + return index_map_5d + + +def get_warp_index_map(index_map, l=16, r=16, is_5d=False): + if is_5d: + return get_index_map_5d(index_map) + return get_index_map_3d(index_map, l, r) + + +class MatmulTensorizationMMA(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + # We first inline all transpose blocks for later analysis of transposed A and B + blocks = inline_transpose_block(sch, blocks) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + dequantize_block = get_dequantize_block(sch, blocks) + + main_block = reduction_blocks[0] + main_block_stmt = sch.get(main_block) + + # Supported data types: + # fp16, fp16, fp16: fp16 precision + # fp16, fp16, fp32: fp16 mixed precision + dtype_a = main_block_stmt.reads[0].buffer.dtype + dtype_b = main_block_stmt.reads[1].buffer.dtype + dtype_c = main_block_stmt.writes[0].buffer.dtype + if dtype_a != dtype_b: + return None + + # Get index maps + index_maps = get_index_map(main_block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # Tensorization by hardware intrinsics + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group, + shared_16x16_to_mma_32x8_layout, + ) + + # tile size + block_m, block_n, block_k = 128, 128, 32 + + # tensor core intrinsic size + micro_size_m, micro_size_n, micro_size_k = 16, 16, 16 + + # thread size + # thread_x == warp_size + thread_z, thread_y, thread_x = 2, 2, 32 + + vector_size = 8 + unroll_depth = 4 + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + is_transpose_a = is_transpose_block(sch.get(block)) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + is_transpose_b = is_identity_block(sch.get(block)) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + batch, i, j, k = sch.get_loops(main_block) + + swizzle_factor_for_l2_m = [1, None] + swizzle_factor_for_l2_n = [1, None] + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + swizzle_factor_for_l2_m[0] * block_m, + swizzle_factor_for_l2_n[0] * block_n, + block_k, + ], + ) + + # Step 3. Reorder loops for tiling + + # Step 3.1 inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_m]) + j, j_inner = sch.split(j, factors=[None, micro_size_n]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = main_block + block_outer = sch.blockize(i_inner) + + # Step 3.2 outer loops for tiling + # split factors for i, j, and k + micro_block_cnt_in_warp_m = block_m // thread_z // micro_size_m + micro_block_cnt_in_warp_n = block_n // thread_y // micro_size_n + micro_block_cnt_in_warp_k = block_k // micro_size_k + + i_factors = swizzle_factor_for_l2_m + [thread_z, micro_block_cnt_in_warp_m] + j_factors = swizzle_factor_for_l2_n + [thread_y, micro_block_cnt_in_warp_n] + k_factors = [None, micro_block_cnt_in_warp_k] + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, factors=k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_axis = sch.fuse(batch, i0, j0, i1, j1) + sch.bind(block_axis, "blockIdx.x") + + sch.bind(i2, "threadIdx.z") + sch.bind(j2, "threadIdx.y") + + # Step 4. Read/write to shared mem and register + def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose): + # 1) Read to shared memory + block_read_smem = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") + sch.compute_at(block_read_smem, k0) + auto_inline_producers( + sch, block_read_smem, [dequantize_block] if dequantize_block else [] + ) + + # For transposed read, we directly load transposed tensor from global + # Then use ldmatrix.trans to handle transpose later + if (tensor_name == "A" and is_transpose) or (tensor_name == "B" and not is_transpose): + # specifical handle transpose read (for NN matmul or TT matmul) + v0, v1 = sch.get_loops(block_read_smem)[-2:] + sch.reorder(v1, v0) + sch.transform_layout(block_read_smem, ("write", 0), lambda b, i, j: (b, j, i)) + + # bind loops + fused = sch.fuse(*sch.get_loops(block_read_smem)[-2:]) + f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size]) + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + + # swizzling + sch.annotate(block_read_smem, ann_key="permuted_layout", ann_val=1) + + # 2) Read to register + block_read_reg = sch.cache_read(block_outer, read_buffer_idx, "warp") + sch.compute_at(block_read_reg, k1) + + # bind_loops + micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n + micro_size_1, micro_size_2 = ( + (micro_size_spatial, micro_size_k) + if not is_transpose + else (micro_size_k, micro_size_spatial) + ) + v00, v01 = sch.split(sch.get_loops(block_read_reg)[-2], [None, micro_size_1]) + v10, v11 = sch.split(sch.get_loops(block_read_reg)[-1], [None, micro_size_2]) + sch.reorder(v00, v10, v01, v11) + + # reorder read axis to match the layout of ldmatrix + sch.transform_layout( + block_read_reg, + ("write", 0), + lambda v0, v1, v2: ( + v0, + v1 // micro_size_1, + v2 // micro_size_2, + *shared_16x16_to_mma_32x8_layout(v1 % micro_size_1, v2 % micro_size_2), + ), + ) + + # swizzling + mma_read_block = sch.blockize(sch.get_loops(block_read_reg)[-2]) + sch.annotate(mma_read_block, ann_key="permuted_layout", ann_val=1) + + return block_read_smem, block_read_reg + + block_read_a, block_read_reg_a = fetch_input(block_outer, 0, "A", is_transpose_a) + block_read_b, block_read_reg_b = fetch_input(block_outer, 1, "B", is_transpose_b) + + # Write to register, and then smem + def store_output(block_outer, write_buffer_idx): + # 1) Write to shared memory + block_write_smem = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") + sch.reverse_compute_at(block_write_smem, block_axis) + auto_inline_consumer_chain(sch, block_write_smem) + + # bind loops + fused = sch.fuse(*sch.get_loops(block_write_smem)[-2:]) + f0, f1, f2 = sch.split(fused, [None, thread_x, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + # 2) Write to register + block_write_reg = sch.cache_write(block_outer, write_buffer_idx, "warp") + + # bind loops + v0, v1, v2 = sch.get_loops(block_write_reg)[-3:] + v11, v12, v13 = sch.split(v1, factors=[thread_z, None, micro_size_m]) + v21, v22, v23 = sch.split(v2, factors=[thread_y, None, micro_size_n]) + sch.reorder(v11, v21, v12, v22, v13, v23) + sch.bind(v11, "threadIdx.z") + sch.bind(v21, "threadIdx.y") + + # reorder write axis to match the layout of ldmatrix + sch.transform_layout( + block_write_reg, + ("read", 0), + lambda v0, v1, v2: ( + v0, + v1 // micro_size_m, + v2 // micro_size_n, + *shared_16x16_to_mma_32x8_layout(v1 % micro_size_m, v2 % micro_size_n), + ), + ) + + return block_write_smem, block_write_reg + + block_write_smem, block_write_reg = store_output(block_outer, 0) + + # Step 5. Schedule tensor core computation + block_init = sch.decompose_reduction(block_outer, k0) + block_init_inner = sch.get_child_blocks(block_init)[0] + + intrin_group = get_mma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype=str(dtype_a), + out_dtype=str(dtype_c), + trans_a=is_transpose_a, + trans_b=is_transpose_b, + not_use_mma_store_intrinic=False, + ) + + sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(block_read_reg_a)[-2], intrin_group["load_a"]) + sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + sch.tensorize(sch.get_loops(block_write_reg)[-2], intrin_group["store"]) + + # Step 6. Async pipeline + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0]) + + # Step 7. Handle dequantize block + # Now we just add a dummy kernel to compute dequantize + if dequantize_block is not None: + auto_inline_producers(sch, dequantize_block) + loops = sch.get_loops(dequantize_block) + loop = sch.fuse(*loops) + v0, v1, v2, v3 = sch.split(loop, [None, 128, 2, 4]) + sch.bind(v0, "blockIdx.x") + sch.bind(v1, "threadIdx.x") + sch.unroll(v2) + sch.vectorize(v3) + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> Optional[tir.Schedule]: + if "dequantize_info" in func.attrs: + dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() + return dequantize_rule.apply_config(func, config) + + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group, + ) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] + + def check_require_cache(func: tir.PrimFunc): + conditions: List[bool] = [] + + # check if has dynamic symbolic + def check_has_dynamic(func: tir.PrimFunc): + for param in func.params: + if param not in func.buffer_map: + continue + arg = func.buffer_map[param] + for i in arg.shape: + if isinstance(i, tir.Var): + return True + return False + + conditions.append(check_has_dynamic(func)) + # check if has post process + conditions.append(sch.get(main_block) not in output_blocks) + return any(conditions) + + cache_write_required = check_require_cache(func) + + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + in_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): + return (r, l) if trans else (l, r) + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_b) + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + # plan rasteration + if ( + not isinstance(config.rasterization_plan, NoRasterization) + and sch.get(batch).extent.value == 1 + ): + device_func, invoke_func = config.rasterization_plan.get_code() + factor = config.rasterization_plan.panel_width_ + + # TODO(lei): this is a trick for rasterization implementation + # is not optimal. (5% performance loss) + # require a solution for general block rasterization + factor = 8 # should be divisible by block_idx + if sch.get(block_idx).extent.value % factor == 0: + block_k, block_idx = sch.split(block_idx, factors=[None, factor]) + sch.reorder(block_k, block_idy, block_idx) + sch.bind(block_k, "blockIdx.z") + else: + sch.bind(batch, "blockIdx.z") + + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") + + # rewrite smooth layout of shared memory + def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_smem_layout_rewrite(block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a) + smooth_smem_layout_rewrite(block_outer, ("read", 1), *b_lr, enable=intrin_info.smooth_b) + smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[num_ty, num_tz, None, warp_size, vec_len] + ) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_1, "threadIdx.z") + sch.bind(f_0, "threadIdx.y") + sch.vectorize(f_4) + sch.unroll(f_2) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_2, "pragma_unroll_explicit", False) + return block_read + + if len(config.vectorize.values()) < 2: + return None + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + trans=intrin_info.trans_a, + ) + b_g2s = fetch_to_shared( + block_outer, + 1, + vec_len=list(config.vectorize.values())[1], + can_swizzle=can_swizzle_b, + is_smooth=intrin_info.smooth_b, + trans=intrin_info.trans_b, + ) + + # rewrite global smooth layout + def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"): + if not enable: + return + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + producers = _collect_producers(sch, block) + + propagate_block: tir.Block = producers[-1] + + # step2: transform the layout with inverse permutation + _, inverse_indexmap = get_propagate_map(trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *inverse_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite(sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + smooth_gmem_layout_rewrite(sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x], preserve_unit_iters=False) + j0, j1 = sch.split(j, factors=[None, micro_size_y], preserve_unit_iters=False) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, sch.get_loops(store)[-5], preserve_unit_loops=True + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split( + fused, factors=[None, warp_size, vec_len] + ) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, ("write", 0), get_warp_index_map(index_map_a, *a_lr, intrin_info.smooth_a) + ) + sch.transform_layout( + B_mat, ("write", 0), get_warp_index_map(index_map_b, *b_lr, intrin_info.smooth_b) + ) + sch.transform_layout( + store, + ("read", 0), + get_warp_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + + return sch diff --git a/python/bitblas/gpu/matmul_mma_dequantize.py b/python/bitblas/gpu/matmul_mma_dequantize.py new file mode 100644 index 000000000000..6900d1e3e378 --- /dev/null +++ b/python/bitblas/gpu/matmul_mma_dequantize.py @@ -0,0 +1,637 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from typing import Literal, Optional + +from tvm import tir +from tvm.target import Target + +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from ..base.analysis import get_coalesced_veclen +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_producers, + get_reduction_blocks, + normalize_to_matmul, + get_propagate_map, + layout_propagate_chain, + find_last_producer_from_buffer, +) + + +def get_index_map_3d(index_map, l=16, r=16): + def index_map_3d(b, i, j): + return ( + b, + i // l, + j // r, + *index_map(i % l, j % r), + ) + + return index_map_3d + + +def get_index_map_5d(index_map): + """ + for layout transformed gemm, the index map should be 5d + """ + + def index_map_5d(b, i, j, ii, jj): + return ( + b, + i, + j, + *index_map(ii, jj), + ) + + return index_map_5d + + +def get_index_map(index_map, l=16, r=16, is_5d=False): + if is_5d: + return get_index_map_5d(index_map) + return get_index_map_3d(index_map, l, r) + + +class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def sch_dequantize_in_register_with_config( + self, + func: tir.PrimFunc, + config, + ): + """ + For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. + quantized weight + | + V + dequantized in register + | + V + save into shared memory + | + V + compute + """ + + return None + + def sch_shared_memory_prefetch_with_config( + self, + func: tir.PrimFunc, + config, + ): + """ + For A100 Like devices, the shared memory prefetch(async) is required + to achieve optimal performance. + quantized weight + | + V + shared memory prefetch (with async copy) + | + V + dequantized into shared memory + | + V + compute + """ + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group, + ) + from .intrin.lop3 import get_lop3_intrin_group + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + # always enable shared memory rewrite + cache_write_required = True + + # Check Dequantize Info + # TODO(leiwang): this is a hack to get the configuaration, can be improved by writing a pass to analysis the dequantize block. + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_b_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "af"] + conditions.append("source_format" in weight_decode_info) + conditions.append( + weight_decode_info["source_format"]["format"] + in ["uint", "int", "fp", "af"] + ) + # check source bits in [1, 2, 4, 8] + conditions.append( + weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8] + ) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append( + weight_decode_info["target_format"] in ["float16", "int8"] + ) + return all(conditions) + + assert check_b_decode_info(weight_decode_info), "Invalid B_decode_info" + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + intrin_info = config.intrin_info + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + in_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): + return (r, l) if trans else (l, r) + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_b) + + # rewrite global smooth layout, for dequantize, currently only support weight only recover. + def smooth_gmem_layout_rewrite( + sch, main_block, enable=True, trans=False, matrix_name="A" + ): + if not enable: + return + + # normalized block may have three read buffers, while the first one is the write buffer. + buffer_offset = ( + 1 + if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer + else 0 + ) + buffer_idx = 0 if matrix_name == "A" else 1 + source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer + + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + propagate_block: tir.Block = find_last_producer_from_buffer( + sch, main_block, source_buffer + ) + # some trick impl may not have reindex block + (weight_dequantize_info,) = dequantize_info.values() + if ( + sch.get(propagate_block).name_hint + == weight_dequantize_info["decode_block"] + ): + return + # step2: transform the layout with inverse permutation + _, inverse_indexmap = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name + ) + + # step3: propagate the matmul layout to the first reindex block + + inverse_indexmap = layout_propagate_chain( + sch, + start_block=main_block, + start_buffer=source_buffer, + end_block=propagate_block, + index_map=inverse_indexmap, + ) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *inverse_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A" + ) + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B" + ) + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not ( + func.attrs is not None + and "dlight.tensorcore_prenormlized" in func.attrs.keys() + ): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + # plan rasteration + if ( + not isinstance(config.rasterization_plan, NoRasterization) + and sch.get(batch).extent.value == 1 + ): + device_func, invoke_func = config.rasterization_plan.get_code() + factor = config.rasterization_plan.panel_width_ + + # TODO(lei): this is a trick for rasterization implementation + # is not optimal. + # require a solution for general block rasterization + # factor = 8 # should be divisible by block_idy + # if sch.get(block_idx).extent.value % factor == 0: + # block_k, block_idx = sch.split(block_idx, factors=[None, factor]) + # sch.bind(block_k, "blockIdx.z") + else: + sch.bind(batch, "blockIdx.z") + + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") + + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_layout_recover( + block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a + ) + smooth_layout_recover( + block_outer, + ("read", 1), + *b_lr, + enable=intrin_info.smooth_b, + ) + smooth_layout_recover(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[None, num_ty, num_tz, warp_size, vec_len] + ) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_2, "threadIdx.z") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_4) + sch.unroll(f_0) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_2, "pragma_unroll_explicit", False) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + ) + + auto_inline_producers(sch, a_g2s) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, k0, preserve_unit_loops=True) + + # TODO(lei): the factor shoule be analyzed more deeper. + decode_factor = get_coalesced_veclen(sch.get(block_shared)) + _, B_shared_vi, _ = sch.split( + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor] + ) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block + auto_inline_producers(sch, block_shared_local) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structual based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "af": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read( + block_shared_local, b_idx, "local" + ) + # global -> prefetch_shared -> vector load -> dequantzed_local -> shared + block_shared_local_local_shared = sch.cache_read( + block_shared_local_local, 0, shared_scope + ) + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at( + block_shared_local_local, B_shared_vi, preserve_unit_loops=True + ) + + dequantize_block_local = block_shared_local + # fast type conversion + if ( + "fast_decoding" in weight_decode_info + and weight_decode_info["fast_decoding"] + ): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + ) + sch.tensorize( + sch.get_loops(dequantize_block_local)[-1], + lop3_intrin_info["compute"], + ) + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=lop3_intrin_info["c_source"], + ) + + sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) + union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( + B_shared_fused, factors=[None, num_ty, num_tz, warp_size] + ) + if not (can_swizzle_b or intrin_info.smooth_b): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align( + block_shared, 0, axis=-2, factor=16, offset=pad_offset + ) + sch.bind(B_shared_tx, "threadIdx.x") + sch.bind(B_shared_ty, "threadIdx.y") + sch.bind(B_shared_tz, "threadIdx.z") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + sch.compute_at( + block_shared_local_local_shared, k0, preserve_unit_loops=True + ) + ndim = len(sch.get(block_shared_local_local_shared).iter_vars) + fused = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, + factors=[ + None, + num_tz, + num_ty, + warp_size, + get_coalesced_veclen(sch.get(block_shared_local_local_shared)), + ], + ) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_2, "threadIdx.y") + sch.bind(f_1, "threadIdx.z") + sch.vectorize(f_4) + sch.unroll(f_0) + sch.annotate(f_0, "pragma_unroll_explicit", False) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read( + dequantize_block_local, 0, shared_scope + ) + sch.reverse_compute_at(block_shared_lut, j2) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size] + ) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + _ = decode_fetch_to_shared(block_outer, 1) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1, preserve_unit_loops=True) + sch.compute_at(B_mat, k1, preserve_unit_loops=True) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, ("write", 0), get_index_map(index_map_a, *a_lr, intrin_info.smooth_a) + ) + sch.transform_layout( + B_mat, ("write", 0), get_index_map(index_map_b, *b_lr, intrin_info.smooth_b) + ) + sch.transform_layout( + store, + ("read", 0), + get_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate( + k0, + ann_key="software_pipeline_stage", + ann_val=[0, 0, stage - 1, stage - 1], + ) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> Optional[tir.Schedule]: + def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + if check_sm_version(config.arch.target.arch) < 80: + """MMA Template only support sm_80 and above""" + return None + + if ( + config.arch.target.kind.name == "cuda" + and check_sm_version(config.arch.target.arch) == 80 + ): + return self.sch_shared_memory_prefetch_with_config(func, config) + else: + return self.sch_with_config(func, config) diff --git a/python/bitblas/gpu/matmul_wmma.py b/python/bitblas/gpu/matmul_wmma.py new file mode 100644 index 000000000000..765860f8811f --- /dev/null +++ b/python/bitblas/gpu/matmul_wmma.py @@ -0,0 +1,909 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +import math +from typing import Literal, Optional + +from tvm import DataType, tir +from tvm.target import Target +from tvm.tir.stmt import ForKind + +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_consumers, + auto_inline_producers, + get_index_map, + get_reduction_blocks, + normalize_to_matmul, +) + + +class MatmulTensorizationWMMA(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + block_m = 128 + block_n = 128 + block_k = 32 + + # tensor core intrinsic size + micro_size_m = 16 + micro_size_n = 16 + micro_size_k = 16 + + thread_z = 2 + thread_y = 2 + warp_size = 32 + thread_cnt = thread_y * thread_z * warp_size + + vector_size = 8 + unroll_depth = 256 + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + + # # Step 2.1 Swizzle for l2, for better performance on inputs exceeding l2 size + # # Get input shape + batch, i, j, k = sch.get_loops(main_block) + # input_b, input_m, input_n, input_k = [sch.get(loop).extent for loop in [batch, i, j, k]] + + # # Get input/output dtype + dtype_a, dtype_b = [DataType(region.buffer.dtype) for region in sch.get(main_block).reads] + dtype_c = DataType(sch.get(main_block).writes[0].buffer.dtype) + # dtype_a_bytes, dtype_b_bytes = [math.ceil(d.bits / 8) for d in [dtype_a, dtype_b]] + + # # Get l2 size + # l2_size = target.l2_cache_size_bytes + + # # Analyse swizzle factor + # def get_swizzle_factor(l2_size, input_k, dtype_bytes, input_spatial, block_size): + # if l2_size != 0 and isinstance(input_k, (int, tir.IntImm)): + # # div by 3: suppose the two inputs and the output uses the same amount of l2 + # swizzle_factor = l2_size / 3 / int(input_k) / dtype_bytes / block_size + # # optimization: try find the best swizzle factor (aka the least additional padding) + # if isinstance(input_spatial, (int, tir.IntImm)): + # block_cnt = math.ceil(int(input_spatial) / block_size) + # swizzle_factor = math.ceil(block_cnt / math.ceil(block_cnt / swizzle_factor)) + # else: + # swizzle_factor = math.floor(swizzle_factor) + # return [None, swizzle_factor] + # else: + # return [4, None] + + # swizzle_factor_m = get_swizzle_factor(l2_size, input_k, dtype_a_bytes, input_m, block_m) + # swizzle_factor_n = get_swizzle_factor(l2_size, input_k, dtype_b_bytes, input_n, block_n) + + swizzle_factor_m = [4, None] + swizzle_factor_n = [4, None] + + # Step 2.2 Add padding + sch.pad_einsum( + main_block, + [ + 1, + (swizzle_factor_m[0] or swizzle_factor_m[1]) * block_m, + (swizzle_factor_n[0] or swizzle_factor_n[1]) * block_n, + block_k, + ], + ) + + # Step 3. Reorder loops for tiling + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_m]) + j, j_inner = sch.split(j, factors=[None, micro_size_n]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = main_block + block_outer = sch.blockize(i_inner) + + # split factors for i, j, and k + in_wrap_block_cnt_m = block_m // thread_z // micro_size_m + in_wrap_block_cnt_n = block_n // thread_y // micro_size_n + in_wrap_block_cnt_k = block_k // micro_size_k + + i_factors = swizzle_factor_m + [thread_z, in_wrap_block_cnt_m] + j_factors = swizzle_factor_n + [thread_y, in_wrap_block_cnt_n] + k_factors = [None, in_wrap_block_cnt_k] + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, factors=k_factors) + + sch.reorder(i0, j0, i1, j1, k0, i2, j2, k1, i3, j3) + block_axis = sch.fuse(batch, i0, j0, i1, j1) + + sch.bind(block_axis, "blockIdx.x") + sch.bind(i2, "threadIdx.z") + sch.bind(j2, "threadIdx.y") + + # Step 4. Read to/write from shared mem, and from/to wmma fragments + def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], wmma_name): + block_read = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-2:]) + + f0, f1, f2, f3, f4 = sch.split( + fused, [None, thread_z, thread_y, warp_size, vector_size] + ) + + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) + + auto_inline_producers(sch, block_read) + + wmma_read = sch.cache_read(block_outer, read_buffer_idx, wmma_name) + sch.compute_at(wmma_read, k1) + + micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n + v0, v1 = sch.get_loops(wmma_read)[-2:] + sch.split(v0, factors=[None, micro_size_spatial]) + + return wmma_read + + wmma_read_a = fetch_input( + block_outer, 0, [block_m, block_k, micro_size_m, micro_size_k], "wmma.matrix_a" + ) + wmma_read_b = fetch_input( + block_outer, 1, [block_n, block_k, micro_size_n, micro_size_k], "wmma.matrix_b" + ) + + def store_output(block_outer, write_buffer_idx, wmma_name): + block_write = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") + sch.reverse_compute_at(block_write, block_axis) + + fused = sch.fuse(*sch.get_loops(block_write)[-2:]) + + f0, f1, f2, f3, f4 = sch.split( + fused, [None, thread_z, thread_y, warp_size, vector_size] + ) + + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + # sch.storage_align(block_write, 0, axis=-2, factor=128, offset=16) + + auto_inline_consumer_chain(sch, block_write) + + wmma_store = sch.cache_write(block_outer, write_buffer_idx, wmma_name) + v0, v1 = sch.get_loops(wmma_store)[-2:] + v00, v01, v02 = sch.split(v0, factors=[thread_z, None, micro_size_m]) + v10, v11, v12 = sch.split(v1, factors=[thread_y, None, micro_size_n]) + sch.reorder(v00, v10, v01, v11, v02, v12) + sch.bind(v00, "threadIdx.z") + sch.bind(v10, "threadIdx.y") + return wmma_store + + wmma_store = store_output(block_outer, 0, "wmma.accumulator") + + block_init = sch.decompose_reduction(block_outer, k0) + block_init_inner = sch.get_child_blocks(block_init)[0] + + # unroll k + sch.unroll(k0) + + # Step 5. Schedule tensor core computation + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group, + ) + + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype=str(dtype_a), + out_dtype=str(dtype_c), + trans_b=True, + ) + + sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(wmma_read_a)[-2], intrin_group["load_a"]) + sch.tensorize(sch.get_loops(wmma_read_b)[-2], intrin_group["load_b"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + sch.tensorize(sch.get_loops(wmma_store)[-2], intrin_group["store"]) + + return sch + + +class MatmulInt8Tensorization(GPUScheduleRule): + """ + The schedule rule for int8 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group, + ) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + micro_size_x = 16 + micro_size_y = 16 + micro_size_k = 16 + + warp_size = 32 + vector_size = 4 + + i_factors, j_factors, k_factors = ( + [None, 1, 4, 2], + [1, None, 4, 2], + [None, 1], + ) + + num_ty = i_factors[2] * j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) + sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) + sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) + sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16) + sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) + sch.annotate(block_read, "double_buffer_scope", 0) + return block_read + + a_g2s = fetch_to_shared(block_outer, 0, 2) + b_g2s = fetch_to_shared(block_outer, 1, 2) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") + B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") + sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) + + store = sch.cache_write(block_outer, 0, "wmma.accumulator") + sch.reverse_compute_at(store, thread_idy) + sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="int8", + out_dtype="int32", + trans_b=True, + ) + + try: + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_b"]) + except: # pylint: disable=bare-except + return None + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + try: + tensorize_init_store_compute() + except: # pylint: disable=bare-except + return None + + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) + _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + return sch + + +class MatmulTensorizationLegacy(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group, + ) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + micro_size_x = 16 + micro_size_y = 16 + micro_size_k = 16 + + warp_size = 32 + vector_size = 4 + + i_factors, j_factors, k_factors = ( + [None, 1, 4, 2], + [1, None, 4, 2], + [None, 4], + ) + + num_ty = i_factors[2] * j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) + sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) + sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) + sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) + sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) + sch.annotate(block_read, "double_buffer_scope", 0) + return block_read + + a_g2s = fetch_to_shared(block_outer, 0, 2) + b_g2s = fetch_to_shared(block_outer, 1, 2) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") + B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") + sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) + + store = sch.cache_write(block_outer, 0, "wmma.accumulator") + sch.reverse_compute_at(store, thread_idy) + sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="float16", + out_dtype="float32", + trans_b=True, + ) + + try: + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_b"]) + except: # pylint: disable=bare-except + return None + + # Try to tensorize the init, store and compute block with f16 or f32 intrinsics + tensorize_success: bool = False + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + try: + tensorize_init_store_compute() + tensorize_success = True + except: # pylint: disable=bare-except + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="float16", + out_dtype="float16", + trans_b=True, + ) + + if not tensorize_success: + try: + tensorize_init_store_compute() + tensorize_success = True + except: # pylint: disable=bare-except + return None + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) + _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + return sch if tensorize_success else None + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group, + ) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + intrin_info = config.intrin_info + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x = 16 + micro_size_y = 16 + micro_size_k = 16 + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] * j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + # plan rasteration + if ( + not isinstance(config.rasterization_plan, NoRasterization) + and sch.get(batch).extent.value == 1 + ): + device_func, invoke_func = config.rasterization_plan.get_code() + factor = config.rasterization_plan.panel_width_ + + # TODO(lei): this is a trick for rasterization implementation + # wait for https://github.com/apache/tvm/pull/16113 to be merged + # require a solution for general block rasterization + factor = 8 # should be divisible by block_idy + if sch.get(block_idy).extent.value % factor == 0: + block_k, block_idy = sch.split(block_idy, factors=[None, factor]) + sch.bind(block_k, "blockIdx.z") + else: + sch.bind(batch, "blockIdx.z") + + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim, vec_len, dtype="float16"): + block_read = sch.cache_read(block, idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vec_len]) + + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + offset: int = 0 + if dtype == "float16": + offset = 8 + elif dtype == "int8": + offset = 16 + # todo(lei): the pad value should be varied according to the data type + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=offset) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + 2, + vec_len=list(config.vectorize.values())[0], + dtype=intrin_info.in_dtype, + ) + b_g2s = fetch_to_shared( + block_outer, + 1, + 2, + vec_len=list(config.vectorize.values())[1], + dtype=intrin_info.in_dtype, + ) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") + B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") + sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) + + store = sch.cache_write(block_outer, 0, "wmma.accumulator") + sch.reverse_compute_at(store, thread_idy) + sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_b=intrin_info.trans_b, + ) + + try: + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_b"]) + except: # pylint: disable=bare-except + return None + + # Try to tensorize the init, store and compute block with f16 or f32 intrinsics + tensorize_success: bool = False + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + try: + tensorize_init_store_compute() + tensorize_success = True + except: # pylint: disable=bare-except + return None + + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) + _, f1, f2 = sch.split( + fused, factors=[None, warp_size, max(list(config.vectorize.values()))] + ) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + if stage > 1: + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + + return sch if tensorize_success else None + diff --git a/python/bitblas/gpu/reduction.py b/python/bitblas/gpu/reduction.py new file mode 100644 index 000000000000..9d6aada75985 --- /dev/null +++ b/python/bitblas/gpu/reduction.py @@ -0,0 +1,301 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm reduction.py in dlight. +"""A rule for reduction. """ +from typing import List, Optional, Tuple, Union + +from tvm import arith, ir, tir +from tvm.target import Target + +from ..base import ( + BlockInfo, + normalize_prim_func, + try_inline_contiguous_spatial, + detect_dominant_read, + is_broadcast_epilogue, +) +from . import utils +from .base import GPUScheduleRule + + +def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: + # Detect and return `Y` in `X[...] = X[...] + Y` + buffer_store = block.body + if not isinstance(buffer_store, tir.BufferStore): + return None + if not isinstance(buffer_store.value, tir.Add): + return None + if not ir.structural_equal( + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, + ): + return None + return buffer_store.value.b + + +class Reduction(GPUScheduleRule): + """A rule for Reduction.""" + + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + if block_infos is None: + return None + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + block = block_info.block_rv + block_stmt = sch.get(block) + + # Step 1. Check reduction block + if ( + (not block_info.is_reduction()) + or len(block_stmt.writes) != 1 + or _get_reduction_expr(block_stmt) is None + ): + return None + # Step 2. Normalize the block, merge spatial and reduction iters + is_inner_reduction, c_factor, loop_order, s_split_index = self._normalize( + sch, + block_info, + arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ), + ) + if is_inner_reduction is None and c_factor is None: + return None + # Step 3. Do the scheduling + if is_inner_reduction: + self._sch_inner_reduction( + sch, target, block, c_factor, epilogue, loop_order, s_split_index + ) + else: + self._sch_inner_spatial( + sch, target, block, block_info, c_factor, epilogue, loop_order, s_split_index + ) + return sch + + def _normalize( # pylint: disable=too-many-branches + self, + sch: tir.Schedule, + block_info: BlockInfo, + access: arith.IterSumExpr, + ) -> Tuple[Optional[bool], Optional[int]]: + if access.base != 0: + return None, None, None, None + iter_to_info = {i.var: i for i in block_info.iters} + s_loops, r_loops, c_loops, c_factor = [], [], [], None + s_split_loop, s_split_index = None, None + for split_expr in access.args: + var = split_expr.source.source + info = iter_to_info.pop(var) + loop = info.loop_rv + is_inner_reduction = info.kind == "R" + if split_expr.lower_factor > 1: + if c_loops: + return None, None, None, None + s_split_loop = loop + s_split_index = len(s_loops) + loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + c_loops.append(c_loop) + if not is_inner_reduction: + c_factor = split_expr.lower_factor + if is_inner_reduction: + r_loops.append(loop) + else: + s_loops.append(loop) + + if iter_to_info: + for var, info in iter_to_info.items(): + if info.kind == "S" and info.dom.extent == 1: + s_loops.append(info.loop_rv) + else: + return None, None, None, None + + loop_order = {} + s_block_var_loops = [] + for i in block_info.iters: + if i.loop_rv in s_loops or i.loop_rv == s_split_loop: + s_block_var_loops.append(i.loop_rv) + + for i in range(len(s_block_var_loops)): + for j in range(len(s_loops)): + if s_block_var_loops[i] == s_loops[j]: + loop_order[i] = j + break + if s_block_var_loops[i] == s_split_loop: + loop_order[i] = s_split_index + break + + assert s_loops + assert r_loops + if len(s_loops) != len([i for i in block_info.iters if i.kind == "S"]): + return None, None + if not c_loops: + c_loops = [sch.add_unit_loop(block_info.block_rv)] + sch.reorder(*s_loops, *r_loops, *c_loops) + sch.fuse(*s_loops) + sch.fuse(*r_loops) + return is_inner_reduction, c_factor, loop_order, s_split_index + + def _sch_inner_reduction( # pylint: disable=too-many-arguments + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + unroll_spatial_factor: Optional[int], + epilogue_info: Optional[BlockInfo], + loop_order, + s_split_index, + ): + # pylint: disable=invalid-name + _, r, _ = sch.get_loops(block) + (len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking + target, [sch.get(r)] + ) + + _, tx = sch.split(r, factors=[None, len_tx]) + # Schedule the RF block + rf = sch.rfactor(tx, 0) + bx, r, tx, _ = sch.get_loops(rf) + sch.reorder(bx, tx, r) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=256) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + sch.set_scope(rf, 0, "local") + sch.decompose_reduction(rf, r) + # Schedule the write back block + sch.reverse_compute_at(block, bx, preserve_unit_loops=True) + _, tx, *s = sch.get_loops(block) + + if unroll_spatial_factor: + assert len(s) == len(loop_order) + new_order_s = [s[loop_order[i]] for i in range(len(s))] + sch.reorder(*new_order_s) + new_order_s[s_split_index], c = sch.split( + new_order_s[s_split_index], factors=[None, unroll_spatial_factor] + ) + sch.reorder(*new_order_s, c) + s = sch.fuse(*new_order_s) + sch.reorder(s, tx, c) + else: + s = sch.fuse(*s) + sch.reorder(s, tx) + sch.bind(tx, "threadIdx.x") + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + sch.reverse_compute_at(epilogue, bx) + if is_broadcast_epilogue(sch, block, epilogue): + sch.set_scope(block, 0, "shared") + _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx]) + sch.bind(tx, "threadIdx.x") + else: + sch.set_scope(block, 0, "local") + # pylint: enable=invalid-name + + def _sch_inner_spatial( + self, + sch: tir.Schedule, + _: Target, + block: tir.schedule.BlockRV, + block_info: BlockInfo, + unroll_spatial_factor: Optional[int], + epilogue_info: Optional[BlockInfo], + loop_order, + s_split_index, + ): + # pylint: disable=invalid-name + s, r, _ = sch.get_loops(block) + len_tx, len_ty = 16, 16 + s_factor = [i.dom.extent for i in block_info.iters if i.kind == "S"][-1] + # get perfect spatial factor, spatial factor should be divide the innermost spatial loop so + # that the block after r_factor and be reversed compute at the original scope + while len_tx > 1: + if s_factor % len_tx == 0: + break + len_tx -= 1 + _, _ = sch.split(s, factors=[None, len_tx]) + _, ty = sch.split(r, factors=[None, len_ty]) + # Schedule the RF block + rf = sch.rfactor(ty, 0) + bx, tx, r, ty, _ = sch.get_loops(rf) + sch.reorder(bx, tx, ty, r) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(bx, "blockIdx.x") + sch.set_scope(rf, 0, "local") + sch.decompose_reduction(rf, r) + # Schedule the write back block + sch.reverse_compute_at(block, bx, preserve_unit_loops=True) + _, r, *s = sch.get_loops(block) + if unroll_spatial_factor: + assert len(s) == len(loop_order) + new_order_s = [s[loop_order[i]] for i in range(len(s))] + sch.reorder(*new_order_s) + new_order_s[s_split_index], c = sch.split( + new_order_s[s_split_index], factors=[None, unroll_spatial_factor] + ) + sch.reorder(*new_order_s, c) + s = sch.fuse(*new_order_s) + sch.reorder(s, c, r) + else: + s = sch.fuse(*s) + sch.reorder(s, r) + sch.bind(s, "threadIdx.x") + sch.bind(r, "threadIdx.y") + + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + sch.reverse_compute_at(epilogue, bx) + if is_broadcast_epilogue(sch, block, epilogue): + sch.set_scope(block, 0, "shared") + _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx, len_ty]) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + else: + # The epilogue is element-wise without broadcasting. + # Thus the remaining spatial part should be bind to tx. + sch.set_scope(block, 0, "local") + _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + tx, _ = sch.split(sch.fuse(*s), factors=[len_tx, None]) + sch.bind(tx, "threadIdx.x") + # pylint: enable=invalid-name diff --git a/python/bitblas/gpu/rmsnorm.py b/python/bitblas/gpu/rmsnorm.py new file mode 100644 index 000000000000..6e6d3e247905 --- /dev/null +++ b/python/bitblas/gpu/rmsnorm.py @@ -0,0 +1,144 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm rmsnorm.py in dlight. +# pylint: disable=missing-docstring +"""A RMS norm schedule rule for GPU operators.""" + +import tvm +from tvm import tir +from tvm.tir import Block, BufferStore +from tvm.tir.expr import Cast, BufferLoad, Call +from tvm.target import Target + +from ..base import ScheduleRule + + +def identify_cast_or_load_block(block: Block) -> bool: + if len(block.reads) != 1 or len(block.writes) != 1: + return False + + if not isinstance(block.body, BufferStore): + return False + store = block.body + + # check types + if isinstance(store.value, BufferLoad): + load = store.value + elif isinstance(store.value, Cast): + load = store.value.value + if not isinstance(load, BufferLoad): + return False + else: + return False + + # check indices + if len(load.indices) != len(store.indices): + return False + + for lhs, rhs in zip(load.indices, store.indices): + if not lhs.same_as(rhs): + return False + + return True + + +def identify_rsqrt_block(block: Block) -> bool: + if len(block.reads) != 1 or len(block.writes) != 1: + return False + + if not isinstance(block.body, BufferStore): + return False + store = block.body + + if not isinstance(store.value, Call): + return False + call = store.value + op = call.op + + return op == tvm.ir.op.Op.get("tir.rsqrt") + + +class RMSNorm(ScheduleRule): + """A rule for RMS norm.""" + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> tir.Schedule: + if target.kind.name == "cuda": + num_tx = 512 + else: + num_tx = 64 + + sch = tir.Schedule(func) + root = sch.get_block(name="root", func_name="main") + + blocks = sch.get_child_blocks(root) + + if not any([identify_rsqrt_block(sch.get(block)) for block in blocks]): + return None + + read = sch.cache_read(block=blocks[0], read_buffer_index=0, storage_scope="local") + write = sch.cache_write(block=blocks[-1], write_buffer_index=0, storage_scope="local") + + for block in blocks: + if identify_cast_or_load_block(sch.get(block)): + sch.compute_inline(block) + + blocks = sch.get_child_blocks(root) + + read, sqr, redsum, rsqrt, norm, write = blocks + + if not identify_rsqrt_block(sch.get(rsqrt)): + return None + + for name in [read, sqr, redsum, rsqrt, norm, write]: + loops = sch.get_loops(name) + sch.fuse(*loops[:-1]) + + block_loop, loops = sch.get_loops(block=read) + thread_loop, _, _ = sch.split( + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True + ) + sch.bind(block_loop, thread_axis="blockIdx.x") + sch.bind(thread_loop, thread_axis="threadIdx.x") + sch.vectorize(sch.get_loops(block=read)[-1]) + sch.reverse_compute_at(block=sqr, loop=thread_loop) + sch.reverse_compute_at(block=redsum, loop=thread_loop) + + sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1) + sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) + block_loop, loops = sch.get_loops(block=norm) + thread_loop, _, _ = sch.split( + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True + ) + sch.bind(thread_loop, thread_axis="threadIdx.x") + + sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) + sch.vectorize(sch.get_loops(block=write)[-1]) + + sch.set_scope(block=sqr, buffer_index=0, storage_scope="local") + sch.set_scope(block=redsum, buffer_index=0, storage_scope="local") + sch.set_scope(block=rsqrt, buffer_index=0, storage_scope="shared") + sch.set_scope(block=norm, buffer_index=0, storage_scope="local") + + return sch diff --git a/python/bitblas/gpu/transpose.py b/python/bitblas/gpu/transpose.py new file mode 100644 index 000000000000..6dc025c07c9d --- /dev/null +++ b/python/bitblas/gpu/transpose.py @@ -0,0 +1,133 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm transpose.py in dlight. +"""Reduction rule for operators including softmax, layer norm, RMS norm, etc""" +from typing import List, Union + +from tvm import arith, tir +from tvm.target import Target +from tvm.tir import Schedule +from tvm.tir.schedule import BlockRV + +from ..base import ( + detect_dominant_read, + normalize_prim_func, + try_inline_contiguous_spatial, +) +from .base import GPUScheduleRule + + +class Transpose(GPUScheduleRule): + """Schedule rule for transpose""" + + def is_transpose(self, sch: Schedule, block_rv: BlockRV): + block = sch.get(block_rv) + if isinstance(block.body, tir.BufferStore): + rhs = block.body.value + if isinstance(rhs, tir.BufferLoad): + lhs_indices = block.body.indices + rhs_indices = rhs.indices + if list(lhs_indices) != list(rhs_indices) and set(lhs_indices) == set(rhs_indices): + return True + return False + + def apply( # pylint: disable=too-many-locals + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + # pylint: disable=invalid-name + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + if target.kind.name == "cuda": + len_tx = 16 + len_ty = 8 + unroll_depth = 256 + else: + len_tx = 8 + len_ty = 4 + unroll_depth = 64 + len_vec = 4 + + sch = tir.Schedule(func) + blocks = normalize_prim_func(sch) + transpose_block_idx = -1 + for idx, block in reversed(list(enumerate(blocks))): + if self.is_transpose(sch, block.block_rv): + transpose_block_idx = idx + break + if not block.is_injective(): + return None + if transpose_block_idx == -1: + return None + transpose_block = blocks[transpose_block_idx].block_rv + + prologue = None # the optional decoding block + if transpose_block_idx > 0: + spatials = try_inline_contiguous_spatial(sch, blocks[: transpose_block_idx - 1]) + assert len(spatials) == 0 + prologue = blocks[transpose_block_idx - 1].block_rv + + loops = sch.get_loops(transpose_block) + if len(loops) != 2: + # transpose with more than 2 axes is not supported + return None + + c_factor = 1 + if prologue is not None: + block_stmt = sch.get(prologue) + result = arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom.extent for i in block_stmt.iter_vars}, + ) + if len(result.args) > 0: + c_factor = int(result.args[0].lower_factor) + + i, j = loops + i, vi = sch.split(i, factors=[None, c_factor], preserve_unit_iters=True) + bi, ti = sch.split(i, factors=[None, len_ty], preserve_unit_iters=True) + bj, tj = sch.split(j, factors=[None, len_tx], preserve_unit_iters=True) + sch.reorder(bi, bj, ti, tj, vi) + sch.bind(bi, "blockIdx.y") + sch.bind(bj, "blockIdx.x") + sch.bind(ti, "threadIdx.y") + sch.bind(tj, "threadIdx.x") + len_vec = min(len_vec, c_factor) + _, vi = sch.split(vi, factors=[None, len_vec]) + if len_vec > 1: + sch.vectorize(vi) + + cache_read = sch.cache_read(transpose_block, read_buffer_index=0, storage_scope="shared") + sch.compute_at(cache_read, bj) + loops = sch.get_loops(cache_read)[2:] + fused = sch.fuse(*loops) + _, ty, tx, v = sch.split(fused, factors=[None, len_ty, len_tx, c_factor]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.unroll(v) + sch.storage_align(block=cache_read, buffer_index=0, axis=0, factor=32, offset=1) + + sch.annotate(bi, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(bi, ann_key="pragma_unroll_explicit", ann_val=1) + + if prologue is not None: + sch.compute_inline(prologue) + return sch diff --git a/python/bitblas/gpu/utils.py b/python/bitblas/gpu/utils.py new file mode 100644 index 000000000000..e3a5b6098fad --- /dev/null +++ b/python/bitblas/gpu/utils.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring +"""Utility methods for generic GPU.""" +from typing import List, Optional + +from tvm import tir +from tvm.target import Target + + +def max_threads_per_block(target: Target) -> int: + """Get the maximum number of threads per block for a given target. + + Parameters + ---------- + target : Target + The target to get the maximum number of threads per block for. + + Returns + ------- + max_threads_per_block : int + The maximum number of threads per block for the given target. + """ + for name in ["max_threads_per_block", "max_num_threads"]: + result = target.attrs.get(name, None) + if result is not None: + return result + if target.kind.name == "cuda": + return 1024 + return 256 + + +def suggest_threads_per_block( + target: Target, + loops: List[tir.For], + max_threads_for_dynamic_loop: int = 32, +) -> List[int]: + if target.kind.name == "cuda": + threads = 1024 + elif target.kind.name == "rocm": + threads = 256 + elif target.kind.name == "metal": + threads = 256 + else: + threads = 64 + results: List[Optional[int]] = [] + dynamic: List[int] = [] + for i, loop in enumerate(loops): + loop_extent = loop.extent + if isinstance(loop_extent, tir.IntImm): + loop_extent = loop_extent.value + extent = 1 + while extent <= loop_extent and extent <= threads: + extent *= 2 + extent //= 2 + assert extent >= 1 + assert threads % extent == 0 + threads //= extent + results.append(extent) + else: + results.append(None) + dynamic.append(i) + + for i in dynamic: + extent = 1 + while extent <= max_threads_for_dynamic_loop and extent <= threads: + extent *= 2 + extent //= 2 + assert extent >= 1 + assert threads % extent == 0 + threads //= extent + results[i] = extent + + if dynamic: + results[dynamic[0]] *= threads + + return results + + +def get_sm_version(target: Target) -> int: + if target.kind.name != "cuda": + return -1 + arch = target.arch + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 diff --git a/python/bitblas/ops/__init__.py b/python/bitblas/ops/__init__.py new file mode 100644 index 000000000000..08fd3d5d8c25 --- /dev/null +++ b/python/bitblas/ops/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .operator import Operator +from .matmul import Matmul, MatmulConfig +from .ladder_permutate import LadderPermutate, LadderPermutateConfig +from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig diff --git a/python/bitblas/ops/gemv_impl.py b/python/bitblas/ops/gemv_impl.py new file mode 100644 index 000000000000..7a3aac445f56 --- /dev/null +++ b/python/bitblas/ops/gemv_impl.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of gemv +import tvm +from tvm.script import tir as T +from tvm import te + + +def gemv_i4(M, N, K, dtype="float16"): + bit = 4 + n_float_per_i8 = 8 // bit + + def _tir_u8_to_int_to_float( + nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str + ): + assert val.dtype == "int8" + mask = tvm.tir.const((1 << nbit) - 1, "int8") + return ((val >> (pos * nbit).astype("int8")) & mask).astype(dtype) + + A = te.placeholder((M, K), name="A", dtype=dtype) + B = te.placeholder((N, K // 8 * bit), name="B", dtype="int8") + + def decode_func(n, k): + w = _tir_u8_to_int_to_float( + bit, B[n, k // n_float_per_i8], k % n_float_per_i8, dtype=dtype + ) + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), lambda i, j: te.sum(A[i, k] * B_decode[j, k], axis=k), name="C" + ) + func = te.create_prim_func([A, B, C]).with_attr( + "dequantize_info", + { + "B": { + "decode_block": "B_decode", + "fast_decoding": True, + "source_format": { + "bits": 4, + "format": "int", + }, + "target_format": { + "bits": 16, + "format": "float", + }, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def gemv(M, N, K, dtype="float16"): + @tvm.script.ir_module + class GEMV: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [M, K], dtype=dtype) + B = T.match_buffer(b, [N, K], dtype=dtype) + C = T.match_buffer(c, [M, N], dtype=dtype) + + for i, j, k in T.grid(M, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + return GEMV diff --git a/python/bitblas/ops/impl/__init__.py b/python/bitblas/ops/impl/__init__.py new file mode 100644 index 000000000000..a254dc7fb2ad --- /dev/null +++ b/python/bitblas/ops/impl/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .lop3_permutate_impl import tir_interleave_weight diff --git a/python/bitblas/ops/impl/ladder_permutate_impl.py b/python/bitblas/ops/impl/ladder_permutate_impl.py new file mode 100644 index 000000000000..5a09fb09b39e --- /dev/null +++ b/python/bitblas/ops/impl/ladder_permutate_impl.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.gpu.matmul_analysis import get_propagate_map +from typing import Literal +from tvm import te, IRModule, DataType +from tvm.tir import IndexMap + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16", "int8"] = "float16", + dequantize_bits: int = -1, + storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16", + propagate_kind: Literal["A", "B"] = "B", + transpose_matrix: bool = False, + transform_kind: int = 0, + target_instruction: Literal["nvidia-mma"] = "nvidia-mma", +): + if target_instruction != "nvidia-mma": + raise ValueError("Currently only support nvidia-mma instruction") + + # This is trick to get the basic tile size for the current datatype + # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 + l = r = 16 + if datatype == "int8": + l, r = 16, 32 + + intra_index_map, _ = get_propagate_map( + transpose_matrix, dtype=datatype, matrix_name=propagate_kind + ) + + target_dtype = DataType(datatype) + scaling_factor = 1 + if dequantize_bits > 0 and dequantize_bits < target_dtype.bits: + scaling_factor = ( + (target_dtype.bits // dequantize_bits) + * DataType(storage_dtype).bits + // target_dtype.bits + ) + r = r // scaling_factor + initial_indices = intra_index_map.initial_indices + scaling_final_indices = intra_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor] + ) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + intra_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + inp = te.placeholder((M, N // scaling_factor), name="inp", dtype=storage_dtype) + args = [inp] + + if transform_kind >= 1: + arg = args[-1] + + inter_warp = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + lambda i, j, ii, jj: arg[i * l + ii, j * r + jj], + name="inter_warp_permutate", + ) + args.append(inter_warp) + if transform_kind >= 2: + # tir required inverse layout transform. + arg = args[-1] + intra_index_map = intra_index_map.inverse([l, r]) + + def fcompute(*args): + warp_i, warp_j = args[-2:] + spatial_args = args[:-2] + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return arg[new_index] + + intra_warp = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + fcompute, + name="intra_warp_permutate", + ) + args.append(intra_warp) + args = [args[0], args[-1]] + + func = te.create_prim_func(args) + + return IRModule.from_expr(func) diff --git a/python/bitblas/ops/impl/lop3_permutate_impl.py b/python/bitblas/ops/impl/lop3_permutate_impl.py new file mode 100644 index 000000000000..aa77793132d4 --- /dev/null +++ b/python/bitblas/ops/impl/lop3_permutate_impl.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Literal +from tvm import IRModule +from tvm.ir import GlobalVar +from tvm.script import tir as T + +# fmt: off +# TIR interleave weight impl-> 2D implementation +def tir_interleave_weight( + N: int = 2, + K: int = 16, + bits: int = 4, + QK: int = -1, + target_dtype: str = "float16", + storage_dtype: str = "int32", +): + if QK == -1: + QK = K * bits // 32 + bits_stride = 16 + mask = (1 << bits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // bits + + @T.prim_func + def interleave_weight( + A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype) + ): + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + ( + offset // num_groups + ) * bits + B[v0, v1] = B[v0, v1] | ( + ((A[v0, v1] >> (bits * offset)) & mask) << shift + ) + + @T.prim_func + def interleave_weight_f16_2b( + A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype) + ): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + ( + offset // num_groups + ) * bits + B[v0, v1] = B[v0, v1] | ( + ((A[v0, v1] >> (bits * offset)) & mask) << shift + ) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xFF0000FF) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x00FF0000)) << 8) >> 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000FF00)) << 16) >> 8 + B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] + + @T.prim_func + def interleave_weight_f16_1b( + A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype) + ): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_6 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_7 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + ( + offset // num_groups + ) * bits + B[v0, v1] = B[v0, v1] | ( + ((A[v0, v1] >> (bits * offset)) & mask) << shift + ) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF000000F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 8 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x00000F00)) >> 8) << 16 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 8 + B_tmp_6[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 20) << 12 + B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 24) << 20 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1] + | B_tmp_6[v0, v1] + | B_tmp_7[v0, v1] + ) + + @T.prim_func + def interleave_weight_int8_1b( + A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype) + ): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + ( + offset // num_groups + ) * bits + B[v0, v1] = B[v0, v1] | ( + ((A[v0, v1] >> (bits * offset)) & mask) << shift + ) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF0F00F0F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 4 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x0F000000)) >> 24) << 12 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1] + ) + + if target_dtype == "float16" and bits == 2: + return interleave_weight_f16_2b + elif target_dtype == "float16" and bits == 1: + return interleave_weight_f16_1b + elif target_dtype == "int8" and bits == 1: + return interleave_weight_int8_1b + + return interleave_weight +# fmt: on + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16", "int8"] = "float16", + storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32", + dequantize_bits: int = 4, +): + func = tir_interleave_weight( + N=N, + K=M, + bits=dequantize_bits, + target_dtype=datatype, + storage_dtype=storage_dtype, + ) + mod = IRModule() + mod.update_func(GlobalVar("main"), func) + return mod diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py new file mode 100644 index 000000000000..0031cbd62eaf --- /dev/null +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm.script import tir as T +from tvm import te, tir, DataType +from tvm.tir import IndexMap +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.quantization import ( + _tir_packed_to_signed_float, + _tir_packed_to_unsigned_float, +) + + +def matmul_nt_dequantize_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + group_size=-1, + fast_decoding=False, + with_bias=False, +): + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K // storage_nbit * bit), name="B", dtype=storage_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def decode_func(n, k): + if source_format == "uint": + w = _tir_packed_to_unsigned_float(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype + ) + elif source_format == "int": + w = _tir_packed_to_signed_float(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype + ) + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if with_scaling: + w = w * Scale[n, k // group_size] + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k + ), + name="C", + ) + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if with_scaling: + args.append(Scale) + if with_bias: + E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + group_size=-1, + fast_decoding=False, + with_bias=False, +): + l = r = 16 + if in_dtype == "int8": + l, r = 16, 32 + + intra_index_map, _ = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ( + (target_dtype.bits // bit) + * DataType(storage_dtype).bits + // target_dtype.bits + ) + initial_indices = intra_index_map.initial_indices + scaling_final_indices = intra_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor] + ) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + intra_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r // storage_nbit * bit + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder( + (N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype + ) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + w = _tir_packed_to_unsigned_float(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + w = _tir_packed_to_signed_float(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if with_scaling: + w = w * Scale[n, k // group_size] + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k + ), + name="C", + ) + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if with_scaling: + args.append(Scale) + if with_bias: + E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def select_implementation( + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + propagate_a=False, + propagate_b=False, +): + if not isinstance(M, int): + raise ValueError("Currently do not implement with dynamic symbolic") + if layout == "nn": + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" + ) + elif layout == "nt": + if propagate_a and propagate_b: + raise NotImplementedError + elif propagate_a: + raise NotImplementedError + elif propagate_b: + return matmul_nt_dequantize_b_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + group_size, + fast_decoding, + with_bias, + ) + else: + return matmul_nt_dequantize_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + group_size, + fast_decoding, + with_bias, + ) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/matmul_impl.py b/python/bitblas/ops/impl/matmul_impl.py new file mode 100644 index 000000000000..74b6c89c9958 --- /dev/null +++ b/python/bitblas/ops/impl/matmul_impl.py @@ -0,0 +1,617 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm.script import tir as T +from tvm import te, tir +from bitblas.gpu.matmul_analysis import get_propagate_map + + +def matmul_nt_dyn_m( + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + @tvm.script.ir_module + class MatmulNT: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.int32() + A = T.match_buffer(a, [m, K], dtype=in_dtype) + B = T.match_buffer(b, [N, K], dtype=in_dtype) + C = T.match_buffer(c, [m, N], dtype=out_dtype) + + for i, j, k in T.grid(m, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ + vj, vk + ].astype(out_dtype) + + @tvm.script.ir_module + class MatmulNTWithAccum: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.int32() + A = T.match_buffer(a, [m, K], dtype=in_dtype) + B = T.match_buffer(b, [N, K], dtype=in_dtype) + C = T.match_buffer(c, [m, N], dtype=out_dtype) + accum = T.alloc_buffer([m, N], dtype=accum_dtype) + for i, j, k in T.grid(m, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + accum[vi, vj] = tvm.tir.const(0, accum_dtype) + accum[vi, vj] = accum[vi, vj] + A[vi, vk].astype(accum_dtype) * B[ + vj, vk + ].astype(accum_dtype) + + for i, j in T.grid(m, N): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = accum[vi, vj].astype(out_dtype) + + @tvm.script.ir_module + class MatmulNTWithAccumBias: + @T.prim_func + def main(a: T.handle, b: T.handle, bias: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.int32() + A = T.match_buffer(a, [m, K], dtype=in_dtype) + B = T.match_buffer(b, [N, K], dtype=in_dtype) + Bias = T.match_buffer(bias, [N], dtype=out_dtype) + C = T.match_buffer(c, [m, N], dtype=out_dtype) + accum = T.alloc_buffer([m, N], dtype=accum_dtype) + accum_bias = T.alloc_buffer([m, N], dtype=out_dtype) + for i, j, k in T.grid(m, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + accum[vi, vj] = tvm.tir.const(0, accum_dtype) + accum[vi, vj] = accum[vi, vj] + A[vi, vk].astype(accum_dtype) * B[ + vj, vk + ].astype(accum_dtype) + + for i, j in T.grid(m, N): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + accum_bias[vi, vj] = accum[vi, vj].astype(out_dtype) + + for i, j in T.grid(m, N): + with T.block("Bias"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = accum_bias[vi, vj] + Bias[vj] + + final_module = MatmulNT + if with_bias: + final_module = MatmulNTWithAccumBias + elif accum_dtype != out_dtype: + final_module = MatmulNTWithAccum + + return final_module + + +def matmul_nn_dyn_m( + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + @tvm.script.ir_module + class MatmulNN: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.int32() + A = T.match_buffer(a, [m, K], dtype=in_dtype) + B = T.match_buffer(b, [K, N], dtype=in_dtype) + C = T.match_buffer(c, [m, N], dtype=out_dtype) + + for i, j, k in T.grid(m, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = tvm.tir.const(0, out_dtype) + C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ + vk, vj + ].astype(out_dtype) + + @tvm.script.ir_module + class MatmulNNWithAccum: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.int32() + A = T.match_buffer(a, [m, K], dtype=in_dtype) + B = T.match_buffer(b, [K, N], dtype=in_dtype) + C = T.match_buffer(c, [m, N], dtype=out_dtype) + accum = T.alloc_buffer([m, N], dtype=accum_dtype) + for i, j, k in T.grid(m, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + accum[vi, vj] = tvm.tir.const(0, accum_dtype) + accum[vi, vj] = accum[vi, vj] + A[vi, vk].astype(accum_dtype) * B[ + vk, vj + ].astype(accum_dtype) + + for i, j in T.grid(m, N): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = accum[vi, vj].astype(out_dtype) + + @tvm.script.ir_module + class MatmulNNWithAccumBias: + @T.prim_func + def main(a: T.handle, b: T.handle, bias: T.handle, c: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + m = T.int32() + A = T.match_buffer(a, [m, K], dtype=in_dtype) + B = T.match_buffer(b, [K, N], dtype=in_dtype) + Bias = T.match_buffer(bias, [N], dtype=out_dtype) + C = T.match_buffer(c, [m, N], dtype=out_dtype) + accum = T.alloc_buffer([m, N], dtype=accum_dtype) + accum_bias = T.alloc_buffer([m, N], dtype=out_dtype) + for i, j, k in T.grid(m, N, K): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + accum[vi, vj] = tvm.tir.const(0, accum_dtype) + accum[vi, vj] = accum[vi, vj] + A[vi, vk].astype(accum_dtype) * B[ + vk, vj + ].astype(accum_dtype) + + for i, j in T.grid(m, N): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + accum_bias[vi, vj] = accum[vi, vj].astype(out_dtype) + + for i, j in T.grid(m, N): + with T.block("Bias"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = accum_bias[vi, vj] + Bias[vj] + + final_module = MatmulNN + if with_bias: + final_module = MatmulNNWithAccumBias + elif accum_dtype != out_dtype: + final_module = MatmulNNWithAccum + + return final_module + + +def matmul_nn( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((K, N), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B[k, j].astype(accum_dtype), axis=k + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + if with_bias: + args = [A, B, Bias, last_output] + else: + args = [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + if with_bias: + args = [A, B, Bias, last_output] + else: + args = [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul_dyn_m( + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + return matmul_nn_dyn_m(N, K, in_dtype, out_dtype, accum_dtype, with_bias) + return matmul_nt_dyn_m(N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +def matmul( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + return matmul_nn(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + return matmul_nt(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +# always assume propagate both intra and inter layout in BitBLAS +# as we do not have to do runtime layout transform +def matmul_nt_propagate_a_dyn_m( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): ... + + +def matmul_nt_propagate_b_dyn_m( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): ... + + +def matmul_nt_propagate_a_propagate_b_dyn_m( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): ... + + +def matmul_nt_propagate_a( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + l = r = 16 + if in_dtype == "int8": + l, r = 16, 32 + + intra_index_map, _ = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + if with_bias: + args = [A, B, Bias, last_output] + else: + args = [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("smooth_a", True) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + l = r = 16 + if in_dtype == "int8": + l, r = 16, 32 + + intra_index_map, _ = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return B[new_index] + + B_reindex = te.compute( + (N, K), + fcompute, + name="B_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), axis=k + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + if with_bias: + args = [A, B, Bias, last_output] + else: + args = [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("smooth_b", True) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt_propagate_a_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + l = r = 16 + if in_dtype == "int8": + l, r = 16, 32 + + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + intra_index_map, _ = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + + intra_index_map, _ = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="B") + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return B[new_index] + + B_reindex = te.compute( + (N, K), + fcompute, + name="B_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), + axis=k, + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + if with_bias: + args = [A, B, Bias, last_output] + else: + args = [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("smooth_a", True) + func = func.with_attr("smooth_b", True) + + return tvm.IRModule.from_expr(func) + + +def _select_implementation_dyn_m( + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a=False, + propagate_b=False, +): + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn" + ) + return matmul_dyn_m(N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + return matmul_nt_propagate_a_propagate_b_dyn_m( + N, K, in_dtype, out_dtype, accum_dtype, with_bias + ) + elif propagate_a: + return matmul_nt_propagate_a_dyn_m( + N, K, in_dtype, out_dtype, accum_dtype, with_bias + ) + elif propagate_b: + return matmul_nt_propagate_b_dyn_m( + N, K, in_dtype, out_dtype, accum_dtype, with_bias + ) + else: + return matmul_dyn_m( + N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout + ) + + +def select_implementation( + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a=False, + propagate_b=False, +): + if not isinstance(M, int): + return _select_implementation_dyn_m( + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + layout, + propagate_a, + propagate_b, + ) + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn" + ) + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + return matmul_nt_propagate_a_propagate_b( + M, N, K, in_dtype, out_dtype, accum_dtype, with_bias + ) + elif propagate_a: + return matmul_nt_propagate_a( + M, N, K, in_dtype, out_dtype, accum_dtype, with_bias + ) + elif propagate_b: + return matmul_nt_propagate_b( + M, N, K, in_dtype, out_dtype, accum_dtype, with_bias + ) + else: + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/ladder_permutate.py b/python/bitblas/ops/ladder_permutate.py new file mode 100644 index 000000000000..9851cbe52200 --- /dev/null +++ b/python/bitblas/ops/ladder_permutate.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.target import Target +from typing import List, Union, Literal +from .operator import Operator +from .impl.ladder_permutate_impl import select_implementation +from dataclasses import dataclass + + +@dataclass +class LadderPermutateConfig: + M: int + N: int + datatype: Literal["float16", "int8"] = "float16" + dequantize_bits: int = -1 + storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16" + propagate_kind: Literal["A", "B"] = "B" # "A" or "B" + transpose_matrix: bool = False + transform_kind: int = 2 # 0: none, 1: inter_warp 2: intra_warp + target_instruction: Literal["nvidia-mma"] = ( + "nvidia-mma" # maybe extend to "cdna-mfma" in future. + ) + + +class LadderPermutate(Operator): + def __init__( + self, + config: LadderPermutateConfig, + name: str = "permutate", + target: Target = tvm.target.Target("llvm"), # assume to do permutation on gpu. + ): + # consider to warp the arguments to MatmulConfig + super().__init__(name, target) + self.config = config + + if target.kind.name != "llvm": + raise ValueError("Currently only support llvm target for Permutation") + + prim_func_mod = self._select_implementation() + self.prim_func_mod = prim_func_mod + self.target = target + self._build_runtime_module(target) + + # select implementation based on the Operator config + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + datatype=self.datatype, + dequantize_bits=self.dequantize_bits, + storage_dtype=self.storage_dtype, + propagate_kind=self.propagate_kind, + transpose_matrix=self.transpose_matrix, + transform_kind=self.transform_kind, + target_instruction=self.target_instruction, + ) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def datatype(self): + return self.config.datatype + + @property + def dequantize_bits(self): + return self.config.dequantize_bits + + @property + def storage_dtype(self): + return self.config.storage_dtype + + @property + def propagate_kind(self): + return self.config.propagate_kind + + @property + def transpose_matrix(self): + return self.config.transpose_matrix + + @property + def transform_kind(self): + return self.config.transform_kind + + @property + def target_instruction(self): + return self.config.target_instruction + + +__all__ = ["LadderPermutate", "LadderPermutateConfig"] diff --git a/python/bitblas/ops/lop3_permutate.py b/python/bitblas/ops/lop3_permutate.py new file mode 100644 index 000000000000..2d2069721593 --- /dev/null +++ b/python/bitblas/ops/lop3_permutate.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.target import Target +from typing import Literal +from .operator import Operator +from .impl.lop3_permutate_impl import select_implementation +from dataclasses import dataclass +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch +import torch + +@dataclass +class LOP3PermutateConfig: + M: int + N: int + datatype: Literal["float16", "int8"] = "float16" + storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32" + dequantize_bits: int = 4 + + +class LOP3Permutate(Operator): + def __init__( + self, + config: LOP3PermutateConfig, + name: str = "permutate", + target: Target = tvm.target.Target("llvm"), # assume to do permutation on gpu. + ): + # consider to warp the arguments to MatmulConfig + super().__init__(name, target) + self.config = config + + if target.kind.name != "llvm": + raise ValueError("Currently only support llvm target for Permutation") + + prim_func_mod = self._select_implementation() + self.prim_func_mod = prim_func_mod + self.target = target + self._build_runtime_module(target) + + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + datatype=self.datatype, + dequantize_bits=self.dequantize_bits, + ) + + def forward_from_torch(self, weight, res): + # reintepret the input tensor to int32 format + _tvm_args = [self._tensor_adapter(arg.view(torch.int32), self.arch.device) for arg in [weight, res]] + self.rt_mod(*_tvm_args) + return tvm_tensor_to_torch(_tvm_args[-1]).view(weight.dtype) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def datatype(self): + return self.config.datatype + + @property + def storage_dtype(self): + return self.config.storage_dtype + + @property + def dequantize_bits(self): + return self.config.dequantize_bits + + +__all__ = ["LOP3Permutate", "LOP3PermutateConfig"] diff --git a/python/bitblas/ops/matmul.py b/python/bitblas/ops/matmul.py new file mode 100644 index 000000000000..d38250365d15 --- /dev/null +++ b/python/bitblas/ops/matmul.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import numpy as np +from tvm.target import Target +from bitblas.base.roller.arch.cuda import CUDA +from typing import List, Union +from .operator import Operator +from .impl.matmul_impl import select_implementation +from ..base.utils import get_rasterization_code +from bitblas.utils import match_global_kernel +from dataclasses import dataclass +from .ladder_permutate import LadderPermutate, LadderPermutateConfig + + +@dataclass +class MatmulConfig: + M: Union[int, List] + N: int + K: int + in_dtype: str = "float16" + out_dtype: str = "float16" + accum_dtype: str = "float16" + with_bias: bool = False + layout: str = "nt" + propagate_a: bool = False + propagate_b: bool = False + + +class Matmul(Operator): + def __init__( + self, + config: MatmulConfig, + name: str = "matmul", + target: Target = tvm.target.Target("cuda"), + ): + super().__init__(name) + self.config = config + + if target.kind.name != "cuda": + raise ValueError("Currently only support cuda target") + self.arch = CUDA(target) + assert self.propagate_a is False, "Currently only support propagate_a=False" + + prim_func_mod = self._select_implementation() + self.prim_func_mod = prim_func_mod + self.optimized_func = self.apply_default_schedule(prim_func_mod, target) + + if isinstance(self.M, List): + self.dynamic_range = {"m": self.M} + self.update_func( + self.prim_func.with_attrs({"opt_shapes": self.dynamic_range}) + ) + else: + self.dynamic_range = None + self.target = target + self._build_runtime_module(target) + + if self.propagate_a: + ladder_permutate_config = LadderPermutateConfig( + M=self.M, + N=self.K, + datatype=self.in_dtype, + storage_dtype=self.in_dtype, + propagate_kind="A", + transpose_matrix=False, + transform_kind=2, + ) + self.ladder_permutate_a = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_a = None + + if self.propagate_b: + ladder_permutate_config = LadderPermutateConfig( + M=self.N, + N=self.K, + datatype=self.in_dtype, + storage_dtype=self.in_dtype, + propagate_kind="B", + transpose_matrix=True if self.layout == "nt" else False, + transform_kind=2, + ) + self.ladder_permutate_b = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_b = None + + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + with_bias=self.with_bias, + layout=self.layout, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + + def post_process(self, code: str) -> str: + index = code.index("{", match_global_kernel(code)) + # some tricky judge to decide whether to insert rasterization code + if self.N * self.K > 10**6: + rasterization_code = get_rasterization_code(10) + code = code[: index + 2] + rasterization_code + code[index + 2 :] + return code + + def _profile_latency_with_dynamic_range(self) -> List: + func = self.prim_func_mod["main"] + device = self.arch.device + + def var_warpper(v, m): + if isinstance(v, tvm.tir.Var): + assert "opt_shapes" in func.attrs + assert v.name in func.attrs["opt_shapes"] + return m + elif isinstance(v, tvm.tir.IntImm): + return v.value + else: + raise RuntimeError("Not supported type: ", type(v)) + + benchmark_latencies = [] + for m in self.dynamic_range["m"]: + profile_tensors = [] + for param in func.params: + if param not in func.buffer_map: + # in case of dynamic symbolic may in params + continue + arg = func.buffer_map[param] + profile_tensors.append( + tvm.nd.array( + np.random.uniform( + 0, 1, [var_warpper(i, m) for i in arg.shape] + ).astype(arg.dtype), + device=device, + ) + ) + self.profile_tensors = profile_tensors + latency = self.time_evaluator(*profile_tensors).mean * 1e3 + benchmark_latencies.append({"m": m, "latency": latency}) + # ms + return benchmark_latencies + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def K(self): + return self.config.K + + @property + def in_dtype(self): + return self.config.in_dtype + + @property + def out_dtype(self): + return self.config.out_dtype + + @property + def accum_dtype(self): + return self.config.accum_dtype + + @property + def layout(self): + return self.config.layout + + @property + def with_bias(self): + return self.config.with_bias + + @property + def propagate_a(self): + return self.config.propagate_a + + @property + def propagate_b(self): + return self.config.propagate_b + + @property + def input_transform(self): + if self.ladder_permutate_a is not None: + return self.ladder_permutate_a + return None + + @property + def weight_transform(self): + if self.ladder_permutate_b is not None: + return self.ladder_permutate_b + return None + + +__all__ = ["Matmul", "MatmulConfig"] diff --git a/python/bitblas/ops/matmul_dequantize.py b/python/bitblas/ops/matmul_dequantize.py new file mode 100644 index 000000000000..70f1f5386a1f --- /dev/null +++ b/python/bitblas/ops/matmul_dequantize.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.target import Target +from bitblas.base.roller.arch.cuda import CUDA +from typing import Any, List +from .operator import Operator +from .impl.matmul_dequantize_impl import select_implementation +from ..base.utils import get_rasterization_code +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch +from dataclasses import dataclass +from bitblas.utils import match_global_kernel +from .ladder_permutate import LadderPermutate, LadderPermutateConfig +from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig + + +class WeightExecutorCPU: + def __init__(self, operators: List[Operator] = []): + self.operators = operators + + def append(self, op): + self.operators.append(op) + + def is_none(self): + return len(self.operators) == 0 + + def forward(self, weight): + inputs = [weight] + for op in self.operators: + inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) + inputs = [op.forward(*inputs)] + return inputs[-1] + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def size(self): + return len(self.operators) + + +@dataclass +class MatmulWeightOnlyDequantizeConfig: + M: int + N: int + K: int + in_dtype: str = "float16" + out_dtype: str = "float16" + accum_dtype: str = "float16" + bit: int = 4 + storage_dtype: str = "int8" + source_format: str = "int" + with_scaling: bool = False + group_size: int = -1 + fast_decoding: bool = False + with_bias: bool = False + propagate_a: bool = False + propagate_b: bool = False + layout: str = "nt" + + +class MatmulWeightOnlyDequantize(Operator): + + def __init__( + self, + config: MatmulWeightOnlyDequantizeConfig, + name: str = "matmul_weight_only_dequantize", + target: Target = tvm.target.Target("cuda"), + ): + super().__init__(name) + self.config = config + + if target.kind.name != "cuda": + raise ValueError("Currently only support cuda target") + self.arch = CUDA(target) + assert self.propagate_a is False, "Currently only support propagate_a=False" + + self.prim_func_mod = self._select_implementation() + + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + if isinstance(self.M, List): + self.dynamic_range = {"m": self.M} + self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( + {"opt_shapes": self.dynamic_range} + ) + else: + self.dynamic_range = None + self.target = target + self._build_runtime_module(target) + + if self.propagate_a: + ladder_permutate_config = LadderPermutateConfig( + M=self.M, + N=self.K, + datatype=self.in_dtype, + storage_dtype=self.in_dtype, + propagate_kind="A", + transpose_matrix=False, + transform_kind=2, + ) + self.ladder_permutate_a = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_a = None + + if self.propagate_b: + ladder_permutate_config = LadderPermutateConfig( + M=self.N, + N=self.K, + datatype=self.in_dtype, + dequantize_bits=self.bit, + storage_dtype=self.storage_dtype, + propagate_kind="B", + transpose_matrix=True if self.layout == "nt" else False, + transform_kind=2, + ) + self.ladder_permutate_b = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_b = None + + if self.fast_decoding: + lop3_permutate_config = LOP3PermutateConfig( + M=self.N, + N=self.K, + datatype=self.in_dtype, + dequantize_bits=self.bit, + storage_dtype=self.storage_dtype, + ) + self.lop3_permutate = LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.lop3_permutate = None + + weight_executors = WeightExecutorCPU() + if self.lop3_permutate is not None: + weight_executors.append(self.lop3_permutate) + if self.ladder_permutate_b is not None: + weight_executors.append(self.ladder_permutate_b) + + self.weight_executors = weight_executors + + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + bit=self.bit, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + layout=self.layout, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + + def post_process(self, code: str) -> str: + index = code.index("{", match_global_kernel(code)) + # some tricky judge to decide whether to insert rasterization code + if self.M == 1: + return code + if self.N * self.K > 10**6: + rasterization_code = get_rasterization_code(10) + code = code[: index + 2] + rasterization_code + code[index + 2 :] + return code + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def K(self): + return self.config.K + + @property + def in_dtype(self): + return self.config.in_dtype + + @property + def out_dtype(self): + return self.config.out_dtype + + @property + def accum_dtype(self): + return self.config.accum_dtype + + @property + def bit(self): + return self.config.bit + + @property + def storage_dtype(self): + return self.config.storage_dtype + + @property + def source_format(self): + return self.config.source_format + + @property + def with_scaling(self): + return self.config.with_scaling + + @property + def group_size(self): + return self.config.group_size + + @property + def fast_decoding(self): + return self.config.fast_decoding + + @property + def with_bias(self): + return self.config.with_bias + + @property + def propagate_a(self): + return self.config.propagate_a + + @property + def propagate_b(self): + return self.config.propagate_b + + @property + def layout(self): + return self.config.layout + + @property + def input_transform(self): + if self.ladder_permutate_a is not None: + return self.ladder_permutate_a + return None + + @property + def weight_transform(self): + return self.weight_executors if self.weight_executors.size else None diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py new file mode 100644 index 000000000000..13983f1d6726 --- /dev/null +++ b/python/bitblas/ops/operator.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod +import tvm +from tvm import IRModule +from tvm.target import Target +from tvm.tir import PrimFunc +import bitblas +from typing import List, Dict, Any +import numpy as np +from ..base import fast_tune, fast_tune_with_dynamic_range +from copy import deepcopy +from bitblas.base.roller.arch import get_arch +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch + + +class Operator(ABC): + def __init__(self, name, target: Target = None): + self.name = name + self.prim_func_mod = None + self.optimized_func = None + self.rt_mod = None + self.time_evaluator = None + self.profile_tensors = None + self.arch = get_arch(target) if target else None + self.dynamic_range = None + + def codegen(self, target: Target = None) -> str: + if target is None: + target = self.target + if self.rt_mod is None: + self._build_runtime_module(target) + return ( + self.post_process(self.rt_mod.imported_modules[0].get_source()) + if self.rt_mod + else None + ) + + def _build_runtime_module(self, target: Target): + """ + Builds the runtime module based on the architecture platform. + + This function attempts to build a runtime module (rt_mod) for the specified target. + If the platform is CUDA and an optimized function is available, it tries to build + using the optimized function with a specific pass context. Otherwise, it falls back + to building with the primary function. After successful build, it initializes a + time evaluator for performance measurement. + + Args: + target (Target): The compilation target specification. + + Returns: + The compiled runtime module or None if the build was unsuccessful. + """ + + # Initialize rt_mod as None to handle cases where build fails or is skipped + rt_mod = None + + # Check if the platform is CUDA and we have an optimized function + if self.arch.platform == "CUDA" and self.optimized_func: + try: + # Use a specific TVM pass context for CUDA platforms + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) + except Exception as e: + # Log the exception for debugging purposes. Replace 'print' with logging if necessary. + print(f"Failed to build optimized function for CUDA target due to: {e}") + else: + # For non-CUDA platforms or when no optimized function is available, build with the primary function + rt_mod = tvm.build(self.prim_func, target=target, name=self.name) + + # If the runtime module was successfully built, set up for evaluation + if rt_mod: + self.rt_mod = rt_mod + # Initialize a time evaluator with the built module, specifying the device and the number of runs + self.time_evaluator = rt_mod.time_evaluator( + rt_mod.entry_name, self.arch.device, number=10 + ) + + return rt_mod + + def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule: + mod_for_opt = deepcopy(func_mod) + with target: + optimized_mod = ( + bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable + bitblas.gpu.Matmul(), + bitblas.gpu.GEMV(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + )(mod_for_opt) + ) + + if optimized_mod is not None: + return optimized_mod + return None + + def post_process(self, code: str) -> str: + return code + + def apply_fast_tuning( + self, func: PrimFunc, target: Target, topk: int = 20 + ) -> IRModule: + _, best = fast_tune(func, target, topk=topk, parallel_build=True) + if best is not None: + return best.sch.mod + return None + + def apply_fast_tuning_with_dynamic_range( + self, + func: PrimFunc, + target: Target, + topk: int = 20, + dynamic_range: Dict[str, List[int]] = None, + ): + optimized_mod = fast_tune_with_dynamic_range( + func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range + ) + if optimized_mod is not None: + return optimized_mod + return None + + def hardware_aware_finetune(self, topk: int = 20, target: tvm.target.Target = None): + if target is None: + target = self.target + dynamic_range = self.dynamic_range + func = self.prim_func + if dynamic_range is not None: + self.optimized_func = self.apply_fast_tuning_with_dynamic_range( + func, target, topk, dynamic_range + ) + else: + self.optimized_func = self.apply_fast_tuning(func, target, topk) + self._build_runtime_module(self.target) + + def get_profile_tensors(self): + func = self.prim_func + device = self.arch.device + + def var_warpper(v): + if isinstance(v, tvm.tir.Var): + assert "opt_shapes" in func.attrs + assert v.name in func.attrs["opt_shapes"] + return func.attrs["opt_shapes"][v.name].value + elif isinstance(v, tvm.tir.IntImm): + return v.value + else: + raise RuntimeError("Not supported type: ", type(v)) + + profile_tensors = [] + for param in func.params: + if param not in func.buffer_map: + # in case of dynamic symbolic may in params + continue + arg = func.buffer_map[param] + profile_tensors.append( + tvm.nd.array( + np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype( + arg.dtype + ), + device=device, + ) + ) + self.profile_tensors = profile_tensors + return profile_tensors + + def profile_latency(self) -> str: + if self.dynamic_range is not None: + return self._profile_latency_with_dynamic_range() + + profile_tensors = self.get_profile_tensors() + latency = self.time_evaluator(*profile_tensors).mean * 1e3 + return latency + + def _profile_latency_with_dynamic_range(self) -> List: + raise NotImplementedError + + def _tensor_adapter(self, tensor, device): + import torch + from torch.utils.dlpack import to_dlpack + + if isinstance(tensor, tvm.te.Tensor): + return tensor + elif isinstance(tensor, torch.Tensor): + return tvm.runtime.ndarray.from_dlpack(to_dlpack(tensor)) + elif isinstance(tensor, np.ndarray): + return tvm.nd.array(tensor, device=device) + else: + raise RuntimeError("Not supported type: ", type(tensor)) + + def forward_from_torch(self, *args): + # convert tensor from torch to tvm + _tvm_args = [self._tensor_adapter(arg, self.arch.device) for arg in args] + self.rt_mod(*_tvm_args) + return tvm_tensor_to_torch(_tvm_args[-1]) + + def forward(self, *args): + # "Currently only support forward from torch tensor" + return self.forward_from_torch(*args) + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + def update_func(self, func: PrimFunc): + self.prim_func_mod["main"] = func + + @abstractmethod + def _select_implementation(self) -> IRModule: + pass + + @property + def prim_func(self): + return self.prim_func_mod["main"] diff --git a/python/bitblas/quantization/__init__.py b/python/bitblas/quantization/__init__.py new file mode 100644 index 000000000000..0d50de0ddfe2 --- /dev/null +++ b/python/bitblas/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .quantization import ( + _tir_packed_int_to_int_to_float, + _tir_packed_uint_to_uint_to_float, + _tir_packed_to_signed_float, + _tir_packed_to_unsigned_float, +) diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py new file mode 100644 index 000000000000..8a1bab0b391e --- /dev/null +++ b/python/bitblas/quantization/quantization.py @@ -0,0 +1,148 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from mlc.ai quantization.py in mlc-llm. +# pylint: disable=invalid-name,missing-function-docstring,unused-variable +"""TIR computation utilities for quantization.""" + +import tvm +from tvm import tir + +# fmt: off +def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool=True): + mask = tir.const((1 << 16) - 1, "uint32") + res = [] + for data in [v0, v1]: + u32_val = tir.reinterpret("uint32", data) + if round_to_even: + rounding_bias = ((u32_val >> tir.const(16, "uint32")) & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") + u32_val += rounding_bias + res.append((u32_val >> tir.const(16, "uint32")) & mask) + return res[0] | (res[1] << tir.const(16, "uint32")) + + +def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): + mask = tir.const((1 << 16) - 1, "uint32") + x0 = x & mask + x1 = (x >> 16) & mask + return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1]) + + +def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == "uint32" + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask) + + +def _tir_packed_uint_to_uint_to_float(storage_nbit: int): + storage_dtype = "uint" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + max_int_value = (1 << (nbit - 1)) - 1 + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const((1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + + +def _tir_packed_int_to_int_to_float(storage_nbit: int): + storage_dtype = "int" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert + + +def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == "float32" + val_u32 = tir.reinterpret("uint32", val) + # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) + # e_f32 == 120 -> e_f4 = 1 + # e_f32 < 120 -> e_f4 = 0 + m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") + e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") + s = (val_u32 >> tir.const(31, "uint32")) + e_f4 = tir.Select(e_f32 > tir.const(120, "uint32"), tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) + return (s << tir.const(3, "uint32")) | e_f4 + + +def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == "float16" + val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val)) + m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") + e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") + s = (val_u32 >> tir.const(15, "uint32")) + e_f4 = tir.Select(e_f16 > tir.const(8, "uint32"), tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) + return (s << tir.const(3, "uint32")) | e_f4 + + +def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "float32" + assert val.dtype == "uint32" + # e_f4 == 0 -> e_f32 = 0 + # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask + s = f4 >> tir.const(3, "uint32") + e_f4 = f4 & tir.const(7, "uint32") + e_f32 = e_f4 | tir.const(120, "uint32") + val_f32 = tir.reinterpret("float32", (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) + return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) + + +def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "float16" + assert val.dtype == "uint32" + # e_f4 == 0 -> e_f16 = 0 + # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask + s = f4 >> tir.const(3, "uint32") + e_f4 = f4 & tir.const(7, "uint32") + e_f16 = e_f4 | tir.const(8, "uint32") + val_f16 = tir.reinterpret("float16", (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) + return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) + +def _tir_packed_to_signed_float(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + max_int_value = (1 << (nbit - 1)) - 1 + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const((1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + +def _tir_packed_to_unsigned_float(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert( + nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str + ): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype) + return f_convert +# fmt: on diff --git a/python/bitblas/quantization/utils.py b/python/bitblas/quantization/utils.py new file mode 100644 index 000000000000..4ac3a60a13a7 --- /dev/null +++ b/python/bitblas/quantization/utils.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import numpy as np + + +def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == np.float16: + lowprecision_weight = lowprecision_weight.astype(dtype=np.int8) + int8_weight = np.zeros( + ( + *lowprecision_weight.shape[:-1], + lowprecision_weight.shape[-1] // elems_per_byte, + ), + dtype=np.int8, + ) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << ( + source_bits * k + ) + + return int8_weight.view(storage_dtype) + + +# interleave weight numpy implementation +def interleave_weight(qweight, nbits=4, target_dtype="float16"): + assert target_dtype in ["float16", "int8"] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(np.int32) + new_qweight = np.zeros_like(qweight) + bits_stride = 8 if target_dtype == "int8" else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == "int8": + # special handling for 1b interleave + n16_weight = new_qweight & np.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(np.int8) + elif nbits == 2 and target_dtype == "float16": + n8_weight = new_qweight & np.int32(0xFF0000FF) + n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(np.int8) + elif nbits == 1 and target_dtype == "float16": + n8_weight = new_qweight & 0xF000000F + n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 + n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 + n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 + n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 + n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 + n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + + return new_qweight.view(np.int8) diff --git a/python/bitblas/relax/op/interleave_weight.py b/python/bitblas/relax/op/interleave_weight.py new file mode 100644 index 000000000000..98b1f5cd4f30 --- /dev/null +++ b/python/bitblas/relax/op/interleave_weight.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.relax.block_builder import BlockBuilder +from tvm.relax.expr import Call, Expr +from tvm.relax.transform.legalize_ops.common import register_legalize + +from bitblas.ops.impl import tir_interleave_weight + + +@register_legalize("bitblas.interleave_weight") +def _interleave_weight(bb: BlockBuilder, call: Call) -> Expr: + nbits = call.attrs.nbits + target_dtype = call.attrs.target_dtype + out_dtype = call.attrs.out_dtype + + return bb.call_te( + tir_interleave_weight(nbits, target_dtype, out_dtype), + call.args[0], + primfunc_name_hint="interleave_weight", + ) + + +__all__ = ["_interleave_weight"] diff --git a/python/bitblas/relax/transform/__init__.py b/python/bitblas/relax/transform/__init__.py new file mode 100644 index 000000000000..b92f2c0b411a --- /dev/null +++ b/python/bitblas/relax/transform/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .annotate_decode_block import AnnotateDecodeInformation +from .weight_only_propagate import WeightOnlyLayoutPropagation diff --git a/python/bitblas/relax/transform/annotate_decode_block.py b/python/bitblas/relax/transform/annotate_decode_block.py new file mode 100644 index 000000000000..c08f55db66d0 --- /dev/null +++ b/python/bitblas/relax/transform/annotate_decode_block.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Dict, Tuple +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm import tir +from tvm.tir.schedule import BlockRV +from mlc_llm.quantization import quantization_schemes, GroupQuantizationSpec +from bitblas.gpu.gemv import is_gemv +from bitblas.gpu.matmul_analysis import ( + get_reduction_blocks, + get_index_map, + get_root_block, + get_dequantize_block, +) +from bitblas.base import ( + normalize_prim_func, + try_inline_contiguous_spatial, +) + + +# Define a module pass to annotate dequantization information +@module_pass(opt_level=0, name="AnnotateDecodeInformation") +class AnnotateDecodeInformation: + def __init__(self, spec: str = "q4f16_0"): + # Validate and store the specified quantization scheme + if spec not in quantization_schemes: + raise ValueError(f"Quantization scheme {spec} not found") + self.quantize_scheme = quantization_schemes[spec] + + def detect_matmul(self, func: tir.PrimFunc) -> bool: + """Detect if the given function represents a matrix multiplication.""" + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + # Identify reduction blocks to infer matmul operations + reduction_blocks = get_reduction_blocks(sch, blocks) + if not reduction_blocks: + return False + + # Check for index map patterns typical of matmul operations + main_block = reduction_blocks[0] + main_block_stmt = sch.get(main_block) + index_maps = get_index_map(main_block_stmt) + _is_matmul = index_maps is not None + + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + block_info = block_infos[0] + _is_gemv = True + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + _is_gemv = False + if _is_gemv: + _is_gemv = is_gemv(sch, block_info) + return _is_matmul or _is_gemv + + def transform_module(self, mod: IRModule, _: PassContext) -> IRModule: + """Annotate dequantize information for all applicable functions in the module.""" + for g_var, func in mod.functions.items(): + if not isinstance(func, tir.PrimFunc) or g_var.name_hint == "main": + continue + + if not self.detect_matmul(func): + continue # Process only if matmul is detected + + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + dequantize_block = get_dequantize_block(sch, blocks) + if dequantize_block is None: + continue # Skip if no dequantize block is found + + # Prepare dequantize info annotation + dequantize_info = self.prepare_dequantize_info(sch, dequantize_block) + + # Annotate function with dequantize information + mod[g_var] = func.with_attr("dequantize_info", dequantize_info) + return mod + + def prepare_dequantize_info( + self, sch: tir.Schedule, dequantize_block: BlockRV + ) -> Dict: + """Generate dequantize information for a given block.""" + block_stmt = sch.get(dequantize_block) + block_name = block_stmt.name_hint + dequantize_info = { + block_name: {"decode_block": block_name, "fast_decoding": False} + } + + quantize_spec = self.quantize_scheme.linear_weight + if isinstance(quantize_spec, GroupQuantizationSpec): + dequantize_info[block_name].update( + { + "with_scaling": True, + "group_size": quantize_spec.group_size, + } + ) + + # Determine source format based on quantization mode + quantize_mod = quantize_spec.mode + bits, source_format = self.parse_quantize_mode(quantize_mod) + dequantize_info[block_name]["source_format"] = { + "bits": bits, + "format": source_format, + } + + # Set storage and target data types + storage_dtype = self.get_storage_dtype(block_stmt, source_format) + dequantize_info[block_name]["storage_dtype"] = storage_dtype + dequantize_info[block_name]["target_format"] = quantize_spec.dtype + + return dequantize_info + + def parse_quantize_mode(self, quantize_mod: str) -> Tuple[int, str]: + """Extract bits and format from quantization mode.""" + if quantize_mod.startswith("int"): + return int(quantize_mod[3:]), "int" + elif quantize_mod.startswith("uint"): + return int(quantize_mod[4:]), "uint" + raise ValueError(f"Unsupported mode {quantize_mod}") + + def get_storage_dtype(self, block_stmt: BlockRV, source_format: str) -> str: + """Determine storage data type based on source format.""" + return ( + block_stmt.reads[0].buffer.dtype + if "af" not in source_format + else block_stmt.reads[1].buffer.dtype + ) diff --git a/python/bitblas/relax/transform/weight_only_propagate.py b/python/bitblas/relax/transform/weight_only_propagate.py new file mode 100644 index 000000000000..309068f4094f --- /dev/null +++ b/python/bitblas/relax/transform/weight_only_propagate.py @@ -0,0 +1,463 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Optional, Tuple, Union, List, Dict +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm import relax +from tvm import tir +from enum import Enum +from tvm.ir import GlobalVar +from tvm.tir import IndexMap +from tvm.target import Target +from tvm.tir import IterVar +from tvm.tir.schedule.schedule import BlockRV +from tvm.relax import PyExprMutator +from tvm.relax.expr import Call +from bitblas.gpu.matmul_analysis import ( + get_tensorized_func_and_tags, + get_propagate_map, + find_last_producer_from_buffer, + find_arg_idx_from_buffer_chain, + layout_propagate_chain, +) +from tvm.dlight.base import ( + analysis, +) +from dataclasses import dataclass + + +def get_reduction_blocks(sch, blocks) -> bool: + # Get the main computation block + def is_reduction(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + def is_spatial(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all([is_reduction(block) or is_spatial(block) for block in blocks]): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction(block)] + if len(reduction_blocks) != 1: + return None + + return reduction_blocks + + +class TransformKind(Enum): + NonTransform = 0 + InterWarpTransform = 1 + IntraWarpTransform = 2 + + +def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + +def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: + """ + Detect In/Out data types for the given block based on the analysis if read/write buffers. + """ + assert len(block.reads) > 0 and len(block.writes) > 0 + in_dtype = block.reads[0].buffer.dtype + out_dtype = block.writes[0].buffer.dtype + return (in_dtype, out_dtype) + + +@dataclass +class LayoutTransformHint: + """ + A dataclass to store the layout transformation hint. + """ + + transform_level: TransformKind + inter_warp_layout: IndexMap + intra_warp_layout: IndexMap + apply_arg_idx: int + + +@module_pass(opt_level=0, name="InsertLayoutTransform") +class WeightOnlyLayoutPropagation: + def __init__( + self, + transform_level: Union[int, TransformKind] = TransformKind.InterWarpTransform, + target: Optional[Target] = None, + faster_conversion: bool = False, + ) -> None: + if isinstance(transform_level, int): + transform_level = TransformKind(transform_level) + assert transform_level in [ + TransformKind.NonTransform, + TransformKind.InterWarpTransform, + TransformKind.IntraWarpTransform, + ] + # transform_level 1: only transform the inter-warp memory layout + # transform_level 2: transform the inter-warp memory layout and the intra-warp memory layout + self.transform_level = transform_level + self.target = Target.current() if target is None else target + # fast type conversion on nvidia gpu also requires weight permutation + self.faster_conversion = faster_conversion + # layout transform info to sync the layout in both graph and tir + self.layout_transform_hints: Dict[str, List[LayoutTransformHint]] = {} + + def detect_propagate_matmul(self, func: tir.PrimFunc, target: Target): + _, tags = get_tensorized_func_and_tags( + func, target, skip_normalize=True, allow_gemv=True + ) + if tags is None: + return False, None + return True, tags["intrin_info"] + + def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info): + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group, + ) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None or len(reduction_blocks) != 1: + return False + (main_block,) = reduction_blocks + + intrin_group = get_mma_intrin_group( + load_scope="shared", + store_scope="shared", + in_dtype=intrin_info["in_dtype"], + out_dtype=intrin_info["out_dtype"], + trans_a=False, + trans_b=intrin_info["trans_b"], + ) + + _, inter_j, inter_k = intrin_group["micro_kernel"] + + # weight only propagation + target_scope = ("read", 1) + weight_buffer = sch.get(main_block).reads[1].buffer + + # checkout whether the weight buffer has dynamic symbol + def check_dynamic_symbol(buffer): + for axis in buffer.shape: + if isinstance(axis, tir.Var): + return True + return False + + if check_dynamic_symbol(weight_buffer): + print( + "[BitBLAS] Weight buffer has dynamic symbol, skip weight propagation." + ) + return False + + transformed_block = find_last_producer_from_buffer( + sch, main_block, weight_buffer + ) + if transformed_block is None: + return False + if transformed_block != main_block: + target_scope = ("read", 0) + + reindex_block = sch.cache_read(transformed_block, target_scope[1], "global") + + # create inter-warp memory layout index map + inter_warp_layout = IndexMap.from_func( + lambda i, j: (i // inter_j, j // inter_k, i % inter_j, j % inter_k) + ) + + inter_warp_layout = layout_propagate_chain( + sch, + main_block, + sch.get(main_block).reads[1].buffer, + reindex_block, + inter_warp_layout, + ) + + sch.transform_layout( + reindex_block, + ("read", 0), + lambda i, j: inter_warp_layout.map_indices([i, j]), + ) + arg_idx = find_arg_idx_from_buffer_chain( + sch, reindex_block, sch.get(reindex_block).reads[0].buffer + ) + + intra_warp_layout = None + if self.transform_level.value >= TransformKind.IntraWarpTransform.value: + intra_warp_layout, _ = get_propagate_map(intrin_info["trans_b"]) + intra_warp_layout = layout_propagate_chain( + sch, + main_block, + sch.get(main_block).reads[1].buffer, + reindex_block, + intra_warp_layout, + ) + sch.transform_layout( + reindex_block, + ("read", 0), + lambda i, j, ii, jj: ( + i, + j, + *intra_warp_layout.map_indices([ii, jj]), + ), + ) + + self.layout_transform_hints[g_var] = [ + LayoutTransformHint( + transform_level=self.transform_level, + inter_warp_layout=inter_warp_layout, + intra_warp_layout=intra_warp_layout, + apply_arg_idx=arg_idx, + ) + ] + + return sch.mod["main"] + + def transform_module( # pylint: disable=missing-function-docstring + self, + mod: IRModule, + _: PassContext, + ) -> IRModule: + if self.target.kind.name != "cuda": + # currently weight propagation only support nvidia gpus + return mod + + propogate_candidates = {} + propogated_funcs = {} # some funcs may not be able to transform + candidates_intrin_info = {} + decoded_funcs = {} + for g_var, func in mod.functions_items(): + if not isinstance(func, tir.PrimFunc): + continue + if g_var.name_hint != "main": + # Note: this can be applied to any function which can be transformed to matmul (e.g., conv2d) + # for mlc we only consider matmul + # detect the pattern + is_matmul, intrin_info = self.detect_propagate_matmul(func, self.target) + + if ( + func.attrs is not None + and "dlight.do_not_tensorize" in func.attrs.keys() + ): + # currently we only support tensorize propagation + continue + + if is_matmul: + if "dequantize_info" in func.attrs: + decoded_funcs[g_var] = func + if self.transform_level != TransformKind.NonTransform: + # lift tags to the function as it has intrinsic information that can be reused. + propogate_candidates[g_var] = func + candidates_intrin_info[g_var] = intrin_info + + for g_var, func in propogate_candidates.items(): + updated_func = self.transform_matmul( + g_var, func, candidates_intrin_info[g_var] + ) + if updated_func: + updated_func = updated_func.with_attrs( + { + "transform_kind": self.transform_level.value, + "smooth_b": True, + } + ) + propogated_funcs[g_var] = updated_func + mod[g_var] = updated_func + + @relax.expr_functor.mutator + class TensorCoreLayoutMutator(PyExprMutator): + """Mutator that performs transformation.""" + + def __init__( + self, + transform_level: TransformKind = TransformKind.NonTransform, + layout_transform_hints: Dict[str, List[LayoutTransformHint]] = {}, + ): + super().__init__() + self.transform_level = transform_level + self.layout_transform_hints = layout_transform_hints + + def tc_layout_transform(self, call_node: Call) -> Call: + if self.transform_level == TransformKind.NonTransform: + return super().visit_call_(call_node) + g_var = call_node.args[0] + if g_var not in propogated_funcs.keys(): + return super().visit_call_(call_node) + args = list(call_node.args[1]) + # assume we only have weight propagation currently + (weight_layout_hint,) = self.layout_transform_hints[g_var] + weight = args[weight_layout_hint.apply_arg_idx] + weight = self.builder_.emit( + relax.op.layout_transform( + weight, + index_map=lambda i, j: weight_layout_hint.inter_warp_layout.map_indices( + [i, j] + ), + ) + ) + if self.transform_level.value >= TransformKind.IntraWarpTransform.value: + weight = self.builder_.emit( + relax.op.layout_transform( + weight, + index_map=lambda i, j, ii, jj: ( + i, + j, + *weight_layout_hint.intra_warp_layout.map_indices( + [ii, jj] + ), + ), + ) + ) + + call_node = self.builder_.emit( + relax.call_tir( + g_var, + args[: weight_layout_hint.apply_arg_idx] + + [weight] + + args[weight_layout_hint.apply_arg_idx + 1 :], + out_sinfo=call_node.struct_info, + ) + ) + return call_node + + def visit_call_(self, call_node: Call): + return self.tc_layout_transform(call_node) + + def transform( + self, + mod: IRModule, + ): + for gv, func in mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + for gv, func in new_mod.functions_items(): + mod.update_func(gv, func) + return mod + + mod = TensorCoreLayoutMutator( + transform_level=self.transform_level, + layout_transform_hints=self.layout_transform_hints, + ).transform(mod) + + @relax.expr_functor.mutator + class FastTypeConversionLayoutMutator(PyExprMutator): + """Mutator that performs transformation.""" + + def __init__(self, faster_conversion: bool = False): + super().__init__() + self.faster_conversion = faster_conversion + + def lop3_layout_transform(self, call_node: Call) -> Call: + if not self.faster_conversion: + return super().visit_call_(call_node) + + from bitblas.ops.impl import tir_interleave_weight + + g_var = call_node.args[0] + if g_var not in decoded_funcs.keys(): + return super().visit_call_(call_node) + + args = list(call_node.args[1]) + func = decoded_funcs[g_var] + if "dequantize_info" not in func.attrs: + return super().visit_call_(call_node) + dequantize_info = dict(func.attrs["dequantize_info"]) + assert len(dequantize_info) == 1 + (weight_dequantize_info,) = dequantize_info.values() + + sch = tir.Schedule(func) + dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) + + # weight is the first read buffer if format in ["int", "uint"], otherwise the second read buffer, af .etc + source_format = weight_dequantize_info["source_format"]["format"] + source_bits = weight_dequantize_info["source_format"]["bits"] + target_dtype = weight_dequantize_info["target_format"] + + if source_format in ["int", "uint"]: + weight_buffer = sch.get(dequantize_block).reads[0].buffer + elif source_format in ["af"]: + weight_buffer = sch.get(dequantize_block).reads[1].buffer + else: + raise ValueError(f"Unsupported source format {source_format}") + + # update func with dequantize_info + dequantize_info["fast_decoding"] = True + self.builder_.update_func( + g_var, func.with_attrs({"dequantize_info": dequantize_info}) + ) + + weight_idx = find_arg_idx_from_buffer_chain( + sch, dequantize_block, weight_buffer + ) + weight = args[weight_idx] + + weight_shape = weight_buffer.shape + # reshape the weight shape to 2d + reshape_weight = self.builder_.emit( + relax.op.reshape(weight, (-1, weight_shape[-1])) + ) + # register g_var to the func + lop3_interleave_func = tir_interleave_weight( + N=reshape_weight.struct_info.shape[0], + QK=reshape_weight.struct_info.shape[1], + bits=source_bits, + target_dtype=target_dtype, + storage_dtype=reshape_weight.struct_info.dtype, + ) + interleave_gvar = self.builder_.add_func( + lop3_interleave_func.without_attr("global_symbol"), + "tir_interleave_weight", + ) + lop3_interleave_weight = self.builder_.emit( + relax.call_tir( + interleave_gvar, + [reshape_weight], + out_sinfo=reshape_weight.struct_info, + ), + ) + reshape_weight = self.builder_.emit( + relax.op.reshape(lop3_interleave_weight, weight_shape) + ) + call_node = self.builder_.emit( + relax.call_tir( + g_var, + args[:weight_idx] + [reshape_weight] + args[weight_idx + 1 :], + out_sinfo=call_node.struct_info, + ), + ) + + return call_node + + def visit_call_(self, call_node: Call): + return self.lop3_layout_transform(call_node) + + def transform( + self, + mod: IRModule, + ): + for gv, func in mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + for gv, func in new_mod.functions_items(): + mod.update_func(gv, func) + return mod + + mod = FastTypeConversionLayoutMutator( + faster_conversion=self.faster_conversion + ).transform(mod) + mod = relax.transform.LegalizeOps()(mod) + return mod diff --git a/python/bitblas/testing/__init__.py b/python/bitblas/testing/__init__.py new file mode 100644 index 000000000000..0240eb69e0eb --- /dev/null +++ b/python/bitblas/testing/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import sys +import inspect +import pytest + +# pytest.main() wrapper to allow running single test file +def main(): + test_file = inspect.getsourcefile(sys._getframe(1)) + sys.exit(pytest.main([test_file] + sys.argv[1:])) diff --git a/python/bitblas/utils/__init__.py b/python/bitblas/utils/__init__.py new file mode 100644 index 000000000000..b7d5b2f42d36 --- /dev/null +++ b/python/bitblas/utils/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .tensor_adapter import tvm_tensor_to_torch +import re + +def match_global_kernel(source: str) -> int: + pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+" + matched = re.findall(pattern, source) + assert len(matched) > 1 # may have statement before kernel + return source.index(matched[0]) diff --git a/python/bitblas/utils/tensor_adapter.py b/python/bitblas/utils/tensor_adapter.py new file mode 100644 index 000000000000..7708cc40c0bb --- /dev/null +++ b/python/bitblas/utils/tensor_adapter.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from typing import Union + + +def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): + import torch + from torch.utils.dlpack import from_dlpack + + if isinstance(tensor, tvm.te.Tensor): + return torch.from_numpy(tensor.numpy()) + elif isinstance(tensor, tvm.nd.NDArray): + return from_dlpack(tensor) + else: + raise RuntimeError("Not supported type: ", type(tensor)) diff --git a/python/bitblas_cli.py b/python/bitblas_cli.py new file mode 100644 index 000000000000..59e481eb93dd --- /dev/null +++ b/python/bitblas_cli.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/testing/cpp/.gitignore b/testing/cpp/.gitignore new file mode 100644 index 000000000000..f65b0cab7dcb --- /dev/null +++ b/testing/cpp/.gitignore @@ -0,0 +1,2 @@ +# ignore the build directory +build/ diff --git a/testing/cpp/CMakeLists.txt b/testing/cpp/CMakeLists.txt new file mode 100644 index 000000000000..cf8eb0d3a1dc --- /dev/null +++ b/testing/cpp/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(YourProjectTests LANGUAGES CXX CUDA) + +# Set the C++ standard to C++17 +set(CMAKE_CXX_STANDARD 17) + +# Find GTest +find_package(GTest REQUIRED) + +include_directories(${GTEST_INCLUDE_DIRS}) + +add_subdirectory(lop3_type_conversion) diff --git a/testing/cpp/lop3_type_conversion/CMakeLists.txt b/testing/cpp/lop3_type_conversion/CMakeLists.txt new file mode 100644 index 000000000000..61903faf4aa3 --- /dev/null +++ b/testing/cpp/lop3_type_conversion/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +function (ADD_CUDA_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cu) + set_target_properties(${name} PROPERTIES CUDA_ARCHITECTURES 60) + set_target_properties(${name} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON) + target_link_libraries(${name} gtest gtest_main) +endfunction(ADD_CUDA_TEST_EXECUTABLE) + +ADD_CUDA_TEST_EXECUTABLE(lowprecision_to_float16) +ADD_CUDA_TEST_EXECUTABLE(lowprecision_to_int8) diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp new file mode 100644 index 000000000000..627916a0a5eb --- /dev/null +++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp @@ -0,0 +1,441 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include + +// Pack two half values. +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) +{ + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +void general_compress(const int8_t *lowbit, int8_t *compressed, const int nbit, const int N, bool isSigned = false) +{ + int zero_point = isSigned ? ((1 << (nbit - 1)) - 1) : 0; + const int nbit_per_byte = 8 / nbit; + + for (int i = 0; i < N / nbit_per_byte; i++) + { + compressed[i] = 0; + for (int j = 0; j < nbit_per_byte; j++) + { + compressed[i] |= ((lowbit[nbit_per_byte * i + j] + zero_point) << (nbit * j)); + } + } +} + +void general_interleave_fp16(int8_t *origin_arr, int8_t *interleaved, const int nbit, size_t size_in_bytes, bool verbose = false) +{ + // For fp16 example + // i4s {e7,e6,e5,e4,e3,e2,e1,e0} + // |-8b-||-8b-||-8b-||-8b-| + // interleave {e7,e5,e3,e1,e6,e4,e2,e0} + /* + BOTTOM_MASK 0 0 0 f 0 0 0 f + i4s e7 e5 e3 e1 e6 e4 e2 e0 + selectedVal 0000 0000 0000 e1 0000 0000 0000 e0 // selectedVal = i4s & BOTTOM_MASK + h[0] 0110 0100 0 e1 0110 0100 0 e0 // selectVal | 0x6400 + */ + // i2s {e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // i1s {e31,e30,e29,e28,e27,e26,e25,e24,e23,e22,e21,e20,e19,e18,e17,e16,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // Assuming size is the number of int32 elements in origin_arr + size_t size = size_in_bytes / sizeof(int32_t); + int32_t *int32_origin = (int32_t *)origin_arr; + int32_t *int32_interleaved = (int32_t *)interleaved; + + int mask = (1 << nbit) - 1; + int num_groups = (32 / nbit) / 2; + + for (int idx = 0; idx < size; ++idx) + { + int32_t current_value = int32_origin[idx]; + int32_t new_value = 0; + + for (int i = 0; i < num_groups; ++i) + { + int left_shift = nbit * i; + int right_shift = nbit * (num_groups - i - 1); + new_value |= (current_value & (mask << nbit * (2 * i))) >> left_shift; + new_value |= (current_value & (mask << nbit * (2 * i + 1))) << right_shift; + if (verbose) + { + printf("put %d to %d\n", (2 * i), (nbit * (2 * i) - left_shift) / nbit); + printf("put %d to %d\n", (2 * i + 1), (nbit * (2 * i + 1) + right_shift) / nbit); + } + } + if (nbit == 2) + { + int32_t _new_value_n16 = (new_value & 0xff0000ff); + _new_value_n16 |= ((new_value & 0x0000ff00) >> 8) << 16; + _new_value_n16 |= ((new_value & 0x00ff0000) >> 16) << 8; + int32_interleaved[idx] = _new_value_n16; + } + else if (nbit == 1) + { + int32_t _new_value_n16 = (new_value & 0xf000000f); + _new_value_n16 |= ((new_value & 0x000000f0) >> 4) << 8; + _new_value_n16 |= ((new_value & 0x00000f00) >> 8) << 16; + _new_value_n16 |= ((new_value & 0x0000f000) >> 12) << 24; + _new_value_n16 |= ((new_value & 0x000f0000) >> 16) << 4; + _new_value_n16 |= ((new_value & 0x00f00000) >> 20) << 12; + _new_value_n16 |= ((new_value & 0x0f000000) >> 24) << 20; + int32_interleaved[idx] = _new_value_n16; + } + else + int32_interleaved[idx] = new_value; + } + + // Convert back to int8_t if needed + memcpy(interleaved, int32_interleaved, size * sizeof(int32_t)); +} + +template +__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8, const half *scale = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64076407 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + if constexpr (withScaling) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } + } +} + +template +__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4u, B_local_decode, N); +} + +template +__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +{ + decode_i4b_to_f16(_i4s, B_local_decode, N, scale); +} + +template +__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +{ + decode_i4b_to_f16(_i4u, B_local_decode, N, scale); +} + +template +__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8, half *scale = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + if constexpr (withScaling) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } + } +} + +template +__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2u, B_local_decode, N); +} + +template +__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +{ + decode_i2b_to_f16(_i2s, B_local_decode, N, scale); +} + +template +__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +{ + decode_i2b_to_f16(_i2u, B_local_decode, N, scale); +} + +template +__device__ void decode_i1b_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8, half *scale = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64006400 : 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + if constexpr (withScaling) + { + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } + } +} + +template +__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1s, B_local_decode, N); +} + +template +__device__ void decode_i1u_to_f16(T1 *_i1u, T2 *B_local_decode, const int N = 8) +{ + decode_i1b_to_f16(_i1u, B_local_decode, N); +} + +template +__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +{ + decode_i1b_to_f16(_i1s, B_local_decode, N, scale); +} + +template +__device__ void decode_i1u_to_f16_scale(T1 *_i1u, T2 *B_local_decode, half *scale = nullptr, const int N = 8) +{ + decode_i1b_to_f16(_i1u, B_local_decode, N, scale); +} + +void general_interleave_int8(int8_t *origin_arr, int8_t *interleaved, const int nbit, size_t size_in_bytes, bool verbose = false) +{ + // For fp16 example + // i4s {e7,e6,e5,e4,e3,e2,e1,e0} + // |-8b-||-8b-||-8b-||-8b-| + // interleave {e7,e3,e6,e2,e5,e1,e4,e0} + /* + BOTTOM_MASK 0 0 0 f 0 0 0 f + i4s e7 e3 e6 e2 e5 e1 e4 e0 + selectedVal 0000 e3 0000 e2 0000 e1 0000 e0 // selectedVal = i4s & BOTTOM_MASK + s[0] 0 e3 0 e2 0 e1 0 e0 + */ + + // |-----8b-------||-------8b----||----8b---||-----8b----| + // i2s {e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {e15,e11,e7,e3,e14,e10,e6,e2,e13,e9,e5,e1,e12,e8,e4,e0} + + // |-------------8b----------------||--------------8b--------------||------------8b--------------||--------8b-----------| + // i1s {e31,e30,e29,e28,e27,e26,e25,e24,e23,e22,e21,e20,e19,e18,e17,e16,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {e31,e27,e23,e19,e15,e11,e7,e3,e30,e26,e22,e18,e14,e10,e6,e2,e29,e25,e21,e17,e13,e9,e5,e1,e28,e24,e20,e16,e12,e8,e4,e0} + // Assuming size is the number of int32 elements in origin_arr + size_t size = size_in_bytes / sizeof(int32_t); + int32_t *int32_origin = (int32_t *)origin_arr; + int32_t *int32_interleaved = (int32_t *)interleaved; + + constexpr int bits_stride = 8; + int elems_per_group = bits_stride / nbit; + int mask = (1 << nbit) - 1; + int num_groups = 32 / bits_stride; + + for (int idx = 0; idx < size; ++idx) + { + int32_t current_value = int32_origin[idx]; + int32_t new_value = 0; + for (int i = 0; i < num_groups; ++i) + { + for (int j = 0; j < elems_per_group; ++j) + { + int offset = i * elems_per_group + j; + int shift = (offset % num_groups) * bits_stride + (offset / num_groups) * nbit; + int group_value = (current_value >> (nbit * (i * elems_per_group + j))) & mask; + new_value |= group_value << shift; + if (verbose) + printf("put %d to %d\n", offset, shift); + } + } + if (nbit == 1) + { + int32_t _new_value_n16 = (new_value & 0xf0f00f0f); + _new_value_n16 |= ((new_value & 0x000000f0) >> 4) << 16; + _new_value_n16 |= ((new_value & 0x0000f000) >> 12) << 24; + _new_value_n16 |= ((new_value & 0x000f0000) >> 16) << 4; + _new_value_n16 |= ((new_value & 0x0f000000) >> 24) << 12; + int32_interleaved[idx] = _new_value_n16; + } + else + int32_interleaved[idx] = new_value; + } + + // Convert back to int8_t if needed + memcpy(interleaved, int32_interleaved, size * sizeof(int32_t)); +} + +template +__device__ void decode_i4b_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) +{ + uint *i8s = reinterpret_cast(_i8s); + uint *i4b = reinterpret_cast(_i4b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = isSigned ? 0x07070707 : 0x00000000; +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i8s[i]) + : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i8s[i + 2]) + : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + if constexpr (isSigned) + { + i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM); + i8s[i + 2] = __vsubss4(i8s[i + 2], MEDIAN_NUM); + } + } +} + +template +__device__ void decode_i4s_to_i8s(T1 *_i4s, T2 *B_local_decode, const int N = 16) +{ + decode_i4b_to_i8s(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i4u_to_i8s(T1 *_i4u, T2 *B_local_decode, const int N = 16) +{ + decode_i4b_to_i8s(_i4u, B_local_decode, N); +} + +template +__device__ void decode_i2b_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + static constexpr uint MEDIAN_NUM = isSigned ? 0x01010101 : 0x00000000; +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + if constexpr (isSigned) + { + i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM); + } + } +} + +template +__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i8s(_i2s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_i8s(T1 *_i2u, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i8s(_i2u, B_local_decode, N); +} + +template +__device__ void decode_i1b_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) +{ + int *i8s = reinterpret_cast(_i8s); + int16_t i1b_i16 = *reinterpret_cast(_i1b); + // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} + // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} + int i1b = (i1b_i16 & 0x0f0f); + i1b |= ((i1b_i16 & 0xf0f0) << 12); + // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // First, we extract the i1b and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; + static constexpr uint MEDIAN_NUM = isSigned ? 0x00000000 : 0x00000000; + + for (int i = 0; i < N / 4; i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i8s[i]) + : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + + if constexpr (isSigned) + { + i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM); + } + } +} + +template +__device__ void decode_i1s_to_i8s(T1 *_i1s, T2 *B_local_decode, const int N = 16) +{ + decode_i1b_to_i8s(_i1s, B_local_decode, N); +} + +template +__device__ void decode_i1u_to_i8s(T1 *_i1u, T2 *B_local_decode, const int N = 16) +{ + decode_i1b_to_i8s(_i1u, B_local_decode, N); +} diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu new file mode 100644 index 000000000000..b08ca911d7b0 --- /dev/null +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu @@ -0,0 +1,680 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "fast_decoding.hpp" + +#define cudaCheckLastError(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) + exit(code); + } +} + +#define REGISTER_GLOBAL_DEVICE_INVOKER(kernel, function) \ + template \ + __global__ void kernel(Args... args) \ + { \ + function(args...); \ + } + +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4s_to_f16, decode_i4s_to_f16) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16, decode_i4u_to_f16) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2s_to_f16, decode_i2s_to_f16) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16, decode_i2u_to_f16) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1s_to_f16, decode_i1s_to_f16) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1u_to_f16, decode_i1u_to_f16) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4s_to_f16_scale, decode_i4s_to_f16_scale) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale, decode_i4u_to_f16_scale) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2s_to_f16_scale, decode_i2s_to_f16_scale) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16_scale, decode_i2u_to_f16_scale) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1s_to_f16_scale, decode_i1s_to_f16_scale) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1u_to_f16_scale, decode_i1u_to_f16_scale) + +TEST(DecodeTest, DecodeInt4ToFloat16) +{ + constexpr int nbits = 4; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i4s_to_f16<<>>(ins_gpu, decoded_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt4ToFloat16) +{ + constexpr int nbits = 4; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i4u_to_f16<<>>(ins_gpu, decoded_gpu); + + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeInt2ToFloat16) +{ + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i2s_to_f16<<>>(ins_gpu, decoded_gpu); + kernelWrapper_i2s_to_f16<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt2ToFloat16) +{ + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i2u_to_f16<<>>(ins_gpu, decoded_gpu); + kernelWrapper_i2u_to_f16<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2); + + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeInt1ToFloat16) +{ + constexpr int nbits = 1; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i1s_to_f16<<>>(ins_gpu, decoded_gpu); + kernelWrapper_i1s_to_f16<<>>(ins_gpu + QN / 4, decoded_gpu + N / 4); + kernelWrapper_i1s_to_f16<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2); + kernelWrapper_i1s_to_f16<<>>(ins_gpu + +QN / 2 + QN / 4, decoded_gpu + +N / 2 + N / 4); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt1ToFloat16) +{ + constexpr int nbits = 1; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i1u_to_f16<<>>(ins_gpu, decoded_gpu); + kernelWrapper_i1u_to_f16<<>>(ins_gpu + QN / 4, decoded_gpu + N / 4); + kernelWrapper_i1u_to_f16<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2); + kernelWrapper_i1u_to_f16<<>>(ins_gpu + +QN / 2 + QN / 4, decoded_gpu + +N / 2 + N / 4); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeInt4ToFloat16WithScaling) +{ + constexpr int nbits = 4; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + half scale[1] = {__float2half(0.314)}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i4s_to_f16_scale<<>>(ins_gpu, decoded_gpu, scale_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(in_data[i] * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt4ToFloat16WithScaling) +{ + constexpr int nbits = 4; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + half scale[1] = {__float2half(1.2)}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i4u_to_f16_scale<<>>(ins_gpu, decoded_gpu, scale_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(in_data[i] * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeInt2ToFloat16WithScaling) +{ + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + half scale[1] = {__float2half(0.314)}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i2s_to_f16_scale<<>>(ins_gpu, decoded_gpu, scale_gpu); + kernelWrapper_i2s_to_f16_scale<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2, scale_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(in_data[i] * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt2ToFloat16WithScaling) +{ + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + half scale[1] = {__float2half(1.0)}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i2u_to_f16_scale<<>>(ins_gpu, decoded_gpu, scale_gpu); + kernelWrapper_i2u_to_f16_scale<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2, scale_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(in_data[i] * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeInt1ToFloat16WithScaling) +{ + constexpr int nbits = 1; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + half scale[1] = {__float2half(0.314)}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i1s_to_f16_scale<<>>(ins_gpu, decoded_gpu, scale_gpu); + kernelWrapper_i1s_to_f16_scale<<>>(ins_gpu + QN / 4, decoded_gpu + N / 4, scale_gpu); + kernelWrapper_i1s_to_f16_scale<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2, scale_gpu); + kernelWrapper_i1s_to_f16_scale<<>>(ins_gpu + QN / 2 + QN / 4, decoded_gpu + N / 2 + N / 4, scale_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(in_data[i] * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt1ToFloat16WithScaling) +{ + constexpr int nbits = 1; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + half scale[1] = {__float2half(1.0)}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i1u_to_f16_scale<<>>(ins_gpu, decoded_gpu, scale_gpu); + kernelWrapper_i1u_to_f16_scale<<>>(ins_gpu + QN / 4, decoded_gpu + N / 4, scale_gpu); + kernelWrapper_i1u_to_f16_scale<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2, scale_gpu); + kernelWrapper_i1u_to_f16_scale<<>>(ins_gpu + QN / 2 + QN / 4, decoded_gpu + N / 2 + N / 4, scale_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(in_data[i] * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); +} diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu new file mode 100644 index 000000000000..fe1b1dd7198f --- /dev/null +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int8.cu @@ -0,0 +1,345 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "fast_decoding.hpp" + +#define cudaCheckLastError(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) + exit(code); + } +} + +#define REGISTER_GLOBAL_DEVICE_INVOKER(kernel, function) \ + template \ + __global__ void kernel(Args... args) \ + { \ + function(args...); \ + } + +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4s_to_i8s, decode_i4s_to_i8s) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_i8s, decode_i4u_to_i8s) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2s_to_i8s, decode_i2s_to_i8s) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_i8s, decode_i2u_to_i8s) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1s_to_i8s, decode_i1s_to_i8s) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1u_to_i8s, decode_i1u_to_i8s) + +TEST(DecodeTest, DecodeInt4ToINT8) +{ + using target_dtype = int8_t; + constexpr int nbits = 4; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_int8(ins, interleaved, nbits, QN * sizeof(int8_t), false); + target_dtype *decoded = new target_dtype[N]; + int8_t *ins_gpu; + target_dtype *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i4s_to_i8s<<>>(ins_gpu, decoded_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt4ToINT8) +{ + using target_dtype = int8_t; + constexpr int nbits = 4; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_int8(ins, interleaved, nbits, QN * sizeof(int8_t), false); + target_dtype *decoded = new target_dtype[N]; + int8_t *ins_gpu; + target_dtype *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i4u_to_i8s<<>>(ins_gpu, decoded_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeInt2ToINT8) +{ + using target_dtype = int8_t; + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_int8(ins, interleaved, nbits, QN * sizeof(int8_t), false); + target_dtype *decoded = new target_dtype[N]; + int8_t *ins_gpu; + target_dtype *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i2s_to_i8s<<>>(ins_gpu, decoded_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt2ToINT8) +{ + using target_dtype = int8_t; + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_int8(ins, interleaved, nbits, QN * sizeof(int8_t), false); + + target_dtype *decoded = new target_dtype[N]; + int8_t *ins_gpu; + target_dtype *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i2u_to_i8s<<>>(ins_gpu, decoded_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeInt1ToINT8) +{ + using target_dtype = int8_t; + constexpr int nbits = 1; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = true; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_int8(ins, interleaved, nbits, QN * sizeof(int8_t), false); + target_dtype *decoded = new target_dtype[N]; + int8_t *ins_gpu; + target_dtype *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i1s_to_i8s<<>>(ins_gpu, decoded_gpu); + kernelWrapper_i1s_to_i8s<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} + +TEST(DecodeTest, DecodeUInt1ToINT8) +{ + using target_dtype = int8_t; + constexpr int nbits = 1; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_int8(ins, interleaved, nbits, QN * sizeof(int8_t), false); + + target_dtype *decoded = new target_dtype[N]; + int8_t *ins_gpu; + target_dtype *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i1u_to_i8s<<>>(ins_gpu, decoded_gpu); + kernelWrapper_i1u_to_i8s<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_EQ(in_data[i], int(decoded[i])); + } + free(ins); + free(interleaved); + free(decoded); +} diff --git a/testing/python/dsl/test_auto_normalized_tensorcore.py b/testing/python/dsl/test_auto_normalized_tensorcore.py new file mode 100644 index 000000000000..171fef0fab39 --- /dev/null +++ b/testing/python/dsl/test_auto_normalized_tensorcore.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import numpy as np +import tvm +from tvm.script import tir as T +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.gpu import Matmul +from bitblas.base.utils import apply_and_build +import time +from tvm import te, tir + + +def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dtype="float16"): + A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) + B = te.placeholder((kh, kw, c, f), name="weight", dtype=in_dtype) + + pad_shape = (n, h + 2 * p, w + 2 * p, c) + pad_value = tir.const(0.0, A.dtype) + pad = te.compute( + pad_shape, + lambda n, h, w, c: te.if_then_else( + tir.all( + h >= p, + w >= p, + h < pad_shape[1] - p, + w < pad_shape[2] - p, + ), + A[n, h - p, w - p, c], + pad_value, + ), + name="pad", + ) + kernel_h, kernel_w = kh, kw + stride_h, stride_w = s, s + dilation_h, dilation_w = d, d + out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + out_shape = (n, out_h, out_w, f) + kh = te.reduce_axis((0, kernel_h), name="kh") + kw = te.reduce_axis((0, kernel_w), name="kw") + c = te.reduce_axis((0, c), name="c") + C = te.compute( + out_shape, + lambda n, h, w, f: te.sum( + pad[ + n, + h * stride_h + kh * dilation_h, + w * stride_w + kw * dilation_w, + c, + ] + * B[kh, kw, c, f], + axis=[kh, kw, c], + ), + name="C", + ) + return tvm.ir.IRModule({"main": te.create_prim_func([A, B, C])}) + + +benchmark_sets = [ + # (prim_func, input_args, default_dlight_schedule), + # (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float16", "float16"), Matmul), + # (conv2d_nhwc_hwio, (128, 64, 224, 224, 64, 1, 1, 2, 1, 3, "float16", "float16"), Matmul), + # (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float32", "float32"), Matmul), + (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float16", "float16"), Matmul), +] +benchmark_results = {} +for get_prim_func, input_args, d_schedule in benchmark_sets: + ir_module = get_prim_func(*input_args) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + tune_start = time.time() + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + fast_tune_time = time.time() - tune_start + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency * 1e3)) + print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3)) + + # evaluate the performance of the default schedule + + rule = d_schedule() + default_tune_start = time.time() + sch_default = rule.apply(func, target, False) + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): + mod_default = tvm.build(sch_default.mod["main"], target="cuda") + default_tune_time = time.time() - default_tune_start + + args = func.buffer_map.values() + + profile_tensors = [] + for arg in args: + profile_tensors.append( + tvm.nd.array( + np.random.uniform(0, 1, [int(i) for i in arg.shape]).astype(arg.dtype), + device=arch.device, + ) + ) + + timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) + t = timer_cuda_mod(*profile_tensors).mean + + print("Time cost of Dlight default schedule: {:.3f} ms".format(t * 1e3)) + + profile_config = { + f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { + "fast_dlight_top20_tune_time": fast_tune_time, + "fast_dlight_top1_latency": cpresults[0].latency * 1e3, + "fast_dlight_top20_latency": best.latency * 1e3, + "default_dlight_tune_time": default_tune_time, + "default_dlight_latency": t * 1e3, + } + } + benchmark_results.update(profile_config) + +headers = [ + "PrimFunc", + "Input Arguments", + "FastDLight Top20 Tune Time", + "FastDLight Top1 Latency", + "FastDLight Top20 Latency", + "DefaultDLight Tune Time", + "DefaultDLight Latency", +] + +col_width = ( + max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 +) # padding + +print("".join(word.ljust(col_width) for word in headers)) + +print("-" * col_width * len(headers)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f" {str(values['fast_dlight_top20_tune_time'])} s", + f"{values['fast_dlight_top1_latency']:.3f} ms", + f"{values['fast_dlight_top20_latency']:.3f} ms", + str(values["default_dlight_tune_time"]), + f"{values['default_dlight_latency']:.3f} ms", + ] + print("".join(word.ljust(col_width) for word in row)) diff --git a/testing/python/operators/test_int8xint8_gemm.py b/testing/python/operators/test_int8xint8_gemm.py new file mode 100644 index 000000000000..f02ebb32bd4c --- /dev/null +++ b/testing/python/operators/test_int8xint8_gemm.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import bitblas +from bitblas.ops import Matmul +import numpy as np + + +def test_matmul_codegen_static_shape_optimize_s8(): + M = 16384 + N = 16384 + K = 16384 + + target = tvm.target.Target("nvidia/nvidia-a100") + + matmul = Matmul( + M=M, + N=N, + K=K, + a_dtype="int8", + b_dtype="int8", + c_dtype="int32", + propagate_a=False, + propagate_b=False, + layout="nt", + target=target, + ) + matmul.optimize() + code = matmul.codegen(target=target) + latency = matmul.profile_latency() + print(latency) + assert code + + +if __name__ == "__main__": + test_matmul_codegen_static_shape_optimize_s8() diff --git a/testing/python/operators/test_ladder_permutate_ops.py b/testing/python/operators/test_ladder_permutate_ops.py new file mode 100644 index 000000000000..a50fb11c7638 --- /dev/null +++ b/testing/python/operators/test_ladder_permutate_ops.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import tvm +import bitblas +from bitblas.ops.ladder_permutate import LadderPermutate, LadderPermutateConfig + +target = tvm.target.Target("llvm") + +# fmt: off +@pytest.mark.parametrize("M,N,datatype,dequantize_bits,storage_dtype,propagate_kind,transpose_matrix,transform_kind,target_instruction", [ + (1024, 1024, "float16", -1, "float16", "B", True, 0, "nvidia-mma"), + (1024, 1024, "float16", -1, "float16", "B", True, 1, "nvidia-mma"), + (1024, 1024, "float16", -1, "float16", "B", True, 2, "nvidia-mma"), + # dequantize propagation + (1024, 1024, "float16", 4, "uint32", "B", True, 2, "nvidia-mma"), +]) +def test_ladder_permutate_profile_latency( + M, + N, + datatype, + dequantize_bits, + storage_dtype, + propagate_kind, + transpose_matrix, + transform_kind, + target_instruction, +): + + ladder_permutate_config = LadderPermutateConfig( + M=M, + N=N, + datatype=datatype, + dequantize_bits=dequantize_bits, + storage_dtype=storage_dtype, + propagate_kind=propagate_kind, + transpose_matrix=transpose_matrix, + transform_kind=transform_kind, + target_instruction=target_instruction, + ) + ladder_permutate = LadderPermutate( + config=ladder_permutate_config, + target=target, + ) + latency = ladder_permutate.profile_latency() + assert latency + +# fmt: on + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_lop3_permutate_ops.py b/testing/python/operators/test_lop3_permutate_ops.py new file mode 100644 index 000000000000..31017d91eeeb --- /dev/null +++ b/testing/python/operators/test_lop3_permutate_ops.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import tvm +import bitblas +from bitblas.ops.lop3_permutate import LOP3Permutate, LOP3PermutateConfig + +target = tvm.target.Target("llvm") + +# fmt: off +@pytest.mark.parametrize("M,N,datatype,dequantize_bits,storage_dtype", [ + (1024, 1024, "float16", 4, "uint32"), +]) +def test_lop3_permutate_profile_latency( + M, + N, + datatype, + dequantize_bits, + storage_dtype +): + + lop3_permutate_config = LOP3PermutateConfig( + M=M, + N=N, + datatype=datatype, + dequantize_bits=dequantize_bits, + storage_dtype=storage_dtype + ) + lop3_permutate = LOP3Permutate( + config=lop3_permutate_config, + target=target, + ) + latency = lop3_permutate.profile_latency() + assert latency +# fmt: on + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_matmul_dequantize_ops.py b/testing/python/operators/test_matmul_dequantize_ops.py new file mode 100644 index 000000000000..1dd80c3fd3b9 --- /dev/null +++ b/testing/python/operators/test_matmul_dequantize_ops.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import tvm +import bitblas +from bitblas.ops.matmul_dequantize import ( + MatmulWeightOnlyDequantize, + MatmulWeightOnlyDequantizeConfig, +) +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch + +target = tvm.target.Target("nvidia/nvidia-a100") + + +def get_codegen_result(ops, target): + code = ops.codegen(target=target) + return code + + +# fmt: off +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", + [ + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, -1, False, False, False, False, "nt"), + ], +) +def test_matmul_dequantize_codegen_default( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + group_size, + fast_decoding, + with_bias, + propagate_a, + propagate_b, + layout, +): + + matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, + ) + assert get_codegen_result(matmul, target) + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", + [ + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, -1, False, False, False, False, "nt",), + ], +) +def test_matmul_dequantize_codegen_finetune( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + group_size, + fast_decoding, + with_bias, + propagate_a, + propagate_b, + layout, +): + + matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, + ) + matmul.hardware_aware_finetune(topk=20) + assert get_codegen_result(matmul, target) + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", + [ + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, -1, False, False, False, False, "nt",), + ], +) +def test_matmul_dequantize_profile_latency( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + group_size, + fast_decoding, + with_bias, + propagate_a, + propagate_b, + layout, +): + + matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, + ) + matmul.hardware_aware_finetune(topk=20) + latency = matmul.profile_latency() + assert latency + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", + [ + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, -1, False, False, False, False, "nt",), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, -1, True, False, False, False, "nt",), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "int", False, -1, False, False, False, False, "nt",), + (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "int", False, -1, True, False, False, False, "nt",), + (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", False, -1, True, False, False, False, "nt",), + (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, -1, True, False, False, False, "nt",), + (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, 128, True, False, False, False, "nt",), + (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, 128, False, False, False, False, "nt",), + ], +) +def test_matmul_dequantize_torch_forward( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + group_size, + fast_decoding, + with_bias, + propagate_a, + propagate_b, + layout, +): + import torch + import numpy as np + from bitblas.quantization.utils import general_compress + + matmul_config = MatmulWeightOnlyDequantizeConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + bit=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = MatmulWeightOnlyDequantize( + config=matmul_config, + target=target, + ) + matmul.hardware_aware_finetune(topk=20) + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda()) + maxq = 2 ** (bit - 1) - 1 + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + ref_result = torch.matmul(inputs[0], (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) + + intweight = inputs[1] + intweight = intweight.cpu().numpy().astype(np.int8) + if source_format == "int": + intweight = intweight + maxq + # quantize to 4bit + qw_np = general_compress( + intweight, source_bits=bit, storage_dtype=np.int8 + ) + qw_torch = torch.from_numpy(qw_np).cuda() + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + if matmul.weight_transform is not None: + permuted_inputs.append( + matmul.weight_transform(qw_torch.cpu()).cuda() + ) + else: + permuted_inputs.append(qw_torch) + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + permuted_inputs.append(inputs[2]) + matmul(*permuted_inputs) + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) + +# fmt: on + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/operators/test_matmul_ops.py b/testing/python/operators/test_matmul_ops.py new file mode 100644 index 000000000000..86de6f9007a0 --- /dev/null +++ b/testing/python/operators/test_matmul_ops.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import tvm +import bitblas +from bitblas.ops.matmul import Matmul, MatmulConfig +from bitblas.utils import tvm_tensor_to_torch + +target = tvm.target.Target("nvidia/nvidia-a100") + + +def get_codegen_result(ops, target): + code = ops.codegen(target=target) + return code + + +# fmt: off +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", + [ + (16384, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt"), + # dynamic shape + ([1], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt"), + ], +) +def test_matmul_codegen_default( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + propagate_a, + propagate_b, + layout, +): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = Matmul( + config=matmul_config, + target=target, + ) + assert get_codegen_result(matmul, target) + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", + [ + (16384, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt"), + # dynamic shape + ([1], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt"), + ], +) +def test_matmul_codegen_finetune( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + propagate_a, + propagate_b, + layout, +): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = Matmul( + config=matmul_config, + target=target, + ) + matmul.hardware_aware_finetune(topk=20) + assert get_codegen_result(matmul, target) + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", + [ + (1024, 1024, 1024, "float16", "float16", "float16", False, False, False, "nt"), + ], +) +def test_matmul_profile_latency( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + propagate_a, + propagate_b, + layout, +): + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = Matmul( + config=matmul_config, + target=target, + ) + latency = matmul.profile_latency() + assert latency + + +@pytest.mark.parametrize( + "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", + [ + (256, 256, 256, "float16", "float16", "float16", False, False, False, "nt"), + ], +) +def test_matmul_torch_forward( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + propagate_a, + propagate_b, + layout, +): + import torch + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + with_bias=with_bias, + propagate_a=propagate_a, + propagate_b=propagate_b, + layout=layout, + ) + matmul = Matmul( + config=matmul_config, + target=target, + ) + + # convert tensors to torch + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda()) + inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda()) + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + ref_result = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) + + permuted_inputs = [] + if matmul.input_transform is not None: + permuted_input = tvm_tensor_to_torch( + matmul.input_transform.get_profile_tensors()[-1] + ) + permuted_inputs.append( + matmul.input_transform(inputs[0].cpu(), permuted_input.cpu()) + ).cuda() + else: + permuted_inputs.append(inputs[0]) + if matmul.weight_transform is not None: + permuted_input = tvm_tensor_to_torch( + matmul.weight_transform.get_profile_tensors()[-1] + ) + permuted_inputs.append( + matmul.weight_transform(inputs[1].cpu(), permuted_input.cpu()).cuda() + ) + else: + permuted_inputs.append(inputs[1]) + permuted_inputs.append(inputs[2]) + matmul(*permuted_inputs) + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) + +# fmt: on + +if __name__ == "__main__": + bitblas.testing.main() + diff --git a/testing/python/operators/test_weight_dequantize_matmul_codegen.py b/testing/python/operators/test_weight_dequantize_matmul_codegen.py new file mode 100644 index 000000000000..9abc3df4036c --- /dev/null +++ b/testing/python/operators/test_weight_dequantize_matmul_codegen.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import bitblas +from bitblas.ops import Matmul, MatmulWeightOnlyDequantize +import numpy as np + + +def test_weight_only_matmul_codegen_static_shape_optimize(): + M = 16384 + N = 16384 + K = 16384 + + target = tvm.target.Target("nvidia/nvidia-a100") + + matmul = MatmulWeightOnlyDequantize( + M=M, + N=N, + K=K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + propagate_b=True, + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + group_size=-1, + fast_decoding=True, + with_bias=False, + target=target, + ) + matmul.optimize(topk=20) + code = matmul.codegen(target=target) + latency = matmul.profile_latency() + print(latency) + assert code + + +def test_weight_only_matmul_codegen_static_shape_optimize_s8(): + M = 16384 + N = 16384 + K = 16384 + + target = tvm.target.Target("nvidia/nvidia-a100") + + matmul = MatmulWeightOnlyDequantize( + M=M, + N=N, + K=K, + in_dtype="int8", + out_dtype="int8", + accum_dtype="int32", + propagate_b=True, + bit=2, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + group_size=-1, + fast_decoding=True, + with_bias=False, + target=target, + ) + matmul.optimize() + code = matmul.codegen(target=target) + latency = matmul.profile_latency() + print(latency) + assert code + + +if __name__ == "__main__": + test_weight_only_matmul_codegen_static_shape_optimize() + # test_weight_only_matmul_codegen_static_shape_optimize_s8() diff --git a/testing/python/test_fused_decode_matmul.py b/testing/python/test_fused_decode_matmul.py new file mode 100644 index 000000000000..fb90ede368fe --- /dev/null +++ b/testing/python/test_fused_decode_matmul.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R + + +@T.prim_func +def fused_fused_decode3_fused_NT_matmul8_add1( + lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), + lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), + p_lv41: T.handle, + p_lv2: T.handle, + p_output0: T.handle, +): + T.func_attr( + { + "tir.noalias": T.bool(True), + "dequantize_info": { + "B": { + "decode_block": "decode", + "fast_decoding": True, + "source_format": { + "bits": 4, + "format": "int", + }, + "with_scaling": True, + "storage_dtype": "uint32", + "group_size": 32, + "target_format": "float16", + } + }, + } + ) + n = T.int64() + lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16") + lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16") + T_add_intermediate_intermediate = T.match_buffer( + p_output0, (T.int64(1), n, T.int64(4096)), "float16" + ) + # with T.block("root"): + decode_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) + T.writes(decode_intermediate_intermediate[v_i, v_j]) + decode_intermediate_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv48[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv41[v_i0, v_i1, v_k], decode_intermediate_intermediate[v_i2, v_k]) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k] + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv2[v_ax0, v_ax1, v_ax2], NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) + T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = ( + lv2[v_ax0, v_ax1, v_ax2] + NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + ) + + +import tvm +from tvm import dlight as dl +import bitblas +from tvm import relax +import bitblas +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags + +dispatch_target = tvm.target.Target("cuda") +mod_deploy = tvm.IRModule.from_expr(fused_fused_decode3_fused_NT_matmul8_add1.specialize({"n": T.int64(1)})) +target = tvm.target.Target("nvidia/nvidia-a100") +arch = CUDA(target) +func = fused_fused_decode3_fused_NT_matmul8_add1.specialize({"n": T.int64(1)}) +policy = DefaultPolicy(func=func, arch=arch) +try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) +except: + tags = None +if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + +configs = policy.emit_config(20) +print(configs[0]) +sch = bitblas.gpu.gemv.GEMVWithDequantizeInfo().apply_config(func, configs[0]) + +# print(sch.mod) +# with dispatch_target: +# mod_deploy = dl.ApplyDefaultSchedule( # pylint: disable=not-callable +# dl.gpu.Matmul(), +# dl.gpu.GEMV(), +# dl.gpu.Reduction(), +# dl.gpu.GeneralReduction(), +# dl.gpu.Fallback(), +# )(mod_deploy) +# dynamic_range = { +# "n": [64], +# } +# mod_deploy = bitblas.ApplyFastTuning( +# topk=20, +# target=dispatch_target, +# meta_database_dir="vicuna_tune", +# whitelist=["matmul"], +# )(mod_deploy) + +# with tvm.transform.PassContext(config={"tir.use_async_copy": False}): +# mod = tvm.build(mod_deploy, target=dispatch_target) + +# with open("debug/test_dl_fused_decode_matmul.cu", "+w") as f: +# f.write(mod.imported_modules[0].get_source()) diff --git a/testing/python/test_lop3_type_conversion.py b/testing/python/test_lop3_type_conversion.py new file mode 100644 index 000000000000..5fb32cfe1767 --- /dev/null +++ b/testing/python/test_lop3_type_conversion.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.script import tir as T +import bitblas +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.base.utils import apply_and_build +from bitblas.ops.matmul_impl import matmul_nt, matmul_nt_dequantize_b +import numpy as np + + +def test_f16_f16_gemm(): + ir_module = matmul_nt(1, 16384, 16384, "float16", "float16") + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print( + "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) + ) + + +def test_f16_i4_gemm(M=1, N=16384, K=16384, bit=4, fast_decoding=True): + ir_module = matmul_nt_dequantize_b( + M, + N, + K, + "float16", + bit=bit, + storage_dtype="uint32", + with_scaling=True, + group_size=-1, + fast_decoding=fast_decoding, + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + # sch = bitblas.gpu.gemv.GEMVWithDequantizeInfo().apply_config(func, configs[0]) + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print( + "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) + ) + with open("debug/tmp.cu", "w") as f: + f.write(str(best.code)) + + +test_f16_i4_gemm() diff --git a/testing/python/test_matmul_codegen.py b/testing/python/test_matmul_codegen.py new file mode 100644 index 000000000000..7fe7a483150c --- /dev/null +++ b/testing/python/test_matmul_codegen.py @@ -0,0 +1,221 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.script import ir as I +from tvm.script import tir as T +import bitblas +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.base.utils import apply_and_build +from bitblas.ops.matmul_impl import matmul_nt, matmul_nt_dequantize_b +import numpy as np + + +@I.ir_module +class Module: + @T.prim_func + def dequantize_gemv( + lv47: T.Buffer((T.int64(256), T.int64(256), T.int64(16), T.int64(2)), "uint32"), + lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), + lv41: T.Buffer((T.int64(1), 1, T.int64(4096)), "float16"), + NT_matmul_intermediate: T.Buffer((T.int64(1), 1, T.int64(4096)), "float16"), + ): + T.func_attr( + { + "dequantize_info": { + "decode": { + "decode_block": "decode", + "fast_decoding": T.bool(True), + "group_size": 32, + "source_format": {"bits": 4, "format": "int"}, + "storage_dtype": "uint32", + "target_format": "float16", + "with_scaling": T.bool(True), + } + }, + "smooth_b": T.bool(True), + "tir.noalias": T.bool(True), + "transform_kind": 2, + } + ) + # with T.block("root"): + decode_intermediate_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(4096)), "float16" + ) + lv47_global = T.alloc_buffer((T.int64(4096), T.int64(512)), "uint32") + for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)): + with T.block("lv47_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + lv47[ + v0 // T.int64(16), + v1 // T.int64(2), + v0 % T.int64(16) // T.int64(8) * T.int64(8) + + v0 % T.int64(4) * T.int64(2) + + v1 % T.int64(2), + v0 % T.int64(8) // T.int64(4), + ] + ) + T.writes(lv47_global[v0, v1]) + lv47_global[v0, v1] = lv47[ + v0 // T.int64(16), + v1 // T.int64(2), + v0 % T.int64(16) // T.int64(8) * T.int64(8) + + v0 % T.int64(4) * T.int64(2) + + v1 % T.int64(2), + v0 % T.int64(8) // T.int64(4), + ] + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads( + lv47_global[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)] + ) + T.writes(decode_intermediate_intermediate[v_i, v_j]) + decode_intermediate_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47_global[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv48[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads( + lv41[v_i0, v_i1, v_k], decode_intermediate_intermediate[v_i2, v_k] + ) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] + * decode_intermediate_intermediate[v_i2, v_k] + ) + + @T.prim_func + def dequantize_gemm( + lv47: T.Buffer((T.int64(256), T.int64(256), T.int64(16), T.int64(2)), "uint32"), + lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), + lv41: T.Buffer((T.int64(1), T.int64(4096), T.int64(4096)), "float16"), + NT_matmul_intermediate: T.Buffer((T.int64(1), 4096, T.int64(4096)), "float16"), + ): + T.func_attr( + { + "dequantize_info": { + "decode": { + "decode_block": "decode", + "fast_decoding": T.bool(True), + "group_size": 32, + "source_format": {"bits": 4, "format": "int"}, + "storage_dtype": "uint32", + "target_format": "float16", + "with_scaling": T.bool(True), + } + }, + "smooth_b": T.bool(True), + "tir.noalias": T.bool(True), + "transform_kind": 2, + } + ) + # with T.block("root"): + decode_intermediate_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(4096)), "float16" + ) + lv47_global = T.alloc_buffer((T.int64(4096), T.int64(512)), "uint32") + for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)): + with T.block("lv47_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + lv47[ + v0 // T.int64(16), + v1 // T.int64(2), + v0 % T.int64(16) // T.int64(8) * T.int64(8) + + v0 % T.int64(4) * T.int64(2) + + v1 % T.int64(2), + v0 % T.int64(8) // T.int64(4), + ] + ) + T.writes(lv47_global[v0, v1]) + lv47_global[v0, v1] = lv47[ + v0 // T.int64(16), + v1 // T.int64(2), + v0 % T.int64(16) // T.int64(8) * T.int64(8) + + v0 % T.int64(4) * T.int64(2) + + v1 % T.int64(2), + v0 % T.int64(8) // T.int64(4), + ] + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads( + lv47_global[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)] + ) + T.writes(decode_intermediate_intermediate[v_i, v_j]) + decode_intermediate_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47_global[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv48[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid( + T.int64(1), T.int64(4096), T.int64(4096), T.int64(4096) + ): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads( + lv41[v_i0, v_i1, v_k], decode_intermediate_intermediate[v_i2, v_k] + ) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] + * decode_intermediate_intermediate[v_i2, v_k] + ) + + +ir_module = Module +func = ir_module["dequantize_gemm"] +target = tvm.target.Target("nvidia/nvidia-a100") +arch = CUDA(target) +policy = DefaultPolicy(func=func, arch=arch) + +tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) +if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + +configs = policy.emit_config(20) +print(configs) +sch = bitblas.gpu.MatmulTensorizationMMAWithDequantizeInfo().apply_config( + func, configs[0] +) +# cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) +# print( +# "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( +# cpresults[0].latency * 1e3 +# ) +# ) +# print( +# "[BitBLAS] The best latency of top 20 is {:.3f} ms".format( +# best.latency * 1e3 +# ) +# ) +# with open("debug/tmp.cu", "w") as f: +# f.write(str(best.code)) diff --git a/testing/python/test_weight_only_transform.py b/testing/python/test_weight_only_transform.py new file mode 100644 index 000000000000..3e8e91979be4 --- /dev/null +++ b/testing/python/test_weight_only_transform.py @@ -0,0 +1,353 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import ir as I, relax as R, tir as T +from tvm import tir +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass +from bitblas.base.utils import get_dummy_input_arrays +from copy import deepcopy +import bitblas +from bitblas.relax.transform.annotate_decode_block import AnnotateDecodeInformation +from bitblas.relax.transform.weight_only_propagate import WeightOnlyLayoutPropagation +import numpy as np + +np.random.seed(0) + + +def get_ref_result(ref_mod, input_tensors): + # input_tensors to cpu + device = tvm.cpu(0) + target = tvm.target.Target("llvm") + input_tensors = [tvm.nd.array(x, device) for x in input_tensors] + ref_mod = tvm.tir.transform.MakePackedAPI()(ref_mod) + ex = relax.build(ref_mod, target) + vm = relax.VirtualMachine(ex, device) + res = vm["main"](*input_tensors) + return res + + +def get_default_result(ref_mod, input_tensors, target, device): + with target: + ref_mod = bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable + bitblas.gpu.GEMV(), + bitblas.gpu.Matmul(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + )(ref_mod) + ref_mod = tvm.tir.transform.MakePackedAPI()(ref_mod) + ex = relax.build(ref_mod, target) + vm = relax.VirtualMachine(ex, device) + res = vm["main"](*input_tensors) + return res + + +def get_fast_tune_result(ref_mod, input_tensors, target, device): + ref_mod = bitblas.ApplyFastTuning(target=target)(ref_mod) + ref_mod = tvm.tir.transform.MakePackedAPI()(ref_mod) + print(ref_mod) + ex = relax.build(ref_mod, target) + vm = relax.VirtualMachine(ex, device) + res = vm["main"](*input_tensors) + return res + + +def test_lop3_transform(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def fused_fused_decode3_fused_NT_matmul8_add1( + lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), + lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), + p_lv41: T.handle, + p_output0: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv41 = T.match_buffer(p_lv41, (T.int64(1), 1, T.int64(4096)), "float16") + NT_matmul_intermediate = T.match_buffer( + p_output0, (T.int64(1), 1, T.int64(4096)), "float16" + ) + # with T.block("root"): + decode_intermediate_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(4096)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) + T.writes(decode_intermediate_intermediate[v_i, v_j]) + decode_intermediate_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv48[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads( + lv41[v_i0, v_i1, v_k], + decode_intermediate_intermediate[v_i2, v_k], + ) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] + * decode_intermediate_intermediate[v_i2, v_k] + ) + + @R.function + def main( + lv47: R.Tensor((T.int64(4096), T.int64(512)), dtype="uint32"), + lv48: R.Tensor((T.int64(4096), T.int64(128)), dtype="float16"), # type: ignore + p_lv41: R.Tensor((T.int64(1), 1, T.int64(4096)), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + R.func_attr({"Primitive": 1}) + # n = T.int64() + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.fused_fused_decode3_fused_NT_matmul8_add1, + (lv47, lv48, p_lv41), + out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"), + ) + R.output(gv) + return gv + + relax_mod = Before + ref_mod = deepcopy(relax_mod) + dispatch_target = tvm.target.Target("cuda") + # input_arrays = get_dummy_input_arrays(relax_mod) + + relax_mod = AnnotateDecodeInformation()(relax_mod) + with dispatch_target: + relax_mod = WeightOnlyLayoutPropagation( + transform_level=0, faster_conversion=False + )(relax_mod) + + input_tensors = get_dummy_input_arrays(ref_mod["main"], tvm.cpu()) + + ref_mod = tvm.tir.transform.MakePackedAPI()(ref_mod) + ex = relax.build(ref_mod, "llvm") + + device = tvm.cpu(0) + vm = relax.VirtualMachine(ex, device) + res = vm["main"](*input_tensors) + print("ref ", res) + + print(relax_mod) + relax_mod = tvm.tir.transform.MakePackedAPI()(relax_mod) + ex = relax.build(relax_mod, "llvm") + + device = tvm.cpu(0) + vm = relax.VirtualMachine(ex, device) + res = vm["main"](*input_tensors) + print("relax ", res) + + +def test_matmul_transform(transform_level = 2): + + @I.ir_module + class Before: + @T.prim_func(private=True) + def fused_fused_decode3_fused_NT_matmul8_add1( + p_lv41: T.handle, + lv47: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + p_output0: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv41 = T.match_buffer(p_lv41, (T.int64(1), 1, T.int64(4096)), "float16") + NT_matmul_intermediate = T.match_buffer( + p_output0, (T.int64(1), 1, T.int64(4096)), "float16" + ) + + for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv41[v_i0, v_i1, v_k], lv47[v_i2, v_k]) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * lv47[v_i2, v_k] + ) + + @R.function + def main( + lv47: R.Tensor((T.int64(4096), T.int64(4096)), dtype="float16"), + p_lv41: R.Tensor((T.int64(1), 1, T.int64(4096)), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + R.func_attr({"Primitive": 1}) + # n = T.int64() + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.fused_fused_decode3_fused_NT_matmul8_add1, + (p_lv41, lv47), + out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"), + ) + R.output(gv) + return gv + + relax_mod = Before + ref_mod = deepcopy(relax_mod) + dispatch_target = tvm.target.Target("cuda") + + relax_mod = AnnotateDecodeInformation()(relax_mod) + with dispatch_target: + relax_mod = WeightOnlyLayoutPropagation( + transform_level=transform_level, faster_conversion=False + )(relax_mod) + + input_tensors = get_dummy_input_arrays(ref_mod["main"], tvm.cpu()) + + ref_mod = tvm.tir.transform.MakePackedAPI()(ref_mod) + ex = relax.build(ref_mod, "llvm") + + device = tvm.cpu(0) + vm = relax.VirtualMachine(ex, device) + res = vm["main"](*input_tensors) + print("ref ", res) + + print(relax_mod) + relax_mod = tvm.tir.transform.MakePackedAPI()(relax_mod) + ex = relax.build(relax_mod, "llvm") + + device = tvm.cpu(0) + vm = relax.VirtualMachine(ex, device) + res = vm["main"](*input_tensors) + print("relax ", res) + + +def test_dequantize_matmul_transform(transform_level=1): + + @I.ir_module + class Before: + @T.prim_func(private=True) + def fused_fused_decode3_fused_NT_matmul8_add1( + lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), + lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), + p_lv41: T.handle, + p_output0: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv41 = T.match_buffer( + p_lv41, (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ) + NT_matmul_intermediate = T.match_buffer( + p_output0, (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ) + # with T.block("root"): + decode_intermediate_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(4096)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) + T.writes(decode_intermediate_intermediate[v_i, v_j]) + decode_intermediate_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv48[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads( + lv41[v_i0, v_i1, v_k], + decode_intermediate_intermediate[v_i2, v_k], + ) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] + * decode_intermediate_intermediate[v_i2, v_k] + ) + + @R.function + def main( + lv47: R.Tensor((T.int64(4096), T.int64(512)), dtype="uint32"), + lv48: R.Tensor((T.int64(4096), T.int64(128)), dtype="float16"), # type: ignore + p_lv41: R.Tensor((T.int64(1), T.int64(1), T.int64(4096)), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + R.func_attr({"Primitive": 1}) + # n = T.int64() + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.fused_fused_decode3_fused_NT_matmul8_add1, + (lv47, lv48, p_lv41), + out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"), + ) + R.output(gv) + return gv + + relax_mod = Before + ref_mod = deepcopy(relax_mod) + dispatch_target = tvm.target.Target("cuda") + # input_arrays = get_dummy_input_arrays(relax_mod) + device = tvm.cpu(0) + if dispatch_target.kind.name == "cuda": + device = tvm.cuda(0) + + relax_mod = AnnotateDecodeInformation()(relax_mod) + with dispatch_target: + relax_mod = WeightOnlyLayoutPropagation( + transform_level=transform_level, faster_conversion=False + )(relax_mod) + input_tensors = get_dummy_input_arrays(ref_mod["main"], device) + + print("=======================ref llvm result=======================") + # ref_res = get_ref_result(ref_mod, input_tensors) + # print("ref_mod", ref_res) + # bitblas_res = get_ref_result(relax_mod, input_tensors) + # print("bitblas_res", bitblas_res) + print("=======================default gpu result=======================") + # ref_res = get_default_result(ref_mod, input_tensors, dispatch_target, device) + # print("ref_mod", ref_res) + # bitblas_res = get_default_result(relax_mod, input_tensors, dispatch_target, device) + # print("bitblas_res", bitblas_res) + # print("=======================fast tune gpu result=======================") + ref_res = get_fast_tune_result(ref_mod, input_tensors, dispatch_target, device) + print("ref_mod", ref_res) + print(relax_mod) + bitblas_res = get_fast_tune_result( + relax_mod, input_tensors, dispatch_target, device + ) + print("bitblas_res", bitblas_res) + + +# test_lop3_transform() +# test_matmul_transform() +# test_dequantize_matmul_transform() diff --git a/testing/python/tir_expr/float16xfloat16_gemm.py b/testing/python/tir_expr/float16xfloat16_gemm.py new file mode 100644 index 000000000000..75fc24ef0939 --- /dev/null +++ b/testing/python/tir_expr/float16xfloat16_gemm.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.script import tir as T +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.base.utils import apply_and_build +from bitblas.ops.matmul_impl import matmul_nt, matmul_nt_propagate_b_s8_s8_s32_mma +import numpy as np + + +def test_f16_f16_gemm(): + ir_module = matmul_nt(1024, 1024, 1024, "float16", "float16") + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(1) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3) + ) + + numpy_a = np.random.randint(-4, 3, (1024, 1024)).astype("float16") + numpy_b = np.random.randint(-4, 3, (1024, 1024)).astype("float16") + numpy_c = np.matmul(numpy_a.astype("float16"), numpy_b.T.astype("float16")) + ctx = tvm.cuda() + tvm_a = tvm.nd.array(numpy_a, device=ctx) + tvm_b = tvm.nd.array(numpy_b, device=ctx) + tvm_c = tvm.nd.array(np.zeros((1024, 1024), dtype="float16"), device=ctx) + print(best.code) + best.mod(tvm_a, tvm_b, tvm_c) + print(best.config) + print("numpy_c ", numpy_c) + print("tvm_c.asnumpy() ", tvm_c.asnumpy()) + + +def test_i8_i8_gemm_propagate_b(): + ir_module = matmul_nt_propagate_b_s8_s8_s32_mma( + 16384, 16384, 16384, "int8", "int32" + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(1) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3) + ) + print(best.sch.mod) + + +test_f16_f16_gemm() +# test_i8_i8_gemm_propagate_b() diff --git a/testing/python/tir_expr/int8xint8_gemm.py b/testing/python/tir_expr/int8xint8_gemm.py new file mode 100644 index 000000000000..1515c307cd22 --- /dev/null +++ b/testing/python/tir_expr/int8xint8_gemm.py @@ -0,0 +1,368 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import bitblas +import numpy as np +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.roller.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.base.utils import apply_and_build +from bitblas.ops.matmul_impl import ( + matmul_nt, + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, + matmul_nt_dequantize_b_propagate_a_b, + matmul_nt_propagate_b_s8_s8_s32_mma, + matmul_nt_propagate_b_s8_s8_s32_cast_s8_mma, + matmul_nt_propagate_a_propagate_b_s8_s8_s32_mma, + matmul_nt_propagate_a_propagate_b_s8_s8_s32_mma_cast_s8, +) + + +def test_i8_i8_gemm(): + ir_module = matmul_nt(16384, 16384, 16384, "int8", "int32") + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + with open("debug/after_memory_rewrite.cu", "+w") as f: + f.write(best.code) + + +def test_i8_i8_gemm_correctness(): + ir_module = matmul_nt(1024, 1024, 1024, "int8", "int32") + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + + numpy_a = np.random.randint(-4, 3, (1024, 1024)).astype("int8") + numpy_b = np.random.randint(-4, 3, (1024, 1024)).astype("int8") + numpy_c = np.matmul(numpy_a.astype("int32"), numpy_b.T.astype("int32")) + ctx = tvm.cuda() + tvm_a = tvm.nd.array(numpy_a, device=ctx) + tvm_b = tvm.nd.array(numpy_b, device=ctx) + tvm_c = tvm.nd.array(np.zeros((1024, 1024), dtype="int32"), device=ctx) + # print(best.sch.mod) + # print(best.code) + best.mod(tvm_a, tvm_b, tvm_c) + print(best.config) + print("numpy_c ", numpy_c) + print("tvm_c.asnumpy() ", tvm_c.asnumpy()) + + np.testing.assert_allclose(tvm_c.asnumpy(), numpy_c, atol=1e-5) + # print(best.code) + + +def test_i8_i8_i32_gemm_propagate_b(): + ir_module = matmul_nt_propagate_b_s8_s8_s32_mma( + 16384, 16384, 16384, "int8", "int32" + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + + +def test_i8_i8_i32_cast_i8_gemm_propagate_b(): + ir_module = matmul_nt_propagate_b_s8_s8_s32_cast_s8_mma( + 16384, 16384, 16384, "int8", "int32" + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + + +def test_i8_i8_i32_gemm_propagate_a_propagate_b(): + ir_module = matmul_nt_propagate_a_propagate_b_s8_s8_s32_mma( + 16384, 16384, 16384, "int8", "int32" + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + + +def test_i8_i8_i32_gemm_propagate_a_propagate_b_cast_s8(): + ir_module = matmul_nt_propagate_a_propagate_b_s8_s8_s32_mma_cast_s8( + 16384, 16384, 16384, "int8", "int32" + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + + +def test_i8_i4_gemm(): + ir_module = matmul_nt_dequantize_b(16384, 16384, 16384, "int8", "int32") + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + + +def test_i8_i4_propagate_b_gemm(): + ir_module = matmul_nt_dequantize_b_propagate_b(16384, 16384, 16384, "int8", "int32") + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + # print(best.sch.mod) + print(best.code) + + +def test_i8_i4_propagate_a_propagate_b_gemm(): + ir_module = matmul_nt_dequantize_b_propagate_a_b( + 16384, 16384, 16384, "int8", "int32" + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + print(best.config) + + +def test_i8_i2_gemm(): + ir_module = matmul_nt_dequantize_b(1, 16384, 16384, "int8", "int32", bit=2) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + print(configs) + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + print(best.code) + + +def test_i8_i2_propagate_b_gemm(): + ir_module = matmul_nt_dequantize_b_propagate_b( + 16384, + 16384, + 16384, + "int8", + "int8", + accum_dtype="int32", + bit=2, + fast_decoding=True, + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + with open("debug/after_memory_rewrite.cu", "+w") as f: + f.write(best.code) + + +def test_i8_i2_propagate_a_propagate_b_gemm(): + ir_module = matmul_nt_dequantize_b_propagate_a_b( + 16384, 16384, 16384, "int8", "int32", "int8", bit=2, fast_decoding=False + ) + func = ir_module["main"] + target = tvm.target.Target("nvidia/nvidia-a100") + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) + print( + "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( + cpresults[0].latency * 1e3 + ) + ) + print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(best.latency * 1e3)) + with open("debug/after_memory_rewrite.cu", "+w") as f: + f.write(best.code) + + +# test_i8_i8_gemm() +# test_i8_i8_gemm_correctness() +# test_i8_i8_i32_gemm_propagate_b() +# test_i8_i8_i32_cast_i8_gemm_propagate_b() +# test_i8_i8_i32_gemm_propagate_a_propagate_b() +# test_i8_i8_i32_gemm_propagate_a_propagate_b_cast_s8() +# test_i8_i4_gemm() +# test_i8_i4_propagate_b_gemm() +# test_i8_i4_propagate_a_propagate_b_gemm() + +test_i8_i2_gemm() +# test_i8_i2_propagate_b_gemm() +# test_i8_i2_propagate_a_propagate_b_gemm() diff --git a/testing/python/tir_expr/test_tir.py b/testing/python/tir_expr/test_tir.py new file mode 100644 index 000000000000..819444d0f806 --- /dev/null +++ b/testing/python/tir_expr/test_tir.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Metadata omitted. Use show_meta=True in script() method to show it. +from tvm.script import ir as I +from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A: T.Buffer((1, 16384), "float16"), B: T.Buffer((16384, 8192), "int8"), Scale: T.Buffer((16384, 512), "float16"), D: T.Buffer((1, 16384), "float16")): + T.func_attr({"dequantize_info": {"B": {"decode_block": "B_decode", "fast_decoding": T.bool(True), "group_size": 32, "source_format": {"bits": 4, "format": "uint"}, "target_format": "float16", "with_scaling": T.bool(True)}}, "tir.noalias": T.bool(True)}) + # with T.block("root"): + B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") + A_local = T.alloc_buffer((1, 16384), "float16", scope="local") + B_local = T.alloc_buffer((16384, 8192), "int8", scope="local") + C_local = T.alloc_buffer((1, 16384), "float16", scope="local") + for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"): + for ax0_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax1_0 in range(32): + for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in range(1): + for ax1 in T.vectorized(4): + with T.block("B_local"): + v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) + v1 = T.axis.spatial(8192, ax1_0 * 256 + ax1_1 * 4 + ax1) + T.reads(B[v0, v1]) + T.writes(B_local[v0, v1]) + B_local[v0, v1] = B[v0, v1] + for ax0 in range(1): + with T.block("B_decode_local_o"): + v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) + v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1) + T.reads(B_local[v0_o, v1_o * 4:v1_o * 4 + 4], Scale[v0_o, v1_o // 4]) + T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8]) + Compressed = T.match_buffer(B_local[v0_o, v1_o * 4:v1_o * 4 + 4], (4,), "int8", scope="local") + Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local") + # Scale_1 = T.match_buffer(Scale[v0_o, v1_o // 4: v1_o // 4 + 1], (1,), "float16") + Scale_1 = T.match_buffer(Scale[v0_o, v1_o // 4], (1,), "float16", elem_offset=Scale.elem_offset) + T.call_extern("handle", "decode_i4s_to_f16_scale", Compressed.data, Decompressed.data, Scale_1.access_ptr("r"), 8) + for ax0 in range(1): + for ax1 in T.vectorized(8): + with T.block("A_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1) + T.reads(A[v0, v1]) + T.writes(A_local[v0, v1]) + A_local[v0, v1] = A[v0, v1] + for ax1_2 in range(8): + with T.block("C"): + v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1) + v1 = T.axis.reduce(16384, ax1_0 * 512 + ax1_1 * 8 + ax1_2) + T.reads(A_local[0, v1], B_decode_local[v0, v1]) + T.writes(C_local[0, v0]) + with T.init(): + C_local[0, v0] = T.float16(0) + C_local[0, v0] = C_local[0, v0] + A_local[0, v1] * B_decode_local[v0, v1] + for ax0, ax1 in T.grid(1, 1): + with T.block("C_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax1) + T.reads(C_local[v0, v1]) + T.writes(D[0, v1]) + D[0, v1] = C_local[v0, v1] + + +import tvm +mod = Module +sch = tvm.tir.Schedule(mod, debug_mask="all") +with tvm.transform.PassContext( + config={"tir.use_async_copy": True} + ): + dense_relu_0_rt_mod = tvm.build(sch.mod, target="cuda") +with open("debug/after_memory_rewrite.cu", "+w") as f: + f.write(dense_relu_0_rt_mod.imported_modules[0].get_source()) diff --git a/testing/python/tir_expr/test_tir_0.py b/testing/python/tir_expr/test_tir_0.py new file mode 100644 index 000000000000..bd33ce8d98f5 --- /dev/null +++ b/testing/python/tir_expr/test_tir_0.py @@ -0,0 +1,189 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.tir.tensor_intrin.cuda import get_mma_intrin_group + +@I.ir_module +class Module: + @T.prim_func + def main(A: T.Buffer((1024, 512, 16, 32), "int8"), B: T.Buffer((1024, 512, 16, 8), "int8"), C: T.Buffer((16384, 16384), "int32")): + T.func_attr({"dequantize_info": {"B": {"decode_block": "B_decode", "fast_decoding": T.bool(False), "source_format": {"bits": 2, "format": "int"}, "target_format": "int8"}}, "dlight.tensorcore_prenormlized": T.bool(True), "smooth_a": T.bool(True), "smooth_b": T.bool(True), "tir.noalias": T.bool(True)}) + # with T.block("root"): + A_reindex_reindex_shared = T.alloc_buffer((1, 1024, 512, 16, 32), "int8", scope="shared") + B_reindex_reindex_shared = T.alloc_buffer((1, 1024, 512, 16, 32), "int8", scope="shared") + B_reindex_reindex_local = T.alloc_buffer((1, 1024, 512, 16, 32), "int8", scope="local") + B_local = T.alloc_buffer((1024, 512, 16, 8), "int8", scope="local") + B_shared = T.alloc_buffer((1024, 512, 16, 8), "int8", scope="shared") + A_reindex_reindex_shared_warp = T.alloc_buffer((1, 1024, 512, 32, 16), "int8", scope="warp") + B_reindex_reindex_shared_warp = T.alloc_buffer((1, 1024, 512, 32, 16), "int8", scope="warp") + C_reindex_shared = T.alloc_buffer((1, 1024, 1024, 16, 16), "int32", scope="shared") + C_reindex_shared_warp = T.alloc_buffer((1, 1024, 1024, 32, 8), "int32", scope="warp") + for ax0 in range(1): + for ax1_0_0_ax2_0_0_fused in T.thread_binding(64, thread="blockIdx.y"): + for ax1_0_1_ax2_0_1_fused in T.thread_binding(256, thread="blockIdx.x"): + for ax1_0_2 in T.thread_binding(2, thread="threadIdx.y"): + for ax2_0_2 in T.thread_binding(2, thread="threadIdx.z"): + for ax1_0_3_init, ax2_0_3_init in T.grid(8, 2): + with T.block("C_o_init"): + v0_o = T.axis.spatial(1, ax0) + v1_o = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax1_0_3_init) + v2_o = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax2_0_3_init) + T.reads() + T.writes(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8]) + with T.block("C_init_o"): + v1_i_init_o = T.axis.spatial(1, 0) + v2_i_init_o = T.axis.spatial(1, 0) + T.reads() + T.writes(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8]) + C_warp = T.match_buffer(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8], (32, 8), "int32", scope="warp", offset_factor=1) + for tx in T.thread_binding(32, thread="threadIdx.x"): + T.mma_fill("int32", 8, C_warp.data, C_warp.elem_offset) + for ax3_0_0 in T.serial(256, annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2, 3], "software_pipeline_stage": [0, 0, 1, 1]}): + for ax0_ax1_ax2_ax3_ax4_fused_0 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_ax2_ax3_ax4_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_ax2_ax3_ax4_fused_2 in T.unroll(8, annotations={"pragma_unroll_explicit": 0}): + for ax0_ax1_ax2_ax3_ax4_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_ax4_fused_4 in T.vectorized(16): + with T.block("A_reindex_reindex_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + (ax0_ax1_ax2_ax3_ax4_fused_0 * 8192 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4096 + ax0_ax1_ax2_ax3_ax4_fused_2 * 512 + ax0_ax1_ax2_ax3_ax4_fused_3 * 16 + ax0_ax1_ax2_ax3_ax4_fused_4) // 1024) + v2 = T.axis.spatial(512, ax3_0_0 * 2 + (ax0_ax1_ax2_ax3_ax4_fused_0 * 8192 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4096 + ax0_ax1_ax2_ax3_ax4_fused_2 * 512 + ax0_ax1_ax2_ax3_ax4_fused_3 * 16 + ax0_ax1_ax2_ax3_ax4_fused_4) % 1024 // 512) + v3 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_ax4_fused_0 * 8192 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4096 + ax0_ax1_ax2_ax3_ax4_fused_2 * 512 + ax0_ax1_ax2_ax3_ax4_fused_3 * 16 + ax0_ax1_ax2_ax3_ax4_fused_4) % 512 // 32) + v4 = T.axis.spatial(32, (ax0_ax1_ax2_ax3_ax4_fused_0 * 8192 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4096 + ax0_ax1_ax2_ax3_ax4_fused_2 * 512 + ax0_ax1_ax2_ax3_ax4_fused_3 * 16 + ax0_ax1_ax2_ax3_ax4_fused_4) % 32) + T.reads(A[v1, v2, v3, v4]) + T.writes(A_reindex_reindex_shared[v0, v1, v2, v3, v4]) + T.block_attr({"permuted_layout": 0}) + A_reindex_reindex_shared[v0, v1, v2, v3, v4] = A[v1, v2, v3, v4] + for ax0_ax1_ax2_ax3_fused_0 in T.unroll(1, annotations={"pragma_unroll_explicit": 0}): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_ax2_ax3_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_4 in T.vectorized(16): + with T.block("B_shared"): + v0 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + (ax0_ax1_ax2_ax3_fused_0 * 2048 + ax0_ax1_ax2_ax3_fused_1 * 1024 + ax0_ax1_ax2_ax3_fused_2 * 512 + ax0_ax1_ax2_ax3_fused_3 * 16 + ax0_ax1_ax2_ax3_fused_4) // 256) + v1 = T.axis.spatial(512, ax3_0_0 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 2048 + ax0_ax1_ax2_ax3_fused_1 * 1024 + ax0_ax1_ax2_ax3_fused_2 * 512 + ax0_ax1_ax2_ax3_fused_3 * 16 + ax0_ax1_ax2_ax3_fused_4) % 256 // 128) + v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_fused_0 * 2048 + ax0_ax1_ax2_ax3_fused_1 * 1024 + ax0_ax1_ax2_ax3_fused_2 * 512 + ax0_ax1_ax2_ax3_fused_3 * 16 + ax0_ax1_ax2_ax3_fused_4) % 128 // 8) + v3 = T.axis.spatial(8, (ax0_ax1_ax2_ax3_fused_0 * 2048 + ax0_ax1_ax2_ax3_fused_1 * 1024 + ax0_ax1_ax2_ax3_fused_2 * 512 + ax0_ax1_ax2_ax3_fused_3 * 16 + ax0_ax1_ax2_ax3_fused_4) % 8) + T.where((((ax0_ax1_ax2_ax3_fused_0 * 2 + ax0_ax1_ax2_ax3_fused_1) * 2 + ax0_ax1_ax2_ax3_fused_2) * 32 + ax0_ax1_ax2_ax3_fused_3) * 16 + ax0_ax1_ax2_ax3_fused_4 < 1024) + T.reads(B[v0, v1, v2, v3]) + T.writes(B_shared[v0, v1, v2, v3]) + B_shared[v0, v1, v2, v3] = B[v0, v1, v2, v3] + for ax0_1, ax1_ax2_ax3_ax4_0_fused_0 in T.grid(1, 2): + for ax1_ax2_ax3_ax4_0_fused_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax1_ax2_ax3_ax4_0_fused_2 in T.thread_binding(2, thread="threadIdx.z"): + for ax1_ax2_ax3_ax4_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax4_1 in range(1): + for ax0_2, ax1, ax2 in T.grid(1, 1, 1): + for ax3 in T.vectorized(4): + with T.block("B_local"): + v0 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) // 64 + ax0_2) + v1 = T.axis.spatial(512, ax3_0_0 * 2 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 64 // 32 + ax1) + v2 = T.axis.spatial(16, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 32 // 2 + ax2) + v3 = T.axis.spatial(8, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 2 * 4 + ax3) + T.reads(B_shared[v0, v1, v2, v3]) + T.writes(B_local[v0, v1, v2, v3]) + B_local[v0, v1, v2, v3] = B_shared[v0, v1, v2, v3] + for ax0_2, ax1, ax2, ax3, ax4 in T.grid(1, 1, 1, 1, 16): + with T.block("B_reindex_reindex_local"): + v0 = T.axis.spatial(1, ax0_2) + v1 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) // 64 + ax1) + v2 = T.axis.spatial(512, ax3_0_0 * 2 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 64 // 32 + ax2) + v3 = T.axis.spatial(16, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 32 // 2 + ax3) + v4 = T.axis.spatial(32, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 2 * 16 + ax4) + T.reads(B_local[v1, v2, v3, v4 // 4]) + T.writes(B_reindex_reindex_local[v0, v1, v2, v3, v4]) + B_reindex_reindex_local[v0, v1, v2, v3, v4] = T.bitwise_and(T.shift_right(B_local[v1, v2, v3, v4 // 4], T.Cast("int8", v4 % 4 * 2)), T.int8(3)) + for ax4_2 in T.vectorized(16): + with T.block("B_reindex_reindex_shared"): + v0 = T.axis.spatial(1, ax0_1) + v1 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) // 64) + v2 = T.axis.spatial(512, ax3_0_0 * 2 + (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 64 // 32) + v3 = T.axis.spatial(16, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 32 // 2) + v4 = T.axis.spatial(32, (ax1_ax2_ax3_ax4_0_fused_0 * 128 + ax1_ax2_ax3_ax4_0_fused_1 * 64 + ax1_ax2_ax3_ax4_0_fused_2 * 32 + ax1_ax2_ax3_ax4_0_fused_3) % 2 * 16 + ax4_1 * 16 + ax4_2) + T.reads(B_reindex_reindex_local[v0, v1, v2, v3, v4]) + T.writes(B_reindex_reindex_shared[v0, v1, v2, v3, v4]) + T.block_attr({"permuted_layout": 0}) + B_reindex_reindex_shared[v0, v1, v2, v3, v4] = B_reindex_reindex_local[v0, v1, v2, v3, v4] + for ax3_0_1 in range(2): + for ax0_1, ax1, ax2, ax3_0, ax4_0 in T.grid(1, 8, 1, 1, 1): + with T.block("A_reindex_reindex_shared_warp_o"): + v0_o = T.axis.spatial(1, ax0_1) + v1_o = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax1) + v2_o = T.axis.spatial(512, ax3_0_0 * 2 + ax3_0_1 + ax2) + v3_o, v4_o = T.axis.remap("SS", [ax3_0, ax4_0]) + T.reads(A_reindex_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:32]) + T.writes(A_reindex_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:16]) + T.block_attr({"permuted_layout": 0}) + warp = T.match_buffer(A_reindex_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:16], (32, 16), "int8", scope="warp", offset_factor=32) + shared = T.match_buffer(A_reindex_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:32], (16, 32), "int8", strides=("shared_s0", "shared_s1"), scope="shared", offset_factor=32) + for tx in T.thread_binding(32, thread="threadIdx.x"): + T.ptx_ldmatrix("int8", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + 16 * tx, T.tvm_access_ptr(T.type_annotation("int8"), shared.data, shared.elem_offset, shared.strides[0] * 16, 1), tx * 16) + for ax0_1, ax1, ax2, ax3_0, ax4_0 in T.grid(1, 2, 1, 1, 1): + with T.block("B_reindex_reindex_shared_warp_o"): + v0_o = T.axis.spatial(1, ax0_1) + v1_o = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax1) + v2_o = T.axis.spatial(512, ax3_0_0 * 2 + ax3_0_1 + ax2) + v3_o, v4_o = T.axis.remap("SS", [ax3_0, ax4_0]) + T.reads(B_reindex_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:32]) + T.writes(B_reindex_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:16]) + T.block_attr({"permuted_layout": 0}) + warp = T.match_buffer(B_reindex_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:16], (32, 16), "int8", scope="warp", offset_factor=32) + shared = T.match_buffer(B_reindex_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:32], (16, 32), "int8", strides=("shared_s0", "shared_s1"), scope="shared", offset_factor=32) + for tx in T.thread_binding(32, thread="threadIdx.x"): + T.ptx_ldmatrix("int8", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + 16 * tx, T.tvm_access_ptr(T.type_annotation("int8"), shared.data, shared.elem_offset, shared.strides[0] * 16, 1), tx * 16) + for ax1_0_3, ax2_0_3 in T.grid(8, 2): + with T.block("C_o_update"): + v0_o = T.axis.spatial(1, ax0) + v1_o = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax1_0_3) + v2_o = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax2_0_3) + v3_o = T.axis.reduce(512, ax3_0_0 * 2 + ax3_0_1) + T.reads(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8], A_reindex_reindex_shared_warp[0, v1_o, v3_o, 0:32, 0:16], B_reindex_reindex_shared_warp[0, v2_o, v3_o, 0:32, 0:16]) + T.writes(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8]) + with T.block("C_o"): + v1_i_o = T.axis.spatial(1, 0) + v2_i_o = T.axis.spatial(1, 0) + v3_i_o = T.axis.reduce(1, 0) + T.reads(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8], A_reindex_reindex_shared_warp[0, v1_o, v3_o, 0:32, 0:16], B_reindex_reindex_shared_warp[0, v2_o, v3_o, 0:32, 0:16]) + T.writes(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8]) + A_1 = T.match_buffer(A_reindex_reindex_shared_warp[0, v1_o, v3_o, 0:32, 0:16], (32, 16), "int8", scope="warp", offset_factor=32) + B_1 = T.match_buffer(B_reindex_reindex_shared_warp[0, v2_o, v3_o, 0:32, 0:16], (32, 16), "int8", scope="warp", offset_factor=32) + C_1 = T.match_buffer(C_reindex_shared_warp[0, v1_o, v2_o, 0:32, 0:8], (32, 8), "int32", scope="warp", offset_factor=16) + for tx in T.thread_binding(32, thread="threadIdx.x"): + T.ptx_mma("int32", "m16n8k32", "row", "col", "int8", "int8", "int32", A_1.data, A_1.elem_offset + tx * 16, B_1.data, B_1.elem_offset + tx * 16, C_1.data, C_1.elem_offset + tx * 8, T.bool(False)) + T.ptx_mma("int32", "m16n8k32", "row", "col", "int8", "int8", "int32", A_1.data, A_1.elem_offset + tx * 16, B_1.data, B_1.elem_offset + tx * 16 + 8, C_1.data, C_1.elem_offset + tx * 8 + 4, T.bool(False)) + for ax0_1, ax1 in T.grid(8, 2): + for ax2_0, ax3_0 in T.grid(1, 1): + with T.block("C_reindex_shared_warp_o"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax0_1) + v2_o = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax1) + v3_o, v4_o = T.axis.remap("SS", [ax2_0, ax3_0]) + T.reads(C_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:8]) + T.writes(C_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:16]) + C_warp = T.match_buffer(C_reindex_shared_warp[v0_o, v1_o, v2_o, 0:32, 0:8], (32, 8), "int32", scope="warp", offset_factor=1) + C_1 = T.match_buffer(C_reindex_shared[v0_o, v1_o, v2_o, 0:16, 0:16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared", offset_factor=1) + for tx in T.thread_binding(32, thread="threadIdx.x"): + T.mma_store("int32", 16, 16, T.tvm_access_ptr(T.type_annotation("int32"), C_1.data, C_1.elem_offset, C_1.strides[0] * 16, 2), C_warp.data, C_warp.elem_offset, C_1.strides[0]) + for ax0_ax1_ax2_ax3_ax4_fused_0 in T.unroll(2, annotations={"pragma_unroll_explicit": 0}): + for ax0_ax1_ax2_ax3_ax4_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_ax4_fused_2 in T.vectorized(4): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1024, ax1_0_0_ax2_0_0_fused * 16 + ax1_0_2 * 8 + ax0_1) + v2 = T.axis.spatial(1024, ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2 * 2 + ax1) + v3 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_ax4_fused_0 * 128 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4 + ax0_ax1_ax2_ax3_ax4_fused_2) // 16) + v4 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_ax4_fused_0 * 128 + ax0_ax1_ax2_ax3_ax4_fused_1 * 4 + ax0_ax1_ax2_ax3_ax4_fused_2) % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4]) + T.writes(C[v3 + v1 * 16, v4 + v2 * 16]) + C[v3 + v1 * 16, v4 + v2 * 16] = C_reindex_shared[v0, v1, v2, v3, v4] + +mod = Module +sch = tvm.tir.Schedule(mod, debug_mask="all") +with tvm.transform.PassContext( + config={"tir.use_async_copy": True} + ): + dense_relu_0_rt_mod = tvm.build(sch.mod, target="cuda") +with open("after_memory_rewrite.cu", "+w") as f: + f.write(dense_relu_0_rt_mod.imported_modules[0].get_source()) diff --git a/testing/python/tir_expr/test_tir_1.py b/testing/python/tir_expr/test_tir_1.py new file mode 100644 index 000000000000..49d1f71ec372 --- /dev/null +++ b/testing/python/tir_expr/test_tir_1.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.tir.tensor_intrin.cuda import * + +# from tvm.script import tir as T +@T.prim_func +def main(input0: T.Buffer[(1024, 512, 16, 32), "int8"], input1: T.Buffer[(1024, 512, 16, 8), "int8"], output0: T.Buffer[(16384, 16384), "int8"]): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + tx = T.env_thread("threadIdx.x") + C_s0 = T.var("int32") + C_s1 = T.var("int32") + shared_s0 = T.var("int32") + shared_s0_1 = T.var("int32") + shared_s1 = T.var("int32") + shared_s1_1 = T.var("int32") + # body + # with T.block("root") + input0_shared = T.alloc_buffer([1024, 512, 16, 32], dtype="int8", scope="shared") + mediate0_shared = T.alloc_buffer([1024, 512, 16, 32], dtype="int8", scope="shared") + mediate1_shared = T.alloc_buffer([1024, 1024, 16, 16], dtype="int32", scope="shared") + mediate1_shared_warp = T.alloc_buffer([1024, 1024, 32, 8], dtype="int32", scope="warp") + mediate0_local = T.alloc_buffer([1024, 512, 16, 32], dtype="int8", scope="local") + input1_shared = T.alloc_buffer([1024, 512, 16, 8], dtype="int8", scope="shared") + input1_shared_local = T.alloc_buffer([1024, 512, 16, 8], dtype="int8", scope="local") + input0_shared_warp = T.alloc_buffer([1024, 512, 32, 16], dtype="int8", scope="warp") + mediate0_shared_warp = T.alloc_buffer([1024, 512, 32, 16], dtype="int8", scope="warp") + for i_0 in T.thread_binding(256, thread="blockIdx.y"): + for j_0 in T.thread_binding(64, thread="blockIdx.x"): + for i_1 in T.thread_binding(2, thread="threadIdx.y"): + for j_1 in T.thread_binding(2, thread="threadIdx.z"): + for i_2_init in T.serial(2, annotations={"pragma_unroll_explicit":0, "thread_rasterization":10}): + for j_2_init in T.serial(8, annotations={"pragma_unroll_explicit":0}): + with T.block("mediate1_init_o"): + v_i = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + i_2_init) + v_j = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + j_2_init) + v_ii_o = T.axis.spatial(1, 0) + v_jj_o = T.axis.spatial(1, 0) + T.reads() + T.writes(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8]) + C_warp = T.match_buffer(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8], [32, 8], dtype="int32", scope="warp", offset_factor=1) + T.launch_thread(tx, 32) + T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32") + for k_0 in T.serial(256, annotations={"software_pipeline_async_stages":[0], "software_pipeline_order":[0, 1, 2, 3], "software_pipeline_stage":[0, 0, 1, 1]}): + for ax0_ax1_ax2_ax3_0_fused_0 in T.unroll(2, annotations={"pragma_unroll_explicit":0}): + for ax0_ax1_ax2_ax3_0_fused_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_ax2_ax3_0_fused_2 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_ax2_ax3_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax3_1 in T.vectorized(16): + with T.block("input0_shared"): + v0 = T.axis.spatial(1024, i_0 * 4 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) // 64) + v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 64 // 32) + v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 32 // 2) + v3 = T.axis.spatial(32, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 2 * 16 + ax3_1) + T.reads(input0[v0, v1, v2, v3]) + T.writes(input0_shared[v0, v1, v2, v3]) + input0_shared[v0, v1, v2, v3] = input0[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0_0_0_0 in T.serial(2): + for ax0_ax1_ax2_ax3_fused_0_0_0_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_ax2_ax3_fused_0_0_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_ax2_ax3_fused_0_1 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(16): + with T.block("input1_shared"): + v0 = T.axis.spatial(1024, j_0 * 16 + (ax0_ax1_ax2_ax3_fused_0_0_0_0 * 2048 + ax0_ax1_ax2_ax3_fused_0_0_0_1 * 1024 + ax0_ax1_ax2_ax3_fused_0_0_1 * 512 + ax0_ax1_ax2_ax3_fused_0_1 * 16 + ax0_ax1_ax2_ax3_fused_1) // 256) + v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_fused_0_0_0_0 * 2048 + ax0_ax1_ax2_ax3_fused_0_0_0_1 * 1024 + ax0_ax1_ax2_ax3_fused_0_0_1 * 512 + ax0_ax1_ax2_ax3_fused_0_1 * 16 + ax0_ax1_ax2_ax3_fused_1) % 256 // 128) + v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_fused_0_0_0_0 * 2048 + ax0_ax1_ax2_ax3_fused_0_0_0_1 * 1024 + ax0_ax1_ax2_ax3_fused_0_0_1 * 512 + ax0_ax1_ax2_ax3_fused_0_1 * 16 + ax0_ax1_ax2_ax3_fused_1) % 128 // 8) + v3 = T.axis.spatial(8, (ax0_ax1_ax2_ax3_fused_0_0_0_0 * 2048 + ax0_ax1_ax2_ax3_fused_0_0_0_1 * 1024 + ax0_ax1_ax2_ax3_fused_0_0_1 * 512 + ax0_ax1_ax2_ax3_fused_0_1 * 16 + ax0_ax1_ax2_ax3_fused_1) % 8) + T.reads(input1[v0, v1, v2, v3]) + T.writes(input1_shared[v0, v1, v2, v3]) + input1_shared[v0, v1, v2, v3] = input1[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_0_fused_0 in T.serial(8): + for ax0_ax1_ax2_ax3_0_fused_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_ax2_ax3_0_fused_2 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_ax2_ax3_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax3_1 in T.serial(1): + for ax0 in T.vectorized(4): + with T.block("input1_shared_local"): + v0 = T.axis.spatial(1024, j_0 * 16 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) // 64) + v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 64 // 32) + v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 32 // 2) + v3 = T.axis.spatial(8, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 2 * 4 + ax0) + T.reads(input1_shared[v0, v1, v2, v3]) + T.writes(input1_shared_local[v0, v1, v2, v3]) + input1_shared_local[v0, v1, v2, v3] = input1_shared[v0, v1, v2, v3] + for ax0 in T.serial(16): + with T.block("mediate0_local"): + v0 = T.axis.spatial(1024, j_0 * 16 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) // 64) + v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 64 // 32) + v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 32 // 2) + v3 = T.axis.spatial(32, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 2 * 16 + ax0) + T.reads(input1_shared_local[v0, v1, v2, v3 // 4]) + T.writes(mediate0_local[v0, v1, v2, v3]) + mediate0_local[v0, v1, v2, v3] = T.bitwise_and(T.shift_right(input1_shared_local[v0, v1, v2, v3 // 4], T.Cast("int8", v3 % 4), dtype="int8"), T.int8(1), dtype="int8") + for ax3_2 in T.vectorized(16): + with T.block("mediate0_shared"): + v0 = T.axis.spatial(1024, j_0 * 16 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) // 64) + v1 = T.axis.spatial(512, k_0 * 2 + (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 64 // 32) + v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 32 // 2) + v3 = T.axis.spatial(32, (ax0_ax1_ax2_ax3_0_fused_0 * 128 + ax0_ax1_ax2_ax3_0_fused_1 * 64 + ax0_ax1_ax2_ax3_0_fused_2 * 32 + ax0_ax1_ax2_ax3_0_fused_3) % 2 * 16 + ax3_1 * 16 + ax3_2) + T.reads(mediate0_local[v0, v1, v2, v3]) + T.writes(mediate0_shared[v0, v1, v2, v3]) + mediate0_shared[v0, v1, v2, v3] = mediate0_local[v0, v1, v2, v3] + for k_1 in T.serial(2): + for ax0, ax1 in T.grid(2, 1): + with T.block("input0_shared_warp_o"): + v0 = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + ax0) + v1 = T.axis.spatial(512, ax1 * 512 + k_0 * 2 + k_1) + v2_o = T.axis.spatial(1, 0) + v3_o = T.axis.spatial(1, 0) + T.reads(input0_shared[v0, v1, 0 : 16, 0 : 32]) + T.writes(input0_shared_warp[v0, v1, 0 : 32, 0 : 16]) + warp = T.match_buffer(input0_shared_warp[v0, v1, 0 : 32, 0 : 16], [32, 16], dtype="int8", scope="warp", offset_factor=16) + shared = T.match_buffer(input0_shared[v0, v1, 0 : 16, 0 : 32], [16, 32], dtype="int8", strides=[shared_s0, shared_s1], scope="shared", offset_factor=16) + T.launch_thread(tx, 32) + T.ptx_ldmatrix(False, 4, ".b16", warp.data, warp.elem_offset + 16 * tx, T.tvm_access_ptr(T.type_annotation(dtype="int8"), shared.data, shared.elem_offset, shared_s0 * 16, 1, dtype="handle"), 16 * tx, dtype="int8") + for ax0, ax1 in T.grid(8, 1): + with T.block("mediate0_shared_warp_o"): + v0 = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + ax0) + v1 = T.axis.spatial(512, ax1 * 512 + k_0 * 2 + k_1) + v2_o = T.axis.spatial(1, 0) + v3_o = T.axis.spatial(1, 0) + T.reads(mediate0_shared[v0, v1, 0 : 16, 0 : 32]) + T.writes(mediate0_shared_warp[v0, v1, 0 : 32, 0 : 16]) + warp_1 = T.match_buffer(mediate0_shared_warp[v0, v1, 0 : 32, 0 : 16], [32, 16], dtype="int8", scope="warp", offset_factor=16) + shared_1 = T.match_buffer(mediate0_shared[v0, v1, 0 : 16, 0 : 32], [16, 32], dtype="int8", strides=[shared_s0_1, shared_s1_1], scope="shared", offset_factor=16) + T.launch_thread(tx, 32) + T.ptx_ldmatrix(False, 4, ".b16", warp_1.data, warp_1.elem_offset + 16 * tx, T.tvm_access_ptr(T.type_annotation(dtype="int8"), shared_1.data, shared_1.elem_offset, shared_s0_1 * 16, 1, dtype="handle"), 16 * tx, dtype="int8") + for i_2, j_2 in T.grid(2, 8): + with T.block("mediate1_update_o"): + v_i = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + i_2) + v_j = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + j_2) + v_ii_o = T.axis.spatial(1, 0) + v_jj_o = T.axis.spatial(1, 0) + v_k = T.axis.reduce(512, k_0 * 2 + k_1) + v_kk_o = T.axis.reduce(1, 0) + T.reads(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8], input0_shared_warp[v_i, v_k, 0 : 32, 0 : 16], mediate0_shared_warp[v_j, v_k, 0 : 32, 0 : 16]) + T.writes(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8]) + A = T.match_buffer(input0_shared_warp[v_i, v_k, 0 : 32, 0 : 16], [32, 16], dtype="int8", scope="warp", offset_factor=16) + B = T.match_buffer(mediate0_shared_warp[v_j, v_k, 0 : 32, 0 : 16], [32, 16], dtype="int8", scope="warp", offset_factor=16) + C = T.match_buffer(mediate1_shared_warp[v_i, v_j, 0 : 32, 0 : 8], [32, 8], dtype="int32", scope="warp", offset_factor=16) + T.launch_thread(tx, 32) + T.ptx_mma("m16n8k32", "row", "col", "int8", "int8", "int32", A.data, A.elem_offset + tx * 16, B.data, B.elem_offset + tx * 16, C.data, C.elem_offset + tx * 8, False, dtype="int32") + T.ptx_mma("m16n8k32", "row", "col", "int8", "int8", "int32", A.data, A.elem_offset + tx * 16, B.data, B.elem_offset + tx * 16 + T.FloorDiv(16, 2), C.data, C.elem_offset + tx * 8 + T.FloorDiv(8, 2), False, dtype="int32") + for ax0, ax1 in T.grid(2, 8): + with T.block("mediate1_shared_warp_o"): + v0 = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + ax0) + v1 = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + ax1) + v2_o = T.axis.spatial(1, 0) + v3_o = T.axis.spatial(1, 0) + T.reads(mediate1_shared_warp[v0, v1, 0 : 32, 0 : 8]) + T.writes(mediate1_shared[v0, v1, 0 : 16, 0 : 16]) + C_warp_1 = T.match_buffer(mediate1_shared_warp[v0, v1, 0 : 32, 0 : 8], [32, 8], dtype="int32", scope="warp", offset_factor=1) + C_1 = T.match_buffer(mediate1_shared[v0, v1, 0 : 16, 0 : 16], [16, 16], dtype="int32", strides=[C_s0, C_s1], scope="shared", offset_factor=1) + T.launch_thread(tx, 32) + T.mma_store(16, 16, T.tvm_access_ptr(T.type_annotation(dtype="int32"), C_1.data, C_1.elem_offset, C_s0 * 16, 2, dtype="handle"), C_warp_1.data, C_warp_1.elem_offset, C_s0, dtype="int32") + for ax0_ax1_ax2_ax3_fused_0 in T.unroll(2, annotations={"pragma_unroll_explicit":0}): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(4): + with T.block("mediate1_shared"): + v0 = T.axis.spatial(1024, i_0 * 4 + i_1 * 2 + ax0) + v1 = T.axis.spatial(1024, j_0 * 16 + j_1 * 8 + ax1) + v2 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) // 16) + v3 = T.axis.spatial(16, (ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 4 + ax0_ax1_ax2_ax3_fused_2) % 16) + T.reads(mediate1_shared[v0, v1, v2, v3]) + T.writes(output0[v0 * 16 + v2, v1 * 16 + v3]) + output0[v2 + v0 * 16, v3 + v1 * 16] = T.Cast("int8", mediate1_shared[v0, v1, v2, v3]) + +mod = main +sch = tvm.tir.Schedule(mod, debug_mask="all") +with tvm.transform.PassContext( + config={"tir.use_async_copy": True} + ): + dense_relu_0_rt_mod = tvm.build(sch.mod, target="cuda") +with open("after_memory_rewrite.cu", "+w") as f: + f.write(dense_relu_0_rt_mod.imported_modules[0].get_source()) diff --git a/testing/python/tir_expr/test_tir_2.py b/testing/python/tir_expr/test_tir_2.py new file mode 100644 index 000000000000..d2a93aeadd7d --- /dev/null +++ b/testing/python/tir_expr/test_tir_2.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R +import bitblas + + +@T.prim_func +def fused_fused_decode3_fused_NT_matmul8_add1( + lv47: T.Buffer((T.int64(256), T.int64(256), T.int64(16), T.int64(2)), "uint32"), + lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), + lv41: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), + NT_matmul_intermediate: T.Buffer( + (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ), +): + T.func_attr( + { + "dequantize_info": { + "decode": { + "decode_block": "decode", + "fast_decoding": T.bool(False), + "group_size": 32, + "source_format": {"bits": 4, "format": "int"}, + "storage_dtype": "uint32", + "target_format": "float16", + "with_scaling": T.bool(True), + } + }, + "smooth_b": T.bool(True), + "tir.noalias": T.bool(True), + "transform_kind": 1, + } + ) + # with T.block("root"): + decode_intermediate_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(4096)), "float16" + ) + lv47_global = T.alloc_buffer((T.int64(4096), T.int64(512)), "uint32") + for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)): + with T.block("lv47_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + lv47[ + v0 // T.int64(16), + v1 // T.int64(2), + v0 % T.int64(16), + v1 % T.int64(2), + ] + ) + T.writes(lv47_global[v0, v1]) + lv47_global[v0, v1] = lv47[ + v0 // T.int64(16), + v1 // T.int64(2), + v0 % T.int64(16), + v1 % T.int64(2), + ] + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv47_global[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) + T.writes(decode_intermediate_intermediate[v_i, v_j]) + decode_intermediate_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47_global[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv48[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv41[v_i0, v_i1, v_k], decode_intermediate_intermediate[v_i2, v_k]) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k] + ) + + +import tvm + +sch = bitblas.gpu.GEMV().apply( + fused_fused_decode3_fused_NT_matmul8_add1, tvm.target.Target("cuda"), False +) +print(sch) diff --git a/testing/python/tir_expr/test_tir_3.py b/testing/python/tir_expr/test_tir_3.py new file mode 100644 index 000000000000..1a879859cd66 --- /dev/null +++ b/testing/python/tir_expr/test_tir_3.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R +import bitblas + + +@T.prim_func +def fused_fused_decode3_fused_NT_matmul8_add1( + lv47: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), + lv48: T.Buffer((T.int64(4096), T.int64(128)), "float16"), + p_lv41: T.handle, + p_output0: T.handle, +): + T.func_attr( + { + "dequantize_info": { + "decode": { + "decode_block": "decode", + "fast_decoding": T.bool(False), + "group_size": 32, + "source_format": {"bits": 4, "format": "int"}, + "storage_dtype": "uint32", + "target_format": "float16", + "with_scaling": T.bool(True), + } + }, + "tir.is_scheduled": 1, + "tir.noalias": T.bool(True), + } + ) + n = T.int64() + lv41 = T.match_buffer(p_lv41, (T.int64(1), T.int64(1), T.int64(4096)), "float16") + NT_matmul_intermediate = T.match_buffer( + p_output0, (T.int64(1), T.int64(1), T.int64(4096)), "float16" + ) + # with T.block("root"): + decode_intermediate_intermediate = T.alloc_buffer( + (T.int64(4096), T.int64(4096)), "float16" + ) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv47[v_i, v_j // T.int64(8)], lv48[v_i, v_j // T.int64(32)]) + T.writes(decode_intermediate_intermediate[v_i, v_j]) + decode_intermediate_intermediate[v_i, v_j] = ( + T.Cast( + "float16", + T.bitwise_and( + T.shift_right( + lv47[v_i, v_j // T.int64(8)], + T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4), + ), + T.uint32(15), + ), + ) + - T.float16(7) + ) * lv48[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), 1, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads( + lv41[v_i0, v_i1, v_k], + decode_intermediate_intermediate[v_i2, v_k], + ) + T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( + NT_matmul_intermediate[v_i0, v_i1, v_i2] + + lv41[v_i0, v_i1, v_k] * decode_intermediate_intermediate[v_i2, v_k] + ) + + +import tvm +from bitblas.base.roller.policy import DefaultPolicy +from bitblas.base.roller.arch import CUDA + +func = fused_fused_decode3_fused_NT_matmul8_add1 +target = tvm.target.Target("nvidia/nvidia-a100") +arch = CUDA(target) +policy = DefaultPolicy(func=func, arch=arch) +configs = policy.emit_config(20) +print(configs) +sch = bitblas.gpu.gemv.GEMVWithDequantizeInfo().apply_config(func, configs[0]) +print(sch.mod) diff --git a/testing/python/type_conversion/int4b_fp16_convert.py b/testing/python/type_conversion/int4b_fp16_convert.py new file mode 100644 index 000000000000..92f5f46a8928 --- /dev/null +++ b/testing/python/type_conversion/int4b_fp16_convert.py @@ -0,0 +1,229 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import torch +import numpy as np +import tvm.testing +from tvm.script import tir as T +import os +from tvm import te +import numpy as np + + +def general_compress_to_int8(lowprecision_weight, source_bits=4): + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == np.float16: + lowprecision_weight = lowprecision_weight.astype(dtype=np.int8) + int8_weight = np.zeros( + ( + *lowprecision_weight.shape[:-1], + lowprecision_weight.shape[-1] // elems_per_byte, + ), + dtype=np.int8, + ) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << ( + source_bits * k + ) + return int8_weight + + +def interleave_weight(qweight, nbits=4, target_dtype="float16"): + assert target_dtype in ["float16", "int8"] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(np.int32) + new_qweight = np.zeros_like(qweight) + bits_stride = 8 if target_dtype == "int8" else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == "int8": + # special handling for 1b interleave + n16_weight = new_qweight & np.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(np.int8) + elif nbits == 2 and target_dtype == "float16": + n8_weight = new_qweight & np.int32(0xFF0000FF) + n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(np.int8) + elif nbits == 1 and target_dtype == "float16": + n8_weight = new_qweight & 0xF000000F + n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 + n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 + n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 + n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 + n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 + n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + + return new_qweight.view(np.int8) + + +def tir_interleave_weight(N=2, K=16, bits=4, target_dtype="float16"): + QK = K * bits // 32 + bits_stride = 16 + mask = (1 << bits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // bits + + @T.prim_func + def interleave_weight(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + ( + offset // num_groups + ) * bits + B[v0, v1] = B[v0, v1] | ( + ((A[v0, v1] >> (bits * offset)) & mask) << shift + ) + + @T.prim_func + def interleave_weight_f16_2b( + A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") + ): + B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + ( + offset // num_groups + ) * bits + B[v0, v1] = B[v0, v1] | ( + ((A[v0, v1] >> (bits * offset)) & mask) << shift + ) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xFF0000FF) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x00FF0000)) << 8) >> 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000FF00)) << 16) >> 8 + B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] + + @T.prim_func + def interleave_weight_f16_1b( + A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") + ): + B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_6 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_7 = T.alloc_buffer((N, QK), "int32", scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + ( + offset // num_groups + ) * bits + B[v0, v1] = B[v0, v1] | ( + ((A[v0, v1] >> (bits * offset)) & mask) << shift + ) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF000000F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 8 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x00000F00)) >> 8) << 16 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 8 + B_tmp_6[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 20) << 12 + B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 24) << 20 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1] + | B_tmp_6[v0, v1] + | B_tmp_7[v0, v1] + ) + + @T.prim_func + def interleave_weight_int8_1b( + A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") + ): + B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), "int32", scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), "int32", scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + ( + offset // num_groups + ) * bits + B[v0, v1] = B[v0, v1] | ( + ((A[v0, v1] >> (bits * offset)) & mask) << shift + ) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF0F00F0F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 4 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x0F000000)) >> 24) << 12 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1] + ) + + if target_dtype == "float16" and bits == 2: + return interleave_weight_f16_2b + elif target_dtype == "float16" and bits == 1: + return interleave_weight_f16_1b + elif target_dtype == "int8" and bits == 1: + return interleave_weight_int8_1b + + return interleave_weight + + +def test_lop3_interleave_weight(): + source_nbits = 2 + N = 2 + K = 16 + target_dtype = "float16" + torch.manual_seed(0) + uint_max = 2 ** (source_nbits) - 1 + raw_data = torch.randint(0, uint_max, (N, K), dtype=torch.int8).cpu().numpy() + compressed_b = general_compress_to_int8(raw_data, source_nbits) + interleaved_weight = interleave_weight(compressed_b, source_nbits, target_dtype) + interleave_func = tir_interleave_weight(N, K, source_nbits, target_dtype) + + ref_func = tvm.build(interleave_func, target="llvm") + ctx = tvm.cpu(0) + compressed_b_cast_32 = compressed_b.view(np.int32) + tvm_compress_b = tvm.nd.array(compressed_b_cast_32, ctx) + tvm_interleaved_b = tvm.nd.array(np.zeros_like(compressed_b_cast_32), ctx) + ref_func(tvm_compress_b, tvm_interleaved_b) + tvm_interleaved_b_np = tvm_interleaved_b.asnumpy() + tvm_interleaved_b_np_int8 = tvm_interleaved_b_np.view(np.int8) + np.testing.assert_allclose(tvm_interleaved_b_np_int8, interleaved_weight, atol=1e-5) + + +test_lop3_interleave_weight() diff --git a/testing/python/type_conversion/test_numpy_compress_convert.py b/testing/python/type_conversion/test_numpy_compress_convert.py new file mode 100644 index 000000000000..59e481eb93dd --- /dev/null +++ b/testing/python/type_conversion/test_numpy_compress_convert.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/testing/python/weight_only/correctness/test_fp16xint4_correctness.py b/testing/python/weight_only/correctness/test_fp16xint4_correctness.py new file mode 100644 index 000000000000..7f5b3027dbf5 --- /dev/null +++ b/testing/python/weight_only/correctness/test_fp16xint4_correctness.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import torch + +import bitblas +import numpy as np + +from bitblas.quantization.utils import general_compress, interleave_weight +from bitblas.ops.matmul import MatmulWeightOnlyDequantize + +M = 1 +N = 4096 +K = 1024 +bitblas_matmul = MatmulWeightOnlyDequantize( + M=M, + N=N, + K=K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + propagate_b=False, + bit=4, + storage_dtype="uint8", + source_format="int", + with_scaling=False, + group_size=128, + fast_decoding=False, + with_bias=False, +) + +torch_arrs = [] +torch_arrs.append(torch.randint(0, 10, (M, K), dtype=torch.float16, device="cuda")) +torch_arrs.append(torch.randint(0, 7, (N, K), dtype=torch.float16, device="cuda")) +torch_arrs.append(torch.zeros((M, K), dtype=torch.float16, device="cuda")) + +print("torch: {}".format(torch_arrs[-1])) + diff --git a/testing/python/weight_only/index_map_deduce.py b/testing/python/weight_only/index_map_deduce.py new file mode 100644 index 000000000000..892cea202e1e --- /dev/null +++ b/testing/python/weight_only/index_map_deduce.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm.ir import assert_structural_equal +from tvm.runtime import const +from tvm.tir import IndexMap, IntImm, floordiv, floormod +from tvm import tir +index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") +initial_i = index_map.initial_indices[0] + +# but what we have is i <=> i // 4 +# should do inverse + +block_iter_map = IndexMap.from_func(lambda i: [i // 4], index_dtype="int32") +inverse_block_iter_map = index_map.inverse([32,]) + +new_final_indices = index_map.map_indices([initial_i * 4]) + +# # tir.IndexMap([initial_i // 4], final_indices, None) +# print(new_final_indices) diff --git a/testing/python/weight_only/index_map_fuse.py b/testing/python/weight_only/index_map_fuse.py new file mode 100644 index 000000000000..6660903cac2e --- /dev/null +++ b/testing/python/weight_only/index_map_fuse.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.script import tir as T +from tvm.tir import IndexMap +from tvm.tir.tensor_intrin.cuda import ( + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, +) + +def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id) + +@tvm.script.ir_module +class LDMATRIX_16x16: + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="float16") + B = T.match_buffer(b, [16, 16], dtype="float16") + + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(A[vi, vj]) + A[vi, vj] = B[vi, vj] + +ir_module = LDMATRIX_16x16 +sch = tvm.tir.Schedule(ir_module) + +block_b = sch.get_block("B") +sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x16_32x8_16x16) +print("========================inject transform=============================") +print(sch.mod["main"].script()) + +index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16, index_dtype="int32") +inversed_index_map = index_map.inverse([16, 16]) +def inverse_permutation(i, j): + return inversed_index_map.map_indices([i, j]) +sch.transform_layout(block_b, ('read', 0), inverse_permutation) +print("========================inverse inject transform=============================") +print(sch.mod["main"].script()) + + +def ldmatrix_trans_permutation_16x32_16x32_16x32(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 16 + local_id = kernel_j % 16 + return ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id) + +@tvm.script.ir_module +class LDMATRIX_16x32_A: + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [16, 32], dtype="float16") + B = T.match_buffer(b, [16, 32], dtype="float16") + + for i, j in T.grid(16, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(A[vi, vj]) + A[vi, vj] = B[vi, vj] + +ir_module = LDMATRIX_16x32_A +sch = tvm.tir.Schedule(ir_module) + +block_b = sch.get_block("B") +sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x32_16x32_16x32) +print("========================inject transform=============================") +print(sch.mod["main"].script()) + +index_map_inter = IndexMap.from_func(lambda i, j: (i // 16, j // 16, i % 16, j % 16), index_dtype="int32") + +index_map_intra = IndexMap.from_func(ldmatrix_trans_permutation_16x32_16x32_16x32, index_dtype="int32") + +print("index_map_inter", index_map_inter) \ No newline at end of file diff --git a/testing/python/weight_only/inverse_index_map.py b/testing/python/weight_only/inverse_index_map.py new file mode 100644 index 000000000000..bfc1782972d2 --- /dev/null +++ b/testing/python/weight_only/inverse_index_map.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.script import tir as T +from tvm.tir import IndexMap +from tvm.tir.tensor_intrin.cuda import ( + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, +) + +def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id) + +@tvm.script.ir_module +class LDMATRIX_16x16: + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="float16") + B = T.match_buffer(b, [16, 16], dtype="float16") + + for i, j in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(A[vi, vj]) + A[vi, vj] = B[vi, vj] + +ir_module = LDMATRIX_16x16 +sch = tvm.tir.Schedule(ir_module) + +block_b = sch.get_block("B") +sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x16_32x8_16x16) +print("========================inject transform=============================") +print(sch.mod["main"].script()) + +index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16, index_dtype="int32") +inversed_index_map = index_map.inverse([16, 16]) +def inverse_permutation(i, j): + return inversed_index_map.map_indices([i, j]) +sch.transform_layout(block_b, ('read', 0), inverse_permutation) +print("========================inverse inject transform=============================") +print(sch.mod["main"].script()) + + +def ldmatrix_trans_permutation_16x32_16x32_16x32(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 16 + local_id = kernel_j % 16 + return ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id) + +@tvm.script.ir_module +class LDMATRIX_16x32_A: + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [16, 32], dtype="float16") + B = T.match_buffer(b, [16, 32], dtype="float16") + + for i, j in T.grid(16, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(A[vi, vj]) + A[vi, vj] = B[vi, vj] + +ir_module = LDMATRIX_16x32_A +sch = tvm.tir.Schedule(ir_module) + +block_b = sch.get_block("B") +sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x32_16x32_16x32) +print("========================inject transform=============================") +print(sch.mod["main"].script()) + +index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x32_16x32_16x32, index_dtype="int32") +inversed_index_map = index_map.inverse([16, 32]) +def inverse_permutation(i, j): + return inversed_index_map.map_indices([i, j]) +sch.transform_layout(block_b, ('read', 0), inverse_permutation) +print("========================inverse inject transform=============================") +print(sch.mod["main"].script()) + +def ldmatrix_trans_permutation_16x32_16x32_16x32(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 16 + local_id = kernel_j % 16 + return ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id) + +@tvm.script.ir_module +class LDMATRIX_16x32_B: + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [16, 32], dtype="float16") + B = T.match_buffer(b, [16, 32], dtype="float16") + + for i, j in T.grid(16, 32): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(A[vi, vj]) + A[vi, vj] = B[vi, vj] + +ir_module = LDMATRIX_16x32_B +sch = tvm.tir.Schedule(ir_module) + +block_b = sch.get_block("B") +sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x32_16x32_16x32) +print("========================inject transform=============================") +print(sch.mod["main"].script()) + +index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x32_16x32_16x32, index_dtype="int32") +inversed_index_map = index_map.inverse([16, 32]) +def inverse_permutation(i, j): + return inversed_index_map.map_indices([i, j]) +sch.transform_layout(block_b, ('read', 0), inverse_permutation) +print("========================inverse inject transform=============================") +print(sch.mod["main"].script())