From dcee93770cb2c8658dcdbfa96ef326e79e9f3788 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 2 Oct 2019 15:39:54 -0700 Subject: [PATCH] [QNN][Relay] Calling Dialect passes from inside Relay Build API. (#3971) --- include/tvm/relay/op.h | 16 ++++++ include/tvm/relay/qnn/transform.h | 60 ++++++++++++++++++++ src/relay/backend/build_module.cc | 15 +++-- src/relay/ir/op.cc | 11 ++++ src/relay/pass/legalize.cc | 58 +++++++++++-------- src/relay/qnn/pass/legalize.cc | 47 +++++++++++++++ tests/python/relay/test_op_qnn_conv2d.py | 1 - tests/python/relay/test_op_qnn_dequantize.py | 1 - tests/python/relay/test_op_qnn_quantize.py | 1 - tests/python/relay/test_op_qnn_requantize.py | 1 - 10 files changed, 177 insertions(+), 34 deletions(-) create mode 100644 include/tvm/relay/qnn/transform.h create mode 100644 src/relay/qnn/pass/legalize.cc diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index e4c96490ff53..0a6d3725655f 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -153,6 +153,12 @@ class Op : public relay::Expr { */ template inline static OpMap GetAttr(const std::string& attr_name); + /*! + * \brief Checks if an attr is present in the registry. + * \param attr_name The name of the attribute. + * \return bool True if the attr is present. + */ + inline static bool HasAttr(const std::string& attr_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. @@ -171,6 +177,12 @@ class Op : public relay::Expr { * \return reference to GenericOpMap */ TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); + /*! + * \brief Checks if the key is present in the registry + * \param key The attribute key + * \return bool True if the key is present + */ + TVM_DLL static const bool HasGenericAttr(const std::string& key); }; /*! \brief Helper structure to register operators */ @@ -393,6 +405,10 @@ inline OpMap Op::GetAttr(const std::string& key) { return OpMap(Op::GetGenericAttr(key)); } +inline bool Op::HasAttr(const std::string& key) { + return Op::HasGenericAttr(key); +} + inline OpNode* OpRegistry::get() { return const_cast(op_.operator->()); } diff --git a/include/tvm/relay/qnn/transform.h b/include/tvm/relay/qnn/transform.h new file mode 100644 index 000000000000..10cd19afe6f3 --- /dev/null +++ b/include/tvm/relay/qnn/transform.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/qnn/transform.h + * + * This file implements a pass manager for QNN ops using Relay Pass manager. + */ +#ifndef TVM_RELAY_QNN_TRANSFORM_H_ +#define TVM_RELAY_QNN_TRANSFORM_H_ + +#include +#include + +namespace tvm { +namespace relay { + +using relay::transform::Pass; + +namespace qnn { +namespace transform { + +/*! + * \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First, + * converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops. + * Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass. + * One can register the lowering/transformation function for this op using FTVMQnnCanonicalize + * attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes + * only QNN ops. One can register a transformation/legalization function for an op by using the + * FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize + * gives us separation of concerns, leading to a better software practice. The legalization can be + * configured to happen per target. + * + * \return The pass. + */ +TVM_DLL Pass Legalize(); + +} // namespace transform + +} // namespace qnn +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_QNN_TRANSFORM_H_ diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ef3ab723a28b..20e760f9956b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include "utils.h" @@ -286,6 +287,15 @@ class RelayBuildModule : public runtime::ModuleNode { const TargetsMap& targets, const std::unordered_map& params) { Array pass_seqs; + + // Run all dialect legalization passes. + pass_seqs.push_back(relay::qnn::transform::Legalize()); + + // Legalize pass is restricted to homogeneous execution for now. + if (targets.size() == 1) { + pass_seqs.push_back(transform::Legalize()); + } + pass_seqs.push_back(transform::SimplifyInference()); PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { Expr expr = args[0]; @@ -309,11 +319,6 @@ class RelayBuildModule : public runtime::ModuleNode { pass_seqs.push_back(transform::CanonicalizeCast()); pass_seqs.push_back(transform::CanonicalizeOps()); - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - // Alter layout transformation is only applied to homogeneous execution yet. if (targets.size() == 1) { pass_seqs.push_back(transform::AlterOpLayout()); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 76b56ae954ad..d098863208fc 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -84,6 +84,17 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) { return *it->second.get(); } +// Check if a key is present in the registry. +const bool Op::HasGenericAttr(const std::string& key) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + auto it = mgr->attr.find(key); + if (it == mgr->attr.end()) { + return false; + } + return true; +} + void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) { diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc index 07b1d81e042b..f57d9103412e 100644 --- a/src/relay/pass/legalize.cc +++ b/src/relay/pass/legalize.cc @@ -46,32 +46,40 @@ class Legalizer : public ExprMutator { Expr new_e = ExprMutator::VisitExpr_(call_node); Call new_call = Downcast(new_e); + // Check if the string is registered in the OpRegistry. + if (!Op::HasAttr(legalize_map_attr_name_)) { + return new_e; + } + // Collect the registered legalize function. auto fop_legalize = Op::GetAttr(legalize_map_attr_name_); - Op op = Downcast(call_node->op); - - if (fop_legalize.count(op)) { - // Collect the new_args. - tvm::Array call_args = new_call->args; - - // Collect input and output dtypes to pass on to Legalize API. - tvm::Array types; - for (auto arg : call_node->args) { - types.push_back(arg->checked_type()); - } - types.push_back(call_node->checked_type()); - - // Transform the op by calling the registered legalize function. - Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types); - - // Reassign new_e if the transformation succeeded. - if (legalized_value.defined()) { - // Check that the returned Expr from legalize is CallNode. - const CallNode* legalized_call_node = legalized_value.as(); - CHECK(legalized_call_node) - << "Can only replace the original operator with another call node"; - - new_e = legalized_value; + auto call_op = call_node->op; + if (call_op.as()) { + Op op = Downcast(call_node->op); + + if (fop_legalize.count(op)) { + // Collect the new_args. + tvm::Array call_args = new_call->args; + + // Collect input and output dtypes to pass on to Legalize API. + tvm::Array types; + for (auto arg : call_node->args) { + types.push_back(arg->checked_type()); + } + types.push_back(call_node->checked_type()); + + // Transform the op by calling the registered legalize function. + Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types); + + // Reassign new_e if the transformation succeeded. + if (legalized_value.defined()) { + // Check that the returned Expr from legalize is CallNode. + const CallNode* legalized_call_node = legalized_value.as(); + CHECK(legalized_call_node) + << "Can only replace the original operator with another call node"; + + new_e = legalized_value; + } } } @@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) { [=](Function f, Module m, PassContext pc) { return Downcast(relay::legalize::Legalize(f, legalize_map_attr_name)); }; - return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")}); + return CreateFunctionPass(pass_func, 0, "Legalize", {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize); diff --git a/src/relay/qnn/pass/legalize.cc b/src/relay/qnn/pass/legalize.cc new file mode 100644 index 000000000000..07864add4827 --- /dev/null +++ b/src/relay/qnn/pass/legalize.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/qnn/pass/legalize.cc + * \brief The Legalize wrapper for QNN. + */ + +#include + +namespace tvm { +namespace relay { +namespace qnn { + +namespace transform { + +Pass Legalize() { + Array pass_seqs; + pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize")); + pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize")); + relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); + return seq; +} + +TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize); + +} // namespace transform + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index dd4ad8d491fa..c8e479d99ee4 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -77,7 +77,6 @@ def get_qnn_func(data, mod = relay.Function(relay.analysis.free_vars(func), func) mod = relay.Module.from_expr(mod) - mod = relay.qnn.transform.CanonicalizeOps()(mod) return mod def get_funcs(data_shape, diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index 76b61ae48ab8..51258651ab36 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): input_zero_point=input_zero_point) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = relay.Module.from_expr(mod) - mod = relay.qnn.transform.CanonicalizeOps()(mod) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 15319a734894..9805db56ced7 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output output_zero_point=output_zero_point,out_dtype=out_dtype) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = relay.Module.from_expr(mod) - mod = relay.qnn.transform.CanonicalizeOps()(mod) with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index 131500c094bf..18e2f308969b 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -49,7 +49,6 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale, mod = relay.Function(relay.analysis.free_vars(mod), mod) mod = relay.Module.from_expr(mod) - mod = relay.qnn.transform.CanonicalizeOps()(mod) return mod def same_scale_test():