Skip to content

Commit

Permalink
add Concat node
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Aug 29, 2024
1 parent a548718 commit d53a8aa
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 46 deletions.
56 changes: 56 additions & 0 deletions stream/classes/io/onnx/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser
from zigzag.parser.onnx.utils import OnnxTensorCategory, get_onnx_tensor_type

from stream.classes.workload.concat_node import ConcatNode


class ConcatParser(ONNXOperatorParser):
"""Parses an onnx gather operator into a ConcatNode."""

def run(self):
return self.generate_node()

def generate_node(self):
predecessors = self.get_node_predecessors()

axis = self.get_axis_value()
output_names = [self.node.output[0]]

input_1, input_2 = self.node.input[0], self.node.input[1]

try: # Try first one as constant input
constant_tensor = get_onnx_tensor_type(input_1, self.onnx_model)
if constant_tensor.category != OnnxTensorCategory.HIDDEN or "constant" not in input_1.lower():
raise ValueError

constant_shape = tuple(constant_tensor.shape)
variable_input_first = True
input_names = [input_2]
except ValueError: # Try second one as constant input
constant_tensor = get_onnx_tensor_type(input_2, self.onnx_model)
if constant_tensor.category != OnnxTensorCategory.HIDDEN or "constant" not in input_2.lower():
raise ValueError

constant_shape = tuple(constant_tensor.shape)
variable_input_first = True
input_names = [input_1]

return ConcatNode(
node_id=self.node_id,
node_name=self.node.name,
predecessors=predecessors,
axis=axis,
constant_shape=constant_shape,
variable_input_first=variable_input_first,
input_names=input_names,
output_names=output_names,
)

def get_axis_value(self):
"""Find the value of the axis associated with this concat node in ONNX"""
# `axis` is an attribute of the node
try:
axis_attr = next(filter(lambda x: x.name == "axis", self.node.attribute))
return axis_attr.i
except StopIteration:
raise ValueError("Axis attribute not found in ONNX node")
2 changes: 1 addition & 1 deletion stream/classes/io/onnx/lpnormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class LpNormalizationParser(ONNXOperatorParser):
"""Parses an onnx reshape operator into a ReshapeNode."""
"""Parses an onnx reshape operator into a LpNormalizationNode."""

def __init__(self, node_id, node, nodes_outputs, mapping, onnx_model) -> None:
raise NotImplementedError
Expand Down
11 changes: 5 additions & 6 deletions stream/classes/io/onnx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from zigzag.stages.WorkloadParserStage import WorkloadParserStage

