diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f0bdd806975e3..8249345ae4eea 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -151,7 +151,7 @@ def _impl(inputs, input_types): data = inputs[0] if input_types[0] == "quint8": assert len(inputs) == 3, "Input quant param not found in op inputs" - input_zero_point = _expr.const(inputs[2]) + input_zero_point = _expr.const(inputs[2], dtype="int32") return qnn_torch.quantized_relu(data, input_zero_point) return _op.nn.relu(data) return _impl diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 2f0553cfc7fd2..1a93832bee24c 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -26,7 +26,7 @@ from tvm.relay.frontend.common import infer_shape -class QuantParam: +class QNNParam: """ A placeholder for weight quantization parameters """ def __init__(self, weight, bias, scale, zero_point, param_key): @@ -55,13 +55,13 @@ def _unpack_quant_params(param_name, packed_params, unpack_func): import torch if qweight.qscheme() == torch.per_tensor_affine: - param = QuantParam(weight_np, bias, qweight.q_scale(), + param = QNNParam(weight_np, bias, qweight.q_scale(), int(qweight.q_zero_point()), param_name) else: scales = qweight.q_per_channel_scales().numpy() zero_points = qweight.q_per_channel_zero_points().numpy() assert np.all(zero_points == 0) - param = QuantParam(weight_np, bias, scales, 0, param_name) + param = QNNParam(weight_np, bias, scales, 0, param_name) return param @@ -119,7 +119,7 @@ def _get_quant_param_for_input(input_value): are embeded in a QTensor data structure, not visible statically). We know that it is quantized using output scale and zp of some previous quantized op. The purpose of this function - is to find that pair of paramters. + is to find that pair of parameters. """ # Indices for output scale and zp # For example, in quantized::conv2d(%input, %1, %2, %3, %4, %5, %6, %7), @@ -245,7 +245,7 @@ def _add_output_quant_params_to_scalar_op(node, graph, _get_add_scalar_output_quant_param(input_scale, input_zero_point, scalar) else: - assert False, "unsupported scalar op: %s" % operator + raise NotImplementedError("unsupported scalar op: %s" % operator) # create new constant nodes and add them to graph out_scale_node = graph.create("prim::Constant") diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 6607fea40e573..846a9ba03bf11 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -282,15 +282,34 @@ def test_quantized_modules(): runtime.run() tvm_result = runtime.get_output(0).asnumpy() - # we cannot make any guarantee on how close the raw output is to torch - # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1) - max_abs_diff = np.max(np.abs(tvm_result - pt_result)) mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) num_identical = np.sum(tvm_result == pt_result) - correct_ratio = num_identical / float(np.prod(tvm_result.shape)) + match_ratio = num_identical / float(np.prod(tvm_result.shape)) + + print(module_name, max_abs_diff, mean_abs_diff, match_ratio) + + # sample outputs + """ + relu 0.0039215684 2.6052087e-08 0.9999933567176871 + upsample bilinear 0.0 0.0 1.0 + conv_bn 0.22062653 0.011478779 0.6909348115006899 + conv_bn_relu 0.3700896 0.010921672 0.7489366477964451 + linear 0.15987062 0.009231662 0.794921875 + linear_relu 0.14180502 0.0053220326 0.8828125 + conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019 + conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732 + linear, per_channel 0.0 0.0 1.0 + linear_relu, per_channel 0.0 0.0 1.0 + hsigmoid 0.002614379 0.00020525524 0.9214896896258503 + hswish 0.0052286386 0.00063522335 0.7587359162414966 + semodule, per_channel 0.0039885044 0.0008620687 0.7838592529296875 + mul_scalar negative 0.0011764616 7.815566e-09 0.9999933567176871 + """ - print(module_name, max_abs_diff, mean_abs_diff, correct_ratio) + # we cannot make any guarantee on how close the raw output is to torch + # tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-1, atol=1e-1) + assert match_ratio > 0.6 def test_quantized_imagenet(): @@ -361,22 +380,77 @@ def get_imagenet_input(): results.append((model_name, pt_result[0], tvm_result[0])) - pt_top3_labels = np.argsort(pt_result[0])[::-1][:3] - tvm_top3_labels = np.argsort(pt_result[0])[::-1][:3] - - assert set(pt_top3_labels) == set(tvm_top3_labels) - - print("Torch top3 label:", pt_top3_labels) - print("TVM top3 label:", tvm_top3_labels) - for (model_name, pt_result, tvm_result) in results: max_abs_diff = np.max(np.abs(tvm_result - pt_result)) mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) num_identical = np.sum(tvm_result == pt_result) + pt_top3_labels = np.argsort(pt_result)[::-1][:3] + tvm_top3_labels = np.argsort(pt_result)[::-1][:3] print("\nModel name: %s" % model_name) - print("PyTorch top3 label:", np.argsort(pt_result)[::-1][:3]) - print("TVM top3 label:", np.argsort(tvm_result)[::-1][:3]) + print("PyTorch top3 label:", pt_top3_labels) + print("TVM top3 label:", tvm_top3_labels) print("max abs diff:", max_abs_diff) print("mean abs_diff:", mean_abs_diff) print("%d in 1000 raw outputs identical." % num_identical) + + assert set(pt_top3_labels) == set(tvm_top3_labels) + + # sample outputs + """ + Model name: resnet18, per tensor quantization + PyTorch top3 label: [386 101 385] + TVM top3 label: [386 101 385] + max abs diff: 0.65681696 + mean abs_diff: 0.14055882 + 236 in 1000 raw outputs identical. + + Model name: mobilenet_v2, per tensor quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 2.1262953 + mean abs_diff: 0.41025686 + 101 in 1000 raw outputs identical. + + Model name: inception_v3, per tensor quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.9994669 + mean abs_diff: 0.098697364 + 272 in 1000 raw outputs identical. + + Model name: googlenet, per tensor quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.28248847 + mean abs_diff: 0.0634469 + 274 in 1000 raw outputs identical. + + Model name: resnet18, per channel quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.65908074 + mean abs_diff: 0.1274223 + 469 in 1000 raw outputs identical. + + Model name: mobilenet_v2, per channel quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.71120834 + mean abs_diff: 0.15883648 + 423 in 1000 raw outputs identical. + + Model name: inception_v3, per channel quantization + PyTorch top3 label: [386 101 385] + TVM top3 label: [386 101 385] + max abs diff: 1.3372154 + mean abs_diff: 0.1225224 + 401 in 1000 raw outputs identical. + + Model name: googlenet, per channel quantization + PyTorch top3 label: [101 386 385] + TVM top3 label: [101 386 385] + max abs diff: 0.34015465 + mean abs_diff: 0.054197952 + 558 in 1000 raw outputs identical. + """