From 38ac99cfb688749a0d74e5f2a7665d1d412eebad Mon Sep 17 00:00:00 2001 From: Duarte Nunes Date: Mon, 15 Jul 2024 23:45:46 -0300 Subject: [PATCH] Add support for type mapping Similar to Kanel's customTypeMap. One use case for this is allowing one to treat PG extension types as some other types that they behave as (e.g., map public.ulid to String). Fixes #246 --- benches/codegen.rs | 2 + benches/execution/diesel_benches.rs | 5 +- crates/cornucopia/Cargo.toml | 4 ++ crates/cornucopia/src/cli.rs | 62 ++++++++++++++++++++++-- crates/cornucopia/src/config.rs | 12 +++++ crates/cornucopia/src/lib.rs | 9 ++-- crates/cornucopia/src/prepare_queries.rs | 33 +++++++------ crates/cornucopia/src/type_registrar.rs | 62 ++++++++++++++++++++---- crates/cornucopia/src/utils.rs | 13 ++++- test_integration/src/fixtures.rs | 2 + 10 files changed, 165 insertions(+), 39 deletions(-) create mode 100644 crates/cornucopia/src/config.rs diff --git a/benches/codegen.rs b/benches/codegen.rs index 4eabc30f..8bca128d 100644 --- a/benches/codegen.rs +++ b/benches/codegen.rs @@ -17,6 +17,7 @@ fn bench(c: &mut Criterion) { gen_sync: true, gen_async: false, derive_ser: true, + config: Default::default(), }, ) .unwrap() @@ -32,6 +33,7 @@ fn bench(c: &mut Criterion) { gen_sync: true, gen_async: false, derive_ser: true, + config: Default::default(), }, ) .unwrap() diff --git a/benches/execution/diesel_benches.rs b/benches/execution/diesel_benches.rs index 5c7407e1..b5f2ee66 100644 --- a/benches/execution/diesel_benches.rs +++ b/benches/execution/diesel_benches.rs @@ -164,10 +164,7 @@ pub fn bench_insert(b: &mut Bencher, conn: &mut PgConnection, size: usize) { }; let insert = &insert; - b.iter(|| { - let insert = insert; - insert(conn) - }) + b.iter(|| insert(conn)) } pub fn loading_associations_sequentially(b: &mut Bencher, conn: &mut PgConnection) { diff --git a/crates/cornucopia/Cargo.toml b/crates/cornucopia/Cargo.toml index 04773127..f7e21b0d 100644 --- a/crates/cornucopia/Cargo.toml +++ b/crates/cornucopia/Cargo.toml @@ -33,3 +33,7 @@ heck = "0.4.0" # Order-preserving map to work around borrowing issues indexmap = "2.0.2" + +# Config handling +serde = { version = "1.0.203", features = ["derive"] } +toml = "0.8.14" diff --git a/crates/cornucopia/src/cli.rs b/crates/cornucopia/src/cli.rs index 73692260..e85502f4 100644 --- a/crates/cornucopia/src/cli.rs +++ b/crates/cornucopia/src/cli.rs @@ -1,8 +1,12 @@ -use std::path::PathBuf; +use miette::Diagnostic; +use std::{fs, path::PathBuf}; +use thiserror::Error as ThisError; use clap::{Parser, Subcommand}; -use crate::{conn, container, error::Error, generate_live, generate_managed, CodegenSettings}; +use crate::{ + config::Config, conn, container, error::Error, generate_live, generate_managed, CodegenSettings, +}; /// Command line interface to interact with Cornucopia SQL. #[derive(Parser, Debug)] @@ -28,6 +32,13 @@ struct Args { /// Derive serde's `Serialize` trait for generated types. #[clap(long)] serialize: bool, + /// The location of the configuration file. + #[clap(short, long, default_value = default_config_path())] + config: PathBuf, +} + +const fn default_config_path() -> &'static str { + "cornucopia.toml" } #[derive(Debug, Subcommand)] @@ -44,8 +55,26 @@ enum Action { }, } +/// Enumeration of the errors reported by the CLI. +#[derive(ThisError, Debug, Diagnostic)] +pub enum CliError { + /// An error occurred while loading the configuration file. + #[error("Could not load config `{path}`: ({err})")] + MissingConfig { path: String, err: std::io::Error }, + /// An error occurred while parsing the configuration file. + #[error("Could not parse config `{path}`: ({err})")] + ConfigContents { + path: String, + err: Box, + }, + /// An error occurred while running the CLI. + #[error(transparent)] + #[diagnostic(transparent)] + Internal(#[from] Error), +} + // Main entrypoint of the CLI. Parses the args and calls the appropriate routines. -pub fn run() -> Result<(), Error> { +pub fn run() -> Result<(), CliError> { let Args { podman, queries_path, @@ -54,17 +83,40 @@ pub fn run() -> Result<(), Error> { sync, r#async, serialize, + config, } = Args::parse(); + let config = match fs::read_to_string(config.as_path()) { + Ok(contents) => match toml::from_str(&contents) { + Ok(config) => config, + Err(err) => { + return Err(CliError::ConfigContents { + path: config.to_string_lossy().into_owned(), + err: err.into(), + }); + } + }, + Err(err) => { + if config.as_path().as_os_str() != default_config_path() { + return Err(CliError::MissingConfig { + path: config.to_string_lossy().into_owned(), + err, + }); + } else { + Config::default() + } + } + }; let settings = CodegenSettings { gen_async: r#async || !sync, gen_sync: sync, derive_ser: serialize, + config, }; match action { Action::Live { url } => { - let mut client = conn::from_url(&url)?; + let mut client = conn::from_url(&url).map_err(|e| CliError::Internal(e.into()))?; generate_live(&mut client, &queries_path, Some(&destination), settings)?; } Action::Schema { schema_files } => { @@ -77,7 +129,7 @@ pub fn run() -> Result<(), Error> { settings, ) { container::cleanup(podman).ok(); - return Err(e); + return Err(CliError::Internal(e)); } } }; diff --git a/crates/cornucopia/src/config.rs b/crates/cornucopia/src/config.rs new file mode 100644 index 00000000..92419e0f --- /dev/null +++ b/crates/cornucopia/src/config.rs @@ -0,0 +1,12 @@ +//! Configuration for Cornucopia. + +use std::collections::HashMap; + +use serde::Deserialize; + +/// Configuration for Cornucopia. +#[derive(Clone, Deserialize, Default, Debug)] +pub struct Config { + /// Contains a map of what given type should map to. + pub custom_type_map: HashMap, +} diff --git a/crates/cornucopia/src/lib.rs b/crates/cornucopia/src/lib.rs index b879ee28..7da88f9d 100644 --- a/crates/cornucopia/src/lib.rs +++ b/crates/cornucopia/src/lib.rs @@ -1,5 +1,6 @@ mod cli; mod codegen; +mod config; mod error; mod load_schema; mod parser; @@ -16,6 +17,7 @@ pub mod container; use std::path::Path; +use config::Config; use postgres::Client; use codegen::generate as generate_internal; @@ -31,11 +33,12 @@ pub use error::Error; pub use load_schema::load_schema; /// Struct containing the settings for code generation. -#[derive(Clone, Copy)] +#[derive(Clone)] pub struct CodegenSettings { pub gen_async: bool, pub gen_sync: bool, pub derive_ser: bool, + pub config: Config, } /// Generates Rust queries from PostgreSQL queries located at `queries_path`, @@ -54,7 +57,7 @@ pub fn generate_live>( .map(parse_query_module) .collect::>()?; // Generate - let prepared_modules = prepare(client, modules)?; + let prepared_modules = prepare(client, modules, settings.clone())?; let generated_code = generate_internal(prepared_modules, settings); // Write if let Some(d) = destination { @@ -86,7 +89,7 @@ pub fn generate_managed>( container::setup(podman)?; let mut client = conn::cornucopia_conn()?; load_schema(&mut client, schema_files)?; - let prepared_modules = prepare(&mut client, modules)?; + let prepared_modules = prepare(&mut client, modules, settings.clone())?; let generated_code = generate_internal(prepared_modules, settings); container::cleanup(podman)?; diff --git a/crates/cornucopia/src/prepare_queries.rs b/crates/cornucopia/src/prepare_queries.rs index 9edf6eee..bd2afbff 100644 --- a/crates/cornucopia/src/prepare_queries.rs +++ b/crates/cornucopia/src/prepare_queries.rs @@ -9,10 +9,9 @@ use crate::{ codegen::GenCtx, parser::{Module, NullableIdent, Query, Span, TypeAnnotation}, read_queries::ModuleInfo, - type_registrar::CornucopiaType, - type_registrar::TypeRegistrar, + type_registrar::{CornucopiaType, TypeRegistrar}, utils::KEYWORD, - validation, + validation, CodegenSettings, }; use self::error::Error; @@ -226,8 +225,12 @@ impl PreparedModule { } /// Prepares all modules -pub(crate) fn prepare(client: &mut Client, modules: Vec) -> Result { - let mut registrar = TypeRegistrar::default(); +pub(crate) fn prepare( + client: &mut Client, + modules: Vec, + settings: CodegenSettings, +) -> Result { + let mut registrar = TypeRegistrar::new(settings.config.custom_type_map); let mut tmp = Preparation { modules: Vec::new(), types: IndexMap::new(), @@ -244,16 +247,12 @@ pub(crate) fn prepare(client: &mut Client, modules: Vec) -> Result { - entry.get_mut().push(ty); - } - Entry::Vacant(entry) => { - entry.insert(vec![ty]); - } - } + for (schema_key, ty) in registrar.types() { + if let Some(ty) = prepare_type(®istrar, schema_key.name, ty, &declared) { + tmp.types + .entry(schema_key.schema.to_owned()) + .or_default() + .push(ty); } } Ok(tmp) @@ -301,7 +300,9 @@ fn prepare_type( }) .collect(), ), - _ => unreachable!(), + _ => { + return None; + } }; Some(PreparedType { name: name.to_string(), diff --git a/crates/cornucopia/src/type_registrar.rs b/crates/cornucopia/src/type_registrar.rs index 1b7ca70f..aefc9cb7 100644 --- a/crates/cornucopia/src/type_registrar.rs +++ b/crates/cornucopia/src/type_registrar.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::{collections::HashMap, rc::Rc}; use heck::ToUpperCamelCase; use indexmap::{map::Entry, IndexMap}; @@ -18,7 +18,7 @@ use self::error::Error; pub(crate) enum CornucopiaType { Simple { pg_ty: Type, - rust_name: &'static str, + rust_name: String, is_copy: bool, }, Array { @@ -33,6 +33,7 @@ pub(crate) enum CornucopiaType { struct_name: String, is_copy: bool, is_params: bool, + is_mapped: bool, }, } @@ -168,6 +169,11 @@ impl CornucopiaType { } } CornucopiaType::Domain { inner, .. } => inner.own_ty(false, ctx), + CornucopiaType::Custom { + is_mapped, + struct_name, + .. + } if *is_mapped => struct_name.to_string(), CornucopiaType::Custom { struct_name, pg_ty, .. } => custom_ty_path(pg_ty.schema(), struct_name, ctx), @@ -287,9 +293,14 @@ impl CornucopiaType { is_copy, pg_ty, struct_name, + is_mapped, .. } => { - let path = custom_ty_path(pg_ty.schema(), struct_name, ctx); + let path = if *is_mapped { + struct_name.to_string() + } else { + custom_ty_path(pg_ty.schema(), struct_name, ctx) + }; if *is_copy { path } else { @@ -311,12 +322,27 @@ pub fn custom_ty_path(schema: &str, struct_name: &str, ctx: &GenCtx) -> String { } /// Data structure holding all types known to this particular run of Cornucopia. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub(crate) struct TypeRegistrar { - pub types: IndexMap<(String, String), Rc>, + types: IndexMap<(String, String), Rc>, + type_mappings: HashMap, } impl TypeRegistrar { + /// Create a new type registrar using the specified type mappings. + pub(crate) fn new(type_mappings: HashMap) -> Self { + Self { + types: IndexMap::new(), + type_mappings, + } + } + + pub(crate) fn types(&self) -> impl Iterator { + self.types + .iter() + .map(|((schema, name), ty)| (SchemaKey::new(schema, name), ty.as_ref())) + } + pub(crate) fn register( &mut self, name: &str, @@ -331,6 +357,7 @@ impl TypeRegistrar { struct_name: rust_ty_name, is_copy, is_params, + is_mapped: false, } } @@ -345,6 +372,23 @@ impl TypeRegistrar { return Ok(&self.types[idx]); } + if let Some(mapped_type) = self.type_mappings.get(ty.name()).cloned() { + if matches!(mapped_type.as_str(), "String" | "str") { + return Ok(self.insert(ty, move || CornucopiaType::Simple { + pg_ty: Type::VARCHAR, + rust_name: "String".to_string(), + is_copy: false, + })); + } + return Ok(self.insert(ty, move || CornucopiaType::Custom { + pg_ty: ty.clone(), + is_copy: false, + struct_name: mapped_type, + is_params: true, + is_mapped: true, + })); + } + Ok(match ty.kind() { Kind::Enum(_) => self.insert(ty, || custom(ty, true, true)), Kind::Array(inner_ty) => { @@ -397,12 +441,12 @@ impl TypeRegistrar { query: query_name.span, col_name: name.to_string(), col_ty: ty.to_string(), - }) + }); } }; self.insert(ty, || CornucopiaType::Simple { pg_ty: ty.clone(), - rust_name, + rust_name: rust_name.to_owned(), is_copy, }) } @@ -412,7 +456,7 @@ impl TypeRegistrar { query: query_name.span, col_name: name.to_string(), col_ty: ty.to_string(), - }) + }); } }) } @@ -424,7 +468,7 @@ impl TypeRegistrar { .clone() } - fn insert(&mut self, ty: &Type, call: impl Fn() -> CornucopiaType) -> &Rc { + fn insert(&mut self, ty: &Type, call: impl FnOnce() -> CornucopiaType) -> &Rc { let index = match self .types .entry((ty.schema().to_owned(), ty.name().to_owned())) diff --git a/crates/cornucopia/src/utils.rs b/crates/cornucopia/src/utils.rs index e9748da3..87442abb 100644 --- a/crates/cornucopia/src/utils.rs +++ b/crates/cornucopia/src/utils.rs @@ -5,8 +5,17 @@ use postgres_types::Type; /// Allows us to query a map using type schema as key without having to own the key strings #[derive(PartialEq, Eq, Hash)] pub struct SchemaKey<'a> { - schema: &'a str, - name: &'a str, + /// The schema of this type. + pub schema: &'a str, + /// The name of this type. + pub name: &'a str, +} + +impl<'a> SchemaKey<'a> { + /// Creates a new [`SchemaKey`] from the specified components. + pub fn new(schema: &'a str, name: &'a str) -> Self { + SchemaKey { schema, name } + } } impl<'a> From<&'a Type> for SchemaKey<'a> { diff --git a/test_integration/src/fixtures.rs b/test_integration/src/fixtures.rs index 308c81c2..7769e0d0 100644 --- a/test_integration/src/fixtures.rs +++ b/test_integration/src/fixtures.rs @@ -77,6 +77,7 @@ impl From<&CodegenTest> for CodegenSettings { gen_async: codegen_test.r#async || !codegen_test.sync, gen_sync: codegen_test.sync, derive_ser: codegen_test.derive_ser, + config: Default::default(), } } } @@ -96,6 +97,7 @@ impl From<&ErrorTest> for CodegenSettings { derive_ser: false, gen_async: false, gen_sync: true, + config: Default::default(), } } }