diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index f241fe227..2db81f1bc 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -27,10 +27,10 @@ """ +import gc import html from functools import partial from typing import ( - TYPE_CHECKING, Any, Callable, Mapping, @@ -56,6 +56,7 @@ Stack, ) from pytato.codegen import normalize_outputs +from pytato.distributed.nodes import DistributedSendRefHolder from pytato.distributed.partition import ( DistributedGraphPart, DistributedGraphPartition, @@ -67,10 +68,6 @@ from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer -if TYPE_CHECKING: - from pytato.distributed.nodes import DistributedSendRefHolder - - __doc__ = """ .. currentmodule:: pytato @@ -161,7 +158,8 @@ def emit_subgraph(sg: _SubgraphTree) -> None: class _DotNodeInfo: title: str fields: dict[str, Any] - edges: dict[str, ArrayOrNames | FunctionDefinition] + edges: dict[str, ArrayOrNames | FunctionDefinition | + tuple[int | ArrayOrNames, ArrayOrNames], Array] def stringify_tags(tags: frozenset[Tag | None]) -> str: @@ -178,11 +176,22 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" +def get_object_by_id(object_id: int) -> Any | ArrayOrNames: + """Find an object by its ID.""" + for obj in gc.get_objects(): + if id(obj) == object_id: + return obj + return None + + class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): - def __init__(self) -> None: - super().__init__() - self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {} + def __init__(self, count_duplicates: bool = False): + self.node_to_dot: dict[int | ArrayOrNames, _DotNodeInfo] = {} self.functions: set[FunctionDefinition] = set() + self.count_duplicates = count_duplicates + + def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: + return id(expr) if self.count_duplicates else expr def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ @@ -193,68 +202,93 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: "non_equality_tags": expr.non_equality_tags, } - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, + ArrayOrNames | + FunctionDefinition | + tuple[ + int | AbstractResultWithNamedArrays | Array, Array + ]] = {} return _DotNodeInfo(title, fields, edges) - # type-ignore-reason: incompatible with supertype - def handle_unsupported_array(self, # type: ignore[override] - expr: Array) -> None: + def process_node(self, expr: ArrayOrNames) -> None: + if isinstance(expr, DataWrapper): + self.map_data_wrapper(expr) + elif isinstance(expr, IndexLambda): + self.map_index_lambda(expr) + elif isinstance(expr, Stack): + self.map_stack(expr) + elif isinstance(expr, IndexBase): + self.map_basic_index(expr) + elif isinstance(expr, Einsum): + self.map_einsum(expr) + elif isinstance(expr, DictOfNamedArrays): + self.map_dict_of_named_arrays(expr) + elif isinstance(expr, LoopyCall): + self.map_loopy_call(expr) + elif isinstance(expr, DistributedSendRefHolder): + self.map_distributed_send_ref_holder(expr) + elif isinstance(expr, Call): + self.map_call(expr) + elif isinstance(expr, NamedCallResult): + self.map_named_call_result(expr) + else: + self.handle_unsupported_array(expr) + + def handle_unsupported_array(self, + expr: Array) -> None: # Default handler, does its best to guess how to handle fields. info = self.get_common_dot_info(expr) - - # pylint: disable=not-an-iterable + expr_key = self.get_cache_key(expr) for field in attrs.fields(type(expr)): if field.name in info.fields: continue attr = getattr(expr, field.name) - if isinstance(attr, Array): - self.rec(attr) - info.edges[field.name] = attr - + self.process_node(attr) + key = self.get_cache_key(attr) + info.edges[field.name] = (key, attr) elif isinstance(attr, AbstractResultWithNamedArrays): - self.rec(attr) - info.edges[field.name] = attr - + self.process_node(attr) + key = self.get_cache_key(attr) + info.edges[field.name] = (key, attr) elif isinstance(attr, tuple): info.fields[field.name] = stringify_shape(attr) - else: info.fields[field.name] = str(attr) - - self.node_to_dot[expr] = info + self.node_to_dot[expr_key] = info def map_data_wrapper(self, expr: DataWrapper) -> None: info = self.get_common_dot_info(expr) if expr.name is not None: info.fields["name"] = expr.name - # Only show summarized data import numpy as np with np.printoptions(threshold=4, precision=2): info.fields["data"] = str(expr.data) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_index_lambda(self, expr: IndexLambda) -> None: info = self.get_common_dot_info(expr) info.fields["expr"] = str(expr.expr) for name, val in expr.bindings.items(): - self.rec(val) - info.edges[name] = val + self.process_node(val) + key = self.get_cache_key(val) + info.edges[name] = (key, val) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_stack(self, expr: Stack) -> None: info = self.get_common_dot_info(expr) info.fields["axis"] = str(expr.axis) for i, array in enumerate(expr.arrays): - self.rec(array) - info.edges[str(i)] = array + self.process_node(array) + key = self.get_cache_key(array) + info.edges[str(i)] = (key, array) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info map_concatenate = map_stack @@ -270,9 +304,10 @@ def map_basic_index(self, expr: IndexBase) -> None: elif isinstance(index, Array): label = f"i{i}" - self.rec(index) + self.process_node(index) + key = self.get_cache_key(index) indices_parts.append(label) - info.edges[label] = index + info.edges[label] = (key, index) elif index is None: indices_parts.append("newaxis") @@ -282,10 +317,11 @@ def map_basic_index(self, expr: IndexBase) -> None: info.fields["indices"] = ", ".join(indices_parts) - self.rec(expr.array) - info.edges["array"] = expr.array + self.process_node(expr.array) + key = self.get_cache_key(expr.array) + info.edges["array"] = (key, expr.array) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info map_contiguous_advanced_index = map_basic_index map_non_contiguous_advanced_index = map_basic_index @@ -295,30 +331,35 @@ def map_einsum(self, expr: Einsum) -> None: for iarg, (access_descr, val) in enumerate(zip(expr.access_descriptors, expr.args)): - self.rec(val) - info.edges[f"{iarg}: {access_descr}"] = val + self.process_node(val) + key = self.get_cache_key(val) + info.edges[f"{iarg}: {access_descr}"] = (key, val) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, ArrayOrNames | FunctionDefinition | + tuple[int | ArrayOrNames, Array]] = {} for name, val in expr._data.items(): - edges[name] = val - self.rec(val) + self.process_node(val) + key = self.get_cache_key(val) + edges[name] = (key, val) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=type(expr).__name__, fields={}, edges=edges) def map_loopy_call(self, expr: LoopyCall) -> None: - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, ArrayOrNames | FunctionDefinition | + tuple[int | ArrayOrNames, Array]] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): - edges[name] = arg - self.rec(arg) + self.process_node(arg) + key = self.get_cache_key(arg) + edges[name] = (key, arg) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=type(expr).__name__, fields={"addr": hex(id(expr)), "entrypoint": expr.entrypoint}, edges=edges) @@ -328,29 +369,31 @@ def map_distributed_send_ref_holder( info = self.get_common_dot_info(expr) - self.rec(expr.passthrough_data) - info.edges["passthrough"] = expr.passthrough_data + self.process_node(expr.passthrough_data) + key = self.get_cache_key(expr.passthrough_data) + info.edges["passthrough"] = (key, expr.passthrough_data) - self.rec(expr.send.data) - info.edges["sent"] = expr.send.data + self.process_node(expr.send.data) + key = self.get_cache_key(expr.send.data) + info.edges["sent"] = (key, expr.send.data) info.fields["dest_rank"] = str(expr.send.dest_rank) - info.fields["comm_tag"] = str(expr.send.comm_tag) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_call(self, expr: Call) -> None: self.functions.add(expr.function) for bnd in expr.bindings.values(): - self.rec(bnd) + self.process_node(bnd) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=expr.__class__.__name__, edges={ "": expr.function, - **expr.bindings}, + **{name: (self.get_cache_key(bnd), bnd) + for name, bnd in expr.bindings.items()}}, fields={ "addr": hex(id(expr)), "tags": stringify_tags(expr.tags), @@ -358,14 +401,16 @@ def map_call(self, expr: Call) -> None: ) def map_named_call_result(self, expr: NamedCallResult) -> None: - self.rec(expr._container) - self.node_to_dot[expr] = _DotNodeInfo( + self.process_node(expr._container) + key = self.get_cache_key(expr._container) + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=expr.__class__.__name__, - edges={"": expr._container}, + edges={"": (key, expr._container)}, fields={"addr": hex(id(expr)), "name": expr.name}, ) + # }}} @@ -379,6 +424,12 @@ def dot_escape_leave_space(s: str) -> str: return html.escape(s.replace("\\", "\\\\")) +def get_array_key(array: ArrayOrNames | FunctionDefinition | int, + count_duplicates: bool = False) -> Any: + """Return a consistent key for the array.""" + return id(array) if count_duplicates and not isinstance(array, int) else array + + # {{{ emit helpers def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str: @@ -391,7 +442,7 @@ def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str: def _emit_array(emit: Callable[[str], None], title: str, fields: dict[str, Any], - dot_node_id: str, color: str = "white") -> None: + dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' @@ -416,8 +467,10 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: dict[str, Any], def _emit_name_cluster( emit: DotEmitter, subgraph_path: tuple[str, ...], names: Mapping[str, ArrayOrNames], - array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str], - label: str) -> None: + array_to_id: Mapping[ + int | ArrayOrNames, str], id_gen: Callable[[str], str], + label: str, + count_duplicates: bool = False) -> None: edges = [] cluster_subgraph_path = (*subgraph_path, f"cluster_{dot_escape(label)}") @@ -428,7 +481,8 @@ def _emit_name_cluster( for name, array in names.items(): name_id = id_gen(dot_escape(name)) emit_cluster(f'{name_id} [label="{dot_escape(name)}"]') - array_id = array_to_id[array] + array_key = get_array_key(array, count_duplicates) + array_id = array_to_id[array_key] # Edges must be outside the cluster. edges.append((name_id, array_id)) @@ -439,16 +493,18 @@ def _emit_name_cluster( def _emit_function( emitter: DotEmitter, subgraph_path: tuple[str, ...], id_gen: UniqueNameGenerator, - node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo], + node_to_dot: Mapping[int | ArrayOrNames, _DotNodeInfo], func_to_id: Mapping[FunctionDefinition, str], - outputs: Mapping[str, Array]) -> None: + outputs: Mapping[str, Array], + count_duplicates: bool = False) -> None: input_arrays: list[Array] = [] - internal_arrays: list[ArrayOrNames] = [] - array_to_id: dict[ArrayOrNames, str] = {} + internal_arrays: list[int | ArrayOrNames] = [] + array_to_id: dict[int | ArrayOrNames, str] = {} emit = partial(emitter, subgraph_path) for array in node_to_dot: - array_to_id[array] = id_gen("array") + key = get_array_key(array, count_duplicates) + array_to_id[key] = id_gen("array") if isinstance(array, InputArgumentBase): input_arrays.append(array) else: @@ -460,36 +516,47 @@ def _emit_function( emit_input('label="Arguments"') for array in input_arrays: + key = get_array_key(array, count_duplicates) _emit_array( emit_input, node_to_dot[array].title, node_to_dot[array].fields, - array_to_id[array]) + array_to_id[key]) # Emit non-inputs. for array in internal_arrays: + key = get_array_key(array, count_duplicates) _emit_array(emit, node_to_dot[array].title, node_to_dot[array].fields, - array_to_id[array]) + array_to_id[key]) # Emit edges. for array, node in node_to_dot.items(): - for label, tail_item in node.edges.items(): - head = array_to_id[array] + key = get_array_key(array, count_duplicates) + for label, edge_info in node.edges.items(): + if isinstance(edge_info, tuple): + tail_key, tail_item = edge_info + else: + tail_item = edge_info + tail_key = get_array_key(tail_item, count_duplicates) + + head = array_to_id[key] if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): - tail = array_to_id[tail_item] + tail = array_to_id[tail_key] elif isinstance(tail_item, FunctionDefinition): tail = func_to_id[tail_item] else: raise ValueError( - f"unexpected type of tail on edge: {type(tail_item)}") + f"unexpected type of tail on edge: {type(tail_item)}") emit(f'{tail} -> {head} [label="{dot_escape(label)}"]') # Emit output/namespace name mappings. _emit_name_cluster( - emitter, subgraph_path, outputs, array_to_id, id_gen, label="Returns") + emitter, subgraph_path, + outputs, array_to_id, id_gen, + label="Returns", count_duplicates=count_duplicates) # }}} @@ -507,20 +574,21 @@ def _get_function_name(f: FunctionDefinition) -> str | None: def _gather_partition_node_information( id_gen: UniqueNameGenerator, - partition: DistributedGraphPartition + partition: DistributedGraphPartition, + count_duplicates: bool = False ) -> tuple[ - Mapping[PartId, Mapping[FunctionDefinition, str]], - Mapping[tuple[PartId, FunctionDefinition | None], - Mapping[ArrayOrNames, _DotNodeInfo]] - ]: + dict[PartId, dict[FunctionDefinition, str]], + dict[tuple[PartId, FunctionDefinition | None], + dict[int | ArrayOrNames, _DotNodeInfo]]]: part_id_to_func_to_id: dict[PartId, dict[FunctionDefinition, str]] = {} part_id_func_to_node_info: dict[tuple[PartId, FunctionDefinition | None], - dict[ArrayOrNames, _DotNodeInfo]] = {} + dict[int | ArrayOrNames, + _DotNodeInfo]] = {} for part in partition.parts.values(): - mapper = ArrayToDotNodeInfoMapper() + mapper = ArrayToDotNodeInfoMapper(count_duplicates) for out_name in part.output_names: - mapper(partition.name_to_output[out_name]) + mapper.process_node(partition.name_to_output[out_name]) part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot part_id_to_func_to_id[part.pid] = {} @@ -535,9 +603,9 @@ def gather_function_info(f: FunctionDefinition) -> None: if key in part_id_func_to_node_info: return - mapper = ArrayToDotNodeInfoMapper() + mapper = ArrayToDotNodeInfoMapper(count_duplicates) for elem in f.returns.values(): - mapper(elem) + mapper.process_node(elem) part_id_func_to_node_info[key] = mapper.node_to_dot @@ -563,10 +631,12 @@ def gather_function_info(f: FunctionDefinition) -> None: return part_id_to_func_to_id, part_id_func_to_node_info + # }}} -def get_dot_graph(result: Array | DictOfNamedArrays) -> str: +def get_dot_graph(result: Array | DictOfNamedArrays, + count_duplicates: bool = False) -> str: r"""Return a string in the `dot `_ language depicting the graph of the computation of *result*. @@ -576,30 +646,32 @@ def get_dot_graph(result: Array | DictOfNamedArrays) -> str: outputs: DictOfNamedArrays = normalize_outputs(result) - return get_dot_graph_from_partition( - DistributedGraphPartition( - parts={ - None: DistributedGraphPart( - pid=None, - needed_pids=frozenset(), - user_input_names=frozenset( + partition = DistributedGraphPartition( + parts={ + None: DistributedGraphPart( + pid=None, + needed_pids=frozenset(), + user_input_names=frozenset( expr.name for expr in InputGatherer()(outputs) if isinstance(expr, Placeholder) ), - partition_input_names=frozenset(), - output_names=frozenset(outputs.keys()), - name_to_recv_node={}, - name_to_send_nodes={}, - ) - }, - name_to_output=outputs._data, - overall_output_names=tuple(outputs), - )) - - -def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: - r"""Return a string in the `dot `_ language depicting the + partition_input_names=frozenset(), + output_names=frozenset(outputs.keys()), + name_to_recv_node={}, + name_to_send_nodes={}, + ) + }, + name_to_output=outputs._data, + overall_output_names=tuple(outputs), + ) + + return get_dot_graph_from_partition(partition, count_duplicates) + + +def get_dot_graph_from_partition(partition: DistributedGraphPartition, + count_duplicates: bool = False) -> str: + """Return a string in the `dot `_ language depicting the graph of the partitioned computation of *partition*. :arg partition: Outputs of :func:`~pytato.find_distributed_partition`. @@ -611,9 +683,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # The "None" function is the body of the partition. part_id_to_func_to_id, part_id_func_to_node_info = \ - _gather_partition_node_information(id_gen, partition) - - # }}} + _gather_partition_node_information(id_gen, partition, count_duplicates) emitter = DotEmitter() emit_root = partial(emitter, ()) @@ -622,8 +692,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: emit_root("node [shape=rectangle]") - placeholder_to_id: dict[ArrayOrNames, str] = {} - part_id_to_array_to_id: dict[PartId, dict[ArrayOrNames, str]] = {} + placeholder_to_id: dict[int | ArrayOrNames, str] = {} + part_id_to_array_to_id: dict[PartId, dict[int | ArrayOrNames, str]] = {} part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts} assert len(set(part_id_to_id.values())) == len(partition.parts) @@ -633,16 +703,18 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: for part in partition.parts.values(): array_to_id = {} for array in part_id_func_to_node_info[part.pid, None].keys(): + if isinstance(array, int): # if the key is an ID + array = get_object_by_id(array) + key = get_array_key(array, count_duplicates) if isinstance(array, Placeholder): - # Placeholders are only emitted once - if array in placeholder_to_id: - node_id = placeholder_to_id[array] + if key in placeholder_to_id: + node_id = placeholder_to_id[key] else: node_id = id_gen("array") - placeholder_to_id[array] = node_id + placeholder_to_id[key] = node_id else: node_id = id_gen("array") - array_to_id[array] = node_id + array_to_id[key] = node_id part_id_to_array_to_id[part.pid] = array_to_id # }}} @@ -679,22 +751,22 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: _emit_function(emitter, func_subgraph_path, id_gen, part_id_func_to_node_info[part.pid, func], part_id_to_func_to_id[part.pid], - func.returns) + func.returns, + count_duplicates=count_duplicates) # }}} # {{{ emit receives nodes part_dist_recv_var_name_to_node_id = {} - for name, recv in ( - part.name_to_recv_node.items()): + for name, recv in part.name_to_recv_node.items(): node_id = id_gen("recv") _emit_array(emit_part, "DistributedRecv", { "shape": stringify_shape(recv.shape), "dtype": str(recv.dtype), "src_rank": str(recv.src_rank), "comm_tag": str(recv.comm_tag), - }, node_id) + }, node_id) part_dist_recv_var_name_to_node_id[name] = node_id @@ -705,6 +777,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: internal_arrays: list[ArrayOrNames] = [] for array in part_node_to_info.keys(): + if isinstance(array, int): # if the key is an ID + array = get_object_by_id(array) if isinstance(array, InputArgumentBase): input_arrays.append(array) else: @@ -718,26 +792,26 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # subgraphs. for array in input_arrays: + key = array = get_array_key(array, count_duplicates) if not isinstance(array, Placeholder): _emit_array(emit_part, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") + array_to_id[key], "deepskyblue") else: # Is a Placeholder - if array in emitted_placeholders: + if key in emitted_placeholders: continue - _emit_array(emit_root, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") + array_to_id[key], "deepskyblue") # Emit cross-partition edges if array.name in part_dist_recv_var_name_to_node_id: tgt = part_dist_recv_var_name_to_node_id[array.name] - emit_root(f"{tgt} -> {array_to_id[array]} [style=dotted]") - emitted_placeholders.add(array) + emit_root(f"{tgt} -> {array_to_id[key]} [style=dotted]") + emitted_placeholders.add(key) elif array.name in part.user_input_names: # no arrows for these pass @@ -750,18 +824,22 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: break assert computing_pid is not None tgt = part_id_to_array_to_id[computing_pid][ - partition.name_to_output[array.name]] - emit_root(f"{tgt} -> {array_to_id[array]} [style=dashed]") - emitted_placeholders.add(array) + id(partition.name_to_output[array.name]) + if count_duplicates + else partition.name_to_output[array.name]] + emit_root(f"{tgt} -> {array_to_id[key]} [style=dashed]") + emitted_placeholders.add(key) # }}} # Emit internal nodes + for array in internal_arrays: + key = array = get_array_key(array, count_duplicates) _emit_array(emit_part, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array]) + array_to_id[key]) # {{{ emit send nodes if distributed @@ -772,35 +850,47 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: _emit_array(emit_part, "DistributedSend", { "dest_rank": str(send.dest_rank), "comm_tag": str(send.comm_tag), - }, node_id) + }, node_id) # If an edge is emitted in a subgraph, it drags its # nodes into the subgraph, too. Not what we want. + data = id(send.data) if count_duplicates else send.data emit_root( - f"{array_to_id[send.data]} -> {node_id}" - f'[style=dotted, label="{dot_escape(name)}"]') + f"{array_to_id[data]} -> {node_id}" + f'[style=dotted, label="{dot_escape(name)}"]') # }}} # Emit intra-partition edges for array, node in part_node_to_info.items(): - for label, tail_item in node.edges.items(): - head = array_to_id[array] + key = get_array_key(array, count_duplicates) + + tail_item: Array | AbstractResultWithNamedArrays | FunctionDefinition + for label, edge_info in node.edges.items(): + if isinstance(edge_info, tuple): + tail_key, tail_item = edge_info + else: + tail_item = edge_info + tail_key = get_array_key(tail_item, count_duplicates) + + head = array_to_id[key] if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): - tail = array_to_id[tail_item] + tail = array_to_id[tail_key] elif isinstance(tail_item, FunctionDefinition): tail = part_id_to_func_to_id[part.pid][tail_item] else: raise ValueError( - f"unexpected type of tail on edge: {type(tail_item)}") + f"unexpected type of tail on edge: {type(tail_item)}") emit_root(f'{tail} -> {head} [label="{dot_escape(label)}"]') _emit_name_cluster( - emitter, part_subgraph_path, - {name: partition.name_to_output[name] for name in part.output_names}, - array_to_id, id_gen, "Part outputs") + emitter, part_subgraph_path, + {name: partition.name_to_output[name] + for name in part.output_names}, + array_to_id, id_gen, "Part outputs", + count_duplicates) # }}} @@ -810,15 +900,16 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # {{{ draw overall outputs - combined_array_to_id: dict[ArrayOrNames, str] = {} + combined_array_to_id: dict[int | ArrayOrNames, str] = {} for part_id in partition.parts.keys(): combined_array_to_id.update(part_id_to_array_to_id[part_id]) _emit_name_cluster( - emitter, (), - {name: partition.name_to_output[name] - for name in partition.overall_output_names}, - combined_array_to_id, id_gen, "Overall outputs") + emitter, (), + {name: partition.name_to_output[name] + for name in partition.overall_output_names}, + combined_array_to_id, id_gen, "Overall outputs", + count_duplicates) # }}} @@ -826,7 +917,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: def show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPartition, - **kwargs: Any) -> None: + count_duplicates: bool = False, + **kwargs: Any) -> None: """Show a graph representing the computation of *result* in a browser. :arg result: Outputs of the computation (cf. @@ -839,9 +931,9 @@ def show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPar if isinstance(result, str): dot_code = result elif isinstance(result, DistributedGraphPartition): - dot_code = get_dot_graph_from_partition(result) + dot_code = get_dot_graph_from_partition(result, count_duplicates) else: - dot_code = get_dot_graph(result) + dot_code = get_dot_graph(result, count_duplicates) from pytools.graphviz import show_dot show_dot(dot_code, **kwargs) diff --git a/test/test_codegen.py b/test/test_codegen.py index 3a6f2b7c1..d2260d077 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -581,7 +581,7 @@ def test_dict_to_loopy_kernel(ctx_factory): def test_only_deps_as_knl_args(): # See https://gitlab.tiker.net/inducer/pytato/-/issues/13 x = pt.make_placeholder(name="x", shape=(10, 4), dtype=float) - y = pt.make_placeholder(name="y", shape=(10, 4), dtype=float) # noqa:F841 + pt.make_placeholder(name="y", shape=(10, 4), dtype=float) z = 2*x knl = pt.generate_loopy(z).kernel @@ -941,9 +941,9 @@ def _get_a_shape(_m, _n): def _get_x_shape(_m, _n): return (3*_n+7, ) - A_in = np.random.rand(*_get_a_shape(m_in, n_in)) # noqa: N806 + A_in = np.random.rand(*_get_a_shape(m_in, n_in)) x_in = np.random.rand(*_get_x_shape(m_in, n_in)) - A = pt.make_data_wrapper(A_in, shape=_get_a_shape(m, n)) # noqa: N806 + A = pt.make_data_wrapper(A_in, shape=_get_a_shape(m, n)) x = pt.make_data_wrapper(x_in, shape=_get_x_shape(m, n)) np_out = np.einsum("ij, j -> i", A_in, x_in) @@ -982,7 +982,7 @@ def test_call_loopy_shape_inference1(ctx_factory): rng = default_rng() - A_in = rng.random((20, 37)) # noqa + A_in = rng.random((20, 37)) knl = lp.make_kernel( ["{[i, j]: 0<=i<(2*n + 3*m + 2) and 0<=j<(6*n + 4*m + 3)}", @@ -992,7 +992,7 @@ def test_call_loopy_shape_inference1(ctx_factory): out[ii, jj] = tmp*(ii + jj) """, lang_version=(2018, 2)) - A = pt.make_placeholder(name="x", shape=(20, 37), dtype=np.float64) # noqa: N806 + A = pt.make_placeholder(name="x", shape=(20, 37), dtype=np.float64) y_pt = call_loopy(knl, {"A": A})["out"] _, (out,) = pt.generate_loopy(y_pt)(queue, x=A_in) @@ -1014,7 +1014,7 @@ def test_call_loopy_shape_inference2(ctx_factory): rng = default_rng() - A_in = rng.random((38, 71)) # noqa + A_in = rng.random((38, 71)) knl = lp.make_kernel( ["{[i, j]: 0<=i<(2*n + 3*m + 2) and 0<=j<(6*n + 4*m + 3)}", @@ -1026,7 +1026,7 @@ def test_call_loopy_shape_inference2(ctx_factory): n1 = pt.make_size_param("n1") n2 = pt.make_size_param("n2") - A = pt.make_placeholder(name="x", # noqa: N806 + A = pt.make_placeholder(name="x", shape=(4*n1 + 6*n2 + 2, 12*n1 + 8*n2 + 3), dtype=np.float64) @@ -1284,7 +1284,7 @@ def test_advanced_indexing_fuzz(ctx_factory): cq = cl.CommandQueue(ctx) rng = default_rng(seed=0) - NSAMPLES = 50 # noqa: N806 + NSAMPLES = 50 for i in range(NSAMPLES): input_ndim = rng.integers(1, 8) @@ -1992,7 +1992,9 @@ def test_nested_function_calls(ctx_factory): _, out = prg(cq, x=x_np) np.testing.assert_allclose(out["out1"], 3*x_np) np.testing.assert_allclose(out["out2"], x_np) - ref_tracer = lambda f, *args, identifier: f(*args) # noqa: E731 + + def ref_tracer(f, *args, identifier): + return f(*args) def foo(tracer, x, y): return 2*x + 3*y @@ -2024,6 +2026,79 @@ def call_bar(tracer, x, y): np.testing.assert_allclose(result_out[k], expect_out[k]) +def test_duplicate_node_count_dot_graph(): + from testlib import count_dot_graph_nodes + + from pytato.analysis import get_num_nodes + from pytato.visualization.dot import get_dot_graph + + for i in range(80): + # print("curr i:", i) + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_duplicate_nodes_with_comm_count_dot_graph(): + from testlib import count_dot_graph_nodes, get_random_pt_dag_with_send_recv_nodes + + from pytato.analysis import get_num_nodes + from pytato.visualization.dot import get_dot_graph + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_large_dot_graph_with_duplicates_count(): + from testlib import count_dot_graph_nodes, make_large_dag + + from pytato.analysis import get_num_nodes + from pytato.visualization.dot import get_dot_graph + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts with duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_pytato.py b/test/test_pytato.py index f67e7e5f1..406231646 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -765,6 +765,53 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_duplicate_node_count(): + from testlib import get_random_pt_dag + + from pytato.analysis import get_node_multiplicities, get_num_nodes + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Get the number of types of expressions + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of expressions and the amount they're called + node_multiplicity = get_node_multiplicities(dag) + + # Get difference in duplicates + num_duplicates = sum( + count - 1 for count in node_multiplicity.values() if count > 1) + # Check that duplicates are correctly calculated + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) + + +def test_duplicate_nodes_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + + from pytato.analysis import get_node_multiplicities, get_num_nodes + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) + + # Get the number of types of expressions + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of expressions and the amount they're called + node_multiplicity = get_node_multiplicities(dag) + + # Get difference in duplicates + num_duplicates = sum( + count - 1 for count in node_multiplicity.values() if count > 1) + + # Check that duplicates are correctly calculated + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index 53bf79436..e44becd0e 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -2,6 +2,7 @@ import operator import random +import re import types from typing import Any, Callable, Sequence @@ -395,6 +396,23 @@ def make_large_dag_with_duplicates(iterations: int, result = pt.sum(combined_expr, axis=0) return pt.make_dict_of_named_arrays({"result": result}) + +def count_dot_graph_nodes(dot_graph: str) -> dict[Any, int]: + """ + Parses a dot graph and returns a dictionary with + the count of each unique node identifier. + """ + + node_pattern = re.compile( + r'addr:(0x[0-9a-f]+)') + nodes = node_pattern.findall(dot_graph) + + node_counts: dict[Any, int] = {} + for node in nodes: + node_counts[node] = node_counts.get(node, 0) + 1 + + return node_counts + # }}}