Skip to content

Commit

Permalink
[IR] Platform-independent SHash (#14204)
Browse files Browse the repository at this point in the history
This PR introduces the necessary change to make structural-hash
platform-independent. The change includes:
- Explicitly require the hash type to be `uint64_t` rather than
  platform-dependent `size_t`.
- Implement structural hash for POD types including `double`, `int`
  by unioning them with `uint64_t` explicitly.
- Implement a platform-independent fast hashing algorithm for `std::string`
  • Loading branch information
junrushao authored Mar 7, 2023
1 parent 9b91247 commit 012d6a7
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 82 deletions.
2 changes: 1 addition & 1 deletion include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class ReflectionVTable::Registry {
* static constexpr const std::nullptr_t VisitAttrs = nullptr;
*
* static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) {
* hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size));
* hash_reduce->SHashReduceHashedValue(runtime::String::StableHashBytes(key->data, key->size));
* }
*
* static bool SEqualReduce(const runtime::StringObj* lhs,
Expand Down
74 changes: 41 additions & 33 deletions include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,52 +36,60 @@ namespace tvm {
* \brief Hash definition of base value classes.
*/
class BaseValueHash {
public:
size_t operator()(const double& key) const { return std::hash<double>()(key); }

size_t operator()(const int64_t& key) const { return std::hash<int64_t>()(key); }

size_t operator()(const uint64_t& key) const { return std::hash<uint64_t>()(key); }

size_t operator()(const int& key) const { return std::hash<int>()(key); }

size_t operator()(const bool& key) const { return std::hash<bool>()(key); }

size_t operator()(const std::string& key) const { return std::hash<std::string>()(key); }

size_t operator()(const runtime::DataType& key) const {
return std::hash<int32_t>()(static_cast<int32_t>(key.code()) |
(static_cast<int32_t>(key.bits()) << 8) |
(static_cast<int32_t>(key.lanes()) << 16));
protected:
template <typename T, typename U>
uint64_t Reinterpret(T value) const {
union Union {
T a;
U b;
} u;
static_assert(sizeof(Union) == sizeof(T), "sizeof(Union) != sizeof(T)");
static_assert(sizeof(Union) == sizeof(U), "sizeof(Union) != sizeof(U)");
u.b = 0;
u.a = value;
return u.b;
}

public:
uint64_t operator()(const float& key) const { return Reinterpret<float, uint32_t>(key); }
uint64_t operator()(const double& key) const { return Reinterpret<double, uint64_t>(key); }
uint64_t operator()(const int64_t& key) const { return Reinterpret<int64_t, uint64_t>(key); }
uint64_t operator()(const uint64_t& key) const { return key; }
uint64_t operator()(const int& key) const { return Reinterpret<int, uint32_t>(key); }
uint64_t operator()(const bool& key) const { return key; }
uint64_t operator()(const runtime::DataType& key) const {
return Reinterpret<DLDataType, uint32_t>(key);
}
template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& key) const {
return std::hash<size_t>()(static_cast<size_t>(key));
uint64_t operator()(const ENum& key) const {
return Reinterpret<int64_t, uint64_t>(static_cast<int64_t>(key));
}
uint64_t operator()(const std::string& key) const {
return runtime::String::StableHashBytes(key.data(), key.length());
}
};

/*!
* \brief Content-aware structural hasing.
* \brief Content-aware structural hashing.
*
* The structural hash value is recursively defined in the DAG of IRNodes.
* There are two kinds of nodes:
*
* - Normal node: the hash value is defined by its content and type only.
* - Graph node: each graph node will be assigned a unique index ordered by the
* first occurence during the visit. The hash value of a graph node is
* first occurrence during the visit. The hash value of a graph node is
* combined from the hash values of its contents and the index.
*/
class StructuralHash : public BaseValueHash {
public:
// inheritate operator()
// inherit operator()
using BaseValueHash::operator();
/*!
* \brief Compute structural hashing value for an object.
* \param key The left operand.
* \return The hash value.
*/
TVM_DLL size_t operator()(const ObjectRef& key) const;
TVM_DLL uint64_t operator()(const ObjectRef& key) const;
};

/*!
Expand Down Expand Up @@ -109,23 +117,23 @@ class SHashReducer {
*
* \param hashed_value The hashed value
*/
virtual void SHashReduceHashedValue(size_t hashed_value) = 0;
virtual void SHashReduceHashedValue(uint64_t hashed_value) = 0;
/*!
* \brief Append hash value of key to the current sequence of hashes.
*
* \param key The object to compute hash from.
* \param map_free_vars Whether to map free variables by their occurence number.
* \param map_free_vars Whether to map free variables by their occurrence number.
*/
virtual void SHashReduce(const ObjectRef& key, bool map_free_vars) = 0;
/*!
* \brief Apppend a hash value of free variable to the current sequence of hashes.
* \brief Append a hash value of free variable to the current sequence of hashes.
*
* \param var The var of interest.
* \param map_free_vars Whether to map free variables by their occurence number.
* \param map_free_vars Whether to map free variables by their occurrence number.
*
* \note If map_free_vars is set to be true,
* internally the handler can maintain a counter to encode free variables
* by their order of occurence. This helps to resolve variable
* by their order of occurrence. This helps to resolve variable
* mapping of function parameters and let binding variables.
*
* If map_free_vars is set to be false, the address of the variable will be used.
Expand All @@ -139,7 +147,7 @@ class SHashReducer {
*
* \return Whether there is already a pre-computed hash value.
*/
virtual bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) = 0;
virtual bool LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) = 0;
/*!
* \brief Mark current comparison as graph node in hashing.
* Graph node hash will depends on the graph structure.
Expand Down Expand Up @@ -193,7 +201,7 @@ class SHashReducer {
/*! \brief Internal class pointer. */
Handler* handler_;
/*!
* \brief Whether or not to map free variables by their occurence
* \brief Whether or not to map free variables by their occurrence
* If the flag is false, then free variables will be mapped
* by their in-memory address.
*/
Expand All @@ -210,10 +218,10 @@ class SHashHandlerDefault : public SHashReducer::Handler {
SHashHandlerDefault();
virtual ~SHashHandlerDefault();

void SHashReduceHashedValue(size_t hashed_value) override;
void SHashReduceHashedValue(uint64_t hashed_value) override;
void SHashReduce(const ObjectRef& key, bool map_free_vars) override;
void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) override;
bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) override;
bool LookupHashedValue(const ObjectRef& key, uint64_t* hashed_value) override;
void MarkGraphNode() override;

