Skip to content

Commit

Permalink
Implement transactions to secure data (#3)
Browse files Browse the repository at this point in the history
* New sub command 'db-console' for development

* Implement transactions to secure data

* Fix data racing

* Fix data racing
  • Loading branch information
linw1995 authored Aug 6, 2024
1 parent 5634f99 commit 1224850
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 53 deletions.
4 changes: 4 additions & 0 deletions scripts/cli.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ help() {
echo ""
echo " setup: Setup the project development environment"
echo " teardown: Teardown the project development environment"
echo " db-console: Open the database console"
echo " serve: Run the api server"
echo " test: Run tests"
echo " coverage: Run tests with coverage"
Expand Down Expand Up @@ -71,6 +72,9 @@ main() {
docker compose down
echo ">>> Done"
;;
"db-console")
docker compose exec db psql -Upostgres
;;
"serve")
echo ">>> Running the api server"
cargo run --bin serve
Expand Down
98 changes: 52 additions & 46 deletions src/api/bookmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use super::errors::Error;
use super::fairings::db::Db;
use crate::db::{self, bookmark, folder, tag};

use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket_db_pools::Connection;
Expand Down Expand Up @@ -43,16 +45,22 @@ pub async fn create_bookmark(
},
payload.tags,
);
let m = bookmark::create_bookmark(&mut db, &new).await;
let (m, f, ts) = db
.transaction::<_, Error, _>(|db| {
async move {
let m = bookmark::create_bookmark(db, &new).await;

tag::update_bookmark_tags(&mut db, &m, &tags).await;
tag::update_bookmark_tags(db, &m, &tags).await;

if let Some(folder_id) = payload.folder_id {
// TODO: transaction required?
folder::move_bookmarks(&mut db, folder_id, &vec![m.id]).await?;
}
if let Some(folder_id) = payload.folder_id {
folder::move_bookmarks(db, folder_id, &vec![m.id]).await?;
}

let (m, f, ts) = db::get_bookmark_details(&mut db, vec![m]).await.remove(0);
Ok(db::get_bookmark_details(db, vec![m]).await.remove(0))
}
.scope_boxed()
})
.await?;

