From 4d3796cecd2603dc7aade5a77d9260add964baae Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 15 Apr 2020 11:11:28 -0700 Subject: [PATCH] [PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases (#5335) --- python/tvm/ir/function.py | 31 ++++++++++++++++ python/tvm/relay/function.py | 19 ---------- python/tvm/testing.py | 37 ------------------- python/tvm/tir/function.py | 19 ---------- src/ir/function.cc | 26 +++++++++++++ src/ir/module.cc | 8 +++- src/relay/ir/function.cc | 6 --- src/tir/ir/function.cc | 6 --- .../python/unittest/test_runtime_extension.py | 3 +- .../unittest/test_runtime_module_load.py | 7 +++- .../unittest/test_target_codegen_llvm.py | 12 ++++-- .../test_target_codegen_static_init.py | 23 ++++-------- .../unittest/test_target_codegen_vm_basic.py | 34 +++++++---------- tests/python/unittest/test_tir_nodes.py | 2 +- .../unittest/test_tir_pass_storage_flatten.py | 4 +- .../test_tir_transform_lower_warp_memory.py | 2 +- .../test_tir_transform_make_packed_api.py | 10 ++--- .../test_tir_transform_thread_sync.py | 9 +++-- 18 files changed, 117 insertions(+), 141 deletions(-) diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index afc8c1066b1c..d28ffa68339e 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -16,6 +16,8 @@ # under the License. """Function defintiions.""" from enum import IntEnum +import tvm.runtime + from .expr import RelayExpr from . import _ffi_api @@ -34,3 +36,32 @@ def attrs(self): """Return the attrs member of the function. """ return _ffi_api.BaseFunc_Attrs(self) + + def with_attr(self, attr_key_or_dict, attr_value=None): + """Create a new copy of the function and update the attribute. + + Parameters + ---------- + attr_key_or_dict : Union[str, dict] + The attribute key to use or a dict containing multiple key value pairs. + + attr_value : Object + The new attribute value. + + Returns + ------- + func : Function + A new copy of the function + """ + # make sure we first copy so that we can safely do copy on write + # for multiple updates. + res = _ffi_api.BaseFuncCopy(self) + + if isinstance(attr_key_or_dict, dict): + for key, val in attr_key_or_dict.items(): + res = _ffi_api.BaseFuncWithAttr( + res._move(), key, tvm.runtime.convert(val)) + return res + + return _ffi_api.BaseFuncWithAttr( + res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 786a7f4cfc24..568dd4165160 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -65,22 +65,3 @@ def __call__(self, *args): Arguments. """ return Call(self, args, None, None) - - def with_attr(self, attr_key, attr_value): - """Create a new copy of the function and update the attribute - - Parameters - ---------- - attr_key : str - The attribute key to use. - - attr_value : Object - The new attribute value. - - Returns - ------- - func : Function - A new copy of the function - """ - return _ffi_api.FunctionWithAttr( - self, attr_key, convert(attr_value)) diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 064c43891d2d..0f50636d68d8 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -168,41 +168,4 @@ def compare_derivative(j, n_der, grad): x_name, grad.shape, dist, max_diff, avg_diff) -def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias): - """Legacy adapter to build a Module from statement. - - Used for migrating existing test cases only. - - Parameters - ---------- - stmt: Stmt - The input statement. - - name: str - The name of the funciton. - - args: list of Buffer or Vars - The function arguments - - num_unpacked_args: int - Number of unpacked arguments. - - nolias: bool - Whether allow noalias. - - Returns - ------- - mod : IRModule - The created IRModule. - """ - assert num_unpacked_args == 0 - f = tvm.tir.PrimFunc(args, stmt).with_attr( - "global_symbol", tvm.runtime.String(name)) - f = f.with_attr("tir.is_entry_func", True) - if noalias: - f = f.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: f}) - return mod - - tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 0ed1762a889c..4ec1a71f345e 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -67,22 +67,3 @@ def __init__(self, self.__init_handle_by_constructor__( _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs) - - def with_attr(self, attr_key, attr_value): - """Create a new copy of the function and update the attribute - - Parameters - ---------- - attr_key : str - The attribute key to use. - - attr_value : Object - The new attribute value. - - Returns - ------- - func : Function - A new copy of the function - """ - return _ffi_api.PrimFuncWithAttr( - self, attr_key, tvm.runtime.convert(attr_value)) diff --git a/src/ir/function.cc b/src/ir/function.cc index e7ccbbe73e7b..08cdc93e28b5 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -23,6 +23,14 @@ */ #include #include +// NOTE: reverse dependency on relay, tir/ +// These dependencies do not happen at the interface-level, +// and are only used in minimum cases where they are clearly marked. +// +// Rationale: We calls into the type specific WithAttr function +#include +#include + namespace tvm { @@ -31,4 +39,22 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs") return func->attrs; }); +TVM_REGISTER_GLOBAL("ir.BaseFuncCopy") +.set_body_typed([](BaseFunc func) { + return func; +}); + +TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") +.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } +}); + + } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index bcf56aacb859..6262150556c7 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -362,13 +362,19 @@ IRModule IRModule::FromExpr( const tvm::Map& type_definitions) { auto mod = IRModule(global_funcs, type_definitions); BaseFunc func; + std::string gv_name = "main"; + if (auto* func_node = expr.as()) { func = GetRef(func_node); + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + gv_name = opt.value(); + } + } else { func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); } - auto main_gv = GlobalVar("main"); + auto main_gv = GlobalVar(gv_name); mod->Add(main_gv, func); return mod; } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 48cb4d8a8693..12a80c5698af 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -74,11 +74,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->attrs << ")"; }); -TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr") -.set_body_typed( - [](Function func, std::string name, ObjectRef ref) { - return WithAttr(std::move(func), name, ref); - }); - } // namespace relay } // namespace tvm diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 0891c47ab58c..ecaad586f894 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -84,11 +84,5 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc") return PrimFunc(params, body, ret_type, buffer_map, attrs); }); - -TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr") -.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) { - return WithAttr(std::move(func), name, ref); -}); - } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_runtime_extension.py b/tests/python/unittest/test_runtime_extension.py index 52fc8c233a12..48eaf7dd306b 100644 --- a/tests/python/unittest/test_runtime_extension.py +++ b/tests/python/unittest/test_runtime_extension.py @@ -39,7 +39,8 @@ def test_dltensor_compatible(): A[i + 1] = A[i] + 1 stmt = ib.get() - mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange")) f = tvm.build(mod, target="stackvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) aview = MyTensorView(a) diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index f6abebd2fbc4..c7a5544f4a30 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -57,8 +57,11 @@ def save_object(names): tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1)) - m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) - m = tvm.driver.build(m, target="llvm") + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr( + "global_symbol", "main") + ) + m = tvm.driver.build(mod, target="llvm") for name in names: m.save(name) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 76f96d4a3ba0..44b05c90ff17 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -36,8 +36,11 @@ def test_llvm_intrin(): "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0))) body = ib.get() - func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True) - fcode = tvm.build(func, None, "llvm") + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A], body).with_attr( + "global_symbol", "prefetch") + ) + fcode = tvm.build(mod, None, "llvm") def test_llvm_overloaded_intrin(): @@ -111,8 +114,9 @@ def test_llvm_lookup_intrin(): x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z]) ib.emit(x) body = ib.get() - func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 0, True) - fcode = tvm.build(func, None, "llvm") + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) + fcode = tvm.build(mod, None, "llvm") def test_llvm_large_uintimm(): diff --git a/tests/python/unittest/test_target_codegen_static_init.py b/tests/python/unittest/test_target_codegen_static_init.py index 64bb698ced9a..e3b8ff3ed5ed 100644 --- a/tests/python/unittest/test_target_codegen_static_init.py +++ b/tests/python/unittest/test_target_codegen_static_init.py @@ -20,17 +20,6 @@ import numpy as np -def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias): - """Legacy adapter to create a API""" - f = tvm.tir.PrimFunc(args, stmt).with_attr( - "global_symbol", tvm.runtime.String(name)) - f = f.with_attr("tir.is_entry_func", True) - if noalias: - f = f.with_attr("tir.noalias", True) - mod = tvm.IRModule.from_expr(f) - return tvm.tir.transform.MakePackedAPI()(mod) - - def test_static_callback(): dtype = 'int64' n = te.size_var('n') @@ -44,8 +33,11 @@ def test_static_callback(): with ib.for_range(0, n, "i", for_type="parallel") as i: A[i] = A[i] + 1 stmt = ib.get() - fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) - f = tvm.driver.build(fapi, target="llvm") + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp") + ) + f = tvm.driver.build(mod, target="llvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) f(a) @@ -67,8 +59,9 @@ def test_cb(sh, A): return sh stmt = ib.get() - fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) - f = tvm.driver.build(fapi, target="llvm") + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")) + f = tvm.driver.build(mod, target="llvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) diff --git a/tests/python/unittest/test_target_codegen_vm_basic.py b/tests/python/unittest/test_target_codegen_vm_basic.py index 6af4279e6ffd..e03d689b6c5f 100644 --- a/tests/python/unittest/test_target_codegen_vm_basic.py +++ b/tests/python/unittest/test_target_codegen_vm_basic.py @@ -26,18 +26,6 @@ def run_jit(fapi, check): s = f.get_source() check(f) - -def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias): - """Legacy adapter to create a API""" - f = tvm.tir.PrimFunc(args, stmt).with_attr( - "global_symbol", tvm.runtime.String(name)) - f = f.with_attr("tir.is_entry_func", True) - if noalias: - f = f.with_attr("tir.noalias", True) - mod = tvm.IRModule.from_expr(f) - return tvm.tir.transform.MakePackedAPI()(mod) - - def test_stack_vm_basic(): a = tvm.nd.array(np.zeros(10, dtype='float32')) @tvm.register_func @@ -48,8 +36,11 @@ def tvm_call_back_get_shape(shape0): n = te.size_var('n') Ab = tvm.tir.decl_buffer((n, ), "float32") stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0])) - fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True) - run_jit(fapi, lambda f: f(a)) + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "print_shape")) + + run_jit(mod, lambda f: f(a)) @tvm.register_func @@ -69,12 +60,13 @@ def test_stack_vm_loop(): ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i)) stmt = ib.get() - fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")) a = tvm.nd.array(np.zeros(10, dtype=dtype)) def check(f): f(a) np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0])) - run_jit(fapi, check) + run_jit(mod, check) def test_stack_vm_cond(): @@ -91,14 +83,15 @@ def test_stack_vm_cond(): A[i + 1] = A[i] + 2 stmt = ib.get() - fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test")) def check(f): a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) y = np.arange(a.shape[0]) * 2 y[5:] -= 1 np.testing.assert_equal(a.asnumpy(), y) - run_jit(fapi, check) + run_jit(mod, check) def test_vm_parallel(): dtype = 'int64' @@ -110,12 +103,13 @@ def test_vm_parallel(): with ib.for_range(0, n, "i", for_type="parallel") as i: A[i] = A[i] + 1 stmt = ib.get() - fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test")) def check(f): a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0])) - run_jit(fapi, check) + run_jit(mod, check) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 00ac7e346152..9f4ccadde94d 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -277,7 +277,7 @@ def test_prim_func(): assert func.buffer_map[func.params[2]].same_as(b) assert len(func.buffer_map) == 1 - f2 = func.with_attr("calling_conv", 1) + f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True}) assert f2.attrs["calling_conv"].value == 1 assert func.attrs is None diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py b/tests/python/unittest/test_tir_pass_storage_flatten.py index 88799c4736d9..1eaadb35009d 100644 --- a/tests/python/unittest/test_tir_pass_storage_flatten.py +++ b/tests/python/unittest/test_tir_pass_storage_flatten.py @@ -92,7 +92,9 @@ def test_flatten_double_buffer(): stmt = tvm.tir.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 - mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 0, True) + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 25204eb1d906..8a31a1537ca2 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -43,7 +43,7 @@ def test_lower_warp_memory_local_scope(): mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) - fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"] + fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.body.extents[0].value == 2) diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index 7222a617ced5..fb76597577b6 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -35,11 +35,11 @@ def test_makeapi(): stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) num_unpacked_args = 2 - f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt) - f = f.with_attr("global_symbol", "myadd") - f = f.with_attr("target", tvm.target.create("llvm")) - - mod = tvm.IRModule.from_expr(f) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr({ + "global_symbol": "main", + "target": tvm.target.create("llvm") + })) f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] assert(len(f.params) == 7) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 64b454fe7645..9257f6cd3320 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -39,12 +39,15 @@ def test_thread_storage_sync(): stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) cuda_target = tvm.target.create("cuda") - mod = tvm.testing.MakeAPILegacy(stmt, "test", [Ab, A2b], 0, True) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({ + "global_symbol": "test", "target": cuda_target})) + fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] mod = tvm.IRModule.from_expr(fdevice) cuda_target = tvm.target.create("cuda") - f = tvm.tir.transform.ThreadSync("shared")(mod)["main"] + f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] body_list = tvm.tir.stmt_list(f.body.body.body.body) assert(body_list[1].value.name == "tvm_storage_sync")