Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Range analysis improvements and better input shape override #97

Merged
merged 12 commits into from
Feb 5, 2024
Merged
12 changes: 6 additions & 6 deletions src/qonnx/transformation/extract_conv_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,29 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import warnings
from onnx import TensorProto, helper
from onnx import helper

from qonnx.transformation.base import Transformation


class ExtractBiasFromConv(Transformation):
"""
Extracts the (optional) Bias from a Conv node and inserts it behind the
Conv node as an Add node.
Extracts the (optional) Bias from a Conv(Transpose) node and inserts it behind the
Conv(Transpose) node as an Add node.
"""

def apply(self, model):
graph = model.graph
node_ind = 0
for n in graph.node:
node_ind += 1
if n.op_type == "Conv":
if n.op_type in ["Conv", "ConvTranspose"]:
# Check if the node has a bias input
if len(n.input) > 2:
# Extract bias
bias = model.get_initializer(n.input[2])
if bias is None:
warnings.warn(f"Could not extract bias from Conv node {n}")
warnings.warn(f"Could not extract bias from node {n}")
continue

# Insert bias as Add node behind the Conv node
Expand All @@ -65,7 +65,7 @@ def apply(self, model):

act_add_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
model.get_tensor_valueinfo(n.output[0]).type.tensor_type.elem_type,
out_shape,
)
graph.value_info.append(act_add_tensor)
Expand Down
26 changes: 18 additions & 8 deletions src/qonnx/util/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit


def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract_conv_bias=False):
def cleanup_model(model, preserve_qnt_ops=True, override_inpsize=None, extract_conv_bias=False):
"""Execute the transformations for the cleanup function on a model level.
This allows the reuse of the cleanup transformations, without needing to read/write the model from/to disk.

Expand All @@ -61,6 +61,19 @@ def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract
preserve_qnt_optypes = ["Quant", "BipolarQuant", "QuantizeLinear", "DequantizeLinear"]
else:
preserve_qnt_optypes = []

if override_inpsize is not None:
if type(override_inpsize) is str:
override_inpsize = eval(override_inpsize)
if type(override_inpsize) is int:
override_batchsize = override_inpsize
model = model.transform(ChangeBatchSize(override_batchsize))
elif type(override_inpsize) is tuple:
override_batchsize = override_inpsize[0]
model = model.transform(ChangeBatchSize(override_batchsize))
iname = model.graph.input[0].name
model.set_tensor_shape(iname, override_inpsize)

cleanup_transformations = [
InferShapes(),
GiveUniqueParameterTensors(),
Expand All @@ -80,27 +93,24 @@ def cleanup_model(model, preserve_qnt_ops=True, override_batchsize=None, extract
model = model.transform(GiveUniqueNodeNames())
model = model.transform(GiveReadableTensorNames())

if override_batchsize is not None:
model = model.transform(ChangeBatchSize(override_batchsize))
model = model.transform(InferShapes())

return model


def cleanup(in_file, *, out_file=None, preserve_qnt_ops=True, override_batchsize: int = None, extract_conv_bias=False):
def cleanup(in_file, *, out_file=None, preserve_qnt_ops=True, override_inpsize: str = None, extract_conv_bias=False):
"""Execute a set of graph transformations to clean-up the given ONNX file.

