Skip to content

Commit

Permalink
relay::StructuralHash to tvm::StructuralHash (#5166)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored Mar 29, 2020
1 parent 919ae88 commit a2edd01
Show file tree
Hide file tree
Showing 9 changed files with 11 additions and 490 deletions.
22 changes: 0 additions & 22 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,28 +225,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};

} // namespace relay
} // namespace tvm

Expand Down
26 changes: 1 addition & 25 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
"""
from tvm.ir import RelayExpr, IRModule
from tvm.ir import IRModule

from . import _ffi_api
from .feature import Feature
from ..ty import Type


def post_order_visit(expr, fvisit):
Expand Down Expand Up @@ -314,29 +313,6 @@ def detect_feature(a, b=None):
return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)}


def structural_hash(value):
"""Hash a Relay expression structurally.
Parameters
----------
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result : int
The hash value
"""
if isinstance(value, RelayExpr):
return int(_ffi_api._expr_hash(value))
elif isinstance(value, Type):
return int(_ffi_api._type_hash(value))
else:
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)


def extract_fused_functions(mod):
"""Pass to extract IRModule of only fused primitive functions.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from tvm.ir import IRModule
from tvm.relay.prelude import Prelude
from tvm.relay.analysis import structural_hash as s_hash
from tvm.ir import structural_hash as s_hash

from .. import analysis
from .. import expr as _expr
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def create_op_call(self, op: Function, relay_args, py_args):

# compile the function and register globally
cc_key = compile_engine.CCacheKey(op, self.tgt)
func_hash = relay.analysis.structural_hash(op)
func_hash = tvm.ir.structural_hash(op)
op_name = '_lowered_op_{}'.format(func_hash)
if not tvm.get_global_func(op_name, allow_missing=True):
jitted = self.engine.jit(cc_key, self.tgt)
Expand Down
3 changes: 2 additions & 1 deletion src/relay/analysis/extract_fused_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file extract_fused_functions.cc
* \brief Apply fusion and extract fused primitive functions from an IRModule
*/
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
Expand Down Expand Up @@ -55,7 +56,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor {
if (n->HasNonzeroAttr(attr::kPrimitive)) {
// Add function to functions, keyed by function hash string
Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs);
size_t hash_ = StructuralHash()(func);
size_t hash_ = tvm::StructuralHash()(func);
this->functions.Set(std::to_string(hash_), func);
}

Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/compile_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_

#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
Expand Down Expand Up @@ -258,7 +259,7 @@ bool IsDynamic(const Type& ty);
inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_;
// do structral hash, avoid 0.
hash_ = StructuralHash()(this->source_func);
hash_ = tvm::StructuralHash()(this->source_func);
hash_ = dmlc::HashCombine(
hash_, std::hash<std::string>()(target->str()));
if (hash_ == 0) hash_ = 1;
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/

#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
Expand All @@ -39,7 +40,7 @@ namespace relay {
namespace vm {

inline std::string GenerateName(const Function& func) {
size_t hash = StructuralHash()(func);
size_t hash = tvm::StructuralHash()(func);
return std::string("lifted_name") + std::to_string(hash);
}

Expand Down
Loading

0 comments on commit a2edd01

Please sign in to comment.