diff --git a/.github/workflows/dependabot-automerge.yml b/.github/workflows/dependabot-automerge.yml index a9841d7..bdaea4d 100644 --- a/.github/workflows/dependabot-automerge.yml +++ b/.github/workflows/dependabot-automerge.yml @@ -16,23 +16,33 @@ jobs: automerge: name: Auto-merge Dependabot PRs runs-on: ubuntu-latest - if: github.event.pull_request.user.login == 'dependabot[bot]' steps: + - name: Check if Dependabot PR + id: check + run: | + if [ "${{ github.event.pull_request.user.login }}" = "dependabot[bot]" ]; then + echo "is_dependabot=true" >> $GITHUB_OUTPUT + else + echo "is_dependabot=false" >> $GITHUB_OUTPUT + echo "Not a Dependabot PR, skipping auto-merge" + fi + - name: Fetch Dependabot metadata + if: steps.check.outputs.is_dependabot == 'true' id: metadata uses: dependabot/fetch-metadata@v2 with: github-token: "${{ secrets.GITHUB_TOKEN }}" - name: Enable auto-merge for minor/patch updates - if: steps.metadata.outputs.update-type == 'version-update:semver-minor' || steps.metadata.outputs.update-type == 'version-update:semver-patch' + if: steps.check.outputs.is_dependabot == 'true' && (steps.metadata.outputs.update-type == 'version-update:semver-minor' || steps.metadata.outputs.update-type == 'version-update:semver-patch') run: gh pr merge --auto --squash "$PR_URL" env: PR_URL: ${{ github.event.pull_request.html_url }} GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Approve patch updates - if: steps.metadata.outputs.update-type == 'version-update:semver-patch' + if: steps.check.outputs.is_dependabot == 'true' && steps.metadata.outputs.update-type == 'version-update:semver-patch' run: gh pr review --approve "$PR_URL" env: PR_URL: ${{ github.event.pull_request.html_url }} diff --git a/README.md b/README.md index 94bd40c..4fbeb7e 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,9 @@ REUSE Compliant + + Wiki +

