diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 59aa955b0d9e..829de7381f00 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -225,8 +226,8 @@ class BuildConfigNode : public Object { /*! \brief Whether to partition const loop */ bool partition_const_loop = false; - /*! \brief Whether to dump the IR of each pass (only when building from python) */ - std::vector< std::pair > add_lower_pass; + /*! \brief List of passes to be injected into the low-level pipeline. */ + std::vector> add_lower_pass; /*! \brief Whether to dump the IR of each pass (only when building from python) */ bool dump_pass_ir = false; diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 698ddbc68dd7..5ddc5df7d1f5 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs): """Verify the validity of a gpu kernel. This pass will check memory usage and number of threads per block. """ - def verify_pass(stmt): - valid = ir_pass.VerifyGPUCode(stmt, kwargs) + def verify_pass(f, *_): + valid = ir_pass.VerifyGPUCode(f.body, kwargs) if not valid: raise InstantiationError("Skipped because of invalid gpu kernel") - return stmt - return verify_pass + return f + return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 35700badb04b..dcd6d444f02d 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds): return tvm.IRModule({name: func}) -def _wrap_as_prim_func_pass(flist, name): - """Wrap flist as a function pass. - - This is an temporary adapter before we fully - migrate to the new pass manager. - """ - def _transform(func, *_): - stmt = func.body - for f in flist: - stmt = f(stmt) - # create a new function with updated body. - return tvm.tir.PrimFunc(func.params, - stmt, - func.ret_type, - func.buffer_map, - func.attrs) - return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name) - - def lower(sch, args, name="main", @@ -190,15 +171,15 @@ def lower(sch, else: mod = sch + pass_list = lower_phase0 # Phase 1 - pass_list = [ - _wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"), + pass_list += [ tvm.tir.transform.InjectPrefetch(), tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), - _wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"), ] + pass_list += lower_phase1 # Phase 2 if not simple_mode: @@ -214,8 +195,8 @@ def lower(sch, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit), - _wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"), ] + pass_list += lower_phase2 # Phase 3 pass_list += [ @@ -225,7 +206,7 @@ def lower(sch, if not cfg.disable_select_rewriting: pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] - pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")] + pass_list += lower_phase3 # Instrument BoundCheckers if cfg.instrument_bound_checkers: diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 4ec1a71f345e..47ad94f503d8 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -67,3 +67,19 @@ def __init__(self, self.__init_handle_by_constructor__( _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs) + + def with_body(self, new_body): + """Create a new PrimFunc with the same set signatures but a new body. + + Parameters + ---------- + new_body : Stmt + The new body. + + Returns + ------- + new_func : PrimFunc + The created new function. + """ + return PrimFunc( + self.params, new_body, self.ret_type, self.buffer_map, self.attrs) diff --git a/src/target/target.cc b/src/target/target.cc index 50856d62af30..a72ce1c5b3e4 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope") TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass") .set_body([](TVMArgs args, TVMRetValue* ret) { BuildConfig cfg = args[0]; - std::vector< std::pair > add_lower_pass; + std::vector> add_lower_pass; CHECK_EQ(args.size() % 2, 1); for (int i = 1; i < args.size(); i += 2) { add_lower_pass.push_back(std::make_pair( args[i].operator int(), - args[i + 1].operator tvm::runtime::PackedFunc())); + args[i + 1].operator transform::Pass())); } cfg->add_lower_pass = add_lower_pass; }); diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 4f44d2b3043f..b212b26c99a7 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -51,11 +51,13 @@ def expected(): z = relay.add(y, relay.const(c_data)) return relay.Function([x], z) - def fail(x): - raise RuntimeError() + def FailPass(): + def _transform(m, *args): + raise RuntimeError() + return tvm.transform.module_pass(_transform, opt_level=0) # the fold constant should work on any context. - with tvm.target.build_config(add_lower_pass=[(0, fail)]): + with tvm.target.build_config(add_lower_pass=[(0, FailPass())]): with tvm.target.create("cuda"): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 739fc6fda76d..4c2ec2e884bb 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -182,7 +182,7 @@ def test_cuda_shuffle(): sch[c].bind(xo, thrx) sch[c].vectorize(xi) - def my_vectorize(stmt): + def MyVectorize(): def vectorizer(op): if op.for_type == tvm.tir.For.Vectorized: four = tvm.tir.const(4, 'int32') @@ -198,9 +198,13 @@ def vectorizer(op): new_b = tvm.tir.Shuffle(bs, ids) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return None - return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) - with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): + def _transform(f, *_): + return f.with_body( + tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For'])) + return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize") + + with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]): module = tvm.build(sch, [a, b, c], target='cuda') a_ = np.array(list(range(64)), dtype='int32') b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32') diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 44b05c90ff17..26f93478b4fc 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -671,8 +671,7 @@ def test_llvm_shuffle(): c = te.compute((8, ), lambda x: a[x] + b[7-x]) sch = te.create_schedule(c.op) - def my_vectorize(stmt): - + def my_vectorize(): def vectorizer(op): store = op.body idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8) @@ -684,9 +683,13 @@ def vectorizer(op): value = new_a + new_b return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) - return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) + def _transform(f, *_): + return f.with_body( + tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For'])) + + return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") - with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): + with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]): ir = tvm.lower(sch, [a, b, c], simple_mode=True) module = tvm.build(sch, [a, b, c]) a_ = tvm.nd.array(np.arange(1, 9, dtype='int32')) diff --git a/tests/python/unittest/test_tir_pass_verify_gpu_code.py b/tests/python/unittest/test_tir_pass_verify_gpu_code.py index 6e138a29b3e9..091a3749dc74 100644 --- a/tests/python/unittest/test_tir_pass_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_pass_verify_gpu_code.py @@ -19,10 +19,10 @@ from tvm import te def get_verify_pass(valid, **kwargs): - def verify_pass(stmt): - valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs) - return stmt - return verify_pass + def _fverify(f, *_): + valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs) + return f + return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0) def test_shared_memory(): def check_shared_memory(dtype): diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index d35913b1cd83..49e86fdb8e9b 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -117,19 +117,20 @@ def vectorize8(op): return body return None -def vectorize(stmt): +@tvm.tir.transform.prim_func_pass(opt_level=0) +def vectorize(f, mod, ctx): global loops - tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8) + tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8) if not loops: - return stmt + return sf # The last list arugment indicates what kinds of nodes will be transformed. # Thus, in this case only `For` nodes will call `vectorize8` - stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For']) + return f.with_body( + tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For'])) - return stmt ##################################################################### # Glue to Lowering diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index 4c33d36d69b5..40bee86f6451 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -14,25 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-argument +# pylint: disable=unused-argument, invalid-name """VTA specific buildin for runtime.""" import tvm -from . import ir_pass +from . import transform from .environment import get_env -def lift_coproc_scope(x): - """Lift coprocessings cope to the """ - x = ir_pass.lift_alloc_to_scope_begin(x) - x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False) - return x - -def early_rewrite(stmt): +def EarlyRewrite(): """Try to do storage rewrite in early pass.""" - try: - return tvm.tir.ir_pass.StorageRewrite(stmt) - except tvm.error.TVMError: - return stmt + def _transform(mod, ctx): + try: + return tvm.tir.transform.StorageRewrite()(mod) + except tvm.error.TVMError: + return mod + return tvm.transform.module_pass( + _transform, opt_level=0, name="tir.vta.EarlyRewrite") def build_config(debug_flag=0, **kwargs): @@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs): vta_module = tvm.build(s, ...) """ env = get_env() - def add_debug(stmt): + + @tvm.tir.transform.prim_func_pass(opt_level=0) + def add_debug(f, *_): debug = tvm.tir.call_extern( "int32", "VTASetDebugMode", env.dev.command_handle, debug_flag) - return tvm.tir.stmt_seq(debug, stmt) - pass_list = [(0, ir_pass.inject_conv2d_transpose_skip), - (1, ir_pass.inject_dma_intrin), - (1, ir_pass.inject_skip_copy), - (1, ir_pass.annotate_alu_coproc_scope), - (1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)), - (1, lift_coproc_scope), - (1, ir_pass.inject_coproc_sync), - (1, early_rewrite)] + return f.with_body(tvm.tir.stmt_seq(debug, f.body)) + + + pass_list = [(0, transform.InjectConv2DTransposeSkip()), + (1, transform.InjectDMAIntrin()), + (1, transform.InjectSkipCopy()), + (1, transform.AnnotateALUCoProcScope()), + (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")), + (1, transform.LiftAllocToScopeBegin()), + (1, tvm.tir.transform.LiftAttrScope("coproc_scope")), + (1, transform.InjectCoProcSync()), + (1, EarlyRewrite())] if debug_flag: pass_list.append((1, add_debug)) - pass_list.append((2, ir_pass.inject_alu_intrin)) - pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo)) - pass_list.append((3, ir_pass.fold_uop_loop)) - pass_list.append((3, ir_pass.cpu_access_rewrite)) + pass_list.append((2, transform.InjectALUIntrin())) + pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo())) + pass_list.append((3, transform.FoldUopLoop())) + pass_list.append((3, transform.CPUAccessRewrite())) return tvm.target.build_config(add_lower_pass=pass_list, **kwargs) diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py deleted file mode 100644 index 9836d133ceb7..000000000000 --- a/vta/python/vta/ir_pass.py +++ /dev/null @@ -1,995 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Additional IR Pass for VTA""" -# pylint: disable=len-as-condition, no-else-return -import tvm -from tvm import te -from topi import util - -from .environment import get_env - - -def _match_pragma(stmt, key): - """Internal helper to match stmt to pragma stmt. - - Parameters - ---------- - stmt : Stmt - The AttrStmt - - key : str - The pragma key - """ - return ((stmt.attr_key == "pragma_" + key) or - (stmt.attr_key == "pragma_scope" and stmt.value.value == key)) - - -def fold_uop_loop(stmt_in): - """Detect and fold uop loop. - - VTA support uop programming model - that recognizes loop structure. - This pass detect the loop structure - and extract that into uop loop AST. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Output statement. - """ - env = get_env() - - def _fold_outermost_loop(body): - stmt = body - if not isinstance(stmt, tvm.tir.For): - return None, body, None - - loop_var = stmt.loop_var - gemm_offsets = [None, None, None] - fail = [False] - - def _post_order(op): - assert isinstance(op, tvm.tir.Call) - base_args = 2 - if op.name == "VTAUopPush": - args = [] - args += op.args[:base_args] - for i in range(3): - m = tvm.arith.detect_linear_equation( - op.args[i + base_args], [loop_var]) - if not m: - fail[0] = True - return op - if gemm_offsets[i] is not None: - if not tvm.ir.structural_equal(m[0], gemm_offsets[i]): - fail[0] = True - return op - args.append(m[1]) - else: - gemm_offsets[i] = m[0] - args.append(m[1]) - args += op.args[base_args+3:] - return tvm.tir.call_extern("int32", "VTAUopPush", *args) - if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): - raise RuntimeError("unexpected op %s" % op) - return op - - ret = tvm.tir.ir_pass.IRTransform( - stmt.body, None, _post_order, ["Call"]) - - if not fail[0] and all(x is not None for x in gemm_offsets): - def _visit(op): - if op.same_as(loop_var): - fail[0] = True - tvm.tir.ir_pass.PostOrderVisit(ret, _visit) - if not fail[0]: - begin = tvm.tir.call_extern( - "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) - end = tvm.tir.call_extern("int32", "VTAUopLoopEnd") - return [begin, ret, end] - raise ValueError("Failed to fold the GEMM instructions..") - - def _do_fold(stmt): - if (stmt.attr_key == "coproc_uop_scope" and - isinstance(stmt.value, tvm.tir.StringImm) and - stmt.value.value == env.dev.vta_push_uop.value): - body = stmt.body - begins = [] - ends = [] - try: - begin, body, end = _fold_outermost_loop(body) - if begin is not None: - begins.append(begin) - if end is not None: - ends.append(end) - begin, body, end = _fold_outermost_loop(body) - if begin is not None: - begins.append(begin) - if end is not None: - ends.append(end) - except ValueError: - pass - if body == stmt.body: - return stmt - ends = list(reversed(ends)) - body = tvm.tir.stmt_seq(*(begins + [body] + ends)) - return tvm.tir.AttrStmt( - stmt.node, stmt.attr_key, stmt.value, body) - return None - out = tvm.tir.ir_pass.IRTransform( - stmt_in, _do_fold, None, ["AttrStmt"]) - return out - - -def cpu_access_rewrite(stmt_in): - """Detect CPU access to VTA buffer and get address correctly. - - VTA's buffer is an opaque handle that do not - correspond to address in CPU. - This pass detect CPU access and rewrite to use pointer - returned VTABufferCPUPtr for CPU access. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - rw_info = {} - def _post_order(op): - if isinstance(op, tvm.tir.Allocate): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - return None - new_var = rw_info[buffer_var] - let_stmt = tvm.tir.LetStmt( - new_var, tvm.tir.call_extern( - "handle", "VTABufferCPUPtr", - env.dev.command_handle, - buffer_var), op.body) - alloc = tvm.tir.Allocate( - buffer_var, op.dtype, op.extents, - op.condition, let_stmt) - del rw_info[buffer_var] - return alloc - if isinstance(op, tvm.tir.Load): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - rw_info[buffer_var] = te.var( - buffer_var.name + "_ptr", "handle") - new_var = rw_info[buffer_var] - return tvm.tir.Load(op.dtype, new_var, op.index) - if isinstance(op, tvm.tir.Store): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - rw_info[buffer_var] = te.var( - buffer_var.name + "_ptr", "handle") - new_var = rw_info[buffer_var] - return tvm.tir.Store(new_var, op.value, op.index) - raise RuntimeError("not reached") - stmt = tvm.tir.ir_pass.IRTransform( - stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) - for buffer_var, new_var in rw_info.items(): - stmt = tvm.tir.LetStmt( - new_var, tvm.tir.call_extern( - "handle", "VTABufferCPUPtr", - env.dev.command_handle, - buffer_var), stmt) - return stmt - - -def lift_alloc_to_scope_begin(stmt_in): - """Lift allocate to beginning of the current scope. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - lift_stmt = [[]] - def _merge_block(slist, body): - for op in slist: - if op.body == body: - body = op - elif isinstance(op, tvm.tir.Allocate): - body = tvm.tir.Allocate( - op.buffer_var, op.dtype, - op.extents, op.condition, body) - elif isinstance(op, tvm.tir.AttrStmt): - body = tvm.tir.AttrStmt( - op.node, op.attr_key, op.value, body) - elif isinstance(op, tvm.tir.For): - body = tvm.tir.For( - op.loop_var, op.min, op.extent, op.for_type, - op.device_api, body) - else: - raise RuntimeError("unexpected op") - del slist[:] - return body - - def _pre_order(op): - if isinstance(op, tvm.tir.For): - lift_stmt.append([]) - elif isinstance(op, tvm.tir.AttrStmt): - if op.attr_key == "virtual_thread": - lift_stmt.append([]) - - def _post_order(op): - if isinstance(op, tvm.tir.Allocate): - lift_stmt[-1].append(op) - return op.body - if isinstance(op, tvm.tir.AttrStmt): - if op.attr_key == "storage_scope": - lift_stmt[-1].append(op) - return op.body - if op.attr_key == "virtual_thread": - return _merge_block(lift_stmt.pop() + [op], op.body) - return op - if isinstance(op, tvm.tir.For): - return _merge_block(lift_stmt.pop() + [op], op.body) - raise RuntimeError("not reached") - stmt = tvm.tir.ir_pass.IRTransform( - stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) - assert len(lift_stmt) == 1 - return _merge_block(lift_stmt[0], stmt) - - -def inject_skip_copy(stmt_in): - """Pass to inject skip copy stmt, used for debug purpose. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - def _do_fold(stmt): - if _match_pragma(stmt, "skip_dma_copy"): - return tvm.tir.Evaluate(0) - return None - return tvm.tir.ir_pass.IRTransform( - stmt_in, _do_fold, None, ["AttrStmt"]) - - -def inject_coproc_sync(stmt_in): - """Pass to inject skip copy stmt, used in debug. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - success = [False] - def _do_fold(stmt): - if _match_pragma(stmt, "coproc_sync"): - success[0] = True - sync = tvm.tir.Call( - "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0) - return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)]) - if _match_pragma(stmt, "trim_loop"): - op = stmt.body - assert isinstance(op, tvm.tir.For) - return tvm.tir.For( - op.loop_var, op.min, 2, op.for_type, - op.device_api, op.body) - return None - stmt = tvm.tir.ir_pass.IRTransform( - stmt_in, None, _do_fold, ["AttrStmt"]) - stmt = tvm.tir.ir_pass.CoProcSync(stmt) - return stmt - - -def inject_dma_intrin(stmt_in): - """Pass to inject DMA copy intrinsics. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - def _check_compact(buf): - ndim = len(buf.shape) - size = tvm.tir.const(1, buf.shape[0].dtype) - for i in reversed(range(ndim)): - if not util.equal_const_int(size - buf.strides[i], 0): - raise RuntimeError( - "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)) - size = size * buf.shape[i] - - def _fold_buffer_dim(buf, scope, elem_block): - ndim = len(buf.shape) - x_size = 1 - base = 0 - for i in range(1, ndim + 1): - if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0): - raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block)) - x_size = x_size * buf.shape[ndim - i] - if util.equal_const_int(x_size - elem_block, 0): - base = i + 1 - break - if base == 0: - raise RuntimeError("scope %s need to have block=%d, shape=%s" % ( - scope, elem_block, buf.shape)) - shape = [elem_block] - strides = [1] - - if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block): - shape.append(1) - strides.append(elem_block) - - analyzer = tvm.arith.Analyzer() - while base < ndim + 1: - x_size = 1 - x_stride = buf.strides[ndim - base] - next_base = base - if not util.equal_const_int(idxm(x_stride, elem_block), 0): - raise RuntimeError( - "scope %s need to have block=%d, shape=%s, strides=%s" % ( - scope, elem_block, buf.shape, buf.strides)) - for i in range(base, ndim + 1): - k = ndim - i - if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0): - break - x_size = x_size * buf.shape[k] - next_base = i + 1 - shape.append(analyzer.simplify(x_size)) - strides.append(x_stride) - assert next_base != base - base = next_base - - strides = list(reversed(strides)) - shape = list(reversed(shape)) - return shape, strides - - def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): - elem_block = elem_bytes * 8 // elem_width - if buf.dtype != dtype: - raise RuntimeError("Expect buffer type to be %s instead of %s" % - (dtype, buf.dtype)) - shape, strides = buf.shape, buf.strides - if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0): - raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) - if allow_fold: - shape, strides = _fold_buffer_dim(buf, scope, elem_block) - else: - shape = list(x for x in shape) - strides = list(x for x in strides) - - def raise_error(): - """Internal function to raise error """ - raise RuntimeError( - ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" + - " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides)) - - ndim = len(shape) - - # Check if the inner-tensor is already flat - flat = util.equal_const_int(shape[-1], elem_block) - - if flat: - if not util.equal_const_int(strides[-1], 1): - raise_error() - - if ndim == 1: - x_size = 1 - x_stride = 1 - y_size = 1 - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - if not util.equal_const_int(strides[-2] - elem_block, 0): - raise_error() - - if ndim == 2: - x_size = shape[-2] - x_stride = shape[-2] - y_size = 1 - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - if not util.equal_const_int(idxm(strides[-3], elem_block), 0): - raise_error() - - if ndim == 3: - x_size = shape[-2] - x_stride = idxd(strides[-3], elem_block) - y_size = shape[-3] - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - - else: - if not util.equal_const_int(strides[-1], 1): - raise_error() - if not util.equal_const_int(strides[-2] - shape[-1], 0): - raise_error() - if not util.equal_const_int(shape[-1] * shape[-2], elem_block): - raise_error() - - if ndim == 2: - x_size = 1 - x_stride = 1 - y_size = 1 - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - if not util.equal_const_int(strides[-3], elem_block): - raise_error() - - if ndim == 3: - x_size = shape[-3] - x_stride = shape[-3] - y_size = 1 - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - if not util.equal_const_int(idxm(strides[-4], elem_block), 0): - raise_error() - - if ndim == 4: - x_size = shape[-3] - x_stride = idxd(strides[-4], elem_block) - y_size = shape[-4] - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - - raise_error() - - - def _inject_copy(src, dst, pad_before, pad_after, pad_value): - # FIXME: pad_value is ignored... - _ = pad_value - if dst.scope == "global": - # Store - if pad_before or pad_after: - raise RuntimeError("Do not support copy into DRAM with pad") - if src.scope == env.acc_scope: - elem_width = env.OUT_WIDTH - elem_bytes = env.OUT_ELEM_BYTES - mem_type = env.dev.MEM_ID_OUT - data_type = "int%d" % env.OUT_WIDTH - task_qid = env.dev.QID_STORE_OUT - else: - raise RuntimeError("Do not support copy %s->dram" % (src.scope)) - _check_compact(src) - x_size, y_size, x_stride, offset = _get_2d_pattern( - dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True) - irb = tvm.tir.ir_builder.create() - irb.scope_attr(env.dev.vta_axis, "coproc_scope", - env.dev.get_task_qid(task_qid)) - irb.emit(tvm.tir.call_extern( - "int32", "VTAStoreBuffer2D", - env.dev.command_handle, - src.access_ptr("r", "int32"), - mem_type, dst.data, offset, x_size, y_size, x_stride)) - return irb.get() - elif src.scope == "global": - if dst.scope == env.acc_scope: - elem_width = env.ACC_WIDTH - elem_bytes = env.ACC_ELEM_BYTES - mem_type = env.dev.MEM_ID_ACC - data_type = "int%d" % env.ACC_WIDTH - task_qid = env.dev.QID_LOAD_OUT - elif dst.scope == env.inp_scope: - elem_width = env.INP_WIDTH - elem_bytes = env.INP_ELEM_BYTES - mem_type = env.dev.MEM_ID_INP - data_type = "int%d" % env.INP_WIDTH - task_qid = env.dev.QID_LOAD_INP - elif dst.scope == env.wgt_scope: - elem_width = env.WGT_WIDTH - elem_bytes = env.WGT_ELEM_BYTES - mem_type = env.dev.MEM_ID_WGT - data_type = "int%d" % env.WGT_WIDTH - task_qid = env.dev.QID_LOAD_WGT - else: - raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) - # collect pad statistics - if pad_before: - assert pad_after - ndim = len(pad_before) - if ndim <= 2 or ndim > 5: - raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim) - if ndim == 5: - # This case occurs when batch size N > 1 - y_pad_before = pad_before[1] - x_pad_before = pad_before[2] - y_pad_after = pad_after[1] - x_pad_after = pad_after[2] - for dim in range(3, ndim): - if not util.equal_const_int(pad_before[dim], 0): - raise ValueError("Do not support pad on the innermost block") - if not util.equal_const_int(pad_after[dim], 0): - raise ValueError("Do not support pad on the innermost block") - else: - y_pad_before = pad_before[0] - x_pad_before = pad_before[1] - y_pad_after = pad_after[0] - x_pad_after = pad_after[1] - for dim in range(2, ndim): - if not util.equal_const_int(pad_before[dim], 0): - raise ValueError("Do not support pad on the innermost block") - if not util.equal_const_int(pad_after[dim], 0): - raise ValueError("Do not support pad on the innermost block") - allow_fold = False - else: - x_pad_before = 0 - y_pad_before = 0 - x_pad_after = 0 - y_pad_after = 0 - allow_fold = True - - _check_compact(dst) - x_size, y_size, x_stride, offset = _get_2d_pattern( - src, elem_width, elem_bytes, data_type, - dst.scope, allow_fold=allow_fold) - - irb = tvm.tir.ir_builder.create() - irb.scope_attr(env.dev.vta_axis, "coproc_scope", - env.dev.get_task_qid(task_qid)) - - irb.emit(tvm.tir.call_extern( - "int32", "VTALoadBuffer2D", - env.dev.command_handle, - src.data, offset, x_size, y_size, x_stride, - x_pad_before, y_pad_before, - x_pad_after, y_pad_after, - dst.access_ptr("r", "int32"), mem_type)) - return irb.get() - - else: - raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) - - return tvm.tir.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy) - - -def _get_gemm_intrin_buffer(): - env = get_env() - wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH - assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN - wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) - assert wgt_shape[0] * wgt_shape[1] == wgt_lanes - inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH - assert inp_lanes == env.BATCH * env.BLOCK_IN - inp_shape = (env.BATCH, env.BLOCK_IN) - assert inp_shape[0] * inp_shape[1] == inp_lanes - out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH - assert out_lanes == env.BATCH * env.BLOCK_OUT - out_shape = (env.BATCH, env.BLOCK_OUT) - assert out_shape[0] * out_shape[1] == out_lanes - wgt = te.placeholder((wgt_shape[0], wgt_shape[1]), - dtype="int%d" % env.WGT_WIDTH, - name=env.wgt_scope) - inp = te.placeholder((inp_shape[0], inp_shape[1]), - dtype="int%d" % env.INP_WIDTH, - name=env.inp_scope) - k = te.reduce_axis((0, wgt_shape[1]), name="k") - out_dtype = "int%d" % env.ACC_WIDTH - out = te.compute((out_shape[0], out_shape[1]), - lambda i, j: te.sum(inp[i, k].astype(out_dtype) * - wgt[j, k].astype(out_dtype), - axis=[k]), - name="out") - wgt_layout = tvm.tir.decl_buffer( - wgt.shape, wgt.dtype, env.wgt_scope, - scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) - inp_layout = tvm.tir.decl_buffer( - inp.shape, inp.dtype, env.inp_scope, - scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) - out_layout = tvm.tir.decl_buffer( - out.shape, out.dtype, env.acc_scope, - scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) - - return wgt_layout, inp_layout, out_layout - - -def inject_conv2d_transpose_skip(stmt_in): - """Pass to skip 0-weights in conv2d transpose with stride > 1. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - dwgt, dinp, dout = _get_gemm_intrin_buffer() - - calls = [] - selects = [] - - def _find_basics(op): - if isinstance(op, tvm.tir.BufferLoad): - calls.append(op) - elif isinstance(op, tvm.tir.Select): - selects.append(op) - - def _do_fold(op): - if _match_pragma(op, "conv2d_transpose_gemm"): - is_init = ".init" in str(op) - tvm.tir.ir_pass.PostOrderVisit(op, _find_basics) - - if is_init: - # create inner most block - irb = tvm.tir.ir_builder.create() - dev = env.dev - irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) - irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", - 0, 1, - dout.access_ptr("rw", "int32"), - 0, 0, - 0, 0, 0)) - inner = irb.get() - # TODO(@tmoreau89): This is only a temporary fix, please take a look. - body = op.body.body - while isinstance(body, tvm.tir.IfThenElse): - body = body.then_case - args = body.indices - res_buffer = body.buffer - tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) - inner = tvm.tir.AttrStmt( - [dout, res_buffer], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - return inner - else: - conv_call, data_call, kernel_call = calls[-3:] - pad_data_tensor = data_call.buffer - kernel_tensor = kernel_call.buffer - res_tensor = conv_call.buffer - - if selects: - condition = selects[0].condition - else: - condition = tvm.tir.const(1, 'int') - - # create inner most block - irb = tvm.tir.ir_builder.create() - with irb.if_scope(condition): - dev = env.dev - irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) - irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", - 0, 0, - dout.access_ptr("rw", "int32"), - dinp.access_ptr("r", "int32"), - dwgt.access_ptr("r", "int32"), - 0, 0, 0)) - inner = irb.get() - - args = conv_call.indices - tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], - 1, 0, 1, 0, env.BLOCK_OUT) - inner = tvm.tir.AttrStmt( - [dout, res_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - args = kernel_call.indices - tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], - 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) - inner = tvm.tir.AttrStmt( - [dwgt, kernel_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - args = data_call.indices - tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], - 1, 0, 1, 0, env.BLOCK_IN) - inner = tvm.tir.AttrStmt( - [dinp, pad_data_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - return inner - return None - ret = tvm.tir.ir_pass.IRTransform( - stmt_in, _do_fold, None, ["AttrStmt"]) - return ret - - -def annotate_alu_coproc_scope(stmt_in): - """Pass to insert ALU instruction. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - def _do_fold(stmt): - if _match_pragma(stmt, "alu"): - irb = tvm.tir.ir_builder.create() - irb.scope_attr(env.dev.vta_axis, "coproc_scope", - env.dev.get_task_qid(env.dev.QID_COMPUTE)) - irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope", - tvm.tir.StringImm("VTAPushALUOp")) - irb.emit(stmt) - return irb.get() - if _match_pragma(stmt, "skip_alu"): - return tvm.tir.Evaluate(0) - return stmt - - stmt_out = tvm.tir.ir_pass.IRTransform( - stmt_in, None, _do_fold, ["AttrStmt"]) - - return stmt_out - - -def inject_alu_intrin(stmt_in): - """Pass to inject ALU micro-ops. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - idxm = tvm.tir.indexmod - analyzer = tvm.arith.Analyzer() - - def _do_fold(stmt): - def _equal(x, y): - return tvm.ir.structural_equal(analyzer.simplify(x - y), 0) - - def _flatten_loop(src_coeff, dst_coeff, extents): - src_coeff = list(src_coeff) - dst_coeff = list(dst_coeff) - extents = list(extents) - rev_src_coeff = [src_coeff.pop()] - rev_dst_coeff = [dst_coeff.pop()] - rev_extents = [] - assert src_coeff - vsrc = src_coeff.pop() - vdst = dst_coeff.pop() - vext = extents.pop() - while src_coeff: - next_src = src_coeff.pop() - next_dst = dst_coeff.pop() - next_ext = extents.pop() - - if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): - vext = analyzer.simplify(vext * next_ext) - else: - rev_src_coeff.append(vsrc) - rev_dst_coeff.append(vdst) - rev_extents.append(vext) - vsrc = next_src - vdst = next_dst - vext = next_ext - rev_src_coeff.append(vsrc) - rev_dst_coeff.append(vdst) - rev_extents.append(vext) - rev_src_coeff.reverse() - rev_dst_coeff.reverse() - rev_extents.reverse() - - return rev_src_coeff, rev_dst_coeff, rev_extents - - if _match_pragma(stmt, "alu"): - # Get to the innermost loop body - loop_body = stmt.body - nest_size = 0 - while isinstance(loop_body, tvm.tir.For): - loop_body = loop_body.body - nest_size += 1 - # Get the src/dst arguments - dst_var = loop_body.buffer_var - dst_idx = loop_body.index - # Derive loop variables and extents - tmp_body = stmt.body - indices = [] - extents = [] - for _ in range(nest_size): - indices.append(tmp_body.loop_var) - extents.append(tmp_body.extent) - tmp_body = tmp_body.body - # Derive opcode - if isinstance(loop_body.value, tvm.tir.Add): - alu_opcode = env.dev.ALU_OPCODE_ADD - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Sub): - alu_opcode = env.dev.ALU_OPCODE_SUB - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Mul): - alu_opcode = env.dev.ALU_OPCODE_MUL - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Min): - alu_opcode = env.dev.ALU_OPCODE_MIN - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Max): - alu_opcode = env.dev.ALU_OPCODE_MAX - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Call): - if loop_body.value.name == 'shift_left': - alu_opcode = env.dev.ALU_OPCODE_SHR - lhs = loop_body.value.args[0] - rhs = analyzer.simplify(-loop_body.value.args[1]) - elif loop_body.value.name == 'shift_right': - alu_opcode = env.dev.ALU_OPCODE_SHR - lhs = loop_body.value.args[0] - rhs = loop_body.value.args[1] - else: - raise RuntimeError( - "Function call not recognized %s" % (loop_body.value.name)) - elif isinstance(loop_body.value, tvm.tir.Load): - alu_opcode = env.dev.ALU_OPCODE_SHR - lhs = loop_body.value - rhs = tvm.tir.const(0, "int32") - else: - raise RuntimeError( - "Expression not recognized %s, %s, %s" % ( - type(loop_body.value), str(loop_body.value), str(stmt))) - - # Derive array index coefficients - dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices) - # Check if lhs/rhs is immediate - use_imm = False - imm_val = None - if isinstance(rhs, tvm.tir.IntImm): - assert lhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) - use_imm = True - imm_val = rhs - if isinstance(lhs, tvm.tir.IntImm): - assert rhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) - use_imm = True - imm_val = lhs - if imm_val is None: - imm_val = 0 - assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) - src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) - src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) - # Determine which side has the same coefficients - lhs_equal = True - rhs_equal = True - for i, coef in enumerate(dst_coeff): - if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]): - lhs_equal = False - if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]): - rhs_equal = False - # Make sure at least one of the source is identical to the - # destination (in-place computation) - assert lhs_equal or rhs_equal - # Assign the source coefficients - if lhs_equal: - src_coeff = src_rhs_coeff - else: - src_coeff = src_lhs_coeff - - # Ensure that we have the proper tensor dimensions in the - # innermost loop (pattern match) - src_coeff = list(src_coeff) - dst_coeff = list(dst_coeff) - extents = list(extents) - assert len(src_coeff) > 1 - assert len(dst_coeff) > 1 - assert len(extents) != 0 - assert tvm.ir.structural_equal( - analyzer.simplify( - idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) - assert tvm.ir.structural_equal( - analyzer.simplify( - idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) - assert tvm.ir.structural_equal(src_coeff[-2], 1) - assert tvm.ir.structural_equal(dst_coeff[-2], 1) - if env.BATCH > 1: - assert len(src_coeff) > 2 - assert len(dst_coeff) > 2 - assert len(extents) > 1 - assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT) - assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT) - - # Apply tensorization of the loop coefficients - src_offset = src_coeff[-1] - dst_offset = dst_coeff[-1] - if env.BATCH == 1: - src_coeff = src_coeff[:-2] - dst_coeff = dst_coeff[:-2] - extents = extents[:-1] - else: - src_coeff = src_coeff[:-3] - dst_coeff = dst_coeff[:-3] - extents = extents[:-2] - src_coeff.append(src_offset) - dst_coeff.append(dst_offset) - src_coeff = [ - analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] - dst_coeff = [ - analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] - - # Flatten the outer loops - if extents: - src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents) - - # Insert ALU micro-ops - irb = tvm.tir.ir_builder.create() - for idx, extent in enumerate(extents): - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopLoopBegin", - extent, dst_coeff[idx], src_coeff[idx], 0)) - use_imm = int(use_imm) - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopPush", - 1, 0, - dst_coeff[len(dst_coeff)-1], - src_coeff[len(src_coeff)-1], - 0, - alu_opcode, use_imm, imm_val)) - for extent in extents: - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopLoopEnd")) - return irb.get() - return stmt - - stmt_out = tvm.tir.ir_pass.IRTransform( - stmt_in, None, _do_fold, ["AttrStmt"]) - return stmt_out - - -def debug_print(stmt): - """A debug pass that print the stmt - - Parameters - ---------- - stmt : Stmt - The input statement - - Returns - ------- - stmt : Stmt - The - """ - # pylint: disable=superfluous-parens - print(stmt) - return stmt diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py new file mode 100644 index 000000000000..f930b3f1e59c --- /dev/null +++ b/vta/python/vta/transform.py @@ -0,0 +1,962 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Additional Transformation Passes. for VTA""" +# pylint: disable=len-as-condition, no-else-return, unused-argument, invalid-name +import tvm +from tvm import te +from topi import util + +from .environment import get_env + + +def _match_pragma(stmt, key): + """Internal helper to match stmt to pragma stmt. + + Parameters + ---------- + stmt : Stmt + The AttrStmt + + key : str + The pragma key + """ + return ((stmt.attr_key == "pragma_" + key) or + (stmt.attr_key == "pragma_scope" and stmt.value.value == key)) + + +def FoldUopLoop(): + """Detect and fold uop loop. + + VTA support uop programming model + that recognizes loop structure. + This pass detect the loop structure + and extract that into uop loop AST. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _fold_outermost_loop(body): + stmt = body + if not isinstance(stmt, tvm.tir.For): + return None, body, None + + loop_var = stmt.loop_var + gemm_offsets = [None, None, None] + fail = [False] + + def _post_order(op): + assert isinstance(op, tvm.tir.Call) + base_args = 2 + if op.name == "VTAUopPush": + args = [] + args += op.args[:base_args] + for i in range(3): + m = tvm.arith.detect_linear_equation( + op.args[i + base_args], [loop_var]) + if not m: + fail[0] = True + return op + if gemm_offsets[i] is not None: + if not tvm.ir.structural_equal(m[0], gemm_offsets[i]): + fail[0] = True + return op + args.append(m[1]) + else: + gemm_offsets[i] = m[0] + args.append(m[1]) + args += op.args[base_args+3:] + return tvm.tir.call_extern("int32", "VTAUopPush", *args) + if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): + raise RuntimeError("unexpected op %s" % op) + return op + + ret = tvm.tir.ir_pass.IRTransform( + stmt.body, None, _post_order, ["Call"]) + + if not fail[0] and all(x is not None for x in gemm_offsets): + def _visit(op): + if op.same_as(loop_var): + fail[0] = True + tvm.tir.ir_pass.PostOrderVisit(ret, _visit) + if not fail[0]: + begin = tvm.tir.call_extern( + "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) + end = tvm.tir.call_extern("int32", "VTAUopLoopEnd") + return [begin, ret, end] + raise ValueError("Failed to fold the GEMM instructions..") + + def _do_fold(stmt): + env = get_env() + if (stmt.attr_key == "coproc_uop_scope" and + isinstance(stmt.value, tvm.tir.StringImm) and + stmt.value.value == env.dev.vta_push_uop.value): + body = stmt.body + begins = [] + ends = [] + try: + begin, body, end = _fold_outermost_loop(body) + if begin is not None: + begins.append(begin) + if end is not None: + ends.append(end) + begin, body, end = _fold_outermost_loop(body) + if begin is not None: + begins.append(begin) + if end is not None: + ends.append(end) + except ValueError: + pass + if body == stmt.body: + return stmt + ends = list(reversed(ends)) + body = tvm.tir.stmt_seq(*(begins + [body] + ends)) + return tvm.tir.AttrStmt( + stmt.node, stmt.attr_key, stmt.value, body) + return None + + def _ftransform(f, mod, ctx): + return f.with_body(tvm.tir.ir_pass.IRTransform( + f.body, _do_fold, None, ["AttrStmt"])) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.FoldUopLoop") + + +def CPUAccessRewrite(): + """Detect CPU access to VTA buffer and get address correctly. + + VTA's buffer is an opaque handle that do not + correspond to address in CPU. + This pass detect CPU access and rewrite to use pointer + returned VTABufferCPUPtr for CPU access. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(f, mod, ctx): + rw_info = {} + env = get_env() + def _post_order(op): + if isinstance(op, tvm.tir.Allocate): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + return None + new_var = rw_info[buffer_var] + let_stmt = tvm.tir.LetStmt( + new_var, tvm.tir.call_extern( + "handle", "VTABufferCPUPtr", + env.dev.command_handle, + buffer_var), op.body) + alloc = tvm.tir.Allocate( + buffer_var, op.dtype, op.extents, + op.condition, let_stmt) + del rw_info[buffer_var] + return alloc + if isinstance(op, tvm.tir.Load): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + rw_info[buffer_var] = te.var( + buffer_var.name + "_ptr", "handle") + new_var = rw_info[buffer_var] + return tvm.tir.Load(op.dtype, new_var, op.index) + if isinstance(op, tvm.tir.Store): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + rw_info[buffer_var] = te.var( + buffer_var.name + "_ptr", "handle") + new_var = rw_info[buffer_var] + return tvm.tir.Store(new_var, op.value, op.index) + raise RuntimeError("not reached") + + stmt_in = f.body + stmt = tvm.tir.ir_pass.IRTransform( + stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) + + for buffer_var, new_var in rw_info.items(): + stmt = tvm.tir.LetStmt( + new_var, tvm.tir.call_extern( + "handle", "VTABufferCPUPtr", + env.dev.command_handle, + buffer_var), stmt) + return f.with_body(stmt) + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite") + + +def LiftAllocToScopeBegin(): + """Lift allocate to beginning of the current scope. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(f, mod, ctx): + lift_stmt = [[]] + def _merge_block(slist, body): + for op in slist: + if op.body == body: + body = op + elif isinstance(op, tvm.tir.Allocate): + body = tvm.tir.Allocate( + op.buffer_var, op.dtype, + op.extents, op.condition, body) + elif isinstance(op, tvm.tir.AttrStmt): + body = tvm.tir.AttrStmt( + op.node, op.attr_key, op.value, body) + elif isinstance(op, tvm.tir.For): + body = tvm.tir.For( + op.loop_var, op.min, op.extent, op.for_type, + op.device_api, body) + else: + raise RuntimeError("unexpected op") + del slist[:] + return body + + def _pre_order(op): + if isinstance(op, tvm.tir.For): + lift_stmt.append([]) + elif isinstance(op, tvm.tir.AttrStmt): + if op.attr_key == "virtual_thread": + lift_stmt.append([]) + + def _post_order(op): + if isinstance(op, tvm.tir.Allocate): + lift_stmt[-1].append(op) + return op.body + if isinstance(op, tvm.tir.AttrStmt): + if op.attr_key == "storage_scope": + lift_stmt[-1].append(op) + return op.body + if op.attr_key == "virtual_thread": + return _merge_block(lift_stmt.pop() + [op], op.body) + return op + if isinstance(op, tvm.tir.For): + return _merge_block(lift_stmt.pop() + [op], op.body) + raise RuntimeError("not reached") + stmt_in = f.body + stmt = tvm.tir.ir_pass.IRTransform( + stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) + assert len(lift_stmt) == 1 + return f.with_body(_merge_block(lift_stmt[0], stmt)) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin") + + +def InjectSkipCopy(): + """Pass to inject skip copy stmt, used for debug purpose. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _do_fold(stmt): + if _match_pragma(stmt, "skip_dma_copy"): + return tvm.tir.Evaluate(0) + return None + + def _ftransform(f, mod, ctx): + return f.with_body(tvm.tir.ir_pass.IRTransform( + f.body, _do_fold, None, ["AttrStmt"])) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy") + + +def InjectCoProcSync(): + """Pass inject coproc sync + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(f, *_): + success = [False] + def _do_fold(stmt): + if _match_pragma(stmt, "coproc_sync"): + success[0] = True + sync = tvm.tir.Call( + "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0) + return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)]) + if _match_pragma(stmt, "trim_loop"): + op = stmt.body + assert isinstance(op, tvm.tir.For) + return tvm.tir.For( + op.loop_var, op.min, 2, op.for_type, + op.device_api, op.body) + return None + return f.with_body(tvm.tir.ir_pass.IRTransform( + f.body, None, _do_fold, ["AttrStmt"])) + return tvm.transform.Sequential( + [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"), + tvm.tir.transform.CoProcSync()], + opt_level=0, name="tir.vta.InjectCoProcSync") + + +def InjectDMAIntrin(): + """Pass to inject DMA copy intrinsics. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + def _check_compact(buf): + ndim = len(buf.shape) + size = tvm.tir.const(1, buf.shape[0].dtype) + for i in reversed(range(ndim)): + if not util.equal_const_int(size - buf.strides[i], 0): + raise RuntimeError( + "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)) + size = size * buf.shape[i] + + def _fold_buffer_dim(buf, scope, elem_block): + ndim = len(buf.shape) + x_size = 1 + base = 0 + for i in range(1, ndim + 1): + if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0): + raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block)) + x_size = x_size * buf.shape[ndim - i] + if util.equal_const_int(x_size - elem_block, 0): + base = i + 1 + break + if base == 0: + raise RuntimeError("scope %s need to have block=%d, shape=%s" % ( + scope, elem_block, buf.shape)) + shape = [elem_block] + strides = [1] + + if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block): + shape.append(1) + strides.append(elem_block) + + analyzer = tvm.arith.Analyzer() + while base < ndim + 1: + x_size = 1 + x_stride = buf.strides[ndim - base] + next_base = base + if not util.equal_const_int(idxm(x_stride, elem_block), 0): + raise RuntimeError( + "scope %s need to have block=%d, shape=%s, strides=%s" % ( + scope, elem_block, buf.shape, buf.strides)) + for i in range(base, ndim + 1): + k = ndim - i + if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0): + break + x_size = x_size * buf.shape[k] + next_base = i + 1 + shape.append(analyzer.simplify(x_size)) + strides.append(x_stride) + assert next_base != base + base = next_base + + strides = list(reversed(strides)) + shape = list(reversed(shape)) + return shape, strides + + def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): + elem_block = elem_bytes * 8 // elem_width + if buf.dtype != dtype: + raise RuntimeError("Expect buffer type to be %s instead of %s" % + (dtype, buf.dtype)) + shape, strides = buf.shape, buf.strides + if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0): + raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) + if allow_fold: + shape, strides = _fold_buffer_dim(buf, scope, elem_block) + else: + shape = list(x for x in shape) + strides = list(x for x in strides) + + def raise_error(): + """Internal function to raise error """ + raise RuntimeError( + ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" + + " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides)) + + ndim = len(shape) + + # Check if the inner-tensor is already flat + flat = util.equal_const_int(shape[-1], elem_block) + + if flat: + if not util.equal_const_int(strides[-1], 1): + raise_error() + + if ndim == 1: + x_size = 1 + x_stride = 1 + y_size = 1 + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(strides[-2] - elem_block, 0): + raise_error() + + if ndim == 2: + x_size = shape[-2] + x_stride = shape[-2] + y_size = 1 + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(idxm(strides[-3], elem_block), 0): + raise_error() + + if ndim == 3: + x_size = shape[-2] + x_stride = idxd(strides[-3], elem_block) + y_size = shape[-3] + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + + else: + if not util.equal_const_int(strides[-1], 1): + raise_error() + if not util.equal_const_int(strides[-2] - shape[-1], 0): + raise_error() + if not util.equal_const_int(shape[-1] * shape[-2], elem_block): + raise_error() + + if ndim == 2: + x_size = 1 + x_stride = 1 + y_size = 1 + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(strides[-3], elem_block): + raise_error() + + if ndim == 3: + x_size = shape[-3] + x_stride = shape[-3] + y_size = 1 + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(idxm(strides[-4], elem_block), 0): + raise_error() + + if ndim == 4: + x_size = shape[-3] + x_stride = idxd(strides[-4], elem_block) + y_size = shape[-4] + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + + raise_error() + + + def _inject_copy(src, dst, pad_before, pad_after, pad_value): + # FIXME: pad_value is ignored... + env = get_env() + _ = pad_value + if dst.scope == "global": + # Store + if pad_before or pad_after: + raise RuntimeError("Do not support copy into DRAM with pad") + if src.scope == env.acc_scope: + elem_width = env.OUT_WIDTH + elem_bytes = env.OUT_ELEM_BYTES + mem_type = env.dev.MEM_ID_OUT + data_type = "int%d" % env.OUT_WIDTH + task_qid = env.dev.QID_STORE_OUT + else: + raise RuntimeError("Do not support copy %s->dram" % (src.scope)) + _check_compact(src) + x_size, y_size, x_stride, offset = _get_2d_pattern( + dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True) + irb = tvm.tir.ir_builder.create() + irb.scope_attr(env.dev.vta_axis, "coproc_scope", + env.dev.get_task_qid(task_qid)) + irb.emit(tvm.tir.call_extern( + "int32", "VTAStoreBuffer2D", + env.dev.command_handle, + src.access_ptr("r", "int32"), + mem_type, dst.data, offset, x_size, y_size, x_stride)) + return irb.get() + elif src.scope == "global": + if dst.scope == env.acc_scope: + elem_width = env.ACC_WIDTH + elem_bytes = env.ACC_ELEM_BYTES + mem_type = env.dev.MEM_ID_ACC + data_type = "int%d" % env.ACC_WIDTH + task_qid = env.dev.QID_LOAD_OUT + elif dst.scope == env.inp_scope: + elem_width = env.INP_WIDTH + elem_bytes = env.INP_ELEM_BYTES + mem_type = env.dev.MEM_ID_INP + data_type = "int%d" % env.INP_WIDTH + task_qid = env.dev.QID_LOAD_INP + elif dst.scope == env.wgt_scope: + elem_width = env.WGT_WIDTH + elem_bytes = env.WGT_ELEM_BYTES + mem_type = env.dev.MEM_ID_WGT + data_type = "int%d" % env.WGT_WIDTH + task_qid = env.dev.QID_LOAD_WGT + else: + raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) + # collect pad statistics + if pad_before: + assert pad_after + ndim = len(pad_before) + if ndim <= 2 or ndim > 5: + raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim) + if ndim == 5: + # This case occurs when batch size N > 1 + y_pad_before = pad_before[1] + x_pad_before = pad_before[2] + y_pad_after = pad_after[1] + x_pad_after = pad_after[2] + for dim in range(3, ndim): + if not util.equal_const_int(pad_before[dim], 0): + raise ValueError("Do not support pad on the innermost block") + if not util.equal_const_int(pad_after[dim], 0): + raise ValueError("Do not support pad on the innermost block") + else: + y_pad_before = pad_before[0] + x_pad_before = pad_before[1] + y_pad_after = pad_after[0] + x_pad_after = pad_after[1] + for dim in range(2, ndim): + if not util.equal_const_int(pad_before[dim], 0): + raise ValueError("Do not support pad on the innermost block") + if not util.equal_const_int(pad_after[dim], 0): + raise ValueError("Do not support pad on the innermost block") + allow_fold = False + else: + x_pad_before = 0 + y_pad_before = 0 + x_pad_after = 0 + y_pad_after = 0 + allow_fold = True + + _check_compact(dst) + x_size, y_size, x_stride, offset = _get_2d_pattern( + src, elem_width, elem_bytes, data_type, + dst.scope, allow_fold=allow_fold) + + irb = tvm.tir.ir_builder.create() + irb.scope_attr(env.dev.vta_axis, "coproc_scope", + env.dev.get_task_qid(task_qid)) + + irb.emit(tvm.tir.call_extern( + "int32", "VTALoadBuffer2D", + env.dev.command_handle, + src.data, offset, x_size, y_size, x_stride, + x_pad_before, y_pad_before, + x_pad_after, y_pad_after, + dst.access_ptr("r", "int32"), mem_type)) + return irb.get() + + else: + raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) + + return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy) + + +def _get_gemm_intrin_buffer(): + env = get_env() + wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH + assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN + wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) + assert wgt_shape[0] * wgt_shape[1] == wgt_lanes + inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH + assert inp_lanes == env.BATCH * env.BLOCK_IN + inp_shape = (env.BATCH, env.BLOCK_IN) + assert inp_shape[0] * inp_shape[1] == inp_lanes + out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH + assert out_lanes == env.BATCH * env.BLOCK_OUT + out_shape = (env.BATCH, env.BLOCK_OUT) + assert out_shape[0] * out_shape[1] == out_lanes + wgt = te.placeholder((wgt_shape[0], wgt_shape[1]), + dtype="int%d" % env.WGT_WIDTH, + name=env.wgt_scope) + inp = te.placeholder((inp_shape[0], inp_shape[1]), + dtype="int%d" % env.INP_WIDTH, + name=env.inp_scope) + k = te.reduce_axis((0, wgt_shape[1]), name="k") + out_dtype = "int%d" % env.ACC_WIDTH + out = te.compute((out_shape[0], out_shape[1]), + lambda i, j: te.sum(inp[i, k].astype(out_dtype) * + wgt[j, k].astype(out_dtype), + axis=[k]), + name="out") + wgt_layout = tvm.tir.decl_buffer( + wgt.shape, wgt.dtype, env.wgt_scope, + scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) + inp_layout = tvm.tir.decl_buffer( + inp.shape, inp.dtype, env.inp_scope, + scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) + out_layout = tvm.tir.decl_buffer( + out.shape, out.dtype, env.acc_scope, + scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) + + return wgt_layout, inp_layout, out_layout + + +def InjectConv2DTransposeSkip(): + """Pass to skip 0-weights in conv2d transpose with stride > 1. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(func, mod, ctx): + env = get_env() + dwgt, dinp, dout = _get_gemm_intrin_buffer() + + calls = [] + selects = [] + + def _find_basics(op): + if isinstance(op, tvm.tir.BufferLoad): + calls.append(op) + elif isinstance(op, tvm.tir.Select): + selects.append(op) + + def _do_fold(op): + if _match_pragma(op, "conv2d_transpose_gemm"): + is_init = ".init" in str(op) + tvm.tir.ir_pass.PostOrderVisit(op, _find_basics) + + if is_init: + # create inner most block + irb = tvm.tir.ir_builder.create() + dev = env.dev + irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) + irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) + irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", + 0, 1, + dout.access_ptr("rw", "int32"), + 0, 0, + 0, 0, 0)) + inner = irb.get() + # TODO(@tmoreau89): This is only a temporary fix, please take a look. + body = op.body.body + while isinstance(body, tvm.tir.IfThenElse): + body = body.then_case + args = body.indices + res_buffer = body.buffer + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) + inner = tvm.tir.AttrStmt( + [dout, res_buffer], 'buffer_bind_scope', + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + return inner + else: + conv_call, data_call, kernel_call = calls[-3:] + pad_data_tensor = data_call.buffer + kernel_tensor = kernel_call.buffer + res_tensor = conv_call.buffer + + if selects: + condition = selects[0].condition + else: + condition = tvm.tir.const(1, 'int') + + # create inner most block + irb = tvm.tir.ir_builder.create() + with irb.if_scope(condition): + dev = env.dev + irb.scope_attr( + dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) + irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) + irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", + 0, 0, + dout.access_ptr("rw", "int32"), + dinp.access_ptr("r", "int32"), + dwgt.access_ptr("r", "int32"), + 0, 0, 0)) + inner = irb.get() + + args = conv_call.indices + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, 1, 0, env.BLOCK_OUT) + inner = tvm.tir.AttrStmt( + [dout, res_tensor], 'buffer_bind_scope', + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + args = kernel_call.indices + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) + inner = tvm.tir.AttrStmt( + [dwgt, kernel_tensor], 'buffer_bind_scope', + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + args = data_call.indices + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, 1, 0, env.BLOCK_IN) + inner = tvm.tir.AttrStmt( + [dinp, pad_data_tensor], 'buffer_bind_scope', + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + return inner + return None + + return func.with_body(tvm.tir.ir_pass.IRTransform( + func.body, _do_fold, None, ["AttrStmt"])) + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip") + + +def AnnotateALUCoProcScope(): + """Pass to insert ALU instruction. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(func, mod, ctx): + env = get_env() + def _do_fold(stmt): + if _match_pragma(stmt, "alu"): + irb = tvm.tir.ir_builder.create() + irb.scope_attr(env.dev.vta_axis, "coproc_scope", + env.dev.get_task_qid(env.dev.QID_COMPUTE)) + irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope", + tvm.tir.StringImm("VTAPushALUOp")) + irb.emit(stmt) + return irb.get() + if _match_pragma(stmt, "skip_alu"): + return tvm.tir.Evaluate(0) + return stmt + + return func.with_body(tvm.tir.ir_pass.IRTransform( + func.body, None, _do_fold, ["AttrStmt"])) + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope") + + +def InjectALUIntrin(): + """Pass to inject ALU micro-ops. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(func, mod, ctx): + env = get_env() + idxm = tvm.tir.indexmod + analyzer = tvm.arith.Analyzer() + + def _do_fold(stmt): + def _equal(x, y): + return tvm.ir.structural_equal(analyzer.simplify(x - y), 0) + + def _flatten_loop(src_coeff, dst_coeff, extents): + src_coeff = list(src_coeff) + dst_coeff = list(dst_coeff) + extents = list(extents) + rev_src_coeff = [src_coeff.pop()] + rev_dst_coeff = [dst_coeff.pop()] + rev_extents = [] + assert src_coeff + vsrc = src_coeff.pop() + vdst = dst_coeff.pop() + vext = extents.pop() + while src_coeff: + next_src = src_coeff.pop() + next_dst = dst_coeff.pop() + next_ext = extents.pop() + + if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): + vext = analyzer.simplify(vext * next_ext) + else: + rev_src_coeff.append(vsrc) + rev_dst_coeff.append(vdst) + rev_extents.append(vext) + vsrc = next_src + vdst = next_dst + vext = next_ext + rev_src_coeff.append(vsrc) + rev_dst_coeff.append(vdst) + rev_extents.append(vext) + rev_src_coeff.reverse() + rev_dst_coeff.reverse() + rev_extents.reverse() + + return rev_src_coeff, rev_dst_coeff, rev_extents + + if _match_pragma(stmt, "alu"): + # Get to the innermost loop body + loop_body = stmt.body + nest_size = 0 + while isinstance(loop_body, tvm.tir.For): + loop_body = loop_body.body + nest_size += 1 + # Get the src/dst arguments + dst_var = loop_body.buffer_var + dst_idx = loop_body.index + # Derive loop variables and extents + tmp_body = stmt.body + indices = [] + extents = [] + for _ in range(nest_size): + indices.append(tmp_body.loop_var) + extents.append(tmp_body.extent) + tmp_body = tmp_body.body + # Derive opcode + if isinstance(loop_body.value, tvm.tir.Add): + alu_opcode = env.dev.ALU_OPCODE_ADD + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Sub): + alu_opcode = env.dev.ALU_OPCODE_SUB + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Mul): + alu_opcode = env.dev.ALU_OPCODE_MUL + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Min): + alu_opcode = env.dev.ALU_OPCODE_MIN + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Max): + alu_opcode = env.dev.ALU_OPCODE_MAX + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Call): + if loop_body.value.name == 'shift_left': + alu_opcode = env.dev.ALU_OPCODE_SHR + lhs = loop_body.value.args[0] + rhs = analyzer.simplify(-loop_body.value.args[1]) + elif loop_body.value.name == 'shift_right': + alu_opcode = env.dev.ALU_OPCODE_SHR + lhs = loop_body.value.args[0] + rhs = loop_body.value.args[1] + else: + raise RuntimeError( + "Function call not recognized %s" % (loop_body.value.name)) + elif isinstance(loop_body.value, tvm.tir.Load): + alu_opcode = env.dev.ALU_OPCODE_SHR + lhs = loop_body.value + rhs = tvm.tir.const(0, "int32") + else: + raise RuntimeError( + "Expression not recognized %s, %s, %s" % ( + type(loop_body.value), str(loop_body.value), str(stmt))) + + # Derive array index coefficients + dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices) + # Check if lhs/rhs is immediate + use_imm = False + imm_val = None + if isinstance(rhs, tvm.tir.IntImm): + assert lhs.buffer_var.same_as(dst_var) + src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) + use_imm = True + imm_val = rhs + if isinstance(lhs, tvm.tir.IntImm): + assert rhs.buffer_var.same_as(dst_var) + src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) + use_imm = True + imm_val = lhs + if imm_val is None: + imm_val = 0 + assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) + src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) + src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) + # Determine which side has the same coefficients + lhs_equal = True + rhs_equal = True + for i, coef in enumerate(dst_coeff): + if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]): + lhs_equal = False + if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]): + rhs_equal = False + # Make sure at least one of the source is identical to the + # destination (in-place computation) + assert lhs_equal or rhs_equal + # Assign the source coefficients + if lhs_equal: + src_coeff = src_rhs_coeff + else: + src_coeff = src_lhs_coeff + + # Ensure that we have the proper tensor dimensions in the + # innermost loop (pattern match) + src_coeff = list(src_coeff) + dst_coeff = list(dst_coeff) + extents = list(extents) + assert len(src_coeff) > 1 + assert len(dst_coeff) > 1 + assert len(extents) != 0 + assert tvm.ir.structural_equal( + analyzer.simplify( + idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) + assert tvm.ir.structural_equal( + analyzer.simplify( + idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) + assert tvm.ir.structural_equal(src_coeff[-2], 1) + assert tvm.ir.structural_equal(dst_coeff[-2], 1) + if env.BATCH > 1: + assert len(src_coeff) > 2 + assert len(dst_coeff) > 2 + assert len(extents) > 1 + assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT) + assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT) + + # Apply tensorization of the loop coefficients + src_offset = src_coeff[-1] + dst_offset = dst_coeff[-1] + if env.BATCH == 1: + src_coeff = src_coeff[:-2] + dst_coeff = dst_coeff[:-2] + extents = extents[:-1] + else: + src_coeff = src_coeff[:-3] + dst_coeff = dst_coeff[:-3] + extents = extents[:-2] + src_coeff.append(src_offset) + dst_coeff.append(dst_offset) + src_coeff = [ + analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] + dst_coeff = [ + analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] + + # Flatten the outer loops + if extents: + src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents) + + # Insert ALU micro-ops + irb = tvm.tir.ir_builder.create() + for idx, extent in enumerate(extents): + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopLoopBegin", + extent, dst_coeff[idx], src_coeff[idx], 0)) + use_imm = int(use_imm) + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopPush", + 1, 0, + dst_coeff[len(dst_coeff)-1], + src_coeff[len(src_coeff)-1], + 0, + alu_opcode, use_imm, imm_val)) + for extent in extents: + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopLoopEnd")) + return irb.get() + return stmt + + return func.with_body(tvm.tir.ir_pass.IRTransform( + func.body, None, _do_fold, ["AttrStmt"])) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")