from stream.classes.hardware.architecture.accelerator import Accelerator
from stream.classes.io.onnx.concat import ConcatParser
from stream.classes.io.onnx.conv import ConvParser
from stream.classes.io.onnx.default import DefaultNodeParser
from stream.classes.io.onnx.flatten import FlattenParser
Expand Down Expand Up @@ -109,12 +110,7 @@ def parse_workload_from_onnx_model_and_mapping(self):
accelerator=self.accelerator,
)
logger.info("Parsed Gemm node %s.", node.name)
elif node.op_type in [
"MaxPool",
"AveragePool",
"GlobalMaxPool",
"GlobalAveragePool",
]:
elif node.op_type in ["MaxPool", "AveragePool", "GlobalMaxPool", "GlobalAveragePool"]:
parser = PoolingParser(
node_id=node_id,
node=node,
Expand Down Expand Up @@ -183,6 +179,9 @@ def parse_workload_from_onnx_model_and_mapping(self):
elif node.op_type in ["LpNormalization"]:
parser = LpNormalizationParser(node_id, node, nodes_outputs, self.mapping_data, self.onnx_model)
logger.info("Parsed LpNormalization node %s.", node.name)
elif node.op_type in ["Concat"]:
parser = ConcatParser(node_id, node, nodes_outputs, self.onnx_model)
logger.info("Parsed LpNormalization node %s.", node.name)
# it is not any of the above, so create a DummyNode
else:
parser = DefaultNodeParser(node_id, node, nodes_outputs, self.onnx_model)
Expand Down
2 changes: 1 addition & 1 deletion stream/classes/io/onnx/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class TransposeParser(ONNXOperatorParser):
"""Parses an onnx reshape operator into a ReshapeNode."""
"""Parses an onnx reshape operator into a TransposeNode."""

def run(self):
return self.generate_layer_node_for_transpose()
Expand Down
97 changes: 70 additions & 27 deletions stream/classes/stages/GenerateCNWorkloadHybridStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from stream.classes.opt.splitting.TemporalLoop import TemporalLoop
from stream.classes.workload.computation_node import ComputationNode, LoopRanges
from stream.classes.workload.concat_node import ConcatNode
from stream.classes.workload.dummy_node import DummyNode
from stream.classes.workload.elementwise_node import ElementwiseNode
from stream.classes.workload.flatten_node import FlattenNode
Expand Down Expand Up @@ -602,9 +603,9 @@ def get_tensor_cn_for_op(node: ComputationNode, dependent_operand: LayerOperand)

for path_between in self.workload.all_simple_paths(producer, consumer):
# First node in the path is a ComputationNode, of which we extract the output operand dependency tensor
node = path_between[0]
assert isinstance(node, ComputationNode), "First node in path should be ComputationNode"
tensor = get_tensor_cn_for_op(node, dependent_operand=Constants.OUTPUT_LAYER_OP)
first_node = path_between[0]
assert isinstance(first_node, ComputationNode), "First node in path should be ComputationNode"
tensor = get_tensor_cn_for_op(first_node, dependent_operand=Constants.OUTPUT_LAYER_OP)

# Propagate through intermediate, non-computation nodes
for _, node in enumerate(path_between[1:-1], start=1):
Expand All @@ -613,24 +614,64 @@ def get_tensor_cn_for_op(node: ComputationNode, dependent_operand: LayerOperand)
tensor = self.propagate_cn_production_for_non_cn(node, tensor)

# Final node: Computation node
final_node = path_between[-1]
final_node: ComputationNode = path_between[-1] # type: ignore
assert isinstance(final_node, ComputationNode), "Last node in path should be ComputationNode"

# Find the operand for which this last node connects to its predecessor
dependent_operand = next(
op for op, dependent_node_id in final_node.input_operand_source.items() if dependent_node_id == node.id
)

try:
# Error handling of shape mismatches in tensor propagation
def get_final_tensor_alt_operand():
"""Error handling case 1: sources for `W` and `I` operand are swapped for this node
-> try the other one"""
try:
alt_operand = next(op for op in final_node.input_operand_source if op != dependent_operand)
except StopIteration:
# No alt operand was found -> we're still in trouble
raise TensorDimensionMismatchException
return get_tensor_cn_for_op(final_node, alt_operand)

def get_shape_inferred_propagated_tensor(tensor: NodeTensor, final_tensor: NodeTensor):
"""Error handling case 2: dimensions of ComputationNode (`final_tensor`) were altered by stream
(e.g. to be properly divisible) but this is not reflected in `ConcatNode` with constant shape.
-> manually fix shape"""
if not any(isinstance(node, ConcatNode) for node in path_between[1:-1]):
raise TensorDimensionMismatchException(
"This function only solves the case of errors due to constant shapes in ConcatNode"
)

target_shape = final_tensor.tensor_shape
propagated_shape = tensor.tensor_shape
extension_axis = next(i for i in range(len(target_shape)) if target_shape[i] != propagated_shape[i])
extension_value = target_shape[extension_axis] - propagated_shape[extension_axis]
if extension_value <= 0:
raise TensorDimensionMismatchException(
"Propagated shape cannot be larger than (extended) found shape"
)
extension_shape = tuple(
val if i != extension_axis else extension_value for i, val in enumerate(target_shape)
)
return tensor.concat_with_empty(extension_shape, extension_axis, variable_input_first=False)

try: # Regular case
final_tensor = get_tensor_cn_for_op(final_node, dependent_operand)
inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor)
except TensorDimensionMismatchException as exc:
try:
# Possible cause: Sources for `W` and `I` operand are swapped for this node -> try the other one
dependent_operand = next(op for op in final_node.input_operand_source if op != dependent_operand)
final_tensor = get_tensor_cn_for_op(final_node, dependent_operand)
except TensorDimensionMismatchException:
try: # Error case 1
final_tensor = get_final_tensor_alt_operand()
inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor)
except StopIteration:
raise exc
except TensorDimensionMismatchException:
try: # Error case 2
final_tensor = get_tensor_cn_for_op(final_node, dependent_operand)
tensor = get_shape_inferred_propagated_tensor(tensor, final_tensor)
inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor)
except TensorDimensionMismatchException:
# Error case 1 and 2 combined
final_tensor = get_final_tensor_alt_operand()
tensor = get_shape_inferred_propagated_tensor(tensor, final_tensor)
inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor)

