From 2b0119ee1dcf3a1f4c65c62a5696d9c82288f672 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 11 Jun 2022 13:01:32 -0700 Subject: [PATCH] `prim_func` methods (#42) --- python/tvm/script/builder/tir/__init__.py | 2 +- python/tvm/script/builder/tir/op.py | 4 + .../tvm/script/builder/tir/prim_func_frame.py | 71 +++++++++++++++- python/tvm/script/builder/tir/var.py | 2 +- python/tvm/tir/__init__.py | 27 +----- src/script/builder/tir/op.cc | 1 + src/script/builder/tir/op.h | 2 + src/script/builder/tir/prim_func_frame.cc | 85 ++++++++++++++++++- src/script/builder/tir/prim_func_frame.h | 21 +++++ src/script/builder/tir/var.h | 6 +- .../test_builder_basic.py} | 14 ++- 11 files changed, 203 insertions(+), 32 deletions(-) rename tests/python/{unittest/test_tvmscript_builder.py => tvmscript/test_builder_basic.py} (68%) diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py index 4aa0309b8741..2fdee6a46d10 100644 --- a/python/tvm/script/builder/tir/__init__.py +++ b/python/tvm/script/builder/tir/__init__.py @@ -29,6 +29,6 @@ unroll, vectorized, ) -from .prim_func_frame import arg, prim_func +from .prim_func_frame import arg, func_attr, func_ret, prim_func, match_buffer, preflattened_buffer from .var import Buffer from .op import * diff --git a/python/tvm/script/builder/tir/op.py b/python/tvm/script/builder/tir/op.py index d70f1f0a2920..d75e242ba71a 100644 --- a/python/tvm/script/builder/tir/op.py +++ b/python/tvm/script/builder/tir/op.py @@ -109,6 +109,10 @@ def float64(expr): return _ffi_api.PrimType("float64", expr) +def handle(): + return _ffi_api.Handle() + + def min(a, b, span=None): """Compute the minimum value of two expressions. diff --git a/python/tvm/script/builder/tir/prim_func_frame.py b/python/tvm/script/builder/tir/prim_func_frame.py index 525be3b66c2c..59b5ce251743 100644 --- a/python/tvm/script/builder/tir/prim_func_frame.py +++ b/python/tvm/script/builder/tir/prim_func_frame.py @@ -15,11 +15,12 @@ # specific language governing permissions and limitations # under the License. """TVM Script TIR Prim Func Frame""" -from typing import Union +from typing import Union, Dict, Any from tvm._ffi import register_object as _register_object from tvm.tir.buffer import Buffer from tvm.tir.expr import Var +from tvm.ir import Type from ..builder import Builder from . import _ffi_api @@ -40,3 +41,71 @@ def arg(name, obj) -> Union[Var, Buffer]: setattr(prim_func, "dispatch_token", "tir") + + +def func_attr(attrs: Dict[str, Any]) -> None: + return _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore + + +def func_ret(ret_type) -> Type: + return _ffi_api.FuncRet(ret_type) # pylint: disable=no-member # type: ignore + + +def match_buffer( + param, + shape, + dtype="float32", + data=None, + strides=[], + elem_offset=None, + storage_scope="", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, + span=None, +) -> Buffer: + return _ffi_api.MatchBuffer( + param, + shape, + dtype, + data, + strides, + elem_offset, + storage_scope, + align, + offset_factor, + buffer_type, + axis_separators, + span, + ) + + +def preflattened_buffer( + postflattened, + shape, + dtype="float32", + data=None, + strides=[], + elem_offset=None, + storage_scope="", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None, + span=None, +) -> None: + _ffi_api.PreflattenedBuffer( + postflattened, + shape, + dtype, + data, + strides, + elem_offset, + storage_scope, + align, + offset_factor, + buffer_type, + axis_separators, + span, + ) diff --git a/python/tvm/script/builder/tir/var.py b/python/tvm/script/builder/tir/var.py index 4c4163cb941a..3e2c48065be7 100644 --- a/python/tvm/script/builder/tir/var.py +++ b/python/tvm/script/builder/tir/var.py @@ -22,7 +22,7 @@ def Buffer( # pylint: disable=invalid-name shape, - dtype, + dtype="float32", name="buffer", storage_scope="", ) -> tir.Buffer: diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 173f4b8c4dbe..6db93b6ad091 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -44,36 +44,17 @@ from .function import PrimFunc, TensorIntrin, IndexMap -from .op import call_packed, call_intrin, call_pure_extern, call_extern -from .op import ( - call_llvm_intrin, - call_llvm_pure_intrin, - ret, - all, - any, - min_value, - max_value, - trace, -) +from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern +from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import ( - trunc, - abs, - round, - nextafter, - nearbyint, - power, - popcount, - fmod, - if_then_else, -) +from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod -from .op import comm_reducer, min, max, sum, infinity, reinterpret +from .op import comm_reducer, min, max, sum from .op import q_multiply_shift from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/src/script/builder/tir/op.cc b/src/script/builder/tir/op.cc index 777ac8a4a407..6bd5a45f7bbf 100644 --- a/src/script/builder/tir/op.cc +++ b/src/script/builder/tir/op.cc @@ -28,6 +28,7 @@ PrimExpr prim_type(String type_name, PrimExpr expr) { } TVM_REGISTER_GLOBAL("script.builder.tir.PrimType").set_body_typed(prim_type); +TVM_REGISTER_GLOBAL("script.builder.tir.Handle").set_body_typed(handle); TVM_REGISTER_GLOBAL("script.builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return tvm::min(a, b, span); }); diff --git a/src/script/builder/tir/op.h b/src/script/builder/tir/op.h index 9f7a668f8330..9eb8629f73ab 100644 --- a/src/script/builder/tir/op.h +++ b/src/script/builder/tir/op.h @@ -48,6 +48,8 @@ PrimExpr bool_(PrimExpr expr) { return cast(DataType::Bool(), expr); } PrimExpr prim_type(String type_name, PrimExpr expr); +tvm::tir::Var handle() { return tvm::tir::Var("", DataType::Handle()); } + using tvm::cast; using tvm::if_then_else; using tvm::infinity; diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index 9d85e2193c56..d4acc151c571 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -21,6 +21,8 @@ #include +#include "./block_frame.h" + namespace tvm { namespace script { namespace builder { @@ -33,7 +35,9 @@ void PrimFuncFrameNode::ExitWithScope() { PrimFunc func(/*params=*/args, /*body=*/AsStmt(stmts), /*ret_type=*/ret_type, - /*buffer_map=*/buffer_map); + /*buffer_map=*/buffer_map, + /*preflattened_buffer_map=*/preflattened_buffer_map, + /*attrs=*/DictAttrs(attrs)); if (builder->frames.empty()) { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; @@ -52,6 +56,8 @@ PrimFuncFrame PrimFunc_(String name) { n->args.clear(); n->ret_type = TupleType::Empty(); n->buffer_map.clear(); + n->preflattened_buffer_map.clear(); + n->attrs.clear(); return PrimFuncFrame(n); } @@ -72,6 +78,79 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) { return buffer; } +void FuncAttrs(Map attrs) { + using namespace tvm::tir; + PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + frame->attrs = attrs; +} + +tvm::Type FuncRet(tvm::Type ret_type) { + PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + frame->ret_type = ret_type; + return ret_type; +} + +tvm::tir::Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, + Optional data, Array strides, + PrimExpr elem_offset, String storage_scope, int align, + int offset_factor, String buffer_type_str, + Array axis_separators, Span span) { + using namespace tvm::tir; + Var buffer_data; + if (!data.defined()) { + DataType storage_dtype = dtype; + if (storage_dtype == DataType::Bool()) { + storage_dtype = DataType::Int(8); + } + buffer_data = Var("", PointerType(PrimType(storage_dtype), storage_scope), span); + } else { + buffer_data = data.value(); + } + BufferType buffer_type = (buffer_type_str == "auto_broadcast") ? kAutoBroadcast : kDefault; + Buffer buffer(buffer_data, dtype, shape, strides, elem_offset, "", align, offset_factor, + buffer_type, axis_separators, span); + PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + if (const auto* var = param.as()) { + Var v = GetRef(var); + for (auto const& arg : frame->args) { + if (arg.same_as(v)) { + frame->buffer_map.Set(v, buffer); + return buffer; + } + } + LOG(FATAL) << "ValueError: Can not bind non-input param to buffer."; + } else if (const auto* buffer_region = param.as()) { + BlockFrame block_frame = Builder::Current()->FindFrame().value(); + block_frame->match_buffers.push_back( + MatchBufferRegion(buffer, GetRef(buffer_region))); + } else { + LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer."; + } + return buffer; +}; + +void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array shape, + DataType dtype, Optional data, Array strides, + PrimExpr elem_offset, String storage_scope, int align, int offset_factor, + String buffer_type_str, Array axis_separators, Span span) { + using namespace tvm::tir; + PrimFuncFrame frame = Builder::Current()->FindFrame().value(); + for (auto const& p : frame->buffer_map) { + if (p.second.same_as(postflattened_buffer)) { + Var buffer_data = (data.defined()) ? data.value() : frame->buffer_map.at(p.first)->data; + String buffer_name(postflattened_buffer->name + "_preflatten"); + BufferType buffer_type = (buffer_type_str == "auto_broadcast") ? kAutoBroadcast : kDefault; + Buffer buffer(buffer_data, dtype, shape, strides, elem_offset, buffer_name, align, + offset_factor, buffer_type, axis_separators, span); + Namer::Name(buffer, buffer_name); + frame->preflattened_buffer_map.Set(p.first, buffer); + return; + } + } + LOG(FATAL) << "ValueError: postflattened buffer " << postflattened_buffer->name + << " does not exist."; +}; + TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); TVM_REGISTER_GLOBAL("script.builder.tir.PrimFuncFrame").set_body_typed(PrimFunc_); TVM_REGISTER_GLOBAL("script.builder.tir.Arg") @@ -86,6 +165,10 @@ TVM_REGISTER_GLOBAL("script.builder.tir.Arg") LOG(FATAL) << "ValueError: Unexpected type for TIR Arg."; throw; }); +TVM_REGISTER_GLOBAL("script.builder.tir.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.builder.tir.FuncRet").set_body_typed(FuncRet); +TVM_REGISTER_GLOBAL("script.builder.tir.MatchBuffer").set_body_typed(MatchBuffer); +TVM_REGISTER_GLOBAL("script.builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer); } // namespace tir } // namespace builder diff --git a/src/script/builder/tir/prim_func_frame.h b/src/script/builder/tir/prim_func_frame.h index 11a6a564deff..4b03985a0bbd 100644 --- a/src/script/builder/tir/prim_func_frame.h +++ b/src/script/builder/tir/prim_func_frame.h @@ -32,6 +32,8 @@ class PrimFuncFrameNode : public TIRFrameNode { Array args; Type ret_type; Map buffer_map; + Map preflattened_buffer_map; + Map attrs; void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); @@ -39,6 +41,8 @@ class PrimFuncFrameNode : public TIRFrameNode { v->Visit("args", &args); v->Visit("ret_type", &ret_type); v->Visit("buffer_map", &buffer_map); + v->Visit("preflattened_buffer_map", &preflattened_buffer_map); + v->Visit("attrs", &attrs); } static constexpr const char* _type_key = "script.builder.tir.PrimFuncFrame"; @@ -56,6 +60,23 @@ class PrimFuncFrame : public TIRFrame { PrimFuncFrame PrimFunc_(String name); tvm::tir::Var Arg(String name, tvm::tir::Var var); tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer); +void FuncAttrs(Map attrs); +tvm::Type FuncRet(tvm::Type ret_type); + +tvm::tir::Buffer MatchBuffer(ObjectRef param, Array shape, + DataType dtype = DataType::Float(32), + Optional data = NullOpt, Array strides = {}, + PrimExpr elem_offset = PrimExpr(), String storage_scope = "", + int align = -1, int offset_factor = 0, + String buffer_type_str = "default", Array axis_separators = {}, + Span span = Span()); + +void PreflattenedBuffer(tvm::tir::Buffer postflattened_buffer, Array shape, + DataType dtype = DataType::Float(32), + Optional data = NullOpt, Array strides = {}, + PrimExpr elem_offset = PrimExpr(), String storage_scope = "", + int align = -1, int offset_factor = 0, String buffer_type_str = "default", + Array axis_separators = {}, Span span = Span()); } // namespace tir } // namespace builder diff --git a/src/script/builder/tir/var.h b/src/script/builder/tir/var.h index 41257f0ca6d3..81120cadb892 100644 --- a/src/script/builder/tir/var.h +++ b/src/script/builder/tir/var.h @@ -26,9 +26,9 @@ namespace script { namespace builder { namespace tir { -tvm::tir::Buffer Buffer_(Array shape, // - DataType dtype, // - String name = "buffer", // +tvm::tir::Buffer Buffer_(Array shape, // + DataType dtype = DataType::Float(32), // + String name = "buffer", // String storage_scope = ""); } diff --git a/tests/python/unittest/test_tvmscript_builder.py b/tests/python/tvmscript/test_builder_basic.py similarity index 68% rename from tests/python/unittest/test_tvmscript_builder.py rename to tests/python/tvmscript/test_builder_basic.py index 83ab995e76b7..ee50b8a717ec 100644 --- a/tests/python/unittest/test_tvmscript_builder.py +++ b/tests/python/tvmscript/test_builder_basic.py @@ -23,8 +23,18 @@ def test_builder_basic(): with Builder() as b: with T.prim_func(name="main"): - A = T.arg("A", T.Buffer((128, 128, 128), "float32")) - B = T.arg("B", T.Buffer((128, 128, 128), "float32")) + T.func_attr({"global_symbol": "main"}) + arg_a = T.arg("a", T.handle()) + arg_b = T.arg("b", T.handle()) + buffer_c = T.Buffer((128,), "float32") + buffer_d = T.Buffer((128,), "float32") + arg_c = T.arg("c", buffer_c) + arg_d = T.arg("d", buffer_d) + T.func_ret(tvm.ir.PrimType("int8")) + A = def_("A", T.match_buffer(arg_a, (128, 128, 128))) + B = def_("B", T.match_buffer(arg_b, (128, 128, 128))) + T.preflattened_buffer(buffer_c, (128,), data=buffer_c.data) + T.preflattened_buffer(buffer_d, (128,), data=buffer_d.data) with T.grid(128, 128, 128) as (i, j, k): def_many(["i", "j", "k"], [i, j, k]) with T.block(name="block"):