Skip to content

Commit

Permalink
New cleaned up tryout without annotated+optional hack
Browse files Browse the repository at this point in the history
  • Loading branch information
sveinugu committed Sep 2, 2024
1 parent 01995a8 commit b64ccfc
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 51 deletions.
16 changes: 9 additions & 7 deletions src/omnipy/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from omnipy.util.web import download_file_to_memory

ModelT = TypeVar('ModelT', bound=Model)
# ModelT = TypeVar('ModelT', bound=Model, default=Model[object])
GeneralModelT = TypeVar('GeneralModelT', bound=Model)
_DatasetT = TypeVar('_DatasetT')

Expand Down Expand Up @@ -242,14 +243,13 @@ def _get_data_field(cls) -> ModelField:
return cast(ModelField, cls.__fields__.get(DATA_KEY))

@classmethod
def get_model_class(cls) -> Type[Model]:
def get_model_class(cls) -> type[Model] | None:
"""
Returns the concrete Model class used for all data files in the dataset, e.g.:
`Model[list[int]]`
:return: The concrete Model class used for all data files in the dataset
"""
model_type = cls._get_data_field().type_
return model_type
return cls._get_data_field().type_

@staticmethod
def _raise_no_model_exception() -> None:
Expand Down Expand Up @@ -315,7 +315,9 @@ def update_forward_refs(cls, **localns: Any) -> None:
"""
Try to update ForwardRefs on fields based on this Model, globalns and localns.
"""
cls.get_model_class().update_forward_refs(**localns) # Update Model cls
model_cls = cls.get_model_class()
# if model_cls:
model_cls.update_forward_refs(**localns) # Update Model cls
super().update_forward_refs(**localns)
cls.__name__ = remove_forward_ref_notation(cls.__name__)
cls.__qualname__ = remove_forward_ref_notation(cls.__qualname__)
Expand Down Expand Up @@ -618,9 +620,9 @@ def _to_data_if_model(data_obj: Any):
return data_obj


_KwargValT = TypeVar('_KwargValT', bound=object)
_ParamModelT = TypeVar('_ParamModelT', bound=ParamModel)
_ListOfParamModelT = TypeVar('_ListOfParamModelT', bound=ListOfParamModel)
_KwargValT = TypeVar('_KwargValT', bound=object, default=object)
_ParamModelT = TypeVar('_ParamModelT', bound=ParamModel, default=ParamModel)
_ListOfParamModelT = TypeVar('_ListOfParamModelT', bound=ListOfParamModel, default=ListOfParamModel)

ParamModelSuperKwargsType: TypeAlias = \
dict[str, dict[str, _ParamModelT | DataWithParams[_ParamModelT, _KwargValT]]]
Expand Down
55 changes: 33 additions & 22 deletions src/omnipy/data/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pydantic import NoneIsNotAllowedError
from pydantic import Protocol as PydanticProtocol
from pydantic import root_validator, ValidationError
from pydantic.fields import DeferredType, ModelField, Undefined, UndefinedType
from pydantic.fields import DeferredType, Field, ModelField, Undefined, UndefinedType
from pydantic.generics import GenericModel
from pydantic.main import BaseModel, ModelMetaclass, validate_model
from pydantic.typing import display_as_type, is_none_type
Expand Down Expand Up @@ -70,7 +70,9 @@
_IterT = TypeVar('_IterT')
_ReturnT = TypeVar('_ReturnT')
_IdxT = TypeVar('_IdxT', bound=SupportsIndex)
_RootT = TypeVar('_RootT', bound=object | None, default=object)
# _RootT = TypeVar('_RootT', bound=object | None, default=object)
# _RootT = TypeVar('_RootT', bound=object | None)
_RootT = TypeVar('_RootT')
_ModelT = TypeVar('_ModelT')

