From 55354a46027e3fa2d432576f1021be044841e8f6 Mon Sep 17 00:00:00 2001 From: Alexander Tikhonov Date: Sat, 23 Nov 2024 19:21:18 +0300 Subject: [PATCH] Add support for recursive Union --- mashumaro/core/meta/types/common.py | 7 +++ mashumaro/core/meta/types/pack.py | 20 +++++++- mashumaro/core/meta/types/unpack.py | 16 ++++++- tests/conftest.py | 1 + tests/test_recursive_union.py | 72 +++++++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 tests/test_recursive_union.py diff --git a/mashumaro/core/meta/types/common.py b/mashumaro/core/meta/types/common.py index 67777462..8366a8ce 100644 --- a/mashumaro/core/meta/types/common.py +++ b/mashumaro/core/meta/types/common.py @@ -60,6 +60,8 @@ def __init__(self, expression: str): class FieldContext: name: str metadata: Mapping + packer: Optional[str] = None + unpacker: Optional[str] = None def copy(self, **changes: Any) -> "FieldContext": return replace(self, **changes) @@ -181,8 +183,13 @@ def _get_call_expr(self, spec: ValueSpec, method_name: str) -> str: def _before_build(self, spec: ValueSpec) -> None: pass + def _get_existing_method(self, spec: ValueSpec) -> Optional[str]: + return None + def build(self, spec: ValueSpec) -> str: self._before_build(spec) + if method := self._get_existing_method(spec): + return method lines = CodeLines() method_name = self._add_definition(spec, lines) with lines.indent(): diff --git a/mashumaro/core/meta/types/pack.py b/mashumaro/core/meta/types/pack.py index b3ff00b6..07aad539 100644 --- a/mashumaro/core/meta/types/pack.py +++ b/mashumaro/core/meta/types/pack.py @@ -289,11 +289,27 @@ def pack_any(spec: ValueSpec) -> Optional[Expression]: def pack_union( spec: ValueSpec, args: tuple[type, ...], prefix: str = "union" ) -> Expression: + if spec.type is spec.owner and spec.field_ctx.packer: + return spec.field_ctx.packer lines = CodeLines() + method_name = ( f"__pack_{prefix}_{spec.builder.cls.__name__}_{spec.field_ctx.name}__" f"{random_hex()}" ) + + if not spec.field_ctx.packer: + method_args = ", ".join( + filter(None, ("value", spec.builder.get_pack_method_flags())) + ) + if spec.builder.is_nailed: + union_packer = ( + f"{spec.self_attrs_name}.{method_name}({method_args})" + ) + else: + union_packer = f"{method_name}({method_args})" + spec.field_ctx.packer = union_packer + method_args = "self, value" if spec.builder.is_nailed else "value" default_kwargs = spec.builder.get_pack_method_default_flag_values() if default_kwargs: @@ -304,7 +320,7 @@ def pack_union( packer_arg_types: dict[str, list[type]] = {} for type_arg in args: packer = PackerRegistry.get( - spec.copy(type=type_arg, expression="value") + spec.copy(type=type_arg, expression="value", owner=spec.type) ) if packer not in packers: if packer == "value": @@ -363,7 +379,9 @@ def pack_union( if spec.builder.get_config().debug: print(f"{type_name(spec.builder.cls)}:") print(lines.as_text()) + exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__) + method_args = ", ".join( filter(None, (spec.expression, spec.builder.get_pack_method_flags())) ) diff --git a/mashumaro/core/meta/types/unpack.py b/mashumaro/core/meta/types/unpack.py index 5120fe46..f783d6a2 100644 --- a/mashumaro/core/meta/types/unpack.py +++ b/mashumaro/core/meta/types/unpack.py @@ -162,11 +162,21 @@ def _get_call_expr(self, spec: ValueSpec, method_name: str) -> str: class UnionUnpackerBuilder(AbstractUnpackerBuilder): def __init__(self, args: tuple[type, ...]): self.union_args = args + self.method_name: str | None = None def get_method_prefix(self) -> str: return "union" + def _generate_method_name(self, spec: ValueSpec) -> str: + method_name = super()._generate_method_name(spec) + self.method_name = method_name + return method_name + def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: + if not spec.field_ctx.unpacker and self.method_name: + spec.field_ctx.unpacker = self._get_call_expr( + spec, self.method_name + ) orig_lines = lines lines = CodeLines() unpackers = set() @@ -175,7 +185,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: type_match_statements = 0 for type_arg in self.union_args: unpacker = UnpackerRegistry.get( - spec.copy(type=type_arg, expression="value") + spec.copy(type=type_arg, expression="value", owner=spec.type) ) type_arg_unpackers.append((type_arg, unpacker)) if isinstance(unpacker, TypeMatchEligibleExpression): @@ -230,6 +240,10 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: orig_lines.append("__value_type = type(value)") orig_lines.extend(lines) + def _get_existing_method(self, spec: ValueSpec) -> Optional[str]: + if spec.owner is spec.type: + return spec.field_ctx.unpacker + class TypeVarUnpackerBuilder(UnionUnpackerBuilder): def get_method_prefix(self) -> str: diff --git a/tests/conftest.py b/tests/conftest.py index 3dbe5f7b..97783728 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ collect_ignore = [ "test_generics_pep_695.py", "test_pep_695.py", + "test_recursive_union.py", ] if PY_313_MIN: diff --git a/tests/test_recursive_union.py b/tests/test_recursive_union.py new file mode 100644 index 00000000..189bd4f1 --- /dev/null +++ b/tests/test_recursive_union.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass + +from mashumaro import DataClassDictMixin +from mashumaro.codecs import BasicDecoder, BasicEncoder + +type JSON = str | int | float | bool | dict[str, JSON] | list[JSON] | None + + +@dataclass +class MyClass: + x: str + y: JSON + + +def test_encoder_with_recursive_union(): + encoder = BasicEncoder(JSON) + assert encoder.encode( + {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]} + ) == {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]} + + +def test_encoder_with_recursive_union_in_dataclass(): + encoder = BasicEncoder(MyClass) + assert encoder.encode( + MyClass( + x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]} + ) + ) == { + "x": "x", + "y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}, + } + + +def test_decoder_with_recursive_union(): + decoder = BasicDecoder(JSON) + assert decoder.decode( + {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]} + ) == {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]} + + +def test_decoder_with_recursive_union_in_dataclass(): + decoder = BasicDecoder(MyClass) + assert decoder.decode( + { + "x": "x", + "y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}, + } + ) == MyClass( + x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]} + ) + + +def test_dataclass_dict_mixin_with_recursive_union(): + @dataclass + class MyClassWithMixin(DataClassDictMixin): + x: str + y: JSON + + assert MyClassWithMixin( + x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]} + ).to_dict() == { + "x": "x", + "y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}, + } + assert MyClassWithMixin.from_dict( + { + "x": "x", + "y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}, + } + ) == MyClassWithMixin( + x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]} + )