Skip to content

Commit

Permalink
Merge pull request #8 from casper-network/fix-large-vol-test-v2
Browse files Browse the repository at this point in the history
Apply fixes for deadlock in RPC test
  • Loading branch information
marc-casperlabs authored Mar 14, 2024
2 parents 90f92f0 + 9db8b5e commit 24ace0e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 51 deletions.
126 changes: 99 additions & 27 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ impl Channel {
self.outgoing_requests.len() < self.config.request_limit as usize
}

/// Returns the configured request limit for this channel.
pub fn request_limit(&self) -> u16 {
self.config.request_limit
}

/// Creates a new request, bypassing all client-side checks.
///
/// Low-level function that does nothing but create a syntactically correct request and track
Expand Down Expand Up @@ -474,7 +479,7 @@ impl Display for CompletedRead {
}
}

/// The caller of the this crate has violated the protocol.
/// The caller of this crate has violated the protocol.
///
/// A correct implementation of a client should never encounter this, thus simply unwrapping every
/// instance of this as part of a `Result<_, LocalProtocolViolation>` is usually a valid choice.
Expand All @@ -487,24 +492,45 @@ pub enum LocalProtocolViolation {
///
/// Wait for additional requests to be cancelled or answered. Calling
/// [`JulietProtocol::allowed_to_send_request()`] beforehand is recommended.
#[error("sending would exceed request limit")]
WouldExceedRequestLimit,
#[error("sending would exceed request limit of {limit}")]
WouldExceedRequestLimit {
/// The configured limit for requests on the channel.
limit: u16,
},
/// The channel given does not exist.
///
/// The given [`ChannelId`] exceeds `N` of [`JulietProtocol<N>`].
#[error("invalid channel")]
InvalidChannel(ChannelId),
#[error("channel {channel} not a member of configured {channel_count} channels")]
InvalidChannel {
/// The provided channel ID.
channel: ChannelId,
/// The configured number of channels.
channel_count: usize,
},
/// The given payload exceeds the configured limit.
///
/// See [`ChannelConfiguration::with_max_request_payload_size()`] and
/// [`ChannelConfiguration::with_max_response_payload_size()`] for details.
#[error("payload exceeds configured limit")]
PayloadExceedsLimit,
#[error("payload length of {payload_length} bytes exceeds configured limit of {limit}")]
PayloadExceedsLimit {
/// The payload length in bytes.
payload_length: usize,
/// The configured upper limit for payload length in bytes.
limit: usize,
},
/// The given error payload exceeds a single frame.
///
/// Error payloads may not span multiple frames, shorten the payload or increase frame size.
#[error("error payload would be multi-frame")]
ErrorPayloadIsMultiFrame,
#[error(
"error payload of {payload_length} bytes exceeds a single frame with configured max size \
of {max_frame_size})"
)]
ErrorPayloadIsMultiFrame {
/// The payload length in bytes.
payload_length: usize,
/// The configured maximum frame size in bytes.
max_frame_size: u32,
},
}

