From 1a8380d695f2a7fd2d40a6c011e07332fcad1a17 Mon Sep 17 00:00:00 2001 From: Bobby Noelte Date: Sun, 24 Nov 2024 09:23:43 +0100 Subject: [PATCH] Add core abstract and base classes. Signed-off-by: Bobby Noelte --- src/akkudoktoreos/core/coreabc.py | 218 +++++ src/akkudoktoreos/core/dataabc.py | 1186 ++++++++++++++++++++++++++++ src/akkudoktoreos/core/pydantic.py | 226 ++++++ tests/test_dataabc.py | 635 +++++++++++++++ 4 files changed, 2265 insertions(+) create mode 100644 src/akkudoktoreos/core/coreabc.py create mode 100644 src/akkudoktoreos/core/dataabc.py create mode 100644 src/akkudoktoreos/core/pydantic.py create mode 100644 tests/test_dataabc.py diff --git a/src/akkudoktoreos/core/coreabc.py b/src/akkudoktoreos/core/coreabc.py new file mode 100644 index 0000000..1283d62 --- /dev/null +++ b/src/akkudoktoreos/core/coreabc.py @@ -0,0 +1,218 @@ +"""Abstract and base classes for EOS core. + +This module provides foundational classes for handling configuration and prediction functionality +in EOS. It includes base classes that provide convenient access to global +configuration and prediction instances through properties. + +Classes: + - ConfigMixin: Mixin class for managing and accessing global configuration. + - PredictionMixin: Mixin class for managing and accessing global prediction data. + - SingletonMixin: Mixin class to create singletons. +""" + +import threading +from typing import Any, ClassVar, Dict, Optional, Type + +from pendulum import DateTime +from pydantic import computed_field + +from akkudoktoreos.utils.logutil import get_logger + +logger = get_logger(__name__) + + +class ConfigMixin: + """Mixin class for managing EOS configuration data. + + This class serves as a foundational component for EOS-related classes requiring access + to the global EOS configuration. It provides a `config` property that dynamically retrieves + the configuration instance, ensuring up-to-date access to configuration settings. + + Usage: + Subclass this base class to gain access to the `config` attribute, which retrieves the + global configuration instance lazily to avoid import-time circular dependencies. + + Attributes: + config (ConfigEOS): Property to access the global EOS configuration. + + Example: + ```python + class MyEOSClass(ConfigMixin): + def my_method(self): + if self.config.myconfigval: + ``` + """ + + @property + def config(self) -> Any: + """Convenience method/ attribute to retrieve the EOS onfiguration data. + + Returns: + ConfigEOS: The configuration. + """ + # avoid circular dependency at import time + from akkudoktoreos.config.config import get_config + + return get_config() + + +class PredictionMixin: + """Mixin class for managing EOS prediction data. + + This class serves as a foundational component for EOS-related classes requiring access + to global prediction data. It provides a `prediction` property that dynamically retrieves + the prediction instance, ensuring up-to-date access to prediction results. + + Usage: + Subclass this base class to gain access to the `prediction` attribute, which retrieves the + global prediction instance lazily to avoid import-time circular dependencies. + + Attributes: + prediction (Prediction): Property to access the global EOS prediction data. + + Example: + ```python + class MyOptimizationClass(PredictionMixin): + def analyze_myprediction(self): + prediction_data = self.prediction.mypredictionresult + # Perform analysis + ``` + """ + + @property + def prediction(self) -> Any: + """Convenience method/ attribute to retrieve the EOS prediction data. + + Returns: + Prediction: The prediction. + """ + # avoid circular dependency at import time + from akkudoktoreos.prediction.prediction import get_prediction + + return get_prediction() + + +class EnergyManagementSystemMixin: + """Mixin class for managing EOS energy management system. + + This class serves as a foundational component for EOS-related classes requiring access + to global energy management system. It provides a `ems` property that dynamically retrieves + the energy management system instance, ensuring up-to-date access to energy management system + control. + + Usage: + Subclass this base class to gain access to the `ems` attribute, which retrieves the + global EnergyManagementSystem instance lazily to avoid import-time circular dependencies. + + Attributes: + ems (EnergyManagementSystem): Property to access the global EOS energy management system. + + Example: + ```python + class MyOptimizationClass(EnergyManagementSystemMixin): + def analyze_myprediction(self): + ems_data = self.ems.the_ems_method() + # Perform analysis + ``` + """ + + @property + def ems(self) -> Any: + """Convenience method/ attribute to retrieve the EOS energy management system. + + Returns: + EnergyManagementSystem: The energy management system. + """ + # avoid circular dependency at import time + from akkudoktoreos.core.ems import get_ems + + return get_ems() + + +class StartMixin(EnergyManagementSystemMixin): + """A mixin to manage the start datetime for energy management. + + Provides property: + - `start_datetime`: The starting datetime of the current or latest energy management. + """ + + # Computed field for start_datetime + @computed_field # type: ignore[prop-decorator] + @property + def start_datetime(self) -> Optional[DateTime]: + """Returns the start datetime of the current or latest energy management. + + Returns: + DateTime: The starting datetime of the current or latest energy management, or None. + """ + return self.ems.start_datetime + + +class SingletonMixin: + """A thread-safe singleton mixin class. + + Ensures that only one instance of the derived class is created, even when accessed from multiple + threads. This mixin is intended to be combined with other classes, such as Pydantic models, + to make them singletons. + + Attributes: + _instances (Dict[Type, Any]): A dictionary holding instances of each singleton class. + _lock (threading.Lock): A lock to synchronize access to singleton instance creation. + + Usage: + - Inherit from `SingletonMixin` alongside other classes to make them singletons. + - Avoid using `__init__` to reinitialize the singleton instance after it has been created. + + Example: + class MySingletonModel(SingletonMixin, PydanticBaseModel): + name: str + + instance1 = MySingletonModel(name="Instance 1") + instance2 = MySingletonModel(name="Instance 2") + + assert instance1 is instance2 # True + print(instance1.name) # Output: "Instance 1" + """ + + _lock: ClassVar[threading.Lock] = threading.Lock() + _instances: ClassVar[Dict[Type, Any]] = {} + + def __new__(cls: Type["SingletonMixin"], *args: Any, **kwargs: Any) -> "SingletonMixin": + """Creates or returns the singleton instance of the class. + + Ensures thread-safe instance creation by locking during the first instantiation. + + Args: + *args: Positional arguments for instance creation (ignored if instance exists). + **kwargs: Keyword arguments for instance creation (ignored if instance exists). + + Returns: + SingletonMixin: The singleton instance of the derived class. + """ + if cls not in cls._instances: + with cls._lock: + if cls not in cls._instances: + instance = super().__new__(cls) + cls._instances[cls] = instance + return cls._instances[cls] + + @classmethod + def reset_instance(cls) -> None: + """Resets the singleton instance, forcing it to be recreated on next access.""" + with cls._lock: + if cls in cls._instances: + del cls._instances[cls] + logger.debug(f"{cls.__name__} singleton instance has been reset.") + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initializes the singleton instance if it has not been initialized previously. + + Further calls to `__init__` are ignored for the singleton instance. + + Args: + *args: Positional arguments for initialization. + **kwargs: Keyword arguments for initialization. + """ + if not hasattr(self, "_initialized"): + super().__init__(*args, **kwargs) + self._initialized = True diff --git a/src/akkudoktoreos/core/dataabc.py b/src/akkudoktoreos/core/dataabc.py new file mode 100644 index 0000000..9701b73 --- /dev/null +++ b/src/akkudoktoreos/core/dataabc.py @@ -0,0 +1,1186 @@ +"""Abstract and base classes for generic data. + +This module provides classes for managing and processing generic data in a flexible, configurable manner. +It includes classes to handle configurations, record structures, sequences, and containers for generic data, +enabling efficient storage, retrieval, and manipulation of data records. + +This module is designed for use in predictive modeling workflows, facilitating the organization, serialization, +and manipulation of configuration and generic data in a clear, scalable, and structured manner. +""" + +import difflib +import json +from abc import abstractmethod +from collections.abc import MutableMapping, MutableSequence +from itertools import chain +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, overload + +import pandas as pd +import pendulum +from pendulum import DateTime +from pydantic import AwareDatetime, ConfigDict, Field, computed_field, field_validator + +from akkudoktoreos.core.coreabc import ConfigMixin, SingletonMixin, StartMixin +from akkudoktoreos.core.pydantic import PydanticBaseModel +from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime +from akkudoktoreos.utils.logutil import get_logger + +logger = get_logger(__name__) + + +class DataBase(ConfigMixin, StartMixin, PydanticBaseModel): + """Base class for handling generic data. + + Enables access to EOS configuration data (attribute `config`). + """ + + pass + + +class DataRecord(DataBase, MutableMapping): + """Base class for data records, enabling dynamic access to fields defined in derived classes. + + Fields can be accessed and mutated both using dictionary-style access (`record['field_name']`) + and attribute-style access (`record.field_name`). + + Attributes: + date_time (Optional[AwareDatetime]): Aware datetime indicating when the data record applies. + + Configurations: + - Allows mutation after creation. + - Supports non-standard data types like `datetime`. + """ + + date_time: Optional[AwareDatetime] = Field(default=None, description="DateTime") + + # Pydantic v2 model configuration + model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) + + @field_validator("date_time", mode="before") + @classmethod + def transform_to_datetime(cls, value: Any) -> DateTime: + """Converts various datetime formats into AwareDatetime.""" + if value is None: + # Allow to set to default. + return None + return to_datetime(value) + + @classmethod + def record_keys(cls) -> List[str]: + """Returns the keys of all fields in the data record.""" + key_list = [] + key_list.extend(list(cls.model_fields.keys())) + key_list.extend(list(cls.__pydantic_decorators__.computed_fields.keys())) + return key_list + + @classmethod + def record_keys_writable(cls) -> List[str]: + """Returns the keys of all fields in the data record that are writable.""" + return list(cls.model_fields.keys()) + + def _validate_key_writable(self, key: str) -> None: + """Verify that a specified key exists and is writable in the current record keys. + + Args: + key (str): The key to check for in the records. + + Raises: + KeyError: If the specified key is not in the expected list of keys for the records. + """ + if key not in self.record_keys_writable(): + raise KeyError( + f"Key '{key}' is not in writable record keys: {self.record_keys_writable()}" + ) + + def __getitem__(self, key: str) -> Any: + """Retrieve the value of a field by key name. + + Args: + key (str): The name of the field to retrieve. + + Returns: + Any: The value of the requested field. + + Raises: + KeyError: If the specified key does not exist. + """ + if key in self.model_fields: + return getattr(self, key) + raise KeyError(f"'{key}' not found in the record fields.") + + def __setitem__(self, key: str, value: Any) -> None: + """Set the value of a field by key name. + + Args: + key (str): The name of the field to set. + value (Any): The value to assign to the field. + + Raises: + KeyError: If the specified key does not exist in the fields. + """ + if key in self.model_fields: + setattr(self, key, value) + else: + raise KeyError(f"'{key}' is not a recognized field.") + + def __delitem__(self, key: str) -> None: + """Delete the value of a field by key name by setting it to None. + + Args: + key (str): The name of the field to delete. + + Raises: + KeyError: If the specified key does not exist in the fields. + """ + if key in self.model_fields: + setattr(self, key, None) # Optional: set to None instead of deleting + else: + raise KeyError(f"'{key}' is not a recognized field.") + + def __iter__(self) -> Iterator[str]: + """Iterate over the field names in the data record. + + Returns: + Iterator[str]: An iterator over field names. + """ + return iter(self.model_fields) + + def __len__(self) -> int: + """Return the number of fields in the data record. + + Returns: + int: The number of defined fields. + """ + return len(self.model_fields) + + def __repr__(self) -> str: + """Provide a string representation of the data record. + + Returns: + str: A string representation showing field names and their values. + """ + field_values = {field: getattr(self, field) for field in self.model_fields} + return f"{self.__class__.__name__}({field_values})" + + def __getattr__(self, key: str) -> Any: + """Dynamic attribute access for fields. + + Args: + key (str): The name of the field to access. + + Returns: + Any: The value of the requested field. + + Raises: + AttributeError: If the field does not exist. + """ + if key in self.model_fields: + return getattr(self, key) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + + def __setattr__(self, key: str, value: Any) -> None: + """Set attribute values directly if they are recognized fields. + + Args: + key (str): The name of the attribute/field to set. + value (Any): The value to assign to the attribute/field. + + Raises: + AttributeError: If the attribute/field does not exist. + """ + if key in self.model_fields: + super().__setattr__(key, value) + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + + def __delattr__(self, key: str) -> None: + """Delete an attribute by setting it to None if it exists as a field. + + Args: + key (str): The name of the attribute/field to delete. + + Raises: + AttributeError: If the attribute/field does not exist. + """ + if key in self.model_fields: + setattr(self, key, None) # Optional: set to None instead of deleting + else: + super().__delattr__(key) + + @classmethod + def key_from_description(cls, description: str, threshold: float = 0.8) -> Optional[str]: + """Returns the attribute key that best matches the provided description. + + Fuzzy matching is used. + + Args: + description (str): The description text to search for. + threshold (float): The minimum ratio for a match (0-1). Default is 0.8. + + Returns: + Optional[str]: The attribute key if a match is found above the threshold, else None. + """ + if description is None: + return None + + # Get all descriptions from the fields + descriptions = { + field_name: field_info.description + for field_name, field_info in cls.model_fields.items() + } + + # Use difflib to get close matches + matches = difflib.get_close_matches( + description, descriptions.values(), n=1, cutoff=threshold + ) + + # Check if there is a match + if matches: + best_match = matches[0] + # Return the key that corresponds to the best match + for key, desc in descriptions.items(): + if desc == best_match: + return key + return None + + @classmethod + def keys_from_descriptions( + cls, descriptions: List[str], threshold: float = 0.8 + ) -> List[Optional[str]]: + """Returns a list of attribute keys that best matches the provided list of descriptions. + + Fuzzy matching is used. + + Args: + descriptions (List[str]): A list of description texts to search for. + threshold (float): The minimum ratio for a match (0-1). Default is 0.8. + + Returns: + List[Optional[str]]: A list of attribute keys matching the descriptions, with None for unmatched descriptions. + """ + keys = [] + for description in descriptions: + key = cls.key_from_description(description, threshold) + keys.append(key) + return keys + + +class DataSequence(DataBase, MutableSequence): + """A managed sequence of DataRecord instances with list-like behavior. + + The DataSequence class provides an ordered, mutable collection of DataRecord + instances, allowing list-style access for adding, deleting, and retrieving records. It also + supports advanced data operations such as JSON serialization, conversion to Pandas Series, + and sorting by timestamp. + + Attributes: + records (List[DataRecord]): A list of DataRecord instances representing + individual generic data points. + record_keys (Optional[List[str]]): A list of field names (keys) expected in each + DataRecord. + + Note: + Derived classes have to provide their own records field with correct record type set. + + Usage: + # Example of creating, adding, and using DataSequence + class DerivedSequence(DataSquence): + records: List[DerivedDataRecord] = Field(default_factory=list, + description="List of data records") + + seq = DerivedSequence() + seq.insert(DerivedDataRecord(date_time=datetime.now(), temperature=72)) + seq.insert(DerivedDataRecord(date_time=datetime.now(), temperature=75)) + + # Convert to JSON and back + json_data = seq.to_json() + new_seq = DerivedSequence.from_json(json_data) + + # Convert to Pandas Series + series = seq.key_to_series('temperature') + """ + + # To be overloaded by derived classes. + records: List[DataRecord] = Field(default_factory=list, description="List of data records") + + # Derived fields (computed) + @computed_field # type: ignore[prop-decorator] + @property + def record_keys(self) -> List[str]: + """Returns the keys of all fields in the data records.""" + key_list = [] + key_list.extend(list(self.record_class().model_fields.keys())) + key_list.extend(list(self.record_class().__pydantic_decorators__.computed_fields.keys())) + return key_list + + @computed_field # type: ignore[prop-decorator] + @property + def record_keys_writable(self) -> List[str]: + """Returns the keys of all fields in the data records that are writable.""" + return list(self.record_class().model_fields.keys()) + + @classmethod + def record_class(cls) -> Type: + """Returns the class of the data record this data sequence handles.""" + # Access the model field metadata + field_info = cls.model_fields["records"] + # Get the list element type from the 'type_' attribute + list_element_type = field_info.annotation.__args__[0] + if not isinstance(list_element_type(), DataRecord): + raise ValueError( + f"Data record must be an instance of DataRecord: '{list_element_type}'." + ) + return list_element_type + + def _validate_key(self, key: str) -> None: + """Verify that a specified key exists in the current record keys. + + Args: + key (str): The key to check for in the records. + + Raises: + KeyError: If the specified key is not in the expected list of keys for the records. + """ + if key not in self.record_keys: + raise KeyError(f"Key '{key}' is not in record keys: {self.record_keys}") + + def _validate_key_writable(self, key: str) -> None: + """Verify that a specified key exists and is writable in the current record keys. + + Args: + key (str): The key to check for in the records. + + Raises: + KeyError: If the specified key is not in the expected list of keys for the records. + """ + if key not in self.record_keys_writable: + raise KeyError( + f"Key '{key}' is not in writable record keys: {self.record_keys_writable}" + ) + + def _validate_record(self, value: DataRecord) -> None: + """Check if the provided value is a valid DataRecord with compatible keys. + + Args: + value (DataRecord): The record to validate. + + Raises: + ValueError: If the value is not an instance of DataRecord or has an invalid date_time type. + KeyError: If the value has different keys from those expected in the sequence. + """ + # Assure value is of correct type + if value.__class__.__name__ != self.record_class().__name__: + raise ValueError(f"Value must be an instance of `{self.record_class().__name__}`.") + + # Assure datetime value can be converted to datetime object + value.date_time = to_datetime(value.date_time) + + @overload + def __getitem__(self, index: int) -> DataRecord: ... + + @overload + def __getitem__(self, index: slice) -> list[DataRecord]: ... + + def __getitem__(self, index: Union[int, slice]) -> Union[DataRecord, list[DataRecord]]: + """Retrieve a DataRecord or list of DataRecords by index or slice. + + Supports both single item and slice-based access to the sequence. + + Args: + index (int or slice): The index or slice to access. + + Returns: + DataRecord or list[DataRecord]: A single DataRecord or a list of DataRecords. + + Raises: + IndexError: If the index is invalid or out of range. + """ + if isinstance(index, int): + # Single item access logic + return self.records[index] + elif isinstance(index, slice): + # Slice access logic + return self.records[index] + raise IndexError("Invalid index") + + def __setitem__(self, index: Any, value: Any) -> None: + """Replace a data record or slice of records with new value(s). + + Supports setting a single record at an integer index or + multiple records using a slice. + + Args: + index (int or slice): The index or slice to modify. + value (DataRecord or list[DataRecord]): + Single record or list of records to set. + + Raises: + ValueError: If the number of records does not match the slice length. + IndexError: If the index is out of range. + """ + if isinstance(index, int): + if isinstance(value, list): + raise ValueError("Cannot assign list to single index") + self._validate_record(value) + self.records[index] = value + elif isinstance(index, slice): + if isinstance(value, DataRecord): + raise ValueError("Cannot assign single record to slice") + for record in value: + self._validate_record(record) + self.records[index] = value + else: + # Should never happen + raise TypeError("Invalid type for index") + + def __delitem__(self, index: Any) -> None: + """Remove a single data record or a slice of records. + + Supports deleting a single record by integer index + or multiple records using a slice. + + Args: + index (int or slice): The index or slice to delete. + + Raises: + IndexError: If the index is out of range. + """ + del self.records[index] + + def __len__(self) -> int: + """Get the number of DataRecords in the sequence. + + Returns: + int: The count of records in the sequence. + """ + return len(self.records) + + def __iter__(self) -> Iterator[DataRecord]: + """Create an iterator for accessing DataRecords sequentially. + + Returns: + Iterator[DataRecord]: An iterator for the records. + """ + return iter(self.records) + + def __repr__(self) -> str: + """Provide a string representation of the DataSequence. + + Returns: + str: A string representation of the DataSequence. + """ + return f"{self.__class__.__name__}([{', '.join(repr(record) for record in self.records)}])" + + def insert(self, index: int, value: DataRecord) -> None: + """Insert a DataRecord at a specified index in the sequence. + + This method inserts a `DataRecord` at the specified index within the sequence of records, + shifting subsequent records to the right. If `index` is 0, the record is added at the beginning + of the sequence, and if `index` is equal to the length of the sequence, the record is appended + at the end. + + Args: + index (int): The position before which to insert the new record. An index of 0 inserts + the record at the start, while an index equal to the length of the sequence + appends it to the end. + value (DataRecord): The `DataRecord` instance to insert into the sequence. + + Raises: + ValueError: If `value` is not an instance of `DataRecord`. + """ + self.records.insert(index, value) + + def insert_by_datetime(self, value: DataRecord) -> None: + """Insert or merge a DataRecord into the sequence based on its date. + + If a record with the same date exists, merges new data fields with the existing record. + Otherwise, appends the record and maintains chronological order. + + Args: + value (DataRecord): The record to add or merge. + """ + self._validate_record(value) + # Check if a record with the given date already exists + for record in self.records: + if not isinstance(record.date_time, DateTime): + raise ValueError( + f"Record date '{record.date_time}' is not a datetime, but a `{type(record.date_time).__name__}`." + ) + if compare_datetimes(record.date_time, value.date_time).equal: + # Merge values, only updating fields where data record has a non-None value + for field, val in value.model_dump(exclude_unset=True).items(): + if field in value.record_keys_writable(): + setattr(record, field, val) + break + else: + # Add data record if the date does not exist + self.records.append(value) + # Sort the list by datetime after adding/updating + self.sort_by_datetime() + + def update_value(self, date: DateTime, key: str, value: Any) -> None: + """Updates a specific value in the data record for a given date. + + If a record for the date exists, updates the specified attribute with the new value. + Otherwise, appends a new record with the given value and maintains chronological order. + + Args: + date (datetime): The date for which the weather value is to be added or updated. + key (str): The attribute name to be updated. + value: The new value to set for the specified attribute. + """ + self._validate_key_writable(key) + # Ensure datetime objects are normalized + date = to_datetime(date, to_maxtime=False) + # Check if a record with the given date already exists + for record in self.records: + if not isinstance(record.date_time, DateTime): + raise ValueError( + f"Record date '{record.date_time}' is not a datetime, but a `{type(record.date_time).__name__}`." + ) + if compare_datetimes(record.date_time, date).equal: + # Update the DataRecord with the new value for the specified key + setattr(record, key, value) + break + else: + # Create a new record and append to the list + record = self.record_class()(date_time=date, **{key: value}) + self.records.append(record) + # Sort the list by datetime after adding/updating + self.sort_by_datetime() + + def to_datetimeindex(self) -> pd.DatetimeIndex: + """Generate a Pandas DatetimeIndex from the date_time fields of all records in the sequence. + + Returns: + pd.DatetimeIndex: An index of datetime values corresponding to each record's date_time attribute. + + Raises: + ValueError: If any record does not have a valid date_time attribute. + """ + date_times = [record.date_time for record in self.records if record.date_time is not None] + + if not date_times: + raise ValueError("No valid date_time values found in the records.") + + return pd.DatetimeIndex(date_times) + + def key_to_dict( + self, + key: str, + start_datetime: Optional[DateTime] = None, + end_datetime: Optional[DateTime] = None, + ) -> Dict[DateTime, Any]: + """Extract a dictionary indexed by the date_time field of the DataRecords. + + The dictionary will contain values extracted from the specified key attribute of each DataRecord, + using the date_time field as the key. + + Args: + key (str): The field name in the DataRecord from which to extract values. + start_datetime (datetime, optional): The start date to filter records (inclusive). + end_datetime (datetime, optional): The end date to filter records (exclusive). + + Returns: + Dict[datetime, Any]: A dictionary with the date_time of each record as the key + and the values extracted from the specified key. + + Raises: + KeyError: If the specified key is not found in any of the DataRecords. + """ + self._validate_key(key) + # Ensure datetime objects are normalized + start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None + end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + + # Create a dictionary to hold date_time and corresponding values + filtered_data = { + to_datetime(record.date_time, as_string=True): getattr(record, key, None) + for record in self.records + if (start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge) + and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt) + } + + return filtered_data + + def key_to_lists( + self, + key: str, + start_datetime: Optional[DateTime] = None, + end_datetime: Optional[DateTime] = None, + ) -> Tuple[List[DateTime], List[Optional[float]]]: + """Extracts two lists from data records within an optional date range. + + The lists are: + Dates: List of datetime elements. + Values: List of values corresponding to the specified key in the data records. + + Args: + key (str): The key of the attribute in DataRecord to extract. + start_datetime (datetime, optional): The start date for filtering the records (inclusive). + end_datetime (datetime, optional): The end date for filtering the records (exclusive). + + Returns: + tuple: A tuple containing a list of datetime values and a list of extracted values. + + Raises: + KeyError: If the specified key is not found in any of the DataRecords. + """ + self._validate_key(key) + # Ensure datetime objects are normalized + start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None + end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + + # Create two lists to hold date_time and corresponding values + filtered_records = [] + for record in self.records: + if record.date_time is None: + continue + if ( + start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge + ) and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt): + filtered_records.append(record) + dates = [record.date_time for record in filtered_records] + values = [getattr(record, key, None) for record in filtered_records] + + return dates, values + + def key_to_series( + self, + key: str, + start_datetime: Optional[DateTime] = None, + end_datetime: Optional[DateTime] = None, + ) -> pd.Series: + """Extract a series indexed by the date_time field from data records within an optional date range. + + Args: + key (str): The field name in the DataRecord from which to extract values. + start_datetime (datetime, optional): The start date for filtering the records (inclusive). + end_datetime (datetime, optional): The end date for filtering the records (exclusive). + + Returns: + pd.Series: A Pandas Series with the index as the date_time of each record + and the values extracted from the specified key. + + Raises: + KeyError: If the specified key is not found in any of the DataRecords. + """ + dates, values = self.key_to_lists(key, start_datetime, end_datetime) + return pd.Series(data=values, index=pd.DatetimeIndex(dates), name=key) + + def key_from_series(self, key: str, series: pd.Series) -> None: + """Update the DataSequence from a Pandas Series. + + The series index should represent the date_time of each DataRecord, and the series values + should represent the corresponding data values for the specified key. + + Args: + series (pd.Series): A Pandas Series containing data to update the DataSequence. + key (str): The field name in the DataRecord that corresponds to the values in the Series. + """ + self._validate_key_writable(key) + + for date_time, value in series.items(): + # Ensure datetime objects are normalized + date_time = to_datetime(date_time, to_maxtime=False) if date_time else None + # Check if there's an existing record for this date_time + existing_record = next((r for r in self.records if r.date_time == date_time), None) + if existing_record: + # Update existing record's specified key + setattr(existing_record, key, value) + else: + # Create a new DataRecord if none exists + new_record = self.record_class()(date_time=date_time, **{key: value}) + self.records.append(new_record) + self.sort_by_datetime() + + def sort_by_datetime(self, reverse: bool = False) -> None: + """Sort the DataRecords in the sequence by their date_time attribute. + + This method modifies the existing list of records in place, arranging them in order + based on the date_time attribute of each DataRecord. + + Args: + reverse (bool, optional): If True, sorts in descending order. + If False (default), sorts in ascending order. + + Raises: + TypeError: If any record's date_time attribute is None or not comparable. + """ + try: + # Use a default value (-inf or +inf) for None to make all records comparable + self.records.sort( + key=lambda record: record.date_time or pendulum.datetime(1, 1, 1, 0, 0, 0), + reverse=reverse, + ) + except TypeError as e: + # Provide a more informative error message + none_records = [i for i, record in enumerate(self.records) if record.date_time is None] + if none_records: + raise TypeError( + f"Cannot sort: {len(none_records)} record(s) have None date_time " + f"at indices {none_records}" + ) from e + raise + + def delete_by_datetime( + self, start_datetime: Optional[DateTime] = None, end_datetime: Optional[DateTime] = None + ) -> None: + """Delete DataRecords from the sequence within a specified datetime range. + + Removes records with `date_time` attributes that fall between `start_datetime` (inclusive) + and `end_datetime` (exclusive). If only `start_datetime` is provided, records from that date + onward will be removed. If only `end_datetime` is provided, records up to that date will be + removed. If none is given, no record will be deleted. + + Args: + start_datetime (datetime, optional): The start date to begin deleting records (inclusive). + end_datetime (datetime, optional): The end date to stop deleting records (exclusive). + + Raises: + ValueError: If both `start_datetime` and `end_datetime` are None. + """ + # Ensure datetime objects are normalized + start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None + end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + + # Retain records that are outside the specified range + retained_records = [] + for record in self.records: + if record.date_time is None: + continue + if ( + ( + start_datetime is not None + and compare_datetimes(record.date_time, start_datetime).lt + ) + or ( + end_datetime is not None + and compare_datetimes(record.date_time, end_datetime).ge + ) + or (start_datetime is None and end_datetime is None) + ): + retained_records.append(record) + self.records = retained_records + + def key_delete_by_datetime( + self, + key: str, + start_datetime: Optional[DateTime] = None, + end_datetime: Optional[DateTime] = None, + ) -> None: + """Delete an attribute specified by `key` from records in the sequence within a given datetime range. + + This method removes the attribute identified by `key` from records that have a `date_time` value falling + within the specified `start_datetime` (inclusive) and `end_datetime` (exclusive) range. + + - If only `start_datetime` is specified, attributes will be removed from records from that date onward. + - If only `end_datetime` is specified, attributes will be removed from records up to that date. + - If neither `start_datetime` nor `end_datetime` is given, the attribute will be removed from all records. + + Args: + key (str): The attribute name to delete from each record. + start_datetime (datetime, optional): The start datetime to begin attribute deletion (inclusive). + end_datetime (datetime, optional): The end datetime to stop attribute deletion (exclusive). + + Raises: + KeyError: If `key` is not a valid attribute of the records. + """ + self._validate_key_writable(key) + # Ensure datetime objects are normalized + start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None + end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + + for record in self.records: + if ( + start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge + ) and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt): + del record[key] + + def filter_by_datetime( + self, start_datetime: Optional[DateTime] = None, end_datetime: Optional[DateTime] = None + ) -> "DataSequence": + """Returns a new DataSequence object containing only records within the specified datetime range. + + Args: + start_datetime (Optional[datetime]): The start of the datetime range (inclusive). If None, no lower limit. + end_datetime (Optional[datetime]): The end of the datetime range (exclusive). If None, no upper limit. + + Returns: + DataSequence: A new DataSequence object with filtered records. + """ + # Ensure datetime objects are normalized + start_datetime = to_datetime(start_datetime, to_maxtime=False) if start_datetime else None + end_datetime = to_datetime(end_datetime, to_maxtime=False) if end_datetime else None + + filtered_records = [ + record + for record in self.records + if (start_datetime is None or compare_datetimes(record.date_time, start_datetime).ge) + and (end_datetime is None or compare_datetimes(record.date_time, end_datetime).lt) + ] + return self.__class__(records=filtered_records) + + +class DataProvider(SingletonMixin, DataSequence): + """Abstract base class for data providers with singleton thread-safety and configurable data parameters. + + This class serves as a base for managing generic data, providing an interface for derived + classes to maintain a single instance across threads. It offers attributes for managing + data and historical data retention. + + Note: + Derived classes have to provide their own records field with correct record type set. + """ + + update_datetime: Optional[AwareDatetime] = Field( + None, description="Latest update datetime for generic data" + ) + + @abstractmethod + def provider_id(self) -> str: + """Return the unique identifier for the data provider. + + To be implemented by derived classes. + """ + return "DataProvider" + + @abstractmethod + def enabled(self) -> bool: + """Return True if the provider is enabled according to configuration. + + To be implemented by derived classes. + """ + return self.provider_id() == self.config.abstract_provider + + @abstractmethod + def _update_data(self, force_update: Optional[bool] = False) -> None: + """Abstract method for custom data update logic, to be implemented by derived classes. + + Args: + force_update (bool, optional): If True, forces the provider to update the data even if still cached. + """ + pass + + def update_data( + self, + force_enable: Optional[bool] = False, + force_update: Optional[bool] = False, + ) -> None: + """Calls the custom update function if enabled or forced. + + Args: + force_enable (bool, optional): If True, forces the update even if the provider is disabled. + force_update (bool, optional): If True, forces the provider to update the data even if still cached. + """ + # Check after configuration is updated. + if not force_enable and not self.enabled(): + return + + # Call the custom update logic + self._update_data(force_update=force_update) + + # Assure records are sorted. + self.sort_by_datetime() + + +class DataImportProvider(DataProvider): + """Abstract base class for data providers that import generic data. + + This class is designed to handle generic data provided in the form of a key-value dictionary. + - **Keys**: Represent identifiers from the record keys of a specific data. + - **Values**: Are lists of data values starting at a specified `start_datetime`, where + each value corresponds to a subsequent time interval (e.g., hourly). + + Subclasses must implement the logic for managing generic data based on the imported records. + """ + + def import_datetimes(self, value_count: int) -> List[Tuple[DateTime, int]]: + """Generates a list of tuples containing timestamps and their corresponding value indices. + + The function accounts for daylight saving time (DST) transitions: + - During a spring forward transition (e.g., DST begins), skipped hours are omitted. + - During a fall back transition (e.g., DST ends), repeated hours are included, + but they share the same value index. + + Args: + value_count (int): The number of timestamps to generate. + + Returns: + List[Tuple[DateTime, int]]: + A list of tuples, where each tuple contains: + - A `DateTime` object representing an hourly step from `start_datetime`. + - An integer value index corresponding to the logical hour. + + Behavior: + - Skips invalid timestamps during DST spring forward transitions. + - Includes both instances of repeated timestamps during DST fall back transitions. + - Ensures the list contains exactly `value_count` entries. + + Example: + >>> start_datetime = pendulum.datetime(2024, 11, 3, 0, 0, tz="America/New_York") + >>> import_datetimes(5) + [(DateTime(2024, 11, 3, 0, 0, tzinfo=Timezone('America/New_York')), 0), + (DateTime(2024, 11, 3, 1, 0, tzinfo=Timezone('America/New_York')), 1), + (DateTime(2024, 11, 3, 1, 0, tzinfo=Timezone('America/New_York')), 1), # Repeated hour + (DateTime(2024, 11, 3, 2, 0, tzinfo=Timezone('America/New_York')), 2), + (DateTime(2024, 11, 3, 3, 0, tzinfo=Timezone('America/New_York')), 3)] + """ + timestamps_with_indices: List[Tuple[DateTime, int]] = [] + value_datetime = self.start_datetime + value_index = 0 + + while value_index < value_count: + i = len(timestamps_with_indices) + logger.debug(f"{i}: Insert at {value_datetime} with index {value_index}") + timestamps_with_indices.append((value_datetime, value_index)) + + # Check if there is a DST transition + next_time = value_datetime.add(hours=1) + if next_time <= value_datetime: + # Check if there is a DST transition (i.e., ambiguous time during fall back) + # Repeat the hour value (reuse value index) + value_datetime = next_time + logger.debug(f"{i+1}: Repeat at {value_datetime} with index {value_index}") + timestamps_with_indices.append((value_datetime, value_index)) + elif next_time.hour != value_datetime.hour + 1 and value_datetime.hour != 23: + # Skip the hour value (spring forward in value index) + value_index += 1 + logger.debug(f"{i+1}: Skip at {next_time} with index {value_index}") + + # Increment value index and value_datetime for new hour + value_index += 1 + value_datetime = value_datetime.add(hours=1) + + return timestamps_with_indices + + def import_from_json(self, json_str: str, key_prefix: str = "") -> None: + """Updates generic data by importing it from a JSON string. + + This method reads generic data from a JSON string, matches keys based on the + record keys and the provided `key_prefix`, and updates the data values sequentially, + starting from the `start_datetime`. Each data value is associated with an hourly + interval. + + Args: + json_str (str): The JSON string containing the generic data. + key_prefix (str, optional): A prefix to filter relevant keys from the generic data. + Only keys starting with this prefix will be considered. Defaults to an empty string. + + Raises: + JSONDecodeError: If the file content is not valid JSON. + + Example: + Given a JSON string with the following content: + ```json + { + "load0_mean": [20.5, 21.0, 22.1], + "load1_mean": [50, 55, 60] + } + ``` + and `key_prefix = "load1"`, only the "load1_mean" key will be processed even though + both keys are in the record. + """ + import_data = json.loads(json_str) + for key in self.record_keys_writable: + if key.startswith(key_prefix) and key in import_data: + value_list = import_data[key] + value_datetime_mapping = self.import_datetimes(len(value_list)) + for value_datetime, value_index in value_datetime_mapping: + self.update_value(value_datetime, key, value_list[value_index]) + + def import_from_file(self, import_file_path: Path, key_prefix: str = "") -> None: + """Updates generic data by importing it from a file. + + This method reads generic data from a JSON file, matches keys based on the + record keys and the provided `key_prefix`, and updates the data values sequentially, + starting from the `start_datetime`. Each data value is associated with an hourly + interval. + + Args: + import_file_path (Path): The path to the JSON file containing the generic data. + key_prefix (str, optional): A prefix to filter relevant keys from the generic data. + Only keys starting with this prefix will be considered. Defaults to an empty string. + + Raises: + FileNotFoundError: If the specified file does not exist. + JSONDecodeError: If the file content is not valid JSON. + + Example: + Given a JSON file with the following content: + ```json + { + "load0_mean": [20.5, 21.0, 22.1], + "load1_mean": [50, 55, 60] + } + ``` + and `key_prefix = "load1"`, only the "load1_mean" key will be processed even though + both keys are in the record. + """ + with import_file_path.open("r") as import_file: + import_str = import_file.read() + self.import_from_json(import_str, key_prefix) + + +class DataContainer(SingletonMixin, DataBase, MutableMapping): + """A container for managing multiple DataProvider instances. + + This class enables access to data from multiple data providers, supporting retrieval and + aggregation of their data as Pandas Series objects. It acts as a dictionary-like structure + where each key represents a specific data field, and the value is a Pandas Series containing + combined data from all DataProvider instances for that key. + + Note: + Derived classes have to provide their own providers field with correct provider type set. + """ + + # To be overloaded by derived classes. + providers: List[DataProvider] = Field( + default_factory=list, description="List of data providers" + ) + + @field_validator("providers", mode="after") + def check_providers(cls, value: List[DataProvider]) -> List[DataProvider]: + # Check each item in the list + for item in value: + if not isinstance(item, DataProvider): + raise TypeError( + f"Each item in the providers list must be a DataProvider, got {type(item).__name__}" + ) + return value + + def __getitem__(self, key: str) -> pd.Series: + """Retrieve a Pandas Series for a specified key from the data in each DataProvider. + + Iterates through providers to find and return the first available Series for the specified key. + + Args: + key (str): The field name to retrieve, representing a data attribute in DataRecords. + + Returns: + pd.Series: A Pandas Series containing aggregated data for the specified key. + + Raises: + KeyError: If no provider contains data for the specified key. + """ + series = None + for provider in self.providers: + try: + series = provider.key_to_series(key) + break + except KeyError: + continue + + if series is None: + raise KeyError(f"No data found for key '{key}'.") + + return series + + def __setitem__(self, key: str, value: pd.Series) -> None: + """Add or merge a Pandas Series for a specified key into the records of an appropriate provider. + + Attempts to update or insert the provided Series data in each provider. If no provider supports + the specified key, an error is raised. + + Args: + key (str): The field name to update, representing a data attribute in DataRecords. + value (pd.Series): A Pandas Series containing data for the specified key. + + Raises: + ValueError: If `value` is not an instance of `pd.Series`. + KeyError: If no provider supports the specified key. + """ + if not isinstance(value, pd.Series): + raise ValueError("Value must be an instance of pd.Series.") + + for provider in self.providers: + try: + provider.key_from_series(key, value) + break + except KeyError: + continue + else: + raise KeyError(f"Key '{key}' not found in any provider.") + + def __delitem__(self, key: str) -> None: + """Set the value of the specified key in the data records of each provider to None. + + Args: + key (str): The field name in DataRecords to clear. + + Raises: + KeyError: If the key is not found in any provider. + """ + for provider in self.providers: + try: + provider.key_delete_by_datetime(key) + break + except KeyError: + continue + else: + raise KeyError(f"Key '{key}' not found in any provider.") + + def __iter__(self) -> Iterator[str]: + """Return an iterator over all unique keys available across providers. + + Returns: + Iterator[str]: An iterator over the unique keys from all providers. + """ + return iter(set(chain.from_iterable(provider.record_keys for provider in self.providers))) + + def __len__(self) -> int: + """Return the number of keys in the container. + + Returns: + int: The total number of keys in this container. + """ + return len(list(chain.from_iterable(provider.record_keys for provider in self.providers))) + + def __repr__(self) -> str: + """Provide a string representation of the DataContainer instance. + + Returns: + str: A string representing the container and its contained providers. + """ + return f"{self.__class__.__name__}({self.providers})" + + def update_data( + self, + force_enable: Optional[bool] = False, + force_update: Optional[bool] = False, + ) -> None: + """Update data. + + Args: + force_enable (bool, optional): If True, forces the update even if a provider is disabled. + force_update (bool, optional): If True, forces the providers to update the data even if still cached. + """ + for provider in self.providers: + provider.update_data(force_enable=force_enable, force_update=force_update) + + def provider_by_id(self, provider_id: str) -> DataProvider: + """Retrieves a data provider by its unique identifier. + + This method searches through the list of available providers and + returns the first provider whose `provider_id` matches the given + `provider_id`. If no matching provider is found, the method returns `None`. + + Args: + provider_id (str): The unique identifier of the desired data provider. + + Returns: + DataProvider: The data provider matching the given `provider_id`. + + Raises: + ValueError if provider id is unknown. + + Example: + provider = data.provider_by_id("WeatherImport") + """ + providers = {provider.provider_id(): provider for provider in self.providers} + if provider_id not in providers: + error_msg = f"Unknown provider id: '{provider_id}' of '{providers.keys()}'." + logger.error(error_msg) + raise ValueError(error_msg) + return providers[provider_id] diff --git a/src/akkudoktoreos/core/pydantic.py b/src/akkudoktoreos/core/pydantic.py new file mode 100644 index 0000000..2dbc9aa --- /dev/null +++ b/src/akkudoktoreos/core/pydantic.py @@ -0,0 +1,226 @@ +"""Module for managing and serializing Pydantic-based models with custom support. + +This module introduces the `PydanticBaseModel` class, which extends Pydantic’s `BaseModel` to facilitate +custom serialization and deserialization for `pendulum.DateTime` objects. The main features include +automatic handling of `pendulum.DateTime` fields, custom serialization to ISO 8601 format, and utility +methods for converting model instances to and from dictionary and JSON formats. + +Key Classes: + - PendulumDateTime: A custom type adapter that provides serialization and deserialization + functionality for `pendulum.DateTime` objects, converting them to ISO 8601 strings and back. + - PydanticBaseModel: A base model class for handling prediction records or configuration data + with automatic Pendulum DateTime handling and additional methods for JSON and dictionary + conversion. + +Classes: + PendulumDateTime(TypeAdapter[pendulum.DateTime]): Type adapter for `pendulum.DateTime` fields + with ISO 8601 serialization. Includes: + - serialize: Converts `pendulum.DateTime` instances to ISO 8601 string. + - deserialize: Converts ISO 8601 strings to `pendulum.DateTime` instances. + - is_iso8601: Validates if a string matches the ISO 8601 date format. + + PydanticBaseModel(BaseModel): Extends `pydantic.BaseModel` to handle `pendulum.DateTime` fields + and adds convenience methods for dictionary and JSON serialization. Key methods: + - model_dump: Dumps the model, converting `pendulum.DateTime` fields to ISO 8601. + - model_construct: Constructs a model instance with automatic deserialization of + `pendulum.DateTime` fields from ISO 8601. + - to_dict: Serializes the model instance to a dictionary. + - from_dict: Constructs a model instance from a dictionary. + - to_json: Converts the model instance to a JSON string. + - from_json: Creates a model instance from a JSON string. + +Usage Example: + # Define custom settings in a model using PydanticBaseModel + class PredictionCommonSettings(PydanticBaseModel): + prediction_start: pendulum.DateTime = Field(...) + + # Serialize a model instance to a dictionary or JSON + config = PredictionCommonSettings(prediction_start=pendulum.now()) + config_dict = config.to_dict() + config_json = config.to_json() + + # Deserialize from dictionary or JSON + new_config = PredictionCommonSettings.from_dict(config_dict) + restored_config = PredictionCommonSettings.from_json(config_json) + +Dependencies: + - `pendulum`: Required for handling timezone-aware datetime fields. + - `pydantic`: Required for model and validation functionality. + +Notes: + - This module enables custom handling of Pendulum DateTime fields within Pydantic models, + which is particularly useful for applications requiring consistent ISO 8601 datetime formatting + and robust timezone-aware datetime support. +""" + +import json +import re +from typing import Any, Type + +import pendulum +from pydantic import BaseModel, ConfigDict, TypeAdapter + + +# Custom type adapter for Pendulum DateTime fields +class PendulumDateTime(TypeAdapter[pendulum.DateTime]): + @classmethod + def serialize(cls, value: Any) -> str: + """Convert pendulum.DateTime to ISO 8601 string.""" + if isinstance(value, pendulum.DateTime): + return value.to_iso8601_string() + raise ValueError(f"Expected pendulum.DateTime, got {type(value)}") + + @classmethod + def deserialize(cls, value: Any) -> pendulum.DateTime: + """Convert ISO 8601 string to pendulum.DateTime.""" + if isinstance(value, str) and cls.is_iso8601(value): + try: + return pendulum.parse(value) + except pendulum.parsing.exceptions.ParserError as e: + raise ValueError(f"Invalid date format: {value}") from e + elif isinstance(value, pendulum.DateTime): + return value + raise ValueError(f"Expected ISO 8601 string or pendulum.DateTime, got {type(value)}") + + @staticmethod + def is_iso8601(value: str) -> bool: + """Check if the string is a valid ISO 8601 date string.""" + iso8601_pattern = ( + r"^(\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d{1,3})?(?:Z|[+-]\d{2}:\d{2})?)$" + ) + return bool(re.match(iso8601_pattern, value)) + + +class PydanticBaseModel(BaseModel): + """Base model class with automatic serialization and deserialization of `pendulum.DateTime` fields. + + This model serializes pendulum.DateTime objects to ISO 8601 strings and + deserializes ISO 8601 strings to pendulum.DateTime objects. + """ + + # Enable custom serialization globally in config + model_config = ConfigDict( + arbitrary_types_allowed=True, + use_enum_values=True, + validate_assignment=True, + ) + + # Override Pydantic’s serialization for all DateTime fields + def model_dump(self, *args: Any, **kwargs: Any) -> dict: + """Custom dump method to handle serialization for DateTime fields.""" + result = super().model_dump(*args, **kwargs) + for key, value in result.items(): + if isinstance(value, pendulum.DateTime): + result[key] = PendulumDateTime.serialize(value) + return result + + @classmethod + def model_construct(cls, data: dict) -> "PydanticBaseModel": + """Custom constructor to handle deserialization for DateTime fields.""" + for key, value in data.items(): + if isinstance(value, str) and PendulumDateTime.is_iso8601(value): + data[key] = PendulumDateTime.deserialize(value) + return super().model_construct(data) + + def reset(self) -> "PydanticBaseModel": + """Resets all optional fields in the model to None. + + Iterates through all model fields and sets any optional (non-required) + fields to None. The modification is done in-place on the current instance. + + Returns: + PydanticBaseModel: The current instance with all optional fields + reset to None. + + Example: + >>> settings = PydanticBaseModel(name="test", optional_field="value") + >>> settings.reset() + >>> assert settings.optional_field is None + """ + for field_name, field in self.model_fields.items(): + if field.is_required is False: # Check if field is optional + setattr(self, field_name, None) + return self + + def to_dict(self) -> dict: + """Convert this PredictionRecord instance to a dictionary representation. + + Returns: + dict: A dictionary where the keys are the field names of the PydanticBaseModel, + and the values are the corresponding field values. + """ + return self.model_dump() + + @classmethod + def from_dict(cls: Type["PydanticBaseModel"], data: dict) -> "PydanticBaseModel": + """Create a PydanticBaseModel instance from a dictionary. + + Args: + data (dict): A dictionary containing data to initialize the PydanticBaseModel. + Keys should match the field names defined in the model. + + Returns: + PydanticBaseModel: An instance of the PydanticBaseModel populated with the data. + + Notes: + Works with derived classes by ensuring the `cls` argument is used to instantiate the object. + """ + return cls.model_validate(data) + + @classmethod + def from_dict_with_reset(cls, data: dict | None = None) -> "PydanticBaseModel": + """Creates a new instance with reset optional fields, then updates from dict. + + First creates an instance with default values, resets all optional fields + to None, then updates the instance with the provided dictionary data if any. + + Args: + data (dict | None): Dictionary containing field values to initialize + the instance with. Defaults to None. + + Returns: + PydanticBaseModel: A new instance with all optional fields initially + reset to None and then updated with provided data. + + Example: + >>> data = {"name": "test", "optional_field": "value"} + >>> settings = PydanticBaseModel.from_dict_with_reset(data) + >>> # All non-specified optional fields will be None + """ + # Create instance with model defaults + instance = cls() + + # Reset all optional fields to None + instance.reset() + + # Update with provided data if any + if data: + # Use model_validate to ensure proper type conversion and validation + updated_instance = instance.model_validate({**instance.model_dump(), **data}) + return updated_instance + + return instance + + def to_json(self) -> str: + """Convert the PydanticBaseModel instance to a JSON string. + + Returns: + str: The JSON representation of the instance. + """ + return self.model_dump_json() + + @classmethod + def from_json(cls: Type["PydanticBaseModel"], json_str: str) -> "PydanticBaseModel": + """Create an instance of the PydanticBaseModel class or its subclass from a JSON string. + + Args: + json_str (str): JSON string to parse and convert into a PydanticBaseModel instance. + + Returns: + PydanticBaseModel: A new instance of the class, populated with data from the JSON string. + + Notes: + Works with derived classes by ensuring the `cls` argument is used to instantiate the object. + """ + data = json.loads(json_str) + return cls.model_validate(data) diff --git a/tests/test_dataabc.py b/tests/test_dataabc.py new file mode 100644 index 0000000..e1fa656 --- /dev/null +++ b/tests/test_dataabc.py @@ -0,0 +1,635 @@ +from datetime import datetime, timezone +from typing import Any, ClassVar, List, Optional, Union + +import pandas as pd +import pendulum +import pytest +from pydantic import Field, ValidationError + +from akkudoktoreos.config.configabc import SettingsBaseModel +from akkudoktoreos.core.dataabc import ( + DataBase, + DataContainer, + DataImportProvider, + DataProvider, + DataRecord, + DataSequence, +) +from akkudoktoreos.core.ems import get_ems +from akkudoktoreos.utils.datetimeutil import compare_datetimes, to_datetime, to_duration + +# Derived classes for testing +# --------------------------- + + +class DerivedConfig(SettingsBaseModel): + env_var: Optional[int] = Field(default=None, description="Test config by environment var") + instance_field: Optional[str] = Field(default=None, description="Test config by instance field") + class_constant: Optional[int] = Field(default=None, description="Test config by class constant") + + +class DerivedBase(DataBase): + instance_field: Optional[str] = Field(default=None, description="Field Value") + class_constant: ClassVar[int] = 30 + + +class DerivedRecord(DataRecord): + data_value: Optional[float] = Field(default=None, description="Data Value") + + +class DerivedSequence(DataSequence): + # overload + records: List[DerivedRecord] = Field( + default_factory=list, description="List of DerivedRecord records" + ) + + @classmethod + def record_class(cls) -> Any: + return DerivedRecord + + +class DerivedDataProvider(DataProvider): + """A concrete subclass of DataProvider for testing purposes.""" + + # overload + records: List[DerivedRecord] = Field( + default_factory=list, description="List of DerivedRecord records" + ) + provider_enabled: ClassVar[bool] = False + provider_updated: ClassVar[bool] = False + + @classmethod + def record_class(cls) -> Any: + return DerivedRecord + + # Implement abstract methods for test purposes + def provider_id(self) -> str: + return "DerivedDataProvider" + + def enabled(self) -> bool: + return self.provider_enabled + + def _update_data(self, force_update: Optional[bool] = False) -> None: + # Simulate update logic + DerivedDataProvider.provider_updated = True + + +class DerivedDataImportProvider(DataImportProvider): + """A concrete subclass of DataImportProvider for testing purposes.""" + + # overload + records: List[DerivedRecord] = Field( + default_factory=list, description="List of DerivedRecord records" + ) + provider_enabled: ClassVar[bool] = False + provider_updated: ClassVar[bool] = False + + @classmethod + def record_class(cls) -> Any: + return DerivedRecord + + # Implement abstract methods for test purposes + def provider_id(self) -> str: + return "DerivedDataImportProvider" + + def enabled(self) -> bool: + return self.provider_enabled + + def _update_data(self, force_update: Optional[bool] = False) -> None: + # Simulate update logic + DerivedDataImportProvider.provider_updated = True + + +class DerivedDataContainer(DataContainer): + providers: List[Union[DerivedDataProvider, DataProvider]] = Field( + default_factory=list, description="List of data providers" + ) + + +# Tests +# ---------- + + +class TestDataBase: + @pytest.fixture + def base(self, reset_config, monkeypatch): + # Provide default values for configuration + derived = DerivedBase() + derived.config.update() + return derived + + def test_get_config_value_key_error(self, base): + with pytest.raises(AttributeError): + base.config.non_existent_key + + +class TestDataRecord: + def create_test_record(self, date, value): + """Helper function to create a test DataRecord.""" + return DerivedRecord(date_time=date, data_value=value) + + def test_getitem(self): + record = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 10.0) + assert record["data_value"] == 10.0 + + def test_setitem(self): + record = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 10.0) + record["data_value"] = 20.0 + assert record.data_value == 20.0 + + def test_delitem(self): + record = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 10.0) + record.data_value = 20.0 + del record["data_value"] + assert record.data_value is None + + def test_len(self): + record = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 10.0) + record.date_time = None + record.data_value = 20.0 + assert len(record) == 2 + + def test_to_dict(self): + record = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 10.0) + record.data_value = 20.0 + record_dict = record.to_dict() + assert "data_value" in record_dict + assert record_dict["data_value"] == 20.0 + record2 = DerivedRecord.from_dict(record_dict) + assert record2 == record + + def test_to_json(self): + record = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 10.0) + record.data_value = 20.0 + json_str = record.to_json() + assert "data_value" in json_str + assert "20.0" in json_str + record2 = DerivedRecord.from_json(json_str) + assert record2 == record + + +class TestDataSequence: + @pytest.fixture + def sequence(self): + sequence0 = DerivedSequence() + assert len(sequence0) == 0 + return sequence0 + + @pytest.fixture + def sequence2(self): + sequence = DerivedSequence() + record1 = self.create_test_record(datetime(1970, 1, 1), 1970) + record2 = self.create_test_record(datetime(1971, 1, 1), 1971) + sequence.append(record1) + sequence.append(record2) + assert len(sequence) == 2 + return sequence + + def create_test_record(self, date, value): + """Helper function to create a test DataRecord.""" + return DerivedRecord(date_time=date, data_value=value) + + # Test cases + def test_getitem(self, sequence): + assert len(sequence) == 0 + record = self.create_test_record("2024-01-01 00:00:00", 0) + sequence.insert_by_datetime(record) + assert isinstance(sequence[0], DerivedRecord) + + def test_setitem(self, sequence2): + new_record = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 1) + sequence2[0] = new_record + assert sequence2[0].date_time == datetime(2024, 1, 3, tzinfo=timezone.utc) + + def test_set_record_at_index(self, sequence2): + record1 = self.create_test_record(datetime(2024, 1, 3, tzinfo=timezone.utc), 1) + record2 = self.create_test_record(datetime(2023, 11, 5), 0.8) + sequence2[1] = record1 + assert sequence2[1].date_time == datetime(2024, 1, 3, tzinfo=timezone.utc) + sequence2[0] = record2 + assert len(sequence2) == 2 + assert sequence2[0] == record2 + + def test_insert_duplicate_date_record(self, sequence): + record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) + record2 = self.create_test_record(datetime(2023, 11, 5), 0.9) # Duplicate date + sequence.insert_by_datetime(record1) + sequence.insert_by_datetime(record2) + assert len(sequence) == 1 + assert sequence[0].data_value == 0.9 # Record should have merged with new value + + def test_sort_by_datetime_ascending(self, sequence): + """Test sorting records in ascending order by date_time.""" + records = [ + self.create_test_record(pendulum.datetime(2024, 11, 1), 0.7), + self.create_test_record(pendulum.datetime(2024, 10, 1), 0.8), + self.create_test_record(pendulum.datetime(2024, 12, 1), 0.9), + ] + for i, record in enumerate(records): + sequence.insert(i, record) + sequence.sort_by_datetime() + sorted_dates = [record.date_time for record in sequence.records] + for i, expected_date in enumerate( + [ + pendulum.datetime(2024, 10, 1), + pendulum.datetime(2024, 11, 1), + pendulum.datetime(2024, 12, 1), + ] + ): + assert compare_datetimes(sorted_dates[i], expected_date).equal + + def test_sort_by_datetime_descending(self, sequence): + """Test sorting records in descending order by date_time.""" + records = [ + self.create_test_record(pendulum.datetime(2024, 11, 1), 0.7), + self.create_test_record(pendulum.datetime(2024, 10, 1), 0.8), + self.create_test_record(pendulum.datetime(2024, 12, 1), 0.9), + ] + for i, record in enumerate(records): + sequence.insert(i, record) + sequence.sort_by_datetime(reverse=True) + sorted_dates = [record.date_time for record in sequence.records] + for i, expected_date in enumerate( + [ + pendulum.datetime(2024, 12, 1), + pendulum.datetime(2024, 11, 1), + pendulum.datetime(2024, 10, 1), + ] + ): + assert compare_datetimes(sorted_dates[i], expected_date).equal + + def test_sort_by_datetime_with_none(self, sequence): + """Test sorting records when some date_time values are None.""" + records = [ + self.create_test_record(pendulum.datetime(2024, 11, 1), 0.7), + self.create_test_record(pendulum.datetime(2024, 10, 1), 0.8), + self.create_test_record(pendulum.datetime(2024, 12, 1), 0.9), + ] + for i, record in enumerate(records): + sequence.insert(i, record) + sequence.records[2].date_time = None + assert sequence.records[2].date_time is None + sequence.sort_by_datetime() + sorted_dates = [record.date_time for record in sequence.records] + for i, expected_date in enumerate( + [ + None, # None values should come first + pendulum.datetime(2024, 10, 1), + pendulum.datetime(2024, 11, 1), + ] + ): + if expected_date is None: + assert sorted_dates[i] is None + else: + assert compare_datetimes(sorted_dates[i], expected_date).equal + + def test_sort_by_datetime_error_on_uncomparable(self, sequence): + """Test error is raised when date_time contains uncomparable values.""" + records = [ + self.create_test_record(pendulum.datetime(2024, 11, 1), 0.7), + self.create_test_record(pendulum.datetime(2024, 12, 1), 0.9), + self.create_test_record(pendulum.datetime(2024, 10, 1), 0.8), + ] + for i, record in enumerate(records): + sequence.insert(i, record) + with pytest.raises( + ValidationError, match="Date string not_a_datetime does not match any known formats." + ): + sequence.records[2].date_time = "not_a_datetime" # Invalid date_time + sequence.sort_by_datetime() + + def test_key_to_series(self, sequence): + record = self.create_test_record(datetime(2023, 11, 6), 0.8) + sequence.append(record) + series = sequence.key_to_series("data_value") + assert isinstance(series, pd.Series) + assert series[to_datetime(datetime(2023, 11, 6))] == 0.8 + + def test_key_from_series(self, sequence): + series = pd.Series( + data=[0.8, 0.9], index=pd.to_datetime([datetime(2023, 11, 5), datetime(2023, 11, 6)]) + ) + sequence.key_from_series("data_value", series) + assert len(sequence) == 2 + assert sequence[0].data_value == 0.8 + assert sequence[1].data_value == 0.9 + + def test_to_datetimeindex(self, sequence2): + record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) + record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) + sequence2.insert(0, record1) + sequence2.insert(1, record2) + dt_index = sequence2.to_datetimeindex() + assert isinstance(dt_index, pd.DatetimeIndex) + assert dt_index[0] == to_datetime(datetime(2023, 11, 5)) + assert dt_index[1] == to_datetime(datetime(2023, 11, 6)) + + def test_delete_by_datetime_range(self, sequence): + record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) + record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) + record3 = self.create_test_record(datetime(2023, 11, 7), 1.0) + sequence.append(record1) + sequence.append(record2) + sequence.append(record3) + assert len(sequence) == 3 + sequence.delete_by_datetime( + start_datetime=datetime(2023, 11, 6), end_datetime=datetime(2023, 11, 7) + ) + assert len(sequence) == 2 + assert sequence[0].date_time == to_datetime(datetime(2023, 11, 5)) + assert sequence[1].date_time == to_datetime(datetime(2023, 11, 7)) + + def test_delete_by_datetime_start(self, sequence): + record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) + record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) + sequence.append(record1) + sequence.append(record2) + assert len(sequence) == 2 + sequence.delete_by_datetime(start_datetime=datetime(2023, 11, 6)) + assert len(sequence) == 1 + assert sequence[0].date_time == to_datetime(datetime(2023, 11, 5)) + + def test_delete_by_datetime_end(self, sequence): + record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) + record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) + sequence.append(record1) + sequence.append(record2) + assert len(sequence) == 2 + sequence.delete_by_datetime(end_datetime=datetime(2023, 11, 6)) + assert len(sequence) == 1 + assert sequence[0].date_time == to_datetime(datetime(2023, 11, 6)) + + def test_filter_by_datetime(self, sequence): + record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) + record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) + sequence.append(record1) + sequence.append(record2) + filtered_sequence = sequence.filter_by_datetime(start_datetime=datetime(2023, 11, 6)) + assert len(filtered_sequence) == 1 + assert filtered_sequence[0].date_time == to_datetime(datetime(2023, 11, 6)) + + def test_to_dict(self, sequence): + record = self.create_test_record(datetime(2023, 11, 6), 0.8) + sequence.append(record) + data_dict = sequence.to_dict() + assert isinstance(data_dict, dict) + sequence_other = sequence.from_dict(data_dict) + assert sequence_other == sequence + + def test_to_json(self, sequence): + record = self.create_test_record(datetime(2023, 11, 6), 0.8) + sequence.append(record) + json_str = sequence.to_json() + assert isinstance(json_str, str) + assert "2023-11-06" in json_str + assert ":0.8" in json_str + + def test_from_json(self, sequence, sequence2): + json_str = sequence2.to_json() + sequence = sequence.from_json(json_str) + assert len(sequence) == len(sequence2) + assert sequence[0].date_time == sequence2[0].date_time + assert sequence[0].data_value == sequence2[0].data_value + + def test_key_to_dict(self, sequence): + record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) + record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) + sequence.append(record1) + sequence.append(record2) + data_dict = sequence.key_to_dict("data_value") + assert isinstance(data_dict, dict) + assert data_dict[to_datetime(datetime(2023, 11, 5), as_string=True)] == 0.8 + assert data_dict[to_datetime(datetime(2023, 11, 6), as_string=True)] == 0.9 + + def test_key_to_lists(self, sequence): + record1 = self.create_test_record(datetime(2023, 11, 5), 0.8) + record2 = self.create_test_record(datetime(2023, 11, 6), 0.9) + sequence.append(record1) + sequence.append(record2) + dates, values = sequence.key_to_lists("data_value") + assert dates == [to_datetime(datetime(2023, 11, 5)), to_datetime(datetime(2023, 11, 6))] + assert values == [0.8, 0.9] + + +class TestDataProvider: + # Fixtures and helper functions + @pytest.fixture + def provider(self): + """Fixture to provide an instance of TestDataProvider for testing.""" + DerivedDataProvider.provider_enabled = True + DerivedDataProvider.provider_updated = False + return DerivedDataProvider() + + @pytest.fixture + def sample_start_datetime(self): + """Fixture for a sample start datetime.""" + return to_datetime(datetime(2024, 11, 1, 12, 0)) + + def create_test_record(self, date, value): + """Helper function to create a test DataRecord.""" + return DerivedRecord(date_time=date, data_value=value) + + # Tests + + def test_singleton_behavior(self, provider): + """Test that DataProvider enforces singleton behavior.""" + instance1 = provider + instance2 = DerivedDataProvider() + assert ( + instance1 is instance2 + ), "Singleton pattern is not enforced; instances are not the same." + + def test_update_method_with_defaults(self, provider, sample_start_datetime, monkeypatch): + """Test the `update` method with default parameters.""" + ems_eos = get_ems() + + ems_eos.start_datetime = sample_start_datetime + provider.update_data() + + assert provider.start_datetime == sample_start_datetime + + def test_update_method_force_enable(self, provider, monkeypatch): + """Test that `update` executes when `force_enable` is True, even if `enabled` is False.""" + # Override enabled to return False for this test + DerivedDataProvider.provider_enabled = False + DerivedDataProvider.provider_updated = False + provider.update_data(force_enable=True) + assert provider.enabled() is False, "Provider should be disabled, but enabled() is True." + assert ( + DerivedDataProvider.provider_updated is True + ), "Provider should have been executed, but was not." + + def test_delete_by_datetime(self, provider, sample_start_datetime): + """Test `delete_by_datetime` method for removing records by datetime range.""" + # Add records to the provider for deletion testing + provider.records = [ + self.create_test_record(sample_start_datetime - to_duration("3 hours"), 1), + self.create_test_record(sample_start_datetime - to_duration("1 hour"), 2), + self.create_test_record(sample_start_datetime + to_duration("1 hour"), 3), + ] + + provider.delete_by_datetime( + start_datetime=sample_start_datetime - to_duration("2 hours"), + end_datetime=sample_start_datetime + to_duration("2 hours"), + ) + assert ( + len(provider.records) == 1 + ), "Only one record should remain after deletion by datetime." + assert provider.records[0].date_time == sample_start_datetime - to_duration( + "3 hours" + ), "Unexpected record remains." + + +class TestDataImportProvider: + # Fixtures and helper functions + @pytest.fixture + def provider(self): + """Fixture to provide an instance of DerivedDataImportProvider for testing.""" + DerivedDataImportProvider.provider_enabled = True + DerivedDataImportProvider.provider_updated = False + return DerivedDataImportProvider() + + @pytest.mark.parametrize( + "start_datetime, value_count, expected_mapping_count", + [ + ("2024-11-10 00:00:00", 24, 24), # No DST in Germany + ("2024-08-10 00:00:00", 24, 24), # DST in Germany + ("2024-03-31 00:00:00", 24, 23), # DST change in Germany (23 hours/ day) + ("2024-10-27 00:00:00", 24, 25), # DST change in Germany (25 hours/ day) + ], + ) + def test_import_datetimes(self, provider, start_datetime, value_count, expected_mapping_count): + ems_eos = get_ems() + ems_eos.start_datetime = to_datetime(start_datetime, in_timezone="Europe/Berlin") + + value_datetime_mapping = provider.import_datetimes(value_count) + + assert len(value_datetime_mapping) == expected_mapping_count + + @pytest.mark.parametrize( + "start_datetime, value_count, expected_mapping_count", + [ + ("2024-11-10 00:00:00", 24, 24), # No DST in Germany + ("2024-08-10 00:00:00", 24, 24), # DST in Germany + ("2024-03-31 00:00:00", 24, 23), # DST change in Germany (23 hours/ day) + ("2024-10-27 00:00:00", 24, 25), # DST change in Germany (25 hours/ day) + ], + ) + def test_import_datetimes_utc( + self, set_other_timezone, provider, start_datetime, value_count, expected_mapping_count + ): + original_tz = set_other_timezone("Etc/UTC") + ems_eos = get_ems() + ems_eos.start_datetime = to_datetime(start_datetime, in_timezone="Europe/Berlin") + assert ems_eos.start_datetime.timezone.name == "Europe/Berlin" + + value_datetime_mapping = provider.import_datetimes(value_count) + + assert len(value_datetime_mapping) == expected_mapping_count + + +class TestDataContainer: + # Fixture and helpers + @pytest.fixture + def container(self): + container = DerivedDataContainer() + return container + + @pytest.fixture + def container_with_providers(self): + record1 = self.create_test_record(datetime(2023, 11, 5), 1) + record2 = self.create_test_record(datetime(2023, 11, 6), 2) + record3 = self.create_test_record(datetime(2023, 11, 7), 3) + provider = DerivedDataProvider() + provider.clear() + assert len(provider) == 0 + provider.append(record1) + provider.append(record2) + provider.append(record3) + assert len(provider) == 3 + container = DerivedDataContainer() + container.providers.clear() + assert len(container.providers) == 0 + container.providers.append(provider) + assert len(container.providers) == 1 + return container + + def create_test_record(self, date, value): + """Helper function to create a test DataRecord.""" + return DerivedRecord(date_time=date, data_value=value) + + def test_append_provider(self, container): + assert len(container.providers) == 0 + container.providers.append(DerivedDataProvider()) + assert len(container.providers) == 1 + assert isinstance(container.providers[0], DerivedDataProvider) + + @pytest.mark.skip(reason="type check not implemented") + def test_append_provider_invalid_type(self, container): + with pytest.raises(ValueError, match="must be an instance of DataProvider"): + container.providers.append("not_a_provider") + + def test_getitem_existing_key(self, container_with_providers): + assert len(container_with_providers.providers) == 1 + # check all keys are available (don't care for position) + for key in ["data_value", "date_time"]: + assert key in list(container_with_providers.keys()) + series = container_with_providers["data_value"] + assert isinstance(series, pd.Series) + assert series.name == "data_value" + assert series.tolist() == [1.0, 2.0, 3.0] + + def test_getitem_non_existing_key(self, container_with_providers): + with pytest.raises(KeyError, match="No data found for key 'non_existent_key'"): + container_with_providers["non_existent_key"] + + def test_setitem_existing_key(self, container_with_providers): + new_series = container_with_providers["data_value"] + new_series[:] = [4, 5, 6] + container_with_providers["data_value"] = new_series + series = container_with_providers["data_value"] + assert series.name == "data_value" + assert series.tolist() == [4, 5, 6] + + def test_setitem_invalid_value(self, container_with_providers): + with pytest.raises(ValueError, match="Value must be an instance of pd.Series"): + container_with_providers["test_key"] = "not_a_series" + + def test_setitem_non_existing_key(self, container_with_providers): + new_series = pd.Series([4, 5, 6], name="non_existent_key") + with pytest.raises(KeyError, match="Key 'non_existent_key' not found"): + container_with_providers["non_existent_key"] = new_series + + def test_delitem_existing_key(self, container_with_providers): + del container_with_providers["data_value"] + series = container_with_providers["data_value"] + assert series.name == "data_value" + assert series.tolist() == [None, None, None] + + def test_delitem_non_existing_key(self, container_with_providers): + with pytest.raises(KeyError, match="Key 'non_existent_key' not found"): + del container_with_providers["non_existent_key"] + + def test_len(self, container_with_providers): + assert len(container_with_providers) == 3 + + def test_repr(self, container_with_providers): + representation = repr(container_with_providers) + assert representation.startswith("DerivedDataContainer(") + assert "DerivedDataProvider" in representation + + def test_to_json(self, container_with_providers): + json_str = container_with_providers.to_json() + container_other = DerivedDataContainer.from_json(json_str) + assert container_other == container_with_providers + + def test_from_json(self, container_with_providers): + json_str = container_with_providers.to_json() + container = DerivedDataContainer.from_json(json_str) + assert isinstance(container, DerivedDataContainer) + assert len(container.providers) == 1 + assert container.providers[0] == container_with_providers.providers[0] + + def test_provider_by_id(self, container_with_providers): + provider = container_with_providers.provider_by_id("DerivedDataProvider") + assert isinstance(provider, DerivedDataProvider)