Skip to content

Commit

Permalink
Merge pull request #1 from TileLang/cy_merge
Browse files Browse the repository at this point in the history
Merge code into Tilelang
  • Loading branch information
chengyupku authored Oct 4, 2024
2 parents 6e87cff + 536d1e8 commit ef0837f
Show file tree
Hide file tree
Showing 54 changed files with 6,231 additions and 1,104 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py

# GDB history file
.gdb_history

*.ptx
*.ncu-rep
12 changes: 12 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,15 @@
[submodule "3rdparty/flashinfer"]
path = 3rdparty/flashinfer
url = https://github.com/flashinfer-ai/flashinfer.git
[submodule "3rdparty/fa3"]
path = 3rdparty/fa3
url = [email protected]:Dao-AILab/flash-attention.git
[submodule "3rdparty/flash-linear-attention"]
path = 3rdparty/flash-linear-attention
url = [email protected]:sustcsonglin/flash-linear-attention.git
[submodule "3rdparty/mamba"]
path = 3rdparty/mamba
url = [email protected]:state-spaces/mamba.git
[submodule "cutlass"]
path = cutlass
url = [email protected]:TileLang/cutlass.git
1 change: 1 addition & 0 deletions 3rdparty/fa3
Submodule fa3 added at 74b076
1 change: 1 addition & 0 deletions 3rdparty/flash-linear-attention
Submodule flash-linear-attention added at 33b89d
1 change: 1 addition & 0 deletions 3rdparty/mamba
Submodule mamba added at 62db60
1 change: 1 addition & 0 deletions cutlass
Submodule cutlass added at a2954a
8 changes: 8 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ TVM_DLL Pass LowerDeviceKernelLaunch();
*/
TVM_DLL Pass SkipAssert();

/*!
* \brief Insert partial sync
*
* \param storage_scope The storage scope considered.
* \return The pass.
*/
TVM_DLL Pass ThreadPartialSync(String storage_scope);

/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from . import utils


def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None, get_output=False):
"""Compile cuda code with NVCC from env.
Parameters
Expand Down Expand Up @@ -121,6 +121,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target

(out, _) = proc.communicate()

if get_output:
print(py_str(out))

if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,22 @@ def SkipAssert():
return _ffi_api.SkipAssert() # type: ignore


def ThreadPartialSync(storage_scope: str):
"""Insert partial sync.
Parameters
----------
storage_scope: str
The target storage scope.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.ThreadPartialSync(storage_scope) # type: ignore


def ThreadSync(storage_scope: str):
"""Insert sync between parallel read/write of shared buffers.
Expand Down
47 changes: 37 additions & 10 deletions python/tvm/tl/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@
from tvm import tl
import inspect
from functools import wraps
from typing import Any, Callable, List, Literal
from typing import Any, Callable, List, Any, Literal
import inspect
import multiprocessing
from tqdm import tqdm
import logging

logging.basicConfig(filename='out.log', filemode='w', level=logging.INFO,
format='%(asctime)s %(levelname)s:%(message)s')
class Autotuner:
def __init__(
self,
Expand All @@ -47,21 +52,42 @@ def run(self, *args: Any, **kwds: Any) -> Any:
# print(f"{name} = {value}")
best_latency = 1e8
best_config = None
for config in tqdm(self.configs, desc="Auto-tuning progress"):
tqdm.write(f"Current config: {config}")

def target_fn(pipe, *new_args, **kwds):
try:
latency, ref_latency = self.fn(*new_args, **kwds)
pipe.send((latency, ref_latency))
except Exception as e:
logging.error(f"Fail on config {new_args} with error: {e}")
pipe.send((1e8, None))

progress_bar = tqdm(self.configs, desc="Running configurations")
for config in progress_bar:
new_args = []
for name, value in bound_args.arguments.items():
if name not in self.keys:
new_args.append(value)
else:
new_args.append(config[name])
new_args = tuple(new_args)
# print("auto-tunner new_args:", new_args)
try:
latency, ref_latency = self.fn(*new_args, **kwds)
except Exception as e:
print("Fail on config ", config, " with error: ", e)

parent_pipe, child_pipe = multiprocessing.Pipe()

p = multiprocessing.Process(target=target_fn, args=(child_pipe, *new_args), kwargs=kwds)
p.start()

p.join(40)
if p.is_alive():
logging.error(f"Killing config {config} due to timeout.")
p.terminate()
p.join()
latency = 1e8
else:
latency, ref_latency = parent_pipe.recv()
logging.info(f"Config {config} latency: {latency}")

progress_bar.set_postfix({"best_latency": best_latency})

if latency < best_latency:
best_latency = latency
best_config = config
Expand All @@ -83,6 +109,7 @@ def jit(
out_idx: List[int],
supply_type: tl.TensorSupplyType = tl.TensorSupplyType.Normal,
ref_prog: Callable = None,
check_close: bool = True,
rtol: float = 1e-5,
atol: float = 1e-5,
skip_check: bool = False,
Expand All @@ -104,9 +131,9 @@ def decorator(*args, **kwargs) -> float:
if (not skip_check) and (ref_prog is not None):
mod.assert_allclose(ref_prog, rtol=rtol, atol=atol)

latency = mod.do_bench(mod.func, warmup = 25, profiler = profiler)
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler=profiler)
if ref_latency_cache is None and ref_prog is not None:
ref_latency_cache = mod.do_bench(ref_prog, warmup = 25)
ref_latency_cache = mod.do_bench(ref_prog, n_warmup=10, n_repeat=10, profiler="torch")
return latency, ref_latency_cache
return decorator
return wrapper
18 changes: 15 additions & 3 deletions python/tvm/tl/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import tvm
from tvm import tir, tl, relay
from tvm.contrib import nvcc

try:
from tvm.tl.code_replace import replace_code
except ImportError:
def replace_code(code):
return code

def is_device_call(func: tir.PrimFunc):
return bool(func.attrs and "calling_conv" in func.attrs and func.attrs["calling_conv"] == 2)
Expand All @@ -47,18 +51,20 @@ def tvm_callback_cuda_compile(code, target):
format = "cubin"
else:
arch = [f"-arch=sm_{compute_version}"]
format = "ptx"
format = "cubin"

ptx = nvcc.compile_cuda(
code,
format,
arch,
options=[
"-std=c++17",
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers
"--use_fast_math",
"-I" + tl_template_path,
"-I" + cutlass_path,
],
get_output=True,
)
# with open("save.ptx", "wb") as f:
# f.write(ptx)
Expand Down Expand Up @@ -86,7 +92,12 @@ def lower(func, target="cuda", runtime_only=False):
mod = tir.transform.Simplify()(mod)

if target.arch == "sm_90":
mod = tl.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.MultiVersionBuffer()(mod)
mod = tl.transform.WarpSpecialized()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# mod = tl.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.InjectFenceProxy()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tl.transform.PipelinePlanning()(mod)
Expand Down Expand Up @@ -117,6 +128,7 @@ def lower(func, target="cuda", runtime_only=False):
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tir.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
Expand Down
24 changes: 22 additions & 2 deletions python/tvm/tl/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ def Parallel(*extents: tir.PrimExpr):
return _ffi_api.Parallel(extents) # type: ignore[attr-defined] # pylint: disable=no-member


def Pipelined(start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int = 0):
def Pipelined(
start: tir.PrimExpr,
stop: tir.PrimExpr = None,
num_stages: int = 0,
order: List[int] = None,
stage: List[int] = None,
sync: List[List[int]] = None,
group: List[List[int]] = None
):
"""Tools to construct pipelined for loop.
Parameters
Expand All @@ -66,8 +74,16 @@ def Pipelined(start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int =
start = IntImm(start.dtype, 0)
else:
start = 0
if order is None:
order = []
if stage is None:
stage = []
if sync is None:
sync = []
if group is None:
group = []
# type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined(start, stop, num_stages)
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)


