Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for union discriminators different from class names #87

Closed
wants to merge 1 commit into from
Closed

Allow for union discriminators different from class names #87

wants to merge 1 commit into from

Conversation

Philipp-Userlike
Copy link

I work with an application that uses lots of nested JSON structures which are modelled with dataclass unions on our end.
Union discriminators are often part of the message payload and cannot be arbitrarily defined. Often they are short lowercase words which cannot be made to class names, since that would violate python conventions.

For example a dataclass could look like this. It is not possible to change the 'type' Literals, since they are part of an API specification.

@dataclasses.dataclass
class Wood:
    species: str
    type: Literal["wood"] = "wood"


@dataclasses.dataclass
class Steel:
    alloy: str
    type: Literal["steel"] = "steel"


@dataclasses.dataclass
class Building:
    material: typing.Union[Wood, Steel]

From what I see, to make those dataclasses work with serializers, I would have to subclass the UnionField, override get_discriminator and explicitly define this on my serializer:
This is already pretty verbose and it would get worse when there is further nesting, since I always have to explicitly redefine the nested structure as a serializer.

It would be nice to have a way on the dataclass to define its discriminator. I made an attempt here to add this by having a special attribute on the dataclass, but I am not sure if this is the best solution.

@oxan
Copy link
Owner

oxan commented Sep 21, 2023

From what I see, to make those dataclasses work with serializers, I would have to subclass the UnionField, override get_discriminator and explicitly define this on my serializer:

Defining your custom field on every serializer is not the only solution here. You can also subclass DataclassSerializer and use that serializer instead (see the documentation):

class MyUnionField(UnionField):
    def get_discriminator(self, tp: type) -> str:
        return getattr(tp, "serializer_discriminant", tp.__name__)

class MyDataclassSerializer(DataclassSerializer):
    serializer_union_field = MyUnionField

    @property
    def serializer_dataclass_field(self):
        return MyDataclassSerializer

If you want this policy to be used for your whole project, you can also monkey-patch the serializer_union_field property of DataclassSerializer itself.

Having said that, I do agree that having a nicer way to specify the discriminant for a type would be good, but I'm also not sure if this is the best solution. I'll give it some thought.

@Philipp-Userlike
Copy link
Author

Thanks for the reply! I got this working now with patching, but in general I noticed that the approach of serializing type-information as discriminator only works well when you have full control over the protocol.
If you create dataclasses to model an existing protocol, you run into some issues.

I think for that an approach similar to dacite would work better for deserialization of payload into dataclasses:
https://github.com/konradhalas/dacite/blob/10a9ec40fc5874ae434aa68b975d1b1bf667a42f/dacite/core.py#L110C28-L110C28
They simply try to initialize every member of the union with the payload value, skipping values where exceptions happen.
There is a strict_unions_match mode that allows to make sure only one union member is successfully initialized.

@Philipp-Userlike
Copy link
Author

The requirement to have union members be a mapping was also a problem for me, when payloads had unions of different primitive values.
I saw the solution with nest_value, but this again requires control over the payload protocol definition.

@oxan
Copy link
Owner

oxan commented Oct 8, 2023

If you create dataclasses to model an existing protocol, you run into some issues.

That doesn't surprise me, as it's not a common usecase.

They simply try to initialize every member of the union with the payload value, skipping values where exceptions happen.

That sounds quite fragile to me. It doesn't work for values that can initialize more than one union member, and could also cause side-effects from the serializers and constructors of the non-active members. I very much prefer an explicit solution.
You can of course always implement your own, custom UnionField that initializes this way.

@codebutler
Copy link

codebutler commented Oct 30, 2023

Hi, what are the concerns with using Literal?

Here is how I monkey-patched this for reference:

from typing import get_args, get_origin, Literal

_orig_get_discriminator = UnionField.get_discriminator


def _get_discriminator(self, tp: type):
    if self.discriminator_field_name not in tp.__annotations__:
        return _orig_get_discriminator(self, tp)
    literal_hint = tp.__annotations__[self.discriminator_field_name]
    if get_origin(literal_hint) is not Literal:
        raise TypeError(
            f"{tp} has a {self.discriminator_field_name} attribute that is not a Literal"
        )
    return get_args(literal_hint)[0]


UnionField.get_discriminator = _get_discriminator

@oxan
Copy link
Owner

oxan commented Oct 31, 2023

Hi, what are the concerns with using Literal?

I don't like that it creates an extra field on the dataclass, which mixes actual data with metadata. It also allows you to do nonsensical things such as Steel(type="wood").

The three options I'm currently considering for custom discriminators are:

  • Placing the discriminator value in a decorator on the value class. Cons: implementation complexity, discriminator must be the same everywhere.
    @dataclass
    @discriminator('my_dog')
    class Dog:
        # ...
  • Placing the discriminator values in the field metadata. Cons: clumsy syntax, requires writing the discriminators everywhere the classes are used.
    obj: Dog | Cat = dataclasses.field(metadata={'serializer_kwargs': {'discriminator_values': {Dog: 'dog', Cat: 'cat'}}})
  • Using typing.Annotated. Cons: somewhat clumsy syntax.
    obj: Annotated[Dog, Discriminator('dog')] | Annotated[Cat, Discriminator('cat')]

I'm not convinced of any of them yet, though.

@userlike userlike closed this by deleting the head repository Jul 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants