Skip to content

Commit

Permalink
feat: add memory cacher to idempotent-proxy-server
Browse files Browse the repository at this point in the history
  • Loading branch information
zensh committed Jul 18, 2024
1 parent 0aae62a commit 622ad7b
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 44 deletions.
3 changes: 2 additions & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
SERVER_ADDR=127.0.0.1:8080
REDIS_URL=127.0.0.1:6379
# if not set, use in-memory cache
# REDIS_URL=127.0.0.1:6379
POLL_INTERVAL=100 # in milliseconds
REQUEST_TIMEOUT=10000 # in milliseconds
LOG_LEVEL=info # debug, info, warn, error
Expand Down
5 changes: 5 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions src/idempotent-proxy-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ name = "idempotent-proxy-server"
axum = { workspace = true }
axum-server = { workspace = true }
tokio = { workspace = true }
futures = { workspace = true }
reqwest = { workspace = true }
dotenvy = { workspace = true }
log = { workspace = true }
Expand All @@ -26,9 +27,17 @@ rustis = { workspace = true }
bb8 = { workspace = true }
async-trait = { workspace = true }
serde = { workspace = true }
serde_bytes = { workspace = true }
serde_json = { workspace = true }
ciborium = { workspace = true }
anyhow = { workspace = true }
k256 = { workspace = true }
ed25519-dalek = { workspace = true }
base64 = { workspace = true }
idempotent-proxy-types = { path = "../idempotent-proxy-types", version = "1" }

[dev-dependencies]
rand_core = "0.6"
hex = { package = "hex-conservative", version = "0.2", default-features = false, features = [
"alloc",
] }
211 changes: 211 additions & 0 deletions src/idempotent-proxy-server/src/cache/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
use async_trait::async_trait;
use std::{
collections::{
hash_map::{Entry, HashMap},
BTreeSet,
},
sync::Arc,
};
use structured_logger::unix_ms;
use tokio::{
sync::RwLock,
time::{sleep, Duration},
};

use super::Cacher;

#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
struct PriorityKey(u64, String);

#[derive(Clone, Default)]
pub struct MemoryCacher {
priority_queue: Arc<RwLock<BTreeSet<PriorityKey>>>,
kv: Arc<RwLock<HashMap<String, (u64, Vec<u8>)>>>,
}

impl MemoryCacher {
fn clean_expired_values(&self) -> tokio::task::JoinHandle<()> {
let kv = self.kv.clone();
let priority_queue = self.priority_queue.clone();
tokio::spawn(async move {
let now = unix_ms();
let mut pq = priority_queue.write().await;
let mut kv = kv.write().await;
while let Some(PriorityKey(expire_at, key)) = pq.pop_first() {
if expire_at > now {
pq.insert(PriorityKey(expire_at, key));
break;
}

kv.remove(&key);
}
()
})
}
}

#[async_trait]
impl Cacher for MemoryCacher {
async fn obtain(&self, key: &str, ttl: u64) -> Result<bool, String> {
let mut kv = self.kv.write().await;
let now = unix_ms();
match kv.entry(key.to_string()) {
Entry::Occupied(mut entry) => {
let (expire_at, value) = entry.get_mut();
if *expire_at > now {
return Ok(false);
}

let mut pq = self.priority_queue.write().await;
pq.remove(&PriorityKey(*expire_at, key.to_string()));

*expire_at = now + ttl;
*value = vec![0];
pq.insert(PriorityKey(*expire_at, key.to_string()));
Ok(true)
}
Entry::Vacant(entry) => {
let expire_at = now + ttl;
entry.insert((expire_at, vec![0]));
self.priority_queue
.write()
.await
.insert(PriorityKey(expire_at, key.to_string()));
Ok(true)
}
}
}

async fn polling_get(
&self,
key: &str,
poll_interval: u64,
mut counter: u64,
) -> Result<Vec<u8>, String> {
while counter > 0 {
let kv = self.kv.read().await;
let res = kv.get(key);
match res {
None => return Err("not obtained".to_string()),
Some((expire_at, value)) => {
if *expire_at <= unix_ms() {
drop(kv);
self.kv.write().await.remove(key);
self.clean_expired_values();
return Err("value expired".to_string());
}

if value.len() > 1 {
return Ok(value.clone());
}
}
}

counter -= 1;
sleep(Duration::from_millis(poll_interval)).await;
}

Err(("polling get cache timeout").to_string())
}

async fn set(&self, key: &str, val: Vec<u8>, ttl: u64) -> Result<bool, String> {
let mut kv = self.kv.write().await;
match kv.get_mut(key) {
Some((expire_at, value)) => {
let now = unix_ms();
if *expire_at <= now {
kv.remove(key);
self.clean_expired_values();
return Err("value expired".to_string());
}

let mut pq = self.priority_queue.write().await;
pq.remove(&PriorityKey(*expire_at, key.to_string()));

*expire_at = now + ttl;
*value = val;
pq.insert(PriorityKey(*expire_at, key.to_string()));
Ok(true)
}
None => Err("not obtained".to_string()),
}
}

async fn del(&self, key: &str) -> Result<(), String> {
let mut kv = self.kv.write().await;
if let Some(val) = kv.remove(key) {
let mut pq = self.priority_queue.write().await;
pq.remove(&PriorityKey(val.0, key.to_string()));
}
self.clean_expired_values();
Ok(())
}
}

