Skip to content

Commit

Permalink
fix: more robust hashing & pickling when memoizing (#80)
Browse files Browse the repository at this point in the history
Closes #75
Closes #79

### Summary of Changes

* Catch errors when attempting to pickle/hash a value.
* Make dicts and lambdas serializable.

---------

Co-authored-by: WinPlay02 <[email protected]>
Co-authored-by: megalinter-bot <[email protected]>
  • Loading branch information
3 people authored Apr 8, 2024
1 parent 32974d0 commit 25c49e2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 16 deletions.
32 changes: 23 additions & 9 deletions src/safeds_runner/server/_memoization_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module that contains the memoization logic and stats."""

import dataclasses
import inspect
import logging
import sys
import time
Expand Down Expand Up @@ -140,7 +141,11 @@ def memoized_function_call(
# Lookup memoized value
lookup_time_start = time.perf_counter_ns()
key = self._create_memoization_key(function_name, parameters, hidden_parameters)
memoized_value = self._lookup_value(key)
try:
memoized_value = self._lookup_value(key)
# Pickling may raise AttributeError, hashing may raise TypeError
except (AttributeError, TypeError):
return function_callable(*parameters)
lookup_time = time.perf_counter_ns() - lookup_time_start

# Hit
Expand Down Expand Up @@ -195,7 +200,7 @@ def _create_memoization_key(
-------
A memoization key, which contains the lists converted to tuples
"""
return function_name, _convert_list_to_tuple(parameters), _convert_list_to_tuple(hidden_parameters)
return function_name, _make_hashable(parameters), _make_hashable(hidden_parameters)

def _lookup_value(self, key: MemoizationKey) -> Any | None:
"""
Expand Down Expand Up @@ -290,21 +295,30 @@ def _update_stats_on_miss(
self._map_stats[function_name] = stats


def _convert_list_to_tuple(values: list) -> tuple:
def _make_hashable(value: Any) -> Any:
"""
Recursively convert a mutable list of values to an immutable tuple containing the same values, to make the values hashable.
Make a value hashable.
Parameters
----------
values : list
Values that should be converted to a tuple
value:
Value to be converted.
Returns
-------
tuple
Converted list containing all the elements of the provided list
converted_value:
Converted value.
"""
return tuple(_convert_list_to_tuple(value) if isinstance(value, list) else value for value in values)
if isinstance(value, dict):
return tuple((_make_hashable(key), _make_hashable(value)) for key, value in value.items())
elif isinstance(value, list):
return tuple(_make_hashable(element) for element in value)
elif callable(value):
# This is a band-aid solution to make callables serializable. Unfortunately, `getsource` returns more than just
# the source code for lambdas.
return inspect.getsource(value)
else:
return value


def _get_size_of_value(value: Any) -> int:
Expand Down
50 changes: 43 additions & 7 deletions tests/safeds_runner/server/test_memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
from safeds_runner.server._memoization_map import (
MemoizationMap,
MemoizationStats,
_convert_list_to_tuple,
_get_size_of_value,
_make_hashable,
)
from safeds_runner.server._messages import MessageDataProgram, ProgramMainInformation
from safeds_runner.server._pipeline_manager import PipelineProcess, file_mtime, memoized_call


class UnhashableClass:
def __hash__(self) -> int:
raise TypeError("unhashable type")


@pytest.mark.parametrize(
argnames="function_name,params,hidden_params,expected_result",
argvalues=[
Expand All @@ -39,11 +44,13 @@ def test_memoization_already_present_values(
{},
MemoizationMap({}, {}),
)
_pipeline_manager.current_pipeline.get_memoization_map()._map_values[(
function_name,
_convert_list_to_tuple(params),
_convert_list_to_tuple(hidden_params),
)] = expected_result
_pipeline_manager.current_pipeline.get_memoization_map()._map_values[
(
function_name,
_make_hashable(params),
_make_hashable(hidden_params),
)
] = expected_result
_pipeline_manager.current_pipeline.get_memoization_map()._map_stats[function_name] = MemoizationStats(
[time.perf_counter_ns()],
[],
Expand All @@ -59,8 +66,10 @@ def test_memoization_already_present_values(
argvalues=[
("function_pure", lambda a, b, c: a + b + c, [1, 2, 3], [], 6),
("function_impure_readfile", lambda filename: filename.split(".")[0], ["abc.txt"], [1234567891], "abc"),
("function_dict", lambda x: len(x), [{}], [], 0),
("function_lambda", lambda x: x(), [lambda: 0], [], 0),
],
ids=["function_pure", "function_impure_readfile"],
ids=["function_pure", "function_impure_readfile", "function_dict", "function_lambda"],
)
def test_memoization_not_present_values(
function_name: str,
Expand All @@ -84,6 +93,33 @@ def test_memoization_not_present_values(
assert result2 == expected_result


@pytest.mark.parametrize(
argnames="function_name,function,params,hidden_params,expected_result",
argvalues=[
("unhashable_params", lambda a: type(a).__name__, [UnhashableClass()], [], "UnhashableClass"),
("unhashable_hidden_params", lambda: None, [], [UnhashableClass()], None),
],
ids=["unhashable_params", "unhashable_hidden_params"],
)
def test_memoization_unhashable_values(
function_name: str,
function: typing.Callable,
params: list,
hidden_params: list,
expected_result: Any,
) -> None:
_pipeline_manager.current_pipeline = PipelineProcess(
MessageDataProgram({}, ProgramMainInformation("", "", "")),
"",
Queue(),
{},
MemoizationMap({}, {}),
)

result = memoized_call(function_name, function, params, hidden_params)
assert result == expected_result


def test_file_mtime_exists() -> None:
with tempfile.NamedTemporaryFile() as file:
mtime = file_mtime(file.name)
Expand Down

0 comments on commit 25c49e2

Please sign in to comment.