Skip to content

Commit

Permalink
orm: permit Decimal in attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-zero committed Jul 21, 2021
1 parent 758ebf1 commit dd5d359
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 2 deletions.
16 changes: 15 additions & 1 deletion aiida/common/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from functools import singledispatch
from itertools import chain
from operator import itemgetter
from decimal import Decimal

import pytz

Expand Down Expand Up @@ -222,11 +223,24 @@ def _(mapping, **kwargs):
def _(val, **kwargs):
"""
Before hashing a float, convert to a string (via rounding) and with a fixed number of digits after the comma.
Note that the `_singe_digest` requires a bytes object so we need to encode the utf-8 string first
Note that the `_single_digest` requires a bytes object so we need to encode the utf-8 string first
"""
return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))]


@_make_hash.register(Decimal)
def _(val, **kwargs):
"""
While a decimal can be converted exactly to a string which captures all characteristics of the underlying
implementation, we also need compatibility with "equal" representations as int or float. Hence we are checking
for the exponent (which is negative if there is a fractional component, 0 otherwise) and get the same hash
as for a corresponding float or int.
"""
if val.as_tuple().exponent < 0:
return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))]
return [_single_digest('int', f'{val}'.encode('utf-8'))]


@_make_hash.register(numbers.Complex)
def _(val, **kwargs):
"""
Expand Down
3 changes: 2 additions & 1 deletion aiida/orm/implementation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""Utility methods for backend non-specific implementations."""
import math
import numbers
from decimal import Decimal

from collections.abc import Iterable, Mapping

Expand Down Expand Up @@ -63,7 +64,7 @@ def clean_builtin(val):
It mainly checks that we don't store NaN or Inf.
"""
# This is a whitelist of all the things we understand currently
if val is None or isinstance(val, (bool, str)):
if val is None or isinstance(val, (bool, str, Decimal)):
return val

# This fixes #2773 - in python3, ``numpy.int64(-1)`` cannot be json-serialized
Expand Down
13 changes: 13 additions & 0 deletions tests/common/test_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import datetime
import hashlib
import uuid
from decimal import Decimal

import numpy as np
import pytz
Expand Down Expand Up @@ -175,6 +176,18 @@ def test_numpy_types(self):
) # pylint: disable=no-member
self.assertEqual(make_hash(np.int64(42)), '9468692328de958d7a8039e8a2eb05cd6888b7911bbc3794d0dfebd8df3482cd') # pylint: disable=no-member

def test_decimal(self):
self.assertEqual(
make_hash(Decimal('3.141')), 'b3302aad550413e14fe44d5ead10b3aeda9884055fca77f9368c48517916d4be'
) # pylint: disable=no-member

# make sure we get the same hashes as for corresponding float or int
self.assertEqual(make_hash(Decimal('3.141')), make_hash(3.141)) # pylint: disable=no-member

self.assertEqual(make_hash(Decimal('3.')), make_hash(3)) # pylint: disable=no-member

self.assertEqual(make_hash(Decimal('3141')), make_hash(3141)) # pylint: disable=no-member

def test_unhashable_type(self):

class MadeupClass:
Expand Down
8 changes: 8 additions & 0 deletions tests/orm/node/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import os
import tempfile
from decimal import Decimal

import pytest

Expand Down Expand Up @@ -416,6 +417,13 @@ def test_extras_keys(self):
self.node.set_extra_many(extras)
assert set(self.node.extras_keys()) == set(extras)

def test_attribute_decimal(self):
"""Test that the `Node.set_attribute` method supports Decimal."""
self.node.set_attribute('a_val', Decimal('3.141'))
self.node.store()
# ensure the returned node is a float
assert self.node.get_attribute('a_val') == 3.141


@pytest.mark.usefixtures('clear_database_before_test_class')
class TestNodeLinks:
Expand Down

0 comments on commit dd5d359

Please sign in to comment.