diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index 7eb56db382..6699340cac 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -659,7 +659,6 @@ def minimize_accumulator_width(self, model): # for no-activation nodes, output dt = acc dt self.set_nodeattr("outputDataType", adt.name) self.set_nodeattr("accDataType", adt.name) - return DataType[self.get_nodeattr("accDataType")] def minimize_weight_bit_width(self, model): diff --git a/src/finn/custom_op/fpgadataflow/thresholding_batch.py b/src/finn/custom_op/fpgadataflow/thresholding_batch.py index 3bcc5c05cf..72ee2f7af6 100644 --- a/src/finn/custom_op/fpgadataflow/thresholding_batch.py +++ b/src/finn/custom_op/fpgadataflow/thresholding_batch.py @@ -211,6 +211,8 @@ def minimize_accumulator_width(self, model): threshold_tensor ).all(), "Thresholds can't be expressed with type %s" % str(tdt) self.set_nodeattr("weightDataType", tdt.name) + # Update QONNX DataType of tensor for consistency + model.set_tensor_datatype(self.onnx_node.input[1], tdt) return DataType[self.get_nodeattr("weightDataType")] def get_instream_width(self, ind=0): diff --git a/src/finn/transformation/fpgadataflow/minimize_accumulator_width.py b/src/finn/transformation/fpgadataflow/minimize_accumulator_width.py index bc020ca428..8d04d5b817 100644 --- a/src/finn/transformation/fpgadataflow/minimize_accumulator_width.py +++ b/src/finn/transformation/fpgadataflow/minimize_accumulator_width.py @@ -28,6 +28,7 @@ from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation +from qonnx.transformation.infer_datatypes import InferDataTypes from finn.util.fpgadataflow import is_fpgadataflow_node @@ -41,9 +42,15 @@ def __init__(self): super().__init__() def apply(self, model): - for node in model.graph.node: + for node_id in range(len(model.graph.node)): + # Since InferDataTypes potentially changes node attributes in each loop iterations, + # the for-loop cannot loop over a list of a snapshot of the graph's node protos + node = model.graph.node[node_id] if is_fpgadataflow_node(node) is True: inst = getCustomOp(node) if hasattr(inst, "minimize_accumulator_width"): inst.minimize_accumulator_width(model) + # Since this transformation is applied iteratively, we have to ensure that + # we propagate the new datatype to other layers + model = model.transform(InferDataTypes()) return (model, False)