From e2e55f6b8d6cd72fc03b855766947ccf95716870 Mon Sep 17 00:00:00 2001 From: iGxnon Date: Tue, 16 Jul 2024 11:49:13 +0800 Subject: [PATCH] 4-way handshake test Signed-off-by: iGxnon --- src/client/conn/tokio.rs | 4 +- src/io.rs | 28 ++++--- src/server/incoming/tokio.rs | 29 +++----- src/state.rs | 15 ++-- src/tests.rs | 137 +++++++++++++++++++++++++++++++---- 5 files changed, 159 insertions(+), 54 deletions(-) diff --git a/src/client/conn/tokio.rs b/src/client/conn/tokio.rs index c3d5e02..f600d34 100644 --- a/src/client/conn/tokio.rs +++ b/src/client/conn/tokio.rs @@ -14,7 +14,7 @@ use crate::codec::tokio::Codec; use crate::codec::{Decoded, Encoded}; use crate::errors::Error; use crate::guard::HandleOutgoing; -use crate::io::{SplittedIO, IO}; +use crate::io::{SeparatedIO, IO}; use crate::link::TransferLink; use crate::state::{IncomingStateManage, OutgoingStateManage}; use crate::utils::{Logged, TraceStreamExt}; @@ -71,6 +71,6 @@ impl ConnectTo for TokioUdpSocket { .handle_online(addr, config.client_guid, Arc::clone(&ack)) .enter_on_item(Span::noop); - Ok(SplittedIO::new(src, dst)) + Ok(SeparatedIO::new(src, dst)) } } diff --git a/src/io.rs b/src/io.rs index 564358a..e6d5c99 100644 --- a/src/io.rs +++ b/src/io.rs @@ -28,11 +28,16 @@ pub trait IO: fn last_trace_id(&self) -> Option; /// Split into a Stream and a Sink - fn split(self) -> (impl Stream, impl Sink); + fn split( + self, + ) -> ( + impl Stream + TraceInfo + Send, + impl Sink + Send, + ); } pin_project! { - pub(crate) struct SplittedIO { + pub(crate) struct SeparatedIO { #[pin] src: I, #[pin] @@ -42,13 +47,13 @@ pin_project! { } } -impl SplittedIO +impl SeparatedIO where I: Stream + TraceInfo + Send, O: Sink + Send, { pub(crate) fn new(src: I, dst: O) -> Self { - SplittedIO { + SeparatedIO { src, dst, default_reliability: Reliability::ReliableOrdered, @@ -57,7 +62,7 @@ where } } -impl Stream for SplittedIO +impl Stream for SeparatedIO where I: Stream, { @@ -68,7 +73,7 @@ where } } -impl Sink for SplittedIO +impl Sink for SeparatedIO where O: Sink, { @@ -92,7 +97,7 @@ where } } -impl Sink for SplittedIO +impl Sink for SeparatedIO where O: Sink, { @@ -115,7 +120,7 @@ where } } -impl crate::io::IO for SplittedIO +impl crate::io::IO for SeparatedIO where O: Sink + Send, I: Stream + TraceInfo + Send, @@ -141,7 +146,12 @@ where self.src.get_last_trace_id() } - fn split(self) -> (impl Stream, impl Sink) { + fn split( + self, + ) -> ( + impl Stream + TraceInfo + Send, + impl Sink + Send, + ) { (self.src, self.dst) } } diff --git a/src/server/incoming/tokio.rs b/src/server/incoming/tokio.rs index 703a693..ea93421 100644 --- a/src/server/incoming/tokio.rs +++ b/src/server/incoming/tokio.rs @@ -4,7 +4,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; -use flume::{Receiver, Sender}; +use flume::Sender; use futures::Stream; use log::{debug, error}; use minitrace::collector::SpanContext; @@ -18,7 +18,7 @@ use crate::codec::tokio::Codec; use crate::codec::{Decoded, Encoded}; use crate::errors::CodecError; use crate::guard::HandleOutgoing; -use crate::io::{SplittedIO, IO}; +use crate::io::{SeparatedIO, IO}; use crate::link::TransferLink; use crate::packet::connected::{self, FramesMut}; use crate::packet::Packet; @@ -39,25 +39,18 @@ pin_project! { config: Config, socket: Arc, router: HashMap>>, - drop_receiver: Receiver, - drop_notifier: Sender, } } -impl Incoming { - fn clear_dropped_addr(self: Pin<&mut Self>) { - let mut this = self.project(); - for addr in this.drop_receiver.try_iter() { - this.router.remove(&addr); - this.offline.as_mut().disconnect(&addr); - } - } -} +// impl Incoming { +// fn clear_dropped_addr(self: Pin<&mut Self>) { +// // TODO +// } +// } impl MakeIncoming for TokioUdpSocket { fn make_incoming(self, config: Config) -> impl Stream { let socket = Arc::new(self); - let (drop_notifier, drop_receiver) = flume::unbounded(); Incoming { offline: UdpFramed::new(Arc::clone(&socket), Codec) .logged_err(|err| { @@ -67,8 +60,6 @@ impl MakeIncoming for TokioUdpSocket { socket, config, router: HashMap::new(), - drop_receiver, - drop_notifier, } } } @@ -76,8 +67,8 @@ impl MakeIncoming for TokioUdpSocket { impl Stream for Incoming { type Item = impl IO; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.as_mut().clear_dropped_addr(); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // self.as_mut().clear_dropped_addr(); let mut this = self.project(); loop { @@ -127,7 +118,7 @@ impl Stream for Incoming { }) }); - return Poll::Ready(Some(SplittedIO::new(src, dst))); + return Poll::Ready(Some(SeparatedIO::new(src, dst))); } } } diff --git a/src/state.rs b/src/state.rs index fe68865..02333ff 100644 --- a/src/state.rs +++ b/src/state.rs @@ -6,7 +6,7 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use futures::{Sink, Stream}; -use log::{debug, warn}; +use log::warn; use pin_project_lite::pin_project; use crate::errors::{CodecError, Error}; @@ -115,9 +115,7 @@ where } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if matches!(self.state, OutgoingState::Closed) { - return Poll::Ready(Err(Error::ConnectionClosed)); - } + // flush is allowed after the connection is closed, it will deliver ack. self.project().frame.poll_flush(cx).map_err(Into::into) } @@ -188,7 +186,6 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); if matches!(this.state, IncomingState::Closed) { - debug!("state closed, poll_next to deliver ack"); // Poll the frame even if the state is closed to because the peer can send the // DisconnectNotification as it did not receive ack. // This will trigger the ack of the DisconnectNotification to be delivered. @@ -299,14 +296,14 @@ mod test { FrameBody::DisconnectNotification )); - let mut closed = SinkExt::::close(&mut goodbye).await.unwrap_err(); + let closed = SinkExt::::close(&mut goodbye).await.unwrap_err(); // closed assert!(matches!(closed, Error::ConnectionClosed)); // No more DisconnectNotification assert_eq!(goodbye.frame.buf.len(), 1); - // closed - closed = SinkExt::::flush(&mut goodbye).await.unwrap_err(); - assert!(matches!(closed, Error::ConnectionClosed)); + std::future::poll_fn(|cx| SinkExt::::poll_ready_unpin(&mut goodbye, cx)) + .await + .unwrap_err(); } } diff --git a/src/tests.rs b/src/tests.rs index 54a4fa2..43bf40d 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -90,23 +90,34 @@ pub(crate) fn test_trace_log_setup() -> TestTraceLogGuard { TestTraceLogGuard { spans } } +fn make_server_conf() -> server::Config { + server::Config::new() + .send_buf_cap(1024) + .sever_guid(1919810) + .advertisement(&b"123456"[..]) + .max_mtu(1500) + .min_mtu(510) + .max_pending(1024) + .support_version(vec![9, 11, 13]) +} + +fn make_client_conf() -> client::Config { + client::Config::new() + .send_buf_cap(1024) + .mtu(1000) + .client_guid(114514) + .protocol_version(11) +} + #[tokio::test(unhandled_panic = "shutdown_runtime")] async fn test_tokio_udp_works() { let _guard = test_trace_log_setup(); let echo_server = async { - let config = server::Config::new() - .send_buf_cap(1024) - .sever_guid(1919810) - .advertisement(&b"123456"[..]) - .max_mtu(1500) - .min_mtu(510) - .max_pending(1024) - .support_version(vec![9, 11, 13]); let mut incoming = UdpSocket::bind("0.0.0.0:19132") .await .unwrap() - .make_incoming(config); + .make_incoming(make_server_conf()); loop { let io = incoming.next().await.unwrap(); tokio::spawn(async move { @@ -130,15 +141,10 @@ async fn test_tokio_udp_works() { tokio::spawn(echo_server); let client = async { - let config = client::Config::new() - .send_buf_cap(1024) - .mtu(1000) - .client_guid(114514) - .protocol_version(11); let mut io = UdpSocket::bind("0.0.0.0:0") .await .unwrap() - .connect_to("127.0.0.1:19132", config) + .connect_to("127.0.0.1:19132", make_client_conf()) .await .unwrap(); io.send(Bytes::from_iter(repeat(0xfe).take(256))) @@ -180,3 +186,104 @@ async fn test_tokio_udp_works() { tokio::spawn(client).await.unwrap(); } + +#[tokio::test(unhandled_panic = "shutdown_runtime")] +async fn test_4way_handshake_client_close() { + let _guard = test_trace_log_setup(); + + let server = async { + let mut incoming = UdpSocket::bind("0.0.0.0:19133") + .await + .unwrap() + .make_incoming(make_server_conf()); + loop { + let io = incoming.next().await.unwrap(); + tokio::spawn(async move { + tokio::pin!(io); + let mut ticker = tokio::time::interval(Duration::from_millis(10)); + loop { + tokio::select! { + None = io.next() => { + break; + } + _ = ticker.tick() => { + SinkExt::::flush(&mut io).await.unwrap(); + } + }; + } + info!("connection closed by client, close the io"); + loop { + tokio::select! { + _ = ticker.tick() => { + // That's ridiculous because all calculations are lazy, we should call `poll_next` on io to receive the ack + // so that close will return. + // TODO: eagerly deliver the ack + assert!(io.next().await.is_none()); + } + _ = SinkExt::::close(&mut io) => { + break; + } + } + } + info!("io closed"); + }); + } + }; + + let client = async { + let mut io = UdpSocket::bind("0.0.0.0:0") + .await + .unwrap() + .connect_to("127.0.0.1:19133", make_client_conf()) + .await + .unwrap(); + + io.send(Bytes::from_iter(repeat(0xfe).take(1024))) + .await + .unwrap(); + + // split it to avoid annoying borrow checker + let (src, dst) = IO::split(io); + + tokio::pin!(src); + tokio::pin!(dst); + + loop { + tokio::select! { + _ = src.next() => {}, // same as above, we should call `poll_next` on src to receive the ack so that dst.close will return + _ = dst.close() => { + break; + } + } + } + + info!("client closed the connection, wait for server to close"); + + let mut ticker = tokio::time::interval(Duration::from_millis(10)); + let mut last_ticker = tokio::time::interval(Duration::from_millis(20)); + let mut tick = 0; + loop { + tokio::select! { + None = src.next(), if tick == 0 => { + info!("received close notification from server, wait for 2MSL"); + tick = 1; + } + _ = ticker.tick() => { + dst.flush().await.unwrap(); + } + _ = last_ticker.tick(), if tick > 0 => { + // same as above, deliver the incoming ack + assert!(src.next().await.is_none()); + tick += 1; + // last 2MSL + if tick > 10 { + break; + } + } + }; + } + }; + + tokio::spawn(server); + tokio::spawn(client).await.unwrap(); +}