diff --git a/src/avram/join.cr b/src/avram/join.cr index a3b566bc0..14583d779 100644 --- a/src/avram/join.cr +++ b/src/avram/join.cr @@ -76,4 +76,60 @@ module Avram::Join "FULL" end end + + class Raw + @clause : String + + def self.new(statement : String, *bind_vars) + new(statement, args: bind_vars.to_a) + end + + def initialize(statement : String, *, args bind_vars : Array) + ensure_enough_bind_variables_for!(statement, bind_vars) + @clause = build_clause(statement, bind_vars) + end + + def prepare(placeholder_supplier : Proc(String)) : String + @clause + end + + def to_sql : String + @clause + end + + def clone : self + self + end + + private def ensure_enough_bind_variables_for!(statement, bind_vars) + bindings = statement.chars.select(&.== '?') + if bindings.size != bind_vars.size + raise "wrong number of bind variables (#{bind_vars.size} for #{bindings.size}) in #{statement}" + end + end + + private def build_clause(statement, bind_vars) + bind_vars.each do |arg| + encoded_arg = prepare_for_execution(arg) + statement = statement.sub('?', encoded_arg) + end + statement + end + + private def prepare_for_execution(value) + if value.is_a?(Array) + "'#{PQ::Param.encode_array(value)}'" + else + escape_if_needed(value) + end + end + + private def escape_if_needed(value) + if value.is_a?(String) || value.is_a?(Slice(UInt8)) + PG::EscapeHelper.escape_literal(value) + else + value + end + end + end end diff --git a/src/avram/query_builder.cr b/src/avram/query_builder.cr index 1d67e8703..c00a6660a 100644 --- a/src/avram/query_builder.cr +++ b/src/avram/query_builder.cr @@ -8,6 +8,7 @@ class Avram::QueryBuilder @offset : Int32? @wheres = [] of Avram::Where::Condition @joins = [] of Avram::Join::SqlClause + @raw_joins = [] of Avram::Join::Raw @orders = [] of Avram::OrderByClause @groups = [] of ColumnName @selections : String = "*" @@ -318,6 +319,15 @@ class Avram::QueryBuilder @joins.uniq(&.to_sql) end + def join(raw_join_clause : Avram::Join::Raw) : self + @raw_joins << raw_join_clause + self + end + + def joins : Array(Avram::Join::Raw) + @raw_joins.uniq(&.to_sql) + end + private def joins_sql : String joins.join(" ", &.to_sql) end diff --git a/src/avram/queryable.cr b/src/avram/queryable.cr index ccbb9bdb0..e9ae5b082 100644 --- a/src/avram/queryable.cr +++ b/src/avram/queryable.cr @@ -124,6 +124,14 @@ module Avram::Queryable(T) clone.tap &.query.join(join_clause) end + def join(statement : String, *bind_vars) : self + join(statement, args: bind_vars.to_a) + end + + def join(statement : String, *, args bind_vars : Array) : self + clone.tap &.query.join(Avram::Join::Raw.new(statement, args: bind_vars)) + end + def where(column : Symbol, value) : self clone.tap &.query.where(Avram::Where::Equal.new(column, value.to_s)) end