From 9f1d1724dfeccb1a87faf7617bd01029f0e3d814 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 16 Jun 2020 09:53:47 +0200 Subject: [PATCH] fix relay.build to not change the module argument in place --- src/relay/backend/build_module.cc | 3 ++- tests/python/relay/test_cpp_build_module.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f9ce24d410b7..dea923d9013f 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -244,7 +244,8 @@ class RelayBuildModule : public runtime::ModuleNode { GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); Function main_func = Downcast(relay_module->Lookup(main_glb_var)); auto new_main = BindParamsByName(main_func, params); - relay_module->Update(main_glb_var, new_main); + IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite(); + relay_module_ptr->Update(main_glb_var, new_main); } Array pass_seqs; diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 8d5438424e32..fa56eb0eef29 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -44,7 +44,12 @@ def test_basic_build(): targets = { tvm.tir.IntImm("int32", ctx.device_type): tgt } - g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params) + mod = tvm.IRModule.from_expr(func) + func_in_mod = mod["main"] + assert mod["main"] == func_in_mod, "cannot compare function to itself" + + g_json, mmod, params = relay.build(mod, targets, "llvm", params=params) + assert mod["main"] == func_in_mod, "relay.build changed module in-place" # test rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)