diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 1ee7c323336d1..4613bec706334 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -162,6 +162,14 @@ class IRModuleNode : public Object { */ TVM_DLL Array GetGlobalTypeVars() const; + /*! + * \brief Find constructor of ADT using name + * \param adt name of the ADT the constructor belongs to + * \param cons name of the constructor + * \returns Constructor of ADT, error if not found + */ + TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const; + /*! * \brief Look up a global function by its variable. * \param var The global var to lookup. diff --git a/src/ir/module.cc b/src/ir/module.cc index 45f39d5ade889..a78a7525425c1 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { return (*it).second; } +Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const { + TypeData typeDef = this->LookupTypeDef(adt); + for (Constructor c : typeDef->constructors) { + if (cons.compare(c->name_hint) == 0) { + return c; + } + } + + LOG(FATAL) << adt << " does not contain constructor " << cons; + throw std::runtime_error("Constructor Not Found."); +} + tvm::Array IRModuleNode::GetGlobalTypeVars() const { std::vector global_type_vars; for (const auto& pair : global_type_var_map_) { diff --git a/src/relay/transforms/gradient_cell.cc b/src/relay/transforms/gradient_cell.cc index b5504b3ae7eec..eb60176c12253 100644 --- a/src/relay/transforms/gradient_cell.cc +++ b/src/relay/transforms/gradient_cell.cc @@ -66,23 +66,6 @@ namespace tvm { namespace relay { -/*! -* \brief Get constructor of GradCell TypeDef with name_hint -* -* module must have TypeDefinition of GradCell (defined in gradient.rly) -*/ -Constructor getGradCellConstructor(IRModule module, std::string name_hint) { - TypeData gradCell = module->LookupTypeDef("GradCell"); - for (Constructor c : gradCell->constructors) { - if (name_hint.compare(c->name_hint) == 0) { - return c; - } - } - - LOG(FATAL) << "Constructor " << name_hint << "not found in GradCell typedata."; - throw std::runtime_error("Constructor not found in GradCell typedata"); -} - /*! * \brief Visitor to wrap inputs */ @@ -92,7 +75,7 @@ class InputVisitor: public ExprFunctor { Expr wrapExpr(const Expr expr, const Type& type) { if (type.as()) { - return CallNode::make(getGradCellConstructor(module_, "Raw"), + return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; @@ -191,14 +174,15 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const ConstantNode* op) final { - return CallNode::make(getGradCellConstructor(module_, "Raw"), + return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {GetRef(op)}, Attrs(), {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { // optimize operators if (auto* op = (call_node->op).as()) { - if (op->name.compare("add") == 0 && call_node->args.size() == 2 && + Expr op_expr = GetRef(op); + if (op_expr == Op::Get("add") && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { // case: "add" between two tensors of the same size const auto addFunc = module_->GetGlobalVar("AddGradCell"); @@ -217,7 +201,7 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { args.push_back(VisitExpr(expr)); } return CallNode::make(addFunc, args, Attrs(), {paramType}); - } else if (op->name.compare("multiply") == 0 && call_node->args.size() == 2 && + } else if (op_expr == Op::Get("multiply") && call_node->args.size() == 2 && AlphaEqual(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { // case: "multiply" between two tensors of the same size const auto multFunc = module_->GetGlobalVar("MultiplyGradCell"); @@ -237,17 +221,17 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { args.push_back(VisitExpr(expr)); } return CallNode::make(multFunc, args, Attrs(), {paramType}); - } else if (op->name.compare("ones") == 0) { + } else if (op_expr == Op::Get("ones")) { // ones operator, use One constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(getGradCellConstructor(module_, "One"), + return CallNode::make(module_->GetConstructor("GradCell", "One"), {func}, Attrs(), {call_node->checked_type()}); - } else if (op->name.compare("zeros") == 0) { + } else if (op_expr == Op::Get("zeros")) { // zeros operator, use Zero constructor of GradCell Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, {call_node->checked_type()}, {}); - return CallNode::make(getGradCellConstructor(module_, "Zero"), + return CallNode::make(module_->GetConstructor("GradCell", "Zero"), {func}, Attrs(), {call_node->checked_type()}); } @@ -264,18 +248,18 @@ class GradientCellTransform: public ExprMutator, public TypeMutator { const Expr tensorRes = CallNode::make(call_node->op, args); - if (op->name.compare("ones_like") == 0) { + if (op_expr == Op::Get("ones_like")) { Expr onesFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(getGradCellConstructor(module_, "One"), + return CallNode::make(module_->GetConstructor("GradCell", "One"), {onesFunction}, Attrs(), {call_node->checked_type()}); - } else if (op->name.compare("zeros_like") == 0) { + } else if (op_expr == Op::Get("zeros_like")) { Expr zerosFunction = Function({}, tensorRes, {call_node->checked_type()}, Array()); - return CallNode::make(getGradCellConstructor(module_, "Zero"), + return CallNode::make(module_->GetConstructor("GradCell", "Zero"), {zerosFunction}, Attrs(), {call_node->checked_type()}); } - return CallNode::make(getGradCellConstructor(module_, "Raw"), {tensorRes}, + return CallNode::make(module_->GetConstructor("GradCell", "Raw"), {tensorRes}, Attrs(), {call_node->checked_type()}); } // call-> op is not a relay op