From ea00fe81030d7cb70c8e353c6222e223d97059b3 Mon Sep 17 00:00:00 2001 From: RobinGeens Date: Tue, 13 Aug 2024 12:03:52 +0200 Subject: [PATCH] more typehint fixes --- .../cost_model/communication_manager.py | 30 ++++----- stream/classes/cost_model/cost_model.py | 2 +- stream/classes/cost_model/scheduler.py | 41 +++++++------ .../hardware/architecture/accelerator.py | 61 ++++++++++++------- .../architecture/noc/communication_link.py | 31 ++++++---- .../classes/stages/IntraCoreMappingStage.py | 7 ++- stream/classes/workload/computation_node.py | 27 ++++---- stream/classes/workload/onnx_workload.py | 26 ++++++++ stream/classes/workload/tensor.py | 19 +++--- 9 files changed, 147 insertions(+), 97 deletions(-) diff --git a/stream/classes/cost_model/communication_manager.py b/stream/classes/cost_model/communication_manager.py index 3128e64..96c2f4a 100644 --- a/stream/classes/cost_model/communication_manager.py +++ b/stream/classes/cost_model/communication_manager.py @@ -2,7 +2,6 @@ from math import ceil from typing import TYPE_CHECKING, Any -import networkx as nx from zigzag.datatypes import Constants, MemoryOperand from zigzag.hardware.architecture.Core import Core @@ -12,6 +11,7 @@ if TYPE_CHECKING: from stream.classes.hardware.architecture.accelerator import Accelerator + from stream.classes.hardware.architecture.noc.communication_link import CommunicationLink class CommunicationEvent: @@ -50,7 +50,9 @@ class CommunicationLinkEvent: * the percentage of the link bandwidth used """ - def __init__(self, type, start, end, tensors, energy, activity=100) -> None: + def __init__( + self, type: str, start: int, end: int, tensors: list[Tensor], energy: float, activity: float = 100 + ) -> None: self.type = type self.start = start self.end = end @@ -89,12 +91,10 @@ def __init__(self, accelerator: "Accelerator") -> None: def get_shortest_paths(self): # For each core pair save a shortest path - shortest_paths: dict[tuple[Core, Core], Any] = {} - for producer_core, consumer_core in itertools.product( - self.accelerator.cores.nodes(), self.accelerator.cores.nodes() - ): - shortest_paths[(producer_core, consumer_core)] = nx.shortest_path( - self.accelerator.cores, producer_core, consumer_core + shortest_paths: dict[tuple[Core, Core], list[Core]] = {} + for producer_core, consumer_core in itertools.product(self.accelerator.core_list, self.accelerator.core_list): + shortest_paths[(producer_core, consumer_core)] = self.accelerator.cores.shortest_path( + producer_core, consumer_core ) return shortest_paths @@ -137,7 +137,7 @@ def update_links( tensor: Tensor, sender: Core | int, receiver: Core | int, - receiver_memory_operand: str, + receiver_memory_operand: MemoryOperand, start_timestep: int, duration: int, ) -> tuple[int, int, float, float]: @@ -163,7 +163,7 @@ def update_links( receiver = self.accelerator.get_core(receiver) links = self.get_links_for_pair(sender, receiver) if not links: # When sender == receiver - return 0, 0 + return 0, 0, 0, 0 cles = [ CommunicationLinkEvent( @@ -214,7 +214,7 @@ def block_offchip_links( duration (int): The duration of the blocking in cycles. cn (ComputationNode): The computational node for which we are blocking the links. """ - links_to_block = dict() + links_to_block: dict["CommunicationLink", int] = {} core = self.accelerator.get_core(core_id) offchip_core = self.accelerator.get_core(self.accelerator.offchip_core_id) if Constants.OUTPUT_MEM_OP in too_large_operands: @@ -230,7 +230,7 @@ def block_offchip_links( if not too_large_operands: return start_timestep # Get the tensors for which we are blocking based on the operands - tensors = [] + tensors: list[Tensor] = [] for mem_op in too_large_operands: layer_op = cn.memory_operand_links.mem_to_layer_op(mem_op) tensors.append(cn.operand_tensors[layer_op]) @@ -242,7 +242,9 @@ def block_offchip_links( link.block(block_start, duration, tensors, activity=req_bw) return block_start - def get_links_idle_window(self, links: dict, best_case_start: int, duration: int, tensors: list[Tensor]) -> int: + def get_links_idle_window( + self, links: dict["CommunicationLink", int], best_case_start: int, duration: int, tensors: list[Tensor] + ) -> int: """Return the timestep at which tensor can be transfered across the links. Both links must have an idle window large enough for the transfer. The timestep must be greater than or equal to best_case_start. @@ -254,7 +256,7 @@ def get_links_idle_window(self, links: dict, best_case_start: int, duration: int tensors (list): The tensors to be transferred. Used to broadcast from previous transfer. """ assert len(links) > 0 - idle_intersections = [] + idle_intersections: list[tuple[int, int]] = [] for i, (link, req_bw) in enumerate(links.items()): req_bw = min(req_bw, link.bandwidth) # ceil the bw windows = link.get_idle_window(req_bw, duration, best_case_start, tensors) diff --git a/stream/classes/cost_model/cost_model.py b/stream/classes/cost_model/cost_model.py index 44e0021..38f082b 100644 --- a/stream/classes/cost_model/cost_model.py +++ b/stream/classes/cost_model/cost_model.py @@ -16,7 +16,7 @@ def __init__( workload: ComputationNodeWorkload, accelerator: Accelerator, operands_to_prefetch: list[str], - scheduling_order: list[int], + scheduling_order: list[tuple[int, int]], ) -> None: # Initialize the SCME by setting the workload graph to be scheduled self.workload = workload diff --git a/stream/classes/cost_model/scheduler.py b/stream/classes/cost_model/scheduler.py index 27b6f8a..3417715 100644 --- a/stream/classes/cost_model/scheduler.py +++ b/stream/classes/cost_model/scheduler.py @@ -1,25 +1,27 @@ import logging from operator import itemgetter +from typing import TYPE_CHECKING -from networkx import DiGraph from zigzag.datatypes import Constants, LayerOperand, MemoryOperand from zigzag.hardware.architecture.Core import Core -from stream.classes.hardware.architecture.accelerator import Accelerator from stream.classes.workload.computation_node import ComputationNode from stream.classes.workload.onnx_workload import ComputationNodeWorkload from stream.classes.workload.tensor import Tensor +if TYPE_CHECKING: + from stream.classes.hardware.architecture.accelerator import Accelerator + logger = logging.getLogger(__name__) -def initialize_priorities(workload: ComputationNodeWorkload, accelerator: Accelerator): +def initialize_priorities(workload: ComputationNodeWorkload, accelerator: "Accelerator"): for n in workload.node_list: for tensor in n.operand_tensors.values(): tensor.initialize_instance_priorities(workload, n, accelerator) -def initialize_offchip_tensors(workload: ComputationNodeWorkload, accelerator: Accelerator): +def initialize_offchip_tensors(workload: ComputationNodeWorkload, accelerator: "Accelerator"): offchip_core_id = accelerator.offchip_core_id assert offchip_core_id is not None, "No offchip core found for this accelerator" offchip_core = accelerator.get_core(offchip_core_id) @@ -44,7 +46,7 @@ def initialize_offchip_tensors(workload: ComputationNodeWorkload, accelerator: A ) -def prefetch_constant_operands(G: ComputationNodeWorkload, accelerator: Accelerator, operands_to_prefetch: list[str]): +def prefetch_constant_operands(G: ComputationNodeWorkload, accelerator: "Accelerator", operands_to_prefetch: list[str]): operands_to_prefetch_converted = [LayerOperand(x) for x in operands_to_prefetch] total_cn_offchip_link_energy = 0 total_cn_offchip_memory_energy = 0 @@ -78,11 +80,14 @@ def prefetch_constant_operands(G: ComputationNodeWorkload, accelerator: Accelera ) -def get_best_candidate(candidates: list[ComputationNode], scheduling_order: list[int]) -> tuple[ComputationNode, int]: +def get_best_candidate( + candidates: list[tuple[int, ComputationNode]], scheduling_order: list[tuple[int, int]] +) -> tuple[ComputationNode, int]: # If this core doesn't have any candidates, continue to the next core if not candidates: raise ValueError("There are no candidates to schedule.") preds_ends, cn_candidates = zip(*candidates) + cn_candidates: list[ComputationNode] idxs = [scheduling_order.index((n.id, n.sub_id)) for n in cn_candidates] best_candidate_idx = idxs.index(min(idxs)) best_candidate = cn_candidates[best_candidate_idx] @@ -132,7 +137,7 @@ def get_tensors_needed_for_node(node: ComputationNode, G: ComputationNodeWorkloa def clear_memories( - accelerator: Accelerator, + accelerator: "Accelerator", core: Core, memory_operands: list[MemoryOperand], timestep: int, @@ -158,7 +163,7 @@ def clear_memories( def decrease_priority( tensors: list[Tensor], tensors_operands: list[MemoryOperand], - accelerator: Accelerator, + accelerator: "Accelerator", node: ComputationNode, ): for tensor_used_by_node, tensor_memory_operand in zip(tensors, tensors_operands): @@ -171,18 +176,15 @@ def decrease_priority( def check_for_removal( tensors: list[Tensor], - accelerator: Accelerator, + accelerator: "Accelerator", node: ComputationNode, - G: DiGraph, + G: ComputationNodeWorkload, timestep: int, ): offchip_core_id = accelerator.offchip_core_id for tensor_used_by_node in tensors: if tensor_used_by_node.get_total_priority() == 0: - ( - instances_storing_tensor, - _, - ) = accelerator.memory_manager.find_tensor_in_top_instances(tensor_used_by_node) + instances_storing_tensor, _ = accelerator.memory_manager.find_tensor_in_top_instances(tensor_used_by_node) for instance_storing_tensor in instances_storing_tensor: core_ids_of_instance = [ core.id for core in accelerator.memory_manager.cores_per_top_instance[instance_storing_tensor] @@ -220,11 +222,11 @@ def check_for_removal( def schedule_graph( G: ComputationNodeWorkload, - accelerator: Accelerator, + accelerator: "Accelerator", cores_idle_from: dict[int, int] | None = None, operands_to_prefetch: list[str] = [], - scheduling_order=None, -): + scheduling_order: list[tuple[int, int]] | None = None, +) -> tuple[int, float, float, float, float, float, float, float, float, float]: """Schedule the nodes of graph G across the cores in the system. Each node should have a core_allocation and runtime set. @@ -264,7 +266,7 @@ def schedule_graph( # Put the very first nodes of a layer that doesn't have any incoming edges as the first candidates for source_node in (n for n, d in G.in_degree() if d == 0): core_allocation = source_node.chosen_core_allocation - candidates.append((cores_idle_from[core_allocation], source_node)) + candidates.append((cores_idle_from[core_allocation], source_node)) # type: ignore # Get all the nodes with no successors that produce final outputs, used for off-loading final outputs sink_layers = sorted(set(n.id for n, d in G.out_degree() if d == 0)) @@ -272,6 +274,7 @@ def schedule_graph( # Get the offchip core id and core offchip_core_id = accelerator.offchip_core_id + assert offchip_core_id is not None offchip_core = accelerator.get_core(offchip_core_id) # Schedule preparation: @@ -433,7 +436,7 @@ def schedule_graph( # Only push back sink node outputs if they're generated and stored on the core if best_candidate.output_operand not in best_candidate.too_large_operands: ( - current_timestep, + _, link_energy_cost, memory_energy_cost, ) = accelerator.remove( diff --git a/stream/classes/hardware/architecture/accelerator.py b/stream/classes/hardware/architecture/accelerator.py index f58ef6e..161626c 100644 --- a/stream/classes/hardware/architecture/accelerator.py +++ b/stream/classes/hardware/architecture/accelerator.py @@ -1,6 +1,6 @@ from math import ceil -from typing import Iterator +import networkx as nx from networkx import DiGraph from zigzag.datatypes import MemoryOperand from zigzag.hardware.architecture.Core import Core @@ -9,9 +9,21 @@ from stream.classes.cost_model.communication_manager import CommunicationManager from stream.classes.cost_model.memory_manager import MemoryManager +from stream.classes.workload.computation_node import ComputationNode from stream.classes.workload.tensor import Tensor +class CoreGraph(DiGraph): + """Represents the core structure of an accelerator""" + + @property + def node_list(self) -> list[Core]: + return list(self.nodes()) # type: ignore + + def shortest_path(self, producer: Core, consumer: Core) -> list[Core]: + return nx.shortest_path(self, producer, consumer) # type: ignore + + class Accelerator: """ The Accelerator class houses a set of Cores with an additional Global Buffer. @@ -22,7 +34,7 @@ class Accelerator: def __init__( self, name: str, - cores: DiGraph, + cores: CoreGraph, offchip_core_id: int | None = None, ): self.name = name @@ -48,35 +60,35 @@ def get_core(self, core_id: int) -> Core: Return the core with id 'core_id'. Raises ValueError() when a core_id is not found in the available cores. """ - core = next((core for core in self.core_iterator if core.id == core_id), None) + core = next((core for core in self.core_list if core.id == core_id), None) if core is None: raise ValueError(f"Requested core with id {core_id} is not present in accelerator.") return core - @property - def core_iterator(self) -> Iterator[Core]: - return self.cores.nodes() # type: ignore + # @property + # def core_iterator(self) -> Iterator[Core]: + # return self.cores.nodes() # type: ignore @property - def core_list(self) -> list[Core]: - return list(self.cores.nodes()) # type: ignore + def core_list(self): + return self.cores.node_list def spawn( self, tensor: Tensor, core: Core, - memory_op: str, + memory_op: MemoryOperand, initial_timestep: int, available_timestep: int, ): """Spawns a tensor on a core. Args: - tensor (Tensor): The tensor to be spawned. - core (Core): The core on which to spawn the tensor. - memory_op (str): The memory operand on the core where the tensor will spawn. - initial_timestep (int): The timestep at which space will be reserved for the tensor. - available_timestep (int): The timestep at which the tensor will become available. Different from + tensor: The tensor to be spawned. + core: The core on which to spawn the tensor. + memory_op: The memory operand on the core where the tensor will spawn. + initial_timestep: The timestep at which space will be reserved for the tensor. + available_timestep: The timestep at which the tensor will become available. Different from initial_timestep when it is transferred. """ self.memory_manager.add_tensor_to_core(tensor, core, initial_timestep, available_timestep, memory_op) @@ -93,6 +105,7 @@ def remove( timestep (int): The timestep to remove the tensor at. write_back_to_offchip (bool, optional): Write the tensor to offchip before removal. Defaults to False. """ + assert self.offchip_core_id is not None ################################# STEP 1 ################################# # Transfer the tensor to off-chip if required and not present there link_energy_cost = 0 @@ -234,7 +247,7 @@ def transfer_tensor_to_core( tensor_operand: MemoryOperand, non_evictable_tensors: list[Tensor], sending_core_id: int | None = None, - ): + ) -> tuple[int, float, float, float, float, bool]: """ Transfer a tensor to a given core id. If the tensor is already present on the receiving core, nothing happens. @@ -275,10 +288,7 @@ def transfer_tensor_to_core( tensor.equality_hash() ] else: - ( - _, - available_since_timesteps, - ) = self.find_tensor_in_top_instances(tensor) + (_, available_since_timesteps) = self.find_tensor_in_top_instances(tensor) # Pick the core that has stored the tensor the longest available_since_timestep = min(available_since_timesteps.values()) storing_instance = next( @@ -323,9 +333,7 @@ def transfer_tensor_to_core( links, evictions_complete_timestep, transfer_duration, - [ - tensor, - ], + [tensor], ) transfer_end = transfer_start + transfer_duration ################################# STEP 5 ################################# @@ -401,7 +409,14 @@ def get_memory_energy_cost_of_transfer( return sender_energy + receiver_energy - def block_offchip_links(self, too_large_operands, core_id, start_timestep, duration, cn) -> int: + def block_offchip_links( + self, + too_large_operands: list[MemoryOperand], + core_id: int, + start_timestep: int, + duration: int, + cn: ComputationNode, + ) -> int: return self.communication_manager.block_offchip_links(too_large_operands, core_id, start_timestep, duration, cn) def contains_tensor(self, tensor: Tensor, top_instance: int | MemoryInstance): diff --git a/stream/classes/hardware/architecture/noc/communication_link.py b/stream/classes/hardware/architecture/noc/communication_link.py index 884ff13..e1667b3 100644 --- a/stream/classes/hardware/architecture/noc/communication_link.py +++ b/stream/classes/hardware/architecture/noc/communication_link.py @@ -1,12 +1,21 @@ +from typing import TYPE_CHECKING + import numpy as np from stream.classes.cost_model.communication_manager import CommunicationLinkEvent +if TYPE_CHECKING: + from zigzag.hardware.architecture.Core import Core + + from stream.classes.workload.tensor import Tensor + class CommunicationLink: """Represents a fixed-bandwidth communication link used to communicate between two cores.""" - def __init__(self, sender, receiver, bandwidth, unit_energy_cost, bidirectional=False) -> None: + def __init__( + self, sender: "Core", receiver: "Core", bandwidth: int, unit_energy_cost: float, bidirectional: bool = False + ) -> None: self.sender = sender self.receiver = receiver self.bandwidth = bandwidth @@ -36,8 +45,8 @@ def __hash__(self) -> int: ) ) - def __eq__(self, other) -> bool: - return (self.sender, self.receiver, self.bandwidth) == ( + def __eq__(self, other: object) -> bool: + return isinstance(other, CommunicationLink) and (self.sender, self.receiver, self.bandwidth) == ( other.sender, other.receiver, other.bandwidth, @@ -54,9 +63,9 @@ def transfer(self, cle: CommunicationLinkEvent) -> float: The transfer can take longer than necessary for this link if another lower-bandwidth link is involved. Args: - tensor (Tensor): The tensor to be transferred. - start (int): The timestep in clock cyles to start the transfer. - duration (int): The duration of the transfer. + tensor : The tensor to be transferred. + start : The timestep in clock cyles to start the transfer. + duration : The duration of the transfer. Returns: int: The end time when communication on this link is finished @@ -69,7 +78,7 @@ def block( self, start: int, duration: int, - tensors: list, + tensors: list["Tensor"], activity: int = 100, ): """Block this communication link from start timestep for a given duration. @@ -121,12 +130,12 @@ def update_activity(self, event: CommunicationLinkEvent): self.tensors[tensor] = self.tensors.get(tensor, []) + [event] self.events.append(event) - def get_idle_window(self, activity, duration, earliest_t, tensors): + def get_idle_window(self, activity: float, duration: int, earliest_t: int, tensors: list["Tensor"]): """ Get the earliest time window of duration 'duration' from 'earliest_t' with atleast 'activity' percent available. """ - valid_windows = [] + valid_windows: list[tuple[int, int]] = [] ## Check if this tensor has already been transferred on this link before # If so, check duration and earliest timestep requirements of this call for tensor in tensors: @@ -154,11 +163,11 @@ def get_idle_window(self, activity, duration, earliest_t, tensors): idxs.append(len(updated_ts) - 1) start = earliest_t for idx in idxs: - end = updated_ts[idx] + end: int = updated_ts[idx] if end - start >= duration: valid_windows.append((start, end)) try: - start = updated_ts[idx + 1] + start: int = updated_ts[idx + 1] except IndexError: break if not valid_windows: diff --git a/stream/classes/stages/IntraCoreMappingStage.py b/stream/classes/stages/IntraCoreMappingStage.py index 2087719..237ab83 100644 --- a/stream/classes/stages/IntraCoreMappingStage.py +++ b/stream/classes/stages/IntraCoreMappingStage.py @@ -98,7 +98,7 @@ def run(self): for core_id in core_ids: core = self.accelerator.get_core(core_id) # Offchip memory core doesn't have operational units - if core.operational_array.total_area == 0: + if core.operational_array.total_unit_count == 0: continue # It's possible this node might not fully fit within the core's top level memories. If so, we update # the core @@ -148,7 +148,8 @@ def run(self): ) answers = main_stage.run() assert len(answers) == 1, "IntraCoreMappingStage's subflow returned more than one CME" - cme = answers[0][0] + cme: CostModelEvaluation = answers[0][0] # type: ignore + # TODO should this be `chosen_core_allocation`? node.core_allocation = None # Reset the node's core allocation self.node_hw_performances[node][core] = cme self.save_node_hw_performances() # Save the hw performances dict after every node is finished @@ -175,7 +176,7 @@ def visualize_node_hw_performances(self): if "visualize_node_hw_performances_path": # Get the scale factors scale_factors = { - n.id: len(list(cn for cn in self.workload if cn == n)) for n in self.node_hw_performances + n.id: len([cn for cn in self.workload.node_list if cn == n]) for n in self.node_hw_performances } # Run the visualization visualize_node_hw_performances_pickle( diff --git a/stream/classes/workload/computation_node.py b/stream/classes/workload/computation_node.py index 3038d2f..84f498f 100644 --- a/stream/classes/workload/computation_node.py +++ b/stream/classes/workload/computation_node.py @@ -23,9 +23,6 @@ class ComputationNode(LayerNode, Node): On top of that, some new information is added for correct dependency generation for the finer graph that is built when a layer is split into one and is a producer/consumer of another layer. - - Args: - LayerNode (_type_): _description_ """ def __init__( @@ -58,17 +55,16 @@ def __init__( self.sub_id = sub_id self.group = group_id - self.__hash_value = hash((self.id, self.sub_id)) + self.__static_hash_value = hash((self.id, self.sub_id)) self.operand_tensor_reshape = ( operand_tensor_reshape if operand_tensor_reshape is not None else self.get_operand_tensor_reshape_default() ) # Whether this ComputationNode produces a final output self.produces_final_output = produces_final_output - # self.loop_ranges: dict[str, tuple] = node_attrs.get( - # "loop_ranges", {dim: (0, size) for dim, size in self.loop_dim_size.items()} - # ) - self.loop_ranges: LoopRanges = {layer_dim: (0, size) for layer_dim, size in self.layer_dim_sizes.items()} + self.loop_ranges: LoopRanges = { # type: ignore + layer_dim: (0, size) for layer_dim, size in self.layer_dim_sizes.items() + } # adds pr dimensions loop ranges to self.loop_ranges self.calculate_pr_loop_ranges() @@ -129,16 +125,16 @@ def __str__(self): return f"ComputationNode{self.id}_{self.sub_id}" def __hash__(self) -> int: - """The hash operator of a node depending on its id. The id is a tuple that can be of variable depth. + """The hash operator of a node. Returns: - int: the pre-computed hash + the pre-computed hash """ - return self.__hash_value + return self.__static_hash_value def __eq__(self, other: object): """Fast equality comparison between two nodes""" - return self.__hash_value == hash(other) + return self.__static_hash_value == hash(other) def is_equal_extended(self, other: object) -> bool: """Compare the equality between two nodes. @@ -202,11 +198,10 @@ def calculate_pr_loop_ranges(self): # Assume that there is always 2 dimensions involved in the calculation of a pr dimension pr_dim_val_min = -padding_begin pr_dim_val_max = -padding_begin - for related_dimension, scaling_factor in related_dims_and_scalings.items(): + for related_dimension, scaling_factor in related_dims_and_scalings: pr_dim_val_min += scaling_factor * self.loop_ranges[related_dimension][0] - pr_dim_val_max += scaling_factor * ( - self.loop_ranges[related_dimension][1] - 1 - ) # convert to inclusive upper limit + # convert to inclusive upper limit + pr_dim_val_max += scaling_factor * (self.loop_ranges[related_dimension][1] - 1) pr_dim_val_max += 1 # convert to exclusive upper range self.loop_ranges[pr_dim] = (pr_dim_val_min, pr_dim_val_max) diff --git a/stream/classes/workload/onnx_workload.py b/stream/classes/workload/onnx_workload.py index 0cf500d..bac4d39 100644 --- a/stream/classes/workload/onnx_workload.py +++ b/stream/classes/workload/onnx_workload.py @@ -69,12 +69,38 @@ def in_edges( # type: ignore """Overwrite DiGraph method with type hints""" return super().in_edges(node, data) # type: ignore + @overload + def out_edges(self, node: T, data: Literal[True]) -> list[tuple[T, T, dict[str, Any]]]: + ... # type: ignore + + @overload + def out_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: + ... # type: ignore + + @overload + def out_edges(self, node: T) -> list[tuple[T, T]]: + ... # type: ignore + + def out_edges( # type: ignore + self, + node: T, + data: bool = False, + ) -> list[tuple[T, T]] | list[tuple[T, T, dict[str, Any]]]: + """Overwrite DiGraph method with type hints""" + return super().out_edges(node, data) # type: ignore + def in_degree(self) -> Iterator[tuple[T, int]]: # type: ignore return super().in_degree() # type:ignore def out_degree(self) -> Iterator[tuple[T, int]]: # type: ignore return super().out_degree() # type:ignore + def successors(self, node: T) -> Iterator[T]: # type: ignore + return super().successors(node) # type: ignore + + def predecessors(self, node: T) -> Iterator[T]: # type: ignore + return super().predecessors(node) # type: ignore + class ComputationNodeWorkload(WorkloadABC[ComputationNode]): """Workload graph with only ComputationNodes""" diff --git a/stream/classes/workload/tensor.py b/stream/classes/workload/tensor.py index 602f28b..530658c 100644 --- a/stream/classes/workload/tensor.py +++ b/stream/classes/workload/tensor.py @@ -1,11 +1,14 @@ from typing import TYPE_CHECKING -from networkx import DiGraph from zigzag.datatypes import LayerDim, LayerOperand if TYPE_CHECKING: + from zigzag.hardware.architecture.MemoryInstance import MemoryInstance + + from stream.classes.cost_model.memory_manager import MemoryManager from stream.classes.hardware.architecture.accelerator import Accelerator from stream.classes.workload.computation_node import ComputationNode + from stream.classes.workload.onnx_workload import ComputationNodeWorkload class Tensor: @@ -37,7 +40,7 @@ def __init__( self.loop_dimensions = loop_dimensions self.loop_ranges = loop_ranges self.base_priority: None | int = None # Will be set when we know how many successors this node has (static) - self.instance_priorities = {} + self.instance_priorities: dict[MemoryInstance, int] = {} self.id = (self.origin.id, self.origin.sub_id, layer_operand) def __str__(self) -> str: @@ -55,19 +58,13 @@ def __hash__(self) -> int: def __lt__(self, __o: object) -> bool: return isinstance(__o, Tensor) and self.size < __o.size - # def __eq__(self, __o: object) -> bool: - # return isinstance(__o, Tensor) and \ - # self.origin.id == __o.origin.id and \ - # self.layer_operand == __o.layer_operand and \ - # self.loop_ranges == __o.loop_ranges - def equality_hash(self): return hash((self.origin.id, self.layer_operand, self.loop_ranges)) def set_base_priorities(self, base_priority: int): self.base_priority = base_priority - def get_instance_priority(self, top_instance, memory_manager): + def get_instance_priority(self, top_instance: "MemoryInstance", memory_manager: "MemoryManager"): if top_instance in self.instance_priorities: return self.instance_priorities[top_instance] else: @@ -80,7 +77,9 @@ def get_instance_priority(self, top_instance, memory_manager): ) return not_storing_priority - def initialize_instance_priorities(self, G: DiGraph, node: "ComputationNode", accelerator: "Accelerator"): + def initialize_instance_priorities( + self, G: "ComputationNodeWorkload", node: "ComputationNode", accelerator: "Accelerator" + ): if self.layer_operand == node.output_operand: out_edges = [(succ, d) for n, succ, d in G.out_edges(node, data=True) if succ.id != n.id] for successor, data in out_edges: