Skip to content

Commit

Permalink
[Relax][AOT] Add pass that mangles TIR PrimFunc names
Browse files Browse the repository at this point in the history
Fix edge case in ConvertRelaxToDPS
  • Loading branch information
gigiblender committed Dec 5, 2022
1 parent c099d06 commit 08b2ec1
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 20 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
*/
TVM_DLL Pass InstrumentProfileIntrinsics();

/*!
* \brief Mangle TIR function names by appending a prefix to avoid symbol collisions.
* \return The pass.
*/
TVM_DLL Pass TIRFuncRename();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions src/relax/backend/aot/codegen_aot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ runtime::Module Build(IRModule mod, String mod_name, CompilationConfig config, r
mod = relax::transform::UnifiedStaticMemoryPlanner()(mod);
mod = AOTLowerMain(mod_name, config)(mod);
mod = tir::transform::LegalizePackedCalls()(mod);
mod = tir::transform::TIRFuncRename()(mod);

auto lowered_funcs = tvm::relay::tec::GetPerTargetModules(mod);
auto exec_metadata = tvm::relay::backend::aot::CreateExecutorMetadata(mod, mod_name, executor, workspace_byte_alignment,
Expand Down
18 changes: 12 additions & 6 deletions src/relax/usmp/transform/convert_relax_to_dps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,18 @@ class ConvertRelaxMainToDPS : public ExprMutator {
for (auto iter = block->bindings.rbegin(); iter != block->bindings.rend(); iter++) {
Binding binding = *iter;
if (const auto* var_binding = binding.as<VarBindingNode>()) {
if (var_binding->value->IsInstance<VarNode>() &&
return_alias_.count(var_binding->var) > 0) {
// Alias. Update alias map and do not emit binding.
return_alias_[runtime::Downcast<Var>(var_binding->value)] =
return_alias_[var_binding->var];
continue;
if (var_binding->value->IsInstance<VarNode>()) {
if (return_alias_.count(var_binding->var) > 0) {
// Alias. Update alias map and do not emit binding.
return_alias_[runtime::Downcast<Var>(var_binding->value)] =
return_alias_[var_binding->var];
continue;
}
if (return_alias_.count(var_binding->var) == 0
&& return_alias_.count(var_binding->value) > 0) {
// Creating an alias for a dead var. Do not emit binding.
continue;
}
}

if (var_binding->value->IsInstance<relay::TupleNode>() &&
Expand Down
128 changes: 128 additions & 0 deletions src/tir/transforms/tir_func_rename.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/


/*!
* \file src/relax/backend/aot/tir_func_rename.cc
* \brief Mangles TIR function names to avoid symbol conflicts.
* Appends "_tvm_gen" to all function names in the IRModule.
*/

#include <utility>

#include "tvm/ir/name_supply.h"
#include "tvm/ir/transform.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/stmt_functor.h"

namespace tvm {
namespace tir {
namespace aot {

class TIRMangleFuncName : public StmtExprMutator {

public:
explicit TIRMangleFuncName(IRModule mod) : mod_(std::move(mod)) {
ICHECK(mod_->ContainGlobalVar(runtime::symbol::tvm_module_main)) << "Expecting module to have"
<< " symbol " << runtime::symbol::tvm_module_main << " attached.";
auto main_func_gv = mod_->GetGlobalVar(runtime::symbol::tvm_module_main);
NameSupply name_supply = NameSupply("_tvm_gen");
for (auto pair : mod_->functions) {
if (pair.first.same_as(main_func_gv)) {
// Ignore the main function.
continue;
}
auto prim_func = runtime::Downcast<PrimFunc>(pair.second);
auto func_name = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(func_name.defined()) << "Expecting global_symbol attribute to be attached to the"
" function";
name_map_[func_name.value()] = name_supply->FreshName(func_name.value());
}
}

IRModule operator()() {
auto main_func_gv = mod_->GetGlobalVar(runtime::symbol::tvm_module_main);

Map<GlobalVar, BaseFunc> func_map = Map<GlobalVar, BaseFunc>();
for (auto pair : mod_->functions) {
auto prim_func = runtime::Downcast<PrimFunc>(pair.second);
auto func_name = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);

Stmt new_body = this->VisitStmt(prim_func->body);
if (pair.first.same_as(main_func_gv)) {
// No need to set a new global var and global symbol for the main function.
func_map.Set(pair.first, PrimFunc(prim_func->params, new_body, prim_func->ret_type,
prim_func->buffer_map, prim_func->attrs, prim_func->span));
} else {
ICHECK(name_map_.count(func_name.value()) > 0) << "Expecting new name in name_map_ at "
"this stage.";
GlobalVar new_var = GlobalVar(name_map_[func_name.value()]);
PrimFunc new_func = PrimFunc(prim_func->params, new_body, prim_func->ret_type,
prim_func->buffer_map, prim_func->attrs, prim_func->span);
new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol,
String(name_map_[func_name.value()]));
func_map.Set(new_var, new_func);
}
}

IRModule new_mod = IRModule(func_map, mod_->type_definitions, mod_->Imports(),
mod_->source_map, mod_->attrs);
return new_mod;
}

private:
PrimExpr VisitExpr_(const CallNode* op) override {
String func_name;
if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
func_name = Downcast<StringImm>(op->args[0])->value;
}
if (op->op->IsInstance<PrimFuncNode>()) {
func_name = Downcast<StringImm>(op->args[0])->value;
}
if (func_name.defined() && mod_->ContainGlobalVar(func_name) &&
mod_->Lookup(func_name)->IsInstance<PrimFuncNode>()) {
ICHECK(name_map_.count(func_name) > 0) << "Name map should contain a name.";
StringImm new_name = StringImm(name_map_[func_name]);
Array<PrimExpr> new_args = { new_name };
new_args.insert(new_args.end(), op->args.begin() + 1, op->args.end());
return Call(op->dtype, op->op, new_args, op->span);
}
return StmtExprMutator::VisitExpr_(op);
}

std::unordered_map<std::string, std::string> name_map_;
IRModule mod_;
};

} // namespace aot

namespace transform {

tvm::transform::Pass TIRFuncRename() {
auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
return runtime::Downcast<IRModule>(tvm::tir::aot::TIRMangleFuncName(m)());
};

return tvm::transform::CreateModulePass(pass_func, 0,
"tir.transform.TIRFuncRename", {});
}

} // namespace transform
} // namespace tir
} // namespace tvm
17 changes: 10 additions & 7 deletions tests/python/relax/aot/test_aot_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _relay():
def _reference(inputs):
x = inputs["x"]
return np.abs(x) # abs

relax_mod = relay_translator.from_relay(
_relay(),
target,
Expand All @@ -75,7 +75,7 @@ def _relay():
def _reference(inputs):
x = inputs["x"]
return np.add(x, -1) # add

relax_mod = relay_translator.from_relay(
_relay(),
target,
Expand All @@ -102,7 +102,7 @@ def _relay():
def _reference(inputs):
x = inputs["x"]
return np.add(x, np.array([[1, 2], [3, 4]])) # add

relax_mod = relay_translator.from_relay(
_relay(),
target,
Expand All @@ -119,7 +119,10 @@ def _reference(inputs):
def test_multi_input():
dtype = "int32"
target = "llvm"
inputs = {"x": np.array([[-10, 1], [5, 1]], dtype=dtype), "y": np.array([[1, 2], [3, 4]], dtype=dtype)}
inputs = {
"x": np.array([[-10, 1], [5, 1]], dtype=dtype),
"y": np.array([[1, 2], [3, 4]], dtype=dtype),
}

def _relay():
x = relay.var("x", shape=(2, 2), dtype=dtype)
Expand All @@ -131,7 +134,7 @@ def _reference(inputs):
x = inputs["x"]
y = inputs["y"]
return np.add(x, y) # add

relax_mod = relay_translator.from_relay(
_relay(),
target,
Expand Down Expand Up @@ -159,10 +162,10 @@ def _relay():

def _reference(inputs):
x = inputs["x"]
abs = np.abs(x) # abs
abs = np.abs(x) # abs
out = abs - 1
return [abs, out]

relax_mod = relay_translator.from_relay(
_relay(),
target,
Expand Down
36 changes: 29 additions & 7 deletions tests/python/relax/aot/test_pass_aot_lower_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def test_multi_input():
@tvm.script.ir_module
class MultiInput:
@R.function
def main(a: R.Tensor((5, 7), "float32"), b: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")):
def main(
a: R.Tensor((5, 7), "float32"),
b: R.Tensor((5, 7), "float32"),
output: R.Tensor((5, 7), "float32"),
):
R.func_attr({"input_vars": [a, b], "output_vars": [output]})
tid_0 = output
_ = R.call_packed("add", a, b, tid_0, type_args=R.Tensor(ndim=2, dtype="float32"))
Expand All @@ -149,11 +153,17 @@ def test_multi_output():
@tvm.script.ir_module
class MultiOutput:
@R.function
def main(a: R.Tensor((5, 7), "float32"), output_0: R.Tensor((5, 7), "float32"), output_1: R.Tensor((5, 7), "float32")):
def main(
a: R.Tensor((5, 7), "float32"),
output_0: R.Tensor((5, 7), "float32"),
output_1: R.Tensor((5, 7), "float32"),
):
R.func_attr({"input_vars": [a], "output_vars": [output_0, output_1]})
tid_0 = output_0
tid_1 = output_1
_ = R.call_packed("duplicate", a, tid_0, tid_1, type_args=R.Tensor(ndim=2, dtype="float32"))
_ = R.call_packed(
"duplicate", a, tid_0, tid_1, type_args=R.Tensor(ndim=2, dtype="float32")
)
return ()

# fmt: off
Expand Down Expand Up @@ -202,7 +212,9 @@ class TupleGetItem:
def main(a: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")):
R.func_attr({"input_vars": [a], "output_vars": [output]})
tup = (a, a)
_ = R.call_packed("identity", tup[1], output, type_args=R.Tensor(ndim=2, dtype="float32"))
_ = R.call_packed(
"identity", tup[1], output, type_args=R.Tensor(ndim=2, dtype="float32")
)
return ()

# fmt: off
Expand Down Expand Up @@ -235,7 +247,9 @@ def main(a: R.Tensor((5, 7), "float32"), output: R.Tensor((5, 7), "float32")):
tid_2 = R.memory.alloc_tensor(alloc_2, (5, 7), offset=0, dtype="float32")
_ = R.call_packed("identity", tid_0, tid_2, type_args=R.Tensor(ndim=2, dtype="float32"))
tid_3 = output
_ = R.call_packed("add", tid_1, tid_2, tid_3, type_args=R.Tensor(ndim=2, dtype="float32"))
_ = R.call_packed(
"add", tid_1, tid_2, tid_3, type_args=R.Tensor(ndim=2, dtype="float32")
)
return ()

# fmt: off
Expand Down Expand Up @@ -268,9 +282,17 @@ def test_device_hooks():
@tvm.script.ir_module
class DeviceHooks:
@T.prim_func
def identity(a: T.handle, output: T.handle, device_context_example_target_hook: T.handle) -> None:
def identity(
a: T.handle, output: T.handle, device_context_example_target_hook: T.handle
) -> None:
# function attr dict
T.func_attr({"global_symbol": "identity", "runner_function": True, "target": T.target({"kind":"llvm", "tag":"", "keys":["cpu"]})})
T.func_attr(
{
"global_symbol": "identity",
"runner_function": True,
"target": T.target({"kind": "llvm", "tag": "", "keys": ["cpu"]}),
}
)
a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16)
output_buffer = T.match_buffer(output, [5, 7], dtype="float32", align=16)
# body
Expand Down
51 changes: 51 additions & 0 deletions tests/python/relax/test_relax_usmp_convert_to_dps.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,57 @@ def test_tuple_both_alloc():
# tvm.ir.assert_structural_equal(actual_func, ref_func)


# fmt: off
@tvm.script.ir_module
class TestTupleBothAllocDeadCode:
@R.function
def main(input: R.Tensor((16, 16), "uint8")) -> R.Tuple(R.Tensor(None, "float32", ndim = 2), R.Tensor(None, "int32", ndim = 2)):
# block 0
tsid_11 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0)
alloc = R.builtin.alloc_tensor((5, 7), dtype="float32", runtime_device_index=0)
_ = R.call_packed("prim_func_2", input, tsid_11, alloc, type_args=(R.Tensor(ndim=2, dtype="float32")))
output_1 = alloc

alloc1 = R.builtin.alloc_tensor((5, 7), dtype="int8", runtime_device_index=0)
_1 = R.call_packed("prim_func_3", input, alloc, alloc1, type_args=(R.Tensor(ndim=2, dtype="int8")))
lv0 = alloc1

tsid_12 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0)
alloc2 = R.builtin.alloc_tensor((802816, 1), dtype="int32", runtime_device_index=0)
_2 = R.call_packed("prim_func_1", input, lv0, tsid_12, alloc2, type_args=(R.Tensor(ndim=2, dtype="int32")))
output_2 = alloc2
output = (alloc, alloc2)
gv = output
return output


@tvm.script.ir_module
class TestTupleBothAllocDeadCodeExpected:
@R.function
def main(input: R.Tensor((16, 16), "uint8"), alloc: R.Tensor((5, 7), "float32"), alloc2: R.Tensor((802816, 1), "int32")):
# block 0
tsid_11 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0)
_ = R.call_packed("prim_func_2", input, tsid_11, alloc, type_args=(R.Tensor(ndim=2, dtype="float32")))
alloc1 = R.builtin.alloc_tensor((5, 7), dtype="int8", runtime_device_index=0)
_1 = R.call_packed("prim_func_3", input, alloc, alloc1, type_args=(R.Tensor(ndim=2, dtype="int8")))
lv0 = alloc1
tsid_12 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0)
_2 = R.call_packed("prim_func_1", input, lv0, tsid_12, alloc2, type_args=(R.Tensor(ndim=2, dtype="int32")))
return R.Tuple()

# fmt: on


def test_tuple_both_alloc_dead_code():
before_mod = TestTupleBothAllocDeadCode
after_mod = tvm.relax.transform.ConvertRelaxMainToDPS(attach_io_to_attrs=False)(before_mod)
expected_mod = TestTupleBothAllocDeadCodeExpected
for gv, ref_func in expected_mod.functions.items():
actual_func = after_mod[gv.name_hint]
assert str(actual_func) == str(ref_func)
# tvm.ir.assert_structural_equal(actual_func, ref_func)


# fmt: off
@tvm.script.ir_module
class TestTupleOneAllocOneParam:
Expand Down

0 comments on commit 08b2ec1

Please sign in to comment.