Skip to content

Commit

Permalink
fix: prevent reconnecting of client after invoking disconnect (1c3t3a…
Browse files Browse the repository at this point in the history
  • Loading branch information
SenseiHiraku committed Sep 23, 2024
1 parent 2ef32ec commit f4e6f8b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 28 deletions.
82 changes: 59 additions & 23 deletions socketio/src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
time::Duration,
};

use super::{ClientBuilder, RawClient};
use super::{raw_client::DisconnectReason, ClientBuilder, RawClient};
use crate::{
error::Result,
packet::{Packet, PacketId},
Expand Down Expand Up @@ -165,6 +165,11 @@ impl Client {
client.disconnect()
}

fn do_disconnect(&self) -> Result<()> {
let client = self.client.read()?;
client.do_disconnect()
}

fn reconnect(&mut self) -> Result<()> {
let mut reconnect_attempts = 0;
let (reconnect, max_reconnect_attempts) = {
Expand All @@ -174,6 +179,17 @@ impl Client {

if reconnect {
loop {
// Check if disconnect_reason is Manual
{
let disconnect_reason = {
let client = self.client.read()?;
client.get_disconnect_reason()
};
if disconnect_reason == DisconnectReason::Manual {
// Exit the loop, stop reconnecting
break;
}
}
if let Some(max_reconnect_attempts) = max_reconnect_attempts {
reconnect_attempts += 1;
if reconnect_attempts > max_reconnect_attempts {
Expand All @@ -186,6 +202,12 @@ impl Client {
}

if self.do_reconnect().is_ok() {
// Reset disconnect_reason to Unknown after successful reconnection
{
let client = self.client.read()?;
let mut reason = client.disconnect_reason.write()?;
*reason = DisconnectReason::Unknown;
}
break;
}
}
Expand Down Expand Up @@ -213,29 +235,43 @@ impl Client {
let mut self_clone = self.clone();
// Use thread to consume items in iterator in order to call callbacks
std::thread::spawn(move || {
// tries to restart a poll cycle whenever a 'normal' error occurs,
// it just panics on network errors, in case the poll cycle returned
// `Result::Ok`, the server receives a close frame so it's safe to
// terminate
for packet in self_clone.iter() {
let should_reconnect = match packet {
Err(Error::IncompleteResponseFromEngineIo(_)) => {
//TODO: 0.3.X handle errors
//TODO: logging error
true
loop {
let next_item = self_clone.iter().next();
match next_item {
Some(Ok(_packet)) => {
// Process packet normally
continue;
}
Some(Err(_)) => {
let should_reconnect = {
let disconnect_reason = {
let client = self_clone.client.read().unwrap();
client.get_disconnect_reason()
};
match disconnect_reason {
DisconnectReason::Unknown => {
let builder = self_clone.builder.lock().unwrap();
builder.reconnect
}
DisconnectReason::Manual => false,
DisconnectReason::Server => {
let builder = self_clone.builder.lock().unwrap();
builder.reconnect_on_disconnect
}
}
};
if should_reconnect {
let _ = self_clone.do_disconnect();
let _ = self_clone.reconnect();
} else {
// No reconnection needed, exit the loop
break;
}
}
None => {
// Iterator has ended, exit the loop
break;
}
Ok(Packet {
packet_type: PacketId::Disconnect,
..
}) => match self_clone.builder.lock() {
Ok(builder) => builder.reconnect_on_disconnect,
Err(_) => false,
},
_ => false,
};
if should_reconnect {
let _ = self_clone.disconnect();
let _ = self_clone.reconnect();
}
}
});
Expand Down
50 changes: 45 additions & 5 deletions socketio/src/client/raw_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@ use crate::client::callback::{SocketAnyCallback, SocketCallback};
use crate::error::Result;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;
use std::time::Instant;

use crate::socket::Socket as InnerSocket;

#[derive(Default, Clone, Copy, PartialEq)]
pub enum DisconnectReason {
/// There is no known reason for the disconnect; likely a network error
#[default]
Unknown,
/// The user disconnected manually
Manual,
/// The server disconnected
Server,
}

/// Represents an `Ack` as given back to the caller. Holds the internal `id` as
/// well as the current ack'ed state. Holds data which will be accessible as
/// soon as the ack'ed state is set to true. An `Ack` that didn't get ack'ed
Expand All @@ -41,6 +52,7 @@ pub struct RawClient {
nsp: String,
// Data send in the opening packet (commonly used as for auth)
auth: Option<Value>,
pub(crate) disconnect_reason: Arc<RwLock<DisconnectReason>>,
}

impl RawClient {
Expand All @@ -62,6 +74,7 @@ impl RawClient {
on_any,
outstanding_acks: Arc::new(Mutex::new(Vec::new())),
auth,
disconnect_reason: Arc::new(RwLock::new(DisconnectReason::default())),
})
}

Expand Down Expand Up @@ -142,7 +155,14 @@ impl RawClient {
///
/// ```
pub fn disconnect(&self) -> Result<()> {
let disconnect_packet =
*(self.disconnect_reason.write()?) = DisconnectReason::Manual;
self.do_disconnect()
}

/// Disconnects this client the same way as `disconnect()` but
/// without setting the `DisconnectReason` to `DisconnectReason::Manual`
pub fn do_disconnect(&self) -> Result<()> {
let disconnect_packet =
Packet::new(PacketId::Disconnect, self.nsp.clone(), None, None, 0, None);

// TODO: logging
Expand All @@ -153,6 +173,10 @@ impl RawClient {
Ok(())
}

pub fn get_disconnect_reason(&self) -> DisconnectReason {
*self.disconnect_reason.read().unwrap()
}

/// Sends a message to the server but `alloc`s an `ack` to check whether the
/// server responded in a given time span. This message takes an event, which
/// could either be one of the common events like "message" or "error" or a
Expand Down Expand Up @@ -222,18 +246,32 @@ impl RawClient {
}

pub(crate) fn poll(&self) -> Result<Option<Packet>> {
{
let disconnect_reason = *self.disconnect_reason.read()?;
if disconnect_reason == DisconnectReason::Manual {
// If disconnected manually, return Ok(None) to end iterator
return Ok(None);
}
}
loop {
match self.socket.poll() {
Err(err) => {
self.callback(&Event::Error, err.to_string())?;
return Err(err);
// Check if the disconnection was manual
let disconnect_reason = *self.disconnect_reason.read()?;
if disconnect_reason == DisconnectReason::Manual {
// Return Ok(None) to signal the end of the iterator
return Ok(None);
} else {
self.callback(&Event::Error, err.to_string())?;
return Err(err);
}
}
Ok(Some(packet)) => {
if packet.nsp == self.nsp {
self.handle_socketio_packet(&packet)?;
return Ok(Some(packet));
} else {
// Not our namespace continue polling
// Not our namespace, continue polling
}
}
Ok(None) => return Ok(None),
Expand Down Expand Up @@ -369,9 +407,11 @@ impl RawClient {
}
}
PacketId::Connect => {
*(self.disconnect_reason.write()?) = DisconnectReason::default();
self.callback(&Event::Connect, "")?;
}
PacketId::Disconnect => {
*(self.disconnect_reason.write()?) = DisconnectReason::Server;
self.callback(&Event::Close, "")?;
}
PacketId::ConnectError => {
Expand Down

0 comments on commit f4e6f8b

Please sign in to comment.