Skip to content

Commit

Permalink
add nb_real_predecessors to ComputationNode hash and has_same_perform…
Browse files Browse the repository at this point in the history
…ance
  • Loading branch information
asyms committed Oct 17, 2024
1 parent b73bad5 commit f0bafb9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

logger = logging.getLogger(__name__)

Edge = tuple[ComputationNode, ComputationNode, dict]


class TensorDimensionMismatchException(Exception):
"""Facilitates error handling in case incorrect tensor dimensions are passed on"""
Expand Down Expand Up @@ -65,8 +67,9 @@ def __init__(

def run(self):
unique_finer_nodes: list[ComputationNode] = []
# For each node get all the finer nodes and set the intra edges
partitioned_workload = ComputationNodeWorkload()
# For each node get all the finer nodes and the edges between them
all_finer_nodes = []
all_finer_edges = []
for node in self.workload.topological_sort():
# If other node types shouldn't be included in finer node graph, add here
if not isinstance(node, ComputationNode):
Expand All @@ -77,12 +80,10 @@ def run(self):
logger.info(f"{node}: Generated {len(finer_nodes)} finer nodes.")
self.finer_nodes_dict[node] = finer_nodes
unique_finer_nodes += unique_nodes
# Compute the edges between nodes originating from one bigger node (intra-edges)
intra_edges = self.get_intra_edges(finer_nodes)
partitioned_workload.add_edges_from(intra_edges)
# If there is only one finer node for this layer, add the node to the graph
if not intra_edges:
partitioned_workload.add_nodes_from(finer_nodes)
# Add the finer nodes and intra edges to the lists
all_finer_nodes += finer_nodes
all_finer_edges += intra_edges

# Get all pairs of nodes that we have to extract inter edges for
all_pairs = self.get_all_node_pairs(self.workload)
Expand All @@ -93,10 +94,18 @@ def run(self):
inter_edges = self.get_inter_edges_numpy(producer, consumer)
else:
inter_edges = self.get_inter_edges_rtree(producer, consumer, finer_producers, finer_consumers)
partitioned_workload.add_edges_from(inter_edges)
all_finer_edges += inter_edges

# Set the base_priority value of all nodes
self.set_base_priority_of_nodes(all_finer_nodes, all_finer_edges)

# Set the base_priority value of all nodes in G
self.set_base_priority_of_nodes(partitioned_workload, self.finer_nodes_dict)
# Set the number of real predecessors of all nodes
self.set_nb_real_predecessors(all_finer_nodes, all_finer_edges)

# Construct the new finer workload graph
# The graph construction needs to happen after the base priority and nb_real_predecessors are set
partitioned_workload = ComputationNodeWorkload()
partitioned_workload.add_edges_from(all_finer_edges)

logger.info(f"Finer graph: {partitioned_workload}.")

Expand Down Expand Up @@ -779,26 +788,28 @@ def get_tensor_cns(
return tensors_cns

@staticmethod
def set_base_priority_of_nodes(
G: ComputationNodeWorkload, finer_nodes_dict: dict[ComputationNode, list[ComputationNode]]
):
def set_base_priority_of_nodes(nodes: list[ComputationNode], edges: list[Edge]):
"""Set the base_priority of all stored tensors of variable operands in every node in finer_nodes
based on the amount of real (excluding same layer edges) edges.
based on the amount of real (excluding same layer edges) edges.
Args:
finer_nodes (list): List of the nodes for which to set the tensors' base_priority
nodes (list): List of nodes.
edges (list): List of edges in the form of (producer, consumer, data).
"""
nb_nodes_per_layer_id = {layer.id: len(finer_nodes_dict[layer]) for layer in finer_nodes_dict.keys()}
nb_seen_nodes_per_layer_id = {layer_id: 0 for layer_id in nb_nodes_per_layer_id.keys()}
for node in G.topological_sort():
layer_id = node.id
for layer_operand in node.layer_operands:
tensor: Tensor = node.operand_tensors[layer_operand]
if layer_operand == node.output_operand:
# Look at the amount of successors from different layers
successors = [succ for succ in G.successors(node) if succ.id != layer_id]
tensor.set_base_priorities(len(successors))
nb_seen_nodes_per_layer_id[layer_id] += 1
for node in nodes:
output_operand = node.output_operand
output_tensor = node.operand_tensors[output_operand]
successors = [cons for prod, cons, _ in edges if prod == node]
output_tensor.set_base_priorities(len(successors))

@staticmethod
def set_nb_real_predecessors(nodes: list[ComputationNode], edges: list[Edge]):
"""Set the number of real predecessors for each node in the graph.
A real predecessor is a node that is not in the same layer as the node itself.
"""
for node in nodes:
nb_real_predecessors = [prod for prod, cons, _ in edges if cons == node and prod.id != cons.id]
node.set_nb_real_predecessors(len(nb_real_predecessors))

def get_weight_capacities(self):
# Get the weight capacity of all cores
Expand Down
20 changes: 18 additions & 2 deletions stream/workload/computation/computation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TypeAlias

from zigzag.datatypes import Constants, LayerDim, LayerOperand, MemoryOperand
from zigzag.utils import hash_sha512
from zigzag.visualization.results.plot_cme import shorten_onnx_layer_name
from zigzag.workload.layer_attributes import (
LayerPadding,
Expand Down Expand Up @@ -84,7 +85,6 @@ def __init__(

self.sub_id = sub_id
self.group = group_id
self._static_hash_value = self.__compute_static_hash()
self.operand_tensor_reshape = (
operand_tensor_reshape if operand_tensor_reshape is not None else self.get_operand_tensor_reshape_default()
)
Expand All @@ -102,6 +102,12 @@ def __init__(
self.get_node_operand = self.memory_operand_links.mem_to_layer_op
self.extract_node_info = self.extract_layer_info

# Number of real predecessors is saved to deal with edge cases where some nodes of the same layer have differing predecessors
# This is used to hash the node and to get accurate knowledge of the number of unique nodes.
# This should be set after the node is created and the number of predecessors is known.
self.nb_real_predecessors = None
self._static_hash_value = self.__compute_static_hash()

try:
self.fusion_partition_dims = ComputationNode.FUSION_DIM_MAPPING[op_type]
except KeyError:
Expand Down Expand Up @@ -151,14 +157,15 @@ def short_name(self) -> str:
def __compute_static_hash(self):
"""Return a value that can be used to identify unique nodes in sets, dicts and equality. It is pre-computed at
initialization time to speed up dict lookup and instance equality"""
return hash(
return hash_sha512(
(
self.layer_dim_sizes,
frozenset(self.dimension_relations),
self.operand_precision,
self.memory_operand_links,
self.id,
self.sub_id,
self.nb_real_predecessors,
)
)

Expand Down Expand Up @@ -190,6 +197,7 @@ def has_same_performance(self, other: object) -> bool:
be equal.
- memory_operand_links: The link between memory operand (paths in mem hierarchy) and this node's operands
accurate knowledge of the number of unique nodes.
- nb_real_predecessors: The number of predecessors of the node. This impacts the required memory size.
Args:
other (Node): The other node to compare this node with
Expand All @@ -204,6 +212,7 @@ def has_same_performance(self, other: object) -> bool:
and self.operand_precision == other.operand_precision
and self.memory_operand_links == other.memory_operand_links
and self.id == other.id
and self.nb_real_predecessors == other.nb_real_predecessors
# NOTE: don't include sub_id
)

Expand Down Expand Up @@ -274,3 +283,10 @@ def extract_inter_core_mapping_attr(self):
inter_core_tiling=self.inter_core_tiling,
)
return deepcopy(mapping_attr)

def get_nb_real_predecessors(self):
return self.nb_real_predecessors

def set_nb_real_predecessors(self, nb_real_predecessors: int):
self.nb_real_predecessors = nb_real_predecessors
self._static_hash_value = self.__compute_static_hash()

0 comments on commit f0bafb9

Please sign in to comment.