diff --git a/src/protocol.rs b/src/protocol.rs index 82145a5..759ddc1 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -305,6 +305,11 @@ impl Channel { self.outgoing_requests.len() < self.config.request_limit as usize } + /// Returns the configured request limit for this channel. + pub fn request_limit(&self) -> u16 { + self.config.request_limit + } + /// Creates a new request, bypassing all client-side checks. /// /// Low-level function that does nothing but create a syntactically correct request and track @@ -474,7 +479,7 @@ impl Display for CompletedRead { } } -/// The caller of the this crate has violated the protocol. +/// The caller of this crate has violated the protocol. /// /// A correct implementation of a client should never encounter this, thus simply unwrapping every /// instance of this as part of a `Result<_, LocalProtocolViolation>` is usually a valid choice. @@ -487,24 +492,45 @@ pub enum LocalProtocolViolation { /// /// Wait for additional requests to be cancelled or answered. Calling /// [`JulietProtocol::allowed_to_send_request()`] beforehand is recommended. - #[error("sending would exceed request limit")] - WouldExceedRequestLimit, + #[error("sending would exceed request limit of {limit}")] + WouldExceedRequestLimit { + /// The configured limit for requests on the channel. + limit: u16, + }, /// The channel given does not exist. /// /// The given [`ChannelId`] exceeds `N` of [`JulietProtocol`]. - #[error("invalid channel")] - InvalidChannel(ChannelId), + #[error("channel {channel} not a member of configured {channel_count} channels")] + InvalidChannel { + /// The provided channel ID. + channel: ChannelId, + /// The configured number of channels. + channel_count: usize, + }, /// The given payload exceeds the configured limit. /// /// See [`ChannelConfiguration::with_max_request_payload_size()`] and /// [`ChannelConfiguration::with_max_response_payload_size()`] for details. - #[error("payload exceeds configured limit")] - PayloadExceedsLimit, + #[error("payload length of {payload_length} bytes exceeds configured limit of {limit}")] + PayloadExceedsLimit { + /// The payload length in bytes. + payload_length: usize, + /// The configured upper limit for payload length in bytes. + limit: usize, + }, /// The given error payload exceeds a single frame. /// /// Error payloads may not span multiple frames, shorten the payload or increase frame size. - #[error("error payload would be multi-frame")] - ErrorPayloadIsMultiFrame, + #[error( + "error payload of {payload_length} bytes exceeds a single frame with configured max size \ + of {max_frame_size})" + )] + ErrorPayloadIsMultiFrame { + /// The payload length in bytes. + payload_length: usize, + /// The configured maximum frame size in bytes. + max_frame_size: u32, + }, } macro_rules! log_frame { @@ -534,7 +560,10 @@ impl JulietProtocol { #[inline(always)] const fn lookup_channel(&self, channel: ChannelId) -> Result<&Channel, LocalProtocolViolation> { if channel.0 as usize >= N { - Err(LocalProtocolViolation::InvalidChannel(channel)) + Err(LocalProtocolViolation::InvalidChannel { + channel, + channel_count: N, + }) } else { Ok(&self.channels[channel.0 as usize]) } @@ -549,7 +578,10 @@ impl JulietProtocol { channel: ChannelId, ) -> Result<&mut Channel, LocalProtocolViolation> { if channel.0 as usize >= N { - Err(LocalProtocolViolation::InvalidChannel(channel)) + Err(LocalProtocolViolation::InvalidChannel { + channel, + channel_count: N, + }) } else { Ok(&mut self.channels[channel.0 as usize]) } @@ -595,12 +627,17 @@ impl JulietProtocol { if let Some(ref payload) = payload { if payload.len() > chan.config.max_request_payload_size as usize { - return Err(LocalProtocolViolation::PayloadExceedsLimit); + return Err(LocalProtocolViolation::PayloadExceedsLimit { + payload_length: payload.len(), + limit: chan.config.max_request_payload_size as usize, + }); } } if !chan.allowed_to_send_request() { - return Err(LocalProtocolViolation::WouldExceedRequestLimit); + return Err(LocalProtocolViolation::WouldExceedRequestLimit { + limit: chan.request_limit(), + }); } Ok(chan.create_unchecked_request(channel, payload)) @@ -637,7 +674,10 @@ impl JulietProtocol { if let Some(ref payload) = payload { if payload.len() > chan.config.max_response_payload_size as usize { - return Err(LocalProtocolViolation::PayloadExceedsLimit); + return Err(LocalProtocolViolation::PayloadExceedsLimit { + payload_length: payload.len(), + limit: chan.config.max_request_payload_size as usize, + }); } } @@ -712,11 +752,15 @@ impl JulietProtocol { id: Id, payload: Bytes, ) -> Result { - let header = Header::new_error(header::ErrorKind::Other, channel, id); + let header = Header::new_error(ErrorKind::Other, channel, id); + let payload_length = payload.len(); let msg = OutgoingMessage::new(header, Some(payload)); if msg.is_multi_frame(self.max_frame_size) { - Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame) + Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame { + payload_length, + max_frame_size: self.max_frame_size.0, + }) } else { Ok(msg) } @@ -1253,7 +1297,8 @@ mod tests { #[test] fn test_channel_lookups_work() { - let mut protocol: JulietProtocol<3> = ProtocolBuilder::new().build(); + const CHANNEL_COUNT: usize = 3; + let mut protocol: JulietProtocol = ProtocolBuilder::new().build(); // We mark channels by inserting an ID into them, that way we can ensure we're not getting // back the same channel every time. @@ -1274,15 +1319,24 @@ mod tests { .insert(Id::new(102)); assert!(matches!( protocol.lookup_channel_mut(ChannelId(3)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(3))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(3), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel_mut(ChannelId(4)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(4))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(4), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel_mut(ChannelId(255)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(255))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(255), + channel_count: CHANNEL_COUNT + }) )); // Now look up the channels and ensure they contain the right values @@ -1309,15 +1363,24 @@ mod tests { ); assert!(matches!( protocol.lookup_channel(ChannelId(3)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(3))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(3), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel(ChannelId(4)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(4))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(4), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel(ChannelId(255)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(255))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(255), + channel_count: CHANNEL_COUNT + }) )); } @@ -1442,7 +1505,10 @@ mod tests { // Try an invalid channel, should result in an error. assert!(matches!( protocol.create_request(ChannelId::new(2), payload.get()), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(2))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(2), + channel_count: 2 + }) )); assert!(protocol @@ -1454,7 +1520,7 @@ mod tests { assert!(matches!( protocol.create_request(channel, payload.get()), - Err(LocalProtocolViolation::WouldExceedRequestLimit) + Err(LocalProtocolViolation::WouldExceedRequestLimit { limit: 1 }) )); } } @@ -2202,7 +2268,10 @@ mod tests { .create_request(env.common_channel, payload.get()) .expect_err("should not be able to create too large request"); - assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit); + assert_matches!( + violation, + LocalProtocolViolation::PayloadExceedsLimit { .. } + ); // If we force the issue, Bob must refuse it instead. let bob_result = env.inject_and_send_request(Alice, payload.get()); @@ -2219,7 +2288,10 @@ mod tests { .bob .create_request(env.common_channel, payload.get()) .expect_err("should not be able to create too large response"); - assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit); + assert_matches!( + violation, + LocalProtocolViolation::PayloadExceedsLimit { .. } + ); // If we force the issue, Alice must refuse it. let alice_result = env.inject_and_send_response(Bob, id, payload.get()); diff --git a/src/protocol/multiframe.rs b/src/protocol/multiframe.rs index de6a913..75c908a 100644 --- a/src/protocol/multiframe.rs +++ b/src/protocol/multiframe.rs @@ -44,6 +44,7 @@ pub(super) enum MultiframeReceiver { /// The outcome of a multiframe acceptance. #[derive(Debug)] +#[allow(clippy::enum_variant_names)] pub(crate) enum CompletedFrame { /// A new multi-frame transfer was started. NewMultiFrame, diff --git a/src/rpc.rs b/src/rpc.rs index cc2cff6..76165c5 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -847,7 +847,7 @@ impl IncomingRequest { // Do nothing, just discard the response. } EnqueueError::BufferLimitHit(_) => { - // TODO: Add seperate type to avoid this. + // TODO: Add separate type to avoid this. unreachable!("cannot hit request limit when responding") } } @@ -932,7 +932,10 @@ mod tests { use bytes::Bytes; use futures::FutureExt; - use tokio::io::{DuplexStream, ReadHalf, WriteHalf}; + use tokio::{ + io::{DuplexStream, ReadHalf, WriteHalf}, + sync::mpsc, + }; use tracing::{error_span, info, span, Instrument, Level}; use crate::{ @@ -1004,7 +1007,7 @@ mod tests { async fn run_echo_client( mut rpc_server: JulietRpcServer, WriteHalf>, ) { - while let Some(inc) = rpc_server + if let Some(inc) = rpc_server .next_request() .await .expect("client rpc_server error") @@ -1411,7 +1414,7 @@ mod tests { large_volume_test::<1>(spec).await; } - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn run_large_volume_test_with_default_values_10_channels() { tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) @@ -1433,7 +1436,7 @@ mod tests { let (mut alice, mut bob) = LargeVolumeTestSpec::::default().mk_rpc(); - // Alice server. Will close the connection after enough bytes have been sent. + // Alice server. Will close the connection after enough bytes have been received. let mut remaining = spec.min_send_bytes; let alice_server = tokio::spawn( async move { @@ -1452,6 +1455,7 @@ mod tests { request.respond(None); remaining = remaining.saturating_sub(payload_size); + tracing::debug!("payload_size: {payload_size}, remaining: {remaining}"); if remaining == 0 { // We've reached the volume we were looking for, end test. break; @@ -1501,14 +1505,18 @@ mod tests { Err(guard) => { // Not ready, but we are not going to wait. - tokio::spawn(async move { - if let Err(err) = guard.wait_for_response().await { - match err { - RequestError::RemoteClosed(_) | RequestError::Shutdown => {} - err => panic!("{}", err), + tokio::spawn( + async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) + | RequestError::Shutdown => {} + err => panic!("{}", err), + } } } - }); + .in_current_span(), + ); } } } @@ -1518,10 +1526,11 @@ mod tests { .instrument(error_span!("alice_client")), ); - // Bob server. + // A channel to allow Bob's server to notify Bob's client to send a new request to Alice. + let (notify_tx, mut notify_rx) = mpsc::unbounded_channel(); + // Bob server. Will shut down once Alice closes the connection. let bob_server = tokio::spawn( async move { - let mut bob_counter = 0; while let Some(request) = bob .server .next_request() @@ -1540,7 +1549,19 @@ mod tests { let channel = request.channel(); // Just discard the message payload, but acknowledge receiving it. request.respond(None); + // Notify Bob client to send a new request to Alice. + notify_tx.send(channel).unwrap(); + } + info!("exiting"); + } + .instrument(error_span!("bob_server")), + ); + // Bob client. Will shut down once Alice closes the connection. + let bob_client = tokio::spawn( + async move { + let mut bob_counter = 0; + while let Some(channel) = notify_rx.recv().await { let payload_size = spec.gen_payload_size(bob_counter); let large_payload: Bytes = iter::repeat(0xFF) .take(payload_size) @@ -1551,11 +1572,11 @@ mod tests { let bobs_request: RequestGuard = bob .client .create_request(channel) - .with_payload(large_payload.clone()) + .with_payload(large_payload) .queue_for_sending() .await; - info!(bob_counter, "bob enqueued request"); + info!(bob_counter, payload_size, "bob enqueued request"); bob_counter += 1; match bobs_request.try_get_response() { @@ -1573,26 +1594,30 @@ mod tests { Err(guard) => { // Do not wait, instead attempt to retrieve next request. - tokio::spawn(async move { - if let Err(err) = guard.wait_for_response().await { - match err { - RequestError::RemoteClosed(_) | RequestError::Shutdown => {} - err => panic!("{}", err), + tokio::spawn( + async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) + | RequestError::Shutdown => {} + err => panic!("{}", err), + } } } - }); + .in_current_span(), + ); } } } - info!("exiting"); } - .instrument(error_span!("bob_server")), + .instrument(error_span!("bob_client")), ); alice_server.await.expect("failed to join alice server"); alice_client.await.expect("failed to join alice client"); bob_server.await.expect("failed to join bob server"); + bob_client.await.expect("failed to join bob client"); info!("all joined"); } @@ -1632,7 +1657,7 @@ mod tests { let mut bob = CompleteSetup::new(&rpc_builder, bob_stream); let alice_join_handle = tokio::spawn(async move { - while let Some(incoming_request) = alice + if let Some(incoming_request) = alice .server .next_request() .await