Skip to content

Commit

Permalink
Refactor MPU logic (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua authored Oct 19, 2024
1 parent bc16f6f commit 0d07616
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 53 deletions.
14 changes: 3 additions & 11 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def __init__(
multipart_staging_dirs: str = tempfile.mkdtemp(),
multipart_size: int = 8 << 20,
multipart_thread_pool_size: int = max(2, os.cpu_count() or 1),
multipart_staging_buffer_size: int = 4 << 10,
multipart_threshold: int = 10 << 20,
multipart_threshold: int = 5 << 20,
enable_crc: bool = True,
enable_verify_ssl: bool = True,
dns_cache_timeout: int = 0,
Expand Down Expand Up @@ -186,11 +185,6 @@ def __init__(
multipart_thread_pool_size : int, optional
The size of thread pool used for uploading multipart in parallel for the
given object storage. (default is max(2, os.cpu_count()).
multipart_staging_buffer_size : int, optional
The max byte size which will buffer the staging data in-memory before
flushing to the staging file. It will decrease the random write in local
staging disk dramatically if writing plenty of small files.
default is 4096.
multipart_threshold : int, optional
The threshold which control whether enable multipart upload during
writing data to the given object storage, if the write data size is less
Expand Down Expand Up @@ -270,7 +264,6 @@ def __init__(
]
self.multipart_size = multipart_size
self.multipart_thread_pool_size = multipart_thread_pool_size
self.multipart_staging_buffer_size = multipart_staging_buffer_size
self.multipart_threshold = multipart_threshold

super().__init__(**kwargs)
Expand Down Expand Up @@ -2252,7 +2245,6 @@ def __init__(
key=key,
part_size=fs.multipart_size,
thread_pool_size=fs.multipart_thread_pool_size,
staging_buffer_size=fs.multipart_staging_buffer_size,
multipart_threshold=fs.multipart_threshold,
)

Expand Down Expand Up @@ -2322,7 +2314,7 @@ def _upload_chunk(self, final: bool = False) -> bool:
self.autocommit
and final
and self.tell()
< max(self.blocksize, self.multipart_uploader.multipart_threshold)
< min(self.blocksize, self.multipart_uploader.multipart_threshold)
):
# only happens when closing small file, use one-shot PUT
pass
Expand Down Expand Up @@ -2387,7 +2379,7 @@ def commit(self) -> None:
logger.debug("Empty file committed %s", self)
self.multipart_uploader.abort_upload()
self.fs.touch(self.path, **self.kwargs)
elif not self.multipart_uploader.staging_files:
elif not self.multipart_uploader.staging_part_mgr.staging_files:
if self.buffer is not None:
logger.debug("One-shot upload of %s", self)
self.buffer.seek(0)
Expand Down
100 changes: 58 additions & 42 deletions tosfs/mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,57 @@
from tosfs.core import TosFileSystem


class StagingPartMgr:
"""A class to handle staging parts for multipart upload."""

def __init__(self, part_size: int, staging_dirs: itertools.cycle):
"""Instantiate a StagingPart object."""
self.part_size = part_size
self.staging_dirs = staging_dirs
self.staging_buffer = io.BytesIO()
self.staging_files: list[str] = []

def write_to_buffer(self, chunk: bytes) -> None:
"""Write data to the staging buffer."""
self.staging_buffer.write(chunk)
if self.staging_buffer.tell() >= self.part_size:
self.flush_buffer(False)

def flush_buffer(self, final: bool = False) -> None:
"""Flush the staging buffer."""
if self.staging_buffer.tell() == 0:
return

buffer_size = self.staging_buffer.tell()
self.staging_buffer.seek(0)

while buffer_size >= self.part_size:
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read(self.part_size))
self.staging_files.append(tmp.name)
buffer_size -= self.part_size

if not final:
remaining_data = self.staging_buffer.read()
self.staging_buffer = io.BytesIO()
self.staging_buffer.write(remaining_data)
else:
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read())
self.staging_files.append(tmp.name)
self.staging_buffer = io.BytesIO()

def get_staging_files(self) -> list[str]:
"""Get the staging files."""
return self.staging_files

def clear_staging_files(self) -> None:
"""Clear the staging files."""
self.staging_files = []


class MultipartUploader:
"""A class to upload large files to the object store using multipart upload."""

Expand All @@ -39,7 +90,6 @@ def __init__(
key: str,
part_size: int,
thread_pool_size: int,
staging_buffer_size: int,
multipart_threshold: int,
):
"""Instantiate a MultipartUploader object."""
Expand All @@ -48,12 +98,11 @@ def __init__(
self.key = key
self.part_size = part_size
self.thread_pool_size = thread_pool_size
self.staging_buffer_size = staging_buffer_size
self.multipart_threshold = multipart_threshold
self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size)
self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs)
self.staging_files: list[str] = []
self.staging_buffer: io.BytesIO = io.BytesIO()
self.staging_part_mgr = StagingPartMgr(
part_size, itertools.cycle(fs.multipart_staging_dirs)
)
self.parts: list = []
self.mpu: CreateMultipartUploadOutput = None

