diff --git a/src/omnipy/api/protocols/public/data.py b/src/omnipy/api/protocols/public/data.py index 5e038db9..740524c5 100644 --- a/src/omnipy/api/protocols/public/data.py +++ b/src/omnipy/api/protocols/public/data.py @@ -1,13 +1,234 @@ -from typing import Any, Callable, IO, Iterator, Protocol, Type, TypeVar +from abc import abstractmethod +from collections.abc import Sized +from pathlib import Path +from typing import (AbstractSet, + Any, + BinaryIO, + Callable, + Hashable, + IO, + Iterable, + Iterator, + overload, + Protocol, + runtime_checkable, + Type, + TypeVar) from pydantic.fields import Undefined, UndefinedType from omnipy.api.protocols.private.log import CanLog -_ModelT = TypeVar('_ModelT') +_RootT = TypeVar('_RootT', covariant=True) +_ModelT = TypeVar('_ModelT', bound='IsModel') +_ModelTContra = TypeVar('_ModelTContra', bound='IsModel', contravariant=True) +_ModelTCov = TypeVar('_ModelTCov', bound='IsModel', covariant=True) +KeyT = TypeVar('KeyT') +KeyContraT = TypeVar('KeyContraT', bound=Hashable, contravariant=True) +ValT = TypeVar('ValT') +ValCoT = TypeVar('ValCoT', covariant=True) +RootT = TypeVar('RootT') -class IsDataset(Protocol[_ModelT]): + +class SupportsKeysAndGetItem(Protocol[KeyT, ValCoT]): + def keys(self) -> Iterable[KeyT]: + ... + + def __getitem__(self, __key: KeyT) -> ValCoT: + ... + + +class IsSet(Protocol): + """ + IsSet is a protocol with the same interface as the abstract class Set. + It is the protocol of a finite, iterable container. + """ + def __le__(self, other: AbstractSet) -> bool: + ... + + def __lt__(self, other: AbstractSet) -> bool: + ... + + def __gt__(self, other: AbstractSet) -> bool: + ... + + def __ge__(self, other: AbstractSet) -> bool: + ... + + def __eq__(self, other: object) -> bool: + ... + + def __and__(self, other: Iterable) -> 'IsSet': + ... + + def __rand__(self, other: Iterable) -> 'IsSet': + ... + + def isdisjoint(self, other: AbstractSet) -> bool: + ... + + def __or__(self, other: Iterable) -> 'IsSet': + ... + + def __ror__(self, other: Iterable) -> 'IsSet': + ... + + def __sub__(self, other: Iterable) -> 'IsSet': + ... + + def __rsub__(self, other: Iterable) -> 'IsSet': + ... + + def __xor__(self, other: Iterable) -> 'IsSet': + ... + + def __rxor__(self, other: Iterable) -> 'IsSet': + ... + + +class IsMapping(Protocol[KeyT, ValT]): + """ + IsMapping is a protocol with the same interface as the abstract class Mapping. + It is the protocol of a generic container for associating key/value pairs. + """ + @abstractmethod + def __getitem__(self, key: KeyT) -> ValT: + raise KeyError + + def get(self, key: KeyT, /) -> ValT | None: + """ + D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None. + """ + ... + + def __contains__(self, key: KeyT) -> ValT: + ... + + def keys(self) -> 'IsKeysView[KeyT]': + """ + D.keys() -> a set-like object providing a view on D's keys + """ + ... + + def items(self) -> 'IsItemsView[KeyT, ValT]': + """ + D.items() -> a set-like object providing a view on D's items + """ + ... + + def values(self) -> 'IsValuesView[ValT]': + """ + D.values() -> an object providing a view on D's values + """ + ... + + def __eq__(self, other: object) -> bool: + ... + + +class IsMappingView(Protocol, Sized): + def __len__(self) -> int: + ... + + def __repr__(self) -> str: + ... + + +class IsKeysView(IsMappingView, Protocol[KeyContraT]): + def __contains__(self, key: KeyContraT) -> bool: + ... + + def __iter__(self) -> Iterator[ValT]: + ... + + +class IsItemsView(IsMappingView, Protocol[KeyT, ValT]): + def __contains__(self, item: tuple[KeyT, ValT]) -> bool: + ... + + def __iter__(self) -> Iterator[tuple[KeyT, ValT]]: + ... + + +class IsValuesView(IsMappingView, Protocol[ValT]): + def __contains__(self, value: ValT) -> bool: + ... + + def __iter__(self) -> Iterator[ValT]: + ... + + +class IsMutableMapping(IsMapping[KeyT, ValT], Protocol[KeyT, ValT]): + """ + IsMutableMapping is a protocol with the same interface as the abstract class MutableMapping. + It is the protocol of a generic mutable container for associating key/value pairs. + """ + @abstractmethod + def __setitem__(self, key: KeyT, value: ValT) -> None: + raise KeyError + + @abstractmethod + def __delitem__(self, key: KeyT) -> None: + raise KeyError + + def pop(self, key: KeyT) -> ValT: + """ + D.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised. + """ + ... + + def popitem(self) -> tuple[KeyT, ValT]: + """ + D.popitem() -> (k, v), remove and return some (key, value) pair + as a 2-tuple; but raise KeyError if D is empty. + """ + ... + + def clear(self) -> None: + """ + D.clear() -> None. Remove all items from D. + """ + ... + + @overload + def update(self, other: SupportsKeysAndGetItem[KeyT, ValT], /, **kwargs: ValT) -> None: + ... + + @overload + def update(self, other: Iterable[tuple[KeyT, ValT]], /, **kwargs: ValT) -> None: + ... + + @overload + def update(self, /, **kwargs: ValT) -> None: + ... + + def update(self, other: Any = None, /, **kwargs: ValT) -> None: + """ + D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. + If E present and has a .keys() method, does: for k in E: D[k] = E[k] + If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v + In either case, this is followed by: for k, v in F.items(): D[k] = v + """ + ... + + def setdefault(self, key: KeyT, default: ValT | None = None): + """ + D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D + """ + ... + + +@runtime_checkable +class IsModel(Protocol[_RootT]): + # @property + # def contents(self) -> _RootT: + # ... + ... + + +class IsDataset(IsMutableMapping, Protocol[_ModelTCov]): """ Dict-based container of data files that follow a specific Model """ @@ -21,7 +242,7 @@ def __init__( ... @classmethod - def get_model_class(cls) -> Type[_ModelT]: + def get_model_class(cls) -> type[_ModelTCov]: """ Returns the concrete Model class used for all data files in the dataset, e.g.: `Model[list[int]]` @@ -49,22 +270,64 @@ def from_json(self, def to_json_schema(cls, pretty=True) -> str | dict[str, str]: ... - def as_multi_model_dataset(self) -> 'MultiModelDataset[_ModelT]': - ... + # def as_multi_model_dataset(self) -> 'IsMultiModelDataset[_ModelT]': + # ... -class MultiModelDataset(Protocol[_ModelT]): +class IsMultiModelDataset(IsDataset[_ModelTCov], Protocol[_ModelTCov]): """ Variant of Dataset that allows custom models to be set on individual data files """ - def set_model(self, data_file: str, model: _ModelT) -> None: + ... + + def set_model(self, data_file: str, model: type[IsModel]) -> None: + ... + + # def get_model(self, data_file: str) -> type[_ModelT]: + # ... + + +def a() -> None: + from omnipy import Dataset, Model + from omnipy.data.dataset import MultiModelDataset + + t: IsModel[int] = Model[int] + + d: IsDataset[IsModel[int]] = Dataset[Model[int]]() + # + e: IsMultiModelDataset[IsModel[int]] = MultiModelDataset[Model[int]]() + + +class CanSerialize(Protocol[RootT]): + @classmethod + def serialize_to_bytes(cls, dataset: IsDataset[IsModel[RootT]]) -> BinaryIO: + ... + + @classmethod + def deserialize_from_bytes(cls, data: BinaryIO) -> IsDataset[IsModel[RootT]]: + ... + + @classmethod + def serialize_to_directory(cls, dataset: IsDataset[IsModel[RootT]], + dir_path: Path | str) -> None: + ... + + @classmethod + def deserialize_from_directory(cls, dir_path: Path | str) -> IsDataset[IsModel[RootT]]: + ... + + +class IsDataEncoder(Protocol[RootT]): + @classmethod + def encode_data(cls, dataset_key: str, data: RootT) -> bytes: ... - def get_model(self, data_file: str) -> _ModelT: + @classmethod + def decode_data(cls, dataset_key: str, encoded_data: bytes) -> RootT: ... -class IsSerializer(Protocol): +class SupportsGeneralSerializerQueries(Protocol): """""" @classmethod def is_dataset_directly_supported(cls, dataset: IsDataset) -> bool: @@ -72,19 +335,18 @@ def is_dataset_directly_supported(cls, dataset: IsDataset) -> bool: @classmethod def get_dataset_cls_for_new(cls) -> Type[IsDataset]: - pass + ... @classmethod def get_output_file_suffix(cls) -> str: pass - @classmethod - def serialize(cls, dataset: IsDataset) -> bytes | memoryview: - pass - @classmethod - def deserialize(cls, serialized: bytes, any_file_suffix=False) -> IsDataset: - pass +class IsSerializer(IsDataEncoder[RootT], + CanSerialize[RootT], + SupportsGeneralSerializerQueries, + Protocol[RootT]): + ... class IsTarFileSerializer(IsSerializer): diff --git a/src/omnipy/api/typedefs.py b/src/omnipy/api/typedefs.py index f340f33e..5295e9dc 100644 --- a/src/omnipy/api/typedefs.py +++ b/src/omnipy/api/typedefs.py @@ -1,4 +1,12 @@ -from typing import Callable, TypeAlias, TypeVar +from types import UnionType +from typing import (_AnnotatedAlias, + _GenericAlias, + _LiteralGenericAlias, + _SpecialForm, + _UnionGenericAlias, + Callable, + TypeAlias, + TypeVar) GeneralDecorator = Callable[[Callable], Callable] LocaleType: TypeAlias = str | tuple[str | None, str | None] @@ -10,3 +18,6 @@ TaskTemplateT = TypeVar('TaskTemplateT') TaskTemplateContraT = TypeVar('TaskTemplateContraT', contravariant=True) TaskTemplateCovT = TypeVar('TaskTemplateCovT', covariant=True) + +# TODO: While waiting for https://github.com/python/mypy/issues/9773 +TypeForm: TypeAlias = type | UnionType | _UnionGenericAlias | _AnnotatedAlias | _GenericAlias | _LiteralGenericAlias | _SpecialForm \ No newline at end of file diff --git a/src/omnipy/config/data.py b/src/omnipy/config/data.py index 8ba61bdf..8a2ab529 100644 --- a/src/omnipy/config/data.py +++ b/src/omnipy/config/data.py @@ -6,6 +6,6 @@ @dataclass class DataConfig: - interactive_mode: bool = True + interactive_mode: bool = False terminal_size_columns: int = _terminal_size.columns terminal_size_lines: int = _terminal_size.lines diff --git a/src/omnipy/data/dataset.py b/src/omnipy/data/dataset.py index 2954ef5c..1b4deca4 100644 --- a/src/omnipy/data/dataset.py +++ b/src/omnipy/data/dataset.py @@ -13,7 +13,6 @@ get_origin, Iterator, Optional, - Type, TypeAlias, TypeVar) from urllib.parse import ParseResult, urlparse @@ -26,6 +25,7 @@ from pydantic.generics import GenericModel from pydantic.utils import lenient_isinstance, lenient_issubclass +from omnipy.api.protocols.public.data import IsModel from omnipy.data.model import (_cleanup_name_qualname_and_module, _is_interactive_mode, _waiting_for_terminal_repr, @@ -110,7 +110,7 @@ class Config: data: dict[str, ModelT] = Field(default={}) - def __class_getitem__(cls, model: ModelT) -> ModelT: + def __class_getitem__(cls, model: ModelT) -> ModelT: # type: ignore[override] # TODO: change model type to params: Type[Any] | tuple[Type[Any], ...] # as in GenericModel. @@ -140,7 +140,7 @@ def __class_getitem__(cls, model: ModelT) -> ModelT: if cls == Dataset and not is_optional(model): # TODO: Handle MultiModelDataset?? model = Annotated[Optional[model], 'Fake Optional from Dataset'] - created_dataset = super().__class_getitem__(model) + created_dataset = super().__class_getitem__(model) # type: ignore[override] _cleanup_name_qualname_and_module(cls, created_dataset, model, orig_model) @@ -225,7 +225,7 @@ def _get_data_field(cls) -> ModelField: return cast(ModelField, cls.__fields__.get(DATA_KEY)) @classmethod - def get_model_class(cls) -> Type[Model]: + def get_model_class(cls) -> type[ModelT]: """ Returns the concrete Model class used for all data files in the dataset, e.g.: `Model[list[int]]` @@ -272,6 +272,28 @@ def _get_standard_field_description(cls) -> str: 'particular specialization of the Model class. Both main classes are wrapping ' 'the excellent Python package named `pydantic`.') + # def copy( + # self: 'Model', + # *, + # include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + # exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + # update: Optional['DictStrAny'] = None, + # deep: bool = False, + # ) -> 'Model': + # pass + + def copy(self): + return self.__copy__() + + def deepcopy(self): + return self.__deepcopy__() + + def __copy__(self): + return GenericModel.copy(self, update={'data': self.data.copy()}, deep=False) + + def __deepcopy__(self, _memo=None): + return GenericModel.copy(self, deep=True) + def __setitem__(self, data_file: str, data_obj: Any) -> None: has_prev_value = data_file in self.data prev_value = self.data.get(data_file) @@ -573,9 +595,12 @@ class MultiModelDataset(Dataset[ModelT], Generic[ModelT]): custom models. """ - _custom_field_models: dict[str, ModelT] = PrivateAttr(default={}) + # Custom field models should really be a subtype of ModelT, however this is currently not + # checkable in the type system. Instead, we rely on the _validate method to ensure that the + # custom field models are valid. + _custom_field_models: dict[str, type[IsModel]] = PrivateAttr(default={}) - def set_model(self, data_file: str, model: ModelT) -> None: + def set_model(self, data_file: str, model: type[IsModel]) -> None: try: self._custom_field_models[data_file] = model if data_file in self.data: @@ -586,7 +611,7 @@ def set_model(self, data_file: str, model: ModelT) -> None: del self._custom_field_models[data_file] raise - def get_model(self, data_file: str) -> ModelT: + def get_model(self, data_file: str) -> type[IsModel]: if data_file in self._custom_field_models: return self._custom_field_models[data_file] else: diff --git a/src/omnipy/data/model.py b/src/omnipy/data/model.py index e45424af..c9bc3bc2 100644 --- a/src/omnipy/data/model.py +++ b/src/omnipy/data/model.py @@ -8,6 +8,7 @@ from types import ModuleType, NoneType, UnionType from typing import (Annotated, Any, + Callable, cast, ContextManager, Generic, @@ -441,10 +442,10 @@ def _validate_contents_from_value(self, value: object) -> _RootT: raise validation_error return values[ROOT_KEY] - def _get_restorable_contents(self): + def _get_restorable_contents(self) -> RestorableContents: if not id(self) in _restorable_content_cache: _restorable_content_cache[id(self)] = RestorableContents() - return _restorable_content_cache.get(id(self)) + return _restorable_content_cache[id(self)] def _take_snapshot_of_validated_contents(self): interactive_mode = _is_interactive_mode() @@ -525,6 +526,11 @@ def contents(self) -> _RootT: @contents.setter def contents(self, value: _RootT) -> None: + """ + Sets the contents of the model. Note: in contrast to the `__init__()`, `from_data()` and + `from_json()` methods, the contents are not validated automatically. To validate the contents, + call the `validate_contents()` method explicitly. + """ super().__setattr__(ROOT_KEY, value) def dict(self, *args, **kwargs) -> dict[str, object]: @@ -648,38 +654,7 @@ def _special_method(self, name: str, info: MethodInfo, *args: object, method = self._getattr_from_contents_obj(name) if info.state_changing: - restorable = self._get_restorable_contents() - reset_solution: ContextManager - - if _is_interactive_mode(): - if restorable.has_snapshot() \ - and restorable.last_snapshot_taken_of_same_obj(self.contents) \ - and restorable.differs_from_last_snapshot(self.contents): - - # Current contents not validated - reset_contents_to_last_snapshot = AttribHolder( - self, 'contents', restorable.get_last_snapshot(), reset_to_other=True) - with reset_contents_to_last_snapshot: - validated_contents = self._validate_contents_from_value(self.contents) - - reset_contents_to_validated_prev = AttribHolder( - self, 'contents', validated_contents, reset_to_other=True) - reset_solution = reset_contents_to_validated_prev - else: - reset_contents_to_prev = AttribHolder(self, 'contents', copy_attr=True) - reset_solution = reset_contents_to_prev - else: - reset_solution = nothing() - - with reset_solution: - ret = method(*args, **kwargs) - if _is_interactive_mode(): - needs_validation = restorable.differs_from_last_snapshot(self.contents) \ - if restorable.has_snapshot() else True - else: - needs_validation = True - if needs_validation: - self.validate_contents() + ret = self._call_method_with_content_reset_if_validation_error(method, *args, **kwargs) else: ret = method(*args, **kwargs) @@ -688,6 +663,52 @@ def _special_method(self, name: str, info: MethodInfo, *args: object, return ret + def _call_method_with_content_reset_if_validation_error(self, + method: Callable, + *args: object, + **kwargs: object) -> object: + restorable_contents = self._get_restorable_contents() + contents_reset_solution = self._get_contents_reset_solution(restorable_contents) + + with contents_reset_solution: + ret = method(*args, **kwargs) + + if _is_interactive_mode(): + needs_validation = \ + restorable_contents.differs_from_last_snapshot(self.contents) \ + if restorable_contents.has_snapshot() else True + else: + needs_validation = True + + if needs_validation: + self.validate_contents() + + return ret + + def _get_contents_reset_solution(self, restorable) -> ContextManager: + reset_solution: ContextManager + + if _is_interactive_mode(): + if restorable.has_snapshot() \ + and restorable.last_snapshot_taken_of_same_obj(self.contents) \ + and restorable.differs_from_last_snapshot(self.contents): + + # Current contents not validated + reset_contents_to_last_snapshot = AttribHolder( + self, 'contents', restorable.get_last_snapshot(), reset_to_other=True) + with reset_contents_to_last_snapshot: + validated_contents = self._validate_contents_from_value(self.contents) + + reset_contents_to_validated_prev = AttribHolder( + self, 'contents', validated_contents, reset_to_other=True) + reset_solution = reset_contents_to_validated_prev + else: + reset_contents_to_prev = AttribHolder(self, 'contents', copy_attr=True) + reset_solution = reset_contents_to_prev + else: + reset_solution = nothing() + return reset_solution + def _convert_to_model_if_reasonable(self, args, name, ret): if not isinstance(ret, self.__class__): if name == '__getitem__': diff --git a/src/omnipy/data/serializer.py b/src/omnipy/data/serializer.py index b080feca..0b8247fc 100644 --- a/src/omnipy/data/serializer.py +++ b/src/omnipy/data/serializer.py @@ -1,14 +1,25 @@ from abc import ABC, abstractmethod +import ast from io import BytesIO import os +from pathlib import Path import tarfile from tarfile import TarInfo -from typing import Any, Callable, IO, Type +from typing import Any, BinaryIO, Callable, Generic, IO, Type, TypeVar from pydantic import ValidationError from omnipy.api.protocols.private.log import CanLog -from omnipy.api.protocols.public.data import IsDataset, IsSerializer, IsTarFileSerializer +from omnipy.api.protocols.public.data import (IsDataEncoder, + IsDataset, + IsModel, + IsSerializer, + IsTarFileSerializer, + RootT, + SupportsGeneralSerializerQueries) +from omnipy.util.helpers import ensure_path_obj + +DatasetT = TypeVar('DatasetT', bound=IsDataset) class Serializer(ABC): @@ -16,27 +27,78 @@ class Serializer(ABC): @classmethod @abstractmethod def is_dataset_directly_supported(cls, dataset: IsDataset) -> bool: - pass + ... @classmethod @abstractmethod def get_dataset_cls_for_new(cls) -> Type[IsDataset]: - pass + ... @classmethod @abstractmethod def get_output_file_suffix(cls) -> str: - pass + ... + + # @classmethod + # @abstractmethod + # def serialize(cls, dataset: IsDataset) -> bytes | memoryview: + # ... + # + # @classmethod + # @abstractmethod + # def deserialize(cls, serialized: bytes, any_file_suffix=False) -> IsDataset: + # ... + +class BytesSerializerMixin(IsDataEncoder[RootT], SupportsGeneralSerializerQueries, Generic[RootT]): @classmethod - @abstractmethod - def serialize(cls, dataset: IsDataset) -> bytes | memoryview: - pass + def serialize_to_bytes(cls, dataset: IsDataset[IsModel[RootT]]) -> BinaryIO: + dataset_as_dict = { + key: cls.encode_data(key, val.contents).decode('utf8') for (key, val) in dataset.items() + } + return BytesIO(repr(dataset_as_dict).encode('utf8')) @classmethod - @abstractmethod - def deserialize(cls, serialized: bytes, any_file_suffix=False) -> IsDataset: - pass + def deserialize_from_bytes(cls, data: BinaryIO) -> IsDataset[IsModel[RootT]]: + dataset_cls = cls.get_dataset_cls_for_new() + dataset = dataset_cls() + + dataset_as_dict_repr = data.read().decode('utf8') + dataset_as_dict: dict[str, str] = ast.literal_eval(dataset_as_dict_repr) + for key, val in dataset_as_dict.items(): + dataset[key] = cls.decode_data(key, val.encode('utf8')) + + return dataset + + +class DirectorySerializerMixin(IsDataEncoder[RootT], + SupportsGeneralSerializerQueries, + Generic[RootT]): + @classmethod + def serialize_to_directory(cls, dataset: IsDataset[IsModel[RootT]], + dir_path: Path | str) -> None: + dir_path = ensure_path_obj(dir_path) + os.makedirs(dir_path) + + for key, val in dataset.items(): + with open(dir_path / f'{key}.{cls.get_output_file_suffix()}', 'bw') as file: + file.write(cls.encode_data(key, val.contents)) + + @classmethod + def deserialize_from_directory(cls, dir_path: Path | str) -> IsDataset[IsModel[RootT]]: + dir_path = ensure_path_obj(dir_path) + + dataset_cls = cls.get_dataset_cls_for_new() + dataset = dataset_cls() + + for root, dirs, files in os.walk(dir_path): + for filename in files: + basename, suffix = os.path.splitext(filename) + assert suffix == f'{os.path.extsep}{cls.get_output_file_suffix()}' + with open(dir_path / filename, 'br') as file: + dataset[basename] = cls.decode_data(basename, file.read()) + + return dataset class TarFileSerializer(Serializer, ABC): @@ -66,11 +128,36 @@ def create_dataset_from_tarfile(cls, with tarfile.open(fileobj=BytesIO(tarfile_bytes), mode='r:gz') as tarfile_stream: for filename in tarfile_stream.getnames(): data_file = tarfile_stream.extractfile(filename) + assert data_file is not None if not any_file_suffix: assert filename.endswith(f'.{cls.get_output_file_suffix()}') data_file_name = os.path.basename('.'.join(filename.split('.')[:-1])) - getattr(dataset, import_method)( - dictify_object_func(data_file_name, data_decode_func(data_file))) + getattr(dataset, import_method)({data_file_name: data_decode_func(data_file)}) + + +class DatasetToTarFileSerializer(TarFileSerializer): + def __init__(self, registry: 'SerializerRegistry'): + self._registry = registry + + @classmethod + def is_dataset_directly_supported(cls, dataset: IsDataset) -> bool: + ... + + @classmethod + def get_dataset_cls_for_new(cls) -> Type[IsDataset]: + ... + + @classmethod + def get_output_file_suffix(cls) -> str: + return 'num' + + @classmethod + def serialize(cls, number_dataset: IsDataset) -> bytes | memoryview: + ... + + @classmethod + def deserialize(cls, tarfile_bytes: bytes, any_file_suffix=False) -> IsDataset: + ... class SerializerRegistry: @@ -129,7 +216,7 @@ def _to_data_from_data_if_direct(dataset: IsDataset, serializer: IsSerializer): new_dataset = func(dataset, serializer) return new_dataset, serializer except (TypeError, ValueError, ValidationError, AssertionError): - pass + ... return None, None diff --git a/src/omnipy/modules/json/serializers.py b/src/omnipy/modules/json/serializers.py index f0abd348..fea9e6d3 100644 --- a/src/omnipy/modules/json/serializers.py +++ b/src/omnipy/modules/json/serializers.py @@ -36,14 +36,14 @@ def deserialize(cls, tarfile_bytes: bytes, any_file_suffix=False) -> JsonDataset def json_decode_func(file_stream: IO[bytes]) -> str: return file_stream.read().decode('utf8') - def json_dictify_object(data_file: str, obj_val: str) -> dict[str, str]: - return {f'{data_file}': f'{obj_val}'} + def python_dictify_object(data_file: str, obj_val: object) -> dict: + return {data_file: obj_val} cls.create_dataset_from_tarfile( json_dataset, tarfile_bytes, data_decode_func=json_decode_func, - dictify_object_func=json_dictify_object, + dictify_object_func=python_dictify_object, import_method='from_json', any_file_suffix=any_file_suffix, ) diff --git a/src/omnipy/util/contexts.py b/src/omnipy/util/contexts.py index 3eaf86f0..9f8b1192 100644 --- a/src/omnipy/util/contexts.py +++ b/src/omnipy/util/contexts.py @@ -1,5 +1,6 @@ from contextlib import AbstractContextManager, contextmanager from copy import deepcopy +from typing import ContextManager from omnipy.util.helpers import all_equals @@ -87,7 +88,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @contextmanager -def nothing(*args, **kwds): +def nothing(*args, **kwds) -> ContextManager[None]: yield None diff --git a/src/omnipy/util/helpers.py b/src/omnipy/util/helpers.py index 19c856c5..f51a49bd 100644 --- a/src/omnipy/util/helpers.py +++ b/src/omnipy/util/helpers.py @@ -3,6 +3,7 @@ import inspect from inspect import getmodule, isclass import locale as pkg_locale +from pathlib import Path from types import GenericAlias, ModuleType, UnionType from typing import (Annotated, Any, @@ -13,7 +14,6 @@ Mapping, NamedTuple, Protocol, - Type, TypeVar, Union) @@ -24,7 +24,7 @@ from pydantic.typing import display_as_type from typing_inspect import get_generic_bases, is_generic_type -from omnipy.api.typedefs import LocaleType +from omnipy.api.typedefs import LocaleType, TypeForm _KeyT = TypeVar('_KeyT', bound=Hashable) @@ -172,12 +172,11 @@ class IsDataclass(Protocol): __dataclass_fields__: ClassVar[dict] -def remove_annotated_plus_optional_if_present( - type_or_class: Type | UnionType | object) -> Type | UnionType | object: +def remove_annotated_plus_optional_if_present(type_or_class: TypeForm) -> TypeForm: if get_origin(type_or_class) == Annotated: type_or_class = get_args(type_or_class)[0] if is_optional(type_or_class): - args = get_args(type_or_class) + args: tuple[type, ...] = get_args(type_or_class) if len(args) == 2: type_or_class = args[0] else: @@ -274,3 +273,9 @@ def called_from_omnipy_tests() -> bool: and 'omnipy/tests' in module.__file__: return True return False + + +def ensure_path_obj(dir_path: Path | str) -> Path: + if isinstance(dir_path, str): + dir_path = Path(dir_path) + return dir_path diff --git a/tests/data/helpers/classes.py b/tests/data/helpers/classes.py new file mode 100644 index 00000000..bdbe23a9 --- /dev/null +++ b/tests/data/helpers/classes.py @@ -0,0 +1,12 @@ +class MyPath(): + def __init__(self, *args, **kwargs): + self._path = args[0] if len(args) > 0 else '.' + + def __eq__(self, other): + return self._path == other._path + + def __str__(self): + return self._path + + def __truediv__(self, append_path: str) -> 'MyPath': + return MyPath(f'{self._path}/{append_path}') \ No newline at end of file diff --git a/tests/data/helpers/functions.py b/tests/data/helpers/functions.py index 06da42f4..994ecfa9 100644 --- a/tests/data/helpers/functions.py +++ b/tests/data/helpers/functions.py @@ -12,3 +12,9 @@ def assert_tar_file_contents(tarfile_bytes: bytes, file_contents = tarfile_stream.extractfile(f'{data_file_name}.{file_suffix}') assert file_contents is not None assert decode_func(file_contents.read()) == exp_contents + + +def assert_directory_in_tar_file(tarfile_bytes: bytes, dir_file_name: str): + with tarfile.open(fileobj=BytesIO(tarfile_bytes), mode='r:gz') as tarfile_stream: + tar_info = tarfile_stream.getmember(dir_file_name) + assert tar_info.isdir() diff --git a/tests/data/helpers/mocks.py b/tests/data/helpers/mocks.py index a6f59141..3e2bfd7f 100644 --- a/tests/data/helpers/mocks.py +++ b/tests/data/helpers/mocks.py @@ -1,20 +1,28 @@ +from io import BytesIO import sys -from typing import Any, IO, Type +from typing import Any, BinaryIO, Callable, IO, Type from omnipy.api.protocols.public.data import IsDataset from omnipy.data.dataset import Dataset from omnipy.data.model import Model -from omnipy.data.serializer import Serializer, TarFileSerializer +from omnipy.data.serializer import (BytesSerializerMixin, + DirectorySerializerMixin, + Serializer, + TarFileSerializer) class NumberDataset(Dataset[Model[int]]): ... -class MockNumberSerializer(Serializer): +class TextDataset(Dataset[Model[str]]): + ... + + +class MockNumberSerializer(Serializer, BytesSerializerMixin[int], DirectorySerializerMixin[int]): @classmethod def is_dataset_directly_supported(cls, dataset: IsDataset) -> bool: - return isinstance(dataset, NumberDataset) + return isinstance(dataset, Dataset) and dataset.get_model_class() == Model[int] @classmethod def get_dataset_cls_for_new(cls) -> Type[IsDataset]: @@ -25,16 +33,34 @@ def get_output_file_suffix(cls) -> str: return 'num' @classmethod - def serialize(cls, number_dataset: NumberDataset) -> bytes | memoryview: - return ','.join( - ':'.join([k, str(v.contents)]) for (k, v) in number_dataset.items()).encode('utf8') + def encode_data(cls, dataset_key: str, data: int) -> bytes: + return bytes([data]) @classmethod - def deserialize(cls, serialized_bytes: bytes, any_file_suffix=False) -> NumberDataset: - number_dataset = NumberDataset() - for key, val in [_.split(':') for _ in serialized_bytes.decode('utf8').split(',')]: - number_dataset[key] = int(val) - return number_dataset + def decode_data(cls, dataset_key: str, encoded_data: bytes) -> int: + return int.from_bytes(encoded_data, byteorder=sys.byteorder) + + +class MockTextSerializer(Serializer, BytesSerializerMixin[str], DirectorySerializerMixin[str]): + @classmethod + def is_dataset_directly_supported(cls, dataset: IsDataset) -> bool: + return isinstance(dataset, Dataset) and dataset.get_model_class() == Model[str] + + @classmethod + def get_dataset_cls_for_new(cls) -> Type[IsDataset]: + return TextDataset + + @classmethod + def get_output_file_suffix(cls) -> str: + return 'txt' + + @classmethod + def encode_data(cls, dataset_key: str, data: str) -> bytes: + return data.encode('utf8') + + @classmethod + def decode_data(cls, dataset_key: str, encoded_data: bytes) -> str: + return encoded_data.decode('utf8') class MockNumberToTarFileSerializer(TarFileSerializer): diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 5ffbd1b7..f8e19256 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -198,9 +198,13 @@ def test_more_dict_methods_with_parsing(): Dataset[Model[str]](data_file_1='321', data_file_2='321') assert len(dataset) == 3 + assert 'data_file_2' in dataset + assert 'data_file_3' in dataset + assert 'data_file_4' in dataset dataset.pop('data_file_3') assert len(dataset) == 2 + assert 'data_file_3' not in dataset # UserDict() implementation of popitem pops items FIFO contrary of the LIFO specified # in the standard library: https://docs.python.org/3/library/stdtypes.html#dict.popitem @@ -540,6 +544,62 @@ def test_import_export_custom_parser_to_other_type(): }''') # noqa: Q001 +def test_copy(): + from copy import copy + numbers = [1, 2, 3] + dataset = Dataset[Model[list[int]]]({'data_file_1': numbers, 'data_file_2': numbers}) + + for dataset_copy in [dataset.copy(), copy(dataset)]: + assert dataset == dataset_copy + assert type(dataset) == type(dataset_copy) + + dataset['data_file_1'].append(4) + assert dataset == dataset_copy + + dataset['data_file_3'] = numbers + assert dataset != dataset_copy + assert len(dataset) != len(dataset_copy) + + assert dataset['data_file_1'].contents == [1, 2, 3, 4] + assert dataset_copy['data_file_1'].contents == [1, 2, 3, 4] + + assert dataset['data_file_2'].contents == [1, 2, 3] + assert dataset_copy['data_file_2'].contents == [1, 2, 3] + + assert dataset['data_file_3'].contents == [1, 2, 3] + assert 'data_file_3' not in dataset_copy + + dataset['data_file_1'].remove(4) + del dataset['data_file_3'] + + +def test_deepcopy(): + from copy import deepcopy + numbers = [1, 2, 3] + dataset = Dataset[Model[list[int]]]({'data_file_1': numbers, 'data_file_2': numbers}) + + for dataset_deepcopy in [dataset.deepcopy(), deepcopy(dataset)]: + assert dataset == dataset_deepcopy + + dataset['data_file_1'].append(4) + assert dataset != dataset_deepcopy + + dataset['data_file_3'] = numbers + assert len(dataset) != len(dataset_deepcopy) + + assert dataset['data_file_1'].contents == [1, 2, 3, 4] + assert dataset_deepcopy['data_file_1'].contents == [1, 2, 3] + + assert dataset['data_file_2'].contents == [1, 2, 3] + assert dataset_deepcopy['data_file_2'].contents == [1, 2, 3] + + assert dataset['data_file_3'].contents == [1, 2, 3] + assert 'data_file_3' not in dataset_deepcopy + + dataset['data_file_1'].remove(4) + del dataset['data_file_3'] + + def test_generic_dataset_unbound_typevar(): # Note that the TypeVars for generic Dataset classes do not need to be bound, in contrast to # TypeVars used for generic Model classes (see test_generic_dataset_bound_typevar() below). diff --git a/tests/data/test_model.py b/tests/data/test_model.py index 800bc7ee..0ebb306c 100644 --- a/tests/data/test_model.py +++ b/tests/data/test_model.py @@ -1,3 +1,4 @@ +from datetime import datetime from math import floor import os from textwrap import dedent @@ -19,9 +20,12 @@ from pydantic.generics import GenericModel import pytest +from omnipy.api.protocols.public.data import IsModel +from omnipy.api.protocols.public.hub import IsRuntime from omnipy.data.model import Model from omnipy.modules.general.typedefs import FrozenDict +from .helpers.classes import MyPath from .helpers.models import (DefaultStrModel, ListOfUpperStrModel, LiteralFiveModel, @@ -1049,6 +1053,7 @@ def test_import_export_methods() -> None: model_dict = Model[dict]() model_dict.from_json('{"a": 2}') + assert model_dict.contents == {'a': 2} assert model_dict.to_data() == {'a': 2} model_dict.contents = {'b': 3} @@ -1172,6 +1177,24 @@ def test_model_of_pydantic_model() -> None: } +def test_mimic_validation_failure_recovery_with_interactive_mode( + runtime: Annotated[IsRuntime, pytest.fixture]) -> None: + model = Model[list[int]]([12]) + assert model.contents == [12] + + runtime.config.data.interactive_mode = False + with pytest.raises(ValidationError): + model.append('abc') + assert model.contents == [12, 'abc'] + + del model[-1] + model.validate_contents() + runtime.config.data.interactive_mode = True + with pytest.raises(ValidationError): + model.append('abc') + assert model.contents == [12] + + def test_mimic_simple_list_operations() -> None: model = Model[list[int]]() assert len(model) == 0 @@ -1639,9 +1662,23 @@ def test_mimic_operations_on_literal_models() -> None: LiteralFiveOrTextModel('text') / 2 -@pytest.mark.skipif(os.getenv('OMNIPY_FORCE_SKIPPED_TEST') != '1', reason="Not implemented") -def test_model_copy() -> None: - ... +# @pytest.mark.skipif(os.getenv('OMNIPY_FORCE_SKIPPED_TEST') != '1', reason="Not implemented") +def test_model_copy(runtime) -> None: + numbers = [1, 2, 3] + model = Model[list[int]](numbers) + + model_copy = model.copy() + assert model == model_copy + assert model is not model_copy + assert model.contents == model_copy.contents == [1, 2, 3] + + numbers.append(4) + assert model.contents == model_copy.contents == [1, 2, 3] + + model.append(4) + assert model.contents == model_copy.contents == [1, 2, 3, 4] + + model.append('five') def test_json_schema_generic_model_one_level() -> None: @@ -1929,36 +1966,104 @@ class ProductFactorDictInRomanNumerals(Model[dict[str, list[str]]]): }""") -def test_pandas_dataframe_non_builtin_direct() -> None: - # TODO: Using pandas here to test concept of non-builtin data structures. Switch to other - # example to remove dependency, to prepare splitting of pandas module to separate repo +def test_non_builtin_model_with_parser() -> None: + class PathModelWithParser(Model[MyPath | str]): + @classmethod + def _parse_data(cls, data: MyPath | str) -> MyPath: + if isinstance(data, str): + return MyPath(data) + return data - import pandas as pd + def to_data(self) -> str: + return str(self.contents) - class PandasDataFrameModel(Model[pd.DataFrame]): - ... + def __str__(self): + return str(self.contents) + + _assert_path_model(PathModelWithParser) - dataframe = pd.DataFrame([[1, 2, 3], [4, 5, 6]]) + str_path = PathModelWithParser('tests/data') + assert isinstance(str_path.contents, MyPath) + assert str_path.contents == MyPath('tests/data') + assert str_path.to_data() == 'tests/data' + assert str(str_path) == 'tests/data' - model_1 = PandasDataFrameModel() - assert isinstance(model_1.contents, pd.DataFrame) and model_1.contents.empty + int_path = PathModelWithParser(123) + assert int_path.contents == MyPath('123') + assert int_path.to_data() == '123' + assert str(int_path) == '123' - model_1.contents = dataframe - pd.testing.assert_frame_equal( - model_1.contents, - dataframe, - ) +def test_non_builtin_model_with_from_data() -> None: + class PathModelWithFromData(Model[MyPath]): + def from_data(self, value: MyPath | str) -> None: + if isinstance(value, str): + self._validate_and_set_contents(MyPath(value)) + else: + self._validate_and_set_contents(value) + + def to_data(self) -> str: + return str(self.contents) + + def __str__(self): + return str(self.contents) + + _assert_path_model(PathModelWithFromData) with pytest.raises(ValidationError): - PandasDataFrameModel([[1, 2, 3], [4, 5, 6]]) + PathModelWithFromData('tests/data') + + with pytest.raises(ValidationError): + PathModelWithFromData(123) + + +def _assert_path_model(PathModel: type[IsModel[MyPath]]) -> None: + path = PathModel() + assert isinstance(path.contents, MyPath) + assert path.contents == MyPath() + assert path.to_data() == '.' + assert str(path) == '.' + path.from_data('tests/data') + assert isinstance(path.contents, MyPath) + assert path.contents == MyPath('tests/data') + new_path = path / 'test_model.py' + assert isinstance(new_path, PathModel) + assert new_path.contents == MyPath('tests/data/test_model.py') + assert str(new_path) == 'tests/data/test_model.py' + + +def test_non_builtin_model_with_custom_default_value() -> None: + # Hack to provide a custom default value for models where calling the root type without + # arguments does not produce a default value. + class DefaultDatetime(datetime): + def __new__(cls, *args, **kwargs): + if len(args) == 0: + return datetime.min + return datetime.__new__(datetime, *args, **kwargs) + + class DatetimeModel(Model[DefaultDatetime | datetime | str]): + @classmethod + def _parse_data(cls, data: datetime | str) -> datetime: + if isinstance(data, str): + return datetime.fromisoformat(data) + return data + + event_time = datetime(year=2024, month=5, day=17, hour=8) + + model = DatetimeModel() + assert isinstance(model.contents, datetime) + assert model.contents == datetime.min + assert model.to_data() == datetime.min - model_2 = PandasDataFrameModel(dataframe) + model.from_data('2024-05-17T08:00:00') + assert isinstance(model.contents, datetime) + assert model.contents == event_time + assert model.to_data() == event_time - pd.testing.assert_frame_equal( - model_2.contents, - dataframe, - ) + model.from_data(event_time) + assert isinstance(model.contents, datetime) + assert model.contents == event_time + assert model.to_data() == event_time def test_parametrized_model() -> None: diff --git a/tests/data/test_serializer.py b/tests/data/test_serializer.py index 2f48e6c9..0826bf35 100644 --- a/tests/data/test_serializer.py +++ b/tests/data/test_serializer.py @@ -1,32 +1,190 @@ +from dataclasses import dataclass +from io import BytesIO +import os +from pathlib import Path import sys +from typing import Annotated, cast, Generic, NamedTuple, Type, TypeVar -from omnipy.data.dataset import Dataset +import pytest +import pytest_cases as pc + +from omnipy.api.protocols.public.data import CanSerialize, IsDataset, IsModel, IsSerializer +from omnipy.data.dataset import Dataset, MultiModelDataset from omnipy.data.model import Model -from omnipy.data.serializer import SerializerRegistry +from omnipy.data.serializer import DatasetToTarFileSerializer, SerializerRegistry -from .helpers.functions import assert_tar_file_contents -from .helpers.mocks import MockNumberSerializer, MockNumberToTarFileSerializer, NumberDataset +from .helpers.functions import assert_directory_in_tar_file, assert_tar_file_contents +from .helpers.mocks import (MockNumberSerializer, + MockNumberToTarFileSerializer, + MockTextSerializer, + NumberDataset, + TextDataset) +RootT = TypeVar('RootT') -def test_number_dataset_serializer(): - number_data = Dataset[Model[int]]() - number_data['data_file_1'] = 35 - number_data['data_file_2'] = 12 +@dataclass +class DatasetSerializationCase(Generic[RootT]): + dataset: IsDataset[IsModel[RootT]] + serializer: IsSerializer[RootT] + decoded_dataset_cls: type[IsDataset[IsModel[RootT]]] + data_files_encoded: dict[str, bytes] + dataset_encoded: bytes - serializer = MockNumberSerializer() - assert serializer.get_dataset_cls_for_new() is NumberDataset +def populate_number_dataset(dataset_cls: Type[IsDataset[IsModel[int]]]) -> IsDataset[IsModel[int]]: + number_data = dataset_cls() + + number_data['data_file_æ'] = 35 + number_data['data_file_ø'] = 12 + + return number_data + - serialized_bytes = serializer.serialize(number_data) - assert serialized_bytes == b'data_file_1:35,data_file_2:12' +def populate_text_dataset(dataset_cls: Type[IsDataset[IsModel[str]]]) -> IsDataset[IsModel[str]]: + str_data = dataset_cls() - deserialized_obj = serializer.deserialize(serialized_bytes) - assert deserialized_obj.to_data() == number_data.to_data() - assert type(deserialized_obj) is NumberDataset + str_data['data_file_å'] = 'thirty-five æ' + str_data['data_file_ß'] = 'twelve ø' + return str_data -def test_number_dataset_to_tar_file_serializer(): + +@pc.fixture +def number_data_files_encoded() -> Annotated[dict[str, bytes], pc.fixture]: + return {'data_file_æ': b'#', 'data_file_ø': b'\x0c'} + + +@pc.fixture +def number_dataset_encoded() -> Annotated[bytes, pc.fixture]: + return b"{'data_file_\xc3\xa6': '#', 'data_file_\xc3\xb8': '\\x0c'}" + + +@pc.case(id='MockNumberSerializer', tags=['dataset']) +@pc.parametrize("number_dataset_cls", [NumberDataset, Dataset[Model[int]]]) +def case_mock_number_serializer( + number_dataset_cls: Annotated[Type[IsDataset[IsModel[int]]], pc.case], + number_data_files_encoded: Annotated[dict[str, bytes], pc.fixture], + number_dataset_encoded: Annotated[bytes, pc.fixture], +) -> Annotated[DatasetSerializationCase[int], pc.case]: + return DatasetSerializationCase( + dataset=populate_number_dataset(number_dataset_cls), + serializer=MockNumberSerializer(), + decoded_dataset_cls=NumberDataset, + data_files_encoded=number_data_files_encoded, + dataset_encoded=number_dataset_encoded, + ) + + +@pc.fixture +def text_data_files_encoded() -> Annotated[dict[str, bytes], pc.fixture]: + return {'data_file_å': b'thirty-five \xc3\xa6', 'data_file_ß': b'twelve \xc3\xb8'} + + +@pc.fixture +def text_dataset_encoded() -> Annotated[bytes, pc.fixture]: + return b"{'data_file_\xc3\xa5': 'thirty-five \xc3\xa6', 'data_file_\xc3\x9f': 'twelve \xc3\xb8'}" + + +@pc.case(id='MockTextSerializer', tags=['dataset']) +@pc.parametrize("text_dataset_cls", [TextDataset, Dataset[Model[str]]]) +def case_mock_text_serializer( + text_dataset_cls: Annotated[Type[IsDataset[IsModel[str]]], pc.case], + text_data_files_encoded: Annotated[dict[str, bytes], pc.fixture], + text_dataset_encoded: Annotated[bytes, pc.fixture], +) -> Annotated[DatasetSerializationCase[str], pc.case]: + return DatasetSerializationCase( + dataset=populate_text_dataset(text_dataset_cls), + serializer=MockTextSerializer(), + decoded_dataset_cls=TextDataset, + data_files_encoded=text_data_files_encoded, + dataset_encoded=text_dataset_encoded, + ) + + +# @pc.case(id='SerializerRegistry', tags=['dataset']) +# def case_multi_model_dataset( +# text_data_files_encoded: Annotated[dict[str, bytes], pc.fixture], +# text_dataset_encoded: Annotated[bytes, pc.fixture], +# ) -> Annotated[DatasetSerializationCase, pc.case]: +# dataset = MultiModelDataset[Model[int | str]]() +# dataset.set_model('number', Model[int]) +# dataset.set_model('text', Model[str]) +# return DatasetSerializationCase( +# dataset=populate_text_dataset(text_dataset_cls), +# serializer=MockTextSerializer, +# decoded_dataset_cls=TextDataset, +# data_files_encoded=text_data_files_encoded, +# dataset_encoded=text_dataset_encoded, +# ) + + +@pc.parametrize_with_cases('case', has_tag='dataset', cases='.') +def test_mock_serializers(case: Annotated[DatasetSerializationCase, pc.case],) -> None: + assert case.serializer.get_dataset_cls_for_new() is case.decoded_dataset_cls + assert case.serializer.is_dataset_directly_supported(case.dataset) + + +@pc.parametrize_with_cases('case', has_tag='dataset', cases='.') +def test_dataset_serialization_to_bytes(case: Annotated[DatasetSerializationCase, pc.case]) -> None: + serialized_bytes = case.serializer.serialize_to_bytes(case.dataset) + assert cast(BytesIO, serialized_bytes).getvalue() == case.dataset_encoded + + deserialized_obj = case.serializer.deserialize_from_bytes(serialized_bytes) + assert deserialized_obj.to_data() == case.dataset.to_data() + assert type(deserialized_obj) is case.decoded_dataset_cls + + +@pc.parametrize_with_cases('case', has_tag='dataset', cases='.') +def test_dataset_serialization_to_directory( + case: Annotated[DatasetSerializationCase[RootT], pc.case], + tmp_path: Annotated[Path, pytest.fixture], +) -> None: + dir_path = tmp_path / case.dataset.__class__.__name__ + type2prefix = dict(int='num', str='txt', dict='json') + + case.serializer.serialize_to_directory(case.dataset, dir_path) + + assert os.path.exists(dir_path) and len(os.listdir(dir_path)) == len(case.dataset) == 2 + for root, dirs, files in os.walk(dir_path): + assert len(dirs) == 0 + + key: str + val: IsModel[RootT] + for key, val in case.dataset.items(): + data_file_name = f'{key}.{type2prefix[type(val.contents).__name__]}' + assert data_file_name in files + with open(dir_path / data_file_name, 'br') as file: + assert file.read() == case.data_files_encoded[key] + + deserialized_dataset = case.serializer.deserialize_from_directory(dir_path) + assert deserialized_dataset.to_data() == case.dataset.to_data() + assert type(deserialized_dataset) is case.decoded_dataset_cls + + +# +# def test_number_dataset_serialization_to_bytes(): +# number_data = NumberDataset() +# +# number_data['data_file_å'] = 35 +# number_data['data_file_ø'] = 12 +# +# serializer = MockNumberSerializer() +# +# assert serializer.get_dataset_cls_for_new() is NumberDataset +# assert serializer.is_dataset_directly_supported(number_data) +# +# serialized_bytes = serializer.serialize_to_bytes(number_data) +# assert serialized_bytes.getvalue() == \ +# b"{'data_file_\xc3\xa5': '#', 'data_file_\xc3\xb8': '\\x0c'}" +# +# deserialized_obj = serializer.deserialize_from_bytes(serialized_bytes) +# assert deserialized_obj.to_data() == number_data.to_data() +# assert type(deserialized_obj) is NumberDataset +# + + +def test_number_dataset_serialization_to_tar_file(): number_data = NumberDataset() number_data['data_file_1'] = 35 @@ -35,6 +193,7 @@ def test_number_dataset_to_tar_file_serializer(): serializer = MockNumberToTarFileSerializer() assert serializer.get_dataset_cls_for_new() is NumberDataset + assert serializer.is_dataset_directly_supported(number_data) tarfile_bytes = serializer.serialize(number_data) decode_func = lambda x: int.from_bytes(x, byteorder=sys.byteorder) # noqa @@ -47,6 +206,37 @@ def test_number_dataset_to_tar_file_serializer(): assert deserialized_json_data == number_data +def test_multi_model_dataset_of_datasets_to_bytes(): + dataset_of_datasets = MultiModelDataset[Model[Dataset[Model[int | str]]]]() + + dataset_of_datasets['data_dir_1'] = dict(data_file_1=35, data_file_2=27) + dataset_of_datasets['data_dir_2'] = dict(data_file_2=13, data_file_3=45) + + registry = SerializerRegistry() + registry.register(DatasetToTarFileSerializer) + registry.register(MockNumberToTarFileSerializer) + + serializer = DatasetToTarFileSerializer(registry) + + assert serializer.get_dataset_cls_for_new() is Dataset[Model[NumberDataset]] + assert serializer.is_dataset_directly_supported(dataset_of_datasets) + + tarfile_bytes = serializer.serialize(dataset_of_datasets) + decode_func = lambda x: int.from_bytes(x, byteorder=sys.byteorder) # noqa + + assert_directory_in_tar_file(tarfile_bytes, 'data_dir_1') + assert_directory_in_tar_file(tarfile_bytes, 'data_dir_2') + + assert_tar_file_contents(tarfile_bytes, 'data_dir_1/data_file_1', 'num', decode_func, 35) + assert_tar_file_contents(tarfile_bytes, 'data_dir_1/data_file_2', 'num', decode_func, 27) + assert_tar_file_contents(tarfile_bytes, 'data_dir_2/data_file_2', 'num', decode_func, 13) + assert_tar_file_contents(tarfile_bytes, 'data_dir_2/data_file_3', 'num', decode_func, 45) + + deserialized_json_data = serializer.deserialize(tarfile_bytes) + + assert deserialized_json_data == dataset_of_datasets + + def test_serializer_registry(): registry = SerializerRegistry() diff --git a/tests/integration/novel/serialize/cases/datasets.py b/tests/integration/novel/serialize/cases/datasets.py index a649f334..da40e3b2 100644 --- a/tests/integration/novel/serialize/cases/datasets.py +++ b/tests/integration/novel/serialize/cases/datasets.py @@ -6,6 +6,8 @@ from omnipy.modules.json.datasets import JsonDataset from omnipy.modules.pandas.models import PandasDataset +from .models import PydanticModel, TwoLevelPydanticModel + pandas_dataset = PandasDataset() pandas_dataset.from_data({ 'pandas_person': @@ -79,3 +81,17 @@ python_dataset = Dataset[Model[object]]() python_dataset['python_a'] = [{'a': 1, 'b': [2, 3, 4], 'c': {'yes': True, 'no': False}}] python_dataset['python_b'] = lambda x: x + 1 + +pydantic_dataset = Dataset[Model[PydanticModel]]() +pydantic_dataset['pydantic_a'] = dict(number=3, string='three') +pydantic_dataset['pydantic_b'] = dict(number=5, string='five') + +two_level_pydantic_dataset = Dataset[Model[TwoLevelPydanticModel]]() +two_level_pydantic_dataset['two_level_pydantic_a'] = dict( + a=dict(number=3, string='three'), + b=dict(number=4, string='four'), +) +two_level_pydantic_dataset['two_level_pydantic_b'] = dict( + a=dict(number=4, string='four'), + b=dict(number=5, string='five'), +) diff --git a/tests/integration/novel/serialize/cases/functions.py b/tests/integration/novel/serialize/cases/functions.py index 80199116..4117b6bf 100644 --- a/tests/integration/novel/serialize/cases/functions.py +++ b/tests/integration/novel/serialize/cases/functions.py @@ -11,9 +11,12 @@ json_table_as_str_dataset, json_table_dataset, pandas_dataset, + pydantic_dataset, python_dataset, str_dataset, - str_unicode_dataset) + str_unicode_dataset, + two_level_pydantic_dataset) +from .models import PydanticModel, TwoLevelPydanticModel def pandas_func() -> PandasDataset: @@ -54,3 +57,11 @@ def str_unicode_func() -> StrDataset: def python_func() -> Dataset[Model[object]]: return python_dataset + + +def pydantic_func() -> Dataset[Model[PydanticModel]]: + return pydantic_dataset + + +def two_level_pydantic_func() -> Dataset[Model[TwoLevelPydanticModel]]: + return two_level_pydantic_dataset diff --git a/tests/integration/novel/serialize/cases/jobs.py b/tests/integration/novel/serialize/cases/jobs.py index 56493a10..8d0893aa 100644 --- a/tests/integration/novel/serialize/cases/jobs.py +++ b/tests/integration/novel/serialize/cases/jobs.py @@ -10,9 +10,11 @@ json_table_as_str_func, json_table_func, pandas_func, + pydantic_func, python_func, str_func, - str_unicode_func) + str_unicode_func, + two_level_pydantic_func) @pc.case(tags=['task']) @@ -113,3 +115,23 @@ def fail_case_python_task_tmpl() -> TaskTemplate: @pc.case(tags=['flow']) def fail_case_python_flow_tmpl() -> LinearFlowTemplate: return LinearFlowTemplate(fail_case_python_task_tmpl())(python_func) + + +@pc.case(tags=['task']) +def case_pydantic_task_tmpl() -> TaskTemplate: + return TaskTemplate()(pydantic_func) + + +@pc.case(tags=['flow']) +def case_pydantic_flow_tmpl() -> LinearFlowTemplate: + return LinearFlowTemplate(case_pydantic_task_tmpl())(pydantic_func) + + +@pc.case(tags=['task']) +def case_two_level_pydantic_task_tmpl() -> TaskTemplate: + return TaskTemplate()(two_level_pydantic_func) + + +@pc.case(tags=['flow']) +def case_two_level_pydantic_flow_tmpl() -> LinearFlowTemplate: + return LinearFlowTemplate(case_two_level_pydantic_task_tmpl())(two_level_pydantic_func) diff --git a/tests/integration/novel/serialize/cases/models.py b/tests/integration/novel/serialize/cases/models.py new file mode 100644 index 00000000..eda99df3 --- /dev/null +++ b/tests/integration/novel/serialize/cases/models.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + + +class PydanticModel(BaseModel): + number: int = 0 + string: str = '' + + +class TwoLevelPydanticModel(BaseModel): + a: PydanticModel = PydanticModel() + b: PydanticModel = PydanticModel() diff --git a/tests/util/test_helpers.py b/tests/util/test_helpers.py index 889ecc30..66918f82 100644 --- a/tests/util/test_helpers.py +++ b/tests/util/test_helpers.py @@ -168,14 +168,14 @@ def __len__(self): return 1 a = MyClass() - a.__len__ = MethodType(__len__, a) + a.__len__ = MethodType(__len__, a) # type: ignore[attr-defined] assert has_items(a) is True def test_get_first_item() -> None: with pytest.raises(AssertionError): - get_first_item(42) + get_first_item(42) # type: ignore[arg-type] with pytest.raises(AssertionError): get_first_item('') @@ -204,7 +204,7 @@ def test_is_union() -> None: assert is_union(Union[Union[str, int], None]) is True assert is_union(Union[Union[str, None], int]) is True - assert is_union(Union[str, int] | None) is True + assert is_union(Union[str, int] | None) is True # type: ignore[operator] assert is_union(Union[str, None] | int) is True assert is_union(Union) is True @@ -237,7 +237,7 @@ def test_is_optional() -> None: assert is_optional(Union[Union[str, NoneType], int]) is True assert is_optional(Union[Union[str, None], int]) is True - assert is_optional(Union[str, int] | None) is True + assert is_optional(Union[str, int] | None) is True # type: ignore[operator] assert is_optional(Union[str, int] | NoneType) is True assert is_optional(Union[str, NoneType] | int) is True assert is_optional(Union[str, None] | int) is True @@ -435,7 +435,7 @@ def test_restorable_contents(): def test_get_calling_module_name() -> None: - def local_call_get_calling_module_name() -> str: + def local_call_get_calling_module_name() -> str | None: return get_calling_module_name() from .helpers.other_module import (calling_module_name_when_importing_other_module,