Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][IR] Remove AttrsEqual and AttrsHash related code #5169

Merged
merged 1 commit into from
Mar 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 3 additions & 168 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

#include <dmlc/common.h>
#include <tvm/ir/expr.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/packed_func.h>

#include <unordered_map>
Expand Down Expand Up @@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
};

class AttrsHashHandler;
class AttrsEqualHandler;
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}

bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
return lhs == rhs;
}
bool operator()(const int& lhs, const int& rhs) const {
return lhs == rhs;
}
bool operator()(const bool& lhs, const bool& rhs) const {
return lhs == rhs;
}
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;

protected:
friend class AttrsEqualHandler;
/*! \brief internal handle. */
AttrsEqualHandler* handler_{nullptr};
};

/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class AttrsHash {
public:
size_t operator()(const double& value) const {
return std::hash<double>()(value);
}
size_t operator()(const int64_t& value) const {
return std::hash<int64_t>()(value);
}
size_t operator()(const uint64_t& value) const {
return std::hash<uint64_t>()(value);
}
size_t operator()(const int& value) const {
return std::hash<int>()(value);
}
size_t operator()(const bool& value) const {
return std::hash<bool>()(value);
}
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const DataType& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const ObjectRef& value) const;

private:
friend class AttrsHashHandler;
/*! \brief internal handle. */
AttrsHashHandler* handler_{nullptr};
};

/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
Expand Down Expand Up @@ -266,20 +179,6 @@ class BaseAttrsNode : public Object {
* \note This function throws when the required field is not present.
*/
TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
/*!
* \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node.
* \param equal The equal comparator
* \return The comparison result.
*/
TVM_DLL virtual bool ContentEqual(
const Object* other, AttrsEqual equal) const = 0;
/*!
* \brief Content aware hash.
* \param hasher The hasher to run the hash.
* \return the hash result.
*/
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down Expand Up @@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Object* other, AttrsEqual equal) const final;
size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
Expand Down Expand Up @@ -386,34 +283,6 @@ class AttrNormalVisitor {
AttrVisitor* visitor_;
};

// Wrapper for normal visitor.
class AttrsEqualVisitor {
public:
bool result_{true};
// constructor
AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
AttrNopEntry operator()(const char* key, T* lhs_value) {
if (!result_) return AttrNopEntry();
const T* rhs_value =
reinterpret_cast<const T*>(
reinterpret_cast<const char*>(rhs_) +
(reinterpret_cast<const char*>(lhs_value) -
reinterpret_cast<const char*>(lhs_)));
if (!equal_(*lhs_value, *rhs_value)) {
result_ = false;
}
return AttrNopEntry();
}

private:
const Object* lhs_;
const Object* rhs_;
const AttrsEqual& equal_;
};

class AttrsSEqualVisitor {
public:
bool result_{true};
Expand Down Expand Up @@ -441,23 +310,6 @@ class AttrsSEqualVisitor {
const SEqualReducer& equal_;
};

class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
: hasher_(hasher) {}

size_t result_{0};

template<typename T>
AttrNopEntry operator()(const char* key, T* value) {
result_ = dmlc::HashCombine(result_, hasher_(*value));
return AttrNopEntry();
}

private:
const AttrsHash& hasher_;
};

class AttrsSHashVisitor {
public:
explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
Expand Down Expand Up @@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry {
return *this;
}
TSelf& set_default(const T& value) {
if (AttrsEqual()(value, *data_)) {
if (tvm::StructuralEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
Expand Down Expand Up @@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_;
}

bool ContentEqual(const Object* other, AttrsEqual equal) const final {
DerivedType* pself = self();
if (pself == other) return true;
if (other == nullptr) return false;
if (pself->type_index() != other->type_index()) return false;
::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal);
self()->__VisitAttrs__(visitor);
return visitor.result_;
}

size_t ContentHash(AttrsHash hasher) const final {
::tvm::detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = this->GetTypeKeyHash();
self()->__VisitAttrs__(visitor);
return visitor.result_;
}

private:
DerivedType* self() const {
return const_cast<DerivedType*>(
Expand Down
89 changes: 0 additions & 89 deletions src/ir/attr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,94 +147,5 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
}
};

class AttrsEqualHandler :
protected AttrFunctor<bool(const ObjectRef&, const ObjectRef&)> {
public:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs);

protected:
bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final;
};

class AttrsHashHandler :
protected AttrFunctor<size_t(const ObjectRef&)> {
public:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t Hash(const ObjectRef& node) {
if (!node.defined()) return 0;
return this->VisitAttr(node);
}

protected:
size_t VisitAttrDefault_(const Object* lhs) final;
size_t VisitAttr_(const tir::IntImmNode* lhs) final;
size_t VisitAttr_(const tir::FloatImmNode* lhs) final;
size_t VisitAttr_(const tir::StringImmNode* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
size_t VisitAttr_(const tir::AddNode* op) final;
size_t VisitAttr_(const tir::SubNode* op) final;
size_t VisitAttr_(const tir::MulNode* op) final;
size_t VisitAttr_(const tir::DivNode* op) final;
size_t VisitAttr_(const tir::ModNode* op) final;
size_t VisitAttr_(const tir::FloorDivNode* op) final;
size_t VisitAttr_(const tir::FloorModNode* op) final;
size_t VisitAttr_(const tir::MinNode* op) final;
size_t VisitAttr_(const tir::MaxNode* op) final;
size_t VisitAttr_(const tir::GENode* op) final;
size_t VisitAttr_(const tir::GTNode* op) final;
size_t VisitAttr_(const tir::LENode* op) final;
size_t VisitAttr_(const tir::LTNode* op) final;
size_t VisitAttr_(const tir::EQNode* op) final;
size_t VisitAttr_(const tir::NENode* op) final;
size_t VisitAttr_(const tir::AndNode* op) final;
size_t VisitAttr_(const tir::OrNode* op) final;
size_t VisitAttr_(const tir::NotNode* op) final;
size_t VisitAttr_(const tir::CastNode* op) final;
size_t VisitAttr_(const tir::CallNode* op) final;
size_t VisitAttr_(const tir::SelectNode* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static size_t Combine(size_t lhs, size_t rhs) {
return dmlc::HashCombine(lhs, rhs);
}
};
} // namespace tvm
#endif // TVM_IR_ATTR_FUNCTOR_H_
Loading