Skip to content

Commit

Permalink
4-way handshake test
Browse files Browse the repository at this point in the history
Signed-off-by: iGxnon <igxnon@gmail.com>
  • Loading branch information
iGxnon committed Jul 16, 2024
1 parent a5590c6 commit e2e55f6
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 54 deletions.
4 changes: 2 additions & 2 deletions src/client/conn/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::codec::tokio::Codec;
use crate::codec::{Decoded, Encoded};
use crate::errors::Error;
use crate::guard::HandleOutgoing;
use crate::io::{SplittedIO, IO};
use crate::io::{SeparatedIO, IO};
use crate::link::TransferLink;
use crate::state::{IncomingStateManage, OutgoingStateManage};
use crate::utils::{Logged, TraceStreamExt};
Expand Down Expand Up @@ -71,6 +71,6 @@ impl ConnectTo for TokioUdpSocket {
.handle_online(addr, config.client_guid, Arc::clone(&ack))
.enter_on_item(Span::noop);

Ok(SplittedIO::new(src, dst))
Ok(SeparatedIO::new(src, dst))
}
}
28 changes: 19 additions & 9 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ pub trait IO:
fn last_trace_id(&self) -> Option<TraceId>;

/// Split into a Stream and a Sink
fn split(self) -> (impl Stream<Item = Bytes>, impl Sink<Message, Error = Error>);
fn split(
self,
) -> (
impl Stream<Item = Bytes> + TraceInfo + Send,
impl Sink<Message, Error = Error> + Send,
);
}

pin_project! {
pub(crate) struct SplittedIO<I, O> {
pub(crate) struct SeparatedIO<I, O> {
#[pin]
src: I,
#[pin]
Expand All @@ -42,13 +47,13 @@ pin_project! {
}
}

impl<I, O> SplittedIO<I, O>
impl<I, O> SeparatedIO<I, O>
where
I: Stream<Item = Bytes> + TraceInfo + Send,
O: Sink<Message, Error = Error> + Send,
{
pub(crate) fn new(src: I, dst: O) -> Self {
SplittedIO {
SeparatedIO {
src,
dst,
default_reliability: Reliability::ReliableOrdered,
Expand All @@ -57,7 +62,7 @@ where
}
}

impl<I, O> Stream for SplittedIO<I, O>
impl<I, O> Stream for SeparatedIO<I, O>
where
I: Stream<Item = Bytes>,
{
Expand All @@ -68,7 +73,7 @@ where
}
}

impl<I, O> Sink<Bytes> for SplittedIO<I, O>
impl<I, O> Sink<Bytes> for SeparatedIO<I, O>
where
O: Sink<Message, Error = Error>,
{
Expand All @@ -92,7 +97,7 @@ where
}
}

impl<I, O> Sink<Message> for SplittedIO<I, O>
impl<I, O> Sink<Message> for SeparatedIO<I, O>
where
O: Sink<Message, Error = Error>,
{
Expand All @@ -115,7 +120,7 @@ where
}
}

impl<I, O> crate::io::IO for SplittedIO<I, O>
impl<I, O> crate::io::IO for SeparatedIO<I, O>
where
O: Sink<Message, Error = Error> + Send,
I: Stream<Item = Bytes> + TraceInfo + Send,
Expand All @@ -141,7 +146,12 @@ where
self.src.get_last_trace_id()
}

fn split(self) -> (impl Stream<Item = Bytes>, impl Sink<Message, Error = Error>) {
fn split(
self,
) -> (
impl Stream<Item = Bytes> + TraceInfo + Send,
impl Sink<Message, Error = Error> + Send,
) {
(self.src, self.dst)
}
}
29 changes: 10 additions & 19 deletions src/server/incoming/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};

