diff --git a/Cargo.lock b/Cargo.lock index b1ef5e7..d5597d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "aho-corasick" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f2135563fb5c609d2b2b87c1e8ce7bc41b0b45430fa9661f457981503dd5bf0" +checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" dependencies = [ "memchr", ] @@ -530,12 +530,11 @@ checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" [[package]] name = "headers" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584" +checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270" dependencies = [ - "base64 0.13.1", - "bitflags 1.3.2", + "base64 0.21.4", "bytes", "headers-core", "http", @@ -555,9 +554,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "hex" @@ -776,9 +775,9 @@ checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "matchit" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" [[package]] name = "memchr" @@ -1064,9 +1063,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.13" +version = "0.38.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662" +checksum = "747c788e9ce8e92b12cd485c49ddf90723550b654b32508f979b71a7b1ecda4f" dependencies = [ "bitflags 2.4.0", "errno", @@ -1083,7 +1082,7 @@ checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" dependencies = [ "log", "ring", - "rustls-webpki 0.101.5", + "rustls-webpki 0.101.6", "sct", ] @@ -1099,9 +1098,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.101.5" +version = "0.101.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45a27e3b59326c16e23d30aeb7a36a24cc0d29e71d68ff611cdfb4a01d013bed" +checksum = "3c7d5dece342910d9ba34d259310cae3e0154b873b35408b787b59bce53d34fe" dependencies = [ "ring", "untrusted", @@ -1237,9 +1236,9 @@ dependencies = [ [[package]] name = "sha1" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", @@ -1277,9 +1276,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" [[package]] name = "socket2" @@ -1571,6 +1570,7 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "base64 0.13.1", "chrono", "diesel", "diesel_migrations", @@ -1695,9 +1695,9 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" dependencies = [ "winapi", ] diff --git a/Cargo.toml b/Cargo.toml index c7da229..4d921a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,9 +6,9 @@ edition = "2021" [dependencies] anyhow = "1.0" axum = { version = "0.6.16", features = ["headers"] } +base64 = "0.13.1" chrono = { version = "0.4.26", features = ["serde"] } diesel = { version = "2.1", features = ["postgres", "r2d2", "chrono", "numeric"] } -diesel_migrations = "2.1.0" dotenv = "0.15.0" futures = "0.3.28" hex = "0.4.3" @@ -23,3 +23,6 @@ tokio = { version = "1.12.0", features = ["full"] } tower-http = { version = "0.4.0", features = ["cors"] } ureq = { version = "2.5.0", features = ["json"] } + +[dev-dependencies] +diesel_migrations = "2.1.0" \ No newline at end of file diff --git a/README.md b/README.md index a8f864d..3730de6 100644 --- a/README.md +++ b/README.md @@ -8,3 +8,9 @@ You need a postgres database and an authentication key. These can be set in the and `AUTH_KEY` respectively. This can be set in a `.env` file in the root of the project. To run the server, run `cargo run --release` in the root of the project. + +## Stress testing + +``` +AUTH_TOKEN=ey... drill --benchmark drill.yml -o 30 +``` diff --git a/drill.yml b/drill.yml new file mode 100644 index 0000000..8aa278a --- /dev/null +++ b/drill.yml @@ -0,0 +1,188 @@ +base: 'https://vss-staging.fly.dev' +concurrency: 125 +iterations: 2000 + +plan: + - name: Health Check + request: + method: GET + url: /health-check + + # Generate a random UUID and assign it to the test_key variable + - name: Generate unique test key + exec: + command: "echo \"drill_test_$(uuidgen)\"" + assign: test_key + + # Get the Object to fetch the initial version + - name: Get Object Initial (1) - '{{ test_key }}' + request: + method: POST + url: /getObject + body: '{"key": "{{ test_key }}"}' + headers: + Content-Type: 'application/json' + Authorization: Bearer {{ AUTH_TOKEN }} + assign: initial_get_object_response + + - name: Extract and increment version (1) - '{{ test_key }}' + exec: + command: "input='{{ initial_get_object_response.body }}'; [ -z \"$input\" ] && echo 0 || echo \"$input\" | jq '.version? // empty | if type == \"number\" then . + 1 else 0 end'" + assign: version + + - name: Put Objects (1) - '{{ test_key }}' + request: + method: PUT + url: /putObjects + body: '{"transaction_items": [{"key": "{{ test_key }}", "value": [0, 1, 2], "version": {{ version }} }]}' + headers: + Content-Type: 'application/json' + Authorization: Bearer {{ AUTH_TOKEN }} + + - name: Get Object after Put (1) - '{{ test_key }}' + request: + method: POST + url: /getObject + body: '{"key": "{{ test_key }}"}' + headers: + Content-Type: 'application/json' + Authorization: Bearer {{ AUTH_TOKEN }} + assign: get_object_response + + # Basic assertion to make sure we got a good response with the right key + - name: Extract key from response - '{{ test_key }}' + exec: + command: "echo '{{ get_object_response.body }}' | jq '.key'" + assign: retrieved_key + + - name: Compare test_key and retrieved_key - '{{ test_key }}' + exec: + command: "if [ \"{{ test_key }}\" = \"{{ retrieved_key }}\" ]; then echo 'true'; else echo 'false'; fi" + assign: key_comparison_result + + - name: Assert keys match - '{{ test_key }}' + assert: + key: key_comparison_result + value: "true" + + # Compare the version from "Get Object response" with the assigned version + - name: Compare versions with external command (1) - '{{ test_key }}' + exec: + command: "echo '{{ get_object_response.body }}' | jq --arg version '{{ version }}' '.version == ($version | tonumber)'" + assign: version_match_result + + # Assert that the result from the comparison is true + - name: Assert versions match (1) - '{{ test_key }}' + assert: + key: version_match_result + value: "true" + + # + ## Do this a 2nd time with a bigger version + # + - name: Extract and increment version (2) - '{{ test_key }}' + exec: + command: "echo $(({{ version }} + 1))" + assign: version + + - name: Put Objects (2) - '{{ test_key }}' + request: + method: PUT + url: /putObjects + body: '{"transaction_items": [{"key": "{{ test_key }}", "value": [0, 1, 2], "version": {{ version }} }]}' + headers: + Content-Type: 'application/json' + Authorization: Bearer {{ AUTH_TOKEN }} + + - name: Get Object after Put (2) - '{{ test_key }}' + request: + method: POST + url: /getObject + body: '{"key": "{{ test_key }}"}' + headers: + Content-Type: 'application/json' + Authorization: Bearer {{ AUTH_TOKEN }} + assign: get_object_response + + # Compare the version from "Get Object response" with the assigned version + - name: Compare versions with external command (2) - '{{ test_key }}' + exec: + command: "echo '{{ get_object_response.body }}' | jq --arg version '{{ version }}' '.version == ($version | tonumber)'" + assign: version_match_result + + # Basic assertion to make sure we got a good response with the right key + - name: Extract key from response - '{{ test_key }}' + exec: + command: "echo '{{ get_object_response.body }}' | jq '.key'" + assign: retrieved_key + + - name: Compare test_key and retrieved_key - '{{ test_key }}' + exec: + command: "if [ \"{{ test_key }}\" = \"{{ retrieved_key }}\" ]; then echo 'true'; else echo 'false'; fi" + assign: key_comparison_result + + - name: Assert keys match - '{{ test_key }}' + assert: + key: key_comparison_result + value: "true" + + # Assert that the result from the comparison is true + - name: Assert versions match (2) - '{{ test_key }}' + assert: + key: version_match_result + value: "true" + + # + ## Do this a third time with a bigger version + # + - name: Extract and increment version (3) - '{{ test_key }}' + exec: + command: "echo $(({{ version }} + 1))" + assign: version + + - name: Put Objects (3) - '{{ test_key }}' + request: + method: PUT + url: /putObjects + body: '{"transaction_items": [{"key": "{{ test_key }}", "value": [0, 1, 2], "version": {{ version }} }]}' + headers: + Content-Type: 'application/json' + Authorization: Bearer {{ AUTH_TOKEN }} + + - name: Get Object after Put (3) - '{{ test_key }}' + request: + method: POST + url: /getObject + body: '{"key": "{{ test_key }}"}' + headers: + Content-Type: 'application/json' + Authorization: Bearer {{ AUTH_TOKEN }} + assign: get_object_response + + # Basic assertion to make sure we got a good response with the right key + - name: Extract key from response - '{{ test_key }}' + exec: + command: "echo '{{ get_object_response.body }}' | jq '.key'" + assign: retrieved_key + + - name: Compare test_key and retrieved_key - '{{ test_key }}' + exec: + command: "if [ \"{{ test_key }}\" = \"{{ retrieved_key }}\" ]; then echo 'true'; else echo 'false'; fi" + assign: key_comparison_result + + - name: Assert keys match - '{{ test_key }}' + assert: + key: key_comparison_result + value: "true" + + # Compare the version from "Get Object response" with the assigned version + - name: Compare versions with external command (3) - '{{ test_key }}' + exec: + command: "echo '{{ get_object_response.body }}' | jq --arg version '{{ version }}' '.version == ($version | tonumber)'" + assign: version_match_result + + # Assert that the result from the comparison is true + - name: Assert versions match (3) - '{{ test_key }}' + assert: + key: version_match_result + value: "true" diff --git a/migrations/2023-09-18-225828_baseline/up.sql b/migrations/2023-09-18-225828_baseline/up.sql deleted file mode 100644 index a901b6d..0000000 --- a/migrations/2023-09-18-225828_baseline/up.sql +++ /dev/null @@ -1,43 +0,0 @@ -CREATE TABLE vss_db -( - store_id TEXT NOT NULL CHECK (store_id != ''), - key TEXT NOT NULL, - value TEXT, - version BIGINT NOT NULL, - created_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, - updated_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, - PRIMARY KEY (store_id, key) -); - --- triggers to set dates automatically, generated by ChatGPT - --- Function to set created_date and updated_date during INSERT -CREATE OR REPLACE FUNCTION set_created_date() -RETURNS TRIGGER AS $$ -BEGIN - NEW.created_date := CURRENT_TIMESTAMP; - NEW.updated_date := CURRENT_TIMESTAMP; -RETURN NEW; -END; -$$ LANGUAGE plpgsql; - --- Function to set updated_date during UPDATE -CREATE OR REPLACE FUNCTION set_updated_date() -RETURNS TRIGGER AS $$ -BEGIN - NEW.updated_date := CURRENT_TIMESTAMP; -RETURN NEW; -END; -$$ LANGUAGE plpgsql; - --- Trigger for INSERT operation on vss_db -CREATE TRIGGER tr_set_dates_after_insert - BEFORE INSERT ON vss_db - FOR EACH ROW - EXECUTE FUNCTION set_created_date(); - --- Trigger for UPDATE operation on vss_db -CREATE TRIGGER tr_set_dates_after_update - BEFORE UPDATE ON vss_db - FOR EACH ROW - EXECUTE FUNCTION set_updated_date(); \ No newline at end of file diff --git a/migrations/2023-09-20-043550_change-default-timestamp/down.sql b/migrations/2023-09-20-043550_change-default-timestamp/down.sql deleted file mode 100644 index 96dcc8a..0000000 --- a/migrations/2023-09-20-043550_change-default-timestamp/down.sql +++ /dev/null @@ -1,3 +0,0 @@ -ALTER TABLE vss_db - ALTER COLUMN created_date SET DEFAULT CURRENT_TIMESTAMP, - ALTER COLUMN updated_date SET DEFAULT CURRENT_TIMESTAMP; diff --git a/migrations/2023-09-20-043550_change-default-timestamp/up.sql b/migrations/2023-09-20-043550_change-default-timestamp/up.sql deleted file mode 100644 index 1df7ce1..0000000 --- a/migrations/2023-09-20-043550_change-default-timestamp/up.sql +++ /dev/null @@ -1,3 +0,0 @@ -ALTER TABLE vss_db - ALTER COLUMN created_date SET DEFAULT '2023-07-13'::TIMESTAMP, - ALTER COLUMN updated_date SET DEFAULT '2023-07-13'::TIMESTAMP; diff --git a/migrations/2023-09-18-225828_baseline/down.sql b/migrations/2023-09-23-030518_baseline/down.sql similarity index 90% rename from migrations/2023-09-18-225828_baseline/down.sql rename to migrations/2023-09-23-030518_baseline/down.sql index 46feaba..4b78a0e 100644 --- a/migrations/2023-09-18-225828_baseline/down.sql +++ b/migrations/2023-09-23-030518_baseline/down.sql @@ -6,4 +6,4 @@ DROP TRIGGER IF EXISTS tr_set_dates_after_update ON vss_db; DROP FUNCTION IF EXISTS set_created_date(); DROP FUNCTION IF EXISTS set_updated_date(); -DROP TABLE IF EXISTS vss_db; +DROP TABLE IF EXISTS vss_db; \ No newline at end of file diff --git a/migrations/2023-09-23-030518_baseline/up.sql b/migrations/2023-09-23-030518_baseline/up.sql new file mode 100644 index 0000000..63722d3 --- /dev/null +++ b/migrations/2023-09-23-030518_baseline/up.sql @@ -0,0 +1,79 @@ +CREATE TABLE vss_db +( + store_id TEXT NOT NULL CHECK (store_id != ''), + key TEXT NOT NULL, + value bytea, + version BIGINT NOT NULL, + created_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + PRIMARY KEY (store_id, key) +); + +-- triggers to set dates automatically, generated by ChatGPT + +-- Function to set created_date and updated_date during INSERT +CREATE OR REPLACE FUNCTION set_created_date() + RETURNS TRIGGER AS +$$ +BEGIN + NEW.created_date := CURRENT_TIMESTAMP; + NEW.updated_date := CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Function to set updated_date during UPDATE +CREATE OR REPLACE FUNCTION set_updated_date() + RETURNS TRIGGER AS +$$ +BEGIN + NEW.updated_date := CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Trigger for INSERT operation on vss_db +CREATE TRIGGER tr_set_dates_after_insert + BEFORE INSERT + ON vss_db + FOR EACH ROW +EXECUTE FUNCTION set_created_date(); + +-- Trigger for UPDATE operation on vss_db +CREATE TRIGGER tr_set_dates_after_update + BEFORE UPDATE + ON vss_db + FOR EACH ROW +EXECUTE FUNCTION set_updated_date(); + +CREATE OR REPLACE FUNCTION upsert_vss_db( + p_store_id TEXT, + p_key TEXT, + p_value bytea, + p_version BIGINT +) RETURNS VOID AS +$$ +BEGIN + + WITH new_values (store_id, key, value, version) AS (VALUES (p_store_id, p_key, p_value, p_version)) + INSERT + INTO vss_db + (store_id, key, value, version) + SELECT new_values.store_id, + new_values.key, + new_values.value, + new_values.version + FROM new_values + LEFT JOIN vss_db AS existing + ON new_values.store_id = existing.store_id + AND new_values.key = existing.key + WHERE CASE + WHEN new_values.version >= 4294967295 THEN new_values.version >= COALESCE(existing.version, -1) + ELSE new_values.version > COALESCE(existing.version, -1) + END + ON CONFLICT (store_id, key) + DO UPDATE SET value = excluded.value, + version = excluded.version; + +END; +$$ LANGUAGE plpgsql; diff --git a/src/kv.rs b/src/kv.rs new file mode 100644 index 0000000..3eba388 --- /dev/null +++ b/src/kv.rs @@ -0,0 +1,87 @@ +use core::fmt; +use serde::de::Visitor; +use serde::*; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KeyValue { + pub key: String, + pub value: ByteData, + pub version: i64, +} + +impl KeyValue { + pub fn new(key: String, value: Vec, version: i64) -> KeyValue { + KeyValue { + key, + value: ByteData(value), + version, + } + } +} + +#[derive(Debug, Clone)] +pub struct ByteData(pub Vec); + +impl Serialize for ByteData { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.0.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ByteData { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ByteDataVisitor; + + impl<'de> Visitor<'de> for ByteDataVisitor { + type Value = ByteData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a Vec or a base64 encoded string") + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + let decoded = + base64::decode(v).map_err(|err| de::Error::custom(err.to_string()))?; + Ok(ByteData(decoded)) + } + + fn visit_seq(self, seq: S) -> Result + where + S: de::SeqAccess<'de>, + { + let vec = Vec::::deserialize(de::value::SeqAccessDeserializer::new(seq))?; + Ok(ByteData(vec)) + } + } + + deserializer.deserialize_any(ByteDataVisitor) + } +} + +// need this for backwards compat for now + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KeyValueOld { + pub key: String, + pub value: String, + pub version: i64, +} + +impl From for KeyValueOld { + fn from(kv: KeyValue) -> Self { + KeyValueOld { + key: kv.key, + value: base64::encode(kv.value.0), + version: kv.version, + } + } +} diff --git a/src/main.rs b/src/main.rs index 874454d..68722b7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,15 @@ -use crate::models::MIGRATIONS; use crate::routes::*; -use axum::http::{Method, StatusCode, Uri}; +use axum::headers::Origin; +use axum::http::{request::Parts, HeaderValue, Method, StatusCode, Uri}; use axum::routing::{get, post, put}; -use axum::{http, Extension, Router}; +use axum::{http, Extension, Router, TypedHeader}; use diesel::r2d2::{ConnectionManager, Pool}; use diesel::PgConnection; -use diesel_migrations::MigrationHarness; use secp256k1::{All, PublicKey, Secp256k1}; -use tower_http::cors::{Any, CorsLayer}; +use tower_http::cors::{AllowOrigin, CorsLayer}; mod auth; +mod kv; mod migration; mod models; mod routes; @@ -29,8 +29,8 @@ const ALLOWED_LOCALHOST: &str = "http://127.0.0.1:"; #[derive(Clone)] pub struct State { db_pool: Pool>, - auth_key: PublicKey, - secp: Secp256k1, + pub auth_key: PublicKey, + pub secp: Secp256k1, } #[tokio::main] @@ -54,17 +54,11 @@ async fn main() -> anyhow::Result<()> { // DB management let manager = ConnectionManager::::new(&pg_url); let db_pool = Pool::builder() - .max_size(16) + .max_size(10) // should be a multiple of 100, our database connection limit .test_on_check_out(true) .build(manager) .expect("Could not build connection pool"); - // run migrations - let mut connection = db_pool.get()?; - connection - .run_pending_migrations(MIGRATIONS) - .expect("migrations could not run"); - let secp = Secp256k1::new(); let state = State { @@ -80,20 +74,34 @@ async fn main() -> anyhow::Result<()> { let server_router = Router::new() .route("/health-check", get(health_check)) .route("/getObject", post(get_object)) + .route("/v2/getObject", post(get_object_v2)) .route("/putObjects", put(put_objects)) + .route("/v2/putObjects", put(put_objects)) .route("/listKeyVersions", post(list_key_versions)) + .route("/v2/listKeyVersions", post(list_key_versions)) .route("/migration", get(migration::migration)) .fallback(fallback) - .layer(Extension(state.clone())) .layer( CorsLayer::new() - .allow_origin(Any) - .allow_headers(vec![ - http::header::CONTENT_TYPE, - http::header::AUTHORIZATION, - ]) - .allow_methods([Method::GET, Method::POST, Method::PUT, Method::OPTIONS]), - ); + .allow_origin(AllowOrigin::predicate( + |origin: &HeaderValue, _request_parts: &Parts| { + let Ok(origin) = origin.to_str() else { + return false; + }; + + valid_origin(origin) + }, + )) + .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION]) + .allow_methods([ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ]), + ) + .layer(Extension(state)); let server = axum::Server::bind(&addr).serve(server_router.into_make_service()); @@ -113,6 +121,10 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -async fn fallback(uri: Uri) -> (StatusCode, String) { +async fn fallback(origin: Option>, uri: Uri) -> (StatusCode, String) { + if let Err((status, msg)) = validate_cors(origin) { + return (status, msg); + }; + (StatusCode::NOT_FOUND, format!("No route for {uri}")) } diff --git a/src/migration.rs b/src/migration.rs index 444761b..52acad6 100644 --- a/src/migration.rs +++ b/src/migration.rs @@ -18,7 +18,7 @@ pub struct Item { pub key: String, #[serde(default)] pub value: String, - pub version: u32, + pub version: i64, #[serde(default)] #[serde(deserialize_with = "deserialize_datetime_opt")] @@ -66,8 +66,6 @@ pub async fn migration_impl(admin_key: String, state: &State) -> anyhow::Result< let mut finished = false; - let mut conn = state.db_pool.get()?; - info!("Starting migration"); while !finished { info!("Fetching {limit} items from offset {offset}"); @@ -78,24 +76,22 @@ pub async fn migration_impl(admin_key: String, state: &State) -> anyhow::Result< .post(&url) .set("x-api-key", &admin_key) .send_string(&payload.to_string())?; - let values: Vec = resp.into_json()?; + let items: Vec = resp.into_json()?; + + let mut conn = state.db_pool.get().unwrap(); // Insert values into DB conn.transaction::<_, anyhow::Error, _>(|conn| { - for value in values.iter() { - VssItem::put_item( - conn, - &value.store_id, - &value.key, - &value.value, - value.version as u64, - )?; + for item in items.iter() { + if let Ok(value) = base64::decode(&item.value) { + VssItem::put_item(conn, &item.store_id, &item.key, &value, item.version)?; + } } Ok(()) })?; - if values.len() < limit { + if items.len() < limit { finished = true; } else { offset += limit; diff --git a/src/models/mod.rs b/src/models/mod.rs index 5d7e831..d7f82a5 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,14 +1,11 @@ -use crate::routes::KeyValue; +use crate::kv::KeyValue; use diesel::prelude::*; use diesel::sql_query; -use diesel::sql_types::{BigInt, Text}; -use diesel_migrations::{embed_migrations, EmbeddedMigrations}; +use diesel::sql_types::{BigInt, Bytea, Text}; use schema::vss_db; use serde::{Deserialize, Serialize}; -pub mod schema; - -pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); +mod schema; #[derive( QueryableByName, @@ -26,29 +23,17 @@ pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); pub struct VssItem { pub store_id: String, pub key: String, - pub value: Option, + pub value: Option>, pub version: i64, created_date: chrono::NaiveDateTime, updated_date: chrono::NaiveDateTime, } -#[derive(Insertable, AsChangeset)] -#[diesel(table_name = vss_db)] -pub struct NewVssItem { - pub store_id: String, - pub key: String, - pub value: Option, - pub version: i64, -} - impl VssItem { pub fn into_kv(self) -> Option { - self.value.map(|value| KeyValue { - key: self.key, - value, - version: self.version as u64, - }) + self.value + .map(|value| KeyValue::new(self.key, value, self.version)) } pub fn get_item( @@ -67,20 +52,13 @@ impl VssItem { conn: &mut PgConnection, store_id: &str, key: &str, - value: &str, - version: u64, + value: &[u8], + version: i64, ) -> anyhow::Result<()> { - // safely convert u64 to i64 - let version = if version >= i64::MAX as u64 { - i64::MAX - } else { - version as i64 - }; - - sql_query(include_str!("put_item.sql")) + sql_query("SELECT upsert_vss_db($1, $2, $3, $4)") .bind::(store_id) .bind::(key) - .bind::(value) + .bind::(value) .bind::(version) .execute(conn)?; @@ -112,19 +90,19 @@ mod test { use super::*; use crate::State; use diesel::r2d2::{ConnectionManager, Pool}; - use diesel::{Connection, PgConnection, RunQueryDsl}; - use diesel_migrations::MigrationHarness; + use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use secp256k1::Secp256k1; use std::str::FromStr; const PUBKEY: &str = "04547d92b618856f4eda84a64ec32f1694c9608a3f9dc73e91f08b5daa087260164fbc9e2a563cf4c5ef9f4c614fd9dfca7582f8de429a4799a4b202fbe80a7db5"; + const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); fn init_state() -> State { dotenv::dotenv().ok(); let url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); let manager = ConnectionManager::::new(url); let db_pool = Pool::builder() - .max_size(16) + .max_size(10) .test_on_check_out(true) .build(manager) .expect("Could not build connection pool"); @@ -163,22 +141,22 @@ mod test { let store_id = "test_store_id"; let key = "test"; - let value = "test_value"; + let value = [1, 2, 3]; let version = 0; let mut conn = state.db_pool.get().unwrap(); - VssItem::put_item(&mut conn, store_id, key, value, version).unwrap(); + VssItem::put_item(&mut conn, store_id, key, &value, version).unwrap(); let versions = VssItem::list_key_versions(&mut conn, store_id, None).unwrap(); assert_eq!(versions.len(), 1); assert_eq!(versions[0].0, key); - assert_eq!(versions[0].1, version as i64); + assert_eq!(versions[0].1, version); - let new_value = "new_value"; + let new_value = [4, 5, 6]; let new_version = version + 1; - VssItem::put_item(&mut conn, store_id, key, new_value, new_version).unwrap(); + VssItem::put_item(&mut conn, store_id, key, &new_value, new_version).unwrap(); let item = VssItem::get_item(&mut conn, store_id, key) .unwrap() @@ -187,7 +165,7 @@ mod test { assert_eq!(item.store_id, store_id); assert_eq!(item.key, key); assert_eq!(item.value.unwrap(), new_value); - assert_eq!(item.version, new_version as i64); + assert_eq!(item.version, new_version); clear_database(&state); } @@ -199,11 +177,11 @@ mod test { let store_id = "max_test_store_id"; let key = "max_test"; - let value = "test_value"; - let version = u32::MAX as u64; + let value = [1, 2, 3]; + let version = u32::MAX as i64; let mut conn = state.db_pool.get().unwrap(); - VssItem::put_item(&mut conn, store_id, key, value, version).unwrap(); + VssItem::put_item(&mut conn, store_id, key, &value, version).unwrap(); let item = VssItem::get_item(&mut conn, store_id, key) .unwrap() @@ -213,9 +191,9 @@ mod test { assert_eq!(item.key, key); assert_eq!(item.value.unwrap(), value); - let new_value = "new_value"; + let new_value = [4, 5, 6]; - VssItem::put_item(&mut conn, store_id, key, new_value, version).unwrap(); + VssItem::put_item(&mut conn, store_id, key, &new_value, version).unwrap(); let item = VssItem::get_item(&mut conn, store_id, key) .unwrap() @@ -236,13 +214,13 @@ mod test { let store_id = "list_kv_test_store_id"; let key = "kv_test"; let key1 = "other_kv_test"; - let value = "test_value"; + let value = [1, 2, 3]; let version = 0; let mut conn = state.db_pool.get().unwrap(); - VssItem::put_item(&mut conn, store_id, key, value, version).unwrap(); + VssItem::put_item(&mut conn, store_id, key, &value, version).unwrap(); - VssItem::put_item(&mut conn, store_id, key1, value, version).unwrap(); + VssItem::put_item(&mut conn, store_id, key1, &value, version).unwrap(); let versions = VssItem::list_key_versions(&mut conn, store_id, None).unwrap(); assert_eq!(versions.len(), 2); @@ -250,12 +228,12 @@ mod test { let versions = VssItem::list_key_versions(&mut conn, store_id, Some("kv")).unwrap(); assert_eq!(versions.len(), 1); assert_eq!(versions[0].0, key); - assert_eq!(versions[0].1, version as i64); + assert_eq!(versions[0].1, version); let versions = VssItem::list_key_versions(&mut conn, store_id, Some("other")).unwrap(); assert_eq!(versions.len(), 1); assert_eq!(versions[0].0, key1); - assert_eq!(versions[0].1, version as i64); + assert_eq!(versions[0].1, version); clear_database(&state); } diff --git a/src/models/put_item.sql b/src/models/put_item.sql deleted file mode 100644 index feb7b63..0000000 --- a/src/models/put_item.sql +++ /dev/null @@ -1,17 +0,0 @@ -WITH new_values (store_id, key, value, version) AS (VALUES ($1, $2, $3, $4)) -INSERT -INTO vss_db - (store_id, key, value, version) -SELECT new_values.store_id, - new_values.key, - new_values.value, - new_values.version -FROM new_values - LEFT JOIN vss_db AS existing ON new_values.store_id = existing.store_id AND new_values.key = existing.key -WHERE CASE - WHEN new_values.version >= 4294967295 THEN new_values.version >= COALESCE(existing.version, -1) - ELSE new_values.version > COALESCE(existing.version, -1) - END -ON CONFLICT (store_id, key) - DO UPDATE SET value = excluded.value, - version = excluded.version; diff --git a/src/models/schema.rs b/src/models/schema.rs index 0857dac..de1aed0 100644 --- a/src/models/schema.rs +++ b/src/models/schema.rs @@ -4,7 +4,7 @@ diesel::table! { vss_db (store_id, key) { store_id -> Text, key -> Text, - value -> Nullable, + value -> Nullable, version -> Int8, created_date -> Timestamp, updated_date -> Timestamp, diff --git a/src/routes.rs b/src/routes.rs index d9bfc87..7003a60 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,4 +1,5 @@ use crate::auth::verify_token; +use crate::kv::{KeyValue, KeyValueOld}; use crate::models::VssItem; use crate::{State, ALLOWED_LOCALHOST, ALLOWED_ORIGINS, ALLOWED_SUBDOMAIN}; use axum::headers::authorization::Bearer; @@ -10,14 +11,7 @@ use log::{debug, error, trace}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct KeyValue { - pub key: String, - pub value: String, - pub version: u64, -} - -macro_rules! check_store_id { +macro_rules! ensure_store_id { ($payload:ident, $store_id:expr) => { match $payload.store_id { None => $payload.store_id = Some($store_id), @@ -43,11 +37,11 @@ pub async fn get_object_impl( req: GetObjectRequest, state: &State, ) -> anyhow::Result> { - let mut conn = state.db_pool.get()?; - trace!("get_object_impl: {req:?}"); let store_id = req.store_id.expect("must have"); + let mut conn = state.db_pool.get()?; + let item = VssItem::get_item(&mut conn, &store_id, &req.key)?; Ok(item.and_then(|i| i.into_kv())) @@ -58,16 +52,37 @@ pub async fn get_object( TypedHeader(token): TypedHeader>, Extension(state): Extension, Json(mut payload): Json, -) -> Result>, (StatusCode, String)> { +) -> Result>, (StatusCode, String)> { debug!("get_object: {payload:?}"); validate_cors(origin)?; + + let store_id = verify_token(token.token(), &state)?; + + ensure_store_id!(payload, store_id); + + match get_object_impl(payload, &state).await { + Ok(Some(res)) => Ok(Json(Some(res.into()))), + Ok(None) => Ok(Json(None)), + Err(e) => Err(handle_anyhow_error("get_object", e)), + } +} + +pub async fn get_object_v2( + origin: Option>, + TypedHeader(token): TypedHeader>, + Extension(state): Extension, + Json(mut payload): Json, +) -> Result>, (StatusCode, String)> { + debug!("get_object v2: {payload:?}"); + validate_cors(origin)?; + let store_id = verify_token(token.token(), &state)?; - check_store_id!(payload, store_id); + ensure_store_id!(payload, store_id); match get_object_impl(payload, &state).await { Ok(res) => Ok(Json(res)), - Err(e) => Err(handle_anyhow_error(e)), + Err(e) => Err(handle_anyhow_error("get_object_v2", e)), } } @@ -88,13 +103,16 @@ pub async fn put_objects_impl(req: PutObjectsRequest, state: &State) -> anyhow:: let store_id = req.store_id.expect("must have"); let mut conn = state.db_pool.get()?; - conn.transaction(|conn| { + + conn.transaction::<_, anyhow::Error, _>(|conn| { for kv in req.transaction_items { - VssItem::put_item(conn, &store_id, &kv.key, &kv.value, kv.version)?; + VssItem::put_item(conn, &store_id, &kv.key, &kv.value.0, kv.version)?; } Ok(()) - }) + })?; + + Ok(()) } pub async fn put_objects( @@ -104,13 +122,14 @@ pub async fn put_objects( Json(mut payload): Json, ) -> Result, (StatusCode, String)> { validate_cors(origin)?; + let store_id = verify_token(token.token(), &state)?; - check_store_id!(payload, store_id); + ensure_store_id!(payload, store_id); match put_objects_impl(payload, &state).await { Ok(res) => Ok(Json(res)), - Err(e) => Err(handle_anyhow_error(e)), + Err(e) => Err(handle_anyhow_error("put_objects", e)), } } @@ -126,11 +145,11 @@ pub async fn list_key_versions_impl( req: ListKeyVersionsRequest, state: &State, ) -> anyhow::Result> { - let mut conn = state.db_pool.get()?; - // todo pagination let store_id = req.store_id.expect("must have"); + let mut conn = state.db_pool.get()?; + let versions = VssItem::list_key_versions(&mut conn, &store_id, req.key_prefix.as_deref())?; let json = versions @@ -153,13 +172,14 @@ pub async fn list_key_versions( Json(mut payload): Json, ) -> Result>, (StatusCode, String)> { validate_cors(origin)?; + let store_id = verify_token(token.token(), &state)?; - check_store_id!(payload, store_id); + ensure_store_id!(payload, store_id); match list_key_versions_impl(payload, &state).await { Ok(res) => Ok(Json(res)), - Err(e) => Err(handle_anyhow_error(e)), + Err(e) => Err(handle_anyhow_error("list_key_versions", e)), } } @@ -167,17 +187,22 @@ pub async fn health_check() -> Result, (StatusCode, String)> { Ok(Json(())) } -fn validate_cors(origin: Option>) -> Result<(), (StatusCode, String)> { +pub fn valid_origin(origin: &str) -> bool { + ALLOWED_ORIGINS.contains(&origin) + || origin.ends_with(ALLOWED_SUBDOMAIN) + || origin.starts_with(ALLOWED_LOCALHOST) +} + +pub fn validate_cors(origin: Option>) -> Result<(), (StatusCode, String)> { if let Some(TypedHeader(origin)) = origin { if origin.is_null() { return Ok(()); } let origin_str = origin.to_string(); - if !ALLOWED_ORIGINS.contains(&origin_str.as_str()) - && !origin_str.ends_with(ALLOWED_SUBDOMAIN) - && !origin_str.starts_with(ALLOWED_LOCALHOST) - { + if valid_origin(&origin_str) { + return Ok(()); + } else { // The origin is not in the allowed list block the request return Err((StatusCode::NOT_FOUND, String::new())); } @@ -186,7 +211,7 @@ fn validate_cors(origin: Option>) -> Result<(), (StatusCode, Ok(()) } -pub(crate) fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) { - error!("Error: {err:?}"); +pub(crate) fn handle_anyhow_error(function: &str, err: anyhow::Error) -> (StatusCode, String) { + error!("Error in {function}: {err:?}"); (StatusCode::BAD_REQUEST, format!("{err}")) }