Skip to content

Commit

Permalink
bugfix in reduce_1d: explicitly manage the keep_dim option
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Nov 8, 2024
1 parent d1db0d9 commit 5485268
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 18 deletions.
12 changes: 0 additions & 12 deletions stream/parser/onnx/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,6 @@ def get_layer_node_user_format(
# 1D Conv case: append dimensions of size 1 so equation holds. Conv in FY dimension
is_1d_conv = len(kernel_shape) == 1

# if len(kernel_shape) == 1:
# kernel_shape.insert(0, 1)
# input_shape.append(1)
# output_shape.append(1)
# strides.append(1)
# dilations.append(1)
# assert len(input_shape) == 4
# assert len(output_shape) == 4

# if len(padding) == 2:
# padding = 2 * padding

# Get dimension sizes from input parameters
assert input_shape[0] == output_shape[0], "Batch size is different for input and output activations."
B = output_shape[0]
Expand Down
9 changes: 8 additions & 1 deletion stream/parser/onnx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from stream.parser.onnx.mul import MulParser
from stream.parser.onnx.operator_parser import OnnxOperatorParser
from stream.parser.onnx.pooling import PoolingParser
from stream.parser.onnx.reduce_1d import Reduce1DParser
from stream.parser.onnx.reshape import ReshapeParser
from stream.parser.onnx.simd import SimdParser
from stream.parser.onnx.slice import SliceParser
Expand All @@ -34,6 +35,7 @@ class ONNXModelParser:

# Map the node's op_type to the corresponding Parser class
OP_TYPE_TO_PARSER: dict[str, Type[OnnxOperatorParser]] = {
# General
"QLinearConv": ConvParser,
"Conv": ConvParser,
"MatMul": MatMulParser,
Expand All @@ -46,10 +48,15 @@ class ONNXModelParser:
"Add": MulParser,
"Mul": MulParser,
"Softmax": SoftmaxParser,
# Activations
# Single-input element-wise
"ReduceMean": Reduce1DParser,
"Relu": SimdParser,
"Gelu": SimdParser,
"Silu": SimdParser,
"Sqrt": SimdParser,
"Div": SimdParser,
"Pow": SimdParser,
"Reciprocal": SimdParser, # Div with 1 as numerator
# Dependency propagation
"LpNormalization": LpNormalizationParser,
"Gather": GatherParser,
Expand Down
3 changes: 3 additions & 0 deletions stream/parser/onnx/mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def get_operand_source_input_format(self, shape_of_w: list[int]):
shape is always at `W`"""
predecessors = self.get_node_predecessors()
match len(predecessors):
case 0:
# e.g. first node of graph
return {"W": self.node_id, "I": self.node_id}
case 1:
# One source operand, one constant
return {"W": self.node_id, "I": predecessors[0]}
Expand Down
3 changes: 3 additions & 0 deletions stream/parser/onnx/operator_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def generate_node(self) -> Node: ...
def get_operand_source_input_format(self):
predecessors = self.get_node_predecessors()
match len(predecessors):
case 0:
# e.g. first node of graph
return {"W": self.node_id, "I": self.node_id}
case 1:
# One source operand, one constant
return {"W": self.node_id, "I": predecessors[0]}
Expand Down
42 changes: 37 additions & 5 deletions stream/parser/onnx/reduce_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,36 @@ class Reduce1DParser(OnnxComputeOperatorParser):
e.g. sum over one row or max of a single row
"""

def get_reduction_dim(self, input_shape: list[int], output_shape: list[int]):
"""Returns the axis in which the dimension is reduced"""

# The case that keepdim=True: the reduced dimension is kept with size 1
if len(input_shape) == len(output_shape):
different_size = [a != b for a, b in zip(input_shape, output_shape)]
if sum(different_size) != 1:
raise ValueError(f"Input and output shapes {input_shape}, {output_shape} should only differ in one dim")
reduction_dim = different_size.index(True)
if output_shape[reduction_dim] != 1:
raise ValueError(f"The reduced dimension at axis {reduction_dim} in {output_shape} is larger than 1")
return reduction_dim

# Other: assume that the reduction is at axis=-1
if not all(a == b for a, b in zip(input_shape, output_shape)):
raise NotImplementedError("Reduce node with reduction axis other than -1 not implemented yet.")
reduction_dim = len(input_shape) - 1 # Last dimension

def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[int]):
"""
Generate the necessary dictionary items required for the LayerNode creation.
"""
# TODO check the output shape as well?
assert len(self.get_node_predecessors()) == 1
if len(self.get_node_predecessors()) != 1:
raise NotImplementedError

if self.get_reduction_dim(input_shape, output_shape) != len(input_shape) - 1:
raise NotImplementedError("Only reduction in axis=-1 is supported")

# This is a ONNX node property but can be inferred from the shapes
keep_dim = len(input_shape) == len(output_shape)

data: dict[str, Any] = {}
data["id"] = self.node_id
Expand All @@ -24,17 +48,25 @@ def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[
data["dimension_relations"] = []
data["loop_sizes"] = input_shape

# C is always the reduction dim
# If keep_dim: add an arbitrary dim of size 1
reduced_dim_output = "CR" # C reduced to 1
eq_part_CR = f"[{reduced_dim_output}]" if keep_dim else ""
match len(input_shape):
case 2:
data["equation"] = "O[k]+=I[k][c]*W[]"
data["equation"] = f"O[k]{eq_part_CR}+=I[k][c]*W[]"
data["loop_dims"] = ["K", "C"]
case 3:
data["equation"] = "O[b][k]+=I[b][k][c]*W[]"
data["equation"] = f"O[b][k]{eq_part_CR}+=I[b][k][c]*W[]"
data["loop_dims"] = ["B", "K", "C"]
case 4:
data["equation"] = "O[b][h][k]+=I[b][h][k][c]*W[]"
data["equation"] = f"O[b][h][k]{eq_part_CR}+=I[b][h][k][c]*W[]"
data["loop_dims"] = ["B", "H", "K", "C"]
case _:
raise NotImplementedError

if keep_dim:
data["loop_dims"] += [reduced_dim_output]
data["loop_sizes"] += [1]

return data
1 change: 1 addition & 0 deletions stream/stages/generation/tiled_workload_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def get_bounding_box_dimensions(
# where the onnx tensors are always flattened back to 4D (merging the G+C or G+K into one channel dimension)
dimensions, loop_ranges = self.flatten_grouped_convolution_ranges(producer, consumer, dimensions, loop_ranges)
bounding_box = [loop_ranges[dim] for dim in dimensions]
# TODO can bounding box have size 1? Will probably crash if so

if not interleaved:
bounding_box_flat = tuple([item for sublist in bounding_box for item in sublist])
Expand Down

0 comments on commit 5485268

Please sign in to comment.