diff --git a/src/omnipy/data/dataset.py b/src/omnipy/data/dataset.py index 7ea94505..f6ee470f 100644 --- a/src/omnipy/data/dataset.py +++ b/src/omnipy/data/dataset.py @@ -1,14 +1,12 @@ -from collections import UserDict +import asyncio +from collections import defaultdict, UserDict from collections.abc import Iterable, Mapping, MutableMapping from copy import copy import json import os import tarfile -from tempfile import TemporaryDirectory -from typing import Any, Callable, cast, Generic, Iterator -from urllib.parse import ParseResult, urlparse +from typing import Any, Callable, cast, Generic, Iterator, TYPE_CHECKING -# from orjson import orjson from pydantic import Field, PrivateAttr, root_validator, ValidationError from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.generics import GenericModel @@ -30,9 +28,16 @@ prepare_selected_items_with_mapping_data, select_keys) from omnipy.util.decorators import call_super_if_available -from omnipy.util.helpers import get_default_if_typevar, is_iterable, remove_forward_ref_notation +from omnipy.util.helpers import (get_default_if_typevar, + get_event_loop_and_check_if_loop_is_running, + is_iterable, + remove_forward_ref_notation) from omnipy.util.web import download_file_to_memory +if TYPE_CHECKING: + from omnipy.modules.remote.datasets import HttpUrlDataset + from omnipy.modules.remote.models import HttpUrlModel + ModelT = TypeVar('ModelT', bound=Model) GeneralModelT = TypeVar('GeneralModelT', bound=Model) _DatasetT = TypeVar('_DatasetT') @@ -544,45 +549,87 @@ def save(self, path: str): tar.extractall(path=directory) tar.close() - def load(self, *path_or_urls: str, by_file_suffix=False): + def load(self, + paths_or_urls: 'str | Iterable[str] | HttpUrlModel | HttpUrlDataset', + by_file_suffix: bool = False) -> list[asyncio.Task] | None: + from omnipy import HttpUrlDataset, HttpUrlModel + + match paths_or_urls: + case HttpUrlDataset(): + return self._load_http_urls(paths_or_urls) + + case HttpUrlModel(): + return self._load_http_urls(HttpUrlDataset({str(paths_or_urls): paths_or_urls})) + + case str(): + try: + http_url_dataset = HttpUrlDataset({paths_or_urls: paths_or_urls}) + except ValidationError: + return self._load_paths([paths_or_urls], by_file_suffix) + return self._load_http_urls(http_url_dataset) + case Iterable(): + try: + path_or_url_iterable = cast(Iterable[str], paths_or_urls) + http_url_dataset = HttpUrlDataset( + zip(path_or_url_iterable, path_or_url_iterable)) + except ValidationError: + return self._load_paths(path_or_url_iterable, by_file_suffix) + return self._load_http_urls(http_url_dataset) + case _: + raise TypeError(f'"paths_or_urls" argument is of incorrect type. Type ' + f'{type(paths_or_urls)} is not supported.') + + def _load_http_urls(self, http_url_dataset: 'HttpUrlDataset') -> list[asyncio.Task]: + from omnipy.modules.remote.helpers import RateLimitingClientSession + from omnipy.modules.remote.tasks import get_json_from_api_endpoint + hosts: defaultdict[str, list[int]] = defaultdict(list) + for i, url in enumerate(http_url_dataset.values()): + hosts[url.host].append(i) + + async def load_all(): + tasks = [] + client_sessions = {} + for host in hosts: + client_sessions[host] = RateLimitingClientSession( + self.config.http_config_for_host[host].requests_per_time_period, + self.config.http_config_for_host[host].time_period_in_secs) + + for host, indices in hosts.items(): + task = ( + get_json_from_api_endpoint.refine(output_dataset_param='output_dataset').run( + http_url_dataset[indices], + client_session=client_sessions[host], + output_dataset=self)) + tasks.append(task) + + await asyncio.gather(*tasks) + return self + + loop, loop_is_running = get_event_loop_and_check_if_loop_is_running() + + if loop and loop_is_running: + return loop.create_task(load_all()) + else: + return asyncio.run(load_all()) + + def _load_paths(self, path_or_urls: Iterable[str], by_file_suffix: bool) -> None: for path_or_url in path_or_urls: - if is_model_instance(path_or_url): - path_or_url = path_or_url.contents - - with TemporaryDirectory() as tmp_dir_path: - serializer_registry = self._get_serializer_registry() - - parsed_url = urlparse(path_or_url) - - if parsed_url.scheme in ['http', 'https']: - download_path = self._download_file(path_or_url, parsed_url.path, tmp_dir_path) - if download_path is None: - continue - tar_gz_file_path = self._ensure_tar_gz_file(download_path) - elif parsed_url.scheme in ['file', '']: - tar_gz_file_path = self._ensure_tar_gz_file(parsed_url.path) - elif self._is_windows_path(parsed_url): - tar_gz_file_path = self._ensure_tar_gz_file(path_or_url) - else: - raise ValueError(f'Unsupported scheme "{parsed_url.scheme}"') - - if by_file_suffix: - loaded_dataset = \ - serializer_registry.load_from_tar_file_path_based_on_file_suffix( - self, tar_gz_file_path, self) - else: - loaded_dataset = \ - serializer_registry.load_from_tar_file_path_based_on_dataset_cls( - self, tar_gz_file_path, self) - if loaded_dataset is not None: - self.absorb(loaded_dataset) - continue - else: - raise RuntimeError('Unable to load serializer') + serializer_registry = self._get_serializer_registry() + tar_gz_file_path = self._ensure_tar_gz_file(path_or_url) - @staticmethod - def _is_windows_path(parsed_url: ParseResult) -> bool: - return len(parsed_url.scheme) == 1 and parsed_url.scheme.isalpha() + if by_file_suffix: + loaded_dataset = \ + serializer_registry.load_from_tar_file_path_based_on_file_suffix( + self, tar_gz_file_path, self) + else: + loaded_dataset = \ + serializer_registry.load_from_tar_file_path_based_on_dataset_cls( + self, tar_gz_file_path, self) + if loaded_dataset is not None: + self.absorb(loaded_dataset) + continue + else: + raise RuntimeError('Unable to load from serializer') @staticmethod def _download_file(url: str, path: str, tmp_dir_path: str) -> str | None: @@ -638,7 +685,7 @@ def __eq__(self, other: object) -> bool: and self.to_data() == other.to_data() # last is probably unnecessary, but just in case def __repr_args__(self): - return [(k, v.contents) for k, v in self.data.items()] + return [(k, v.contents) if is_model_instance(v) else (k, v) for k, v in self.data.items()] class MultiModelDataset(Dataset[GeneralModelT], Generic[GeneralModelT]):