Skip to content

Commit

Permalink
[Torch][Quantized] Fix converting serialized quantized models (apache…
Browse files Browse the repository at this point in the history
…#5839)

* [Torch] Fix converting serialized quantized models

* clean up dtype check

* comment clean up
  • Loading branch information
masahi authored and Trevor Morris committed Jun 30, 2020
1 parent 96d4bd3 commit fa4ec78
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 20 deletions.
42 changes: 25 additions & 17 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def inplace_add_to_add(op_name):
return False


def _is_quantized_tensor(data, prelude):
# If a quantized Torch module is saved and loaded back, dtype will be dropped
# Since dtypes from Torch tensors are not reliable in such cases, we use
# Relay's type inference result to decide if an input tensor is quantized
ty = _infer_type_with_prelude(data, prelude)
return ty.dtype == "uint8"


# operator implementation
def _elemwise(name):
def _impl(inputs, input_types):
Expand Down Expand Up @@ -530,10 +538,10 @@ def _impl(inputs, input_types):
return _impl


def _relu():
def _relu(prelude):
def _impl(inputs, input_types):
data = inputs[0]
if input_types[0] == "quint8":
if _is_quantized_tensor(data, prelude):
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
Expand Down Expand Up @@ -595,15 +603,15 @@ def _impl(inputs, input_types):
return _op.log(_op.tensor.sigmoid(data))
return _impl

def _adaptive_avg_pool_2d():
def _adaptive_avg_pool_2d(prelude):
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])

def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)

if input_types[0] == "quint8":
if _is_quantized_tensor(data, prelude):
return qnn_torch.apply_with_upcast(data, func)

return func(data)
Expand Down Expand Up @@ -1108,7 +1116,7 @@ def _impl(inputs, input_types):
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
return _impl

def _avg_pool2d():
def _avg_pool2d(prelude):
def _impl(inputs, input_types):
data = inputs[0]

Expand All @@ -1130,7 +1138,7 @@ def func(x):
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)

if input_types[0] == "quint8":
if _is_quantized_tensor(data, prelude):
return qnn_torch.apply_with_upcast(data, func)

return func(data)
Expand Down Expand Up @@ -1254,7 +1262,7 @@ def _impl(inputs, input_types):

return _impl

def _mean():
def _mean(prelude):
def _impl(inputs, input_types):
data = inputs[0]

Expand All @@ -1274,7 +1282,7 @@ def _impl(inputs, input_types):
def func(x):
return _op.mean(x, axis, keepdims, exclude)

if input_types[0] == "quint8":
if _is_quantized_tensor(data, prelude):
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
Expand Down Expand Up @@ -1492,7 +1500,7 @@ def _impl(inputs, input_types):

return _impl

def _upsample(method):
def _upsample(method, prelude):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
Expand All @@ -1516,7 +1524,7 @@ def _impl(inputs, input_types):
def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)

if input_types[0] == "quint8":
if _is_quantized_tensor(data, prelude):
import torch
from packaging import version

Expand Down Expand Up @@ -1835,8 +1843,8 @@ def _get_convert_map(prelude):
"aten::take" : _take(),
"aten::where" : _where(),
"aten::topk" : _topk(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
"aten::relu" : _relu(prelude),
"aten::relu_" : _relu(prelude),
"aten::prelu" : _prelu(),
"aten::leaky_relu" : _leaky_relu(),
"aten::elu" : _elu(),
Expand All @@ -1845,7 +1853,7 @@ def _get_convert_map(prelude):
"aten::gelu" : _gelu(),
"aten::selu" : _selu(),
"aten::log_sigmoid" : _log_sigmoid(),
"aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
"aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(prelude),
"aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
"aten::max_pool2d" : _maxpool_2d(),
"aten::max_pool2d_with_indices" : _maxpool_2d_with_indices(),
Expand Down Expand Up @@ -1874,13 +1882,13 @@ def _get_convert_map(prelude):
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
"aten::softplus" : _softplus(),
"aten::avg_pool2d" : _avg_pool2d(),
"aten::avg_pool2d" : _avg_pool2d(prelude),
"aten::avg_pool3d" : _avg_pool3d(),
"aten::dropout" : _dropout(),
"aten::dropout_" : _dropout(),
"aten::feature_dropout" : _dropout(),
"aten::alpha_dropout" : _dropout(),
"aten::mean" : _mean(),
"aten::mean" : _mean(prelude),
"aten::chunk" : _chunk(prelude),
"aten::matmul" : _matmul(prelude),
"aten::expand" : _expand(),
Expand Down Expand Up @@ -1932,8 +1940,8 @@ def _get_convert_map(prelude):
"aten::isnan" : _unary("isnan"),
"aten::clamp" : _clamp(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
"aten::upsample_bilinear2d" : _upsample("bilinear", prelude),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor", prelude),
"aten::upsample_trilinear3d" : _upsample3d("trilinear"),
"aten::upsample_nearest3d" : _upsample3d("nearest_neighbor"),
"aten::expand_as" : _expand_as(),
Expand Down
45 changes: 42 additions & 3 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_qconfig(per_channel):
weight=default_weight_observer)


def quantize_model(model, inp, per_channel=False, dummy=True):
def quantize_model(model, inp, per_channel=False):
model.fuse_model()
model.qconfig = get_qconfig(per_channel)
torch.quantization.prepare(model, inplace=True)
Expand Down Expand Up @@ -243,6 +243,18 @@ def fuse_model(self):
pass


class AdaptiveAvgPool2d(nn.Module):
def __init__(self):
super().__init__()
self.pool = QuantWrapper(nn.AdaptiveAvgPool2d((1, 1)))

def forward(self, x):
return self.pool(x)

def fuse_model(self):
pass


def test_quantized_modules():
imagenet_ishape = (1, 3, 224, 224)

Expand Down Expand Up @@ -280,7 +292,7 @@ def test_quantized_modules():
raw_module.eval()
inp = torch.rand(ishape)

quantize_model(raw_module, inp, per_channel=per_channel, dummy=True)
quantize_model(raw_module, inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_module, inp).eval()

with torch.no_grad():
Expand Down Expand Up @@ -376,7 +388,7 @@ def get_imagenet_input():
inp = get_imagenet_input()
pt_inp = torch.from_numpy(inp)

quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False)
quantize_model(raw_model, pt_inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_model, pt_inp).eval()

with torch.no_grad():
Expand Down Expand Up @@ -465,3 +477,30 @@ def get_imagenet_input():
mean abs_diff: 0.054197952
558 in 1000 raw outputs identical.
"""


def test_serialized_modules():
ishape = (1, 16, 64, 64)
raw_module = AdaptiveAvgPool2d().eval()
inp = torch.rand(ishape)

quantize_model(raw_module, inp)
script_module = torch.jit.trace(raw_module, inp).eval()

fname = "tmp.pt"
torch.jit.save(script_module, fname)
loaded = torch.jit.load(fname)
os.remove(fname)

with torch.no_grad():
pt_result = loaded(inp.clone()).numpy()

input_name = "input"
runtime = get_tvm_runtime(loaded, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).asnumpy()

num_identical = np.sum(tvm_result == pt_result)
match_ratio = num_identical / float(np.prod(tvm_result.shape))
assert match_ratio > 0.2

0 comments on commit fa4ec78

Please sign in to comment.