Skip to content

Commit

Permalink
[REFACTOR][IR] Migrate IRModule ObjectRef to not-null (apache#5654)
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and Trevor Morris committed Jun 18, 2020
1 parent 8e6e824 commit c13cab5
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 25 deletions.
12 changes: 5 additions & 7 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ class IRModule : public ObjectRef {
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {});
/*! \brief default constructor */
IRModule() {}
IRModule() : IRModule(Map<GlobalVar, BaseFunc>()) {}
/*!
* \brief constructor
* \param n The object pointer.
Expand All @@ -298,12 +298,6 @@ class IRModule : public ObjectRef {
return static_cast<IRModuleNode*>(ptr);
}

/*!
* \brief Construct an empty module.
*
* \returns The constructed module
*/
static IRModule Empty() { return IRModule(Map<GlobalVar, BaseFunc>()); }
/*!
* \brief Construct a module from a standalone expression.
*
Expand All @@ -330,6 +324,10 @@ class IRModule : public ObjectRef {

/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;

/*! \brief Declare whether Ref is nullable. */
static constexpr bool _type_is_nullable = false;

// allow copy on write.
TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
};
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,8 @@ def to_cps(func, mod=None):
result: tvm.relay.Function
The output function.
"""
return _ffi_api.to_cps(func, mod)
use_mod = mod if mod is not None else tvm.ir.IRModule()
return _ffi_api.to_cps(func, use_mod)


def un_cps(func):
Expand Down
7 changes: 5 additions & 2 deletions src/relay/analysis/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ FeatureSet DetectFeature(const IRModule& mod) {
return fs;
}

Array<Integer> PyDetectFeature(const Expr& expr, const IRModule& mod) {
FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod) {
FeatureSet fs = DetectFeature(expr);
if (mod.defined()) {
fs = fs + DetectFeature(mod.value());
}
return static_cast<Array<Integer>>(fs);
}

Expand Down
7 changes: 2 additions & 5 deletions src/relay/analysis/match_exhaustion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,8 @@ Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {

// expose for testing only
TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
.set_body_typed([](const Match& match, const IRModule& mod_ref) {
IRModule call_mod = mod_ref;
if (!call_mod.defined()) {
call_mod = IRModule({}, {});
}
.set_body_typed([](const Match& match, const Optional<IRModule>& mod_ref) {
IRModule call_mod = mod_ref.defined() ? mod_ref.value() : IRModule({}, {});
return UnmatchedCases(match, call_mod);
});

Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/compile_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct CachedFuncNode : public Object {
/*! \brief The schedule to the function */
te::Schedule schedule;
/*! \brief The lowered functions to support the function. */
IRModule funcs = IRModule::Empty();
IRModule funcs = IRModule();

/*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states;
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G

for (auto& kv : lowered_funcs_) {
if (ret.lowered_funcs.count(kv.first) == 0) {
ret.lowered_funcs.Set(kv.first, IRModule::Empty());
ret.lowered_funcs.Set(kv.first, IRModule());
}
auto& mod = ret.lowered_funcs[kv.first];
mod->Update(kv.second);
Expand Down Expand Up @@ -395,7 +395,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = IRModule::Empty();
lowered_funcs_[target->str()] = IRModule();
}
lowered_funcs_[target->str()]->Update(lowered_func->funcs);
return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name);
Expand Down
14 changes: 8 additions & 6 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Type WithGradientType(const Type&);
/*! return an expression that represent differentiation of e (according to WithGradientType).
* This version only work on first order code without control flow.
*/
Expr FirstOrderGradient(const Expr& e, const IRModule& mod);
Expr FirstOrderGradient(const Expr& e, const Optional<IRModule>& mod);

Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking
Expand All @@ -78,9 +78,11 @@ Type WithGradientType(const Type& t) {
}

//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const IRModule& mod, const Expr& e) {
if (const auto* x = e.as<GlobalVarNode>()) {
BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
const auto* x = e.as<GlobalVarNode>();

if (mod.defined() && (x)) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
} else {
Expand Down Expand Up @@ -214,7 +216,7 @@ Type GradRetType(const Function& f) {
return TupleType({f->ret_type, TupleType(vt)});
}

Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
Expr FirstOrderGradient(const Expr& re, const Optional<IRModule>& mod) {
// Currently we first remove any global functions for the first
// order case.
auto e = DeGlobal(mod, re);
Expand Down Expand Up @@ -482,7 +484,7 @@ bool MissingGrad(const Expr& e) {
return false;
}

Expr Gradient(const Expr& re, const IRModule& mod) {
Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ Pass SplitHostDevice() {
auto pass_func = [](IRModule mod, PassContext ctx) {
IRModuleNode* mod_ptr = mod.CopyOnWrite();
auto* func_dict = mod_ptr->functions.CopyOnWrite();
IRModule device_mod = IRModule::Empty();
IRModule device_mod = IRModule();

for (auto& kv : func_dict->data) {
if (kv.second->IsInstance<PrimFuncNode>()) {
Expand Down

0 comments on commit c13cab5

Please sign in to comment.