Skip to content

Commit

Permalink
Add a switch to control if use glob search
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Nov 6, 2024
1 parent 6942153 commit 59f9538
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 9 deletions.
77 changes: 69 additions & 8 deletions tosfs/compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.

"""The compatible module about AbstractFileSystem in fsspec."""
import os
import re
from typing import Any, Optional
from typing import Any, Optional, Union

from fsspec import AbstractFileSystem
from fsspec.implementations.local import LocalFileSystem
from fsspec.utils import other_paths

magic_check_bytes = re.compile(b"([*?[])")
Expand All @@ -32,6 +34,66 @@ def has_magic(s: str) -> bool:
return match is not None


class EnhancedLocalFileSystem(LocalFileSystem):
"""Enhanced LocalFileSystem than fsspec's LocalFileSystem."""

def expand_path(
self,
path: Union[str, list[str]],
recursive: bool = False,
maxdepth: Optional[int] = None,
disable_glob: bool = False,
**kwargs: Any,
) -> Union[str, list[str]]:
"""Turn one or more globs or dirs into a list of files or dirs.
kwargs are passed to ``glob`` or ``find``, which may in turn call ``ls``
"""
# type: ignore
if maxdepth is not None and maxdepth < 1:
raise ValueError("maxdepth must be at least 1")

if isinstance(path, (str, os.PathLike)):
out = self.expand_path(
[path], recursive, maxdepth, disable_glob=disable_glob
)
else:
out = set() # type: ignore
path = [self._strip_protocol(p) for p in path]
for p in path:
if has_magic(p) and not disable_glob:
bit = set(self.glob(p, maxdepth=maxdepth, **kwargs))
out |= bit # type: ignore
if recursive:
# glob call above expanded one depth so if maxdepth is defined
# then decrement it in expand_path call below. If it is zero
# after decrementing then avoid expand_path call.
if maxdepth is not None and maxdepth <= 1:
continue
out |= set( # type: ignore
self.expand_path(
list(bit),
recursive=recursive,
maxdepth=maxdepth - 1 if maxdepth is not None else None,
**kwargs,
)
)
continue
elif recursive:
rec = set(
self.find(
p, maxdepth=maxdepth, withdirs=True, detail=False, **kwargs
)
)
out |= rec # type: ignore
if p not in out and (recursive is False or self.exists(p)):
# should only check once, for the root
out.add(p) # type: ignore
if not out:
raise FileNotFoundError(path)
return sorted(out)


class FsspecCompatibleFS(AbstractFileSystem):
"""A fsspec compatible file system.
Expand Down Expand Up @@ -200,6 +262,7 @@ def put(
recursive: bool = False,
callback: Any = None,
maxdepth: Optional[int] = None,
disable_glob: bool = False,
**kwargs: Any,
) -> None:
"""Copy file(s) from local.
Expand All @@ -216,17 +279,15 @@ def put(
rpaths = rpath
lpaths = lpath
else:
from fsspec.implementations.local import (
LocalFileSystem,
make_path_posix,
trailing_sep,
)
from fsspec.implementations.local import make_path_posix, trailing_sep

source_is_str = isinstance(lpath, str)
if source_is_str:
lpath = make_path_posix(lpath)
fs = LocalFileSystem()
lpaths = fs.expand_path(lpath, recursive=recursive, maxdepth=maxdepth)
fs = EnhancedLocalFileSystem()
lpaths = fs.expand_path(
lpath, recursive=recursive, maxdepth=maxdepth, disable_glob=disable_glob
)
if source_is_str and (not recursive or maxdepth is not None):
# Non-recursive glob does not copy directories
lpaths = [p for p in lpaths if not (trailing_sep(p) or fs.isdir(p))]
Expand Down
5 changes: 4 additions & 1 deletion tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ def put(
recursive: bool = False,
callback: Any = None,
maxdepth: Optional[int] = None,
disable_glob: bool = False,
**kwargs: Any,
) -> None:
"""Copy file(s) from local.
Expand All @@ -917,7 +918,9 @@ def put(
Calls put_file for each source.
"""
super().put(lpath, rpath, recursive=recursive, **kwargs)
super().put(
lpath, rpath, recursive=recursive, disable_glob=disable_glob, **kwargs
)

def put_file(
self,
Expand Down
66 changes: 66 additions & 0 deletions tosfs/tests/test_tosfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,72 @@ def test_put_file(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -
assert f.read() == b"a" * 1024 * 1024 * 6


def test_put(tosfs: TosFileSystem, bucket: str, temporary_workspace: str):
with tempfile.TemporaryDirectory() as local_temp_dir:
dir_1 = f"{local_temp_dir}/生技[2005]174号文/"
os.makedirs(dir_1)
with open(f"{dir_1}/test.txt", "w") as f:
f.write("hello world")
tosfs.put(
local_temp_dir,
f"{bucket}/{temporary_workspace}",
recursive=True,
disable_glob=True,
)
assert tosfs.exists(
f"{bucket}/{temporary_workspace}/{os.path.basename(local_temp_dir)}"
f"/生技[2005]174号文/"
)
assert tosfs.exists(
f"{bucket}/{temporary_workspace}/{os.path.basename(local_temp_dir)}"
f"/生技[2005]174号文/test.txt"
)
with tosfs.open(
f"{bucket}/{temporary_workspace}/"
f"{os.path.basename(local_temp_dir)}/生技[2005]174号文/test.txt",
mode="r",
) as file:
assert file.read() == "hello world"

with tempfile.TemporaryDirectory() as local_temp_dir:
dir_2 = f"{local_temp_dir}/生技??174号文/"
dir_3 = f"{local_temp_dir}/生技**174号文/"
dir_4 = f"{local_temp_dir}/生技_=+&^%#174号文/"
os.makedirs(dir_2)
os.makedirs(dir_3)
os.makedirs(dir_4)
with open(f"{dir_2}/test.txt", "w") as f:
f.write("hello world")
tosfs.put(
local_temp_dir,
f"{bucket}/{temporary_workspace}",
recursive=True,
disable_glob=True,
)
assert tosfs.exists(
f"{bucket}/{temporary_workspace}/{os.path.basename(local_temp_dir)}"
f"/生技??174号文/"
)
assert tosfs.exists(
f"{bucket}/{temporary_workspace}/{os.path.basename(local_temp_dir)}"
f"/生技??174号文/test.txt"
)
assert tosfs.exists(
f"{bucket}/{temporary_workspace}/{os.path.basename(local_temp_dir)}"
f"/生技**174号文/"
)
assert tosfs.exists(
f"{bucket}/{temporary_workspace}/{os.path.basename(local_temp_dir)}"
f"/生技_=+&^%#174号文/"
)
with tosfs.open(
f"{bucket}/{temporary_workspace}/"
f"{os.path.basename(local_temp_dir)}/生技??174号文/test.txt",
mode="r",
) as file:
assert file.read() == "hello world"


def test_get_file(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> None:
file_name = random_str()
file_content = "hello world"
Expand Down

0 comments on commit 59f9538

Please sign in to comment.