Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation registration #134

Merged
merged 10 commits into from
Aug 28, 2023
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,39 @@ with fs.open(p.path) as f:
data = f.read()
```

### Register custom UPath implementations

In case you develop a custom UPath implementation, feel free to open an issue to discuss integrating it
in `universal_pathlib`. You can dynamically register your implementation too! Here are your options:

#### Dynamic registration from Python

```python
# for example: mymodule/submodule.py
from upath import UPath
from upath.registry import register_implementation

my_protocol = "myproto"
class MyPath(UPath):
... # your custom implementation

register_implementation(my_protocol, MyPath)
```

#### Registration via entry points

```toml
# pyproject.toml
[project.entry-points."unversal_pathlib.implementations"]
myproto = "my_module.submodule:MyPath"
```

```ini
# setup.cfg
[options.entry_points]
universal_pathlib.implementations =
myproto = my_module.submodule:MyPath
```

## Contributing

Expand Down
176 changes: 155 additions & 21 deletions upath/registry.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,63 @@
"""upath.registry -- registry for file system specific implementations

Retrieve UPath implementations via `get_upath_class`.
Register custom UPath subclasses in one of two ways:

### directly from Python

>>> from upath import UPath
>>> from upath.registry import register_implementation
>>> my_protocol = "myproto"
>>> class MyPath(UPath):
... pass
>>> register_implementation(my_protocol, MyPath)

### via entry points

```toml
# pyproject.toml
[project.entry-points."unversal_pathlib.implementations"]
myproto = "my_module.submodule:MyPath"
```

```ini
# setup.cfg
[options.entry_points]
universal_pathlib.implementations =
myproto = my_module.submodule:MyPath
```
"""
from __future__ import annotations

import importlib
import os
import re
import sys
import warnings
from collections import ChainMap
from functools import lru_cache
from typing import TYPE_CHECKING
from importlib import import_module
from importlib.metadata import entry_points
from typing import Iterator
from typing import MutableMapping

from fsspec.core import get_filesystem_class
from fsspec.registry import available_protocols

if TYPE_CHECKING:
from upath.core import UPath
import upath.core

__all__ = [
"get_upath_class",
"available_implementations",
"register_implementation",
]


class _Registry:
_ENTRY_POINT_GROUP = "universal_pathlib.implementations"


class _Registry(MutableMapping[str, "type[upath.core.UPath]"]):
"""internal registry for UPath subclasses"""

known_implementations: dict[str, str] = {
"abfs": "upath.implementations.cloud.AzurePath",
"abfss": "upath.implementations.cloud.AzurePath",
Expand All @@ -35,26 +76,118 @@ class _Registry:
"webdav+https": "upath.implementations.webdav.WebdavPath",
}

def __getitem__(self, item: str) -> type[UPath] | None:
try:
fqn = self.known_implementations[item]
except KeyError:
return None
module_name, name = fqn.rsplit(".", 1)
mod = importlib.import_module(module_name)
return getattr(mod, name) # type: ignore
def __init__(self) -> None:
if sys.version_info >= (3, 10):
eps = entry_points(group=_ENTRY_POINT_GROUP)
else:
eps = entry_points().get(_ENTRY_POINT_GROUP, [])
self._entries = {ep.name: ep for ep in eps}
self._m = ChainMap({}, self.known_implementations) # type: ignore

def __contains__(self, item: object) -> bool:
return item in set().union(self._m, self._entries)

def __getitem__(self, item: str) -> type[upath.core.UPath]:
fqn = self._m.get(item)
if fqn is None:
if item in self._entries:
fqn = self._m[item] = self._entries[item].load()
if fqn is None:
raise KeyError(f"{item} not in registry")
if isinstance(fqn, str):
module_name, name = fqn.rsplit(".", 1)
mod = import_module(module_name)
cls = getattr(mod, name) # type: ignore
else:
cls = fqn
return cls

def __setitem__(self, item: str, value: type[upath.core.UPath] | str) -> None:
if not (
(isinstance(value, type) and issubclass(value, upath.core.UPath))
or isinstance(value, str)
):
raise ValueError(
f"expected UPath subclass or FQN-string, got: {type(value).__name__!r}"
)
self._m[item] = value

def __delitem__(self, __v: str) -> None:
raise NotImplementedError("removal is unsupported")

def __len__(self) -> int:
return len(set().union(self._m, self._entries))

def __iter__(self) -> Iterator[str]:
return iter(set().union(self._m, self._entries))


_registry = _Registry()


@lru_cache
def get_upath_class(protocol: str) -> type[UPath] | None:
"""Return the upath cls for the given protocol."""
cls: type[UPath] | None = _registry[protocol]
if cls is not None:
return cls
def available_implementations(*, fallback: bool = False) -> list[str]:
"""return a list of protocols for available implementations

Parameters
----------
fallback:
If True, also return protocols for fsspec filesystems without
an implementation in universal_pathlib.
"""
impl = list(_registry)
if not fallback:
return impl
else:
return list({*impl, *available_protocols()})


def register_implementation(
protocol: str,
cls: type[upath.core.UPath] | str,
*,
clobber: bool = False,
) -> None:
"""register a UPath implementation with a protocol

Parameters
----------
protocol:
Protocol name to associate with the class
cls:
The UPath subclass for the protocol or a str representing the
full path to an implementation class like package.module.class.
clobber:
Whether to overwrite a protocol with the same name; if False,
will raise instead.
"""
if not re.match(r"^[a-z][a-z0-9+_.]+$", protocol):
raise ValueError(f"{protocol!r} is not a valid URI scheme")
if not clobber and protocol in _registry:
raise ValueError(f"{protocol!r} is already in registry and clobber is False!")
_registry[protocol] = cls


@lru_cache
def get_upath_class(
protocol: str,
*,
fallback: bool = True,
) -> type[upath.core.UPath] | None:
"""Return the upath cls for the given protocol.

Returns `None` if no matching protocol can be found.

Parameters
----------
protocol:
The protocol string
fallback:
If fallback is False, don't return UPath instances for fsspec
filesystems that don't have an implementation registered.
"""
try:
return _registry[protocol]
except KeyError:
if not protocol:
if os.name == "nt":
from upath.implementations.local import WindowsUPath
Expand All @@ -64,6 +197,8 @@ def get_upath_class(protocol: str) -> type[UPath] | None:
from upath.implementations.local import PosixUPath

return PosixUPath
if not fallback:
return None
try:
_ = get_filesystem_class(protocol)
except ValueError:
Expand All @@ -76,5 +211,4 @@ def get_upath_class(protocol: str) -> type[UPath] | None:
UserWarning,
stacklevel=2,
)
mod = importlib.import_module("upath.core")
return mod.UPath # type: ignore
return upath.core.UPath
126 changes: 126 additions & 0 deletions upath/tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
from fsspec.registry import available_protocols

from upath import UPath
from upath.registry import available_implementations
from upath.registry import get_upath_class
from upath.registry import register_implementation

IMPLEMENTATIONS = {
"abfs",
"abfss",
"adl",
"az",
"file",
"gcs",
"gs",
"hdfs",
"http",
"https",
"memory",
"s3",
"s3a",
"webdav+http",
"webdav+https",
}


@pytest.fixture(autouse=True)
def reset_registry():
from upath.registry import _registry

try:
yield
finally:
_registry._m.maps[0].clear() # type: ignore


@pytest.fixture()
def fake_entrypoint():
from importlib.metadata import EntryPoint

from upath.registry import _registry

ep = EntryPoint(
name="myeps",
value="upath.core:UPath",
group="universal_pathlib.implementations",
)
old_registry = _registry._entries.copy()

try:
_registry._entries["myeps"] = ep
yield
finally:
_registry._entries.clear()
_registry._entries.update(old_registry)


def test_available_implementations():
impl = available_implementations()
assert len(impl) == len(set(impl))
assert set(impl) == IMPLEMENTATIONS


def test_available_implementations_with_fallback():
impl = available_implementations(fallback=True)
assert set(impl) == IMPLEMENTATIONS.union(available_protocols())


def test_available_implementations_with_entrypoint(fake_entrypoint):
impl = available_implementations()
assert set(impl) == IMPLEMENTATIONS.union({"myeps"})


def test_register_implementation():
class MyProtoPath(UPath):
pass

register_implementation("myproto", MyProtoPath)

assert get_upath_class("myproto") is MyProtoPath


def test_register_implementation_wrong_input():
with pytest.raises(TypeError):
register_implementation(None, UPath) # type: ignore
with pytest.raises(ValueError):
register_implementation("incorrect**protocol", UPath)
with pytest.raises(ValueError):
register_implementation("myproto", object, clobber=True) # type: ignore
with pytest.raises(ValueError):
register_implementation("file", UPath, clobber=False)
assert set(available_implementations()) == IMPLEMENTATIONS


@pytest.mark.parametrize("protocol", IMPLEMENTATIONS)
def test_get_upath_class(protocol):
upath_cls = get_upath_class("file")
assert issubclass(upath_cls, UPath)


def test_get_upath_class_without_implementation(clear_registry):
with pytest.warns(
UserWarning, match="UPath 'mock' filesystem not explicitly implemented."
):
upath_cls = get_upath_class("mock")
assert issubclass(upath_cls, UPath)


def test_get_upath_class_without_implementation_no_fallback(clear_registry):
assert get_upath_class("mock", fallback=False) is None


def test_get_upath_class_unknown_protocol(clear_registry):
assert get_upath_class("doesnotexist") is None


def test_get_upath_class_from_entrypoint(fake_entrypoint):
assert issubclass(get_upath_class("myeps"), UPath)


@pytest.mark.parametrize(
"protocol", [pytest.param("", id="empty-str"), pytest.param(None, id="none")]
)
def test_get_upath_class_falsey_protocol(protocol):
assert issubclass(get_upath_class(protocol), UPath)