diff --git a/snitun/server/worker.py b/snitun/server/worker.py index 082346f5..77299cf8 100644 --- a/snitun/server/worker.py +++ b/snitun/server/worker.py @@ -3,7 +3,7 @@ import logging from multiprocessing import Process, Manager, Queue from threading import Thread -from typing import Dict, Optional, List +from typing import TYPE_CHECKING, Dict, Optional, List from socket import socket from .listener_peer import PeerListener @@ -13,6 +13,9 @@ _LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from multiprocessing.managers import SyncManager + class ServerWorker(Process): """Worker for multiplexer.""" @@ -35,14 +38,15 @@ def __init__( self._loop: Optional[asyncio.BaseEventLoop] = None # Communication between Parent/Child - self._manager: Manager = Manager() + self._manager: SyncManager = Manager() self._new: Queue = self._manager.Queue() self._sync: Dict[str, None] = self._manager.dict() + self._peer_count = self._manager.Value("peer_count", 0) @property def peer_size(self) -> int: """Return amount of managed peers.""" - return len(self._sync) + return self._peer_count.value def is_responsible_peer(self, sni: str) -> bool: """Return True if worker is responsible for this peer domain.""" @@ -61,9 +65,11 @@ async def _async_init(self) -> None: def _event_stream(self, peer: Peer, event: PeerManagerEvent) -> None: """Event stream peer connection data.""" if event == PeerManagerEvent.CONNECTED: + self._peer_count.set(self._peer_count.value + 1) for hostname in peer.all_hostnames: self._sync[hostname] = None else: + self._peer_count.set(self._peer_count.value - 1) for hostname in peer.all_hostnames: self._sync.pop(hostname, None)