Skip to content

Commit

Permalink
MAIN: Fix poll_write implementation
Browse files Browse the repository at this point in the history
* Make AsyncPDataWriter a proper state machine
* Remove use of `<stream>.write_all` which already loops over input data
  and removes some of our control, switch to manual use of `poll_write`
  on underlying stream
  • Loading branch information
naterichman committed Oct 21, 2024
1 parent 5929c2d commit 521ec2a
Showing 1 changed file with 124 additions and 48 deletions.
172 changes: 124 additions & 48 deletions ul/src/association/pdata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,14 @@ pub mod non_blocking {
pub use super::PDataReader;
use super::{calculate_max_data_len_single, setup_pdata_header};

/// Enum representing state of the Async Writer
enum WriteState {
// Ready to write to the underlying stream
Ready,
// Currently writing to underlying stream, with a position in the buffer
Writing(usize),
}

/// A P-Data async value writer.
///
/// This exposes an API to iteratively construct and send Data messages
Expand Down Expand Up @@ -414,7 +422,7 @@ pub mod non_blocking {
stream: W,
max_data_len: u32,
msg: u32,

Check failure on line 424 in ul/src/association/pdata.rs

View workflow job for this annotation

GitHub Actions / Check (macOS)

field `msg` is never read

Check failure on line 424 in ul/src/association/pdata.rs

View workflow job for this annotation

GitHub Actions / Build (Windows)

field `msg` is never read

Check failure on line 424 in ul/src/association/pdata.rs

View workflow job for this annotation

GitHub Actions / Test (default) (stable)

field `msg` is never read

Check failure on line 424 in ul/src/association/pdata.rs

View workflow job for this annotation

GitHub Actions / Test (default) (beta)

field `msg` is never read
writing: bool,
state: WriteState,
}

#[cfg(feature = "async")]
Expand Down Expand Up @@ -450,11 +458,66 @@ pub mod non_blocking {
]);

AsyncPDataWriter {
stream,
stream, // fn poll_write(
// mut self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// buf: &[u8],
// ) -> Poll<std::result::Result<usize, std::io::Error>> {
// // If we're still writing (i.e. last write was pending), continue writing
// if self.writing {
// let this = self.get_mut();
// let buffer = &this.buffer;
// let mut stream = Pin::new(&mut this.stream);
// // Each call to `poll_write` may or may not write the whole of `self.buffer`
// let write_all = stream.write_all(buffer);
// tokio::pin!(write_all);
// match write_all.poll(cx) {
// Poll::Ready(Ok(_)) => {
// this.writing = false;
// this.msg += 1;
// this.buffer.truncate(12);
// return Poll::Ready(Ok(buf.len()));
// }
// Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
// Poll::Pending => return Poll::Pending,
// }
// }
// let total_len = self.max_data_len as usize + 12;
// if self.buffer.len() + buf.len() <= total_len {
// // accumulate into buffer, do nothing
// self.buffer.extend(buf);
// Poll::Ready(Ok(buf.len()))
// } else {
// // fill in the rest of the buffer, send PDU,
// // and leave out the rest for subsequent writes
// let buf = &buf[..total_len - self.buffer.len()];
// self.buffer.extend(buf);
// debug_assert_eq!(self.buffer.len(), total_len);
// setup_pdata_header(&mut self.buffer, false);
// let this = self.get_mut();
// let buffer = &this.buffer;
// let mut stream = Pin::new(&mut this.stream);
// // Each call to `poll_write` may or may not write the whole of `self.buffer`
// let write_all = stream.write_all(buffer);
// tokio::pin!(write_all);
// match write_all.poll(cx) {
// Poll::Ready(Ok(_)) => {
// this.msg += 1;
// this.buffer.truncate(12);
// Poll::Ready(Ok(buf.len()))
// }
// Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
// Poll::Pending => {
// this.writing = true;
// Poll::Pending
// }
// }
// }
// }
max_data_len: max_data_length,
buffer,
msg: 0,
writing: false,
state: WriteState::Ready,
}
}

Expand Down Expand Up @@ -492,53 +555,66 @@ pub mod non_blocking {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
// If we're still writing (i.e. last write was pending), continue writing
if self.writing {
let this = self.get_mut();
let buffer = &this.buffer;
let mut stream = Pin::new(&mut this.stream);
// Each call to `poll_write` may or may not write the whole of `self.buffer`
let write_all = stream.write_all(buffer);
tokio::pin!(write_all);
match write_all.poll(cx) {
Poll::Ready(Ok(_)) => {
this.writing = false;
this.msg += 1;
this.buffer.truncate(12);
return Poll::Ready(Ok(buf.len()));
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
let total_len = self.max_data_len as usize + 12;
if self.buffer.len() + buf.len() <= total_len {
// accumulate into buffer, do nothing
self.buffer.extend(buf);
Poll::Ready(Ok(buf.len()))
} else {
// fill in the rest of the buffer, send PDU,
// and leave out the rest for subsequent writes
let buf = &buf[..total_len - self.buffer.len()];
self.buffer.extend(buf);
debug_assert_eq!(self.buffer.len(), total_len);
setup_pdata_header(&mut self.buffer, false);
let this = self.get_mut();
let buffer = &this.buffer;
let mut stream = Pin::new(&mut this.stream);
// Each call to `poll_write` may or may not write the whole of `self.buffer`
let write_all = stream.write_all(buffer);
tokio::pin!(write_all);
match write_all.poll(cx) {
Poll::Ready(Ok(_)) => {
this.msg += 1;
this.buffer.truncate(12);
// Each call to `poll_write` on the underlying stream may or may not
// write the whole of `self.buffer`, therefore we need to keep track
// of how much we've written, this is done in `self.state`
match self.state {
WriteState::Ready => {
// If we're in ready state, we can prepare another PDU
let total_len = self.max_data_len as usize + 12;
if self.buffer.len() + buf.len() <= total_len {
// Still have space in `self.buffer`, accumulate into buffer
self.buffer.extend(buf);
Poll::Ready(Ok(buf.len()))
} else {
// `self.buffer` is full, fill in the rest of the
// buffer, prepare to send PDU
let slice = &buf[..total_len - self.buffer.len()];
self.buffer.extend(slice);
debug_assert_eq!(self.buffer.len(), total_len);
setup_pdata_header(&mut self.buffer, false);
let this = self.get_mut();
// Attempt to send PDU on wire
match Pin::new(&mut this.stream).poll_write(cx, &this.buffer) {
Poll::Ready(Ok(n)) => {
if n == this.buffer.len() {
// If we wrote the whole buffer, reset `self.buffer`
this.buffer.truncate(12);
Poll::Ready(Ok(slice.len()))
} else {
// Otherwise keep track of how much we wrote and change state to Writing
this.state = WriteState::Writing(n);
Poll::Pending
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
// Nothing was written yet, change state to writing at position 0
this.state = WriteState::Writing(0);
Poll::Pending
}
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
this.writing = true;
Poll::Pending
}
WriteState::Writing(pos) => {
// Continue writing to stream from current position
let buflen = self.buffer.len();
let this = self.get_mut();
match Pin::new(&mut this.stream).poll_write(cx, &this.buffer[pos..]) {
Poll::Ready(Ok(n)) => {
if (n + pos) == this.buffer.len() {
// If we wrote the whole buffer, reset `self.buffer` and change state back to ready
this.buffer.truncate(12);
this.state = WriteState::Ready;
Poll::Ready(Ok(buflen - 12))
} else {
// Otherwise add to current position
this.state = WriteState::Writing(n + pos);
Poll::Pending
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
Expand Down

0 comments on commit 521ec2a

Please sign in to comment.