Skip to content

Commit

Permalink
Handle alias in the worker (#203)
Browse files Browse the repository at this point in the history
* Handle alias in the worker

* Add _peer_count to track peers
  • Loading branch information
ludeeus authored Aug 23, 2023
1 parent d05765e commit eec9050
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

VERSION = "0.36.0"
VERSION = "0.36.1"

setup(
name="snitun",
Expand Down
5 changes: 5 additions & 0 deletions snitun/server/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def alias(self) -> List[str]:
"""Return the alias."""
return self._alias

@property
def all_hostnames(self) -> List[str]:
"""Return a list of the base hostname and any alias."""
return [self._hostname, *self._alias]

@property
def is_connected(self) -> bool:
"""Return True if we are connected to peer."""
Expand Down
5 changes: 2 additions & 3 deletions snitun/server/peer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def remove_peer(self, peer: Peer) -> None:
if self._peers.get(peer.hostname) != peer:
return
_LOGGER.debug("Close peer connection: %s", peer.hostname)
self._peers.pop(peer.hostname)
for alias in peer.alias:
self._peers.pop(alias, None)
for hostname in peer.all_hostnames:
self._peers.pop(hostname, None)

if self._event_callback:
self._loop.call_soon(
Expand Down
18 changes: 13 additions & 5 deletions snitun/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from multiprocessing import Process, Manager, Queue
from threading import Thread
from typing import Dict, Optional, List
from typing import TYPE_CHECKING, Dict, Optional, List
from socket import socket

from .listener_peer import PeerListener
Expand All @@ -13,6 +13,9 @@

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from multiprocessing.managers import SyncManager


class ServerWorker(Process):
"""Worker for multiplexer."""
Expand All @@ -35,14 +38,15 @@ def __init__(
self._loop: Optional[asyncio.BaseEventLoop] = None

# Communication between Parent/Child
self._manager: Manager = Manager()
self._manager: SyncManager = Manager()
self._new: Queue = self._manager.Queue()
self._sync: Dict[str, None] = self._manager.dict()
self._peer_count = self._manager.Value("peer_count", 0)

@property
def peer_size(self) -> int:
"""Return amount of managed peers."""
return len(self._sync)
return self._peer_count.value

def is_responsible_peer(self, sni: str) -> bool:
"""Return True if worker is responsible for this peer domain."""
Expand All @@ -61,9 +65,13 @@ async def _async_init(self) -> None:
def _event_stream(self, peer: Peer, event: PeerManagerEvent) -> None:
"""Event stream peer connection data."""
if event == PeerManagerEvent.CONNECTED:
self._sync[peer.hostname] = None
self._peer_count.set(self._peer_count.value + 1)
for hostname in peer.all_hostnames:
self._sync[hostname] = None
else:
self._sync.pop(peer.hostname, None)
self._peer_count.set(self._peer_count.value - 1)
for hostname in peer.all_hostnames:
self._sync.pop(hostname, None)

def shutdown(self) -> None:
"""Shutdown child process."""
Expand Down
7 changes: 6 additions & 1 deletion tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def test_sni_connection(
aes_key = os.urandom(32)
aes_iv = os.urandom(16)
hostname = "localhost"
fernet_token = create_peer_config(valid.timestamp(), hostname, aes_key, aes_iv)
alias = ["localhost.custom"]
fernet_token = create_peer_config(
valid.timestamp(), hostname, aes_key, aes_iv, alias=alias
)

worker.start()
crypto = CryptoTransport(aes_key, aes_iv)
Expand All @@ -102,6 +105,8 @@ def test_sni_connection(

time.sleep(1)
assert worker.is_responsible_peer(hostname)
for entry in alias:
assert worker.is_responsible_peer(entry)

worker.handover_connection(test_server_sync[1], TLS_1_2, hostname)
assert len(test_client_sync.recv(1048)) == 32
Expand Down

0 comments on commit eec9050

Please sign in to comment.