diff --git a/Cargo.lock b/Cargo.lock index f33548a..55ed45a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -328,6 +328,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -2087,7 +2096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70f7f96eab66f057d0ce7139840edadacc114fbc978d41898550731b473186ad" dependencies = [ "async-trait", - "atoi", + "atoi 0.3.3", "backtrace", "bytes", "chrono", @@ -4298,9 +4307,10 @@ dependencies = [ name = "synd-kvsd-protocol" version = "0.1.0" dependencies = [ - "atoi", + "atoi 2.0.0", "bytes", "chrono", + "futures", "thiserror", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index c01bc89..3e33928 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ either = { version = "1.13.0" } fake = { version = "2.10.0", features = ["derive", "chrono"] } fdlimit = { version = "0.3.0", default-features = false } feed-rs = { version = "1.5", default-features = false } +futures = { version = "0.3.30" } futures-util = { version = "0.3.30", default-features = false } graphql_client = { version = "0.13.0", default-features = false } headers = { version = "0.4.0" } diff --git a/crates/synd_kvsd_protocol/Cargo.toml b/crates/synd_kvsd_protocol/Cargo.toml index 146b844..361429b 100644 --- a/crates/synd_kvsd_protocol/Cargo.toml +++ b/crates/synd_kvsd_protocol/Cargo.toml @@ -14,12 +14,12 @@ name = "synd-kvsd-protocol" version = "0.1.0" [dependencies] -# TODO: use latest -atoi = { version = "0.3.3" } +atoi = { version = "2.0.0" } bytes = { workspace = true } chrono = { workspace = true } +futures = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true, features = ["net"] } +tokio = { workspace = true, features = ["net", "time", "io-util"] } [lints] workspace = true diff --git a/crates/synd_kvsd_protocol/src/connection.rs b/crates/synd_kvsd_protocol/src/connection.rs index ef1506b..c687394 100644 --- a/crates/synd_kvsd_protocol/src/connection.rs +++ b/crates/synd_kvsd_protocol/src/connection.rs @@ -1,9 +1,10 @@ use std::{ - io::{self, Cursor}, + io::{self}, time::Duration, }; use bytes::{Buf as _, BytesMut}; +use futures::TryFutureExt as _; use thiserror::Error; use tokio::{ io::{AsyncRead, AsyncReadExt as _, AsyncWrite, BufWriter}, @@ -11,7 +12,7 @@ use tokio::{ time::error::Elapsed, }; -use crate::message::{FrameError, Message, MessageError, MessageFrames}; +use crate::message::{Cursor, FrameError, Message, MessageError, MessageFrames}; #[derive(Error, Debug)] pub enum ConnectionError { @@ -72,10 +73,9 @@ where &mut self, duration: Duration, ) -> Result, ConnectionError> { - match tokio::time::timeout(duration, self.read_message()).await { - Ok(read_result) => read_result, - Err(elapsed) => Err(ConnectionError::read_timeout(elapsed)), - } + tokio::time::timeout(duration, self.read_message()) + .map_err(ConnectionError::read_timeout) + .await? } pub async fn read_message(&mut self) -> Result, ConnectionError> { @@ -111,14 +111,14 @@ where fn parse_message_frames(&mut self) -> Result, ConnectionError> { use FrameError::Incomplete; - let mut buf = Cursor::new(&self.buffer[..]); + let mut cursor = Cursor::new(&self.buffer[..]); - match MessageFrames::check_parse(&mut buf) { + match MessageFrames::check_parse(&mut cursor) { Ok(()) => { #[allow(clippy::cast_possible_truncation)] - let len = buf.position() as usize; - buf.set_position(0); - let message_frames = MessageFrames::parse(&mut buf) + let len = cursor.position() as usize; + cursor.set_position(0); + let message_frames = MessageFrames::parse(&mut cursor) .map_err(ConnectionError::parse_message_frames)?; self.buffer.advance(len); diff --git a/crates/synd_kvsd_protocol/src/message/cursor.rs b/crates/synd_kvsd_protocol/src/message/cursor.rs new file mode 100644 index 0000000..0b01674 --- /dev/null +++ b/crates/synd_kvsd_protocol/src/message/cursor.rs @@ -0,0 +1,108 @@ +use std::io; + +use bytes::Buf; + +use crate::message::{spec, FrameError}; + +pub(crate) struct Cursor<'a> { + cursor: io::Cursor<&'a [u8]>, +} + +impl<'a> Cursor<'a> { + pub(crate) fn new(buf: &'a [u8]) -> Self { + Self { + cursor: io::Cursor::new(buf), + } + } + + pub(crate) fn position(&self) -> u64 { + self.cursor.position() + } + + pub(crate) fn set_position(&mut self, pos: u64) { + self.cursor.set_position(pos); + } + + pub(super) fn skip(&mut self, n: usize) -> Result<(), FrameError> { + if self.cursor.remaining() < n { + Err(FrameError::Incomplete) + } else { + self.cursor.advance(n); + Ok(()) + } + } + + pub(super) fn remaining(&self) -> usize { + self.cursor.remaining() + } + + pub(super) fn chunk(&self) -> &[u8] { + self.cursor.chunk() + } + + pub(super) fn u8(&mut self) -> Result { + if self.cursor.has_remaining() { + Ok(self.cursor.get_u8()) + } else { + Err(FrameError::Incomplete) + } + } + + pub(super) fn u64(&mut self) -> Result { + let line = self.line()?; + atoi::atoi::(line).ok_or_else(|| FrameError::Invalid("invalid u64".into())) + } + + /// Return the buffer up to the line delimiter. + /// If the line delimiter is not found within the buffer, return [`FrameError::Incomplete`]. + /// When the line delimiter is found, set the cursor position to the next position after the line delimiter + /// so that subsequent reads do not need to be aware of the line delimiter. + pub(super) fn line(&mut self) -> Result<&'a [u8], FrameError> { + let slice = *self.cursor.get_ref(); + #[allow(clippy::cast_possible_truncation)] + let start = self.cursor.position() as usize; + let end = slice.len() - (spec::DELIMITER.len() - 1); + + for i in start..end { + if &slice[i..i + spec::DELIMITER.len()] == spec::DELIMITER { + self.cursor.set_position((i + spec::DELIMITER.len()) as u64); + return Ok(&slice[start..i]); + } + } + + Err(FrameError::Incomplete) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn u8() { + let buf = [128, 129]; + let mut cursor = Cursor::new(&buf[..]); + assert_eq!(cursor.u8(), Ok(128)); + assert_eq!(cursor.u8(), Ok(129)); + assert_eq!(cursor.u8(), Err(FrameError::Incomplete)); + assert_eq!(cursor.u8(), Err(FrameError::Incomplete)); + } + + #[test] + fn line() { + let buf = [b'x', b'x', b'\r', b'\n', b'y']; + let mut cursor = Cursor::new(&buf[..]); + assert_eq!(cursor.line(), Ok([b'x', b'x'].as_slice())); + assert_eq!(cursor.line(), Err(FrameError::Incomplete)); + assert_eq!(cursor.line(), Err(FrameError::Incomplete)); + } + + #[test] + fn skip() { + let buf = [b'_', b'_', b'a']; + let mut cursor = Cursor::new(&buf[..]); + assert_eq!(cursor.skip(2), Ok(())); + assert_eq!(cursor.u8(), Ok(b'a')); + assert_eq!(cursor.skip(1), Err(FrameError::Incomplete)); + } +} diff --git a/crates/synd_kvsd_protocol/src/message/frame.rs b/crates/synd_kvsd_protocol/src/message/frame.rs index 9e0524e..d382f9d 100644 --- a/crates/synd_kvsd_protocol/src/message/frame.rs +++ b/crates/synd_kvsd_protocol/src/message/frame.rs @@ -1,9 +1,8 @@ -use bytes::Buf as _; use thiserror::Error; -use crate::message::{MessageError, MessageType}; +use crate::message::{cursor::Cursor, MessageError, MessageType}; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq, Eq)] pub enum FrameError { /// Not enough data is available to decode a message frames from buffer. #[error("incomplete")] @@ -28,58 +27,60 @@ pub(crate) enum Frame { } impl Frame { - fn check(src: &mut ByteCursor) -> Result<(), FrameError> { - match cursor::get_u8(src)? { - frameprefix::MESSAGE_TYPE => { - cursor::get_u8(src)?; + fn check(src: &mut Cursor) -> Result<(), FrameError> { + match src.u8()? { + prefix::MESSAGE_TYPE => { + src.u8()?; Ok(()) } - frameprefix::STRING => { - cursor::get_line(src)?; + prefix::STRING => { + src.line()?; Ok(()) } - frameprefix::BYTES => { + prefix::BYTES => { #[allow(clippy::cast_possible_truncation)] - let len = cursor::get_decimal(src)? as usize; + let len = src.u64()? as usize; // skip bytes length + delimiter - cursor::skip(src, len + 2) + src.skip(len + 2) } - frameprefix::TIME => { - cursor::get_line(src)?; + prefix::TIME => { + src.line()?; Ok(()) } - frameprefix::NULL => Ok(()), - _ => unreachable!(), + prefix::NULL => Ok(()), + unexpected => Err(FrameError::Invalid(format!( + "unexpected prefix: {unexpected}" + ))), } } - fn parse(src: &mut ByteCursor) -> Result { - match cursor::get_u8(src)? { - frameprefix::MESSAGE_TYPE => { + fn parse(src: &mut Cursor) -> Result { + match src.u8()? { + prefix::MESSAGE_TYPE => { Err(FrameError::Invalid("unexpected message type frame".into())) } - frameprefix::STRING => { - let line = cursor::get_line(src)?.to_vec(); + prefix::STRING => { + let line = src.line()?.to_vec(); let string = String::from_utf8(line).map_err(|e| FrameError::Invalid(e.to_string()))?; Ok(Frame::String(string)) } - frameprefix::BYTES => { + prefix::BYTES => { #[allow(clippy::cast_possible_truncation)] - let len = cursor::get_decimal(src)? as usize; + let len = src.u64()? as usize; let n = len + 2; if src.remaining() < n { return Err(FrameError::Incomplete); } let value = Vec::from(&src.chunk()[..len]); - cursor::skip(src, n)?; + src.skip(n)?; Ok(Frame::Bytes(value)) } - frameprefix::TIME => { + prefix::TIME => { use chrono::{DateTime, Utc}; - let line = cursor::get_line(src)?.to_vec(); + let line = src.line()?.to_vec(); let string = String::from_utf8(line).map_err(|e| FrameError::Invalid(e.to_string()))?; Ok(Frame::Time( @@ -88,13 +89,13 @@ impl Frame { .unwrap(), )) } - frameprefix::NULL => Ok(Frame::Null), + prefix::NULL => Ok(Frame::Null), _ => unreachable!(), } } } -mod frameprefix { +mod prefix { pub(super) const MESSAGE_FRAMES: u8 = b'*'; pub(super) const MESSAGE_TYPE: u8 = b'#'; pub(super) const STRING: u8 = b'+'; @@ -106,8 +107,6 @@ mod frameprefix { #[derive(Clone, PartialEq, Debug)] pub(crate) struct MessageFrames(Vec); -type ByteCursor<'a> = std::io::Cursor<&'a [u8]>; - impl IntoIterator for MessageFrames { type Item = Frame; type IntoIter = std::vec::IntoIter; @@ -124,8 +123,8 @@ impl MessageFrames { Self(v) } - pub(crate) fn check_parse(src: &mut ByteCursor) -> Result<(), FrameError> { - let frames_len = MessageFrames::ensure_prefix_format(src)?; + pub(crate) fn check_parse(src: &mut Cursor) -> Result<(), FrameError> { + let frames_len = MessageFrames::frames_len(src)?; for _ in 0..frames_len { Frame::check(src)?; @@ -134,14 +133,14 @@ impl MessageFrames { Ok(()) } - pub(crate) fn parse(src: &mut ByteCursor) -> Result { + pub(crate) fn parse(src: &mut Cursor) -> Result { #[allow(clippy::cast_possible_truncation)] - let frames_len = (MessageFrames::ensure_prefix_format(src)? - 1) as usize; + let frames_len = (MessageFrames::frames_len(src)? - 1) as usize; - if cursor::get_u8(src)? != frameprefix::MESSAGE_TYPE { + if src.u8()? != prefix::MESSAGE_TYPE { return Err(FrameError::Invalid("message type expected".into())); } - let message_type = cursor::get_u8(src)?; + let message_type = src.u8()?; let message_type = MessageType::try_from(message_type).map_err(FrameError::InvalidMessageType)?; @@ -154,59 +153,10 @@ impl MessageFrames { Ok(frames) } - fn ensure_prefix_format(src: &mut ByteCursor) -> Result { - if cursor::get_u8(src)? != frameprefix::MESSAGE_FRAMES { + fn frames_len(src: &mut Cursor) -> Result { + if src.u8()? != prefix::MESSAGE_FRAMES { return Err(FrameError::Invalid("message frames prefix expected".into())); } - - cursor::get_decimal(src) - } -} - -/// cursor utilities. -// TODO: impl to ByteCursor -mod cursor { - use bytes::Buf as _; - - use super::*; - use crate::message::spec; - - pub(super) fn get_u8(src: &mut ByteCursor) -> Result { - if !src.has_remaining() { - return Err(FrameError::Incomplete); - } - Ok(src.get_u8()) - } - - pub(super) fn skip(src: &mut ByteCursor, n: usize) -> Result<(), FrameError> { - if src.remaining() < n { - return Err(FrameError::Incomplete); - } - src.advance(n); - Ok(()) - } - - pub(super) fn get_decimal(src: &mut ByteCursor) -> Result { - let line = get_line(src)?; - - atoi::atoi::(line) - .ok_or_else(|| FrameError::Invalid("invalid protocol decimal format".into())) - } - - pub(super) fn get_line<'a>(src: &'a mut ByteCursor) -> Result<&'a [u8], FrameError> { - #[allow(clippy::cast_possible_truncation)] - let start = src.position() as usize; - let end = src.get_ref().len() - 1; - - for i in start..end { - if src.get_ref()[i] == spec::DELIMITER[0] && src.get_ref()[i + 1] == spec::DELIMITER[1] - { - src.set_position((i + 2) as u64); - - return Ok(&src.get_ref()[start..i]); - } - } - - Err(FrameError::Incomplete) + src.u64() } } diff --git a/crates/synd_kvsd_protocol/src/message/mod.rs b/crates/synd_kvsd_protocol/src/message/mod.rs index ad28412..fdc62eb 100644 --- a/crates/synd_kvsd_protocol/src/message/mod.rs +++ b/crates/synd_kvsd_protocol/src/message/mod.rs @@ -4,6 +4,8 @@ pub(crate) use frame::{FrameError, MessageFrames}; mod authenticate; use authenticate::Authenticate; +mod cursor; +pub(crate) use cursor::Cursor; mod parse; mod spec; @@ -11,7 +13,7 @@ use thiserror::Error; use crate::message::parse::Parse; -#[derive(Error, Debug)] +#[derive(Error, Debug, PartialEq, Eq)] pub enum MessageError { #[error("unknown message type {message_type}")] UnknownMessageType { message_type: u8 },