Skip to content

Commit

Permalink
Broaden from_dict applicability to non-Serializable dataclasses (#217)
Browse files Browse the repository at this point in the history
* [temp] Save local changes

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fix issue with resolving of forward references

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fix logging format string for py37

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Move some test files over to test/helpers

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

* Fuse test_serialization into test_from_dict

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>

---------

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
lebrice authored Feb 12, 2023
1 parent 404f7f3 commit 3122f6e
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 108 deletions.
172 changes: 69 additions & 103 deletions simple_parsing/helpers/serialization/decoding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
""" Functions for decoding dataclass fields from "raw" values (e.g. from json).
"""
from __future__ import annotations

import inspect
import warnings
from collections import OrderedDict
Expand All @@ -9,14 +11,16 @@
from functools import lru_cache, partial
from logging import getLogger
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, TypeVar

from simple_parsing.annotation_utils.get_field_annotations import (
evaluate_string_annotation,
)
from simple_parsing.utils import (
get_bound,
get_forward_arg,
get_type_arguments,
is_dataclass_type,
is_dict,
is_enum,
is_forward_ref,
Expand All @@ -35,7 +39,7 @@
V = TypeVar("V")

# Dictionary mapping from types/type annotations to their decoding functions.
_decoding_fns: Dict[Type[T], Callable[[Any], T]] = {
_decoding_fns: dict[type[T], Callable[[Any], T]] = {
# the 'primitive' types are decoded using the type fn as a constructor.
t: t
for t in [str, float, int, bytes]
Expand All @@ -51,7 +55,7 @@ def decode_bool(v: Any) -> bool:
_decoding_fns[bool] = decode_bool


def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[type] = None) -> Any:
def decode_field(field: Field, raw_value: Any, containing_dataclass: type | None = None) -> Any:
"""Converts a "raw" value (e.g. from json file) to the type of the `field`.
When serializing a dataclass to json, all objects are converted to dicts.
Expand Down Expand Up @@ -84,7 +88,7 @@ def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[ty


@lru_cache(maxsize=100)
def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
"""Fetches/Creates a decoding function for the given type annotation.
This decoding function can then be used to create an instance of the type
Expand All @@ -111,67 +115,54 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
A function that decodes a 'raw' value to an instance of type `t`.
"""
# cache_info = get_decoding_fn.cache_info()
# logger.debug(f"called for type {t}! Cache info: {cache_info}")

def _get_potential_keys(annotation: str) -> List[str]:
# Type annotation is a string.
# This can happen when the `from __future__ import annotations` feature is used.
potential_keys: List[Type] = []
for key in _decoding_fns:
if inspect.isclass(key):
if key.__qualname__ == annotation:
# Qualname is more specific, there can't possibly be another match, so break.
potential_keys.append(key)
break
if key.__qualname__ == annotation:
# For just __name__, there could be more than one match.
potential_keys.append(key)
return potential_keys

if isinstance(t, str):
if t in _decoding_fns:
return _decoding_fns[t]

potential_keys = _get_potential_keys(t)

if not potential_keys:
# Try to replace the new-style annotation str with the old style syntax, and see if we
# find a match.
# try:
try:
evaluated_t = evaluate_string_annotation(t)
# NOTE: We now have a 'live'/runtime type annotation object from the typing module.
except (ValueError, TypeError) as err:
logger.error(f"Unable to evaluate the type annotation string {t}: {err}.")
else:
if evaluated_t in _decoding_fns:
return _decoding_fns[evaluated_t]
# If we still don't have this annotation stored in our dict of known functions, we
# recurse, to try to deconstruct this annotation into its parts, and construct the
# decoding function for the annotation. If this doesn't work, we just raise the
# errors.
return get_decoding_fn(evaluated_t)

raise ValueError(
f"Couldn't find a decoding function for the string annotation '{t}'.\n"
f"This is probably a bug. If it is, please make an issue on GitHub so we can get "
f"to work on fixing it.\n"
f"Types with a known decoding function: {list(_decoding_fns.keys())}"
from .serializable import from_dict

logger.debug(f"Getting the decoding function for {type_annotation!r}")

if isinstance(type_annotation, str):
# Check first if there are any matching registered decoding functions.
# TODO: Might be better to actually use the scope of the field, right?
matching_entries = {
key: decoding_fn
for key, decoding_fn in _decoding_fns.items()
if (inspect.isclass(key) and key.__name__ == type_annotation)
}
if len(matching_entries) == 1:
_, decoding_fn = matching_entries.popitem()
return decoding_fn
elif len(matching_entries) > 1:
# Multiple decoding functions match the type. Can't tell.
logger.warning(
RuntimeWarning(
f"More than one potential decoding functions were found for types that match "
f"the string annotation {type_annotation!r}. This will simply try each one "
f"and return the first one that works."
)
)
if len(potential_keys) == 1:
t = potential_keys[0]
return try_functions(*(decoding_fn for _, decoding_fn in matching_entries.items()))
else:
raise ValueError(
f"Multiple decoding functions registered for a type {t}: {potential_keys} \n"
f"This could be a bug, but try to use different names for each type, or add the "
f"modules they come from as a prefix, perhaps?"
)
# Try to evaluate the string annotation.
t = evaluate_string_annotation(type_annotation)

elif is_forward_ref(type_annotation):
forward_arg: str = get_forward_arg(type_annotation)
# Recurse until we've resolved the forward reference.
return get_decoding_fn(forward_arg)

else:
t = type_annotation

logger.debug(f"{type_annotation!r} -> {t!r}")

# T should now be a type or one of the objects from the typing module.

if t in _decoding_fns:
# The type has a dedicated decoding function.
return _decoding_fns[t]

if is_dataclass_type(t):
return partial(from_dict, t)

if t is Any:
logger.debug(f"Decoding an Any type: {t}")
return no_op
Expand Down Expand Up @@ -214,31 +205,6 @@ def _get_potential_keys(annotation: str) -> List[str]:
logger.debug(f"Decoding an Enum field: {t}")
return decode_enum(t)

from .serializable import SerializableMixin, get_dataclass_types_from_forward_ref

if is_forward_ref(t):
dcs = get_dataclass_types_from_forward_ref(t)
if len(dcs) == 1:
dc = dcs[0]
return dc.from_dict
if len(dcs) > 1:
logger.warning(
RuntimeWarning(
f"More than one potential Serializable dataclass was found with a name matching "
f"the type annotation {t}. This will simply try each one, and return the "
f"first one that works. Potential classes: {dcs}"
)
)
return try_functions(*[partial(dc.from_dict, drop_extra_fields=False) for dc in dcs])
else:
# No idea what the forward ref refers to!
logger.warning(
f"Unable to find a dataclass that matches the forward ref {t} inside the "
f"registered {SerializableMixin} subclasses. Leaving the value as-is."
f"(Consider using Serializable or FrozenSerializable as a base class?)."
)
return no_op

if is_typevar(t):
bound = get_bound(t)
logger.debug(f"Decoding a typevar: {t}, bound type is {bound}.")
Expand All @@ -256,31 +222,31 @@ def _get_potential_keys(annotation: str) -> List[str]:
return try_constructor(t)


def _register(t: Type, func: Callable) -> None:
def _register(t: type, func: Callable) -> None:
if t not in _decoding_fns:
# logger.debug(f"Registering the type {t} with decoding function {func}")
_decoding_fns[t] = func


def register_decoding_fn(some_type: Type[T], function: Callable[[Any], T]) -> None:
def register_decoding_fn(some_type: type[T], function: Callable[[Any], T]) -> None:
"""Register a decoding function for the type `some_type`."""
_register(some_type, function)


def decode_optional(t: Type[T]) -> Callable[[Optional[Any]], Optional[T]]:
def decode_optional(t: type[T]) -> Callable[[Any | None], T | None]:
decode = get_decoding_fn(t)

def _decode_optional(val: Optional[Any]) -> Optional[T]:
def _decode_optional(val: Any | None) -> T | None:
return val if val is None else decode(val)

return _decode_optional


def try_functions(*funcs: Callable[[Any], T]) -> Callable[[Any], Union[T, Any]]:
def try_functions(*funcs: Callable[[Any], T]) -> Callable[[Any], T | Any]:
"""Tries to use the functions in succession, else returns the same value unchanged."""

def _try_functions(val: Any) -> Union[T, Any]:
e: Optional[Exception] = None
def _try_functions(val: Any) -> T | Any:
e: Exception | None = None
for func in funcs:
try:
return func(val)
Expand All @@ -293,30 +259,30 @@ def _try_functions(val: Any) -> Union[T, Any]:
return _try_functions


def decode_union(*types: Type[T]) -> Callable[[Any], Union[T, Any]]:
def decode_union(*types: type[T]) -> Callable[[Any], T | Any]:
types = list(types)
optional = type(None) in types
# Partition the Union into None and non-None types.
while type(None) in types:
types.remove(type(None))

decoding_fns: List[Callable[[Any], T]] = [
decoding_fns: list[Callable[[Any], T]] = [
decode_optional(t) if optional else get_decoding_fn(t) for t in types
]
# Try using each of the non-None types, in succession. Worst case, return the value.
return try_functions(*decoding_fns)


def decode_list(t: Type[T]) -> Callable[[List[Any]], List[T]]:
def decode_list(t: type[T]) -> Callable[[list[Any]], list[T]]:
decode_item = get_decoding_fn(t)

def _decode_list(val: List[Any]) -> List[T]:
def _decode_list(val: list[Any]) -> list[T]:
return [decode_item(v) for v in val]

return _decode_list


def decode_tuple(*tuple_item_types: Type[T]) -> Callable[[List[T]], Tuple[T, ...]]:
def decode_tuple(*tuple_item_types: type[T]) -> Callable[[list[T]], tuple[T, ...]]:
"""Makes a parsing function for creating tuples.
Can handle tuples with different item types, for instance:
Expand All @@ -338,7 +304,7 @@ def decode_tuple(*tuple_item_types: Type[T]) -> Callable[[List[T]], Tuple[T, ...
# Note, if there are more values than types in the tuple type, then the
# last type is used.

def _decode_tuple(val: Tuple[Any, ...]) -> Tuple[T, ...]:
def _decode_tuple(val: tuple[Any, ...]) -> tuple[T, ...]:
if has_ellipsis:
return tuple(decoding_fn(v) for v in val)
else:
Expand All @@ -347,7 +313,7 @@ def _decode_tuple(val: Tuple[Any, ...]) -> Tuple[T, ...]:
return _decode_tuple


def decode_set(item_type: Type[T]) -> Callable[[List[T]], Set[T]]:
def decode_set(item_type: type[T]) -> Callable[[list[T]], set[T]]:
"""Makes a parsing function for creating sets with items of type `item_type`.
Args:
Expand All @@ -359,13 +325,13 @@ def decode_set(item_type: Type[T]) -> Callable[[List[T]], Set[T]]:
# Get the parse fn for a list of items of type `item_type`.
parse_list_fn = decode_list(item_type)

def _decode_set(val: List[Any]) -> Set[T]:
def _decode_set(val: list[Any]) -> set[T]:
return set(parse_list_fn(val))

return _decode_set


def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], Dict[K, V]]:
def decode_dict(K_: type[K], V_: type[V]) -> Callable[[list[tuple[Any, Any]]], dict[K, V]]:
"""Creates a decoding function for a dict type. Works with OrderedDict too.
Args:
Expand All @@ -379,8 +345,8 @@ def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], D
decode_k = get_decoding_fn(K_)
decode_v = get_decoding_fn(V_)

def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V]:
result: Dict[K, V] = {}
def _decode_dict(val: dict[Any, Any] | list[tuple[Any, Any]]) -> dict[K, V]:
result: dict[K, V] = {}
if isinstance(val, list):
result = OrderedDict()
items = val
Expand All @@ -399,7 +365,7 @@ def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V
return _decode_dict


def decode_enum(item_type: Type[Enum]) -> Callable[[str], Enum]:
def decode_enum(item_type: type[Enum]) -> Callable[[str], Enum]:
"""
Creates a decoding function for an enum type.
Expand Down Expand Up @@ -428,7 +394,7 @@ def no_op(v: T) -> T:
return v


def try_constructor(t: Type[T]) -> Callable[[Any], Union[T, Any]]:
def try_constructor(t: type[T]) -> Callable[[Any], T | Any]:
"""Tries to use the type as a constructor. If that fails, returns the value as-is.
Args:
Expand Down
3 changes: 2 additions & 1 deletion simple_parsing/helpers/serialization/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,10 @@ class SimpleSerializable(SerializableMixin, decode_into_subclasses=True):
S = TypeVar("S", bound=SerializableMixin)


def get_dataclass_types_from_forward_ref(
def get_serializable_dataclass_types_from_forward_ref(
forward_ref: type, serializable_base_class: type[S] = SerializableMixin
) -> list[type[S]]:
"""Gets all the subclasses of `serializable_base_class` that have the same name as the argument of this forward reference annotation."""
arg = get_forward_arg(forward_ref)
potential_classes: list[type] = []
for serializable_class in serializable_base_class.subclasses:
Expand Down
7 changes: 4 additions & 3 deletions simple_parsing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ClassVar,
Container,
Dict,
ForwardRef,
Iterable,
List,
Mapping,
Expand All @@ -50,12 +51,12 @@ def get_bound(t):
raise TypeError(f"type is not a `TypeVar`: {t}")


def is_forward_ref(t):
def is_forward_ref(t) -> TypeGuard[typing.ForwardRef]:
return isinstance(t, typing.ForwardRef)


def get_forward_arg(fr):
return getattr(fr, "__forward_arg__", None)
def get_forward_arg(fr: ForwardRef) -> str:
return getattr(fr, "__forward_arg__")


logger = getLogger(__name__)
Expand Down
Loading

0 comments on commit 3122f6e

Please sign in to comment.