Skip to content

Commit

Permalink
refactor(kvsd_protocol): make it possible to parse without check buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
ymgyt committed Oct 29, 2024
1 parent 014ae83 commit ba4488f
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 418 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/synd_kvsd_protocol/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ futures = { workspace = true }
nom = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["net", "time", "io-util"] }
tracing = { workspace = true }

[dev-dependencies]
tokio = { workspace = true, features = ["net", "time", "io-util", "macros", "rt-multi-thread"] }
Expand Down
94 changes: 35 additions & 59 deletions crates/synd_kvsd_protocol/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ use std::{
time::Duration,
};

use bytes::{Buf as _, BytesMut};
use bytes::{Buf, BytesMut};
use futures::TryFutureExt as _;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt, BufWriter},
net::TcpStream,
time::error::Elapsed,
};
use tracing::trace;

use crate::message::{Cursor, FrameError, Message, MessageError, MessageFrames};
use crate::message::{FrameError, Message, MessageError, ParseError, Parser};

#[derive(Error, Debug)]
pub enum ConnectionError {
Expand All @@ -24,6 +25,8 @@ pub enum ConnectionError {
ParseMessageFrames { source: FrameError },
#[error("read message io: {source}")]
ReadMessageIo { source: io::Error },
#[error("parse message: {0}")]
ParseMessage(#[from] ParseError),
#[error("connection reset by peer")]
ResetByPeer,
#[error("write message: {source}")]
Expand All @@ -35,14 +38,6 @@ impl ConnectionError {
ConnectionError::ReadTimeout(elapsed)
}

fn read_message_frames(source: MessageError) -> Self {
ConnectionError::ReadMessageFrames { source }
}

fn parse_message_frames(source: FrameError) -> Self {
ConnectionError::ParseMessageFrames { source }
}

fn read_message_io(source: io::Error) -> Self {
ConnectionError::ReadMessageIo { source }
}
Expand Down Expand Up @@ -102,54 +97,36 @@ where
}

pub async fn read_message(&mut self) -> Result<Option<Message>, ConnectionError> {
match self.read_message_frames().await? {
Some(message_frames) => Ok(Some(
Message::parse(message_frames).map_err(ConnectionError::read_message_frames)?,
)),
None => Ok(None),
}
}

async fn read_message_frames(&mut self) -> Result<Option<MessageFrames>, ConnectionError> {
loop {
if let Some(message_frames) = self.parse_message_frames()? {
return Ok(Some(message_frames));
}

if 0 == self
.stream
.read_buf(&mut self.buffer)
.await
.map_err(ConnectionError::read_message_io)?
{
return if self.buffer.is_empty() {
Ok(None)
} else {
Err(ConnectionError::ResetByPeer)
};
}
}
}

fn parse_message_frames(&mut self) -> Result<Option<MessageFrames>, ConnectionError> {
use FrameError::Incomplete;

let mut cursor = Cursor::new(&self.buffer[..]);

match MessageFrames::check_parse(&mut cursor) {
Ok(()) => {
#[allow(clippy::cast_possible_truncation)]
let len = cursor.position() as usize;
cursor.set_position(0);
let message_frames = MessageFrames::parse(&mut cursor)
.map_err(ConnectionError::parse_message_frames)?;
self.buffer.advance(len);

Ok(Some(message_frames))
let input = &self.buffer[..];
match Parser::new().parse(input) {
Ok((remain, message)) => {
let consumed = input.len() - remain.len();
self.buffer.advance(consumed);
return Ok(Some(message));
}
Err(ParseError::Incomplete) => {
let read = self
.stream
.read_buf(&mut self.buffer)
.await
.map_err(ConnectionError::read_message_io)?;
trace!(
bytes = read,
buffer = self.buffer.len(),
"read from connection"
);

if read == 0 {
return if self.buffer.is_empty() {
Ok(None)
} else {
Err(ConnectionError::ResetByPeer)
};
}
}
Err(err) => return Err(ConnectionError::from(err)),
}
Err(Incomplete) => Ok(None),
// TODO: define distinct error
Err(e) => Err(ConnectionError::parse_message_frames(e)),
}
}
}
Expand All @@ -166,16 +143,15 @@ mod tests {

let buf_size = 1024;
let (read, write) = tokio::io::duplex(buf_size);
let (mut _read, mut write) = (
let (mut read, mut write) = (
Connection::new(read, buf_size),
Connection::new(write, buf_size),
);

for message in messages {
write.write_message(message.clone()).await.unwrap();

// TODO: enable
// assert_eq!(read.read_message().await.unwrap(), Some(message));
assert_eq!(read.read_message().await.unwrap(), Some(message));
}
}
}
108 changes: 0 additions & 108 deletions crates/synd_kvsd_protocol/src/message/cursor.rs

This file was deleted.

Loading

0 comments on commit ba4488f

Please sign in to comment.