use flume::{Receiver, Sender};
use flume::Sender;
use futures::Stream;
use log::{debug, error};
use minitrace::collector::SpanContext;
Expand All @@ -18,7 +18,7 @@ use crate::codec::tokio::Codec;
use crate::codec::{Decoded, Encoded};
use crate::errors::CodecError;
use crate::guard::HandleOutgoing;
use crate::io::{SplittedIO, IO};
use crate::io::{SeparatedIO, IO};
use crate::link::TransferLink;
use crate::packet::connected::{self, FramesMut};
use crate::packet::Packet;
Expand All @@ -39,25 +39,18 @@ pin_project! {
config: Config,
socket: Arc<TokioUdpSocket>,
router: HashMap<SocketAddr, Sender<connected::Packet<FramesMut>>>,
drop_receiver: Receiver<SocketAddr>,
drop_notifier: Sender<SocketAddr>,
}
}

impl Incoming {
fn clear_dropped_addr(self: Pin<&mut Self>) {
let mut this = self.project();
for addr in this.drop_receiver.try_iter() {
this.router.remove(&addr);
this.offline.as_mut().disconnect(&addr);
}
}
}
// impl Incoming {
// fn clear_dropped_addr(self: Pin<&mut Self>) {
// // TODO
// }
// }

impl MakeIncoming for TokioUdpSocket {
fn make_incoming(self, config: Config) -> impl Stream<Item = impl IO> {
let socket = Arc::new(self);
let (drop_notifier, drop_receiver) = flume::unbounded();
Incoming {
offline: UdpFramed::new(Arc::clone(&socket), Codec)
.logged_err(|err| {
Expand All @@ -67,17 +60,15 @@ impl MakeIncoming for TokioUdpSocket {
socket,
config,
router: HashMap::new(),
drop_receiver,
drop_notifier,
}
}
}

impl Stream for Incoming {
type Item = impl IO;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().clear_dropped_addr();
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// self.as_mut().clear_dropped_addr();

let mut this = self.project();
loop {
Expand Down Expand Up @@ -127,7 +118,7 @@ impl Stream for Incoming {
})
});

return Poll::Ready(Some(SplittedIO::new(src, dst)));
return Poll::Ready(Some(SeparatedIO::new(src, dst)));
}
}
}
15 changes: 6 additions & 9 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::pin::Pin;
use std::task::{ready, Context, Poll};

use futures::{Sink, Stream};
use log::{debug, warn};
use log::warn;
use pin_project_lite::pin_project;

use crate::errors::{CodecError, Error};
Expand Down Expand Up @@ -115,9 +115,7 @@ where
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if matches!(self.state, OutgoingState::Closed) {
return Poll::Ready(Err(Error::ConnectionClosed));
}
// flush is allowed after the connection is closed, it will deliver ack.
self.project().frame.poll_flush(cx).map_err(Into::into)
}

Expand Down Expand Up @@ -188,7 +186,6 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if matches!(this.state, IncomingState::Closed) {
debug!("state closed, poll_next to deliver ack");
// Poll the frame even if the state is closed to because the peer can send the
// DisconnectNotification as it did not receive ack.
// This will trigger the ack of the DisconnectNotification to be delivered.
Expand Down Expand Up @@ -299,14 +296,14 @@ mod test {
FrameBody::DisconnectNotification
));

let mut closed = SinkExt::<FrameBody>::close(&mut goodbye).await.unwrap_err();
let closed = SinkExt::<FrameBody>::close(&mut goodbye).await.unwrap_err();
// closed
assert!(matches!(closed, Error::ConnectionClosed));
// No more DisconnectNotification
assert_eq!(goodbye.frame.buf.len(), 1);

// closed
closed = SinkExt::<Message>::flush(&mut goodbye).await.unwrap_err();
assert!(matches!(closed, Error::ConnectionClosed));
std::future::poll_fn(|cx| SinkExt::<FrameBody>::poll_ready_unpin(&mut goodbye, cx))
.await
.unwrap_err();
}
}
137 changes: 122 additions & 15 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,34 @@ pub(crate) fn test_trace_log_setup() -> TestTraceLogGuard {
TestTraceLogGuard { spans }
}

fn make_server_conf() -> server::Config {
server::Config::new()
.send_buf_cap(1024)
.sever_guid(1919810)
.advertisement(&b"123456"[..])
.max_mtu(1500)
.min_mtu(510)
.max_pending(1024)
.support_version(vec![9, 11, 13])
}

