Skip to content

Commit

Permalink
Implement async download of data from urls in Dataset.load(). Refacto…
Browse files Browse the repository at this point in the history
…r file-based loading
  • Loading branch information
sveinugu committed Nov 6, 2024
1 parent 47c2128 commit cefdda4
Showing 1 changed file with 91 additions and 44 deletions.
135 changes: 91 additions & 44 deletions src/omnipy/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from collections import UserDict
import asyncio
from collections import defaultdict, UserDict
from collections.abc import Iterable, Mapping, MutableMapping
from copy import copy
import json
import os
import tarfile
from tempfile import TemporaryDirectory
from typing import Any, Callable, cast, Generic, Iterator
from urllib.parse import ParseResult, urlparse
from typing import Any, Callable, cast, Generic, Iterator, TYPE_CHECKING

# from orjson import orjson
from pydantic import Field, PrivateAttr, root_validator, ValidationError
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.generics import GenericModel
Expand All @@ -30,9 +28,16 @@
prepare_selected_items_with_mapping_data,
select_keys)
from omnipy.util.decorators import call_super_if_available
from omnipy.util.helpers import get_default_if_typevar, is_iterable, remove_forward_ref_notation
from omnipy.util.helpers import (get_default_if_typevar,
get_event_loop_and_check_if_loop_is_running,
is_iterable,
remove_forward_ref_notation)
from omnipy.util.web import download_file_to_memory

if TYPE_CHECKING:
from omnipy.modules.remote.datasets import HttpUrlDataset
from omnipy.modules.remote.models import HttpUrlModel

ModelT = TypeVar('ModelT', bound=Model)
GeneralModelT = TypeVar('GeneralModelT', bound=Model)
_DatasetT = TypeVar('_DatasetT')
Expand Down Expand Up @@ -544,45 +549,87 @@ def save(self, path: str):
tar.extractall(path=directory)
tar.close()

def load(self, *path_or_urls: str, by_file_suffix=False):
def load(self,
paths_or_urls: 'str | Iterable[str] | HttpUrlModel | HttpUrlDataset',
by_file_suffix: bool = False) -> list[asyncio.Task] | None:
from omnipy import HttpUrlDataset, HttpUrlModel

match paths_or_urls:
case HttpUrlDataset():
return self._load_http_urls(paths_or_urls)

case HttpUrlModel():
return self._load_http_urls(HttpUrlDataset({str(paths_or_urls): paths_or_urls}))

case str():
try:
http_url_dataset = HttpUrlDataset({paths_or_urls: paths_or_urls})
except ValidationError:
return self._load_paths([paths_or_urls], by_file_suffix)
return self._load_http_urls(http_url_dataset)
case Iterable():
try:
path_or_url_iterable = cast(Iterable[str], paths_or_urls)
http_url_dataset = HttpUrlDataset(
zip(path_or_url_iterable, path_or_url_iterable))
except ValidationError:
return self._load_paths(path_or_url_iterable, by_file_suffix)
return self._load_http_urls(http_url_dataset)
case _:
raise TypeError(f'"paths_or_urls" argument is of incorrect type. Type '
f'{type(paths_or_urls)} is not supported.')

def _load_http_urls(self, http_url_dataset: 'HttpUrlDataset') -> list[asyncio.Task]:
from omnipy.modules.remote.helpers import RateLimitingClientSession
from omnipy.modules.remote.tasks import get_json_from_api_endpoint
hosts: defaultdict[str, list[int]] = defaultdict(list)
for i, url in enumerate(http_url_dataset.values()):
hosts[url.host].append(i)

async def load_all():
tasks = []
client_sessions = {}
for host in hosts:
client_sessions[host] = RateLimitingClientSession(
self.config.http_config_for_host[host].requests_per_time_period,
self.config.http_config_for_host[host].time_period_in_secs)

for host, indices in hosts.items():
task = (
get_json_from_api_endpoint.refine(output_dataset_param='output_dataset').run(
http_url_dataset[indices],
client_session=client_sessions[host],
output_dataset=self))
tasks.append(task)

await asyncio.gather(*tasks)
return self

loop, loop_is_running = get_event_loop_and_check_if_loop_is_running()

if loop and loop_is_running:
return loop.create_task(load_all())
else:
return asyncio.run(load_all())

def _load_paths(self, path_or_urls: Iterable[str], by_file_suffix: bool) -> None:
for path_or_url in path_or_urls:
if is_model_instance(path_or_url):
path_or_url = path_or_url.contents

with TemporaryDirectory() as tmp_dir_path:
serializer_registry = self._get_serializer_registry()

parsed_url = urlparse(path_or_url)

if parsed_url.scheme in ['http', 'https']:
download_path = self._download_file(path_or_url, parsed_url.path, tmp_dir_path)
if download_path is None:
continue
tar_gz_file_path = self._ensure_tar_gz_file(download_path)
elif parsed_url.scheme in ['file', '']:
tar_gz_file_path = self._ensure_tar_gz_file(parsed_url.path)
elif self._is_windows_path(parsed_url):
tar_gz_file_path = self._ensure_tar_gz_file(path_or_url)
else:
raise ValueError(f'Unsupported scheme "{parsed_url.scheme}"')

if by_file_suffix:
loaded_dataset = \
serializer_registry.load_from_tar_file_path_based_on_file_suffix(
self, tar_gz_file_path, self)
else:
loaded_dataset = \
serializer_registry.load_from_tar_file_path_based_on_dataset_cls(
self, tar_gz_file_path, self)
if loaded_dataset is not None:
self.absorb(loaded_dataset)
continue
else:
raise RuntimeError('Unable to load serializer')
serializer_registry = self._get_serializer_registry()
tar_gz_file_path = self._ensure_tar_gz_file(path_or_url)

@staticmethod
def _is_windows_path(parsed_url: ParseResult) -> bool:
return len(parsed_url.scheme) == 1 and parsed_url.scheme.isalpha()
if by_file_suffix:
loaded_dataset = \
serializer_registry.load_from_tar_file_path_based_on_file_suffix(
self, tar_gz_file_path, self)
else:
loaded_dataset = \
serializer_registry.load_from_tar_file_path_based_on_dataset_cls(
self, tar_gz_file_path, self)
if loaded_dataset is not None:
self.absorb(loaded_dataset)
continue
else:
raise RuntimeError('Unable to load from serializer')

@staticmethod
def _download_file(url: str, path: str, tmp_dir_path: str) -> str | None:
Expand Down Expand Up @@ -638,7 +685,7 @@ def __eq__(self, other: object) -> bool:
and self.to_data() == other.to_data() # last is probably unnecessary, but just in case

def __repr_args__(self):
return [(k, v.contents) for k, v in self.data.items()]
return [(k, v.contents) if is_model_instance(v) else (k, v) for k, v in self.data.items()]


class MultiModelDataset(Dataset[GeneralModelT], Generic[GeneralModelT]):
Expand Down

0 comments on commit cefdda4

Please sign in to comment.