-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
141 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import HGQ | ||
import qkeras | ||
from qkeras.quantizers import BaseQuantizer | ||
import tensorflow as tf | ||
|
||
|
||
# TODO: this should be implemented in HGQ so we can import it here | ||
# from HGQ.utils import REGISTERED_LAYERS as HGQ_LAYERS | ||
HGQ_LAYERS = ["FixedPointQuantizer"] | ||
|
||
def extract_quantizers_from_hgq_layer(layer, model): | ||
""" """ | ||
layer_class = layer.__class__.__name__ | ||
if layer_class in HGQ_LAYERS: | ||
handler = handler_map.get(layer_class, None) | ||
if handler: | ||
return handler_map[layer_class](layer, model) | ||
else: | ||
return layer_class, layer.get_config(), None | ||
else: | ||
return layer_class, layer.get_config(), None | ||
|
||
|
||
def extract_FixedPointQuantizer(layer, model): | ||
|
||
quantizers = layer.get_config() | ||
|
||
if "overrides" not in quantizers: | ||
# TODO: add support for FixedPointQuantizer which dont override | ||
raise ValueError(f"Not supported: FixedPointQuantizer has no layers to override") | ||
|
||
quantizers["inputs"] = { | ||
"keep_negative": layer.keep_negative.numpy(), | ||
"bits": layer.bits.numpy(), | ||
"integer_bits": layer.integers.numpy(), | ||
} | ||
keras_config = {'name': quantizers["name"], 'dtype': 'float32'} | ||
|
||
return "Identity", keras_config, quantizers | ||
|
||
|
||
handler_map = { | ||
"FixedPointQuantizer": extract_FixedPointQuantizer | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import numpy as np | ||
|
||
from .quantizers import get_quant_params | ||
|
||
|
||
def get_hgq_onnx_handlers(all_quantizers): | ||
"""Returns the handlers for each kind of layer | ||
Args: | ||
all_quantizers: All the quantizers of the model in dictionary format *check | ||
Returns: | ||
Dictionary containing the handler information for every type of layer | ||
""" | ||
return { | ||
# NOTE: we replace the StatefulPartitionedCall layers with Identity layers | ||
# after them we are adding now FixedPoint layers for the quantitzation | ||
"Identity": ( | ||
FixedPoint, ["FixedPoint", all_quantizers] | ||
), | ||
} | ||
|
||
|
||
def _extract_node_name(onnx_node, keras_quantizers): | ||
""" | ||
Args: | ||
onnx_node: The onnx node to get the information from | ||
keras_quantizers: The dictionary of all the keras quantizers | ||
""" | ||
onnx_name = onnx_node.name | ||
print(onnx_node) | ||
keras_names = keras_quantizers.keys() | ||
print(keras_names, onnx_name) | ||
for keras_name in keras_names: | ||
match = "/" + keras_name + "/" | ||
if match in onnx_name: | ||
return keras_name | ||
|
||
return None | ||
|
||
|
||
def FixedPoint(ctx, node, name, args): | ||
all_quantizers = args[0] | ||
keras_name = _extract_node_name(node, all_quantizers) | ||
if not keras_name: | ||
return # Not found in quantizers, nothing to do | ||
quantizers = all_quantizers[keras_name] | ||
# if we have overrides we are converting a FixedPointQuantizer from HGQ | ||
if quantizers.get("overrides"): | ||
quant_params = get_quant_params(None, quantizers) | ||
attr = quant_params["attributes"] | ||
input_nodes = [node.output[0]] | ||
print(node.input[0]) | ||
for key in quantizers["inputs"].keys(): | ||
name = f"{node.name}_FixedPointQuantizer_quantizer_{key}" | ||
np_val = np.asarray(quant_params["inputs"][key]) | ||
ctx.make_const(name, np_val) | ||
input_nodes.append(name) | ||
quant_fixed_node = ctx.make_node( | ||
"FixedPoint", | ||
input_nodes, | ||
dtypes=None, # TODO: we have to get the type here | ||
name=node.name + "_FixedPoint_quantizer", | ||
attr=attr, | ||
domain="qonnx", | ||
) | ||
ctx.insert_node_on_output(quant_fixed_node, node.output[0]) | ||
ctx.set_shape(quant_fixed_node.output[0], ctx.get_shape(node.output[0])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
|
||
def get_quant_params(tensor, hgq_quantizer): | ||
|
||
return handler_map[hgq_quantizer["keras_layer"]](tensor, hgq_quantizer) | ||
|
||
|
||
def convert_quantized_bits(tensor, quantizer): | ||
|
||
settings = { | ||
"attributes": { | ||
"RND": quantizer["RND"], | ||
"SAT": quantizer["SAT"], | ||
}, | ||
"inputs": { | ||
"integer_bits": quantizer["inputs"]["integers"], | ||
"keep_negative": quantizer["inputs"]["keep_negative"], | ||
"bits": quantizer["inputs"]["bits"], | ||
}, | ||
} | ||
|
||
return settings | ||
|
||
|
||
handler_map = { | ||
"Identity": convert_quantized_bits, | ||
} |