fn make_client_conf() -> client::Config {
client::Config::new()
.send_buf_cap(1024)
.mtu(1000)
.client_guid(114514)
.protocol_version(11)
}

#[tokio::test(unhandled_panic = "shutdown_runtime")]
async fn test_tokio_udp_works() {
let _guard = test_trace_log_setup();

let echo_server = async {
let config = server::Config::new()
.send_buf_cap(1024)
.sever_guid(1919810)
.advertisement(&b"123456"[..])
.max_mtu(1500)
.min_mtu(510)
.max_pending(1024)
.support_version(vec![9, 11, 13]);
let mut incoming = UdpSocket::bind("0.0.0.0:19132")
.await
.unwrap()
.make_incoming(config);
.make_incoming(make_server_conf());
loop {
let io = incoming.next().await.unwrap();
tokio::spawn(async move {
Expand All @@ -130,15 +141,10 @@ async fn test_tokio_udp_works() {
tokio::spawn(echo_server);

let client = async {
let config = client::Config::new()
.send_buf_cap(1024)
.mtu(1000)
.client_guid(114514)
.protocol_version(11);
let mut io = UdpSocket::bind("0.0.0.0:0")
.await
.unwrap()
.connect_to("127.0.0.1:19132", config)
.connect_to("127.0.0.1:19132", make_client_conf())
.await
.unwrap();
io.send(Bytes::from_iter(repeat(0xfe).take(256)))
Expand Down Expand Up @@ -180,3 +186,104 @@ async fn test_tokio_udp_works() {

tokio::spawn(client).await.unwrap();
}

#[tokio::test(unhandled_panic = "shutdown_runtime")]
async fn test_4way_handshake_client_close() {
let _guard = test_trace_log_setup();

let server = async {
let mut incoming = UdpSocket::bind("0.0.0.0:19133")
.await
.unwrap()
.make_incoming(make_server_conf());
loop {
let io = incoming.next().await.unwrap();
tokio::spawn(async move {
tokio::pin!(io);
let mut ticker = tokio::time::interval(Duration::from_millis(10));
loop {
tokio::select! {
None = io.next() => {
break;
}
_ = ticker.tick() => {
SinkExt::<Bytes>::flush(&mut io).await.unwrap();
}
};
}
info!("connection closed by client, close the io");
loop {
tokio::select! {
_ = ticker.tick() => {
// That's ridiculous because all calculations are lazy, we should call `poll_next` on io to receive the ack
// so that close will return.
// TODO: eagerly deliver the ack
assert!(io.next().await.is_none());
}
_ = SinkExt::<Bytes>::close(&mut io) => {
break;
}
}
}
info!("io closed");
});
}
};

let client = async {
let mut io = UdpSocket::bind("0.0.0.0:0")
.await
.unwrap()
.connect_to("127.0.0.1:19133", make_client_conf())
.await
.unwrap();

io.send(Bytes::from_iter(repeat(0xfe).take(1024)))
.await
.unwrap();

// split it to avoid annoying borrow checker
let (src, dst) = IO::split(io);

tokio::pin!(src);
tokio::pin!(dst);

loop {
tokio::select! {
_ = src.next() => {}, // same as above, we should call `poll_next` on src to receive the ack so that dst.close will return
_ = dst.close() => {
break;
}
}
}

info!("client closed the connection, wait for server to close");

let mut ticker = tokio::time::interval(Duration::from_millis(10));
let mut last_ticker = tokio::time::interval(Duration::from_millis(20));
let mut tick = 0;
loop {
tokio::select! {
None = src.next(), if tick == 0 => {
info!("received close notification from server, wait for 2MSL");
tick = 1;
}
_ = ticker.tick() => {
dst.flush().await.unwrap();
}
_ = last_ticker.tick(), if tick > 0 => {
// same as above, deliver the incoming ack
assert!(src.next().await.is_none());
tick += 1;
// last 2MSL
if tick > 10 {
break;
}
}
};
}
};

tokio::spawn(server);
tokio::spawn(client).await.unwrap();
}

0 comments on commit e2e55f6

Please sign in to comment.