From b2d71a36a0ec0316ba1e222fdd9166943dddc49c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 25 Sep 2023 17:36:38 -0500 Subject: [PATCH] add PytatoKeyBuilder --- pytato/analysis/__init__.py | 29 +++++++++++++++++++++++++++++ test/test_pytato.py | 20 ++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d0fd6ef1e..01bd81c9c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -36,6 +36,7 @@ from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method +from loopy.tools import LoopyKeyBuilder, PersistentHashWalkMapper if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder @@ -453,4 +454,32 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} + +# {{{ PytatoKeyBuilder + +class PytatoKeyBuilder(LoopyKeyBuilder): + """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass + for objects within :mod:`pytato`. + """ + + def update_for_ndarray(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, hash(key.data.tobytes())) # type: ignore[no-untyped-call] + + def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: + if key is None: + self.update_for_NoneType(key_hash, key) # type: ignore[no-untyped-call] + else: + PersistentHashWalkMapper(key_hash)(key) + + update_for_Product = update_for_pymbolic_expression # noqa: N815 + update_for_Sum = update_for_pymbolic_expression # noqa: N815 + update_for_If = update_for_pymbolic_expression # noqa: N815 + update_for_LogicalOr = update_for_pymbolic_expression # noqa: N815 + update_for_Call = update_for_pymbolic_expression # noqa: N815 + update_for_Comparison = update_for_pymbolic_expression # noqa: N815 + update_for_Quotient = update_for_pymbolic_expression # noqa: N815 + update_for_Power = update_for_pymbolic_expression # noqa: N815 + +# }}} + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 98393b95b..7e73c6267 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1116,6 +1116,26 @@ def test_dot_visualizers(): # }}} +def test_persistent_dict(): + from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError + from pytato.analysis import PytatoKeyBuilder + + axis_len = 5 + + pd = WriteOncePersistentDict("test_persistent_dict", key_builder=PytatoKeyBuilder(), container_dir="./pytest-pdict") + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=True) + + dag = make_random_dag(rdagc) + pd[dag] = 42 + + # Make sure key stays the same + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])