diff --git a/simple_parsing/helpers/serialization/decoding.py b/simple_parsing/helpers/serialization/decoding.py index 25264913..f40ee84b 100644 --- a/simple_parsing/helpers/serialization/decoding.py +++ b/simple_parsing/helpers/serialization/decoding.py @@ -1,5 +1,7 @@ """ Functions for decoding dataclass fields from "raw" values (e.g. from json). """ +from __future__ import annotations + import inspect import warnings from collections import OrderedDict @@ -9,14 +11,16 @@ from functools import lru_cache, partial from logging import getLogger from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, TypeVar from simple_parsing.annotation_utils.get_field_annotations import ( evaluate_string_annotation, ) from simple_parsing.utils import ( get_bound, + get_forward_arg, get_type_arguments, + is_dataclass_type, is_dict, is_enum, is_forward_ref, @@ -35,7 +39,7 @@ V = TypeVar("V") # Dictionary mapping from types/type annotations to their decoding functions. -_decoding_fns: Dict[Type[T], Callable[[Any], T]] = { +_decoding_fns: dict[type[T], Callable[[Any], T]] = { # the 'primitive' types are decoded using the type fn as a constructor. t: t for t in [str, float, int, bytes] @@ -51,7 +55,7 @@ def decode_bool(v: Any) -> bool: _decoding_fns[bool] = decode_bool -def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[type] = None) -> Any: +def decode_field(field: Field, raw_value: Any, containing_dataclass: type | None = None) -> Any: """Converts a "raw" value (e.g. from json file) to the type of the `field`. When serializing a dataclass to json, all objects are converted to dicts. @@ -84,7 +88,7 @@ def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[ty @lru_cache(maxsize=100) -def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]: +def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]: """Fetches/Creates a decoding function for the given type annotation. This decoding function can then be used to create an instance of the type @@ -111,67 +115,54 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]: A function that decodes a 'raw' value to an instance of type `t`. """ - # cache_info = get_decoding_fn.cache_info() - # logger.debug(f"called for type {t}! Cache info: {cache_info}") - - def _get_potential_keys(annotation: str) -> List[str]: - # Type annotation is a string. - # This can happen when the `from __future__ import annotations` feature is used. - potential_keys: List[Type] = [] - for key in _decoding_fns: - if inspect.isclass(key): - if key.__qualname__ == annotation: - # Qualname is more specific, there can't possibly be another match, so break. - potential_keys.append(key) - break - if key.__qualname__ == annotation: - # For just __name__, there could be more than one match. - potential_keys.append(key) - return potential_keys - - if isinstance(t, str): - if t in _decoding_fns: - return _decoding_fns[t] - - potential_keys = _get_potential_keys(t) - - if not potential_keys: - # Try to replace the new-style annotation str with the old style syntax, and see if we - # find a match. - # try: - try: - evaluated_t = evaluate_string_annotation(t) - # NOTE: We now have a 'live'/runtime type annotation object from the typing module. - except (ValueError, TypeError) as err: - logger.error(f"Unable to evaluate the type annotation string {t}: {err}.") - else: - if evaluated_t in _decoding_fns: - return _decoding_fns[evaluated_t] - # If we still don't have this annotation stored in our dict of known functions, we - # recurse, to try to deconstruct this annotation into its parts, and construct the - # decoding function for the annotation. If this doesn't work, we just raise the - # errors. - return get_decoding_fn(evaluated_t) - - raise ValueError( - f"Couldn't find a decoding function for the string annotation '{t}'.\n" - f"This is probably a bug. If it is, please make an issue on GitHub so we can get " - f"to work on fixing it.\n" - f"Types with a known decoding function: {list(_decoding_fns.keys())}" + from .serializable import from_dict + + logger.debug(f"Getting the decoding function for {type_annotation!r}") + + if isinstance(type_annotation, str): + # Check first if there are any matching registered decoding functions. + # TODO: Might be better to actually use the scope of the field, right? + matching_entries = { + key: decoding_fn + for key, decoding_fn in _decoding_fns.items() + if (inspect.isclass(key) and key.__name__ == type_annotation) + } + if len(matching_entries) == 1: + _, decoding_fn = matching_entries.popitem() + return decoding_fn + elif len(matching_entries) > 1: + # Multiple decoding functions match the type. Can't tell. + logger.warning( + RuntimeWarning( + f"More than one potential decoding functions were found for types that match " + f"the string annotation {type_annotation!r}. This will simply try each one " + f"and return the first one that works." + ) ) - if len(potential_keys) == 1: - t = potential_keys[0] + return try_functions(*(decoding_fn for _, decoding_fn in matching_entries.items())) else: - raise ValueError( - f"Multiple decoding functions registered for a type {t}: {potential_keys} \n" - f"This could be a bug, but try to use different names for each type, or add the " - f"modules they come from as a prefix, perhaps?" - ) + # Try to evaluate the string annotation. + t = evaluate_string_annotation(type_annotation) + + elif is_forward_ref(type_annotation): + forward_arg: str = get_forward_arg(type_annotation) + # Recurse until we've resolved the forward reference. + return get_decoding_fn(forward_arg) + + else: + t = type_annotation + + logger.debug(f"{type_annotation!r} -> {t!r}") + + # T should now be a type or one of the objects from the typing module. if t in _decoding_fns: # The type has a dedicated decoding function. return _decoding_fns[t] + if is_dataclass_type(t): + return partial(from_dict, t) + if t is Any: logger.debug(f"Decoding an Any type: {t}") return no_op @@ -214,31 +205,6 @@ def _get_potential_keys(annotation: str) -> List[str]: logger.debug(f"Decoding an Enum field: {t}") return decode_enum(t) - from .serializable import SerializableMixin, get_dataclass_types_from_forward_ref - - if is_forward_ref(t): - dcs = get_dataclass_types_from_forward_ref(t) - if len(dcs) == 1: - dc = dcs[0] - return dc.from_dict - if len(dcs) > 1: - logger.warning( - RuntimeWarning( - f"More than one potential Serializable dataclass was found with a name matching " - f"the type annotation {t}. This will simply try each one, and return the " - f"first one that works. Potential classes: {dcs}" - ) - ) - return try_functions(*[partial(dc.from_dict, drop_extra_fields=False) for dc in dcs]) - else: - # No idea what the forward ref refers to! - logger.warning( - f"Unable to find a dataclass that matches the forward ref {t} inside the " - f"registered {SerializableMixin} subclasses. Leaving the value as-is." - f"(Consider using Serializable or FrozenSerializable as a base class?)." - ) - return no_op - if is_typevar(t): bound = get_bound(t) logger.debug(f"Decoding a typevar: {t}, bound type is {bound}.") @@ -256,31 +222,31 @@ def _get_potential_keys(annotation: str) -> List[str]: return try_constructor(t) -def _register(t: Type, func: Callable) -> None: +def _register(t: type, func: Callable) -> None: if t not in _decoding_fns: # logger.debug(f"Registering the type {t} with decoding function {func}") _decoding_fns[t] = func -def register_decoding_fn(some_type: Type[T], function: Callable[[Any], T]) -> None: +def register_decoding_fn(some_type: type[T], function: Callable[[Any], T]) -> None: """Register a decoding function for the type `some_type`.""" _register(some_type, function) -def decode_optional(t: Type[T]) -> Callable[[Optional[Any]], Optional[T]]: +def decode_optional(t: type[T]) -> Callable[[Any | None], T | None]: decode = get_decoding_fn(t) - def _decode_optional(val: Optional[Any]) -> Optional[T]: + def _decode_optional(val: Any | None) -> T | None: return val if val is None else decode(val) return _decode_optional -def try_functions(*funcs: Callable[[Any], T]) -> Callable[[Any], Union[T, Any]]: +def try_functions(*funcs: Callable[[Any], T]) -> Callable[[Any], T | Any]: """Tries to use the functions in succession, else returns the same value unchanged.""" - def _try_functions(val: Any) -> Union[T, Any]: - e: Optional[Exception] = None + def _try_functions(val: Any) -> T | Any: + e: Exception | None = None for func in funcs: try: return func(val) @@ -293,30 +259,30 @@ def _try_functions(val: Any) -> Union[T, Any]: return _try_functions -def decode_union(*types: Type[T]) -> Callable[[Any], Union[T, Any]]: +def decode_union(*types: type[T]) -> Callable[[Any], T | Any]: types = list(types) optional = type(None) in types # Partition the Union into None and non-None types. while type(None) in types: types.remove(type(None)) - decoding_fns: List[Callable[[Any], T]] = [ + decoding_fns: list[Callable[[Any], T]] = [ decode_optional(t) if optional else get_decoding_fn(t) for t in types ] # Try using each of the non-None types, in succession. Worst case, return the value. return try_functions(*decoding_fns) -def decode_list(t: Type[T]) -> Callable[[List[Any]], List[T]]: +def decode_list(t: type[T]) -> Callable[[list[Any]], list[T]]: decode_item = get_decoding_fn(t) - def _decode_list(val: List[Any]) -> List[T]: + def _decode_list(val: list[Any]) -> list[T]: return [decode_item(v) for v in val] return _decode_list -def decode_tuple(*tuple_item_types: Type[T]) -> Callable[[List[T]], Tuple[T, ...]]: +def decode_tuple(*tuple_item_types: type[T]) -> Callable[[list[T]], tuple[T, ...]]: """Makes a parsing function for creating tuples. Can handle tuples with different item types, for instance: @@ -338,7 +304,7 @@ def decode_tuple(*tuple_item_types: Type[T]) -> Callable[[List[T]], Tuple[T, ... # Note, if there are more values than types in the tuple type, then the # last type is used. - def _decode_tuple(val: Tuple[Any, ...]) -> Tuple[T, ...]: + def _decode_tuple(val: tuple[Any, ...]) -> tuple[T, ...]: if has_ellipsis: return tuple(decoding_fn(v) for v in val) else: @@ -347,7 +313,7 @@ def _decode_tuple(val: Tuple[Any, ...]) -> Tuple[T, ...]: return _decode_tuple -def decode_set(item_type: Type[T]) -> Callable[[List[T]], Set[T]]: +def decode_set(item_type: type[T]) -> Callable[[list[T]], set[T]]: """Makes a parsing function for creating sets with items of type `item_type`. Args: @@ -359,13 +325,13 @@ def decode_set(item_type: Type[T]) -> Callable[[List[T]], Set[T]]: # Get the parse fn for a list of items of type `item_type`. parse_list_fn = decode_list(item_type) - def _decode_set(val: List[Any]) -> Set[T]: + def _decode_set(val: list[Any]) -> set[T]: return set(parse_list_fn(val)) return _decode_set -def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], Dict[K, V]]: +def decode_dict(K_: type[K], V_: type[V]) -> Callable[[list[tuple[Any, Any]]], dict[K, V]]: """Creates a decoding function for a dict type. Works with OrderedDict too. Args: @@ -379,8 +345,8 @@ def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], D decode_k = get_decoding_fn(K_) decode_v = get_decoding_fn(V_) - def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V]: - result: Dict[K, V] = {} + def _decode_dict(val: dict[Any, Any] | list[tuple[Any, Any]]) -> dict[K, V]: + result: dict[K, V] = {} if isinstance(val, list): result = OrderedDict() items = val @@ -399,7 +365,7 @@ def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V return _decode_dict -def decode_enum(item_type: Type[Enum]) -> Callable[[str], Enum]: +def decode_enum(item_type: type[Enum]) -> Callable[[str], Enum]: """ Creates a decoding function for an enum type. @@ -428,7 +394,7 @@ def no_op(v: T) -> T: return v -def try_constructor(t: Type[T]) -> Callable[[Any], Union[T, Any]]: +def try_constructor(t: type[T]) -> Callable[[Any], T | Any]: """Tries to use the type as a constructor. If that fails, returns the value as-is. Args: diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index 4359918b..22bfb425 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -338,9 +338,10 @@ class SimpleSerializable(SerializableMixin, decode_into_subclasses=True): S = TypeVar("S", bound=SerializableMixin) -def get_dataclass_types_from_forward_ref( +def get_serializable_dataclass_types_from_forward_ref( forward_ref: type, serializable_base_class: type[S] = SerializableMixin ) -> list[type[S]]: + """Gets all the subclasses of `serializable_base_class` that have the same name as the argument of this forward reference annotation.""" arg = get_forward_arg(forward_ref) potential_classes: list[type] = [] for serializable_class in serializable_base_class.subclasses: diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 755194ba..916b29c9 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -24,6 +24,7 @@ ClassVar, Container, Dict, + ForwardRef, Iterable, List, Mapping, @@ -50,12 +51,12 @@ def get_bound(t): raise TypeError(f"type is not a `TypeVar`: {t}") -def is_forward_ref(t): +def is_forward_ref(t) -> TypeGuard[typing.ForwardRef]: return isinstance(t, typing.ForwardRef) -def get_forward_arg(fr): - return getattr(fr, "__forward_arg__", None) +def get_forward_arg(fr: ForwardRef) -> str: + return getattr(fr, "__forward_arg__") logger = getLogger(__name__) diff --git a/test/helpers/test_from_dict.py b/test/helpers/test_from_dict.py new file mode 100644 index 00000000..022758f0 --- /dev/null +++ b/test/helpers/test_from_dict.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import functools +from dataclasses import dataclass, field, replace + +import pytest + +from simple_parsing.helpers.serialization import from_dict, to_dict +from simple_parsing.utils import Dataclass + + +def test_replace_and_from_dict_already_call_post_init(): + n_post_init_calls = 0 + + @dataclass + class Bob: + a: int = 123 + + def __post_init__(self): + nonlocal n_post_init_calls + n_post_init_calls += 1 + + assert n_post_init_calls == 0 + bob = Bob() + assert n_post_init_calls == 1 + _ = replace(bob, a=456) + assert n_post_init_calls == 2 + + _ = from_dict(Bob, {"a": 456}) + assert n_post_init_calls == 3 + + +@dataclass +class InnerConfig: + arg1: int = 1 + arg2: str = "foo" + arg1_post_init: str = field(init=False) + + def __post_init__(self): + self.arg1_post_init = str(self.arg1) + + +@dataclass +class OuterConfig1: + out_arg: int = 0 + inner: InnerConfig = field(default_factory=InnerConfig) + + +@dataclass +class OuterConfig2: + out_arg: int = 0 + inner: InnerConfig = field(default_factory=functools.partial(InnerConfig, arg2="bar")) + + +@dataclass +class Level1: + arg: int = 1 + + +@dataclass +class Level2: + arg: int = 1 + prev: Level1 = field(default_factory=Level1) + + +@dataclass +class Level3: + arg: int = 1 + prev: Level2 = field(default_factory=Level2) + + +@pytest.mark.parametrize( + ("config"), + [ + OuterConfig1(), + OuterConfig2(), + Level1(arg=2), + Level2(arg=2, prev=Level1(arg=3)), + Level2(), + Level3(), + ], +) +def test_issue_210_nested_dataclasses_serialization(config: Dataclass): + _from_dict = functools.partial(from_dict, type(config)) + assert _from_dict(to_dict(config)) == config + assert _from_dict(to_dict(config), drop_extra_fields=True) == config + # More 'intense' comparisons, to make sure that the serialization is reversible: + assert to_dict(_from_dict(to_dict(config))) == to_dict(config) + assert _from_dict(to_dict(_from_dict(to_dict(config)))) == _from_dict(to_dict(config)) diff --git a/test/utils/test_serialization.py b/test/helpers/test_serializable.py similarity index 100% rename from test/utils/test_serialization.py rename to test/helpers/test_serializable.py diff --git a/test/test_issue_46.py b/test/test_issue_46.py index 834cfbeb..e5133876 100644 --- a/test/test_issue_46.py +++ b/test/test_issue_46.py @@ -22,7 +22,6 @@ def test_issue_46(assert_equals_stdout): parser.add_arguments(JBuildRelease, dest="jbuild", prefix="jbuild") s = StringIO() - parser.print_help parser.print_help(s) s.seek(0) output = str(s.read())