Skip to content

Commit

Permalink
add full invocation test
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Feb 5, 2024
1 parent cac2e3c commit 544b677
Showing 1 changed file with 62 additions and 10 deletions.
72 changes: 62 additions & 10 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,14 +1115,42 @@ def test_dot_visualizers():
# }}}


def test_persistent_hashing_and_persistent_dict():
# {{{ Test PytatoKeyBuilder

def run_test_with_new_python_invocation(f, *args, extra_env_vars = None) -> None:
import os
if extra_env_vars is None:
extra_env_vars = {}

from base64 import b64encode
from pickle import dumps
from subprocess import check_call

env_vars = {
"INVOCATION_INFO": b64encode(dumps((f, args))).decode(),
}
env_vars.update(extra_env_vars)

my_env = os.environ.copy()
my_env.update(env_vars)

check_call([sys.executable, __file__], env=my_env)


def run_test_with_new_python_invocation_inner() -> None:
from base64 import b64decode
from pickle import loads
f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode()))

f(*args)


def test_persistent_hashing_and_persistent_dict() -> None:
from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError
from pytato.analysis import PytatoKeyBuilder
import shutil
import tempfile

axis_len = 5

try:
tmpdir = tempfile.mkdtemp()

Expand All @@ -1134,26 +1162,50 @@ def test_persistent_hashing_and_persistent_dict():

for i in range(100):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=True)
axis_len=5, use_numpy=True)

dag = make_random_dag(rdagc)

# Make sure the PytatoKeyBuilder can handle 'dag'
pd[dag] = 42

# make sure the key stays the same across invocations
if i == 0:
assert pkb(dag) == "eaa8ad49c9490cb6f0b61a33c17d0c2fd10fafc6ce02705105cc9c379c91b9c8"

# Make sure key stays the same
# Make sure that the key stays the same within the same Python invocation
with pytest.raises(ReadOnlyEntryError):
pd[dag] = 42

# Make sure that the key stays the same across Python invocations
run_test_with_new_python_invocation(_test_persistent_hashing_and_persistent_dict_stage2,
tmpdir)
finally:
shutil.rmtree(tmpdir)

def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None:
from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError

from pytato.analysis import PytatoKeyBuilder
pkb = PytatoKeyBuilder()

pd = WriteOncePersistentDict("test_persistent_dict",
key_builder=pkb,
container_dir=tmpdir)

for i in range(100):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=5, use_numpy=True)

dag = make_random_dag(rdagc)

with pytest.raises(ReadOnlyEntryError):
pd[dag] = 42

# }}}


if __name__ == "__main__":
if len(sys.argv) > 1:
import os
if "INVOCATION_INFO" in os.environ:
run_test_with_new_python_invocation_inner()
elif len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
Expand Down

0 comments on commit 544b677

Please sign in to comment.