Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aditional rounding modes #110

Closed
20 changes: 19 additions & 1 deletion docs/qonnx-custom-ops/quant_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
<dt><tt>narrow</tt> : int (default is 0)</dt>
<dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
<dt><tt>rounding_mode</tt> : string (default is "ROUND")</dt>
<dd>Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
<dd>Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.</dd>
</dl>

#### Inputs
Expand All @@ -46,6 +46,24 @@ This operator is not part of the ONNX standard and is not currently versioned.
</dl>


#### Rounding modes
<details>
<summary>rounding modes</summary>

| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN |
|----------------------------|-----------------|------|-------|----|------|---------|-----------|
| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 |
| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 |
| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 |
| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 |
| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 |
| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 |
| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 |
| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 |
| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 |
</details>

#### Examples
<details>
<summary>Quant</summary>
Expand Down
51 changes: 51 additions & 0 deletions src/qonnx/converters/keras.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import onnx
import tensorflow as tf
import tf2onnx
from collections import OrderedDict
from qkeras.qlayers import QActivation
from qkeras.quantizers import quantized_bits, quantized_relu
from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS

from qonnx.core.modelwrapper import ModelWrapper
Expand Down Expand Up @@ -164,6 +167,41 @@ def _convert_quantizers_to_nodes(onnx_model, quantizers_dict):
return onnx_model.model


def _add_input_quantizer(onnx_model, quantizer):
"Adds an input quantizer to the onnx_model"
iname = onnx_model.graph.input[0].name
scale_init_name = f"{iname}_init_scale"
zp_init_name = f"{iname}_init_zp"
bw_init_name = f"{iname}_init_bw"
onnx_model.set_initializer(scale_init_name, np.array(quantizer.scale))
onnx_model.set_initializer(zp_init_name, np.array(0.0))
onnx_model.set_initializer(bw_init_name, np.array(quantizer.bits))
if isinstance(quantizer, quantized_bits):
signed = quantizer.keep_negative
narrow = quantizer.symmetric
rounding_mode = "ROUND"
elif isinstance(quantizer, quantized_relu):
signed = False
narrow = False
rounding_mode = "HALF_EVEN"
else:
raise NotImplementedError
quant_node = onnx.helper.make_node(
op_type="Quant",
inputs=[iname, scale_init_name, zp_init_name, bw_init_name],
outputs=[f"{iname}_quantized"],
name=f"{iname}_Quant",
domain="qonnx.custom_op.general",
narrow=narrow,
rounding_mode=rounding_mode,
signed=signed,
)
for node in onnx_model.graph.node:
if node.input[0] == iname:
node.input[0] = quant_node.output[0]
onnx_model.graph.node.extend([quant_node])


