Skip to content

Commit

Permalink
[NODE][IR] Introduce StructuralHash for the Unified IR. (apache#5160)
Browse files Browse the repository at this point in the history
* [NODE][IR] Introduce StructuralHash for the Unified IR.

This PR introduces a new way to handle structural hash for the unified IR.

- Each object can now register an optional SEqualHash function, which
  describes how to reduce its structural equality to sequence of hash values.
- Optionally, the object can choose to allow labeling of vars(e.g. function parameters)
  by calling DefHash
- We implemented a non-recursive structural hasher that maintains its own stack
  to traverse te IR.

This PR also improves the hash value property from the previous relay's hash utility.
In particular, the graph node mode hashs a DAG differently from a tree
by attaching an unique occurence index to each graph node.

In all of the test cases so far, structural_hash is consistent with structural_equal.
- if structrual(x, y) then structural_hash(x) == structural_hash(y)
- if structural_hash(x) == structural_hash(y) then highly likely structural_equal(x, y)
  - hash no collison is found in our testcases.

Ideally we should work on automatically generating these functions in the future.

* Fix cases for EnvFunc and Array dims

* fix testcase

* Update src/node/structural_hash.cc

Co-Authored-By: 雾雨魔理沙 <[email protected]>

Co-authored-by: 雾雨魔理沙 <[email protected]>
  • Loading branch information
2 people authored and Trevor Morris committed Apr 16, 2020
1 parent fdb4ca9 commit 0dbfa14
Show file tree
Hide file tree
Showing 29 changed files with 1,407 additions and 154 deletions.
11 changes: 11 additions & 0 deletions include/tvm/ir/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class ConstructorNode : public RelayExprNode {
equal(inputs, other->inputs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce(inputs);
}

static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
Expand Down Expand Up @@ -123,6 +128,12 @@ class TypeDataNode : public TypeNode {
equal(constructors, other->constructors);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(header);
hash_reduce.DefHash(type_vars);
hash_reduce(constructors);
}

static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
Expand Down
26 changes: 26 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class AttrFieldInfoNode : public Object {

static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};

Expand Down Expand Up @@ -281,6 +282,7 @@ class BaseAttrsNode : public Object {
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
Expand Down Expand Up @@ -309,6 +311,10 @@ class DictAttrsNode : public BaseAttrsNode {
return equal(dict, other->dict);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dict);
}

// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
Expand Down Expand Up @@ -452,6 +458,21 @@ class AttrsHashVisitor {
const AttrsHash& hasher_;
};

class AttrsSHashVisitor {
public:
explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
: hash_reducer_(hash_reducer) {}

template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
hash_reducer_(*value);
return AttrNopEntry();
}

private:
const SHashReducer& hash_reducer_;
};

// helper entry that does initialization, set default.
template<typename T>
struct AttrInitEntry {
Expand Down Expand Up @@ -858,6 +879,11 @@ class AttrsNode : public BaseAttrsNode {
return visitor.result_;
}

void SHashReduce(SHashReducer hash_reducer) const {
::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
self()->__VisitAttrs__(visitor);
}

Array<AttrFieldInfo> ListFieldInfo() const final {
::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
Expand Down
9 changes: 8 additions & 1 deletion include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,18 @@ class EnvFuncNode : public Object {
}

bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
return this == other;
// name uniquely identifies the env function.
return name == other->name;
}

void SHashReduce(SHashReducer hash_reduce) const {
// Name uniquely identifies the env function.
hash_reduce(name);
}

static constexpr const char* _type_key = "EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};

Expand Down
22 changes: 22 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

Expand Down Expand Up @@ -205,6 +206,11 @@ class GlobalVarNode : public RelayExprNode {
equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}

static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
Expand Down Expand Up @@ -240,6 +246,11 @@ class IntImmNode : public PrimExprNode {
return equal(dtype, other->dtype) && equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
}

static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
Expand Down Expand Up @@ -279,6 +290,11 @@ class FloatImmNode : public PrimExprNode {
return equal(dtype, other->dtype) && equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
}

static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
Expand Down Expand Up @@ -373,8 +389,14 @@ class RangeNode : public Object {
return equal(min, other->min) && equal(extent, other->extent);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(min);
hash_reduce(extent);
}

static constexpr const char* _type_key = "Range";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};

Expand Down
3 changes: 3 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class IRModuleNode : public Object {

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;

TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;

/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
Expand Down Expand Up @@ -238,6 +240,7 @@ class IRModuleNode : public Object {

static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);

private:
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class OpNode : public RelayExprNode {
return this == other;
}

void SHashReduce(SHashReducer hash_reduce) const {
// Name uniquely identifies an Op.
hash_reduce(name);
}

/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir/tensor_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class TensorTypeNode : public BaseTensorTypeNode {
equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(shape);
hash_reduce(dtype);
}

/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
*/
Expand Down
38 changes: 38 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class TypeNode : public Object {

static constexpr const char* _type_key = "Type";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};

Expand Down Expand Up @@ -115,6 +116,10 @@ class PrimTypeNode : public TypeNode {
return equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
}

static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
Expand Down Expand Up @@ -161,6 +166,10 @@ class PointerTypeNode : public TypeNode {
return equal(element_type, other->element_type);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(element_type);
}

static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
Expand Down Expand Up @@ -233,6 +242,11 @@ class TypeVarNode : public TypeNode {
equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(kind);
hash_reduce.FreeVarHashImpl(this);
}

static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
Expand Down Expand Up @@ -280,6 +294,11 @@ class GlobalTypeVarNode : public TypeNode {
equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name_hint);
hash_reduce.FreeVarHashImpl(this);
}

static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
Expand Down Expand Up @@ -320,6 +339,10 @@ class TupleTypeNode : public TypeNode {
return equal(fields, other->fields);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(fields);
}

static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
Expand Down Expand Up @@ -421,6 +444,13 @@ class FuncTypeNode : public TypeNode {
equal(type_constraints, other->type_constraints);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(type_params);
hash_reduce(arg_types);
hash_reduce(ret_type);
hash_reduce(type_constraints);
}

static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
Expand Down Expand Up @@ -471,6 +501,10 @@ class IncompleteTypeNode : public TypeNode {
return equal(kind, other->kind);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(kind);
}

static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
Expand Down Expand Up @@ -512,6 +546,10 @@ class RelayRefTypeNode : public TypeNode {
return equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(value);
}

// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class TypeCallNode : public TypeNode {
equal(args, other->args);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(args);
}

static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
Expand Down Expand Up @@ -209,6 +214,13 @@ class TypeRelationNode : public TypeConstraintNode {
equal(attrs, other->attrs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(func);
hash_reduce(args);
hash_reduce(num_inputs);
hash_reduce(attrs);
}

static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <tvm/node/repr_printer.h>
#include <tvm/node/container.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>

#include <string>
#include <vector>
Expand Down
Loading

0 comments on commit 0dbfa14

Please sign in to comment.