diff --git a/edgy/core/db/fields/list_foreign_key.py b/edgy/core/db/fields/list_foreign_key.py index 3fee5415..1bd27956 100644 --- a/edgy/core/db/fields/list_foreign_key.py +++ b/edgy/core/db/fields/list_foreign_key.py @@ -1,6 +1,11 @@ +import typing from functools import cached_property +from inspect import isclass from typing import TYPE_CHECKING, Any, TypeVar +from typing_extensions import get_origin + +import edgy from edgy.core.connection.registry import Registry from edgy.core.db.constants import CASCADE, RESTRICT from edgy.core.db.fields._base_fk import BaseField, BaseForeignKey @@ -80,12 +85,6 @@ def target(self) -> Any: """ The target of the ForeignKey model. """ - from edgy.core.db.models.model_reference import ModelRef - - if not issubclass(self.to, ModelRef): - raise ModelReferenceError( - detail="A model reference must be an object of type ModelRef" - ) if not hasattr(self, "_target"): if isinstance(self.to.__model__, str): self._target = self.registry.models[self.to.__model__] # type: ignore @@ -97,13 +96,29 @@ def target(self) -> Any: class ListForeignKey(ForeignKeyFieldFactory, list): + @classmethod + def is_class_and_subclass(cls, value: typing.Any, _type: typing.Any) -> bool: + original = get_origin(value) + if not original and not isclass(value): + return False + + try: + if original: + return original and issubclass(original, _type) + return issubclass(value, _type) + except TypeError: + return False + def __new__( # type: ignore cls, to: "ModelRef", null: bool = False, ) -> BaseField: + if not cls.is_class_and_subclass(to, edgy.ModelRef): + raise ModelReferenceError( + detail="A model reference must be an object of type ModelRef" + ) kwargs = { **{key: value for key, value in locals().items() if key not in CLASS_DEFAULTS}, } - return super().__new__(cls, **kwargs) diff --git a/edgy/core/utils/models.py b/edgy/core/utils/models.py index cdec7f0e..7da7f857 100644 --- a/edgy/core/utils/models.py +++ b/edgy/core/utils/models.py @@ -59,6 +59,7 @@ def extract_model_references( model_references = { name: extracted_values.get(name, None) for name in model_class.meta.model_references.keys() # type: ignore + if extracted_values.get(name) } return model_references diff --git a/tests/foreign_keys/test_list_foreignkey.py b/tests/foreign_keys/test_list_foreignkey.py index 527f7e76..a8d1865f 100644 --- a/tests/foreign_keys/test_list_foreignkey.py +++ b/tests/foreign_keys/test_list_foreignkey.py @@ -1,8 +1,10 @@ import pytest +from pydantic import __version__ from tests.settings import DATABASE_URL import edgy from edgy import ModelRef +from edgy.exceptions import ModelReferenceError from edgy.testclient import DatabaseTestClient as Database pytestmark = pytest.mark.anyio @@ -10,6 +12,8 @@ database = Database(DATABASE_URL) models = edgy.Registry(database=database) +pydantic_version = __version__[:3] + class TrackModelRef(ModelRef): __model__ = "Track" @@ -36,6 +40,24 @@ class Meta: registry = models +class PostRef(ModelRef): + __model__ = "Post" + comment: str + + +class Post(edgy.Model): + user = edgy.ForeignKey("User") + comment = edgy.CharField(max_length=255) + + +class User(edgy.Model): + name = edgy.CharField(max_length=100, null=True) + posts = edgy.ListForeignKey(PostRef) + + class Meta: + registry = models + + @pytest.fixture(autouse=True, scope="module") async def create_test_database(): await models.create_all() @@ -178,3 +200,34 @@ async def test_on_delete_cascade(): await album.delete() assert await Track.query.count() == 0 + + +@pytest.mark.parametrize( + "to", + [1, {"id": 2}, [3], [4, [4]], Track], + ids=["int", "dict", "list", "list-of_lists", "model"], +) +async def test_raises_model_reference_error(to): + with pytest.raises(ModelReferenceError): + + class User(edgy.Model): + name = edgy.CharField(max_length=100) + users = edgy.ListForeignKey(to, null=True) + + class Meta: + registry = models + + +async def test_raise_value_error_on_missing_model_fields(): + with pytest.raises(ValueError) as raised: + await User.query.create() + + assert raised.value.errors() == [ + { + "type": "missing", + "loc": ("posts",), + "msg": "Field required", + "input": {}, + "url": f"https://errors.pydantic.dev/{pydantic_version}/v/missing", + } + ]