Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev][TL] Integrate TL Dequant Implementation into BitBLAS OPs #214

Merged
merged 49 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f3b1eb9
Refactor tilelang dequantize module and add matmul_blocked_weight_onl…
LeiWang1999 Sep 28, 2024
730d13e
remove un-implemented code.
LeiWang1999 Sep 28, 2024
8047ee7
Implement BaseScheduler to wrap some related items.
LeiWang1999 Sep 28, 2024
64db065
lint fix
LeiWang1999 Sep 28, 2024
cef04a8
test skip
LeiWang1999 Sep 28, 2024
f1652e9
Refactor tilelang dequantize module and add matmul_blocked_weight_onl…
LeiWang1999 Sep 29, 2024
4f6c545
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Sep 29, 2024
c485b68
test fix
LeiWang1999 Sep 29, 2024
ebe42a6
hardware tuning demo
LeiWang1999 Sep 29, 2024
88230ec
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Sep 29, 2024
44246a1
remove debug related items.
LeiWang1999 Sep 30, 2024
bb51e15
imlement tuner and cache fix
LeiWang1999 Oct 1, 2024
f42a3b9
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 1, 2024
de7ae18
lint fix
LeiWang1999 Oct 1, 2024
ef40bd8
test case fix.
LeiWang1999 Oct 1, 2024
85f0a5f
Adapt Tuning Space generation with Roller
LeiWang1999 Oct 1, 2024
e9f7db3
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 1, 2024
9e31336
lint fix
LeiWang1999 Oct 1, 2024
2f1a260
Refactor select_scheduler function for fine-grained interface
LeiWang1999 Oct 1, 2024
f1378d4
Refactor select_scheduler function for fine-grained interface
LeiWang1999 Oct 1, 2024
137cce3
Refactor NotImplementedError message in BaseTLHint class
LeiWang1999 Oct 1, 2024
fc19fa2
Update submodule reference in 3rdparty/tvm
LeiWang1999 Oct 2, 2024
fe51bb1
Refactor matmul_finetune function to use topk=20 for hardware-aware f…
LeiWang1999 Oct 2, 2024
79878cb
Refactor submodule reference in 3rdparty/tvm
LeiWang1999 Oct 2, 2024
0fc7ab9
lint fix
LeiWang1999 Oct 2, 2024
255e925
Refactor test_general_matmul_tilelang_impl.py and test_tilelang_gemm.py
LeiWang1999 Oct 2, 2024
df47f63
Refactor MatmulConfig to enable weight propagation on supported devices
LeiWang1999 Oct 2, 2024
826255d
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 2, 2024
48dc94e
Refactor test_general_matmul_tilelang_impl.py and test_general_matmul…
LeiWang1999 Oct 2, 2024
82f39d7
test fix
LeiWang1999 Oct 2, 2024
02ef258
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 2, 2024
e753ef2
test fix
LeiWang1999 Oct 2, 2024
f6dd744
Refactor flash attention tests to use centered random values for inpu…
LeiWang1999 Oct 2, 2024
7417372
Refactor flash attention tests to use centered random values for inpu…
LeiWang1999 Oct 2, 2024
145a850
Refactor flash attention tests to skip test if flash_attn is not inst…
LeiWang1999 Oct 2, 2024
3384458
lint fix
LeiWang1999 Oct 3, 2024
82f50ea
test fix
LeiWang1999 Oct 3, 2024
d2ed936
test fix
LeiWang1999 Oct 3, 2024
6c56273
test fix
LeiWang1999 Oct 3, 2024
2e59e58
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 6, 2024
074b9ca
Refactor quantization module imports
LeiWang1999 Oct 6, 2024
0923344
lint fix
LeiWang1999 Oct 6, 2024
b30bcd4
Update yapf version in requirements-dev.txt and requirements-test.txt
LeiWang1999 Oct 6, 2024
d0a88ac
Refactor shared memory to global memory storage in MatmulFineGrainSch…
LeiWang1999 Oct 6, 2024
62303e2
test fix
LeiWang1999 Oct 6, 2024
01dc3f9
format
LeiWang1999 Oct 6, 2024
c621664
test fix
LeiWang1999 Oct 7, 2024
f934635
Refactor tensorcore policy to use list comprehension for readability
LeiWang1999 Oct 7, 2024
754cf75
lint fix
LeiWang1999 Oct 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+2 −2 python/tvm/tl/language.py
8 changes: 6 additions & 2 deletions benchmark/dsl/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dty
C = te.compute(
out_shape,
lambda n, h, w, f: te.sum(
pad[n, h * stride_h + kh * dilation_h, w * stride_w + kw * dilation_w, c,] * B[kh, kw,
c, f],
pad[
n,
h * stride_h + kh * dilation_h,
w * stride_w + kw * dilation_w,
c,
] * B[kh, kw, c, f],
axis=[kh, kw, c],
),
name="C",
Expand Down
2 changes: 2 additions & 0 deletions bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl")
sys.path.insert(0, install_tvm_path + "/python")

develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
Expand All @@ -18,6 +19,7 @@
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl")
sys.path.insert(0, develop_tvm_path + "/python")

import tvm as tvm # noqa: E402
Expand Down
8 changes: 2 additions & 6 deletions bitblas/base/roller/policy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,7 @@ def _optimize(node, rstep):
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))

