Skip to content

Commit

Permalink
Prevent races in multi-frame sends causing protocol violations
Browse files Browse the repository at this point in the history
  • Loading branch information
marc-casperlabs committed Feb 28, 2024
1 parent c317574 commit 67ff477
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ pub struct IoCore<const N: usize, R, W> {
/// The maximum time allowed for a peer to receive an error.
error_timeout: Duration,

/// The frame in the process of being sent, which may be partially transferred already.
current_frame: Option<OutgoingFrame>,
/// The frame in the process of being sent, which may be partially transferred already. Also
/// indicates if the current frame is the final frame of a message.
current_frame: Option<(OutgoingFrame, bool)>,
/// The headers of active current multi-frame transfers.
active_multi_frame: [Option<Header>; N],
/// Frames waiting to be sent.
Expand Down Expand Up @@ -543,27 +544,42 @@ where
tokio::select! {
biased; // We actually like the bias, avoid the randomness overhead.

write_result = write_all_buf_if_some(&mut self.writer, self.current_frame.as_mut())
write_result = write_all_buf_if_some(&mut self.writer,
self.current_frame.as_mut()
.map(|(ref mut frame, _)| frame))
, if self.current_frame.is_some() => {

write_result.map_err(CoreError::WriteFailed)?;

// Clear `current_frame` via `Option::take` and examine what was sent.
if let Some(frame_sent) = self.current_frame.take() {
if let Some((frame_sent, was_final)) = self.current_frame.take() {
#[cfg(feature = "tracing")]
tracing::trace!(frame=%frame_sent, "sent");

if frame_sent.header().is_error() {
let header_sent = frame_sent.header();

// If we finished the active multi frame send, clear it.
if was_final {
let channel_idx = header_sent.channel().get() as usize;
if let Some(ref active_multi_frame) =
self.active_multi_frame[channel_idx] {
if header_sent == *active_multi_frame {
self.active_multi_frame[channel_idx] = None;
}
}
}

if header_sent.is_error() {
// We finished sending an error frame, time to exit.
return Err(CoreError::RemoteProtocolViolation(frame_sent.header()));
return Err(CoreError::RemoteProtocolViolation(header_sent));
}

// TODO: We should restrict the dirty-queue processing here a little bit
// (only check when completing a multi-frame message).
// A message has completed sending, process the wait queue in case we have
// to start sending a multi-frame message like a response that was delayed
// only because of the one-multi-frame-per-channel restriction.
self.process_wait_queue(frame_sent.header().channel())?;
self.process_wait_queue(header_sent.channel())?;
} else {
#[cfg(feature = "tracing")]
tracing::error!("current frame should not disappear");
Expand Down Expand Up @@ -798,23 +814,14 @@ where
.next_owned(self.juliet.max_frame_size());

// If there are more frames after this one, schedule the remainder.
if let Some(next_frame_iter) = additional_frames {
let is_final = if let Some(next_frame_iter) = additional_frames {
self.ready_queue.push_back(next_frame_iter);
false
} else {
// No additional frames. Check if sending the next frame finishes a multi-frame message.
let about_to_finish = frame.header();
if let Some(ref active_multi) =
self.active_multi_frame[about_to_finish.channel().get() as usize]
{
if about_to_finish == *active_multi {
// Once the scheduled frame is processed, we will finished the multi-frame
// transfer, so we can allow for the next multi-frame transfer to be scheduled.
self.active_multi_frame[about_to_finish.channel().get() as usize] = None;
}
}
}
true
};

self.current_frame = Some(frame);
self.current_frame = Some((frame, is_final));
Ok(())
}

Expand Down

0 comments on commit 67ff477

Please sign in to comment.