Skip to content

Commit

Permalink
Refactor InternallyTaggedUnion resolution into UnionModel
Browse files Browse the repository at this point in the history
  • Loading branch information
object-Object committed May 22, 2024
1 parent 0acf8de commit 6d0d831
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 73 deletions.
2 changes: 2 additions & 0 deletions src/hexdoc/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"TemplateModel",
"TypeTaggedTemplate",
"TypeTaggedUnion",
"UnionModel",
"ValidationContextModel",
"init_context",
]
Expand All @@ -41,5 +42,6 @@
TemplateModel,
TypeTaggedTemplate,
TypeTaggedUnion,
UnionModel,
)
from .types import Color
189 changes: 116 additions & 73 deletions src/hexdoc/model/tagged_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@

from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, ClassVar, Generator, Self, Unpack
from typing import (
Any,
ClassVar,
Generator,
Iterable,
LiteralString,
Self,
Unpack,
)

import more_itertools
from pydantic import (
Expand Down Expand Up @@ -40,7 +48,96 @@
_is_loaded = False


class InternallyTaggedUnion(HexdocModel):
class UnionModel(HexdocModel):
@classmethod
def _resolve_union(
cls,
value: Any,
info: ValidationInfo,
*,
model_types: Iterable[type[Self]],
allow_ambiguous: bool,
error_name: LiteralString = "HexdocUnionMatchError",
error_text: Iterable[LiteralString] = [],
error_data: dict[str, Any] = {},
) -> Self:
# try all the types
exceptions: list[InitErrorDetails] = []
matches: dict[type[Self], Self] = {}

for model_type in model_types:
try:
result = matches[model_type] = model_type.model_validate(
value,
context=info.context,
)
if allow_ambiguous:
return result
except Exception as e:
exceptions.append(
InitErrorDetails(
type=PydanticCustomError(
error_name,
"{exception_class}: {exception}",
{
"exception_class": e.__class__.__name__,
"exception": str(e),
},
),
loc=(
cls.__name__,
model_type.__name__,
),
input=value,
)
)

# ensure we only matched one
# if allow_ambiguous is True, we should have returned a value already
match len(matches):
case 1:
return matches.popitem()[1]
case x if x > 1:
ambiguous_types = ", ".join(str(t) for t in matches.keys())
reason = f"Ambiguous union match: {ambiguous_types}"
case _:
reason = "No match found"

# something went wrong, raise an exception
error = PydanticCustomError(
f"{error_name}Group",
"\n ".join(
(
"Failed to match union {class_name}: {reason}",
"Types: {types}",
"Value: {value}",
*error_text,
)
),
{
"class_name": str(cls),
"reason": reason,
"types": ", ".join(str(t) for t in model_types),
"value": repr(value),
**error_data,
},
)

if exceptions:
exceptions.insert(
0,
InitErrorDetails(
type=error,
loc=(cls.__name__,),
input=value,
),
)
raise ValidationError.from_exception_data(error_name, exceptions)

raise RuntimeError(str(error))


class InternallyTaggedUnion(UnionModel):
"""Implements [internally tagged unions](https://serde.rs/enum-representations.html#internally-tagged)
using the [Registry pattern](https://charlesreid1.github.io/python-patterns-the-registry.html).
Expand Down Expand Up @@ -153,79 +250,25 @@ def _resolve_from_dict(
if tag_types is None:
raise TypeError(f"Unhandled tag: {tag_key}={tag_value} for {cls}: {data}")

# try all the types
exceptions: list[InitErrorDetails] = []
matches: dict[type[Self], Self] = {}

for inner_type in tag_types:
try:
matches[inner_type] = inner_type.model_validate(
data, context=info.context
)
except Exception as e:
exceptions.append(
InitErrorDetails(
type=PydanticCustomError(
"TaggedUnionMatchError",
"{exception_class}: {exception}",
{
"exception_class": e.__class__.__name__,
"exception": str(e),
},
),
loc=(
cls.__name__,
inner_type.__name__,
),
input=data,
)
)

# ensure we only matched one
match len(matches):
case 1:
return matches.popitem()[1]
case x if x > 1:
ambiguous_types = ", ".join(str(t) for t in matches.keys())
reason = f"Ambiguous union match: {ambiguous_types}"
case _:
reason = "No match found"

# something went wrong, raise an exception
error = PydanticCustomError(
"TaggedUnionMatchErrorGroup",
(
"Failed to match tagged union {class_name}: {reason}\n"
" Tag: {tag_key}={tag_value}\n"
" Types: {types}\n"
" Data: {data}"
),
{
"class_name": str(cls),
"reason": reason,
"tag_key": cls._tag_key,
"tag_value": tag_value,
"types": ", ".join(str(t) for t in tag_types),
"data": repr(data),
},
)

if exceptions:
try:
return cls._resolve_union(
data,
info,
model_types=tag_types,
allow_ambiguous=False,
error_name="TaggedUnionMatchError",
error_text=[
"Tag: {tag_key}={tag_value}",
],
error_data={
"tag_key": cls._tag_key,
"tag_value": tag_value,
},
)
except Exception:
if _RESOLVED in data:
data.pop(_RESOLVED) # avoid interfering with other types
exceptions.insert(
0,
InitErrorDetails(
type=error,
loc=(cls.__name__,),
input=data,
),
)
raise ValidationError.from_exception_data(
"TaggedUnionMatchError", exceptions
)

raise RuntimeError(str(error))
raise

@model_validator(mode="before")
def _pop_temporary_keys(cls, value: dict[Any, Any] | Any):
Expand Down

0 comments on commit 6d0d831

Please sign in to comment.