Skip to content

Commit

Permalink
Add the tag ALL to dataset registries
Browse files Browse the repository at this point in the history
  • Loading branch information
simonhkswan committed Mar 14, 2023
1 parent 1135cbe commit 957251b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/synthesized_datasets/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


class _Tag(_Enum):
ALL = "all"
CREDIT = "credit"
INSURANCE = "insurance"
FRAUD = "fraud"
Expand All @@ -24,10 +25,11 @@ def __repr__(self):


class _Dataset:
def __init__(self, name: str, url: str, tags: _typing.List[_Tag] = None):
def __init__(self, name: str, url: str, tags: _typing.Optional[_typing.List[_Tag]] = None):
self._name = name
self._url = _ROOT_URL + url
self._tags: _typing.List[_Tag] = tags if tags is not None else []
_REGISTRIES[_Tag.ALL]._register(self)
for tag in self._tags:
_REGISTRIES[tag]._register(self)

Expand Down Expand Up @@ -59,7 +61,7 @@ def __init__(self, tag: _Tag):
self._datasets: _typing.MutableMapping[str, _Dataset] = {}

def _register(self, dataset: _Dataset):
if self._tag not in dataset.tags:
if self._tag not in dataset.tags and self._tag != _Tag.ALL:
raise ValueError(f"_Dataset {dataset.name} is not tagged with {self._tag}")

if dataset.name not in self._datasets:
Expand Down

0 comments on commit 957251b

Please sign in to comment.