ROOT_KEY = '__root__'
Expand Down Expand Up @@ -219,7 +221,7 @@ class MyNumberList(Model[list[int]]):
See also docs of the Dataset class for more usage examples.
"""

__root__: _RootT | None
__root__: _RootT | None = Field(default_factory=lambda x: None)

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -316,18 +318,21 @@ def _populate_root_field(cls, model: type[_RootT] | TypeVar) -> type[_RootT]:
config=cls.__config__)

if not all_typevars:
data_field.field_info.extra['orig_model'] = model
cls.__fields__[ROOT_KEY] = data_field
cls.__annotations__[ROOT_KEY] = prepared_model
# data_field.field_info.extra['orig_model'] = model
cls.__fields__[ROOT_KEY].field_info.extra['orig_model'] = model
# cls.__fields__[ROOT_KEY] = data_field
# cls.__annotations__[ROOT_KEY] = prepared_model

return model

@classmethod
def _depopulate_root_field(cls):
if ROOT_KEY in cls.__fields__:
del cls.__fields__[ROOT_KEY]
if ROOT_KEY in cls.__annotations__:
del cls.__annotations__[ROOT_KEY]
root_field = cls._get_root_field()
if 'orig_model' in root_field.field_info.extra:
del root_field.field_info.extra['orig_model']
# del cls.__config__.fields[ROOT_KEY]
# del cls.__fields__[ROOT_KEY]
# del cls.__annotations__[ROOT_KEY]

@classmethod
def _prepare_cls_members_to_mimic_model(cls, created_model: 'Model[type[_RootT]]') -> None:
Expand Down Expand Up @@ -379,11 +384,14 @@ def __class_getitem__( # type: ignore[override]
)
with model_prepare as prepared_model:
created_model = cast(Model, super().__class_getitem__(prepared_model))

cls._remove_annotated_optional_hack_from_model(created_model)
else:
cls._add_annotated_optional_hack_to_model(cls)

# if not isinstance(params, tuple) and not isinstance(params, str) and not isinstance(
# params, NoneType):
# params = params | None

created_model = cast(Model, super().__class_getitem__(params))

cls._remove_annotated_optional_hack_from_model(cls)
Expand Down Expand Up @@ -623,16 +631,17 @@ def _identify_all_forward_refs_in_model_field(field: ModelField,
# super().update_forward_refs(**localns)
cls._remove_annotated_optional_hack_from_model(cls, recursive=True)

root_field = cls._get_root_field()
if root_field:
assert root_field.allow_none

# if root_field.sub_fields and not (is_union(root_field.outer_type_) or get_origin(root_field.outer_type_) in [list, dict]):
if root_field.sub_fields and not (get_origin(root_field.outer_type_) in [list, dict]):
# if root_field.sub_fields:
for sub_field in root_field.sub_fields:
if sub_field.type_.__class__ is not ForwardRef:
...
# root_field = cls._get_root_field()
# if root_field:
# assert root_field.allow_none
#
# # if root_field.sub_fields and not (is_union(root_field.outer_type_) or get_origin(root_field.outer_type_) in [list, dict]):
# if root_field.sub_fields and not (get_origin(root_field.outer_type_) in [list, dict]):
# # if root_field.sub_fields:
# for sub_field in root_field.sub_fields:
# if sub_field.type_.__class__ is not ForwardRef:
# ...

cls.__name__ = remove_forward_ref_notation(cls.__name__)
cls.__qualname__ = remove_forward_ref_notation(cls.__qualname__)

Expand Down Expand Up @@ -874,6 +883,7 @@ def _parse_with_root_type_if_model(cls,
default_value = root_field.get_default()
none_default = default_value is None or (is_model_instance(default_value)
and default_value.contents is None)
# none_default = default_value is None
root_type_is_none = is_none_type(root_type)
root_type_is_optional = get_origin(root_type) is Union \
and any(is_none_type(arg) for arg in get_args(root_type))
Expand Down Expand Up @@ -957,6 +967,7 @@ def is_nested_type(cls) -> bool:
return not cls.inner_type(with_args=True) == cls.outer_type(with_args=True)

@classmethod
# Refactor: Remove is_param_model
def is_param_model(cls) -> bool:
if cls.outer_type() is list:
type_to_check = cls.inner_type(with_args=True)
Expand Down Expand Up @@ -1590,7 +1601,7 @@ def _validate_and_set_contents_with_params(self, contents: _ParamRootT, **kwargs
self._validate_and_set_value(DataWithParams(data=contents, params=kwargs))


_ParamModelT = TypeVar('_ParamModelT', bound='ParamModel')
_ParamModelT = TypeVar('_ParamModelT', bound='ParamModel', default='ParamModel')


class ListOfParamModel(ParamModel[list[_ParamModelT
Expand Down
6 changes: 4 additions & 2 deletions src/omnipy/modules/json/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Generic, TypeVar
from typing import Generic

from typing_extensions import TypeVar

from omnipy.data.dataset import Dataset
from omnipy.data.model import Model
Expand Down Expand Up @@ -29,7 +31,7 @@
# TODO: call omnipy modules something else than modules, to distinguish from Python modules.
# Perhaps plugins?
#
_JsonModelT = TypeVar('_JsonModelT', bound=Model)
_JsonModelT = TypeVar('_JsonModelT', bound=Model, default=JsonModel)


class _JsonBaseDataset(Dataset[_JsonModelT], Generic[_JsonModelT]):
Expand Down
1 change: 1 addition & 0 deletions src/omnipy/modules/prefect/engine/prefect.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def task_flow(*inner_args, **inner_kwargs):
# LinearFlowRunnerEngine
def _init_linear_flow(self, linear_flow: IsLinearFlow) -> Any:
assert isinstance(self._config, PrefectEngineConfig)
# flow_kwargs = dict(name=linear_flow.name, persist_result=True, result_storage='S3/minio-s3')
flow_kwargs = dict(name=linear_flow.name,)
call_func = self.default_linear_flow_run_decorator(linear_flow)

Expand Down
1 change: 1 addition & 0 deletions src/omnipy/util/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def get_deepcopy_object_ids(self) -> SetDeque[int]:
return SetDeque(self._sub_obj_ids.keys())

def setup_deepcopy(self, obj):
print(f'setup_deepcopy({obj})')
assert self._cur_deepcopy_obj_id is None, \
f'self._cur_deepcopy_obj_id is not None, but {self._cur_deepcopy_obj_id}'
assert len(self._cur_keep_alive_list) == 0, \
Expand Down
32 changes: 16 additions & 16 deletions tests/data/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,19 +968,19 @@ class DictOfInt2NoneModel(Model[dict[int, NoneModel]]):
# that. Also example in test_mimic_nested_dict_operations_with_model_containers


@pytest.mark.skipif(
os.getenv('OMNIPY_FORCE_SKIPPED_TEST') != '1',
reason='Known issue, unknown why. Most probably related to pydantic v1 hack')
# @pytest.mark.skipif(
# os.getenv('OMNIPY_FORCE_SKIPPED_TEST') != '1',
# reason='Known issue, unknown why. Most probably related to pydantic v1 hack')
def test_model_union_none_known_issue() -> None:
with pytest.raises(ValidationError):
Model[int | float](None)


@pytest.mark.skipif(
os.getenv('OMNIPY_FORCE_SKIPPED_TEST') != '1',
reason='Current pydantic v1 hack requires nested types like list and dict to explicitly'
'include Optional in their arguments to support parsing of None when the level of '
'nesting is 2 or more')
# @pytest.mark.skipif(
# os.getenv('OMNIPY_FORCE_SKIPPED_TEST') != '1',
# reason='Current pydantic v1 hack requires nested types like list and dict to explicitly'
# 'include Optional in their arguments to support parsing of None when the level of '
# 'nesting is 2 or more')
def test_doubly_nested_list_and_dict_of_none_model_known_issue() -> None:
class NoneModel(Model[None]):
...
Expand Down Expand Up @@ -1137,18 +1137,18 @@ class ListModel(GenericListModel['FullModel']):
assert ListModel([None]).contents == [MaybeNumberModel(None)]


@pytest.mark.skipif(
os.getenv('OMNIPY_FORCE_SKIPPED_TEST') != '1',
reason="""
Known issue that popped up in omnipy.modules.json.models. Might be solved by pydantic v2.
Dropping JsonBaseModel (here: BaseModel) is one workaround as it (in contrast to _JsonBaseDataset)
does not seem to be needed.
""")
# @pytest.mark.skipif(
# os.getenv('OMNIPY_FORCE_SKIPPED_TEST') != '1',
# reason="""
# Known issue that popped up in omnipy.modules.json.models. Might be solved by pydantic v2.
# Dropping JsonBaseModel (here: BaseModel) is one workaround as it (in contrast to _JsonBaseDataset)
# does not seem to be needed.
# """)
def test_union_nested_model_classes_inner_forwardref_double_generic_none_as_default_known_issue(
) -> None:
MaybeNumber: TypeAlias = Optional[int]

