diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h new file mode 100644 index 000000000000..19c5a51162a3 --- /dev/null +++ b/include/tvm/ir/op.h @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/op.h + * \brief Primitive operators(builtin intrinsics) + * and registry for them. + */ +#ifndef TVM_IR_OP_H_ +#define TVM_IR_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { + +// forward declare name. +template +class OpMap; +class GenericOpMap; +class OpRegistry; + +// TODO(tvm-team): migrate low-level intrinsics to use Op +/*! + * \brief Primitive Op(builtin intrinsics) + * + * This data structure stores the meta-data + * about primitive operators that can be invoked via Call. + * + * Low-level IR intrinsics(such as libc.expf) are also + * implemented via Op. + * + * \sa Op + */ +class OpNode : public RelayExprNode { + public: + /*! \brief name of the operator */ + std::string name; + /*! \brief the type of the operator */ + mutable FuncType op_type; + /*! + * \brief detailed description of the operator + * This can be used to generate docstring automatically for the operator. + */ + std::string description; + /* \brief Information of input arguments to the operator */ + Array arguments; + /*! + * \brief The type key of the attribute field + * This can be empty, in which case it defaults to anything. + */ + std::string attrs_type_key; + /*! + * \brief attribute type index, + * this field varies in each run and is not exposed to frontend. + */ + uint32_t attrs_type_index{0}; + /*! + * \brief number of input arguments to the operator, + * -1 means it is variable length + */ + int32_t num_inputs = -1; + /*! + * \brief support level of the operator, + * The lower the more priority it contains. + * This is in analogies to BLAS levels. + */ + int32_t support_level = 10; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("op_type", &op_type); + v->Visit("description", &description); + v->Visit("arguments", &arguments); + v->Visit("attrs_type_key", &attrs_type_key); + v->Visit("num_inputs", &num_inputs); + v->Visit("support_level", &support_level); + } + + /*! + * \brief Check that if current op is a "primtive operator". + * That is the arguments are all type variables, and there is a single + * type relation applied to the input and output types. + */ + bool IsPrimitiveOp() const { + if (is_primitive_ != -1) return is_primitive_ != 0; + is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0; + return is_primitive_ != 0; + } + + static constexpr const char* _type_key = "relay.Op"; + TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode); + + private: + // friend class + friend class GenericOpMap; + friend class OpRegistry; + friend bool IsPrimitiveOp(const RelayExpr&); + // Program internal unique index of operator. + // Used to help index the program. + uint32_t index_{0}; + // whether this is a primitive op. -1 means unknown. + mutable int is_primitive_{-1}; + // Internal function to compute if it is primitive op + bool IsPrimitiveOp_() const { + const auto& fn_ty = this->op_type; + if (fn_ty->type_constraints.size() != 1) return false; + const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); + if (rel == nullptr) return false; + // validate if the type parameter matches up + for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { + if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; + } + return true; + } +}; + +/*! + * \brief Managed reference class to OpNode. + * \sa OpNode + */ +class Op : public RelayExpr { + public: + /*! \brief default constructor */ + Op() {} + /*! \brief constructor from node pointer */ + explicit Op(ObjectPtr n) : RelayExpr(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const OpNode* operator->() const; + /*! + * \brief Get additional registered attribute about operators. + * If nothing has been registered, an empty OpMap will be returned. + * \param attr_name The name of the attribute. + * \return An OpMap of specified attr_name. + * \tparam ValueType The type of the attribute. + */ + template + inline static OpMap GetAttr(const std::string& attr_name); + /*! + * \brief Checks if an attr is present in the registry. + * \param attr_name The name of the attribute. + * \return bool True if the attr is present. + */ + inline static bool HasAttr(const std::string& attr_name); + /*! + * \brief Get an Op for a given operator name. + * Will raise an error if the op has not been registered. + * \param op_name Name of the operator. + * \return Pointer to a Op, valid throughout program lifetime. + */ + TVM_DLL static const Op& Get(const std::string& op_name); + + /*! \brief specify container node */ + using ContainerType = OpNode; + + private: + /*! + * \brief Get generic attrmap given attr name + * \param key The attribute key + * \return reference to GenericOpMap + */ + TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); + /*! + * \brief Checks if the key is present in the registry + * \param key The attribute key + * \return bool True if the key is present + */ + TVM_DLL static const bool HasGenericAttr(const std::string& key); +}; + +/*! + * \brief Helper structure to register operators + * \sa TVM_REGISTER_OP + */ +class OpRegistry { + public: + /*! \return the operator */ + const Op& op() const { return op_; } + /*! + * \brief setter function during registration + * Set the description of operator + * \param descr the description string. + * \return reference to self. + */ + inline OpRegistry& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline OpRegistry& add_argument(const std::string& name, + const std::string& type, + const std::string& description); + /*! + * \brief Attach the type function corresponding to the return type. + * \param rel_name The type relation name to register. + * \param type_rel_func The backing relation function which can solve an arbitrary + * relation on variables. + * \return reference to self. + */ + inline OpRegistry& add_type_rel( + const std::string& rel_name, + runtime::TypedPackedFunc&, + int, + const Attrs&, + const TypeReporter&)> type_rel_func); + /*! + * \brief Set the the attrs type key and index to be AttrsType. + * \tparam AttrsType the attribute type to b set. + * \return reference to self. + */ + template + inline OpRegistry& set_attrs_type(); + /*! + * \brief Set the num_inputs + * \param n The number of inputs to be set. + * \return reference to self. + */ + inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) + /*! + * \brief Set the support level of op. + * \param level The support level. + * \return reference to self. + */ + inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) + /*! + * \brief Register additional attributes to operator. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, int plevel = 10); + + /*! + * \brief Resets an attr of the registry. + * \param attr_name The name of the attribute. + */ + inline void reset_attr(const std::string& attr_name); + + // set the name of the op to be the same as registry + inline OpRegistry& set_name() { // NOLINT(*) + if (get()->name.length() == 0) { + get()->name = name; + } + return *this; + } + /*! \return The global single registry */ + TVM_DLL static ::dmlc::Registry* Registry(); + + private: + friend class ::dmlc::Registry; + // the name + std::string name; + /*! \brief The operator */ + Op op_; + // private constructor + TVM_DLL OpRegistry(); + // return internal pointer to op. + inline OpNode* get(); + // update the attribute OpMap + TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, + int plevel); +}; + +/*! + * \brief Generic map to store additional information of Op. + */ +class GenericOpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline const TVMRetValue& operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const Op& op, ValueType def_value) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param expr The key to the map + * \param def_value The default value when the key does not exist + * or if expr is not an Op. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const RelayExpr& expr, ValueType def_value) const; + + private: + friend class OpRegistry; + // the attribute field. + std::string attr_name_; + // internal data + std::vector > data_; + // The value + GenericOpMap() = default; +}; + +/*! + * \brief Map used to store meta-information about Op. + * \tparam ValueType The type of the value stored in map. + */ +template +class OpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline ValueType operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + inline ValueType get(const Op& op, ValueType def_value) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param expr The key to the map + * \param def_value The default value when the key does not exist + * or if expr is not an Op. + * \return the const reference to the content value. + */ + inline ValueType get(const RelayExpr& expr, ValueType def_value) const; + + private: + friend class Op; + // constructor + explicit OpMap(const GenericOpMap& map) : map_(map) {} + /*! \brief The internal map field */ + const GenericOpMap& map_; +}; + +// internal macros to make +#define TVM_OP_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op + +/*! + * \def TVM_REGISTER_OP + * \brief Register a new operator, or set attribute of the corresponding op. + * + * \param OpName The name of registry + * + * \code + * + * TVM_REGISTER_OP("add") + * .describe("add two inputs together") + * .set_num_inputs(2) + * .set_attr("gpu_kernel", AddKernel); + * + * \endcode + */ +#define TVM_REGISTER_OP(OpName) \ + TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::OpRegistry::Registry() \ + ->__REGISTER_OR_GET__(OpName) \ + .set_name() + +// implementations +inline const OpNode* Op::operator->() const { + return static_cast(get()); +} + +template +inline OpMap Op::GetAttr(const std::string& key) { + return OpMap(Op::GetGenericAttr(key)); +} + +inline bool Op::HasAttr(const std::string& key) { + return Op::HasGenericAttr(key); +} + +inline OpNode* OpRegistry::get() { + return const_cast(op_.operator->()); +} + +inline OpRegistry& OpRegistry::describe( + const std::string& descr) { // NOLINT(*) + get()->description = descr; + return *this; +} + +inline OpRegistry& OpRegistry::add_argument(const std::string& name, + const std::string& type, + const std::string& description) { + auto n = make_object(); + n->name = name; + n->type_info = type; + n->description = description; + get()->arguments.push_back(AttrFieldInfo(n)); + return *this; +} + +inline OpRegistry& OpRegistry::add_type_rel( + const std::string& rel_name, + runtime::TypedPackedFunc&, + int, + const Attrs&, + const TypeReporter&)> type_rel_func) { + auto func_name = std::string("tvm.relay.type_relation.") + rel_name; + TypeRelationFn env_type_rel_func; + + if (runtime::Registry::Get(func_name)) { + auto env_func = EnvFunc::Get(func_name); + env_type_rel_func = env_func; + } else { + runtime::Registry::Register(func_name) + .set_body(type_rel_func.packed()); + auto env_func = EnvFunc::Get(func_name); + env_type_rel_func = env_func; + } + + Array type_params; + Array arg_types; + + // Add inputs. + std::string input_name_prefix = "in"; + for (int i = 0; i < get()->num_inputs; i++) { + auto name = input_name_prefix + std::to_string(i); + auto param = TypeVarNode::make(name, TypeKind::kType); + type_params.push_back(param); + arg_types.push_back(param); + } + + Array ty_call_args = arg_types; + + // Add output type. + auto out_param = TypeVarNode::make("out", TypeKind::kType); + type_params.push_back(out_param); + // this will trigger copy on write. + ty_call_args.push_back(out_param); + + // The attributes of primitive op is nullptr + // + // The attributes of primitive operator can vary at the call site. + // The type of sum is also dependent on Attrs being passed. + // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs. + // + // A common example is sum(x, axis), where the choice of axis + // can affect the type of the function. + TypeConstraint type_rel = + TypeRelationNode::make(env_type_rel_func, + ty_call_args, + arg_types.size(), + Attrs()); + + auto func_type = + FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); + + get()->op_type = func_type; + + return *this; +} + +inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) + get()->num_inputs = n; + return *this; +} + +template +inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) + get()->attrs_type_key = AttrsType::_type_key; + get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); + return *this; +} + +inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) + get()->support_level = n; + return *this; +} + +template +inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; + TVMRetValue rv; + rv = value; + UpdateAttr(attr_name, rv, plevel); + return *this; +} + +// member functions of OpMap +inline int GenericOpMap::count(const Op& op) const { + if (op.defined()) { + const uint32_t idx = op->index_; + return idx < data_.size() ? (data_[idx].second != 0) : 0; + } else { + return 0; + } +} + +inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + CHECK(idx < data_.size() && data_[idx].second != 0) + << "Attribute " << attr_name_ << " has not been registered for Operator " + << op->name; + return data_[idx].first; +} + +template +inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } +} + +template +inline ValueType GenericOpMap::get(const RelayExpr& expr, ValueType value) const { + CHECK(expr.defined()); + if (const OpNode* op = expr.as()) { + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } + } else { + return value; + } +} + +template +inline int OpMap::count(const Op& op) const { + return map_.count(op); +} + +template +inline ValueType OpMap::operator[](const Op& op) const { + return map_[op]; +} + +template +inline ValueType OpMap::get(const Op& op, + ValueType def_value) const { + return map_.get(op, def_value); +} + +template +inline ValueType OpMap::get(const RelayExpr& expr, + ValueType def_value) const { + return map_.get(expr, def_value); +} + +/*! + * \brief Check that an expression is a "primitive operator". + * + * Will return true if the expression is an operator which + * matches the form of primitive operators registered directly + * by the Relay codebase. + * + * That is the arguments are all type variables, and there is a single + * type relation applied to the input and output types. + * + * \param expr An expression. + * \return Whether the expression is primitive op. + */ +inline bool IsPrimitiveOp(const RelayExpr& expr) { + const auto* op = expr.as(); + return op != nullptr && op->IsPrimitiveOp(); +} + +} // namespace tvm +#endif // TVM_IR_OP_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index ab2003e9ec5e..ddabd0fbbcc0 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -51,6 +51,7 @@ #include #include +#include #include #include #include diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h new file mode 100644 index 000000000000..71d1d9eb4520 --- /dev/null +++ b/include/tvm/ir/type_relation.h @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/type_relation.h + * \brief Type relation function for type checking. + */ +#ifndef TVM_IR_TYPE_RELATION_H_ +#define TVM_IR_TYPE_RELATION_H_ + +#include +#include + +namespace tvm { + +// TODO(tqchen): remove after migrate Module to ir. +namespace relay { +struct Module; +} + +/*! + * \brief reporter that reports back to the + * type resolution information. + */ +class TypeReporterNode : public Object { + public: + /*! + * \brief Create a type equality constraint. + * + * The "assign direction" acts as a hint to the solver + * showing that it is more likely to resolve dst by src. + * But it is possible for the solver to resolve src by dst as well. + */ + TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0; + + /*! + * \brief assert shape expression comparison. + * \note Use assert only if any of the condition input is symbolic. + * \param cond The condition of operation. + * \return false if assertation can be proven to have failed + * true if solver can still proceed. + */ + TVM_DLL virtual bool Assert(const PrimExpr& cond)= 0; + /*! + * \brief assert shape expression equals each other. + * \param lhs The left operand. + * \param rhs The right operand. + * \return false if assertation can be proven to have failed + * true if solver can still proceed. + */ + TVM_DLL virtual bool AssertEQ(const PrimExpr& lhs, const PrimExpr& rhs) = 0; + + /*! + * \brief Set the location at which to report unification errors. + * \param ref The program node to report the error. + */ + TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0; + + /*! + * \brief Retrieve the current global module. + * \return The global module. + */ + TVM_DLL virtual relay::Module GetModule() = 0; + + // solver is not serializable. + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.TypeReporter"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); +}; + +/*! + * \brief Container class of TypeReporter. + * \sa TypeReporterNode + */ +class TypeReporter : public ObjectRef { + public: + TypeReporter() {} + explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { + } + TypeReporterNode* operator->() const { + return const_cast( + static_cast(get())); + } + using ContainerType = TypeReporterNode; +}; + +/*! + * \brief User defined type constraint function. + * + * If the input type information can be used to fully decide + * the IncompleteTypes, then the function should call + * reporter.Assign to report the new types, and return true. + * Otherwise, the function should return false. + * + * \param args The arguments to the relation. + * The types are stored in the form of + * [input_type_0, input_type_1, ... input_type_n, + * output_type_0, output_type_1, ... output_type_m] + * + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. + * true if this relation has been resolved. + */ +using TypeRelationFn = + TypedEnvFunc& args, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter)>; + +/*! + * \brief User defined type relation, is an input-output relation on types. + */ +class TypeRelation; +/*! + * \brief TypeRelation container. + * \note This node is not directly serializable. + * The type function need to be lookedup in the module. + */ +class TypeRelationNode : public TypeConstraintNode { + public: + /*! + * \brief The function on input and output variables which + * this is not directly serializable, + * need to be looked-up in the module. + */ + TypeRelationFn func; + /*! \brief The type arguments to the type function. */ + tvm::Array args; + /*! \brief Number of inputs arguments */ + int num_inputs; + /*! \brief Attributes to the relation function */ + Attrs attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("func", &func); + v->Visit("args", &args); + v->Visit("num_inputs", &num_inputs); + v->Visit("attrs", &attrs); + v->Visit("span", &span); + } + + TVM_DLL static TypeRelation make(TypeRelationFn func, + Array args, + int num_args, + Attrs attrs); + + static constexpr const char* _type_key = "relay.TypeRelation"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); +}; + +class TypeRelation : public TypeConstraint { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); +}; +} // namespace tvm +#endif // TVM_IR_TYPE_RELATION_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 6bd0a359cc69..fa47da226dff 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -19,595 +19,23 @@ /*! * \file tvm/relay/op.h - * \brief Primitive operator definition. + * \brief Primitive operators(builtin intrinsics). */ #ifndef TVM_RELAY_OP_H_ #define TVM_RELAY_OP_H_ -#include - -#include -#include -#include -#include -#include -#include - -#include "base.h" -#include "expr.h" -#include "type.h" +#include +#include +#include namespace tvm { namespace relay { -// forward declare name. -template -class OpMap; -class GenericOpMap; -class OpRegistry; - -/*! - * \brief Node container of operator structure. - */ -class OpNode : public relay::ExprNode { - public: - /*! \brief name of the operator */ - std::string name; - /*! \brief the type of the operator */ - mutable FuncType op_type; - /*! - * \brief detailed description of the operator - * This can be used to generate docstring automatically for the operator. - */ - std::string description; - /* \brief Information of input arguments to the operator */ - Array arguments; - /*! - * \brief The type key of the attribute field - * This can be empty, in which case it defaults to anything. - */ - std::string attrs_type_key; - /*! - * \brief attribute type index, - * this field varies in each run and is not exposed to frontend. - */ - uint32_t attrs_type_index{0}; - /*! - * \brief number of input arguments to the operator, - * -1 means it is variable length - */ - int32_t num_inputs = -1; - /*! - * \brief support level of the operator, - * The lower the more priority it contains. - * This is in analogies to BLAS levels. - */ - int32_t support_level = 10; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("op_type", &op_type); - v->Visit("description", &description); - v->Visit("arguments", &arguments); - v->Visit("attrs_type_key", &attrs_type_key); - v->Visit("num_inputs", &num_inputs); - v->Visit("support_level", &support_level); - } - - /*! - * \brief Check that if current op is a "primtive operator". - * That is the arguments are all type variables, and there is a single - * type relation applied to the input and output types. - */ - bool IsPrimitiveOp() const { - if (is_primitive_ != -1) return is_primitive_ != 0; - is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0; - return is_primitive_ != 0; - } - - static constexpr const char* _type_key = "relay.Op"; - TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, ExprNode); - - private: - // friend class - friend class GenericOpMap; - friend class OpRegistry; - friend bool IsPrimitiveOp(const Expr&); - // Program internal unique index of operator. - // Used to help index the program. - uint32_t index_{0}; - // whether this is a primitive op. -1 means unknown. - mutable int is_primitive_{-1}; - // Internal function to compute if it is primitive op - bool IsPrimitiveOp_() const { - const auto& fn_ty = this->op_type; - if (fn_ty->type_constraints.size() != 1) return false; - const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); - if (rel == nullptr) return false; - // validate if the type parameter matches up - for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { - if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; - } - return true; - } -}; - -/*! - * \brief Operator reference class. - */ -class Op : public relay::Expr { - public: - /*! \brief default constructor */ - Op() {} - /*! \brief constructor from node pointer */ - explicit Op(ObjectPtr n) : RelayExpr(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const OpNode* operator->() const; - /*! - * \brief Get additional registered attribute about operators. - * If nothing has been registered, an empty OpMap will be returned. - * \param attr_name The name of the attribute. - * \return An OpMap of specified attr_name. - * \tparam ValueType The type of the attribute. - */ - template - inline static OpMap GetAttr(const std::string& attr_name); - /*! - * \brief Checks if an attr is present in the registry. - * \param attr_name The name of the attribute. - * \return bool True if the attr is present. - */ - inline static bool HasAttr(const std::string& attr_name); - /*! - * \brief Get an Op for a given operator name. - * Will raise an error if the op has not been registered. - * \param op_name Name of the operator. - * \return Pointer to a Op, valid throughout program lifetime. - */ - TVM_DLL static const Op& Get(const std::string& op_name); - - /*! \brief specify container node */ - using ContainerType = OpNode; - - private: - /*! - * \brief Get generic attrmap given attr name - * \param key The attribute key - * \return reference to GenericOpMap - */ - TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); - /*! - * \brief Checks if the key is present in the registry - * \param key The attribute key - * \return bool True if the key is present - */ - TVM_DLL static const bool HasGenericAttr(const std::string& key); -}; - -/*! \brief Helper structure to register operators */ -class OpRegistry { - public: - /*! \return the operator */ - const Op& op() const { return op_; } - /*! - * \brief setter function during registration - * Set the description of operator - * \param descr the description string. - * \return reference to self. - */ - inline OpRegistry& describe(const std::string& descr); // NOLINT(*) - /*! - * \brief Add argument information to the function. - * \param name Name of the argument. - * \param type Type of the argument. - * \param description Description of the argument. - * \return reference to self. - */ - inline OpRegistry& add_argument(const std::string& name, - const std::string& type, - const std::string& description); - /*! - * \brief Attach the type function corresponding to the return type. - * \param rel_name The type relation name to register. - * \param type_rel_func The backing relation function which can solve an arbitrary - * relation on variables. - * \return reference to self. - */ - inline OpRegistry& add_type_rel( - const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func); - /*! - * \brief Set the the attrs type key and index to be AttrsType. - * \tparam AttrsType the attribute type to b set. - * \return reference to self. - */ - template - inline OpRegistry& set_attrs_type(); - /*! - * \brief Set the num_inputs - * \param n The number of inputs to be set. - * \return reference to self. - */ - inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) - /*! - * \brief Set the support level of op. - * \param level The support level. - * \return reference to self. - */ - inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) - /*! - * \brief Register additional attributes to operator. - * \param attr_name The name of the attribute. - * \param value The value to be set. - * \param plevel The priority level of this set, - * an higher priority level attribute - * will replace lower priority level attribute. - * Must be bigger than 0. - * - * Cannot set with same plevel twice in the code. - * - * \tparam ValueType The type of the value to be set. - */ - template - inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, int plevel = 10); - - /*! - * \brief Resets an attr of the registry. - * \param attr_name The name of the attribute. - */ - inline void reset_attr(const std::string& attr_name); - - // set the name of the op to be the same as registry - inline OpRegistry& set_name() { // NOLINT(*) - if (get()->name.length() == 0) { - get()->name = name; - } - return *this; - } - /*! \return The global single registry */ - TVM_DLL static ::dmlc::Registry* Registry(); - - private: - friend class ::dmlc::Registry; - // the name - std::string name; - /*! \brief The operator */ - Op op_; - // private constructor - TVM_DLL OpRegistry(); - // return internal pointer to op. - inline OpNode* get(); - // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, - int plevel); -}; - -/*! - * \brief Generic map to store additional information of Op. - */ -class GenericOpMap { - public: - /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op& op) const; - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline const TVMRetValue& operator[](const Op& op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - * \tparam ValueType The content value type. - */ - template - inline ValueType get(const Op& op, ValueType def_value) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param expr The key to the map - * \param def_value The default value when the key does not exist - * or if expr is not an Op. - * \return the const reference to the content value. - * \tparam ValueType The content value type. - */ - template - inline ValueType get(const Expr& expr, ValueType def_value) const; - - private: - friend class OpRegistry; - // the attribute field. - std::string attr_name_; - // internal data - std::vector > data_; - // The value - GenericOpMap() = default; -}; - -/*! - * \brief Map used to store meta-information about Op. - * \tparam ValueType The type of the value stored in map. - */ -template -class OpMap { - public: - /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op& op) const; - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline ValueType operator[](const Op& op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - */ - inline ValueType get(const Op& op, ValueType def_value) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param expr The key to the map - * \param def_value The default value when the key does not exist - * or if expr is not an Op. - * \return the const reference to the content value. - */ - inline ValueType get(const Expr& expr, ValueType def_value) const; - - private: - friend class Op; - // constructor - explicit OpMap(const GenericOpMap& map) : map_(map) {} - /*! \brief The internal map field */ - const GenericOpMap& map_; -}; - -// internal macros to make -#define RELAY_REGISTER_VAR_DEF \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp - -/*! - * \def RELAY_REGISTER_OP - * \brief Register a new operator, or set attribute of the corresponding op. - * - * \param OpName The name of registry - * - * \code - * - * RELAY_REGISTER_OP("add") - * .describe("add two inputs together") - * .set_num_inputs(2) - * .set_attr("gpu_kernel", AddKernel); - * - * \endcode - */ -#define RELAY_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::relay::OpRegistry::Registry() \ - ->__REGISTER_OR_GET__(OpName) \ - .set_name() - -// implementations -inline const OpNode* Op::operator->() const { - return static_cast(get()); -} - -template -inline OpMap Op::GetAttr(const std::string& key) { - return OpMap(Op::GetGenericAttr(key)); -} - -inline bool Op::HasAttr(const std::string& key) { - return Op::HasGenericAttr(key); -} - -inline OpNode* OpRegistry::get() { - return const_cast(op_.operator->()); -} - -inline OpRegistry& OpRegistry::describe( - const std::string& descr) { // NOLINT(*) - get()->description = descr; - return *this; -} - -inline OpRegistry& OpRegistry::add_argument(const std::string& name, - const std::string& type, - const std::string& description) { - auto n = make_object(); - n->name = name; - n->type_info = type; - n->description = description; - get()->arguments.push_back(AttrFieldInfo(n)); - return *this; -} - -inline OpRegistry& OpRegistry::add_type_rel( - const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func) { - auto func_name = std::string("tvm.relay.type_relation.") + rel_name; - TypeRelationFn env_type_rel_func; - - if (runtime::Registry::Get(func_name)) { - auto env_func = EnvFunc::Get(func_name); - env_type_rel_func = env_func; - } else { - runtime::Registry::Register(func_name) - .set_body(type_rel_func.packed()); - auto env_func = EnvFunc::Get(func_name); - env_type_rel_func = env_func; - } - - Array type_params; - Array arg_types; - - // Add inputs. - std::string input_name_prefix = "in"; - for (int i = 0; i < get()->num_inputs; i++) { - auto name = input_name_prefix + std::to_string(i); - auto param = TypeVarNode::make(name, Kind::kType); - type_params.push_back(param); - arg_types.push_back(param); - } - - Array ty_call_args = arg_types; - - // Add output type. - auto out_param = TypeVarNode::make("out", Kind::kType); - type_params.push_back(out_param); - // this will trigger copy on write. - ty_call_args.push_back(out_param); - - // The attributes of primitive op is nullptr - // - // The attributes of primitive operator can vary at the call site. - // The type of sum is also dependent on Attrs being passed. - // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs. - // - // A common example is sum(x, axis), where the choice of axis - // can affect the type of the function. - TypeConstraint type_rel = - TypeRelationNode::make(env_type_rel_func, - ty_call_args, - arg_types.size(), - Attrs()); - - auto func_type = - FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); - - get()->op_type = func_type; - - return *this; -} +using Op = tvm::Op; +using OpNode = tvm::OpNode; -inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) - get()->num_inputs = n; - return *this; -} - -template -inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) - get()->attrs_type_key = AttrsType::_type_key; - get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); - return *this; -} - -inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) - get()->support_level = n; - return *this; -} - -template -inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) - const std::string& attr_name, const ValueType& value, int plevel) { - CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; - TVMRetValue rv; - rv = value; - UpdateAttr(attr_name, rv, plevel); - return *this; -} - -// member functions of OpMap -inline int GenericOpMap::count(const Op& op) const { - if (op.defined()) { - const uint32_t idx = op->index_; - return idx < data_.size() ? (data_[idx].second != 0) : 0; - } else { - return 0; - } -} - -inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { - CHECK(op.defined()); - const uint32_t idx = op->index_; - CHECK(idx < data_.size() && data_[idx].second != 0) - << "Attribute " << attr_name_ << " has not been registered for Operator " - << op->name; - return data_[idx].first; -} - -template -inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { - CHECK(op.defined()); - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; - } else { - return value; - } -} - -template -inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const { - CHECK(expr.defined()); - if (const OpNode* op = expr.as()) { - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; - } else { - return value; - } - } else { - return value; - } -} - -template -inline int OpMap::count(const Op& op) const { - return map_.count(op); -} - -template -inline ValueType OpMap::operator[](const Op& op) const { - return map_[op]; -} - -template -inline ValueType OpMap::get(const Op& op, - ValueType def_value) const { - return map_.get(op, def_value); -} - -template -inline ValueType OpMap::get(const Expr& expr, - ValueType def_value) const { - return map_.get(expr, def_value); -} - -/*! - * \brief Check that an expression is a "primitive operator". - * - * Will return true if the expression is an operator which - * matches the form of primitive operators registered directly - * by the Relay codebase. - * - * That is the arguments are all type variables, and there is a single - * type relation applied to the input and output types. - */ -inline bool IsPrimitiveOp(const Expr& expr) { - const auto* op = expr.as(); - return op != nullptr && op->IsPrimitiveOp(); -} +#define RELAY_REGISTER_OP(OpName) \ + TVM_REGISTER_OP(OpName) } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 31e85f913238..7748bd108dfb 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAY_TYPE_H_ #define TVM_RELAY_TYPE_H_ - #include +#include #include #include #include @@ -51,6 +51,11 @@ using TypeConstraint = tvm::TypeConstraint; using TypeConstraintNode = tvm::TypeConstraintNode; using FuncType = tvm::FuncType; using FuncTypeNode = tvm::FuncTypeNode; +using TypeRelation = tvm::TypeRelation; +using TypeRelationNode = tvm::TypeRelationNode; +using TypeRelationFn = tvm::TypeRelationFn; +using TypeReporter = tvm::TypeReporter; +using TypeReporterNode = tvm::TypeReporterNode; /*! * \brief Base of all Tensor types @@ -235,146 +240,6 @@ class RefType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode); }; -class TypeReporter; - -/*! - * \brief reporter that reports back to the - * type resolution information. - */ -class TypeReporterNode : public Object { - public: - /*! - * \brief Create a type equality constraint. - * - * The "assign direction" acts as a hint to the solver - * showing that it is more likely to resolve dst by src. - * But it is possible for the solver to resolve src by dst as well. - */ - TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0; - - /*! - * \brief assert shape expression comparison. - * \note Use assert only if any of the condition input is symbolic. - * \param cond The condition of operation. - * \return false if assertation can be proven to have failed - * true if solver can still proceed. - */ - TVM_DLL virtual bool Assert(const IndexExpr& cond)= 0; - /*! - * \brief assert shape expression equals each other. - * \param lhs The left operand. - * \param rhs The right operand. - * \return false if assertation can be proven to have failed - * true if solver can still proceed. - */ - TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0; - - /*! - * \brief Set the location at which to report unification errors. - * \param ref The program node to report the error. - */ - TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0; - - /*! - * \brief Retrieve the current global module. - * \return The global module. - */ - TVM_DLL virtual Module GetModule() = 0; - - // solver is not serializable. - void VisitAttrs(tvm::AttrVisitor* v) {} - - static constexpr const char* _type_key = "relay.TypeReporter"; - TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); -}; - -/*! - * \brief Container class of TypeReporter. - * \sa TypeReporterNode - */ -class TypeReporter : public ObjectRef { - public: - TypeReporter() {} - explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { - } - TypeReporterNode* operator->() const { - return const_cast( - static_cast(get())); - } - using ContainerType = TypeReporterNode; -}; - -/*! - * \brief User defined type constraint function. - * - * If the input type information can be used to fully decide - * the IncompleteTypes, then the function should call - * reporter.Assign to report the new types, and return true. - * Otherwise, the function should return false. - * - * \param args The arguments to the relation. - * The types are stored in the form of - * [input_type_0, input_type_1, ... input_type_n, - * output_type_0, output_type_1, ... output_type_m] - * - * \param num_inputs Number of input types in the args. - * \param attrs The additional attributes of the operator. - * \param reporter The reporter to report solution to. - * \return false if This relation cannot be resolved. - * true if this relation has been resolved. - */ -using TypeRelationFn = - TypedEnvFunc& args, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter)>; - -/*! - * \brief User defined type relation, is an input-output relation on types. - */ -class TypeRelation; -/*! - * \brief TypeRelation container. - * \note This node is not directly serializable. - * The type function need to be lookedup in the module. - */ -class TypeRelationNode : public TypeConstraintNode { - public: - /*! - * \brief The function on input and output variables which - * this is not directly serializable, - * need to be looked-up in the module. - */ - TypeRelationFn func; - /*! \brief The type arguments to the type function. */ - tvm::Array args; - /*! \brief Number of inputs arguments */ - int num_inputs; - /*! \brief Attributes to the relation function */ - Attrs attrs; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("args", &args); - v->Visit("num_inputs", &num_inputs); - v->Visit("attrs", &attrs); - v->Visit("span", &span); - } - - TVM_DLL static TypeRelation make(TypeRelationFn func, - Array args, - int num_args, - Attrs attrs); - - static constexpr const char* _type_key = "relay.TypeRelation"; - TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); -}; - -class TypeRelation : public TypeConstraint { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); -}; - // The following fields contains advanced typing // Only keep the class name and reserved for future usage. class GenericTensorType; diff --git a/src/relay/ir/op.cc b/src/ir/op.cc similarity index 96% rename from src/relay/ir/op.cc rename to src/ir/op.cc index b888ecbd9241..0ed2f3dcb015 100644 --- a/src/relay/ir/op.cc +++ b/src/ir/op.cc @@ -18,11 +18,11 @@ */ /*! - * \file src/tvm/relay/op.cc - * \brief Resolve incomplete types to complete types. + * \file src/tvm/ir/op.cc + * \brief Primitive operators and intrinsics. */ -#include -#include +#include +#include #include #include @@ -31,11 +31,10 @@ namespace dmlc { // enable registry -DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); +DMLC_REGISTRY_ENABLE(::tvm::OpRegistry); } // namespace dmlc namespace tvm { -namespace relay { ::dmlc::Registry* OpRegistry::Registry() { return ::dmlc::Registry::Get(); @@ -230,5 +229,4 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "Op(" << node->name << ")"; }); -} // namespace relay } // namespace tvm diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc new file mode 100644 index 000000000000..cc5ceef7dd3e --- /dev/null +++ b/src/ir/type_relation.cc @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/ir/type_relation.cc + * \brief Type relation + */ +#include +#include +#include +#include + +namespace tvm { +TypeRelation TypeRelationNode::make(TypeRelationFn func, + Array args, + int num_inputs, + Attrs attrs) { + ObjectPtr n = make_object(); + n->func = std::move(func); + n->args = std::move(args); + n->num_inputs = num_inputs; + n->attrs = std::move(attrs); + return TypeRelation(n); +} + +TVM_REGISTER_NODE_TYPE(TypeRelationNode); + +TVM_REGISTER_GLOBAL("relay._make.TypeRelation") +.set_body_typed(TypeRelationNode::make); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeRelationNode(" + << node->func->name + << ", " << node->args << ")"; +}); +} // namespace tvm diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 4ae2ee58ae51..099b8013c895 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -101,30 +101,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; }); -TypeRelation TypeRelationNode::make(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { - ObjectPtr n = make_object(); - n->func = std::move(func); - n->args = std::move(args); - n->num_inputs = num_inputs; - n->attrs = std::move(attrs); - return TypeRelation(n); -} - -TVM_REGISTER_NODE_TYPE(TypeRelationNode); - -TVM_REGISTER_GLOBAL("relay._make.TypeRelation") -.set_body_typed(TypeRelationNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeRelationNode(" - << node->func->name - << ", " << node->args << ")"; -}); TupleType TupleTypeNode::make(Array fields) { ObjectPtr n = make_object();