Skip to content

Commit

Permalink
Add support for recursive Union
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Nov 23, 2024
1 parent 89f12f8 commit 55354a4
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 2 deletions.
7 changes: 7 additions & 0 deletions mashumaro/core/meta/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(self, expression: str):
class FieldContext:
name: str
metadata: Mapping
packer: Optional[str] = None
unpacker: Optional[str] = None

def copy(self, **changes: Any) -> "FieldContext":
return replace(self, **changes)
Expand Down Expand Up @@ -181,8 +183,13 @@ def _get_call_expr(self, spec: ValueSpec, method_name: str) -> str:
def _before_build(self, spec: ValueSpec) -> None:
pass

def _get_existing_method(self, spec: ValueSpec) -> Optional[str]:
return None

def build(self, spec: ValueSpec) -> str:
self._before_build(spec)
if method := self._get_existing_method(spec):
return method
lines = CodeLines()
method_name = self._add_definition(spec, lines)
with lines.indent():
Expand Down
20 changes: 19 additions & 1 deletion mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,27 @@ def pack_any(spec: ValueSpec) -> Optional[Expression]:
def pack_union(
spec: ValueSpec, args: tuple[type, ...], prefix: str = "union"
) -> Expression:
if spec.type is spec.owner and spec.field_ctx.packer:
return spec.field_ctx.packer
lines = CodeLines()

method_name = (
f"__pack_{prefix}_{spec.builder.cls.__name__}_{spec.field_ctx.name}__"
f"{random_hex()}"
)

if not spec.field_ctx.packer:
method_args = ", ".join(
filter(None, ("value", spec.builder.get_pack_method_flags()))
)
if spec.builder.is_nailed:
union_packer = (
f"{spec.self_attrs_name}.{method_name}({method_args})"
)
else:
union_packer = f"{method_name}({method_args})"
spec.field_ctx.packer = union_packer

method_args = "self, value" if spec.builder.is_nailed else "value"
default_kwargs = spec.builder.get_pack_method_default_flag_values()
if default_kwargs:
Expand All @@ -304,7 +320,7 @@ def pack_union(
packer_arg_types: dict[str, list[type]] = {}
for type_arg in args:
packer = PackerRegistry.get(
spec.copy(type=type_arg, expression="value")
spec.copy(type=type_arg, expression="value", owner=spec.type)
)
if packer not in packers:
if packer == "value":
Expand Down Expand Up @@ -363,7 +379,9 @@ def pack_union(
if spec.builder.get_config().debug:
print(f"{type_name(spec.builder.cls)}:")
print(lines.as_text())

exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)

method_args = ", ".join(
filter(None, (spec.expression, spec.builder.get_pack_method_flags()))
)
Expand Down
16 changes: 15 additions & 1 deletion mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,21 @@ def _get_call_expr(self, spec: ValueSpec, method_name: str) -> str:
class UnionUnpackerBuilder(AbstractUnpackerBuilder):
def __init__(self, args: tuple[type, ...]):
self.union_args = args
self.method_name: str | None = None

def get_method_prefix(self) -> str:
return "union"

def _generate_method_name(self, spec: ValueSpec) -> str:
method_name = super()._generate_method_name(spec)
self.method_name = method_name
return method_name

def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
if not spec.field_ctx.unpacker and self.method_name:
spec.field_ctx.unpacker = self._get_call_expr(
spec, self.method_name
)
orig_lines = lines
lines = CodeLines()
unpackers = set()
Expand All @@ -175,7 +185,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
type_match_statements = 0
for type_arg in self.union_args:
unpacker = UnpackerRegistry.get(
spec.copy(type=type_arg, expression="value")
spec.copy(type=type_arg, expression="value", owner=spec.type)
)
type_arg_unpackers.append((type_arg, unpacker))
if isinstance(unpacker, TypeMatchEligibleExpression):
Expand Down Expand Up @@ -230,6 +240,10 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
orig_lines.append("__value_type = type(value)")
orig_lines.extend(lines)

def _get_existing_method(self, spec: ValueSpec) -> Optional[str]:
if spec.owner is spec.type:
return spec.field_ctx.unpacker


class TypeVarUnpackerBuilder(UnionUnpackerBuilder):
def get_method_prefix(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
collect_ignore = [
"test_generics_pep_695.py",
"test_pep_695.py",
"test_recursive_union.py",
]

if PY_313_MIN:
Expand Down
72 changes: 72 additions & 0 deletions tests/test_recursive_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dataclasses import dataclass

from mashumaro import DataClassDictMixin
from mashumaro.codecs import BasicDecoder, BasicEncoder

type JSON = str | int | float | bool | dict[str, JSON] | list[JSON] | None


@dataclass
class MyClass:
x: str
y: JSON


def test_encoder_with_recursive_union():
encoder = BasicEncoder(JSON)
assert encoder.encode(
{"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
) == {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}


def test_encoder_with_recursive_union_in_dataclass():
encoder = BasicEncoder(MyClass)
assert encoder.encode(
MyClass(
x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
)
) == {
"x": "x",
"y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]},
}


def test_decoder_with_recursive_union():
decoder = BasicDecoder(JSON)
assert decoder.decode(
{"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
) == {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}


def test_decoder_with_recursive_union_in_dataclass():
decoder = BasicDecoder(MyClass)
assert decoder.decode(
{
"x": "x",
"y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]},
}
) == MyClass(
x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
)


def test_dataclass_dict_mixin_with_recursive_union():
@dataclass
class MyClassWithMixin(DataClassDictMixin):
x: str
y: JSON

assert MyClassWithMixin(
x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
).to_dict() == {
"x": "x",
"y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]},
}
assert MyClassWithMixin.from_dict(
{
"x": "x",
"y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]},
}
) == MyClassWithMixin(
x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
)

0 comments on commit 55354a4

Please sign in to comment.