diff --git a/crates/synd_kvsd_protocol/src/connection.rs b/crates/synd_kvsd_protocol/src/connection.rs index c687394..2523f57 100644 --- a/crates/synd_kvsd_protocol/src/connection.rs +++ b/crates/synd_kvsd_protocol/src/connection.rs @@ -7,12 +7,12 @@ use bytes::{Buf as _, BytesMut}; use futures::TryFutureExt as _; use thiserror::Error; use tokio::{ - io::{AsyncRead, AsyncReadExt as _, AsyncWrite, BufWriter}, + io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, BufWriter}, net::TcpStream, time::error::Elapsed, }; -use crate::message::{Cursor, FrameError, Message, MessageError, MessageFrames}; +use crate::message::{Cursor, Frame, FrameError, Message, MessageError, MessageFrames}; #[derive(Error, Debug)] pub enum ConnectionError { @@ -26,6 +26,8 @@ pub enum ConnectionError { ReadMessageIo { source: io::Error }, #[error("connection reset by peer")] ResetByPeer, + #[error("write message frame: {source}")] + WriteMessageFrame { source: io::Error }, } impl ConnectionError { @@ -44,6 +46,10 @@ impl ConnectionError { fn read_message_io(source: io::Error) -> Self { ConnectionError::ReadMessageIo { source } } + + fn write_message_frame(source: io::Error) -> Self { + ConnectionError::WriteMessageFrame { source } + } } pub struct Connection { @@ -64,6 +70,35 @@ where } } +impl Connection +where + Stream: AsyncWrite + Unpin, +{ + pub async fn write_message(&mut self, message: Message) -> Result<(), ConnectionError> { + let frames: MessageFrames = message.into(); + + // TODO: impl in Into + // self.stream.write_u8(prefix::MESSAGE_FRAMES).await?; + // self.write_decimal(frames.len()).await?; + + for frame in frames { + self.write_frame(frame).await?; + } + + self.stream + .flush() + .await + .map_err(ConnectionError::write_message_frame) + } + + async fn write_frame(&mut self, frame: Frame) -> Result<(), ConnectionError> { + frame + .write(&mut self.stream) + .await + .map_err(ConnectionError::write_message_frame) + } +} + impl Connection where Stream: AsyncRead + Unpin, diff --git a/crates/synd_kvsd_protocol/src/message/authenticate.rs b/crates/synd_kvsd_protocol/src/message/authenticate.rs index 96977e3..ce977de 100644 --- a/crates/synd_kvsd_protocol/src/message/authenticate.rs +++ b/crates/synd_kvsd_protocol/src/message/authenticate.rs @@ -1,4 +1,7 @@ -use crate::message::parse::{Parse, ParseError}; +use crate::message::{ + parse::{Parse, ParseError}, + MessageFrames, +}; /// `Authenticate` is a message in which client requests the server /// to perform authentication process. @@ -29,3 +32,9 @@ impl Authenticate { Ok(Authenticate::new(username, password)) } } + +impl From for MessageFrames { + fn from(_m: Authenticate) -> Self { + todo!() + } +} diff --git a/crates/synd_kvsd_protocol/src/message/frame.rs b/crates/synd_kvsd_protocol/src/message/frame.rs index d382f9d..56230c6 100644 --- a/crates/synd_kvsd_protocol/src/message/frame.rs +++ b/crates/synd_kvsd_protocol/src/message/frame.rs @@ -1,6 +1,9 @@ +use std::io; + use thiserror::Error; +use tokio::io::AsyncWriteExt; -use crate::message::{cursor::Cursor, MessageError, MessageType}; +use crate::message::{cursor::Cursor, spec, MessageError, MessageType}; #[derive(Error, Debug, PartialEq, Eq)] pub enum FrameError { @@ -74,6 +77,8 @@ impl Frame { } let value = Vec::from(&src.chunk()[..len]); + // TODO: debug assert delimiter + src.skip(n)?; Ok(Frame::Bytes(value)) @@ -93,6 +98,52 @@ impl Frame { _ => unreachable!(), } } + + pub(crate) async fn write(self, mut writer: W) -> Result<(), io::Error> + where + W: AsyncWriteExt + Unpin, + { + match self { + Frame::MessageType(mt) => { + writer.write_u8(prefix::MESSAGE_TYPE).await?; + writer.write_u8(mt.into()).await + } + Frame::String(val) => { + writer.write_u8(prefix::STRING).await?; + writer.write_all(val.as_bytes()).await?; + writer.write_all(spec::DELIMITER).await + } + Frame::Bytes(val) => { + writer.write_u8(prefix::BYTES).await?; + Frame::write_u64(val.len() as u64, &mut writer).await?; + writer.write_all(val.as_ref()).await?; + writer.write_all(spec::DELIMITER).await + } + Frame::Time(val) => { + writer.write_u8(prefix::TIME).await?; + writer.write_all(val.to_rfc3339().as_bytes()).await?; + writer.write_all(spec::DELIMITER).await + } + Frame::Null => writer.write_u8(prefix::NULL).await, + } + } + + // TODO: refactor + async fn write_u64(val: u64, mut writer: W) -> io::Result<()> + where + W: AsyncWriteExt + Unpin, + { + use std::io::Write; + + let mut buf = [0u8; 12]; + let mut buf = std::io::Cursor::new(&mut buf[..]); + write!(&mut buf, "{val}")?; + + #[allow(clippy::cast_possible_truncation)] + let pos = buf.position() as usize; + writer.write_all(&buf.get_ref()[..pos]).await?; + writer.write_all(spec::DELIMITER).await + } } mod prefix { diff --git a/crates/synd_kvsd_protocol/src/message/mod.rs b/crates/synd_kvsd_protocol/src/message/mod.rs index fdc62eb..ac35efe 100644 --- a/crates/synd_kvsd_protocol/src/message/mod.rs +++ b/crates/synd_kvsd_protocol/src/message/mod.rs @@ -1,5 +1,5 @@ mod frame; -pub(crate) use frame::{FrameError, MessageFrames}; +pub(crate) use frame::{Frame, FrameError, MessageFrames}; mod authenticate; use authenticate::Authenticate; @@ -65,6 +65,14 @@ pub enum Message { // Delete(Delete), } +impl From for MessageFrames { + fn from(message: Message) -> Self { + match message { + Message::Authenticate(m) => m.into(), + } + } +} + impl Message { pub(crate) fn parse(frames: MessageFrames) -> Result { let mut parse = Parse::new(frames);