Skip to content

Commit

Permalink
[REDO AFTER GH BUG] Add support for quantized models via QNN (apache#…
Browse files Browse the repository at this point in the history
…5016)

This reverts commit f346c60.
  • Loading branch information
masahi authored and zhiics committed Apr 17, 2020
1 parent 04d9310 commit 4a0a4ab
Show file tree
Hide file tree
Showing 4 changed files with 1,232 additions and 9 deletions.
88 changes: 79 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
import itertools
import logging

import numpy as np

Expand All @@ -32,6 +33,8 @@
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value

from . import qnn_torch

__all__ = ["from_pytorch"]

# operator implementation
Expand Down Expand Up @@ -146,6 +149,10 @@ def _impl(inputs, input_types):
def _relu():
def _impl(inputs, input_types):
data = inputs[0]
if input_types[0] == "quint8":
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.nn.relu(data)
return _impl

Expand All @@ -154,9 +161,14 @@ def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])

return _op.nn.adaptive_avg_pool2d(
data,
output_size=output_size)
def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)

if input_types[0] == "quint8":
return qnn_torch.quantized_adaptive_avg_2d(data, func)

return func(data)

return _impl

def _adaptive_max_2d():
Expand Down Expand Up @@ -506,7 +518,18 @@ def _impl(inputs, input_types):
else:
exclude = False

return _op.mean(data, axis, keepdims, exclude)
def func(x):
return _op.mean(x, axis, keepdims, exclude)

if input_types[0] == "quint8":
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
return qnn_torch.quantized_mean(data, input_scale,
input_zero_point, func)

return func(data)

return _impl

def _chunk():
Expand Down Expand Up @@ -668,10 +691,40 @@ def _impl(inputs, input_types):
else:
coord_trans = "half_pixel"

return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)

if input_types[0] == "quint8":
import torch
from packaging import version

# Torch version > 1.4 changed upsampling API
if version.parse(torch.__version__) > version.parse("1.4.0"):
num_inputs = 7
else:
num_inputs = 5

assert len(inputs) == num_inputs, "Input quant param not found in op inputs"

input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
return qnn_torch.quantized_upsample(data, input_scale,
input_zero_point, func)
return func(data)

return _impl


def _expand_as():
def _impl(inputs, input_types):
# TODO: maybe fix this
# This assumes expand_as can be removed because TVM has broadcast op
msg = "aten::expand_as(...) found, assume it is part of broadcast op"
logging.warning(msg)
return inputs[0]
return _impl


# Helper functions for operator implementation

def _convert_data_type(input_type):
Expand Down Expand Up @@ -792,6 +845,7 @@ def _convert_elemwise_input(data, input_type):
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
"aten::expand_as" : _expand_as()
}


Expand Down Expand Up @@ -842,6 +896,7 @@ def _report_missing_conversion(op_names):
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
known_ops += list(_convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())

missing = [op_name for op_name in op_names
if op_name not in known_ops]
Expand Down Expand Up @@ -1008,6 +1063,7 @@ def parse_params(graph, state_dict):
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
packed_param_map = {}
seen = set()

for node in getattr_nodes:
Expand All @@ -1020,14 +1076,18 @@ def parse_params(graph, state_dict):
full_attr = _getattr_full_name(getattrs)
full_attr_node_name = _get_output_name(getattrs[-1])

if full_attr in state_dict:
if full_attr.endswith("_packed_params"): # for quantized models
err_msg = "parameter %s not found in state dict" % full_attr
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor,
full_attr_node_name)
param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var

return params, param_tensors
return params, param_tensors, packed_param_map


def parse_operators(operators, outputs, output_index_map, ret_name):
Expand Down Expand Up @@ -1108,16 +1168,26 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):

params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
param_vars, tensors = parse_params(graph, params)
param_vars, tensors, packed_param_map = parse_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0]

# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)

body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

return _module.IRModule.from_expr(func), tvm_params
Loading

0 comments on commit 4a0a4ab

Please sign in to comment.