@register_object("tl.KernelLaunchFrame")
Expand Down Expand Up @@ -314,6 +330,10 @@ def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
return reduce(buffer, out, "sum", dim, True)


def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
return reduce(buffer, out, "abssum", dim, True)


def atomic_add(dst, value):
return T.call_extern("handle", "atomicAdd", T.address_of(dst), value)

Expand Down
33 changes: 33 additions & 0 deletions python/tvm/tl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,36 @@ def WarpSpecializedPipeline():
The result pass
"""
return _ffi_api.WarpSpecializedPipeline() # type: ignore


def MultiVersionBuffer():
"""WarpSpecializedPipeline
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MultiVersionBuffer() # type: ignore


def WarpSpecialized():
"""WarpSpecializedPipeline
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.WarpSpecialized() # type: ignore


def InjectFenceProxy():
"""InjectFenceProxy
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectFenceProxy() # type: ignore
17 changes: 15 additions & 2 deletions python/tvm/tl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch

import tvm
from torch.utils.dlpack import to_dlpack
from tvm.runtime import ndarray
from tvm.relay import TensorType
from tvm.contrib.dlpack import to_pytorch_func
from torch.utils.dlpack import to_dlpack
Expand Down Expand Up @@ -142,7 +144,14 @@ def assert_allclose(self, reference_program: callable, atol: float = 1e-8, rtol:
if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs]
assert len(lib_outs) == len(ref_outs)
# torch.set_printoptions(edgeitems=torch.inf)
for lhs, rhs in zip(lib_outs, ref_outs):
# close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol)
# total_elements = lhs.numel()
# num_not_close = (~close_mask).sum().item()
# percentage_not_close = (num_not_close / total_elements) * 100
# print(f"{percentage_not_close:.2f}% of the elements are not close.")
# print(f"Total elements: {total_elements}, Not close elements: {num_not_close}")
assert torch.allclose(lhs, rhs, rtol=rtol, atol=atol), (lhs, rhs)

def assert_consistent(self, repeat=10):
Expand All @@ -155,9 +164,13 @@ def assert_consistent(self, repeat=10):
for lhs, rhs in zip(lib_outs, ref_outs):
assert torch.allclose(lhs, rhs), ["result is not consistent", lhs, rhs]

def run_once(self):
def run_once(self, func=None):
import ctypes
libcuda = ctypes.CDLL("libcuda.so")

ins = self._get_inputs()
return self.__call__(*ins)
if not func:
func = self.__call__

def do_bench(
self,
Expand Down
8 changes: 8 additions & 0 deletions src/tir/transforms/lower_opaque_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ class OpaqueBlockLower : public StmtExprMutator {
return body;
}

Stmt VisitStmt_(const BlockNode* op) final {
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
if (block->annotations.count("stmt_group")) {
return block->body;
}
return block;
}

Stmt VisitStmt_(const ForNode* op) final {
// Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min);
Expand Down
Loading

0 comments on commit ef0837f

Please sign in to comment.