diff --git a/tests/mssql/tests/all_tests.rs b/tests/mssql/tests/all_tests.rs index cc54520..1f98819 100644 --- a/tests/mssql/tests/all_tests.rs +++ b/tests/mssql/tests/all_tests.rs @@ -349,3 +349,24 @@ async fn should_be_able_to_write_mapquery_with_a_column_rename() { eprintln!("SQL: {}", sql); q.run(&conn).await.unwrap(); } + +#[tokio::test] +async fn should_be_able_to_select_hourse_or_dog() { + use welds::query::clause::or; + let conn = get_conn().await; + use mssql_test::models::product::ProductSchema; + + // verify pulling out lambda into variable + let clause = |x: ProductSchema| or(x.name.like("horse"), x.name.like("dog")); + let q = Product::all().where_col(clause); + + eprintln!("SQL: {}", q.to_sql(Syntax::Sqlite)); + let data = q.run(&conn).await.unwrap(); + assert_eq!(data.len(), 2, "Expected horse and dog",); + + // verify inline clause + let q2 = Product::all().where_col(|x| or(x.name.like("horse"), x.name.like("dog"))); + eprintln!("SQL: {}", q2.to_sql(Syntax::Sqlite)); + let data = q2.run(&conn).await.unwrap(); + assert_eq!(data.len(), 2, "Expected horse and dog",); +} diff --git a/tests/mysql/tests/all_tests.rs b/tests/mysql/tests/all_tests.rs index effb282..f2e6038 100644 --- a/tests/mysql/tests/all_tests.rs +++ b/tests/mysql/tests/all_tests.rs @@ -374,3 +374,26 @@ fn should_be_able_to_write_a_custom_set() { q.run(&conn).await.unwrap(); }) } + +#[test] +fn should_be_able_to_select_hourse_or_dog() { + async_std::task::block_on(async { + use welds::query::clause::or; + let conn = get_conn().await; + use mysql_test::models::product::ProductSchema; + + // verify pulling out lambda into variable + let clause = |x: ProductSchema| or(x.name.like("horse"), x.name.like("dog")); + let q = Product::all().where_col(clause); + + eprintln!("SQL: {}", q.to_sql(Syntax::Mysql)); + let data = q.run(&conn).await.unwrap(); + assert_eq!(data.len(), 2, "Expected horse and dog",); + + // verify inline clause + let q2 = Product::all().where_col(|x| or(x.name.like("horse"), x.name.like("dog"))); + eprintln!("SQL: {}", q2.to_sql(Syntax::Mysql)); + let data = q2.run(&conn).await.unwrap(); + assert_eq!(data.len(), 2, "Expected horse and dog",); + }) +} diff --git a/tests/postgres/tests/all_tests.rs b/tests/postgres/tests/all_tests.rs index 94a8fc6..c5e12f2 100644 --- a/tests/postgres/tests/all_tests.rs +++ b/tests/postgres/tests/all_tests.rs @@ -786,3 +786,26 @@ fn should_be_able_to_select_all_orders_with_there_products() { assert_eq!(o3_products[0].product_id, 1); }) } + +#[test] +fn should_be_able_to_select_hourse_or_dog() { + async_std::task::block_on(async { + use welds::query::clause::or; + let conn = get_conn().await; + use postgres_test::models::product::ProductSchema; + + // verify pulling out lambda into variable + let clause = |x: ProductSchema| or(x.name.like("horse"), x.name.like("dog")); + let q = Product::all().where_col(clause); + + eprintln!("SQL: {}", q.to_sql(Syntax::Postgres)); + let data = q.run(&conn).await.unwrap(); + assert_eq!(data.len(), 2, "Expected horse and dog",); + + // verify inline clause + let q2 = Product::all().where_col(|x| or(x.name.like("horse"), x.name.like("dog"))); + eprintln!("SQL: {}", q2.to_sql(Syntax::Postgres)); + let data = q2.run(&conn).await.unwrap(); + assert_eq!(data.len(), 2, "Expected horse and dog",); + }) +} diff --git a/tests/sqlite/tests/all_tests.rs b/tests/sqlite/tests/all_tests.rs index 858be25..2b815f4 100644 --- a/tests/sqlite/tests/all_tests.rs +++ b/tests/sqlite/tests/all_tests.rs @@ -606,3 +606,26 @@ fn should_be_able_to_fetch_a_single_object() { .unwrap(); }) } + +#[test] +fn should_be_able_to_select_hourse_or_dog() { + async_std::task::block_on(async { + use welds::query::clause::or; + let conn = get_conn().await; + use sqlite_test::models::product::ProductSchema; + + // verify pulling out lambda into variable + let clause = |x: ProductSchema| or(x.name.like("horse"), x.name.like("dog")); + let q = Product::all().where_col(clause); + + eprintln!("SQL: {}", q.to_sql(Syntax::Sqlite)); + let data = q.run(&conn).await.unwrap(); + assert_eq!(data.len(), 2, "Expected horse and dog",); + + // verify inline clause + let q2 = Product::all().where_col(|x| or(x.name.like("horse"), x.name.like("dog"))); + eprintln!("SQL: {}", q2.to_sql(Syntax::Sqlite)); + let data = q2.run(&conn).await.unwrap(); + assert_eq!(data.len(), 2, "Expected horse and dog",); + }) +} diff --git a/welds-macros/src/blocks/define_schema.rs b/welds-macros/src/blocks/define_schema.rs index 91276bb..58c2079 100644 --- a/welds-macros/src/blocks/define_schema.rs +++ b/welds-macros/src/blocks/define_schema.rs @@ -23,6 +23,7 @@ pub(crate) fn write(info: &Info) -> TokenStream { quote! { + #[derive(Copy,Clone)] pub struct #name { #(#fields),* } @@ -75,6 +76,7 @@ mod tests { let code = ts.to_string(); let expected: &str = r#" + #[derive(Copy,Clone)] pub struct MockSchema { pub id: welds::query::clause::Numeric } diff --git a/welds-macros/src/hook.rs b/welds-macros/src/hook.rs index 9424a2e..8cae64d 100644 --- a/welds-macros/src/hook.rs +++ b/welds-macros/src/hook.rs @@ -1,7 +1,6 @@ +use crate::errors::Result; use proc_macro2::{TokenStream, TokenTree}; use quote::ToTokens; -use crate::errors::Result; -//use syn::Ident; use syn::{MetaList, Path}; /// User has defined a Hook on the model @@ -32,7 +31,7 @@ impl Hook { .to_owned()) }; - let list= &list.tokens.clone().into_iter().collect::>(); + let list = &list.tokens.clone().into_iter().collect::>(); if list.len() > 5 { return badformat(); } @@ -41,15 +40,16 @@ impl Hook { if list.len() == 5 { match &list[3] { - TokenTree::Punct(punct)=> - if punct.as_char() != '=' { - return badformat(); + TokenTree::Punct(punct) => { + if punct.as_char() != '=' { + return badformat(); + } } _ => return badformat(), } match &list[2] { - TokenTree::Ident(ident)=> { - if ident.to_string() != "async" { + TokenTree::Ident(ident) => { + if *ident != "async" { return badformat(); } } @@ -58,14 +58,14 @@ impl Hook { match &list[4] { TokenTree::Ident(ident) => { - if ident.to_string()=="true" { + if *ident == "true" { is_async = true; - } else if ident.to_string()=="false" { + } else if *ident == "false" { is_async = false; } else { return badformat(); } - }, + } _ => return badformat(), } } @@ -77,7 +77,7 @@ impl Hook { tokens }; - let callback: syn::Result =syn::parse2(token_stream); + let callback: syn::Result = syn::parse2(token_stream); let callback = match callback { Ok(path) => path, diff --git a/welds-macros/src/relation/basic.rs b/welds-macros/src/relation/basic.rs index 2392983..3dbd6fc 100644 --- a/welds-macros/src/relation/basic.rs +++ b/welds-macros/src/relation/basic.rs @@ -1,7 +1,7 @@ use super::Relation; use super::{read_as_ident, read_as_path, read_as_string}; use crate::errors::Result; -use syn::{Expr, Ident, Token}; +use syn::{Expr, Ident}; use syn::MetaList; use syn::parse::Parser; use syn::punctuated::Punctuated; diff --git a/welds/src/query/builder/mod.rs b/welds/src/query/builder/mod.rs index 48af6fc..b7c7265 100644 --- a/welds/src/query/builder/mod.rs +++ b/welds/src/query/builder/mod.rs @@ -150,7 +150,7 @@ where FN: AsFieldName, { let field = col(Default::default()); - let colname = field.colname().to_string(); + let colname = field.colname(); let params: ManualParam = params.into(); let c = clause::ClauseColManual { col: Some(colname), diff --git a/welds/src/query/clause/assignment_adder.rs b/welds/src/query/clause/assignment_adder.rs index 0ff6144..fdd8fa3 100644 --- a/welds/src/query/clause/assignment_adder.rs +++ b/welds/src/query/clause/assignment_adder.rs @@ -31,7 +31,7 @@ where fn clause(&self, syntax: Syntax, _alias: &str, next_params: &NextParam) -> Option { // build the column name - let colname = ColumnWriter::new(syntax).excape(&self.col); + let colname = ColumnWriter::new(syntax).excape(self.col); let mut parts = vec![colname.as_str()]; // handle null clones @@ -104,7 +104,7 @@ impl AssignmentAdder for AssignmentManual { // build the column name let mut parts = vec![]; - let colname = ColumnWriter::new(syntax).excape(&self.col); + let colname = ColumnWriter::new(syntax).excape(self.col); parts.push(colname); parts.push(" = ( ".to_string()); diff --git a/welds/src/query/clause/basic.rs b/welds/src/query/clause/basic.rs index 7798ef2..5c4d194 100644 --- a/welds/src/query/clause/basic.rs +++ b/welds/src/query/clause/basic.rs @@ -2,26 +2,28 @@ use super::{AsFieldName, ClauseColVal, ClauseColValEqual, ClauseColValIn}; use std::marker::PhantomData; use welds_connections::Param; +#[derive(Clone)] pub struct Basic { - col: String, - field: String, + col: &'static str, + field: &'static str, _t: PhantomData, } impl AsFieldName for Basic { - fn colname(&self) -> &str { - self.col.as_str() + fn colname(&self) -> &'static str { + self.col } - fn fieldname(&self) -> &str { - self.field.as_str() + fn fieldname(&self) -> &'static str { + self.field } } +impl Copy for Basic {} impl Basic where T: 'static + Clone + Send + Sync, { - pub fn new(col: impl Into, field: impl Into) -> Self { + pub fn new(col: &'static str, field: &'static str) -> Self { Self { col: col.into(), field: field.into(), diff --git a/welds/src/query/clause/basicopt.rs b/welds/src/query/clause/basicopt.rs index 97936c4..b4dffd7 100644 --- a/welds/src/query/clause/basicopt.rs +++ b/welds/src/query/clause/basicopt.rs @@ -4,20 +4,22 @@ use crate::query::optional::Optional; use std::marker::PhantomData; use welds_connections::Param; +#[derive(Clone)] pub struct BasicOpt { - col: String, - field: String, + col: &'static str, + field: &'static str, _t: PhantomData, } impl AsFieldName for BasicOpt { - fn colname(&self) -> &str { - self.col.as_str() + fn colname(&self) -> &'static str { + self.col } - fn fieldname(&self) -> &str { - self.field.as_str() + fn fieldname(&self) -> &'static str { + self.field } } +impl Copy for BasicOpt {} impl AsOptField for BasicOpt {} @@ -25,7 +27,7 @@ impl BasicOpt where T: 'static + Clone + Send + Sync, { - pub fn new(col: impl Into, field: impl Into) -> Self { + pub fn new(col: &'static str, field: &'static str) -> Self { Self { col: col.into(), field: field.into(), diff --git a/welds/src/query/clause/clause_adder.rs b/welds/src/query/clause/clause_adder.rs index 259bdfe..e3b8a6a 100644 --- a/welds/src/query/clause/clause_adder.rs +++ b/welds/src/query/clause/clause_adder.rs @@ -201,3 +201,4 @@ impl ClauseAdder for ClauseColManual { Some(clause) } } + diff --git a/welds/src/query/clause/mod.rs b/welds/src/query/clause/mod.rs index 9f5e38f..b35d672 100644 --- a/welds/src/query/clause/mod.rs +++ b/welds/src/query/clause/mod.rs @@ -30,12 +30,22 @@ pub use clause_adder::ClauseAdder; // trait used to write assignments in a sql statement mod assignment_adder; + +#[cfg(feature = "unstable-api")] +mod or_and; +#[cfg(feature = "unstable-api")] +pub use or_and::ClauseAdderAndOrExt; +#[cfg(feature = "unstable-api")] +pub use or_and::and; +#[cfg(feature = "unstable-api")] +pub use or_and::or; + pub use assignment_adder::AssignmentAdder; pub struct ClauseColVal { pub null_clause: bool, pub not_clause: bool, - pub col: String, + pub col: &'static str, pub operator: &'static str, pub val: Option, } @@ -43,31 +53,31 @@ pub struct ClauseColVal { pub struct ClauseColValEqual { pub null_clause: bool, pub not_clause: bool, - pub col: String, + pub col: &'static str, pub operator: &'static str, pub val: Option, } pub struct ClauseColValIn { - pub col: String, + pub col: &'static str, pub operator: &'static str, pub list: Vec, } pub struct ClauseColValList { - pub col: String, + pub col: &'static str, pub operator: &'static str, pub list: Vec, } pub struct ClauseColManual { - pub(crate) col: Option, + pub(crate) col: Option<&'static str>, pub(crate) sql: String, pub(crate) params: Vec>, } pub struct AssignmentManual { - pub(crate) col: String, + pub(crate) col: &'static str, pub(crate) sql: String, pub(crate) params: Vec>, } @@ -77,8 +87,8 @@ pub struct AssignmentManual { // fieldname refers to what we want to get the column out as. // for example: select id as ids from bla. pub trait AsFieldName { - fn colname(&self) -> &str; - fn fieldname(&self) -> &str; + fn colname(&self) -> &'static str; + fn fieldname(&self) -> &'static str; } // marker trait to make sure a field is nullable diff --git a/welds/src/query/clause/numeric.rs b/welds/src/query/clause/numeric.rs index cb28d10..d159cef 100644 --- a/welds/src/query/clause/numeric.rs +++ b/welds/src/query/clause/numeric.rs @@ -5,18 +5,19 @@ use std::marker::PhantomData; use welds_connections::Param; /// Clauses for numeric types such as int, float, etc +#[derive(Copy,Clone)] pub struct Numeric { - col: String, - field: String, + col: &'static str, + field: &'static str, _t: PhantomData, } impl AsFieldName for Numeric { - fn colname(&self) -> &str { - self.col.as_str() + fn colname(&self) -> &'static str { + self.col } - fn fieldname(&self) -> &str { - self.field.as_str() + fn fieldname(&self) -> &'static str { + self.field } } @@ -24,7 +25,7 @@ impl Numeric where T: 'static + Clone + Send + Sync, { - pub fn new(col: impl Into, field: impl Into) -> Self { + pub fn new(col: &'static str, field: &'static str) -> Self { Self { col: col.into(), field: field.into(), diff --git a/welds/src/query/clause/numericopt.rs b/welds/src/query/clause/numericopt.rs index c722bcd..561021b 100644 --- a/welds/src/query/clause/numericopt.rs +++ b/welds/src/query/clause/numericopt.rs @@ -6,18 +6,19 @@ use crate::query::optional::Optional; use std::marker::PhantomData; use welds_connections::Param; +#[derive(Copy,Clone)] pub struct NumericOpt { - col: String, - field: String, + col: &'static str, + field: &'static str, _t: PhantomData, } impl AsFieldName for NumericOpt { - fn colname(&self) -> &str { - self.col.as_str() + fn colname(&self) -> &'static str { + self.col } - fn fieldname(&self) -> &str { - self.field.as_str() + fn fieldname(&self) -> &'static str { + self.field } } @@ -27,7 +28,7 @@ impl NumericOpt where T: 'static + Clone + Send + Sync, { - pub fn new(col: impl Into, field: impl Into) -> Self { + pub fn new(col: &'static str, field: &'static str) -> Self { Self { col: col.into(), field: field.into(), diff --git a/welds/src/query/clause/or_and.rs b/welds/src/query/clause/or_and.rs new file mode 100644 index 0000000..8344f43 --- /dev/null +++ b/welds/src/query/clause/or_and.rs @@ -0,0 +1,156 @@ +use crate::query::clause::{ClauseAdder, ParamArgs}; +use crate::writers::NextParam; +use welds_connections::Syntax; + +enum LogicalOp { + And, + Or, +} + +pub struct LogicalClause { + left_clause: Box, + operator: LogicalOp, + right_clause: Box, +} + +impl LogicalOp { + pub fn to_str(&self) -> &'static str { + match self { + LogicalOp::And => "AND", + LogicalOp::Or => "OR", + } + } +} + +pub fn or( + left_clause: Box, + right_clause: Box, +) -> Box { + Box::new(LogicalClause { + left_clause, + operator: LogicalOp::Or, + right_clause, + }) +} + +pub fn and( + left_clause: Box, + right_clause: Box, +) -> Box { + Box::new(LogicalClause { + left_clause, + operator: LogicalOp::And, + right_clause, + }) +} + +impl ClauseAdder for LogicalClause { + fn bind<'lam, 'args, 'p>(&'lam self, args: &'args mut ParamArgs<'p>) + where + 'lam: 'p, + { + self.left_clause.bind(args); + self.right_clause.bind(args); + } + + fn clause(&self, syntax: Syntax, alias: &str, next_params: &NextParam) -> Option { + let left = self.left_clause.clause(syntax, alias, next_params); + let right = self.right_clause.clause(syntax, alias, next_params); + let operator = self.operator.to_str(); + + // Both have some + if let Some(left) = &left + && let Some(right) = &right + { + return Some(format!("({left} {operator} {right})")); + } + // Left has some + if left.is_some() { + return left; + } + // Right has some + if right.is_some() { + return right; + } + // Both are none + None + } +} + +/// Extensions on ClauseAdder to add builder style (and/or) methods +pub trait ClauseAdderAndOrExt { + fn and(self: Box, other: Box) -> Box; + fn or(self: Box, other: Box) -> Box; +} + +impl ClauseAdderAndOrExt for CA +where + CA: ClauseAdder + 'static, +{ + fn and(self: Box, other: Box) -> Box { + and(self, other) + } + fn or(self: Box, other: Box) -> Box { + or(self, other) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::WeldsModel; + use welds_connections::Syntax; + + #[derive(Debug, Default, WeldsModel)] + #[welds(table = "test_table")] + #[welds_path(crate)] + struct TestModel { + #[welds(primary_key)] + pub id: i32, + #[welds(rename = "name_column")] + pub name: String, + pub is_active: bool, + pub score: f64, + } + + #[test] + fn test_and_logical_clause() { + let a = TestModelSchema::default(); + + let and_clause = and(a.id.equal(1), a.is_active.equal(true)); + + let sql = and_clause.clause(Syntax::Postgres, "t1", &NextParam::new(Syntax::Postgres)); + assert!(sql.is_some()); + assert_eq!(sql.unwrap(), "(t1.id = $1 AND t1.is_active = $2)"); + } + + #[test] + fn test_or_logical_clause() { + let a = TestModelSchema::default(); + + let or_clause = or(a.score.gt(0.5), a.name.equal("test".to_string())); + + let sql = or_clause.clause(Syntax::Postgres, "t1", &NextParam::new(Syntax::Postgres)); + assert!(sql.is_some()); + assert_eq!(sql.unwrap(), "(t1.score > $1 OR t1.name_column = $2)"); + } + + #[test] + fn test_nested_logical_clauses() { + let a = TestModelSchema::default(); + + // (id = 1 AND is_active = true) OR score > 0.5 + let and_clause = and(a.id.equal(1), a.is_active.equal(true)); + let nested = or(and_clause, a.score.gte(0.5)); + + let sql = nested.clause(Syntax::Postgres, "t1", &NextParam::new(Syntax::Postgres)); + assert!(sql.is_some()); + let sql_str = sql.unwrap(); + assert!(sql_str.contains("AND")); + assert!(sql_str.contains("OR")); + assert_eq!( + sql_str, + "((t1.id = $1 AND t1.is_active = $2) OR t1.score >= $3)" + ); + } +} diff --git a/welds/src/query/clause/text.rs b/welds/src/query/clause/text.rs index 49639a4..0d26266 100644 --- a/welds/src/query/clause/text.rs +++ b/welds/src/query/clause/text.rs @@ -2,26 +2,28 @@ use super::{AsFieldName, ClauseColVal, ClauseColValEqual, ClauseColValIn}; use std::marker::PhantomData; use welds_connections::Param; +#[derive(Clone)] pub struct Text { - col: String, - field: String, + col: &'static str, + field: &'static str, _t: PhantomData, } impl AsFieldName for Text { - fn colname(&self) -> &str { - self.col.as_str() + fn colname(&self) -> &'static str { + self.col } - fn fieldname(&self) -> &str { - self.field.as_str() + fn fieldname(&self) -> &'static str { + self.field } } +impl Copy for Text {} impl Text where T: 'static + Clone + Send + Sync, { - pub fn new(col: impl Into, field: impl Into) -> Self { + pub fn new(col: &'static str, field: &'static str) -> Self { Self { col: col.into(), field: field.into(), diff --git a/welds/src/query/clause/textopt.rs b/welds/src/query/clause/textopt.rs index 6f6bbfe..3963b69 100644 --- a/welds/src/query/clause/textopt.rs +++ b/welds/src/query/clause/textopt.rs @@ -4,20 +4,22 @@ use crate::query::optional::Optional; use std::marker::PhantomData; use welds_connections::Param; +#[derive(Clone)] pub struct TextOpt { - col: String, - field: String, + col: &'static str, + field: &'static str, _t: PhantomData, } impl AsFieldName for TextOpt { - fn colname(&self) -> &str { - self.col.as_str() + fn colname(&self) -> &'static str { + self.col } - fn fieldname(&self) -> &str { - self.field.as_str() + fn fieldname(&self) -> &'static str { + self.field } } +impl Copy for TextOpt {} impl AsOptField for TextOpt {} @@ -25,7 +27,7 @@ impl TextOpt where T: 'static + Clone + Send + Sync, { - pub fn new(col: impl Into, field: impl Into) -> Self { + pub fn new(col: &'static str, field: &'static str) -> Self { Self { col: col.into(), field: field.into(), diff --git a/welds/src/query/update/bulk/mod.rs b/welds/src/query/update/bulk/mod.rs index 23495a8..f3cb76c 100644 --- a/welds/src/query/update/bulk/mod.rs +++ b/welds/src/query/update/bulk/mod.rs @@ -168,7 +168,7 @@ where { let params: ManualParam = params.into(); let field = lam(Default::default()); - let col_raw = field.colname().to_string(); + let col_raw = field.colname(); let adder = AssignmentManual { col: col_raw,