Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bench/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fn bench(c: &mut Criterion) {
CodegenSettings {
is_async: false,
derive_ser: true,
is_recursive: false,
},
)
.unwrap()
Expand All @@ -30,6 +31,7 @@ fn bench(c: &mut Criterion) {
CodegenSettings {
is_async: true,
derive_ser: true,
is_recursive: false,
},
)
.unwrap()
Expand Down
6 changes: 6 additions & 0 deletions cornucopia/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ struct Args {
/// Derive serde's `Serialize` trait for generated types.
#[clap(long)]
serialize: bool,
/// Recursive lookup
#[clap(long)]
recursive: bool,
}

#[derive(Debug, Subcommand)]
Expand All @@ -48,6 +51,7 @@ pub fn run() -> Result<(), Error> {
action,
sync,
serialize,
recursive,
} = Args::parse();

match action {
Expand All @@ -60,6 +64,7 @@ pub fn run() -> Result<(), Error> {
CodegenSettings {
is_async: !sync,
derive_ser: serialize,
is_recursive: recursive,
},
)?;
}
Expand All @@ -73,6 +78,7 @@ pub fn run() -> Result<(), Error> {
CodegenSettings {
is_async: !sync,
derive_ser: serialize,
is_recursive: recursive,
},
) {
container::cleanup(podman).ok();
Expand Down
2 changes: 2 additions & 0 deletions cornucopia/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ fn gen_row_structs(
CodegenSettings {
is_async,
derive_ser,
..
}: CodegenSettings,
) {
let PreparedItem {
Expand Down Expand Up @@ -614,6 +615,7 @@ fn gen_custom_type(
CodegenSettings {
derive_ser,
is_async,
..
}: CodegenSettings,
) {
let PreparedType {
Expand Down
18 changes: 13 additions & 5 deletions cornucopia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use codegen::generate as generate_internal;
use error::WriteOutputError;
use parser::parse_query_module;
use prepare_queries::prepare;
use read_queries::read_query_modules;
use read_queries::{read_query_modules, read_query_modules_recursive};

#[doc(hidden)]
pub use cli::run;
Expand All @@ -33,6 +33,7 @@ pub use load_schema::load_schema;
pub struct CodegenSettings {
pub is_async: bool,
pub derive_ser: bool,
pub is_recursive: bool,
}

/// Generates Rust queries from PostgreSQL queries located at `queries_path`,
Expand All @@ -46,10 +47,17 @@ pub fn generate_live(
settings: CodegenSettings,
) -> Result<String, Error> {
// Read
let modules = read_query_modules(queries_path)?
.into_iter()
.map(parse_query_module)
.collect::<Result<_, parser::error::Error>>()?;
let modules = if settings.is_recursive {
read_query_modules_recursive(queries_path)?
.into_iter()
.map(parse_query_module)
.collect::<Result<_, parser::error::Error>>()?
} else {
read_query_modules(queries_path)?
.into_iter()
.map(parse_query_module)
.collect::<Result<_, parser::error::Error>>()?
};
// Generate
let prepared_modules = prepare(client, modules)?;
let generated_code = generate_internal(prepared_modules, settings);
Expand Down
72 changes: 72 additions & 0 deletions cornucopia/src/read_queries.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::{Path, PathBuf};

use miette::NamedSource;

use self::error::Error;
Expand Down Expand Up @@ -68,6 +70,76 @@ pub(crate) fn read_query_modules(dir_path: &str) -> Result<Vec<ModuleInfo>, Erro
Ok(modules_info)
}

/// Reads queries in the directory and checks each directory found within given path.
/// Only .sql files are considered.
///
/// # Error
/// Returns an error if `dir_path` does not point to a valid directory or if a query file cannot be parsed.
pub(crate) fn read_query_modules_recursive(dir_path: &str) -> Result<Vec<ModuleInfo>, Error> {
let mut modules_info = Vec::new();
for entry_result in std::fs::read_dir(dir_path).map_err(|err| Error {
err,
path: String::from(dir_path),
})? {
// Directory entry
let entry = entry_result.map_err(|err| Error {
err,
path: dir_path.to_owned(),
})?;
let path_buf = entry.path();

let path_bufs = if path_buf.is_dir() {
find_queries(&path_buf, Vec::<PathBuf>::new())
} else {
vec![path_buf]
};

// Check we're dealing with a .sql file
for path_buf in path_bufs {
if path_buf
.extension()
.map(|extension| extension == "sql")
.unwrap_or_default()
{
let module_name = path_buf
.file_stem()
.expect("is a file")
.to_str()
.expect("file name is valid utf8")
.to_string();

let file_contents = std::fs::read_to_string(&path_buf).map_err(|err| Error {
err,
path: dir_path.to_owned(),
})?;

modules_info.push(ModuleInfo {
path: String::from(path_buf.to_string_lossy()),
name: module_name,
content: file_contents,
});
}
}
}
// Sort module for consistent codegen
modules_info.sort_by(|a, b| a.name.cmp(&b.name));
Ok(modules_info)
}

fn find_queries(start: &Path, mut queries: Vec<PathBuf>) -> Vec<PathBuf> {
for entry in start.read_dir().unwrap() {
let entry = entry.unwrap();
let path = entry.path();
if path.is_dir() {
queries = find_queries(&path, queries);
} else {
queries.push(path);
}
}

queries
}

pub(crate) mod error {
use miette::Diagnostic;
use thiserror::Error as ThisError;
Expand Down
3 changes: 3 additions & 0 deletions integration/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ fn run_errors_test(
CodegenSettings {
is_async: false,
derive_ser: false,
is_recursive: false,
},
)?;
Ok(())
Expand Down Expand Up @@ -235,6 +236,7 @@ fn run_codegen_test(
CodegenSettings {
is_async,
derive_ser,
is_recursive: false,
},
)
.map_err(Error::report)?;
Expand All @@ -254,6 +256,7 @@ fn run_codegen_test(
CodegenSettings {
is_async,
derive_ser,
is_recursive: false,
},
)
.map_err(Error::report)?;
Expand Down