Skip to content

Commit

Permalink
[Torch, QNN] Add missing upcast to uint8 avg_pool conversion (apache#…
Browse files Browse the repository at this point in the history
…5089)

* add missing upcast to avgpool

* add avg pool test
  • Loading branch information
masahi authored and Trevor Morris committed Apr 16, 2020
1 parent 29932d8 commit 6f253b3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
22 changes: 15 additions & 7 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)

if input_types[0] == "quint8":
return qnn_torch.quantized_adaptive_avg_2d(data, func)
return qnn_torch.apply_with_upcast(data, func)

return func(data)

Expand Down Expand Up @@ -484,14 +484,22 @@ def _impl(inputs, input_types):
ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])

return _op.nn.avg_pool2d(data,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
def func(x):
return _op.nn.avg_pool2d(x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)

if input_types[0] == "quint8":
return qnn_torch.apply_with_upcast(data, func)

return func(data)

return _impl


def _dropout():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,9 @@ def add_quant_params(params, quant_params):
params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)


def quantized_adaptive_avg_2d(data, func_fp32):
# this follows tflite impl
def apply_with_upcast(data, func):
inp = _op.cast(data, dtype="int32")
out = func_fp32(inp)
out = func(inp)
return _op.cast(out, "uint8")


Expand Down
15 changes: 13 additions & 2 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ def fuse_model(self):
class UpsamplingBilinear(nn.Module):
def __init__(self):
super().__init__()
self.relu = QuantWrapper(nn.ReLU())
self.quant = QuantStub()
self.dequant = DeQuantStub()

Expand All @@ -233,12 +232,25 @@ def fuse_model(self):
pass


class AvgPool2d(nn.Module):
def __init__(self):
super().__init__()
self.pool = QuantWrapper(nn.AvgPool2d(kernel_size=2))

def forward(self, x):
return self.pool(x)

def fuse_model(self):
pass


def test_quantized_modules():
imagenet_ishape = (1, 3, 224, 224)

qmodules = [
("relu", imagenet_ishape, ReLU(), False),
("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
("avgpool", imagenet_ishape, AvgPool2d(), False),
]

for per_channel in [False, True]:
Expand Down Expand Up @@ -276,7 +288,6 @@ def test_quantized_modules():
pt_result = script_module(inp.clone()).numpy()

input_name = get_graph_input_names(script_module)[0]

runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
Expand Down

0 comments on commit 6f253b3

Please sign in to comment.