BaseT = TypeVar('BaseT', default=list | 'FullModel' | MaybeNumber)
BaseT = TypeVar('BaseT', default='list | FullModel | MaybeNumber')

class BaseModel(Model[BaseT], Generic[BaseT]):
...
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/novel/full/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def record_schema_factory(data_file: str,
class Config(BaseConfig):
extra = Extra.forbid

# Force config.dynamically_convert... is False

return create_model(
data_file,
__base__=RecordSchemaBase,
Expand Down
1 change: 1 addition & 0 deletions tests/integration/novel/full/test_multi_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_specialize_record_models_signature_and_return_type_func(
@pc.parametrize_with_cases('case', cases='.cases.flows', has_tag='specialize_record_models')
def test_run_specialize_record_models_consistent_types(
runtime_all_engines: Annotated[None, pytest.fixture], # noqa
skip_test_if_dynamically_convert_elements_to_models,
case: FlowCase):
specialize_record_models = case.flow_template.apply()

Expand Down
4 changes: 4 additions & 0 deletions tests/modules/frozen/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest_cases as pc

from omnipy.data.model import Model
from omnipy.modules.frozen.models import NestedFrozenDictsOrTuplesModel
from omnipy.modules.frozen.typedefs import FrozenDict

from ..helpers.classes import CaseInfo
Expand Down Expand Up @@ -36,6 +37,8 @@ class FrozenDictOfInt2NoneModel(Model[FrozenDict[int, NoneModel]]):

@pc.parametrize_with_cases('case', cases='.cases.frozen_data')
def test_nested_frozen_models(case: CaseInfo) -> None:
# NestedFrozenDictsOrTuplesModel[str, None | int](None)

for field in fields(case.data_points):
name = field.name
for model_cls in case.model_classes_for_data_point(name):
Expand All @@ -51,6 +54,7 @@ def test_nested_frozen_models(case: CaseInfo) -> None:
model_cls(data)
# print(f'Error: {e}')
else:
print(data)
model_obj = model_cls(data)

# print(f'repr(model_obj): {repr(model_obj)}')
Expand Down
8 changes: 4 additions & 4 deletions tests/util/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,17 +462,17 @@ def test_remove_annotated_optional_if_present() -> None:

assert remove_annotated_plus_opt(Annotated[str, 'something']) == str
assert remove_annotated_plus_opt(Annotated[str | list[int], 'something']) == str | list[int]
assert remove_annotated_plus_opt(Annotated[Union[str, list[int]], 'something']) == \
Union[str, list[int]]
assert (remove_annotated_plus_opt(Annotated[Union[str, list[int]],
'something']) == str | list[int])

assert remove_annotated_plus_opt(Annotated[None, 'something']) == NoneType
assert remove_annotated_plus_opt(Annotated[NoneType, 'something']) == NoneType
assert remove_annotated_plus_opt(Annotated[str | None, 'something']) == str
assert remove_annotated_plus_opt(Annotated[Union[str, None], 'something']) == str
assert remove_annotated_plus_opt(Annotated[Optional[str], 'something']) == str

assert remove_annotated_plus_opt(Annotated[str | list[int] | None, 'something']) == \
Union[str, list[int]]
assert remove_annotated_plus_opt(Annotated[str | list[int] | None,
'something']) == Union[str, list[int]]
assert remove_annotated_plus_opt(Annotated[Union[str, list[int], None], 'something']) == \
Union[str, list[int]]
assert remove_annotated_plus_opt(Annotated[Optional[Union[str, list[int]]], 'something']) == \
Expand Down

0 comments on commit b64ccfc

Please sign in to comment.