def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis}
score = 0
shape = node.propagate_inputs(td.get_tile(node), rstep=rstep)
for i, input_buffer in enumerate(node.input_buffers):
Expand Down Expand Up @@ -325,9 +323,7 @@ def _enlarge(rstep_id):
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis}
return rstep

for node in self.ordered_nodes:
Expand Down
6 changes: 3 additions & 3 deletions bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def get_node_reduce_step_candidates(self, node):
else:
# must be a a multiple of wmma_k
return {
k.var.name:
[x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)]
for k in node.raxis
k.var.name: [
x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)
] for k in node.raxis
}

def check_tile_shape_isvalid(self, td: TileDict):
Expand Down
16 changes: 15 additions & 1 deletion bitblas/builder/lib_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ctypes
import os
import os.path as osp
import sys
import tempfile
import subprocess
import logging
Expand Down Expand Up @@ -47,8 +48,21 @@ def compile_lib(self, timeout: float = None, with_tl: bool = False):
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
]

if with_tl:
tvm_root = osp.join(osp.dirname(__file__), "../../../3rdparty/tvm")
install_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tvm")
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tvm")

tvm_root = next((path for path in [install_tvm_path, develop_tvm_path]
if os.path.exists(path) and path not in sys.path), None)

if "TL_TEMPLATE_PATH " in os.environ:
tl_template_path = os.environ["TL_TEMPLATE_PATH"]
else:
tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl"))

tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl"))
if "TL_CUTLASS_PATH" in os.environ:
cutlass_path = os.environ["TL_CUTLASS_PATH"]
Expand Down
2 changes: 1 addition & 1 deletion bitblas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

BITBLAS_DEFAULT_CACHE_PATH = os.path.expanduser("~/.cache/bitblas")

MAX_ERROR_MESSAGE_LENGTH = 100
MAX_ERROR_MESSAGE_LENGTH = 200
36 changes: 10 additions & 26 deletions bitblas/gpu/element_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,9 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
vector_factors = [1] * len(block_factors)
vector_factors[-1] = vec_len

if (
any(
[
sch.get(loop_rv).thread_binding is not None
for loop_rv in sch.get_loops(block)
]
)
or len(sch.get_loops(block)) == 0
):
if (any([
sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block)
]) or len(sch.get_loops(block)) == 0):
continue

for loop, iter_type in zip(sch.get_loops(block), dom_kind):
Expand All @@ -68,41 +62,31 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
thread_loops = []
inner_loops = []
for s_loop, block_factor, step_factor, thread_factor in zip(
s_loops, block_factors, step_factors, thread_factors
):
s_loops, block_factors, step_factors, thread_factors):
block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor])
vthread_loop, inner_loop = sch.split(
inner_loop, factors=[None, thread_factor * step_factor]
)
thread_loop, inner_loop = sch.split(
inner_loop, factors=[None, step_factor]
)
inner_loop, factors=[None, thread_factor * step_factor])
thread_loop, inner_loop = sch.split(inner_loop, factors=[None, step_factor])
block_loops.append(block_loop)
vthread_loops.append(vthread_loop)
thread_loops.append(thread_loop)
inner_loops.append(inner_loop)

# inner virtual thread first
vthread_loops = list(reversed(vthread_loops))
sch.reorder(
*block_loops,
*vthread_loops,
*thread_loops,
*inner_loops,
*r_loops,
*o_loops
)
sch.reorder(*block_loops, *vthread_loops, *thread_loops, *inner_loops, *r_loops,
*o_loops)
sch.bind(sch.fuse(*block_loops), "blockIdx.x")
sch.bind(sch.fuse(*thread_loops), "threadIdx.x")
if len(vthread_loops) > 3:
vthread_loops = vthread_loops[0:2] + [sch.fuse(*vthread_loops[2:])]