/*!
Expand All @@ -222,7 +230,7 @@ class SHashHandlerDefault : public SHashReducer::Handler {
* \param map_free_vars Whether or not to remap variables if possible.
* \return The hash result.
*/
virtual size_t Hash(const ObjectRef& object, bool map_free_vars);
virtual uint64_t Hash(const ObjectRef& object, bool map_free_vars);

protected:
/*!
Expand Down
73 changes: 67 additions & 6 deletions include/tvm/runtime/container/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_RUNTIME_CONTAINER_STRING_H_
#define TVM_RUNTIME_CONTAINER_STRING_H_

#include <dmlc/endian.h>
#include <dmlc/logging.h>
#include <tvm/runtime/container/base.h>
#include <tvm/runtime/logging.h>
Expand Down Expand Up @@ -247,10 +248,70 @@ class String : public ObjectRef {
* \param size The size of the bytes.
* \return the hash value.
*/
static size_t HashBytes(const char* data, size_t size) {
// This function falls back to string copy with c++11 compiler and is
// recommended to be compiled with c++14
return std::hash<std::string_view>()(std::string_view(data, size));
static uint64_t StableHashBytes(const char* data, size_t size) {
const constexpr uint64_t kMultiplier = 1099511628211ULL;
const constexpr uint64_t kMod = 2147483647ULL;
union Union {
uint8_t a[8];
uint64_t b;
} u;
static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)");
const char* it = data;
const char* end = it + size;
uint64_t result = 0;
for (; it + 8 <= end; it += 8) {
if (DMLC_IO_NO_ENDIAN_SWAP) {
u.a[0] = it[0];
u.a[1] = it[1];
u.a[2] = it[2];
u.a[3] = it[3];
u.a[4] = it[4];
u.a[5] = it[5];
u.a[6] = it[6];
u.a[7] = it[7];
} else {
u.a[0] = it[7];
u.a[1] = it[6];
u.a[2] = it[5];
u.a[3] = it[4];
u.a[4] = it[3];
u.a[5] = it[2];
u.a[6] = it[1];
u.a[7] = it[0];
}
result = (result * kMultiplier + u.b) % kMod;
}
if (it < end) {
u.b = 0;
uint8_t* a = u.a;
if (it + 4 <= end) {
a[0] = it[0];
a[1] = it[1];
a[2] = it[2];
a[3] = it[3];
it += 4;
a += 4;
}
if (it + 2 <= end) {
a[0] = it[0];
a[1] = it[1];
it += 2;
a += 2;
}
if (it + 1 <= end) {
a[0] = it[0];
it += 1;
a += 1;
}
if (!DMLC_IO_NO_ENDIAN_SWAP) {
std::swap(u.a[0], u.a[7]);
std::swap(u.a[1], u.a[6]);
std::swap(u.a[2], u.a[5]);
std::swap(u.a[3], u.a[4]);
}
result = (result * kMultiplier + u.b) % kMod;
}
return result;
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
Expand Down Expand Up @@ -448,7 +509,7 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, s

inline size_t ObjectHash::operator()(const ObjectRef& a) const {
if (const auto* str = a.as<StringObj>()) {
return String::HashBytes(str->data, str->size);
return String::StableHashBytes(str->data, str->size);
}
return ObjectPtrHash()(a);
}
Expand Down Expand Up @@ -476,7 +537,7 @@ namespace std {
template <>
struct hash<::tvm::runtime::String> {
std::size_t operator()(const ::tvm::runtime::String& str) const {
return ::tvm::runtime::String::HashBytes(str.data(), str.size());
return ::tvm::runtime::String::StableHashBytes(str.data(), str.size());
}
};
} // namespace std
Expand Down
Loading

0 comments on commit 012d6a7

Please sign in to comment.