From 25c49e29da2506d514485b001dd8fc27caf230f9 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Mon, 8 Apr 2024 19:10:02 +0200 Subject: [PATCH] fix: more robust hashing & pickling when memoizing (#80) Closes #75 Closes #79 ### Summary of Changes * Catch errors when attempting to pickle/hash a value. * Make dicts and lambdas serializable. --------- Co-authored-by: WinPlay02 Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> --- src/safeds_runner/server/_memoization_map.py | 32 ++++++++---- .../safeds_runner/server/test_memoization.py | 50 ++++++++++++++++--- 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/src/safeds_runner/server/_memoization_map.py b/src/safeds_runner/server/_memoization_map.py index 801c519..a90209a 100644 --- a/src/safeds_runner/server/_memoization_map.py +++ b/src/safeds_runner/server/_memoization_map.py @@ -1,6 +1,7 @@ """Module that contains the memoization logic and stats.""" import dataclasses +import inspect import logging import sys import time @@ -140,7 +141,11 @@ def memoized_function_call( # Lookup memoized value lookup_time_start = time.perf_counter_ns() key = self._create_memoization_key(function_name, parameters, hidden_parameters) - memoized_value = self._lookup_value(key) + try: + memoized_value = self._lookup_value(key) + # Pickling may raise AttributeError, hashing may raise TypeError + except (AttributeError, TypeError): + return function_callable(*parameters) lookup_time = time.perf_counter_ns() - lookup_time_start # Hit @@ -195,7 +200,7 @@ def _create_memoization_key( ------- A memoization key, which contains the lists converted to tuples """ - return function_name, _convert_list_to_tuple(parameters), _convert_list_to_tuple(hidden_parameters) + return function_name, _make_hashable(parameters), _make_hashable(hidden_parameters) def _lookup_value(self, key: MemoizationKey) -> Any | None: """ @@ -290,21 +295,30 @@ def _update_stats_on_miss( self._map_stats[function_name] = stats -def _convert_list_to_tuple(values: list) -> tuple: +def _make_hashable(value: Any) -> Any: """ - Recursively convert a mutable list of values to an immutable tuple containing the same values, to make the values hashable. + Make a value hashable. Parameters ---------- - values : list - Values that should be converted to a tuple + value: + Value to be converted. Returns ------- - tuple - Converted list containing all the elements of the provided list + converted_value: + Converted value. """ - return tuple(_convert_list_to_tuple(value) if isinstance(value, list) else value for value in values) + if isinstance(value, dict): + return tuple((_make_hashable(key), _make_hashable(value)) for key, value in value.items()) + elif isinstance(value, list): + return tuple(_make_hashable(element) for element in value) + elif callable(value): + # This is a band-aid solution to make callables serializable. Unfortunately, `getsource` returns more than just + # the source code for lambdas. + return inspect.getsource(value) + else: + return value def _get_size_of_value(value: Any) -> int: diff --git a/tests/safeds_runner/server/test_memoization.py b/tests/safeds_runner/server/test_memoization.py index 656fc63..33960ae 100644 --- a/tests/safeds_runner/server/test_memoization.py +++ b/tests/safeds_runner/server/test_memoization.py @@ -11,13 +11,18 @@ from safeds_runner.server._memoization_map import ( MemoizationMap, MemoizationStats, - _convert_list_to_tuple, _get_size_of_value, + _make_hashable, ) from safeds_runner.server._messages import MessageDataProgram, ProgramMainInformation from safeds_runner.server._pipeline_manager import PipelineProcess, file_mtime, memoized_call +class UnhashableClass: + def __hash__(self) -> int: + raise TypeError("unhashable type") + + @pytest.mark.parametrize( argnames="function_name,params,hidden_params,expected_result", argvalues=[ @@ -39,11 +44,13 @@ def test_memoization_already_present_values( {}, MemoizationMap({}, {}), ) - _pipeline_manager.current_pipeline.get_memoization_map()._map_values[( - function_name, - _convert_list_to_tuple(params), - _convert_list_to_tuple(hidden_params), - )] = expected_result + _pipeline_manager.current_pipeline.get_memoization_map()._map_values[ + ( + function_name, + _make_hashable(params), + _make_hashable(hidden_params), + ) + ] = expected_result _pipeline_manager.current_pipeline.get_memoization_map()._map_stats[function_name] = MemoizationStats( [time.perf_counter_ns()], [], @@ -59,8 +66,10 @@ def test_memoization_already_present_values( argvalues=[ ("function_pure", lambda a, b, c: a + b + c, [1, 2, 3], [], 6), ("function_impure_readfile", lambda filename: filename.split(".")[0], ["abc.txt"], [1234567891], "abc"), + ("function_dict", lambda x: len(x), [{}], [], 0), + ("function_lambda", lambda x: x(), [lambda: 0], [], 0), ], - ids=["function_pure", "function_impure_readfile"], + ids=["function_pure", "function_impure_readfile", "function_dict", "function_lambda"], ) def test_memoization_not_present_values( function_name: str, @@ -84,6 +93,33 @@ def test_memoization_not_present_values( assert result2 == expected_result +@pytest.mark.parametrize( + argnames="function_name,function,params,hidden_params,expected_result", + argvalues=[ + ("unhashable_params", lambda a: type(a).__name__, [UnhashableClass()], [], "UnhashableClass"), + ("unhashable_hidden_params", lambda: None, [], [UnhashableClass()], None), + ], + ids=["unhashable_params", "unhashable_hidden_params"], +) +def test_memoization_unhashable_values( + function_name: str, + function: typing.Callable, + params: list, + hidden_params: list, + expected_result: Any, +) -> None: + _pipeline_manager.current_pipeline = PipelineProcess( + MessageDataProgram({}, ProgramMainInformation("", "", "")), + "", + Queue(), + {}, + MemoizationMap({}, {}), + ) + + result = memoized_call(function_name, function, params, hidden_params) + assert result == expected_result + + def test_file_mtime_exists() -> None: with tempfile.NamedTemporaryFile() as file: mtime = file_mtime(file.name)