diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6da91c17fd94..0c7465b37871 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -172,7 +172,7 @@ def func(x): return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) if input_types[0] == "quint8": - return qnn_torch.quantized_adaptive_avg_2d(data, func) + return qnn_torch.apply_with_upcast(data, func) return func(data) @@ -484,14 +484,22 @@ def _impl(inputs, input_types): ceil_mode = int(inputs[4]) count_include_pad = int(inputs[5]) - return _op.nn.avg_pool2d(data, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad) + def func(x): + return _op.nn.avg_pool2d(x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad) + + if input_types[0] == "quint8": + return qnn_torch.apply_with_upcast(data, func) + + return func(data) + return _impl + def _dropout(): def _impl(inputs, input_types): data = inputs[0] diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 70178be52bb0..e6a015f8a89e 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -359,10 +359,9 @@ def add_quant_params(params, quant_params): params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias) -def quantized_adaptive_avg_2d(data, func_fp32): - # this follows tflite impl +def apply_with_upcast(data, func): inp = _op.cast(data, dtype="int32") - out = func_fp32(inp) + out = func(inp) return _op.cast(out, "uint8") diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 23fcb7cf6c94..ebc00bfcf541 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -218,7 +218,6 @@ def fuse_model(self): class UpsamplingBilinear(nn.Module): def __init__(self): super().__init__() - self.relu = QuantWrapper(nn.ReLU()) self.quant = QuantStub() self.dequant = DeQuantStub() @@ -233,12 +232,25 @@ def fuse_model(self): pass +class AvgPool2d(nn.Module): + def __init__(self): + super().__init__() + self.pool = QuantWrapper(nn.AvgPool2d(kernel_size=2)) + + def forward(self, x): + return self.pool(x) + + def fuse_model(self): + pass + + def test_quantized_modules(): imagenet_ishape = (1, 3, 224, 224) qmodules = [ ("relu", imagenet_ishape, ReLU(), False), ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False), + ("avgpool", imagenet_ishape, AvgPool2d(), False), ] for per_channel in [False, True]: @@ -276,7 +288,6 @@ def test_quantized_modules(): pt_result = script_module(inp.clone()).numpy() input_name = get_graph_input_names(script_module)[0] - runtime = get_tvm_runtime(script_module, input_name, ishape) runtime.set_input(input_name, inp.numpy().copy()) runtime.run()