Skip to content

Commit

Permalink
Ruff changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kajalpatelinfo committed Aug 21, 2024
1 parent d96a4bf commit 724c799
Showing 1 changed file with 41 additions and 133 deletions.
174 changes: 41 additions & 133 deletions pytato/visualization/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,16 @@
"""


import gc
import html
import re
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Mapping,
Dict,
Tuple,
Union,
List,
Any,
FrozenSet,
Set,
Optional
)

import gc
import re
import attrs

from pytools import UniqueNameGenerator
Expand All @@ -66,6 +57,7 @@
Stack,
)
from pytato.codegen import normalize_outputs
from pytato.distributed.nodes import DistributedSendRefHolder
from pytato.distributed.partition import (
DistributedGraphPart,
DistributedGraphPartition,
Expand All @@ -75,7 +67,6 @@
from pytato.loopy import LoopyCall
from pytato.tags import FunctionIdentifier
from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer
from pytato.distributed.nodes import DistributedSendRefHolder


__doc__ = """
Expand Down Expand Up @@ -173,84 +164,6 @@ def simplify_indexlambda_node_to_symbol_only(s):
return s


def extract_operation_symbol(expr):

operation_replacements = {
r"NaN_if": "if",
r"else": "else",
r"isnan": "is NaN",
r"<": "&lt;",
r">": "&gt;",
r"\s*==\s*": "==",
r"\s*!=\s*": "!=",
r"\s*<=\s*": "<=",
r"\s*>=\s*": ">=",
r"\s*\+\s*": "+",
r"\s*\-\s*": "-",
r"\s*\*\*\s*": "**",
r"\s*\*\s*": "*",
r"\s*/\s*": "/",
r"\s*//\s*": "//",
r"\s*%\s*": "%",
r"\s*or\s*": "or",
r"\s*and\s*": "and",
r"\s*not\s*": "not",
r"\s*<<\s*": "<<",
r"\s*>>\s*": ">>",
r"\s*\|\s*": "|",
r"\s*\^\s*": "^",
r"~\s*": "~",
r"\s*@\s*": "@",
r"\s*SumReductionOperation\s*": "Σ",
r"&lt;": "&lt;",
r"&gt;": "&gt;",
r"&": "&amp;",
}

for pattern, replacement in operation_replacements.items():
if re.search(pattern, expr.strip()):
return replacement

return expr


def simplify_indexlambda_node_to_symbol_only(s):
if "IndexLambda" in s:
expr_match = re.search(
r'expr:</td><td border="0"><FONT FACE=\'monospace\'>(.*?)</FONT></td>', s
)

if expr_match:
original_expr = expr_match.group(1)
operation_symbol = extract_operation_symbol(original_expr)

tooltip_content = []
tooltip_matches = re.findall(
r'<tr><td border="0">(.*?)</td><td border="0">'
r'<FONT FACE=\'monospace\'>(.*?)</FONT></td></tr>',
s
)

for key, value in tooltip_matches:
tooltip_content.append(f"{key}: {value}")

tooltip_text = ",\n".join(tooltip_content)

new_label = (
f'<tr><td colspan="2" border="0" align="center">'
f'<FONT POINT-SIZE="20">{operation_symbol}</FONT>'
f'</td></tr>'
)

s = (
f'{new_label}</table>> '
f'style=filled fillcolor="white" '
f'tooltip="{tooltip_text}"];'
)

return s


class DotEmitter:
def __init__(self) -> None:
self.subgraph_to_lines: dict[tuple[str, ...], list[str]] = {}
Expand Down Expand Up @@ -325,12 +238,10 @@ def emit_subgraph(sg: _SubgraphTree) -> None:
@attrs.define
class _DotNodeInfo:
title: str
fields: Dict[str, Any]
edges: Dict[str, Union[
ArrayOrNames,
FunctionDefinition,
Tuple[Union[int, ArrayOrNames], ArrayOrNames]],
Array]
fields: dict[str, Any]
edges: dict[str, ArrayOrNames |
FunctionDefinition | tuple[int |
ArrayOrNames, ArrayOrNames], Array]


def stringify_tags(tags: frozenset[Tag | None]) -> str:
Expand All @@ -347,7 +258,7 @@ def stringify_shape(shape: ShapeType) -> str:
return "(" + ", ".join(components) + ")"


