diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index c9ce0ecf0dd8..2335c598fef9 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -57,7 +57,10 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale - CHECK(IsScalarType(types[5], DataType::Float(32))); // kernel_scale + // Kernel scale can be a vector of length output_channels or a scalar. + size_t axis = param->kernel_layout.find('O'); + CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; + AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 2e332413c1f6..2316bed5f5aa 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -152,6 +152,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multiplier, */ static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) { const auto* tensor_type = expr_type.as(); + CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got" + << AsText(expr_type, false); CHECK_EQ(tensor_type->shape.size(), 0); CHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype; return true; @@ -168,6 +170,8 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons const TypeReporter& reporter) { // Scale/Zero_points can be either const scalar or a vector with C axis num elems. const auto* tensor_type = expr_type.as(); + CHECK(tensor_type) << "Can assign type to Tensor type only. But got " + << AsText(expr_type, false); const auto tensor_dtype = tensor_type->dtype; CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype; if (tensor_type->shape.size() != 0) { diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 9effa6f11ebe..9631ffc2faf3 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -768,8 +768,8 @@ def test_depthwise_depth_multiplier(): channels=4) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) - - + + # Depthwise multiplier = 2 data_shape = (10, 4, 16, 16) data_dtype = 'uint8' @@ -794,7 +794,7 @@ def test_depthwise_depth_multiplier(): channels=8) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) - + # uint8 input, NHWC and HWOI # Depthwise multiplier = 1 data_shape = (2, 16, 16, 4) @@ -820,7 +820,7 @@ def test_depthwise_depth_multiplier(): channels=4) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) - + # Depthwise multiplier = 2 data_shape = (2, 16, 16, 4) data_dtype = 'uint8' @@ -846,6 +846,35 @@ def test_depthwise_depth_multiplier(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) +def test_per_channel_kernel_scale(): + with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): + data_shape = (2, 1, 2, 4) + data_dtype = 'uint8' + kernel_shape = (3, 1, 2, 2) + kernel_dtype = 'uint8' + data = relay.var("data", shape=data_shape, + dtype=data_dtype) + kernel = relay.var("kernel", shape=kernel_shape, + dtype=kernel_dtype) + kernel_scales = [2, 2, 2] + kernel_scales = relay.const(np.array(kernel_scales).astype('float32')) + func = relay.qnn.op.conv2d( + data, kernel, + input_zero_point=relay.const(0, 'int32'), + kernel_zero_point=relay.const(0, 'int32'), + input_scale=relay.const(2.0, 'float32'), + kernel_scale=kernel_scales, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32") + + mod = relay.Function(relay.analysis.free_vars(func), func) + mod = relay.Module.from_expr(mod) + if __name__ == "__main__": test_no_zero_point() test_input_zero_point() @@ -861,3 +890,4 @@ def test_depthwise_depth_multiplier(): test_tflite_output_multiplier_greater_than_one() test_tflite_anistropic_strides() test_depthwise_depth_multiplier() + test_per_channel_kernel_scale()