diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 7f826710b9c9..646d5f1508b3 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -696,6 +696,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage(); */ TVM_DLL Pass InstrumentProfileIntrinsics(); +/*! + * \brief Mangle TIR function names by appending a prefix to avoid symbol collisions. + * \return The pass. + */ +TVM_DLL Pass TIRFuncRename(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/relax/backend/aot/codegen_aot.cc b/src/relax/backend/aot/codegen_aot.cc index 5a3da9756194..ff28a9137cc1 100644 --- a/src/relax/backend/aot/codegen_aot.cc +++ b/src/relax/backend/aot/codegen_aot.cc @@ -68,6 +68,7 @@ runtime::Module Build(IRModule mod, String mod_name, CompilationConfig config, r mod = relax::transform::UnifiedStaticMemoryPlanner()(mod); mod = AOTLowerMain(mod_name, config)(mod); mod = tir::transform::LegalizePackedCalls()(mod); + mod = tir::transform::TIRFuncRename()(mod); auto lowered_funcs = tvm::relay::tec::GetPerTargetModules(mod); auto exec_metadata = tvm::relay::backend::aot::CreateExecutorMetadata(mod, mod_name, executor, workspace_byte_alignment, diff --git a/src/relax/usmp/transform/convert_relax_to_dps.cc b/src/relax/usmp/transform/convert_relax_to_dps.cc index 26bfddd04765..b9588326ee0a 100644 --- a/src/relax/usmp/transform/convert_relax_to_dps.cc +++ b/src/relax/usmp/transform/convert_relax_to_dps.cc @@ -123,12 +123,18 @@ class ConvertRelaxMainToDPS : public ExprMutator { for (auto iter = block->bindings.rbegin(); iter != block->bindings.rend(); iter++) { Binding binding = *iter; if (const auto* var_binding = binding.as()) { - if (var_binding->value->IsInstance() && - return_alias_.count(var_binding->var) > 0) { - // Alias. Update alias map and do not emit binding. - return_alias_[runtime::Downcast(var_binding->value)] = - return_alias_[var_binding->var]; - continue; + if (var_binding->value->IsInstance()) { + if (return_alias_.count(var_binding->var) > 0) { + // Alias. Update alias map and do not emit binding. + return_alias_[runtime::Downcast(var_binding->value)] = + return_alias_[var_binding->var]; + continue; + } + if (return_alias_.count(var_binding->var) == 0 + && return_alias_.count(var_binding->value) > 0) { + // Creating an alias for a dead var. Do not emit binding. + continue; + } } if (var_binding->value->IsInstance() && diff --git a/src/tir/transforms/tir_func_rename.cc b/src/tir/transforms/tir_func_rename.cc new file mode 100644 index 000000000000..de30f0fe3f76 --- /dev/null +++ b/src/tir/transforms/tir_func_rename.cc @@ -0,0 +1,128 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + + +/*! + * \file src/relax/backend/aot/tir_func_rename.cc + * \brief Mangles TIR function names to avoid symbol conflicts. + * Appends "_tvm_gen" to all function names in the IRModule. + */ + +#include + +#include "tvm/ir/name_supply.h" +#include "tvm/ir/transform.h" +#include "tvm/tir/builtin.h" +#include "tvm/tir/stmt_functor.h" + +namespace tvm { +namespace tir { +namespace aot { + +class TIRMangleFuncName : public StmtExprMutator { + + public: + explicit TIRMangleFuncName(IRModule mod) : mod_(std::move(mod)) { + ICHECK(mod_->ContainGlobalVar(runtime::symbol::tvm_module_main)) << "Expecting module to have" + << " symbol " << runtime::symbol::tvm_module_main << " attached."; + auto main_func_gv = mod_->GetGlobalVar(runtime::symbol::tvm_module_main); + NameSupply name_supply = NameSupply("_tvm_gen"); + for (auto pair : mod_->functions) { + if (pair.first.same_as(main_func_gv)) { + // Ignore the main function. + continue; + } + auto prim_func = runtime::Downcast(pair.second); + auto func_name = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(func_name.defined()) << "Expecting global_symbol attribute to be attached to the" + " function"; + name_map_[func_name.value()] = name_supply->FreshName(func_name.value()); + } + } + + IRModule operator()() { + auto main_func_gv = mod_->GetGlobalVar(runtime::symbol::tvm_module_main); + + Map func_map = Map(); + for (auto pair : mod_->functions) { + auto prim_func = runtime::Downcast(pair.second); + auto func_name = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + + Stmt new_body = this->VisitStmt(prim_func->body); + if (pair.first.same_as(main_func_gv)) { + // No need to set a new global var and global symbol for the main function. + func_map.Set(pair.first, PrimFunc(prim_func->params, new_body, prim_func->ret_type, + prim_func->buffer_map, prim_func->attrs, prim_func->span)); + } else { + ICHECK(name_map_.count(func_name.value()) > 0) << "Expecting new name in name_map_ at " + "this stage."; + GlobalVar new_var = GlobalVar(name_map_[func_name.value()]); + PrimFunc new_func = PrimFunc(prim_func->params, new_body, prim_func->ret_type, + prim_func->buffer_map, prim_func->attrs, prim_func->span); + new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, + String(name_map_[func_name.value()])); + func_map.Set(new_var, new_func); + } + } + + IRModule new_mod = IRModule(func_map, mod_->type_definitions, mod_->Imports(), + mod_->source_map, mod_->attrs); + return new_mod; + } + + private: + PrimExpr VisitExpr_(const CallNode* op) override { + String func_name; + if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { + func_name = Downcast(op->args[0])->value; + } + if (op->op->IsInstance()) { + func_name = Downcast(op->args[0])->value; + } + if (func_name.defined() && mod_->ContainGlobalVar(func_name) && + mod_->Lookup(func_name)->IsInstance()) { + ICHECK(name_map_.count(func_name) > 0) << "Name map should contain a name."; + StringImm new_name = StringImm(name_map_[func_name]); + Array new_args = { new_name }; + new_args.insert(new_args.end(), op->args.begin() + 1, op->args.end()); + return Call(op->dtype, op->op, new_args, op->span); + } + return StmtExprMutator::VisitExpr_(op); + } + + std::unordered_map name_map_; + IRModule mod_; +}; + +} // namespace aot + +namespace transform { + +tvm::transform::Pass TIRFuncRename() { + auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { + return runtime::Downcast(tvm::tir::aot::TIRMangleFuncName(m)()); + }; + + return tvm::transform::CreateModulePass(pass_func, 0, + "tir.transform.TIRFuncRename", {}); +} + +} // namespace transform +} // namespace tir +} // namespace tvm \ No newline at end of file diff --git a/tests/python/relax/aot/test_aot_build.py b/tests/python/relax/aot/test_aot_build.py index a84a783a113d..2ba6c7181b63 100644 --- a/tests/python/relax/aot/test_aot_build.py +++ b/tests/python/relax/aot/test_aot_build.py @@ -48,7 +48,7 @@ def _relay(): def _reference(inputs): x = inputs["x"] return np.abs(x) # abs - + relax_mod = relay_translator.from_relay( _relay(), target, @@ -75,7 +75,7 @@ def _relay(): def _reference(inputs): x = inputs["x"] return np.add(x, -1) # add - + relax_mod = relay_translator.from_relay( _relay(), target, @@ -102,7 +102,7 @@ def _relay(): def _reference(inputs): x = inputs["x"] return np.add(x, np.array([[1, 2], [3, 4]])) # add - + relax_mod = relay_translator.from_relay( _relay(), target, @@ -119,7 +119,10 @@ def _reference(inputs): def test_multi_input(): dtype = "int32" target = "llvm" - inputs = {"x": np.array([[-10, 1], [5, 1]], dtype=dtype), "y": np.array([[1, 2], [3, 4]], dtype=dtype)} + inputs = { + "x": np.array([[-10, 1], [5, 1]], dtype=dtype), + "y": np.array([[1, 2], [3, 4]], dtype=dtype), + } def _relay(): x = relay.var("x", shape=(2, 2), dtype=dtype) @@ -131,7 +134,7 @@ def _reference(inputs): x = inputs["x"] y = inputs["y"] return np.add(x, y) # add - + relax_mod = relay_translator.from_relay( _relay(), target, @@ -159,10 +162,10 @@ def _relay(): def _reference(inputs): x = inputs["x"] - abs = np.abs(x) # abs + abs = np.abs(x) # abs out = abs - 1 return [abs, out] - + relax_mod = relay_translator.from_relay( _relay(), target, diff --git a/tests/python/relax/aot/test_pass_aot_lower_main.py b/tests/python/relax/aot/test_pass_aot_lower_main.py index 7a0d5380f072..c0a664f7921f 100644 --- a/tests/python/relax/aot/test_pass_aot_lower_main.py +++ b/tests/python/relax/aot/test_pass_aot_lower_main.py @@ -124,7 +124,11 @@ def test_multi_input(): @tvm.script.ir_module class MultiInput: @R.function - def main(a: R.Tensor((5, 7), "float32"), b: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")): + def main( + a: R.Tensor((5, 7), "float32"), + b: R.Tensor((5, 7), "float32"), + output: R.Tensor((5, 7), "float32"), + ): R.func_attr({"input_vars": [a, b], "output_vars": [output]}) tid_0 = output _ = R.call_packed("add", a, b, tid_0, type_args=R.Tensor(ndim=2, dtype="float32")) @@ -149,11 +153,17 @@ def test_multi_output(): @tvm.script.ir_module class MultiOutput: @R.function - def main(a: R.Tensor((5, 7), "float32"), output_0: R.Tensor((5, 7), "float32"), output_1: R.Tensor((5, 7), "float32")): + def main( + a: R.Tensor((5, 7), "float32"), + output_0: R.Tensor((5, 7), "float32"), + output_1: R.Tensor((5, 7), "float32"), + ): R.func_attr({"input_vars": [a], "output_vars": [output_0, output_1]}) tid_0 = output_0 tid_1 = output_1 - _ = R.call_packed("duplicate", a, tid_0, tid_1, type_args=R.Tensor(ndim=2, dtype="float32")) + _ = R.call_packed( + "duplicate", a, tid_0, tid_1, type_args=R.Tensor(ndim=2, dtype="float32") + ) return () # fmt: off @@ -202,7 +212,9 @@ class TupleGetItem: def main(a: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")): R.func_attr({"input_vars": [a], "output_vars": [output]}) tup = (a, a) - _ = R.call_packed("identity", tup[1], output, type_args=R.Tensor(ndim=2, dtype="float32")) + _ = R.call_packed( + "identity", tup[1], output, type_args=R.Tensor(ndim=2, dtype="float32") + ) return () # fmt: off @@ -235,7 +247,9 @@ def main(a: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")): tid_2 = R.memory.alloc_tensor(alloc_2, (5, 7), offset=0, dtype="float32") _ = R.call_packed("identity", tid_0, tid_2, type_args=R.Tensor(ndim=2, dtype="float32")) tid_3 = output - _ = R.call_packed("add", tid_1, tid_2, tid_3, type_args=R.Tensor(ndim=2, dtype="float32")) + _ = R.call_packed( + "add", tid_1, tid_2, tid_3, type_args=R.Tensor(ndim=2, dtype="float32") + ) return () # fmt: off @@ -268,9 +282,17 @@ def test_device_hooks(): @tvm.script.ir_module class DeviceHooks: @T.prim_func - def identity(a: T.handle, output: T.handle, device_context_example_target_hook: T.handle) -> None: + def identity( + a: T.handle, output: T.handle, device_context_example_target_hook: T.handle + ) -> None: # function attr dict - T.func_attr({"global_symbol": "identity", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]})}) + T.func_attr( + { + "global_symbol": "identity", + "runner_function": True, + "target": T.target({"kind": "llvm", "tag": "", "keys": ["cpu"]}), + } + ) a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16) output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16) # body diff --git a/tests/python/relax/test_relax_usmp_convert_to_dps.py b/tests/python/relax/test_relax_usmp_convert_to_dps.py index b0339a95a603..561185c9513f 100644 --- a/tests/python/relax/test_relax_usmp_convert_to_dps.py +++ b/tests/python/relax/test_relax_usmp_convert_to_dps.py @@ -233,6 +233,57 @@ def test_tuple_both_alloc(): # tvm.ir.assert_structural_equal(actual_func, ref_func) +# fmt: off +@tvm.script.ir_module +class TestTupleBothAllocDeadCode: + @R.function + def main(input: R.Tensor((16, 16), "uint8")) -> R.Tuple(R.Tensor(None, "float32", ndim = 2), R.Tensor(None, "int32", ndim = 2)): + # block 0 + tsid_11 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0) + alloc = R.builtin.alloc_tensor((5, 7), dtype="float32", runtime_device_index=0) + _ = R.call_packed("prim_func_2", input, tsid_11, alloc, type_args=(R.Tensor(ndim=2, dtype="float32"))) + output_1 = alloc + + alloc1 = R.builtin.alloc_tensor((5, 7), dtype="int8", runtime_device_index=0) + _1 = R.call_packed("prim_func_3", input, alloc, alloc1, type_args=(R.Tensor(ndim=2, dtype="int8"))) + lv0 = alloc1 + + tsid_12 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0) + alloc2 = R.builtin.alloc_tensor((802816, 1), dtype="int32", runtime_device_index=0) + _2 = R.call_packed("prim_func_1", input, lv0, tsid_12, alloc2, type_args=(R.Tensor(ndim=2, dtype="int32"))) + output_2 = alloc2 + output = (alloc, alloc2) + gv = output + return output + + +@tvm.script.ir_module +class TestTupleBothAllocDeadCodeExpected: + @R.function + def main(input: R.Tensor((16, 16), "uint8"), alloc: R.Tensor((5, 7), "float32"), alloc2: R.Tensor((802816, 1), "int32")): + # block 0 + tsid_11 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0) + _ = R.call_packed("prim_func_2", input, tsid_11, alloc, type_args=(R.Tensor(ndim=2, dtype="float32"))) + alloc1 = R.builtin.alloc_tensor((5, 7), dtype="int8", runtime_device_index=0) + _1 = R.call_packed("prim_func_3", input, alloc, alloc1, type_args=(R.Tensor(ndim=2, dtype="int8"))) + lv0 = alloc1 + tsid_12 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0) + _2 = R.call_packed("prim_func_1", input, lv0, tsid_12, alloc2, type_args=(R.Tensor(ndim=2, dtype="int32"))) + return R.Tuple() + +# fmt: on + + +def test_tuple_both_alloc_dead_code(): + before_mod = TestTupleBothAllocDeadCode + after_mod = tvm.relax.transform.ConvertRelaxMainToDPS(attach_io_to_attrs=False)(before_mod) + expected_mod = TestTupleBothAllocDeadCodeExpected + for gv, ref_func in expected_mod.functions.items(): + actual_func = after_mod[gv.name_hint] + assert str(actual_func) == str(ref_func) + # tvm.ir.assert_structural_equal(actual_func, ref_func) + + # fmt: off @tvm.script.ir_module class TestTupleOneAllocOneParam: