Skip to content

Commit

Permalink
experimenting with sql! (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjovanc authored Nov 18, 2024
2 parents 12920af + 55f2844 commit 68b0634
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 8 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 32 additions & 4 deletions njord/src/sqlite/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ use crate::util::{Join, JoinType};
/// # Returns
///
/// A `SelectQueryBuilder` instance.
pub fn select<'a, T: Table + Default>(
columns: Vec<Column<'a>>,
) -> SelectQueryBuilder<'a, T> {
pub fn select<'a, T: Table + Default>(columns: Vec<Column<'a>>) -> SelectQueryBuilder<'a, T> {
SelectQueryBuilder::new(columns)
}

Expand Down Expand Up @@ -348,7 +346,7 @@ impl<'a, T: Table + Default> SelectQueryBuilder<'a, T> {
}

/// Builds and executes the SELECT query.
///
///
/// # Arguments
///
/// * `conn` - A reference to the database connection.
Expand Down Expand Up @@ -406,3 +404,33 @@ where
self.build_query()
}
}

pub fn raw_execute<T: Table + Default>(sql: &str, conn: &Connection) -> Result<Vec<T>> {
let mut binding = conn.prepare(sql)?;
let iter = binding.query_map((), |row| {
let mut instance = T::default();
let columns = instance.get_column_fields();

for (index, column) in columns.iter().enumerate() {
let value = row.get::<usize, Value>(index)?;

let string_value = match value {
Value::Integer(val) => val.to_string(),
Value::Null => String::new(),
Value::Real(val) => val.to_string(),
Value::Text(val) => val.to_string(),
Value::Blob(val) => String::from_utf8_lossy(&val).to_string(),
};

instance.set_column_value(column, &string_value);
}

Ok(instance)
})?;

let result: Result<Vec<T>> = iter
.map(|row_result| row_result.and_then(|row| Ok(row)))
.collect::<Result<Vec<T>>>();

result.map_err(|err| err.into())
}
57 changes: 57 additions & 0 deletions njord/tests/sqlite/select_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use njord::condition::Condition;
use njord::keys::AutoIncrementPrimaryKey;
use njord::sqlite;
use njord::{column::Column, condition::Value};
use njord_derive::sql;
use std::collections::HashMap;
use std::path::Path;

Expand Down Expand Up @@ -525,3 +526,59 @@ fn select_in() {
Err(e) => panic!("Failed to SELECT: {:?}", e),
};
}

#[test]
fn sql_bang() {
let user_id = 1;

let query = sql! {
SELECT *
FROM user
WHERE id = {user_id}
};

assert_eq!(query.to_string(), "SELECT * FROM user WHERE id = '1'");

let complex_query = sql! {
SELECT a.company, COUNT(i.id) AS total_impressions, COUNT(DISTINCT i.ip_address) AS unique_impressions
FROM impressions i
INNER JOIN cached_content c ON c.content_hash = i.content_hash
INNER JOIN ads a ON a.id = c.ad_id
GROUP BY a.company;
};

assert_eq!(
complex_query.to_string(),
"SELECT a.company, COUNT (i.id) AS total_impressions, COUNT (DISTINCT i.ip_address) AS unique_impressions \
FROM impressions i \
INNER JOIN cached_content c ON c.content_hash = i.content_hash \
INNER JOIN ads a ON a.id = c.ad_id \
GROUP BY a.company;"
);
}

#[test]
fn raw_execute() {
let db_relative_path = "./db/select.db";
let db_path = Path::new(&db_relative_path);
let conn = sqlite::open(db_path);

let username = "mjovanc";

let query = sql! {
SELECT *
FROM users
WHERE username = {username}
};

match conn {
Ok(ref c) => {
let result = sqlite::select::raw_execute::<User>(&query, c);
match result {
Ok(r) => assert_eq!(r.len(), 2),
Err(e) => panic!("Failed to SELECT: {:?}", e),
};
}
Err(e) => panic!("Failed to SELECT: {:?}", e),
};
}
3 changes: 2 additions & 1 deletion njord_derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0.89"
quote = "1.0"
syn = "2.0.87"
syn = { version = "2.0.87", features = ["full"] }
rusqlite = { version = "0.32.1", features = ["bundled"] }
regex = "1.11.1"
149 changes: 148 additions & 1 deletion njord_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ extern crate proc_macro;
use proc_macro::TokenStream;

use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, FieldsNamed};

use proc_macro2::{Delimiter, TokenTree as TokenTree2};
use quote::quote;

use util::{extract_table_name, has_default_impl};

mod util;
Expand Down Expand Up @@ -218,3 +220,148 @@ pub fn table_derive(input: TokenStream) -> TokenStream {

output.into()
}

/// The procedural macro `sql!` takes a SQL-like syntax and transforms it into a string.
// #[proc_macro]
// pub fn sql(input: TokenStream) -> TokenStream {
// /*
// GOAL:
// let id = 1;

// let query = sql! {
// SELECT *
// FROM user
// WHERE id = {id}
// };
// */
// let input_string = input.to_string();

// // Remove the outer quotes
// let input_string = input_string.trim_matches(|c| c == '"' || c == '`' || c == '\'');

// let expanded = quote! {
// {
// #input_string
// }
// };

// expanded.into()
// }

#[proc_macro]
pub fn sql(input: TokenStream) -> TokenStream {
let input: proc_macro2::TokenStream = input.into();
let mut tokens = input.into_iter().peekable();
let mut sql_parts = Vec::new();
let mut expressions = Vec::new();
let mut param_types = Vec::new();
let mut current_sql = String::new();
let mut last_token_type = TokenType::Other;

#[derive(PartialEq, Clone)]
enum TokenType {
Dot,
OpenParen,
CloseParen,
Operator,
Other,
}

while let Some(token) = tokens.next() {
match token {
TokenTree2::Group(group) if group.delimiter() == Delimiter::Brace => {
if !current_sql.is_empty() {
sql_parts.push(current_sql);
current_sql = String::new();
}

// Parse the expression to determine its type
let expr = group.stream();
let expr_str = expr.to_string();

// Check if it's an identifier (likely a string variable)
let needs_quotes = !expr_str.contains("as")
&& !expr_str.contains("::")
&& !expr_str.starts_with("Some")
&& !expr_str.parse::<f64>().is_ok()
&& !expr_str.parse::<i64>().is_ok();

if needs_quotes {
sql_parts.push("'{}'".to_string());
} else {
sql_parts.push("{}".to_string());
}

expressions.push(expr);
param_types.push(needs_quotes);
last_token_type = TokenType::Other;
}
token => {
let token_str = token.to_string();
let current_token_type = match token_str.as_str() {
"." => TokenType::Dot,
"(" => TokenType::OpenParen,
")" => TokenType::CloseParen,
"=" | ">" | "<" | ">=" | "<=" | "!=" => TokenType::Operator,
_ => TokenType::Other,
};
match current_token_type {
TokenType::Dot => {
current_sql.push('.');
}
TokenType::OpenParen => {
current_sql.push('(');
}
TokenType::CloseParen => {
current_sql.push(')');
if let Some(next) = tokens.peek() {
let next_str = next.to_string();
if !matches!(next_str.as_str(), "," | "." | ")" | ";") {
current_sql.push(' ');
}
}
}
TokenType::Operator => {
if !current_sql.ends_with(' ') {
current_sql.push(' ');
}
current_sql.push_str(&token_str);
current_sql.push(' ');
}
TokenType::Other => {
let needs_space = !current_sql.is_empty()
&& !current_sql.ends_with(' ')
&& !matches!(last_token_type, TokenType::Dot | TokenType::OpenParen)
&& token_str != ","
&& token_str != ";";
if needs_space {
current_sql.push(' ');
}
current_sql.push_str(&token_str);
if token_str == "," {
current_sql.push(' ');
}
}
}
last_token_type = current_token_type;
}
}
}

if !current_sql.is_empty() {
sql_parts.push(current_sql);
}

let sql_format = sql_parts.join("");
let expanded = if expressions.is_empty() {
quote! {
#sql_format.to_string()
}
} else {
quote! {
format!(#sql_format #(,#expressions)*)
}
};

expanded.into()
}

0 comments on commit 68b0634

Please sign in to comment.