From 0648fffc9b6fddd27dc04a91ebac9cccd780b3b3 Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Wed, 25 Aug 2021 17:34:10 -0500 Subject: [PATCH] [BUG] ToBasicBlockNormalForm immutability (#8778) * ToBasicBlockNormalForm immutability * better comment on ToBasicBlock * refine comment of ToBasicBlockForm --- .../transforms/to_basic_block_normal_form.cc | 11 +++++---- .../test_pass_to_basic_block_normal_form.py | 24 ++++++++++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 79157bba1918..d03fc1488aea 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -51,8 +51,11 @@ Expr ToBasicBlockNormalFormAux(const Expr& e) { IRModule ToBasicBlockNormalForm(const IRModule& mod) { DLOG(INFO) << "ToBBlock:" << std::endl << mod; + // Create a new module by shallow copy. + auto mod_ = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map); + tvm::Map updates; - auto funcs = mod->functions; + auto funcs = mod_->functions; for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables"; if (const auto* n = it.second.as()) { @@ -63,12 +66,12 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) { } for (auto pair : updates) { - mod->Add(pair.first, pair.second, true); + mod_->Add(pair.first, pair.second, true); } - DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod; + DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod_; - return mod; + return mod_; } bool BasicBlockNormalFormCheck(const Expr& e) { diff --git a/tests/python/relay/test_pass_to_basic_block_normal_form.py b/tests/python/relay/test_pass_to_basic_block_normal_form.py index 642cab751b79..d04afe15b5bb 100644 --- a/tests/python/relay/test_pass_to_basic_block_normal_form.py +++ b/tests/python/relay/test_pass_to_basic_block_normal_form.py @@ -22,7 +22,7 @@ from tvm.relay.analysis import detect_feature from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude -from tvm.relay.testing import count +from tvm.relay.testing import count, create_workload from tvm.relay.analysis import Feature from tvm.relay.analysis import check_basic_block_normal_form @@ -491,5 +491,27 @@ def test_higher_order_nested(): check_basic_block_normal_form(bblock) +def test_immutability(): + simple_net = relay.nn.conv2d( + data=relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")), + weight=relay.var("weight"), + kernel_size=(5, 5), + channels=3, + padding=(1, 1), + ) + simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net) + mod, _ = create_workload(simple_net) + + old_mod = mod + + with tvm.transform.PassContext(opt_level=4): + with tvm.target.Target("llvm"): + seq = tvm.transform.Sequential(passes=[transform.ToBasicBlockNormalForm()], opt_level=4) + new_mod = seq(mod) + + assert old_mod.astext() == mod.astext() + assert old_mod.astext() != new_mod.astext() + + if __name__ == "__main__": pytest.main([__file__])