diff --git a/example/src/operator.cc b/example/src/operator.cc index 1533b6a4b250..34e4529ecb0b 100644 --- a/example/src/operator.cc +++ b/example/src/operator.cc @@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape) NNVM_REGISTER_OP(cast) .describe("cast source type to target") .set_num_inputs(1) +.include("ElementwiseOpAttr") .set_attr_parser( [](NodeAttrs* attrs) { // parse attr parser to get target attribute @@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast) CHECK(is >> dtype); attrs->parsed = std::move(dtype); }) -.set_attr("FInferShape", SameShape) .set_attr( "FInferType", [](const NodeAttrs& attrs, std::vector *itype, @@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast) return true; }); -NNVM_REGISTER_OP(exp) -.describe("take exponential") -.set_num_inputs(1) -.set_attr("FInferShape", SameShape) -.set_attr( - "FGradient", [](const NodePtr& n, - const std::vector& ograds) { - return std::vector{ - MakeNode("mul", n->attrs.name + "_grad", - {ograds[0], NodeEntry{n, 0, 0}}) - }; - }); - NNVM_REGISTER_OP(identity) .describe("identity function") .set_num_inputs(1) -.set_attr("FInferShape", SameShape) +.include("ElementwiseOpAttr") .set_attr( "FGradient", [](const NodePtr& n, const std::vector& ograds) { @@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add) .describe("add two data together") .set_num_inputs(2) .add_alias("__add_symbol__") -.set_attr("FInferShape", SameShape) +.include("ElementwiseOpAttr") .set_attr("FInplaceOption", InplaceIn0Out0) .set_attr( "FGradient", [](const NodePtr& n, @@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(mul) .describe("multiply two data together") .set_num_inputs(2) +.include("ElementwiseOpAttr") .set_attr("FInferShape", SameShape) .set_attr("FInplaceOption", InplaceIn0Out0) .set_attr( @@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign) return std::vector{0}; }); +NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) +.set_attr("FInferShape", SameShape); + + +NNVM_REGISTER_OP(exp) +.describe("take exponential") +.set_num_inputs(1) +.include("ElementwiseOpAttr") +.set_attr( + "FGradient", [](const NodePtr& n, + const std::vector& ograds) { + return std::vector{ + MakeNode("mul", n->attrs.name + "_grad", + {ograds[0], NodeEntry{n, 0, 0}}) + }; + }); + + } // namespace myproject diff --git a/include/nnvm/op.h b/include/nnvm/op.h index b79f170ac878..23648980e1f9 100644 --- a/include/nnvm/op.h +++ b/include/nnvm/op.h @@ -22,6 +22,7 @@ class Node; struct NodeAttrs; template class OpMap; +class OpGroup; class OpRegistryEntry; using dmlc::ParamFieldInfo; @@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits::max(); * NNVM_REGISTER_OP(add) * .describe("add two inputs together") * .set_num_inputs(2) - * .set_attr("gpu_kernel", AddKernel); + * .set_attr("OpKernel", AddKernel) + * .include("ElementwiseOpAttr"); + * + * // can register attribute by group + * // all the ops that include the group get the attribute. + * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) + * .set_attr("FInferShape", ElementwiseInferShape); * * NNVM_REGISTER_OP(sub) * .describe("substract one tensor from another") @@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits::max(); * // Can call regster multiple times in different files * // to register different part of information * NNVM_REGISTER_OP(sub) - * .set_attr("gpu_kernel", SubKernel); + * .set_attr("OpKernel", SubKernel); + * .include("ElementwiseOpAttr"); * * // get operators from registry. * void my_function() { @@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits::max(); * * // get additional registered information, * // Assume user registered a OpKernel type attribute as gpu_kernel on each operator. - * const OpMap& kernel = Op::GetAttr("gpu_kernel"); + * const OpMap& kernel = Op::GetAttr("OpKernel"); * // we can get the kernel functions by using operator as key. * auto add_kernel = kernel[add]; * auto sub_kernel = kernel[sub]; @@ -199,6 +207,23 @@ class Op { * \return reference to self. */ inline Op& set_attr_parser(std::function fn); // NOLINT(*) + /*! + * \brief Register additional attributes to operator. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline Op& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, + int plevel = 10); /*! * \brief Add another alias to this operator. * The same Op can be queried with Op::Get(alias) @@ -207,14 +232,13 @@ class Op { */ Op& add_alias(const std::string& alias); // NOLINT(*) /*! - * \brief Register additional attributes to operator. - * \param attr_name The name of the attribute. - * \param value The value to be set. - * \tparam ValueType The type of the value to be set. + * \brief Include all the attributes from an registered op group. + * \param group_name The name of the group. + * \return reference to self. + * + * \sa NNVM_REGISTER_OP_GROUP */ - template - inline Op& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value); + Op& include(const std::string& group_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. @@ -235,6 +259,7 @@ class Op { private: template friend class OpMap; + friend class OpGroup; friend class dmlc::Registry; // Program internal unique index of operator. // Used to help index the program. @@ -246,6 +271,13 @@ class Op { // update the attribute OpMap static void UpdateAttrMap(const std::string& key, std::function updater); + // add a trigger based on tag matching on certain tag attribute + // This will apply trigger on all the op such that + // include the corresponding group. + // The trigger will also be applied to all future registrations + // that calls include + static void AddGroupTrigger(const std::string& group_name, + std::function trigger); }; /*! @@ -285,14 +317,44 @@ class OpMap { OpMap() = default; }; +/*! + * \brief auxiliary data structure used to + * set attributes to a group of operators + */ +class OpGroup { + public: + /*! \brief the tag key to be matched */ + std::string group_name; + /*! + * \brief Register additional attributes to operator group. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, + int plevel = 1); +}; + // internal macros to make -#define NNVM_REGISTER_VAR_DEF(OpName) \ +#define NNVM_REGISTER_VAR_DEF(OpName) \ static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName +#define NNVM_REGISTER_GVAR_DEF(TagName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName + /*! * \def NNVM_REGISTER_OP - * \brief Register - * This macro must be used under namespace dmlc, and only used once in cc file. + * \brief Register a new operator, or set attribute of the corresponding op. + * * \param OpName The name of registry * * \code @@ -308,6 +370,31 @@ class OpMap { DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) +/*! + * \def NNVM_REGISTER_OP_GROUP + * \brief Register attribute to a group of operators. + * These attributes will be registered to Op that include the group. + * + * \param GroupName The name of the group. + * + * \code + * + * NNVM_REGISTER_OP(add) + * .include("ElementwiseOpAttr"); + * + * // register same attributes to all the ops that include the group + * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) + * .set_attr("FInferShape", ElementwiseInferShape); + * + * NNVM_REGISTER_OP(mul) + * .include("ElementwiseOpAttr"); + * + * \endcode + */ +#define NNVM_REGISTER_OP_GROUP(GroupName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \ + ::nnvm::OpGroup {#GroupName} + // implementations of template functions after this. // member function of Op template @@ -330,9 +417,14 @@ inline const OpMap& Op::GetAttr(const std::string& key) { template inline Op& Op::set_attr( // NOLINT(*) - const std::string& attr_name, const ValueType& value) { + const std::string& attr_name, + const ValueType& value, + int plevel) { + CHECK_GT(plevel, 0) + << "plevel in set_attr must be greater than 0"; // update the attribute map of the key by creating new empty if needed. - UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) { + UpdateAttrMap(attr_name, + [this, attr_name, value, plevel](any* pmap) { // the callback is in lockscope so is threadsafe. if (pmap->empty()) { OpMap pm; @@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*) std::make_pair(ValueType(), 0)); } std::pair& p = vec[index_]; - CHECK(p.second == 0) + CHECK(p.second != plevel) << "Attribute " << attr_name << " of operator " << this->name - << " is already registered."; - vec[index_] = std::make_pair(value, 1); + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + vec[index_] = std::make_pair(value, plevel); + } }); return *this; } + inline Op& Op::describe(const std::string& descr) { // NOLINT(*) this->description = descr; return *this; @@ -409,7 +504,7 @@ template inline int OpMap::count(const Op* op) const { if (op == nullptr) return 0; const uint32_t idx = op->index_; - return idx < data_.size() ? data_[idx].second : 0; + return idx < data_.size() ? (data_[idx].second != 0) : 0; } template @@ -433,6 +528,17 @@ inline const ValueType& OpMap::get(const Op* op, const ValueType& def } } +template +inline OpGroup& OpGroup::set_attr(const std::string& attr_name, + const ValueType& value, + int plevel) { + auto trigger = [attr_name, value, plevel](Op* op) { + op->set_attr(attr_name, value, plevel); + }; + Op::AddGroupTrigger(group_name, trigger); + return *this; +} + } // namespace nnvm #endif // NNVM_OP_H_ diff --git a/src/core/op.cc b/src/core/op.cc index a8e54ab92cab..e554d36b4e8c 100644 --- a/src/core/op.cc +++ b/src/core/op.cc @@ -9,6 +9,7 @@ #include #include #include +#include namespace dmlc { // enable registry @@ -20,11 +21,16 @@ namespace nnvm { // single manager of operator information. struct OpManager { // mutex to avoid registration from multiple threads. - std::mutex mutex; + // recursive is needed for trigger(which calls UpdateAttrMap) + std::recursive_mutex mutex; // global operator counter std::atomic op_counter{0}; // storage of additional attribute table. std::unordered_map > attr; + // storage of existing triggers + std::unordered_map > > tmap; + // group of each operator. + std::vector > op_group; // get singleton of the static OpManager* Global() { static OpManager inst; @@ -66,10 +72,42 @@ const any* Op::GetAttrMap(const std::string& key) { void Op::UpdateAttrMap(const std::string& key, std::function updater) { OpManager* mgr = OpManager::Global(); - std::lock_guard(mgr->mutex); + std::lock_guard(mgr->mutex); std::unique_ptr& value = mgr->attr[key]; if (value.get() == nullptr) value.reset(new any()); if (updater != nullptr) updater(value.get()); } +void Op::AddGroupTrigger(const std::string& group_name, + std::function trigger) { + OpManager* mgr = OpManager::Global(); + std::lock_guard(mgr->mutex); + auto& tvec = mgr->tmap[group_name]; + tvec.push_back(trigger); + auto& op_group = mgr->op_group; + for (const Op* op : dmlc::Registry::List()) { + if (op->index_ < op_group.size() && + op_group[op->index_].count(group_name) != 0) { + trigger((Op*)op); // NOLINT(*) + } + } +} + +Op& Op::include(const std::string& group_name) { + OpManager* mgr = OpManager::Global(); + std::lock_guard(mgr->mutex); + auto it = mgr->tmap.find(group_name); + if (it != mgr->tmap.end()) { + for (auto& trigger : it->second) { + trigger(this); + } + } + auto& op_group = mgr->op_group; + if (index_ >= op_group.size()) { + op_group.resize(index_ + 1); + } + op_group[index_].insert(group_name); + return *this; +} + } // namespace nnvm