diff --git a/sshfs/pools/soft.py b/sshfs/pools/soft.py index 14d300d..1c8bfb8 100644 --- a/sshfs/pools/soft.py +++ b/sshfs/pools/soft.py @@ -1,3 +1,4 @@ +import asyncio import heapq from collections import Counter from contextlib import asynccontextmanager @@ -19,20 +20,23 @@ class SFTPSoftChannelPool(BaseSFTPChannelPool): def __init__(self, *args, **kwargs): self._channels = Counter() + self._channels_lock = asyncio.Lock() super().__init__(*args, **kwargs) @asynccontextmanager async def get(self): - [(least_used_channel, num_connections)] = ( - heapq.nsmallest(1, self._channels.items(), lambda kv: kv[1]) - or self._NO_CHANNELS - ) - + least_used_channel, num_connections = self._least_used() if least_used_channel is None or num_connections >= self._THRESHOLD: - channel = await self._maybe_new_channel() - if channel is not None: - least_used_channel = channel - num_connections = 0 + async with self._channels_lock: + channel = await self._maybe_new_channel() + if channel is not None: + least_used_channel = channel + num_connections = 0 + self._channels[least_used_channel] = 0 + + if channel is None: + # another coroutine may have opened a channel while we waited + least_used_channel, num_connections = self._least_used() if least_used_channel is None: raise ValueError("Can't create any SFTP connections!") @@ -46,6 +50,13 @@ async def get(self): async def _cleanup(self): self._channels.clear() + def _least_used(self): + [(least_used_channel, num_connections)] = ( + heapq.nsmallest(1, self._channels.items(), lambda kv: kv[1]) + or self._NO_CHANNELS + ) + return least_used_channel, num_connections + @property def active_channels(self): return len(self._channels)