Skip to content

Commit

Permalink
[QNN] More doc fix on quantize and convolution (#4874)
Browse files Browse the repository at this point in the history
* [QNN] Doc fix on quantize and convolution

* update test
  • Loading branch information
masahi authored Feb 14, 2020
1 parent 7013fc9 commit 24c53a3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
10 changes: 5 additions & 5 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def quantize(data,
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8]
The data type of the input tensor. Can be [int8, uint8, int32]
Returns
-------
result : tvm.relay.Expr
Expand Down Expand Up @@ -202,11 +202,11 @@ def conv2d(data,
input_scale,
kernel_scale,
kernel_size,
channels,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
Expand Down Expand Up @@ -247,6 +247,9 @@ def conv2d(data,
kernel_size : tuple of int
The spatial width and height of the convolution kernel.
channels : int
Number of output channels of this convolution.
strides : tuple of int, optional
The strides of convolution.
Expand All @@ -259,9 +262,6 @@ def conv2d(data,
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
data_layout : str, optional
Layout of the input.
Expand Down
41 changes: 26 additions & 15 deletions tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def get_qnn_func(data,
data_layout,
kernel_layout,
out_dtype,
groups,
channels=None):
channels,
groups):
func = relay.qnn.op.conv2d(
data, kernel,
input_zero_point=relay.const(input_zero_point, 'int32'),
Expand Down Expand Up @@ -116,12 +116,23 @@ def get_funcs(data_shape,
data_layout,
kernel_layout,
out_dtype,
groups=1,
channels=None):
groups=1):
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
kernel = relay.var("kernel", shape=kernel_shape,
dtype=kernel_dtype)

if groups > 1:
channels = groups
elif kernel_layout == "OIHW":
channels = kernel_shape[0]
elif kernel_layout == "HWIO":
channels = kernel_shape[3]
elif kernel_layout == "HWOI":
channels = kernel_shape[2]
else:
raise NotImplementedError

ref_func = get_ref_func(data,
kernel,
input_zero_point,
Expand Down Expand Up @@ -152,8 +163,9 @@ def get_funcs(data_shape,
data_layout,
kernel_layout,
out_dtype,
groups,
channels)
channels,
groups)

return (ref_func, qnn_func)

def verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape,
Expand Down Expand Up @@ -418,7 +430,7 @@ def test_layout():
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

# NHWC and HWIO layout. Used in depthwise conv.
# NHWC and HWOI layout. Used in depthwise conv.
data_shape = (2, 2, 4, 1) # NHWC
data_dtype = 'uint8'
kernel_shape = (2, 2, 1, 1) # HWOI
Expand Down Expand Up @@ -568,6 +580,7 @@ def test_const_folding():
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32",
channels=kernel_shape[0],
groups=1)
folded_mod = transform.FoldConstant()(qnn_func)
folded_func = folded_mod["main"]
Expand Down Expand Up @@ -787,8 +800,8 @@ def test_depthwise_depth_multiplier():
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32",
groups=4,
channels=4)
groups=4)

verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

Expand All @@ -813,8 +826,7 @@ def test_depthwise_depth_multiplier():
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="int32",
groups=8,
channels=8)
groups=8)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

Expand All @@ -839,8 +851,7 @@ def test_depthwise_depth_multiplier():
data_layout="NHWC",
kernel_layout="HWOI",
out_dtype="int32",
groups=4,
channels=4)
groups=4)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

Expand All @@ -864,8 +875,7 @@ def test_depthwise_depth_multiplier():
data_layout="NHWC",
kernel_layout="HWOI",
out_dtype="int32",
groups=8,
channels=8)
groups=8)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)

Expand All @@ -888,6 +898,7 @@ def test_per_channel_kernel_scale():
input_scale=relay.const(2.0, 'float32'),
kernel_scale=kernel_scales,
kernel_size=(2, 2),
channels=kernel_shape[0],
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_pass_qnn_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _get_mod(data_dtype, kernel_dtype):
input_scale=relay.const(1.0, 'float32'),
kernel_scale=relay.const(1.0, 'float32'),
kernel_size=(3, 3),
channels=kernel_shape[0],
strides=(1, 1),
dilation=(1, 1),
out_dtype='int32',
Expand Down

0 comments on commit 24c53a3

Please sign in to comment.