Skip to content

Commit

Permalink
tune with dynamic range
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Feb 3, 2024
1 parent 1132fce commit 2b39d03
Show file tree
Hide file tree
Showing 11 changed files with 1,031 additions and 611 deletions.
1 change: 1 addition & 0 deletions python/tvm/dlight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""DLight package provides efficient schedules out-of-box for deep learning workloads."""
from . import gpu
from .base import (
fast_tune,
ApplyDefaultSchedule,
ApplyFastTuning,
BlockInfo,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/dlight/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial
from .schedule_rule import ScheduleRule
from .transform import ApplyDefaultSchedule, ApplyFastTuning
from .utils import fast_tune
32 changes: 5 additions & 27 deletions python/tvm/dlight/base/roller/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,32 +113,6 @@ def get_tag(self, k: str) -> Any:
if k not in self._tag:
return None
return self._tag[k]
class BufferNode(Node):
"""BufferNode is a wrapper of tir.Buffer, which is used to store the buffer information."""

def __init__(self, buffer: tir.Buffer, tags: Dict = {}) -> None:
super().__init__()
self.buffer = buffer
self._tag: Dict = {}
for tag in tags:
self.add_tag(tag, tags[tag])
self.set_dtype(tvm.DataType(self.buffer.dtype))

def set_dtype(self, dtype: tvm.DataType, id=0) -> None:
assert isinstance(dtype, tvm.DataType), type(dtype)
if dtype == tvm.DataType("bool"):
dtype = tvm.DataType("int8")
if len(self._dtypes) <= id:
self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)])
elif self._dtypes[id] is not None:
assert self._dtypes[id] == dtype, (self._dtypes, dtype)
self._dtypes[id] = dtype

def get_dtype(self, id=0) -> tvm.DataType:
return self._dtypes[id]

def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType:
return tvm.DataType(buffer.dtype)


class PrimFuncNode(Node):
Expand All @@ -161,6 +135,9 @@ def __init__(self, prim_func: PrimFunc, tags: Dict = {}) -> None:

def _specialize_func(self, func: PrimFunc):
# Specialize the function to make it more friendly for analysis.
# set attrs
for k, v in func.attrs.items():
self.set_tag(k, v)
opt_shapes = self.get_tag("opt_shapes")
if opt_shapes:
for name, shape in opt_shapes.items():
Expand Down Expand Up @@ -277,7 +254,8 @@ def propogate_inputs(self, tile, rstep={}) -> List[List[int]]:
continue
# should not exceed original shape
trimmed_shape = [
self.extent_warpper(i) for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))
self.extent_warpper(i)
for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape)))
]
results.append(trimmed_shape)
return results
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/dlight/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .roller.policy import DefaultPolicy, TensorCorePolicy
from .roller.arch import CUDA
from .schedule_rule import ScheduleRule
from .analysis import get_tensorized_func_and_tags
from ..gpu.matmul_analysis import get_tensorized_func_and_tags
from .utils import apply_and_build


Expand Down
252 changes: 230 additions & 22 deletions python/tvm/dlight/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,31 @@
from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind, MapResult
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Dict
from tvm import tir, IRModule
from tvm.runtime import Module
from tvm.tir import Schedule
from tvm import dlight as dl
from .analysis import get_root_block, get_reduction_blocks
from .roller.arch import Arch
from tvm.dlight.base.roller.arch import CUDA
from tvm.dlight.base.roller.policy import TensorCorePolicy, DefaultPolicy
from tvm.dlight.gpu.matmul_analysis import get_tensorized_func_and_tags
from ..base.roller.rasterization import NoRasterization
import tempfile
import re
import itertools
from tvm.ir.supply import GlobalVarSupply


def match_global_kernel(source: str) -> int:
pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+"
matched = re.findall(pattern, source)
assert len(matched) > 1 # may have statement before kernel
assert len(matched) > 1 # may have statement before kernel
return source.index(matched[0])

def get_rasterization_code(pannel_width:int = 8) -> str:

def get_rasterization_code(pannel_width: int = 8) -> str:
return f"""
const int MAX_BLOCK_N = {pannel_width};
const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y;
Expand All @@ -49,9 +56,8 @@ def get_rasterization_code(pannel_width:int = 8) -> str:
const auto bz = blockIdx.z;
const dim3 blockIdx(bx, by, bz);
"""

...



class CompileResult:
"""
Class to store the result of compilation
Expand Down Expand Up @@ -123,15 +129,19 @@ def apply_and_build_parallel(func, configs, arch, num_repeats=5, max_workers=10)