#[cfg(test)]
mod test {
use super::*;

#[tokio::test]
async fn memory_cacher() {
let mc = MemoryCacher::default();

assert!(mc.obtain("key1", 100).await.unwrap());
assert!(!mc.obtain("key1", 100).await.unwrap());
assert!(mc.polling_get("key1", 10, 2).await.is_err());
assert!(mc.set("key", vec![1, 2, 3, 4], 100).await.is_err());
assert!(mc.set("key1", vec![1, 2, 3, 4], 100).await.is_ok());
assert!(!mc.obtain("key1", 100).await.unwrap());
assert_eq!(
mc.polling_get("key1", 10, 2).await.unwrap(),
vec![1, 2, 3, 4]
);
assert_eq!(
mc.polling_get("key1", 10, 2).await.unwrap(),
vec![1, 2, 3, 4]
);

assert!(mc.del("key").await.is_ok());
assert!(mc.del("key1").await.is_ok());
assert!(mc.polling_get("key1", 10, 2).await.is_err());
assert!(mc.set("key1", vec![1, 2, 3, 4], 100).await.is_err());
assert!(mc.obtain("key1", 100).await.unwrap());
assert!(mc.set("key1", vec![1, 2, 3, 4], 100).await.is_ok());
assert_eq!(
mc.polling_get("key1", 10, 2).await.unwrap(),
vec![1, 2, 3, 4]
);

sleep(Duration::from_millis(200)).await;
assert!(mc.polling_get("key1", 10, 2).await.is_err());
assert!(mc.set("key1", vec![1, 2, 3, 4], 100).await.is_err());
assert!(mc.del("key1").await.is_ok());

assert!(mc.obtain("key1", 100).await.unwrap());
sleep(Duration::from_millis(200)).await;
let _ = mc.clean_expired_values().await;
println!("{:?}", mc.priority_queue.read().await);

let res = futures::try_join!(
mc.obtain("key1", 100),
mc.obtain("key1", 100),
mc.obtain("key1", 100),
)
.unwrap();
match res {
(true, false, false) | (false, true, false) | (false, false, true) => {}
_ => panic!("unexpected result"),
}

assert_eq!(mc.kv.read().await.len(), 1);
assert_eq!(mc.priority_queue.read().await.len(), 1);

sleep(Duration::from_millis(200)).await;
assert_eq!(mc.kv.read().await.len(), 1);
assert_eq!(mc.priority_queue.read().await.len(), 1);
let _ = mc.clean_expired_values().await;

assert!(mc.kv.read().await.is_empty());
assert!(mc.priority_queue.read().await.is_empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,36 @@ use http::{
header::{HeaderMap, HeaderName, HeaderValue},
StatusCode,
};
use idempotent_proxy_types::err_string;
use serde::{Deserialize, Serialize};
use serde_bytes::ByteBuf;

use crate::err_string;
mod memory;
mod redis;

pub use memory::*;
pub use redis::*;

pub struct HybridCacher {
pub poll_interval: u64,
pub cache_ttl: u64,
cache: CacherEntry,
}

impl HybridCacher {
pub fn new(poll_interval: u64, cache_ttl: u64, cache: CacherEntry) -> Self {
Self {
poll_interval,
cache_ttl,
cache,
}
}
}

pub enum CacherEntry {
Memory(MemoryCacher),
Redis(RedisClient),
}

#[async_trait]
pub trait Cacher {
Expand All @@ -26,6 +52,42 @@ pub trait Cacher {
async fn del(&self, key: &str) -> Result<(), String>;
}

#[async_trait]
impl Cacher for HybridCacher {
async fn obtain(&self, key: &str, ttl: u64) -> Result<bool, String> {
match &self.cache {
CacherEntry::Memory(cacher) => cacher.obtain(key, ttl).await,
CacherEntry::Redis(cacher) => cacher.obtain(key, ttl).await,
}
}

async fn polling_get(
&self,
key: &str,
poll_interval: u64,
counter: u64,
) -> Result<Vec<u8>, String> {
match &self.cache {
CacherEntry::Memory(cacher) => cacher.polling_get(key, poll_interval, counter).await,
CacherEntry::Redis(cacher) => cacher.polling_get(key, poll_interval, counter).await,
}
}

async fn set(&self, key: &str, val: Vec<u8>, ttl: u64) -> Result<bool, String> {
match &self.cache {
CacherEntry::Memory(cacher) => cacher.set(key, val, ttl).await,
CacherEntry::Redis(cacher) => cacher.set(key, val, ttl).await,
}
}

async fn del(&self, key: &str) -> Result<(), String> {
match &self.cache {
CacherEntry::Memory(cacher) => cacher.del(key).await,
CacherEntry::Redis(cacher) => cacher.del(key).await,
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct ResponseData {
pub status: u16,
Expand Down
Loading

0 comments on commit 622ad7b

Please sign in to comment.