diff --git a/Cargo.toml b/Cargo.toml index fbf0cad..d8b15f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ description = "A CLI tool for generating models based on a SQL Database using SQ license = "MIT" [dependencies] -sqlx = { version = "0.7", features = ["postgres","runtime-tokio"] } +sqlx = { version = "0.7", features = ["postgres", "runtime-tokio"] } sqlx-cli = "0.7" clap = "3.0" regex = "1.5" @@ -14,5 +14,5 @@ chrono = "0.4" # regex = "1.5" tokio = { version = "1", features = ["full"] } dotenv = "0.15.0" -testcontainers = { version ="0.15.0" } +testcontainers = { version = "0.15.0" } testcontainers-modules = { version = "0.3.5", features = ["postgres"] } diff --git a/src/db_queries.rs b/src/db_queries.rs index 389881e..da00159 100644 --- a/src/db_queries.rs +++ b/src/db_queries.rs @@ -1,6 +1,6 @@ use sqlx::PgPool; -use crate::models::TableColumn; +use crate::models::{TableColumn, UserDefinedEnums}; pub async fn get_table_columns( pool: &PgPool, @@ -98,3 +98,28 @@ ORDER BY .await?; Ok(rows) } + +pub async fn get_user_defined_enums( + udt_names: &Vec, + pool: &PgPool, +) -> sqlx::Result> { + let query = " + SELECT + t.typname AS enum_name, + e.enumlabel AS enum_value + FROM + pg_type t + JOIN pg_enum e ON t.oid = e.enumtypid + WHERE + t.typname = ANY($1) + ORDER BY + t.typname, + e.enumsortorder; + "; + + let rows = sqlx::query_as::<_, UserDefinedEnums>(query) + .bind(udt_names) + .fetch_all(pool) + .await?; + Ok(rows) +} diff --git a/src/generate.rs b/src/generate.rs index c03d390..ca59305 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -3,11 +3,13 @@ use sqlx::PgPool; use std::fs; use std::path::Path; -use crate::db_queries::get_table_columns; +use crate::db_queries::{get_table_columns, get_user_defined_enums}; use crate::models::TableColumn; -use crate::utils::{generate_struct_code, to_pascal_case, to_snake_case}; +use crate::utils::{generate_enum_code, generate_struct_code, to_pascal_case, to_snake_case}; use crate::query_generate::generate_query_code; +use crate::utils::{DateTimeLib, SqlGenState}; +use crate::STATE; pub async fn generate( output_folder: &str, @@ -15,7 +17,11 @@ pub async fn generate( context: Option<&str>, force: bool, include_tables: Option>, + exclude_tables: Vec, schemas: Option>, + date_time_lib: DateTimeLib, + struct_derives: Vec, + enum_derives: Vec, ) -> Result<(), Box> { // Connect to the Postgres database let pool = PgPoolOptions::new() @@ -28,7 +34,23 @@ pub async fn generate( let default_schema = "public"; let rows = get_table_columns(&pool, schemas.unwrap_or(vec![default_schema]), None).await?; + let user_defined = rows + .iter() + .filter_map(|e| { + if e.data_type.as_str() == "USER-DEFINED" && e.udt_name.as_str() != "geometry" { + Some(e.udt_name.clone()) + } else { + None + } + }) + .collect::>(); + let enum_rows = get_user_defined_enums(&user_defined, &pool).await?; + let mut unique_enums = std::collections::BTreeSet::new(); + for row in &enum_rows { + unique_enums.insert(row.enum_name.clone()); + } + let enums = unique_enums.into_iter().collect::>(); // Create the output folder if it doesn't exist fs::create_dir_all(output_folder)?; @@ -36,10 +58,34 @@ pub async fn generate( for row in &rows { unique.insert(row.table_name.clone()); } - let tables: Vec = unique.into_iter().collect::>(); - + let tables: Vec = unique + .into_iter() + .collect::>() + .into_iter() + .filter(|e| !exclude_tables.contains(e)) + .collect(); + + if !enums.is_empty() { + println!("Outputting user defined enums: {:?}", enums); + } println!("Outputting tables: {:?}", tables); + STATE + .set(SqlGenState { + user_defined: enums.clone(), + date_time_lib, + struct_derives, + enum_derives, + }) + .expect("Unable to set state"); + + let mut rs_enums = Vec::new(); + + for user_enum in enums { + let enum_code = generate_enum_code(&user_enum, &enum_rows); + rs_enums.push(enum_code); + } + // Generate structs and queries for each table for table in &tables { if let Some(ts) = include_tables.clone() { @@ -75,18 +121,28 @@ pub async fn generate( } } - let context_code = generate_db_context(context.unwrap_or(&database_name), &tables, &rows); + let context_code = + generate_db_context(context.unwrap_or(&database_name), &rs_enums, &tables, &rows); let context_file_path = format!("{}/mod.rs", output_folder); fs::write(context_file_path, context_code)?; Ok(()) } -fn generate_db_context(database_name: &str, tables: &[String], _rows: &[TableColumn]) -> String { +fn generate_db_context( + database_name: &str, + enums: &[String], + tables: &[String], + _rows: &[TableColumn], +) -> String { let mut db_context_code = String::new(); db_context_code.push_str("#![allow(dead_code)]\n"); db_context_code .push_str("// Generated with sql-gen\n//https://github.com/jayy-lmao/sql-gen\n\n"); + for enum_item in enums { + db_context_code.push_str(enum_item); + db_context_code.push_str("\n\n"); + } for table in tables { db_context_code.push_str(&format!("pub mod {};\n", to_snake_case(table))); db_context_code.push_str(&format!( diff --git a/src/main.rs b/src/main.rs index 8cc8e3f..615c22b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,7 @@ +use std::{collections::BTreeSet, sync::OnceLock}; + use clap::{App, Arg, SubCommand}; +use utils::{DateTimeLib, SqlGenState}; mod db_queries; mod generate; @@ -7,6 +10,8 @@ mod models; mod query_generate; mod utils; +pub(crate) static STATE: OnceLock = OnceLock::new(); + #[tokio::main] async fn main() -> Result<(), Box> { dotenv::dotenv().ok(); @@ -22,6 +27,14 @@ async fn main() -> Result<(), Box> { .help("Sets the output folder for generated structs") .takes_value(true), ) + .arg( + Arg::with_name("serde") + .long("serde") + .default_value("true") + .value_name("SQLGEN_ENABLE_SERDE") + .help("Adds Serde derices to created structs") + .takes_value(false), + ) .arg( Arg::with_name("migrations") .short('m') @@ -68,6 +81,16 @@ async fn main() -> Result<(), Box> { .use_delimiter(true) .help("Specify the table name(s)"), ) + .arg( + Arg::with_name("exclude") + .short('e') + .long("exclude") + .takes_value(true) + .value_name("SQLGEN_EXCLUDE") + .multiple(true) + .use_delimiter(true) + .help("Specify the excluded table name(s)"), + ) .arg( Arg::new("force") .short('f') @@ -76,6 +99,31 @@ async fn main() -> Result<(), Box> { .takes_value(false) .required(false) .help("Overwrites existing files sharing names in that folder"), + ) + .arg( + Arg::with_name("datetime-lib") + .long("datetime-lib") + .default_value("chrono") + .possible_values(&["chrono", "time"]) + .value_name("SQLGEN_DATETIME_LIB") + .help("Specifies the library to use for date and time handling") + .takes_value(true), + ) + .arg( + Arg::with_name("struct-derive") + .long("struct-derive") + .value_name("SQLGEN_STRUCT_DERIVE") + .help("Derive created structs with given values") + .multiple(true) + .takes_value(true), + ) + .arg( + Arg::with_name("enum-derive") + .long("enum-derive") + .value_name("SQLGEN_ENUM_DERIVE") + .help("Derive created enums with given values") + .multiple(true) + .takes_value(true), ); let migrate_subcommand = SubCommand::with_name("migrate") @@ -195,7 +243,74 @@ async fn main() -> Result<(), Box> { let schemas: Option> = matches.values_of("schema").map(|schemas| schemas.collect()); let force = matches.is_present("force"); - generate::generate(output_folder, database_url, context, force, None, schemas).await?; + let include_tables = matches.values_of("table").map(|v| v.collect::>()); + let exclude_tables = matches + .values_of("exclude") + .map(|v| { + v.into_iter() + .map(|e| e.to_string()) + .collect::>() + }) + .unwrap_or(vec![]); + + if !exclude_tables.is_empty() { + println!("Excluding tables: {:?}", exclude_tables); + } + + let enable_serde = matches.is_present("serde"); + let mut struct_derives = matches + .values_of("struct-derive") + .map(|v| { + v.into_iter() + .map(|e| e.to_string()) + .collect::>() + }) + .unwrap_or_default(); + let mut enum_derives = matches + .values_of("enum-derive") + .map(|v| { + v.into_iter() + .map(|e| e.to_string()) + .collect::>() + }) + .unwrap_or_default(); + + if enable_serde { + let mut unique_struct_derivies = struct_derives + .clone() + .into_iter() + .collect::>(); + let mut unique_enum_derivies = enum_derives + .clone() + .into_iter() + .collect::>(); + for serde_derive in ["serde::Serialize", "serde::Deserialize"] { + unique_struct_derivies.insert(serde_derive.to_string()); + unique_enum_derivies.insert(serde_derive.to_string()); + } + + struct_derives = unique_struct_derivies.into_iter().collect(); + enum_derives = unique_enum_derivies.into_iter().collect(); + } + let date_time_lib = matches + .value_of("datetime-lib") + .map(|e| e.to_string()) + .unwrap(); + let date_time_lib = DateTimeLib::from(date_time_lib); + + generate::generate( + output_folder, + database_url, + context, + force, + include_tables, + exclude_tables, + schemas, + date_time_lib, + struct_derives, + enum_derives, + ) + .await?; } else if let Some(matches) = matches.subcommand_matches("migrate") { let input_migrations_folder = matches.value_of("migrations").unwrap_or("./migrations"); println!( diff --git a/src/models.rs b/src/models.rs index def398f..656b0a6 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,4 +1,4 @@ -#[derive(sqlx::FromRow)] +#[derive(sqlx::FromRow, Clone)] pub struct TableColumn { pub(crate) table_name: String, pub(crate) column_name: String, @@ -12,3 +12,9 @@ pub struct TableColumn { // #todo pub(crate) table_schema: String, } + +#[derive(sqlx::FromRow, Clone)] +pub struct UserDefinedEnums { + pub(crate) enum_name: String, + pub(crate) enum_value: String, +} diff --git a/src/query_generate.rs b/src/query_generate.rs index c2e37d8..ee71886 100644 --- a/src/query_generate.rs +++ b/src/query_generate.rs @@ -1,6 +1,6 @@ use crate::{ models::TableColumn, - utils::{convert_data_type, to_pascal_case, to_snake_case}, + utils::{convert_data_type, rust_type_fix, to_pascal_case, to_snake_case}, }; pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { @@ -28,14 +28,14 @@ pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { query_code.push_str(&format!("impl {}Set {{\n", struct_name)); // Generate query code for SELECT statements - query_code.push_str(&generate_select_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_select_query_code(table_name, schema_name, &rows)); query_code.push('\n'); // Generate query code for SELECT BY PK statements query_code.push_str(&generate_select_by_pk_query_code( table_name, schema_name, - rows, + &rows, )); query_code.push('\n'); @@ -43,7 +43,7 @@ pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { query_code.push_str(&generate_select_many_by_pks_query_code( table_name, schema_name, - rows, + &rows, )); query_code.push('\n'); @@ -51,30 +51,30 @@ pub fn generate_query_code(table_name: &str, rows: &[TableColumn]) -> String { query_code.push_str(&generate_select_by_pk_query_code_optional( table_name, schema_name, - rows, + &rows, )); query_code.push('\n'); - query_code.push_str(&generate_unique_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_unique_query_code(table_name, schema_name, &rows)); query_code.push('\n'); query_code.push_str(&generate_select_all_fk_queries( table_name, schema_name, - rows, + &rows, )); query_code.push('\n'); // Generate query code for INSERT statements - query_code.push_str(&generate_insert_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_insert_query_code(table_name, schema_name, &rows)); query_code.push('\n'); // Generate query code for UPDATE statements - query_code.push_str(&generate_update_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_update_query_code(table_name, schema_name, &rows)); query_code.push('\n'); // Generate query code for DELETE statements - query_code.push_str(&generate_delete_query_code(table_name, schema_name, rows)); + query_code.push_str(&generate_delete_query_code(table_name, schema_name, &rows)); query_code.push('\n'); query_code.push_str("}\n"); @@ -123,7 +123,7 @@ fn generate_select_by_pk_query_code( .map(|r| { format!( "{}: {}", - r.column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -146,7 +146,12 @@ fn generate_select_by_pk_query_code( let bind = rows .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) - .map(|r| format!(" .bind({})\n", r.column_name)) + .map(|r| { + format!( + " .bind({})\n", + rust_type_fix(r.column_name.as_str()) + ) + }) .collect::>() .join(""); @@ -187,7 +192,7 @@ fn generate_select_many_by_pks_query_code( .map(|r| { format!( "{}_list: Vec<{}>", - r.column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -210,7 +215,12 @@ fn generate_select_many_by_pks_query_code( let bind = rows .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) - .map(|r| format!(" .bind({}_list)\n", r.column_name)) + .map(|r| { + format!( + " .bind({}_list)\n", + rust_type_fix(r.column_name.as_str()) + ) + }) .collect::>() .join(""); @@ -246,7 +256,7 @@ fn generate_select_by_pk_query_code_optional( .map(|r| { format!( "{}: {}", - r.column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) @@ -271,7 +281,12 @@ fn generate_select_by_pk_query_code_optional( let bind = rows .iter() .filter(|r| r.is_primary_key && r.table_name == table_name) - .map(|r| format!(" .bind({})\n", r.column_name)) + .map(|r| { + format!( + " .bind({})\n", + rust_type_fix(r.column_name.as_str()) + ) + }) .collect::>() .join(""); @@ -289,7 +304,15 @@ fn generate_select_by_pk_query_code_optional( fn generate_unique_query_code(table_name: &str, schema_name: &str, rows: &[TableColumn]) -> String { let mut code = String::new(); - for row in rows.iter().filter(|r| r.is_unique) { + let mut unique_rows = std::collections::HashMap::<(&str, &str), &TableColumn>::new(); + for row in rows + .iter() + .filter(|r| r.table_name.as_str() == table_name && r.is_unique) + { + unique_rows.insert((&row.table_name, &row.column_name), row); + } + let unique_users_vec: Vec<&TableColumn> = unique_rows.into_values().collect(); + for row in unique_users_vec.iter() { code.push_str( generate_select_by_unique_query_code(&row.column_name, table_name, schema_name, rows) .as_str(), @@ -339,15 +362,19 @@ fn generate_select_by_unique_query_code( .map(|r| { format!( "{}: {}", - r.column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) - .collect::>() - .join(", "); + .collect::>(); + + if fn_args.is_empty() { + return String::from(""); + } + let fn_args = fn_args.join(", "); select_code.push_str(&format!( - " pub async fn by_{}<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result<{}> {{\n", + " pub async fn unique_by_{}<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result<{}> {{\n", unique_name, fn_args, struct_name )); @@ -362,7 +389,12 @@ fn generate_select_by_unique_query_code( let bind = rows .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) - .map(|r| format!(" .bind({})\n", r.column_name)) + .map(|r| { + format!( + " .bind({})\n", + rust_type_fix(r.column_name.as_str()) + ) + }) .collect::>() .join(""); @@ -396,17 +428,26 @@ fn generate_select_many_by_uniques_query_code( .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) .map(|r| { - format!( - "{}_list: Vec<{}>", - r.column_name, - convert_data_type(r.udt_name.as_str()) - ) + let rs_type = convert_data_type(r.udt_name.as_str()); + if rs_type.starts_with("Vec<") { + None + } else { + Some(format!( + "{}_list: Vec<{}>", + rust_type_fix(r.column_name.as_str()), + rs_type + )) + } }) - .collect::>() - .join(", "); + .collect::>>(); + + if fn_args.is_none() { + return String::from(""); + } + let fn_args = fn_args.unwrap().join(", "); select_code.push_str(&format!( - " pub async fn many_by_{}_list<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", + " pub async fn unique_many_by_{}_list<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", unique_name, fn_args, struct_name )); @@ -421,7 +462,12 @@ fn generate_select_many_by_uniques_query_code( let bind = rows .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) - .map(|r| format!(" .bind({}_list)\n", r.column_name)) + .map(|r| { + format!( + " .bind({}_list)\n", + rust_type_fix(r.column_name.as_str()) + ) + }) .collect::>() .join(""); @@ -452,15 +498,19 @@ fn generate_select_by_unique_query_code_optional( .map(|r| { format!( "{}: {}", - r.column_name, + rust_type_fix(r.column_name.as_str()), convert_data_type(r.udt_name.as_str()) ) }) - .collect::>() - .join(", "); + .collect::>(); + + if fn_args.is_empty() { + return String::from(""); + } + let fn_args = fn_args.join(", "); select_code.push_str(&format!( - " pub async fn by_{}_optional<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", + " pub async fn unique_by_{}_optional<'e, E: PgExecutor<'e>>(&self, executor: E, {}) -> Result> {{\n", unique_name, fn_args, struct_name @@ -477,7 +527,12 @@ fn generate_select_by_unique_query_code_optional( let bind = rows .iter() .filter(|r| r.is_unique && r.table_name == table_name && r.column_name == unique_name) - .map(|r| format!(" .bind({})\n", r.column_name)) + .map(|r| { + format!( + " .bind({})\n", + rust_type_fix(r.column_name.as_str()) + ) + }) .collect::>() .join(""); @@ -543,7 +598,7 @@ fn generate_select_by_fk_query_code( select_code.push_str(&format!( " pub async fn all_by_{}_{}<'e, E: PgExecutor<'e>>(executor: E, {}_{}: {}) -> Result> {{\n", to_snake_case(foreign_row_table_name), - foreign_row_column_name, + to_snake_case(column_name), to_snake_case(foreign_row_table_name), foreign_row_column_name, data_type, @@ -569,6 +624,7 @@ fn generate_select_by_fk_query_code( fn generate_insert_query_code(table_name: &str, schema_name: &str, rows: &[TableColumn]) -> String { let struct_name = to_pascal_case(table_name); let mut insert_code = String::new(); + insert_code.push_str(&format!( " pub async fn insert<'e, E: PgExecutor<'e>>(&self, executor: E, {}: {}) -> Result<{}> {{\n", to_snake_case(table_name), @@ -659,10 +715,11 @@ fn generate_value_list(table_name: &str, rows: &[TableColumn]) -> String { rows.iter() .filter(|row| row.table_name == table_name) .map(|row| { + let column_name = rust_type_fix(row.column_name.as_str()); format!( ".bind({}.{})", to_snake_case(&row.table_name), - to_snake_case(&row.column_name) + to_snake_case(&column_name) ) }) .collect::>() diff --git a/src/utils.rs b/src/utils.rs index cc051cd..0f0652e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,64 @@ -use crate::models::TableColumn; +use crate::{ + models::{TableColumn, UserDefinedEnums}, + STATE, +}; + +#[derive(Debug)] +pub(crate) enum DateTimeLib { + Time, + Chrono, +} + +impl Default for DateTimeLib { + fn default() -> Self { + DateTimeLib::Chrono + } +} + +impl DateTimeLib { + pub(crate) fn date_type(&self) -> &str { + match self { + DateTimeLib::Time => "time::Date", + DateTimeLib::Chrono => "chrono::NaiveDate", + } + } + pub(crate) fn time_type(&self) -> &str { + match self { + DateTimeLib::Time => "time::Time", + DateTimeLib::Chrono => "chrono::NaiveTime", + } + } + pub(crate) fn timestamp_type(&self) -> &str { + match self { + DateTimeLib::Time => "time::OffsetDateTime", + DateTimeLib::Chrono => "chrono::NaiveDateTime", + } + } + pub(crate) fn timestampz_type(&self) -> &str { + match self { + DateTimeLib::Time => "time::OffsetDateTime", + DateTimeLib::Chrono => "chrono::DateTime", + } + } +} + +impl From for DateTimeLib { + fn from(value: String) -> Self { + if value == "chrono" { + DateTimeLib::Chrono + } else { + DateTimeLib::Time + } + } +} + +#[derive(Debug, Default)] +pub(crate) struct SqlGenState { + pub user_defined: Vec, + pub struct_derives: Vec, + pub enum_derives: Vec, + pub date_time_lib: DateTimeLib, +} pub(crate) fn to_snake_case(input: &str) -> String { let mut output = String::new(); @@ -20,25 +80,68 @@ pub(crate) fn to_snake_case(input: &str) -> String { output } +pub fn generate_enum_code(enum_name: &str, enum_rows: &Vec) -> String { + let rs_enum_name = to_pascal_case(enum_name); + let mut enum_code = String::new(); + + let enum_derives = STATE.get().unwrap().enum_derives.join(", "); + + let extra_derives = if !enum_derives.is_empty() { + format!(", {}", enum_derives) + } else { + "".to_string() + }; + enum_code.push_str(&format!( + "#[derive(sqlx::Type, Debug, Clone, Eq, PartialEq{})]\n", + extra_derives + )); + enum_code.push_str(&format!(r#"#[sqlx(type_name = "{}")]"#, enum_name)); + enum_code.push_str("\n"); + enum_code.push_str(&format!("pub enum {} {{\n", rs_enum_name)); + + for row in enum_rows.iter().filter(|e| &e.enum_name == enum_name) { + enum_code.push_str(&format!(r#" #[sqlx(rename = "{}")]"#, row.enum_value)); + enum_code.push_str("\n"); + enum_code.push_str(&format!(" {},\n", to_pascal_case(&row.enum_value))) + } + + enum_code.push_str("}\n"); + + enum_code +} + pub fn generate_struct_code(table_name: &str, rows: &Vec) -> String { let struct_name = to_pascal_case(table_name); let mut struct_code = String::new(); + let struct_derives = STATE.get().unwrap().struct_derives.join(", "); + + let extra_derives = if !struct_derives.is_empty() { + format!(", {}", struct_derives) + } else { + "".to_string() + }; + struct_code.push_str("#![allow(dead_code)]\n"); struct_code.push_str("// Generated with sql-gen\n// https://github.com/jayy-lmao/sql-gen\n\n"); - struct_code.push_str("#[derive(sqlx::FromRow, Debug)]\n"); + struct_code.push_str(&format!( + "#[derive(sqlx::FromRow, Debug, Clone{})]\n", + extra_derives + )); struct_code.push_str(&format!("pub struct {} {{\n", struct_name)); for row in rows { if row.table_name == table_name { let column_name = to_snake_case(&row.column_name); let mut data_type = convert_data_type(&row.udt_name); - let optional_type = format!("Option<{}>", data_type); if row.is_nullable { - data_type = optional_type; + data_type = format!("Option<{}>", data_type); } - - struct_code.push_str(&format!(" pub {}: {},\n", column_name, data_type)); + struct_code.push_str(&format!( + " pub {}: {},\n", + rust_type_fix(column_name.as_str()), + data_type + )); } } struct_code.push_str("}\n"); @@ -56,24 +159,34 @@ pub fn convert_data_type(data_type: &str) -> String { return vec_type; } + let state = STATE.get().unwrap(); + match data_type { "bool" | "boolean" => "bool", "bytea" => "Vec", // is this right? "char" | "bpchar" | "character" => "String", - "date" => "chrono::NaiveDate", + "date" => state.date_time_lib.date_type(), "float4" | "real" => "f32", - "float8" | "double precision" => "f64", + "float8" | "double precision" | "numeric" => "f64", "int2" | "smallint" | "smallserial" => "i16", "int4" | "int" | "serial" => "i32", "int8" | "bigint" | "bigserial" => "i64", "void" => "()", "jsonb" | "json" => "serde_json::Value", "text" | "varchar" | "name" | "citext" => "String", - "time" => "chrono::NaiveTime", - "timestamp" => "chrono::NaiveDateTime", - "timestamptz" => "chrono::DateTime", + "geometry" => "String", // when sqlx supports geo types we could change this + "time" => state.date_time_lib.time_type(), + "timestamp" => state.date_time_lib.timestamp_type(), + "timestamptz" => state.date_time_lib.timestampz_type(), + "interval" => "sqlx::postgres::types::PgInterval", "uuid" => "uuid::Uuid", - _ => panic!("Unknown type: {}", data_type), + _ => { + if state.user_defined.contains(&data_type.to_string()) { + return format!("crate::{}", to_pascal_case(data_type)); + } else { + panic!("Unknown type: {}", data_type) + } + } } .to_string() } @@ -171,3 +284,11 @@ pub fn to_pascal_case(input: &str) -> String { output } + +pub(crate) fn rust_type_fix(column_name: &str) -> String { + if column_name == "type" { + String::from("r#type") + } else { + column_name.to_string() + } +}