Skip to content

Commit

Permalink
[QNN][Relay] Calling Dialect passes from inside Relay Build API.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu authored and anijain2305 committed Sep 29, 2019
1 parent 4ba911a commit fee4f73
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 34 deletions.
17 changes: 17 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ class Op : public relay::Expr {
*/
template <typename ValueType>
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
/*!
* \brief Checks if an attr is present in the registry.
* \param attr_name The name of the attribute.
* \return An OpMap of specified attr_name.
* \tparam bool True if the attr is present.
*/
inline static bool HasAttr(const std::string& attr_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
Expand All @@ -171,6 +178,12 @@ class Op : public relay::Expr {
* \return reference to GenericOpMap
*/
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
/*!
* \brief Checks if the key is present in the registryg
* \param key The attribute key
* \return bool True if the key is present
*/
TVM_DLL static const bool HasGenericAttr(const std::string& key);
};

/*! \brief Helper structure to register operators */
Expand Down Expand Up @@ -393,6 +406,10 @@ inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
return OpMap<ValueType>(Op::GetGenericAttr(key));
}

inline bool Op::HasAttr(const std::string& key) {
return Op::HasGenericAttr(key);
}

inline OpNode* OpRegistry::get() {
return const_cast<OpNode*>(op_.operator->());
}
Expand Down
60 changes: 60 additions & 0 deletions include/tvm/relay/qnn/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/qnn/transform.h
*
* This file implements a pass manager for QNN ops using Relay Pass manager.
*/
#ifndef TVM_RELAY_QNN_TRANSFORM_H_
#define TVM_RELAY_QNN_TRANSFORM_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/relay/transform.h>

namespace tvm {
namespace relay {

using relay::transform::Pass;

namespace qnn {
namespace transform {

/*!
* \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First,
* converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops.
* Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass.
* One can register the lowering/transformation function for this op using FTVMQnnCanonicalize
* attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes
* only QNN ops. One can register a transformation/legalization function for an op by using the
* FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize
* gives us separation of concerns, leading to a better software practice. The legalization can be
* configured to happen per target.
*
* \return The pass.
*/
TVM_DLL Pass Legalize();

} // namespace transform

} // namespace qnn
} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_QNN_TRANSFORM_H_
15 changes: 10 additions & 5 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/qnn/transform.h>
#include <memory>

#include "utils.h"
Expand Down Expand Up @@ -282,6 +283,15 @@ class RelayBuildModule : public runtime::ModuleNode {
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
Array<Pass> pass_seqs;

// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}

pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
Expand All @@ -304,11 +314,6 @@ class RelayBuildModule : public runtime::ModuleNode {
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());

// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}

// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
Expand Down
11 changes: 11 additions & 0 deletions src/relay/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
return *it->second.get();
}

// Check if a key is present in the registry.
const bool Op::HasGenericAttr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
if (it == mgr->attr.end()) {
return false;
}
return true;
}

void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
Expand Down
58 changes: 33 additions & 25 deletions src/relay/pass/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,40 @@ class Legalizer : public ExprMutator {
Expr new_e = ExprMutator::VisitExpr_(call_node);
Call new_call = Downcast<Call>(new_e);

// Check if the string is registered in the OpRegistry.
if (!Op::HasAttr(legalize_map_attr_name_)) {
return new_e;
}

// Collect the registered legalize function.
auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
Op op = Downcast<Op>(call_node->op);

if (fop_legalize.count(op)) {
// Collect the new_args.
tvm::Array<Expr> call_args = new_call->args;

// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto arg : call_node->args) {
types.push_back(arg->checked_type());
}
types.push_back(call_node->checked_type());

// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);

// Reassign new_e if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";

new_e = legalized_value;
auto call_op = call_node->op;
if (call_op.as<OpNode>()) {
Op op = Downcast<Op>(call_node->op);

if (fop_legalize.count(op)) {
// Collect the new_args.
tvm::Array<Expr> call_args = new_call->args;

// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto arg : call_node->args) {
types.push_back(arg->checked_type());
}
types.push_back(call_node->checked_type());

// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);

// Reassign new_e if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";

new_e = legalized_value;
}
}
}

Expand All @@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
return CreateFunctionPass(pass_func, 0, "Legalize", {ir::StringImm::make("InferType")});
}

TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
Expand Down
47 changes: 47 additions & 0 deletions src/relay/qnn/pass/legalize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file relay/qnn/pass/legalize.cc
* \brief The Legalize wrapper for QNN.
*/

#include <tvm/relay/qnn/transform.h>

namespace tvm {
namespace relay {
namespace qnn {

namespace transform {

Pass Legalize() {
Array<Pass> pass_seqs;
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
return seq;
}

TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize);

} // namespace transform

} // namespace qnn
} // namespace relay
} // namespace tvm
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def get_qnn_func(data,

mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod

def get_funcs(data_shape,
Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
input_zero_point=input_zero_point)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output
output_zero_point=output_zero_point,out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,

mod = relay.Function(relay.analysis.free_vars(mod), mod)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod

def same_scale_test():
Expand Down

0 comments on commit fee4f73

Please sign in to comment.