Skip to content

Commit

Permalink
feat(kvsd_protocol): impl MessageFrame write
Browse files Browse the repository at this point in the history
  • Loading branch information
ymgyt committed Oct 14, 2024
1 parent cff8743 commit 935bd53
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 5 deletions.
39 changes: 37 additions & 2 deletions crates/synd_kvsd_protocol/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use bytes::{Buf as _, BytesMut};
use futures::TryFutureExt as _;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncReadExt as _, AsyncWrite, BufWriter},
io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, BufWriter},
net::TcpStream,
time::error::Elapsed,
};

use crate::message::{Cursor, FrameError, Message, MessageError, MessageFrames};
use crate::message::{Cursor, Frame, FrameError, Message, MessageError, MessageFrames};

#[derive(Error, Debug)]
pub enum ConnectionError {
Expand All @@ -26,6 +26,8 @@ pub enum ConnectionError {
ReadMessageIo { source: io::Error },
#[error("connection reset by peer")]
ResetByPeer,
#[error("write message frame: {source}")]
WriteMessageFrame { source: io::Error },
}

impl ConnectionError {
Expand All @@ -44,6 +46,10 @@ impl ConnectionError {
fn read_message_io(source: io::Error) -> Self {
ConnectionError::ReadMessageIo { source }
}

fn write_message_frame(source: io::Error) -> Self {
ConnectionError::WriteMessageFrame { source }
}
}

pub struct Connection<Stream = TcpStream> {
Expand All @@ -64,6 +70,35 @@ where
}
}

impl<Stream> Connection<Stream>
where
Stream: AsyncWrite + Unpin,
{
pub async fn write_message(&mut self, message: Message) -> Result<(), ConnectionError> {
let frames: MessageFrames = message.into();

// TODO: impl in Into<MessageFrames>
// self.stream.write_u8(prefix::MESSAGE_FRAMES).await?;
// self.write_decimal(frames.len()).await?;

for frame in frames {
self.write_frame(frame).await?;
}

self.stream
.flush()
.await
.map_err(ConnectionError::write_message_frame)
}

async fn write_frame(&mut self, frame: Frame) -> Result<(), ConnectionError> {
frame
.write(&mut self.stream)
.await
.map_err(ConnectionError::write_message_frame)
}
}

impl<Stream> Connection<Stream>
where
Stream: AsyncRead + Unpin,
Expand Down
11 changes: 10 additions & 1 deletion crates/synd_kvsd_protocol/src/message/authenticate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::message::parse::{Parse, ParseError};
use crate::message::{
parse::{Parse, ParseError},
MessageFrames,
};

/// `Authenticate` is a message in which client requests the server
/// to perform authentication process.
Expand Down Expand Up @@ -29,3 +32,9 @@ impl Authenticate {
Ok(Authenticate::new(username, password))
}
}

impl From<Authenticate> for MessageFrames {
fn from(_m: Authenticate) -> Self {
todo!()
}
}
53 changes: 52 additions & 1 deletion crates/synd_kvsd_protocol/src/message/frame.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::io;

use thiserror::Error;
use tokio::io::AsyncWriteExt;

use crate::message::{cursor::Cursor, MessageError, MessageType};
use crate::message::{cursor::Cursor, spec, MessageError, MessageType};

#[derive(Error, Debug, PartialEq, Eq)]
pub enum FrameError {
Expand Down Expand Up @@ -74,6 +77,8 @@ impl Frame {
}
let value = Vec::from(&src.chunk()[..len]);

// TODO: debug assert delimiter

src.skip(n)?;

Ok(Frame::Bytes(value))
Expand All @@ -93,6 +98,52 @@ impl Frame {
_ => unreachable!(),
}
}

pub(crate) async fn write<W>(self, mut writer: W) -> Result<(), io::Error>
where
W: AsyncWriteExt + Unpin,
{
match self {
Frame::MessageType(mt) => {
writer.write_u8(prefix::MESSAGE_TYPE).await?;
writer.write_u8(mt.into()).await
}
Frame::String(val) => {
writer.write_u8(prefix::STRING).await?;
writer.write_all(val.as_bytes()).await?;
writer.write_all(spec::DELIMITER).await
}
Frame::Bytes(val) => {
writer.write_u8(prefix::BYTES).await?;
Frame::write_u64(val.len() as u64, &mut writer).await?;
writer.write_all(val.as_ref()).await?;
writer.write_all(spec::DELIMITER).await
}
Frame::Time(val) => {
writer.write_u8(prefix::TIME).await?;
writer.write_all(val.to_rfc3339().as_bytes()).await?;
writer.write_all(spec::DELIMITER).await
}
Frame::Null => writer.write_u8(prefix::NULL).await,
}
}

// TODO: refactor
async fn write_u64<W>(val: u64, mut writer: W) -> io::Result<()>
where
W: AsyncWriteExt + Unpin,
{
use std::io::Write;

let mut buf = [0u8; 12];
let mut buf = std::io::Cursor::new(&mut buf[..]);
write!(&mut buf, "{val}")?;

#[allow(clippy::cast_possible_truncation)]
let pos = buf.position() as usize;
writer.write_all(&buf.get_ref()[..pos]).await?;
writer.write_all(spec::DELIMITER).await
}
}

mod prefix {
Expand Down
10 changes: 9 additions & 1 deletion crates/synd_kvsd_protocol/src/message/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod frame;
pub(crate) use frame::{FrameError, MessageFrames};
pub(crate) use frame::{Frame, FrameError, MessageFrames};

mod authenticate;
use authenticate::Authenticate;
Expand Down Expand Up @@ -65,6 +65,14 @@ pub enum Message {
// Delete(Delete),
}

impl From<Message> for MessageFrames {
fn from(message: Message) -> Self {
match message {
Message::Authenticate(m) => m.into(),
}
}
}

impl Message {
pub(crate) fn parse(frames: MessageFrames) -> Result<Message, MessageError> {
let mut parse = Parse::new(frames);
Expand Down

0 comments on commit 935bd53

Please sign in to comment.