Skip to content

Commit

Permalink
Manually implement rkyv for gateway messages
Browse files Browse the repository at this point in the history
  • Loading branch information
novacrazy committed Apr 11, 2024
1 parent cbf459c commit 6b1f46a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/api/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ pub trait Command: sealed::Sealed {
/// Body to be serialized as request body or query parameters (if GET)
fn body(&self) -> &Self::Body;

/// Used to collect the [`Result`](Self::Result) from an arbitrary [`Stream`] of items.
/// Used to collect the [`Result`](Self::Result) from an arbitrary [`Stream`](futures::Stream) of items.
fn collect<S, E>(stream: S) -> impl std::future::Future<Output = Result<Self::Result, E>> + Send
where
S: futures::Stream<Item = Result<Self::Item, E>> + Send,
Expand Down
139 changes: 127 additions & 12 deletions src/models/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ pub mod message {
),*$(,)*
}
) => {paste::paste!{
#[doc = "OpCodes for [" $name "]"]
#[doc = "OpCodes for [`" $name "`]"]
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema_repr))]
#[repr(u8)]
Expand All @@ -330,7 +330,7 @@ pub mod message {
$(
$(#[$variant_meta])*
#[doc = ""]
#[doc = "Payload struct for [" $name "::" $opcode "]"]
#[doc = "Payload struct for [`" $name "::" $opcode "`]"]
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))]
Expand Down Expand Up @@ -453,7 +453,7 @@ pub mod message {
)*
}

