diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d889631a4cd8..e92e4cef205d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -982,6 +982,7 @@ def convert_fully_connected(self, op): weight_value = self.get_tensor_value(weight_tensor) weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) + weight_shape = _infer_shape(weight_expr) if input_tensor.qnn_params: out = _qnn.op.dense(in_expr, weight_expr, @@ -989,6 +990,7 @@ def convert_fully_connected(self, op): kernel_zero_point=weight_tensor.qnn_params['zero_point'], input_scale=input_tensor.qnn_params['scale'], kernel_scale=weight_tensor.qnn_params['scale'], + units=weight_shape[0], out_dtype='int32') else: out = _op.nn.dense(in_expr, weight_expr) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index f76d7b3df9bf..a0eef8d7bd37 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -345,7 +345,7 @@ def dense(data, kernel_zero_point, input_scale, kernel_scale, - units=None, + units, out_dtype="int32"): """Qnn Dense operator. Applies a quantized linear transformation @@ -371,7 +371,7 @@ def dense(data, stored for access to this during relay. This information is not needed in the pass pipeline after qnn.conv2d is lowered to the sequence of steps as in nn.conv2d. See also input_scale in Requantize. - units : int, optional + units : int Number of hidden units of the dense transformation. out_dtype : str, optional Specifies the output data type for mixed precision dense can be int32 or int16. diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index b7a12e1a64b3..de3c4dbc7dc1 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -55,7 +55,7 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale - CHECK(IsScalarType(types[5], DataType::Float(32))); // kernel_scale + AssignType(types[5], DataType::Float(32), param->units, reporter); CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; diff --git a/tests/python/relay/test_op_qnn_dense.py b/tests/python/relay/test_op_qnn_dense.py index 0e7c284653f4..43600cbf60c5 100644 --- a/tests/python/relay/test_op_qnn_dense.py +++ b/tests/python/relay/test_op_qnn_dense.py @@ -75,52 +75,8 @@ def make_configuration(quantized_data, return config -def make_uint_configuration(use_bias=False, requantize_output=False): - input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) - input_zero_point, kernel_zero_point = 127, 127 - input_scale = 0.5 - kernel_scale = 0.5 - output_scale = 1.0 - in_dtype = 'uint8' - out_dtype = 'int32' if not requantize_output else 'uint8' - units = 3 - quantized_data_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 109, 107, - 129, 131, 133, 135, 137, 139, 141, 111, 145, 107]) \ - .astype(in_dtype) \ - .reshape(input_shape) - quantized_kernel_np = np.array([129, 131, 133, 135, 137, 139, 141, 143, 145, 147, - 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, - 129, 131, 133, 135, 137, 139, 141, 143, 145, 147]) \ - .astype(in_dtype) \ - .reshape(kernel_shape) - bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None - requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, 127, 'uint8') if requantize_output else None - - if requantize_output: - assert use_bias - output = np.array([151, 152, 153, 185, 186, 187]) - elif use_bias: - output = np.array([96, 100, 104, 232, 236, 240 ]) - else: - output = np.array([92, 92, 92, 228, 228, 228 ]) - output = output.astype(out_dtype).reshape(output_shape) - return make_configuration(quantized_data=quantized_data_np, - quantized_kernel=quantized_kernel_np, - dtype=in_dtype, - input_shape=input_shape, - kernel_shape=kernel_shape, - input_zero_point=input_zero_point, - kernel_zero_point=kernel_zero_point, - input_scale=input_scale, - kernel_scale= kernel_scale, - units=units, - output=output, - bias=bias, - requantize=requant_params) - - -def make_int_configuration(use_bias=False, requantize_output=False): - input_shape, kernel_shape, output_shape = (2, 10), (3,10), (2, 3) +def make_int_configuration(use_bias=False, requantize_output=False, per_channel=False): + input_shape, kernel_shape, output_shape = (2, 10), (3, 10), (2, 3) input_zero_point, kernel_zero_point = -1, -1 in_dtype = 'int8' out_dtype = 'int32' if not requantize_output else 'int8' @@ -138,15 +94,22 @@ def make_int_configuration(use_bias=False, requantize_output=False): kernel_scale = 0.5 output_scale = 1.0 bias = np.array([4, 8, 12]).astype(out_dtype).reshape((units, )) if use_bias else None - requant_params = make_requantize_params(input_scale * kernel_scale, output_scale, -1, 'int8') if requantize_output else None - if requantize_output: + if per_channel: + assert use_bias and requantize_output + kernel_scale = np.array([0.5, 0.3, 0.4], dtype=np.float32) + output = np.array([23, 14, 20, 57, 34, 47]) + elif requantize_output: assert use_bias output = np.array([23, 24, 25, 57, 58, 59]) elif use_bias: - output = np.array([96, 100, 104, 232, 236, 240 ]) + output = np.array([96, 100, 104, 232, 236, 240]) else: - output = np.array([92, 92, 92, 228, 228, 228 ]) + output = np.array([92, 92, 92, 228, 228, 228]) + + requant_params = make_requantize_params(input_scale * kernel_scale, + output_scale, -1, 'int8') if requantize_output else None + output = output.astype(out_dtype).reshape(output_shape) return make_configuration(quantized_data=quantized_data_np, quantized_kernel=quantized_kernel_np, @@ -206,8 +169,8 @@ def qnn_dense_driver(test_configuration): with relay.build_config(opt_level=2): graph, lib, params = relay.build(mod, "llvm", params=None) mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) - mod.set_input(quantized_data_name,test_configuration[quantized_data_name]) - mod.set_input(quantized_kernel_name,test_configuration[quantized_kernel_name]) + mod.set_input(quantized_data_name, test_configuration[quantized_data_name]) + mod.set_input(quantized_kernel_name, test_configuration[quantized_kernel_name]) if test_configuration[bias_name] is not None: mod.set_input(bias_name, test_configuration[bias_name]) mod.set_input(**params) @@ -241,7 +204,15 @@ def test_qnn_dense_with_requantized_output(): qnn_dense_driver(int8_requantized_output_with_bias_params) +def test_per_channel_weight_scale(): + with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_dense): + config = make_int_configuration(use_bias=True, requantize_output=True, + per_channel=True) + qnn_dense_driver(config) + + if __name__ == "__main__": test_qnn_dense_without_bias() test_qnn_dense_with_bias() test_qnn_dense_with_requantized_output() + test_per_channel_weight_scale() diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 38fdb7dd07b1..37635e3daf7a 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -190,6 +190,7 @@ def _get_mod(data_dtype, kernel_dtype): kernel_zero_point=relay.const(1, 'int32'), input_scale=relay.const(1, 'float32'), kernel_scale=relay.const(1, 'float32'), + units=kernel_shape[0], out_dtype='int32') mod = relay.Function(relay.analysis.free_vars(func), func)