From 252077281f5680f0aa80970bc6c318a058ab3aaa Mon Sep 17 00:00:00 2001 From: Maciej Olko Date: Tue, 22 Oct 2024 20:42:30 +0200 Subject: [PATCH 1/8] replace fast loop exit in _build_value_for_union() to poping from union_matches at the end as base for checking the matches when strict_unions_match is false --- dacite/core.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/dacite/core.py b/dacite/core.py index 9e45129..919a0c4 100644 --- a/dacite/core.py +++ b/dacite/core.py @@ -120,16 +120,14 @@ def _build_value_for_union(union: Type, data: Any, config: Config) -> Any: except Exception: # pylint: disable=broad-except continue if is_instance(value, inner_type): - if config.strict_unions_match: - union_matches[inner_type] = value - else: - return value + union_matches[inner_type] = value except DaciteError: pass - if config.strict_unions_match: - if len(union_matches) > 1: - raise StrictUnionMatchError(union_matches) - return union_matches.popitem()[1] + if len(union_matches) > 1 and config.strict_unions_match: + raise StrictUnionMatchError(union_matches) + if union_matches: + k = next(iter(union_matches)) + return union_matches.pop(k) if not config.check_types: return data raise UnionMatchError(field_type=union, value=data) From 01a6dc9447a6549cff87ad6fee05b35414cd8fe1 Mon Sep 17 00:00:00 2001 From: Maciej Olko Date: Tue, 22 Oct 2024 20:43:02 +0200 Subject: [PATCH 2/8] stub of test for the issue --- tests/core/test_union.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/core/test_union.py b/tests/core/test_union.py index 3f22c18..d0b235b 100644 --- a/tests/core/test_union.py +++ b/tests/core/test_union.py @@ -182,3 +182,21 @@ class Y: result = from_dict(Y, {"d": {"x": {"i": 42}, "z": {"i": 37}}}) assert result == Y(d={"x": X(i=42), "z": X(i=37)}) + + +def test(): + @dataclass + class X: + i: Optional[int] + + @dataclass + class Y: + j: int + + @dataclass + class Z: + d: Union[X, Y] + + result = from_dict(Z, {"d": {"j": 42}}) + + assert result == Z(d=Y(j=42)) From 846b7f6a908e3e6ee01a43c0c30a3ab1367cdf77 Mon Sep 17 00:00:00 2001 From: Maciej Olko Date: Tue, 22 Oct 2024 22:27:52 +0200 Subject: [PATCH 3/8] get the match by most matching fields --- dacite/core.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dacite/core.py b/dacite/core.py index 919a0c4..779e219 100644 --- a/dacite/core.py +++ b/dacite/core.py @@ -1,4 +1,5 @@ from dataclasses import is_dataclass +from functools import partial from itertools import zip_longest from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any, Collection, MutableMapping @@ -126,13 +127,19 @@ def _build_value_for_union(union: Type, data: Any, config: Config) -> Any: if len(union_matches) > 1 and config.strict_unions_match: raise StrictUnionMatchError(union_matches) if union_matches: - k = next(iter(union_matches)) - return union_matches.pop(k) + return union_matches[sorted(union_matches.keys(), key=partial(_field_key_matches, data))[0]] if not config.check_types: return data raise UnionMatchError(field_type=union, value=data) +def _field_key_matches(data: Any, inner_type: Type) -> int: + if not is_dataclass(inner_type): + return 0 + data_class_fields = cache(get_fields)(inner_type) + return len(set(data.keys()) | {f.name for f in data_class_fields}) + + def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any: data_type = data.__class__ if isinstance(data, Mapping) and is_subclass(collection, Mapping): From f19a50715755cf8fe0bb678a260618d59d35897d Mon Sep 17 00:00:00 2001 From: Maciej Olko Date: Tue, 22 Oct 2024 22:30:54 +0200 Subject: [PATCH 4/8] rename the test --- tests/core/test_union.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_union.py b/tests/core/test_union.py index d0b235b..8d83551 100644 --- a/tests/core/test_union.py +++ b/tests/core/test_union.py @@ -184,7 +184,7 @@ class Y: assert result == Y(d={"x": X(i=42), "z": X(i=37)}) -def test(): +def test_from_dict_with_union_of_data_classes_selects_type_by_number_of_matching_fields(): @dataclass class X: i: Optional[int] From 39d01b7bcd2cf442c4c7a19091f009f9c83e95e7 Mon Sep 17 00:00:00 2001 From: Maciej Olko Date: Wed, 23 Oct 2024 13:15:00 +0200 Subject: [PATCH 5/8] fix the ordering function and selecting the match the sort orders ascending --- dacite/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dacite/core.py b/dacite/core.py index 779e219..6ec58e3 100644 --- a/dacite/core.py +++ b/dacite/core.py @@ -127,7 +127,7 @@ def _build_value_for_union(union: Type, data: Any, config: Config) -> Any: if len(union_matches) > 1 and config.strict_unions_match: raise StrictUnionMatchError(union_matches) if union_matches: - return union_matches[sorted(union_matches.keys(), key=partial(_field_key_matches, data))[0]] + return union_matches[sorted(union_matches.keys(), key=partial(_field_key_matches, data))[-1]] if not config.check_types: return data raise UnionMatchError(field_type=union, value=data) @@ -137,7 +137,7 @@ def _field_key_matches(data: Any, inner_type: Type) -> int: if not is_dataclass(inner_type): return 0 data_class_fields = cache(get_fields)(inner_type) - return len(set(data.keys()) | {f.name for f in data_class_fields}) + return len(set(data.keys()) & {f.name for f in data_class_fields}) def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any: From 7f1d246b14525fa0d3bec69874224ff42b142af0 Mon Sep 17 00:00:00 2001 From: Maciej Olko Date: Wed, 23 Oct 2024 13:39:08 +0200 Subject: [PATCH 6/8] extend the test to better detect an issue in data type selection covers regression I've introduced unintentionally --- tests/core/test_union.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/core/test_union.py b/tests/core/test_union.py index 8d83551..a670fb4 100644 --- a/tests/core/test_union.py +++ b/tests/core/test_union.py @@ -191,12 +191,17 @@ class X: @dataclass class Y: - j: int + j: Optional[int] @dataclass class Z: - d: Union[X, Y] + j: int + k: int + + @dataclass + class A: + d: Union[X, Y, Z] - result = from_dict(Z, {"d": {"j": 42}}) + result = from_dict(A, {"d": {"j": 42, "k": 42}}) - assert result == Z(d=Y(j=42)) + assert result == A(d=Z(j=42, k=42)) From d100836868b42e6da001bdce23d6040c72375cfe Mon Sep 17 00:00:00 2001 From: Maciej Olko Date: Thu, 24 Oct 2024 09:03:36 +0200 Subject: [PATCH 7/8] README docs update --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index da32ad7..d612501 100644 --- a/README.md +++ b/README.md @@ -349,7 +349,7 @@ exception. ### Strict unions match `Union` allows to define multiple possible types for a given field. By default -`dacite` is trying to find the first matching type for a provided data and it +`dacite` is trying to find the best matching type for a provided data by number of matching fields and it returns instance of this type. It means that it's possible that there are other matching types further on the `Union` types list. With `strict_unions_match` only a single match is allowed, otherwise `dacite` raises `StrictUnionMatchError`. From dc331d27bea75738fef443b494c47ff854901756 Mon Sep 17 00:00:00 2001 From: Maciej Olko Date: Thu, 24 Oct 2024 10:26:11 +0200 Subject: [PATCH 8/8] add changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9ea14e..34572a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - Fix issues with caching internal function calls +- Improve non-strict dataclasses union match ## [1.8.1] - 2023-05-12