diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c896a002334b5..a67f6dfc22a05 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -205,10 +205,26 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) @reg.register_legalize("nn.conv2d") -def legalize_conv2d(attrs, inputs, arg_dtypes): - """Legalize conv2d""" +def legalize_conv2d(attrs, inputs, types): + """Legalize conv2d op. + Parameters + ---------- + attrs : nnvm.top.AttrDict or tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized. + types : list of types + List of input and output types + F: symbol + The context, can be either nnvm.sym or relay.op + + Returns + ------- + result : tvm.relay.Expr + The legalized expr. + """ from ... import op - return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op) + return topi.nn.conv2d_legalize(attrs, inputs, types, op) reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc index c041cb9c668c7..0079dabb88b3a 100644 --- a/src/relay/pass/legalize.cc +++ b/src/relay/pass/legalize.cc @@ -42,11 +42,17 @@ Expr Legalizer(const Call& ref_call, const Array& new_args, const NodeRef& Expr new_e; bool modified = false; if (fop_legalize.count(op)) { - tvm::Array arg_types; + // Collect input and output dtypes to pass on to Legalize API. + tvm::Array types; for (auto& expr : ref_call->args) { - arg_types.push_back(expr->checked_type()); + types.push_back(expr->checked_type()); } - Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, arg_types); + types.push_back(ref_call->checked_type()); + + // Transform the op by calling the registered legalize function. + Expr legalized_value = fop_legalize[op](ref_call->attrs, new_args, types); + + // Check if the transformation succeeded. If not, revert back to the original ref_call->op. if (legalized_value.defined()) { new_e = legalized_value; modified = true; diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 1e59440149346..e42be2af9b788 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -74,12 +74,12 @@ Expr DequantizeLower(const Expr& input_tensor, Expr DequantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& data = new_args[0]; const auto* dequantize_attrs = attrs.as(); CHECK(dequantize_attrs != nullptr); - CHECK_EQ(arg_types.size(), 1); + CHECK_EQ(types.size(), 2); return DequantizeLower(data, dequantize_attrs); } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 2f494008cc46c..675cd4c5a7007 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -85,13 +85,13 @@ Expr QuantizeLower(const Expr& input_tensor, Expr QuantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& data = new_args[0]; const auto* quantize_attrs = attrs.as(); CHECK(quantize_attrs != nullptr); - CHECK_EQ(arg_types.size(), 1); + CHECK_EQ(types.size(), 2); return QuantizeLower(data, quantize_attrs); } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index e3052b71c4caf..e0f3962a8f25e 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -109,7 +109,7 @@ std::pair GetFixedPointMultiplierShift(double double_multiplie * 7) Cast to the out_dtype. */ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, - const Array& input_shape) { + const Array& input_shape, const DataType& out_dtype) { double double_multiplier = param->input_scale / param->output_scale; // Choose high precision datatype to be int64. This is for avoiding overflow @@ -173,10 +173,10 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, auto shifted_int64_t = Add(output_zp, scaled_int64_t); // 7) Clip to the out_dtype min/max. - auto q_min = GetQmin(param->out_dtype); - auto q_max = GetQmax(param->out_dtype); + auto q_min = GetQmin(out_dtype); + auto q_max = GetQmax(out_dtype); auto clipped_t = Clip(shifted_int64_t, q_min, q_max); - return Cast(clipped_t, param->out_dtype); + return Cast(clipped_t, out_dtype); } /* @@ -193,25 +193,32 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) */ Expr RequantizeLegalize(const Attrs& attrs, const Array& new_args, - const Array& arg_types) { + const Array& types) { CHECK_EQ(new_args.size(), 1); auto& quantized_data = new_args[0]; const auto* param = attrs.as(); CHECK(param != nullptr); // Find input shape. - CHECK_EQ(arg_types.size(), 1); - auto input_dtype = arg_types[0]; - auto input_tensor_type = input_dtype.as(); - CHECK(input_tensor_type != nullptr) << "Type information missing." - << " Please run infer_type pass."; - Array input_shape = input_tensor_type->shape; + CHECK_EQ(types.size(), 2); + auto in_type = types[0]; + auto in_tensor_type = in_type.as(); + CHECK(in_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = in_tensor_type->shape; + + // Find the output dtype. + auto out_type = types[1]; + auto out_tensor_type = out_type.as(); + CHECK(out_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + auto out_dtype = out_tensor_type->dtype; // Check rounding validity. CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") << "QNN requantize supports two rounding modes - UPWARD and " << "TONEAREST"; - return RequantizeLower(quantized_data, param, input_shape); + return RequantizeLower(quantized_data, param, input_shape, out_dtype); } /* @@ -261,7 +268,7 @@ The requantize operator converts one quantized tensor to another quantized tensor. For the output tensor, we are provided with output scale and zero point. The computation looks like this -Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) +Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) )code" TVM_ADD_FILELINE) .set_attrs_type_key("relay.attrs.RequantizeAttrs") diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 52deeb58ca351..0dbb5e243e3a7 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -47,7 +47,7 @@ def before(): return y @register_legalize("nn.conv2d", level=100) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): data, weight = inputs weight = relay.multiply(weight, relay.const(2.0, "float32")) return relay.nn.conv2d(data, weight, **attrs) @@ -80,7 +80,7 @@ def before(): called = [False] @register_legalize("nn.global_max_pool2d", level=101) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): called[0] = True return None @@ -103,12 +103,13 @@ def before(): return func @register_legalize("concatenate", level=100) - def legalize_concatenate(attrs, inputs, arg_types): + def legalize_concatenate(attrs, inputs, types): # Check that the correct multi-input case is handled. assert len(inputs) == 1 assert isinstance(inputs[0], tvm.relay.expr.Tuple) - assert len(arg_types) == 1 - assert isinstance(arg_types[0], tvm.relay.ty.TupleType) + assert len(types) == 2 + assert isinstance(types[0], tvm.relay.ty.TupleType) + assert isinstance(types[1], tvm.relay.ty.TensorType) return None def expected(): @@ -153,9 +154,9 @@ def before(): return func @register_legalize("nn.conv2d", level=101) - def legalize_conv2d(attrs, inputs, arg_types): + def legalize_conv2d(attrs, inputs, types): from topi.arm_cpu.conv2d import _conv2d_legalize - return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op) + return _conv2d_legalize(attrs, inputs, types, tvm.relay.op) a = before() b = run_opt_pass(a, transform.Legalize()) diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index e7ab7ba99990d..4387552dea40f 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -72,7 +72,7 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N @tvm.target.generic_func -def conv2d_legalize(attrs, inputs, arg_dtypes, F): +def conv2d_legalize(attrs, inputs, types, F): """Legalizes Conv2D op. Parameters ---------- @@ -80,8 +80,8 @@ def conv2d_legalize(attrs, inputs, arg_dtypes, F): Attributes of current convolution inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized. - arg_dtypes : list of types - List of types of input arguments + types : list of types + List of input and output types F: symbol The context, can be either nnvm.sym or relay.op Note