From 559f0c76a0a8ee9c1620ee29ecd8ce1ced07093e Mon Sep 17 00:00:00 2001 From: Florin Blanaru Date: Mon, 30 May 2022 08:18:32 +0100 Subject: [PATCH] [Pass] Add utility that asserts that IRModule is not mutated in a pass. (#11498) --- include/tvm/ir/transform.h | 4 ++ src/ir/transform.cc | 25 ++++++- tests/cpp/pass_immutable_module_test.cc | 86 +++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 tests/cpp/pass_immutable_module_test.cc diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index d8f6632a66ca..febcca5c0107 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -390,6 +390,10 @@ class Pass : public ObjectRef { IRModule operator()(IRModule mod, const PassContext& pass_ctx) const; TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); + + private: + IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node, + const PassContext& pass_ctx); }; /*! diff --git a/src/ir/transform.cc b/src/ir/transform.cc index dfd307d715ae..d945278abc72 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -41,6 +42,8 @@ using tvm::ReprPrinter; using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; +TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool); + struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ PassContext default_context; @@ -264,11 +267,31 @@ IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { << " with opt level: " << pass_info->opt_level; return mod; } - auto ret = node->operator()(std::move(mod), pass_ctx); + IRModule ret; + if (pass_ctx->GetConfig("testing.immutable_module", Bool(false)).value()) { + ret = Pass::AssertImmutableModule(mod, node, pass_ctx); + } else { + ret = node->operator()(std::move(mod), pass_ctx); + } pass_ctx.InstrumentAfterPass(ret, pass_info); return std::move(ret); } +IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node, + const PassContext& pass_ctx) { + size_t before_pass_hash = tvm::StructuralHash()(mod); + ObjectPtr module_ptr = ObjectRef::GetDataPtr(mod); + IRModule copy_mod = IRModule(module_ptr); + IRModule ret = node->operator()(mod, pass_ctx); + size_t after_pass_hash = tvm::StructuralHash()(copy_mod); + if (before_pass_hash != after_pass_hash) { + // The chance of getting a hash conflict between a module and the same module but mutated + // must be very low. + LOG_FATAL << "Immutable module has been modified in pass: " << node->Info()->name; + } + return std::move(ret); +} + /*! * \brief Module-level passes are designed to implement global * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes diff --git a/tests/cpp/pass_immutable_module_test.cc b/tests/cpp/pass_immutable_module_test.cc new file mode 100644 index 000000000000..b90f1deee737 --- /dev/null +++ b/tests/cpp/pass_immutable_module_test.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace transform; + +Pass MutateModulePass() { + auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { + GlobalVar var = mod->GetGlobalVar("dummyFunction"); + mod->Remove(var); + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 1, "ImmutableModulev1", {}); +} + +Pass DoNotMutateModulePass() { + auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { + IRModule result(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map, + mod->attrs); + GlobalVar var = result->GetGlobalVar("dummyFunction"); + result->Remove(var); + return result; + }; + return tvm::transform::CreateModulePass(pass_func, 1, "ImmutableModulev2", {}); +} + +IRModule preamble() { + auto x = relay::Var("x", relay::Type()); + auto f = relay::Function(tvm::Array{x}, x, relay::Type(), {}); + ICHECK(f->IsInstance()); + + auto global_var = GlobalVar("dummyFunction"); + auto mod = IRModule::FromExpr(f, {{global_var, f}}, {}); + return mod; +} + +TEST(Relay, ModuleIsMutated) { + IRModule mod = preamble(); + + EXPECT_THROW( + { + auto pass_ctx = relay::transform::PassContext::Create(); + pass_ctx->config.Set("testing.immutable_module", Bool(true)); + { + tvm::With ctx_scope(pass_ctx); + mod = MutateModulePass()(mod); + } + }, + runtime::InternalError); +} + +TEST(Relay, ModuleIsNotMutated) { + IRModule mod = preamble(); + + auto pass_ctx = relay::transform::PassContext::Create(); + pass_ctx->config.Set("testing.immutable_module", Bool(true)); + { + tvm::With ctx_scope(pass_ctx); + mod = DoNotMutateModulePass()(mod); + } +}