Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REDO AFTER GH BUG] Add support for quantized models via QNN #5016

Merged
merged 1 commit into from
Mar 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 79 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
"""PT: PyTorch frontend."""
import itertools
import logging

import numpy as np

Expand All @@ -32,6 +33,8 @@
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value

from . import qnn_torch

__all__ = ["from_pytorch"]

# operator implementation
Expand Down Expand Up @@ -146,6 +149,10 @@ def _impl(inputs, input_types):
def _relu():
def _impl(inputs, input_types):
data = inputs[0]
if input_types[0] == "quint8":
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.nn.relu(data)
return _impl

Expand All @@ -154,9 +161,14 @@ def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])

return _op.nn.adaptive_avg_pool2d(
data,
output_size=output_size)
def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)

if input_types[0] == "quint8":
return qnn_torch.quantized_adaptive_avg_2d(data, func)

return func(data)

return _impl

def _adaptive_max_2d():
Expand Down Expand Up @@ -506,7 +518,18 @@ def _impl(inputs, input_types):
else:
exclude = False

return _op.mean(data, axis, keepdims, exclude)
def func(x):
return _op.mean(x, axis, keepdims, exclude)

if input_types[0] == "quint8":
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
return qnn_torch.quantized_mean(data, input_scale,
input_zero_point, func)

return func(data)

return _impl

def _chunk():
Expand Down Expand Up @@ -668,10 +691,40 @@ def _impl(inputs, input_types):
else:
coord_trans = "half_pixel"

return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)

if input_types[0] == "quint8":
import torch
from packaging import version

# Torch version > 1.4 changed upsampling API
if version.parse(torch.__version__) > version.parse("1.4.0"):
num_inputs = 7
else:
num_inputs = 5

assert len(inputs) == num_inputs, "Input quant param not found in op inputs"

input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
return qnn_torch.quantized_upsample(data, input_scale,
input_zero_point, func)
return func(data)

return _impl


def _expand_as():
def _impl(inputs, input_types):
# TODO: maybe fix this
# This assumes expand_as can be removed because TVM has broadcast op
msg = "aten::expand_as(...) found, assume it is part of broadcast op"
logging.warning(msg)
return inputs[0]
return _impl


# Helper functions for operator implementation

def _convert_data_type(input_type):
Expand Down Expand Up @@ -792,6 +845,7 @@ def _convert_elemwise_input(data, input_type):
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
"aten::expand_as" : _expand_as()
}


Expand Down Expand Up @@ -842,6 +896,7 @@ def _report_missing_conversion(op_names):
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
known_ops += list(_convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())

missing = [op_name for op_name in op_names
if op_name not in known_ops]
Expand Down Expand Up @@ -1008,6 +1063,7 @@ def parse_params(graph, state_dict):
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
packed_param_map = {}
seen = set()

for node in getattr_nodes:
Expand All @@ -1020,14 +1076,18 @@ def parse_params(graph, state_dict):
full_attr = _getattr_full_name(getattrs)
full_attr_node_name = _get_output_name(getattrs[-1])

if full_attr in state_dict:
if full_attr.endswith("_packed_params"): # for quantized models
err_msg = "parameter %s not found in state dict" % full_attr
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor,
full_attr_node_name)
param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var

return params, param_tensors
return params, param_tensors, packed_param_map


def parse_operators(operators, outputs, output_index_map, ret_name):
Expand Down Expand Up @@ -1108,16 +1168,26 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):

params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
param_vars, tensors = parse_params(graph, params)
param_vars, tensors, packed_param_map = parse_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0]

# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)

body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

return _module.IRModule.from_expr(func), tvm_params
Loading