for i, ax in enumerate(vthread_loops):
sch.bind(ax, "vthread" + [".x", ".y", ".z"][i])

# vectorize the last axis
ax = inner_loops[-1]
if sch.get(ax).extent.value > 1:
sch.vectorize(ax)

return sch
10 changes: 10 additions & 0 deletions bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,17 @@ def get_lop3_intrin_group(
if is_ladder_stage3:
key += "_offset"

if out_dtype == "float16":
d4f = "f16"
elif out_dtype == "int8":
d4f = "i8s"
else:
raise ValueError("Unsupported target dtype: {}".format(target_dtype))
source_symbol = "u" if source_format == "uint" else "s"
func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f)

return {
"func_name": func_name,
"c_source": import_c_map[key],
"compute": _intrin,
}
5 changes: 2 additions & 3 deletions bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,9 +680,8 @@ def check_last_trait(region: List[Range]):
minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
# the batch dimension is not taken into consideration.
extent = block_stmt.iter_vars[1].dom.extent
if isinstance(extent,
tir.expr.IntImm) and (extent.value <
(1 if allow_gemv else minimal_tensorize_threshold)):
if isinstance(extent, tir.expr.IntImm) and (extent.value < (1 if allow_gemv else
minimal_tensorize_threshold)):
return func, None
for item_var in block_stmt.iter_vars[2:]:
extent = item_var.dom.extent
Expand Down
67 changes: 14 additions & 53 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ def get_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741
return get_index_map_3d(index_map, l, r)


def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"])
return all(conditions)


class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule):
"""
The schedule rule for float16 tensor core matmul computation.
Expand Down Expand Up @@ -212,19 +226,6 @@ def check_dequantize_info(dequantize_info):

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info"

# Start Schedule
Expand Down Expand Up @@ -727,19 +728,6 @@ def check_dequantize_info(dequantize_info):

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info"

# Start Schedule
Expand Down Expand Up @@ -1225,20 +1213,6 @@ def check_dequantize_info(dequantize_info):

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(
weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info"

# Start Schedule
Expand Down Expand Up @@ -1820,19 +1794,6 @@ def check_dequantize_info(dequantize_info):

(weight_decode_info,) = list(dequantize_info.values())

def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
conditions.append("target_format" in weight_decode_info)
conditions.append(weight_decode_info["target_format"] in ["float16", "int8"])
return all(conditions)

assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info"

# Start Schedule
Expand Down
14 changes: 13 additions & 1 deletion bitblas/ops/base_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from tvm import IRModule
from tvm.tir import PrimFunc
from typing import Union
from typing import Union, Callable
from dataclasses import dataclass, field
from tvm.tir.transform import Simplify
from abc import ABC, abstractmethod
from bitblas.base.arch import TileDevice


# Decorator to simplify the output of a function
def maybe_simplify(self, func: Callable):

def wrapper(*args, **kwargs):
stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
if self._enable_simplify:
return self.Simplify(stmt)
return stmt

return wrapper


@dataclass
class BaseScheduler(ABC):

Expand Down
22 changes: 21 additions & 1 deletion bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation
from .tirscript.matmul_impl import select_implementation as consistent_implementation
from .tilelang.dense import select_scheduler as consistent_scheduler
from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler
from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils.target_detector import auto_detect_nvidia_target
from dataclasses import dataclass
Expand Down Expand Up @@ -591,7 +592,26 @@ def _select_scheduler(self):
propagate_b=self.propagate_b,
)
else:
raise ValueError("Currently only support native compute for scheduler")
return weight_dequantize_scheduler(
M=self.M,
N=self.N,
K=self.K,
in_dtype=self.A_dtype,
out_dtype=self.out_dtype,
accum_dtype=self.accum_dtype,
bit=self.bit,
storage_dtype=self.storage_dtype,
source_format=self.source_format,
with_scaling=self.with_scaling,
with_zeros=self.with_zeros,
group_size=self.group_size,
fast_decoding=self.fast_decoding,
with_bias=self.with_bias,
layout=self.layout,
zeros_mode=self.zeros_mode,
propagate_a=self.propagate_a,
propagate_b=self.propagate_b,
)

def post_process(self, code: str) -> str:
code = tensor_replace_dp4a(code)
Expand Down
Loading
Loading