def get_object_by_id(object_id: int) -> Union[Any, ArrayOrNames]:
def get_object_by_id(object_id: int) -> Any | ArrayOrNames:
"""Find an object by its ID."""
for obj in gc.get_objects():
if id(obj) == object_id:
Expand All @@ -357,11 +268,11 @@ def get_object_by_id(object_id: int) -> Union[Any, ArrayOrNames]:

class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
def __init__(self, count_duplicates: bool = False):
self.node_to_dot: Dict[Union[int, ArrayOrNames], _DotNodeInfo] = {}
self.functions: Set[FunctionDefinition] = set()
self.node_to_dot: dict[int | ArrayOrNames, _DotNodeInfo] = {}
self.functions: set[FunctionDefinition] = set()
self.count_duplicates = count_duplicates

def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]:
def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
return id(expr) if self.count_duplicates else expr

def get_common_dot_info(self, expr: Array) -> _DotNodeInfo:
Expand All @@ -373,10 +284,10 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo:
"non_equality_tags": expr.non_equality_tags,
}

edges: Dict[str,
Union[ArrayOrNames, FunctionDefinition,
Tuple[Union[int, AbstractResultWithNamedArrays,
Array], Array]]] = {}
edges: dict[str,
ArrayOrNames | FunctionDefinition |
tuple[int | AbstractResultWithNamedArrays |
Array, Array]] = {}
return _DotNodeInfo(title, fields, edges)

def process_node(self, expr: ArrayOrNames) -> None:
Expand Down Expand Up @@ -507,8 +418,8 @@ def map_einsum(self, expr: Einsum) -> None:
self.node_to_dot[self.get_cache_key(expr)] = info

def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
edges: Dict[str, Union[ArrayOrNames, FunctionDefinition, Tuple[Union[
int, ArrayOrNames], Array]]] = {}
edges: dict[str, ArrayOrNames | FunctionDefinition |
tuple[int | ArrayOrNames, Array]] = {}
for name, val in expr._data.items():
self.process_node(val)
key = self.get_cache_key(val)
Expand All @@ -520,8 +431,8 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
edges=edges)

def map_loopy_call(self, expr: LoopyCall) -> None:
edges: Dict[str, Union[ArrayOrNames, FunctionDefinition, Tuple[Union[
int, ArrayOrNames], Array]]] = {}
edges: dict[str, ArrayOrNames | FunctionDefinition |
tuple[int | ArrayOrNames, Array]] = {}
for name, arg in expr.bindings.items():
if isinstance(arg, Array):
self.process_node(arg)
Expand Down Expand Up @@ -593,7 +504,7 @@ def dot_escape_leave_space(s: str) -> str:
return html.escape(s.replace("\\", "\\\\"))