Expand All @@ -72,46 +121,13 @@ def upload_multiple_chunks(self, buffer: Optional[io.BytesIO]) -> None:
chunk = buffer.read(self.part_size)
if not chunk:
break
self._write_to_staging_buffer(chunk)

def _write_to_staging_buffer(self, chunk: bytes) -> None:
self.staging_buffer.write(chunk)
if self.staging_buffer.tell() >= self.part_size:
self._flush_staging_buffer(False)

def _flush_staging_buffer(self, final: bool = False) -> None:
if self.staging_buffer.tell() == 0:
return

buffer_size = self.staging_buffer.tell()
self.staging_buffer.seek(0)

while buffer_size >= self.part_size:
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read())
self.staging_files.append(tmp.name)
buffer_size -= self.part_size

if not final:
# Move remaining data to a new buffer
remaining_data = self.staging_buffer.read()
self.staging_buffer = io.BytesIO()
self.staging_buffer.write(remaining_data)
else:
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read())
self.staging_files.append(tmp.name)
buffer_size -= self.part_size

self.staging_buffer = io.BytesIO()
self.staging_part_mgr.write_to_buffer(chunk)

def upload_staged_files(self) -> None:
"""Upload the staged files to the object store."""
self._flush_staging_buffer(True)
self.staging_part_mgr.flush_buffer(True)
futures = []
for i, staging_file in enumerate(self.staging_files):
for i, staging_file in enumerate(self.staging_part_mgr.get_staging_files()):
part_number = i + 1
futures.append(
self.executor.submit(
Expand All @@ -123,7 +139,7 @@ def upload_staged_files(self) -> None:
part_info = future.result()
self.parts.append(part_info)

self.staging_files = []
self.staging_part_mgr.clear_staging_files()

def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo:
with open(staging_file, "rb") as f:
Expand Down
25 changes: 25 additions & 0 deletions tosfs/tests/test_tosfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,31 @@ def test_file_write_mpu(
assert f.read() == first_part + second_part + third_part


def test_file_write_mpu_content(
tosfs: TosFileSystem, bucket: str, temporary_workspace: str
) -> None:
file_name = random_str()

# mock a content let the write logic trigger mpu:
origin_content = (
random_str(5 * 1024 * 1024)
+ random_str(5 * 1024 * 1024)
+ random_str(3 * 1024 * 1024)
)
block_size = 4 * 1024 * 1024
with tosfs.open(
f"{bucket}/{temporary_workspace}/{file_name}", "w", block_size=block_size
) as f:
f.write(origin_content)

assert tosfs.info(f"{bucket}/{temporary_workspace}/{file_name}")["size"] == len(
origin_content
)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "r") as f:
assert f.read() == origin_content


def test_file_write_mpu_threshold_check(
tosfs: TosFileSystem, bucket: str, temporary_workspace: str
):
Expand Down

0 comments on commit 0d07616

Please sign in to comment.