Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PytatoKeyBuilder, persistent_dict test #459

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
38e4332
add PytatoKeyBuilder
matthiasdiener Sep 25, 2023
4dd3250
mypy fixes
matthiasdiener Sep 25, 2023
970e7bb
support TaggableCLArray, Subscript
matthiasdiener Sep 28, 2023
95dec09
CL Array, function
matthiasdiener Sep 28, 2023
2ac10ee
add prim.Variable
matthiasdiener Feb 5, 2024
62a13ae
fixes to ndarray, pymb expressions
matthiasdiener Feb 5, 2024
b8e04bf
flake8
matthiasdiener Feb 5, 2024
ad9aa28
improve test
matthiasdiener Feb 5, 2024
60d8e41
add full invocation test
matthiasdiener Feb 5, 2024
9d45e65
lint fixes
matthiasdiener Feb 5, 2024
08be380
add missing pymbolic expressions
matthiasdiener Feb 5, 2024
058f6f9
flake8
matthiasdiener Feb 6, 2024
352bab6
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jun 13, 2024
0360e21
remove update_for_function (now handled directly by pytools)
matthiasdiener Jun 13, 2024
8e3277c
Merge remote-tracking branch 'refs/remotes/origin/PytatoKeyBuilder' i…
matthiasdiener Jun 13, 2024
b1aaa97
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jul 3, 2024
16516ec
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jul 3, 2024
454f273
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jul 25, 2024
993dfe4
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Sep 7, 2024
82a5f25
lint
matthiasdiener Sep 7, 2024
dc53746
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Sep 9, 2024
97df5d7
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Sep 19, 2024
93abdc1
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Sep 27, 2024
b35d841
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Oct 9, 2024
70a6887
add typecheck, remove pymbolic handling
matthiasdiener Oct 9, 2024
4f9f95f
lint
matthiasdiener Oct 9, 2024
cffda36
pylint
matthiasdiener Oct 9, 2024
47a31d6
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Nov 7, 2024
57908b1
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from typing import TYPE_CHECKING, Any, Mapping

from loopy.tools import LoopyKeyBuilder
from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

Expand Down Expand Up @@ -564,4 +565,47 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int:

# }}}


# {{{ PytatoKeyBuilder

class PytatoKeyBuilder(LoopyKeyBuilder):
"""A custom :class:`pytools.persistent_dict.KeyBuilder` subclass
for objects within :mod:`pytato`.
"""

def update_for_ndarray(self, key_hash: Any, key: Any) -> None:
self.rec(key_hash, key.data.tobytes())

def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None:
self.rec(key_hash, key.get())

def update_for_Array(self, key_hash: Any, key: Any) -> None:
# CL Array
self.rec(key_hash, key.get())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use an isinstance check to make sure that this only hashes the intended types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 70a6887


update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this can go now that inducer/pymbolic#125 is in.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in 70a6887


# }}}

# vim: fdm=marker
96 changes: 95 additions & 1 deletion test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,97 @@ def test_dot_visualizers():
# }}}


# {{{ Test PytatoKeyBuilder

def run_test_with_new_python_invocation(f, *args, extra_env_vars=None) -> None:
import os
if extra_env_vars is None:
extra_env_vars = {}

from base64 import b64encode
from pickle import dumps
from subprocess import check_call

env_vars = {
"INVOCATION_INFO": b64encode(dumps((f, args))).decode(),
}
env_vars.update(extra_env_vars)

my_env = os.environ.copy()
my_env.update(env_vars)

check_call([sys.executable, __file__], env=my_env)


def run_test_with_new_python_invocation_inner() -> None:
import os
from base64 import b64decode
from pickle import loads

f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode()))

f(*args)


def test_persistent_hashing_and_persistent_dict() -> None:
import shutil
import tempfile

from pytools.persistent_dict import ReadOnlyEntryError, WriteOncePersistentDict

from pytato.analysis import PytatoKeyBuilder

try:
tmpdir = tempfile.mkdtemp()

pkb = PytatoKeyBuilder()

pd = WriteOncePersistentDict("test_persistent_dict",
key_builder=pkb,
container_dir=tmpdir)

for i in range(100):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=5, use_numpy=True)

dag = make_random_dag(rdagc)

# Make sure the PytatoKeyBuilder can handle 'dag'
pd[dag] = 42

# Make sure that the key stays the same within the same Python invocation
with pytest.raises(ReadOnlyEntryError):
pd[dag] = 42

# Make sure that the key stays the same across Python invocations
run_test_with_new_python_invocation(
_test_persistent_hashing_and_persistent_dict_stage2, tmpdir)
finally:
shutil.rmtree(tmpdir)


def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None:
from pytools.persistent_dict import ReadOnlyEntryError, WriteOncePersistentDict

from pytato.analysis import PytatoKeyBuilder
pkb = PytatoKeyBuilder()

pd = WriteOncePersistentDict("test_persistent_dict",
key_builder=pkb,
container_dir=tmpdir)

for i in range(100):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=5, use_numpy=True)

dag = make_random_dag(rdagc)

with pytest.raises(ReadOnlyEntryError):
pd[dag] = 42

# }}}


def test_numpy_type_promotion_with_pytato_arrays():
class NotReallyAnArray:
@property
Expand All @@ -1380,7 +1471,10 @@ def dtype(self):


if __name__ == "__main__":
if len(sys.argv) > 1:
import os
if "INVOCATION_INFO" in os.environ:
run_test_with_new_python_invocation_inner()
elif len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
Expand Down
Loading