diff --git a/Cargo.lock b/Cargo.lock index 01e8d0d..06ffca7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4304,6 +4304,7 @@ dependencies = [ "nom 8.0.0-alpha2", "thiserror", "tokio", + "tracing", ] [[package]] diff --git a/crates/synd_kvsd_protocol/Cargo.toml b/crates/synd_kvsd_protocol/Cargo.toml index 1e33234..2458be4 100644 --- a/crates/synd_kvsd_protocol/Cargo.toml +++ b/crates/synd_kvsd_protocol/Cargo.toml @@ -21,6 +21,7 @@ futures = { workspace = true } nom = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["net", "time", "io-util"] } +tracing = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["net", "time", "io-util", "macros", "rt-multi-thread"] } diff --git a/crates/synd_kvsd_protocol/src/connection.rs b/crates/synd_kvsd_protocol/src/connection.rs index fa9902f..e91ecdd 100644 --- a/crates/synd_kvsd_protocol/src/connection.rs +++ b/crates/synd_kvsd_protocol/src/connection.rs @@ -3,7 +3,7 @@ use std::{ time::Duration, }; -use bytes::{Buf as _, BytesMut}; +use bytes::{Buf, BytesMut}; use futures::TryFutureExt as _; use thiserror::Error; use tokio::{ @@ -11,8 +11,9 @@ use tokio::{ net::TcpStream, time::error::Elapsed, }; +use tracing::trace; -use crate::message::{Cursor, FrameError, Message, MessageError, MessageFrames}; +use crate::message::{FrameError, Message, MessageError, ParseError, Parser}; #[derive(Error, Debug)] pub enum ConnectionError { @@ -24,6 +25,8 @@ pub enum ConnectionError { ParseMessageFrames { source: FrameError }, #[error("read message io: {source}")] ReadMessageIo { source: io::Error }, + #[error("parse message: {0}")] + ParseMessage(#[from] ParseError), #[error("connection reset by peer")] ResetByPeer, #[error("write message: {source}")] @@ -35,14 +38,6 @@ impl ConnectionError { ConnectionError::ReadTimeout(elapsed) } - fn read_message_frames(source: MessageError) -> Self { - ConnectionError::ReadMessageFrames { source } - } - - fn parse_message_frames(source: FrameError) -> Self { - ConnectionError::ParseMessageFrames { source } - } - fn read_message_io(source: io::Error) -> Self { ConnectionError::ReadMessageIo { source } } @@ -102,54 +97,36 @@ where } pub async fn read_message(&mut self) -> Result, ConnectionError> { - match self.read_message_frames().await? { - Some(message_frames) => Ok(Some( - Message::parse(message_frames).map_err(ConnectionError::read_message_frames)?, - )), - None => Ok(None), - } - } - - async fn read_message_frames(&mut self) -> Result, ConnectionError> { loop { - if let Some(message_frames) = self.parse_message_frames()? { - return Ok(Some(message_frames)); - } - - if 0 == self - .stream - .read_buf(&mut self.buffer) - .await - .map_err(ConnectionError::read_message_io)? - { - return if self.buffer.is_empty() { - Ok(None) - } else { - Err(ConnectionError::ResetByPeer) - }; - } - } - } - - fn parse_message_frames(&mut self) -> Result, ConnectionError> { - use FrameError::Incomplete; - - let mut cursor = Cursor::new(&self.buffer[..]); - - match MessageFrames::check_parse(&mut cursor) { - Ok(()) => { - #[allow(clippy::cast_possible_truncation)] - let len = cursor.position() as usize; - cursor.set_position(0); - let message_frames = MessageFrames::parse(&mut cursor) - .map_err(ConnectionError::parse_message_frames)?; - self.buffer.advance(len); - - Ok(Some(message_frames)) + let input = &self.buffer[..]; + match Parser::new().parse(input) { + Ok((remain, message)) => { + let consumed = input.len() - remain.len(); + self.buffer.advance(consumed); + return Ok(Some(message)); + } + Err(ParseError::Incomplete) => { + let read = self + .stream + .read_buf(&mut self.buffer) + .await + .map_err(ConnectionError::read_message_io)?; + trace!( + bytes = read, + buffer = self.buffer.len(), + "read from connection" + ); + + if read == 0 { + return if self.buffer.is_empty() { + Ok(None) + } else { + Err(ConnectionError::ResetByPeer) + }; + } + } + Err(err) => return Err(ConnectionError::from(err)), } - Err(Incomplete) => Ok(None), - // TODO: define distinct error - Err(e) => Err(ConnectionError::parse_message_frames(e)), } } } @@ -166,7 +143,7 @@ mod tests { let buf_size = 1024; let (read, write) = tokio::io::duplex(buf_size); - let (mut _read, mut write) = ( + let (mut read, mut write) = ( Connection::new(read, buf_size), Connection::new(write, buf_size), ); @@ -174,8 +151,7 @@ mod tests { for message in messages { write.write_message(message.clone()).await.unwrap(); - // TODO: enable - // assert_eq!(read.read_message().await.unwrap(), Some(message)); + assert_eq!(read.read_message().await.unwrap(), Some(message)); } } } diff --git a/crates/synd_kvsd_protocol/src/message/cursor.rs b/crates/synd_kvsd_protocol/src/message/cursor.rs deleted file mode 100644 index 490af87..0000000 --- a/crates/synd_kvsd_protocol/src/message/cursor.rs +++ /dev/null @@ -1,108 +0,0 @@ -use std::io; - -use bytes::Buf; - -use crate::message::{spec, FrameError}; - -pub(crate) struct Cursor<'a> { - cursor: io::Cursor<&'a [u8]>, -} - -impl<'a> Cursor<'a> { - pub(crate) fn new(buf: &'a [u8]) -> Self { - Self { - cursor: io::Cursor::new(buf), - } - } - - pub(crate) fn position(&self) -> u64 { - self.cursor.position() - } - - pub(crate) fn set_position(&mut self, pos: u64) { - self.cursor.set_position(pos); - } - - pub(super) fn skip(&mut self, n: usize) -> Result<(), FrameError> { - if self.cursor.remaining() < n { - Err(FrameError::Incomplete) - } else { - self.cursor.advance(n); - Ok(()) - } - } - - pub(super) fn remaining(&self) -> usize { - self.cursor.remaining() - } - - pub(super) fn chunk(&self) -> &[u8] { - self.cursor.chunk() - } - - pub(super) fn u8(&mut self) -> Result { - if self.cursor.has_remaining() { - Ok(self.cursor.get_u8()) - } else { - Err(FrameError::Incomplete) - } - } - - pub(super) fn u64(&mut self) -> Result { - let line = self.line()?; - atoi::atoi::(line).ok_or_else(|| FrameError::Invalid("invalid u64 line".into())) - } - - /// Return the buffer up to the line delimiter. - /// If the line delimiter is not found within the buffer, return [`FrameError::Incomplete`]. - /// When the line delimiter is found, set the cursor position to the next position after the line delimiter - /// so that subsequent reads do not need to be aware of the line delimiter. - pub(super) fn line(&mut self) -> Result<&'a [u8], FrameError> { - let slice = *self.cursor.get_ref(); - #[allow(clippy::cast_possible_truncation)] - let start = self.cursor.position() as usize; - let end = slice.len() - (spec::DELIMITER.len() - 1); - - for i in start..end { - if &slice[i..i + spec::DELIMITER.len()] == spec::DELIMITER { - self.cursor.set_position((i + spec::DELIMITER.len()) as u64); - return Ok(&slice[start..i]); - } - } - - Err(FrameError::Incomplete) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn u8() { - let buf = [128, 129]; - let mut cursor = Cursor::new(&buf[..]); - assert_eq!(cursor.u8(), Ok(128)); - assert_eq!(cursor.u8(), Ok(129)); - assert_eq!(cursor.u8(), Err(FrameError::Incomplete)); - assert_eq!(cursor.u8(), Err(FrameError::Incomplete)); - } - - #[test] - fn line() { - let buf = [b'x', b'x', b'\r', b'\n', b'y']; - let mut cursor = Cursor::new(&buf[..]); - assert_eq!(cursor.line(), Ok([b'x', b'x'].as_slice())); - assert_eq!(cursor.line(), Err(FrameError::Incomplete)); - assert_eq!(cursor.line(), Err(FrameError::Incomplete)); - } - - #[test] - fn skip() { - let buf = [b'_', b'_', b'a']; - let mut cursor = Cursor::new(&buf[..]); - assert_eq!(cursor.skip(2), Ok(())); - assert_eq!(cursor.u8(), Ok(b'a')); - assert_eq!(cursor.skip(1), Err(FrameError::Incomplete)); - } -} diff --git a/crates/synd_kvsd_protocol/src/message/frame.rs b/crates/synd_kvsd_protocol/src/message/frame.rs index 9f7aeb8..d478bd0 100644 --- a/crates/synd_kvsd_protocol/src/message/frame.rs +++ b/crates/synd_kvsd_protocol/src/message/frame.rs @@ -1,8 +1,9 @@ use std::io; use thiserror::Error; +use tokio::io::AsyncWriteExt; -use crate::message::{cursor::Cursor, ioext::MessageWriteExt, spec, MessageError, MessageType}; +use crate::message::{spec, MessageError, MessageType}; pub(in crate::message) mod prefix { pub(in crate::message) const MESSAGE_START: u8 = b'*'; @@ -30,6 +31,7 @@ pub enum FrameError { // Should support time type ? pub(crate) type Time = chrono::DateTime; +#[expect(dead_code)] #[derive(Debug, Clone, PartialEq)] pub(crate) enum Frame { MessageStart, @@ -42,34 +44,8 @@ pub(crate) enum Frame { } impl Frame { - fn check(src: &mut Cursor) -> Result<(), FrameError> { - match src.u8()? { - prefix::MESSAGE_TYPE => { - src.u8()?; - Ok(()) - } - prefix::STRING => { - #[allow(clippy::cast_possible_truncation)] - let len = src.u64()? as usize; - src.skip(len + spec::DELIMITER.len()) - } - prefix::BYTES => { - #[allow(clippy::cast_possible_truncation)] - let len = src.u64()? as usize; - // skip bytes length + delimiter - src.skip(len + spec::DELIMITER.len()) - } - prefix::TIME => { - src.line()?; - Ok(()) - } - prefix::NULL => Ok(()), - unexpected => Err(FrameError::Invalid(format!( - "unexpected prefix: {unexpected}" - ))), - } - } - + // TODO: remove + /* fn read(src: &mut Cursor) -> Result { match src.u8()? { prefix::MESSAGE_START => Ok(Frame::MessageStart), @@ -123,10 +99,11 @@ impl Frame { _ => unreachable!(), } } + */ pub(crate) async fn write(self, mut writer: W) -> Result<(), io::Error> where - W: MessageWriteExt, + W: AsyncWriteExt + Unpin, { match self { Frame::MessageStart => writer.write_u8(prefix::MESSAGE_START).await, @@ -146,7 +123,7 @@ impl Frame { } Frame::Bytes(val) => { writer.write_u8(prefix::BYTES).await?; - writer.write_u64m(val.len() as u64).await?; + writer.write_u64(val.len() as u64).await?; writer.write_all(val.as_ref()).await?; writer.write_all(spec::DELIMITER).await } @@ -183,82 +160,7 @@ impl MessageFrames { MessageFrames(v) } - pub(crate) fn check_parse(src: &mut Cursor) -> Result<(), FrameError> { - let frames_len = MessageFrames::frames_len(src)?; - - for _ in 0..frames_len { - Frame::check(src)?; - } - - Ok(()) - } - - pub(crate) fn parse(src: &mut Cursor) -> Result { - #[allow(clippy::cast_possible_truncation)] - let frames_len = (MessageFrames::frames_len(src)? - 1) as usize; - - let message_type = match Frame::read(src) { - Ok(Frame::MessageType(mt)) => mt, - Ok(frame) => return Err(FrameError::Invalid(format!("invalid frame {frame:?}"))), - Err(err) => return Err(err), - }; - - let mut frames = MessageFrames::new(message_type, frames_len); - - for _ in 0..frames_len { - frames.0.push(Frame::read(src)?); - } - - Ok(frames) - } - - fn frames_len(src: &mut Cursor) -> Result { - if src.u8()? != prefix::MESSAGE_START { - return Err(FrameError::Invalid("message frames prefix expected".into())); - } - if src.u8()? != prefix::FRAME_LENGTH { - return Err(FrameError::Invalid("message frame length expected".into())); - } - src.u64() - } - pub(super) fn push_string(&mut self, s: impl Into) { self.0.push(Frame::String(s.into())); } } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn write_read_u64() { - for val in [0, 1, 1024, u64::MAX] { - let mut buf = Vec::new(); - buf.write_u64m(val).await.unwrap(); - let mut cursor = Cursor::new(&buf); - assert_eq!(cursor.u64(), Ok(val)); - } - } - - #[tokio::test] - async fn frame_message_type() { - let types = vec![ - MessageType::Ping, - MessageType::Authenticate, - MessageType::Success, - MessageType::Fail, - MessageType::Set, - MessageType::Get, - MessageType::Delete, - ]; - - for ty in types { - let mut buf = Vec::new(); - let frame = Frame::MessageType(ty); - frame.write(&mut buf).await.unwrap(); - let mut cursor = Cursor::new(&buf); - assert_eq!(Frame::read(&mut cursor), Ok(Frame::MessageType(ty))); - } - } -} diff --git a/crates/synd_kvsd_protocol/src/message/ioext.rs b/crates/synd_kvsd_protocol/src/message/ioext.rs deleted file mode 100644 index 4becf21..0000000 --- a/crates/synd_kvsd_protocol/src/message/ioext.rs +++ /dev/null @@ -1,20 +0,0 @@ -use tokio::io::AsyncWriteExt; - -use crate::message::spec; - -pub(crate) trait MessageWriteExt: AsyncWriteExt + Unpin { - async fn write_u64m(&mut self, val: u64) -> std::io::Result<()> { - use std::io::Write; - - // for write u64::MAX - let mut buf = [0u8; 20]; - let mut buf = std::io::Cursor::new(&mut buf[..]); - write!(&mut buf, "{val}")?; - - let pos: usize = buf.position().try_into().unwrap(); - self.write_all(&buf.get_ref()[..pos]).await?; - self.write_all(spec::DELIMITER).await - } -} - -impl MessageWriteExt for T where T: AsyncWriteExt + Unpin {} diff --git a/crates/synd_kvsd_protocol/src/message/mod.rs b/crates/synd_kvsd_protocol/src/message/mod.rs index 0617739..678796c 100644 --- a/crates/synd_kvsd_protocol/src/message/mod.rs +++ b/crates/synd_kvsd_protocol/src/message/mod.rs @@ -1,20 +1,16 @@ mod frame; pub(crate) use frame::{FrameError, MessageFrames}; -mod cursor; -pub(crate) use cursor::Cursor; -mod ioext; -pub(crate) use ioext::MessageWriteExt; mod parse; +pub(crate) use parse::{ParseError, Parser}; mod payload; pub use payload::authenticate::Authenticate; +use tokio::io::AsyncWriteExt; mod spec; use std::io; use thiserror::Error; -use crate::message::parse::Parse; - #[derive(Error, Debug, PartialEq, Eq)] pub enum MessageError { #[error("unknown message type {message_type}")] @@ -76,29 +72,9 @@ impl From for MessageFrames { } impl Message { - pub(crate) fn parse(frames: MessageFrames) -> Result { - let mut parse = Parse::new(frames); - // skip message_start and frame_length - parse.skip(2); - let message_type = parse.message_type().ok_or(MessageError::ParseFrame { - message: "message type not found", - })?; - - let message = match message_type { - MessageType::Authenticate => Message::Authenticate( - Authenticate::parse_frames(&mut parse) - .map_err(|_| MessageError::ParseFrame { message: "TODO" })?, - ), - // TODO: impl - _ => unimplemented!(), - }; - - Ok(message) - } - pub(crate) async fn write(self, mut writer: W) -> Result<(), io::Error> where - W: MessageWriteExt, + W: AsyncWriteExt + Unpin, { let frames: MessageFrames = self.into(); diff --git a/crates/synd_kvsd_protocol/src/message/parse.rs b/crates/synd_kvsd_protocol/src/message/parse.rs index c4ba832..a3a03eb 100644 --- a/crates/synd_kvsd_protocol/src/message/parse.rs +++ b/crates/synd_kvsd_protocol/src/message/parse.rs @@ -1,17 +1,17 @@ -use std::vec; +use std::string::FromUtf8Error; use thiserror::Error; -use crate::message::{frame::Frame, Message, MessageError, MessageFrames, MessageType}; +use crate::message::{Authenticate, Message, MessageError, MessageType}; #[derive(Error, Debug)] -pub(super) enum ParseError { +pub enum ParseError { #[error("end of stream")] EndOfStream, - #[error("unexpecte frame: {frame:?}")] - UnexpectedFrame { frame: Frame }, #[error("invalid message type: {0}")] InvalidMessageType(#[from] MessageError), + #[error("invalid utf8: {0}")] + InvalidUtf8(#[from] FromUtf8Error), #[error("expect frame: {0}")] Expect(&'static str), #[error("incomplete")] @@ -29,87 +29,14 @@ impl ParseError { } } -pub(super) struct Parse { - frames: vec::IntoIter, -} -impl Parse { - pub(super) fn new(message_frames: MessageFrames) -> Self { - Self { - frames: message_frames.into_iter(), - } - } - - pub(super) fn skip(&mut self, n: usize) { - for _ in 0..n { - self.frames.next(); - } - } - - pub(super) fn message_type(&mut self) -> Option { - self.next().ok().and_then(|frame| match frame { - Frame::MessageType(mt) => Some(mt), - _ => None, - }) - } +pub(crate) struct Parser; - pub(crate) fn next_string(&mut self) -> Result { - match self.next()? { - Frame::String(s) => Ok(s), - frame => Err(ParseError::UnexpectedFrame { frame }), - } - } - /* - pub(crate) fn next_bytes(&mut self) -> Result, ParseError> { - match self.next()? { - Frame::Bytes(val) => Ok(val), - frame => Err(format!("unexpected frame. want bytes got {:?}", frame).into()), - } - } - - pub(crate) fn next_bytes_or_null(&mut self) -> Result>, ParseError> { - match self.next()? { - Frame::Bytes(val) => Ok(Some(val)), - Frame::Null => Ok(None), - frame => Err(format!("unexpected frame. want (bytes|null) got {:?}", frame).into()), - } - } - - pub(crate) fn next_time_or_null(&mut self) -> Result, ParseError> { - match self.next()? { - Frame::Time(time) => Ok(Some(time)), - Frame::Null => Ok(None), - frame => Err(format!("unexpected frame. want (time|null) got {:?} ", frame).into()), - } - } - - // Make sure that caller has parse all the frames. - pub(crate) fn expect_consumed(&mut self) -> Result<()> { - match self.next() { - Ok(frame) => Err(ErrorKind::NetworkFraming(format!( - "unparsed frame still remains {:?}", - frame - )) - .into()), - Err(ParseError::EndOfStream) => Ok(()), - Err(err) => Err(err.into()), - } - } - */ - - fn next(&mut self) -> Result { - self.frames.next().ok_or(ParseError::EndOfStream) - } -} - -pub(super) struct Parser; - -#[expect(dead_code)] impl Parser { - pub(super) fn new() -> Self { + pub(crate) fn new() -> Self { Self } - pub(super) fn parse(&self, input: &[u8]) -> Result { + pub(crate) fn parse<'a>(&self, input: &'a [u8]) -> Result<(&'a [u8], Message), ParseError> { let (input, _start) = parse::message_start(input).map_err(|err| ParseError::expect(err, "message_start"))?; @@ -122,7 +49,16 @@ impl Parser { match message_type { MessageType::Ping => todo!(), - MessageType::Authenticate => parse::authenticate(input).map(Message::Authenticate), + MessageType::Authenticate => { + let (input, (username, password)) = parse::authenticate(input) + .map_err(|err| ParseError::expect(err, "message authenticate"))?; + let username = String::from_utf8(username.to_vec())?; + let password = String::from_utf8(password.to_vec())?; + Ok(( + input, + Message::Authenticate(Authenticate::new(username, password)), + )) + } MessageType::Success => todo!(), MessageType::Fail => todo!(), MessageType::Set => todo!(), @@ -137,11 +73,11 @@ mod parse { bytes::streaming::{tag, take}, combinator::map, number::streaming::{be_u64, u8}, - sequence::{preceded, terminated}, + sequence::{pair, preceded, terminated}, IResult, Parser as _, }; - use crate::message::{frame::prefix, parse::ParseError, spec, Authenticate}; + use crate::message::{frame::prefix, spec}; pub(super) fn message_start(input: &[u8]) -> IResult<&[u8], &[u8]> { tag([prefix::MESSAGE_START].as_slice())(input) } @@ -154,8 +90,8 @@ mod parse { preceded(tag([prefix::MESSAGE_TYPE].as_slice()), u8).parse(input) } - pub(super) fn authenticate(_input: &[u8]) -> Result { - todo!() + pub(super) fn authenticate(input: &[u8]) -> IResult<&[u8], (&[u8], &[u8])> { + pair(string, string).parse(input) } fn delimiter(input: &[u8]) -> IResult<&[u8], ()> { diff --git a/crates/synd_kvsd_protocol/src/message/payload/authenticate.rs b/crates/synd_kvsd_protocol/src/message/payload/authenticate.rs index bf05fa3..8b20b25 100644 --- a/crates/synd_kvsd_protocol/src/message/payload/authenticate.rs +++ b/crates/synd_kvsd_protocol/src/message/payload/authenticate.rs @@ -1,7 +1,4 @@ -use crate::message::{ - parse::{Parse, ParseError}, - MessageFrames, MessageType, -}; +use crate::message::{MessageFrames, MessageType}; /// `Authenticate` is a message in which client requests the server /// to perform authentication process. @@ -23,14 +20,6 @@ impl Authenticate { password: password.into(), } } - - // TODO: impl in trait - pub(in crate::message) fn parse_frames(parse: &mut Parse) -> Result { - let username = parse.next_string()?; - let password = parse.next_string()?; - - Ok(Authenticate::new(username, password)) - } } impl From for MessageFrames {