diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 51eae5a9ab7d..e04b4e639dd8 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -225,28 +225,6 @@ TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); */ TVM_DLL Array UnmatchedCases(const Match& match, const IRModule& mod); -/*! \brief A hashing structure in the style of std::hash. */ -struct StructuralHash { - /*! \brief Hash a Relay type. - * - * Implements structural hashing of a Relay type. - * - * \param type the type to hash. - * - * \return the hash value. - */ - size_t operator()(const Type& type) const; - /*! \brief Hash a Relay expression. - * - * Implements structural hashing of a Relay expression. - * - * \param expr the expression to hash. - * - * \return the hash value. - */ - size_t operator()(const Expr& expr) const; -}; - } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index b09a40bb9957..21f3edfb99eb 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -20,11 +20,10 @@ This file contains the set of passes for Relay, which exposes an interface for configuring the passes and scripting them in Python. """ -from tvm.ir import RelayExpr, IRModule +from tvm.ir import IRModule from . import _ffi_api from .feature import Feature -from ..ty import Type def post_order_visit(expr, fvisit): @@ -314,29 +313,6 @@ def detect_feature(a, b=None): return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)} -def structural_hash(value): - """Hash a Relay expression structurally. - - Parameters - ---------- - expr : Union[tvm.relay.Expr, tvm.relay.Type] - The expression to hash. - - Returns - ------- - result : int - The hash value - """ - if isinstance(value, RelayExpr): - return int(_ffi_api._expr_hash(value)) - elif isinstance(value, Type): - return int(_ffi_api._type_hash(value)) - else: - msg = ("found value of type {0} expected" + - "relay.Expr or relay.Type").format(type(value)) - raise TypeError(msg) - - def extract_fused_functions(mod): """Pass to extract IRModule of only fused primitive functions. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d0b90e5fb295..56aa1d6dcaf8 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,7 +27,7 @@ from tvm.ir import IRModule from tvm.relay.prelude import Prelude -from tvm.relay.analysis import structural_hash as s_hash +from tvm.ir import structural_hash as s_hash from .. import analysis from .. import expr as _expr diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index e85000052231..eec5e16fdd13 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -238,7 +238,7 @@ def create_op_call(self, op: Function, relay_args, py_args): # compile the function and register globally cc_key = compile_engine.CCacheKey(op, self.tgt) - func_hash = relay.analysis.structural_hash(op) + func_hash = tvm.ir.structural_hash(op) op_name = '_lowered_op_{}'.format(func_hash) if not tvm.get_global_func(op_name, allow_missing=True): jitted = self.engine.jit(cc_key, self.tgt) diff --git a/src/relay/analysis/extract_fused_functions.cc b/src/relay/analysis/extract_fused_functions.cc index 8cb517f7e33d..ff3756cd318d 100644 --- a/src/relay/analysis/extract_fused_functions.cc +++ b/src/relay/analysis/extract_fused_functions.cc @@ -21,6 +21,7 @@ * \file extract_fused_functions.cc * \brief Apply fusion and extract fused primitive functions from an IRModule */ +#include #include #include #include @@ -55,7 +56,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor { if (n->HasNonzeroAttr(attr::kPrimitive)) { // Add function to functions, keyed by function hash string Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs); - size_t hash_ = StructuralHash()(func); + size_t hash_ = tvm::StructuralHash()(func); this->functions.Set(std::to_string(hash_), func); } diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index eec2bd344f15..9bd6a4ef31b6 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -26,6 +26,7 @@ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #include +#include #include #include #include @@ -258,7 +259,7 @@ bool IsDynamic(const Type& ty); inline size_t CCacheKeyNode::Hash() const { if (hash_ != 0) return hash_; // do structral hash, avoid 0. - hash_ = StructuralHash()(this->source_func); + hash_ = tvm::StructuralHash()(this->source_func); hash_ = dmlc::HashCombine( hash_, std::hash()(target->str())); if (hash_ == 0) hash_ = 1; diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 398760ffa789..80745e1a1114 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -39,7 +40,7 @@ namespace relay { namespace vm { inline std::string GenerateName(const Function& func) { - size_t hash = StructuralHash()(func); + size_t hash = tvm::StructuralHash()(func); return std::string("lifted_name") + std::to_string(hash); } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc deleted file mode 100644 index ce15e2a3fe70..000000000000 --- a/src/relay/ir/hash.cc +++ /dev/null @@ -1,437 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/ir/hash.cc - * \brief Hash functions for Relay types and expressions. - */ -#include -#include -#include -#include -#include -#include -#include -#include "../../ir/attr_functor.h" - -namespace tvm { -namespace relay { - -// Hash handler for Relay. -class RelayHashHandler: - public AttrsHashHandler, - public TypeFunctor, - public ExprFunctor, - public PatternFunctor { - public: - explicit RelayHashHandler() {} - - /*! - * Compute hash of a node. - * \param ref The node to hash. - * \return the hash value. - */ - size_t Hash(const ObjectRef& ref) { - if (!ref.defined()) return ObjectHash()(ref); - - if (ref->IsInstance()) { - return TypeHash(Downcast(ref)); - } - if (ref->IsInstance()) { - return ExprHash(Downcast(ref)); - } - return AttrHash(ref); - } - - /*! - * Compute hash of the attributes. - * \param ref The attributes. - * \return the hash value - */ - size_t AttrHash(const ObjectRef& ref) { - if (!ref.defined()) { - return ObjectHash()(ref); - } - return AttrsHashHandler::Hash(ref); - } - /*! - * Compute hash of a Relay type. - * \param ref The type to hash. - * \param rhs The right hand operand. - * \return the hash value. - */ - size_t TypeHash(const Type& type) { - if (!type.defined()) { - return ObjectHash()(type); - } - auto found = hash_map_.find(type); - if (found != hash_map_.end()) { - return found->second; - } else { - auto hash = this->VisitType(type); - hash_map_.insert({type, hash}); - return hash; - } - } - /*! - * Compute the hash of an expression. - * - * \note We run graph structural equality checking when comparing two Exprs. - * This means that AlphaEqualHandler can only be used once for each pair. - * The equality checker checks data-flow equvalence of the Expr DAG. - * This function also runs faster as it memomizes equal_map. - * - * \param expr The expression to hash. - * \return the hash value. - */ - size_t ExprHash(const Expr& expr) { - if (!expr.defined()) { - return ObjectHash()(expr); - } - auto found = hash_map_.find(expr); - if (found != hash_map_.end()) { - return found->second; - } else { - auto hash = this->VisitExpr(expr); - hash_map_.insert({expr, hash}); - return hash; - } - } - - protected: - /*! - * \brief Hash a DataType. - * \param dtype The dtype to hash. - * \return the hash value. - */ - size_t DataTypeHash(const DataType& dtype) { - return ::tvm::AttrsHash()(dtype); - } - - using AttrsHashHandler::VisitAttr_; - size_t VisitAttr_(const tvm::tir::VarNode* var) final { - size_t hash = std::hash()(VarNode::_type_key); - auto it = hash_map_.find(GetRef(var)); - if (it != hash_map_.end()) { - return it->second; - } - return Combine(hash, std::hash()(var->name_hint)); - } - - // Type hashing - size_t VisitType_(const TensorTypeNode* tensor_type) final { - size_t hash = std::hash()(TensorTypeNode::_type_key); - hash = Combine(hash, DataTypeHash(tensor_type->dtype)); - hash = Combine(hash, Hash(tensor_type->shape)); - return hash; - } - - size_t VisitType_(const IncompleteTypeNode* incomplete) final { - size_t hash = std::hash()(IncompleteTypeNode::_type_key); - return Combine(hash, std::hash()(incomplete->kind)); - } - - size_t VisitType_(const TypeVarNode* tyvar) final { - /* - TypeVar/Var/Variable have two locations where they are hashed: - - The declaration site of a function, let, or function type. - The first occurence in the term. - - We will only reach this code if the TypeVar itself is unbound, we assign - a free variable index to it, meaning this hashing function implements - structural equality for both open (i.e graph equality) and closed terms - (i.e alpha_equality). - */ - return BindVar(GetRef(tyvar)); - } - - size_t VisitType_(const FuncTypeNode* func_type) final { - size_t hash = std::hash()(FuncTypeNode::_type_key); - - for (auto type_param : func_type->type_params) { - hash = Combine(hash, BindVar(type_param)); - } - - for (auto arg : func_type->arg_types) { - hash = Combine(hash, TypeHash(arg)); - } - - hash = Combine(hash, TypeHash(func_type->ret_type)); - for (auto cs : func_type->type_constraints) { - hash = Combine(hash, TypeHash(cs)); - } - - return hash; - } - - size_t VisitType_(const TypeRelationNode* type_rel) final { - size_t hash = std::hash()(TypeRelationNode::_type_key); - hash = Combine(hash, std::hash()(type_rel->func->name)); - hash = Combine(hash, AttrHash(type_rel->attrs)); - - for (auto arg : type_rel->args) { - hash = Combine(hash, TypeHash(arg)); - } - - return hash; - } - - size_t VisitType_(const TupleTypeNode* tuple_type) final { - size_t hash = std::hash()(TupleTypeNode::_type_key); - for (size_t i = 0; i < tuple_type->fields.size(); i++) { - hash = Combine(hash, TypeHash(tuple_type->fields[i])); - } - return hash; - } - - size_t VisitType_(const RelayRefTypeNode* rtn) final { - size_t hash = std::hash()(RelayRefTypeNode::_type_key); - hash = Combine(hash, TypeHash(rtn->value)); - return hash; - } - - // Expr hashing. - size_t NDArrayHash(const runtime::NDArray& array) { - size_t hash = std::hash()(array->dtype.code); - hash = Combine(hash, std::hash()(array->dtype.bits)); - hash = Combine(hash, std::hash()(array->dtype.lanes)); - CHECK_EQ(array->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - size_t data_size = runtime::GetDataSize(*array.operator->()); - uint8_t * data = reinterpret_cast(array->data); - for (size_t i = 0; i < data_size; i++) { - hash = Combine(hash, std::hash()(data[i])); - } - return hash; - } - - size_t BindVar(const ObjectRef& var) { - size_t hash = std::hash()(var_counter++); - CHECK_EQ(hash_map_.count(var), 0); - if (auto var_node = var.as()) { - hash = Combine(hash, TypeHash(var_node->type_annotation)); - } - hash_map_[var] = hash; - return hash; - } - - size_t VisitExpr_(const VarNode* var) final { - // hash free variable - size_t name_hash = std::hash()(var->vid.get()); - return Combine(name_hash, TypeHash(var->type_annotation)); - } - - size_t VisitExpr_(const GlobalVarNode* global) final { - return std::hash()(global->name_hint); - } - - size_t VisitExpr_(const TupleNode* tuple) final { - size_t hash = std::hash()(TupleNode::_type_key); - for (size_t i = 0; i < tuple->fields.size(); i++) { - hash = Combine(hash, ExprHash(tuple->fields[i])); - } - return hash; - } - - size_t VisitExpr_(const FunctionNode* func) final { - size_t hash = std::hash()(FunctionNode::_type_key); - for (auto type_param : func->type_params) { - hash = Combine(hash, BindVar(type_param)); - } - - for (auto param : func->params) { - hash = Combine(hash, BindVar(param)); - } - - hash = Combine(hash, TypeHash(func->ret_type)); - hash = Combine(hash, ExprHash(func->body)); - - hash = Combine(hash, AttrHash(func->attrs)); - - return hash; - } - - size_t VisitExpr_(const CallNode* call) final { - size_t hash = std::hash()(CallNode::_type_key); - hash = Combine(hash, ExprHash(call->op)); - - for (auto arg : call->args) { - hash = Combine(hash, ExprHash(arg)); - } - - for (auto t : call->type_args) { - CHECK(t.defined()); - hash = Combine(hash, TypeHash(t)); - } - - hash = Combine(hash, AttrHash(call->attrs)); - - return hash; - } - - size_t VisitExpr_(const LetNode* let) final { - size_t hash = std::hash()(LetNode::_type_key); - hash = Combine(hash, BindVar(let->var)); - hash = Combine(hash, ExprHash(let->value)); - hash = Combine(hash, ExprHash(let->body)); - return hash; - } - - size_t VisitExpr_(const IfNode* ite) final { - size_t key = std::hash()(IfNode::_type_key); - size_t hash = key; - hash = Combine(hash, ExprHash(ite->cond)); - hash = Combine(hash, ExprHash(ite->true_branch)); - hash = Combine(hash, ExprHash(ite->false_branch)); - return hash; - } - - size_t VisitExpr_(const OpNode* op) final { - return ObjectHash()(GetRef(op)); - } - - size_t VisitExpr_(const ConstantNode* rconst) final { - return NDArrayHash(rconst->data); - } - - size_t VisitExpr_(const TupleGetItemNode* get_item) final { - size_t hash = std::hash()(TupleGetItemNode::_type_key); - hash = Combine(hash, ExprHash(get_item->tuple)); - hash = Combine(hash, std::hash()(get_item->index)); - return hash; - } - - size_t VisitExpr_(const RefCreateNode* rn) final { - size_t hash = std::hash()(RefCreateNode::_type_key); - hash = Combine(hash, ExprHash(rn->value)); - return hash; - } - - size_t VisitExpr_(const RefReadNode* rn) final { - size_t hash = std::hash()(RefReadNode::_type_key); - hash = Combine(hash, ExprHash(rn->ref)); - return hash; - } - - size_t VisitExpr_(const RefWriteNode* rn) final { - size_t hash = std::hash()(RefWriteNode::_type_key); - hash = Combine(hash, ExprHash(rn->ref)); - hash = Combine(hash, ExprHash(rn->value)); - return hash; - } - - size_t VisitExpr_(const MatchNode* mn) final { - size_t hash = std::hash()(MatchNode::_type_key); - hash = Combine(hash, ExprHash(mn->data)); - for (const auto& c : mn->clauses) { - hash = Combine(hash, PatternHash(c->lhs)); - hash = Combine(hash, ExprHash(c->rhs)); - } - hash = Combine(hash, std::hash()(mn->complete)); - return hash; - } - - size_t VisitExpr_(const ConstructorNode* cn) final { - size_t hash = std::hash()(ConstructorNode::_type_key); - hash = Combine(hash, std::hash()(cn->name_hint)); - return hash; - } - - size_t VisitType_(const TypeCallNode* tcn) final { - size_t hash = std::hash()(TypeCallNode::_type_key); - hash = Combine(hash, TypeHash(tcn->func)); - for (const auto& t : tcn->args) { - hash = Combine(hash, TypeHash(t)); - } - return hash; - } - - size_t VisitType_(const TypeDataNode* tdn) final { - size_t hash = std::hash()(TypeDataNode::_type_key); - hash = Combine(hash, TypeHash(tdn->header)); - for (const auto& tv : tdn->type_vars) { - hash = Combine(hash, TypeHash(tv)); - } - for (const auto& cn : tdn->constructors) { - hash = Combine(hash, ExprHash(cn)); - } - return hash; - } - - size_t VisitType_(const GlobalTypeVarNode* tvn) final { - return BindVar(GetRef(tvn)); - } - - size_t PatternHash(const Pattern& p) { - return VisitPattern(p); - } - - size_t VisitPattern_(const PatternConstructorNode* pcn) final { - size_t hash = std::hash()(PatternConstructorNode::_type_key); - hash = Combine(hash, ExprHash(pcn->constructor)); - for (const auto& p : pcn->patterns) { - hash = Combine(hash, PatternHash(p)); - } - return hash; - } - - size_t VisitPattern_(const PatternTupleNode* ptn) final { - size_t hash = std::hash()(PatternTupleNode::_type_key); - for (const auto& p : ptn->patterns) { - hash = Combine(hash, PatternHash(p)); - } - return hash; - } - - size_t VisitPattern_(const PatternVarNode* pvn) final { - size_t hash = std::hash()(PatternVarNode::_type_key); - hash = Combine(hash, BindVar(pvn->var)); - return hash; - } - - size_t VisitPattern_(const PatternWildcardNode* pwn) final { - size_t hash = std::hash()(PatternWildcardNode::_type_key); - return hash; - } - private: - // renaming of NodeRef to indicate two nodes equals to each other - std::unordered_map hash_map_; - int var_counter = 0; -}; - -size_t StructuralHash::operator()(const Type& type) const { - return RelayHashHandler().TypeHash(type); -} - -size_t StructuralHash::operator()(const Expr& expr) const { - return RelayHashHandler().ExprHash(expr); -} - -TVM_REGISTER_GLOBAL("relay.analysis._expr_hash") -.set_body_typed([](ObjectRef ref) { - return static_cast(RelayHashHandler().Hash(ref)); -}); - -TVM_REGISTER_GLOBAL("relay.analysis._type_hash") -.set_body_typed([](Type type) { - return static_cast(RelayHashHandler().TypeHash(type)); -}); - -} // namespace relay -} // namespace tvm diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index e7980e712035..c291c4e6b170 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -31,7 +31,8 @@ def alpha_equal(x, y): """ x = x['main'] y = y['main'] - return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) + return tvm.ir.structural_equal(x, y) and \ + tvm.ir.structural_hash(x) == tvm.ir.structural_hash(y) def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes]