diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f9ce24d410b7..dea923d9013f 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -244,7 +244,8 @@ class RelayBuildModule : public runtime::ModuleNode { GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); Function main_func = Downcast(relay_module->Lookup(main_glb_var)); auto new_main = BindParamsByName(main_func, params); - relay_module->Update(main_glb_var, new_main); + IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite(); + relay_module_ptr->Update(main_glb_var, new_main); } Array pass_seqs; diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 8d5438424e32..fa56eb0eef29 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -44,7 +44,12 @@ def test_basic_build(): targets = { tvm.tir.IntImm("int32", ctx.device_type): tgt } - g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params) + mod = tvm.IRModule.from_expr(func) + func_in_mod = mod["main"] + assert mod["main"] == func_in_mod, "cannot compare function to itself" + + g_json, mmod, params = relay.build(mod, targets, "llvm", params=params) + assert mod["main"] == func_in_mod, "relay.build changed module in-place" # test rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)