def var_warpper(v):
if isinstance(v, tvm.tir.Var):
assert v.name in config.opt_shapes
return config.opt_shapes[v.name]
assert "opt_shapes" in func.attrs
assert v.name in func.attrs["opt_shapes"]
return func.attrs["opt_shapes"][v.name].value
elif isinstance(v, tvm.tir.IntImm):
return v.value
else:
raise RuntimeError("Not supported type: ", type(v))

profile_tensors = []
for param in func.params:
if param not in func.buffer_map:
# in case of dynamic symbolic may in params
continue
arg = func.buffer_map[param]
if arg.dtype == "int8":
profile_tensors.append(
Expand Down Expand Up @@ -166,21 +176,20 @@ def var_warpper(v):
# build in process parallel
def _build(context) -> str:
idx, mod, arch = context
config = configs[idx]

# TODO(lei):
# this is a trick to implement rasteration, will be removed in the future
@tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True)
def tvm_callback_cuda_postproc(code, _):
index = code.index("{", match_global_kernel(code))
if not isinstance(config.rasterization_plan, NoRasterization):
factor = config.rasterization_plan.panel_width_
rasterization_code = get_rasterization_code(factor)
code = code[: index + 2] + rasterization_code + code[index + 2 :]
return code

with tvm.transform.PassContext(
config={"tir.use_async_copy": True, "tir.merge_static_smem": False}
):
# config = configs[idx]
# @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True)
# def tvm_callback_cuda_postproc(code, _):
# index = code.index("{", match_global_kernel(code))
# if not isinstance(config.rasterization_plan, NoRasterization):
# factor = config.rasterization_plan.panel_width_
# rasterization_code = get_rasterization_code(factor)
# code = code[: index + 2] + rasterization_code + code[index + 2 :]
# return code

with tvm.transform.PassContext(config={"tir.use_async_copy": True}):
rt_mod = tvm.build(mod["main"], target=arch.target)

from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel
Expand All @@ -197,7 +206,8 @@ def tvm_callback_cuda_postproc(code, _):
if map_result.status == StatusKind.TIMEOUT:
print("[FastDlight] LocalBuilder: Timeout")
elif map_result.status == StatusKind.EXCEPTION:
print("[FastDlight] LocalBuilder: An exception occurred ", map_result.value)
# TODO(lei): redirect the exception to file if needed
print("[FastDlight] LocalBuilder: An exception occurred ")
continue
elif map_result.status == StatusKind.COMPLETE:
idx, code, artifact_path = map_result.value
Expand Down Expand Up @@ -247,3 +257,201 @@ def apply_and_build(
) -> Tuple[List[CompileResult], CompileResult]:
max_workers = 10 if parallel_build else 1
return apply_and_build_parallel(func, configs, arch, max_workers)


def fast_tune(
func: tir.PrimFunc,
target: tvm.target.Target,
topk: int = 10,
parallel_build: bool = True,
):
if target.kind.name != "cuda":
print("[FastDlight] Only support CUDA target")
return func
if "opt_shapes" in func.attrs:
# should be int value
if not all([isinstance(v.value, int) for v in func.attrs["opt_shapes"].values()]):
print("[FastDlight] The opt_shapes should be int value")
return func

arch = CUDA(target)

policy = DefaultPolicy(func=func, arch=arch)
try:
func, tags = get_tensorized_func_and_tags(func, arch.target)
except:
tags = None
if tags:
policy = TensorCorePolicy(func=func, arch=arch, tags=tags)

configs = policy.emit_config(topk)
cpresults, best = apply_and_build(func, configs, arch, parallel_build=parallel_build)

return cpresults, best


# always use the first function as the base
def collect_buffers_to_declare(func):
params = []
# collect dynamic symbolic
dyn_symbolic: List[tvm.tir.Var] = []
buffers_to_declare = []
for param in func.params:
if param not in func.buffer_map:
continue
buffer = func.buffer_map[param]
for axis in buffer.shape:
if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic:
dyn_symbolic.append(axis)
buffers_to_declare.append(buffer)
params.append(buffer.data)

# the args should be buffers + dynamic symbolic
params += list(dyn_symbolic)

return params, buffers_to_declare


# always use the first function as the base
def collect_buffers_to_declare(func):
params = []
# collect dynamic symbolic
dyn_symbolic: List[tvm.tir.Var] = []
buffers_to_declare = []
for param in func.params:
if param not in func.buffer_map:
continue
buffer = func.buffer_map[param]
for axis in buffer.shape:
if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic:
dyn_symbolic.append(axis)
buffers_to_declare.append(buffer)
params.append(buffer.data)

# the args should be buffers + dynamic symbolic
params += list(dyn_symbolic)

return params, buffers_to_declare


def refactor_specialized_func(func, params, buffers_to_declare):
body = func.body
attrs = func.attrs
global_symbol = func.attrs["global_symbol"]
if "opt_shapes" in func.attrs:
opt_shapes = func.attrs["opt_shapes"]

