From 997a14eda9aec3b343e742e55c3018f9dc23d8c3 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 27 Mar 2020 22:21:00 -0700 Subject: [PATCH] [NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154) * [NODE][IR] Introduce StructuralEqual Infra for the Unified IR. This PR introduces a new way to handle structural equality for both TIR and relay nodes in an extensive way. - Each object can now register an optional SEqualReduce function, which describes how to reduce its structural equality to another instance into equality of the children. - Optionally, the object can choose to allow remapping of vars(e.g. function parameters) by calling DefEqual - We implemented a non-recursive structural equality checker that recursively traverses the objects and does the structural equality checking. This PR also fixes a few potential problems in previous relay's AlphaEqual. - In particular, the new structural equality relation will be communicative. - It is can be dangerous to use same_as relation to quickly check equality, demonstrated by the following case. (%x, %y) are shared vars between two functions. - function0: fn (%x, %y) { %x + %y } - function1: fn (%y, %x) { %x + %y } The new structural equal is intented to supersede AlphaEqual and AttrsEqual. Follow-up PRs should be performed to redirect the existing usages, and removes the corresponding implementation. * Update the rule to distinguish between graph node and non-graph nodes. * Refactor the test cases to use structural equal. * address comments * Mark more relay::Expr as graph node, fix a testcase issue(was bug that was not caught by previous alpha equal) * Remove unrelated comment * Fix file comment * Address review comment * Relax condition to fit flaky case --- include/tvm/arith/analyzer.h | 8 + include/tvm/arith/int_set.h | 1 + include/tvm/ir/adt.h | 15 + include/tvm/ir/attrs.h | 41 +++ include/tvm/ir/env_func.h | 5 + include/tvm/ir/expr.h | 21 ++ include/tvm/ir/module.h | 3 + include/tvm/ir/op.h | 5 + include/tvm/ir/span.h | 11 + include/tvm/ir/tensor_type.h | 6 + include/tvm/ir/transform.h | 2 + include/tvm/ir/type.h | 43 +++ include/tvm/ir/type_relation.h | 14 + include/tvm/node/container.h | 20 +- include/tvm/node/node.h | 2 + include/tvm/node/reflection.h | 148 ++++++++- include/tvm/node/structural_equal.h | 225 ++++++++++++++ include/tvm/relay/adt.h | 32 ++ include/tvm/relay/expr.h | 71 +++++ include/tvm/relay/function.h | 11 + include/tvm/runtime/ndarray.h | 22 ++ include/tvm/runtime/object.h | 4 + include/tvm/tir/buffer.h | 15 + include/tvm/tir/expr.h | 153 +++++++++- include/tvm/tir/function.h | 11 + include/tvm/tir/stmt.h | 103 +++++++ python/tvm/ir/__init__.py | 1 + python/tvm/ir/base.py | 73 +++++ src/ir/attr_functor.h | 4 +- src/ir/expr.cc | 8 +- src/ir/module.cc | 19 +- src/node/container.cc | 140 +++++++++ src/node/reflection.cc | 2 +- src/node/structural_equal.cc | 241 +++++++++++++++ src/tir/ir/expr.cc | 18 +- .../frontend/tensorflow/test_forward.py | 2 +- tests/python/relay/test_ir_parser.py | 109 +++---- ...a_equal.py => test_ir_structural_equal.py} | 280 +++++++++--------- .../relay/test_pass_dead_code_elimination.py | 14 +- tests/python/relay/test_pass_partial_eval.py | 26 +- tests/python/relay/test_pass_qnn_legalize.py | 8 +- .../relay/test_pass_to_a_normal_form.py | 4 +- tests/python/relay/test_pass_to_cps.py | 2 +- tests/python/relay/test_type_infer.py | 3 +- tests/python/unittest/test_node_reflection.py | 4 +- .../unittest/test_tir_structural_equal.py | 102 +++++++ 46 files changed, 1781 insertions(+), 271 deletions(-) create mode 100644 include/tvm/node/structural_equal.h create mode 100644 src/node/structural_equal.cc rename tests/python/relay/{test_pass_alpha_equal.py => test_ir_structural_equal.py} (78%) create mode 100644 tests/python/unittest/test_tir_structural_equal.py diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 31f2216b7e2b..e7f5ede22995 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -68,6 +68,10 @@ class ConstIntBoundNode : public Object { v->Visit("max_value", &max_value); } + bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const { + return equal(min_value, other->min_value) && equal(max_value, other->max_value); + } + /*! \brief Number to represent +inf */ static const constexpr int64_t kPosInf = std::numeric_limits::max(); /*! @@ -170,6 +174,10 @@ class ModularSetNode : public Object { v->Visit("base", &base); } + bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const { + return equal(coeff, other->coeff) && equal(base, other->base); + } + static constexpr const char* _type_key = "arith.ModularSet"; TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object); }; diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 8b73f871f1d2..86ef906fef0a 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -59,6 +59,7 @@ enum SignType { class IntSetNode : public Object { public: static constexpr const char* _type_key = "IntSet"; + static constexpr bool _type_has_method_sequal_reduce = false; TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object); }; diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index 67cfb8d67a18..260161432d2b 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -63,6 +63,14 @@ class ConstructorNode : public RelayExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const { + // Use namehint for now to be consistent with the legacy relay impl + // TODO(tvm-team) revisit, need to check the type var. + return + equal(name_hint, other->name_hint) && + equal(inputs, other->inputs); + } + static constexpr const char* _type_key = "relay.Constructor"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode); }; @@ -108,6 +116,13 @@ class TypeDataNode : public TypeNode { v->Visit("span", &span); } + bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const { + return + equal.DefEqual(header, other->header) && + equal.DefEqual(type_vars, other->type_vars) && + equal(constructors, other->constructors); + } + static constexpr const char* _type_key = "relay.TypeData"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode); }; diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 4413fc36879c..c3b5831353f1 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -118,7 +118,9 @@ class AttrFieldInfoNode : public Object { v->Visit("type_info", &type_info); v->Visit("description", &description); } + static constexpr const char* _type_key = "AttrFieldInfo"; + static constexpr bool _type_has_method_sequal_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); }; @@ -278,6 +280,7 @@ class BaseAttrsNode : public Object { */ TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0; + static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "Attrs"; TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); }; @@ -302,6 +305,10 @@ class DictAttrsNode : public BaseAttrsNode { /*! \brief internal attrs map */ Map dict; + bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const { + return equal(dict, other->dict); + } + // implementations void VisitAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final; @@ -401,6 +408,33 @@ class AttrsEqualVisitor { const AttrsEqual& equal_; }; +class AttrsSEqualVisitor { + public: + bool result_{true}; + // constructor + AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal) + : lhs_(lhs), rhs_(rhs), equal_(equal) { + } + template + AttrNopEntry operator()(const char* key, T* lhs_value) { + if (!result_) return AttrNopEntry(); + const T* rhs_value = + reinterpret_cast( + reinterpret_cast(rhs_) + + (reinterpret_cast(lhs_value) - + reinterpret_cast(lhs_))); + if (!equal_(*lhs_value, *rhs_value)) { + result_ = false; + } + return AttrNopEntry(); + } + + private: + const Object* lhs_; + const Object* rhs_; + const SEqualReducer& equal_; +}; + class AttrsHashVisitor { public: explicit AttrsHashVisitor(const AttrsHash& hasher) @@ -817,6 +851,13 @@ class AttrsNode : public BaseAttrsNode { } } + bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const { + DerivedType* pself = self(); + ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal); + self()->__VisitAttrs__(visitor); + return visitor.result_; + } + Array ListFieldInfo() const final { ::tvm::detail::AttrDocVisitor visitor; self()->__VisitAttrs__(visitor); diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index f5b17bb1db08..1064fd1462de 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -51,7 +51,12 @@ class EnvFuncNode : public Object { v->Visit("name", &name); } + bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { + return this == other; + } + static constexpr const char* _type_key = "EnvFunc"; + static constexpr bool _type_has_method_sequal_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object); }; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 44244df83ff6..fc63da0afd25 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -43,6 +43,7 @@ namespace tvm { class BaseExprNode : public Object { public: static constexpr const char* _type_key = "Expr"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); }; @@ -197,6 +198,13 @@ class GlobalVarNode : public RelayExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { + // name matters for global var. + return + equal(name_hint, other->name_hint) && + equal.FreeVarEqualImpl(this, other); + } + static constexpr const char* _type_key = "GlobalVar"; TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode); }; @@ -228,6 +236,10 @@ class IntImmNode : public PrimExprNode { v->Visit("value", &value); } + bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(value, other->value); + } + static constexpr const char* _type_key = "IntImm"; TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); }; @@ -263,6 +275,10 @@ class FloatImmNode : public PrimExprNode { v->Visit("value", &value); } + bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(value, other->value); + } + static constexpr const char* _type_key = "FloatImm"; TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); }; @@ -353,7 +369,12 @@ class RangeNode : public Object { v->Visit("extent", &extent); } + bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const { + return equal(min, other->min) && equal(extent, other->extent); + } + static constexpr const char* _type_key = "Range"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); }; diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 4613bec70633..38e583dd4d83 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -62,6 +62,8 @@ class IRModuleNode : public Object { v->Visit("global_type_var_map_", &global_type_var_map_); } + TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; + /*! * \brief Add a function to the global environment. * \param var The var of the global function. @@ -235,6 +237,7 @@ class IRModuleNode : public Object { TVM_DLL std::unordered_set Imports() const; static constexpr const char* _type_key = "IRModule"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); private: diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 8a6ab77427fb..f023e8732545 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -101,6 +101,11 @@ class OpNode : public RelayExprNode { v->Visit("support_level", &support_level); } + bool SEqualReduce(const OpNode* other, SEqualReducer equal) const { + // pointer equality is fine as there is only one op with the same name. + return this == other; + } + /*! * \brief Check that if current op is a "primtive operator". * That is the arguments are all type variables, and there is a single diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 4720dfe0a84e..7194e903549c 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -44,6 +44,10 @@ class SourceNameNode : public Object { // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const { + return equal(name, other->name); + } + static constexpr const char* _type_key = "SourceName"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); }; @@ -87,6 +91,13 @@ class SpanNode : public Object { v->Visit("col_offset", &col_offset); } + bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { + return + equal(source, other->source) && + equal(lineno, other->lineno) && + equal(col_offset, other->col_offset); + } + TVM_DLL static Span make(SourceName source, int lineno, int col_offset); static constexpr const char* _type_key = "Span"; diff --git a/include/tvm/ir/tensor_type.h b/include/tvm/ir/tensor_type.h index 70a2df19db6a..05c7a95d5601 100644 --- a/include/tvm/ir/tensor_type.h +++ b/include/tvm/ir/tensor_type.h @@ -73,6 +73,12 @@ class TensorTypeNode : public BaseTensorTypeNode { v->Visit("span", &span); } + bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const { + return + equal(shape, other->shape) && + equal(dtype, other->dtype); + } + /*! \brief Return product of elements in the shape. * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero. */ diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 1b6ea25f9e22..ecd234a93f76 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -111,6 +111,7 @@ class PassContextNode : public Object { } static constexpr const char* _type_key = "transform.PassContext"; + static constexpr bool _type_has_method_sequal_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; @@ -207,6 +208,7 @@ class PassInfoNode : public Object { } static constexpr const char* _type_key = "transform.PassInfo"; + static constexpr bool _type_has_method_sequal_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index c23626e4de7f..dd7002993424 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -79,6 +79,7 @@ class TypeNode : public Object { mutable Span span; static constexpr const char* _type_key = "Type"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); }; @@ -110,6 +111,10 @@ class PrimTypeNode : public TypeNode { v->Visit("dtype", &dtype); } + bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype); + } + static constexpr const char* _type_key = "PrimType"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); }; @@ -152,6 +157,10 @@ class PointerTypeNode : public TypeNode { v->Visit("element_type", &element_type); } + bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const { + return equal(element_type, other->element_type); + } + static constexpr const char* _type_key = "PointerType"; TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); }; @@ -218,6 +227,12 @@ class TypeVarNode : public TypeNode { v->Visit("span", &span); } + bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const { + return + equal(kind, other->kind) && + equal.FreeVarEqualImpl(this, other); + } + static constexpr const char* _type_key = "TypeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); }; @@ -258,6 +273,13 @@ class GlobalTypeVarNode : public TypeNode { v->Visit("kind", &kind); } + bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const { + // name matters for now in global type var. + return + equal(name_hint, other->name_hint) && + equal.FreeVarEqualImpl(this, other); + } + static constexpr const char* _type_key = "GlobalTypeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); }; @@ -294,6 +316,10 @@ class TupleTypeNode : public TypeNode { v->Visit("span", &span); } + bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const { + return equal(fields, other->fields); + } + static constexpr const char* _type_key = "TupleType"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); }; @@ -386,6 +412,15 @@ class FuncTypeNode : public TypeNode { v->Visit("span", &span); } + bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const { + // type params first as they defines type vars. + return + equal.DefEqual(type_params, other->type_params) && + equal(arg_types, other->arg_types) && + equal(ret_type, other->ret_type) && + equal(type_constraints, other->type_constraints); + } + static constexpr const char* _type_key = "FuncType"; TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); }; @@ -432,6 +467,10 @@ class IncompleteTypeNode : public TypeNode { v->Visit("span", &span); } + bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const { + return equal(kind, other->kind); + } + static constexpr const char* _type_key = "IncompleteType"; TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); }; @@ -469,6 +508,10 @@ class RelayRefTypeNode : public TypeNode { v->Visit("span", &span); } + bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const { + return equal(value, other->value); + } + // Keep the relay prefix in the type as this type is specific // to the relay itself. static constexpr const char* _type_key = "relay.RefType"; diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index f7bfb68a54ef..592bf25a7270 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -50,6 +50,12 @@ class TypeCallNode : public TypeNode { v->Visit("span", &span); } + bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const { + return + equal(func, other->func) && + equal(args, other->args); + } + static constexpr const char* _type_key = "TypeCall"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); }; @@ -195,6 +201,14 @@ class TypeRelationNode : public TypeConstraintNode { v->Visit("span", &span); } + bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const { + return + equal(func, other->func) && + equal(args, other->args) && + equal(num_inputs, other->num_inputs) && + equal(attrs, other->attrs); + } + static constexpr const char* _type_key = "TypeRelation"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); }; diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index a541385bb575..461fa11b4f30 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -23,7 +23,9 @@ #ifndef TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_ -#include +#include +#include +#include #include #include @@ -34,15 +36,19 @@ namespace tvm { +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; +using runtime::make_object; +using runtime::ObjectHash; +using runtime::ObjectEqual; + /*! \brief array node content in array */ class ArrayNode : public Object { public: /*! \brief the data content */ std::vector data; - void VisitAttrs(AttrVisitor* visitor) { - } - static constexpr const char* _type_key = "Array"; TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); }; @@ -50,9 +56,6 @@ class ArrayNode : public Object { /*! \brief map node content */ class MapNode : public Object { public: - void VisitAttrs(AttrVisitor* visitor) { - } - /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map< ObjectRef, @@ -73,9 +76,6 @@ class StrMapNode : public Object { /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map; - void VisitAttrs(AttrVisitor* visitor) { - } - /*! \brief the data content */ ContainerType data; diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 3ea3d763df74..76e574bcc2b2 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -39,6 +39,8 @@ #include #include #include +#include +#include #include #include diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index daffeb859668..d0a9304a5027 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -29,13 +29,14 @@ #include #include #include +#include #include #include +#include namespace tvm { -// forward declaration using runtime::Object; using runtime::ObjectPtr; using runtime::ObjectRef; @@ -86,6 +87,13 @@ class ReflectionVTable { * does not need as much customization. */ typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor); + /*! + * \brief Equality comparison function. + * \note We use function pointer, instead of std::function + * to reduce the dispatch overhead as field visit + * does not need as much customization. + */ + typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal); /*! * \brief creator function. * \param global_key Key that identifies a global single object. @@ -111,6 +119,14 @@ class ReflectionVTable { * \return the global key if object has one, otherwise return empty string. */ inline std::string GetGlobalKey(Object* self) const; + /*! + * \brief Dispatch the SEqualReduce function. + * \param self The pointer to the object. + * \param other The pointer to another object to be compared. + * \param equal The equality comparator. + * \return the result. + */ + bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const; /*! * \brief Create an initial object using default constructor * by type_key and global key. @@ -139,12 +155,14 @@ class ReflectionVTable { TVM_DLL static ReflectionVTable* Global(); class Registry; - template + template inline Registry Register(); private: /*! \brief Attribute visitor. */ std::vector fvisit_attrs_; + /*! \brief Structural equal function. */ + std::vector fsequal_; /*! \brief Creation function. */ std::vector fcreate_; /*! \brief Global key function. */ @@ -182,6 +200,44 @@ class ReflectionVTable::Registry { uint32_t type_index_; }; + +#define TVM_REFLECTION_REG_VAR_DEF \ + static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \ + __make_reflectiion + +/*! + * \brief Directly register reflection VTable. + * \param TypeName The name of the type. + * \param TraitName A trait class that implements functions like VisitAttrs and SEqualReduce. + * + * \code + * + * // Example SEQualReduce traits for runtime StringObj. + * + * struct StringObjTrait { + * static constexpr const std::nullptr_t VisitAttrs = nullptr; + * + * static bool SEqualReduce(const runtime::StringObj* lhs, + * const runtime::StringObj* rhs, + * SEqualReducer equal) { + * if (lhs == rhs) return true; + * if (lhs->size != rhs->size) return false; + * if (lhs->data != rhs->data) return true; + * return std::memcmp(lhs->data, rhs->data, lhs->size) != 0; + * } + * }; + * + * TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait); + * + * \endcode + * + * \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE. + * And can be used to register the related reflection functions for runtime objects. + */ +#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ + TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ + ::tvm::ReflectionVTable::Global()->Register() \ + /*! * \brief Register a node type to object registry and reflection registry. * \param TypeName The name of the type. @@ -189,15 +245,79 @@ class ReflectionVTable::Registry { */ #define TVM_REGISTER_NODE_TYPE(TypeName) \ TVM_REGISTER_OBJECT_TYPE(TypeName); \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \ - __make_Node ## _ ## TypeName ## __ = \ - ::tvm::ReflectionVTable::Global()->Register() \ - .set_creator([](const std::string&) -> ObjectPtr { \ - return ::tvm::runtime::make_object(); \ - }) + TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait) \ + .set_creator([](const std::string&) -> ObjectPtr { \ + return ::tvm::runtime::make_object(); \ + }) + // Implementation details +namespace detail { + +template +struct ImplVisitAttrs { + static constexpr const std::nullptr_t VisitAttrs = nullptr; +}; + +template +struct ImplVisitAttrs { + static void VisitAttrs(T* self, AttrVisitor* v) { + self->VisitAttrs(v); + } +}; + +template +struct ImplSEqualReduce { + static constexpr const std::nullptr_t SEqualReduce = nullptr; +}; + +template +struct ImplSEqualReduce { + static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) { + return self->SEqualReduce(other, equal); + } +}; + template +struct ReflectionTrait : + public ImplVisitAttrs, + public ImplSEqualReduce { +}; + +template::value> +struct SelectVisitAttrs { + static constexpr const std::nullptr_t VisitAttrs = nullptr; +}; + +template +struct SelectVisitAttrs { + static void VisitAttrs(Object* self, AttrVisitor* v) { + TraitName::VisitAttrs(static_cast(self), v); + } +}; + +template::value> +struct SelectSEqualReduce { + static constexpr const std::nullptr_t SEqualReduce = nullptr; +}; + +template +struct SelectSEqualReduce { + static bool SEqualReduce(const Object* self, + const Object* other, + SEqualReducer equal) { + return TraitName::SEqualReduce(static_cast(self), + static_cast(other), + equal); + } +}; +} // namespace detail + +template inline ReflectionVTable::Registry ReflectionVTable::Register() { uint32_t tindex = T::RuntimeTypeIndex(); @@ -205,15 +325,15 @@ ReflectionVTable::Register() { fvisit_attrs_.resize(tindex + 1, nullptr); fcreate_.resize(tindex + 1, nullptr); fglobal_key_.resize(tindex + 1, nullptr); + fsequal_.resize(tindex + 1, nullptr); } // functor that implemnts the redirection. - struct Functor { - static void VisitAttrs(Object* self, AttrVisitor* v) { - static_cast(self)->VisitAttrs(v); - } - }; + fvisit_attrs_[tindex] = + ::tvm::detail::SelectVisitAttrs::VisitAttrs; + + fsequal_[tindex] = + ::tvm::detail::SelectSEqualReduce::SEqualReduce; - fvisit_attrs_[tindex] = Functor::VisitAttrs; return Registry(this, tindex); } diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h new file mode 100644 index 000000000000..f719e24f619c --- /dev/null +++ b/include/tvm/node/structural_equal.h @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/structural_equal.h + * \brief Structural equality comparison. + */ +#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ +#define TVM_NODE_STRUCTURAL_EQUAL_H_ + +#include +#include +#include +#include + +namespace tvm { + +/*! + * \brief Equality definition of base value class. + */ +class BaseValueEqual { + public: + bool operator()(const double& lhs, const double& rhs) const { + // fuzzy float pt comparison + constexpr double atol = 1e-9; + if (lhs == rhs) return true; + double diff = lhs - rhs; + return diff > -atol && diff < atol; + } + + bool operator()(const int64_t& lhs, const int64_t& rhs) const { + return lhs == rhs; + } + bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { + return lhs == rhs; + } + bool operator()(const int& lhs, const int& rhs) const { + return lhs == rhs; + } + bool operator()(const bool& lhs, const bool& rhs) const { + return lhs == rhs; + } + bool operator()(const std::string& lhs, const std::string& rhs) const { + return lhs == rhs; + } + bool operator()(const DataType& lhs, const DataType& rhs) const { + return lhs == rhs; + } + template::value>::type> + bool operator()(const ENum& lhs, const ENum& rhs) const { + return lhs == rhs; + } +}; + +/*! + * \brief Content-aware structural equality comparator for objects. + * + * The structural equality is recursively defined in the DAG of IR nodes via SEqual. + * There are two kinds of nodes: + * + * - Graph node: a graph node in lhs can only be mapped as equal to + * one and only one graph node in rhs. + * - Normal node: equality is recursively defined without the restriction + * of graph nodes. + * + * Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes. + * For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal + * to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay. + * + * A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var + * with the same type if one of the following condition holds: + * + * - They appear in a same definition point(e.g. function argument). + * - They points to the same VarNode via the same_as relation. + * - They appear in a same usage point, and map_free_vars is set to be True. + */ +class StructuralEqual : public BaseValueEqual { + public: + // inheritate operator() + using BaseValueEqual::operator(); + /*! + * \brief Compare objects via strutural equal. + * \param lhs The left operand. + * \param rhs The right operand. + * \return The comparison result. + */ + TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; +}; + +/*! + * \brief A Reducer class to reduce the structural equality result of two objects. + * + * The reducer will call the SEqualReduce function of each objects recursively. + * Importantly, the reducer may not directly use recursive calls to resolve the + * equality checking. Instead, it can store the necessary equality conditions + * and check later via an internally managed stack. + */ +class SEqualReducer : public BaseValueEqual { + public: + /*! \brief Internal handler that defines custom behaviors.. */ + class Handler { + public: + /*! + * \brief Reduce condition to equality of lhs and rhs. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param map_free_vars Whether do we allow remap variables if possible. + * + * \return false if there is an immediate failure, true otherwise. + * \note This function may save the equality condition of (lhs == rhs) in an internal + * stack and try to resolve later. + */ + virtual bool SEqualReduce(const ObjectRef& lhs, + const ObjectRef& rhs, + bool map_free_vars) = 0; + /*! + * \brief Lookup the graph node equal map for vars that are already mapped. + * + * This is an auxiliary method to check the Map equality. + * \param lhs an lhs value. + * + * \return The corresponding rhs value if any, nullptr if not available. + */ + virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0; + /*! + * \brief Mark current comparison as graph node equal comparison. + */ + virtual void MarkGraphNode() = 0; + }; + + using BaseValueEqual::operator(); + + /*! \brief default constructor */ + SEqualReducer() = default; + /*! + * \brief Constructor with a specific handler. + * \param handler The equal handler for objects. + * \param map_free_vars Whether or not to map free variables. + */ + explicit SEqualReducer(Handler* handler, bool map_free_vars) + : handler_(handler), map_free_vars_(map_free_vars) {} + /*! + * \brief Reduce condition to comparison of two objects. + * \param lhs The left operand. + * \param rhs The right operand. + * \return the immediate check result. + */ + bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { + return handler_->SEqualReduce(lhs, rhs, map_free_vars_); + } + /*! + * \brief Reduce condition to comparison of two definitions, + * where free vars can be mapped. + * + * Call this function to compare definition points such as function params + * and var in a let-binding. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \return the immediate check result. + */ + bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { + return handler_->SEqualReduce(lhs, rhs, true); + } + /*! + * \brief Reduce condition to comparison of two arrays. + * \param lhs The left operand. + * \param rhs The right operand. + * \return the immediate check result. + */ + template + bool operator()(const Array& lhs, const Array& rhs) const { + // quick specialization for Array to reduce amount of recursion + // depth as array comparison is pretty common. + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(operator()(lhs[i], rhs[i]))) return false; + } + return true; + } + /*! + * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var). + * \param lhs The left operand. + * \param rhs The right operand. + * \return the result. + */ + bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const { + // var need to be remapped, so it belongs to graph node. + handler_->MarkGraphNode(); + // We only map free vars if they corresponds to the same address + // or map free_var option is set to be true. + return lhs == rhs || map_free_vars_; + } + + /*! \return Get the internal handler. */ + Handler* operator->() const { + return handler_; + } + + private: + /*! \brief Internal class pointer. */ + Handler* handler_; + /*! \brief Whether or not to map free vars. */ + bool map_free_vars_; +}; + +} // namespace tvm +#endif // TVM_NODE_STRUCTURAL_EQUAL_H_ diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 8189b210e6ba..ea13e25f4d0b 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -46,6 +46,7 @@ using TypeDataNode = tvm::TypeDataNode; class PatternNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Pattern"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object); }; @@ -74,6 +75,10 @@ class PatternWildcardNode : public PatternNode { v->Visit("span", &span); } + bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { + return true; + } + static constexpr const char* _type_key = "relay.PatternWildcard"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); }; @@ -118,6 +123,10 @@ class PatternVarNode : public PatternNode { v->Visit("span", &span); } + bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const { + return equal.DefEqual(var, other->var); + } + static constexpr const char* _type_key = "relay.PatternVar"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); }; @@ -149,6 +158,12 @@ class PatternConstructorNode : public PatternNode { v->Visit("span", &span); } + bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const { + return + equal(constructor, other->constructor) && + equal(patterns, other->patterns); + } + static constexpr const char* _type_key = "relay.PatternConstructor"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode); }; @@ -178,6 +193,10 @@ class PatternTupleNode : public PatternNode { v->Visit("span", &span); } + bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const { + return equal(patterns, other->patterns); + } + static constexpr const char* _type_key = "relay.PatternTuple"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); }; @@ -208,7 +227,12 @@ class ClauseNode : public Object { v->Visit("rhs", &rhs); } + bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const { + return equal(lhs, other->lhs) && equal(rhs, other->rhs); + } + static constexpr const char* _type_key = "relay.Clause"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object); }; @@ -248,6 +272,14 @@ class MatchNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return + equal(data, other->data) && + equal(clauses, other->clauses) && + equal(complete, other->complete); + } + static constexpr const char* _type_key = "relay.Match"; TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode); }; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 3acb5ddae778..731046e0a8cc 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -72,6 +73,10 @@ class ConstantNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { + return equal(data, other->data); + } + static constexpr const char* _type_key = "relay.Constant"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); }; @@ -101,6 +106,16 @@ class TupleNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { + // specially handle empty tuple as a constant is not a graph node. + if (fields.size() == other->fields.size() && fields.size() == 0) { + return true; + } else { + equal->MarkGraphNode(); + return equal(fields, other->fields); + } + } + static constexpr const char* _type_key = "relay.Tuple"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); }; @@ -157,6 +172,12 @@ class VarNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + return + equal(type_annotation, other->type_annotation) && + equal.FreeVarEqualImpl(this, other); + } + TVM_DLL static Var make(std::string name_hint, Type type_annotation); @@ -238,6 +259,16 @@ class CallNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { + // skip type_args check for primitive ops. + equal->MarkGraphNode(); + return + equal(op, other->op) && + equal(args, other->args) && + equal(attrs, other->attrs) && + (IsPrimitiveOp(op) || equal(type_args, other->type_args)); + } + static constexpr const char* _type_key = "relay.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); }; @@ -289,6 +320,14 @@ class LetNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return + equal.DefEqual(var, other->var) && + equal(value, other->value) && + equal(body, other->body); + } + static constexpr const char* _type_key = "relay.Let"; TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); }; @@ -336,6 +375,14 @@ class IfNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return + equal(cond, other->cond) && + equal(true_branch, other->true_branch) && + equal(false_branch, other->false_branch); + } + static constexpr const char* _type_key = "relay.If"; TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); }; @@ -369,6 +416,12 @@ class TupleGetItemNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { + return + equal(tuple, other->tuple) && + equal(index, other->index); + } + static constexpr const char* _type_key = "relay.TupleGetItem"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); }; @@ -398,6 +451,11 @@ class RefCreateNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(value, other->value); + } + static constexpr const char* _type_key = "relay.RefCreate"; TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode); }; @@ -426,6 +484,11 @@ class RefReadNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(ref, other->ref); + } + static constexpr const char* _type_key = "relay.RefRead"; TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode); }; @@ -456,6 +519,13 @@ class RefWriteNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return + equal(ref, other->ref) && + equal(value, other->value); + } + TVM_DLL static RefWrite make(Expr ref, Expr value); static constexpr const char* _type_key = "relay.RefWrite"; @@ -497,6 +567,7 @@ class TempExprNode : public ExprNode { virtual Expr Realize() const = 0; static constexpr const char* _type_key = "relay.TempExpr"; + static constexpr const bool _type_has_method_sequal_reduce = false; TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode); }; diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 5c5bd2673073..ed39caaf2690 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -68,6 +68,17 @@ class FunctionNode : public BaseFuncNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { + // Important to make def equal first. + equal->MarkGraphNode(); + return + equal.DefEqual(params, other->params) && + equal.DefEqual(type_params, other->type_params) && + equal(ret_type, other->ret_type) && + equal(attrs, other->attrs) && + equal(body, other->body); + } + /*! * \brief Return the derived function annotation of this expression. * diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 2441ab659b84..17f81a2a8b68 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -65,6 +65,8 @@ class NDArray : public ObjectRef { inline int use_count() const; /*! \return Pointer to content of DLTensor */ inline const DLTensor* operator->() const; + /*! \return Whether the tensor is contiguous */ + inline bool IsContiguous() const; /*! * \brief Copy data content from another array. * \param other The source array to be copied from. @@ -313,6 +315,26 @@ inline size_t GetDataSize(const DLTensor& arr) { return size; } +/*! + * \brief check if a DLTensor is contiguous. + * \param arr The input DLTensor. + * \return The check result. + */ +inline bool IsContiguous(const DLTensor& arr) { + if (arr.strides == nullptr) return true; + int64_t expected_stride = 1; + for (int32_t i = arr.ndim; i != 0; --i) { + int32_t k = i - 1; + if (arr.strides[k] != expected_stride) return false; + expected_stride *= arr.shape[k]; + } + return true; +} + +inline bool NDArray::IsContiguous() const { + return ::tvm::runtime::IsContiguous(get_mutable()->dl_tensor); +} + inline void NDArray::CopyFrom(const DLTensor* other) { CHECK(data_ != nullptr); CopyFromTo(other, &(get_mutable()->dl_tensor)); diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index fe5e30bba2e8..80b479df31a0 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -211,11 +211,15 @@ class Object { static constexpr bool _type_final = false; static constexpr uint32_t _type_child_slots = 0; static constexpr bool _type_child_slots_can_overflow = true; + // member information + static constexpr bool _type_has_method_visit_attrs = true; + static constexpr bool _type_has_method_sequal_reduce = false; // NOTE: the following field is not type index of Object // but was intended to be used by sub-classes as default value. // The type index of Object is TypeIndex::kRoot static constexpr uint32_t _type_index = TypeIndex::kDynamic; + // Default constructor and copy constructor Object() {} // Override the copy and assign constructors to do nothing. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index c1723168d40c..60dd4558b30f 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -150,6 +150,20 @@ class BufferNode : public Object { v->Visit("buffer_type", &buffer_type); } + bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { + // Use DefEqual as buffer can define variables + // in its semantics, skip name as name is not important. + return + equal.DefEqual(data, other->data) && + equal(dtype, other->dtype) && + equal.DefEqual(shape, other->shape) && + equal.DefEqual(strides, other->strides) && + equal.DefEqual(elem_offset, other->elem_offset) && + equal(scope, other->scope) && + equal(data_alignment, other->data_alignment) && + equal(buffer_type, other->buffer_type); + } + /*! \return preferred index type for this buffer node */ DataType DefaultIndexType() const { return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); @@ -169,6 +183,7 @@ class BufferNode : public Object { BufferType buffer_type); static constexpr const char* _type_key = "Buffer"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); }; diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 90fef87a0a75..28e618659fd8 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -75,6 +75,12 @@ class VarNode : public PrimExprNode { v->Visit("type_annotation", &type_annotation); } + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + if (!equal(dtype, other->dtype)) return false; + if (!equal(type_annotation, other->type_annotation)) return false; + return equal.FreeVarEqualImpl(this, other); + } + static constexpr const char* _type_key = "tir.Var"; TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); }; @@ -288,11 +294,20 @@ class IterVarNode : public Object { v->Visit("thread_tag", &thread_tag); } + bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { + return + equal(dom, other->dom) && + equal.DefEqual(var, other->var) && + equal(iter_type, other->iter_type) && + equal(thread_tag, other->thread_tag); + } + TVM_DLL static IterVar make(Range dom, Var var, IterVarType iter_type, std::string thread_tag = ""); static constexpr const char* _type_key = "IterVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); }; @@ -334,6 +349,10 @@ class StringImmNode : public PrimExprNode { v->Visit("value", &value); } + bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { + return equal(value, other->value); + } + TVM_DLL PrimExpr static make(std::string value); static constexpr const char* _type_key = "StringImm"; @@ -359,6 +378,10 @@ class CastNode : public PrimExprNode { v->Visit("value", &value); } + bool SEqualReduce(const CastNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(value, other->value); + } + TVM_DLL static PrimExpr make(DataType t, PrimExpr v); static constexpr const char* _type_key = "Cast"; @@ -383,6 +406,13 @@ class BinaryOpNode : public PrimExprNode { v->Visit("b", &b); } + bool SEqualReduce(const T* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(a, other->a) && + equal(b, other->b); + } + static PrimExpr make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; @@ -475,6 +505,13 @@ class CmpOpNode : public PrimExprNode { v->Visit("b", &b); } + bool SEqualReduce(const T* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(a, other->a) && + equal(b, other->b); + } + static PrimExpr make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; @@ -539,6 +576,13 @@ class AndNode : public PrimExprNode { v->Visit("b", &b); } + bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(a, other->a) && + equal(b, other->b); + } + TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); static constexpr const char* _type_key = "And"; @@ -559,6 +603,13 @@ class OrNode : public PrimExprNode { v->Visit("b", &b); } + bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(a, other->a) && + equal(b, other->b); + } + TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); static constexpr const char* _type_key = "Or"; @@ -576,6 +627,10 @@ class NotNode : public PrimExprNode { v->Visit("a", &a); } + bool SEqualReduce(const NotNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(a, other->a); + } + TVM_DLL static PrimExpr make(PrimExpr a); static constexpr const char* _type_key = "Not"; @@ -605,6 +660,14 @@ class SelectNode : public PrimExprNode { v->Visit("false_value", &false_value); } + bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(condition, other->condition) && + equal(true_value, other->true_value) && + equal(false_value, other->false_value); + } + TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); static constexpr const char* _type_key = "Select"; @@ -642,6 +705,14 @@ class LoadNode : public PrimExprNode { v->Visit("predicate", &predicate); } + bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(buffer_var, other->buffer_var) && + equal(index, other->index) && + equal(predicate, other->predicate); + } + TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); static constexpr const char* _type_key = "Load"; @@ -673,6 +744,14 @@ class RampNode : public PrimExprNode { v->Visit("lanes", &lanes); } + bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(base, other->base) && + equal(stride, other->stride) && + equal(lanes, other->lanes); + } + TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes); static constexpr const char* _type_key = "Ramp"; @@ -693,6 +772,13 @@ class BroadcastNode : public PrimExprNode { v->Visit("lanes", &lanes); } + bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(value, other->value) && + equal(lanes, other->lanes); + } + TVM_DLL static PrimExpr make(PrimExpr value, int lanes); static constexpr const char* _type_key = "Broadcast"; @@ -718,6 +804,14 @@ class LetNode : public PrimExprNode { v->Visit("body", &body); } + bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal.DefEqual(var, other->var) && + equal(value, other->value) && + equal(body, other->body); + } + TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body); static constexpr const char* _type_key = "Let"; @@ -788,12 +882,22 @@ class CallNode : public PrimExprNode { v->Visit("value_index", &value_index); } + bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(name, other->name) && + equal(args, other->args) && + equal(call_type, other->call_type) && + equal(func, other->func) && + equal(value_index, other->value_index); + } + TVM_DLL static PrimExpr make(DataType dtype, - std::string name, - Array args, - CallType call_type, - FunctionRef func = FunctionRef(), - int value_index = 0); + std::string name, + Array args, + CallType call_type, + FunctionRef func = FunctionRef(), + int value_index = 0); /*! \return Whether call node is pure. */ bool is_pure() const { @@ -856,6 +960,13 @@ class ShuffleNode : public PrimExprNode { v->Visit("indices", &indices); } + bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(vectors, other->vectors) && + equal(indices, other->indices); + } + TVM_DLL static PrimExpr make(Array vectors, Array indices); TVM_DLL static PrimExpr make_concat(Array vectors); TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index); @@ -918,7 +1029,16 @@ class CommReducerNode : public Object { v->Visit("identity_element", &identity_element); } + bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { + return + equal.DefEqual(lhs, other->lhs) && + equal.DefEqual(rhs, other->rhs) && + equal(result, other->result) && + equal(identity_element, other->identity_element); + } + static constexpr const char* _type_key = "CommReducer"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); }; @@ -948,10 +1068,10 @@ class ReduceNode : public PrimExprNode { /*! \brief construct expr from op and rdom */ TVM_DLL static PrimExpr make(CommReducer combiner, - Array src, - Array rdom, - PrimExpr condition, - int value_index); + Array src, + Array rdom, + PrimExpr condition, + int value_index); void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -962,6 +1082,16 @@ class ReduceNode : public PrimExprNode { v->Visit("value_index", &value_index); } + bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const { + // check axis first so IterVars can define the necessary variables. + return + equal(dtype, other->dtype) && + equal(axis, other->axis) && + equal(combiner, other->combiner) && + equal(source, other->source) && + equal(condition, other->condition) && + equal(value_index, other->value_index); + } static constexpr const char* _type_key = "Reduce"; TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); }; @@ -970,6 +1100,11 @@ class ReduceNode : public PrimExprNode { class AnyNode : public PrimExprNode { public: void VisitAttrs(AttrVisitor* v) {} + + bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { + return true; + } + /*! \brief Convert to var. */ Var ToVar() const { return Var("any_dim", DataType::Int(32)); diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 63a8630a9212..26b643a03ebc 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -102,6 +102,16 @@ class PrimFuncNode : public BaseFuncNode { v->Visit("_checked_type_", &checked_type_); } + bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { + // visit params and buffer_map first as they contains defs. + return + equal.DefEqual(params, other->params) && + equal(buffer_map, other->buffer_map) && + equal(ret_type, other->ret_type) && + equal(body, other->body) && + equal(attrs, other->attrs); + } + /*! * \brief Return the derived function annotation of this function. * @@ -112,6 +122,7 @@ class PrimFuncNode : public BaseFuncNode { TVM_DLL FuncType func_type_annotation() const; static constexpr const char* _type_key = "tir.PrimFunc"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index a543737f4065..d4b144dc30d5 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -38,6 +38,7 @@ namespace tir { class StmtNode : public Object { public: static constexpr const char* _type_key = "Stmt"; + static constexpr const bool _type_has_method_sequal_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); }; @@ -65,6 +66,13 @@ class LetStmtNode : public StmtNode { v->Visit("body", &body); } + bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const { + return + equal.DefEqual(var, other->var) && + equal(value, other->value) && + equal(body, other->body); + } + TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body); static constexpr const char* _type_key = "LetStmt"; @@ -99,6 +107,14 @@ class AttrStmtNode : public StmtNode { v->Visit("body", &body); } + bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const { + return + equal(node, other->node) && + equal(attr_key, other->attr_key) && + equal(value, other->value) && + equal(body, other->body); + } + TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, @@ -129,6 +145,13 @@ class AssertStmtNode : public StmtNode { v->Visit("body", &body); } + bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const { + return + equal(condition, other->condition) && + equal(message, other->message) && + equal(body, other->body); + } + TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body); static constexpr const char* _type_key = "AssertStmt"; @@ -152,6 +175,13 @@ class ProducerConsumerNode : public StmtNode { v->Visit("body", &body); } + bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const { + return + equal(func, other->func) && + equal(is_producer, other->is_producer) && + equal(body, other->body); + } + TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); static constexpr const char* _type_key = "ProducerConsumer"; @@ -194,6 +224,14 @@ class StoreNode : public StmtNode { v->Visit("predicate", &predicate); } + bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const { + return + equal(buffer_var, other->buffer_var) && + equal(value, other->value) && + equal(index, other->index) && + equal(predicate, other->predicate); + } + TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, @@ -224,6 +262,14 @@ class ProvideNode : public StmtNode { v->Visit("args", &args); } + bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const { + return + equal(func, other->func) && + equal(value_index, other->value_index) && + equal(value, other->value) && + equal(args, other->args); + } + TVM_DLL static Stmt make(FunctionRef func, int value_index, PrimExpr value, @@ -261,6 +307,15 @@ class AllocateNode : public StmtNode { v->Visit("body", &body); } + bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { + return + equal.DefEqual(buffer_var, other->buffer_var) && + equal(dtype, other->dtype) && + equal(extents, other->extents) && + equal(condition, other->condition) && + equal(body, other->body); + } + TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array extents, @@ -300,6 +355,11 @@ class FreeNode : public StmtNode { v->Visit("buffer_var", &buffer_var); } + bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const { + return + equal(buffer_var, other->buffer_var); + } + TVM_DLL static Stmt make(Var buffer_var); static constexpr const char* _type_key = "Free"; @@ -341,6 +401,16 @@ class RealizeNode : public StmtNode { PrimExpr condition, Stmt body); + bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const { + return + equal(func, other->func) && + equal(value_index, other->value_index) && + equal(dtype, other->dtype) && + equal(bounds, other->bounds) && + equal(condition, other->condition) && + equal(body, other->body); + } + static constexpr const char* _type_key = "Realize"; TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode); }; @@ -369,6 +439,10 @@ class SeqStmtNode : public StmtNode { v->Visit("seq", &seq); } + bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const { + return equal(seq, other->seq); + } + static constexpr const char* _type_key = "SeqStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); }; @@ -472,6 +546,13 @@ class IfThenElseNode : public StmtNode { v->Visit("else_case", &else_case); } + bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const { + return + equal(condition, other->condition) && + equal(then_case, other->then_case) && + equal(else_case, other->else_case); + } + TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); static constexpr const char* _type_key = "IfThenElse"; @@ -493,6 +574,10 @@ class EvaluateNode : public StmtNode { v->Visit("value", &value); } + bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { + return equal(value, other->value); + } + TVM_DLL static Stmt make(PrimExpr v); static constexpr const char* _type_key = "Evaluate"; @@ -562,6 +647,16 @@ class ForNode : public StmtNode { v->Visit("body", &body); } + bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { + return + equal.DefEqual(loop_var, other->loop_var) && + equal(min, other->min) && + equal(extent, other->extent) && + equal(for_type, other->for_type) && + equal(device_api, other->device_api) && + equal(body, other->body); + } + static constexpr const char* _type_key = "For"; TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); }; @@ -587,6 +682,14 @@ class PrefetchNode : public StmtNode { v->Visit("bounds", &bounds); } + bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { + return + equal(func, other->func) && + equal(value_index, other->value_index) && + equal(dtype, other->dtype) && + equal(bounds, other->bounds); + } + TVM_DLL static Stmt make(FunctionRef func, int value_index, DataType dtype, diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 1e114469d986..88af05ce532d 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -17,6 +17,7 @@ # pylint: disable=unused-import """Common data structures across all IR variants.""" from .base import SourceName, Span, Node, EnvFunc, load_json, save_json +from .base import structural_equal, assert_structural_equal from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .tensor_type import TensorType diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 810d78fa00ce..df69a2ce3bf8 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -149,3 +149,76 @@ def save_json(node): Saved json string. """ return tvm.runtime._ffi_node_api.SaveJSON(node) + + +def structural_equal(lhs, rhs, map_free_vars=False): + """Check structural equality of lhs and rhs. + + The structural equality is recursively defined in the DAG of IRNodes. + There are two kinds of nodes: + + - Graph node: a graph node in lhs can only be mapped as equal to + one and only one graph node in rhs. + - Normal node: equality is recursively defined without the restriction + of graph nodes. + + Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes. + For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal + to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay. + + A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var + with the same type if one of the following condition holds: + + - They appear in a same definition point(e.g. function argument). + - They points to the same VarNode via the same_as relation. + - They appear in a same usage point, and map_free_vars is set to be True. + + The rules for var are used to remap variables occurs in function + arguments and let-bindings. + + Parameters + ---------- + lhs : Object + The left operand. + + rhs : Object + The left operand. + + map_free_vars : bool + Whether or not shall we map free vars that does + not bound to any definitions as equal to each other. + + Return + ------ + result : bool + The comparison result. + """ + return tvm.runtime._ffi_node_api.StructuralEqual( + lhs, rhs, False, map_free_vars) + + +def assert_structural_equal(lhs, rhs, map_free_vars=False): + """Assert lhs and rhs are structurally equal to each other. + + Parameters + ---------- + lhs : Object + The left operand. + + rhs : Object + The left operand. + + map_free_vars : bool + Whether or not shall we map free vars that does + not bound to any definitions as equal to each other. + + Raises + ------ + ValueError : if assertion does not hold. + + See Also + -------- + structural_equal + """ + tvm.runtime._ffi_node_api.StructuralEqual( + lhs, rhs, True, map_free_vars) diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index babd08a83c6e..9acc4651d089 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -45,8 +45,8 @@ class AttrFunctor; #define ATTR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitAttr_(static_cast(n.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitAttr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 9731a51b5bc2..b07f04aa6974 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -105,6 +105,7 @@ TVM_REGISTER_GLOBAL("ir.FloatImm") TVM_REGISTER_NODE_TYPE(FloatImmNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -143,17 +144,14 @@ TVM_REGISTER_GLOBAL("ir.Range") *ret = Range(args[0], args[1]); }); +TVM_REGISTER_NODE_TYPE(RangeNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -TVM_REGISTER_NODE_TYPE(ArrayNode); -TVM_REGISTER_NODE_TYPE(MapNode); -TVM_REGISTER_NODE_TYPE(StrMapNode); -TVM_REGISTER_NODE_TYPE(RangeNode); - GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); diff --git a/src/ir/module.cc b/src/ir/module.cc index 4ac769b38e73..ca85cb853417 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -65,6 +65,21 @@ IRModule::IRModule(tvm::Map functions, data_ = std::move(n); } + +bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { + if (functions.size() != other->functions.size()) return false; + for (const auto& kv : this->functions) { + if (!other->ContainGlobalVar(kv.first->name_hint)) return false; + if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; + } + if (type_definitions.size() != other->type_definitions.size()) return false; + for (const auto& kv : this->type_definitions) { + if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; + if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; + } + return true; +} + bool IRModuleNode::ContainGlobalVar(const std::string& name) const { return global_var_map_.find(name) != global_var_map_.end(); } @@ -305,8 +320,8 @@ IRModule IRModule::FromExpr( const tvm::Map& type_definitions) { auto mod = IRModule(global_funcs, type_definitions); BaseFunc func; - if (auto* func_node = expr.as()) { - func = GetRef(func_node); + if (auto* func_node = expr.as()) { + func = GetRef(func_node); } else { func = relay::Function( relay::FreeVars(expr), expr, Type(), diff --git a/src/node/container.cc b/src/node/container.cc index 25bfe9dbba56..fc5c62a25685 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -21,11 +21,98 @@ * \file src/node/container.cc */ #include +#include #include #include +#include namespace tvm { +// SEQualReduce traits for runtime containers. +struct StringObjTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static bool SEqualReduce(const runtime::StringObj* lhs, + const runtime::StringObj* rhs, + SEqualReducer equal) { + if (lhs == rhs) return true; + if (lhs->size != rhs->size) return false; + if (lhs->data != rhs->data) return true; + return std::memcmp(lhs->data, rhs->data, lhs->size) != 0; + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait); + +struct ADTObjTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static bool SEqualReduce(const runtime::ADTObj* lhs, + const runtime::ADTObj* rhs, + SEqualReducer equal) { + if (lhs == rhs) return true; + if (lhs->tag != rhs->tag) return false; + if (lhs->size != rhs->size) return false; + + for (uint32_t i = 0; i < lhs->size; ++i) { + if (!equal((*lhs)[i], (*rhs)[i])) return false; + } + return true; + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); + + +struct NDArrayContainerTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static bool SEqualReduce(const runtime::NDArray::Container* lhs, + const runtime::NDArray::Container* rhs, + SEqualReducer equal) { + if (lhs == rhs) return true; + + auto ldt = lhs->dl_tensor.dtype; + auto rdt = rhs->dl_tensor.dtype; + CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + CHECK(runtime::IsContiguous(lhs->dl_tensor)) + << "Can only compare contiguous tensor"; + CHECK(runtime::IsContiguous(rhs->dl_tensor)) + << "Can only compare contiguous tensor"; + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t data_size = runtime::GetDataSize(lhs->dl_tensor); + return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; + } else { + return false; + } + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait); + + +struct ArrayNodeTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static bool SEqualReduce(const ArrayNode* lhs, + const ArrayNode* rhs, + SEqualReducer equal) { + if (lhs->data.size() != rhs->data.size()) return false; + for (size_t i = 0; i < lhs->data.size(); ++i) { + if (!equal(lhs->data[i], rhs->data[i])) return false; + } + return true; + } +}; + +TVM_REGISTER_OBJECT_TYPE(ArrayNode); +TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) +.set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + + TVM_REGISTER_GLOBAL("node.Array") .set_body([](TVMArgs args, TVMRetValue* ret) { std::vector data; @@ -62,6 +149,59 @@ TVM_REGISTER_GLOBAL("node.ArraySize") static_cast(ptr)->data.size()); }); + +struct MapNodeTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static bool SEqualReduce(const MapNode* lhs, + const MapNode* rhs, + SEqualReducer equal) { + if (rhs->data.size() != lhs->data.size()) return false; + for (const auto& kv : lhs->data) { + // Only allow equal checking if the keys are already mapped + // This resolves common use cases where we want to store + // Map where Var is defined in the function + // parameters. + ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); + if (!rhs_key.defined()) return false; + auto it = rhs->data.find(rhs_key); + if (it == rhs->data.end()) return false; + if (!equal(kv.second, it->second)) return false; + } + return true; + } +}; + +TVM_REGISTER_OBJECT_TYPE(MapNode); +TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) +.set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + + +struct StrMapNodeTrait { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static bool SEqualReduce(const StrMapNode* lhs, + const StrMapNode* rhs, + SEqualReducer equal) { + if (rhs->data.size() != lhs->data.size()) return false; + for (const auto& kv : lhs->data) { + auto it = rhs->data.find(kv.first); + if (it == rhs->data.end()) return false; + if (!equal(kv.second, it->second)) return false; + } + return true; + } +}; + +TVM_REGISTER_OBJECT_TYPE(StrMapNode); +TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait) +.set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + + TVM_REGISTER_GLOBAL("node.Map") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size() % 2, 0); diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 183079ffc82a..824874f24ab0 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -180,7 +180,7 @@ ObjectPtr ReflectionVTable::CreateInitObject(const std::string& type_key, const std::string& global_key) const { uint32_t tindex = Object::TypeKey2Index(type_key); - if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { + if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE"; } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc new file mode 100644 index 000000000000..23dfe1502dea --- /dev/null +++ b/src/node/structural_equal.cc @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/node/structural_equal.cc + */ +#include +#include +#include +#include +#include + +#include + +namespace tvm { + +// Define the dispatch functio here since primary user is in this file. +bool ReflectionVTable:: +SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const { + uint32_t tindex = self->type_index(); + if (tindex >= fsequal_.size() || fsequal_[tindex] == nullptr) { + LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey() + << " is not registered via TVM_REGISTER_NODE_TYPE"; + } + return fsequal_[tindex](self, other, equal); +} + +/*! + * \brief A non recursive stack based SEqual handler that can remaps vars. + * + * This handler pushs the Object equality cases into a stack, and + * traverses the stack to expand the necessary children that need to be checked. + * + * The order of SEqual being called is the same as the order as if we + * eagerly do recursive calls in SEqualReduce. + */ +class RemapVarSEqualHandler : + public SEqualReducer::Handler { + public: + explicit RemapVarSEqualHandler(bool assert_mode) + : assert_mode_(assert_mode) {} + + bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { + // We cannot use check lhs.same_as(rhs) to check equality. + // if we choose to enable var remapping. + // + // Counter example below (%x, %y) are shared vars + // between the two functions(possibly before/after rewriting). + // + // - function0: fn (%x, %y) { %x + %y } + // - function1. fn (%y, %x) { %x + %y } + // + // Because we choose to enable var remapping, + // %x is mapped to %y, and %y is mapped to %x, + // the body of the function no longer means the same thing. + // + // Take away: We can either choose only compare Var by address, + // in which case we can use same_as for quick checking, + // or we have to run deep comparison and avoid to use same_as checks. + auto run = [=]() { + if (!lhs.defined() && !rhs.defined()) return true; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; + if (lhs->type_index() != rhs->type_index()) return false; + auto it = equal_map_lhs_.find(lhs); + if (it != equal_map_lhs_.end()) { + return it->second.same_as(rhs); + } + if (equal_map_rhs_.count(rhs)) return false; + // need to push to pending tasks in this case + pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars)); + return true; + }; + return CheckResult(run(), lhs, rhs); + } + + void MarkGraphNode() final { + // need to push to pending tasks in this case + CHECK(!allow_push_to_stack_ && !task_stack_.empty()); + task_stack_.back().graph_equal = true; + } + + ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { + auto it = equal_map_lhs_.find(lhs); + if (it != equal_map_lhs_.end()) return it->second; + return ObjectRef(nullptr); + } + + // Function that implements actual equality check. + bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { + task_stack_.clear(); + pending_tasks_.clear(); + equal_map_lhs_.clear(); + equal_map_rhs_.clear(); + if (!SEqualReduce(lhs, rhs, map_free_vars)) return false; + CHECK_EQ(pending_tasks_.size(), 1U); + CHECK(allow_push_to_stack_); + task_stack_.emplace_back(std::move(pending_tasks_.back())); + pending_tasks_.clear(); + return RunTasks(); + } + + protected: + // Check the result. + bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { + if (assert_mode_ && !result) { + LOG(FATAL) + << "ValueError: StructuralEqual check failed, caused by\n" + << "lhs = " << lhs << "\nrhs = " << rhs; + } + return result; + } + /*! + * \brief Run tasks until the stack reaches the stack begin + * \param stack_begin The expected beginning of the stack. + * \return The checks we encountered throughout the process. + */ + bool RunTasks() { + while (task_stack_.size() != 0) { + // Caution: entry becomes invalid when the stack changes + auto& entry = task_stack_.back(); + + if (entry.children_expanded) { + // When all the children has expanded and visited. + // This means all the condition checks for + // the current entry has been passed + // We can safely mark lhs and rhs as equal to each other. + auto it = equal_map_lhs_.find(entry.lhs); + if (it != equal_map_lhs_.end()) { + CHECK(it->second.same_as(entry.rhs)); + } + // create the map if the quality is graph equal. + if (entry.graph_equal) { + equal_map_lhs_[entry.lhs] = entry.rhs; + equal_map_rhs_[entry.rhs] = entry.lhs; + } + task_stack_.pop_back(); + } else { + // mark before expand + // Important: because entry becomes invalid when stack changes. + entry.children_expanded = true; + // Expand the objects + // The SEqual of the object can call into this->SEqualReduce + // which populates the pending tasks. + CHECK_EQ(pending_tasks_.size(), 0U); + allow_push_to_stack_ = false; + if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars)) return false; + allow_push_to_stack_ = true; + // Push pending tasks in reverse order, so earlier tasks get to + // expand first in the stack + while (pending_tasks_.size() != 0) { + task_stack_.emplace_back(std::move(pending_tasks_.back())); + pending_tasks_.pop_back(); + } + } + } + return true; + } + + // The default equal as registered in the structural equal vtable. + bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { + auto compute = [=]() { + CHECK(lhs.defined() && + rhs.defined() && + lhs->type_index() == rhs->type_index()); + // skip entries that already have equality maps. + auto it = equal_map_lhs_.find(lhs); + if (it != equal_map_lhs_.end()) { + return it->second.same_as(rhs); + } + if (equal_map_rhs_.count(rhs)) return false; + // Run reduce check for free nodes. + return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, map_free_vars)); + }; + return CheckResult(compute(), lhs, rhs); + } + + private: + /*! \brief Pending reduce tasks. */ + struct Task { + /*! \brief The lhs operand to be compared. */ + ObjectRef lhs; + /*! \brief The rhs operand to be compared. */ + ObjectRef rhs; + /*! \brief The map free var argument. */ + bool map_free_vars; + /*! \brief Whether the children has been expanded via SEqualReduce */ + bool children_expanded{false}; + /*! \brief whether the task is about graph equality(need remap). */ + bool graph_equal{false}; + + Task() = default; + Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars) + : lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {} + }; + // list of pending tasks to be pushed to the stack. + std::vector pending_tasks_; + // Internal task stack to executed the task. + std::vector task_stack_; + // Whether we allow push to stack. + bool allow_push_to_stack_{true}; + // If in assert mode, must return true, and will throw error otherwise. + bool assert_mode_{false}; + // reflection vtable + ReflectionVTable* vtable_ = ReflectionVTable::Global(); + // map from lhs to rhs + std::unordered_map equal_map_lhs_; + // map from rhs to lhs + std::unordered_map equal_map_rhs_; +}; + + +TVM_REGISTER_GLOBAL("node.StructuralEqual") +.set_body_typed([](const ObjectRef& lhs, + const ObjectRef& rhs, + bool assert_mode, + bool map_free_vars) { + return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars); +}); + +bool StructuralEqual::operator()(const ObjectRef& lhs, + const ObjectRef& rhs) const { + return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); +} + +} // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 07759b3f126e..bee025687173 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -81,7 +81,8 @@ TVM_REGISTER_GLOBAL("tir.Var") TVM_REGISTER_GLOBAL("tir.SizeVar") .set_body_typed([](std::string s, DataType t) { return SizeVar(s, t); - }); +}); + IterVar IterVarNode::make(Range dom, Var var, @@ -132,6 +133,7 @@ PrimExpr StringImmNode::make(std::string value) { TVM_REGISTER_GLOBAL("tir.StringImm") .set_body_typed(StringImmNode::make); + PrimExpr CastNode::make(DataType t, PrimExpr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); @@ -141,6 +143,7 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) { return PrimExpr(node); } + PrimExpr AndNode::make(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; @@ -169,6 +172,7 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) { return PrimExpr(node); } + PrimExpr NotNode::make(PrimExpr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); @@ -179,6 +183,8 @@ PrimExpr NotNode::make(PrimExpr a) { return PrimExpr(node); } + + PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined"; @@ -270,11 +276,11 @@ bool CallNode::is_vectorizable() const { } PrimExpr CallNode::make(DataType dtype, - std::string name, - Array args, - CallType call_type, - FunctionRef func, - int value_index) { + std::string name, + Array args, + CallType call_type, + FunctionRef func, + int value_index) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 9d875c17a198..b1efe4a8c26f 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1114,7 +1114,7 @@ def test_read_variable_op(): num_output=len(out_name)) for i in range(len(tf_output)): tvm.testing.assert_allclose( - tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) + tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5) sess.close() diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index fbe521340930..9e624917ab1a 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -17,8 +17,6 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.analysis import graph_equal, assert_graph_equal -from tvm.relay.analysis import alpha_equal, assert_alpha_equal import pytest from numpy import isclose from typing import Union @@ -69,6 +67,13 @@ } """ +def assert_graph_equal(lhs, rhs): + tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True) + +def graph_equal(lhs, rhs): + return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True) + + def roundtrip(expr): x = relay.fromtext(expr.astext()) assert_graph_equal(x, expr) @@ -86,6 +91,12 @@ def parses_as(code, expr): result = graph_equal(parsed, expr) return result + +def assert_parses_as(code, expr): + parsed = parse_text(code) + assert_graph_equal(parsed, expr) + + def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) return x.data.asnumpy().item() @@ -102,7 +113,7 @@ def get_scalar(x): def test_comments(): - assert parses_as( + assert_parses_as( """ // This is a line comment! () @@ -110,7 +121,7 @@ def test_comments(): UNIT ) - assert parses_as( + assert_parses_as( """ /* This is a block comment! This is still a block comment! @@ -120,7 +131,7 @@ def test_comments(): UNIT ) - assert parses_as( + assert_parses_as( """ /* This is a block comment! /*Block comment is recursive!*/ @@ -172,7 +183,7 @@ def test_negative(): def test_bin_op(): for bin_op in BINARY_OPS.keys(): - assert parses_as( + assert_parses_as( "1 {} 1".format(bin_op), BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) ) @@ -213,7 +224,7 @@ def test_vars(): def test_let(): - assert parses_as( + assert_parses_as( "let %x = 1; ()", relay.Let( X, @@ -222,7 +233,7 @@ def test_let(): ) ) - assert parses_as( + assert_parses_as( """ let %x = 1; let %y = 2; @@ -241,7 +252,7 @@ def test_let(): def test_seq(): - assert parses_as( + assert_parses_as( "();; ()", relay.Let( _, @@ -249,7 +260,7 @@ def test_seq(): UNIT) ) - assert parses_as( + assert_parses_as( "let %_ = 1; ()", relay.Let( X, @@ -261,14 +272,10 @@ def test_seq(): def test_graph(): code = "%0 = (); %1 = 1; (%0, %0, %1)" - assert parses_as( + assert_parses_as( code, relay.Tuple([UNIT, UNIT, relay.const(1)]) ) - assert not parses_as( - code, - relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)]) - ) @raises_parse_error @@ -287,18 +294,18 @@ def test_let_op(): def test_tuple(): - assert parses_as("()", relay.Tuple([])) + assert_parses_as("()", relay.Tuple([])) - assert parses_as("(0,)", relay.Tuple([relay.const(0)])) + assert_parses_as("(0,)", relay.Tuple([relay.const(0)])) - assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)])) + assert_parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)])) - assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) + assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) def test_func(): # 0 args - assert parses_as( + assert_parses_as( "fn () { 0 }", relay.Function( [], @@ -309,7 +316,7 @@ def test_func(): ) # 1 arg - assert parses_as( + assert_parses_as( "fn (%x) { %x }", relay.Function( [X], @@ -320,7 +327,7 @@ def test_func(): ) # 2 args - assert parses_as( + assert_parses_as( "fn (%x, %y) { %x + %y }", relay.Function( [X, Y], @@ -331,7 +338,7 @@ def test_func(): ) # annotations - assert parses_as( + assert_parses_as( "fn (%x: int32) -> int32 { %x }", relay.Function( [X_ANNO], @@ -342,7 +349,7 @@ def test_func(): ) # attributes - assert parses_as( + assert_parses_as( "fn (n=5) { () }", relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5))) ) @@ -370,7 +377,7 @@ def @id(%x: int32) -> int32 { def test_ifelse(): - assert parses_as( + assert_parses_as( """ if (True) { 0 @@ -403,7 +410,7 @@ def test_ifelse_scope(): def test_call(): # select right function to call: simple ident case id_func = relay.Var("id") - assert parses_as( + assert_parses_as( """ let %id = fn (%x) { %x }; 10 * %id(10) @@ -417,7 +424,7 @@ def test_call(): # 0 args constant = relay.Var("constant") - assert parses_as( + assert_parses_as( """ let %constant = fn () { 0 }; %constant() @@ -431,7 +438,7 @@ def test_call(): # 1 arg id_var = relay.Var("id") - assert parses_as( + assert_parses_as( """ let %id = fn (%x) { %x }; %id(1) @@ -445,7 +452,7 @@ def test_call(): # 2 args multiply = relay.Var("multiply") - assert parses_as( + assert_parses_as( """ let %multiply = fn (%x, %y) { %x * %y }; %multiply(0, 0) @@ -463,7 +470,7 @@ def test_call(): ) # anonymous function - assert parses_as( + assert_parses_as( """ (fn (%x) { %x })(0) """, @@ -483,7 +490,7 @@ def test_call(): # TODO(@jmp): re-enable after sequence parsing improvements # curried function # curried_mult = relay.Var("curried_mult") - # assert parses_as( + # assert_parses_as( # """ # let %curried_mult = # fn (%x) { @@ -516,7 +523,7 @@ def test_call(): # ) # op - assert parses_as( + assert_parses_as( "abs(1)", relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) ) @@ -525,7 +532,7 @@ def test_call(): def test_incomplete_type(): - assert parses_as( + assert_parses_as( "let %_ : _ = (); ()", relay.Let( _, @@ -541,7 +548,7 @@ def test_builtin_types(): def test_tensor_type(): - assert parses_as( + assert_parses_as( "let %_ : Tensor[(), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((), "float32")), @@ -550,7 +557,7 @@ def test_tensor_type(): ) ) - assert parses_as( + assert_parses_as( "let %_ : Tensor[(1), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), @@ -559,7 +566,7 @@ def test_tensor_type(): ) ) - assert parses_as( + assert_parses_as( "let %_ : Tensor[(1, 1), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), @@ -570,7 +577,7 @@ def test_tensor_type(): def test_function_type(): - assert parses_as( + assert_parses_as( """ let %_: fn () -> int32 = fn () -> int32 { 0 }; () """, @@ -581,7 +588,7 @@ def test_function_type(): ) ) - assert parses_as( + assert_parses_as( """ let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () """, @@ -592,7 +599,7 @@ def test_function_type(): ) ) - assert parses_as( + assert_parses_as( """ let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () """, @@ -605,7 +612,7 @@ def test_function_type(): def test_tuple_type(): - assert parses_as( + assert_parses_as( """ let %_: () = (); () """, @@ -616,7 +623,7 @@ def test_tuple_type(): ) ) - assert parses_as( + assert_parses_as( """ let %_: (int32,) = (0,); () """, @@ -627,7 +634,7 @@ def test_tuple_type(): ) ) - assert parses_as( + assert_parses_as( """ let %_: (int32, int32) = (0, 1); () """, @@ -648,7 +655,7 @@ def test_adt_defn(): [], [relay.Constructor("Nil", [], glob_typ_var)]) mod[glob_typ_var] = prog - assert parses_as( + assert_parses_as( """ type Ayy { Nil } """, @@ -662,7 +669,7 @@ def test_empty_adt_defn(): glob_typ_var = relay.GlobalTypeVar("Ayy") prog = relay.TypeData(glob_typ_var, [], []) mod[glob_typ_var] = prog - assert parses_as( + assert_parses_as( """ type Ayy { } """, @@ -683,7 +690,7 @@ def test_multiple_cons_defn(): relay.Constructor("Nil", [], list_var), ]) mod[list_var] = prog - assert parses_as(LIST_DEFN, mod) + assert_parses_as(LIST_DEFN, mod) def test_multiple_type_param_defn(): @@ -699,7 +706,7 @@ def test_multiple_type_param_defn(): ]) mod = tvm.IRModule() mod[glob_typ_var] = prog - assert parses_as( + assert_parses_as( """ type Either[A, B] { Left(A), @@ -755,7 +762,7 @@ def test_match(): ) mod[length_var] = length_func - assert parses_as( + assert_parses_as( """ %s @@ -796,7 +803,7 @@ def test_adt_cons_expr(): ) mod[make_singleton_var] = make_singleton_func - assert parses_as( + assert_parses_as( """ %s @@ -861,7 +868,7 @@ def test_extern_adt_defn(): extern_def = relay.TypeData(extern_var, [typ_var], []) mod[extern_var] = extern_def - assert parses_as( + assert_parses_as( """ extern type T[A] """, @@ -872,6 +879,7 @@ def test_import_grad(): mod.import_from_std("gradient.rly") if __name__ == "__main__": + test_graph() test_comments() test_int_literal() test_float_literal() @@ -882,7 +890,6 @@ def test_import_grad(): test_op_assoc() test_let() test_seq() - test_graph() test_tuple() test_func() test_defn() @@ -905,4 +912,4 @@ def test_import_grad(): test_duplicate_adt_cons_defn() test_duplicate_global_var() test_extern_adt_defn() - test_import_grad() \ No newline at end of file + test_import_grad() diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_ir_structural_equal.py similarity index 78% rename from tests/python/relay/test_pass_alpha_equal.py rename to tests/python/relay/test_ir_structural_equal.py index 411906dbf83f..5881ab9d178c 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_ir_structural_equal.py @@ -21,23 +21,24 @@ from tvm.relay import analysis from tvm.relay.testing import run_opt_pass -def alpha_equal(x, y): +def sequal(x, y): """ Wrapper around alpha equality which ensures that the hash function respects equality. """ - return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) + return (tvm.ir.structural_equal(x, y) and + analysis.structural_hash(x) == analysis.structural_hash(y)) -def alpha_equal_commutative(x, y): +def sequal_commutative(x, y): """ Check for commutative property of equality """ - xy = analysis.alpha_equal(x, y) - yx = analysis.alpha_equal(y, x) + xy = tvm.ir.structural_equal(x, y) + yx = tvm.ir.structural_equal(y, x) assert xy == yx return xy -def test_tensor_type_alpha_equal(): +def test_tensor_type_sequal(): t1 = relay.TensorType((3, 4), "float32") t2 = relay.TensorType((3, 4), "float32") t3 = relay.TensorType((3, 4, 5), "float32") @@ -49,7 +50,7 @@ def test_tensor_type_alpha_equal(): assert t1 == t2 -def test_incomplete_type_alpha_equal(): +def test_incomplete_type_sequal(): t1 = relay.IncompleteType(relay.TypeKind.ShapeVar) t2 = relay.IncompleteType(relay.TypeKind.Type) t3 = relay.IncompleteType(relay.TypeKind.Type) @@ -61,7 +62,7 @@ def test_incomplete_type_alpha_equal(): assert t2 != t3 -def test_type_param_alpha_equal(): +def test_type_param_sequal(): t1 = relay.TypeVar("v1", relay.TypeKind.Type) t2 = relay.TypeVar("v2", relay.TypeKind.ShapeVar) t3 = relay.TypeVar("v3", relay.TypeKind.Type) @@ -83,7 +84,7 @@ def test_type_param_alpha_equal(): assert ft1 != ft3 # kinds still do not match -def test_func_type_alpha_equal(): +def test_func_type_sequal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") @@ -143,7 +144,7 @@ def test_func_type_alpha_equal(): assert ft != more_rels -def test_tuple_type_alpha_equal(): +def test_tuple_type_sequal(): t1 = relay.TensorType((1, 2, 3), "float32") t2 = relay.TensorType((1, 2, 3, 4), "float32") tp1 = relay.TypeVar("v1", relay.TypeKind.Type) @@ -161,7 +162,7 @@ def test_tuple_type_alpha_equal(): assert tup1 != tup4 -def test_type_relation_alpha_equal(): +def test_type_relation_sequal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") t3 = relay.TensorType((1, 2, 3, 4), "float32") @@ -197,7 +198,7 @@ def test_type_relation_alpha_equal(): assert bigger != diff_num_inputs -def test_type_call_alpha_equal(): +def test_type_call_sequal(): h1 = relay.GlobalTypeVar("h1") h2 = relay.GlobalTypeVar("h2") t1 = relay.TensorType((1, 2), "float32") @@ -221,49 +222,49 @@ def test_type_call_alpha_equal(): assert tc != different_order_args -def test_constant_alpha_equal(): +def test_constant_sequal(): x = relay.const(1) y = relay.const(2) - assert alpha_equal(x, x) - assert not alpha_equal(x, y) - assert alpha_equal(x, relay.const(1)) + assert sequal(x, x) + assert not sequal(x, y) + assert sequal(x, relay.const(1)) -def test_type_node_alpha_equal(): +def test_type_node_sequal(): v1 = relay.TypeVar('v1', 6) v2 = relay.TypeVar('v2', 6) - assert not alpha_equal(v1, v2) + assert not sequal(v1, v2) v1 = relay.TypeVar('v1', 0) v2 = relay.TypeVar('v2', 6) - assert not alpha_equal(v1, v2) + assert not sequal(v1, v2) - assert alpha_equal_commutative(v1, v1) + assert sequal_commutative(v1, v1) -def test_type_node_incompatible_alpha_equal(): +def test_type_node_incompatible_sequal(): v1 = relay.TypeVar('v1', 6) v2 = relay.Var("v2") - assert not alpha_equal_commutative(v1, v2) + assert not sequal_commutative(v1, v2) -def test_expr_node_incompatible_alpha_equal(): +def test_expr_node_incompatible_sequal(): v1 = relay.Var("v1") v2 = relay.PatternVar(relay.Var("v2")) - assert not alpha_equal_commutative(v1, v2) + assert not sequal_commutative(v1, v2) -def test_var_alpha_equal(): +def test_var_sequal(): v1 = relay.Var("v1") v2 = relay.Var("v2") # normally only pointer equality - assert alpha_equal(v1, v1) - assert not alpha_equal(v1, v2) + assert sequal(v1, v1) + assert not sequal(v1, v2) # let node allows for setting the eq_map l1 = relay.Let(v1, relay.const(1), v1) l2 = relay.Let(v2, relay.const(1), v2) l3 = relay.Let(v1, relay.const(1), v2) - assert alpha_equal(l1, l2) - assert not alpha_equal(l1, l3) + assert sequal(l1, l2) + assert not sequal(l1, l3) # type annotations tt1 = relay.TensorType([], "int32") @@ -278,34 +279,34 @@ def test_var_alpha_equal(): l6 = relay.Let(v5, relay.const(1), v5) # same annotations - assert alpha_equal(l4, l5) + assert sequal(l4, l5) # different annotations - assert not alpha_equal(l4, l6) + assert not sequal(l4, l6) # one null annotation - assert not alpha_equal(l1, l4) + assert not sequal(l1, l4) -def test_global_var_alpha_equal(): +def test_global_var_sequal(): v1 = relay.GlobalVar("v1") v2 = relay.GlobalVar("v2") # only pointer equality suffices (smoke test) - assert alpha_equal(v1, v1) - assert not alpha_equal(v1, v2) + assert sequal(v1, v1) + assert not sequal(v1, v2) -def test_tuple_alpha_equal(): +def test_tuple_sequal(): v0 = relay.Var("v0") v1 = relay.Var("v1") v2 = relay.Var("v2") # unit value is a valid tuple - assert alpha_equal(relay.Tuple([]), relay.Tuple([])) + assert sequal(relay.Tuple([]), relay.Tuple([])) tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) - assert alpha_equal(tup, same) + assert sequal(tup, same) # use the eq_map @@ -315,33 +316,33 @@ def test_tuple_alpha_equal(): relay.Tuple([relay.const(4)])]), v2) - assert alpha_equal(let_tup, let_mapped) + assert sequal(let_tup, let_mapped) more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2]) - assert not alpha_equal(tup, more_fields) + assert not sequal(tup, more_fields) fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)]) - assert not alpha_equal(tup, fewer_fields) + assert not sequal(tup, fewer_fields) different_end = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(5)])]) - assert not alpha_equal(tup, different_end) + assert not sequal(tup, different_end) different_start = relay.Tuple([v2, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])]) - assert not alpha_equal(tup, different_start) + assert not sequal(tup, different_start) longer_at_end = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4), relay.const(5)])]) - assert not alpha_equal(tup, longer_at_end) + assert not sequal(tup, longer_at_end) -def test_tuple_get_item_alpha_equal(): +def test_tuple_get_item_sequal(): x = relay.Var('x') y = relay.Var('y') - assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1)) - assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2)) - assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) + assert not sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1)) + assert not sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2)) + assert sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) def test_function_attr(): @@ -364,10 +365,10 @@ def test_function_attr(): q10 = relay.multiply(p10, w12) func1 = relay.Function([x1, w10, w11, w12], q10) func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b")) - assert not alpha_equal(func0, func1) + assert not sequal(func0, func1) -def test_function_alpha_equal(): +def test_function_sequal(): tt1 = relay.TensorType((1, 2, 3), "float32") tt2 = relay.TensorType((4, 5, 6), "int8") tt3 = relay.TupleType([tt1, tt2]) @@ -389,58 +390,58 @@ def test_function_alpha_equal(): func = relay.Function([v1, v2], v1, tt2, basic_tps) mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps) - assert alpha_equal(func, mapped) + assert sequal(func, mapped) fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps) - assert not alpha_equal(func, fewer_params) + assert not sequal(func, fewer_params) more_params = relay.Function([relay.Var("v3", tt1), relay.Var("v4", tt2), relay.Var("v2", tt2)], v4, tt2, basic_tps) - assert not alpha_equal(func, more_params) + assert not sequal(func, more_params) params_unordered = relay.Function([v2, v1], v1, tt2, basic_tps) - assert not alpha_equal(func, params_unordered) + assert not sequal(func, params_unordered) params_mismatch = relay.Function([v1, v3], v1, tt2, basic_tps) - assert not alpha_equal(func, params_mismatch) + assert not sequal(func, params_mismatch) # also would not typecheck ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps) - assert not alpha_equal(func, ret_type_mismatch) + assert not sequal(func, ret_type_mismatch) # also mis-typed different_body = relay.Function(basic_args, v3, tt2, basic_tps) - assert not alpha_equal(func, different_body) + assert not sequal(func, different_body) fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1]) - assert not alpha_equal(func, fewer_type_params) + assert not sequal(func, fewer_type_params) more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3]) - assert not alpha_equal(func, more_type_params) + assert not sequal(func, more_type_params) type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1]) - assert not alpha_equal(func, type_params_unordered) + assert not sequal(func, type_params_unordered) different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4]) - assert not alpha_equal(func, different_type_params) + assert not sequal(func, different_type_params) # a well-typed example that also differs in body, ret type, and type params tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3) - assert not alpha_equal(func, tupled_example) + assert not sequal(func, tupled_example) # nullable no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2]) # both null - assert alpha_equal(no_ret_type, no_ret_type) + assert sequal(no_ret_type, no_ret_type) # one null - assert not alpha_equal(func, no_ret_type) - assert not alpha_equal(no_ret_type, func) + assert not sequal(func, no_ret_type) + assert not sequal(no_ret_type, func) -def test_call_alpha_equal(): +def test_call_sequal(): v1 = relay.Var("v1") v2 = relay.Var("v2") @@ -458,43 +459,43 @@ def test_call_alpha_equal(): call = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])], attr1, [tt1]) same = relay.Call(v1, basic_args, attr1, [tt1]) - assert alpha_equal(call, same) + assert sequal(call, same) different_fn = relay.Call(v2, basic_args, attr1, [tt1]) - assert not alpha_equal(call, different_fn) + assert not sequal(call, different_fn) fewer_args = relay.Call(v1, [relay.const(1), relay.const(2), v2], attr1, [tt1]) - assert not alpha_equal(call, fewer_args) + assert not sequal(call, fewer_args) reordered_args = relay.Call(v1, [relay.const(2), relay.const(1), relay.Tuple([]), v2], attr1, [tt1]) - assert not alpha_equal(call, reordered_args) + assert not sequal(call, reordered_args) different_args = relay.Call(v1, [relay.const(1), relay.const(2), relay.const(3)], attr1, [tt1]) - assert not alpha_equal(call, different_args) + assert not sequal(call, different_args) more_args = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([]), relay.const(3), relay.const(4)], attr1, [tt1]) - assert not alpha_equal(call, more_args) + assert not sequal(call, more_args) different_attrs = relay.Call(v1, basic_args, attr2, [tt1]) - assert not alpha_equal(call, different_attrs) + assert not sequal(call, different_attrs) same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1]) - assert alpha_equal(call, same_attrs) + assert sequal(call, same_attrs) no_type_args = relay.Call(v1, basic_args, attr1) - assert not alpha_equal(call, no_type_args) + assert not sequal(call, no_type_args) more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2]) - assert not alpha_equal(call, more_type_args) + assert not sequal(call, more_type_args) different_type_arg = relay.Call(v1, basic_args, attr1, [tt2]) - assert not alpha_equal(call, different_type_arg) + assert not sequal(call, different_type_arg) -def test_let_alpha_equal(): +def test_let_sequal(): tt1 = relay.TensorType((), "float32") tt2 = relay.TensorType((), "int8") v1 = relay.Var("v1") @@ -504,57 +505,57 @@ def test_let_alpha_equal(): let = relay.Let(v1, relay.const(2), v1) mapped = relay.Let(v2, relay.const(2), v2) - assert alpha_equal(let, mapped) + assert sequal(let, mapped) mismatched_var = relay.Let(v2, relay.const(2), v3) - assert not alpha_equal(let, mismatched_var) + assert not sequal(let, mismatched_var) different_value = relay.Let(v2, relay.const(3), v2) - assert not alpha_equal(let, different_value) + assert not sequal(let, different_value) different_body = relay.Let(v2, relay.const(3), relay.const(12)) - assert not alpha_equal(let, different_body) + assert not sequal(let, different_body) # specified types must match let_with_type = relay.Let(v1_wtype, relay.const(2), v1_wtype) same_type = relay.Let(v1_wtype, relay.const(2), v1_wtype) - assert alpha_equal(let_with_type, same_type) - assert not alpha_equal(let, let_with_type) + assert sequal(let_with_type, same_type) + assert not sequal(let, let_with_type) v2 = relay.Var("v1", tt2) different_type = relay.Let(v2, relay.const(2), v2) - assert not alpha_equal(let_with_type, different_type) + assert not sequal(let_with_type, different_type) -def test_if_alpha_equal(): +def test_if_sequal(): v1 = relay.Var("v1") v2 = relay.Var("v2") if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) - assert alpha_equal(if_sample, same) + assert sequal(if_sample, same) different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)])) - assert not alpha_equal(if_sample, different_cond) + assert not sequal(if_sample, different_cond) different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2), relay.const(3)])) - assert not alpha_equal(if_sample, different_true) + assert not sequal(if_sample, different_true) different_false = relay.If(v1, relay.const(1), relay.Tuple([])) - assert not alpha_equal(if_sample, different_false) + assert not sequal(if_sample, different_false) -def test_constructor_alpha_equal(): +def test_constructor_sequal(): # smoke test: it should be pointer equality mod = tvm.IRModule() p = relay.prelude.Prelude(mod) - assert alpha_equal(p.nil, p.nil) - assert alpha_equal(p.cons, p.cons) - assert not alpha_equal(p.nil, p.cons) + assert sequal(p.nil, p.nil) + assert sequal(p.cons, p.cons) + assert not sequal(p.nil, p.cons) -def test_match_alpha_equal(): +def test_match_sequal(): mod = tvm.IRModule() p = relay.prelude.Prelude(mod) @@ -604,27 +605,28 @@ def test_match_alpha_equal(): p.cons(x, p.nil())) ]) - assert alpha_equal(match, match) - assert alpha_equal(match, equivalent) - assert not alpha_equal(match, no_cons) - assert not alpha_equal(match, no_nil) - assert not alpha_equal(match, empty) - assert not alpha_equal(match, different_data) - assert not alpha_equal(match, different_order) - assert not alpha_equal(match, different_nil) - assert not alpha_equal(match, different_cons) - assert not alpha_equal(match, another_case) - assert not alpha_equal(match, wrong_constructors) - - -def test_op_alpha_equal(): + tvm.ir.assert_structural_equal(match, match) + assert sequal(match, match) + assert sequal(match, equivalent) + assert not sequal(match, no_cons) + assert not sequal(match, no_nil) + assert not sequal(match, empty) + assert not sequal(match, different_data) + assert not sequal(match, different_order) + assert not sequal(match, different_nil) + assert not sequal(match, different_cons) + assert not sequal(match, another_case) + assert not sequal(match, wrong_constructors) + + +def test_op_sequal(): # only checks names op1 = relay.op.get("add") op2 = relay.op.get("add") - assert alpha_equal(op1, op2) + assert sequal(op1, op2) op3 = relay.op.get("take") - assert not alpha_equal(op1, op3) + assert not sequal(op1, op3) def test_graph_equal(): @@ -638,14 +640,14 @@ def test_graph_equal(): z3 = relay.add(relay.add(x, x), relay.add(x, x)) - assert alpha_equal(z0, z1) - assert alpha_equal(z0, z1) + assert sequal(z0, z1) + assert sequal(z0, z1) # z3's dataflow format is different from z0 # z0 is computed from a common y0 node # Relay view them as different programs # Check the difference in the text format. - assert not alpha_equal(z0, z3) + assert not sequal(z0, z3) def test_hash_unequal(): x1 = relay.var("x1", shape=(10, 10), dtype="float32") @@ -677,7 +679,7 @@ def test_tuple_match(): b = relay.Var("b") clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) - assert analysis.alpha_equal(x, y) + assert sequal(x, y) assert analysis.structural_hash(x) == analysis.structural_hash(y) @@ -697,34 +699,34 @@ def test_fn_attribute(): add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test")) add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType()) - assert not relay.analysis.alpha_equal(add_1_fn, add_fn) - assert not relay.analysis.alpha_equal(add_fn, add_1_fn) + assert not sequal(add_1_fn, add_fn) + assert not sequal(add_fn, add_1_fn) if __name__ == "__main__": - test_tensor_type_alpha_equal() - test_incomplete_type_alpha_equal() - test_constant_alpha_equal() - test_type_node_alpha_equal() - test_type_node_incompatible_alpha_equal() - test_expr_node_incompatible_alpha_equal() - test_func_type_alpha_equal() - test_tuple_type_alpha_equal() - test_type_relation_alpha_equal() - test_type_call_alpha_equal() - test_constant_alpha_equal() - test_global_var_alpha_equal() - test_tuple_alpha_equal() - test_tuple_get_item_alpha_equal() - test_function_alpha_equal() + test_tensor_type_sequal() + test_incomplete_type_sequal() + test_constant_sequal() + test_type_node_sequal() + test_type_node_incompatible_sequal() + test_expr_node_incompatible_sequal() + test_func_type_sequal() + test_tuple_type_sequal() + test_type_relation_sequal() + test_type_call_sequal() + test_constant_sequal() + test_global_var_sequal() + test_tuple_sequal() + test_tuple_get_item_sequal() + test_function_sequal() test_function_attr() - test_call_alpha_equal() - test_let_alpha_equal() - test_if_alpha_equal() - test_constructor_alpha_equal() - test_match_alpha_equal() - test_op_alpha_equal() - test_var_alpha_equal() + test_call_sequal() + test_let_sequal() + test_if_sequal() + test_constructor_sequal() + test_match_sequal() + test_op_sequal() + test_var_sequal() test_graph_equal() test_hash_unequal() test_fn_attribute() diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 604ec8969ef7..3a0bf1feccc2 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -57,14 +57,14 @@ def run_opt_pass(expr, opt_pass): def test_let(): orig = relay.Let(e.x, e.y, e.z) orig = run_opt_pass(orig, transform.DeadCodeElimination()) - assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z)) + assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z)) def test_used_let(): orig = relay.Let(e.c, e.one, e.c + e.c) orig = run_opt_pass(orig, transform.DeadCodeElimination()) expected = relay.Let(e.c, e.one, e.c + e.c) - assert alpha_equal(Function([e.c], orig), Function([e.c], expected)) + assert tvm.ir.structural_equal(Function([], orig), Function([], expected)) def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) @@ -75,7 +75,7 @@ def test_inline(): def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) orig = run_opt_pass(orig, transform.DeadCodeElimination()) - assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e)) + assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.e], e.e)) def use_f(func): @@ -111,13 +111,13 @@ def test_recursion_dead(): x = relay.Let(e.a, e.one, e.three) dced_f = lambda f: x dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination()) - assert alpha_equal(dced, e.three) + assert tvm.ir.structural_equal(dced, e.three) def test_op_let(): dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination()) - assert alpha_equal(dced, add(e.three, e.two)) + assert tvm.ir.structural_equal(dced, add(e.three, e.two)) def test_tuple_get_item(): @@ -126,10 +126,10 @@ def test_tuple_get_item(): a = relay.Var('a') g = relay.TupleGetItem(t, 0) dced = run_opt_pass(g, transform.DeadCodeElimination()) - assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) + assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) dced = run_opt_pass(orig, transform.DeadCodeElimination()) - assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) + assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) @pytest.mark.timeout(timeout=10, method="thread") diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index f54dd6bf69c5..1299084ef740 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -72,7 +72,7 @@ def test_tuple(): f = Function([x], body, None, [t]) expected = relay.Function([x], x, None, [t]) expected = run_opt_pass(expected, transform.InferType()) - assert alpha_equal(dcpe(f), expected) + assert tvm.ir.structural_equal(dcpe(f), expected) def test_const_inline(): @@ -80,7 +80,7 @@ def test_const_inline(): d = Var("d", t) double = Function([d], d + d) orig = double(const(4.0)) - assert alpha_equal(dcpe(orig), const(8.0)) + assert tvm.ir.structural_equal(dcpe(orig), const(8.0)) def test_ref(): @@ -93,7 +93,7 @@ def test_ref(): body = Let(r, RefCreate(d), body) square = Function([d], body) expected = run_opt_pass(Function([d], d * d), transform.InferType()) - assert alpha_equal(dcpe(square), expected) + assert tvm.ir.structural_equal(dcpe(square), expected) def test_empty_ad(): @@ -105,7 +105,7 @@ def test_empty_ad(): g = dcpe(f, grad=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = run_opt_pass(expected, transform.InferType()) - assert alpha_equal(g, expected) + assert tvm.ir.structural_equal(g, expected) def test_ad(): @@ -180,7 +180,7 @@ def test_head_cons(): body = hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) - assert alpha_equal(res, Function([x], x, t, [t])) + assert tvm.ir.structural_equal(res, Function([x], x, t, [t])) def test_map(): @@ -197,7 +197,7 @@ def test_map(): expected = mod["main"] orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res.body, expected.body) + assert tvm.ir.structural_equal(res.body, expected.body) def test_loop(): @@ -211,7 +211,7 @@ def test_loop(): expected = mod["main"].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) - assert alpha_equal(res.body, expected) + assert tvm.ir.structural_equal(res.body, expected) def test_swap_loop(): @@ -226,7 +226,7 @@ def test_swap_loop(): prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = Function([], prog) res = dcpe(res, mod=mod) - assert alpha_equal(prog, res.body) + assert tvm.ir.structural_equal(prog, res.body) def test_abs_diff(): @@ -248,7 +248,7 @@ def test_abs_diff(): orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res.body, make_nat_expr(p, 4)) + assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4)) def test_match_nat_id(): @@ -265,7 +265,7 @@ def test_match_nat_id(): orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res.body, make_nat_expr(p, 3)) + assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3)) def test_nat_id(): @@ -280,7 +280,7 @@ def test_nat_id(): orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res.body, make_nat_expr(p, 3)) + assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3)) def test_global_match_nat_id(): @@ -294,7 +294,7 @@ def test_global_match_nat_id(): orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res.body, make_nat_expr(p, 3)) + assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3)) def test_double(): @@ -304,7 +304,7 @@ def test_double(): orig = p.double(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res.body, make_nat_expr(p, 6)) + assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6)) def test_concat(): diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index ed05096aec29..b1648211002c 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -134,7 +134,7 @@ def _get_mod(data_dtype, kernel_dtype): # Since same dtype, there should not be any transformation with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): legalized_mod = relay.qnn.transform.Legalize()(mod) - assert alpha_equal(mod, legalized_mod) + assert tvm.ir.structural_equal(mod, legalized_mod) ################################################################ # Check transformations for platforms without fast Int8 support. @@ -157,7 +157,7 @@ def _get_mod(data_dtype, kernel_dtype): # Check no transformation for Intel VNNI. with tvm.target.create('llvm -mcpu=skylake-avx512'): legalized_mod = relay.qnn.transform.Legalize()(mod) - assert alpha_equal(mod, legalized_mod) + assert tvm.ir.structural_equal(mod, legalized_mod) # ARM - so check that transformation has happened. with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): @@ -221,7 +221,7 @@ def _get_mod(data_dtype, kernel_dtype): # Since same dtype, there should not be any transformation with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): legalized_mod = relay.qnn.transform.Legalize()(mod) - assert alpha_equal(mod, legalized_mod) + assert tvm.ir.structural_equal(mod, legalized_mod) ################################################################ # Check transformations for platforms without fast Int8 support. @@ -244,7 +244,7 @@ def _get_mod(data_dtype, kernel_dtype): # Check no transformation for Intel VNNI. with tvm.target.create('llvm -mcpu=skylake-avx512'): legalized_mod = relay.qnn.transform.Legalize()(mod) - assert alpha_equal(mod, legalized_mod) + assert tvm.ir.structural_equal(mod, legalized_mod) # ARM - so check that transformation has happened. with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'): diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 2a6103ea1fbe..29818f8e136a 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -76,7 +76,7 @@ def test_order(): expected_output = relay.Let(b, y, expected_output) expected_output = relay.Let(a, x, expected_output) expected_output = run_opt_pass(expected_output, transform.InferType()) - assert alpha_equal(anf, expected_output) + assert tvm.ir.structural_equal(anf, expected_output) def test_if(): @@ -93,7 +93,7 @@ def test_if(): expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(c, cond, expected_output) expected_output = run_opt_pass(expected_output, transform.InferType()) - assert alpha_equal(anf, expected_output) + assert tvm.ir.structural_equal(anf, expected_output) # make sure we dont infinite loop. diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index e2ac924e9661..4aaa9a0f0e9a 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -17,7 +17,7 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.analysis import alpha_equal, detect_feature +from tvm.relay.analysis import detect_feature from tvm.relay.transform import to_cps, un_cps from tvm.relay.analysis import Feature from tvm.relay.prelude import Prelude diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 74507baa1096..45916180c1d3 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -21,7 +21,6 @@ from tvm import te from tvm import relay from tvm.relay import op, transform, analysis -from tvm.relay.analysis import assert_alpha_equal def run_infer_type(expr, mod=None): @@ -360,7 +359,7 @@ def test_let_polymorphism(): body = relay.Let(id, relay.Function([x], x, xt, [xt]), body) body = run_infer_type(body) int32 = relay.TensorType((), "int32") - assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) + tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) if __name__ == "__main__": diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index a25ba0ab42f0..f2848ff0ef50 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -25,7 +25,7 @@ def test_const_saveload_json(): z = z + z json_str = tvm.ir.save_json(z) zz = tvm.ir.load_json(json_str) - assert tvm.ir.save_json(zz) == tvm.ir.save_json(z) + tvm.ir.assert_structural_equal(zz, z, map_free_vars=True) def test_make_smap(): @@ -38,6 +38,7 @@ def test_make_smap(): arr = tvm.ir.load_json(json_str) assert len(arr) == 1 assert arr[0]["z"].a == arr[0]["x"] + tvm.ir.assert_structural_equal(arr, [smap], map_free_vars=True) def test_make_node(): @@ -90,7 +91,6 @@ def test(x): if __name__ == "__main__": test_env_func() - test_make_attrs() test_make_node() test_make_smap() test_const_saveload_json() diff --git a/tests/python/unittest/test_tir_structural_equal.py b/tests/python/unittest/test_tir_structural_equal.py new file mode 100644 index 000000000000..26f3085f8df7 --- /dev/null +++ b/tests/python/unittest/test_tir_structural_equal.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import pytest +from tvm import te + + +def test_exprs(): + # save load json + x = tvm.tir.const(1, "int32") + y = tvm.tir.const(10, "int32") + vx = te.var("x") + vy = te.var("y") + vz = te.var("z") + + # test assert trigger. + with pytest.raises(ValueError): + tvm.ir.assert_structural_equal(x, y) + + assert not tvm.ir.structural_equal(vx, vy) + assert tvm.ir.structural_equal(vx, vy, map_free_vars=True) + # corner case lhs:vx == rhs:vy, but cannot map it iteslf + assert not tvm.ir.structural_equal(vx + vx, vy + vx, map_free_vars=True) + # corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx + assert tvm.ir.structural_equal(vx + vy, vy + vx, map_free_vars=True) + # corner case2: rolling remap. + assert tvm.ir.structural_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True) + assert not tvm.ir.structural_equal(vx + 1, vy + 1, map_free_vars=False) + # Defintition remap + assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx - 1), + tvm.tir.Let(vy, 1, vy - 1)) + # Default same address free var remap + assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx // vz), + tvm.tir.Let(vy, 1, vy // vz)) + + zx = vx + vx + zy = vy + vy + assert tvm.ir.structural_equal(zx * zx, zx * zx) + assert tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=True) + assert not tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=False) + assert tvm.ir.structural_equal(zx * zx, (vx + vx) * (vx + vx), + map_free_vars=False) + + +def test_prim_func(): + x = te.var('x') + y = te.var('y') + # counter example of same equality + func0 = tvm.tir.PrimFunc( + [x, y], tvm.tir.Evaluate(x + y)) + func1 = tvm.tir.PrimFunc( + [x, y], tvm.tir.Evaluate(y + x)) + assert not tvm.ir.structural_equal(func0, func1) + + # new cases + b = tvm.tir.decl_buffer((x,), "float32") + stmt = tvm.tir.LetStmt( + x, 10, tvm.tir.Evaluate(x + 1)) + func0 = tvm.tir.PrimFunc( + [x, y, b], stmt) + # easiest way to deep copy is via save/load + func1 = tvm.ir.load_json(tvm.ir.save_json(func0)) + tvm.ir.assert_structural_equal(func0, func1) + + data0 = tvm.nd.array([1, 2, 3]) + data1 = tvm.nd.array([1, 2, 3]) + # attributes and ndarrays + func0 = func0.with_attr("data", data0) + func1 = func1.with_attr("data", data1) + # IRModules + mod0 = tvm.IRModule.from_expr(func0) + mod1 = tvm.IRModule.from_expr(func1) + tvm.ir.assert_structural_equal(mod0, mod1) + + +def test_attrs(): + x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx") + y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx") + z = tvm.ir.make_node("attrs.TestAttrs", axis=2, name="xx") + tvm.ir.assert_structural_equal(y, x) + assert not tvm.ir.structural_equal(y, z) + + + +if __name__ == "__main__": + test_exprs() + test_prim_func() + test_attrs()