diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index f6166b1..3164d35 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -8,12 +8,45 @@ from app.utils import APIResponse +def _format_validation_errors(errors: list[dict]) -> str: + missing_fields = [] + invalid_fields = [] + + for err in errors: + loc = [x for x in err["loc"] if x != "body"] + + if not loc: + continue + + field = ".".join(str(x) for x in loc) + + if err["msg"] == "Field required": + missing_fields.append(field) + else: + invalid_fields.append(f"{field} ({err['msg']})") + + messages = [] + + if missing_fields: + messages.append( + f"Missing required field(s): {', '.join(missing_fields)}" + ) + + if invalid_fields: + messages.append( + f"Invalid field(s): {', '.join(invalid_fields)}" + ) + + return ". ".join(messages) + + def register_exception_handlers(app: FastAPI): @app.exception_handler(RequestValidationError) async def validation_error_handler(request: Request, exc: RequestValidationError): + formatted_message = _format_validation_errors(exc.errors()) return JSONResponse( status_code=HTTP_422_UNPROCESSABLE_ENTITY, - content=APIResponse.failure_response(exc.errors()).model_dump(), + content=APIResponse.failure_response(error=formatted_message).model_dump(), ) @app.exception_handler(HTTPException) diff --git a/backend/app/models/base_validator_config.py b/backend/app/models/base_validator_config.py index f4a2345..d9b051a 100644 --- a/backend/app/models/base_validator_config.py +++ b/backend/app/models/base_validator_config.py @@ -1,5 +1,6 @@ from guardrails import OnFailAction from guardrails.validators import Validator +from pydantic import ConfigDict from sqlmodel import SQLModel from app.core.enum import GuardrailOnFail @@ -13,9 +14,12 @@ } class BaseValidatorConfig(SQLModel): - on_fail: GuardrailOnFail = GuardrailOnFail.Fix + model_config = ConfigDict( + extra="forbid", + arbitrary_types_allowed=True + ) - model_config = {"arbitrary_types_allowed": True} + on_fail: GuardrailOnFail = GuardrailOnFail.Fix def resolve_on_fail(self): try: diff --git a/backend/app/models/guardrail_config.py b/backend/app/models/guardrail_config.py index bfe36a6..1da302d 100644 --- a/backend/app/models/guardrail_config.py +++ b/backend/app/models/guardrail_config.py @@ -1,6 +1,7 @@ from typing import Annotated, List, Optional, Union from uuid import UUID +from pydantic import ConfigDict from sqlmodel import Field, SQLModel # todo this could be improved by having some auto-discovery mechanism inside @@ -22,6 +23,7 @@ ] class GuardrailRequest(SQLModel): + model_config = ConfigDict(extra="forbid") request_id: str input: str validators: List[ValidatorConfigItem]