Skip to content

Commit

Permalink
address comments, add sample outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 2, 2020
1 parent 9c9556d commit 7b9664a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 21 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _impl(inputs, input_types):
data = inputs[0]
if input_types[0] == "quint8":
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2])
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.nn.relu(data)
return _impl
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.relay.frontend.common import infer_shape


class QuantParam:
class QNNParam:
""" A placeholder for weight quantization parameters """

def __init__(self, weight, bias, scale, zero_point, param_key):
Expand Down Expand Up @@ -55,13 +55,13 @@ def _unpack_quant_params(param_name, packed_params, unpack_func):

import torch
if qweight.qscheme() == torch.per_tensor_affine:
param = QuantParam(weight_np, bias, qweight.q_scale(),
param = QNNParam(weight_np, bias, qweight.q_scale(),
int(qweight.q_zero_point()), param_name)
else:
scales = qweight.q_per_channel_scales().numpy()
zero_points = qweight.q_per_channel_zero_points().numpy()
assert np.all(zero_points == 0)
param = QuantParam(weight_np, bias, scales, 0, param_name)
param = QNNParam(weight_np, bias, scales, 0, param_name)

return param

Expand Down Expand Up @@ -119,7 +119,7 @@ def _get_quant_param_for_input(input_value):
are embeded in a QTensor data structure, not visible statically).
We know that it is quantized using output scale and zp
of some previous quantized op. The purpose of this function
is to find that pair of paramters.
is to find that pair of parameters.
"""
# Indices for output scale and zp
# For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7),
Expand Down Expand Up @@ -245,7 +245,7 @@ def _add_output_quant_params_to_scalar_op(node, graph,
_get_add_scalar_output_quant_param(input_scale, input_zero_point,
scalar)
else:
assert False, "unsupported scalar op: %s" % operator
raise NotImplementedError("unsupported scalar op: %s" % operator)

# create new constant nodes and add them to graph
out_scale_node = graph.create("prim::Constant")
Expand Down
104 changes: 89 additions & 15 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,34 @@ def test_quantized_modules():
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()

# we cannot make any guarantee on how close the raw output is to torch
# tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1)

max_abs_diff = np.max(np.abs(tvm_result - pt_result))
mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
num_identical = np.sum(tvm_result == pt_result)
correct_ratio = num_identical / float(np.prod(tvm_result.shape))
match_ratio = num_identical / float(np.prod(tvm_result.shape))

print(module_name, max_abs_diff, mean_abs_diff, match_ratio)

# sample outputs
"""
relu 0.0039215684 2.6052087e-08 0.9999933567176871
upsample bilinear 0.0 0.0 1.0
conv_bn 0.22062653 0.011478779 0.6909348115006899
conv_bn_relu 0.3700896 0.010921672 0.7489366477964451
linear 0.15987062 0.009231662 0.794921875
linear_relu 0.14180502 0.0053220326 0.8828125
conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019
conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732
linear, per_channel 0.0 0.0 1.0
linear_relu, per_channel 0.0 0.0 1.0
hsigmoid 0.002614379 0.00020525524 0.9214896896258503
hswish 0.0052286386 0.00063522335 0.7587359162414966
semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875
mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871
"""

print(module_name, max_abs_diff, mean_abs_diff, correct_ratio)
# we cannot make any guarantee on how close the raw output is to torch
# tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1)
assert match_ratio > 0.6


def test_quantized_imagenet():
Expand Down Expand Up @@ -361,22 +380,77 @@ def get_imagenet_input():

results.append((model_name, pt_result[0], tvm_result[0]))

pt_top3_labels = np.argsort(pt_result[0])[::-1][:3]
tvm_top3_labels = np.argsort(pt_result[0])[::-1][:3]

assert set(pt_top3_labels) == set(tvm_top3_labels)

print("Torch top3 label:", pt_top3_labels)
print("TVM top3 label:", tvm_top3_labels)

for (model_name, pt_result, tvm_result) in results:
max_abs_diff = np.max(np.abs(tvm_result - pt_result))
mean_abs_diff = np.mean(np.abs(tvm_result - pt_result))
num_identical = np.sum(tvm_result == pt_result)
pt_top3_labels = np.argsort(pt_result)[::-1][:3]
tvm_top3_labels = np.argsort(pt_result)[::-1][:3]

print("\nModel name: %s" % model_name)
print("PyTorch top3 label:", np.argsort(pt_result)[::-1][:3])
print("TVM top3 label:", np.argsort(tvm_result)[::-1][:3])
print("PyTorch top3 label:", pt_top3_labels)
print("TVM top3 label:", tvm_top3_labels)
print("max abs diff:", max_abs_diff)
print("mean abs_diff:", mean_abs_diff)
print("%d in 1000 raw outputs identical." % num_identical)

assert set(pt_top3_labels) == set(tvm_top3_labels)

# sample outputs
"""
Model name: resnet18, per tensor quantization
PyTorch top3 label: [386 101 385]
TVM top3 label: [386 101 385]
max abs diff: 0.65681696
mean abs_diff: 0.14055882
236 in 1000 raw outputs identical.
Model name: mobilenet_v2, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 2.1262953
mean abs_diff: 0.41025686
101 in 1000 raw outputs identical.
Model name: inception_v3, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.9994669
mean abs_diff: 0.098697364
272 in 1000 raw outputs identical.
Model name: googlenet, per tensor quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.28248847
mean abs_diff: 0.0634469
274 in 1000 raw outputs identical.
Model name: resnet18, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.65908074
mean abs_diff: 0.1274223
469 in 1000 raw outputs identical.
Model name: mobilenet_v2, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.71120834
mean abs_diff: 0.15883648
423 in 1000 raw outputs identical.
Model name: inception_v3, per channel quantization
PyTorch top3 label: [386 101 385]
TVM top3 label: [386 101 385]
max abs diff: 1.3372154
mean abs_diff: 0.1225224
401 in 1000 raw outputs identical.
Model name: googlenet, per channel quantization
PyTorch top3 label: [101 386 385]
TVM top3 label: [101 386 385]
max abs diff: 0.34015465
mean abs_diff: 0.054197952
558 in 1000 raw outputs identical.
"""

0 comments on commit 7b9664a

Please sign in to comment.