From bcdc9a2c4831ed912386083392055e0a5ef45481 Mon Sep 17 00:00:00 2001 From: Andreas Poehlmann Date: Fri, 6 Oct 2023 08:19:30 +0200 Subject: [PATCH] upath.implementations.webdav: working webdav implementation --- upath/core312plus.py | 2 +- upath/implementations/webdav.py | 45 ++++++++++++++++++++++++++------- upath/registry.py | 1 + upath/tests/test_registry.py | 1 + 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/upath/core312plus.py b/upath/core312plus.py index fe133f73..828052a7 100644 --- a/upath/core312plus.py +++ b/upath/core312plus.py @@ -101,7 +101,7 @@ def get_upath_protocol( raise TypeError(f"expected a str or PurePath instance, got: {pth!r}") # if storage_options and not protocol: # protocol = "file" - if protocol and pth_protocol and protocol != pth_protocol: + if protocol and pth_protocol and not pth_protocol.startswith(protocol): raise ValueError( f"requested protocol {protocol!r} incompatible with {pth_protocol!r}" ) diff --git a/upath/implementations/webdav.py b/upath/implementations/webdav.py index d5744bc2..ff1f1225 100644 --- a/upath/implementations/webdav.py +++ b/upath/implementations/webdav.py @@ -3,6 +3,7 @@ import sys from typing import Any from urllib.parse import ParseResult +from urllib.parse import urlsplit from urllib.parse import urlunsplit import upath.core @@ -71,17 +72,43 @@ def storage_options(self) -> dict[str, Any]: if sys.version_info >= (3, 12): + import upath.core312plus + + class WebdavPath(upath.core312plus.UPath): # noqa + __slots__ = () - class WebdavPath(upath.core.UPath): # noqa def __init__( self, *args, protocol: str | None = None, **storage_options: Any ) -> None: - if self._protocol == "webdav+http": - ... - elif self._protocol == "webdav+https": - ... - elif self._protocol == "webdav": - ... + base_options = getattr(self, "_storage_options", {}) # when unpickling + if args: + args0, *argsN = args + url = urlsplit(str(args0)) + args0 = urlunsplit(url._replace(scheme="", netloc="")) or "/" + if "base_url" not in storage_options: + if self._protocol == "webdav+http": + storage_options["base_url"] = urlunsplit( + url._replace(scheme="http", path="") + ) + elif self._protocol == "webdav+https": + storage_options["base_url"] = urlunsplit( + url._replace(scheme="https", path="") + ) else: - raise NotImplementedError - super().__init__(*args, protocol="webdav", **storage_options) + args0, argsN = "/", () + storage_options = {**base_options, **storage_options} + if "base_url" not in storage_options: + raise ValueError( + f"must provide `base_url` storage option for args: {args!r}" + ) + self._protocol = "webdav" + super().__init__(args0, *argsN, protocol="webdav", **storage_options) + + @property + def path(self) -> str: + # webdav paths don't start at "/" + return super().path.removeprefix("/") + + def __str__(self): + base_url = self.storage_options["base_url"].removesuffix("/") + return super().__str__().replace("webdav://", f"webdav+{base_url}", 1) diff --git a/upath/registry.py b/upath/registry.py index c03547bc..3ed65674 100644 --- a/upath/registry.py +++ b/upath/registry.py @@ -72,6 +72,7 @@ class _Registry(MutableMapping[str, "type[upath.UPath]"]): "memory": "upath.implementations.memory.MemoryPath", "s3": "upath.implementations.cloud.S3Path", "s3a": "upath.implementations.cloud.S3Path", + "webdav": "upath.implementations.webdav.WebdavPath", "webdav+http": "upath.implementations.webdav.WebdavPath", "webdav+https": "upath.implementations.webdav.WebdavPath", } diff --git a/upath/tests/test_registry.py b/upath/tests/test_registry.py index 93388f11..38cadb45 100644 --- a/upath/tests/test_registry.py +++ b/upath/tests/test_registry.py @@ -20,6 +20,7 @@ "memory", "s3", "s3a", + "webdav", "webdav+http", "webdav+https", }