From 9eb8b76ba158c9ce5ab913e0f2e44846fb6a90e7 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 2 Jul 2020 22:23:04 +0000 Subject: [PATCH] MXNet pre-quantized BERT --- include/tvm/relay/qnn/attrs.h | 13 ++ python/tvm/relay/frontend/mxnet.py | 165 +++++++++++++++---- python/tvm/relay/frontend/nnvm_common.py | 51 +++++- python/tvm/relay/qnn/op/qnn.py | 8 +- src/relay/qnn/op/dequantize.cc | 55 ++++++- tests/python/frontend/mxnet/test_forward.py | 38 +++++ tests/python/relay/test_op_qnn_dequantize.py | 30 +++- 7 files changed, 306 insertions(+), 54 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 4b5cd89f0b0c..c5213fe07471 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -75,6 +75,19 @@ struct QuantizeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attribute for dequantize operator */ +struct DequantizeAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") { + TVM_ATTR_FIELD(axis) + .describe( + "The channel axis for channel wise dequantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + } +}; + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 97b9d7a44997..327bcd483c67 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1944,18 +1944,27 @@ def _qnn_batch_norm(inputs, attrs): def _qnn_fully_connected(inputs, attrs, subgraphs, params): - def _get_input_scale_zp(_data, _inputs, _has_bias): + def _get_input_scale_zp(_data_dtype, _inputs, _has_bias): data_min_idx, data_max_idx = (3, 4) if _has_bias else (2, 3) data_min, data_max = _inputs[data_min_idx], _inputs[data_max_idx] - data_dtype = _infer_type(_data).checked_type.dtype _data_scale = get_mkldnn_uint8_scale(data_min, data_max) \ - if data_dtype == 'uint8' \ + if _data_dtype == 'uint8' \ else get_mkldnn_int8_scale(data_min, data_max) _data_zp = 0 return _data_scale, _data_zp - def _get_kernel_scale_zp(_kernel, _inputs, _has_bias): + def _get_kernel_scale_zp_tensor_quantized(_kernel, _inputs, _has_bias): kernel_dtype = _infer_type(_kernel).checked_type.dtype + + if kernel_dtype != "int8": + raise tvm.error.OpNotImplemented(\ + "Tensor wise quantized expects weights in int8 data type") + + if isinstance(_kernel, tvm.relay.Call) and _kernel.op.name == "qnn.quantize": + _kernel_scale = _kernel.args[1].data.asnumpy() + _kernel_zp = _kernel.args[2].data.asnumpy() + return _kernel_scale, _kernel_zp + kernel_min_idx, kernel_max_idx = (5, 6) if _has_bias else (4, 5) kernel_min_name = _get_name(_inputs[kernel_min_idx]) kernel_min = params[kernel_min_name].asnumpy()[0] @@ -1967,7 +1976,34 @@ def _get_kernel_scale_zp(_kernel, _inputs, _has_bias): _kernel_zp = 0 return _kernel_scale, _kernel_zp + def _get_kernel_scale_zp_channel_quantized(_kernel, _bias, _data_scale): + kernel_dtype = _infer_type(_kernel).checked_type.dtype + if kernel_dtype != "float32": + raise tvm.error.OpNotImplemented(\ + "Channel wise quantized expects weights in float32 data type") + + # Get the FP32 values, calculate min/max and then channel quantize them + np_kernel = _infer_value(_kernel, params).asnumpy() + kernel_channel_min = np.amin(np_kernel, axis=(1, )) + kernel_channel_max = np.amax(np_kernel, axis=(1, )) + + np_bias = None + if _bias is not None: + np_bias = _infer_value(_bias, params).asnumpy() + return quantize_conv_weights_bias_channel_mkldnn_from_var(_kernel, + np_bias, + kernel_channel_min, + kernel_channel_max, + _data_scale) + def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): + _bias = _inputs[2] + if isinstance(_bias, tvm.relay.Call) and _bias.op.name == "qnn.quantize": + _bias_scale = _bias.args[1].data.asnumpy() + _bias_requantize_scale = _bias_scale/(_data_scale * _kernel_scale) + _bias_requantize_scale = _expr.const(_bias_requantize_scale, dtype="float32") + return _bias_requantize_scale + bias_min_name = _get_name(_inputs[7]) bias_min = params[bias_min_name].asnumpy()[0] bias_max_name = _get_name(_inputs[8]) @@ -1987,16 +2023,48 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): return res else: has_bias = not subgraph_dense_attrs.get_bool("no_bias", False) - # input - data = inputs[0] - data_scale, data_zp = _get_input_scale_zp(data, inputs, has_bias) - # kernel - kernel = inputs[1] - kernel_scale, kernel_zp = _get_kernel_scale_zp(kernel, inputs, has_bias) units = subgraph_dense_attrs.get_int("num_hidden") + is_flatten = subgraph_dense_attrs.get_bool("flatten", True) + enable_float_output = attrs.get_bool('enable_float_output', False) + is_channel_quantized = attrs.get_bool('channel_wise_quantize', False) + + ######################## + # Get data, kernel, bias + ######################## + data, kernel = inputs[0], inputs[1] + bias = None + if has_bias: + bias = inputs[2] + + ############################## + # Handle for shape of data > 2 + ############################## + if is_flatten: + data = _op.nn.batch_flatten(data) data_shape = _infer_type(data).checked_type.shape if len(data_shape) > 2: - data = _op.nn.batch_flatten(data) + data = _op.reverse_reshape(data, [-1, 0]) + + ############################### + # Get data scale and zero point + ############################### + data_dtype = _infer_type(data).checked_type.dtype + data_scale, data_zp = _get_input_scale_zp(data_dtype, inputs, has_bias) + + ################################# + # Get weight scale and zero point + ################################# + if is_channel_quantized: + kernel, kernel_scale, kernel_zp = _get_kernel_scale_zp_channel_quantized(kernel, + bias, + data_scale) + else: + kernel_scale, kernel_zp = _get_kernel_scale_zp_tensor_quantized(kernel, inputs, + has_bias) + + ################ + # Call QNN dense + ################ res = relay.qnn.op.dense(data, kernel, input_zero_point=relay.const(data_zp, 'int32'), @@ -2004,22 +2072,46 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): input_scale=relay.const(data_scale, 'float32'), kernel_scale=relay.const(kernel_scale, 'float32'), units=units) + + ################# + # Handle bias add + ################# if has_bias: - bias_data = inputs[2] - bias_requantize_scale = \ - _get_bias_requantize_scale(inputs, data_scale, kernel_scale) - multiplied_bias = \ - _op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale) - rounded_bias = _op.round(multiplied_bias) - clipped_bias = _op.clip(rounded_bias, - a_min=tvm.tir.op.min_value('int32').value, - a_max=tvm.tir.op.max_value('int32').value) - requantized_bias = _op.cast(clipped_bias, 'int32') - res = _op.nn.bias_add(res, requantized_bias, axis=-1) - enable_float_output = attrs.get_bool('enable_float_output', False) - out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8' - input_scale = np.float32(data_scale * kernel_scale) - if not enable_float_output: + if is_channel_quantized: + bias_scale = data_scale * kernel_scale + int32_bias = quantize_conv_bias_mkldnn_from_var(bias, bias_scale) + res = _op.nn.bias_add(res, int32_bias, axis=-1) + else: + bias_data = inputs[2] + bias_requantize_scale = \ + _get_bias_requantize_scale(inputs, data_scale, kernel_scale) + multiplied_bias = \ + _op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale) + rounded_bias = _op.round(multiplied_bias) + clipped_bias = _op.clip(rounded_bias, + a_min=tvm.tir.op.min_value('int32').value, + a_max=tvm.tir.op.max_value('int32').value) + requantized_bias = _op.cast(clipped_bias, 'int32') + res = _op.nn.bias_add(res, requantized_bias, axis=-1) + + ############################################## + # Dequantize if float32 output else Requantize + ############################################## + if enable_float_output: + output_scale = np.float32(data_scale * kernel_scale) + res = relay.qnn.op.dequantize(res, + relay.const(output_scale), + input_zero_point=relay.const(0, 'int32'), + axis=1) + if with_relu: + res = _op.nn.relu(res) + else: + + if is_channel_quantized: + raise tvm.error.OpNotImplemented(\ + "Channel wise quantized dense with non float output is not supported yet") + out_dtype = 'uint8' if attrs.get_bool('with_relu', False) else 'int8' + input_scale = np.float32(data_scale * kernel_scale) min_output_range = attrs.get_float('min_calib_range') max_output_range = attrs.get_float('max_calib_range') output_scale = get_mkldnn_requantize_scale_outDtype(min_output_range, @@ -2034,17 +2126,20 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): out_dtype=out_dtype) if with_relu: res = _op.nn.relu(res) - return res, min_output_range, max_output_range - else: - output_scale = np.float32(data_scale * kernel_scale) - res = relay.qnn.op.dequantize(res, - relay.const(output_scale, 'float32'), - input_zero_point=relay.const(0, 'int32')) - if with_relu: - res = _op.nn.relu(res) - return res + ############################## + # Handle for shape of data > 2 + ############################## + if len(data_shape) > 2: + new_shape = data_shape[:-1] + new_shape.append(units) + res = _op.reshape(res, new_shape) + + if enable_float_output: + return res + return res, min_output_range, max_output_range + def _mx_broadcast_to(inputs, attrs): data = inputs[0] tgt_shape = attrs.get_int_tuple("shape", []) diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index a2eea94b06f0..cfb63620c07b 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -17,10 +17,13 @@ # pylint: disable=invalid-name, import-self, len-as-condition """Utility functions common to NNVM and MxNet conversion.""" import warnings +from ... import error +from ...tir.op import min_value from .. import expr as _expr from .. import op as _op from .common import get_relay_op from .common import infer_type as _infer_type +from .common import infer_shape as _infer_shape def _warn_not_used(attr, op='nnvm'): err = "{} is ignored in {}.".format(attr, op) @@ -57,9 +60,53 @@ def _impl(inputs, attrs): def _softmax_op(new_op): """softmax/log_softmax""" def _impl(inputs, attrs, _dtype='float32'): - # TODO(@icemelon9): currently ignore the 2nd input to softmax for mxnet 1.6 - # assert len(inputs) == 1 axis = attrs.get_int("axis", -1) + use_length = attrs.get_bool("use_length", False) + if use_length: + # The second arg is valid_length. We can use sequence mask to mask the input before + # computing softmax + assert len(inputs) == 2 + + data = inputs[0] + length = inputs[1] + data_shape = _infer_shape(data) + length_shape = _infer_shape(length) + + if axis < 0: + axis = len(data_shape) + axis + + data_ndims = len(data_shape) + length_ndims = len(length_shape) + + # Sequence_mask supports axis = 0 and 1 and requires data to be in specific format. + if axis == data_ndims - 1 and data_ndims > 2 and length_ndims == 2: + new_batch_size = 1 + for dim in range(length_ndims): + assert data_shape[dim] == length_shape[dim] + new_batch_size *= data_shape[dim] + + # Reshape the data and length to satisfy sequence mask + data = _op.reshape(data, newshape=(new_batch_size, -1)) + length = _op.reshape(length, newshape=(new_batch_size)) + + # Input data is now 2D, we can set the axis = 1 + axis = 1 + elif data_ndims > 2: + raise error.OpNotImplemented(\ + "Operator softmax with use_length=True is supported only for axis -1") + + res = _op.sequence_mask(data=data, + valid_length=length, + mask_value=float(min_value("float").value), + axis=axis) + + # Apply softmax + res = new_op(res, axis=axis) + + # Reshape back to input data shape + if len(data_shape) > 2: + return _op.reshape(res, newshape=data_shape) + return res return new_op(inputs[0], axis=axis) return _impl diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 5a3106d1e787..14d74bfa42fb 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -121,7 +121,8 @@ def quantize(data, def dequantize(data, input_scale, - input_zero_point): + input_zero_point, + axis=-1): r""" Dequantize op This operator takes quantized int8 and unit8 as input and produces dequantized float32 as output. The output shape is the same as input shape. The input @@ -135,6 +136,8 @@ def dequantize(data, The input zero_point. input_scale : tvm.relay.Expr The input scale. + axis : int + The channel axis for quantization. Default value is -1 which corresponds to the last axis. Returns ------- result : tvm.relay.Expr @@ -143,7 +146,8 @@ def dequantize(data, return _make.dequantize(data, input_scale, - input_zero_point) + input_zero_point, + axis) def concatenate(data, diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 7c014d71a76a..da804dace60d 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -34,6 +34,8 @@ namespace tvm { namespace relay { namespace qnn { +TVM_REGISTER_NODE_TYPE(DequantizeAttrs); + bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); @@ -45,9 +47,16 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, << "Input type should be one of the quantized types [unit8, int8, int32] but was " << input_dtype; - // Check the types of scale and zero points. - CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale - CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point + const auto* dequantize_attrs = attrs.as(); + int axis = dequantize_attrs->axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; + CHECK_LT(axis, static_cast(data->shape.size())) + << "axis " << dequantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + + // Check and assign types for scale and zero points. + AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale + AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point const Array oshape = data->shape; // assign output type, output will always be float 32. @@ -55,16 +64,34 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point) { +Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis) { // real_value = scale * (quantized_value - zero_point) // A more detailed explanation can be found here - // https://github.com/google/gemmlowp/blob/master/doc/quantization.md + auto attrs = make_object(); + attrs->axis = axis; static const Op& op = Op::Get("qnn.dequantize"); - return Call(op, {data, input_scale, input_zero_point}, Attrs(), {}); + return Call(op, {data, input_scale, input_zero_point}, Attrs(attrs), {}); } Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, - const Expr& input_zero_point) { + const Expr& input_zero_point, const Array& input_shape, + const DequantizeAttrs* attrs) { + const auto axis = attrs->axis; + + size_t n_dim = input_shape.size(); + + // Expand scale and zero point if the input tensor is channel quantized + auto expanded_input_scale = input_scale; + if (!IsConstScalar(input_scale)) { + expanded_input_scale = ExpandBiasToMatchAxis(input_scale, n_dim, {axis}); + } + + auto expanded_input_zero_point = input_zero_point; + if (!IsConstScalar(input_zero_point)) { + expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis}); + } + auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point); auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale); return scaled_output; @@ -77,7 +104,20 @@ Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, auto& input_scale = new_args[1]; auto& input_zero_point = new_args[2]; CHECK_EQ(types.size(), 4); - return DequantizeLower(data, input_scale, input_zero_point); + + // Get attrs. + const auto* dequantize_attrs = attrs.as(); + CHECK(dequantize_attrs != nullptr); + + // Find input shape. + CHECK_EQ(types.size(), 4); + auto in_type = types[0]; + auto in_tensor_type = in_type.as(); + CHECK(in_tensor_type != nullptr) << "Type information missing." + << " Please run infer_type pass."; + Array input_shape = in_tensor_type->shape; + + return DequantizeLower(data, input_scale, input_zero_point, input_shape, dequantize_attrs); } RELAY_REGISTER_OP("qnn.dequantize") @@ -85,6 +125,7 @@ RELAY_REGISTER_OP("qnn.dequantize") The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point. - **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point )code" TVM_ADD_FILELINE) + .set_attrs_type() .set_num_inputs(3) .add_argument("data", "Tensor", "The tensor to dequantize.") .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index c8bbf88c96ef..48ad7361d9cf 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1373,6 +1373,43 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn verify((1, 10, 4), (1, 10, 4), in_format="center") +def test_forward_softmax(): + def verify(data_shape, axis, use_length, length): + dtype = "float32" + x = np.random.uniform(low=-100, high=100, size=data_shape).astype(dtype) + if use_length: + ref_res = mx.nd.softmax(data=mx.nd.array(x), + length=mx.nd.array(length, dtype="int32"), + axis=axis, use_length=use_length) + mx_sym = mx.symbol.softmax(data=mx.sym.var("data"), + length=mx.sym.var("length"), + axis=axis, use_length=use_length) + shape_dict = {"data": data_shape, "length": (length.shape)} + dtype_dict = {"data": dtype, "length": "int32"} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) + else: + ref_res = mx.nd.softmax(data=mx.nd.array(x), axis=axis) + mx_sym = mx.symbol.softmax(data=mx.sym.var("data"), axis=axis) + shape_dict = {"data": data_shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + if use_length: + op_res = intrp.evaluate()(x, length) + else: + op_res = intrp.evaluate()(x) + + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) + + verify((2, 3, 5), -1, False, None) + verify((2, 3, 5), 2, False, None) + verify((2, 3), -1, True, np.array([2, 1]).astype('int32')) + verify((2, 3, 4), -1, True, np.array([[3, 4, 2], [2, 1, 1]]).astype('int32')) + verify((2, 3, 4), 2, True, np.array([[3, 4, 2], [1, 2, 1]]).astype('int32')) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -1449,3 +1486,4 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn test_forward_box_decode() test_forward_amp_multicast() test_forward_amp_cast() + test_forward_softmax() diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index 3c82b7fa0afa..361d6f0d411d 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -21,13 +21,14 @@ from tvm import relay from tvm.contrib import graph_runtime -def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): +def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis): shape = in_data.shape input_data = relay.var("input_data", shape=shape, dtype=in_dtype) input_zero_point = relay.const(quant_args['in_zero_point'], 'int32') input_scale = relay.const(quant_args['in_scale'], 'float32') quantized_output = relay.qnn.op.dequantize(input_data, input_scale=input_scale, - input_zero_point=input_zero_point) + input_zero_point=input_zero_point, + axis=axis) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = tvm.IRModule.from_expr(mod) with tvm.transform.PassContext(opt_level=3): @@ -48,8 +49,8 @@ def test_uint8_to_float32(): .astype('float32') \ .reshape((2, 5)) quant_args = {"in_zero_point":127, "in_scale":0.5} - quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data, - verify_output_data=output) + dequantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data, + verify_output_data=output, axis=-1) def test_int8_to_float32(): data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \ @@ -59,18 +60,31 @@ def test_int8_to_float32(): .astype('float32') \ .reshape((2, 5)) quant_args = {"in_zero_point": -1, "in_scale": 0.5} - quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data, - verify_output_data=output) + dequantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data, + verify_output_data=output, axis=-1) def test_int32_to_float32(): data = np.array([113, 29, -1052]).astype('int32') output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32') quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604} - quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data, - verify_output_data=output) + dequantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data, + verify_output_data=output, axis=-1) + + +def test_channelwise_axis_1(): + data = np.transpose(np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]) \ + .astype('uint8').reshape((2,5))) + output = np.transpose(np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) \ + .astype('float32').reshape((2,5))) + quant_args = {"in_zero_point" : np.array([127, 123]).astype('int32'), + "in_scale" : np.array([0.5, 0.25]).astype('float32')} + + dequantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data, + verify_output_data=output, axis=1) if __name__ == "__main__": test_uint8_to_float32() test_int8_to_float32() test_int32_to_float32() + test_channelwise_axis_1()