#[doc = "Handler callbacks for [" $name "]"]
#[doc = "Handler callbacks for [`" $name "`]"]
#[cfg(feature = "framework")]
#[async_trait::async_trait]
pub trait [<$name Handlers>]<C, U = ()>: Send + Sync where C: Send + 'static {
Expand All @@ -476,7 +476,7 @@ pub mod message {
$(
$(#[$variant_meta])*
#[doc = ""]
#[doc = "Handler callback for [" $name "::" $opcode "]"]
#[doc = "Handler callback for [`" $name "::" $opcode "`]"]
#[inline(always)]
fn [<$opcode:snake>]<'life0, 'async_trait>(&'life0 self, ctx: C, $($field: $ty,)*)
-> std::pin::Pin<Box<dyn Future<Output = U> + Send + 'async_trait>>
Expand All @@ -491,16 +491,130 @@ pub mod message {
$(#[$meta])*
#[derive(Debug)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))]
// #[cfg_attr(feature = "rkyv", archive(check_bytes))] // TODO: Doesn't compile?
#[repr(u8)]
pub enum $name {
$(
$(#[$variant_meta])*
#[doc = ""]
#[doc = "See [" [<new_ $opcode:snake>] "](" $name "::" [<new_ $opcode:snake>] ") for an easy way to create this message."]
#[doc = "See [`" [<new_ $opcode:snake>] "`](" $name "::" [<new_ $opcode:snake>] ") for an easy way to create this message."]
#[cfg_attr(feature = "schema", schemars(description = "" $name "::" $opcode "" ))]
$opcode([<$name:snake _payloads>]::[<$opcode Payload>])
,)*
$opcode([<$name:snake _payloads>]::[<$opcode Payload>]),
)*
}

#[cfg(feature = "rkyv")]
pub use [<$name:snake _proc_impl>]::{[<Archived $name>], [<$name Resolver>]};

#[cfg(feature = "rkyv")]
mod [<$name:snake _proc_impl>] {
use super::*;

use core::marker::PhantomData;
use rkyv::{Archive, Archived, Serialize, Fallible, Deserialize, Resolver};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8, align(1))]
enum ArchivedTag {
$($opcode = $code,)* // NOTE: mirror tag if discriminator grows to u16
}

#[doc = "An archived [`" $name "`]" ]
#[repr(u8, align(1))]
pub enum [<Archived $name>] {
$($opcode(Archived<[<$name:snake _payloads>]::[<$opcode Payload>]>) = ArchivedTag::$opcode as u8,)*
}

#[doc = "Resolver for an archived [`" $name "`]" ]
#[repr(u8)]
pub enum [<$name Resolver>] {
$($opcode(Resolver<[<$name:snake _payloads>]::[<$opcode Payload>]>) = $code,)*
}

$(
#[repr(C)]
struct [<Archived $opcode Variant>] {
tag: ArchivedTag,
op: Archived<[<$name:snake _payloads>]::[<$opcode Payload>]>,
mkr: PhantomData<$name>,
}
)*

impl Archive for $name {
type Archived = [<Archived $name>];
type Resolver = [<$name Resolver>];

unsafe fn resolve(&self, pos: usize, resolver: Self::Resolver, out: *mut Self::Archived) {
match resolver {$(
[<$name Resolver>]::$opcode(resolver_0) => match self {
$name::$opcode(self_0) => {
let out = out.cast::<[<Archived $opcode Variant>]>();
core::ptr::addr_of_mut!((*out).tag).write(ArchivedTag::$opcode);
let (fp, fo) = rkyv::out_field!(out.op);
rkyv::Archive::resolve(self_0, pos + fp, resolver_0, fo);
},
_ => core::hint::unreachable_unchecked(),
},
)*}
}
}

impl<S: Fallible + ?Sized> Serialize<S> for $name
where $([<$name:snake _payloads>]::[<$opcode Payload>]: Serialize<S>,)*
{
fn serialize(&self, serializer: &mut S) -> Result<[<$name Resolver>], S::Error> {
Ok(match self {
$($name::$opcode(op) => [<$name Resolver>]::$opcode(Serialize::serialize(op, serializer)?),)*
})
}
}

impl<D: Fallible + ?Sized> Deserialize<$name, D> for [<Archived $name>]
where $(Archived<[<$name:snake _payloads>]::[<$opcode Payload>]>:
Deserialize<[<$name:snake _payloads>]::[<$opcode Payload>], D>,)*
{
fn deserialize(&self, deserializer: &mut D) -> Result<$name, D::Error> {
Ok(match self {$(
[<Archived $name>]::$opcode(op) => $name::$opcode(Deserialize::deserialize(op, deserializer)?),
)*})
}
}

use rkyv::bytecheck::{CheckBytes, EnumCheckError, ErrorBox, TupleStructCheckError};

impl<C: ?Sized> CheckBytes<C> for [<Archived $name>]
where $(Archived<[<$name:snake _payloads>]::[<$opcode Payload>]>: CheckBytes<C>,)*
{
type Error = EnumCheckError<u8>;

unsafe fn check_bytes<'a>(value: *const Self, context: &mut C) -> Result<&'a Self, Self::Error> {
let tag = *value.cast::<u8>();

struct Discriminant;

#[allow(non_upper_case_globals)]
impl Discriminant {
$(pub const $opcode: u8 = ArchivedTag::$opcode as u8;)*
}

match tag {
$(
Discriminant::$opcode => {
let value = value.cast::<[<Archived $opcode Variant>]>();

if let Err(e) = CheckBytes::<C>::check_bytes(core::ptr::addr_of!((*value).op), context) {
return Err(EnumCheckError::InvalidTuple {
variant_name: stringify!($opcode),
inner: TupleStructCheckError { field_index: 0, inner: ErrorBox::new(e) }
});
}
}
)*
_ => return Err(EnumCheckError::InvalidTag(tag)),
}

Ok(&*value)
}
}
}

impl $name {
Expand All @@ -521,7 +635,7 @@ pub mod message {

impl $name {
$(
#[doc = "Create new [" $opcode "](" $name "::" $opcode ") message from payload fields."]
#[doc = "Create new [`" $opcode "`](" $name "::" $opcode ") message from payload fields."]
#[doc = ""]
$(#[$variant_meta])*
#[inline]
Expand Down Expand Up @@ -566,11 +680,12 @@ pub mod message {
use std::fmt;

#[derive(Clone, Copy, Deserialize)]
#[serde(rename_all = "lowercase")]
enum Field {
#[serde(rename = "o")]
#[serde(alias = "o")]
Opcode,

#[serde(rename = "p")]
#[serde(alias = "p")]
Payload,
}

Expand Down

0 comments on commit 6b1f46a

Please sign in to comment.