Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M1c] Lower and build TensorIR #8044

Merged
merged 4 commits into from
May 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,10 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func);
* access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
* The LCA may be a For loop or a Block.
* \param func The PrimFunc to be detected.
* \return The Map from buffer to the LCA of all access to it.
* \return The Map from buffer to the LCA of all access to it. The lca is function root if the
* return stmt is NullOpt.
*/
TVM_DLL Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func);
TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def find_libdevice_path(arch):
selected_ver = 0
selected_path = None
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2):
if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2, 11.3):
path = os.path.join(lib_path, "libdevice.10.bc")
else:
for fn in os.listdir(lib_path):
Expand Down
94 changes: 68 additions & 26 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,24 @@
# pylint: disable=invalid-name
"""The build utils in python.
"""

from typing import Union, Optional, List, Mapping
import warnings

import tvm.tir

from tvm.runtime import ndarray
from tvm.ir import container
from tvm.ir import CallingConv
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.target import codegen
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var


def get_binds(args, compact=False, binds=None):
Expand Down Expand Up @@ -119,32 +125,39 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func})


def lower(sch, args, name="main", binds=None, simple_mode=False):
def lower(
inputs: Union[schedule.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
simple_mode: bool = False,
) -> IRModule:
"""Lowering step before build into target.

Parameters
----------
sch : tvm.te.schedule.Schedule
The schedule to be built
input : Union[schedule.Schedule, PrimFunc, IRModule]
The TE schedule or TensorIR PrimFunc/IRModule to be built

args : list of Buffer or Tensor or Var
The argument lists to the function.
args : Optional[List[Union[Buffer, tensor.Tensor, Var]]]
The argument lists to the function for TE schedule.
It should be None if we want to lower TensorIR.

name : str, optional
name : str
The name of result function.

binds : dict of :any:`Tensor` to :any:`Buffer`, optional
binds : Optional[Mapping[tensor.Tensor, Buffer]]
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.

simple_mode : bool, optional
simple_mode : bool
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.

Returns
-------
m : IRModule or Stmt
m : IRModule
The result IRModule, if simple_mode=False
Then the Stmt before make api is returned.
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand All @@ -160,16 +173,38 @@ def lower(sch, args, name="main", binds=None, simple_mode=False):
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]

# Phase 0
if isinstance(sch, schedule.Schedule):
mod = form_irmodule(sch, args, name, binds)
pass_list = lower_phase0
is_legacy_te_schedule: bool = False

if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for lowering from TE schedule")
mod = form_irmodule(inputs, args, name, binds)
is_legacy_te_schedule = True
elif isinstance(inputs, PrimFunc):
func = inputs.with_attr("global_symbol", name)
if pass_ctx.config.get("tir.noalias", True):
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
elif isinstance(inputs, IRModule):
mod = inputs
else:
mod = sch
raise TypeError(
f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}"
)

pass_list = lower_phase0
# Phase 1
if is_legacy_te_schedule:
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
]
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
tvm.tir.transform.LowerInitBlock(),
tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
tvm.tir.transform.ConvertBlocksToOpaque(),
tvm.tir.transform.CompactBufferAllocation(),
tvm.tir.transform.FlattenBuffer(),
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
Expand Down Expand Up @@ -297,22 +332,29 @@ def _build_for_device(input_mod, target, target_host):
return mod_host, rt_mod_dev


def build(inputs, args=None, target=None, target_host=None, name="default_function", binds=None):
def build(
inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
target: Optional[Union[str, Target]] = None,
target_host: Optional[Union[str, Target]] = None,
name: Optional[str] = "default_function",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.

Parameters
----------
inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule
The schedule to be built
inputs : Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]]
The input to be built

args : list of Buffer or Tensor or Var, optional
args : Optional[List[Union[Buffer, tensor.Tensor, Var]]]
The argument lists to the function.

target : str or :any:`tvm.target.Target`, optional
target : Optional[Union[str, Target]]
The target and option of the compilation.

target_host : str or :any:`tvm.target.Target` optional
target_host : Optional[Union[str, Target]]
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
Expand All @@ -321,10 +363,10 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.

name : str, optional
name : Optional[str]
The name of result function.

binds : dict, optional
binds : Optional[Mapping[tensor.Tensor, Buffer]]
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.

Expand Down Expand Up @@ -375,10 +417,10 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
elif isinstance(inputs, (list, tuple, container.Array)):
merged_mod = tvm.IRModule({})
for x in inputs:
merged_mod.update(x)
merged_mod.update(lower(x))
input_mod = merged_mod
elif isinstance(inputs, tvm.IRModule):
input_mod = inputs
elif isinstance(inputs, (tvm.IRModule, PrimFunc)):
input_mod = lower(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
f"Inputs must be Schedule, IRModule or dict of target to IRModule, "
Expand Down
4 changes: 2 additions & 2 deletions src/te/schedule/schedule_postproc_to_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
ICHECK(!extern_buffer.count(tensor));

tir::Buffer buffer = CreateBufferFor(tensor);
tir::Var bptr(buffer->name, DataType::Handle());
tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
extern_buffer[tensor] = buffer;
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
tir::Var bptr(buffer->name, DataType::Handle());
tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
}
Expand Down
14 changes: 9 additions & 5 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,20 @@ namespace tir {
*/
class LCADetector : public StmtExprVisitor {
public:
static Map<Buffer, Stmt> Detect(const PrimFunc& func) {
static Map<Buffer, Optional<Stmt>> Detect(const PrimFunc& func) {
LCADetector detector;
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get());
}

detector(func->body);
// Prepare the return
Map<Buffer, Stmt> buffer_lca;
Map<Buffer, Optional<Stmt>> buffer_lca;
for (const auto& kv : detector.buffer_lca_) {
buffer_lca.Set(GetRef<Buffer>(kv.first), GetRef<Stmt>(kv.second->stmt));
const Buffer& buffer = GetRef<Buffer>(kv.first);
const Optional<Stmt> stmt = kv.second ? GetRef<Optional<Stmt>>(kv.second->stmt) : NullOpt;
buffer_lca.Set(buffer, stmt);
}
return buffer_lca;
}
Expand Down Expand Up @@ -131,7 +134,6 @@ class LCADetector : public StmtExprVisitor {
}

static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
ICHECK(lhs || rhs);
if (lhs == nullptr) return rhs;
if (rhs == nullptr) return lhs;
while (lhs->parent_scope_info != nullptr && //
Expand Down Expand Up @@ -166,7 +168,9 @@ class LCADetector : public StmtExprVisitor {
support::Arena arena_;
};

Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); }
Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func) {
return LCADetector::Detect(func);
}

TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA);
} // namespace tir
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tir {
class BufferAllocationLocator : public StmtExprMutator {
public:
explicit BufferAllocationLocator(const PrimFunc& func) {
Map<Buffer, Stmt> buffer_lca = DetectBufferAccessLCA(func);
Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
std::unordered_set<const BufferNode*> arg_buffers;
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
Expand All @@ -42,7 +42,7 @@ class BufferAllocationLocator : public StmtExprMutator {
// create buffers to be allocated at each stmts
for (const auto& kv : buffer_lca) {
const Buffer& buffer = kv.first;
const StmtNode* stmt = kv.second.get();
const StmtNode* stmt = kv.second.defined()? kv.second.value().get() : nullptr;
if (arg_buffers.count(buffer.get())) {
continue;
}
Expand Down
117 changes: 117 additions & 0 deletions tests/python/unittest/test_lower_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np

import tvm
from tvm import te, tir
from tvm.ir.module import IRModule
from tvm.script import ty
import tvm.testing


def _check_module_with_numpy(mod, shape=(128, 128, 128)):
m, n, k = shape
a = tvm.nd.array(np.random.rand(m, k).astype("float32"))
b = tvm.nd.array(np.random.rand(n, k).astype("float32"))
c = tvm.nd.array(np.zeros((m, n), dtype="float32"))
c_np = np.dot(a.asnumpy(), b.asnumpy().transpose())
mod(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)


# pylint: disable=no-self-argument, missing-class-docstring, missing-function-docstring
@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "init") as [vi, vj]:
C[vi, vj] = tir.float32(0)
for k in range(0, 128):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@tvm.script.tir
class LoweredModule:
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
# body
for x, y in tir.grid(128, 128):
C.data[x * 128 + y] = 0.0
for k in tir.serial(0, 128):
C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load(
"float32", A.data, x * 128 + k
) * tir.load("float32", B.data, y * 128 + k)


def test_lower_build_te_schedule():
m, n, k = 128, 128, 128
axis_k = te.reduce_axis((0, k), "k")
A = te.placeholder((m, k), name="A")
B = te.placeholder((k, n), name="B")
C = te.compute((m, n), lambda x, y: te.sum(A[x, axis_k] * B[y, axis_k], axis=axis_k), name="C")
s = te.create_schedule(C.op)
# check lowering
ir_mod = tvm.lower(s, [A, B, C])
tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
# check building
mod = tvm.build(s, [A, B, C], target="llvm")
_check_module_with_numpy(mod)


def test_lower_build_tir_func():
# check lowering
ir_mod = tvm.lower(matmul)
tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
# check building
mod = tvm.build(matmul, target="llvm")
_check_module_with_numpy(mod)


def test_lower_build_tir_module():
func = matmul.with_attr("global_symbol", "main")
func = func.with_attr("tir.noalias", True)
ir_mod = IRModule({"main": func})
# check lowering
lowered_mod = tvm.lower(ir_mod)
tvm.ir.assert_structural_equal(lowered_mod, LoweredModule())
# check building
mod = tvm.build(ir_mod, target="llvm")
_check_module_with_numpy(mod)


def test_lower_build_lowered_module():
# check lowering
ir_mod = tvm.lower(LoweredModule())
tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
# check building
mod = tvm.build(ir_mod, target="llvm")
_check_module_with_numpy(mod)


if __name__ == "__main__":
test_lower_build_te_schedule()
test_lower_build_tir_func()
test_lower_build_tir_module()
test_lower_build_lowered_module()
Loading