diff --git a/sshfs/file.py b/sshfs/file.py index 5187d5b..4b03762 100644 --- a/sshfs/file.py +++ b/sshfs/file.py @@ -26,32 +26,34 @@ def __init__( self.mode = mode self.max_requests = max_requests or _MAX_SFTP_REQUESTS - if block_size is None: - # "The OpenSSH SFTP server will close the connection - # if it receives a message larger than 256 KB, and - # limits read requests to returning no more than - # 64 KB." - # - # We are going to use the maximum block_size possible - # with a 16KB margin (so instead of sending 256 KB data, - # we'll send 240 KB + headers for write requests) - - if self.readable(): - block_size = READ_BLOCK_SIZE - else: - block_size = WRITE_BLOCK_SIZE - # The blocksize is often used with constructs like # shutil.copyfileobj(src, dst, length=file.blocksize) and since we are # using pipelining, we are going to reflect the total size rather than # a size of chunk to our limits. - self.blocksize = block_size * self.max_requests + self.blocksize = None if block_size is None else block_size * self.max_requests self.kwargs = kwargs self._file = sync(self.loop, self._open_file) self._closed = False + def _determine_block_size(self, channel): + # Use the asyncssh block sizes to ensure the best performance. + limits = getattr(channel, "limits", None) + if limits: + return limits.max_read_len if self.readable() else limits.max_write_len + + # "The OpenSSH SFTP server will close the connection + # if it receives a message larger than 256 KB, and + # limits read requests to returning no more than + # 64 KB." + # + # We are going to use the maximum block_size possible + # with a 16KB margin (so instead of sending 256 KB data, + # we'll send 240 KB + headers for write requests) + return READ_BLOCK_SIZE if self.readable() else WRITE_BLOCK_SIZE + + @wrap_exceptions async def _open_file(self): # TODO: this needs to keep a reference to the @@ -61,6 +63,8 @@ async def _open_file(self): # it's operations but the pool it thinking this # channel is freed. async with self.fs._pool.get() as channel: + if self.blocksize is None: + self.blocksize = self._determine_block_size(channel) * self.max_requests return await channel.open( self.path, self.mode,