Skip to content

Commit

Permalink
fix bug in conv: dont remove dims of size 1, except K and C (not pres…
Browse files Browse the repository at this point in the history
…ent in some 1D convs)
  • Loading branch information
RobinGeens committed Nov 8, 2024
1 parent a38309a commit d1db0d9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
31 changes: 11 additions & 20 deletions stream/parser/onnx/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def get_layer_node_user_format(

if is_1d_conv:
# No FY, OY, IY
data["loop_sizes"] = [B, K, G, OX, C, FX]
data["loop_dims"] = ["B", "K", "G", "OX", "C", "FX"]
loop_size_dict = {"B": B, "K": K, "G": G, "OX": OX, "C": C, "FX": FX}
data["equation"] = f"O[b][g][k][ox]+=W[{weight_dim}][c][fx]*I[b][g][c][ix]"
data["pr_loop_dims"] = ["IX"]
data["pr_loop_sizes"] = [IX]
Expand All @@ -99,8 +98,7 @@ def get_layer_node_user_format(
FY = kernel_shape[1] # TODO is kernel_shape in (FX, FY) format or (FY, FX)? (I assumed the former)
IY = input_shape[3]
OY = output_shape[3]
data["loop_sizes"] = [B, K, G, OX, C, FX, OY, FY]
data["loop_dims"] = ["B", "K", "G", "OX", "C", "FX", "OY", "FY"]
loop_size_dict = {"B": B, "K": K, "G": G, "OX": OX, "C": C, "FX": FX, "OY": OY, "FY": FY}
data["equation"] = f"O[b][g][k][oy][ox]+=W[{weight_dim}][c][fy][fx]*I[b][g][c][iy][ix]"
data["pr_loop_dims"] = ["IX", "IY"]
data["pr_loop_sizes"] = [IX, IY]
Expand All @@ -113,22 +111,15 @@ def get_layer_node_user_format(
[padding[1], padding[3]],
]

# Remove dims with size 1, except batch
dim_sizes_larger_than_1 = {
dim: size for dim, size in zip(data["loop_dims"], data["loop_sizes"]) if size > 1 or dim == "B"
}
dims_with_size_1 = [dim for dim in data["loop_dims"] if dim not in dim_sizes_larger_than_1]
data["loop_dims"] = list(dim_sizes_larger_than_1.keys())
data["loop_sizes"] = list(dim_sizes_larger_than_1.values())
for dim in dims_with_size_1:
data["equation"] = data["equation"].replace(f"[{dim.lower()}]", "")

# Filter out loops with size 1
# loop_sizes = {"B": B, "K": K, "G": G, "OX": OX, "OY": OY, "C": C, "FX": FX, "FY": FY}
# dims_with_size_1 = [k for k, v in loop_sizes.items() if v == 1]
# loop_sizes = {k: v for k, v in loop_sizes.items() if v > 1}
# data["loop_dims"] = list(loop_sizes.keys())
# data["loop_sizes"] = list(loop_sizes.values())
# Remove C/K if they have size 1
for dim in ["C", "K"]:
if loop_size_dict[dim] == 1:
del loop_size_dict[dim]
# Remove from equation
data["equation"] = data["equation"].replace(f"[{dim.lower()}]", "")

data["loop_dims"] = list(loop_size_dict.keys())
data["loop_sizes"] = list(loop_size_dict.values())

return data

Expand Down
8 changes: 8 additions & 0 deletions stream/parser/onnx/operator_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def get_operand_precision_user_format(self) -> dict[str, int]:
intermediate_output_precision: int = self.get_intermediate_output_precision()
predecessors = self.get_node_predecessors()
match len(predecessors):
case 0:
# e.g. the first node in the network -> assume only one variable input
return {
"W": weight_precision,
"I": act_precision,
"O_final": act_precision,
"O": intermediate_output_precision,
}
case 1:
# One source operand, one constant
return {
Expand Down
4 changes: 4 additions & 0 deletions stream/workload/computation/pooling_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@


class PoolingNode(ComputationNode):
"""TODO this node can be replaced by instantiating ComputationNode directly"""

def __init__(
self,
node_id: int,
node_name: str,
node_attr: LayerNodeAttributes,
mapping_attr: InterCoreMappingAttributes,
input_names: list[str] = [],
):
super().__init__(
node_id=node_id,
node_name=node_name,
node_attr=node_attr,
mapping_attr=mapping_attr,
op_type="pooling",
input_names=input_names,
)

0 comments on commit d1db0d9

Please sign in to comment.