Skip to content

Commit

Permalink
[Legalize][QNN] Pass out_types to Legalize. Update QNN requantize to …
Browse files Browse the repository at this point in the history
…read from out_types.
  • Loading branch information
anijain2305 committed Aug 17, 2019
1 parent d3eb9cb commit 506d83d
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 50 deletions.
21 changes: 17 additions & 4 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,23 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)

@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, arg_dtypes):
"""Legalize conv2d"""
from ... import op
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
def legalize_conv2d(attrs, inputs, types):
"""Legalize conv2d op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.conv2d_legalize(attrs, inputs, types)

reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down
12 changes: 9 additions & 3 deletions src/relay/pass/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array<Expr>& new_args, const NodeRef&
Expr new_e;
bool modified = false;
if (fop_legalize.count(op)) {
tvm::Array<tvm::relay::Type> arg_types;
// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto& expr : ref_call->args) {
arg_types.push_back(expr->checked_type());
types.push_back(expr->checked_type());
}
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types);
types.push_back(ref_call->checked_type());

// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types);

// Check if the transformation succeeded. If not, revert back to the original ref_call->op.
if (legalized_value.defined()) {
new_e = legalized_value;
modified = true;
Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor,

Expr DequantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
CHECK(dequantize_attrs != nullptr);
CHECK_EQ(arg_types.size(), 1);
CHECK_EQ(types.size(), 2);
return DequantizeLower(data, dequantize_attrs);
}

Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor,

Expr QuantizeLegalize(const Attrs& attrs,
const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
CHECK(quantize_attrs != nullptr);

CHECK_EQ(arg_types.size(), 1);
CHECK_EQ(types.size(), 2);
return QuantizeLower(data, quantize_attrs);
}

Expand Down
33 changes: 20 additions & 13 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
* 7) Cast to the out_dtype.
*/
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape) {
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
double double_multiplier = param->input_scale / param->output_scale;

// Choose high precision datatype to be int64. This is for avoiding overflow
Expand Down Expand Up @@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
auto shifted_int64_t = Add(output_zp, scaled_int64_t);

// 7) Clip to the out_dtype min/max.
auto q_min = GetQmin(param->out_dtype);
auto q_max = GetQmax(param->out_dtype);
auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
return Cast(clipped_t, param->out_dtype);
return Cast(clipped_t, out_dtype);
}

/*
Expand All @@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
* Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
*/
Expr RequantizeLegalize(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& arg_types) {
const Array<tvm::relay::Type>& types) {
CHECK_EQ(new_args.size(), 1);
auto& quantized_data = new_args[0];
const auto* param = attrs.as<RequantizeAttrs>();
CHECK(param != nullptr);

// Find input shape.
CHECK_EQ(arg_types.size(), 1);
auto input_dtype = arg_types[0];
auto input_tensor_type = input_dtype.as<TensorTypeNode>();
CHECK(input_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = input_tensor_type->shape;
CHECK_EQ(types.size(), 2);
auto in_type = types[0];
auto in_tensor_type = in_type.as<TensorTypeNode>();
CHECK(in_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
Array<IndexExpr> input_shape = in_tensor_type->shape;

// Find the output dtype.
auto out_type = types[1];
auto out_tensor_type = out_type.as<TensorTypeNode>();
CHECK(out_tensor_type != nullptr) << "Type information missing."
<< " Please run infer_type pass.";
auto out_dtype = out_tensor_type->dtype;

// Check rounding validity.
CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
<< "QNN requantize supports two rounding modes - UPWARD and "
<< "TONEAREST";
return RequantizeLower(quantized_data, param, input_shape);
return RequantizeLower(quantized_data, param, input_shape, out_dtype);
}

/*
Expand Down Expand Up @@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized
tensor. For the output tensor, we are provided with output scale and zero
point. The computation looks like this
Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.RequantizeAttrs")
Expand Down
15 changes: 8 additions & 7 deletions tests/python/relay/test_pass_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def before():
return y

@register_legalize("nn.conv2d", level=100)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
Expand Down Expand Up @@ -80,7 +80,7 @@ def before():
called = [False]

@register_legalize("nn.global_max_pool2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
called[0] = True
return None

Expand All @@ -103,12 +103,13 @@ def before():
return func

@register_legalize("concatenate", level=100)
def legalize_concatenate(attrs, inputs, arg_types):
def legalize_concatenate(attrs, inputs, types):
# Check that the correct multi-input case is handled.
assert len(inputs) == 1
assert isinstance(inputs[0], tvm.relay.expr.Tuple)
assert len(arg_types) == 1
assert isinstance(arg_types[0], tvm.relay.ty.TupleType)
assert len(types) == 2
assert isinstance(types[0], tvm.relay.ty.TupleType)
assert isinstance(types[1], tvm.relay.ty.TensorType)
return None

def expected():
Expand Down Expand Up @@ -153,9 +154,9 @@ def before():
return func

@register_legalize("nn.conv2d", level=101)
def legalize_conv2d(attrs, inputs, arg_types):
def legalize_conv2d(attrs, inputs, types):
from topi.arm_cpu.conv2d import _conv2d_legalize
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
return _conv2d_legalize(attrs, inputs, types)

a = before()
b = run_opt_pass(a, transform.Legalize())
Expand Down
31 changes: 23 additions & 8 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import tvm
from tvm import autotvm
from tvm import relay
import tvm.contrib.nnpack

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \
Expand Down Expand Up @@ -786,17 +787,31 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
return None

@conv2d_legalize.register("arm_cpu")
def _conv2d_legalize(attrs, inputs, arg_types, F):
if F.__name__ != 'tvm.relay.op':
return None
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""

if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
if attrs['kernel_layout'] == 'HWIO':
# Handle HWIO layout. This is common in TF graph.
kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
elif attrs['kernel_layout'] == 'HWOI':
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
elif attrs['kernel_layout'] != 'OIHW':
return None

Expand All @@ -808,9 +823,9 @@ def _conv2d_legalize(attrs, inputs, arg_types, F):
new_attrs['kernel_layout'] = 'OIHW'

# Convert from NHWC to NCHW.
data = F.transpose(data, axes=(0, 3, 1, 2))
conv = F.nn.conv2d(data, kernel, **new_attrs)
data = relay.transpose(data, axes=(0, 3, 1, 2))
conv = relay.nn.conv2d(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = F.transpose(conv, axes=(0, 2, 3, 1))
out = relay.transpose(conv, axes=(0, 2, 3, 1))
return out
return None
21 changes: 10 additions & 11 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,21 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N


@tvm.target.generic_func
def conv2d_legalize(attrs, inputs, arg_dtypes, F):
def conv2d_legalize(attrs, inputs, types):
"""Legalizes Conv2D op.
Parameters
----------
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized.
arg_dtypes : list of types
List of types of input arguments
F: symbol
The context, can be either nnvm.sym or relay.op
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level,
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# not to change by default
return None
Expand Down

0 comments on commit 506d83d

Please sign in to comment.