From 74ed42c74dc620d280cca8c9572a20ee56669088 Mon Sep 17 00:00:00 2001 From: Burak T Date: Fri, 9 Aug 2024 02:28:18 -0400 Subject: [PATCH 1/9] changes for personalization --- src/db_queries.rs | 1 + src/generate.rs | 8 ++- src/main.rs | 35 ++++++++++- src/models.rs | 2 +- src/query_generate.rs | 132 +++++++++++++++++++++++++++++++++--------- src/utils.rs | 18 +++--- 6 files changed, 159 insertions(+), 37 deletions(-) diff --git a/src/db_queries.rs b/src/db_queries.rs index 389881e..46f31c3 100644 --- a/src/db_queries.rs +++ b/src/db_queries.rs @@ -86,6 +86,7 @@ WHERE AND c.table_name != '_sqlx_migrations' AND ($2 IS NULL OR c.table_name = ANY($2)) + AND c.data_type != 'USER-DEFINED' ORDER BY c.table_name, c.ordinal_position; diff --git a/src/generate.rs b/src/generate.rs index c03d390..7281260 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -15,6 +15,7 @@ pub async fn generate( context: Option<&str>, force: bool, include_tables: Option>, + exclude_tables: Vec, schemas: Option>, ) -> Result<(), Box> { // Connect to the Postgres database @@ -36,7 +37,12 @@ pub async fn generate( for row in &rows { unique.insert(row.table_name.clone()); } - let tables: Vec = unique.into_iter().collect::>(); + let tables: Vec = unique + .into_iter() + .collect::>() + .into_iter() + .filter(|e| !exclude_tables.contains(e)) + .collect(); println!("Outputting tables: {:?}", tables); diff --git a/src/main.rs b/src/main.rs index 8cc8e3f..40247d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -68,6 +68,16 @@ async fn main() -> Result<(), Box> { .use_delimiter(true) .help("Specify the table name(s)"), ) + .arg( + Arg::with_name("exclude") + .short('e') + .long("exclude") + .takes_value(true) + .value_name("SQLGEN_EXCLUDE") + .multiple(true) + .use_delimiter(true) + .help("Specify the excluded table name(s)"), + ) .arg( Arg::new("force") .short('f') @@ -195,7 +205,30 @@ async fn main() -> Result<(), Box> { let schemas: Option> = matches.values_of("schema").map(|schemas| schemas.collect()); let force = matches.is_present("force"); - generate::generate(output_folder, database_url, context, force, None, schemas).await?; + let include_tables = matches.values_of("table").map(|v| v.collect::>()); + let exclude_tables = matches + .values_of("exclude") + .map(|v| { + v.into_iter() + .map(|e| e.to_string()) + .collect::>() + }) + .unwrap_or(vec![]); + + if !exclude_tables.is_empty() { + println!("Excluding tables: {:?}", exclude_tables); + } + + generate::generate( + output_folder, + database_url, + context, + force, + include_tables, + exclude_tables, + schemas, + ) + .await?; } else if let Some(matches) = matches.subcommand_matches("migrate") { let input_migrations_folder = matches.value_of("migrations").unwrap_or("./migrations"); println!( diff --git a/src/models.rs b/src/models.rs index def398f..9de43b2 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,4 +1,4 @@ -#[derive(sqlx::FromRow)] +#[derive(sqlx::FromRow, Clone)] pub struct TableColumn { pub(crate) table_name: String, pub(crate) column_name: String, diff --git a/src/query_generate.rs b/src/query_generate.rs index c2e37d8..8f5b994 100644 --- a/src/query_generate.rs +++ b/src/query_generate.rs @@ -28,14 +28,14 @@ pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { query_code.push_str(&format!("impl {}Set {{\n", struct_name)); // Generate query code for SELECT statements - query_code.push_str(&generate_select_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_select_query_code(table_name, schema_name, &rows)); query_code.push('\n'); // Generate query code for SELECT BY PK statements query_code.push_str(&generate_select_by_pk_query_code( table_name, schema_name, - rows, + &rows, )); query_code.push('\n'); @@ -43,7 +43,7 @@ pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { query_code.push_str(&generate_select_many_by_pks_query_code( table_name, schema_name, - rows, + &rows, )); query_code.push('\n'); @@ -51,30 +51,30 @@ pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { query_code.push_str(&generate_select_by_pk_query_code_optional( table_name, schema_name, - rows, + &rows, )); query_code.push('\n'); - query_code.push_str(&generate_unique_query_code(table_name, schema_name, rows)); - query_code.push('\n'); + // query_code.push_str(&generate_unique_query_code(table_name, schema_name, &rows)); + // query_code.push('\n'); query_code.push_str(&generate_select_all_fk_queries( table_name, schema_name, - rows, + &rows, )); query_code.push('\n'); // Generate query code for INSERT statements - query_code.push_str(&generate_insert_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_insert_query_code(table_name, schema_name, &rows)); query_code.push('\n'); // Generate query code for UPDATE statements - query_code.push_str(&generate_update_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_update_query_code(table_name, schema_name, &rows)); query_code.push('\n'); // Generate query code for DELETE statements - query_code.push_str(&generate_delete_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_delete_query_code(table_name, schema_name, &rows)); query_code.push('\n'); query_code.push_str("}\n"); @@ -121,9 +121,14 @@ fn generate_select_by_pk_query_code( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; format!( "{}: {}", - r.column_name, + column_name, convert_data_type(r.udt_name.as_str()) ) }) @@ -146,7 +151,14 @@ fn generate_select_by_pk_query_code( let bind = rows .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) - .map(|r| format!(" .bind({})\n", r.column_name)) + .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; + format!(" .bind({})\n", column_name) + }) .collect::>() .join(""); @@ -185,9 +197,14 @@ fn generate_select_many_by_pks_query_code( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; format!( "{}_list: Vec<{}>", - r.column_name, + column_name, convert_data_type(r.udt_name.as_str()) ) }) @@ -210,7 +227,14 @@ fn generate_select_many_by_pks_query_code( let bind = rows .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) - .map(|r| format!(" .bind({}_list)\n", r.column_name)) + .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; + format!(" .bind({}_list)\n", column_name) + }) .collect::>() .join(""); @@ -244,9 +268,14 @@ fn generate_select_by_pk_query_code_optional( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; format!( "{}: {}", - r.column_name, + column_name, convert_data_type(r.udt_name.as_str()) ) }) @@ -271,7 +300,14 @@ fn generate_select_by_pk_query_code_optional( let bind = rows .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) - .map(|r| format!(" .bind({})\n", r.column_name)) + .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; + format!(" .bind({})\n", column_name) + }) .collect::>() .join(""); @@ -337,9 +373,14 @@ fn generate_select_by_unique_query_code( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; format!( "{}: {}", - r.column_name, + column_name, convert_data_type(r.udt_name.as_str()) ) }) @@ -347,7 +388,7 @@ fn generate_select_by_unique_query_code( .join(", "); select_code.push_str(&format!( - " pub async fn by_{}<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result<{}> {{\n", + " pub async fn unique_by_{}<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result<{}> {{\n", unique_name, fn_args, struct_name )); @@ -362,7 +403,14 @@ fn generate_select_by_unique_query_code( let bind = rows .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) - .map(|r| format!(" .bind({})\n", r.column_name)) + .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; + format!(" .bind({})\n", column_name) + }) .collect::>() .join(""); @@ -396,9 +444,14 @@ fn generate_select_many_by_uniques_query_code( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; format!( "{}_list: Vec<{}>", - r.column_name, + column_name, convert_data_type(r.udt_name.as_str()) ) }) @@ -406,7 +459,7 @@ fn generate_select_many_by_uniques_query_code( .join(", "); select_code.push_str(&format!( - " pub async fn many_by_{}_list<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", + " pub async fn unique_many_by_{}_list<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", unique_name, fn_args, struct_name )); @@ -421,7 +474,14 @@ fn generate_select_many_by_uniques_query_code( let bind = rows .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) - .map(|r| format!(" .bind({}_list)\n", r.column_name)) + .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; + format!(" .bind({}_list)\n", column_name) + }) .collect::>() .join(""); @@ -450,9 +510,14 @@ fn generate_select_by_unique_query_code_optional( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; format!( "{}: {}", - r.column_name, + column_name, convert_data_type(r.udt_name.as_str()) ) }) @@ -460,7 +525,7 @@ fn generate_select_by_unique_query_code_optional( .join(", "); select_code.push_str(&format!( - " pub async fn by_{}_optional<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", + " pub async fn unique_by_{}_optional<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", unique_name, fn_args, struct_name @@ -477,7 +542,14 @@ fn generate_select_by_unique_query_code_optional( let bind = rows .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) - .map(|r| format!(" .bind({})\n", r.column_name)) + .map(|r| { + let column_name = if r.column_name.as_str() == "type" { + "r#type" + } else { + &r.column_name + }; + format!(" .bind({})\n", column_name) + }) .collect::>() .join(""); @@ -543,7 +615,7 @@ fn generate_select_by_fk_query_code( select_code.push_str(&format!( " pub async fn all_by_{}_{}<'e, E: PgExecutor<'e>>(executor: E, {}_{}: {}) -> Result> {{\n", to_snake_case(foreign_row_table_name), - foreign_row_column_name, + to_snake_case(column_name), to_snake_case(foreign_row_table_name), foreign_row_column_name, data_type, @@ -569,6 +641,7 @@ fn generate_select_by_fk_query_code( fn generate_insert_query_code(table_name: &str, schema_name: &str, rows: &[TableColumn]) -> String { let struct_name = to_pascal_case(table_name); let mut insert_code = String::new(); + insert_code.push_str(&format!( " pub async fn insert<'e, E: PgExecutor<'e>>(&self, executor: E, {}: {}) -> Result<{}> {{\n", to_snake_case(table_name), @@ -659,10 +732,15 @@ fn generate_value_list(table_name: &str, rows: &[TableColumn]) -> String { rows.iter() .filter(|row| row.table_name == table_name) .map(|row| { + let column_name = if row.column_name.as_str() == "type" { + "r#type" + } else { + &row.column_name + }; format!( ".bind({}.{})", to_snake_case(&row.table_name), - to_snake_case(&row.column_name) + to_snake_case(column_name) ) }) .collect::>() diff --git a/src/utils.rs b/src/utils.rs index cc051cd..d3bb1b4 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -31,12 +31,15 @@ pub fn generate_struct_code(table_name: &str, rows: &Vec) -> String for row in rows { if row.table_name == table_name { - let column_name = to_snake_case(&row.column_name); + let mut column_name = to_snake_case(&row.column_name); let mut data_type = convert_data_type(&row.udt_name); let optional_type = format!("Option<{}>", data_type); if row.is_nullable { data_type = optional_type; } + if column_name.as_str() == "type" { + column_name = String::from("r#type"); + } struct_code.push_str(&format!(" pub {}: {},\n", column_name, data_type)); } @@ -60,18 +63,19 @@ pub fn convert_data_type(data_type: &str) -> String { "bool" | "boolean" => "bool", "bytea" => "Vec", // is this right? "char" | "bpchar" | "character" => "String", - "date" => "chrono::NaiveDate", + "date" => "time::Date", "float4" | "real" => "f32", - "float8" | "double precision" => "f64", + "float8" | "double precision" | "numeric" => "f64", "int2" | "smallint" | "smallserial" => "i16", "int4" | "int" | "serial" => "i32", "int8" | "bigint" | "bigserial" => "i64", "void" => "()", "jsonb" | "json" => "serde_json::Value", - "text" | "varchar" | "name" | "citext" => "String", - "time" => "chrono::NaiveTime", - "timestamp" => "chrono::NaiveDateTime", - "timestamptz" => "chrono::DateTime", + "text" | "_text" | "varchar" | "name" | "citext" => "String", + "time" => "time::Time", + "timestamp" => "time::OffsetDateTime", + "timestamptz" => "time::OffsetDateTime", + "interval" => "sqlx::postgres::types::PgInterval", "uuid" => "uuid::Uuid", _ => panic!("Unknown type: {}", data_type), } From db093c798d245baa2eeecc329bd7bb67101d0cb9 Mon Sep 17 00:00:00 2001 From: Burak T Date: Sat, 10 Aug 2024 20:38:49 -0400 Subject: [PATCH 2/9] option to derive w serde --- src/generate.rs | 3 ++- src/main.rs | 11 +++++++++++ src/utils.rs | 17 +++++++++++++++-- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/generate.rs b/src/generate.rs index 7281260..0fa51cb 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -10,6 +10,7 @@ use crate::utils::{generate_struct_code, to_pascal_case, to_snake_case}; use crate::query_generate::generate_query_code; pub async fn generate( + enable_serde: bool, output_folder: &str, database_url: &str, context: Option<&str>, @@ -54,7 +55,7 @@ pub async fn generate( } } // Generate the struct code based on the row - let struct_code = generate_struct_code(&table, &rows); + let struct_code = generate_struct_code(&table, &rows, enable_serde); // Generate the query code based on the row let query_code = generate_query_code(&table, &rows); diff --git a/src/main.rs b/src/main.rs index 40247d9..64200c2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,6 +22,14 @@ async fn main() -> Result<(), Box> { .help("Sets the output folder for generated structs") .takes_value(true), ) + .arg( + Arg::with_name("serde") + .long("serde") + .default_value("true") + .value_name("SQLGEN_ENABLE_SERDE") + .help("Adds Serde derices to created structs") + .takes_value(false), + ) .arg( Arg::with_name("migrations") .short('m') @@ -219,7 +227,10 @@ async fn main() -> Result<(), Box> { println!("Excluding tables: {:?}", exclude_tables); } + let enable_serde = matches.is_present("serde"); + generate::generate( + enable_serde, output_folder, database_url, context, diff --git a/src/utils.rs b/src/utils.rs index d3bb1b4..3a3ff17 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -20,13 +20,26 @@ pub(crate) fn to_snake_case(input: &str) -> String { output } -pub fn generate_struct_code(table_name: &str, rows: &Vec) -> String { +pub fn generate_struct_code( + table_name: &str, + rows: &Vec, + enable_serde: bool, +) -> String { let struct_name = to_pascal_case(table_name); let mut struct_code = String::new(); + let serde_derives = if enable_serde { + ", serde::Serialize, serde::Deserialize" + } else { + "" + }; + struct_code.push_str("#![allow(dead_code)]\n"); struct_code.push_str("// Generated with sql-gen\n// https://github.com/jayy-lmao/sql-gen\n\n"); - struct_code.push_str("#[derive(sqlx::FromRow, Debug)]\n"); + struct_code.push_str(&format!( + "#[derive(sqlx::FromRow, Debug, Clone{})]\n", + serde_derives + )); struct_code.push_str(&format!("pub struct {} {{\n", struct_name)); for row in rows { From 8b4e08b321810907b0fb4d81e6e9ced6504e1f15 Mon Sep 17 00:00:00 2001 From: Burak T Date: Sun, 11 Aug 2024 21:49:50 -0400 Subject: [PATCH 3/9] create enums for user defined types --- src/db_queries.rs | 28 ++++++++++++++++++++++++++-- src/generate.rs | 44 ++++++++++++++++++++++++++++++++++++++++---- src/models.rs | 6 ++++++ src/utils.rs | 40 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 109 insertions(+), 9 deletions(-) diff --git a/src/db_queries.rs b/src/db_queries.rs index 46f31c3..da00159 100644 --- a/src/db_queries.rs +++ b/src/db_queries.rs @@ -1,6 +1,6 @@ use sqlx::PgPool; -use crate::models::TableColumn; +use crate::models::{TableColumn, UserDefinedEnums}; pub async fn get_table_columns( pool: &PgPool, @@ -86,7 +86,6 @@ WHERE AND c.table_name != '_sqlx_migrations' AND ($2 IS NULL OR c.table_name = ANY($2)) - AND c.data_type != 'USER-DEFINED' ORDER BY c.table_name, c.ordinal_position; @@ -99,3 +98,28 @@ ORDER BY .await?; Ok(rows) } + +pub async fn get_user_defined_enums( + udt_names: &Vec, + pool: &PgPool, +) -> sqlx::Result> { + let query = " + SELECT + t.typname AS enum_name, + e.enumlabel AS enum_value + FROM + pg_type t + JOIN pg_enum e ON t.oid = e.enumtypid + WHERE + t.typname = ANY($1) + ORDER BY + t.typname, + e.enumsortorder; + "; + + let rows = sqlx::query_as::<_, UserDefinedEnums>(query) + .bind(udt_names) + .fetch_all(pool) + .await?; + Ok(rows) +} diff --git a/src/generate.rs b/src/generate.rs index 0fa51cb..54e3a3a 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -3,9 +3,9 @@ use sqlx::PgPool; use std::fs; use std::path::Path; -use crate::db_queries::get_table_columns; +use crate::db_queries::{get_table_columns, get_user_defined_enums}; use crate::models::TableColumn; -use crate::utils::{generate_struct_code, to_pascal_case, to_snake_case}; +use crate::utils::{generate_enum_code, generate_struct_code, to_pascal_case, to_snake_case}; use crate::query_generate::generate_query_code; @@ -30,7 +30,23 @@ pub async fn generate( let default_schema = "public"; let rows = get_table_columns(&pool, schemas.unwrap_or(vec![default_schema]), None).await?; + let user_defined = rows + .iter() + .filter_map(|e| { + if e.data_type.as_str() == "USER-DEFINED" && e.udt_name.as_str() != "geometry" { + Some(e.udt_name.clone()) + } else { + None + } + }) + .collect::>(); + let enum_rows = get_user_defined_enums(&user_defined, &pool).await?; + let mut unique_enums = std::collections::BTreeSet::new(); + for row in &enum_rows { + unique_enums.insert(row.enum_name.clone()); + } + let enums = unique_enums.into_iter().collect::>(); // Create the output folder if it doesn't exist fs::create_dir_all(output_folder)?; @@ -45,8 +61,18 @@ pub async fn generate( .filter(|e| !exclude_tables.contains(e)) .collect(); + if !enums.is_empty() { + println!("Outputting user defined enums: {:?}", enums); + } println!("Outputting tables: {:?}", tables); + let mut rs_enums = Vec::new(); + + for user_enum in enums { + let enum_code = generate_enum_code(&user_enum, &enum_rows, enable_serde); + rs_enums.push(enum_code); + } + // Generate structs and queries for each table for table in &tables { if let Some(ts) = include_tables.clone() { @@ -82,18 +108,28 @@ pub async fn generate( } } - let context_code = generate_db_context(context.unwrap_or(&database_name), &tables, &rows); + let context_code = + generate_db_context(context.unwrap_or(&database_name), &rs_enums, &tables, &rows); let context_file_path = format!("{}/mod.rs", output_folder); fs::write(context_file_path, context_code)?; Ok(()) } -fn generate_db_context(database_name: &str, tables: &[String], _rows: &[TableColumn]) -> String { +fn generate_db_context( + database_name: &str, + enums: &[String], + tables: &[String], + _rows: &[TableColumn], +) -> String { let mut db_context_code = String::new(); db_context_code.push_str("#![allow(dead_code)]\n"); db_context_code .push_str("// Generated with sql-gen\n//https://github.com/jayy-lmao/sql-gen\n\n"); + for enum_item in enums { + db_context_code.push_str(enum_item); + db_context_code.push_str("\n\n"); + } for table in tables { db_context_code.push_str(&format!("pub mod {};\n", to_snake_case(table))); db_context_code.push_str(&format!( diff --git a/src/models.rs b/src/models.rs index 9de43b2..656b0a6 100644 --- a/src/models.rs +++ b/src/models.rs @@ -12,3 +12,9 @@ pub struct TableColumn { // #todo pub(crate) table_schema: String, } + +#[derive(sqlx::FromRow, Clone)] +pub struct UserDefinedEnums { + pub(crate) enum_name: String, + pub(crate) enum_value: String, +} diff --git a/src/utils.rs b/src/utils.rs index 3a3ff17..0d28a51 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,4 @@ -use crate::models::TableColumn; +use crate::models::{TableColumn, UserDefinedEnums}; pub(crate) fn to_snake_case(input: &str) -> String { let mut output = String::new(); @@ -20,6 +20,38 @@ pub(crate) fn to_snake_case(input: &str) -> String { output } +pub fn generate_enum_code( + enum_name: &str, + enum_rows: &Vec, + enable_serde: bool, +) -> String { + let rs_enum_name = to_pascal_case(enum_name); + let mut enum_code = String::new(); + + let serde_derives = if enable_serde { + ", serde::Serialize, serde::Deserialize" + } else { + "" + }; + enum_code.push_str(&format!( + "#[derive(sqlx::Type, Debug, Clone, Eq, PartialEq{})]\n", + serde_derives + )); + enum_code.push_str(&format!(r#"#[sqlx(type_name = "{}")]"#, enum_name)); + enum_code.push_str("\n"); + enum_code.push_str(&format!("pub enum {} {{\n", rs_enum_name)); + + for row in enum_rows.iter().filter(|e| &e.enum_name == enum_name) { + enum_code.push_str(&format!(r#" #[sqlx(rename = "{}")]"#, row.enum_value)); + enum_code.push_str("\n"); + enum_code.push_str(&format!(" {},\n", to_pascal_case(&row.enum_value))) + } + + enum_code.push_str("}\n"); + + enum_code +} + pub fn generate_struct_code( table_name: &str, rows: &Vec, @@ -72,6 +104,8 @@ pub fn convert_data_type(data_type: &str) -> String { return vec_type; } + let if_else = format!("crate::{}", to_pascal_case(data_type)); + match data_type { "bool" | "boolean" => "bool", "bytea" => "Vec", // is this right? @@ -84,13 +118,13 @@ pub fn convert_data_type(data_type: &str) -> String { "int8" | "bigint" | "bigserial" => "i64", "void" => "()", "jsonb" | "json" => "serde_json::Value", - "text" | "_text" | "varchar" | "name" | "citext" => "String", + "text" | "_text" | "varchar" | "name" | "citext" | "geometry" => "String", "time" => "time::Time", "timestamp" => "time::OffsetDateTime", "timestamptz" => "time::OffsetDateTime", "interval" => "sqlx::postgres::types::PgInterval", "uuid" => "uuid::Uuid", - _ => panic!("Unknown type: {}", data_type), + _ => if_else.as_str(), } .to_string() } From 353bdc10942d79b238e855bf028309f59eca199d Mon Sep 17 00:00:00 2001 From: Burak T Date: Tue, 13 Aug 2024 13:53:02 -0400 Subject: [PATCH 4/9] data types & date time lib --- Cargo.toml | 1 + src/generate.rs | 10 ++++ src/main.rs | 22 ++++++++- src/query_generate.rs | 112 +++++++++++++----------------------------- src/utils.rs | 99 ++++++++++++++++++++++++++++++------- 5 files changed, 147 insertions(+), 97 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fbf0cad..252376a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,4 @@ tokio = { version = "1", features = ["full"] } dotenv = "0.15.0" testcontainers = { version ="0.15.0" } testcontainers-modules = { version = "0.3.5", features = ["postgres"] } +once_cell = "1.19.0" diff --git a/src/generate.rs b/src/generate.rs index 54e3a3a..a35323e 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -8,6 +8,8 @@ use crate::models::TableColumn; use crate::utils::{generate_enum_code, generate_struct_code, to_pascal_case, to_snake_case}; use crate::query_generate::generate_query_code; +use crate::utils::{DateTimeLib, SqlGenState}; +use crate::STATE; pub async fn generate( enable_serde: bool, @@ -18,6 +20,7 @@ pub async fn generate( include_tables: Option>, exclude_tables: Vec, schemas: Option>, + date_time_lib: DateTimeLib, ) -> Result<(), Box> { // Connect to the Postgres database let pool = PgPoolOptions::new() @@ -66,6 +69,13 @@ pub async fn generate( } println!("Outputting tables: {:?}", tables); + STATE + .set(SqlGenState { + user_defined: enums.clone(), + date_time_lib, + }) + .expect("Unable to set state"); + let mut rs_enums = Vec::new(); for user_enum in enums { diff --git a/src/main.rs b/src/main.rs index 64200c2..4f7e5ca 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,6 @@ use clap::{App, Arg, SubCommand}; +use once_cell::sync::OnceCell; +use utils::{DateTimeLib, SqlGenState}; mod db_queries; mod generate; @@ -7,6 +9,8 @@ mod models; mod query_generate; mod utils; +pub static STATE: OnceCell = OnceCell::new(); + #[tokio::main] async fn main() -> Result<(), Box> { dotenv::dotenv().ok(); @@ -94,7 +98,16 @@ async fn main() -> Result<(), Box> { .takes_value(false) .required(false) .help("Overwrites existing files sharing names in that folder"), - ); + ) + .arg( + Arg::with_name("datetime-lib") + .long("datetime-lib") + .default_value("chrono") + .possible_values(&["chrono", "time"]) + .value_name("SQLGEN_DATETIME_LIB") + .help("Specifies the library to use for date and time handling") + .takes_value(true), + ); let migrate_subcommand = SubCommand::with_name("migrate") .about("Generate SQL migrations based on struct differences") @@ -229,6 +242,12 @@ async fn main() -> Result<(), Box> { let enable_serde = matches.is_present("serde"); + let date_time_lib = matches + .value_of("datetime-lib") + .map(|e| e.to_string()) + .unwrap(); + let date_time_lib = DateTimeLib::from(date_time_lib); + generate::generate( enable_serde, output_folder, @@ -238,6 +257,7 @@ async fn main() -> Result<(), Box> { include_tables, exclude_tables, schemas, + date_time_lib, ) .await?; } else if let Some(matches) = matches.subcommand_matches("migrate") { diff --git a/src/query_generate.rs b/src/query_generate.rs index 8f5b994..8ed574d 100644 --- a/src/query_generate.rs +++ b/src/query_generate.rs @@ -1,6 +1,6 @@ use crate::{ models::TableColumn, - utils::{convert_data_type, to_pascal_case, to_snake_case}, + utils::{convert_data_type, rust_type_fix, to_pascal_case, to_snake_case}, }; pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { @@ -121,14 +121,9 @@ fn generate_select_by_pk_query_code( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; format!( "{}: {}", - column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -152,12 +147,10 @@ fn generate_select_by_pk_query_code( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; - format!(" .bind({})\n", column_name) + format!( + " .bind({})\n", + rust_type_fix(r.column_name.as_str()) + ) }) .collect::>() .join(""); @@ -197,14 +190,9 @@ fn generate_select_many_by_pks_query_code( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; format!( "{}_list: Vec<{}>", - column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -228,12 +216,10 @@ fn generate_select_many_by_pks_query_code( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; - format!(" .bind({}_list)\n", column_name) + format!( + " .bind({}_list)\n", + rust_type_fix(r.column_name.as_str()) + ) }) .collect::>() .join(""); @@ -268,14 +254,9 @@ fn generate_select_by_pk_query_code_optional( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; format!( "{}: {}", - column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -301,12 +282,10 @@ fn generate_select_by_pk_query_code_optional( .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; - format!(" .bind({})\n", column_name) + format!( + " .bind({})\n", + rust_type_fix(r.column_name.as_str()) + ) }) .collect::>() .join(""); @@ -373,14 +352,9 @@ fn generate_select_by_unique_query_code( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; format!( "{}: {}", - column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -404,12 +378,10 @@ fn generate_select_by_unique_query_code( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; - format!(" .bind({})\n", column_name) + format!( + " .bind({})\n", + rust_type_fix(r.column_name.as_str()) + ) }) .collect::>() .join(""); @@ -444,14 +416,9 @@ fn generate_select_many_by_uniques_query_code( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; format!( "{}_list: Vec<{}>", - column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -475,12 +442,10 @@ fn generate_select_many_by_uniques_query_code( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; - format!(" .bind({}_list)\n", column_name) + format!( + " .bind({}_list)\n", + rust_type_fix(r.column_name.as_str()) + ) }) .collect::>() .join(""); @@ -510,14 +475,9 @@ fn generate_select_by_unique_query_code_optional( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; format!( "{}: {}", - column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -543,12 +503,10 @@ fn generate_select_by_unique_query_code_optional( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { - let column_name = if r.column_name.as_str() == "type" { - "r#type" - } else { - &r.column_name - }; - format!(" .bind({})\n", column_name) + format!( + " .bind({})\n", + rust_type_fix(r.column_name.as_str()) + ) }) .collect::>() .join(""); @@ -732,15 +690,11 @@ fn generate_value_list(table_name: &str, rows: &[TableColumn]) -> String { rows.iter() .filter(|row| row.table_name == table_name) .map(|row| { - let column_name = if row.column_name.as_str() == "type" { - "r#type" - } else { - &row.column_name - }; + let column_name = rust_type_fix(row.column_name.as_str()); format!( ".bind({}.{})", to_snake_case(&row.table_name), - to_snake_case(column_name) + to_snake_case(&column_name) ) }) .collect::>() diff --git a/src/utils.rs b/src/utils.rs index 0d28a51..4b54719 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,56 @@ -use crate::models::{TableColumn, UserDefinedEnums}; +use crate::{ + models::{TableColumn, UserDefinedEnums}, + STATE, +}; + +#[derive(Debug)] +pub(crate) enum DateTimeLib { + Time, + Chrono, +} + +impl DateTimeLib { + pub(crate) fn date_type(&self) -> &str { + match self { + DateTimeLib::Time => "time::Date", + DateTimeLib::Chrono => "chrono::NaiveDate", + } + } + pub(crate) fn time_type(&self) -> &str { + match self { + DateTimeLib::Time => "time::Time", + DateTimeLib::Chrono => "chrono::NaiveTime", + } + } + pub(crate) fn timestamp_type(&self) -> &str { + match self { + DateTimeLib::Time => "time::OffsetDateTime", + DateTimeLib::Chrono => "chrono::NaiveDateTime", + } + } + pub(crate) fn timestampz_type(&self) -> &str { + match self { + DateTimeLib::Time => "time::OffsetDateTime", + DateTimeLib::Chrono => "chrono::DateTime", + } + } +} + +impl From for DateTimeLib { + fn from(value: String) -> Self { + if value == "chrono" { + DateTimeLib::Chrono + } else { + DateTimeLib::Time + } + } +} + +#[derive(Debug)] +pub(crate) struct SqlGenState { + pub user_defined: Vec, + pub date_time_lib: DateTimeLib, +} pub(crate) fn to_snake_case(input: &str) -> String { let mut output = String::new(); @@ -76,17 +128,16 @@ pub fn generate_struct_code( for row in rows { if row.table_name == table_name { - let mut column_name = to_snake_case(&row.column_name); + let column_name = to_snake_case(&row.column_name); let mut data_type = convert_data_type(&row.udt_name); - let optional_type = format!("Option<{}>", data_type); if row.is_nullable { - data_type = optional_type; + data_type = format!("Option<{}>", data_type); } - if column_name.as_str() == "type" { - column_name = String::from("r#type"); - } - - struct_code.push_str(&format!(" pub {}: {},\n", column_name, data_type)); + struct_code.push_str(&format!( + " pub {}: {},\n", + rust_type_fix(column_name.as_str()), + data_type + )); } } struct_code.push_str("}\n"); @@ -103,14 +154,13 @@ pub fn convert_data_type(data_type: &str) -> String { let vec_type = format!("Vec<{}>", array_of_type); return vec_type; } - - let if_else = format!("crate::{}", to_pascal_case(data_type)); + let state = STATE.get().unwrap(); match data_type { "bool" | "boolean" => "bool", "bytea" => "Vec", // is this right? "char" | "bpchar" | "character" => "String", - "date" => "time::Date", + "date" => state.date_time_lib.date_type(), "float4" | "real" => "f32", "float8" | "double precision" | "numeric" => "f64", "int2" | "smallint" | "smallserial" => "i16", @@ -118,13 +168,20 @@ pub fn convert_data_type(data_type: &str) -> String { "int8" | "bigint" | "bigserial" => "i64", "void" => "()", "jsonb" | "json" => "serde_json::Value", - "text" | "_text" | "varchar" | "name" | "citext" | "geometry" => "String", - "time" => "time::Time", - "timestamp" => "time::OffsetDateTime", - "timestamptz" => "time::OffsetDateTime", + "text" | "_text" | "varchar" | "name" | "citext" => "String", + "geometry" => "String", // when sqlx supports geo types we could change this + "time" => state.date_time_lib.time_type(), + "timestamp" => state.date_time_lib.timestamp_type(), + "timestamptz" => state.date_time_lib.timestampz_type(), "interval" => "sqlx::postgres::types::PgInterval", "uuid" => "uuid::Uuid", - _ => if_else.as_str(), + _ => { + if state.user_defined.contains(&data_type.to_string()) { + return format!("crate::{}", to_pascal_case(data_type)); + } else { + panic!("Unknown type: {}", data_type) + } + } } .to_string() } @@ -222,3 +279,11 @@ pub fn to_pascal_case(input: &str) -> String { output } + +pub(crate) fn rust_type_fix(column_name: &str) -> String { + if column_name == "type" { + String::from("r#type") + } else { + column_name.to_string() + } +} From 0451b8e26a2eab33c429c2978b76cf1d69ae1251 Mon Sep 17 00:00:00 2001 From: Burak T Date: Thu, 15 Aug 2024 14:25:59 -0400 Subject: [PATCH 5/9] (fix) generated unique queries may have dup functions --- src/query_generate.rs | 53 +++++++++++++++++++++++++++++++------------ src/utils.rs | 2 +- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/query_generate.rs b/src/query_generate.rs index 8ed574d..ee71886 100644 --- a/src/query_generate.rs +++ b/src/query_generate.rs @@ -55,8 +55,8 @@ pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { )); query_code.push('\n'); - // query_code.push_str(&generate_unique_query_code(table_name, schema_name, &rows)); - // query_code.push('\n'); + query_code.push_str(&generate_unique_query_code(table_name, schema_name, &rows)); + query_code.push('\n'); query_code.push_str(&generate_select_all_fk_queries( table_name, @@ -304,7 +304,15 @@ fn generate_select_by_pk_query_code_optional( fn generate_unique_query_code(table_name: &str, schema_name: &str, rows: &[TableColumn]) -> String { let mut code = String::new(); - for row in rows.iter().filter(|r| r.is_unique) { + let mut unique_rows = std::collections::HashMap::<(&str, &str), &TableColumn>::new(); + for row in rows + .iter() + .filter(|r| r.table_name.as_str() == table_name && r.is_unique) + { + unique_rows.insert((&row.table_name, &row.column_name), row); + } + let unique_users_vec: Vec<&TableColumn> = unique_rows.into_values().collect(); + for row in unique_users_vec.iter() { code.push_str( generate_select_by_unique_query_code(&row.column_name, table_name, schema_name, rows) .as_str(), @@ -358,8 +366,12 @@ fn generate_select_by_unique_query_code( convert_data_type(r.udt_name.as_str()) ) }) - .collect::>() - .join(", "); + .collect::>(); + + if fn_args.is_empty() { + return String::from(""); + } + let fn_args = fn_args.join(", "); select_code.push_str(&format!( " pub async fn unique_by_{}<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result<{}> {{\n", @@ -416,14 +428,23 @@ fn generate_select_many_by_uniques_query_code( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { - format!( - "{}_list: Vec<{}>", - rust_type_fix(r.column_name.as_str()), - convert_data_type(r.udt_name.as_str()) - ) + let rs_type = convert_data_type(r.udt_name.as_str()); + if rs_type.starts_with("Vec<") { + None + } else { + Some(format!( + "{}_list: Vec<{}>", + rust_type_fix(r.column_name.as_str()), + rs_type + )) + } }) - .collect::>() - .join(", "); + .collect::>>(); + + if fn_args.is_none() { + return String::from(""); + } + let fn_args = fn_args.unwrap().join(", "); select_code.push_str(&format!( " pub async fn unique_many_by_{}_list<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", @@ -481,8 +502,12 @@ fn generate_select_by_unique_query_code_optional( convert_data_type(r.udt_name.as_str()) ) }) - .collect::>() - .join(", "); + .collect::>(); + + if fn_args.is_empty() { + return String::from(""); + } + let fn_args = fn_args.join(", "); select_code.push_str(&format!( " pub async fn unique_by_{}_optional<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", diff --git a/src/utils.rs b/src/utils.rs index 4b54719..4adb703 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -168,7 +168,7 @@ pub fn convert_data_type(data_type: &str) -> String { "int8" | "bigint" | "bigserial" => "i64", "void" => "()", "jsonb" | "json" => "serde_json::Value", - "text" | "_text" | "varchar" | "name" | "citext" => "String", + "text" | "varchar" | "name" | "citext" => "String", "geometry" => "String", // when sqlx supports geo types we could change this "time" => state.date_time_lib.time_type(), "timestamp" => state.date_time_lib.timestamp_type(), From 56cf8c3132e14bf6cfb9af62e94162624e50e43e Mon Sep 17 00:00:00 2001 From: Burak T Date: Thu, 15 Aug 2024 14:41:23 -0400 Subject: [PATCH 6/9] (fix) select by fk query code --- src/query_generate.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/query_generate.rs b/src/query_generate.rs index ee71886..9df8d8a 100644 --- a/src/query_generate.rs +++ b/src/query_generate.rs @@ -598,7 +598,7 @@ fn generate_select_by_fk_query_code( select_code.push_str(&format!( " pub async fn all_by_{}_{}<'e, E: PgExecutor<'e>>(executor: E, {}_{}: {}) -> Result> {{\n", to_snake_case(foreign_row_table_name), - to_snake_case(column_name), + foreign_row_column_name, to_snake_case(foreign_row_table_name), foreign_row_column_name, data_type, From e56885401e29bed5118b855de68af8f8477f76eb Mon Sep 17 00:00:00 2001 From: Burak T Date: Thu, 15 Aug 2024 14:53:39 -0400 Subject: [PATCH 7/9] (fix) drop once_cell use std OnceLock --- Cargo.toml | 5 ++--- src/main.rs | 5 +++-- src/utils.rs | 8 +++++++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 252376a..d8b15f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ description = "A CLI tool for generating models based on a SQL Database using SQ license = "MIT" [dependencies] -sqlx = { version = "0.7", features = ["postgres","runtime-tokio"] } +sqlx = { version = "0.7", features = ["postgres", "runtime-tokio"] } sqlx-cli = "0.7" clap = "3.0" regex = "1.5" @@ -14,6 +14,5 @@ chrono = "0.4" # regex = "1.5" tokio = { version = "1", features = ["full"] } dotenv = "0.15.0" -testcontainers = { version ="0.15.0" } +testcontainers = { version = "0.15.0" } testcontainers-modules = { version = "0.3.5", features = ["postgres"] } -once_cell = "1.19.0" diff --git a/src/main.rs b/src/main.rs index 4f7e5ca..74c0a7b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ +use std::sync::OnceLock; + use clap::{App, Arg, SubCommand}; -use once_cell::sync::OnceCell; use utils::{DateTimeLib, SqlGenState}; mod db_queries; @@ -9,7 +10,7 @@ mod models; mod query_generate; mod utils; -pub static STATE: OnceCell = OnceCell::new(); +pub(crate) static STATE: OnceLock = OnceLock::new(); #[tokio::main] async fn main() -> Result<(), Box> { diff --git a/src/utils.rs b/src/utils.rs index 4adb703..5b40efd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -9,6 +9,12 @@ pub(crate) enum DateTimeLib { Chrono, } +impl Default for DateTimeLib { + fn default() -> Self { + DateTimeLib::Chrono + } +} + impl DateTimeLib { pub(crate) fn date_type(&self) -> &str { match self { @@ -46,7 +52,7 @@ impl From for DateTimeLib { } } -#[derive(Debug)] +#[derive(Debug, Default)] pub(crate) struct SqlGenState { pub user_defined: Vec, pub date_time_lib: DateTimeLib, From 34936c3451be7f636af1e1ee9ae7747977b96166 Mon Sep 17 00:00:00 2001 From: Burak T Date: Thu, 15 Aug 2024 15:18:37 -0400 Subject: [PATCH 8/9] (fix) added args for extra struct and enum derives --- src/generate.rs | 9 ++++--- src/main.rs | 70 ++++++++++++++++++++++++++++++++++++++++++------- src/utils.rs | 35 ++++++++++++------------- 3 files changed, 83 insertions(+), 31 deletions(-) diff --git a/src/generate.rs b/src/generate.rs index a35323e..ca59305 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -12,7 +12,6 @@ use crate::utils::{DateTimeLib, SqlGenState}; use crate::STATE; pub async fn generate( - enable_serde: bool, output_folder: &str, database_url: &str, context: Option<&str>, @@ -21,6 +20,8 @@ pub async fn generate( exclude_tables: Vec, schemas: Option>, date_time_lib: DateTimeLib, + struct_derives: Vec, + enum_derives: Vec, ) -> Result<(), Box> { // Connect to the Postgres database let pool = PgPoolOptions::new() @@ -73,13 +74,15 @@ pub async fn generate( .set(SqlGenState { user_defined: enums.clone(), date_time_lib, + struct_derives, + enum_derives, }) .expect("Unable to set state"); let mut rs_enums = Vec::new(); for user_enum in enums { - let enum_code = generate_enum_code(&user_enum, &enum_rows, enable_serde); + let enum_code = generate_enum_code(&user_enum, &enum_rows); rs_enums.push(enum_code); } @@ -91,7 +94,7 @@ pub async fn generate( } } // Generate the struct code based on the row - let struct_code = generate_struct_code(&table, &rows, enable_serde); + let struct_code = generate_struct_code(&table, &rows); // Generate the query code based on the row let query_code = generate_query_code(&table, &rows); diff --git a/src/main.rs b/src/main.rs index 74c0a7b..615c22b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::sync::OnceLock; +use std::{collections::BTreeSet, sync::OnceLock}; use clap::{App, Arg, SubCommand}; use utils::{DateTimeLib, SqlGenState}; @@ -101,14 +101,30 @@ async fn main() -> Result<(), Box> { .help("Overwrites existing files sharing names in that folder"), ) .arg( - Arg::with_name("datetime-lib") - .long("datetime-lib") - .default_value("chrono") - .possible_values(&["chrono", "time"]) - .value_name("SQLGEN_DATETIME_LIB") - .help("Specifies the library to use for date and time handling") - .takes_value(true), - ); + Arg::with_name("datetime-lib") + .long("datetime-lib") + .default_value("chrono") + .possible_values(&["chrono", "time"]) + .value_name("SQLGEN_DATETIME_LIB") + .help("Specifies the library to use for date and time handling") + .takes_value(true), + ) + .arg( + Arg::with_name("struct-derive") + .long("struct-derive") + .value_name("SQLGEN_STRUCT_DERIVE") + .help("Derive created structs with given values") + .multiple(true) + .takes_value(true), + ) + .arg( + Arg::with_name("enum-derive") + .long("enum-derive") + .value_name("SQLGEN_ENUM_DERIVE") + .help("Derive created enums with given values") + .multiple(true) + .takes_value(true), + ); let migrate_subcommand = SubCommand::with_name("migrate") .about("Generate SQL migrations based on struct differences") @@ -242,7 +258,40 @@ async fn main() -> Result<(), Box> { } let enable_serde = matches.is_present("serde"); + let mut struct_derives = matches + .values_of("struct-derive") + .map(|v| { + v.into_iter() + .map(|e| e.to_string()) + .collect::>() + }) + .unwrap_or_default(); + let mut enum_derives = matches + .values_of("enum-derive") + .map(|v| { + v.into_iter() + .map(|e| e.to_string()) + .collect::>() + }) + .unwrap_or_default(); + + if enable_serde { + let mut unique_struct_derivies = struct_derives + .clone() + .into_iter() + .collect::>(); + let mut unique_enum_derivies = enum_derives + .clone() + .into_iter() + .collect::>(); + for serde_derive in ["serde::Serialize", "serde::Deserialize"] { + unique_struct_derivies.insert(serde_derive.to_string()); + unique_enum_derivies.insert(serde_derive.to_string()); + } + struct_derives = unique_struct_derivies.into_iter().collect(); + enum_derives = unique_enum_derivies.into_iter().collect(); + } let date_time_lib = matches .value_of("datetime-lib") .map(|e| e.to_string()) @@ -250,7 +299,6 @@ async fn main() -> Result<(), Box> { let date_time_lib = DateTimeLib::from(date_time_lib); generate::generate( - enable_serde, output_folder, database_url, context, @@ -259,6 +307,8 @@ async fn main() -> Result<(), Box> { exclude_tables, schemas, date_time_lib, + struct_derives, + enum_derives, ) .await?; } else if let Some(matches) = matches.subcommand_matches("migrate") { diff --git a/src/utils.rs b/src/utils.rs index 5b40efd..0f0652e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -55,6 +55,8 @@ impl From for DateTimeLib { #[derive(Debug, Default)] pub(crate) struct SqlGenState { pub user_defined: Vec, + pub struct_derives: Vec, + pub enum_derives: Vec, pub date_time_lib: DateTimeLib, } @@ -78,22 +80,20 @@ pub(crate) fn to_snake_case(input: &str) -> String { output } -pub fn generate_enum_code( - enum_name: &str, - enum_rows: &Vec, - enable_serde: bool, -) -> String { +pub fn generate_enum_code(enum_name: &str, enum_rows: &Vec) -> String { let rs_enum_name = to_pascal_case(enum_name); let mut enum_code = String::new(); - let serde_derives = if enable_serde { - ", serde::Serialize, serde::Deserialize" + let enum_derives = STATE.get().unwrap().enum_derives.join(", "); + + let extra_derives = if !enum_derives.is_empty() { + format!(", {}", enum_derives) } else { - "" + "".to_string() }; enum_code.push_str(&format!( "#[derive(sqlx::Type, Debug, Clone, Eq, PartialEq{})]\n", - serde_derives + extra_derives )); enum_code.push_str(&format!(r#"#[sqlx(type_name = "{}")]"#, enum_name)); enum_code.push_str("\n"); @@ -110,25 +110,23 @@ pub fn generate_enum_code( enum_code } -pub fn generate_struct_code( - table_name: &str, - rows: &Vec, - enable_serde: bool, -) -> String { +pub fn generate_struct_code(table_name: &str, rows: &Vec) -> String { let struct_name = to_pascal_case(table_name); let mut struct_code = String::new(); - let serde_derives = if enable_serde { - ", serde::Serialize, serde::Deserialize" + let struct_derives = STATE.get().unwrap().struct_derives.join(", "); + + let extra_derives = if !struct_derives.is_empty() { + format!(", {}", struct_derives) } else { - "" + "".to_string() }; struct_code.push_str("#![allow(dead_code)]\n"); struct_code.push_str("// Generated with sql-gen\n// https://github.com/jayy-lmao/sql-gen\n\n"); struct_code.push_str(&format!( "#[derive(sqlx::FromRow, Debug, Clone{})]\n", - serde_derives + extra_derives )); struct_code.push_str(&format!("pub struct {} {{\n", struct_name)); @@ -160,6 +158,7 @@ pub fn convert_data_type(data_type: &str) -> String { let vec_type = format!("Vec<{}>", array_of_type); return vec_type; } + let state = STATE.get().unwrap(); match data_type { From 6393fe23ffbd14828b948289cfb8020611e193f6 Mon Sep 17 00:00:00 2001 From: Burak T Date: Thu, 15 Aug 2024 16:44:52 -0400 Subject: [PATCH 9/9] (fix) generate_select_by_fk_query_code --- src/query_generate.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/query_generate.rs b/src/query_generate.rs index 9df8d8a..ee71886 100644 --- a/src/query_generate.rs +++ b/src/query_generate.rs @@ -598,7 +598,7 @@ fn generate_select_by_fk_query_code( select_code.push_str(&format!( " pub async fn all_by_{}_{}<'e, E: PgExecutor<'e>>(executor: E, {}_{}: {}) -> Result> {{\n", to_snake_case(foreign_row_table_name), - foreign_row_column_name, + to_snake_case(column_name), to_snake_case(foreign_row_table_name), foreign_row_column_name, data_type,