diff --git a/CHANGELOG.md b/CHANGELOG.md index d9ea14e..34572a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - Fix issues with caching internal function calls +- Improve non-strict dataclasses union match ## [1.8.1] - 2023-05-12 diff --git a/README.md b/README.md index da32ad7..d612501 100644 --- a/README.md +++ b/README.md @@ -349,7 +349,7 @@ exception. ### Strict unions match `Union` allows to define multiple possible types for a given field. By default -`dacite` is trying to find the first matching type for a provided data and it +`dacite` is trying to find the best matching type for a provided data by number of matching fields and it returns instance of this type. It means that it's possible that there are other matching types further on the `Union` types list. With `strict_unions_match` only a single match is allowed, otherwise `dacite` raises `StrictUnionMatchError`. diff --git a/dacite/core.py b/dacite/core.py index 9e45129..6ec58e3 100644 --- a/dacite/core.py +++ b/dacite/core.py @@ -1,4 +1,5 @@ from dataclasses import is_dataclass +from functools import partial from itertools import zip_longest from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping @@ -120,21 +121,25 @@ def _build_value_for_union(union: Type, data: Any, config: Config) -> Any: except Exception: # pylint: disable=broad-except continue if is_instance(value, inner_type): - if config.strict_unions_match: - union_matches[inner_type] = value - else: - return value + union_matches[inner_type] = value except DaciteError: pass - if config.strict_unions_match: - if len(union_matches) > 1: - raise StrictUnionMatchError(union_matches) - return union_matches.popitem()[1] + if len(union_matches) > 1 and config.strict_unions_match: + raise StrictUnionMatchError(union_matches) + if union_matches: + return union_matches[sorted(union_matches.keys(), key=partial(_field_key_matches, data))[-1]] if not config.check_types: return data raise UnionMatchError(field_type=union, value=data) +def _field_key_matches(data: Any, inner_type: Type) -> int: + if not is_dataclass(inner_type): + return 0 + data_class_fields = cache(get_fields)(inner_type) + return len(set(data.keys()) & {f.name for f in data_class_fields}) + + def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any: data_type = data.__class__ if isinstance(data, Mapping) and is_subclass(collection, Mapping): diff --git a/tests/core/test_union.py b/tests/core/test_union.py index 3f22c18..a670fb4 100644 --- a/tests/core/test_union.py +++ b/tests/core/test_union.py @@ -182,3 +182,26 @@ class Y: result = from_dict(Y, {"d": {"x": {"i": 42}, "z": {"i": 37}}}) assert result == Y(d={"x": X(i=42), "z": X(i=37)}) + + +def test_from_dict_with_union_of_data_classes_selects_type_by_number_of_matching_fields(): + @dataclass + class X: + i: Optional[int] + + @dataclass + class Y: + j: Optional[int] + + @dataclass + class Z: + j: int + k: int + + @dataclass + class A: + d: Union[X, Y, Z] + + result = from_dict(A, {"d": {"j": 42, "k": 42}}) + + assert result == A(d=Z(j=42, k=42))