diff --git a/stream/parser/onnx/conv.py b/stream/parser/onnx/conv.py index d286e4c..d22dafc 100644 --- a/stream/parser/onnx/conv.py +++ b/stream/parser/onnx/conv.py @@ -83,8 +83,7 @@ def get_layer_node_user_format( if is_1d_conv: # No FY, OY, IY - data["loop_sizes"] = [B, K, G, OX, C, FX] - data["loop_dims"] = ["B", "K", "G", "OX", "C", "FX"] + loop_size_dict = {"B": B, "K": K, "G": G, "OX": OX, "C": C, "FX": FX} data["equation"] = f"O[b][g][k][ox]+=W[{weight_dim}][c][fx]*I[b][g][c][ix]" data["pr_loop_dims"] = ["IX"] data["pr_loop_sizes"] = [IX] @@ -99,8 +98,7 @@ def get_layer_node_user_format( FY = kernel_shape[1] # TODO is kernel_shape in (FX, FY) format or (FY, FX)? (I assumed the former) IY = input_shape[3] OY = output_shape[3] - data["loop_sizes"] = [B, K, G, OX, C, FX, OY, FY] - data["loop_dims"] = ["B", "K", "G", "OX", "C", "FX", "OY", "FY"] + loop_size_dict = {"B": B, "K": K, "G": G, "OX": OX, "C": C, "FX": FX, "OY": OY, "FY": FY} data["equation"] = f"O[b][g][k][oy][ox]+=W[{weight_dim}][c][fy][fx]*I[b][g][c][iy][ix]" data["pr_loop_dims"] = ["IX", "IY"] data["pr_loop_sizes"] = [IX, IY] @@ -113,22 +111,15 @@ def get_layer_node_user_format( [padding[1], padding[3]], ] - # Remove dims with size 1, except batch - dim_sizes_larger_than_1 = { - dim: size for dim, size in zip(data["loop_dims"], data["loop_sizes"]) if size > 1 or dim == "B" - } - dims_with_size_1 = [dim for dim in data["loop_dims"] if dim not in dim_sizes_larger_than_1] - data["loop_dims"] = list(dim_sizes_larger_than_1.keys()) - data["loop_sizes"] = list(dim_sizes_larger_than_1.values()) - for dim in dims_with_size_1: - data["equation"] = data["equation"].replace(f"[{dim.lower()}]", "") - - # Filter out loops with size 1 - # loop_sizes = {"B": B, "K": K, "G": G, "OX": OX, "OY": OY, "C": C, "FX": FX, "FY": FY} - # dims_with_size_1 = [k for k, v in loop_sizes.items() if v == 1] - # loop_sizes = {k: v for k, v in loop_sizes.items() if v > 1} - # data["loop_dims"] = list(loop_sizes.keys()) - # data["loop_sizes"] = list(loop_sizes.values()) + # Remove C/K if they have size 1 + for dim in ["C", "K"]: + if loop_size_dict[dim] == 1: + del loop_size_dict[dim] + # Remove from equation + data["equation"] = data["equation"].replace(f"[{dim.lower()}]", "") + + data["loop_dims"] = list(loop_size_dict.keys()) + data["loop_sizes"] = list(loop_size_dict.values()) return data diff --git a/stream/parser/onnx/operator_parser.py b/stream/parser/onnx/operator_parser.py index a434589..78d5e99 100644 --- a/stream/parser/onnx/operator_parser.py +++ b/stream/parser/onnx/operator_parser.py @@ -68,6 +68,14 @@ def get_operand_precision_user_format(self) -> dict[str, int]: intermediate_output_precision: int = self.get_intermediate_output_precision() predecessors = self.get_node_predecessors() match len(predecessors): + case 0: + # e.g. the first node in the network -> assume only one variable input + return { + "W": weight_precision, + "I": act_precision, + "O_final": act_precision, + "O": intermediate_output_precision, + } case 1: # One source operand, one constant return { diff --git a/stream/workload/computation/pooling_node.py b/stream/workload/computation/pooling_node.py index 0c4151c..74a2716 100644 --- a/stream/workload/computation/pooling_node.py +++ b/stream/workload/computation/pooling_node.py @@ -5,12 +5,15 @@ class PoolingNode(ComputationNode): + """TODO this node can be replaced by instantiating ComputationNode directly""" + def __init__( self, node_id: int, node_name: str, node_attr: LayerNodeAttributes, mapping_attr: InterCoreMappingAttributes, + input_names: list[str] = [], ): super().__init__( node_id=node_id, @@ -18,4 +21,5 @@ def __init__( node_attr=node_attr, mapping_attr=mapping_attr, op_type="pooling", + input_names=input_names, )