macro_rules! log_frame {
Expand Down Expand Up @@ -534,7 +560,10 @@ impl<const N: usize> JulietProtocol<N> {
#[inline(always)]
const fn lookup_channel(&self, channel: ChannelId) -> Result<&Channel, LocalProtocolViolation> {
if channel.0 as usize >= N {
Err(LocalProtocolViolation::InvalidChannel(channel))
Err(LocalProtocolViolation::InvalidChannel {
channel,
channel_count: N,
})
} else {
Ok(&self.channels[channel.0 as usize])
}
Expand All @@ -549,7 +578,10 @@ impl<const N: usize> JulietProtocol<N> {
channel: ChannelId,
) -> Result<&mut Channel, LocalProtocolViolation> {
if channel.0 as usize >= N {
Err(LocalProtocolViolation::InvalidChannel(channel))
Err(LocalProtocolViolation::InvalidChannel {
channel,
channel_count: N,
})
} else {
Ok(&mut self.channels[channel.0 as usize])
}
Expand Down Expand Up @@ -595,12 +627,17 @@ impl<const N: usize> JulietProtocol<N> {

if let Some(ref payload) = payload {
if payload.len() > chan.config.max_request_payload_size as usize {
return Err(LocalProtocolViolation::PayloadExceedsLimit);
return Err(LocalProtocolViolation::PayloadExceedsLimit {
payload_length: payload.len(),
limit: chan.config.max_request_payload_size as usize,
});
}
}

if !chan.allowed_to_send_request() {
return Err(LocalProtocolViolation::WouldExceedRequestLimit);
return Err(LocalProtocolViolation::WouldExceedRequestLimit {
limit: chan.request_limit(),
});
}

Ok(chan.create_unchecked_request(channel, payload))
Expand Down Expand Up @@ -637,7 +674,10 @@ impl<const N: usize> JulietProtocol<N> {

if let Some(ref payload) = payload {
if payload.len() > chan.config.max_response_payload_size as usize {
return Err(LocalProtocolViolation::PayloadExceedsLimit);
return Err(LocalProtocolViolation::PayloadExceedsLimit {
payload_length: payload.len(),
limit: chan.config.max_request_payload_size as usize,
});
}
}

Expand Down Expand Up @@ -712,11 +752,15 @@ impl<const N: usize> JulietProtocol<N> {
id: Id,
payload: Bytes,
) -> Result<OutgoingMessage, LocalProtocolViolation> {
let header = Header::new_error(header::ErrorKind::Other, channel, id);
let header = Header::new_error(ErrorKind::Other, channel, id);

let payload_length = payload.len();
let msg = OutgoingMessage::new(header, Some(payload));
if msg.is_multi_frame(self.max_frame_size) {
Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame)
Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame {
payload_length,
max_frame_size: self.max_frame_size.0,
})
} else {
Ok(msg)
}
Expand Down Expand Up @@ -1253,7 +1297,8 @@ mod tests {

#[test]
fn test_channel_lookups_work() {
let mut protocol: JulietProtocol<3> = ProtocolBuilder::new().build();
const CHANNEL_COUNT: usize = 3;
let mut protocol: JulietProtocol<CHANNEL_COUNT> = ProtocolBuilder::new().build();

// We mark channels by inserting an ID into them, that way we can ensure we're not getting
// back the same channel every time.
Expand All @@ -1274,15 +1319,24 @@ mod tests {
.insert(Id::new(102));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(3)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(3)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(3),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(4)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(4)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(4),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(255)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(255)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(255),
channel_count: CHANNEL_COUNT
})
));

// Now look up the channels and ensure they contain the right values
Expand All @@ -1309,15 +1363,24 @@ mod tests {
);
assert!(matches!(
protocol.lookup_channel(ChannelId(3)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(3)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(3),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel(ChannelId(4)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(4)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(4),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel(ChannelId(255)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(255)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(255),
channel_count: CHANNEL_COUNT
})
));
}

Expand Down Expand Up @@ -1442,7 +1505,10 @@ mod tests {
// Try an invalid channel, should result in an error.
assert!(matches!(
protocol.create_request(ChannelId::new(2), payload.get()),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(2)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(2),
channel_count: 2
})
));

assert!(protocol
Expand All @@ -1454,7 +1520,7 @@ mod tests {

assert!(matches!(
protocol.create_request(channel, payload.get()),
Err(LocalProtocolViolation::WouldExceedRequestLimit)
Err(LocalProtocolViolation::WouldExceedRequestLimit { limit: 1 })
));
}
}
Expand Down Expand Up @@ -2202,7 +2268,10 @@ mod tests {
.create_request(env.common_channel, payload.get())
.expect_err("should not be able to create too large request");

assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit);
assert_matches!(
violation,
LocalProtocolViolation::PayloadExceedsLimit { .. }
);

// If we force the issue, Bob must refuse it instead.
let bob_result = env.inject_and_send_request(Alice, payload.get());
Expand All @@ -2219,7 +2288,10 @@ mod tests {
.bob
.create_request(env.common_channel, payload.get())
.expect_err("should not be able to create too large response");
assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit);
assert_matches!(
violation,
LocalProtocolViolation::PayloadExceedsLimit { .. }
);

// If we force the issue, Alice must refuse it.
let alice_result = env.inject_and_send_response(Bob, id, payload.get());
Expand Down
1 change: 1 addition & 0 deletions src/protocol/multiframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub(super) enum MultiframeReceiver {

/// The outcome of a multiframe acceptance.
#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
pub(crate) enum CompletedFrame {
/// A new multi-frame transfer was started.
NewMultiFrame,
Expand Down
Loading

0 comments on commit 24ace0e

Please sign in to comment.