def from_keras(
model,
name="qkeras_to_qonnx_converted",
Expand Down Expand Up @@ -230,6 +268,19 @@ def from_keras(
)

onnx_model = ModelWrapper(model_proto)

# checks if there is a quantizer at the input and adds it to the proto
# This is'nt handled in the "qkeras_op_handlers"
for submod in model.submodules:
if (
isinstance(submod, (QActivation, tf.keras.layers.Activation))
and model.input.name == submod.input.name
and isinstance(submod.submodules[0], (quantized_bits, quantized_relu))
):
assert len(submod.submodules) == 1
_add_input_quantizer(onnx_model, submod.submodules[0])
break

# Set the first value of input/output shape to 1, currently this is set to unknown,
# because it is technically the batch size
if not (len(onnx_model.graph.input) == 1 and len(onnx_model.graph.output) == 1):
Expand Down
78 changes: 68 additions & 10 deletions src/qonnx/converters/qkeras/onnx.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import logging
import numpy as np
from tf2onnx.late_rewriters import channel_order_rewriters
from tf2onnx.onnx_opset.math import DirectOp, MatMul
from tf2onnx.onnx_opset.nn import BiasAdd, ConvOp

from qonnx.custom_op.general.quant import quant

from .quantizers import get_quant_params

logger = logging.getLogger(__name__)


def get_qkeras_onnx_handlers(all_quantizers):
"""Returns the handlers for each kind of layer
Expand Down Expand Up @@ -47,37 +52,85 @@ def _extract_node_name(onnx_node, keras_quantizers):
return None


def check_tensor_is_representable(tensor, quant_params, node):
"Gives a Warning iftensor is not representable with the providede quantization settings"
qtensor = quant(
inp_tensor=np.array(tensor),
scale=np.array(quant_params["inputs"]["scale"]),
zeropt=np.array(quant_params["inputs"]["zero_point"]),
bitwidth=np.array(quant_params["inputs"]["bit_width"]),
signed=quant_params["attributes"]["signed"],
narrow=quant_params["attributes"]["narrow"],
rounding_mode=quant_params["attributes"]["rounding_mode"],
)
if not np.array_equal(tensor, qtensor):
logger.warn(
f"Tensor of node: {node.name} is not representable with the provided quantization settings: {quant_params}"
)


def qlayer_handler(ctx, node, name, args):
all_quantizers = args[0]
keras_name = _extract_node_name(node, all_quantizers)
if not keras_name:
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]
if quantizers.get("kernel_quantizer"):

if quantizers.get("kernel_quantizer_cfg"):
weights = node.inputs[1].get_tensor_value(as_list=True)
quant_params = get_quant_params(weights, quantizers["kernel_quantizer"])
quant_params = get_quant_params(weights, quantizers["kernel_quantizer_cfg"])
check_tensor_is_representable(weights, quant_params, node)
attr = quant_params["attributes"]
input_nodes = [node.input[1]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_kernel_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
ctx.insert_new_node_on_input(
quant_node = ctx.insert_new_node_on_input(
node, "Quant", input_nodes, name=node.name + "_kernel_quantizer", **attr, domain="qonnx"
)

if quantizers.get("bias_quantizer") and len(node.input) == 3:
bias = node.inputs[2].get_tensor_value(as_list=True)
quant_params = get_quant_params(bias, quantizers["bias_quantizer"])
if quantizers["kernel_quantizer_cfg"]["class_name"] == "quantized_bits":
bits = quantizers["kernel_quantizer_cfg"]["config"]["bits"]
integer = quantizers["kernel_quantizer_cfg"]["config"]["integer"]
keep_negative = quantizers["kernel_quantizer_cfg"]["config"]["keep_negative"]
if bits == integer + keep_negative:
scale_node = ctx.make_const(
name=node.name + "_kernel_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32)
)
ctx.insert_new_node_on_output(
op_type="Mul",
output_name=quant_node.output[0],
name=node.name + "_kernel_requantizer",
inputs=[quant_node.output[0], scale_node.name],
)

if quantizers.get("bias_quantizer_cfg") and len(node.input) == 3:
bias = node.inputs[-1].get_tensor_value(as_list=True)
quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"])
check_tensor_is_representable(bias, quant_params, node)
attr = quant_params["attributes"]
input_nodes = [node.input[2]]
input_nodes = [node.input[-1]]
for key in quant_params["inputs"].keys():
name = f"{node.name}_bias_quantizer_{key}"
np_val = np.asarray(quant_params["inputs"][key])
ctx.make_const(name, np_val)
input_nodes.append(name)
ctx.insert_new_node_on_input(node, "Quant", input_nodes, name=node.name + "_bias_quantizer", **attr, domain="qonnx")
if quantizers["bias_quantizer_cfg"]["class_name"] == "quantized_bits":
bits = quantizers["bias_quantizer_cfg"]["config"]["bits"]
integer = quantizers["bias_quantizer_cfg"]["config"]["integer"]
keep_negative = quantizers["bias_quantizer_cfg"]["config"]["keep_negative"]
if bits == integer + keep_negative:
scale_node = ctx.make_const(
name=node.name + "_bias_scale", np_val=quant_params["inputs"]["scale"].astype(np.float32)
)
ctx.insert_new_node_on_output(
op_type="Mul",
output_name=quant_node.output[0],
name=node.name + "_bias_requantizer",
inputs=[quant_node.output[0], scale_node.name],
)

if quantizers.get("activation"):
dtypes = [ctx.get_dtype(node.output[0])]
Expand Down Expand Up @@ -109,6 +162,11 @@ def qact_handler(ctx, node, name, args):
quantizers = all_quantizers[keras_name]
if quantizers.get("activation"):
dtypes = [ctx.get_dtype(node.output[0])]
if "auto" in quantizers["activation"]:
if not node.graph.get_node_by_output(node.input[0]).is_const():
raise AttributeError(
f"Automatic quantizers (auto/auto_po2) must have a const input. Invalid topology at node: {name}."
)
quant_params = get_quant_params(None, quantizers["activation"])
attr = quant_params["attributes"]
input_nodes = [node.output[0]]
Expand Down Expand Up @@ -148,9 +206,9 @@ def bias_handler(ctx, node, name, args):
return # Not found in quantizers, nothing to do
quantizers = all_quantizers[keras_name]

if quantizers.get("bias_quantizer"):
if quantizers.get("bias_quantizer_cfg"):
bias = node.inputs[1].get_tensor_value(as_list=True)
quant_params = get_quant_params(bias, quantizers["bias_quantizer"])
quant_params = get_quant_params(bias, quantizers["bias_quantizer_cfg"])
attr = quant_params["attributes"]
input_nodes = [node.input[1]]
for key in quant_params["inputs"].keys():
Expand Down
19 changes: 17 additions & 2 deletions src/qonnx/converters/qkeras/qlayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,26 @@ def extract_qlayer(layer):

keras_config = layer.get_config()

keras_config.pop("kernel_quantizer", None)
keras_config.pop("bias_quantizer", None)
kernel_quant_cfg = keras_config.pop("kernel_quantizer", None)
bias_quant_cfg = keras_config.pop("bias_quantizer", None)
keras_config.pop("kernel_range", None)
keras_config.pop("bias_range", None)

quantizers["kernel_quantizer_cfg"] = kernel_quant_cfg
quantizers["bias_quantizer_cfg"] = bias_quant_cfg

# For some reason downstream can't handle auto_po2, so we just calculate the scale value now
if kernel_quant_cfg is not None and kernel_quant_cfg["config"]["alpha"] == "auto_po2":
layer.kernel_quantizer_internal(layer.kernel) # sets .scale (see auto_po2)
quantizers["kernel_quantizer_cfg"]["config"]["alpha"] = (
layer.kernel_quantizer_internal.scale.numpy().flatten().tolist()
)
if bias_quant_cfg is not None and bias_quant_cfg["config"]["alpha"] == "auto_po2":
layer.bias_quantizer_internal(layer.bias)
quantizers["bias_quantizer_cfg"]["config"]["alpha"] = layer.bias_quantizer_internal.scale.numpy().flatten().tolist()
quantizers.pop("kernel_quantizer", None)
quantizers.pop("bias_quantizer", None)

# Check if activation is quantized
if _is_keras_quantizer(keras_config["activation"]):
keras_config["activation"] = _replace_activation(quantizers["activation"])
Expand Down
11 changes: 7 additions & 4 deletions src/qonnx/converters/qkeras/quantizers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import qkeras
import six
import tensorflow as tf


def get_quant_params(tensor, qkeras_quantizer):
if isinstance(qkeras_quantizer, str):
if isinstance(qkeras_quantizer, (str, dict)):
qkeras_quantizer = qkeras.get_quantizer(qkeras_quantizer)

return handler_map[qkeras_quantizer.__class__.__name__](tensor, qkeras_quantizer)
Expand Down Expand Up @@ -34,11 +36,12 @@ def convert_quantized_bits(tensor, quantizer):
signed = int(config["keep_negative"])
narrow = int(config["symmetric"])
qscale = _get_quantizer_scale(tensor, quantizer)
assert qscale == 1, "Non-unity alpha is not yet supported"
scale = 1.0 / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
if not isinstance(qscale, (np.ndarray, tf.Tensor)):
qscale = np.array(qscale)
scale = qscale / 2 ** (int(config["bits"]) - int(config["integer"] + signed))
zero_point = 0
bit_width = int(config["bits"])
rounding_mode = "ROUND"
rounding_mode = "HALF_EVEN"

settings = {
"attributes": {"signed": signed, "narrow": narrow, "rounding_mode": rounding_mode},
Expand Down
8 changes: 8 additions & 0 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ def get_initializer(self, tensor_name, return_dtype=False):
else:
return None

def del_initializer(self, initializer_name):
"""Deletes an initializer from the model."""
graph = self._model_proto.graph
for init in graph.initializer:
if init.name == initializer_name:
graph.initializer.remove(init)
break

def find_producer(self, tensor_name):
"""Finds and returns the node that produces the tensor with given name."""
for x in self._model_proto.graph.node:
Expand Down
22 changes: 21 additions & 1 deletion src/qonnx/custom_op/general/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,32 @@ def resolve_rounding_mode(mode_string):
"""Resolve the rounding mode string of Quant and Trunc ops
to the corresponding numpy functions."""
normalized_mode_string = mode_string.upper()
if normalized_mode_string == "ROUND":
if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN":
return np.round
elif normalized_mode_string == "CEIL":
return np.ceil
elif normalized_mode_string == "FLOOR":
return np.floor
elif normalized_mode_string == "UP":

def round_up(x):
return np.sign(x) * np.ceil(np.abs(x))

return round_up
elif normalized_mode_string == "DOWN":
return np.fix
elif normalized_mode_string == "HALF_UP":

def round_half_up(x):
return np.sign(x) * np.floor(np.abs(x) + 0.5)

return round_half_up
elif normalized_mode_string == "HALF_DOWN":

def round_half_down(x):
return np.sign(x) * np.ceil(np.abs(x) - 0.5)

return round_half_down
else:
raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")

Expand Down
4 changes: 4 additions & 0 deletions src/qonnx/transformation/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,9 @@ def apply(self, model):
remove_node_and_rewire(model, n)
graph_modified = True
break
elif n.op_type == "Identity":
remove_node_and_rewire(model, n)
graph_modified = True
break
model = model.transform(InferShapes())
return (model, graph_modified)
23 changes: 23 additions & 0 deletions tests/custom_op/test_runding_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

import numpy as np

from qonnx.custom_op.general.quant import resolve_rounding_mode


@pytest.mark.parametrize(
"rmode,exp",
[
("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])),
("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])),
("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])),
("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])),
("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])),
("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])),
("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])),
],
)
def test_rounding_modes(rmode, exp):
test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
rounding_fn = resolve_rounding_mode(rmode)
assert np.array_equal(rounding_fn(test_array), exp)
Loading