From dd5d3596ebe15f58cd7f9a7ef38fbf363dd013b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tiziano=20M=C3=BCller?= Date: Thu, 27 May 2021 10:24:24 +0200 Subject: [PATCH] orm: permit Decimal in attributes --- aiida/common/hashing.py | 16 +++++++++++++++- aiida/orm/implementation/utils.py | 3 ++- tests/common/test_hashing.py | 13 +++++++++++++ tests/orm/node/test_node.py | 8 ++++++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/aiida/common/hashing.py b/aiida/common/hashing.py index 61dbbaf1c4..0c1f86be8d 100644 --- a/aiida/common/hashing.py +++ b/aiida/common/hashing.py @@ -19,6 +19,7 @@ from functools import singledispatch from itertools import chain from operator import itemgetter +from decimal import Decimal import pytz @@ -222,11 +223,24 @@ def _(mapping, **kwargs): def _(val, **kwargs): """ Before hashing a float, convert to a string (via rounding) and with a fixed number of digits after the comma. - Note that the `_singe_digest` requires a bytes object so we need to encode the utf-8 string first + Note that the `_single_digest` requires a bytes object so we need to encode the utf-8 string first """ return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))] +@_make_hash.register(Decimal) +def _(val, **kwargs): + """ + While a decimal can be converted exactly to a string which captures all characteristics of the underlying + implementation, we also need compatibility with "equal" representations as int or float. Hence we are checking + for the exponent (which is negative if there is a fractional component, 0 otherwise) and get the same hash + as for a corresponding float or int. + """ + if val.as_tuple().exponent < 0: + return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))] + return [_single_digest('int', f'{val}'.encode('utf-8'))] + + @_make_hash.register(numbers.Complex) def _(val, **kwargs): """ diff --git a/aiida/orm/implementation/utils.py b/aiida/orm/implementation/utils.py index 2964cf6865..f84482e869 100644 --- a/aiida/orm/implementation/utils.py +++ b/aiida/orm/implementation/utils.py @@ -10,6 +10,7 @@ """Utility methods for backend non-specific implementations.""" import math import numbers +from decimal import Decimal from collections.abc import Iterable, Mapping @@ -63,7 +64,7 @@ def clean_builtin(val): It mainly checks that we don't store NaN or Inf. """ # This is a whitelist of all the things we understand currently - if val is None or isinstance(val, (bool, str)): + if val is None or isinstance(val, (bool, str, Decimal)): return val # This fixes #2773 - in python3, ``numpy.int64(-1)`` cannot be json-serialized diff --git a/tests/common/test_hashing.py b/tests/common/test_hashing.py index 76a85dddec..50656d2043 100644 --- a/tests/common/test_hashing.py +++ b/tests/common/test_hashing.py @@ -16,6 +16,7 @@ from datetime import datetime import hashlib import uuid +from decimal import Decimal import numpy as np import pytz @@ -175,6 +176,18 @@ def test_numpy_types(self): ) # pylint: disable=no-member self.assertEqual(make_hash(np.int64(42)), '9468692328de958d7a8039e8a2eb05cd6888b7911bbc3794d0dfebd8df3482cd') # pylint: disable=no-member + def test_decimal(self): + self.assertEqual( + make_hash(Decimal('3.141')), 'b3302aad550413e14fe44d5ead10b3aeda9884055fca77f9368c48517916d4be' + ) # pylint: disable=no-member + + # make sure we get the same hashes as for corresponding float or int + self.assertEqual(make_hash(Decimal('3.141')), make_hash(3.141)) # pylint: disable=no-member + + self.assertEqual(make_hash(Decimal('3.')), make_hash(3)) # pylint: disable=no-member + + self.assertEqual(make_hash(Decimal('3141')), make_hash(3141)) # pylint: disable=no-member + def test_unhashable_type(self): class MadeupClass: diff --git a/tests/orm/node/test_node.py b/tests/orm/node/test_node.py index 280b0c7347..f95955ff22 100644 --- a/tests/orm/node/test_node.py +++ b/tests/orm/node/test_node.py @@ -12,6 +12,7 @@ import logging import os import tempfile +from decimal import Decimal import pytest @@ -416,6 +417,13 @@ def test_extras_keys(self): self.node.set_extra_many(extras) assert set(self.node.extras_keys()) == set(extras) + def test_attribute_decimal(self): + """Test that the `Node.set_attribute` method supports Decimal.""" + self.node.set_attribute('a_val', Decimal('3.141')) + self.node.store() + # ensure the returned node is a float + assert self.node.get_attribute('a_val') == 3.141 + @pytest.mark.usefixtures('clear_database_before_test_class') class TestNodeLinks: