Skip to content

Commit

Permalink
[Pass] Add utility that asserts that IRModule is not mutated in a pas…
Browse files Browse the repository at this point in the history
…s. (#11498)
  • Loading branch information
gigiblender authored May 30, 2022
1 parent d0b3ec9 commit 559f0c7
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 1 deletion.
4 changes: 4 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ class Pass : public ObjectRef {
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const;

TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);

private:
IRModule static AssertImmutableModule(const IRModule& mod, const PassNode* node,
const PassContext& pass_ctx);
};

/*!
Expand Down
25 changes: 24 additions & 1 deletion src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <dmlc/thread_local.h>
#include <tvm/ir/transform.h>
#include <tvm/node/repr_printer.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

Expand All @@ -41,6 +42,8 @@ using tvm::ReprPrinter;
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;

TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool);

struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
Expand Down Expand Up @@ -264,11 +267,31 @@ IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
<< " with opt level: " << pass_info->opt_level;
return mod;
}
auto ret = node->operator()(std::move(mod), pass_ctx);
IRModule ret;
if (pass_ctx->GetConfig<Bool>("testing.immutable_module", Bool(false)).value()) {
ret = Pass::AssertImmutableModule(mod, node, pass_ctx);
} else {
ret = node->operator()(std::move(mod), pass_ctx);
}
pass_ctx.InstrumentAfterPass(ret, pass_info);
return std::move(ret);
}

IRModule Pass::AssertImmutableModule(const IRModule& mod, const PassNode* node,
const PassContext& pass_ctx) {
size_t before_pass_hash = tvm::StructuralHash()(mod);
ObjectPtr<Object> module_ptr = ObjectRef::GetDataPtr<Object>(mod);
IRModule copy_mod = IRModule(module_ptr);
IRModule ret = node->operator()(mod, pass_ctx);
size_t after_pass_hash = tvm::StructuralHash()(copy_mod);
if (before_pass_hash != after_pass_hash) {
// The chance of getting a hash conflict between a module and the same module but mutated
// must be very low.
LOG_FATAL << "Immutable module has been modified in pass: " << node->Info()->name;
}
return std::move(ret);
}

/*!
* \brief Module-level passes are designed to implement global
* analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes
Expand Down
86 changes: 86 additions & 0 deletions tests/cpp/pass_immutable_module_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <gtest/gtest.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/te/operation.h>

using namespace tvm;
using namespace transform;

Pass MutateModulePass() {
auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule {
GlobalVar var = mod->GetGlobalVar("dummyFunction");
mod->Remove(var);
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 1, "ImmutableModulev1", {});
}

Pass DoNotMutateModulePass() {
auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule {
IRModule result(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map,
mod->attrs);
GlobalVar var = result->GetGlobalVar("dummyFunction");
result->Remove(var);
return result;
};
return tvm::transform::CreateModulePass(pass_func, 1, "ImmutableModulev2", {});
}

IRModule preamble() {
auto x = relay::Var("x", relay::Type());
auto f = relay::Function(tvm::Array<relay::Var>{x}, x, relay::Type(), {});
ICHECK(f->IsInstance<BaseFuncNode>());

auto global_var = GlobalVar("dummyFunction");
auto mod = IRModule::FromExpr(f, {{global_var, f}}, {});
return mod;
}

TEST(Relay, ModuleIsMutated) {
IRModule mod = preamble();

EXPECT_THROW(
{
auto pass_ctx = relay::transform::PassContext::Create();
pass_ctx->config.Set("testing.immutable_module", Bool(true));
{
tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
mod = MutateModulePass()(mod);
}
},
runtime::InternalError);
}

TEST(Relay, ModuleIsNotMutated) {
IRModule mod = preamble();

auto pass_ctx = relay::transform::PassContext::Create();
pass_ctx->config.Set("testing.immutable_module", Bool(true));
{
tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
mod = DoNotMutateModulePass()(mod);
}
}

0 comments on commit 559f0c7

Please sign in to comment.