Skip to content

Commit

Permalink
[NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154)
Browse files Browse the repository at this point in the history
* [NODE][IR] Introduce StructuralEqual Infra for the Unified IR.

This PR introduces a new way to handle structural equality
for both TIR and relay nodes in an extensive way.

- Each object can now register an optional SEqualReduce function, which
  describes how to reduce its structural equality to another instance
  into equality of the children.
- Optionally, the object can choose to allow remapping of vars(e.g. function parameters)
  by calling DefEqual
- We implemented a non-recursive structural equality checker that
  recursively traverses the objects and does the structural equality checking.

This PR also fixes a few potential problems in previous relay's AlphaEqual.

- In particular, the new structural equality relation will be communicative.
- It is can be dangerous to use same_as relation to quickly check equality,
  demonstrated by the following case. (%x, %y) are shared vars between two functions.

- function0: fn (%x, %y) { %x + %y }
- function1: fn (%y, %x) { %x + %y }

The new structural equal is intented to supersede AlphaEqual and AttrsEqual.

Follow-up PRs should be performed to redirect the existing usages, and removes
the corresponding implementation.

* Update the rule to distinguish between graph node and non-graph nodes.

* Refactor the test cases to use structural equal.

* address comments

* Mark more relay::Expr as graph node, fix a testcase issue(was bug that was not caught by previous alpha equal)

* Remove unrelated comment

* Fix file comment

* Address review comment

* Relax condition to fit flaky case
  • Loading branch information
tqchen authored Mar 28, 2020
1 parent 9c80662 commit 997a14e
Show file tree
Hide file tree
Showing 46 changed files with 1,781 additions and 271 deletions.
8 changes: 8 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class ConstIntBoundNode : public Object {
v->Visit("max_value", &max_value);
}

bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
return equal(min_value, other->min_value) && equal(max_value, other->max_value);
}

/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*!
Expand Down Expand Up @@ -170,6 +174,10 @@ class ModularSetNode : public Object {
v->Visit("base", &base);
}

bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
return equal(coeff, other->coeff) && equal(base, other->base);
}

static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ enum SignType {
class IntSetNode : public Object {
public:
static constexpr const char* _type_key = "IntSet";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
};

Expand Down
15 changes: 15 additions & 0 deletions include/tvm/ir/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ class ConstructorNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}

bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
// Use namehint for now to be consistent with the legacy relay impl
// TODO(tvm-team) revisit, need to check the type var.
return
equal(name_hint, other->name_hint) &&
equal(inputs, other->inputs);
}

static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
Expand Down Expand Up @@ -108,6 +116,13 @@ class TypeDataNode : public TypeNode {
v->Visit("span", &span);
}

bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
return
equal.DefEqual(header, other->header) &&
equal.DefEqual(type_vars, other->type_vars) &&
equal(constructors, other->constructors);
}

static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
Expand Down
41 changes: 41 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ class AttrFieldInfoNode : public Object {
v->Visit("type_info", &type_info);
v->Visit("description", &description);
}

static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};

Expand Down Expand Up @@ -278,6 +280,7 @@ class BaseAttrsNode : public Object {
*/
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
Expand All @@ -302,6 +305,10 @@ class DictAttrsNode : public BaseAttrsNode {
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;

bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
return equal(dict, other->dict);
}

// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
Expand Down Expand Up @@ -401,6 +408,33 @@ class AttrsEqualVisitor {
const AttrsEqual& equal_;
};

class AttrsSEqualVisitor {
public:
bool result_{true};
// constructor
AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}

private:
const Object* lhs_;
const Object* rhs_;
const SEqualReducer& equal_;
};

class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
Expand Down Expand Up @@ -817,6 +851,13 @@ class AttrsNode : public BaseAttrsNode {
}
}

bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
DerivedType* pself = self();
::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}