def serialize_name(opt_shapes: Dict):
return "_opt_" + "_".join([f"{k}_{v}" for k, v in opt_shapes.items()])

global_symbol += serialize_name(opt_shapes)
ret_type = func.ret_type
for buf in buffers_to_declare:
body = tvm.tir.DeclBuffer(buf, body=body)

device_func = tvm.tir.PrimFunc(params, body, ret_type, attrs=attrs).without_attr(
"global_symbol"
)
return global_symbol, device_func


def create_dispatch_func(func: tir.PrimFunc, refactored_funcs: List[str]):
global_symbol = func.attrs["global_symbol"]
attrs = func.attrs
buffer_map = func.buffer_map
params = func.params
ret_type = func.ret_type

# collect dynamic symbolic
dyn_symbolic: List[tvm.tir.Var] = []
_invoke_params = []
for param in func.params:
if param not in func.buffer_map:
continue
buffer = func.buffer_map[param]
for axis in buffer.shape:
if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic:
dyn_symbolic.append(axis)
_invoke_params.append(buffer.data)
_invoke_params += list(dyn_symbolic)

func_range: List[int] = []
global_symbols = []
for g_var, refactor_func in refactored_funcs:
opt_shapes = refactor_func.attrs["opt_shapes"]
func_range.append(list(opt_shapes.values())[0])
global_symbols.append(g_var)

# TODO(lei): general the dispatch function to support multiple dynamic symbolics
assert len(dyn_symbolic) == 1, "Only support one dyanmic symbolics currently"

ib = tvm.tir.ir_builder.create()
syb = list(dyn_symbolic)[-1]
last_range = 0
for i, (_range, g_var) in enumerate(zip(func_range, global_symbols)):
if i == 0:
with ib.if_scope(syb <= _range):
ib.emit(tvm.tir.Call(None, g_var, _invoke_params))
else:
with ib.if_scope(tvm.tir.all(syb > last_range, syb <= _range)):
ib.emit(tvm.tir.Call(None, g_var, _invoke_params))
last_range = _range
with ib.if_scope(syb > last_range):
ib.emit(tvm.tir.Call(None, g_var, _invoke_params))
stmt = ib.get()
dispatch_func = tvm.tir.PrimFunc(params, stmt, ret_type, buffer_map, attrs).with_attrs(
{"tir.is_global_func": True, "global_symbol": global_symbol}
)
return dispatch_func


def create_dispatch_mod(
original_func: tir.PrimFunc, specialized_funcs: List[tir.PrimFunc]
) -> IRModule:
dispatch_mod: IRModule = tvm.IRModule()
g_var_supply = GlobalVarSupply(dispatch_mod)
refactored_funcs = []
for func in specialized_funcs:
params, buffers_to_declare = collect_buffers_to_declare(func)
global_symbol, device_func = refactor_specialized_func(func, params, buffers_to_declare)
global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False)
dispatch_mod[global_symbol] = device_func
refactored_funcs.append((global_symbol, device_func))
dispatch_func = create_dispatch_func(original_func, refactored_funcs=refactored_funcs)
print(dispatch_func)
dispatch_mod.update(tvm.IRModule.from_expr(dispatch_func))
return dispatch_mod


def fast_tune_with_dynamic_range(
func: tir.PrimFunc, target: tvm.target.Target, topk: int = 10, parallel_build: bool = True
) -> IRModule:
if target.kind.name != "cuda":
print("[FastDlight] Only support CUDA target")
return func

if "opt_shapes" not in func.attrs:
print("[FastDlight] The primfunc has no opt_shapes, please set opt_shapes for the primfunc")
return func
else:
# should be list value
if not all([isinstance(v, tvm.ir.Array) for v in func.attrs["opt_shapes"].values()]):
print("[FastDlight] The opt_shapes should be list value")
return func

print("[FastDlight] Start fast tuning with dynamic range")
opt_shapes = func.attrs["opt_shapes"]

# Step 1.Calculate the Cartesian product using itertools.product
product_list = list(itertools.product(*(opt_shapes[key] for key in opt_shapes)))

# Convert the Cartesian product to a list of dictionaries
specialize_items: List[Dict] = [dict(zip(opt_shapes.keys(), values)) for values in product_list]

specilized_tuned_funcs: List[tir.PrimFunc] = []
for item in specialize_items:
func = func.with_attr("opt_shapes", item)
_, best = fast_tune(func, target, topk, parallel_build)
specilized_tuned_funcs.append(best.sch.mod["main"])

return create_dispatch_mod(func, specilized_tuned_funcs)
Loading

0 comments on commit 2b39d03

Please sign in to comment.