diff --git a/benches/bulk.rs b/benches/bulk.rs index 89dc9f2..aa0511e 100644 --- a/benches/bulk.rs +++ b/benches/bulk.rs @@ -11,6 +11,9 @@ use raknet_rs::server::{self, MakeIncoming}; use tokio::net::UdpSocket as TokioUdpSocket; use tokio::runtime::Runtime; +const SEND_BUF_CAP: usize = 1024; +const MTU: u16 = 1500; + pub fn bulk_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("bulk_benchmark"); let server_addr = spawn_server(); @@ -44,6 +47,14 @@ pub fn bulk_benchmark(c: &mut Criterion) { }); } + // The following benchmarks are not stable, and the reason is as follows: + // Some ack may fail to wake up a certain client (This client has already received the ack + // before falling asleep to wait RTO, causing the ack not to wake it up. This is almost + // impossible to happen in real-life scenarios.), and this client will wait for a complete + // RTO period before receiving this ack. This ultimately causes this round of benchmarking to + // stall for a while, affecting the benchmark results. + + // TODO: find a way to make the benchmark stable { group.throughput(Throughput::Bytes(short_data.len() as u64 * 10)); group.bench_function("short_data_10_clients", |bencher| { @@ -79,9 +90,8 @@ fn configure_bencher( sock.connect_to( server_addr, client::Config::default() - .send_buf_cap(1024) - .mtu(1400) - .client_guid(1919810) + .send_buf_cap(SEND_BUF_CAP) + .mtu(MTU) .protocol_version(11), ) .await @@ -100,10 +110,6 @@ fn configure_bencher( tokio::pin!(client); client.feed(Bytes::from_static(data)).await.unwrap(); debug!("client {} finished feeding", i); - // TODO: This is the culprit that currently causes the benchmark to be very - // slow. The current implementation avoids spinning in close check by waiting - // for an RTO each time before starting the check, which usually takes a long - // time. client.close().await.unwrap(); // make sure all data is sent debug!("client {} closed", i); }); @@ -122,11 +128,10 @@ fn spawn_server() -> SocketAddr { let server_addr = sock.local_addr().unwrap(); rt().spawn(async move { let config = server::Config::new() - .send_buf_cap(1024) - .sever_guid(114514) - .advertisement(&b"Hello, I am proxy server"[..]) + .send_buf_cap(SEND_BUF_CAP) + .advertisement(&b"Hello, I am server"[..]) .min_mtu(500) - .max_mtu(1400) + .max_mtu(MTU) .support_version(vec![9, 11, 13]) .max_pending(1024); let mut incoming = sock.make_incoming(config); diff --git a/src/guard.rs b/src/guard.rs index 41c7f9e..6353fa1 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -61,7 +61,7 @@ where peer, role, cap, - resend: ResendMap::new(), + resend: ResendMap::new(role), } } } @@ -227,8 +227,10 @@ where } /// Close the outgoing guard, notice that it may resend infinitely if you do not cancel it. + /// Insure all frames are received by the peer at the point of closing fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // insure all frames are received by the peer at the point of closing + // maybe go to sleep, turn on the waking + self.link.turn_on_waking(); loop { ready!(self.as_mut().try_empty(cx))?; debug_assert!(self.buf.is_empty() && self.link.flush_empty()); @@ -240,16 +242,10 @@ where ); break; } - // wait for the next resend - // TODO: When receiving an ack, we should immediately stop waiting and check if it can - // be terminated. - trace!( - "[{}] poll_wait for next timeout, resend map size: {}", - self.role, - self.resend.size() - ); ready!(self.resend.poll_wait(cx)); } + // no need to wake up + self.link.turn_off_waking(); self.project().frame.poll_close(cx) } } diff --git a/src/link.rs b/src/link.rs index 5fad6cd..9e4e623 100644 --- a/src/link.rs +++ b/src/link.rs @@ -1,15 +1,16 @@ use std::cmp::Reverse; use std::collections::{BinaryHeap, VecDeque}; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use async_channel::Sender; use concurrent_queue::ConcurrentQueue; use futures::Stream; -use log::{trace, warn}; +use log::{debug, trace, warn}; -use crate::packet::connected::{self, AckOrNack, Frame, FrameBody, FrameSet, FramesMut}; +use crate::packet::connected::{self, AckOrNack, Frame, FrameBody, FrameSet, FramesMut, Record}; use crate::packet::unconnected; -use crate::resend_map::ResendMap; +use crate::resend_map::{reactor, ResendMap}; use crate::utils::u24; use crate::RoleContext; @@ -21,6 +22,7 @@ pub(crate) type SharedLink = Arc; pub(crate) struct TransferLink { incoming_ack: ConcurrentQueue, incoming_nack: ConcurrentQueue, + forward_waking: AtomicBool, outgoing_ack: parking_lot::Mutex>>, // TODO: nack channel should always be in order according to [`DeFragment::poll_next`], replace @@ -60,6 +62,7 @@ impl TransferLink { Arc::new(Self { incoming_ack: ConcurrentQueue::bounded(MAX_ACK_BUFFER), incoming_nack: ConcurrentQueue::bounded(MAX_ACK_BUFFER), + forward_waking: AtomicBool::new(false), outgoing_ack: parking_lot::Mutex::new(BinaryHeap::with_capacity(MAX_ACK_BUFFER)), outgoing_nack: parking_lot::Mutex::new(BinaryHeap::with_capacity(MAX_ACK_BUFFER)), unconnected: ConcurrentQueue::unbounded(), @@ -68,7 +71,41 @@ impl TransferLink { }) } + pub(crate) fn turn_on_waking(&self) { + self.forward_waking + .store(true, std::sync::atomic::Ordering::Relaxed); + } + + fn should_waking(&self) -> bool { + self.forward_waking + .load(std::sync::atomic::Ordering::Relaxed) + } + + pub(crate) fn turn_off_waking(&self) { + self.forward_waking + .store(false, std::sync::atomic::Ordering::Relaxed); + } + pub(crate) fn incoming_ack(&self, records: AckOrNack) { + let to_wakes = if self.should_waking() { + let mut wakers = Vec::new(); + let mut guard = reactor::Reactor::get().lock(); + for record in &records.records { + match record { + Record::Range(start, end) => { + for seq_num in start.to_u32()..=end.to_u32() { + guard.cancel_timer(seq_num.into(), self.role.guid(), &mut wakers); + } + } + Record::Single(seq_num) => { + guard.cancel_timer(*seq_num, self.role.guid(), &mut wakers); + } + } + } + Some(wakers) + } else { + None + }; if let Some(dropped) = self.incoming_ack.force_push(records).unwrap() { warn!( "[{}] discard received ack {dropped:?}, total count: {}", @@ -76,6 +113,18 @@ impl TransferLink { dropped.total_cnt() ); } + // wake up after sends ack + if let Some(wakers) = to_wakes { + debug!( + "[{}] wake up {} wakers after receives ack", + self.role, + wakers.len() + ); + for waker in wakers { + // safe to panic + waker.wake(); + } + } } pub(crate) fn incoming_nack(&self, records: AckOrNack) { diff --git a/src/resend_map.rs b/src/resend_map.rs index d349814..48d51dd 100644 --- a/src/resend_map.rs +++ b/src/resend_map.rs @@ -2,8 +2,11 @@ use std::collections::{HashMap, VecDeque}; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; +use log::trace; + use crate::packet::connected::{AckOrNack, Frame, Frames, Record}; use crate::utils::u24; +use crate::RoleContext; // TODO: use RTTEstimator to get adaptive RTO const RTO: Duration = Duration::from_secs(1); @@ -15,13 +18,15 @@ struct ResendEntry { pub(crate) struct ResendMap { map: HashMap, + role: RoleContext, last_record_expired_at: Instant, } impl ResendMap { - pub(crate) fn new() -> Self { + pub(crate) fn new(role: RoleContext) -> Self { Self { map: HashMap::new(), + role, last_record_expired_at: Instant::now(), } } @@ -89,6 +94,12 @@ impl ResendMap { } }); debug_assert!(min_expired_at > now); + trace!( + "[{}]: process stales, {} entries left, next expired at {:?}", + self.role, + self.map.len(), + min_expired_at + ); self.last_record_expired_at = min_expired_at; } @@ -96,61 +107,83 @@ impl ResendMap { self.map.is_empty() } - pub(crate) fn size(&self) -> usize { - self.map.len() - } - /// `poll_wait` suspends the task when the resend map needs to wait for the next resend pub(crate) fn poll_wait(&self, cx: &mut Context<'_>) -> Poll<()> { let expired_at; - if let Some((_, entry)) = self.map.iter().min_by_key(|(_, entry)| entry.expired_at) - && entry.expired_at > Instant::now() + let seq_num; + let now = Instant::now(); + if let Some((seq, entry)) = self.map.iter().min_by_key(|(_, entry)| entry.expired_at) + && entry.expired_at > now { expired_at = entry.expired_at; + seq_num = *seq; } else { return Poll::Ready(()); } - reactor::Reactor::get().insert_timer(expired_at, cx.waker()); + trace!( + "[{}]: wait for resend seq_num {} within {:?}", + self.role, + seq_num, + expired_at - now + ); + reactor::Reactor::get().insert_timer(expired_at, seq_num, self.role.guid(), cx.waker()); Poll::Pending } } -/// Timer reactor -mod reactor { - use std::collections::BTreeMap; - use std::sync::atomic::{AtomicUsize, Ordering}; +/// Specialized timer reactor for resend map +pub(crate) mod reactor { + use std::collections::{BTreeMap, HashMap}; use std::sync::OnceLock; use std::task::Waker; use std::time::{Duration, Instant}; use std::{mem, panic, thread}; - use log::trace; - use parking_lot::{Condvar, Mutex}; + use crate::utils::u24; + + /// A unique sequence number with a global unique ID. + /// This is used to identify a timer across the different peers. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default)] + struct UniqueSeq { + seq_num: u24, + guid: u64, + } /// A simple timer reactor. /// /// There is only one global instance of this type, accessible by [`Reactor::get()`]. pub(crate) struct Reactor { - /// An ordered map of registered timers. - /// - /// Timers are in the order in which they fire. The `usize` in this type is a timer ID used - /// to distinguish timers that fire at the same time. The `Waker` represents the - /// task awaiting the timer. - timers: Mutex>, - cond: Condvar, + /// Inner state of the timer reactor. + inner: parking_lot::Mutex, + /// A condition variable to notify the reactor of new timers. + cond: parking_lot::Condvar, } - fn main_loop() { - let reactor = Reactor::get(); - loop { - reactor.process_timers(); - } + /// Inner state of the timer reactor. + struct ReactorInner { + /// An ordered map of registered timers. + /// + /// Timers are in the order in which they fire. The `UniqueSeq` in this type is relative to + /// the timer and plays a role in a unique ID for same timeout. The `Waker` + /// represents the task awaiting the timer. + timers: BTreeMap<(Instant, UniqueSeq), Waker>, + /// A mapping of unique seq num to their respective `Instant`s. + /// + /// This is used to cancel timers with a given sequence number. + mapping: HashMap, } impl Reactor { pub(crate) fn get() -> &'static Reactor { static REACTOR: OnceLock = OnceLock::new(); + fn main_loop() { + let reactor = Reactor::get(); + loop { + reactor.process_timers(); + } + } + REACTOR.get_or_init(|| { // Spawn the daemon thread to motivate the reactor. thread::Builder::new() @@ -159,36 +192,39 @@ mod reactor { .expect("cannot spawn timer-reactor thread"); Reactor { - timers: Mutex::new(BTreeMap::new()), - cond: Condvar::new(), + inner: parking_lot::Mutex::new(ReactorInner { + timers: BTreeMap::new(), + mapping: HashMap::new(), + }), + cond: parking_lot::Condvar::new(), } }) } + /// Locks the reactor for exclusive access. + pub(crate) fn lock(&self) -> ReactorLock<'_> { + ReactorLock { + inner: self.inner.lock(), + cond: &self.cond, + } + } + /// Registers a timer in the reactor. /// /// Returns the inserted timer's ID. - pub(crate) fn insert_timer(&self, when: Instant, waker: &Waker) -> usize { - // Generate a new timer ID. - static ID_GENERATOR: AtomicUsize = AtomicUsize::new(1); - let id = ID_GENERATOR.fetch_add(1, Ordering::Relaxed); - - let mut guard = self.timers.lock(); - guard.insert((when, id), waker.clone()); + pub(crate) fn insert_timer(&self, when: Instant, seq_num: u24, guid: u64, waker: &Waker) { + let mut guard = self.inner.lock(); + let unique_seq = UniqueSeq { seq_num, guid }; + guard.mapping.insert(unique_seq, when); + guard.timers.insert((when, unique_seq), waker.clone()); // Notify that a timer has been inserted. self.cond.notify_one(); - - drop(guard); - - id } - /// Processes ready timers and extends the list of wakers to wake. - /// - /// Returns the duration until the next timer before this method was called. + /// Processes ready timers and waits for the next timer to be inserted. fn process_timers(&self) { - let mut timers = self.timers.lock(); + let mut inner = self.inner.lock(); let now = Instant::now(); @@ -196,27 +232,55 @@ mod reactor { // // Careful to split just *after* `now`, so that a timer set for exactly `now` is // considered ready. - let pending = timers.split_off(&(now + Duration::from_nanos(1), 0)); - let ready = mem::replace(&mut *timers, pending); + let pending = inner + .timers + .split_off(&(now + Duration::from_nanos(1), UniqueSeq::default())); + let ready = mem::replace(&mut inner.timers, pending); - // Calculate the duration until the next event. - let dur = timers - .keys() - .next() - .map(|(when, _)| when.saturating_duration_since(now)); - - for (_, waker) in ready { + for ((_, seq_num), waker) in ready { + inner.mapping.remove(&seq_num); // TODO: wake up maybe slow down the reactor // Don't let a panicking waker blow everything up. panic::catch_unwind(|| waker.wake()).ok(); } + // Calculate the duration until the next event. + let dur = inner + .timers + .keys() + .next() + .map(|(when, _)| when.saturating_duration_since(now)); + if let Some(dur) = dur { - trace!("[timer_reactor] wait for {dur:?}"); - self.cond.wait_for(&mut timers, dur); + self.cond.wait_for(&mut inner, dur); } else { - trace!("[timer_reactor] wait for next timer insertion"); - self.cond.wait(&mut timers); + self.cond.wait(&mut inner); + } + } + } + + pub(crate) struct ReactorLock<'a> { + inner: parking_lot::MutexGuard<'a, ReactorInner>, + cond: &'a parking_lot::Condvar, + } + + impl Drop for ReactorLock<'_> { + fn drop(&mut self) { + // Notify the reactor that the inner state has changed. + self.cond.notify_one(); + } + } + + impl ReactorLock<'_> { + pub(crate) fn cancel_timer(&mut self, seq_num: u24, guid: u64, wakers: &mut Vec) { + let unique_seq = UniqueSeq { seq_num, guid }; + if let Some(when) = self.inner.mapping.remove(&unique_seq) { + wakers.push( + self.inner + .timers + .remove(&(when, unique_seq)) + .expect("timer should exist"), + ); } } } @@ -233,13 +297,13 @@ mod test { use super::ResendMap; use crate::packet::connected::{AckOrNack, Flags, Frame}; use crate::tests::test_trace_log_setup; - use crate::Reliability; + use crate::{Reliability, RoleContext}; const TEST_RTO: Duration = Duration::from_millis(1200); #[test] fn test_resend_map_works() { - let mut map = ResendMap::new(); + let mut map = ResendMap::new(RoleContext::test_server()); map.record(0.into(), vec![]); map.record(1.into(), vec![]); map.record(2.into(), vec![]); @@ -294,7 +358,7 @@ mod test { #[test] fn test_resend_map_stales() { - let mut map = ResendMap::new(); + let mut map = ResendMap::new(RoleContext::test_server()); map.record(0.into(), vec![]); map.record(1.into(), vec![]); map.record(2.into(), vec![]); @@ -309,7 +373,7 @@ mod test { async fn test_resend_map_poll_wait() { let _guard = test_trace_log_setup(); - let mut map = ResendMap::new(); + let mut map = ResendMap::new(RoleContext::test_server()); map.record(0.into(), vec![]); std::thread::sleep(TEST_RTO); map.record(1.into(), vec![]);