Array<AttrFieldInfo> ListFieldInfo() const final {
::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ class EnvFuncNode : public Object {
v->Visit("name", &name);
}

bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
return this == other;
}

static constexpr const char* _type_key = "EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};

Expand Down
21 changes: 21 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace tvm {
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

Expand Down Expand Up @@ -197,6 +198,13 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}

bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
// name matters for global var.
return
equal(name_hint, other->name_hint) &&
equal.FreeVarEqualImpl(this, other);
}

static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
Expand Down Expand Up @@ -228,6 +236,10 @@ class IntImmNode : public PrimExprNode {
v->Visit("value", &value);
}

bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}

static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
Expand Down Expand Up @@ -263,6 +275,10 @@ class FloatImmNode : public PrimExprNode {
v->Visit("value", &value);
}

bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}

static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
Expand Down Expand Up @@ -353,7 +369,12 @@ class RangeNode : public Object {
v->Visit("extent", &extent);
}

bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
return equal(min, other->min) && equal(extent, other->extent);
}

static constexpr const char* _type_key = "Range";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};

Expand Down
3 changes: 3 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_);
}

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;

/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
Expand Down Expand Up @@ -235,6 +237,7 @@ class IRModuleNode : public Object {
TVM_DLL std::unordered_set<std::string> Imports() const;

static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);

private:
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class OpNode : public RelayExprNode {
v->Visit("support_level", &support_level);
}

bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
// pointer equality is fine as there is only one op with the same name.
return this == other;
}

/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/ir/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class SourceNameNode : public Object {
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }

bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const {
return equal(name, other->name);
}

static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
Expand Down Expand Up @@ -87,6 +91,13 @@ class SpanNode : public Object {
v->Visit("col_offset", &col_offset);
}

bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
return
equal(source, other->source) &&
equal(lineno, other->lineno) &&
equal(col_offset, other->col_offset);
}

TVM_DLL static Span make(SourceName source, int lineno, int col_offset);

static constexpr const char* _type_key = "Span";
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/ir/tensor_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}

bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
return
equal(shape, other->shape) &&
equal(dtype, other->dtype);
}

/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class PassContextNode : public Object {
}

static constexpr const char* _type_key = "transform.PassContext";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};

Expand Down Expand Up @@ -207,6 +208,7 @@ class PassInfoNode : public Object {
}

static constexpr const char* _type_key = "transform.PassInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};

Expand Down
43 changes: 43 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class TypeNode : public Object {
mutable Span span;

static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};

Expand Down Expand Up @@ -110,6 +111,10 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype);
}

bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}

static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
Expand Down Expand Up @@ -152,6 +157,10 @@ class PointerTypeNode : public TypeNode {
v->Visit("element_type", &element_type);
}

bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
return equal(element_type, other->element_type);
}

static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
Expand Down Expand Up @@ -218,6 +227,12 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span);
}

bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
return
equal(kind, other->kind) &&
equal.FreeVarEqualImpl(this, other);
}

static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
Expand Down Expand Up @@ -258,6 +273,13 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind);
}

bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const {
// name matters for now in global type var.
return
equal(name_hint, other->name_hint) &&
equal.FreeVarEqualImpl(this, other);
}

static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
Expand Down Expand Up @@ -294,6 +316,10 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span);
}

bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const {
return equal(fields, other->fields);
}

static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
Expand Down Expand Up @@ -386,6 +412,15 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}

bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
// type params first as they defines type vars.
return
equal.DefEqual(type_params, other->type_params) &&
equal(arg_types, other->arg_types) &&
equal(ret_type, other->ret_type) &&
equal(type_constraints, other->type_constraints);
}

static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
Expand Down Expand Up @@ -432,6 +467,10 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span);
}

bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind);
}

static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
Expand Down Expand Up @@ -469,6 +508,10 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span);
}

bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const {
return equal(value, other->value);
}

// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
Expand Down
Loading

0 comments on commit 997a14e

Please sign in to comment.