Skip to content

Commit

Permalink
[Dev][TL] Integrate TL Dequant Implementation into BitBLAS OPs (#214)
Browse files Browse the repository at this point in the history
* Refactor tilelang dequantize module and add matmul_blocked_weight_only function

* remove un-implemented code.

* Implement BaseScheduler to wrap some related items.

* lint fix

* test skip

* Refactor tilelang dequantize module and add matmul_blocked_weight_only function

* test fix

* hardware tuning demo

* remove debug related items.

* imlement tuner and cache fix

* lint fix

* test case fix.

* Adapt Tuning Space generation with Roller

* lint fix

* Refactor select_scheduler function for fine-grained interface

The select_scheduler function in the dense/__init__.py module has been refactored to use a fine-grained interface. This change provides more flexibility and enables the implementation of high-performance kernels.

Update MatmulScheduler class in matmul_tensorcore.py

The MatmulScheduler class in the matmul_tensorcore.py module has been updated to calculate the number of threads based on the block size and warp size. This ensures optimal GPU warp configuration for NVIDIA GPUs.

Improve test_general_matmul_tilelang_kernel.py

The test_general_matmul_tilelang_kernel.py module has been improved to include additional test cases and assertions for correctness.

* Refactor select_scheduler function for fine-grained interface

* Refactor NotImplementedError message in BaseTLHint class

* Update submodule reference in 3rdparty/tvm

* Refactor matmul_finetune function to use topk=20 for hardware-aware finetuning

* Refactor submodule reference in 3rdparty/tvm

* lint fix

* Refactor test_general_matmul_tilelang_impl.py and test_tilelang_gemm.py

* Refactor MatmulConfig to enable weight propagation on supported devices

* Refactor test_general_matmul_tilelang_impl.py and test_general_matmul_tilelang_kernel.py to use centered random values for input tensors

* test fix

* test fix

* Refactor flash attention tests to use centered random values for input tensors

* Refactor flash attention tests to use centered random values for input tensors

* Refactor flash attention tests to skip test if flash_attn is not installed

* lint fix

* test fix

* test fix

* test fix

* Refactor quantization module imports

* lint fix

* Update yapf version in requirements-dev.txt and requirements-test.txt

* Refactor shared memory to global memory storage in MatmulFineGrainScheduler

* test fix

* format

* test fix

* Refactor tensorcore policy to use list comprehension for readability

* lint fix
  • Loading branch information
LeiWang1999 authored Oct 7, 2024
1 parent 314b2a1 commit a6d627c
Show file tree
Hide file tree
Showing 30 changed files with 1,450 additions and 234 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+2 −2 python/tvm/tl/language.py
8 changes: 6 additions & 2 deletions benchmark/dsl/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dty
C = te.compute(
out_shape,
lambda n, h, w, f: te.sum(
pad[n, h * stride_h + kh * dilation_h, w * stride_w + kw * dilation_w, c,] * B[kh, kw,
c, f],
pad[
n,
h * stride_h + kh * dilation_h,
w * stride_w + kw * dilation_w,
c,
] * B[kh, kw, c, f],
axis=[kh, kw, c],
),
name="C",
Expand Down
2 changes: 2 additions & 0 deletions bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl")
sys.path.insert(0, install_tvm_path + "/python")

develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
Expand All @@ -18,6 +19,7 @@
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl")
sys.path.insert(0, develop_tvm_path + "/python")

import tvm as tvm # noqa: E402
Expand Down
8 changes: 2 additions & 6 deletions bitblas/base/roller/policy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,7 @@ def _optimize(node, rstep):
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))

def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis}
score = 0
shape = node.propagate_inputs(td.get_tile(node), rstep=rstep)
for i, input_buffer in enumerate(node.input_buffers):
Expand Down Expand Up @@ -325,9 +323,7 @@ def _enlarge(rstep_id):
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis}
return rstep

for node in self.ordered_nodes:
Expand Down
6 changes: 3 additions & 3 deletions bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def get_node_reduce_step_candidates(self, node):
else:
# must be a a multiple of wmma_k
return {
k.var.name:
[x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)]
for k in node.raxis
k.var.name: [
x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)
] for k in node.raxis
}

def check_tile_shape_isvalid(self, td: TileDict):
Expand Down
16 changes: 15 additions & 1 deletion bitblas/builder/lib_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ctypes
import os
import os.path as osp
import sys
import tempfile
import subprocess
import logging
Expand Down Expand Up @@ -47,8 +48,21 @@ def compile_lib(self, timeout: float = None, with_tl: bool = False):
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
]

if with_tl:
tvm_root = osp.join(osp.dirname(__file__), "../../../3rdparty/tvm")
install_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tvm")
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tvm")

tvm_root = next((path for path in [install_tvm_path, develop_tvm_path]
if os.path.exists(path) and path not in sys.path), None)

if "TL_TEMPLATE_PATH " in os.environ:
tl_template_path = os.environ["TL_TEMPLATE_PATH"]
else:
tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl"))

tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl"))
if "TL_CUTLASS_PATH" in os.environ:
cutlass_path = os.environ["TL_CUTLASS_PATH"]
Expand Down
2 changes: 1 addition & 1 deletion bitblas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

BITBLAS_DEFAULT_CACHE_PATH = os.path.expanduser("~/.cache/bitblas")

MAX_ERROR_MESSAGE_LENGTH = 100
MAX_ERROR_MESSAGE_LENGTH = 200
36 changes: 10 additions & 26 deletions bitblas/gpu/element_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,9 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
vector_factors = [1] * len(block_factors)
vector_factors[-1] = vec_len

