From 370c2dfe69701e6af2996a64ceeb43d7439b23e5 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 11 Jan 2020 13:02:29 -0800 Subject: [PATCH] [REFACTOR][IR] Unified IR Primitive Op and Registry (#4687) This PR migrates relay's Op into the ir folder. Op and its registry provides an useful mechanism to store any attribute meta-data of an operator include function signatures, lowering rules, side effect etc. These features are not only useful for Relay, but also needed in the low-level IR. At the current moment, intrinsic functions in the low-level IR are simply represented by a string. This means we cannot type-check the low-level IR when the type does not meet the constraint, nor can we obtain further information such as side-effect and read write relation of these intrinsics wrt to arguments. Op will be used as the way to handle primitive ops(in DL terminology) (builtin intrinsics or in compiler terminology). We will perform follow-up refactors to make low-level CallNode take Op as the function argument. --- include/tvm/ir/op.h | 627 +++++++++++++++++++++++++++++++++ include/tvm/ir/type.h | 1 + include/tvm/ir/type_relation.h | 175 +++++++++ include/tvm/relay/op.h | 588 +------------------------------ include/tvm/relay/type.h | 147 +------- src/{relay => }/ir/op.cc | 12 +- src/ir/type_relation.cc | 54 +++ src/relay/ir/type.cc | 24 -- 8 files changed, 876 insertions(+), 752 deletions(-) create mode 100644 include/tvm/ir/op.h create mode 100644 include/tvm/ir/type_relation.h rename src/{relay => }/ir/op.cc (96%) create mode 100644 src/ir/type_relation.cc diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h new file mode 100644 index 0000000000000..19c5a51162a35 --- /dev/null +++ b/include/tvm/ir/op.h @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/op.h + * \brief Primitive operators(builtin intrinsics) + * and registry for them. + */ +#ifndef TVM_IR_OP_H_ +#define TVM_IR_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { + +// forward declare name. +template +class OpMap; +class GenericOpMap; +class OpRegistry; + +// TODO(tvm-team): migrate low-level intrinsics to use Op +/*! + * \brief Primitive Op(builtin intrinsics) + * + * This data structure stores the meta-data + * about primitive operators that can be invoked via Call. + * + * Low-level IR intrinsics(such as libc.expf) are also + * implemented via Op. + * + * \sa Op + */ +class OpNode : public RelayExprNode { + public: + /*! \brief name of the operator */ + std::string name; + /*! \brief the type of the operator */ + mutable FuncType op_type; + /*! + * \brief detailed description of the operator + * This can be used to generate docstring automatically for the operator. + */ + std::string description; + /* \brief Information of input arguments to the operator */ + Array arguments; + /*! + * \brief The type key of the attribute field + * This can be empty, in which case it defaults to anything. + */ + std::string attrs_type_key; + /*! + * \brief attribute type index, + * this field varies in each run and is not exposed to frontend. + */ + uint32_t attrs_type_index{0}; + /*! + * \brief number of input arguments to the operator, + * -1 means it is variable length + */ + int32_t num_inputs = -1; + /*! + * \brief support level of the operator, + * The lower the more priority it contains. + * This is in analogies to BLAS levels. + */ + int32_t support_level = 10; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("op_type", &op_type); + v->Visit("description", &description); + v->Visit("arguments", &arguments); + v->Visit("attrs_type_key", &attrs_type_key); + v->Visit("num_inputs", &num_inputs); + v->Visit("support_level", &support_level); + } + + /*! + * \brief Check that if current op is a "primtive operator". + * That is the arguments are all type variables, and there is a single + * type relation applied to the input and output types. + */ + bool IsPrimitiveOp() const { + if (is_primitive_ != -1) return is_primitive_ != 0; + is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0; + return is_primitive_ != 0; + } + + static constexpr const char* _type_key = "relay.Op"; + TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode); + + private: + // friend class + friend class GenericOpMap; + friend class OpRegistry; + friend bool IsPrimitiveOp(const RelayExpr&); + // Program internal unique index of operator. + // Used to help index the program. + uint32_t index_{0}; + // whether this is a primitive op. -1 means unknown. + mutable int is_primitive_{-1}; + // Internal function to compute if it is primitive op + bool IsPrimitiveOp_() const { + const auto& fn_ty = this->op_type; + if (fn_ty->type_constraints.size() != 1) return false; + const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); + if (rel == nullptr) return false; + // validate if the type parameter matches up + for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { + if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; + } + return true; + } +}; + +/*! + * \brief Managed reference class to OpNode. + * \sa OpNode + */ +class Op : public RelayExpr { + public: + /*! \brief default constructor */ + Op() {} + /*! \brief constructor from node pointer */ + explicit Op(ObjectPtr n) : RelayExpr(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const OpNode* operator->() const; + /*! + * \brief Get additional registered attribute about operators. + * If nothing has been registered, an empty OpMap will be returned. + * \param attr_name The name of the attribute. + * \return An OpMap of specified attr_name. + * \tparam ValueType The type of the attribute. + */ + template + inline static OpMap GetAttr(const std::string& attr_name); + /*! + * \brief Checks if an attr is present in the registry. + * \param attr_name The name of the attribute. + * \return bool True if the attr is present. + */ + inline static bool HasAttr(const std::string& attr_name); + /*! + * \brief Get an Op for a given operator name. + * Will raise an error if the op has not been registered. + * \param op_name Name of the operator. + * \return Pointer to a Op, valid throughout program lifetime. + */ + TVM_DLL static const Op& Get(const std::string& op_name); + + /*! \brief specify container node */ + using ContainerType = OpNode; + + private: + /*! + * \brief Get generic attrmap given attr name + * \param key The attribute key + * \return reference to GenericOpMap + */ + TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); + /*! + * \brief Checks if the key is present in the registry + * \param key The attribute key + * \return bool True if the key is present + */ + TVM_DLL static const bool HasGenericAttr(const std::string& key); +}; + +/*! + * \brief Helper structure to register operators + * \sa TVM_REGISTER_OP + */ +class OpRegistry { + public: + /*! \return the operator */ + const Op& op() const { return op_; } + /*! + * \brief setter function during registration + * Set the description of operator + * \param descr the description string. + * \return reference to self. + */ + inline OpRegistry& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline OpRegistry& add_argument(const std::string& name, + const std::string& type, + const std::string& description); + /*! + * \brief Attach the type function corresponding to the return type. + * \param rel_name The type relation name to register. + * \param type_rel_func The backing relation function which can solve an arbitrary + * relation on variables. + * \return reference to self. + */ + inline OpRegistry& add_type_rel( + const std::string& rel_name, + runtime::TypedPackedFunc&, + int, + const Attrs&, + const TypeReporter&)> type_rel_func); + /*! + * \brief Set the the attrs type key and index to be AttrsType. + * \tparam AttrsType the attribute type to b set. + * \return reference to self. + */ + template + inline OpRegistry& set_attrs_type(); + /*! + * \brief Set the num_inputs + * \param n The number of inputs to be set. + * \return reference to self. + */ + inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) + /*! + * \brief Set the support level of op. + * \param level The support level. + * \return reference to self. + */ + inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) + /*! + * \brief Register additional attributes to operator. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, int plevel = 10); + + /*! + * \brief Resets an attr of the registry. + * \param attr_name The name of the attribute. + */ + inline void reset_attr(const std::string& attr_name); + + // set the name of the op to be the same as registry + inline OpRegistry& set_name() { // NOLINT(*) + if (get()->name.length() == 0) { + get()->name = name; + } + return *this; + } + /*! \return The global single registry */ + TVM_DLL static ::dmlc::Registry* Registry(); + + private: + friend class ::dmlc::Registry; + // the name + std::string name; + /*! \brief The operator */ + Op op_; + // private constructor + TVM_DLL OpRegistry(); + // return internal pointer to op. + inline OpNode* get(); + // update the attribute OpMap + TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, + int plevel); +}; + +/*! + * \brief Generic map to store additional information of Op. + */ +class GenericOpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline const TVMRetValue& operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const Op& op, ValueType def_value) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param expr The key to the map + * \param def_value The default value when the key does not exist + * or if expr is not an Op. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const RelayExpr& expr, ValueType def_value) const; + + private: + friend class OpRegistry; + // the attribute field. + std::string attr_name_; + // internal data + std::vector > data_; + // The value + GenericOpMap() = default; +}; + +/*! + * \brief Map used to store meta-information about Op. + * \tparam ValueType The type of the value stored in map. + */ +template +class OpMap { + public: + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op& op) const; + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline ValueType operator[](const Op& op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + inline ValueType get(const Op& op, ValueType def_value) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param expr The key to the map + * \param def_value The default value when the key does not exist + * or if expr is not an Op. + * \return the const reference to the content value. + */ + inline ValueType get(const RelayExpr& expr, ValueType def_value) const; + + private: + friend class Op; + // constructor + explicit OpMap(const GenericOpMap& map) : map_(map) {} + /*! \brief The internal map field */ + const GenericOpMap& map_; +}; + +// internal macros to make +#define TVM_OP_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op + +/*! + * \def TVM_REGISTER_OP + * \brief Register a new operator, or set attribute of the corresponding op. + * + * \param OpName The name of registry + * + * \code + * + * TVM_REGISTER_OP("add") + * .describe("add two inputs together") + * .set_num_inputs(2) + * .set_attr("gpu_kernel", AddKernel); + * + * \endcode + */ +#define TVM_REGISTER_OP(OpName) \ + TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::OpRegistry::Registry() \ + ->__REGISTER_OR_GET__(OpName) \ + .set_name() + +// implementations +inline const OpNode* Op::operator->() const { + return static_cast(get()); +} + +template +inline OpMap Op::GetAttr(const std::string& key) { + return OpMap(Op::GetGenericAttr(key)); +} + +inline bool Op::HasAttr(const std::string& key) { + return Op::HasGenericAttr(key); +} + +inline OpNode* OpRegistry::get() { + return const_cast(op_.operator->()); +} + +inline OpRegistry& OpRegistry::describe( + const std::string& descr) { // NOLINT(*) + get()->description = descr; + return *this; +} + +inline OpRegistry& OpRegistry::add_argument(const std::string& name, + const std::string& type, + const std::string& description) { + auto n = make_object(); + n->name = name; + n->type_info = type; + n->description = description; + get()->arguments.push_back(AttrFieldInfo(n)); + return *this; +} + +inline OpRegistry& OpRegistry::add_type_rel( + const std::string& rel_name, + runtime::TypedPackedFunc&, + int, + const Attrs&, + const TypeReporter&)> type_rel_func) { + auto func_name = std::string("tvm.relay.type_relation.") + rel_name; + TypeRelationFn env_type_rel_func; + + if (runtime::Registry::Get(func_name)) { + auto env_func = EnvFunc::Get(func_name); + env_type_rel_func = env_func; + } else { + runtime::Registry::Register(func_name) + .set_body(type_rel_func.packed()); + auto env_func = EnvFunc::Get(func_name); + env_type_rel_func = env_func; + } + + Array type_params; + Array arg_types; + + // Add inputs. + std::string input_name_prefix = "in"; + for (int i = 0; i < get()->num_inputs; i++) { + auto name = input_name_prefix + std::to_string(i); + auto param = TypeVarNode::make(name, TypeKind::kType); + type_params.push_back(param); + arg_types.push_back(param); + } + + Array ty_call_args = arg_types; + + // Add output type. + auto out_param = TypeVarNode::make("out", TypeKind::kType); + type_params.push_back(out_param); + // this will trigger copy on write. + ty_call_args.push_back(out_param); + + // The attributes of primitive op is nullptr + // + // The attributes of primitive operator can vary at the call site. + // The type of sum is also dependent on Attrs being passed. + // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs. + // + // A common example is sum(x, axis), where the choice of axis + // can affect the type of the function. + TypeConstraint type_rel = + TypeRelationNode::make(env_type_rel_func, + ty_call_args, + arg_types.size(), + Attrs()); + + auto func_type = + FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); + + get()->op_type = func_type; + + return *this; +} + +inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) + get()->num_inputs = n; + return *this; +} + +template +inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) + get()->attrs_type_key = AttrsType::_type_key; + get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); + return *this; +} + +inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) + get()->support_level = n; + return *this; +} + +template +inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; + TVMRetValue rv; + rv = value; + UpdateAttr(attr_name, rv, plevel); + return *this; +} + +// member functions of OpMap +inline int GenericOpMap::count(const Op& op) const { + if (op.defined()) { + const uint32_t idx = op->index_; + return idx < data_.size() ? (data_[idx].second != 0) : 0; + } else { + return 0; + } +} + +inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + CHECK(idx < data_.size() && data_[idx].second != 0) + << "Attribute " << attr_name_ << " has not been registered for Operator " + << op->name; + return data_[idx].first; +} + +template +inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { + CHECK(op.defined()); + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } +} + +template +inline ValueType GenericOpMap::get(const RelayExpr& expr, ValueType value) const { + CHECK(expr.defined()); + if (const OpNode* op = expr.as()) { + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } + } else { + return value; + } +} + +template +inline int OpMap::count(const Op& op) const { + return map_.count(op); +} + +template +inline ValueType OpMap::operator[](const Op& op) const { + return map_[op]; +} + +template +inline ValueType OpMap::get(const Op& op, + ValueType def_value) const { + return map_.get(op, def_value); +} + +template +inline ValueType OpMap::get(const RelayExpr& expr, + ValueType def_value) const { + return map_.get(expr, def_value); +} + +/*! + * \brief Check that an expression is a "primitive operator". + * + * Will return true if the expression is an operator which + * matches the form of primitive operators registered directly + * by the Relay codebase. + * + * That is the arguments are all type variables, and there is a single + * type relation applied to the input and output types. + * + * \param expr An expression. + * \return Whether the expression is primitive op. + */ +inline bool IsPrimitiveOp(const RelayExpr& expr) { + const auto* op = expr.as(); + return op != nullptr && op->IsPrimitiveOp(); +} + +} // namespace tvm +#endif // TVM_IR_OP_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index ab2003e9ec5e4..ddabd0fbbcc0b 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -51,6 +51,7 @@ #include #include +#include #include #include #include diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h new file mode 100644 index 0000000000000..71d1d9eb4520e --- /dev/null +++ b/include/tvm/ir/type_relation.h @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/type_relation.h + * \brief Type relation function for type checking. + */ +#ifndef TVM_IR_TYPE_RELATION_H_ +#define TVM_IR_TYPE_RELATION_H_ + +#include +#include + +namespace tvm { + +// TODO(tqchen): remove after migrate Module to ir. +namespace relay { +struct Module; +} + +/*! + * \brief reporter that reports back to the + * type resolution information. + */ +class TypeReporterNode : public Object { + public: + /*! + * \brief Create a type equality constraint. + * + * The "assign direction" acts as a hint to the solver + * showing that it is more likely to resolve dst by src. + * But it is possible for the solver to resolve src by dst as well. + */ + TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0; + + /*! + * \brief assert shape expression comparison. + * \note Use assert only if any of the condition input is symbolic. + * \param cond The condition of operation. + * \return false if assertation can be proven to have failed + * true if solver can still proceed. + */ + TVM_DLL virtual bool Assert(const PrimExpr& cond)= 0; + /*! + * \brief assert shape expression equals each other. + * \param lhs The left operand. + * \param rhs The right operand. + * \return false if assertation can be proven to have failed + * true if solver can still proceed. + */ + TVM_DLL virtual bool AssertEQ(const PrimExpr& lhs, const PrimExpr& rhs) = 0; + + /*! + * \brief Set the location at which to report unification errors. + * \param ref The program node to report the error. + */ + TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0; + + /*! + * \brief Retrieve the current global module. + * \return The global module. + */ + TVM_DLL virtual relay::Module GetModule() = 0; + + // solver is not serializable. + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.TypeReporter"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); +}; + +/*! + * \brief Container class of TypeReporter. + * \sa TypeReporterNode + */ +class TypeReporter : public ObjectRef { + public: + TypeReporter() {} + explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { + } + TypeReporterNode* operator->() const { + return const_cast( + static_cast(get())); + } + using ContainerType = TypeReporterNode; +}; + +/*! + * \brief User defined type constraint function. + * + * If the input type information can be used to fully decide + * the IncompleteTypes, then the function should call + * reporter.Assign to report the new types, and return true. + * Otherwise, the function should return false. + * + * \param args The arguments to the relation. + * The types are stored in the form of + * [input_type_0, input_type_1, ... input_type_n, + * output_type_0, output_type_1, ... output_type_m] + * + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. + * true if this relation has been resolved. + */ +using TypeRelationFn = + TypedEnvFunc& args, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter)>; + +/*! + * \brief User defined type relation, is an input-output relation on types. + */ +class TypeRelation; +/*! + * \brief TypeRelation container. + * \note This node is not directly serializable. + * The type function need to be lookedup in the module. + */ +class TypeRelationNode : public TypeConstraintNode { + public: + /*! + * \brief The function on input and output variables which + * this is not directly serializable, + * need to be looked-up in the module. + */ + TypeRelationFn func; + /*! \brief The type arguments to the type function. */ + tvm::Array args; + /*! \brief Number of inputs arguments */ + int num_inputs; + /*! \brief Attributes to the relation function */ + Attrs attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("func", &func); + v->Visit("args", &args); + v->Visit("num_inputs", &num_inputs); + v->Visit("attrs", &attrs); + v->Visit("span", &span); + } + + TVM_DLL static TypeRelation make(TypeRelationFn func, + Array args, + int num_args, + Attrs attrs); + + static constexpr const char* _type_key = "relay.TypeRelation"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); +}; + +class TypeRelation : public TypeConstraint { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); +}; +} // namespace tvm +#endif // TVM_IR_TYPE_RELATION_H_ diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 6bd0a359cc69e..fa47da226dffc 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -19,595 +19,23 @@ /*! * \file tvm/relay/op.h - * \brief Primitive operator definition. + * \brief Primitive operators(builtin intrinsics). */ #ifndef TVM_RELAY_OP_H_ #define TVM_RELAY_OP_H_ -#include - -#include -#include -#include -#include -#include -#include - -#include "base.h" -#include "expr.h" -#include "type.h" +#include +#include +#include namespace tvm { namespace relay { -// forward declare name. -template -class OpMap; -class GenericOpMap; -class OpRegistry; - -/*! - * \brief Node container of operator structure. - */ -class OpNode : public relay::ExprNode { - public: - /*! \brief name of the operator */ - std::string name; - /*! \brief the type of the operator */ - mutable FuncType op_type; - /*! - * \brief detailed description of the operator - * This can be used to generate docstring automatically for the operator. - */ - std::string description; - /* \brief Information of input arguments to the operator */ - Array arguments; - /*! - * \brief The type key of the attribute field - * This can be empty, in which case it defaults to anything. - */ - std::string attrs_type_key; - /*! - * \brief attribute type index, - * this field varies in each run and is not exposed to frontend. - */ - uint32_t attrs_type_index{0}; - /*! - * \brief number of input arguments to the operator, - * -1 means it is variable length - */ - int32_t num_inputs = -1; - /*! - * \brief support level of the operator, - * The lower the more priority it contains. - * This is in analogies to BLAS levels. - */ - int32_t support_level = 10; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("op_type", &op_type); - v->Visit("description", &description); - v->Visit("arguments", &arguments); - v->Visit("attrs_type_key", &attrs_type_key); - v->Visit("num_inputs", &num_inputs); - v->Visit("support_level", &support_level); - } - - /*! - * \brief Check that if current op is a "primtive operator". - * That is the arguments are all type variables, and there is a single - * type relation applied to the input and output types. - */ - bool IsPrimitiveOp() const { - if (is_primitive_ != -1) return is_primitive_ != 0; - is_primitive_ = this->IsPrimitiveOp_() ? 1 : 0; - return is_primitive_ != 0; - } - - static constexpr const char* _type_key = "relay.Op"; - TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, ExprNode); - - private: - // friend class - friend class GenericOpMap; - friend class OpRegistry; - friend bool IsPrimitiveOp(const Expr&); - // Program internal unique index of operator. - // Used to help index the program. - uint32_t index_{0}; - // whether this is a primitive op. -1 means unknown. - mutable int is_primitive_{-1}; - // Internal function to compute if it is primitive op - bool IsPrimitiveOp_() const { - const auto& fn_ty = this->op_type; - if (fn_ty->type_constraints.size() != 1) return false; - const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); - if (rel == nullptr) return false; - // validate if the type parameter matches up - for (size_t i = 0; i < fn_ty->type_params.size(); ++i) { - if (!fn_ty->type_params[i].same_as(rel->args[i])) return false; - } - return true; - } -}; - -/*! - * \brief Operator reference class. - */ -class Op : public relay::Expr { - public: - /*! \brief default constructor */ - Op() {} - /*! \brief constructor from node pointer */ - explicit Op(ObjectPtr n) : RelayExpr(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const OpNode* operator->() const; - /*! - * \brief Get additional registered attribute about operators. - * If nothing has been registered, an empty OpMap will be returned. - * \param attr_name The name of the attribute. - * \return An OpMap of specified attr_name. - * \tparam ValueType The type of the attribute. - */ - template - inline static OpMap GetAttr(const std::string& attr_name); - /*! - * \brief Checks if an attr is present in the registry. - * \param attr_name The name of the attribute. - * \return bool True if the attr is present. - */ - inline static bool HasAttr(const std::string& attr_name); - /*! - * \brief Get an Op for a given operator name. - * Will raise an error if the op has not been registered. - * \param op_name Name of the operator. - * \return Pointer to a Op, valid throughout program lifetime. - */ - TVM_DLL static const Op& Get(const std::string& op_name); - - /*! \brief specify container node */ - using ContainerType = OpNode; - - private: - /*! - * \brief Get generic attrmap given attr name - * \param key The attribute key - * \return reference to GenericOpMap - */ - TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); - /*! - * \brief Checks if the key is present in the registry - * \param key The attribute key - * \return bool True if the key is present - */ - TVM_DLL static const bool HasGenericAttr(const std::string& key); -}; - -/*! \brief Helper structure to register operators */ -class OpRegistry { - public: - /*! \return the operator */ - const Op& op() const { return op_; } - /*! - * \brief setter function during registration - * Set the description of operator - * \param descr the description string. - * \return reference to self. - */ - inline OpRegistry& describe(const std::string& descr); // NOLINT(*) - /*! - * \brief Add argument information to the function. - * \param name Name of the argument. - * \param type Type of the argument. - * \param description Description of the argument. - * \return reference to self. - */ - inline OpRegistry& add_argument(const std::string& name, - const std::string& type, - const std::string& description); - /*! - * \brief Attach the type function corresponding to the return type. - * \param rel_name The type relation name to register. - * \param type_rel_func The backing relation function which can solve an arbitrary - * relation on variables. - * \return reference to self. - */ - inline OpRegistry& add_type_rel( - const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func); - /*! - * \brief Set the the attrs type key and index to be AttrsType. - * \tparam AttrsType the attribute type to b set. - * \return reference to self. - */ - template - inline OpRegistry& set_attrs_type(); - /*! - * \brief Set the num_inputs - * \param n The number of inputs to be set. - * \return reference to self. - */ - inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) - /*! - * \brief Set the support level of op. - * \param level The support level. - * \return reference to self. - */ - inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) - /*! - * \brief Register additional attributes to operator. - * \param attr_name The name of the attribute. - * \param value The value to be set. - * \param plevel The priority level of this set, - * an higher priority level attribute - * will replace lower priority level attribute. - * Must be bigger than 0. - * - * Cannot set with same plevel twice in the code. - * - * \tparam ValueType The type of the value to be set. - */ - template - inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, int plevel = 10); - - /*! - * \brief Resets an attr of the registry. - * \param attr_name The name of the attribute. - */ - inline void reset_attr(const std::string& attr_name); - - // set the name of the op to be the same as registry - inline OpRegistry& set_name() { // NOLINT(*) - if (get()->name.length() == 0) { - get()->name = name; - } - return *this; - } - /*! \return The global single registry */ - TVM_DLL static ::dmlc::Registry* Registry(); - - private: - friend class ::dmlc::Registry; - // the name - std::string name; - /*! \brief The operator */ - Op op_; - // private constructor - TVM_DLL OpRegistry(); - // return internal pointer to op. - inline OpNode* get(); - // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, - int plevel); -}; - -/*! - * \brief Generic map to store additional information of Op. - */ -class GenericOpMap { - public: - /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op& op) const; - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline const TVMRetValue& operator[](const Op& op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - * \tparam ValueType The content value type. - */ - template - inline ValueType get(const Op& op, ValueType def_value) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param expr The key to the map - * \param def_value The default value when the key does not exist - * or if expr is not an Op. - * \return the const reference to the content value. - * \tparam ValueType The content value type. - */ - template - inline ValueType get(const Expr& expr, ValueType def_value) const; - - private: - friend class OpRegistry; - // the attribute field. - std::string attr_name_; - // internal data - std::vector > data_; - // The value - GenericOpMap() = default; -}; - -/*! - * \brief Map used to store meta-information about Op. - * \tparam ValueType The type of the value stored in map. - */ -template -class OpMap { - public: - /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op& op) const; - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline ValueType operator[](const Op& op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - */ - inline ValueType get(const Op& op, ValueType def_value) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param expr The key to the map - * \param def_value The default value when the key does not exist - * or if expr is not an Op. - * \return the const reference to the content value. - */ - inline ValueType get(const Expr& expr, ValueType def_value) const; - - private: - friend class Op; - // constructor - explicit OpMap(const GenericOpMap& map) : map_(map) {} - /*! \brief The internal map field */ - const GenericOpMap& map_; -}; - -// internal macros to make -#define RELAY_REGISTER_VAR_DEF \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp - -/*! - * \def RELAY_REGISTER_OP - * \brief Register a new operator, or set attribute of the corresponding op. - * - * \param OpName The name of registry - * - * \code - * - * RELAY_REGISTER_OP("add") - * .describe("add two inputs together") - * .set_num_inputs(2) - * .set_attr("gpu_kernel", AddKernel); - * - * \endcode - */ -#define RELAY_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::relay::OpRegistry::Registry() \ - ->__REGISTER_OR_GET__(OpName) \ - .set_name() - -// implementations -inline const OpNode* Op::operator->() const { - return static_cast(get()); -} - -template -inline OpMap Op::GetAttr(const std::string& key) { - return OpMap(Op::GetGenericAttr(key)); -} - -inline bool Op::HasAttr(const std::string& key) { - return Op::HasGenericAttr(key); -} - -inline OpNode* OpRegistry::get() { - return const_cast(op_.operator->()); -} - -inline OpRegistry& OpRegistry::describe( - const std::string& descr) { // NOLINT(*) - get()->description = descr; - return *this; -} - -inline OpRegistry& OpRegistry::add_argument(const std::string& name, - const std::string& type, - const std::string& description) { - auto n = make_object(); - n->name = name; - n->type_info = type; - n->description = description; - get()->arguments.push_back(AttrFieldInfo(n)); - return *this; -} - -inline OpRegistry& OpRegistry::add_type_rel( - const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func) { - auto func_name = std::string("tvm.relay.type_relation.") + rel_name; - TypeRelationFn env_type_rel_func; - - if (runtime::Registry::Get(func_name)) { - auto env_func = EnvFunc::Get(func_name); - env_type_rel_func = env_func; - } else { - runtime::Registry::Register(func_name) - .set_body(type_rel_func.packed()); - auto env_func = EnvFunc::Get(func_name); - env_type_rel_func = env_func; - } - - Array type_params; - Array arg_types; - - // Add inputs. - std::string input_name_prefix = "in"; - for (int i = 0; i < get()->num_inputs; i++) { - auto name = input_name_prefix + std::to_string(i); - auto param = TypeVarNode::make(name, Kind::kType); - type_params.push_back(param); - arg_types.push_back(param); - } - - Array ty_call_args = arg_types; - - // Add output type. - auto out_param = TypeVarNode::make("out", Kind::kType); - type_params.push_back(out_param); - // this will trigger copy on write. - ty_call_args.push_back(out_param); - - // The attributes of primitive op is nullptr - // - // The attributes of primitive operator can vary at the call site. - // The type of sum is also dependent on Attrs being passed. - // So puting nullptr in the Attrs means that the operator is polymorphic on Attrs. - // - // A common example is sum(x, axis), where the choice of axis - // can affect the type of the function. - TypeConstraint type_rel = - TypeRelationNode::make(env_type_rel_func, - ty_call_args, - arg_types.size(), - Attrs()); - - auto func_type = - FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); - - get()->op_type = func_type; - - return *this; -} +using Op = tvm::Op; +using OpNode = tvm::OpNode; -inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) - get()->num_inputs = n; - return *this; -} - -template -inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) - get()->attrs_type_key = AttrsType::_type_key; - get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); - return *this; -} - -inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) - get()->support_level = n; - return *this; -} - -template -inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) - const std::string& attr_name, const ValueType& value, int plevel) { - CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; - TVMRetValue rv; - rv = value; - UpdateAttr(attr_name, rv, plevel); - return *this; -} - -// member functions of OpMap -inline int GenericOpMap::count(const Op& op) const { - if (op.defined()) { - const uint32_t idx = op->index_; - return idx < data_.size() ? (data_[idx].second != 0) : 0; - } else { - return 0; - } -} - -inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { - CHECK(op.defined()); - const uint32_t idx = op->index_; - CHECK(idx < data_.size() && data_[idx].second != 0) - << "Attribute " << attr_name_ << " has not been registered for Operator " - << op->name; - return data_[idx].first; -} - -template -inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { - CHECK(op.defined()); - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; - } else { - return value; - } -} - -template -inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const { - CHECK(expr.defined()); - if (const OpNode* op = expr.as()) { - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; - } else { - return value; - } - } else { - return value; - } -} - -template -inline int OpMap::count(const Op& op) const { - return map_.count(op); -} - -template -inline ValueType OpMap::operator[](const Op& op) const { - return map_[op]; -} - -template -inline ValueType OpMap::get(const Op& op, - ValueType def_value) const { - return map_.get(op, def_value); -} - -template -inline ValueType OpMap::get(const Expr& expr, - ValueType def_value) const { - return map_.get(expr, def_value); -} - -/*! - * \brief Check that an expression is a "primitive operator". - * - * Will return true if the expression is an operator which - * matches the form of primitive operators registered directly - * by the Relay codebase. - * - * That is the arguments are all type variables, and there is a single - * type relation applied to the input and output types. - */ -inline bool IsPrimitiveOp(const Expr& expr) { - const auto* op = expr.as(); - return op != nullptr && op->IsPrimitiveOp(); -} +#define RELAY_REGISTER_OP(OpName) \ + TVM_REGISTER_OP(OpName) } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 31e85f9132383..7748bd108dfb0 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAY_TYPE_H_ #define TVM_RELAY_TYPE_H_ - #include +#include #include #include #include @@ -51,6 +51,11 @@ using TypeConstraint = tvm::TypeConstraint; using TypeConstraintNode = tvm::TypeConstraintNode; using FuncType = tvm::FuncType; using FuncTypeNode = tvm::FuncTypeNode; +using TypeRelation = tvm::TypeRelation; +using TypeRelationNode = tvm::TypeRelationNode; +using TypeRelationFn = tvm::TypeRelationFn; +using TypeReporter = tvm::TypeReporter; +using TypeReporterNode = tvm::TypeReporterNode; /*! * \brief Base of all Tensor types @@ -235,146 +240,6 @@ class RefType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode); }; -class TypeReporter; - -/*! - * \brief reporter that reports back to the - * type resolution information. - */ -class TypeReporterNode : public Object { - public: - /*! - * \brief Create a type equality constraint. - * - * The "assign direction" acts as a hint to the solver - * showing that it is more likely to resolve dst by src. - * But it is possible for the solver to resolve src by dst as well. - */ - TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0; - - /*! - * \brief assert shape expression comparison. - * \note Use assert only if any of the condition input is symbolic. - * \param cond The condition of operation. - * \return false if assertation can be proven to have failed - * true if solver can still proceed. - */ - TVM_DLL virtual bool Assert(const IndexExpr& cond)= 0; - /*! - * \brief assert shape expression equals each other. - * \param lhs The left operand. - * \param rhs The right operand. - * \return false if assertation can be proven to have failed - * true if solver can still proceed. - */ - TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0; - - /*! - * \brief Set the location at which to report unification errors. - * \param ref The program node to report the error. - */ - TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0; - - /*! - * \brief Retrieve the current global module. - * \return The global module. - */ - TVM_DLL virtual Module GetModule() = 0; - - // solver is not serializable. - void VisitAttrs(tvm::AttrVisitor* v) {} - - static constexpr const char* _type_key = "relay.TypeReporter"; - TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); -}; - -/*! - * \brief Container class of TypeReporter. - * \sa TypeReporterNode - */ -class TypeReporter : public ObjectRef { - public: - TypeReporter() {} - explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { - } - TypeReporterNode* operator->() const { - return const_cast( - static_cast(get())); - } - using ContainerType = TypeReporterNode; -}; - -/*! - * \brief User defined type constraint function. - * - * If the input type information can be used to fully decide - * the IncompleteTypes, then the function should call - * reporter.Assign to report the new types, and return true. - * Otherwise, the function should return false. - * - * \param args The arguments to the relation. - * The types are stored in the form of - * [input_type_0, input_type_1, ... input_type_n, - * output_type_0, output_type_1, ... output_type_m] - * - * \param num_inputs Number of input types in the args. - * \param attrs The additional attributes of the operator. - * \param reporter The reporter to report solution to. - * \return false if This relation cannot be resolved. - * true if this relation has been resolved. - */ -using TypeRelationFn = - TypedEnvFunc& args, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter)>; - -/*! - * \brief User defined type relation, is an input-output relation on types. - */ -class TypeRelation; -/*! - * \brief TypeRelation container. - * \note This node is not directly serializable. - * The type function need to be lookedup in the module. - */ -class TypeRelationNode : public TypeConstraintNode { - public: - /*! - * \brief The function on input and output variables which - * this is not directly serializable, - * need to be looked-up in the module. - */ - TypeRelationFn func; - /*! \brief The type arguments to the type function. */ - tvm::Array args; - /*! \brief Number of inputs arguments */ - int num_inputs; - /*! \brief Attributes to the relation function */ - Attrs attrs; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("args", &args); - v->Visit("num_inputs", &num_inputs); - v->Visit("attrs", &attrs); - v->Visit("span", &span); - } - - TVM_DLL static TypeRelation make(TypeRelationFn func, - Array args, - int num_args, - Attrs attrs); - - static constexpr const char* _type_key = "relay.TypeRelation"; - TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); -}; - -class TypeRelation : public TypeConstraint { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); -}; - // The following fields contains advanced typing // Only keep the class name and reserved for future usage. class GenericTensorType; diff --git a/src/relay/ir/op.cc b/src/ir/op.cc similarity index 96% rename from src/relay/ir/op.cc rename to src/ir/op.cc index b888ecbd92410..0ed2f3dcb015a 100644 --- a/src/relay/ir/op.cc +++ b/src/ir/op.cc @@ -18,11 +18,11 @@ */ /*! - * \file src/tvm/relay/op.cc - * \brief Resolve incomplete types to complete types. + * \file src/tvm/ir/op.cc + * \brief Primitive operators and intrinsics. */ -#include -#include +#include +#include #include #include @@ -31,11 +31,10 @@ namespace dmlc { // enable registry -DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry); +DMLC_REGISTRY_ENABLE(::tvm::OpRegistry); } // namespace dmlc namespace tvm { -namespace relay { ::dmlc::Registry* OpRegistry::Registry() { return ::dmlc::Registry::Get(); @@ -230,5 +229,4 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "Op(" << node->name << ")"; }); -} // namespace relay } // namespace tvm diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc new file mode 100644 index 0000000000000..cc5ceef7dd3eb --- /dev/null +++ b/src/ir/type_relation.cc @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/ir/type_relation.cc + * \brief Type relation + */ +#include +#include +#include +#include + +namespace tvm { +TypeRelation TypeRelationNode::make(TypeRelationFn func, + Array args, + int num_inputs, + Attrs attrs) { + ObjectPtr n = make_object(); + n->func = std::move(func); + n->args = std::move(args); + n->num_inputs = num_inputs; + n->attrs = std::move(attrs); + return TypeRelation(n); +} + +TVM_REGISTER_NODE_TYPE(TypeRelationNode); + +TVM_REGISTER_GLOBAL("relay._make.TypeRelation") +.set_body_typed(TypeRelationNode::make); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeRelationNode(" + << node->func->name + << ", " << node->args << ")"; +}); +} // namespace tvm diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 4ae2ee58ae51a..099b8013c895a 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -101,30 +101,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; }); -TypeRelation TypeRelationNode::make(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { - ObjectPtr n = make_object(); - n->func = std::move(func); - n->args = std::move(args); - n->num_inputs = num_inputs; - n->attrs = std::move(attrs); - return TypeRelation(n); -} - -TVM_REGISTER_NODE_TYPE(TypeRelationNode); - -TVM_REGISTER_GLOBAL("relay._make.TypeRelation") -.set_body_typed(TypeRelationNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeRelationNode(" - << node->func->name - << ", " << node->args << ")"; -}); TupleType TupleTypeNode::make(Array fields) { ObjectPtr n = make_object();