Skip to content

Commit

Permalink
Keep hash consistent for enum in newer Python versions (#189)
Browse files Browse the repository at this point in the history
Fix #188
  • Loading branch information
albertz authored May 7, 2024
1 parent 0da764a commit 816ffac
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
15 changes: 14 additions & 1 deletion sisyphus/hash.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import hashlib
from inspect import isclass, isfunction

Expand Down Expand Up @@ -61,6 +62,11 @@ def get_object_state(obj):
assert args is not None, "Failed to get object state of: %s" % repr(obj)
state = None

if isinstance(obj, enum.Enum):
assert isinstance(state, dict)
# In Python >=3.11, keep hash same as in Python <=3.10, https://github.com/rwth-i6/sisyphus/issues/188
state.pop("_sort_order_", None)

if args is None:
return state
else:
Expand All @@ -76,7 +82,7 @@ def sis_hash_helper(obj):
"""

# Store type to ensure it's unique
byte_list = [type(obj).__qualname__.encode()]
byte_list = [_obj_type_qualname(obj)]

# Using type and not isinstance to avoid derived types
if isinstance(obj, bytes):
Expand Down Expand Up @@ -116,3 +122,10 @@ def sis_hash_helper(obj):
return hashlib.sha256(byte_str).digest()
else:
return byte_str


def _obj_type_qualname(obj) -> bytes:
if type(obj) is enum.EnumMeta: # EnumMeta is old alias for EnumType
# In Python >=3.11, keep hash same as in Python <=3.10, https://github.com/rwth-i6/sisyphus/issues/188
return b"EnumMeta"
return type(obj).__qualname__.encode()
14 changes: 14 additions & 0 deletions tests/hash_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ def b():
pass


class MyEnum(enum.Enum):
Entry0 = 0
Entry1 = 1


class HashTest(unittest.TestCase):
def test_get_object_state(self):

Expand All @@ -18,6 +23,15 @@ def d():
self.assertEqual(sis_hash_helper(b), b"(function, (tuple, (str, '" + __name__.encode() + b"'), (str, 'b')))")
self.assertRaises(AssertionError, sis_hash_helper, c)

def test_enum(self):
self.assertEqual(
sis_hash_helper(MyEnum.Entry1),
b"(%s, (dict, (tuple, (str, '__objclass__')," % MyEnum.__name__.encode()
+ b" (EnumMeta, (tuple, (str, '%s'), (str, '%s')))),"
% (MyEnum.__module__.encode(), MyEnum.__name__.encode())
+ b" (tuple, (str, '_name_'), (str, 'Entry1')), (tuple, (str, '_value_'), (int, 1))))",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 816ffac

Please sign in to comment.