diff --git a/test/test_pytato.py b/test/test_pytato.py index be7da1476..4bca10925 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1115,14 +1115,42 @@ def test_dot_visualizers(): # }}} -def test_persistent_hashing_and_persistent_dict(): +# {{{ Test PytatoKeyBuilder + +def run_test_with_new_python_invocation(f, *args, extra_env_vars = None) -> None: + import os + if extra_env_vars is None: + extra_env_vars = {} + + from base64 import b64encode + from pickle import dumps + from subprocess import check_call + + env_vars = { + "INVOCATION_INFO": b64encode(dumps((f, args))).decode(), + } + env_vars.update(extra_env_vars) + + my_env = os.environ.copy() + my_env.update(env_vars) + + check_call([sys.executable, __file__], env=my_env) + + +def run_test_with_new_python_invocation_inner() -> None: + from base64 import b64decode + from pickle import loads + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) + + f(*args) + + +def test_persistent_hashing_and_persistent_dict() -> None: from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError from pytato.analysis import PytatoKeyBuilder import shutil import tempfile - axis_len = 5 - try: tmpdir = tempfile.mkdtemp() @@ -1134,26 +1162,50 @@ def test_persistent_hashing_and_persistent_dict(): for i in range(100): rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=True) + axis_len=5, use_numpy=True) dag = make_random_dag(rdagc) # Make sure the PytatoKeyBuilder can handle 'dag' pd[dag] = 42 - # make sure the key stays the same across invocations - if i == 0: - assert pkb(dag) == "eaa8ad49c9490cb6f0b61a33c17d0c2fd10fafc6ce02705105cc9c379c91b9c8" - - # Make sure key stays the same + # Make sure that the key stays the same within the same Python invocation with pytest.raises(ReadOnlyEntryError): pd[dag] = 42 + + # Make sure that the key stays the same across Python invocations + run_test_with_new_python_invocation(_test_persistent_hashing_and_persistent_dict_stage2, + tmpdir) finally: shutil.rmtree(tmpdir) +def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None: + from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError + + from pytato.analysis import PytatoKeyBuilder + pkb = PytatoKeyBuilder() + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=5, use_numpy=True) + + dag = make_random_dag(rdagc) + + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + +# }}} + if __name__ == "__main__": - if len(sys.argv) > 1: + import os + if "INVOCATION_INFO" in os.environ: + run_test_with_new_python_invocation_inner() + elif len(sys.argv) > 1: exec(sys.argv[1]) else: from pytest import main