Skip to content

Commit

Permalink
keep hash consistent for enum in newer Python versions
Browse files Browse the repository at this point in the history
Fix #188
  • Loading branch information
albertz committed May 7, 2024
1 parent 0da764a commit 6751a98
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
13 changes: 12 additions & 1 deletion sisyphus/hash.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import hashlib
from inspect import isclass, isfunction

Expand Down Expand Up @@ -61,6 +62,10 @@ def get_object_state(obj):
assert args is not None, "Failed to get object state of: %s" % repr(obj)
state = None

if isinstance(obj, enum.Enum):
assert isinstance(state, dict)
state.pop("_sort_order_", None) # compat with Python <=3.10, https://github.com/rwth-i6/sisyphus/issues/188

if args is None:
return state
else:
Expand All @@ -76,7 +81,7 @@ def sis_hash_helper(obj):
"""

# Store type to ensure it's unique
byte_list = [type(obj).__qualname__.encode()]
byte_list = [_obj_type_qualname(obj)]

# Using type and not isinstance to avoid derived types
if isinstance(obj, bytes):
Expand Down Expand Up @@ -116,3 +121,9 @@ def sis_hash_helper(obj):
return hashlib.sha256(byte_str).digest()
else:
return byte_str


def _obj_type_qualname(obj) -> bytes:
if type(obj) is enum.EnumMeta: # EnumMeta is old alias for EnumType
return b"EnumMeta" # compat with Python <=3.10, https://github.com/rwth-i6/sisyphus/issues/188
return type(obj).__qualname__.encode()
14 changes: 14 additions & 0 deletions tests/hash_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ def b():
pass


class MyEnum(enum.Enum):
Entry0 = 0
Entry1 = 1


class HashTest(unittest.TestCase):
def test_get_object_state(self):

Expand All @@ -18,6 +23,15 @@ def d():
self.assertEqual(sis_hash_helper(b), b"(function, (tuple, (str, '" + __name__.encode() + b"'), (str, 'b')))")
self.assertRaises(AssertionError, sis_hash_helper, c)

def test_enum(self):
self.assertEqual(
sis_hash_helper(MyEnum.Entry1),
b"(%s, (dict, (tuple, (str, '__objclass__')," % MyEnum.__name__.encode()
+ b" (EnumMeta, (tuple, (str, '%s'), (str, '%s')))),"
% (MyEnum.__module__.encode(), MyEnum.__name__.encode())
+ b" (tuple, (str, '_name_'), (str, 'Entry1')), (tuple, (str, '_value_'), (int, 1))))",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6751a98

Please sign in to comment.