Skip to content

Commit

Permalink
remove AttrsEqual and AttrsHash related code (#5169)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored Mar 30, 2020
1 parent a2edd01 commit 6536b35
Show file tree
Hide file tree
Showing 16 changed files with 29 additions and 573 deletions.
171 changes: 3 additions & 168 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

#include <dmlc/common.h>
#include <tvm/ir/expr.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/packed_func.h>

#include <unordered_map>
Expand Down Expand Up @@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
};

class AttrsHashHandler;
class AttrsEqualHandler;
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}

bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;

protected:
friend class AttrsEqualHandler;
/*! \brief internal handle. */
AttrsEqualHandler* handler_{nullptr};
};

/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const DataType& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const ObjectRef& value) const;

private:
friend class AttrsHashHandler;
/*! \brief internal handle. */
AttrsHashHandler* handler_{nullptr};
};

/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
Expand Down Expand Up @@ -266,20 +179,6 @@ class BaseAttrsNode : public Object {
* \note This function throws when the required field is not present.
*/
TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
/*!
* \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node.
* \param equal The equal comparator
* \return The comparison result.
*/
TVM_DLL virtual bool ContentEqual(
const Object* other, AttrsEqual equal) const = 0;
/*!
* \brief Content aware hash.
* \param hasher The hasher to run the hash.
* \return the hash result.
*/
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down Expand Up @@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Object* other, AttrsEqual equal) const final;
size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
Expand Down Expand Up @@ -386,34 +283,6 @@ class AttrNormalVisitor {
AttrVisitor* visitor_;
};

// Wrapper for normal visitor.
class AttrsEqualVisitor {
public:
bool result_{true};
// constructor
AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}

private:
const Object* lhs_;
const Object* rhs_;
const AttrsEqual& equal_;
};

class AttrsSEqualVisitor {
public:
bool result_{true};
Expand Down Expand Up @@ -441,23 +310,6 @@ class AttrsSEqualVisitor {
const SEqualReducer& equal_;
};

class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
: hasher_(hasher) {}

size_t result_{0};

template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
result_ = dmlc::HashCombine(result_, hasher_(*value));
return AttrNopEntry();
}

private:
const AttrsHash& hasher_;
};

class AttrsSHashVisitor {
public:
explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
Expand Down Expand Up @@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry {
return *this;
}
TSelf& set_default(const T& value) {
if (AttrsEqual()(value, *data_)) {
if (tvm::StructuralEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
Expand Down Expand Up @@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_;
}

bool ContentEqual(const Object* other, AttrsEqual equal) const final {
DerivedType* pself = self();
if (pself == other) return true;
if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false;
::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}

size_t ContentHash(AttrsHash hasher) const final {
::tvm::detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = this->GetTypeKeyHash();
self()->__VisitAttrs__(visitor);
return visitor.result_;
}

private:
DerivedType* self() const {
return const_cast<DerivedType*>(
Expand Down
89 changes: 0 additions & 89 deletions src/ir/attr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,94 +147,5 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
}
};

class AttrsEqualHandler :
protected AttrFunctor<bool(const ObjectRef&, const ObjectRef&)> {
public:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs);

protected:
bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final;
};

class AttrsHashHandler :
protected AttrFunctor<size_t(const ObjectRef&)> {
public:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t Hash(const ObjectRef& node) {
if (!node.defined()) return 0;
return this->VisitAttr(node);
}

protected:
size_t VisitAttrDefault_(const Object* lhs) final;
size_t VisitAttr_(const tir::IntImmNode* lhs) final;
size_t VisitAttr_(const tir::FloatImmNode* lhs) final;
size_t VisitAttr_(const tir::StringImmNode* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
size_t VisitAttr_(const tir::AddNode* op) final;
size_t VisitAttr_(const tir::SubNode* op) final;
size_t VisitAttr_(const tir::MulNode* op) final;
size_t VisitAttr_(const tir::DivNode* op) final;
size_t VisitAttr_(const tir::ModNode* op) final;
size_t VisitAttr_(const tir::FloorDivNode* op) final;
size_t VisitAttr_(const tir::FloorModNode* op) final;
size_t VisitAttr_(const tir::MinNode* op) final;
size_t VisitAttr_(const tir::MaxNode* op) final;
size_t VisitAttr_(const tir::GENode* op) final;
size_t VisitAttr_(const tir::GTNode* op) final;
size_t VisitAttr_(const tir::LENode* op) final;
size_t VisitAttr_(const tir::LTNode* op) final;
size_t VisitAttr_(const tir::EQNode* op) final;
size_t VisitAttr_(const tir::NENode* op) final;
size_t VisitAttr_(const tir::AndNode* op) final;
size_t VisitAttr_(const tir::OrNode* op) final;
size_t VisitAttr_(const tir::NotNode* op) final;
size_t VisitAttr_(const tir::CastNode* op) final;
size_t VisitAttr_(const tir::CallNode* op) final;
size_t VisitAttr_(const tir::SelectNode* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static size_t Combine(size_t lhs, size_t rhs) {
return dmlc::HashCombine(lhs, rhs);
}
};
} // namespace tvm
#endif // TVM_IR_ATTR_FUNCTOR_H_
Loading

0 comments on commit 6536b35

Please sign in to comment.