From 506d83d8cd5161d581826669003c26ff3eda4519 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 15 Aug 2019 19:10:21 +0000 Subject: [PATCH] [Legalize][QNN] Pass out_types to Legalize. Update QNN requantize to read from out_types. --- python/tvm/relay/op/nn/_nn.py | 21 ++++++++++++--- src/relay/pass/legalize.cc | 12 ++++++--- src/relay/qnn/op/dequantize.cc | 4 +-- src/relay/qnn/op/quantize.cc | 4 +-- src/relay/qnn/op/requantize.cc | 33 ++++++++++++++---------- tests/python/relay/test_pass_legalize.py | 15 ++++++----- topi/python/topi/arm_cpu/conv2d.py | 31 ++++++++++++++++------ topi/python/topi/nn/conv2d.py | 21 +++++++-------- 8 files changed, 91 insertions(+), 50 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c896a002334b5..66e93250a009e 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -205,10 +205,23 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) @reg.register_legalize("nn.conv2d") -def legalize_conv2d(attrs, inputs, arg_dtypes): - """Legalize conv2d""" - from ... import op - return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op) +def legalize_conv2d(attrs, inputs, types): + """Legalize conv2d op. + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.conv2d_legalize(attrs, inputs, types) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc index c041cb9c668c7..0079dabb88b3a 100644 --- a/src/relay/pass/legalize.cc +++ b/src/relay/pass/legalize.cc @@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array& new_args, const NodeRef& Expr new_e; bool modified = false; if (fop_legalize.count(op)) { - tvm::Array arg_types; + // Collect input and output dtypes to pass on to Legalize API. + tvm::Array types; for (auto& expr : ref_call->args) { - arg_types.push_back(expr->checked_type()); + types.push_back(expr->checked_type()); } - Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types); + types.push_back(ref_call->checked_type()); + + // Transform the op by calling the registered legalize function. + Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types); + + // Check if the transformation succeeded. If not, revert back to the original ref_call->op. if (legalized_value.defined()) { new_e = legalized_value; modified = true; diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 1e59440149346..e42be2af9b788 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor, Expr DequantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& data = new_args[0]; const auto* dequantize_attrs = attrs.as(); CHECK(dequantize_attrs != nullptr); - CHECK_EQ(arg_types.size(), 1); + CHECK_EQ(types.size(), 2); return DequantizeLower(data, dequantize_attrs); } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 2f494008cc46c..675cd4c5a7007 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor, Expr QuantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& data = new_args[0]; const auto* quantize_attrs = attrs.as(); CHECK(quantize_attrs != nullptr); - CHECK_EQ(arg_types.size(), 1); + CHECK_EQ(types.size(), 2); return QuantizeLower(data, quantize_attrs); } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index e3052b71c4caf..ebc537e1ac900 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -109,7 +109,7 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie * 7) Cast to the out_dtype. */ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, - const Array& input_shape) { + const Array& input_shape, const DataType& out_dtype) { double double_multiplier = param->input_scale / param->output_scale; // Choose high precision datatype to be int64. This is for avoiding overflow @@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, auto shifted_int64_t = Add(output_zp, scaled_int64_t); // 7) Clip to the out_dtype min/max. - auto q_min = GetQmin(param->out_dtype); - auto q_max = GetQmax(param->out_dtype); + auto q_min = GetQmin(out_dtype); + auto q_max = GetQmax(out_dtype); auto clipped_t = Clip(shifted_int64_t, q_min, q_max); - return Cast(clipped_t, param->out_dtype); + return Cast(clipped_t, out_dtype); } /* @@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) */ Expr RequantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& quantized_data = new_args[0]; const auto* param = attrs.as(); CHECK(param != nullptr); // Find input shape. - CHECK_EQ(arg_types.size(), 1); - auto input_dtype = arg_types[0]; - auto input_tensor_type = input_dtype.as(); - CHECK(input_tensor_type != nullptr) << "Type information missing." - << " Please run infer_type pass."; - Array input_shape = input_tensor_type->shape; + CHECK_EQ(types.size(), 2); + auto in_type = types[0]; + auto in_tensor_type = in_type.as(); + CHECK(in_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = in_tensor_type->shape; + + // Find the output dtype. + auto out_type = types[1]; + auto out_tensor_type = out_type.as(); + CHECK(out_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + auto out_dtype = out_tensor_type->dtype; // Check rounding validity. CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") << "QNN requantize supports two rounding modes - UPWARD and " << "TONEAREST"; - return RequantizeLower(quantized_data, param, input_shape); + return RequantizeLower(quantized_data, param, input_shape, out_dtype); } /* @@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized tensor. For the output tensor, we are provided with output scale and zero point. The computation looks like this -Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) +Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.RequantizeAttrs") diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 52deeb58ca351..393c86282be68 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -47,7 +47,7 @@ def before(): return y @register_legalize("nn.conv2d", level=100) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): data, weight = inputs weight = relay.multiply(weight, relay.const(2.0, "float32")) return relay.nn.conv2d(data, weight, **attrs) @@ -80,7 +80,7 @@ def before(): called = [False] @register_legalize("nn.global_max_pool2d", level=101) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): called[0] = True return None @@ -103,12 +103,13 @@ def before(): return func @register_legalize("concatenate", level=100) - def legalize_concatenate(attrs, inputs, arg_types): + def legalize_concatenate(attrs, inputs, types): # Check that the correct multi-input case is handled. assert len(inputs) == 1 assert isinstance(inputs[0], tvm.relay.expr.Tuple) - assert len(arg_types) == 1 - assert isinstance(arg_types[0], tvm.relay.ty.TupleType) + assert len(types) == 2 + assert isinstance(types[0], tvm.relay.ty.TupleType) + assert isinstance(types[1], tvm.relay.ty.TensorType) return None def expected(): @@ -153,9 +154,9 @@ def before(): return func @register_legalize("nn.conv2d", level=101) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): from topi.arm_cpu.conv2d import _conv2d_legalize - return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op) + return _conv2d_legalize(attrs, inputs, types) a = before() b = run_opt_pass(a, transform.Legalize()) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 95342b6896b76..31f5b4980760a 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -22,6 +22,7 @@ import tvm from tvm import autotvm +from tvm import relay import tvm.contrib.nnpack from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \ @@ -786,17 +787,31 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): return None @conv2d_legalize.register("arm_cpu") -def _conv2d_legalize(attrs, inputs, arg_types, F): - if F.__name__ != 'tvm.relay.op': - return None +def _conv2d_legalize(attrs, inputs, arg_types): + """Legalizes Conv2D op. + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + if attrs['data_layout'] == 'NHWC': data, kernel = inputs if attrs['kernel_layout'] == 'HWIO': # Handle HWIO layout. This is common in TF graph. - kernel = F.transpose(kernel, axes=(3, 2, 0, 1)) + kernel = relay.transpose(kernel, axes=(3, 2, 0, 1)) elif attrs['kernel_layout'] == 'HWOI': # Handle HWOI layout. This is common in TF depthwise conv2d graph. - kernel = F.transpose(kernel, axes=(2, 3, 0, 1)) + kernel = relay.transpose(kernel, axes=(2, 3, 0, 1)) elif attrs['kernel_layout'] != 'OIHW': return None @@ -808,9 +823,9 @@ def _conv2d_legalize(attrs, inputs, arg_types, F): new_attrs['kernel_layout'] = 'OIHW' # Convert from NHWC to NCHW. - data = F.transpose(data, axes=(0, 3, 1, 2)) - conv = F.nn.conv2d(data, kernel, **new_attrs) + data = relay.transpose(data, axes=(0, 3, 1, 2)) + conv = relay.nn.conv2d(data, kernel, **new_attrs) # Convert back to original NHWC layout. - out = F.transpose(conv, axes=(0, 2, 3, 1)) + out = relay.transpose(conv, axes=(0, 2, 3, 1)) return out return None diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index e7ab7ba99990d..bf1e2fd3dcebb 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -72,22 +72,21 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N @tvm.target.generic_func -def conv2d_legalize(attrs, inputs, arg_dtypes, F): +def conv2d_legalize(attrs, inputs, types): """Legalizes Conv2D op. Parameters ---------- - attrs : nnvm.top.AttrDict or tvm.attrs.Attrs + attrs : tvm.attrs.Attrs Attributes of current convolution inputs : list of tvm.relay.Expr - The args of the Relay expr to be legalized. - arg_dtypes : list of types - List of types of input arguments - F: symbol - The context, can be either nnvm.sym or relay.op - Note - ---- - Unlike other TOPI functions, this function operates on both graph level and operator level, - so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay. + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr """ # not to change by default return None