Skip to content

Commit

Permalink
Add more unit testing covering ListForeignKey
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Aug 9, 2023
1 parent e75c16c commit d20616e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 7 deletions.
29 changes: 22 additions & 7 deletions edgy/core/db/fields/list_foreign_key.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import typing
from functools import cached_property
from inspect import isclass
from typing import TYPE_CHECKING, Any, TypeVar

from typing_extensions import get_origin

import edgy
from edgy.core.connection.registry import Registry
from edgy.core.db.constants import CASCADE, RESTRICT
from edgy.core.db.fields._base_fk import BaseField, BaseForeignKey
Expand Down Expand Up @@ -80,12 +85,6 @@ def target(self) -> Any:
"""
The target of the ForeignKey model.
"""
from edgy.core.db.models.model_reference import ModelRef

if not issubclass(self.to, ModelRef):
raise ModelReferenceError(
detail="A model reference must be an object of type ModelRef"
)
if not hasattr(self, "_target"):
if isinstance(self.to.__model__, str):
self._target = self.registry.models[self.to.__model__] # type: ignore
Expand All @@ -97,13 +96,29 @@ def target(self) -> Any:


class ListForeignKey(ForeignKeyFieldFactory, list):
@classmethod
def is_class_and_subclass(cls, value: typing.Any, _type: typing.Any) -> bool:
original = get_origin(value)
if not original and not isclass(value):
return False

try:
if original:
return original and issubclass(original, _type)
return issubclass(value, _type)
except TypeError:
return False

def __new__( # type: ignore
cls,
to: "ModelRef",
null: bool = False,
) -> BaseField:
if not cls.is_class_and_subclass(to, edgy.ModelRef):
raise ModelReferenceError(
detail="A model reference must be an object of type ModelRef"
)
kwargs = {
**{key: value for key, value in locals().items() if key not in CLASS_DEFAULTS},
}

return super().__new__(cls, **kwargs)
1 change: 1 addition & 0 deletions edgy/core/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def extract_model_references(
model_references = {
name: extracted_values.get(name, None)
for name in model_class.meta.model_references.keys() # type: ignore
if extracted_values.get(name)
}
return model_references

Expand Down
53 changes: 53 additions & 0 deletions tests/foreign_keys/test_list_foreignkey.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import pytest
from pydantic import __version__
from tests.settings import DATABASE_URL

import edgy
from edgy import ModelRef
from edgy.exceptions import ModelReferenceError
from edgy.testclient import DatabaseTestClient as Database

pytestmark = pytest.mark.anyio

database = Database(DATABASE_URL)
models = edgy.Registry(database=database)

pydantic_version = __version__[:3]


class TrackModelRef(ModelRef):
__model__ = "Track"
Expand All @@ -36,6 +40,24 @@ class Meta:
registry = models


class PostRef(ModelRef):
__model__ = "Post"
comment: str


class Post(edgy.Model):
user = edgy.ForeignKey("User")
comment = edgy.CharField(max_length=255)


class User(edgy.Model):
name = edgy.CharField(max_length=100, null=True)
posts = edgy.ListForeignKey(PostRef)

class Meta:
registry = models


@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
await models.create_all()
Expand Down Expand Up @@ -178,3 +200,34 @@ async def test_on_delete_cascade():
await album.delete()

assert await Track.query.count() == 0


@pytest.mark.parametrize(
"to",
[1, {"id": 2}, [3], [4, [4]], Track],
ids=["int", "dict", "list", "list-of_lists", "model"],
)
async def test_raises_model_reference_error(to):
with pytest.raises(ModelReferenceError):

class User(edgy.Model):
name = edgy.CharField(max_length=100)
users = edgy.ListForeignKey(to, null=True)

class Meta:
registry = models


async def test_raise_value_error_on_missing_model_fields():
with pytest.raises(ValueError) as raised:
await User.query.create()

assert raised.value.errors() == [
{
"type": "missing",
"loc": ("posts",),
"msg": "Field required",
"input": {},
"url": f"https://errors.pydantic.dev/{pydantic_version}/v/missing",
}
]

0 comments on commit d20616e

Please sign in to comment.