From c5ce1c2beb053016e391727152aae025b4c98d25 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Mon, 9 Feb 2026 13:16:38 +0530 Subject: [PATCH 1/3] added custom flag in run_guardrails API --- backend/app/api/routes/guardrails.py | 14 ++++++++++---- .../app/tests/test_guardrails_api_integration.py | 4 ++-- backend/app/tests/utils/constants.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 750ac71..d777120 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -1,9 +1,9 @@ -import uuid from uuid import UUID +import uuid from fastapi import APIRouter from guardrails.guard import Guard -from guardrails.validators import FailResult +from guardrails.validators import FailResult, PassResult from app.api.deps import AuthDep, SessionDep from app.core.constants import REPHRASE_ON_FAIL_PREFIX @@ -25,6 +25,7 @@ async def run_guardrails( payload: GuardrailRequest, session: SessionDep, _: AuthDep, + include_all_validator_logs: bool = False, ): request_log_crud = RequestLogCrud(session=session) validator_log_crud = ValidatorLogCrud(session=session) @@ -41,6 +42,7 @@ async def run_guardrails( request_log_crud, request_log.id, validator_log_crud, + include_all_validator_logs ) @router.get("/") @@ -73,6 +75,7 @@ async def _validate_with_guard( request_log_crud: RequestLogCrud, request_log_id: UUID, validator_log_crud: ValidatorLogCrud, + include_all_validator_logs: bool = False, ) -> APIResponse: """ Runs Guardrails validation on input/output data, persists request & validator logs, @@ -112,7 +115,7 @@ def _finalize( ) if guard is not None: - add_validator_logs(guard, request_log_id, validator_log_crud) + add_validator_logs(guard, request_log_id, validator_log_crud, include_all_validator_logs) rephrase_needed = ( validated_output is not None @@ -157,7 +160,7 @@ def _finalize( error_message=str(exc), ) -def add_validator_logs(guard: Guard, request_log_id: UUID, validator_log_crud: ValidatorLogCrud): +def add_validator_logs(guard: Guard, request_log_id: UUID, validator_log_crud: ValidatorLogCrud, include_all_validator_logs: bool = False): history = getattr(guard, "history", None) if not history: return @@ -174,6 +177,9 @@ def add_validator_logs(guard: Guard, request_log_id: UUID, validator_log_crud: V for log in iteration.outputs.validator_logs: result = log.validation_result + if include_all_validator_logs == False and isinstance(result, PassResult): + continue + error_message = None if isinstance(result, FailResult): error_message = result.error_message diff --git a/backend/app/tests/test_guardrails_api_integration.py b/backend/app/tests/test_guardrails_api_integration.py index 48ca1e8..485b532 100644 --- a/backend/app/tests/test_guardrails_api_integration.py +++ b/backend/app/tests/test_guardrails_api_integration.py @@ -70,7 +70,7 @@ def test_input_guardrails_with_lexical_slur(integration_client): body = response.json() assert body["success"] is True - assert body["data"][SAFE_TEXT_FIELD] == "this sentence contains [REDACTED_SLUR]" + assert body["data"][SAFE_TEXT_FIELD] == "this sentence contains [REDACTED_SLUR]." def test_input_guardrails_with_lexical_slur_clean_text(integration_client): @@ -123,7 +123,7 @@ def test_input_guardrails_with_multiple_validators(integration_client): assert body["success"] is True assert ( body["data"][SAFE_TEXT_FIELD] - == "this sentence contains [REDACTED_SLUR] cause i want a s done to kill the female foetus" + == "this sentence contains [REDACTED_SLUR] cause i want a s done to kill the female foetus." ) diff --git a/backend/app/tests/utils/constants.py b/backend/app/tests/utils/constants.py index 568bc19..e642c67 100644 --- a/backend/app/tests/utils/constants.py +++ b/backend/app/tests/utils/constants.py @@ -1,2 +1,2 @@ -VALIDATE_API_PATH = "/api/v1/guardrails/validate/" +VALIDATE_API_PATH = "/api/v1/guardrails/" SAFE_TEXT_FIELD = "safe_text" From f208c9ea6a6ac48d7e8421c6917a536d67f321a9 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Mon, 9 Feb 2026 15:56:29 +0530 Subject: [PATCH 2/3] resolved comment --- backend/app/api/routes/guardrails.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index d777120..27ade90 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -177,7 +177,7 @@ def add_validator_logs(guard: Guard, request_log_id: UUID, validator_log_crud: V for log in iteration.outputs.validator_logs: result = log.validation_result - if include_all_validator_logs == False and isinstance(result, PassResult): + if not include_all_validator_logs and isinstance(result, PassResult): continue error_message = None From abf3812bea7a47fa27433882c8e00d4f303a95b4 Mon Sep 17 00:00:00 2001 From: Kritika Rupauliha Date: Mon, 9 Feb 2026 21:28:33 +0530 Subject: [PATCH 3/3] Guardrails: Config Management (#30) --- .env.test | 31 -- .env.test.example | 2 +- .../alembic/versions/001_added_request_log.py | 6 +- .../versions/002_added_validator_log.py | 6 +- .../versions/003_added_validator_config.py | 50 +++ backend/app/api/main.py | 3 +- backend/app/api/routes/guardrails.py | 6 +- backend/app/api/routes/validator_configs.py | 94 ++++++ backend/app/core/constants.py | 9 + backend/app/core/enum.py | 12 +- backend/app/core/guardrail_controller.py | 2 +- backend/app/core/validators/__init__.py | 0 .../ban_list_safety_validator_config.py | 2 +- .../config}/base_validator_config.py | 0 ...assumption_bias_safety_validator_config.py | 2 +- .../lexical_slur_safety_validator_config.py | 2 +- .../pii_remover_safety_validator_config.py | 2 +- backend/app/crud/__init__.py | 2 +- backend/app/crud/request_log.py | 2 +- backend/app/crud/validator_config.py | 125 ++++++++ backend/app/crud/validator_log.py | 2 +- backend/app/models/__init__.py | 4 +- backend/app/models/config/validator_config.py | 88 +++++ .../logging/{request.py => request_log.py} | 0 .../{validator.py => validator_log.py} | 0 backend/app/schemas/__init__.py | 0 .../{models => schemas}/guardrail_config.py | 8 +- backend/app/schemas/validator_config.py | 34 ++ backend/app/tests/conftest.py | 67 ++-- backend/app/tests/test_validator_configs.py | 135 ++++++++ .../test_validator_configs_integration.py | 301 ++++++++++++++++++ backend/app/utils.py | 18 ++ 32 files changed, 921 insertions(+), 94 deletions(-) delete mode 100644 .env.test create mode 100644 backend/app/alembic/versions/003_added_validator_config.py create mode 100644 backend/app/api/routes/validator_configs.py create mode 100644 backend/app/core/validators/__init__.py rename backend/app/{models/validators => core/validators/config}/ban_list_safety_validator_config.py (81%) rename backend/app/{models => core/validators/config}/base_validator_config.py (100%) rename backend/app/{models/validators => core/validators/config}/gender_assumption_bias_safety_validator_config.py (86%) rename backend/app/{models/validators => core/validators/config}/lexical_slur_safety_validator_config.py (88%) rename backend/app/{models/validators => core/validators/config}/pii_remover_safety_validator_config.py (87%) create mode 100644 backend/app/crud/validator_config.py create mode 100644 backend/app/models/config/validator_config.py rename backend/app/models/logging/{request.py => request_log.py} (100%) rename backend/app/models/logging/{validator.py => validator_log.py} (100%) create mode 100644 backend/app/schemas/__init__.py rename backend/app/{models => schemas}/guardrail_config.py (64%) create mode 100644 backend/app/schemas/validator_config.py create mode 100644 backend/app/tests/test_validator_configs.py create mode 100644 backend/app/tests/test_validator_configs_integration.py diff --git a/.env.test b/.env.test deleted file mode 100644 index 40e9ad8..0000000 --- a/.env.test +++ /dev/null @@ -1,31 +0,0 @@ -DOMAIN=localhost - -ENVIRONMENT=testing - -PROJECT_NAME="Kaapi-Guardrails" -STACK_NAME=Kaapi-Guardrails - -# API Base URL for cron scripts (defaults to http://localhost:8000 if not set) -API_BASE_URL=http://localhost:8000 - -# Postgres -POSTGRES_SERVER=localhost -POSTGRES_PORT=5432 -POSTGRES_DB=kaapi-guardrails -POSTGRES_USER=postgres -POSTGRES_PASSWORD=postgres - -SENTRY_DSN= - -# Configure these with your own Docker registry images - -DOCKER_IMAGE_BACKEND=kaapi-guardrails-backend - -# Callback Timeouts (in seconds) -CALLBACK_CONNECT_TIMEOUT=3 -CALLBACK_READ_TIMEOUT=10 - -# require as a env if you want to use doc transformation -OPENAI_API_KEY="" -GUARDRAILS_HUB_API_KEY="" -AUTH_TOKEN="" \ No newline at end of file diff --git a/.env.test.example b/.env.test.example index 40e9ad8..368b7d7 100644 --- a/.env.test.example +++ b/.env.test.example @@ -10,8 +10,8 @@ API_BASE_URL=http://localhost:8000 # Postgres POSTGRES_SERVER=localhost +POSTGRES_DB=kaapi_guardrails_testing POSTGRES_PORT=5432 -POSTGRES_DB=kaapi-guardrails POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres diff --git a/backend/app/alembic/versions/001_added_request_log.py b/backend/app/alembic/versions/001_added_request_log.py index 706d504..d83577a 100644 --- a/backend/app/alembic/versions/001_added_request_log.py +++ b/backend/app/alembic/versions/001_added_request_log.py @@ -14,9 +14,9 @@ # revision identifiers, used by Alembic. revision: str = '001' -down_revision: Union[str, Sequence[str], None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str | None = None +branch_labels = None +depends_on = None def upgrade() -> None: diff --git a/backend/app/alembic/versions/002_added_validator_log.py b/backend/app/alembic/versions/002_added_validator_log.py index d46513f..e0b8115 100644 --- a/backend/app/alembic/versions/002_added_validator_log.py +++ b/backend/app/alembic/versions/002_added_validator_log.py @@ -14,9 +14,9 @@ # revision identifiers, used by Alembic. revision: str = '002' -down_revision: Union[str, Sequence[str], None] = '001' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str = '001' +branch_labels = None +depends_on = None def upgrade() -> None: diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py new file mode 100644 index 0000000..b50af01 --- /dev/null +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -0,0 +1,50 @@ +"""Added validator_config table + +Revision ID: 003 +Revises: 002 +Create Date: 2026-02-05 09:42:54.128852 + +""" +from typing import Sequence, Union + +from alembic import op +from sqlalchemy.dialects import postgresql +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '003' +down_revision: str = '002' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table('validator_config', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('organization_id', sa.Integer(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('type', sa.String(), nullable=False), + sa.Column('stage', sa.String(), nullable=False), + sa.Column('on_fail_action', sa.String(), nullable=False), + sa.Column( + "config", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text("'{}'::jsonb"), + ), + sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default=sa.true()), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('organization_id', 'project_id', 'type', 'stage', name='uq_validator_identity') + ) + + op.create_index("idx_validator_organization", "validator_config", ["organization_id"]) + op.create_index("idx_validator_project", "validator_config", ["project_id"]) + op.create_index("idx_validator_type", "validator_config", ["type"]) + op.create_index("idx_validator_stage", "validator_config", ["stage"]) + + +def downgrade() -> None: + op.drop_table('validator_config') diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 0fc8026..bf78ade 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,10 +1,11 @@ from fastapi import APIRouter -from app.api.routes import utils, guardrails +from app.api.routes import utils, guardrails, validator_configs api_router = APIRouter() api_router.include_router(utils.router) api_router.include_router(guardrails.router) +api_router.include_router(validator_configs.router) # if settings.ENVIRONMENT == "local": # api_router.include_router(private.router) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 27ade90..9c9c019 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -10,9 +10,9 @@ from app.core.guardrail_controller import build_guard, get_validator_config_models from app.crud.request_log import RequestLogCrud from app.crud.validator_log import ValidatorLogCrud -from app.models.guardrail_config import GuardrailRequest, GuardrailResponse -from app.models.logging.request import RequestLogUpdate, RequestStatus -from app.models.logging.validator import ValidatorLog, ValidatorOutcome +from app.models.logging.request_log import RequestLogUpdate, RequestStatus +from app.models.logging.validator_log import ValidatorLog, ValidatorOutcome +from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse from app.utils import APIResponse router = APIRouter(prefix="/guardrails", tags=["guardrails"]) diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py new file mode 100644 index 0000000..701e6a8 --- /dev/null +++ b/backend/app/api/routes/validator_configs.py @@ -0,0 +1,94 @@ +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter + +from app.api.deps import AuthDep, SessionDep +from app.core.enum import Stage, ValidatorType +from app.schemas.validator_config import ValidatorCreate, ValidatorResponse, ValidatorUpdate +from app.crud.validator_config import validator_config_crud +from app.utils import APIResponse + + +router = APIRouter( + prefix="/guardrails/validators/configs", + tags=["validator configs"], +) + + +@router.post( + "/", + response_model=APIResponse[ValidatorResponse] + ) +def create_validator( + payload: ValidatorCreate, + session: SessionDep, + organization_id: int, + project_id: int, + _: AuthDep, +): + response_model = validator_config_crud.create(session, organization_id, project_id, payload) + return APIResponse.success_response(data=response_model) + +@router.get( + "/", + response_model=APIResponse[list[ValidatorResponse]] + ) +def list_validators( + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, +): + response_model = validator_config_crud.list(session, organization_id, project_id, stage, type) + return APIResponse.success_response(data=response_model) + + +@router.get( + "/{id}", + response_model=APIResponse[ValidatorResponse] + ) +def get_validator( + id: UUID, + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = validator_config_crud.get(session, id, organization_id, project_id) + return APIResponse.success_response(data=validator_config_crud.flatten(obj)) + + +@router.patch( + "/{id}", + response_model=APIResponse[ValidatorResponse] + ) +def update_validator( + id: UUID, + organization_id: int, + project_id: int, + payload: ValidatorUpdate, + session: SessionDep, + _: AuthDep, +): + obj = validator_config_crud.get(session, id, organization_id, project_id) + response_model = validator_config_crud.update(session, obj, payload.model_dump(exclude_unset=True)) + return APIResponse.success_response(data=response_model) + + +@router.delete( + "/{id}", + response_model=APIResponse[dict] + ) +def delete_validator( + id: UUID, + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = validator_config_crud.get(session, id, organization_id, project_id) + validator_config_crud.delete(session, obj) + return APIResponse.success_response(data={"message": "Validator deleted successfully"}) diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index d6e3a7a..6c3825d 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -6,3 +6,12 @@ SCORE = "score" REPHRASE_ON_FAIL_PREFIX = "Please rephrase the query without unsafe content." + +VALIDATOR_CONFIG_SYSTEM_FIELDS = { + "organization_id", + "project_id", + "type", + "stage", + "on_fail_action", + "is_enabled", +} diff --git a/backend/app/core/enum.py b/backend/app/core/enum.py index 38418e9..6b8351f 100644 --- a/backend/app/core/enum.py +++ b/backend/app/core/enum.py @@ -15,4 +15,14 @@ class BiasCategories(Enum): class GuardrailOnFail(Enum): Exception = "exception" Fix = "fix" - Rephrase = "rephrase" \ No newline at end of file + Rephrase = "rephrase" + +class Stage(Enum): + Input = "input" + Output = "output" + +class ValidatorType(Enum): + LexicalSlur = "uli_slur_match" + PIIRemover = "pii_remover" + GenderAssumptionBias = "gender_assumption_bias" + BanList = "ban_list" diff --git a/backend/app/core/guardrail_controller.py b/backend/app/core/guardrail_controller.py index a935636..c4578e0 100644 --- a/backend/app/core/guardrail_controller.py +++ b/backend/app/core/guardrail_controller.py @@ -2,7 +2,7 @@ from guardrails import Guard -from app.models.guardrail_config import ValidatorConfigItem +from app.schemas.guardrail_config import ValidatorConfigItem def build_guard(validator_items): validators = [v_item.build() for v_item in validator_items] diff --git a/backend/app/core/validators/__init__.py b/backend/app/core/validators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/validators/ban_list_safety_validator_config.py b/backend/app/core/validators/config/ban_list_safety_validator_config.py similarity index 81% rename from backend/app/models/validators/ban_list_safety_validator_config.py rename to backend/app/core/validators/config/ban_list_safety_validator_config.py index 4a853f0..260399e 100644 --- a/backend/app/models/validators/ban_list_safety_validator_config.py +++ b/backend/app/core/validators/config/ban_list_safety_validator_config.py @@ -2,7 +2,7 @@ from guardrails.hub import BanList -from app.models.base_validator_config import BaseValidatorConfig +from app.core.validators.config.base_validator_config import BaseValidatorConfig class BanListSafetyValidatorConfig(BaseValidatorConfig): type: Literal["ban_list"] diff --git a/backend/app/models/base_validator_config.py b/backend/app/core/validators/config/base_validator_config.py similarity index 100% rename from backend/app/models/base_validator_config.py rename to backend/app/core/validators/config/base_validator_config.py diff --git a/backend/app/models/validators/gender_assumption_bias_safety_validator_config.py b/backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py similarity index 86% rename from backend/app/models/validators/gender_assumption_bias_safety_validator_config.py rename to backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py index 116c281..7cd3687 100644 --- a/backend/app/models/validators/gender_assumption_bias_safety_validator_config.py +++ b/backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py @@ -1,8 +1,8 @@ from typing import List, Literal, Optional -from app.models.base_validator_config import BaseValidatorConfig from app.core.enum import BiasCategories from app.core.validators.gender_assumption_bias import GenderAssumptionBias +from app.core.validators.config.base_validator_config import BaseValidatorConfig class GenderAssumptionBiasSafetyValidatorConfig(BaseValidatorConfig): type: Literal["gender_assumption_bias"] diff --git a/backend/app/models/validators/lexical_slur_safety_validator_config.py b/backend/app/core/validators/config/lexical_slur_safety_validator_config.py similarity index 88% rename from backend/app/models/validators/lexical_slur_safety_validator_config.py rename to backend/app/core/validators/config/lexical_slur_safety_validator_config.py index 6378182..d86c0d6 100644 --- a/backend/app/models/validators/lexical_slur_safety_validator_config.py +++ b/backend/app/core/validators/config/lexical_slur_safety_validator_config.py @@ -2,7 +2,7 @@ from app.core.enum import SlurSeverity from app.core.validators.lexical_slur import LexicalSlur -from app.models.base_validator_config import BaseValidatorConfig +from app.core.validators.config.base_validator_config import BaseValidatorConfig class LexicalSlurSafetyValidatorConfig(BaseValidatorConfig): type: Literal["uli_slur_match"] diff --git a/backend/app/models/validators/pii_remover_safety_validator_config.py b/backend/app/core/validators/config/pii_remover_safety_validator_config.py similarity index 87% rename from backend/app/models/validators/pii_remover_safety_validator_config.py rename to backend/app/core/validators/config/pii_remover_safety_validator_config.py index d8d3a18..fc18fa5 100644 --- a/backend/app/models/validators/pii_remover_safety_validator_config.py +++ b/backend/app/core/validators/config/pii_remover_safety_validator_config.py @@ -1,8 +1,8 @@ from __future__ import annotations from typing import List, Literal, Optional -from app.models.base_validator_config import BaseValidatorConfig from app.core.validators.pii_remover import PIIRemover +from app.core.validators.config.base_validator_config import BaseValidatorConfig class PIIRemoverSafetyValidatorConfig(BaseValidatorConfig): type: Literal["pii_remover"] diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index c955a67..e58e7bd 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -1 +1 @@ -from app.crud.request_log import RequestLogCrud \ No newline at end of file +from app.crud.request_log import RequestLogCrud diff --git a/backend/app/crud/request_log.py b/backend/app/crud/request_log.py index 74d5ece..1d75d87 100644 --- a/backend/app/crud/request_log.py +++ b/backend/app/crud/request_log.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.request import RequestLog, RequestLogUpdate, RequestStatus +from app.models.logging.request_log import RequestLog, RequestLogUpdate, RequestStatus from app.utils import now class RequestLogCrud: diff --git a/backend/app/crud/validator_config.py b/backend/app/crud/validator_config.py new file mode 100644 index 0000000..26dca48 --- /dev/null +++ b/backend/app/crud/validator_config.py @@ -0,0 +1,125 @@ +from typing import Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.core.enum import Stage, ValidatorType +from app.models.config.validator_config import ValidatorConfig +from app.schemas.validator_config import ValidatorCreate +from app.utils import now, split_validator_payload + + +class ValidatorConfigCrud: + def create( + self, + session: Session, + organization_id: int, + project_id: int, + payload: ValidatorCreate + ) -> dict: + data = payload.model_dump() + model_fields, config_fields = split_validator_payload(data) + + obj = ValidatorConfig( + organization_id=organization_id, + project_id=project_id, + config=config_fields, + **model_fields, + ) + + session.add(obj) + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Validator already exists for this type and stage", + ) + + session.refresh(obj) + return self.flatten(obj) + + def list( + self, + session: Session, + organization_id: int, + project_id: int, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, + ) -> list[dict]: + query = select(ValidatorConfig).where( + ValidatorConfig.organization_id == organization_id, + ValidatorConfig.project_id == project_id, + ) + + if stage: + query = query.where(ValidatorConfig.stage == stage) + + if type: + query = query.where(ValidatorConfig.type == type) + + rows = session.exec(query).all() + return [self.flatten(r) for r in rows] + + def get( + self, + session: Session, + id: UUID, + organization_id: int, + project_id: int, + ) -> ValidatorConfig: + obj = session.get(ValidatorConfig, id) + + if not obj or obj.organization_id != organization_id or obj.project_id != project_id: + raise HTTPException(404, "Validator not found") + + return obj + + def update( + self, + session: Session, + obj: ValidatorConfig, + update_data: dict + ) -> dict: + model_fields, config_fields = split_validator_payload(update_data) + + for k, v in model_fields.items(): + setattr(obj, k, v) + + if config_fields: + obj.config = {**(obj.config or {}), **config_fields} + + obj.updated_at = now() + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Validator already exists for this type and stage", + ) + except Exception: + session.rollback() + raise + + session.refresh(obj) + return self.flatten(obj) + + def delete(self, session: Session, obj: ValidatorConfig): + session.delete(obj) + try: + session.commit() + except Exception: + session.rollback() + raise + + def flatten(self, row: ValidatorConfig) -> dict: + base = row.model_dump(exclude={"config"}) + return {**base, **(row.config or {})} + + +validator_config_crud = ValidatorConfigCrud() diff --git a/backend/app/crud/validator_log.py b/backend/app/crud/validator_log.py index 6eb1c1a..3903129 100644 --- a/backend/app/crud/validator_log.py +++ b/backend/app/crud/validator_log.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.validator import ValidatorLog +from app.models.logging.validator_log import ValidatorLog from app.utils import now class ValidatorLogCrud: diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 5672003..d116b81 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,4 +1,4 @@ from sqlmodel import SQLModel -from app.models.logging.request import RequestLog -from app.models.logging.validator import ValidatorLog +from app.models.logging.request_log import RequestLog +from app.models.logging.validator_log import ValidatorLog diff --git a/backend/app/models/config/validator_config.py b/backend/app/models/config/validator_config.py new file mode 100644 index 0000000..021a41f --- /dev/null +++ b/backend/app/models/config/validator_config.py @@ -0,0 +1,88 @@ +from datetime import datetime +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy import Column, UniqueConstraint +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel, Field +import sqlalchemy as sa + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.utils import now + +class ValidatorConfig(SQLModel, table=True): + __tablename__ = "validator_config" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the validator configuration"}, + ) + + organization_id: int = Field( + index=True, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + type: ValidatorType = Field( + nullable=False, + sa_column_kwargs={"comment": "Type of the validator"}, + ) + + stage: Stage = Field( + nullable=False, + sa_column_kwargs={"comment": "Stage at which the validator is applied"}, + ) + + on_fail_action: GuardrailOnFail = Field( + default=GuardrailOnFail.Fix, + nullable=False, + sa_column_kwargs={"comment": "Action to take when the validator fails"}, + ) + + + config: dict[str, Any] = SQLField( + default_factory=dict, + sa_column=Column( + JSONB, + nullable=False, + server_default=sa.text("'{}'::jsonb"), + comment="Configuration for the validator", + ), + description=( + "Configuration for the validator" + ), + ) + + is_enabled: bool = Field( + default=True, + sa_column_kwargs={"comment": "Indicates if the validator is enabled"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the validator config was inserted"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={ + "comment": "Timestamp when the validator config was last updated", + "onupdate": now, + }, + ) + + __table_args__ = ( + UniqueConstraint( + "organization_id", "project_id", "type", "stage", + name="uq_validator_identity" + ), + ) diff --git a/backend/app/models/logging/request.py b/backend/app/models/logging/request_log.py similarity index 100% rename from backend/app/models/logging/request.py rename to backend/app/models/logging/request_log.py diff --git a/backend/app/models/logging/validator.py b/backend/app/models/logging/validator_log.py similarity index 100% rename from backend/app/models/logging/validator.py rename to backend/app/models/logging/validator_log.py diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/guardrail_config.py b/backend/app/schemas/guardrail_config.py similarity index 64% rename from backend/app/models/guardrail_config.py rename to backend/app/schemas/guardrail_config.py index bfe36a6..6dfde4c 100644 --- a/backend/app/models/guardrail_config.py +++ b/backend/app/schemas/guardrail_config.py @@ -5,10 +5,10 @@ # todo this could be improved by having some auto-discovery mechanism inside # validators. We'll not have to list every new validator like this. -from app.models.validators.ban_list_safety_validator_config import BanListSafetyValidatorConfig -from app.models.validators.gender_assumption_bias_safety_validator_config import GenderAssumptionBiasSafetyValidatorConfig -from app.models.validators.lexical_slur_safety_validator_config import LexicalSlurSafetyValidatorConfig -from app.models.validators.pii_remover_safety_validator_config import PIIRemoverSafetyValidatorConfig +from app.core.validators.config.ban_list_safety_validator_config import BanListSafetyValidatorConfig +from app.core.validators.config.gender_assumption_bias_safety_validator_config import GenderAssumptionBiasSafetyValidatorConfig +from app.core.validators.config.lexical_slur_safety_validator_config import LexicalSlurSafetyValidatorConfig +from app.core.validators.config.pii_remover_safety_validator_config import PIIRemoverSafetyValidatorConfig ValidatorConfigItem = Annotated[ # future validators will come here diff --git a/backend/app/schemas/validator_config.py b/backend/app/schemas/validator_config.py new file mode 100644 index 0000000..57c8852 --- /dev/null +++ b/backend/app/schemas/validator_config.py @@ -0,0 +1,34 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +from pydantic import ConfigDict +from sqlmodel import SQLModel + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType + + +class ValidatorBase(SQLModel): + model_config = ConfigDict(extra="allow") + + type: ValidatorType + stage: Stage + on_fail_action: GuardrailOnFail + is_enabled: bool = True + + +class ValidatorCreate(ValidatorBase): + pass + + +class ValidatorUpdate(SQLModel): + model_config = ConfigDict(extra="forbid") + + type: Optional[ValidatorType] = None + stage: Optional[Stage] = None + on_fail_action: Optional[GuardrailOnFail] = None + is_enabled: Optional[bool] = None + + +class ValidatorResponse(ValidatorBase): + pass diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index f76b33d..d595b97 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -1,50 +1,43 @@ +# conftest.py import os -from unittest.mock import MagicMock +os.environ["ENVIRONMENT"] = "testing" import pytest from fastapi.testclient import TestClient +from sqlmodel import Session, create_engine, SQLModel -# MUST be set before app import -os.environ["ENVIRONMENT"] = "testing" - -from app.api.deps import SessionDep, verify_bearer_token -from app.api.routes import guardrails from app.main import app +from app.api.deps import SessionDep, verify_bearer_token +from app.core.config import settings + +test_engine = create_engine( + str(settings.SQLALCHEMY_DATABASE_URI), + echo=False, + pool_pre_ping=True, +) + +def override_session(): + with Session(test_engine) as session: + yield session + +@pytest.fixture(scope="session", autouse=True) +def setup_test_db(): + SQLModel.metadata.create_all(test_engine) + yield + SQLModel.metadata.drop_all(test_engine) + +@pytest.fixture(scope="function", autouse=True) +def clean_db(): + with Session(test_engine) as session: + for table in reversed(SQLModel.metadata.sorted_tables): + session.execute(table.delete()) + session.commit() @pytest.fixture(scope="function", autouse=True) -def override_dependencies(monkeypatch): - """ - Override ALL external dependencies: - - Auth - - DB session - - CRUDs - """ - - # ---- Auth override ---- +def override_dependencies(): app.dependency_overrides[verify_bearer_token] = lambda: True - # ---- DB session override ---- - mock_session = MagicMock() - app.dependency_overrides[SessionDep] = lambda: mock_session - - # ---- CRUD override ---- - mock_request_log_crud = MagicMock() - mock_request_log_crud.create.return_value = MagicMock(id=1) - mock_request_log_crud.update.return_value = None - - mock_validator_log_crud = MagicMock() - mock_validator_log_crud.create.return_value = None - - monkeypatch.setattr( - guardrails, - "RequestLogCrud", - lambda session: mock_request_log_crud, - ) - monkeypatch.setattr( - guardrails, - "ValidatorLogCrud", - lambda session: mock_validator_log_crud, - ) + app.dependency_overrides[SessionDep] = override_session yield diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py new file mode 100644 index 0000000..017da20 --- /dev/null +++ b/backend/app/tests/test_validator_configs.py @@ -0,0 +1,135 @@ +import uuid +from unittest.mock import MagicMock + +import pytest +from sqlmodel import Session + +from app.crud.validator_config import validator_config_crud +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.models.config.validator_config import ValidatorConfig + +# Test data constants +TEST_ORGANIZATION_ID = 1 +TEST_PROJECT_ID = 1 +TEST_VALIDATOR_ID = uuid.uuid4() +TEST_TYPE = ValidatorType.LexicalSlur +TEST_STAGE = Stage.Input +TEST_ON_FAIL = GuardrailOnFail.Fix + +@pytest.fixture +def mock_session(): + """Create a mock session for database operations.""" + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_validator(): + """Create a sample validator config for testing.""" + return ValidatorConfig( + id=TEST_VALIDATOR_ID, + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config={"severity": "all", "languages": ["en", "hi"]}, + ) + + +class TestFlatten: + def test_flatten_includes_config_fields(self, sample_validator): + result = validator_config_crud.flatten(sample_validator) + + assert result["severity"] == "all" + assert result["languages"] == ["en", "hi"] + assert result["id"] == TEST_VALIDATOR_ID + + def test_flatten_empty_config(self): + validator = ValidatorConfig( + id=TEST_VALIDATOR_ID, + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config={}, + ) + + result = validator_config_crud.flatten(validator) + + assert "severity" not in result + + +class TestGetOr404: + def test_success(self, sample_validator, mock_session): + mock_session.get.return_value = sample_validator + + result = validator_config_crud.get( + mock_session, + TEST_VALIDATOR_ID, + TEST_ORGANIZATION_ID, + TEST_PROJECT_ID, + ) + + assert result == sample_validator + mock_session.get.assert_called_once() + + def test_not_found(self, mock_session): + mock_session.get.return_value = None + + with pytest.raises(Exception) as exc: + validator_config_crud.get( + mock_session, + TEST_VALIDATOR_ID, + TEST_ORGANIZATION_ID, + TEST_PROJECT_ID, + ) + + assert "Validator not found" in str(exc.value) + + +class TestUpdate: + def test_update_base_fields(self, sample_validator, mock_session): + update_data = { + "type": ValidatorType.PIIRemover, + "on_fail_action": GuardrailOnFail.Exception, + } + + result = validator_config_crud.update( + mock_session, + sample_validator, + update_data, + ) + + assert result["type"] == ValidatorType.PIIRemover + assert result["on_fail_action"] == GuardrailOnFail.Exception + + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + def test_update_extra_fields(self, sample_validator, mock_session): + update_data = {"severity": "high", "new_field": "new_value"} + + result = validator_config_crud.update( + mock_session, + sample_validator, + update_data, + ) + + assert result["severity"] == "high" + assert result["new_field"] == "new_value" + assert result["languages"] == ["en", "hi"] + + def test_merge_config(self, sample_validator, mock_session): + sample_validator.config = {"severity": "all", "languages": ["en"]} + + result = validator_config_crud.update( + mock_session, + sample_validator, + {"languages": ["en", "hi"]}, + ) + + assert result["languages"] == ["en", "hi"] + assert result["severity"] == "all" diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py new file mode 100644 index 0000000..bb2de29 --- /dev/null +++ b/backend/app/tests/test_validator_configs_integration.py @@ -0,0 +1,301 @@ +import uuid + +import pytest +from sqlmodel import Session, delete + +from app.core.db import engine +from app.models.config.validator_config import ValidatorConfig + +pytestmark = pytest.mark.integration + +# Test data constants +TEST_ORGANIZATION_ID = 1 +TEST_PROJECT_ID = 1 +BASE_URL = "/api/v1/guardrails/validators/configs/" +DEFAULT_QUERY_PARAMS = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" + +VALIDATOR_PAYLOADS = { + "lexical_slur": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + "severity": "all", + "languages": ["en", "hi"], + }, + "pii_remover_input": { + "type": "pii_remover", + "stage": "input", + "on_fail_action": "fix", + }, + "pii_remover_output": { + "type": "pii_remover", + "stage": "output", + "on_fail_action": "fix", + }, + "minimal": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + }, +} + + +@pytest.fixture +def clear_database(): + """Clear ValidatorConfig table before and after each test.""" + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + yield + + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + +class BaseValidatorTest: + """Base class with helper methods for validator tests.""" + + def create_validator(self, client, payload_key="minimal", **kwargs): + """Helper to create a validator.""" + payload = {**VALIDATOR_PAYLOADS[payload_key], **kwargs} + return client.post(f"{BASE_URL}{DEFAULT_QUERY_PARAMS}", json=payload) + + def get_validator(self, client, validator_id): + """Helper to get a specific validator.""" + return client.get(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}") + + def list_validators(self, client, **query_params): + """Helper to list validators with optional filters.""" + params_str = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" + if query_params: + params_str += "&" + "&".join(f"{k}={v}" for k, v in query_params.items()) + return client.get(f"{BASE_URL}{params_str}") + + def update_validator(self, client, validator_id, payload): + """Helper to update a validator.""" + return client.patch(f"{BASE_URL}{validator_id}{DEFAULT_QUERY_PARAMS}", json=payload) + + def delete_validator(self, client, validator_id): + """Helper to delete a validator.""" + return client.delete(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}") + + +class TestCreateValidator(BaseValidatorTest): + """Tests for POST /guardrails/validators/configs endpoint.""" + + def test_create_validator_success(self, integration_client, clear_database): + """Test successful validator creation.""" + response = self.create_validator(integration_client, "lexical_slur") + + assert response.status_code == 200 + data = response.json()["data"] + assert data["type"] == "uli_slur_match" + assert data["stage"] == "input" + assert data["severity"] == "all" + assert data["languages"] == ["en", "hi"] + assert "id" in data + + def test_create_validator_duplicate_raises_400(self, integration_client, clear_database): + """Test that creating duplicate validator raises 400.""" + # First request should succeed + response1 = self.create_validator(integration_client, "minimal") + assert response1.status_code == 200 + + # Second request with same unique keys should fail + response2 = self.create_validator(integration_client, "minimal") + assert response2.status_code == 400 + + def test_create_validator_missing_required_fields(self, integration_client, clear_database): + """Test that missing required fields returns validation error.""" + response = integration_client.post( + f"{BASE_URL}{DEFAULT_QUERY_PARAMS}", + json={"type": "uli_slur_match"}, + ) + + assert response.status_code == 422 + + +class TestListValidators(BaseValidatorTest): + """Tests for GET /guardrails/validators/configs endpoint.""" + + def test_list_validators_success(self, integration_client, clear_database): + """Test successful validator listing.""" + # Create validators first + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_input") + + response = self.list_validators(integration_client) + + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + + def test_list_validators_filter_by_stage(self, integration_client, clear_database): + """Test filtering validators by stage.""" + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_output") + + response = self.list_validators(integration_client, stage="input") + + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 1 + assert data[0]["stage"] == "input" + + def test_list_validators_filter_by_type(self, integration_client, clear_database): + """Test filtering validators by type.""" + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_input") + + response = self.list_validators(integration_client, type="pii_remover") + + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 1 + assert data[0]["type"] == "pii_remover" + + def test_list_validators_empty(self, integration_client, clear_database): + """Test listing validators when none exist.""" + response = integration_client.get( + f"{BASE_URL}?organization_id=999&project_id=999", + ) + + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 0 + + +class TestGetValidator(BaseValidatorTest): + """Tests for GET /guardrails/validators/configs/{id} endpoint.""" + + def test_get_validator_success(self, integration_client, clear_database): + """Test successful validator retrieval.""" + # Create a validator + create_response = self.create_validator( + integration_client, "lexical_slur", severity="all" + ) + validator_id = create_response.json()["data"]["id"] + + # Retrieve it + response = self.get_validator(integration_client, validator_id) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["id"] == validator_id + assert data["severity"] == "all" + + def test_get_validator_not_found(self, integration_client, clear_database): + """Test retrieving non-existent validator returns 404.""" + fake_id = uuid.uuid4() + response = self.get_validator(integration_client, fake_id) + + assert response.status_code == 404 + + def test_get_validator_wrong_org(self, integration_client, clear_database): + """Test that accessing validator from different org returns 404.""" + # Create a validator for org 1 + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["data"]["id"] + + # Try to access it as different org + response = integration_client.get( + f"{BASE_URL}{validator_id}/?organization_id=2&project_id=1", + ) + + assert response.status_code == 404 + + +class TestUpdateValidator(BaseValidatorTest): + """Tests for PATCH /guardrails/validators/configs/{id} endpoint.""" + + def test_update_validator_success(self, integration_client, clear_database): + """Test successful validator update.""" + # Create a validator + create_response = self.create_validator( + integration_client, "lexical_slur", severity="all" + ) + validator_id = create_response.json()["data"]["id"] + + # Update it + update_payload = {"on_fail_action": "exception", "is_enabled": False} + response = self.update_validator(integration_client, validator_id, update_payload) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["on_fail_action"] == "exception" + assert data["is_enabled"] is False + + def test_update_validator_partial(self, integration_client, clear_database): + """Test partial update preserves original fields.""" + # Create a validator + create_response = self.create_validator( + integration_client, + "lexical_slur", + severity="all", + languages=["en", "hi"], + ) + validator_id = create_response.json()["data"]["id"] + + # Update only one field + update_payload = {"is_enabled": False} + response = self.update_validator(integration_client, validator_id, update_payload) + + assert response.status_code == 200 + data = response.json()["data"] + assert data["is_enabled"] is False + assert data["on_fail_action"] == "fix" # Original preserved + + def test_update_validator_not_found(self, integration_client, clear_database): + """Test updating non-existent validator returns 404.""" + fake_id = uuid.uuid4() + update_payload = {"is_enabled": False} + + response = self.update_validator(integration_client, fake_id, update_payload) + + assert response.status_code == 404 + + +class TestDeleteValidator(BaseValidatorTest): + """Tests for DELETE /guardrails/validators/configs/{id} endpoint.""" + + def test_delete_validator_success(self, integration_client, clear_database): + """Test successful validator deletion.""" + # Create a validator + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["data"]["id"] + + # Delete it + response = self.delete_validator(integration_client, validator_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + + # Verify it's deleted + get_response = self.get_validator(integration_client, validator_id) + assert get_response.status_code == 404 + + def test_delete_validator_not_found(self, integration_client, clear_database): + """Test deleting non-existent validator returns 404.""" + fake_id = uuid.uuid4() + response = self.delete_validator(integration_client, fake_id) + + assert response.status_code == 404 + + def test_delete_validator_wrong_org(self, integration_client, clear_database): + """Test that deleting validator from different org returns 404.""" + # Create a validator for org 1 + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["data"]["id"] + + # Try to delete it as different org + response = integration_client.delete( + f"{BASE_URL}{validator_id}/?organization_id=2&project_id=1", + ) + + assert response.status_code == 404 + + # Verify original is still there + get_response = self.get_validator(integration_client, validator_id) + assert get_response.status_code == 200 diff --git a/backend/app/utils.py b/backend/app/utils.py index 4e10f52..3a1def6 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -3,6 +3,8 @@ from pydantic import BaseModel from typing import Any, Dict, Generic, Optional, TypeVar +from app.core.constants import VALIDATOR_CONFIG_SYSTEM_FIELDS as SYSTEM_FIELDS + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -11,6 +13,22 @@ def now(): return datetime.now(timezone.utc).replace(tzinfo=None) +def split_validator_payload(data: dict): + model_fields = {} + config_fields = {} + + for key, value in data.items(): + if key in SYSTEM_FIELDS: + model_fields[key] = value + else: + config_fields[key] = value + + overlap = set(model_fields) & set(config_fields) + if overlap: + raise ValueError(f"Config keys conflict with reserved field names: {overlap}") + + return model_fields, config_fields + class APIResponse(BaseModel, Generic[T]): success: bool data: Optional[T] = None