--- diff --git a/crates/entity-derive-impl/src/entity/api.rs b/crates/entity-derive-impl/src/entity/api.rs index 91b3be0..3fe4b4f 100644 --- a/crates/entity-derive-impl/src/entity/api.rs +++ b/crates/entity-derive-impl/src/entity/api.rs @@ -11,7 +11,8 @@ //! ```text //! api/ //! ├── mod.rs — Orchestrator (this file) -//! ├── handlers.rs — Axum handler functions with #[utoipa::path] +//! ├── crud.rs — CRUD handler functions (create, get, update, delete, list) +//! ├── handlers.rs — Command handler functions with #[utoipa::path] //! ├── router.rs — Router factory function //! └── openapi.rs — OpenApi struct for Swagger UI //! ``` @@ -58,6 +59,7 @@ //! let openapi = UserApi::openapi(); //! ``` +mod crud; mod handlers; mod openapi; mod router; @@ -69,24 +71,30 @@ use super::parse::EntityDef; /// Main entry point for API code generation. /// -/// Returns empty `TokenStream` if `api(...)` is not configured -/// or no commands are defined. +/// Returns empty `TokenStream` if `api(...)` is not configured. +/// Generates CRUD handlers if `handlers` is enabled, and command handlers +/// if commands are defined. pub fn generate(entity: &EntityDef) -> TokenStream { if !entity.has_api() { return TokenStream::new(); } - // API generation requires commands to be enabled - if !entity.has_commands() || entity.command_defs().is_empty() { + let has_crud = entity.api_config().has_handlers(); + let has_commands = entity.has_commands() && !entity.command_defs().is_empty(); + + // Need at least one type of handler to generate API + if !has_crud && !has_commands { return TokenStream::new(); } - let handlers = handlers::generate(entity); + let crud_handlers = crud::generate(entity); + let command_handlers = handlers::generate(entity); let router = router::generate(entity); let openapi = openapi::generate(entity); quote! { - #handlers + #crud_handlers + #command_handlers #router #openapi } diff --git a/crates/entity-derive-impl/src/entity/api/crud.rs b/crates/entity-derive-impl/src/entity/api/crud.rs new file mode 100644 index 0000000..08219e8 --- /dev/null +++ b/crates/entity-derive-impl/src/entity/api/crud.rs @@ -0,0 +1,708 @@ +// SPDX-FileCopyrightText: 2025-2026 RAprogramm +// SPDX-License-Identifier: MIT + +//! CRUD handler generation with utoipa annotations. +//! +//! Generates production-ready REST handlers with: +//! - OpenAPI documentation via `#[utoipa::path]` +//! - Cookie/Bearer authentication via `security` attribute +//! - Proper error responses using `masterror::ErrorResponse` +//! - Standard HTTP status codes and error handling +//! +//! # Generated Handlers +//! +//! | Operation | HTTP Method | Path | Status Codes | +//! |-----------|-------------|------|--------------| +//! | Create | POST | `/{entities}` | 201, 400, 401, 500 | +//! | Get | GET | `/{entities}/{id}` | 200, 401, 404, 500 | +//! | Update | PATCH | `/{entities}/{id}` | 200, 400, 401, 404, 500 | +//! | Delete | DELETE | `/{entities}/{id}` | 204, 401, 404, 500 | +//! | List | GET | `/{entities}` | 200, 401, 500 | +//! +//! # Security +//! +//! When `security = "cookie"` or `security = "bearer"` is specified, +//! handlers require authentication and use `Claims` extractor. +//! +//! # Example +//! +//! ```rust,ignore +//! #[derive(Entity)] +//! #[entity(table = "users", api(tag = "Users", security = "cookie", handlers))] +//! pub struct User { /* ... */ } +//! +//! // Generated handler with auth: +//! #[utoipa::path( +//! post, +//! path = "/users", +//! tag = "Users", +//! request_body(content = CreateUserRequest, description = "User data to create"), +//! responses( +//! (status = 201, description = "User created successfully", body = UserResponse), +//! (status = 400, description = "Invalid request data", body = ErrorResponse), +//! (status = 401, description = "Authentication required", body = ErrorResponse), +//! (status = 500, description = "Internal server error", body = ErrorResponse) +//! ), +//! security(("cookieAuth" = [])) +//! )] +//! pub async fn create_user( +//! _claims: Claims, +//! State(repo): State>, +//! Json(dto): Json, +//! ) -> AppResult<(StatusCode, Json)> +//! where +//! R: UserRepository + 'static, +//! { /* ... */ } +//! ``` + +use convert_case::{Case, Casing}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +use crate::entity::parse::EntityDef; + +/// Generate CRUD handler functions based on enabled handlers. +pub fn generate(entity: &EntityDef) -> TokenStream { + if !entity.api_config().has_handlers() { + return TokenStream::new(); + } + + let handlers = entity.api_config().handlers(); + + let create = if handlers.create { + generate_create_handler(entity) + } else { + TokenStream::new() + }; + let get = if handlers.get { + generate_get_handler(entity) + } else { + TokenStream::new() + }; + let update = if handlers.update { + generate_update_handler(entity) + } else { + TokenStream::new() + }; + let delete = if handlers.delete { + generate_delete_handler(entity) + } else { + TokenStream::new() + }; + let list = if handlers.list { + generate_list_handler(entity) + } else { + TokenStream::new() + }; + + quote! { + #create + #get + #update + #delete + #list + } +} + +/// Generate the create handler. +fn generate_create_handler(entity: &EntityDef) -> TokenStream { + let vis = &entity.vis; + let entity_name = entity.name(); + let entity_name_str = entity.name_str(); + let api_config = entity.api_config(); + let repo_trait = entity.ident_with("", "Repository"); + let has_security = api_config.security.is_some(); + + let handler_name = format_ident!("create_{}", entity_name_str.to_case(Case::Snake)); + let create_dto = entity.ident_with("Create", "Request"); + let response_dto = entity.ident_with("", "Response"); + + let path = build_collection_path(entity); + let tag = api_config.tag_or_default(&entity_name_str); + + let security_attr = build_security_attr(entity); + let deprecated_attr = build_deprecated_attr(entity); + + let request_body_desc = format!("Data for creating a new {}", entity_name); + let success_desc = format!("{} created successfully", entity_name); + + let utoipa_attr = if has_security { + quote! { + #[utoipa::path( + post, + path = #path, + tag = #tag, + request_body(content = #create_dto, description = #request_body_desc), + responses( + (status = 201, description = #success_desc, body = #response_dto), + (status = 400, description = "Invalid request data"), + (status = 401, description = "Authentication required"), + (status = 500, description = "Internal server error") + ), + #security_attr + #deprecated_attr + )] + } + } else { + quote! { + #[utoipa::path( + post, + path = #path, + tag = #tag, + request_body(content = #create_dto, description = #request_body_desc), + responses( + (status = 201, description = #success_desc, body = #response_dto), + (status = 400, description = "Invalid request data"), + (status = 500, description = "Internal server error") + ) + #deprecated_attr + )] + } + }; + + let doc = format!( + "Create a new {}.\n\n\ + # Responses\n\n\ + - `201 Created` - {} created successfully\n\ + - `400 Bad Request` - Invalid request data\n\ + {}\ + - `500 Internal Server Error` - Database or server error", + entity_name, + entity_name, + if has_security { + "- `401 Unauthorized` - Authentication required\n" + } else { + "" + } + ); + + quote! { + #[doc = #doc] + #utoipa_attr + #vis async fn #handler_name( + axum::extract::State(repo): axum::extract::State>, + axum::extract::Json(dto): axum::extract::Json<#create_dto>, + ) -> masterror::AppResult<(axum::http::StatusCode, axum::response::Json<#response_dto>)> + where + R: #repo_trait + 'static, + { + let entity = repo + .create(dto) + .await + .map_err(|e| masterror::AppError::internal(e.to_string()))?; + Ok((axum::http::StatusCode::CREATED, axum::response::Json(#response_dto::from(entity)))) + } + } +} + +/// Generate the get handler. +fn generate_get_handler(entity: &EntityDef) -> TokenStream { + let vis = &entity.vis; + let entity_name = entity.name(); + let entity_name_str = entity.name_str(); + let api_config = entity.api_config(); + let id_field = entity.id_field(); + let id_type = &id_field.ty; + let repo_trait = entity.ident_with("", "Repository"); + let has_security = api_config.security.is_some(); + + let handler_name = format_ident!("get_{}", entity_name_str.to_case(Case::Snake)); + let response_dto = entity.ident_with("", "Response"); + + let path = build_item_path(entity); + let tag = api_config.tag_or_default(&entity_name_str); + + let security_attr = build_security_attr(entity); + let deprecated_attr = build_deprecated_attr(entity); + + let id_desc = format!("{} unique identifier", entity_name); + let success_desc = format!("{} found", entity_name); + let not_found_desc = format!("{} not found", entity_name); + + let utoipa_attr = if has_security { + quote! { + #[utoipa::path( + get, + path = #path, + tag = #tag, + params(("id" = #id_type, Path, description = #id_desc)), + responses( + (status = 200, description = #success_desc, body = #response_dto), + (status = 401, description = "Authentication required"), + (status = 404, description = #not_found_desc), + (status = 500, description = "Internal server error") + ), + #security_attr + #deprecated_attr + )] + } + } else { + quote! { + #[utoipa::path( + get, + path = #path, + tag = #tag, + params(("id" = #id_type, Path, description = #id_desc)), + responses( + (status = 200, description = #success_desc, body = #response_dto), + (status = 404, description = #not_found_desc), + (status = 500, description = "Internal server error") + ) + #deprecated_attr + )] + } + }; + + let doc = format!( + "Get {} by ID.\n\n\ + # Responses\n\n\ + - `200 OK` - {} found\n\ + {}\ + - `404 Not Found` - {} not found\n\ + - `500 Internal Server Error` - Database or server error", + entity_name, + entity_name, + if has_security { + "- `401 Unauthorized` - Authentication required\n" + } else { + "" + }, + entity_name + ); + + let not_found_msg = format!("{} not found", entity_name); + + quote! { + #[doc = #doc] + #utoipa_attr + #vis async fn #handler_name( + axum::extract::State(repo): axum::extract::State>, + axum::extract::Path(id): axum::extract::Path<#id_type>, + ) -> masterror::AppResult> + where + R: #repo_trait + 'static, + { + let entity = repo + .find_by_id(id) + .await + .map_err(|e| masterror::AppError::internal(e.to_string()))? + .ok_or_else(|| masterror::AppError::not_found(#not_found_msg))?; + Ok(axum::response::Json(#response_dto::from(entity))) + } + } +} + +/// Generate the update handler. +fn generate_update_handler(entity: &EntityDef) -> TokenStream { + let vis = &entity.vis; + let entity_name = entity.name(); + let entity_name_str = entity.name_str(); + let api_config = entity.api_config(); + let id_field = entity.id_field(); + let id_type = &id_field.ty; + let repo_trait = entity.ident_with("", "Repository"); + let has_security = api_config.security.is_some(); + + let handler_name = format_ident!("update_{}", entity_name_str.to_case(Case::Snake)); + let update_dto = entity.ident_with("Update", "Request"); + let response_dto = entity.ident_with("", "Response"); + + let path = build_item_path(entity); + let tag = api_config.tag_or_default(&entity_name_str); + + let security_attr = build_security_attr(entity); + let deprecated_attr = build_deprecated_attr(entity); + + let id_desc = format!("{} unique identifier", entity_name); + let request_body_desc = format!("Fields to update for {}", entity_name); + let success_desc = format!("{} updated successfully", entity_name); + let not_found_desc = format!("{} not found", entity_name); + + let utoipa_attr = if has_security { + quote! { + #[utoipa::path( + patch, + path = #path, + tag = #tag, + params(("id" = #id_type, Path, description = #id_desc)), + request_body(content = #update_dto, description = #request_body_desc), + responses( + (status = 200, description = #success_desc, body = #response_dto), + (status = 400, description = "Invalid request data"), + (status = 401, description = "Authentication required"), + (status = 404, description = #not_found_desc), + (status = 500, description = "Internal server error") + ), + #security_attr + #deprecated_attr + )] + } + } else { + quote! { + #[utoipa::path( + patch, + path = #path, + tag = #tag, + params(("id" = #id_type, Path, description = #id_desc)), + request_body(content = #update_dto, description = #request_body_desc), + responses( + (status = 200, description = #success_desc, body = #response_dto), + (status = 400, description = "Invalid request data"), + (status = 404, description = #not_found_desc), + (status = 500, description = "Internal server error") + ) + #deprecated_attr + )] + } + }; + + let doc = format!( + "Update {} by ID.\n\n\ + # Responses\n\n\ + - `200 OK` - {} updated successfully\n\ + - `400 Bad Request` - Invalid request data\n\ + {}\ + - `404 Not Found` - {} not found\n\ + - `500 Internal Server Error` - Database or server error", + entity_name, + entity_name, + if has_security { + "- `401 Unauthorized` - Authentication required\n" + } else { + "" + }, + entity_name + ); + + quote! { + #[doc = #doc] + #utoipa_attr + #vis async fn #handler_name( + axum::extract::State(repo): axum::extract::State>, + axum::extract::Path(id): axum::extract::Path<#id_type>, + axum::extract::Json(dto): axum::extract::Json<#update_dto>, + ) -> masterror::AppResult> + where + R: #repo_trait + 'static, + { + let entity = repo + .update(id, dto) + .await + .map_err(|e| masterror::AppError::internal(e.to_string()))?; + Ok(axum::response::Json(#response_dto::from(entity))) + } + } +} + +/// Generate the delete handler. +fn generate_delete_handler(entity: &EntityDef) -> TokenStream { + let vis = &entity.vis; + let entity_name = entity.name(); + let entity_name_str = entity.name_str(); + let api_config = entity.api_config(); + let id_field = entity.id_field(); + let id_type = &id_field.ty; + let repo_trait = entity.ident_with("", "Repository"); + let has_security = api_config.security.is_some(); + + let handler_name = format_ident!("delete_{}", entity_name_str.to_case(Case::Snake)); + + let path = build_item_path(entity); + let tag = api_config.tag_or_default(&entity_name_str); + + let security_attr = build_security_attr(entity); + let deprecated_attr = build_deprecated_attr(entity); + + let id_desc = format!("{} unique identifier", entity_name); + let success_desc = format!("{} deleted successfully", entity_name); + let not_found_desc = format!("{} not found", entity_name); + + let utoipa_attr = if has_security { + quote! { + #[utoipa::path( + delete, + path = #path, + tag = #tag, + params(("id" = #id_type, Path, description = #id_desc)), + responses( + (status = 204, description = #success_desc), + (status = 401, description = "Authentication required"), + (status = 404, description = #not_found_desc), + (status = 500, description = "Internal server error") + ), + #security_attr + #deprecated_attr + )] + } + } else { + quote! { + #[utoipa::path( + delete, + path = #path, + tag = #tag, + params(("id" = #id_type, Path, description = #id_desc)), + responses( + (status = 204, description = #success_desc), + (status = 404, description = #not_found_desc), + (status = 500, description = "Internal server error") + ) + #deprecated_attr + )] + } + }; + + let doc = format!( + "Delete {} by ID.\n\n\ + # Responses\n\n\ + - `204 No Content` - {} deleted successfully\n\ + {}\ + - `404 Not Found` - {} not found\n\ + - `500 Internal Server Error` - Database or server error", + entity_name, + entity_name, + if has_security { + "- `401 Unauthorized` - Authentication required\n" + } else { + "" + }, + entity_name + ); + + let not_found_msg = format!("{} not found", entity_name); + + quote! { + #[doc = #doc] + #utoipa_attr + #vis async fn #handler_name( + axum::extract::State(repo): axum::extract::State>, + axum::extract::Path(id): axum::extract::Path<#id_type>, + ) -> masterror::AppResult + where + R: #repo_trait + 'static, + { + let deleted = repo + .delete(id) + .await + .map_err(|e| masterror::AppError::internal(e.to_string()))?; + if deleted { + Ok(axum::http::StatusCode::NO_CONTENT) + } else { + Err(masterror::AppError::not_found(#not_found_msg)) + } + } + } +} + +/// Generate the list handler. +fn generate_list_handler(entity: &EntityDef) -> TokenStream { + let vis = &entity.vis; + let entity_name = entity.name(); + let entity_name_str = entity.name_str(); + let api_config = entity.api_config(); + let repo_trait = entity.ident_with("", "Repository"); + let has_security = api_config.security.is_some(); + + let handler_name = format_ident!("list_{}", entity_name_str.to_case(Case::Snake)); + let response_dto = entity.ident_with("", "Response"); + + let path = build_collection_path(entity); + let tag = api_config.tag_or_default(&entity_name_str); + + let security_attr = build_security_attr(entity); + let deprecated_attr = build_deprecated_attr(entity); + + let success_desc = format!("List of {} entities", entity_name); + + let utoipa_attr = if has_security { + quote! { + #[utoipa::path( + get, + path = #path, + tag = #tag, + params( + ("limit" = Option, Query, description = "Maximum number of items to return (default: 100)"), + ("offset" = Option, Query, description = "Number of items to skip for pagination") + ), + responses( + (status = 200, description = #success_desc, body = Vec<#response_dto>), + (status = 401, description = "Authentication required"), + (status = 500, description = "Internal server error") + ), + #security_attr + #deprecated_attr + )] + } + } else { + quote! { + #[utoipa::path( + get, + path = #path, + tag = #tag, + params( + ("limit" = Option, Query, description = "Maximum number of items to return (default: 100)"), + ("offset" = Option, Query, description = "Number of items to skip for pagination") + ), + responses( + (status = 200, description = #success_desc, body = Vec<#response_dto>), + (status = 500, description = "Internal server error") + ) + #deprecated_attr + )] + } + }; + + let doc = format!( + "List {} entities with pagination.\n\n\ + # Query Parameters\n\n\ + - `limit` - Maximum number of items to return (default: 100)\n\ + - `offset` - Number of items to skip for pagination\n\n\ + # Responses\n\n\ + - `200 OK` - List of {} entities\n\ + {}\ + - `500 Internal Server Error` - Database or server error", + entity_name, + entity_name, + if has_security { + "- `401 Unauthorized` - Authentication required\n" + } else { + "" + } + ); + + quote! { + /// Pagination query parameters. + #[derive(Debug, Clone, serde::Deserialize, utoipa::IntoParams)] + #vis struct PaginationQuery { + /// Maximum number of items to return. + #[serde(default = "default_limit")] + pub limit: i64, + /// Number of items to skip for pagination. + #[serde(default)] + pub offset: i64, + } + + fn default_limit() -> i64 { 100 } + + #[doc = #doc] + #utoipa_attr + #vis async fn #handler_name( + axum::extract::State(repo): axum::extract::State>, + axum::extract::Query(pagination): axum::extract::Query, + ) -> masterror::AppResult>> + where + R: #repo_trait + 'static, + { + let entities = repo + .list(pagination.limit, pagination.offset) + .await + .map_err(|e| masterror::AppError::internal(e.to_string()))?; + let responses: Vec<#response_dto> = entities.into_iter().map(#response_dto::from).collect(); + Ok(axum::response::Json(responses)) + } + } +} + +/// Build the collection path (e.g., `/api/v1/users`). +fn build_collection_path(entity: &EntityDef) -> String { + let api_config = entity.api_config(); + let prefix = api_config.full_path_prefix(); + let entity_path = entity.name_str().to_case(Case::Kebab); + + let path = format!("{}/{}s", prefix, entity_path); + path.replace("//", "/") +} + +/// Build the item path (e.g., `/api/v1/users/{id}`). +fn build_item_path(entity: &EntityDef) -> String { + let collection = build_collection_path(entity); + format!("{}/{{id}}", collection) +} + +/// Build security attribute for a handler. +/// +/// Returns the appropriate security scheme based on the `security` option: +/// - `"cookie"` -> `security(("cookieAuth" = []))` +/// - `"bearer"` -> `security(("bearerAuth" = []))` +/// - `"api_key"` -> `security(("apiKey" = []))` +fn build_security_attr(entity: &EntityDef) -> TokenStream { + let api_config = entity.api_config(); + + if let Some(security) = &api_config.security { + let security_name = match security.as_str() { + "cookie" => "cookieAuth", + "bearer" => "bearerAuth", + "api_key" => "apiKey", + _ => "cookieAuth" + }; + quote! { security((#security_name = [])) } + } else { + TokenStream::new() + } +} + +/// Build deprecated attribute if API is deprecated. +fn build_deprecated_attr(entity: &EntityDef) -> TokenStream { + if entity.api_config().is_deprecated() { + quote! { , deprecated = true } + } else { + TokenStream::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_entity() -> EntityDef { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + #[field(create, update, response)] + pub name: String, + } + }; + EntityDef::from_derive_input(&input).unwrap() + } + + #[test] + fn collection_path_format() { + let entity = create_test_entity(); + let path = build_collection_path(&entity); + assert_eq!(path, "/users"); + } + + #[test] + fn item_path_format() { + let entity = create_test_entity(); + let path = build_item_path(&entity); + assert_eq!(path, "/users/{id}"); + } + + #[test] + fn generates_handlers_when_enabled() { + let entity = create_test_entity(); + let tokens = generate(&entity); + let output = tokens.to_string(); + assert!(output.contains("create_user")); + assert!(output.contains("get_user")); + assert!(output.contains("update_user")); + assert!(output.contains("delete_user")); + assert!(output.contains("list_user")); + } + + #[test] + fn no_handlers_when_disabled() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users"))] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let tokens = generate(&entity); + assert!(tokens.is_empty()); + } +} diff --git a/crates/entity-derive-impl/src/entity/api/openapi.rs b/crates/entity-derive-impl/src/entity/api/openapi.rs index 9133cb4..c9193b9 100644 --- a/crates/entity-derive-impl/src/entity/api/openapi.rs +++ b/crates/entity-derive-impl/src/entity/api/openapi.rs @@ -1,21 +1,31 @@ // SPDX-FileCopyrightText: 2025-2026 RAprogramm // SPDX-License-Identifier: MIT -//! OpenAPI struct generation. +//! OpenAPI struct generation for utoipa 5.x. //! //! Generates a struct that implements `utoipa::OpenApi` for Swagger UI -//! integration. +//! integration, with security schemes and paths added via the `Modify` trait. //! //! # Generated Code //! -//! For `User` entity: +//! For `User` entity with handlers and security: //! //! ```rust,ignore +//! /// OpenAPI modifier for User entity. +//! struct UserApiModifier; +//! +//! impl utoipa::Modify for UserApiModifier { +//! fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { +//! // Add security schemes +//! // Add CRUD paths with documentation +//! } +//! } +//! //! /// OpenAPI documentation for User entity endpoints. //! #[derive(utoipa::OpenApi)] //! #[openapi( -//! paths(register_user, update_email_user), -//! components(schemas(User, RegisterUser, UpdateEmailUser)), +//! components(schemas(UserResponse, CreateUserRequest, UpdateUserRequest)), +//! modifiers(&UserApiModifier), //! tags((name = "Users", description = "User management")) //! )] //! pub struct UserApi; @@ -27,10 +37,12 @@ use quote::{format_ident, quote}; use crate::entity::parse::{CommandDef, EntityDef}; -/// Generate the OpenAPI struct. +/// Generate the OpenAPI struct with modifier. pub fn generate(entity: &EntityDef) -> TokenStream { - let commands = entity.command_defs(); - if commands.is_empty() { + let has_crud = entity.api_config().has_handlers(); + let has_commands = !entity.command_defs().is_empty(); + + if !has_crud && !has_commands { return TokenStream::new(); } @@ -38,27 +50,18 @@ pub fn generate(entity: &EntityDef) -> TokenStream { let entity_name = entity.name(); let api_config = entity.api_config(); - // OpenApi struct name: UserApi let api_struct = format_ident!("{}Api", entity_name); + let modifier_struct = format_ident!("{}ApiModifier", entity_name); - // Tag for OpenAPI grouping let tag = api_config.tag_or_default(&entity.name_str()); - - // Tag description: explicit > entity doc comment > default let tag_description = api_config .tag_description .clone() .or_else(|| entity.doc().map(String::from)) .unwrap_or_else(|| format!("{} management", entity_name)); - // Handler function names for paths - let handler_names = generate_handler_names(entity, commands); - - // Schema types (entity + all command structs) - let schema_types = generate_schema_types(entity, commands); - - // Security schemes - let security_schemes = generate_security_schemes(api_config.security.as_deref()); + let schema_types = generate_all_schema_types(entity); + let modifier_impl = generate_modifier(entity, &modifier_struct); let doc = format!( "OpenAPI documentation for {} entity endpoints.\n\n\ @@ -70,89 +73,850 @@ pub fn generate(entity: &EntityDef) -> TokenStream { entity_name, api_struct ); - if security_schemes.is_empty() { - quote! { - #[doc = #doc] - #[derive(utoipa::OpenApi)] - #[openapi( - paths(#handler_names), - components(schemas(#schema_types)), - tags((name = #tag, description = #tag_description)) - )] - #vis struct #api_struct; + quote! { + #modifier_impl + + #[doc = #doc] + #[derive(utoipa::OpenApi)] + #[openapi( + components(schemas(#schema_types)), + modifiers(&#modifier_struct), + tags((name = #tag, description = #tag_description)) + )] + #vis struct #api_struct; + } +} + +/// Generate all schema types (DTOs, commands). +/// +/// Only registers schemas for enabled handlers to keep OpenAPI spec clean. +fn generate_all_schema_types(entity: &EntityDef) -> TokenStream { + let entity_name_str = entity.name_str(); + let mut types: Vec = Vec::new(); + + // CRUD DTOs - only include schemas for enabled handlers + let handlers = entity.api_config().handlers(); + if handlers.any() { + // Response is always needed (for get, list, create, update responses) + let response = entity.ident_with("", "Response"); + types.push(quote! { #response }); + + // CreateRequest only if create handler is enabled + if handlers.create { + let create = entity.ident_with("Create", "Request"); + types.push(quote! { #create }); + } + + // UpdateRequest only if update handler is enabled + if handlers.update { + let update = entity.ident_with("Update", "Request"); + types.push(quote! { #update }); } + } + + // Command structs + for cmd in entity.command_defs() { + let cmd_struct = cmd.struct_name(&entity_name_str); + types.push(quote! { #cmd_struct }); + } + + quote! { #(#types),* } +} + +/// Generate the modifier struct with Modify implementation. +/// +/// This adds security schemes, common schemas, CRUD paths, and info to the +/// OpenAPI spec. +fn generate_modifier(entity: &EntityDef, modifier_name: &syn::Ident) -> TokenStream { + let entity_name = entity.name(); + let api_config = entity.api_config(); + + let info_code = generate_info_code(entity); + let security_code = generate_security_code(api_config.security.as_deref()); + let common_schemas_code = if api_config.has_handlers() { + generate_common_schemas_code() } else { - quote! { - #[doc = #doc] - #[derive(utoipa::OpenApi)] - #[openapi( - paths(#handler_names), - components( - schemas(#schema_types), - #security_schemes - ), - tags((name = #tag, description = #tag_description)) - )] - #vis struct #api_struct; + TokenStream::new() + }; + let paths_code = if api_config.has_handlers() { + generate_paths_code(entity) + } else { + TokenStream::new() + }; + + let doc = format!("OpenAPI modifier for {} entity.", entity_name); + + quote! { + #[doc = #doc] + struct #modifier_name; + + impl utoipa::Modify for #modifier_name { + fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { + use utoipa::openapi::*; + + #info_code + #security_code + #common_schemas_code + #paths_code + } } } } -/// Generate comma-separated handler function names. -fn generate_handler_names(entity: &EntityDef, commands: &[CommandDef]) -> TokenStream { - let names: Vec = commands - .iter() - .map(|cmd| handler_function_name(entity, cmd)) - .collect(); +/// Generate code to configure OpenAPI info section. +/// +/// Sets title, description, version, license, and contact information. +fn generate_info_code(entity: &EntityDef) -> TokenStream { + let api_config = entity.api_config(); + + // Title: use configured or generate default + let title_code = if let Some(ref title) = api_config.title { + quote! { openapi.info.title = #title.to_string(); } + } else { + TokenStream::new() + }; + + // Description + let description_code = if let Some(ref description) = api_config.description { + quote! { openapi.info.description = Some(#description.to_string()); } + } else if let Some(doc) = entity.doc() { + // Use entity doc comment if no description configured + quote! { openapi.info.description = Some(#doc.to_string()); } + } else { + TokenStream::new() + }; - quote! { #(#names),* } + // API Version + let version_code = if let Some(ref version) = api_config.api_version { + quote! { openapi.info.version = #version.to_string(); } + } else { + TokenStream::new() + }; + + // License + let license_code = match (&api_config.license, &api_config.license_url) { + (Some(name), Some(url)) => { + quote! { + openapi.info.license = Some( + info::LicenseBuilder::new() + .name(#name) + .url(Some(#url)) + .build() + ); + } + } + (Some(name), None) => { + quote! { + openapi.info.license = Some( + info::LicenseBuilder::new() + .name(#name) + .build() + ); + } + } + _ => TokenStream::new() + }; + + // Contact + let has_contact = api_config.contact_name.is_some() + || api_config.contact_email.is_some() + || api_config.contact_url.is_some(); + + let contact_code = if has_contact { + let name = api_config.contact_name.as_deref().unwrap_or(""); + let email = api_config.contact_email.as_deref(); + let url = api_config.contact_url.as_deref(); + + let email_setter = if let Some(e) = email { + quote! { .email(Some(#e)) } + } else { + TokenStream::new() + }; + + let url_setter = if let Some(u) = url { + quote! { .url(Some(#u)) } + } else { + TokenStream::new() + }; + + quote! { + openapi.info.contact = Some( + info::ContactBuilder::new() + .name(Some(#name)) + #email_setter + #url_setter + .build() + ); + } + } else { + TokenStream::new() + }; + + // Deprecated flag + let deprecated_code = if api_config.is_deprecated() { + let version = api_config.deprecated_in.as_deref().unwrap_or("unknown"); + let msg = format!("Deprecated since {}", version); + quote! { + // Mark in description that API is deprecated + if let Some(ref desc) = openapi.info.description { + openapi.info.description = Some(format!("**DEPRECATED**: {}\n\n{}", #msg, desc)); + } else { + openapi.info.description = Some(format!("**DEPRECATED**: {}", #msg)); + } + } + } else { + TokenStream::new() + }; + + quote! { + #title_code + #description_code + #version_code + #license_code + #contact_code + #deprecated_code + } } -/// Generate comma-separated schema types. -fn generate_schema_types(entity: &EntityDef, commands: &[CommandDef]) -> TokenStream { - let entity_name = entity.name(); - let entity_name_str = entity.name_str(); +/// Generate common schemas (ErrorResponse, PaginationQuery) for the OpenAPI +/// spec. +fn generate_common_schemas_code() -> TokenStream { + quote! { + // Add ErrorResponse schema for error responses + if let Some(components) = openapi.components.as_mut() { + // ErrorResponse schema (RFC 7807 Problem Details) + let error_schema = schema::ObjectBuilder::new() + .schema_type(schema::Type::Object) + .title(Some("ErrorResponse")) + .description(Some("Error response following RFC 7807 Problem Details")) + .property("type", schema::ObjectBuilder::new() + .schema_type(schema::Type::String) + .description(Some("A URI reference that identifies the problem type")) + .example(Some(serde_json::json!("https://errors.example.com/not-found"))) + .build()) + .required("type") + .property("title", schema::ObjectBuilder::new() + .schema_type(schema::Type::String) + .description(Some("A short, human-readable summary of the problem")) + .example(Some(serde_json::json!("Resource not found"))) + .build()) + .required("title") + .property("status", schema::ObjectBuilder::new() + .schema_type(schema::Type::Integer) + .description(Some("HTTP status code")) + .example(Some(serde_json::json!(404))) + .build()) + .required("status") + .property("detail", schema::ObjectBuilder::new() + .schema_type(schema::Type::String) + .description(Some("A human-readable explanation specific to this occurrence")) + .example(Some(serde_json::json!("User with ID '123' was not found"))) + .build()) + .property("code", schema::ObjectBuilder::new() + .schema_type(schema::Type::String) + .description(Some("Application-specific error code")) + .example(Some(serde_json::json!("NOT_FOUND"))) + .build()) + .build(); - let command_structs: Vec = commands - .iter() - .map(|cmd| cmd.struct_name(&entity_name_str)) - .collect(); + components.schemas.insert("ErrorResponse".to_string(), error_schema.into()); - quote! { #entity_name, #(#command_structs),* } + // PaginationQuery schema + let pagination_schema = schema::ObjectBuilder::new() + .schema_type(schema::Type::Object) + .title(Some("PaginationQuery")) + .description(Some("Query parameters for paginated list endpoints")) + .property("limit", schema::ObjectBuilder::new() + .schema_type(schema::Type::Integer) + .description(Some("Maximum number of items to return")) + .default(Some(serde_json::json!(100))) + .minimum(Some(1.0)) + .maximum(Some(1000.0)) + .build()) + .property("offset", schema::ObjectBuilder::new() + .schema_type(schema::Type::Integer) + .description(Some("Number of items to skip for pagination")) + .default(Some(serde_json::json!(0))) + .minimum(Some(0.0)) + .build()) + .build(); + + components.schemas.insert("PaginationQuery".to_string(), pagination_schema.into()); + } + } } -/// Generate security schemes if configured. -fn generate_security_schemes(security: Option<&str>) -> TokenStream { - match security { - Some("bearer") => { +/// Generate security scheme code for the Modify implementation. +fn generate_security_code(security: Option<&str>) -> TokenStream { + let Some(security) = security else { + return TokenStream::new(); + }; + + let (scheme_name, scheme_impl) = match security { + "cookie" => ( + "cookieAuth", quote! { - security_schemes( - ("bearer_auth" = ( - ty = Http, - scheme = "bearer", - bearer_format = "JWT" - )) + security::SecurityScheme::ApiKey( + security::ApiKey::Cookie( + security::ApiKeyValue::with_description( + "token", + "JWT token stored in HTTP-only cookie" + ) + ) ) } - } - Some("api_key") => { + ), + "bearer" => ( + "bearerAuth", quote! { - security_schemes( - ("api_key" = ( - ty = ApiKey, - in = "header", - name = "X-API-Key" - )) + security::SecurityScheme::Http( + security::HttpBuilder::new() + .scheme(security::HttpAuthScheme::Bearer) + .bearer_format("JWT") + .description(Some("JWT token in Authorization header")) + .build() + ) + } + ), + "api_key" => ( + "apiKey", + quote! { + security::SecurityScheme::ApiKey( + security::ApiKey::Header( + security::ApiKeyValue::with_description( + "X-API-Key", + "API key for service-to-service authentication" + ) + ) ) } + ), + _ => return TokenStream::new() + }; + + quote! { + if let Some(components) = openapi.components.as_mut() { + components.add_security_scheme(#scheme_name, #scheme_impl); } - _ => TokenStream::new() } } -/// Get the handler function name. -fn handler_function_name(entity: &EntityDef, cmd: &CommandDef) -> syn::Ident { +/// Generate code to add CRUD paths to OpenAPI. +/// +/// Only generates paths for enabled handlers based on `HandlerConfig`. +fn generate_paths_code(entity: &EntityDef) -> TokenStream { + let api_config = entity.api_config(); + let handlers = api_config.handlers(); + let entity_name = entity.name(); + let entity_name_str = entity.name_str(); + let id_field = entity.id_field(); + let id_type = &id_field.ty; + + let tag = api_config.tag_or_default(&entity_name_str); + let collection_path = build_collection_path(entity); + let item_path = build_item_path(entity); + + let response_schema = entity.ident_with("", "Response"); + let create_schema = entity.ident_with("Create", "Request"); + let update_schema = entity.ident_with("Update", "Request"); + + // Schema names for $ref - just the name, not full path + let response_ref = response_schema.to_string(); + let create_ref = create_schema.to_string(); + let update_ref = update_schema.to_string(); + + // Security requirement + let security_req = if let Some(security) = &api_config.security { + let scheme_name = match security.as_str() { + "cookie" => "cookieAuth", + "bearer" => "bearerAuth", + "api_key" => "apiKey", + _ => "cookieAuth" + }; + quote! { + Some(vec![security::SecurityRequirement::new::<_, _, &str>(#scheme_name, [])]) + } + } else { + quote! { None } + }; + + // ID parameter type (only needed for item paths) + let needs_id_param = handlers.get || handlers.update || handlers.delete; + let id_type_str = quote!(#id_type).to_string().replace(' ', ""); + let id_schema_type = if id_type_str.contains("Uuid") { + quote! { + ObjectBuilder::new() + .schema_type(schema::Type::String) + .format(Some(schema::SchemaFormat::Custom("uuid".into()))) + .build() + } + } else { + quote! { + ObjectBuilder::new() + .schema_type(schema::Type::String) + .build() + } + }; + + let create_op_id = format!("create_{}", entity_name_str.to_case(Case::Snake)); + let get_op_id = format!("get_{}", entity_name_str.to_case(Case::Snake)); + let update_op_id = format!("update_{}", entity_name_str.to_case(Case::Snake)); + let delete_op_id = format!("delete_{}", entity_name_str.to_case(Case::Snake)); + let list_op_id = format!("list_{}", entity_name_str.to_case(Case::Snake)); + + let create_summary = format!("Create a new {}", entity_name); + let get_summary = format!("Get {} by ID", entity_name); + let update_summary = format!("Update {} by ID", entity_name); + let delete_summary = format!("Delete {} by ID", entity_name); + let list_summary = format!("List all {}", entity_name); + + let create_desc = format!("Creates a new {} entity", entity_name); + let get_desc = format!("Retrieves a {} by its unique identifier", entity_name); + let update_desc = format!("Updates an existing {} by ID", entity_name); + let delete_desc = format!("Deletes a {} by ID", entity_name); + let list_desc = format!("Returns a paginated list of {} entities", entity_name); + + let id_param_desc = format!("{} unique identifier", entity_name); + let created_desc = format!("{} created successfully", entity_name); + let found_desc = format!("{} found", entity_name); + let updated_desc = format!("{} updated successfully", entity_name); + let deleted_desc = format!("{} deleted successfully", entity_name); + let list_desc_resp = format!("List of {} entities", entity_name); + let not_found_desc = format!("{} not found", entity_name); + + // Common code: error helper, params, security + let common_code = quote! { + // Helper to build error response with ErrorResponse schema + let error_response = |desc: &str| -> response::Response { + response::ResponseBuilder::new() + .description(desc) + .content("application/json", + content::ContentBuilder::new() + .schema(Some(Ref::from_schema_name("ErrorResponse"))) + .build() + ) + .build() + }; + + // Security requirements + let security_req: Option> = #security_req; + }; + + // ID parameter (only if needed) + let id_param_code = if needs_id_param { + quote! { + let id_param = path::ParameterBuilder::new() + .name("id") + .parameter_in(path::ParameterIn::Path) + .required(utoipa::openapi::Required::True) + .description(Some(#id_param_desc)) + .schema(Some(#id_schema_type)) + .build(); + } + } else { + TokenStream::new() + }; + + // CREATE handler + let create_code = if handlers.create { + quote! { + let create_op = { + let mut op = path::OperationBuilder::new() + .operation_id(Some(#create_op_id)) + .tag(#tag) + .summary(Some(#create_summary)) + .description(Some(#create_desc)) + .request_body(Some( + request_body::RequestBodyBuilder::new() + .description(Some("Request body")) + .required(Some(utoipa::openapi::Required::True)) + .content("application/json", + content::ContentBuilder::new() + .schema(Some(Ref::from_schema_name(#create_ref))) + .build() + ) + .build() + )) + .response("201", + response::ResponseBuilder::new() + .description(#created_desc) + .content("application/json", + content::ContentBuilder::new() + .schema(Some(Ref::from_schema_name(#response_ref))) + .build() + ) + .build() + ) + .response("400", error_response("Invalid request data")) + .response("500", error_response("Internal server error")); + if let Some(ref sec) = security_req { + op = op.securities(Some(sec.clone())) + .response("401", error_response("Authentication required")); + } + op.build() + }; + openapi.paths.add_path_operation(#collection_path, vec![path::HttpMethod::Post], create_op); + } + } else { + TokenStream::new() + }; + + // LIST handler + let list_code = if handlers.list { + quote! { + let limit_param = path::ParameterBuilder::new() + .name("limit") + .parameter_in(path::ParameterIn::Query) + .required(utoipa::openapi::Required::False) + .description(Some("Maximum number of items to return (default: 100)")) + .schema(Some(ObjectBuilder::new().schema_type(schema::Type::Integer).build())) + .build(); + + let offset_param = path::ParameterBuilder::new() + .name("offset") + .parameter_in(path::ParameterIn::Query) + .required(utoipa::openapi::Required::False) + .description(Some("Number of items to skip for pagination")) + .schema(Some(ObjectBuilder::new().schema_type(schema::Type::Integer).build())) + .build(); + + let list_op = { + let mut op = path::OperationBuilder::new() + .operation_id(Some(#list_op_id)) + .tag(#tag) + .summary(Some(#list_summary)) + .description(Some(#list_desc)) + .parameter(limit_param) + .parameter(offset_param) + .response("200", + response::ResponseBuilder::new() + .description(#list_desc_resp) + .content("application/json", + content::ContentBuilder::new() + .schema(Some( + schema::ArrayBuilder::new() + .items(Ref::from_schema_name(#response_ref)) + .build() + )) + .build() + ) + .build() + ) + .response("500", error_response("Internal server error")); + if let Some(ref sec) = security_req { + op = op.securities(Some(sec.clone())) + .response("401", error_response("Authentication required")); + } + op.build() + }; + openapi.paths.add_path_operation(#collection_path, vec![path::HttpMethod::Get], list_op); + } + } else { + TokenStream::new() + }; + + // GET handler + let get_code = if handlers.get { + quote! { + let get_op = { + let mut op = path::OperationBuilder::new() + .operation_id(Some(#get_op_id)) + .tag(#tag) + .summary(Some(#get_summary)) + .description(Some(#get_desc)) + .parameter(id_param.clone()) + .response("200", + response::ResponseBuilder::new() + .description(#found_desc) + .content("application/json", + content::ContentBuilder::new() + .schema(Some(Ref::from_schema_name(#response_ref))) + .build() + ) + .build() + ) + .response("404", error_response(#not_found_desc)) + .response("500", error_response("Internal server error")); + if let Some(ref sec) = security_req { + op = op.securities(Some(sec.clone())) + .response("401", error_response("Authentication required")); + } + op.build() + }; + openapi.paths.add_path_operation(#item_path, vec![path::HttpMethod::Get], get_op); + } + } else { + TokenStream::new() + }; + + // UPDATE handler + let update_code = if handlers.update { + quote! { + let update_op = { + let mut op = path::OperationBuilder::new() + .operation_id(Some(#update_op_id)) + .tag(#tag) + .summary(Some(#update_summary)) + .description(Some(#update_desc)) + .parameter(id_param.clone()) + .request_body(Some( + request_body::RequestBodyBuilder::new() + .description(Some("Fields to update")) + .required(Some(utoipa::openapi::Required::True)) + .content("application/json", + content::ContentBuilder::new() + .schema(Some(Ref::from_schema_name(#update_ref))) + .build() + ) + .build() + )) + .response("200", + response::ResponseBuilder::new() + .description(#updated_desc) + .content("application/json", + content::ContentBuilder::new() + .schema(Some(Ref::from_schema_name(#response_ref))) + .build() + ) + .build() + ) + .response("400", error_response("Invalid request data")) + .response("404", error_response(#not_found_desc)) + .response("500", error_response("Internal server error")); + if let Some(ref sec) = security_req { + op = op.securities(Some(sec.clone())) + .response("401", error_response("Authentication required")); + } + op.build() + }; + openapi.paths.add_path_operation(#item_path, vec![path::HttpMethod::Patch], update_op); + } + } else { + TokenStream::new() + }; + + // DELETE handler + let delete_code = if handlers.delete { + quote! { + let delete_op = { + let mut op = path::OperationBuilder::new() + .operation_id(Some(#delete_op_id)) + .tag(#tag) + .summary(Some(#delete_summary)) + .description(Some(#delete_desc)) + .parameter(id_param.clone()) + .response("204", + response::ResponseBuilder::new() + .description(#deleted_desc) + .build() + ) + .response("404", error_response(#not_found_desc)) + .response("500", error_response("Internal server error")); + if let Some(ref sec) = security_req { + op = op.securities(Some(sec.clone())) + .response("401", error_response("Authentication required")); + } + op.build() + }; + openapi.paths.add_path_operation(#item_path, vec![path::HttpMethod::Delete], delete_op); + } + } else { + TokenStream::new() + }; + + quote! { + #common_code + #id_param_code + #create_code + #list_code + #get_code + #update_code + #delete_code + } +} + +/// Build the collection path (e.g., `/users`). +fn build_collection_path(entity: &EntityDef) -> String { + let api_config = entity.api_config(); + let prefix = api_config.full_path_prefix(); + let entity_path = entity.name_str().to_case(Case::Kebab); + + let path = format!("{}/{}s", prefix, entity_path); + path.replace("//", "/") +} + +/// Build the item path (e.g., `/users/{id}`). +fn build_item_path(entity: &EntityDef) -> String { + let collection = build_collection_path(entity); + format!("{}/{{id}}", collection) +} + +/// Get command handler function name. +#[allow(dead_code)] +fn command_handler_name(entity: &EntityDef, cmd: &CommandDef) -> syn::Ident { let entity_snake = entity.name_str().to_case(Case::Snake); let cmd_snake = cmd.name.to_string().to_case(Case::Snake); format_ident!("{}_{}", cmd_snake, entity_snake) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generate_crud_only() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + #[field(create, update, response)] + pub name: String, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let tokens = generate(&entity); + let output = tokens.to_string(); + assert!(output.contains("UserApi")); + assert!(output.contains("UserApiModifier")); + assert!(output.contains("UserResponse")); + assert!(output.contains("CreateUserRequest")); + } + + #[test] + fn generate_with_security() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", security = "bearer", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let tokens = generate(&entity); + let output = tokens.to_string(); + assert!(output.contains("UserApiModifier")); + assert!(output.contains("bearerAuth")); + } + + #[test] + fn generate_cookie_security() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", security = "cookie", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let tokens = generate(&entity); + let output = tokens.to_string(); + assert!(output.contains("cookieAuth")); + } + + #[test] + fn no_api_when_disabled() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users")] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let tokens = generate(&entity); + assert!(tokens.is_empty()); + } + + #[test] + fn collection_path_format() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let path = build_collection_path(&entity); + assert_eq!(path, "/users"); + } + + #[test] + fn item_path_format() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let path = build_item_path(&entity); + assert_eq!(path, "/users/{id}"); + } + + #[test] + fn selective_handlers_schemas_get_list_only() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers(get, list)))] + pub struct User { + #[id] + pub id: uuid::Uuid, + #[field(create, update, response)] + pub name: String, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let tokens = generate(&entity); + let output = tokens.to_string(); + // Response should be included (needed for get/list) + assert!(output.contains("UserResponse")); + // CreateRequest should NOT be included (no create handler) + assert!(!output.contains("CreateUserRequest")); + // UpdateRequest should NOT be included (no update handler) + assert!(!output.contains("UpdateUserRequest")); + } + + #[test] + fn selective_handlers_schemas_create_only() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers(create)))] + pub struct User { + #[id] + pub id: uuid::Uuid, + #[field(create, update, response)] + pub name: String, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let tokens = generate(&entity); + let output = tokens.to_string(); + // Response should be included (create returns response) + assert!(output.contains("UserResponse")); + // CreateRequest should be included + assert!(output.contains("CreateUserRequest")); + // UpdateRequest should NOT be included + assert!(!output.contains("UpdateUserRequest")); + } + + #[test] + fn selective_handlers_all_schemas() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers(create, update)))] + pub struct User { + #[id] + pub id: uuid::Uuid, + #[field(create, update, response)] + pub name: String, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let tokens = generate(&entity); + let output = tokens.to_string(); + // All schemas should be included + assert!(output.contains("UserResponse")); + assert!(output.contains("CreateUserRequest")); + assert!(output.contains("UpdateUserRequest")); + } +} diff --git a/crates/entity-derive-impl/src/entity/api/router.rs b/crates/entity-derive-impl/src/entity/api/router.rs index 05b4eb2..f0445a7 100644 --- a/crates/entity-derive-impl/src/entity/api/router.rs +++ b/crates/entity-derive-impl/src/entity/api/router.rs @@ -3,22 +3,38 @@ //! Router factory generation. //! -//! Generates a function that creates an axum Router with all entity endpoints. +//! Generates functions that create axum Routers for entity endpoints. //! -//! # Generated Code +//! # Generated Routers //! -//! For `User` entity with Register and UpdateEmail commands: +//! | Configuration | Generated Function | Type Parameter | +//! |---------------|-------------------|----------------| +//! | `handlers` | `{entity}_router` | Repository trait | +//! | `commands` | `{entity}_commands_router` | CommandHandler trait | +//! +//! # Example +//! +//! For `User` entity with both handlers and commands: //! //! ```rust,ignore -//! /// Create router for User entity endpoints. -//! pub fn user_router() -> axum::Router +//! // CRUD router +//! pub fn user_router() -> axum::Router> +//! where +//! R: UserRepository + 'static, +//! { +//! axum::Router::new() +//! .route("/users", post(create_user::).get(list_user::)) +//! .route("/users/:id", get(get_user::).patch(update_user::).delete(delete_user::)) +//! } +//! +//! // Commands router +//! pub fn user_commands_router() -> axum::Router //! where //! H: UserCommandHandler + 'static, //! H::Context: Default, //! { //! axum::Router::new() -//! .route("/api/v1/users/register", axum::routing::post(register_user::)) -//! .route("/api/v1/users/:id/update-email", axum::routing::put(update_email_user::)) +//! .route("/users/register", post(register_user::)) //! } //! ``` @@ -28,10 +44,20 @@ use quote::{format_ident, quote}; use crate::entity::parse::{CommandDef, CommandKindHint, EntityDef}; -/// Generate the router factory function. +/// Generate all router factory functions. pub fn generate(entity: &EntityDef) -> TokenStream { - let commands = entity.command_defs(); - if commands.is_empty() { + let crud_router = generate_crud_router(entity); + let commands_router = generate_commands_router(entity); + + quote! { + #crud_router + #commands_router + } +} + +/// Generate CRUD router for repository-based handlers. +fn generate_crud_router(entity: &EntityDef) -> TokenStream { + if !entity.api_config().has_handlers() { return TokenStream::new(); } @@ -40,17 +66,131 @@ pub fn generate(entity: &EntityDef) -> TokenStream { let entity_name_str = entity.name_str(); let entity_snake = entity_name_str.to_case(Case::Snake); - // Router function name: user_router let router_fn = format_ident!("{}_router", entity_snake); + let repo_trait = format_ident!("{}Repository", entity_name); + + let crud_routes = generate_crud_routes(entity); + + let doc = format!( + "Create axum router for {} CRUD endpoints.\n\n\ + # Usage\n\n\ + ```rust,ignore\n\ + let pool = Arc::new(PgPool::connect(url).await?);\n\ + let app = Router::new()\n\ + .merge({}::())\n\ + .with_state(pool);\n\ + ```", + entity_name, router_fn + ); + + quote! { + #[doc = #doc] + #vis fn #router_fn() -> axum::Router> + where + R: #repo_trait + 'static, + { + axum::Router::new() + #crud_routes + } + } +} + +/// Generate CRUD route definitions based on enabled handlers. +fn generate_crud_routes(entity: &EntityDef) -> TokenStream { + let handlers = entity.api_config().handlers(); + let snake = entity.name_str().to_case(Case::Snake); + let collection_path = build_crud_collection_path(entity); + let item_path = build_crud_item_path(entity); + + let create_handler = format_ident!("create_{}", snake); + let get_handler = format_ident!("get_{}", snake); + let update_handler = format_ident!("update_{}", snake); + let delete_handler = format_ident!("delete_{}", snake); + let list_handler = format_ident!("list_{}", snake); + + // Build collection route methods (POST, GET) + let mut collection_methods = Vec::new(); + if handlers.create { + collection_methods.push(quote! { post(#create_handler::) }); + } + if handlers.list { + collection_methods.push(quote! { get(#list_handler::) }); + } + + // Build item route methods (GET, PATCH, DELETE) + let mut item_methods = Vec::new(); + if handlers.get { + item_methods.push(quote! { get(#get_handler::) }); + } + if handlers.update { + item_methods.push(quote! { patch(#update_handler::) }); + } + if handlers.delete { + item_methods.push(quote! { delete(#delete_handler::) }); + } + + // Generate routes only for non-empty method lists + let collection_route = if !collection_methods.is_empty() { + let first = &collection_methods[0]; + let rest: Vec<_> = collection_methods.iter().skip(1).collect(); + quote! { + .route(#collection_path, axum::routing::#first #(.#rest)*) + } + } else { + TokenStream::new() + }; + + let item_route = if !item_methods.is_empty() { + let first = &item_methods[0]; + let rest: Vec<_> = item_methods.iter().skip(1).collect(); + quote! { + .route(#item_path, axum::routing::#first #(.#rest)*) + } + } else { + TokenStream::new() + }; + + quote! { + #collection_route + #item_route + } +} + +/// Build CRUD collection path (e.g., `/api/v1/users`). +fn build_crud_collection_path(entity: &EntityDef) -> String { + let api_config = entity.api_config(); + let prefix = api_config.full_path_prefix(); + let entity_path = entity.name_str().to_case(Case::Kebab); + + let path = format!("{}/{}s", prefix, entity_path); + path.replace("//", "/") +} - // Handler trait name: UserCommandHandler +/// Build CRUD item path (e.g., `/api/v1/users/{id}`). +fn build_crud_item_path(entity: &EntityDef) -> String { + let collection = build_crud_collection_path(entity); + format!("{}/{{id}}", collection) +} + +/// Generate commands router for command handler. +fn generate_commands_router(entity: &EntityDef) -> TokenStream { + let commands = entity.command_defs(); + if commands.is_empty() { + return TokenStream::new(); + } + + let vis = &entity.vis; + let entity_name = entity.name(); + let entity_name_str = entity.name_str(); + let entity_snake = entity_name_str.to_case(Case::Snake); + + let router_fn = format_ident!("{}_commands_router", entity_snake); let handler_trait = format_ident!("{}CommandHandler", entity_name); - // Generate route definitions - let routes = generate_routes(entity, commands); + let routes = generate_command_routes(entity, commands); let doc = format!( - "Create axum router for {} entity endpoints.\n\n\ + "Create axum router for {} command endpoints.\n\n\ # Usage\n\n\ ```rust,ignore\n\ let handler = Arc::new(MyHandler::new());\n\ @@ -74,20 +214,20 @@ pub fn generate(entity: &EntityDef) -> TokenStream { } } -/// Generate all route definitions. -fn generate_routes(entity: &EntityDef, commands: &[CommandDef]) -> TokenStream { +/// Generate command route definitions. +fn generate_command_routes(entity: &EntityDef, commands: &[CommandDef]) -> TokenStream { let routes: Vec = commands .iter() - .map(|cmd| generate_route(entity, cmd)) + .map(|cmd| generate_command_route(entity, cmd)) .collect(); quote! { #(#routes)* } } -/// Generate a single route definition. -fn generate_route(entity: &EntityDef, cmd: &CommandDef) -> TokenStream { - let path = build_axum_path(entity, cmd); - let handler_name = handler_function_name(entity, cmd); +/// Generate a single command route definition. +fn generate_command_route(entity: &EntityDef, cmd: &CommandDef) -> TokenStream { + let path = build_command_path(entity, cmd); + let handler_name = command_handler_name(entity, cmd); let method = axum_method_for_command(cmd); quote! { @@ -95,31 +235,30 @@ fn generate_route(entity: &EntityDef, cmd: &CommandDef) -> TokenStream { } } -/// Build the axum-style path (uses :id instead of {id}). -fn build_axum_path(entity: &EntityDef, cmd: &CommandDef) -> String { +/// Build command path (e.g., `/users/{id}/activate`). +fn build_command_path(entity: &EntityDef, cmd: &CommandDef) -> String { let api_config = entity.api_config(); let prefix = api_config.full_path_prefix(); let entity_path = entity.name_str().to_case(Case::Kebab); let cmd_path = cmd.name.to_string().to_case(Case::Kebab); let path = if cmd.requires_id { - format!("{}/{}s/:id/{}", prefix, entity_path, cmd_path) + format!("{}/{}s/{{id}}/{}", prefix, entity_path, cmd_path) } else { format!("{}/{}s/{}", prefix, entity_path, cmd_path) }; - // Normalize double slashes that can appear when prefix is empty path.replace("//", "/") } -/// Get the handler function name. -fn handler_function_name(entity: &EntityDef, cmd: &CommandDef) -> syn::Ident { +/// Get command handler function name. +fn command_handler_name(entity: &EntityDef, cmd: &CommandDef) -> syn::Ident { let entity_snake = entity.name_str().to_case(Case::Snake); let cmd_snake = cmd.name.to_string().to_case(Case::Snake); format_ident!("{}_{}", cmd_snake, entity_snake) } -/// Get the axum routing method for a command. +/// Get axum routing method for a command. fn axum_method_for_command(cmd: &CommandDef) -> syn::Ident { match cmd.kind { CommandKindHint::Create => format_ident!("post"), @@ -128,3 +267,50 @@ fn axum_method_for_command(cmd: &CommandDef) -> syn::Ident { CommandKindHint::Custom => format_ident!("post") } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn crud_collection_path() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let path = build_crud_collection_path(&entity); + assert_eq!(path, "/users"); + } + + #[test] + fn crud_item_path() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let path = build_crud_item_path(&entity); + assert_eq!(path, "/users/{id}"); + } + + #[test] + fn crud_path_with_prefix() { + let input: syn::DeriveInput = syn::parse_quote! { + #[entity(table = "users", api(tag = "Users", path_prefix = "/api/v1", handlers))] + pub struct User { + #[id] + pub id: uuid::Uuid, + } + }; + let entity = EntityDef::from_derive_input(&input).unwrap(); + let path = build_crud_collection_path(&entity); + assert_eq!(path, "/api/v1/users"); + } +} diff --git a/crates/entity-derive-impl/src/entity/parse/api.rs b/crates/entity-derive-impl/src/entity/parse/api.rs index c0b258c..ec7d419 100644 --- a/crates/entity-derive-impl/src/entity/parse/api.rs +++ b/crates/entity-derive-impl/src/entity/parse/api.rs @@ -30,6 +30,44 @@ use syn::Ident; +/// Handler configuration for selective CRUD generation. +/// +/// # Syntax +/// +/// - `handlers` - enables all handlers +/// - `handlers(create, get, list)` - enables specific handlers +#[derive(Debug, Clone, Default)] +pub struct HandlerConfig { + /// Generate create handler (POST /collection). + pub create: bool, + /// Generate get handler (GET /collection/{id}). + pub get: bool, + /// Generate update handler (PATCH /collection/{id}). + pub update: bool, + /// Generate delete handler (DELETE /collection/{id}). + pub delete: bool, + /// Generate list handler (GET /collection). + pub list: bool +} + +impl HandlerConfig { + /// Create config with all handlers enabled. + pub fn all() -> Self { + Self { + create: true, + get: true, + update: true, + delete: true, + list: true + } + } + + /// Check if any handler is enabled. + pub fn any(&self) -> bool { + self.create || self.get || self.update || self.delete || self.list + } +} + /// API configuration parsed from `#[entity(api(...))]`. /// /// Controls HTTP handler generation and OpenAPI documentation. @@ -73,7 +111,58 @@ pub struct ApiConfig { /// Version in which this API is deprecated. /// /// Marks all endpoints with `deprecated = true` in OpenAPI. - pub deprecated_in: Option + pub deprecated_in: Option, + + /// CRUD handlers configuration. + /// + /// Controls which handlers to generate: + /// - `handlers` - all handlers + /// - `handlers(create, get, list)` - specific handlers only + pub handlers: HandlerConfig, + + /// OpenAPI info: API title. + /// + /// Overrides the default title in OpenAPI spec. + /// Example: `"User Service API"` + pub title: Option, + + /// OpenAPI info: API description. + /// + /// Full description for the API, supports Markdown. + /// Example: `"RESTful API for user management"` + pub description: Option, + + /// OpenAPI info: API version. + /// + /// Semantic version string for the API. + /// Example: `"1.0.0"` + pub api_version: Option, + + /// OpenAPI info: License name. + /// + /// License under which the API is published. + /// Example: `"MIT"`, `"Apache-2.0"` + pub license: Option, + + /// OpenAPI info: License URL. + /// + /// URL to the license text. + pub license_url: Option, + + /// OpenAPI info: Contact name. + /// + /// Name of the API maintainer or team. + pub contact_name: Option, + + /// OpenAPI info: Contact email. + /// + /// Email for API support inquiries. + pub contact_email: Option, + + /// OpenAPI info: Contact URL. + /// + /// URL to API support or documentation. + pub contact_url: Option } impl ApiConfig { @@ -121,6 +210,16 @@ impl ApiConfig { self.deprecated_in.is_some() } + /// Check if any CRUD handler should be generated. + pub fn has_handlers(&self) -> bool { + self.handlers.any() + } + + /// Get handler configuration. + pub fn handlers(&self) -> &HandlerConfig { + &self.handlers + } + /// Get security scheme for a command. /// /// Returns `None` for public commands, otherwise the default security. @@ -209,12 +308,85 @@ pub fn parse_api_config(meta: &syn::Meta) -> syn::Result { let value: syn::LitStr = nested.value()?.parse()?; config.deprecated_in = Some(value.value()); } + "handlers" => { + // Support: + // - `handlers` - all handlers + // - `handlers = true/false` - all or none + // - `handlers(create, get, list)` - specific handlers + if nested.input.peek(syn::Token![=]) { + let _: syn::Token![=] = nested.input.parse()?; + let value: syn::LitBool = nested.input.parse()?; + if value.value() { + config.handlers = HandlerConfig::all(); + } + } else if nested.input.peek(syn::token::Paren) { + let content; + syn::parenthesized!(content in nested.input); + let handlers = + syn::punctuated::Punctuated::::parse_terminated( + &content + )?; + for handler in handlers { + match handler.to_string().as_str() { + "create" => config.handlers.create = true, + "get" => config.handlers.get = true, + "update" => config.handlers.update = true, + "delete" => config.handlers.delete = true, + "list" => config.handlers.list = true, + other => { + return Err(syn::Error::new( + handler.span(), + format!( + "unknown handler '{}', expected: create, get, update, delete, list", + other + ) + )); + } + } + } + } else { + config.handlers = HandlerConfig::all(); + } + } + "title" => { + let value: syn::LitStr = nested.value()?.parse()?; + config.title = Some(value.value()); + } + "description" => { + let value: syn::LitStr = nested.value()?.parse()?; + config.description = Some(value.value()); + } + "api_version" => { + let value: syn::LitStr = nested.value()?.parse()?; + config.api_version = Some(value.value()); + } + "license" => { + let value: syn::LitStr = nested.value()?.parse()?; + config.license = Some(value.value()); + } + "license_url" => { + let value: syn::LitStr = nested.value()?.parse()?; + config.license_url = Some(value.value()); + } + "contact_name" => { + let value: syn::LitStr = nested.value()?.parse()?; + config.contact_name = Some(value.value()); + } + "contact_email" => { + let value: syn::LitStr = nested.value()?.parse()?; + config.contact_email = Some(value.value()); + } + "contact_url" => { + let value: syn::LitStr = nested.value()?.parse()?; + config.contact_url = Some(value.value()); + } _ => { return Err(syn::Error::new( ident.span(), format!( "unknown api option '{}', expected: tag, tag_description, path_prefix, \ - security, public, version, deprecated_in", + security, public, version, deprecated_in, handlers, title, description, \ + api_version, license, license_url, contact_name, contact_email, contact_url", ident_str ) )); @@ -343,4 +515,62 @@ mod tests { }; assert_eq!(config.full_path_prefix(), "/api/v1"); } + + #[test] + fn parse_handlers_flag() { + let config = parse_test_config(r#"api(tag = "Users", handlers)"#); + assert!(config.has_handlers()); + } + + #[test] + fn parse_handlers_true() { + let config = parse_test_config(r#"api(tag = "Users", handlers = true)"#); + assert!(config.has_handlers()); + } + + #[test] + fn parse_handlers_false() { + let config = parse_test_config(r#"api(tag = "Users", handlers = false)"#); + assert!(!config.has_handlers()); + } + + #[test] + fn default_handlers_false() { + let config = parse_test_config(r#"api(tag = "Users")"#); + assert!(!config.has_handlers()); + } + + #[test] + fn parse_handlers_selective() { + let config = parse_test_config(r#"api(tag = "Users", handlers(create, get, list))"#); + assert!(config.has_handlers()); + assert!(config.handlers().create); + assert!(config.handlers().get); + assert!(!config.handlers().update); + assert!(!config.handlers().delete); + assert!(config.handlers().list); + } + + #[test] + fn parse_handlers_single() { + let config = parse_test_config(r#"api(tag = "Users", handlers(get))"#); + assert!(config.has_handlers()); + assert!(!config.handlers().create); + assert!(config.handlers().get); + assert!(!config.handlers().update); + assert!(!config.handlers().delete); + assert!(!config.handlers().list); + } + + #[test] + fn parse_handlers_all_explicit() { + let config = parse_test_config( + r#"api(tag = "Users", handlers(create, get, update, delete, list))"# + ); + assert!(config.handlers().create); + assert!(config.handlers().get); + assert!(config.handlers().update); + assert!(config.handlers().delete); + assert!(config.handlers().list); + } } diff --git a/examples/basic/Cargo.toml b/examples/basic/Cargo.toml index f200972..5fb2d38 100644 --- a/examples/basic/Cargo.toml +++ b/examples/basic/Cargo.toml @@ -16,6 +16,7 @@ validate = [] [dependencies] entity-derive = { path = "../../crates/entity-derive", features = ["postgres", "api"] } +masterror = { version = "0.27", features = ["axum", "openapi"] } axum = "0.8" tokio = { version = "1", features = ["full"] } sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "uuid", "chrono"] } diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 55ce2c9..f71c5db 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -1,38 +1,59 @@ // SPDX-FileCopyrightText: 2025-2026 RAprogramm // SPDX-License-Identifier: MIT -//! Axum CRUD Example with entity-derive +//! Basic CRUD Example with Generated Handlers //! //! Demonstrates full CRUD operations using: -//! - entity-derive for code generation +//! - entity-derive for code generation including HTTP handlers //! - Axum for HTTP routing //! - sqlx for PostgreSQL access //! - utoipa for OpenAPI docs +//! +//! Key features: +//! - `api(tag = "Users", handlers)` generates CRUD handlers automatically +//! - `user_router()` provides ready-to-use axum Router +//! - `UserApi` provides OpenAPI documentation use std::sync::Arc; -use axum::{ - Json, Router, - extract::{Path, Query, State}, - http::StatusCode, - response::IntoResponse, - routing::{get, post}, -}; +use axum::Router; use chrono::{DateTime, Utc}; use entity_derive::Entity; -use serde::Deserialize; use sqlx::PgPool; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; use uuid::Uuid; // ============================================================================ -// Entity Definition +// Entity Definition with Generated API // ============================================================================ -/// User entity with full CRUD support. +/// User entity with full CRUD support and cookie authentication. +/// +/// The `api(tag = "Users", security = "cookie", handlers)` attribute generates: +/// - `create_user()` - POST /users (requires auth) +/// - `get_user()` - GET /users/{id} (requires auth) +/// - `update_user()` - PATCH /users/{id} (requires auth) +/// - `delete_user()` - DELETE /users/{id} (requires auth) +/// - `list_user()` - GET /users (requires auth) +/// - `user_router()` - axum Router with all routes +/// - `UserApi` - OpenAPI documentation with security scheme #[derive(Debug, Clone, Entity)] -#[entity(table = "users", schema = "public")] +#[entity( + table = "users", + schema = "public", + api( + tag = "Users", + security = "cookie", + handlers, + title = "User Service API", + description = "RESTful API for user management with cookie-based authentication", + api_version = "1.0.0", + license = "MIT", + contact_name = "API Support", + contact_email = "support@example.com" + ) +)] pub struct User { /// Unique identifier (UUID v7). #[id] @@ -61,203 +82,25 @@ pub struct User { pub updated_at: DateTime, } -// ============================================================================ -// Application State -// ============================================================================ - -#[derive(Clone)] -struct AppState { - pool: Arc, -} - -impl AppState { - fn new(pool: PgPool) -> Self { - Self { - pool: Arc::new(pool), - } - } - - fn repo(&self) -> &PgPool { - &self.pool - } -} - -// ============================================================================ -// Query Parameters -// ============================================================================ - -#[derive(Debug, Deserialize)] -struct ListParams { - #[serde(default = "default_limit")] - limit: i64, - #[serde(default)] - offset: i64, -} - -fn default_limit() -> i64 { - 20 -} - -// ============================================================================ -// Error Handling -// ============================================================================ - -enum AppError { - NotFound, - Database(sqlx::Error), -} - -impl From for AppError { - fn from(err: sqlx::Error) -> Self { - match err { - sqlx::Error::RowNotFound => Self::NotFound, - _ => Self::Database(err), - } - } -} - -impl IntoResponse for AppError { - fn into_response(self) -> axum::response::Response { - match self { - Self::NotFound => (StatusCode::NOT_FOUND, "Not found").into_response(), - Self::Database(e) => { - tracing::error!("Database error: {e}"); - (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() - } - } - } -} - -// ============================================================================ -// HTTP Handlers -// ============================================================================ - -/// Create a new user. -#[utoipa::path( - post, - path = "/users", - request_body = CreateUserRequest, - responses( - (status = 201, description = "User created", body = UserResponse), - (status = 500, description = "Internal error"), - ) -)] -async fn create_user( - State(state): State, - Json(dto): Json, -) -> Result { - let user = state.repo().create(dto).await?; - Ok((StatusCode::CREATED, Json(UserResponse::from(user)))) -} - -/// Get user by ID. -#[utoipa::path( - get, - path = "/users/{id}", - params(("id" = Uuid, Path, description = "User ID")), - responses( - (status = 200, description = "User found", body = UserResponse), - (status = 404, description = "User not found"), - ) -)] -async fn get_user( - State(state): State, - Path(id): Path, -) -> Result { - let user = state.repo().find_by_id(id).await?.ok_or(AppError::NotFound)?; - Ok(Json(UserResponse::from(user))) -} - -/// Update user by ID. -#[utoipa::path( - patch, - path = "/users/{id}", - params(("id" = Uuid, Path, description = "User ID")), - request_body = UpdateUserRequest, - responses( - (status = 200, description = "User updated", body = UserResponse), - (status = 404, description = "User not found"), - ) -)] -async fn update_user( - State(state): State, - Path(id): Path, - Json(dto): Json, -) -> Result { - let user = state.repo().update(id, dto).await?; - Ok(Json(UserResponse::from(user))) -} - -/// Delete user by ID. -#[utoipa::path( - delete, - path = "/users/{id}", - params(("id" = Uuid, Path, description = "User ID")), - responses( - (status = 204, description = "User deleted"), - (status = 404, description = "User not found"), - ) -)] -async fn delete_user( - State(state): State, - Path(id): Path, -) -> Result { - let deleted = state.repo().delete(id).await?; - if deleted { - Ok(StatusCode::NO_CONTENT) - } else { - Err(AppError::NotFound) - } -} - -/// List users with pagination. -#[utoipa::path( - get, - path = "/users", - params( - ("limit" = Option, Query, description = "Max results"), - ("offset" = Option, Query, description = "Skip results"), - ), - responses( - (status = 200, description = "List of users", body = Vec), - ) -)] -async fn list_users( - State(state): State, - Query(params): Query, -) -> Result { - let users = state.repo().list(params.limit, params.offset).await?; - let responses: Vec = users.into_iter().map(UserResponse::from).collect(); - Ok(Json(responses)) -} - -// ============================================================================ -// OpenAPI Documentation -// ============================================================================ - -#[derive(OpenApi)] -#[openapi( - paths( - create_user, - get_user, - update_user, - delete_user, - list_users, - ), - components(schemas(CreateUserRequest, UpdateUserRequest, UserResponse)) -)] -struct ApiDoc; - // ============================================================================ // Router Setup // ============================================================================ -fn app(state: AppState) -> Router { +/// Create the application router. +/// +/// Uses the generated `user_router()` function which includes: +/// - POST /users - create user +/// - GET /users - list users +/// - GET /users/{id} - get user +/// - PATCH /users/{id} - update user +/// - DELETE /users/{id} - delete user +fn app(pool: Arc) -> Router { Router::new() - .route("/users", post(create_user).get(list_users)) - .route("/users/{id}", get(get_user).patch(update_user).delete(delete_user)) - .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) - .with_state(state) + // Use the generated router for CRUD operations + .merge(user_router::()) + // Add Swagger UI using generated OpenAPI struct + .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", UserApi::openapi())) + .with_state(pool) } // ============================================================================ @@ -267,13 +110,11 @@ fn app(state: AppState) -> Router { #[tokio::main] async fn main() { tracing_subscriber::fmt() - .with_env_filter("axum_crud_example=debug,tower_http=debug") + .with_env_filter("example_basic=debug,tower_http=debug") .init(); - let database_url = - std::env::var("DATABASE_URL").unwrap_or_else(|_| { - "postgres://postgres:postgres@localhost:5432/entity_example".to_string() - }); + let database_url = std::env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/entity_example".into()); let pool = PgPool::connect(&database_url) .await @@ -284,12 +125,18 @@ async fn main() { .await .expect("Failed to run migrations"); - let state = AppState::new(pool); - let app = app(state); + let app = app(Arc::new(pool)); let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); tracing::info!("Listening on http://localhost:3000"); tracing::info!("Swagger UI: http://localhost:3000/swagger-ui"); + tracing::info!(""); + tracing::info!("Try these endpoints:"); + tracing::info!(" POST /users - Create a user"); + tracing::info!(" GET /users - List users"); + tracing::info!(" GET /users/{{id}} - Get user by ID"); + tracing::info!(" PATCH /users/{{id}} - Update user"); + tracing::info!(" DELETE /users/{{id}} - Delete user"); axum::serve(listener, app).await.unwrap(); }