if (
any(
[
sch.get(loop_rv).thread_binding is not None
for loop_rv in sch.get_loops(block)
]
)
or len(sch.get_loops(block)) == 0
):
if (any([
sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block)
]) or len(sch.get_loops(block)) == 0):
continue

for loop, iter_type in zip(sch.get_loops(block), dom_kind):
Expand All @@ -68,41 +62,31 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
thread_loops = []
inner_loops = []
for s_loop, block_factor, step_factor, thread_factor in zip(
s_loops, block_factors, step_factors, thread_factors
):
s_loops, block_factors, step_factors, thread_factors):
block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor])
vthread_loop, inner_loop = sch.split(
inner_loop, factors=[None, thread_factor * step_factor]
)
thread_loop, inner_loop = sch.split(
inner_loop, factors=[None, step_factor]
)
inner_loop, factors=[None, thread_factor * step_factor])
thread_loop, inner_loop = sch.split(inner_loop, factors=[None, step_factor])
block_loops.append(block_loop)
vthread_loops.append(vthread_loop)
thread_loops.append(thread_loop)
inner_loops.append(inner_loop)

# inner virtual thread first
vthread_loops = list(reversed(vthread_loops))
sch.reorder(
*block_loops,
*vthread_loops,
*thread_loops,
*inner_loops,
*r_loops,
*o_loops
)
sch.reorder(*block_loops, *vthread_loops, *thread_loops, *inner_loops, *r_loops,
*o_loops)
sch.bind(sch.fuse(*block_loops), "blockIdx.x")
sch.bind(sch.fuse(*thread_loops), "threadIdx.x")
if len(vthread_loops) > 3:
vthread_loops = vthread_loops[0:2] + [sch.fuse(*vthread_loops[2:])]

for i, ax in enumerate(vthread_loops):
sch.bind(ax, "vthread" + [".x", ".y", ".z"][i])

# vectorize the last axis
ax = inner_loops[-1]
if sch.get(ax).extent.value > 1:
sch.vectorize(ax)

return sch
10 changes: 10 additions & 0 deletions bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,17 @@ def get_lop3_intrin_group(
if is_ladder_stage3:
key += "_offset"

if out_dtype == "float16":
d4f = "f16"
elif out_dtype == "int8":
d4f = "i8s"
else:
raise ValueError("Unsupported target dtype: {}".format(target_dtype))
source_symbol = "u" if source_format == "uint" else "s"
func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f)

return {
"func_name": func_name,
"c_source": import_c_map[key],
"compute": _intrin,
}
5 changes: 2 additions & 3 deletions bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,9 +680,8 @@ def check_last_trait(region: List[Range]):
minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
# the batch dimension is not taken into consideration.
extent = block_stmt.iter_vars[1].dom.extent
if isinstance(extent,
tir.expr.IntImm) and (extent.value <
(1 if allow_gemv else minimal_tensorize_threshold)):
if isinstance(extent, tir.expr.IntImm) and (extent.value < (1 if allow_gemv else
minimal_tensorize_threshold)):
return func, None
for item_var in block_stmt.iter_vars[2:]:
extent = item_var.dom.extent
Expand Down
67 changes: 14 additions & 53 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ def get_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741
return get_index_map_3d(index_map, l, r)


def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"])
return all(conditions)


class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule):
"""
The schedule rule for float16 tensor core matmul computation.
Expand Down Expand Up @@ -212,19 +226,6 @@ def check_dequantize_info(dequantize_info):

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info"

# Start Schedule
Expand Down Expand Up @@ -727,19 +728,6 @@ def check_dequantize_info(dequantize_info):

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info"

# Start Schedule
Expand Down Expand Up @@ -1225,20 +1213,6 @@ def check_dequantize_info(dequantize_info):

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(
weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info"

# Start Schedule
Expand Down Expand Up @@ -1820,19 +1794,6 @@ def check_dequantize_info(dequantize_info):

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info"

# Start Schedule
Expand Down
14 changes: 13 additions & 1 deletion bitblas/ops/base_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from tvm import IRModule
from tvm.tir import PrimFunc
from typing import Union
from typing import Union, Callable
from dataclasses import dataclass, field
from tvm.tir.transform import Simplify
from abc import ABC, abstractmethod
from bitblas.base.arch import TileDevice


# Decorator to simplify the output of a function
def maybe_simplify(self, func: Callable):

def wrapper(*args, **kwargs):
stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
if self._enable_simplify:
return self.Simplify(stmt)
return stmt

return wrapper


@dataclass
class BaseScheduler(ABC):

Expand Down
22 changes: 21 additions & 1 deletion bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation
from .tirscript.matmul_impl import select_implementation as consistent_implementation
from .tilelang.dense import select_scheduler as consistent_scheduler
from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler
from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils.target_detector import auto_detect_nvidia_target
from dataclasses import dataclass
Expand Down Expand Up @@ -591,7 +592,26 @@ def _select_scheduler(self):
propagate_b=self.propagate_b,
)
else:
raise ValueError("Currently only support native compute for scheduler")
return weight_dequantize_scheduler(
M=self.M,
N=self.N,
K=self.K,
in_dtype=self.A_dtype,
out_dtype=self.out_dtype,
accum_dtype=self.accum_dtype,
bit=self.bit,
storage_dtype=self.storage_dtype,
source_format=self.source_format,
with_scaling=self.with_scaling,
with_zeros=self.with_zeros,
group_size=self.group_size,
fast_decoding=self.fast_decoding,
with_bias=self.with_bias,
layout=self.layout,
zeros_mode=self.zeros_mode,
propagate_a=self.propagate_a,
propagate_b=self.propagate_b,
)

def post_process(self, code: str) -> str:
code = tensor_replace_dp4a(code)
Expand Down
Loading

0 comments on commit a6d627c

Please sign in to comment.