-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a548718
commit d53a8aa
Showing
7 changed files
with
208 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser | ||
from zigzag.parser.onnx.utils import OnnxTensorCategory, get_onnx_tensor_type | ||
|
||
from stream.classes.workload.concat_node import ConcatNode | ||
|
||
|
||
class ConcatParser(ONNXOperatorParser): | ||
"""Parses an onnx gather operator into a ConcatNode.""" | ||
|
||
def run(self): | ||
return self.generate_node() | ||
|
||
def generate_node(self): | ||
predecessors = self.get_node_predecessors() | ||
|
||
axis = self.get_axis_value() | ||
output_names = [self.node.output[0]] | ||
|
||
input_1, input_2 = self.node.input[0], self.node.input[1] | ||
|
||
try: # Try first one as constant input | ||
constant_tensor = get_onnx_tensor_type(input_1, self.onnx_model) | ||
if constant_tensor.category != OnnxTensorCategory.HIDDEN or "constant" not in input_1.lower(): | ||
raise ValueError | ||
|
||
constant_shape = tuple(constant_tensor.shape) | ||
variable_input_first = True | ||
input_names = [input_2] | ||
except ValueError: # Try second one as constant input | ||
constant_tensor = get_onnx_tensor_type(input_2, self.onnx_model) | ||
if constant_tensor.category != OnnxTensorCategory.HIDDEN or "constant" not in input_2.lower(): | ||
raise ValueError | ||
|
||
constant_shape = tuple(constant_tensor.shape) | ||
variable_input_first = True | ||
input_names = [input_1] | ||
|
||
return ConcatNode( | ||
node_id=self.node_id, | ||
node_name=self.node.name, | ||
predecessors=predecessors, | ||
axis=axis, | ||
constant_shape=constant_shape, | ||
variable_input_first=variable_input_first, | ||
input_names=input_names, | ||
output_names=output_names, | ||
) | ||
|
||
def get_axis_value(self): | ||
"""Find the value of the axis associated with this concat node in ONNX""" | ||
# `axis` is an attribute of the node | ||
try: | ||
axis_attr = next(filter(lambda x: x.name == "axis", self.node.attribute)) | ||
return axis_attr.i | ||
except StopIteration: | ||
raise ValueError("Axis attribute not found in ONNX node") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from zigzag.datatypes import LayerOperand | ||
from zigzag.workload.LayerNodeABC import LayerNodeABC | ||
|
||
from stream.classes.workload.node import Node | ||
from stream.utils import NodeTensor | ||
|
||
|
||
class ConcatNode(Node, LayerNodeABC): | ||
"""Class that represents an onnx Concat node with one constant input.""" | ||
|
||
def __init__( | ||
self, | ||
node_id: int, | ||
node_name: str, | ||
predecessors: list[int], | ||
axis: int, | ||
constant_shape: tuple[int, ...], | ||
variable_input_first: bool, | ||
input_names: list[str], | ||
output_names: list[str], | ||
) -> None: | ||
"""Initialize the ConcatNode | ||
Args: | ||
predecessors: The id of this node's parent. | ||
axis: axis in which the input/constants are concatenated | ||
constant_shape: the shape of the constant tensor | ||
variable_input_first: Wether the result is `concat(input, constant_tensor)` or | ||
`concat(constant_tensor, input)` | ||
input_names The input names of this node. | ||
output_names: The output names of this node. | ||
""" | ||
Node.__init__( | ||
self, | ||
node_id=node_id, | ||
node_name=node_name, | ||
type="gather", | ||
onchip_energy=0, | ||
offchip_energy=0, | ||
runtime=0, | ||
possible_core_allocation=[-1], | ||
input_names=input_names, | ||
output_names=output_names, | ||
) | ||
LayerNodeABC.__init__(self, node_id=node_id, node_name=node_name) | ||
|
||
self.axis = axis | ||
self.constant_shape = constant_shape | ||
self.variable_input_first = variable_input_first | ||
|
||
match len(predecessors): | ||
case 0: | ||
self.input_operand_source = {} | ||
case 1: | ||
self.input_operand_source = {LayerOperand("I"): predecessors[0]} | ||
case 2: | ||
# `indices` (the second input) are considered as inputs | ||
self.input_operand_source = {LayerOperand("W"): predecessors[0], LayerOperand("I"): predecessors[1]} | ||
case _: | ||
raise ValueError("More than two inputs for ConcatNode") | ||
|
||
def concat(self, tensor: NodeTensor) -> NodeTensor: | ||
"""Perform gather operation on the tensor.""" | ||
return tensor.concat_with_empty( | ||
shape=self.constant_shape, axis=self.axis, variable_input_first=self.variable_input_first | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters