From f79dc4a96d67c585c775f3b711ccb48a5a6f0f5c Mon Sep 17 00:00:00 2001 From: yanghua Date: Fri, 18 Oct 2024 16:02:26 +0800 Subject: [PATCH] Optimize MPU staging memory logic --- tosfs/core.py | 2 +- tosfs/mpu.py | 98 +++++++++++++++++++++++---------------- tosfs/tests/test_tosfs.py | 25 ++++++++++ 3 files changed, 84 insertions(+), 41 deletions(-) diff --git a/tosfs/core.py b/tosfs/core.py index 5dad01a..7d159be 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -2387,7 +2387,7 @@ def commit(self) -> None: logger.debug("Empty file committed %s", self) self.multipart_uploader.abort_upload() self.fs.touch(self.path, **self.kwargs) - elif not self.multipart_uploader.staging_files: + elif not self.multipart_uploader.staging_part_mgr.staging_files: if self.buffer is not None: logger.debug("One-shot upload of %s", self) self.buffer.seek(0) diff --git a/tosfs/mpu.py b/tosfs/mpu.py index 1510db9..f3572b6 100644 --- a/tosfs/mpu.py +++ b/tosfs/mpu.py @@ -29,6 +29,57 @@ from tosfs.core import TosFileSystem +class StagingPartMgr: + """A class to handle staging parts for multipart upload.""" + + def __init__(self, part_size: int, staging_dirs: itertools.cycle): + """Instantiate a StagingPart object.""" + self.part_size = part_size + self.staging_dirs = staging_dirs + self.staging_buffer = io.BytesIO() + self.staging_files: list[str] = [] + + def write_to_buffer(self, chunk: bytes) -> None: + """Write data to the staging buffer.""" + self.staging_buffer.write(chunk) + if self.staging_buffer.tell() >= self.part_size: + self.flush_buffer(False) + + def flush_buffer(self, final: bool = False) -> None: + """Flush the staging buffer.""" + if self.staging_buffer.tell() == 0: + return + + buffer_size = self.staging_buffer.tell() + self.staging_buffer.seek(0) + + while buffer_size >= self.part_size: + staging_dir = next(self.staging_dirs) + with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: + tmp.write(self.staging_buffer.read(self.part_size)) + self.staging_files.append(tmp.name) + buffer_size -= self.part_size + + if not final: + remaining_data = self.staging_buffer.read() + self.staging_buffer = io.BytesIO() + self.staging_buffer.write(remaining_data) + else: + staging_dir = next(self.staging_dirs) + with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: + tmp.write(self.staging_buffer.read()) + self.staging_files.append(tmp.name) + self.staging_buffer = io.BytesIO() + + def get_staging_files(self) -> list[str]: + """Get the staging files.""" + return self.staging_files + + def clear_staging_files(self) -> None: + """Clear the staging files.""" + self.staging_files = [] + + class MultipartUploader: """A class to upload large files to the object store using multipart upload.""" @@ -51,9 +102,9 @@ def __init__( self.staging_buffer_size = staging_buffer_size self.multipart_threshold = multipart_threshold self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size) - self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs) - self.staging_files: list[str] = [] - self.staging_buffer: io.BytesIO = io.BytesIO() + self.staging_part_mgr = StagingPartMgr( + part_size, itertools.cycle(fs.multipart_staging_dirs) + ) self.parts: list = [] self.mpu: CreateMultipartUploadOutput = None @@ -72,46 +123,13 @@ def upload_multiple_chunks(self, buffer: Optional[io.BytesIO]) -> None: chunk = buffer.read(self.part_size) if not chunk: break - self._write_to_staging_buffer(chunk) - - def _write_to_staging_buffer(self, chunk: bytes) -> None: - self.staging_buffer.write(chunk) - if self.staging_buffer.tell() >= self.part_size: - self._flush_staging_buffer(False) - - def _flush_staging_buffer(self, final: bool = False) -> None: - if self.staging_buffer.tell() == 0: - return - - buffer_size = self.staging_buffer.tell() - self.staging_buffer.seek(0) - - while buffer_size >= self.part_size: - staging_dir = next(self.staging_dirs) - with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: - tmp.write(self.staging_buffer.read()) - self.staging_files.append(tmp.name) - buffer_size -= self.part_size - - if not final: - # Move remaining data to a new buffer - remaining_data = self.staging_buffer.read() - self.staging_buffer = io.BytesIO() - self.staging_buffer.write(remaining_data) - else: - staging_dir = next(self.staging_dirs) - with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: - tmp.write(self.staging_buffer.read()) - self.staging_files.append(tmp.name) - buffer_size -= self.part_size - - self.staging_buffer = io.BytesIO() + self.staging_part_mgr.write_to_buffer(chunk) def upload_staged_files(self) -> None: """Upload the staged files to the object store.""" - self._flush_staging_buffer(True) + self.staging_part_mgr.flush_buffer(True) futures = [] - for i, staging_file in enumerate(self.staging_files): + for i, staging_file in enumerate(self.staging_part_mgr.get_staging_files()): part_number = i + 1 futures.append( self.executor.submit( @@ -123,7 +141,7 @@ def upload_staged_files(self) -> None: part_info = future.result() self.parts.append(part_info) - self.staging_files = [] + self.staging_part_mgr.clear_staging_files() def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo: with open(staging_file, "rb") as f: diff --git a/tosfs/tests/test_tosfs.py b/tosfs/tests/test_tosfs.py index 8b72379..95dede4 100644 --- a/tosfs/tests/test_tosfs.py +++ b/tosfs/tests/test_tosfs.py @@ -800,6 +800,31 @@ def test_file_write_mpu( assert f.read() == first_part + second_part + third_part +def test_file_write_mpu_content( + tosfs: TosFileSystem, bucket: str, temporary_workspace: str +) -> None: + file_name = random_str() + + # mock a content let the write logic trigger mpu: + origin_content = ( + random_str(5 * 1024 * 1024) + + random_str(5 * 1024 * 1024) + + random_str(3 * 1024 * 1024) + ) + block_size = 4 * 1024 * 1024 + with tosfs.open( + f"{bucket}/{temporary_workspace}/{file_name}", "w", block_size=block_size + ) as f: + f.write(origin_content) + + assert tosfs.info(f"{bucket}/{temporary_workspace}/{file_name}")["size"] == len( + origin_content + ) + + with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "r") as f: + assert f.read() == origin_content + + def test_file_write_mpu_threshold_check( tosfs: TosFileSystem, bucket: str, temporary_workspace: str ):