diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index da7bc12619bd..fa1861051e2f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; + // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); @@ -232,6 +233,72 @@ class DictAttrs : public Attrs { */ TVM_DLL explicit DictAttrs(Map dict); + // Utils for accessing attributes + // This needs to be on DictAttrs, not DictAttrsNode because we return the default + // value if DictAttrsNode is not defined. + /*! + * \brief Get a function attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const BaseFunc& f) { + * auto value = f->attrs.GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!defined()) return default_value; + const DictAttrsNode* node = this->as(); + + auto it = node->dict.find(attr_key); + if (it != node->dict.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + /*! + * \brief Check whether the function has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const BaseFunc& f) { + * if (f->HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { + return GetAttr(attr_key, 0) != 0; + } + TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } +/*! + * \brief Copy the function or module, but overrides + * the attribute value key with the value. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attr_key The attribute key. + * \param attr_value The value attribute value. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with updated attributes. + * + * \note This function performs copy on write optimization for func and module. + * If we move a uniquely referenced func or module into WithAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithAttr(std::move(func), "key1", value1); + * func = WithAttr(std::move(func), "key2", value2); + * + * \endcode + */ +template +inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = input.CopyOnWrite(); + if (node->attrs.defined()) { + node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); + } else { + Map dict = {{attr_key, attr_value}}; + node->attrs = DictAttrs(dict); + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 09c074cb71bd..13b984d9cb35 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -102,21 +102,14 @@ class BaseFuncNode : public RelayExprNode { Optional GetAttr( const std::string& attr_key, Optional default_value = Optional(nullptr)) const { - static_assert(std::is_base_of::value, - "Can only call GetAttr with ObjectRef types."); - if (!attrs.defined()) return default_value; - auto it = attrs->dict.find(attr_key); - if (it != attrs->dict.end()) { - return Downcast>((*it).second); - } else { - return default_value; - } + return attrs.GetAttr(attr_key, default_value); } // variant that uses TObjectRef to enable implicit conversion to default value. template Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); } + /*! * \brief Check whether the function has an non-zero integer attr. * @@ -136,9 +129,7 @@ class BaseFuncNode : public RelayExprNode { * * \endcode */ - bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0) != 0; - } + bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; @@ -154,48 +145,6 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; -/*! - * \brief Create a new function that copies func, but overrides - * the attribute value key with the value. - * - * \param func The input function. - * \param attr_key The attribute key. - * \param attr_value The value attribute value. - * - * \tparam TFunc The corresponding function type. - * - * \returns The new function with updated attributes. - * - * \note This function performs copy on write optimization for func. - * If we move a uniquely referenced func into WithAttr, - * then no additional copy will be performed. - * - * This is also why we make it as a function instead of a member function - * and why we pass by value in the first argument. - * - * \code - * - * // Recommended way to trigger copy on write - * func = WithAttr(std::move(func), "key1", value1); - * func = WithAttr(std::move(func), "key2", value2); - * - * \endcode - */ -template ::value>::type> -inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { - using TNode = typename TFunc::ContainerType; - static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - TNode* node = func.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } - return func; -} - /*! * \brief Generic attribute names that can be attached to any function. * diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 638f132e3179..9ca27ec3b661 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -58,6 +58,60 @@ class IRModuleNode : public Object { Map type_definitions; /*! \brief The source map for the module. */ parser::SourceMap source_map; + /* \brief Additional attributes storing meta-data about the module. */ + DictAttrs attrs; + + /*! + * \brief Get a module attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const IRModule& mod) { + * auto value = f->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + return attrs.GetAttr(attr_key, default_value); + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + /*! + * \brief Check whether the module has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const IRModule& mod) { + * if (mod->HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } IRModuleNode() : source_map() {}