for producer, cons in inter_edges:
all_inter_edges.append(
Expand All @@ -646,21 +687,23 @@ def get_tensor_cn_for_op(node: ComputationNode, dependent_operand: LayerOperand)
return all_inter_edges

def propagate_cn_production_for_non_cn(self, node: Node, input_tensor: NodeTensor) -> NodeTensor:
if isinstance(node, ReshapeNode):
output_tensor = node.reshape_operand_tensor(input_tensor)
elif isinstance(node, TransposeNode):
output_tensor = node.transpose(input_tensor)
elif isinstance(node, LpNormalizationNode):
output_tensor = node.lpnormalization_operand_tensor(input_tensor)
elif isinstance(node, FlattenNode):
output_tensor = node.flatten(input_tensor)
elif isinstance(node, ElementwiseNode):
output_tensor = input_tensor.copy()
elif isinstance(node, GatherNode):
output_tensor = node.gather_operand_tensor(input_tensor)
else:
raise NotImplementedError(f"Tensor propagation not implemented for node {node.name}.")
return output_tensor
match node:
case ReshapeNode():
return node.reshape_operand_tensor(input_tensor)
case TransposeNode():
return node.transpose(input_tensor)
case LpNormalizationNode():
return node.lpnormalization_operand_tensor(input_tensor)
case FlattenNode():
return node.flatten(input_tensor)
case ElementwiseNode():
return input_tensor.copy()
case GatherNode():
return node.gather_operand_tensor(input_tensor)
case ConcatNode():
return node.concat(input_tensor)
case _:
raise NotImplementedError(f"Tensor propagation not implemented for node {node.name}.")

@staticmethod
def get_inter_edges_tensor_based(producer_output_tensor: NodeTensor, consumer_input_tensor: NodeTensor):
Expand Down
66 changes: 66 additions & 0 deletions stream/classes/workload/concat_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from zigzag.datatypes import LayerOperand
from zigzag.workload.LayerNodeABC import LayerNodeABC

from stream.classes.workload.node import Node
from stream.utils import NodeTensor


class ConcatNode(Node, LayerNodeABC):
"""Class that represents an onnx Concat node with one constant input."""

def __init__(
self,
node_id: int,
node_name: str,
predecessors: list[int],
axis: int,
constant_shape: tuple[int, ...],
variable_input_first: bool,
input_names: list[str],
output_names: list[str],
) -> None:
"""Initialize the ConcatNode
Args:
predecessors: The id of this node's parent.
axis: axis in which the input/constants are concatenated
constant_shape: the shape of the constant tensor
variable_input_first: Wether the result is `concat(input, constant_tensor)` or
`concat(constant_tensor, input)`
input_names The input names of this node.
output_names: The output names of this node.
"""
Node.__init__(
self,
node_id=node_id,
node_name=node_name,
type="gather",
onchip_energy=0,
offchip_energy=0,
runtime=0,
possible_core_allocation=[-1],
input_names=input_names,
output_names=output_names,
)
LayerNodeABC.__init__(self, node_id=node_id, node_name=node_name)

self.axis = axis
self.constant_shape = constant_shape
self.variable_input_first = variable_input_first

match len(predecessors):
case 0:
self.input_operand_source = {}
case 1:
self.input_operand_source = {LayerOperand("I"): predecessors[0]}
case 2:
# `indices` (the second input) are considered as inputs
self.input_operand_source = {LayerOperand("W"): predecessors[0], LayerOperand("I"): predecessors[1]}
case _:
raise ValueError("More than two inputs for ConcatNode")

def concat(self, tensor: NodeTensor) -> NodeTensor:
"""Perform gather operation on the tensor."""
return tensor.concat_with_empty(
shape=self.constant_shape, axis=self.axis, variable_input_first=self.variable_input_first
)
20 changes: 9 additions & 11 deletions stream/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def get_nb_empty_elements(self, slices: tuple[slice, ...]):
"""Returns the number of points for which there are no ComputationNodes."""
assert self.is_valid_shape_dimension(slices), "Last dimension of tensor is reserved for CNs"
extended_slices = slices + (slice(0, self.__node_count),)
# tensor_slice = self.as_ndarray()[slices][: self.__node_count]
tensor_slice = self.as_ndarray()[extended_slices]
# all_empty_mask = np.all(tensor_slice == 0, axis=-1)
all_empty_mask = np.logical_and.reduce(tensor_slice == 0, axis=-1)
return int(np.sum(all_empty_mask))

Expand All @@ -191,21 +189,12 @@ def extend_with_node(self, slices: tuple[slice, ...], node: object) -> "NodeTens
except IndexError:
# Happens when all allocated space has been used up. Create new one and double allocated space
new_tensor_np = np.concat((self, np.zeros(self.full_shape, dtype=object)), axis=-1)
assert new_tensor_np.shape[-1] == 2 * self.__pre_allocation_size
new_tensor = NodeTensor(new_tensor_np, pre_allocation_size=2 * self.__pre_allocation_size)
# Update the node pointer
assert self.__node_count == self.__pre_allocation_size
new_tensor.__node_count = self.__node_count
new_tensor = new_tensor.extend_with_node(slices, node)
print(f"EXTENDING TENSORNODE: {self.__pre_allocation_size}->{new_tensor.__pre_allocation_size}")
return new_tensor

# # Slice of thickness 1
# new_tensor_slice = np.zeros(self.tensor_shape + (1,), dtype=object)
# new_tensor_slice[slices] = node

# return NodeTensor(np.concat((self, new_tensor_slice), axis=-1))

def reshape(self, new_shape: tuple[int, ...] | None) -> "NodeTensor": # type: ignore
"""Wrap the numpy reshape method such that the user is agnostic to the last dimension on which nodes are
accumulated"""
Expand All @@ -228,6 +217,15 @@ def gather(self, gather_indices: int | list[int], axis: int) -> "NodeTensor":
axis = axis - 1 if axis < 0 else axis
return (np.take(self.as_ndarray(), gather_indices, axis=axis)).view(NodeTensor)

def concat_with_empty(self, shape: tuple[int, ...], axis: int, variable_input_first: bool):
emtpy_shape = self.convert_to_full_shape(shape)
empty_tensor = np.zeros(emtpy_shape, dtype=object)
axis = axis - 1 if axis < 0 else axis
if variable_input_first:
return np.concat((empty_tensor, self.as_ndarray()), axis=axis).view(NodeTensor)
else:
return np.concat((self.as_ndarray(), empty_tensor), axis=axis).view(NodeTensor)

def __repr__(self):
return f"TensorNode{self.tensor_shape}[depth={self.__node_count}]"

Expand Down

0 comments on commit d53a8aa

Please sign in to comment.