From 622ad7ba726de53bcaa871de986effe82c827dbd Mon Sep 17 00:00:00 2001 From: 0xZensh Date: Thu, 18 Jul 2024 16:08:17 +0800 Subject: [PATCH] feat: add memory cacher to `idempotent-proxy-server` --- .env | 3 +- Cargo.lock | 5 + src/idempotent-proxy-server/Cargo.toml | 9 + .../src/cache/memory.rs | 211 ++++++++++++++++++ .../src/cache/mod.rs} | 64 +++++- .../src/{ => cache}/redis.rs | 41 ++-- src/idempotent-proxy-server/src/handler.rs | 8 +- src/idempotent-proxy-server/src/main.rs | 29 ++- src/idempotent-proxy-types/src/lib.rs | 1 - 9 files changed, 327 insertions(+), 44 deletions(-) create mode 100644 src/idempotent-proxy-server/src/cache/memory.rs rename src/{idempotent-proxy-types/src/cache.rs => idempotent-proxy-server/src/cache/mod.rs} (83%) rename src/idempotent-proxy-server/src/{ => cache}/redis.rs (79%) diff --git a/.env b/.env index e8ba787..7019693 100644 --- a/.env +++ b/.env @@ -1,5 +1,6 @@ SERVER_ADDR=127.0.0.1:8080 -REDIS_URL=127.0.0.1:6379 +# if not set, use in-memory cache +# REDIS_URL=127.0.0.1:6379 POLL_INTERVAL=100 # in milliseconds REQUEST_TIMEOUT=10000 # in milliseconds LOG_LEVEL=info # debug, info, warn, error diff --git a/Cargo.lock b/Cargo.lock index 8559082..a33352c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1103,13 +1103,18 @@ dependencies = [ "ciborium", "dotenvy", "ed25519-dalek", + "futures", + "hex-conservative", "http", "idempotent-proxy-types", "k256", "log", + "rand_core", "reqwest", "rustis", "serde", + "serde_bytes", + "serde_json", "structured-logger", "tokio", ] diff --git a/src/idempotent-proxy-server/Cargo.toml b/src/idempotent-proxy-server/Cargo.toml index abfd9f1..1b7b204 100644 --- a/src/idempotent-proxy-server/Cargo.toml +++ b/src/idempotent-proxy-server/Cargo.toml @@ -17,6 +17,7 @@ name = "idempotent-proxy-server" axum = { workspace = true } axum-server = { workspace = true } tokio = { workspace = true } +futures = { workspace = true } reqwest = { workspace = true } dotenvy = { workspace = true } log = { workspace = true } @@ -26,9 +27,17 @@ rustis = { workspace = true } bb8 = { workspace = true } async-trait = { workspace = true } serde = { workspace = true } +serde_bytes = { workspace = true } +serde_json = { workspace = true } ciborium = { workspace = true } anyhow = { workspace = true } k256 = { workspace = true } ed25519-dalek = { workspace = true } base64 = { workspace = true } idempotent-proxy-types = { path = "../idempotent-proxy-types", version = "1" } + +[dev-dependencies] +rand_core = "0.6" +hex = { package = "hex-conservative", version = "0.2", default-features = false, features = [ + "alloc", +] } diff --git a/src/idempotent-proxy-server/src/cache/memory.rs b/src/idempotent-proxy-server/src/cache/memory.rs new file mode 100644 index 0000000..e0e7af9 --- /dev/null +++ b/src/idempotent-proxy-server/src/cache/memory.rs @@ -0,0 +1,211 @@ +use async_trait::async_trait; +use std::{ + collections::{ + hash_map::{Entry, HashMap}, + BTreeSet, + }, + sync::Arc, +}; +use structured_logger::unix_ms; +use tokio::{ + sync::RwLock, + time::{sleep, Duration}, +}; + +use super::Cacher; + +#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)] +struct PriorityKey(u64, String); + +#[derive(Clone, Default)] +pub struct MemoryCacher { + priority_queue: Arc>>, + kv: Arc)>>>, +} + +impl MemoryCacher { + fn clean_expired_values(&self) -> tokio::task::JoinHandle<()> { + let kv = self.kv.clone(); + let priority_queue = self.priority_queue.clone(); + tokio::spawn(async move { + let now = unix_ms(); + let mut pq = priority_queue.write().await; + let mut kv = kv.write().await; + while let Some(PriorityKey(expire_at, key)) = pq.pop_first() { + if expire_at > now { + pq.insert(PriorityKey(expire_at, key)); + break; + } + + kv.remove(&key); + } + () + }) + } +} + +#[async_trait] +impl Cacher for MemoryCacher { + async fn obtain(&self, key: &str, ttl: u64) -> Result { + let mut kv = self.kv.write().await; + let now = unix_ms(); + match kv.entry(key.to_string()) { + Entry::Occupied(mut entry) => { + let (expire_at, value) = entry.get_mut(); + if *expire_at > now { + return Ok(false); + } + + let mut pq = self.priority_queue.write().await; + pq.remove(&PriorityKey(*expire_at, key.to_string())); + + *expire_at = now + ttl; + *value = vec![0]; + pq.insert(PriorityKey(*expire_at, key.to_string())); + Ok(true) + } + Entry::Vacant(entry) => { + let expire_at = now + ttl; + entry.insert((expire_at, vec![0])); + self.priority_queue + .write() + .await + .insert(PriorityKey(expire_at, key.to_string())); + Ok(true) + } + } + } + + async fn polling_get( + &self, + key: &str, + poll_interval: u64, + mut counter: u64, + ) -> Result, String> { + while counter > 0 { + let kv = self.kv.read().await; + let res = kv.get(key); + match res { + None => return Err("not obtained".to_string()), + Some((expire_at, value)) => { + if *expire_at <= unix_ms() { + drop(kv); + self.kv.write().await.remove(key); + self.clean_expired_values(); + return Err("value expired".to_string()); + } + + if value.len() > 1 { + return Ok(value.clone()); + } + } + } + + counter -= 1; + sleep(Duration::from_millis(poll_interval)).await; + } + + Err(("polling get cache timeout").to_string()) + } + + async fn set(&self, key: &str, val: Vec, ttl: u64) -> Result { + let mut kv = self.kv.write().await; + match kv.get_mut(key) { + Some((expire_at, value)) => { + let now = unix_ms(); + if *expire_at <= now { + kv.remove(key); + self.clean_expired_values(); + return Err("value expired".to_string()); + } + + let mut pq = self.priority_queue.write().await; + pq.remove(&PriorityKey(*expire_at, key.to_string())); + + *expire_at = now + ttl; + *value = val; + pq.insert(PriorityKey(*expire_at, key.to_string())); + Ok(true) + } + None => Err("not obtained".to_string()), + } + } + + async fn del(&self, key: &str) -> Result<(), String> { + let mut kv = self.kv.write().await; + if let Some(val) = kv.remove(key) { + let mut pq = self.priority_queue.write().await; + pq.remove(&PriorityKey(val.0, key.to_string())); + } + self.clean_expired_values(); + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn memory_cacher() { + let mc = MemoryCacher::default(); + + assert!(mc.obtain("key1", 100).await.unwrap()); + assert!(!mc.obtain("key1", 100).await.unwrap()); + assert!(mc.polling_get("key1", 10, 2).await.is_err()); + assert!(mc.set("key", vec![1, 2, 3, 4], 100).await.is_err()); + assert!(mc.set("key1", vec![1, 2, 3, 4], 100).await.is_ok()); + assert!(!mc.obtain("key1", 100).await.unwrap()); + assert_eq!( + mc.polling_get("key1", 10, 2).await.unwrap(), + vec![1, 2, 3, 4] + ); + assert_eq!( + mc.polling_get("key1", 10, 2).await.unwrap(), + vec![1, 2, 3, 4] + ); + + assert!(mc.del("key").await.is_ok()); + assert!(mc.del("key1").await.is_ok()); + assert!(mc.polling_get("key1", 10, 2).await.is_err()); + assert!(mc.set("key1", vec![1, 2, 3, 4], 100).await.is_err()); + assert!(mc.obtain("key1", 100).await.unwrap()); + assert!(mc.set("key1", vec![1, 2, 3, 4], 100).await.is_ok()); + assert_eq!( + mc.polling_get("key1", 10, 2).await.unwrap(), + vec![1, 2, 3, 4] + ); + + sleep(Duration::from_millis(200)).await; + assert!(mc.polling_get("key1", 10, 2).await.is_err()); + assert!(mc.set("key1", vec![1, 2, 3, 4], 100).await.is_err()); + assert!(mc.del("key1").await.is_ok()); + + assert!(mc.obtain("key1", 100).await.unwrap()); + sleep(Duration::from_millis(200)).await; + let _ = mc.clean_expired_values().await; + println!("{:?}", mc.priority_queue.read().await); + + let res = futures::try_join!( + mc.obtain("key1", 100), + mc.obtain("key1", 100), + mc.obtain("key1", 100), + ) + .unwrap(); + match res { + (true, false, false) | (false, true, false) | (false, false, true) => {} + _ => panic!("unexpected result"), + } + + assert_eq!(mc.kv.read().await.len(), 1); + assert_eq!(mc.priority_queue.read().await.len(), 1); + + sleep(Duration::from_millis(200)).await; + assert_eq!(mc.kv.read().await.len(), 1); + assert_eq!(mc.priority_queue.read().await.len(), 1); + let _ = mc.clean_expired_values().await; + + assert!(mc.kv.read().await.is_empty()); + assert!(mc.priority_queue.read().await.is_empty()); + } +} diff --git a/src/idempotent-proxy-types/src/cache.rs b/src/idempotent-proxy-server/src/cache/mod.rs similarity index 83% rename from src/idempotent-proxy-types/src/cache.rs rename to src/idempotent-proxy-server/src/cache/mod.rs index d3e55cd..01d18dd 100644 --- a/src/idempotent-proxy-types/src/cache.rs +++ b/src/idempotent-proxy-server/src/cache/mod.rs @@ -8,10 +8,36 @@ use http::{ header::{HeaderMap, HeaderName, HeaderValue}, StatusCode, }; +use idempotent_proxy_types::err_string; use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; -use crate::err_string; +mod memory; +mod redis; + +pub use memory::*; +pub use redis::*; + +pub struct HybridCacher { + pub poll_interval: u64, + pub cache_ttl: u64, + cache: CacherEntry, +} + +impl HybridCacher { + pub fn new(poll_interval: u64, cache_ttl: u64, cache: CacherEntry) -> Self { + Self { + poll_interval, + cache_ttl, + cache, + } + } +} + +pub enum CacherEntry { + Memory(MemoryCacher), + Redis(RedisClient), +} #[async_trait] pub trait Cacher { @@ -26,6 +52,42 @@ pub trait Cacher { async fn del(&self, key: &str) -> Result<(), String>; } +#[async_trait] +impl Cacher for HybridCacher { + async fn obtain(&self, key: &str, ttl: u64) -> Result { + match &self.cache { + CacherEntry::Memory(cacher) => cacher.obtain(key, ttl).await, + CacherEntry::Redis(cacher) => cacher.obtain(key, ttl).await, + } + } + + async fn polling_get( + &self, + key: &str, + poll_interval: u64, + counter: u64, + ) -> Result, String> { + match &self.cache { + CacherEntry::Memory(cacher) => cacher.polling_get(key, poll_interval, counter).await, + CacherEntry::Redis(cacher) => cacher.polling_get(key, poll_interval, counter).await, + } + } + + async fn set(&self, key: &str, val: Vec, ttl: u64) -> Result { + match &self.cache { + CacherEntry::Memory(cacher) => cacher.set(key, val, ttl).await, + CacherEntry::Redis(cacher) => cacher.set(key, val, ttl).await, + } + } + + async fn del(&self, key: &str) -> Result<(), String> { + match &self.cache { + CacherEntry::Memory(cacher) => cacher.del(key).await, + CacherEntry::Redis(cacher) => cacher.del(key).await, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct ResponseData { pub status: u16, diff --git a/src/idempotent-proxy-server/src/redis.rs b/src/idempotent-proxy-server/src/cache/redis.rs similarity index 79% rename from src/idempotent-proxy-server/src/redis.rs rename to src/idempotent-proxy-server/src/cache/redis.rs index 10393a0..ae54358 100644 --- a/src/idempotent-proxy-server/src/redis.rs +++ b/src/idempotent-proxy-server/src/cache/redis.rs @@ -1,39 +1,32 @@ use async_trait::async_trait; +use idempotent_proxy_types::err_string; use rustis::bb8::{CustomizeConnection, ErrorSink, Pool}; use rustis::client::PooledClientManager; use rustis::commands::{GenericCommands, SetCondition, SetExpiration, StringCommands}; use rustis::resp::BulkString; use tokio::time::{sleep, Duration}; -use idempotent_proxy_types::{cache::Cacher, err_string}; +use super::Cacher; pub struct RedisClient { pool: Pool, - pub poll_interval: u64, - pub cache_ttl: u64, } -pub async fn new( - url: &str, - poll_interval: u64, - cache_ttl: u64, -) -> Result { - let manager = PooledClientManager::new(url).unwrap(); - let pool = Pool::builder() - .max_size(10) - .min_idle(Some(1)) - .max_lifetime(None) - .idle_timeout(Some(Duration::from_secs(600))) - .connection_timeout(Duration::from_secs(3)) - .error_sink(Box::new(RedisMonitor {})) - .connection_customizer(Box::new(RedisMonitor {})) - .build(manager) - .await?; - Ok(RedisClient { - pool, - poll_interval, - cache_ttl, - }) +impl RedisClient { + pub async fn new(url: &str) -> Result { + let manager = PooledClientManager::new(url).unwrap(); + let pool = Pool::builder() + .max_size(10) + .min_idle(Some(1)) + .max_lifetime(None) + .idle_timeout(Some(Duration::from_secs(600))) + .connection_timeout(Duration::from_secs(3)) + .error_sink(Box::new(RedisMonitor {})) + .connection_customizer(Box::new(RedisMonitor {})) + .build(manager) + .await?; + Ok(RedisClient { pool }) + } } #[derive(Debug, Clone, Copy)] diff --git a/src/idempotent-proxy-server/src/handler.rs b/src/idempotent-proxy-server/src/handler.rs index fd06264..214de7f 100644 --- a/src/idempotent-proxy-server/src/handler.rs +++ b/src/idempotent-proxy-server/src/handler.rs @@ -5,6 +5,7 @@ use axum::{ }; use base64::{engine::general_purpose, Engine}; use http::{header::AsHeaderName, HeaderMap, HeaderValue, StatusCode}; +use idempotent_proxy_types::*; use k256::ecdsa; use reqwest::Client; use std::{ @@ -12,15 +13,12 @@ use std::{ sync::Arc, }; -use crate::redis::RedisClient; -use idempotent_proxy_types::auth; -use idempotent_proxy_types::cache::{Cacher, ResponseData}; -use idempotent_proxy_types::*; +use crate::cache::{Cacher, HybridCacher, ResponseData}; #[derive(Clone)] pub struct AppState { pub http_client: Arc, - pub cacher: Arc, + pub cacher: Arc, pub agents: Arc>, pub url_vars: Arc>, pub header_vars: Arc>, diff --git a/src/idempotent-proxy-server/src/main.rs b/src/idempotent-proxy-server/src/main.rs index 79d4877..7759abf 100644 --- a/src/idempotent-proxy-server/src/main.rs +++ b/src/idempotent-proxy-server/src/main.rs @@ -14,8 +14,8 @@ use std::{ use structured_logger::{async_json::new_writer, get_env_level, Builder}; use tokio::signal; +mod cache; mod handler; -mod redis; const APP_NAME: &str = env!("CARGO_PKG_NAME"); const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -32,6 +32,10 @@ async fn main() { .map(|n| n.parse().unwrap()) .unwrap_or(10000u64) .max(1000u64); + let poll_interval: u64 = std::env::var("POLL_INTERVAL") + .map(|n| n.parse().unwrap()) + .unwrap_or(100u64) + .max(10u64); let http_client = ClientBuilder::new() .http2_keep_alive_interval(Some(Duration::from_secs(25))) @@ -43,16 +47,13 @@ async fn main() { .build() .unwrap(); - let redis_client = redis::new( - &std::env::var("REDIS_URL").expect("REDIS_URL not found"), - std::env::var("POLL_INTERVAL") - .map(|n| n.parse().unwrap()) - .unwrap_or(100u64) - .max(10u64), - req_timeout, - ) - .await - .unwrap(); + let cacher_entry = match std::env::var("REDIS_URL") { + Ok(url) => { + let redis_client = cache::RedisClient::new(&url).await.unwrap(); + cache::CacherEntry::Redis(redis_client) + } + Err(_) => cache::CacherEntry::Memory(cache::MemoryCacher::default()), + }; let agents: BTreeSet = std::env::var("ALLOW_AGENTS") .unwrap_or_default() @@ -106,7 +107,11 @@ async fn main() { .route("/*any", routing::any(handler::proxy)) .with_state(handler::AppState { http_client: Arc::new(http_client), - cacher: Arc::new(redis_client), + cacher: Arc::new(cache::HybridCacher::new( + poll_interval, + req_timeout, + cacher_entry, + )), agents: Arc::new(agents), url_vars: Arc::new(url_vars), header_vars: Arc::new(header_vars), diff --git a/src/idempotent-proxy-types/src/lib.rs b/src/idempotent-proxy-types/src/lib.rs index d7eb542..5228ecb 100644 --- a/src/idempotent-proxy-types/src/lib.rs +++ b/src/idempotent-proxy-types/src/lib.rs @@ -1,7 +1,6 @@ use http::header::HeaderName; pub mod auth; -pub mod cache; pub static HEADER_PROXY_AUTHORIZATION: HeaderName = HeaderName::from_static("proxy-authorization"); pub static HEADER_X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");