diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 262c82df5c5d..7895f32d7e21 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -104,18 +104,20 @@ class ModuleNode : public RelayNode { * \param update Controls whether you can replace a definition in the * environment. */ - TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type, bool update = false); + TVM_DLL void AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update = false); /*! - * \brief Add a type definition to the global environment. - * \param var The name of the global function. + * \brief Add a type-level definition to the global environment. + * \param var The var of the global type definition. * \param type The ADT. * \param update Controls whether you can replace a definition in the * environment. * - * It does not do type inference as AddDef does. + * It does not do kind checking as AddTypeDef does. */ - TVM_DLL void AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false); + TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, + const TypeData& type, + bool update = false); /*! * \brief Update a function in the global environment. @@ -129,7 +131,7 @@ class ModuleNode : public RelayNode { * \param var The name of the global type definition to update. * \param type The new ADT. */ - TVM_DLL void UpdateDef(const GlobalTypeVar& var, const TypeData& type); + TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type); /*! * \brief Remove a function from the global environment. @@ -162,7 +164,7 @@ class ModuleNode : public RelayNode { * \brief Collect all global vars defined in this module. * \returns An array of global vars */ - tvm::Array GetGlobalVars() const; + TVM_DLL tvm::Array GetGlobalVars() const; /*! * \brief Look up a global function by its name. @@ -175,7 +177,7 @@ class ModuleNode : public RelayNode { * \brief Collect all global type vars defined in this module. * \returns An array of global type vars */ - tvm::Array GetGlobalTypeVars() const; + TVM_DLL tvm::Array GetGlobalTypeVars() const; /*! * \brief Look up a global function by its variable. @@ -196,14 +198,14 @@ class ModuleNode : public RelayNode { * \param var The var of the global type definition. * \return The type definition. */ - TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const; + TVM_DLL TypeData LookupTypeDef(const GlobalTypeVar& var) const; /*! * \brief Look up a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ - TVM_DLL TypeData LookupDef(const std::string& var) const; + TVM_DLL TypeData LookupTypeDef(const std::string& var) const; /*! * \brief Look up a constructor by its tag. diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index b41d381dd827..bef0454432f2 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -70,7 +70,7 @@ class AlphaEqualHandler: if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; for (const auto& p : lhsm->type_definitions) { if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) || - !Equal(p.second, rhsm->LookupDef(p.first->name_hint))) { + !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) { return false; } } diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 81dae73b6642..fdaa607e380f 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -185,15 +185,15 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& } } -void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type, bool update) { - AddDefUnchecked(var, type, update); +void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) { + AddTypeDefUnchecked(var, type, update); // need to kind check at the end because the check can look up // a definition potentially CHECK(KindCheck(type, GetRef(this)) == Kind::kTypeData) << "Invalid or malformed typedata given to module: " << type; } -void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { +void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { this->type_definitions.Set(var, type); if (!update) { // set global type var map @@ -208,8 +208,8 @@ void ModuleNode::Update(const GlobalVar& var, const Function& func) { this->Add(var, func, true); } -void ModuleNode::UpdateDef(const GlobalTypeVar& var, const TypeData& type) { - this->AddDef(var, type, true); +void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) { + this->AddTypeDef(var, type, true); } void ModuleNode::Remove(const GlobalVar& var) { @@ -231,16 +231,16 @@ Function ModuleNode::Lookup(const std::string& name) const { return this->Lookup(id); } -TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const { +TypeData ModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -TypeData ModuleNode::LookupDef(const std::string& name) const { +TypeData ModuleNode::LookupTypeDef(const std::string& name) const { GlobalTypeVar id = this->GetGlobalTypeVar(name); - return this->LookupDef(id); + return this->LookupTypeDef(id); } Constructor ModuleNode::LookupTag(const int32_t tag) { @@ -257,13 +257,13 @@ void ModuleNode::Update(const Module& mod) { this->AddUnchecked(pair.first, pair.second); } for (auto pair : mod->type_definitions) { - this->AddDefUnchecked(pair.first, pair.second); + this->AddTypeDefUnchecked(pair.first, pair.second); } for (auto pair : mod->functions) { this->Update(pair.first, pair.second); } for (auto pair : mod->type_definitions) { - this->UpdateDef(pair.first, pair.second); + this->UpdateTypeDef(pair.first, pair.second); } } @@ -347,7 +347,7 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add") }); TVM_REGISTER_GLOBAL("relay._module.Module_AddDef") -.set_body_method(&ModuleNode::AddDef); +.set_body_method(&ModuleNode::AddTypeDef); TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar") .set_body_method(&ModuleNode::GetGlobalVar); @@ -376,12 +376,12 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str") TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef") .set_body_typed([](Module mod, GlobalTypeVar var) { - return mod->LookupDef(var); + return mod->LookupTypeDef(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str") .set_body_typed([](Module mod, std::string var) { - return mod->LookupDef(var); + return mod->LookupTypeDef(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 672d551113ee..b9973f3da1f3 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -101,7 +101,7 @@ class EtaExpander : public ExprMutator { params.push_back(VarNode::make("eta_expand_param", param_type)); } tvm::Array type_params; - TypeData adt_def = mod_->LookupDef(cons->belong_to); + TypeData adt_def = mod_->LookupTypeDef(cons->belong_to); for (const auto& type_var : adt_def->type_vars) { type_params.push_back(type_var_replacer_.VisitType(type_var)); } diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 5b7e1c004735..081f132d0612 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -139,7 +139,7 @@ struct KindChecker : TypeFunctor { // finally we need to check the module to check the number of type params auto var = GetRef(gtv); - auto data = mod->LookupDef(var); + auto data = mod->LookupTypeDef(var); if (data->type_vars.size() != op->args.size()) { ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc << "; got " << op->args.size())); diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc index 2392b18e768f..6e18c630d2a7 100644 --- a/src/relay/pass/match_exhaustion.cc +++ b/src/relay/pass/match_exhaustion.cc @@ -183,7 +183,7 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // for a wildcard node, create constructor nodes with wildcards for all args. if (cand.as()) { - TypeData td = mod->LookupDef(gtv); + TypeData td = mod->LookupTypeDef(gtv); // for each constructor add a candidate. Array ret; for (auto constructor : td->constructors) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 78101bcf045a..23ed83cafc0f 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -574,7 +574,7 @@ class TypeInferencer : private ExprFunctor, CHECK(mod_.defined()) << "Cannot do type inference without a environment:" << c->name_hint; - TypeData td = mod_->LookupDef(c->belong_to); + TypeData td = mod_->LookupTypeDef(c->belong_to); std::vector types; for (const auto & t : td->type_vars) { types.push_back(t); diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 3ad5dd11b7d8..577f492120c0 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -140,7 +140,7 @@ class TypeVarEVisitor : private ExprVisitor { void VisitExpr_(const ConstructorNode* cn) final { // for constructors, type vars will be bound in the module - auto data = mod_->LookupDef(cn->belong_to); + auto data = mod_->LookupTypeDef(cn->belong_to); for (const auto& tv : data->type_vars) { type_vars_.Insert(tv); bound_type_vars_.Insert(tv);