Skip to content

Commit

Permalink
cleanup model parser code
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Sep 3, 2024
1 parent 5055d7e commit 0227f54
Show file tree
Hide file tree
Showing 22 changed files with 280 additions and 324 deletions.
16 changes: 8 additions & 8 deletions stream/classes/hardware/architecture/noc/bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from stream.classes.hardware.architecture.noc.communication_link import CommunicationLink


def have_shared_memory(a, b):
def have_shared_memory(a: Core, b: Core):
"""Returns True if core a and core b have a shared top level memory
Args:
Expand All @@ -25,12 +25,12 @@ def have_shared_memory(a, b):


def get_bus(
cores,
bandwidth,
unit_energy_cost,
pooling_core=None,
simd_core=None,
offchip_core=None,
cores: list[Core],
bandwidth: int,
unit_energy_cost: float,
pooling_core: Core | None = None,
simd_core: Core | None = None,
offchip_core: Core | None = None,
):
"""Return a graph of the cores where each core is connected to a single bus.
Expand All @@ -46,7 +46,7 @@ def get_bus(
"""
bus = CommunicationLink("Any", "Any", bandwidth, unit_energy_cost)

edges = []
edges: list[tuple[Core, Core, dict[str, CommunicationLink]]] = []
pairs = [(a, b) for idx, a in enumerate(cores) for b in cores[idx + 1 :]]
for pair in pairs:
(sender, receiver) = pair
Expand Down
2 changes: 1 addition & 1 deletion stream/classes/hardware/architecture/noc/mesh_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_2d_mesh(
simd_core: Core | None = None,
offchip_core: Core | None = None,
):
"""Return a 2D mesh graph of the cores where each core is connected to its N, E, S, W neighbour.
"""Return a 2D mesh graph of the cores where each core is connected to its N, E, S, W neighbor.
We build the mesh by iterating through the row and then moving to the next column.
Each connection between two cores includes two links, one in each direction, each with specified bandwidth.
Thus there are a total of ((nb_cols-1)*2*nb_rows + (nb_rows-1)*2*nb_cols) links in the noc.
Expand Down
13 changes: 7 additions & 6 deletions stream/classes/io/accelerator_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ def create(self) -> Accelerator:
core = core_factory.create(core_id)
cores.append(core)

if self.data["graph"]["type"] == "2d_mesh":
cores_graph = self.create_2d_mesh(cores)
elif self.data["graph"]["type"] == "bus":
cores_graph = self.create_bus(cores)
else:
raise ValueError(f"Invalid graph type {self.data['graph']['type']}.")
match self.data["graph"]["type"]:
case "2d_mesh":
cores_graph = self.create_2d_mesh(cores)
case "bus":
cores_graph = self.create_bus(cores)
case _:
raise ValueError(f"Invalid graph type {self.data['graph']['type']}.")

offchip_core_id: int | None = self.data["graph"]["offchip_core_id"]
return Accelerator(name=self.data["name"], cores=cores_graph, offchip_core_id=offchip_core_id)
Expand Down
3 changes: 2 additions & 1 deletion stream/classes/io/accelerator_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

class AcceleratorValidator:
INPUT_DIR_LOCATION = "stream/inputs/"
GRAPH_TYPES = ["2d_mesh", "bus"]

SCHEMA = {
"name": {"type": "string", "required": True},
Expand All @@ -26,7 +27,7 @@ class AcceleratorValidator:
"type": "dict",
"required": True,
"schema": {
"type": {"type": "string", "required": True},
"type": {"type": "string", "required": True, "allowed": GRAPH_TYPES},
"nb_rows": {"type": "integer", "required": False},
"nb_cols": {"type": "integer", "required": False},
"bandwidth": {"type": "integer", "required": True},
Expand Down
4 changes: 2 additions & 2 deletions stream/classes/io/onnx/concat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser
from zigzag.parser.onnx.utils import OnnxTensorCategory, get_onnx_tensor_type

from stream.classes.io.onnx.operator_parser import OnnxOperatorParser
from stream.classes.workload.concat_node import ConcatNode


class ConcatParser(ONNXOperatorParser):
class ConcatParser(OnnxOperatorParser):
"""Parses an onnx gather operator into a ConcatNode."""

def run(self):
Expand Down
25 changes: 5 additions & 20 deletions stream/classes/io/onnx/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,22 @@
from math import ceil
from typing import Any

from onnx import ModelProto, NodeProto
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser
from zigzag.parser.onnx.utils import (
get_attribute_ints_with_name,
get_node_input_output_dimension_shapes,
)
from zigzag.parser.workload_factory import LayerNodeFactory

from stream.classes.hardware.architecture.accelerator import Accelerator
from stream.classes.io.onnx.operator_parser import OnnxOperatorParser
from stream.classes.workload.computation_node import ComputationNode

logger = logging.getLogger(__name__)


class ConvParser(ONNXOperatorParser):
class ConvParser(OnnxOperatorParser):
"""Parser for ONNX Conv and QLinearConv nodes into LayerNode."""

def __init__(
self,
node_id: int,
node: NodeProto,
nodes_outputs: dict[int, Any],
mapping_data: list[dict[str, Any]],
onnx_model: ModelProto,
accelerator: Accelerator,
) -> None:
super().__init__(node_id, node, nodes_outputs, onnx_model)
self.onnx_model = onnx_model
self.mapping_data = mapping_data
self.accelerator = accelerator
self.op_type = "conv"
OP_TYPE = "conv"

def run(self) -> ComputationNode:
"""Run the parser and return the created LayerNode object."""
Expand All @@ -57,7 +42,7 @@ def get_layer_node_input_format(
data: dict[str, Any] = {}
data["id"] = self.node_id
data["name"] = f"Layer{self.node_id}"
data["operator_type"] = self.op_type
data["operator_type"] = ConvParser.OP_TYPE
# IMPORTANT: If any of the input loops require padding, they should be defined as the rightmost dimensions in
# the equation. This is because we construct the dimensionality order and then add the padding to those last
# dimensions in the order
Expand Down Expand Up @@ -168,6 +153,6 @@ def generate_layer_node_for_conv(self):
node_attr=node_attrs,
input_names=node_input_names,
output_names=node_output_names,
op_type=self.op_type,
op_type=ConvParser.OP_TYPE,
operand_tensor_reshape=None,
)
5 changes: 2 additions & 3 deletions stream/classes/io/onnx/default.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser

from stream.classes.io.onnx.operator_parser import OnnxOperatorParser
from stream.classes.workload.dummy_node import DummyNode


class DefaultNodeParser(ONNXOperatorParser):
class DefaultNodeParser(OnnxOperatorParser):
"""Parse an ONNX node into a DummyNode."""

def run(self):
Expand Down
5 changes: 2 additions & 3 deletions stream/classes/io/onnx/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser

from stream.classes.io.onnx.operator_parser import OnnxOperatorParser
from stream.classes.workload.elementwise_node import ElementwiseNode


class ElementwiseParser(ONNXOperatorParser):
class ElementwiseParser(OnnxOperatorParser):
"""Parser for onnx operators that perform an elementwise operation on two input tensors into a single output tensor.
For example, an Add operator adds two tensors together in every position into one output tensor.
"""
Expand Down
4 changes: 2 additions & 2 deletions stream/classes/io/onnx/flatten.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser
from zigzag.parser.onnx.utils import get_attribute_ints_with_name

from stream.classes.io.onnx.operator_parser import OnnxOperatorParser
from stream.classes.workload.flatten_node import FlattenNode


class FlattenParser(ONNXOperatorParser):
class FlattenParser(OnnxOperatorParser):
"""Parses an onnx flatten operator into a FlattenNode."""

def run(self):
Expand Down
4 changes: 2 additions & 2 deletions stream/classes/io/onnx/gather.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from onnx import numpy_helper
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser

from stream.classes.io.onnx.operator_parser import OnnxOperatorParser
from stream.classes.workload.gather_node import GatherNode


class GatherParser(ONNXOperatorParser):
class GatherParser(OnnxOperatorParser):
"""Parses an onnx gather operator into a GatherNode."""

def run(self):
Expand Down
14 changes: 10 additions & 4 deletions stream/classes/io/onnx/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,30 @@
from zigzag.parser.onnx.GemmParser import GemmParser as GemmParserZigZag

from stream.classes.hardware.architecture.accelerator import Accelerator
from stream.classes.io.onnx.operator_parser import OnnxOperatorParser
from stream.classes.workload.computation_node import ComputationNode

logger = logging.getLogger(__name__)


class GemmParser(GemmParserZigZag):
class GemmParser(GemmParserZigZag, OnnxOperatorParser):
"""Parses an ONNX Gemm operator into a ComputationNode"""

def __init__(
self,
node_id: int,
node: NodeProto,
nodes_outputs: dict[int, list[str]],
mapping_data: list[dict[str, Any]],
nodes_outputs: dict[int, Any],
onnx_model: ModelProto,
*,
mapping_data: list[dict[str, Any]],
accelerator: Accelerator,
) -> None:
super().__init__(node_id, node, nodes_outputs, mapping_data, onnx_model)
self.node_id = node_id
self.node = node
self.nodes_outputs = nodes_outputs
self.onnx_model = onnx_model
self.mapping_data = mapping_data
self.accelerator = accelerator

def run(self):
Expand Down
5 changes: 2 additions & 3 deletions stream/classes/io/onnx/lpnormalization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from zigzag.parser.onnx.ONNXOperatorParser import ONNXOperatorParser

from stream.classes.io.onnx.operator_parser import OnnxOperatorParser
from stream.classes.workload.lpnormalization_node import LpNormalizationNode


class LpNormalizationParser(ONNXOperatorParser):
class LpNormalizationParser(OnnxOperatorParser):
"""Parses an onnx reshape operator into a LpNormalizationNode."""

def __init__(self, node_id, node, nodes_outputs, mapping, onnx_model) -> None:
Expand Down
49 changes: 3 additions & 46 deletions stream/classes/io/onnx/matmul.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,9 @@
import logging
from typing import Any

from onnx import ModelProto, NodeProto
from zigzag.parser.onnx.MatMulParser import MatMulParser as MatMulParserZigZag

from stream.classes.hardware.architecture.accelerator import Accelerator
from stream.classes.workload.computation_node import ComputationNode
from stream.classes.io.onnx.gemm import GemmParser

logger = logging.getLogger(__name__)


class MatMulParser(MatMulParserZigZag):
"""Parses an ONNX MatMul operator into a ComputationNode"""

def __init__(
self,
node_id: int,
node: NodeProto,
nodes_outputs: dict[int, Any],
mapping_data: list[dict[str, Any]],
onnx_model: ModelProto,
accelerator: Accelerator,
) -> None:
super().__init__(node_id, node, nodes_outputs, mapping_data, onnx_model)
self.accelerator = accelerator

def run(self):
"""Run the parser"""
return self.generate_node()

def generate_node(self):
layer_node = self.generate_layer_node()
node_attrs = layer_node.extract_node_attr()

# Override spatial mapping by the one defined in the core's dataflows
core_allocation = node_attrs.core_allocation
spatial_mapping = self.accelerator.get_spatial_mapping_from_core(core_allocation)
node_attrs.spatial_mapping = spatial_mapping

# Get the node's input(s) and output(s) tensor names
node_input_names = list(self.node.input)
node_output_names = list(self.node.output)
return ComputationNode(
node_id=self.node_id,
node_name=self.node.name,
node_attr=node_attrs,
input_names=node_input_names,
output_names=node_output_names,
op_type=node_attrs.layer_type,
operand_tensor_reshape=None,
)
class MatMulParser(GemmParser):
"""! Parses an ONNX MatMul operator into a ComputationNode. Exactly the same as Gemm Parser"""
Loading

0 comments on commit 0227f54

Please sign in to comment.