Skip to content

Commit

Permalink
[REFACTOR] relay::Module Def -> TypeDef (apache#4665)
Browse files Browse the repository at this point in the history
* [REFACTOR] relay::Module Def -> TypeDef

The term Def was not very clear about what is the object of interest(could be function def or type def).
Changes the term to TypeDef to be more explicit.

* Update include/tvm/relay/module.h

Co-Authored-By: Wei Chen <[email protected]>

Co-authored-by: Wei Chen <[email protected]>
  • Loading branch information
2 people authored and alexwong committed Feb 28, 2020
1 parent 43a22c5 commit 23f60be
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 29 deletions.
22 changes: 12 additions & 10 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,20 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the
* environment.
*/
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);
TVM_DLL void AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);

/*!
* \brief Add a type definition to the global environment.
* \param var The name of the global function.
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The ADT.
* \param update Controls whether you can replace a definition in the
* environment.
*
* It does not do type inference as AddDef does.
* It does not do type checking as AddTypeDef does.
*/
TVM_DLL void AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false);
TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var,
const TypeData& type,
bool update = false);

/*!
* \brief Update a function in the global environment.
Expand All @@ -129,7 +131,7 @@ class ModuleNode : public RelayNode {
* \param var The name of the global type definition to update.
* \param type The new ADT.
*/
TVM_DLL void UpdateDef(const GlobalTypeVar& var, const TypeData& type);
TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Remove a function from the global environment.
Expand Down Expand Up @@ -162,7 +164,7 @@ class ModuleNode : public RelayNode {
* \brief Collect all global vars defined in this module.
* \returns An array of global vars
*/
tvm::Array<GlobalVar> GetGlobalVars() const;
TVM_DLL tvm::Array<GlobalVar> GetGlobalVars() const;

/*!
* \brief Look up a global function by its name.
Expand All @@ -175,7 +177,7 @@ class ModuleNode : public RelayNode {
* \brief Collect all global type vars defined in this module.
* \returns An array of global type vars
*/
tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;
TVM_DLL tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;

/*!
* \brief Look up a global function by its variable.
Expand All @@ -196,14 +198,14 @@ class ModuleNode : public RelayNode {
* \param var The var of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;
TVM_DLL TypeData LookupTypeDef(const GlobalTypeVar& var) const;

/*!
* \brief Look up a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const std::string& var) const;
TVM_DLL TypeData LookupTypeDef(const std::string& var) const;

/*!
* \brief Look up a constructor by its tag.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class AlphaEqualHandler:
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->name_hint))) {
!Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
return false;
}
}
Expand Down
26 changes: 13 additions & 13 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
}
}

void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
AddDefUnchecked(var, type, update);
void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up
// a definition potentially
CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type;
}

void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
Expand All @@ -208,8 +208,8 @@ void ModuleNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true);
}

void ModuleNode::UpdateDef(const GlobalTypeVar& var, const TypeData& type) {
this->AddDef(var, type, true);
void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
this->AddTypeDef(var, type, true);
}

void ModuleNode::Remove(const GlobalVar& var) {
Expand All @@ -231,16 +231,16 @@ Function ModuleNode::Lookup(const std::string& name) const {
return this->Lookup(id);
}

TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
TypeData ModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}

TypeData ModuleNode::LookupDef(const std::string& name) const {
TypeData ModuleNode::LookupTypeDef(const std::string& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupDef(id);
return this->LookupTypeDef(id);
}

Constructor ModuleNode::LookupTag(const int32_t tag) {
Expand All @@ -257,13 +257,13 @@ void ModuleNode::Update(const Module& mod) {
this->AddUnchecked(pair.first, pair.second);
}
for (auto pair : mod->type_definitions) {
this->AddDefUnchecked(pair.first, pair.second);
this->AddTypeDefUnchecked(pair.first, pair.second);
}
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
}
for (auto pair : mod->type_definitions) {
this->UpdateDef(pair.first, pair.second);
this->UpdateTypeDef(pair.first, pair.second);
}
}

Expand Down Expand Up @@ -347,7 +347,7 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add")
});

TVM_REGISTER_GLOBAL("relay._module.Module_AddDef")
.set_body_method<Module>(&ModuleNode::AddDef);
.set_body_method<Module>(&ModuleNode::AddTypeDef);

TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar")
.set_body_method<Module>(&ModuleNode::GetGlobalVar);
Expand Down Expand Up @@ -376,12 +376,12 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str")

TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef")
.set_body_typed([](Module mod, GlobalTypeVar var) {
return mod->LookupDef(var);
return mod->LookupTypeDef(var);
});

TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str")
.set_body_typed([](Module mod, std::string var) {
return mod->LookupDef(var);
return mod->LookupTypeDef(var);
});

TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/eta_expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class EtaExpander : public ExprMutator {
params.push_back(VarNode::make("eta_expand_param", param_type));
}
tvm::Array<Type> type_params;
TypeData adt_def = mod_->LookupDef(cons->belong_to);
TypeData adt_def = mod_->LookupTypeDef(cons->belong_to);
for (const auto& type_var : adt_def->type_vars) {
type_params.push_back(type_var_replacer_.VisitType(type_var));
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/kind_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {

// finally we need to check the module to check the number of type params
auto var = GetRef<GlobalTypeVar>(gtv);
auto data = mod->LookupDef(var);
auto data = mod->LookupTypeDef(var);
if (data->type_vars.size() != op->args.size()) {
ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc
<< "; got " << op->args.size()));
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/match_exhaustion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,

// for a wildcard node, create constructor nodes with wildcards for all args.
if (cand.as<PatternWildcardNode>()) {
TypeData td = mod->LookupDef(gtv);
TypeData td = mod->LookupTypeDef(gtv);
// for each constructor add a candidate.
Array<Pattern> ret;
for (auto constructor : td->constructors) {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
CHECK(mod_.defined())
<< "Cannot do type inference without a environment:"
<< c->name_hint;
TypeData td = mod_->LookupDef(c->belong_to);
TypeData td = mod_->LookupTypeDef(c->belong_to);
std::vector<Type> types;
for (const auto & t : td->type_vars) {
types.push_back(t);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class TypeVarEVisitor : private ExprVisitor {

void VisitExpr_(const ConstructorNode* cn) final {
// for constructors, type vars will be bound in the module
auto data = mod_->LookupDef(cn->belong_to);
auto data = mod_->LookupTypeDef(cn->belong_to);
for (const auto& tv : data->type_vars) {
type_vars_.Insert(tv);
bound_type_vars_.Insert(tv);
Expand Down

0 comments on commit 23f60be

Please sign in to comment.