Ok(Json(Bookmark {
id: m.id,
Expand Down Expand Up @@ -139,37 +147,34 @@ pub async fn update_bookmark(
return Err(Error::BadRequest("No changes".to_string()));
}

let m = if let Some(payload) = modify_bookmark {
bookmark::update_bookmark(&mut db, id, payload).await
} else {
bookmark::Bookmark::get(&mut db, id).await
}
.ok_or_else(|| Error::NotFound("Bookmark not found".to_string()))?;

if let Some(payload) = modify_tags {
tag::update_bookmark_tags(&mut db, &m, &payload).await;
}

let rv = db::get_bookmark_details(&mut db, vec![m.clone()]).await;
if let Some((m, folder, tags)) = rv.first() {
return Ok(Json(Bookmark {
id: m.id,
title: m.title.clone(),
url: m.url.clone(),
folder: folder.clone().map(|f| f.path),
tags: tags.iter().map(|t| t.name.clone()).collect(),
created_at: m.created_at,
updated_at: m.updated_at,
deleted_at: m.deleted_at,
}));
}
let (m, folder, tags) = db
.transaction::<_, Error, _>(|db| {
async move {
let m = if let Some(payload) = modify_bookmark {
bookmark::update_bookmark(db, id, payload).await
} else {
bookmark::Bookmark::get(db, id).await
}
.ok_or_else(|| Error::NotFound("Bookmark not found".to_string()))?;

if let Some(payload) = modify_tags {
tag::update_bookmark_tags(db, &m, &payload).await;
}

Ok(db::get_bookmark_details(db, vec![m.clone()])
.await
.remove(0))
}
.scope_boxed()
})
.await?;

Ok(Json(Bookmark {
id: m.id,
title: m.title,
url: m.url,
folder: None,
tags: vec![],
title: m.title.clone(),
url: m.url.clone(),
folder: folder.clone().map(|f| f.path),
tags: tags.iter().map(|t| t.name.clone()).collect(),
created_at: m.created_at,
updated_at: m.updated_at,
deleted_at: m.deleted_at,
Expand All @@ -189,7 +194,8 @@ pub fn routes() -> Vec<rocket::Route> {
mod test {
use super::*;
use crate::db::bookmark::test::rand_bookmark;
use crate::utils;
use crate::utils::percent_encoding;
use crate::utils::rand::rand_str;

use itertools::Itertools;
use rocket::http::Status;
Expand All @@ -210,7 +216,7 @@ mod test {
url: "https://www.rust-lang.org".to_string(),
title: "Rust".to_string(),
folder_id: None,
tags: vec!["rust".to_string(), "programming".to_string()],
tags: vec![rand_str(4), rand_str(4)],
};
let response = client
.post(uri!(super::create_bookmark))
Expand All @@ -231,7 +237,7 @@ mod test {
url: "https://www.rust-lang.org".to_string(),
title: "Rust".to_string(),
folder_id: None,
tags: vec!["rust".to_string(), "programming".to_string()],
tags: vec![rand_str(4), rand_str(4)],
};
let response = client
.post(uri!(super::create_bookmark))
Expand Down Expand Up @@ -300,42 +306,42 @@ mod test {
);

assert_get_bookmarks!(
format!("/?q={}", utils::percent_encoding("#global weather")),
format!("/?q={}", percent_encoding("#global weather")),
results.len() == 1,
"Expected 1 bookmark, got {}",
results.len()
);

assert_get_bookmarks!(
format!("/?q={}", utils::percent_encoding("#west weather")),
format!("/?q={}", percent_encoding("#west weather")),
results.len() == 1,
"Expected 1 bookmark, got {}",
results.len()
);

assert_get_bookmarks!(
format!("/?q={}", utils::percent_encoding("#global #west weather")),
format!("/?q={}", percent_encoding("#global #west weather")),
results.len() == 0,
"Expected 0 bookmark, got {}",
results.len()
);

assert_get_bookmarks!(
format!("/?q={}", utils::percent_encoding("#weather")),
format!("/?q={}", percent_encoding("#weather")),
results.len() == 3,
"Expected 3 bookmarks, got {}",
results.len()
);
assert_get_bookmarks!(
format!("/?q={}&limit=1", utils::percent_encoding("#weather")),
format!("/?q={}&limit=1", percent_encoding("#weather")),
results.len() == 1,
"Expected 1 bookmark, got {}",
results.len()
);
assert_get_bookmarks!(
format!(
"/?q={}&limit=3&before={}",
utils::percent_encoding("#weather"),
percent_encoding("#weather"),
results[0].id
),
results.len() == 2,
Expand All @@ -351,7 +357,7 @@ mod test {
url: payload.url,
title: payload.title,
folder_id: None,
tags: vec!["rust".to_string(), "programming".to_string()],
tags: vec![rand_str(4), rand_str(4)],
};
info!(?payload, "creating");
let title = payload.title.clone();
Expand Down Expand Up @@ -401,7 +407,7 @@ mod test {
url: m.url,
title: m.title,
folder_id: None,
tags: vec!["rust".to_string(), "programming".to_string()],
tags: vec![rand_str(4), rand_str(4)],
};
let response = client
.post(uri!(super::create_bookmark))
Expand Down
8 changes: 8 additions & 0 deletions src/api/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub enum Error {
NotFound(String),
#[response(status = 400)]
BadRequest(String),
#[response(status = 500)]
InternalServer(String),
}

impl From<DatabaseError> for Error {
Expand All @@ -30,3 +32,9 @@ impl From<BearQLError> for Error {
}
}
}

impl From<diesel::result::Error> for Error {
fn from(e: diesel::result::Error) -> Self {
Error::InternalServer(e.to_string())
}
}
15 changes: 8 additions & 7 deletions src/db/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub struct ModifyTag {
pub name: Option<String>,
}

pub async fn get_tags(conn: &mut Connection, tags: &Vec<String>) -> Vec<Tag> {
pub async fn get_tags(conn: &mut Connection, tags: &[String]) -> Vec<Tag> {
tags::table
.filter(tags::name.eq_any(tags))
.load(conn)
Expand All @@ -49,23 +49,24 @@ pub async fn get_tags(conn: &mut Connection, tags: &Vec<String>) -> Vec<Tag> {
}

pub async fn get_or_create_tags(conn: &mut Connection, tags: &[String]) -> Vec<Tag> {
use diesel::{dsl::now, ExpressionMethods};
let exists_tags = get_tags(conn, tags).await;

let tags = tags
.iter()
.filter(|name| !exists_tags.iter().any(|tag| &&tag.name == name))
.map(|name| NewTag {
name: name.to_string(),
})
.collect::<Vec<_>>();

diesel::insert_into(tags::table)
let new_tags = diesel::insert_into(tags::table)
.values(&tags)
.on_conflict(tags::name)
.do_update()
.set(tags::updated_at.eq(now))
.returning(Tag::as_returning())
.get_results(conn)
.await
.expect("Error creating tags")
.expect("Error creating tags");

exists_tags.into_iter().chain(new_tags).collect()
}

pub async fn update_bookmark_tags(conn: &mut Connection, bookmark: &Bookmark, tags: &[String]) {
Expand Down

0 comments on commit 1224850

Please sign in to comment.