Skip to content

Commit

Permalink
Implement load_urls_into_new_dataset() convenience task
Browse files Browse the repository at this point in the history
  • Loading branch information
sveinugu committed Nov 6, 2024
1 parent cefdda4 commit 5f9562b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
16 changes: 16 additions & 0 deletions src/omnipy/modules/remote/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

from aiohttp import ClientResponse, ClientSession
from aiohttp_retry import ExponentialRetry, FibonacciRetry, JitterRetry, RandomRetry, RetryClient
from typing_extensions import TypeVar

from omnipy.api.enums import BackoffStrategy
from omnipy.compute.task import TaskTemplate
from omnipy.data.dataset import Dataset

from ..json.datasets import JsonDataset
from ..json.models import JsonModel
from ..raw.datasets import BytesDataset, StrDataset
from ..raw.models import BytesModel, StrModel
from .datasets import HttpUrlDataset
from .models import HttpUrlModel

DEFAULT_RETRIES = 5
Expand Down Expand Up @@ -126,3 +129,16 @@ async def get_bytes_from_api_endpoint(
):
async for response in _call_get(url, cast(ClientSession, retry_session)):
return BytesModel(await response.read())


JsonDatasetT = TypeVar('JsonDatasetT', bound=Dataset)


@TaskTemplate()
async def load_urls_into_new_dataset(
urls: HttpUrlDataset,
dataset_cls: type[JsonDatasetT] = JsonDataset,
) -> JsonDatasetT:
dataset = dataset_cls()
await dataset.load(urls)
return dataset
21 changes: 18 additions & 3 deletions tests/modules/remote/cases/request_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

import pytest_cases as pc

from omnipy import BytesDataset, Dataset, JsonDataset, StrDataset, TaskTemplate
from omnipy import BytesDataset, Dataset, JsonDataset, JsonDictDataset, StrDataset, TaskTemplate
from omnipy.modules.remote.tasks import (get_bytes_from_api_endpoint,
get_json_from_api_endpoint,
get_str_from_api_endpoint)
get_str_from_api_endpoint,
load_urls_into_new_dataset)


@dataclass
class RequestTypeCase:
job: TaskTemplate
kwargs: dict[str, str]
kwargs: dict[str, object]
dataset_cls: type[Dataset]


Expand All @@ -28,3 +29,17 @@ def case_get_str_from_api_endpoint() -> RequestTypeCase:
@pc.case
def case_get_bytes_from_api_endpoint() -> RequestTypeCase:
return RequestTypeCase(get_bytes_from_api_endpoint, dict(), BytesDataset)


@pc.case
def case_load_urls_into_new_dataset_default_json() -> RequestTypeCase:
return RequestTypeCase(load_urls_into_new_dataset, dict(), JsonDataset)


@pc.case
def case_load_urls_into_new_dataset_other_dataset_cls() -> RequestTypeCase:
return RequestTypeCase(
load_urls_into_new_dataset,
dict(dataset_cls=JsonDictDataset),
JsonDictDataset,
)
2 changes: 1 addition & 1 deletion tests/modules/remote/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _assert_query_results(assert_model_if_dyn_conv_else_val, case: RequestTypeCa
case omnipy.StrDataset | omnipy.BytesDataset:
json_data = JsonDataset()
json_data.from_json(data.to_data())
case omnipy.JsonDataset:
case omnipy.JsonDataset | omnipy.JsonDictDataset:
json_data = data
case _:
raise ShouldNotOccurException()
Expand Down

0 comments on commit 5f9562b

Please sign in to comment.