:param in_file: Filename for the input ONNX model
:param preserve_qnt_ops: Preserve weight quantization operators
:param out_file: If set, filename for the output ONNX model. Set to in_file with _clean
suffix otherwise.
:param override_batchsize: If specified, override the batch size for the ONNX graph
:param override_inpsize: If specified, override the input size (e.g. "(1,3,224,224)" to set all or
just 1 to set batchsize to 1) for the ONNX graph
:param extract_conv_bias: If specified, separate Conv bias into its own Add node
"""

model = ModelWrapper(in_file)
model = cleanup_model(
model, preserve_qnt_ops=preserve_qnt_ops, override_batchsize=override_batchsize, extract_conv_bias=extract_conv_bias
model, preserve_qnt_ops=preserve_qnt_ops, override_inpsize=override_inpsize, extract_conv_bias=extract_conv_bias
)
if out_file is None:
out_file = in_file.replace(".onnx", "_clean.onnx")
Expand Down
98 changes: 72 additions & 26 deletions src/qonnx/util/range_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ def calculate_matvec_accumulator_extremum(matrix: np.ndarray, vec_min, vec_max):
return (min_values, max_values)


def propagate_range(node, model, range_dict):
iname = node.input[0]
node_irange = range_dict[iname]
for oname in node.output:
range_dict[oname] = node_irange


def calc_gemm_range(node, model, range_dict):
alpha = get_by_name(node.attribute, "alpha").f
beta = get_by_name(node.attribute, "beta").f
Expand Down Expand Up @@ -172,10 +165,49 @@ def calc_conv_range(node, model, range_dict):
range_dict[oname] = ret


def calc_convtranspose_range(node, model, range_dict):
iname = node.input[0]
wname = node.input[1]
assert len(node.input) == 2, "Found unsupported ConvTranspose with bias"
oname = node.output[0]
irange = range_dict[iname]
imin, imax = irange
weights = model.get_initializer(wname)
assert weights is not None, "Uninitialized ConvTranspose weights"
groups = get_by_name(node.attribute, "group")
if groups is None:
# default to dense convs
groups = 1
else:
groups = groups.i
assert groups == 1, "Only dense (non-grouped) ConvTranspose is supported"
# do weight reshaping to treat Conv similar to MatMul
# (mh, mw) = (ofm, (ifm x k0 x k1 x ...))
conv_ofm = weights.shape[1]
conv_ifm = weights.shape[0]
weights = weights.transpose(1, 0, 2, 3).reshape(conv_ofm, -1)
k_total = weights.shape[1] // conv_ifm
if type(imin) is np.ndarray:
imin_rep = np.repeat(imin, k_total)
imax_rep = np.repeat(imax, k_total)
else:
imin_rep = imin
imax_rep = imax
dw_ret_min = []
dw_ret_max = []
for i in range(conv_ofm):
w_slice = weights[i, :].reshape(1, -1)
dw_ret = calculate_matvec_accumulator_extremum(w_slice, imin_rep, imax_rep)
dw_ret_min.append(dw_ret[0].item())
dw_ret_max.append(dw_ret[1].item())
ret = (np.asarray(dw_ret_min), np.asarray(dw_ret_max))
range_dict[oname] = ret


def get_minmax_prototype_tensors(irange, ishp, inp_vi, i_channel_axis=1):
proto_min = valueinfo_to_tensor(inp_vi)
proto_max = valueinfo_to_tensor(inp_vi)
if type(irange[0]) in [float, int, np.float32, np.float64, np.uint8, np.int8]:
if type(irange[0]) in [float, int, np.float16, np.float32, np.float64, np.uint8, np.int8]:
imin, imax = irange
proto_min[...] = imin
proto_max[...] = imax
Expand Down Expand Up @@ -211,25 +243,34 @@ def calc_monotonic_range(node, model, range_dict, i_channel_axis=1):
inp_vi = model.get_tensor_valueinfo(inp)
proto_vectors.append(get_minmax_prototype_tensors(irange, ishp, inp_vi, i_channel_axis))
# process all combinations of prototype vectors for dynamic inputs
running_min = None
running_max = None
running_min = [None for i in range(len(node.output))]
running_max = [None for i in range(len(node.output))]
# create context for single-node execution
ctx = {x: model.get_initializer(x) for x in node.input}
ctx[oname] = valueinfo_to_tensor(model.get_tensor_valueinfo(oname))
for oname in node.output:
ctx[oname] = valueinfo_to_tensor(model.get_tensor_valueinfo(oname))
# assume all outputs are homogenous wrt data layout (e.g. channel axis
# always lives in the same position)
axes_to_min = [i for i in range(ctx[oname].ndim)]
axes_to_min.remove(i_channel_axis)
axes_to_min = tuple(axes_to_min)
for inps in itertools.product(*proto_vectors):
for i in range(n_dyn_inp):
ctx[dyn_inps[i]] = inps[i]
execute_node(node, ctx, model.graph, opset_version=opset_version)
# grab new output and update running min/max
out = ctx[oname]
chanwise_min = out.min(axis=axes_to_min).flatten()
chanwise_max = out.max(axis=axes_to_min).flatten()
running_min = np.minimum(chanwise_min, running_min).flatten() if running_min is not None else chanwise_min
running_max = np.maximum(chanwise_max, running_max).flatten() if running_max is not None else chanwise_max
range_dict[oname] = (running_min, running_max)
for oind, oname in enumerate(node.output):
# grab new output and update running min/max
out = ctx[oname]
chanwise_min = out.min(axis=axes_to_min).flatten()
chanwise_max = out.max(axis=axes_to_min).flatten()
running_min[oind] = (
np.minimum(chanwise_min, running_min[oind]).flatten() if running_min[oind] is not None else chanwise_min
)
running_max[oind] = (
np.maximum(chanwise_max, running_max[oind]).flatten() if running_max[oind] is not None else chanwise_max
)
for oind, oname in enumerate(node.output):
range_dict[oname] = (running_min[oind], running_max[oind])


def calc_range_outdtype(node, model, range_dict):
Expand All @@ -240,12 +281,13 @@ def calc_range_outdtype(node, model, range_dict):


optype_to_range_calc = {
"Transpose": propagate_range,
"Transpose": calc_monotonic_range,
"MatMul": calc_matmul_range,
"Conv": calc_conv_range,
"ConvTranspose": calc_convtranspose_range,
"QuantMaxNorm": calc_range_outdtype,
"Flatten": propagate_range,
"Reshape": propagate_range,
"Flatten": calc_monotonic_range,
"Reshape": calc_monotonic_range,
"Quant": calc_monotonic_range,
"BipolarQuant": calc_monotonic_range,
"Mul": calc_monotonic_range,
Expand All @@ -254,7 +296,7 @@ def calc_range_outdtype(node, model, range_dict):
"Add": calc_monotonic_range,
"BatchNormalization": calc_monotonic_range,
"Relu": calc_monotonic_range,
"Pad": propagate_range,
"Pad": calc_monotonic_range,
"AveragePool": calc_monotonic_range,
"Trunc": calc_range_outdtype,
"MaxPool": calc_monotonic_range,
Expand All @@ -267,6 +309,7 @@ def calc_range_outdtype(node, model, range_dict):
"Clip": calc_monotonic_range,
"Sigmoid": calc_monotonic_range,
"Concat": calc_monotonic_range,
"Split": calc_monotonic_range,
}


Expand Down Expand Up @@ -320,8 +363,12 @@ def range_analysis(
range_min = None
range_max = None
else:
irange = irange.split(",")
range_min, range_max = float(irange[0]), float(irange[1])
irange = eval(irange)
range_min, range_max = irange
if isinstance(range_min, list):
range_min = np.asarray(range_min, dtype=np.float32)
if isinstance(range_max, list):
range_max = np.asarray(range_max, dtype=np.float32)
elif isinstance(irange, tuple):
range_min, range_max = irange
else:
Expand Down Expand Up @@ -350,9 +397,8 @@ def range_analysis(
for node in model.graph.node:
dyn_inputs = [x for x in node.input if is_dyn_input(x, model)]
inprange_ok = all([x in range_dict.keys() for x in dyn_inputs])
outcount_ok = len(node.output) == 1
op_ok = node.op_type in optype_to_range_calc.keys()
if inprange_ok and op_ok and outcount_ok:
if inprange_ok and op_ok:
range_calc_fxn = optype_to_range_calc[node.op_type]
range_calc_fxn(node, model, range_dict)
out_range = range_dict[node.output[0]]
Expand Down
Loading