def get_array_key(array: Union[ArrayOrNames, FunctionDefinition, int],
def get_array_key(array: ArrayOrNames | FunctionDefinition | int,
count_duplicates: bool = False) -> Any:
"""Return a consistent key for the array."""
return id(array) if count_duplicates and not isinstance(array, int) else array
Expand All @@ -610,7 +521,7 @@ def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str:
return "<unknown>"


def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any],
def _emit_array(emit: Callable[[str], None], title: str, fields: dict[str, Any],
dot_node_id: str, color: str = "white") -> None:
td_attrib = 'border="0"'
table_attrib = 'border="0" cellborder="1" cellspacing="0"'
Expand All @@ -637,7 +548,7 @@ def _emit_name_cluster(
emit: DotEmitter, subgraph_path: tuple[str, ...],
names: Mapping[str, ArrayOrNames],
array_to_id: Mapping[
Union[int, ArrayOrNames], str], id_gen: Callable[[str], str],
int | ArrayOrNames, str], id_gen: Callable[[str], str],
label: str,
count_duplicates: bool = False) -> None:
edges = []
Expand All @@ -649,7 +560,7 @@ def _emit_name_cluster(

for name, array in names.items():
name_id = id_gen(dot_escape(name))
emit_cluster('%s [label="%s"]' % (name_id, dot_escape(name)))
emit_cluster(f'{name_id} [label="{dot_escape(name)}"]')
array_key = get_array_key(array, count_duplicates)
array_id = array_to_id[array_key]
# Edges must be outside the cluster.
Expand All @@ -662,13 +573,13 @@ def _emit_name_cluster(
def _emit_function(
emitter: DotEmitter, subgraph_path: tuple[str, ...],
id_gen: UniqueNameGenerator,
node_to_dot: Mapping[Union[int, ArrayOrNames], _DotNodeInfo],
node_to_dot: Mapping[int | ArrayOrNames, _DotNodeInfo],
func_to_id: Mapping[FunctionDefinition, str],
outputs: Mapping[str, Array],
count_duplicates: bool = False) -> None:
input_arrays: List[Array] = []
internal_arrays: List[Union[int, ArrayOrNames]] = []
array_to_id: Dict[Union[int, ArrayOrNames], str] = {}
input_arrays: list[Array] = []
internal_arrays: list[int | ArrayOrNames] = []
array_to_id: dict[int | ArrayOrNames, str] = {}

emit = partial(emitter, subgraph_path)
for array in node_to_dot:
Expand Down Expand Up @@ -745,13 +656,13 @@ def _gather_partition_node_information(
id_gen: UniqueNameGenerator,
partition: DistributedGraphPartition,
count_duplicates: bool = False
) -> Tuple[
Dict[PartId, Dict[FunctionDefinition, str]],
Dict[Tuple[PartId, Optional[FunctionDefinition]],
Dict[Union[int, ArrayOrNames], _DotNodeInfo]]]:
part_id_to_func_to_id: Dict[PartId, Dict[FunctionDefinition, str]] = {}
part_id_func_to_node_info: Dict[Tuple[PartId, Optional[FunctionDefinition]],
Dict[Union[int, ArrayOrNames],
) -> tuple[
dict[PartId, dict[FunctionDefinition, str]],
dict[tuple[PartId, FunctionDefinition | None],
dict[int | ArrayOrNames, _DotNodeInfo]]]:
part_id_to_func_to_id: dict[PartId, dict[FunctionDefinition, str]] = {}
part_id_func_to_node_info: dict[tuple[PartId, FunctionDefinition | None],
dict[int | ArrayOrNames,
_DotNodeInfo]] = {}

for part in partition.parts.values():
Expand Down Expand Up @@ -804,7 +715,7 @@ def gather_function_info(f: FunctionDefinition) -> None:
# }}}


def get_dot_graph(result: Union[Array, DictOfNamedArrays],
def get_dot_graph(result: Array | DictOfNamedArrays,
count_duplicates: bool = False) -> str:
r"""Return a string in the `dot <https://graphviz.org>`_ language depicting the
graph of the computation of *result*.
Expand Down Expand Up @@ -861,8 +772,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition,

emit_root("node [shape=rectangle]")

placeholder_to_id: Dict[Union[int, ArrayOrNames], str] = {}
part_id_to_array_to_id: Dict[PartId, Dict[Union[int, ArrayOrNames], str]] = {}
placeholder_to_id: dict[int | ArrayOrNames, str] = {}
part_id_to_array_to_id: dict[PartId, dict[int | ArrayOrNames, str]] = {}

part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts}
assert len(set(part_id_to_id.values())) == len(partition.parts)
Expand Down Expand Up @@ -1034,8 +945,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition,
for array, node in part_node_to_info.items():
key = get_array_key(array, count_duplicates)

tail_item: Union[Array, AbstractResultWithNamedArrays,
FunctionDefinition]
tail_item: Array | AbstractResultWithNamedArrays | FunctionDefinition
for label, edge_info in node.edges.items():
if isinstance(edge_info, tuple):
tail_key, tail_item = edge_info
Expand All @@ -1053,8 +963,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition,
raise ValueError(
f"unexpected type of tail on edge: {type(tail_item)}")

emit_root('%s -> %s [label="%s"]' %
(tail, head, dot_escape(label)))
emit_root(f'{tail} -> {head} [label="{dot_escape(label)}"]')

_emit_name_cluster(
emitter, part_subgraph_path,
Expand All @@ -1071,7 +980,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition,

# {{{ draw overall outputs

combined_array_to_id: Dict[Union[int, ArrayOrNames], str] = {}
combined_array_to_id: dict[int | ArrayOrNames, str] = {}
for part_id in partition.parts.keys():
combined_array_to_id.update(part_id_to_array_to_id[part_id])

Expand All @@ -1087,8 +996,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition,
return emitter.generate()


def show_dot_graph(result: Union[str, Array, DictOfNamedArrays,
DistributedGraphPartition],
def show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPartition,
count_duplicates: bool = False,
**kwargs: Any) -> None:
"""Show a graph representing the computation of *result* in a browser.
Expand Down

0 comments on commit 724c799

Please sign in to comment.