From 0d07616b4af033ca1591cefadf4fa3b59d422a80 Mon Sep 17 00:00:00 2001 From: vinoyang Date: Sat, 19 Oct 2024 16:58:38 +0800 Subject: [PATCH] Refactor MPU logic (#227) --- tosfs/core.py | 14 ++---- tosfs/mpu.py | 100 ++++++++++++++++++++++---------------- tosfs/tests/test_tosfs.py | 25 ++++++++++ 3 files changed, 86 insertions(+), 53 deletions(-) diff --git a/tosfs/core.py b/tosfs/core.py index 5dad01a..1878b57 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -118,8 +118,7 @@ def __init__( multipart_staging_dirs: str = tempfile.mkdtemp(), multipart_size: int = 8 << 20, multipart_thread_pool_size: int = max(2, os.cpu_count() or 1), - multipart_staging_buffer_size: int = 4 << 10, - multipart_threshold: int = 10 << 20, + multipart_threshold: int = 5 << 20, enable_crc: bool = True, enable_verify_ssl: bool = True, dns_cache_timeout: int = 0, @@ -186,11 +185,6 @@ def __init__( multipart_thread_pool_size : int, optional The size of thread pool used for uploading multipart in parallel for the given object storage. (default is max(2, os.cpu_count()). - multipart_staging_buffer_size : int, optional - The max byte size which will buffer the staging data in-memory before - flushing to the staging file. It will decrease the random write in local - staging disk dramatically if writing plenty of small files. - default is 4096. multipart_threshold : int, optional The threshold which control whether enable multipart upload during writing data to the given object storage, if the write data size is less @@ -270,7 +264,6 @@ def __init__( ] self.multipart_size = multipart_size self.multipart_thread_pool_size = multipart_thread_pool_size - self.multipart_staging_buffer_size = multipart_staging_buffer_size self.multipart_threshold = multipart_threshold super().__init__(**kwargs) @@ -2252,7 +2245,6 @@ def __init__( key=key, part_size=fs.multipart_size, thread_pool_size=fs.multipart_thread_pool_size, - staging_buffer_size=fs.multipart_staging_buffer_size, multipart_threshold=fs.multipart_threshold, ) @@ -2322,7 +2314,7 @@ def _upload_chunk(self, final: bool = False) -> bool: self.autocommit and final and self.tell() - < max(self.blocksize, self.multipart_uploader.multipart_threshold) + < min(self.blocksize, self.multipart_uploader.multipart_threshold) ): # only happens when closing small file, use one-shot PUT pass @@ -2387,7 +2379,7 @@ def commit(self) -> None: logger.debug("Empty file committed %s", self) self.multipart_uploader.abort_upload() self.fs.touch(self.path, **self.kwargs) - elif not self.multipart_uploader.staging_files: + elif not self.multipart_uploader.staging_part_mgr.staging_files: if self.buffer is not None: logger.debug("One-shot upload of %s", self) self.buffer.seek(0) diff --git a/tosfs/mpu.py b/tosfs/mpu.py index 1510db9..8b44adc 100644 --- a/tosfs/mpu.py +++ b/tosfs/mpu.py @@ -29,6 +29,57 @@ from tosfs.core import TosFileSystem +class StagingPartMgr: + """A class to handle staging parts for multipart upload.""" + + def __init__(self, part_size: int, staging_dirs: itertools.cycle): + """Instantiate a StagingPart object.""" + self.part_size = part_size + self.staging_dirs = staging_dirs + self.staging_buffer = io.BytesIO() + self.staging_files: list[str] = [] + + def write_to_buffer(self, chunk: bytes) -> None: + """Write data to the staging buffer.""" + self.staging_buffer.write(chunk) + if self.staging_buffer.tell() >= self.part_size: + self.flush_buffer(False) + + def flush_buffer(self, final: bool = False) -> None: + """Flush the staging buffer.""" + if self.staging_buffer.tell() == 0: + return + + buffer_size = self.staging_buffer.tell() + self.staging_buffer.seek(0) + + while buffer_size >= self.part_size: + staging_dir = next(self.staging_dirs) + with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: + tmp.write(self.staging_buffer.read(self.part_size)) + self.staging_files.append(tmp.name) + buffer_size -= self.part_size + + if not final: + remaining_data = self.staging_buffer.read() + self.staging_buffer = io.BytesIO() + self.staging_buffer.write(remaining_data) + else: + staging_dir = next(self.staging_dirs) + with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: + tmp.write(self.staging_buffer.read()) + self.staging_files.append(tmp.name) + self.staging_buffer = io.BytesIO() + + def get_staging_files(self) -> list[str]: + """Get the staging files.""" + return self.staging_files + + def clear_staging_files(self) -> None: + """Clear the staging files.""" + self.staging_files = [] + + class MultipartUploader: """A class to upload large files to the object store using multipart upload.""" @@ -39,7 +90,6 @@ def __init__( key: str, part_size: int, thread_pool_size: int, - staging_buffer_size: int, multipart_threshold: int, ): """Instantiate a MultipartUploader object.""" @@ -48,12 +98,11 @@ def __init__( self.key = key self.part_size = part_size self.thread_pool_size = thread_pool_size - self.staging_buffer_size = staging_buffer_size self.multipart_threshold = multipart_threshold self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size) - self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs) - self.staging_files: list[str] = [] - self.staging_buffer: io.BytesIO = io.BytesIO() + self.staging_part_mgr = StagingPartMgr( + part_size, itertools.cycle(fs.multipart_staging_dirs) + ) self.parts: list = [] self.mpu: CreateMultipartUploadOutput = None @@ -72,46 +121,13 @@ def upload_multiple_chunks(self, buffer: Optional[io.BytesIO]) -> None: chunk = buffer.read(self.part_size) if not chunk: break - self._write_to_staging_buffer(chunk) - - def _write_to_staging_buffer(self, chunk: bytes) -> None: - self.staging_buffer.write(chunk) - if self.staging_buffer.tell() >= self.part_size: - self._flush_staging_buffer(False) - - def _flush_staging_buffer(self, final: bool = False) -> None: - if self.staging_buffer.tell() == 0: - return - - buffer_size = self.staging_buffer.tell() - self.staging_buffer.seek(0) - - while buffer_size >= self.part_size: - staging_dir = next(self.staging_dirs) - with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: - tmp.write(self.staging_buffer.read()) - self.staging_files.append(tmp.name) - buffer_size -= self.part_size - - if not final: - # Move remaining data to a new buffer - remaining_data = self.staging_buffer.read() - self.staging_buffer = io.BytesIO() - self.staging_buffer.write(remaining_data) - else: - staging_dir = next(self.staging_dirs) - with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: - tmp.write(self.staging_buffer.read()) - self.staging_files.append(tmp.name) - buffer_size -= self.part_size - - self.staging_buffer = io.BytesIO() + self.staging_part_mgr.write_to_buffer(chunk) def upload_staged_files(self) -> None: """Upload the staged files to the object store.""" - self._flush_staging_buffer(True) + self.staging_part_mgr.flush_buffer(True) futures = [] - for i, staging_file in enumerate(self.staging_files): + for i, staging_file in enumerate(self.staging_part_mgr.get_staging_files()): part_number = i + 1 futures.append( self.executor.submit( @@ -123,7 +139,7 @@ def upload_staged_files(self) -> None: part_info = future.result() self.parts.append(part_info) - self.staging_files = [] + self.staging_part_mgr.clear_staging_files() def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo: with open(staging_file, "rb") as f: diff --git a/tosfs/tests/test_tosfs.py b/tosfs/tests/test_tosfs.py index 8b72379..95dede4 100644 --- a/tosfs/tests/test_tosfs.py +++ b/tosfs/tests/test_tosfs.py @@ -800,6 +800,31 @@ def test_file_write_mpu( assert f.read() == first_part + second_part + third_part +def test_file_write_mpu_content( + tosfs: TosFileSystem, bucket: str, temporary_workspace: str +) -> None: + file_name = random_str() + + # mock a content let the write logic trigger mpu: + origin_content = ( + random_str(5 * 1024 * 1024) + + random_str(5 * 1024 * 1024) + + random_str(3 * 1024 * 1024) + ) + block_size = 4 * 1024 * 1024 + with tosfs.open( + f"{bucket}/{temporary_workspace}/{file_name}", "w", block_size=block_size + ) as f: + f.write(origin_content) + + assert tosfs.info(f"{bucket}/{temporary_workspace}/{file_name}")["size"] == len( + origin_content + ) + + with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "r") as f: + assert f.read() == origin_content + + def test_file_write_mpu_threshold_check( tosfs: TosFileSystem, bucket: str, temporary_workspace: str ):