diff --git a/sisyphus/hash.py b/sisyphus/hash.py index d4de9ee..2b61eae 100644 --- a/sisyphus/hash.py +++ b/sisyphus/hash.py @@ -1,3 +1,4 @@ +import enum import hashlib from inspect import isclass, isfunction @@ -61,6 +62,10 @@ def get_object_state(obj): assert args is not None, "Failed to get object state of: %s" % repr(obj) state = None + if isinstance(obj, enum.Enum): + assert isinstance(state, dict) + state.pop("_sort_order_", None) # compat with Python <=3.10, https://github.com/rwth-i6/sisyphus/issues/188 + if args is None: return state else: @@ -76,7 +81,7 @@ def sis_hash_helper(obj): """ # Store type to ensure it's unique - byte_list = [type(obj).__qualname__.encode()] + byte_list = [_obj_type_qualname(obj)] # Using type and not isinstance to avoid derived types if isinstance(obj, bytes): @@ -116,3 +121,9 @@ def sis_hash_helper(obj): return hashlib.sha256(byte_str).digest() else: return byte_str + + +def _obj_type_qualname(obj) -> bytes: + if type(obj) is enum.EnumMeta: # EnumMeta is old alias for EnumType + return b"EnumMeta" # compat with Python <=3.10, https://github.com/rwth-i6/sisyphus/issues/188 + return type(obj).__qualname__.encode() diff --git a/tests/hash_unittest.py b/tests/hash_unittest.py index 8b78279..a653e6c 100644 --- a/tests/hash_unittest.py +++ b/tests/hash_unittest.py @@ -7,6 +7,11 @@ def b(): pass +class MyEnum(enum.Enum): + Entry0 = 0 + Entry1 = 1 + + class HashTest(unittest.TestCase): def test_get_object_state(self): @@ -18,6 +23,15 @@ def d(): self.assertEqual(sis_hash_helper(b), b"(function, (tuple, (str, '" + __name__.encode() + b"'), (str, 'b')))") self.assertRaises(AssertionError, sis_hash_helper, c) + def test_enum(self): + self.assertEqual( + sis_hash_helper(MyEnum.Entry1), + b"(%s, (dict, (tuple, (str, '__objclass__')," % MyEnum.__name__.encode() + + b" (EnumMeta, (tuple, (str, '%s'), (str, '%s'))))," + % (MyEnum.__module__.encode(), MyEnum.__name__.encode()) + + b" (tuple, (str, '_name_'), (str, 'Entry1')), (tuple, (str, '_value_'), (int, 1))))", + ) + if __name__ == "__main__": unittest.main()