From a537bf03a02d6affc9f60a70997c58557a31cc49 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Wed, 4 Feb 2026 14:04:22 +0530 Subject: [PATCH 01/11] Rearranged files in folder --- backend/app/api/routes/guardrails.py | 10 +++++----- backend/app/core/enum.py | 12 +++++++++++- backend/app/core/guardrail_controller.py | 2 +- backend/app/core/validators/__init__.py | 0 .../config}/ban_list_safety_validator_config.py | 2 +- .../validators/config}/base_validator_config.py | 0 ...gender_assumption_bias_safety_validator_config.py | 2 +- .../config}/lexical_slur_safety_validator_config.py | 2 +- .../config}/pii_remover_safety_validator_config.py | 2 +- backend/app/crud/__init__.py | 2 +- .../app/crud/{request_log.py => request_log_repo.py} | 2 +- .../crud/{validator_log.py => validator_log_repo.py} | 2 +- backend/app/models/__init__.py | 4 ++-- .../logging/{request.py => request_log_table.py} | 0 .../logging/{validator.py => validator_log_table.py} | 0 backend/app/schemas/__init__.py | 0 backend/app/{models => schemas}/guardrail_config.py | 8 ++++---- 17 files changed, 30 insertions(+), 20 deletions(-) create mode 100644 backend/app/core/validators/__init__.py rename backend/app/{models/validators => core/validators/config}/ban_list_safety_validator_config.py (81%) rename backend/app/{models => core/validators/config}/base_validator_config.py (100%) rename backend/app/{models/validators => core/validators/config}/gender_assumption_bias_safety_validator_config.py (86%) rename backend/app/{models/validators => core/validators/config}/lexical_slur_safety_validator_config.py (88%) rename backend/app/{models/validators => core/validators/config}/pii_remover_safety_validator_config.py (87%) rename backend/app/crud/{request_log.py => request_log_repo.py} (93%) rename backend/app/crud/{validator_log.py => validator_log_repo.py} (85%) rename backend/app/models/logging/{request.py => request_log_table.py} (100%) rename backend/app/models/logging/{validator.py => validator_log_table.py} (100%) create mode 100644 backend/app/schemas/__init__.py rename backend/app/{models => schemas}/guardrail_config.py (64%) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 750ac71..90f7b3c 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -8,11 +8,11 @@ from app.api.deps import AuthDep, SessionDep from app.core.constants import REPHRASE_ON_FAIL_PREFIX from app.core.guardrail_controller import build_guard, get_validator_config_models -from app.crud.request_log import RequestLogCrud -from app.crud.validator_log import ValidatorLogCrud -from app.models.guardrail_config import GuardrailRequest, GuardrailResponse -from app.models.logging.request import RequestLogUpdate, RequestStatus -from app.models.logging.validator import ValidatorLog, ValidatorOutcome +from app.crud.request_log_repo import RequestLogCrud +from app.crud.validator_log_repo import ValidatorLogCrud +from app.models.logging.request_log_table import RequestLogUpdate, RequestStatus +from app.models.logging.validator_log_table import ValidatorLog, ValidatorOutcome +from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse from app.utils import APIResponse router = APIRouter(prefix="/guardrails", tags=["guardrails"]) diff --git a/backend/app/core/enum.py b/backend/app/core/enum.py index 38418e9..6b8351f 100644 --- a/backend/app/core/enum.py +++ b/backend/app/core/enum.py @@ -15,4 +15,14 @@ class BiasCategories(Enum): class GuardrailOnFail(Enum): Exception = "exception" Fix = "fix" - Rephrase = "rephrase" \ No newline at end of file + Rephrase = "rephrase" + +class Stage(Enum): + Input = "input" + Output = "output" + +class ValidatorType(Enum): + LexicalSlur = "uli_slur_match" + PIIRemover = "pii_remover" + GenderAssumptionBias = "gender_assumption_bias" + BanList = "ban_list" diff --git a/backend/app/core/guardrail_controller.py b/backend/app/core/guardrail_controller.py index a935636..c4578e0 100644 --- a/backend/app/core/guardrail_controller.py +++ b/backend/app/core/guardrail_controller.py @@ -2,7 +2,7 @@ from guardrails import Guard -from app.models.guardrail_config import ValidatorConfigItem +from app.schemas.guardrail_config import ValidatorConfigItem def build_guard(validator_items): validators = [v_item.build() for v_item in validator_items] diff --git a/backend/app/core/validators/__init__.py b/backend/app/core/validators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/validators/ban_list_safety_validator_config.py b/backend/app/core/validators/config/ban_list_safety_validator_config.py similarity index 81% rename from backend/app/models/validators/ban_list_safety_validator_config.py rename to backend/app/core/validators/config/ban_list_safety_validator_config.py index 4a853f0..260399e 100644 --- a/backend/app/models/validators/ban_list_safety_validator_config.py +++ b/backend/app/core/validators/config/ban_list_safety_validator_config.py @@ -2,7 +2,7 @@ from guardrails.hub import BanList -from app.models.base_validator_config import BaseValidatorConfig +from app.core.validators.config.base_validator_config import BaseValidatorConfig class BanListSafetyValidatorConfig(BaseValidatorConfig): type: Literal["ban_list"] diff --git a/backend/app/models/base_validator_config.py b/backend/app/core/validators/config/base_validator_config.py similarity index 100% rename from backend/app/models/base_validator_config.py rename to backend/app/core/validators/config/base_validator_config.py diff --git a/backend/app/models/validators/gender_assumption_bias_safety_validator_config.py b/backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py similarity index 86% rename from backend/app/models/validators/gender_assumption_bias_safety_validator_config.py rename to backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py index 116c281..7cd3687 100644 --- a/backend/app/models/validators/gender_assumption_bias_safety_validator_config.py +++ b/backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py @@ -1,8 +1,8 @@ from typing import List, Literal, Optional -from app.models.base_validator_config import BaseValidatorConfig from app.core.enum import BiasCategories from app.core.validators.gender_assumption_bias import GenderAssumptionBias +from app.core.validators.config.base_validator_config import BaseValidatorConfig class GenderAssumptionBiasSafetyValidatorConfig(BaseValidatorConfig): type: Literal["gender_assumption_bias"] diff --git a/backend/app/models/validators/lexical_slur_safety_validator_config.py b/backend/app/core/validators/config/lexical_slur_safety_validator_config.py similarity index 88% rename from backend/app/models/validators/lexical_slur_safety_validator_config.py rename to backend/app/core/validators/config/lexical_slur_safety_validator_config.py index 6378182..d86c0d6 100644 --- a/backend/app/models/validators/lexical_slur_safety_validator_config.py +++ b/backend/app/core/validators/config/lexical_slur_safety_validator_config.py @@ -2,7 +2,7 @@ from app.core.enum import SlurSeverity from app.core.validators.lexical_slur import LexicalSlur -from app.models.base_validator_config import BaseValidatorConfig +from app.core.validators.config.base_validator_config import BaseValidatorConfig class LexicalSlurSafetyValidatorConfig(BaseValidatorConfig): type: Literal["uli_slur_match"] diff --git a/backend/app/models/validators/pii_remover_safety_validator_config.py b/backend/app/core/validators/config/pii_remover_safety_validator_config.py similarity index 87% rename from backend/app/models/validators/pii_remover_safety_validator_config.py rename to backend/app/core/validators/config/pii_remover_safety_validator_config.py index d8d3a18..fc18fa5 100644 --- a/backend/app/models/validators/pii_remover_safety_validator_config.py +++ b/backend/app/core/validators/config/pii_remover_safety_validator_config.py @@ -1,8 +1,8 @@ from __future__ import annotations from typing import List, Literal, Optional -from app.models.base_validator_config import BaseValidatorConfig from app.core.validators.pii_remover import PIIRemover +from app.core.validators.config.base_validator_config import BaseValidatorConfig class PIIRemoverSafetyValidatorConfig(BaseValidatorConfig): type: Literal["pii_remover"] diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index c955a67..37bc215 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -1 +1 @@ -from app.crud.request_log import RequestLogCrud \ No newline at end of file +from app.crud.request_log_repo import RequestLogCrud \ No newline at end of file diff --git a/backend/app/crud/request_log.py b/backend/app/crud/request_log_repo.py similarity index 93% rename from backend/app/crud/request_log.py rename to backend/app/crud/request_log_repo.py index 74d5ece..9d3b1e5 100644 --- a/backend/app/crud/request_log.py +++ b/backend/app/crud/request_log_repo.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.request import RequestLog, RequestLogUpdate, RequestStatus +from app.models.logging.request_log_table import RequestLog, RequestLogUpdate, RequestStatus from app.utils import now class RequestLogCrud: diff --git a/backend/app/crud/validator_log.py b/backend/app/crud/validator_log_repo.py similarity index 85% rename from backend/app/crud/validator_log.py rename to backend/app/crud/validator_log_repo.py index 6eb1c1a..649d6aa 100644 --- a/backend/app/crud/validator_log.py +++ b/backend/app/crud/validator_log_repo.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.validator import ValidatorLog +from app.models.logging.validator_log_table import ValidatorLog from app.utils import now class ValidatorLogCrud: diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 5672003..2ba735f 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,4 +1,4 @@ from sqlmodel import SQLModel -from app.models.logging.request import RequestLog -from app.models.logging.validator import ValidatorLog +from app.models.logging.request_log_table import RequestLog +from app.models.logging.validator_log_table import ValidatorLog diff --git a/backend/app/models/logging/request.py b/backend/app/models/logging/request_log_table.py similarity index 100% rename from backend/app/models/logging/request.py rename to backend/app/models/logging/request_log_table.py diff --git a/backend/app/models/logging/validator.py b/backend/app/models/logging/validator_log_table.py similarity index 100% rename from backend/app/models/logging/validator.py rename to backend/app/models/logging/validator_log_table.py diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/guardrail_config.py b/backend/app/schemas/guardrail_config.py similarity index 64% rename from backend/app/models/guardrail_config.py rename to backend/app/schemas/guardrail_config.py index bfe36a6..6dfde4c 100644 --- a/backend/app/models/guardrail_config.py +++ b/backend/app/schemas/guardrail_config.py @@ -5,10 +5,10 @@ # todo this could be improved by having some auto-discovery mechanism inside # validators. We'll not have to list every new validator like this. -from app.models.validators.ban_list_safety_validator_config import BanListSafetyValidatorConfig -from app.models.validators.gender_assumption_bias_safety_validator_config import GenderAssumptionBiasSafetyValidatorConfig -from app.models.validators.lexical_slur_safety_validator_config import LexicalSlurSafetyValidatorConfig -from app.models.validators.pii_remover_safety_validator_config import PIIRemoverSafetyValidatorConfig +from app.core.validators.config.ban_list_safety_validator_config import BanListSafetyValidatorConfig +from app.core.validators.config.gender_assumption_bias_safety_validator_config import GenderAssumptionBiasSafetyValidatorConfig +from app.core.validators.config.lexical_slur_safety_validator_config import LexicalSlurSafetyValidatorConfig +from app.core.validators.config.pii_remover_safety_validator_config import PIIRemoverSafetyValidatorConfig ValidatorConfigItem = Annotated[ # future validators will come here From 4174e90b6d656deb72cec24e6af364563a4e33ee Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Wed, 4 Feb 2026 17:12:01 +0530 Subject: [PATCH 02/11] Added validator config management changes --- backend/app/api/main.py | 3 +- backend/app/api/routes/validator_configs.py | 143 ++++++++++++++++++ backend/app/core/constants.py | 9 ++ .../models/config/validator_config_table.py | 76 ++++++++++ backend/app/schemas/validator_config.py | 35 +++++ backend/app/utils.py | 15 ++ 6 files changed, 280 insertions(+), 1 deletion(-) create mode 100644 backend/app/api/routes/validator_configs.py create mode 100644 backend/app/models/config/validator_config_table.py create mode 100644 backend/app/schemas/validator_config.py diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 0fc8026..bf78ade 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,10 +1,11 @@ from fastapi import APIRouter -from app.api.routes import utils, guardrails +from app.api.routes import utils, guardrails, validator_configs api_router = APIRouter() api_router.include_router(utils.router) api_router.include_router(guardrails.router) +api_router.include_router(validator_configs.router) # if settings.ENVIRONMENT == "local": # api_router.include_router(private.router) diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py new file mode 100644 index 0000000..3fb440a --- /dev/null +++ b/backend/app/api/routes/validator_configs.py @@ -0,0 +1,143 @@ +from typing import List, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import select + +from app.api.deps import AuthDep, SessionDep +from app.models.config.validator_config_table import ValidatorConfig +from app.schemas.validator_config import * +from app.utils import split_validator_payload + +router = APIRouter(prefix="/guardrails/validators/configs", tags=["validator configs"]) + + +@router.post( + "/", + response_model=ValidatorResponse + ) +async def create_validator( + payload: ValidatorCreate, + session: SessionDep, + org_id: int, + project_id: int, + _: AuthDep, +): + data = payload.model_dump() + base, config = split_validator_payload(data) + obj = ValidatorConfig( + org_id=org_id, + project_id=project_id, + config=config, + **base, + ) + + session.add(obj) + session.commit() + session.refresh(obj) + return obj + +@router.get( + "/", + response_model=List[ValidatorResponse] + ) +async def list_validators( + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, +): + query = select(ValidatorConfig).where( + ValidatorConfig.org_id == org_id, + ValidatorConfig.project_id == project_id + ) + + if stage: + query = query.where(ValidatorConfig.stage == stage) + + if type: + query = query.where(ValidatorConfig.type == type) + + rows = session.exec(query).all() + return [flatten_validator(r) for r in rows] + + +@router.get( + "/{id}", + response_model=ValidatorResponse + ) +async def get_validator( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = session.get(ValidatorConfig, id) + + if not obj or obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(404) + + return flatten_validator(obj) + + +@router.patch( + "/{id}", + response_model=ValidatorResponse + ) +async def update_validator( + id: UUID, + org_id: int, + project_id: int, + payload: ValidatorUpdate, + session: SessionDep, + _: AuthDep, +): + obj = session.get(ValidatorConfig, id) + + if not obj or obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(404) + + data = payload.model_dump(exclude_unset=True) + base, config = split_validator_payload(data) + + print("base", base) + print("config", config) + for k, v in base.items(): + setattr(obj, k, v) + + if config: + obj.config = {**(obj.config or {}), **config} + + session.add(obj) + session.commit() + session.refresh(obj) + + return flatten_validator(obj) + + +@router.delete("/{id}") +async def delete_validator( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = session.get(ValidatorConfig, id) + + if not obj or obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(404) + + session.delete(obj) + session.commit() + + return {"success": True} + +def flatten_validator(row: ValidatorConfig) -> dict: + base = row.model_dump(exclude={"config"}) + + print(base) + return {**base, **(row.config or {})} diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index d6e3a7a..115ad21 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -6,3 +6,12 @@ SCORE = "score" REPHRASE_ON_FAIL_PREFIX = "Please rephrase the query without unsafe content." + +VALIDATOR_CONFIG_SYSTEM_FIELDS = { + "org_id", + "project_id", + "type", + "stage", + "on_fail_action", + "is_enabled", +} diff --git a/backend/app/models/config/validator_config_table.py b/backend/app/models/config/validator_config_table.py new file mode 100644 index 0000000..8e030e0 --- /dev/null +++ b/backend/app/models/config/validator_config_table.py @@ -0,0 +1,76 @@ +from datetime import datetime +from typing import Any, Optional +from uuid import UUID, uuid4 + +from sqlalchemy import Column +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel, Field + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.utils import now + +class ValidatorConfig(SQLModel, table=True): + __tablename__ = "validator_config" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the validator configuration"}, + ) + + org_id: int = Field( + index=True, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: Optional[int] = Field( + default=None, + index=True, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + type: ValidatorType = Field( + nullable=False, + sa_column_kwargs={"comment": "Type of the validator"}, + ) + + stage: Stage = Field( + nullable=False, + sa_column_kwargs={"comment": "Stage at which the validator is applied"}, + ) + + on_fail_action: GuardrailOnFail = Field( + default=GuardrailOnFail.Fix, + nullable=False, + sa_column_kwargs={"comment": "Action to take when the validator fails"}, + ) + + config: dict[str, Any] = SQLField( + default_factory=dict, + sa_column=Column( + JSONB, + nullable=False, + comment="Configuration for the validator", + ), + description=( + "Configuration for the validator" + ), + ) + + is_enabled: bool = Field( + default=True, + sa_column_kwargs={"comment": "Indicates if the validator is enabled"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the validator config was inserted"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the validator config was last updated"}, + ) diff --git a/backend/app/schemas/validator_config.py b/backend/app/schemas/validator_config.py new file mode 100644 index 0000000..4452e24 --- /dev/null +++ b/backend/app/schemas/validator_config.py @@ -0,0 +1,35 @@ +from typing import Optional +from uuid import UUID, uuid4 + +from sqlmodel import SQLModel + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType + + +class ValidatorBase(SQLModel): + model_config = {"extra": "allow"} + + type: ValidatorType + stage: Stage + on_fail_action: GuardrailOnFail + is_enabled: bool = True + + +class ValidatorCreate(ValidatorBase): + pass + + +class ValidatorUpdate(SQLModel): + # also allow extras for partial updates + model_config = {"extra": "allow"} + + type: Optional[ValidatorType] = None + stage: Optional[Stage] = None + on_fail_action: Optional[GuardrailOnFail] = None + is_enabled: Optional[bool] = None + + +class ValidatorResponse(ValidatorBase): + id: UUID + org_id: int + project_id: Optional[int] = None diff --git a/backend/app/utils.py b/backend/app/utils.py index 4e10f52..30543d5 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -3,6 +3,8 @@ from pydantic import BaseModel from typing import Any, Dict, Generic, Optional, TypeVar +from app.core.constants import VALIDATOR_CONFIG_SYSTEM_FIELDS as SYSTEM_FIELDS + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -11,6 +13,19 @@ def now(): return datetime.now(timezone.utc).replace(tzinfo=None) +def split_validator_payload(data: dict): + base = {} + config = {} + + for k, v in data.items(): + if k in SYSTEM_FIELDS: + base[k] = v + else: + config[k] = v + + return base, config + + class APIResponse(BaseModel, Generic[T]): success: bool data: Optional[T] = None From f691bcb92b77ffc3849014cc49085b6ebb09f164 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 12:44:46 +0530 Subject: [PATCH 03/11] Updated validator config management API --- .../versions/003_added_validator_config.py | 50 +++++++++ backend/app/api/routes/validator_configs.py | 105 ++++++++++++------ .../models/config/validator_config_table.py | 16 ++- backend/app/schemas/validator_config.py | 4 +- 4 files changed, 131 insertions(+), 44 deletions(-) create mode 100644 backend/app/alembic/versions/003_added_validator_config.py diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py new file mode 100644 index 0000000..2ee47d3 --- /dev/null +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -0,0 +1,50 @@ +"""Added validator_config table + +Revision ID: 003 +Revises: 001 +Create Date: 2026-02-05 09:42:54.128852 + +""" +from typing import Sequence, Union + +from alembic import op +from sqlalchemy.dialects import postgresql +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '003' +down_revision: Union[str, Sequence[str], None] = "002" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('validator_config', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('org_id', sa.Integer(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('type', sa.String(), nullable=False), + sa.Column('stage', sa.String(), nullable=False), + sa.Column('on_fail_action', sa.String(), nullable=False), + sa.Column( + "config", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text("'{}'::jsonb"), + ), + sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default=sa.true()), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('org_id', 'project_id', 'type', 'stage', name='uq_validator_identity') + ) + + op.create_index("idx_validator_org", "validator_config", ["org_id"]) + op.create_index("idx_validator_project", "validator_config", ["project_id"]) + op.create_index("idx_validator_type", "validator_config", ["type"]) + op.create_index("idx_validator_stage", "validator_config", ["stage"]) + + +def downgrade() -> None: + op.drop_table('validator_config') diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index 3fb440a..e35ce1c 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -1,7 +1,8 @@ from typing import List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, HTTPException +from sqlalchemy.exc import IntegrityError from sqlmodel import select from app.api.deps import AuthDep, SessionDep @@ -33,9 +34,18 @@ async def create_validator( ) session.add(obj) - session.commit() + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + status_code=400, + detail="Validator already exists for this type and stage", + ) + session.refresh(obj) - return obj + return flatten_validator(obj) @router.get( "/", @@ -75,11 +85,7 @@ async def get_validator( session: SessionDep, _: AuthDep, ): - obj = session.get(ValidatorConfig, id) - - if not obj or obj.org_id != org_id or obj.project_id != project_id: - raise HTTPException(404) - + obj = get_validator_or_404(id, org_id, project_id, session) return flatten_validator(obj) @@ -95,27 +101,13 @@ async def update_validator( session: SessionDep, _: AuthDep, ): - obj = session.get(ValidatorConfig, id) - - if not obj or obj.org_id != org_id or obj.project_id != project_id: - raise HTTPException(404) - - data = payload.model_dump(exclude_unset=True) - base, config = split_validator_payload(data) - - print("base", base) - print("config", config) - for k, v in base.items(): - setattr(obj, k, v) - - if config: - obj.config = {**(obj.config or {}), **config} - - session.add(obj) - session.commit() - session.refresh(obj) - - return flatten_validator(obj) + obj = get_validator_or_404(id, org_id, project_id, session) + updated_obj = update_validator_config( + obj, + payload.model_dump(exclude_unset=True), + session + ) + return flatten_validator(updated_obj) @router.delete("/{id}") @@ -126,18 +118,57 @@ async def delete_validator( session: SessionDep, _: AuthDep, ): - obj = session.get(ValidatorConfig, id) - - if not obj or obj.org_id != org_id or obj.project_id != project_id: - raise HTTPException(404) - + obj = get_validator_or_404(id, org_id, project_id, session) session.delete(obj) session.commit() - return {"success": True} def flatten_validator(row: ValidatorConfig) -> dict: + """ + Flatten validator config: combines base fields with config dict. + Returns a dict with all fields including config extras. + """ base = row.model_dump(exclude={"config"}) + flattened = {**base, **(row.config or {})} + print("FLATTENED:", flattened) + return flattened - print(base) - return {**base, **(row.config or {})} + +def get_validator_or_404( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, +) -> ValidatorConfig: + """Fetch validator by id, org_id, and project_id, or raise 404.""" + obj = session.query(ValidatorConfig).filter( + ValidatorConfig.id == id, + ValidatorConfig.org_id == org_id, + ValidatorConfig.project_id == project_id + ).first() + + if not obj: + raise HTTPException(status_code=404, detail="Validator not found") + + return obj + + +def update_validator_config( + obj: ValidatorConfig, + update_data: dict, + session: SessionDep, +) -> ValidatorConfig: + """Update validator config fields and return the updated object.""" + base, config = split_validator_payload(update_data) + + for k, v in base.items(): + setattr(obj, k, v) + + if config: + obj.config = {**(obj.config or {}), **config} + + session.add(obj) + session.commit() + session.refresh(obj) + + return obj diff --git a/backend/app/models/config/validator_config_table.py b/backend/app/models/config/validator_config_table.py index 8e030e0..68d57cf 100644 --- a/backend/app/models/config/validator_config_table.py +++ b/backend/app/models/config/validator_config_table.py @@ -1,8 +1,8 @@ from datetime import datetime -from typing import Any, Optional +from typing import Any from uuid import UUID, uuid4 -from sqlalchemy import Column +from sqlalchemy import Column, UniqueConstraint from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField from sqlmodel import SQLModel, Field @@ -24,9 +24,8 @@ class ValidatorConfig(SQLModel, table=True): sa_column_kwargs={"comment": "Identifier for the organization"}, ) - project_id: Optional[int] = Field( - default=None, - index=True, + project_id: int = Field( + nullable=False, sa_column_kwargs={"comment": "Identifier for the project"}, ) @@ -74,3 +73,10 @@ class ValidatorConfig(SQLModel, table=True): nullable=False, sa_column_kwargs={"comment": "Timestamp when the validator config was last updated"}, ) + + __table_args__ = ( + UniqueConstraint( + "org_id", "project_id", "type", "stage", + name="uq_validator_identity" + ), + ) \ No newline at end of file diff --git a/backend/app/schemas/validator_config.py b/backend/app/schemas/validator_config.py index 4452e24..c8a2ae1 100644 --- a/backend/app/schemas/validator_config.py +++ b/backend/app/schemas/validator_config.py @@ -1,5 +1,5 @@ from typing import Optional -from uuid import UUID, uuid4 +from uuid import UUID from sqlmodel import SQLModel @@ -32,4 +32,4 @@ class ValidatorUpdate(SQLModel): class ValidatorResponse(ValidatorBase): id: UUID org_id: int - project_id: Optional[int] = None + project_id: int From cfe8a849e14dbbe1b9cd0817b810254fff3b33a4 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 15:57:28 +0530 Subject: [PATCH 04/11] Updated tests --- backend/app/tests/conftest.py | 67 ++-- .../tests/test_guardrails_api_integration.py | 4 +- backend/app/tests/test_validator_configs.py | 173 ++++++++++ .../test_validator_configs_integration.py | 302 ++++++++++++++++++ backend/app/tests/utils/constants.py | 2 +- 5 files changed, 508 insertions(+), 40 deletions(-) create mode 100644 backend/app/tests/test_validator_configs.py create mode 100644 backend/app/tests/test_validator_configs_integration.py diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index f76b33d..2ff34d0 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -1,50 +1,43 @@ +# conftest.py import os -from unittest.mock import MagicMock +os.environ["ENVIRONMENT"] = "testing" import pytest from fastapi.testclient import TestClient +from sqlmodel import Session, create_engine, SQLModel -# MUST be set before app import -os.environ["ENVIRONMENT"] = "testing" - -from app.api.deps import SessionDep, verify_bearer_token -from app.api.routes import guardrails from app.main import app +from app.api.deps import SessionDep, verify_bearer_token +from app.core.config import settings + +test_engine = create_engine( + str(settings.SQLALCHEMY_DATABASE_URI), + echo=False, + pool_pre_ping=True, +) + +def override_session(): + with Session(test_engine) as session: + yield session + +@pytest.fixture(scope="session", autouse=True) +def setup_test_db(): + SQLModel.metadata.create_all(test_engine) + yield + SQLModel.metadata.drop_all(test_engine) + +@pytest.fixture(scope="function", autouse=True) +def clean_db(): + with Session(test_engine) as session: + for table in reversed(SQLModel.metadata.sorted_tables): + session.exec(table.delete()) + session.commit() @pytest.fixture(scope="function", autouse=True) -def override_dependencies(monkeypatch): - """ - Override ALL external dependencies: - - Auth - - DB session - - CRUDs - """ - - # ---- Auth override ---- +def override_dependencies(): app.dependency_overrides[verify_bearer_token] = lambda: True - # ---- DB session override ---- - mock_session = MagicMock() - app.dependency_overrides[SessionDep] = lambda: mock_session - - # ---- CRUD override ---- - mock_request_log_crud = MagicMock() - mock_request_log_crud.create.return_value = MagicMock(id=1) - mock_request_log_crud.update.return_value = None - - mock_validator_log_crud = MagicMock() - mock_validator_log_crud.create.return_value = None - - monkeypatch.setattr( - guardrails, - "RequestLogCrud", - lambda session: mock_request_log_crud, - ) - monkeypatch.setattr( - guardrails, - "ValidatorLogCrud", - lambda session: mock_validator_log_crud, - ) + app.dependency_overrides[SessionDep] = override_session yield diff --git a/backend/app/tests/test_guardrails_api_integration.py b/backend/app/tests/test_guardrails_api_integration.py index 48ca1e8..485b532 100644 --- a/backend/app/tests/test_guardrails_api_integration.py +++ b/backend/app/tests/test_guardrails_api_integration.py @@ -70,7 +70,7 @@ def test_input_guardrails_with_lexical_slur(integration_client): body = response.json() assert body["success"] is True - assert body["data"][SAFE_TEXT_FIELD] == "this sentence contains [REDACTED_SLUR]" + assert body["data"][SAFE_TEXT_FIELD] == "this sentence contains [REDACTED_SLUR]." def test_input_guardrails_with_lexical_slur_clean_text(integration_client): @@ -123,7 +123,7 @@ def test_input_guardrails_with_multiple_validators(integration_client): assert body["success"] is True assert ( body["data"][SAFE_TEXT_FIELD] - == "this sentence contains [REDACTED_SLUR] cause i want a s done to kill the female foetus" + == "this sentence contains [REDACTED_SLUR] cause i want a s done to kill the female foetus." ) diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py new file mode 100644 index 0000000..08c82ac --- /dev/null +++ b/backend/app/tests/test_validator_configs.py @@ -0,0 +1,173 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from sqlmodel import Session, delete + +from app.api.routes.validator_configs import ( + flatten_validator, + get_validator_or_404, + update_validator_config, +) +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.core.db import engine +from app.models.config.validator_config_table import ValidatorConfig + +# Test data constants +TEST_ORG_ID = 1 +TEST_PROJECT_ID = 1 +TEST_VALIDATOR_ID = uuid.uuid4() +TEST_TYPE = ValidatorType.LexicalSlur +TEST_STAGE = Stage.Input +TEST_ON_FAIL = GuardrailOnFail.Fix + + +@pytest.fixture +def clear_database(): + """Clear ValidatorConfig table before and after each test.""" + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + yield + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + +@pytest.fixture +def mock_session(): + """Create a mock session for database operations.""" + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_validator(): + """Create a sample validator config for testing.""" + return ValidatorConfig( + id=TEST_VALIDATOR_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config={"severity": "all", "languages": ["en", "hi"]}, + ) + + +class TestFlattenValidator: + """Tests for flatten_validator helper function.""" + + def test_flatten_validator_includes_config_fields(self, sample_validator): + """Test that flatten_validator includes config fields in output.""" + result = flatten_validator(sample_validator) + + assert result["id"] == TEST_VALIDATOR_ID + assert result["org_id"] == TEST_ORG_ID + assert result["project_id"] == TEST_PROJECT_ID + assert result["type"] == TEST_TYPE + assert result["severity"] == "all" + assert result["languages"] == ["en", "hi"] + + def test_flatten_validator_with_empty_config(self): + """Test flatten_validator with empty config dict.""" + validator = ValidatorConfig( + id=TEST_VALIDATOR_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config={}, + ) + + result = flatten_validator(validator) + + assert result["id"] == TEST_VALIDATOR_ID + assert "severity" not in result + # Base fields: id, org_id, project_id, type, stage, on_fail_action, is_enabled, created_at, updated_at + assert len(result) == 9 + + def test_flatten_validator_with_none_config(self): + """Test flatten_validator with None config.""" + validator = ValidatorConfig( + id=TEST_VALIDATOR_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config=None, + ) + + result = flatten_validator(validator) + + assert result["id"] == TEST_VALIDATOR_ID + assert "severity" not in result + + +class TestGetValidatorOr404: + """Tests for get_validator_or_404 helper function.""" + + def test_get_validator_success(self, sample_validator, mock_session): + """Test successful validator retrieval.""" + mock_session.query.return_value.filter.return_value.first.return_value = ( + sample_validator + ) + + result = get_validator_or_404( + TEST_VALIDATOR_ID, TEST_ORG_ID, TEST_PROJECT_ID, mock_session + ) + + assert result == sample_validator + mock_session.query.assert_called_once_with(ValidatorConfig) + + def test_get_validator_not_found(self, mock_session): + """Test validator not found raises 404.""" + mock_session.query.return_value.filter.return_value.first.return_value = None + + with pytest.raises(Exception) as exc_info: + get_validator_or_404(TEST_VALIDATOR_ID, TEST_ORG_ID, TEST_PROJECT_ID, mock_session) + + assert "404" in str(exc_info.value) + + +class TestUpdateValidatorConfig: + """Tests for update_validator_config helper function.""" + + def test_update_validator_config_base_fields(self, sample_validator, mock_session): + """Test updating base validator fields.""" + update_data = { + "type": ValidatorType.PIIRemover, + "on_fail_action": GuardrailOnFail.Exception, + } + + result = update_validator_config(sample_validator, update_data, mock_session) + + assert result.type == ValidatorType.PIIRemover + assert result.on_fail_action == GuardrailOnFail.Exception + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + def test_update_validator_config_extra_fields(self, sample_validator, mock_session): + """Test updating extra config fields.""" + update_data = {"severity": "high", "new_field": "new_value"} + + result = update_validator_config(sample_validator, update_data, mock_session) + + assert result.config["severity"] == "high" + assert result.config["new_field"] == "new_value" + assert result.config["languages"] == ["en", "hi"] # Original values preserved + + def test_update_validator_merges_config(self, sample_validator, mock_session): + """Test that updating config merges with existing config.""" + sample_validator.config = {"severity": "all", "languages": ["en"]} + update_data = {"languages": ["en", "hi", "mr"]} + + result = update_validator_config(sample_validator, update_data, mock_session) + + assert result.config["languages"] == ["en", "hi", "mr"] + assert result.config["severity"] == "all" + diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py new file mode 100644 index 0000000..effffd5 --- /dev/null +++ b/backend/app/tests/test_validator_configs_integration.py @@ -0,0 +1,302 @@ +import uuid + +import pytest +from sqlalchemy.exc import OperationalError +from sqlmodel import Session, delete + +from app.core.db import engine +from app.models.config.validator_config_table import ValidatorConfig + +pytestmark = pytest.mark.integration + +# Test data constants +TEST_ORG_ID = 1 +TEST_PROJECT_ID = 1 +BASE_URL = "/api/v1/guardrails/validators/configs/" +DEFAULT_QUERY_PARAMS = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" + +VALIDATOR_PAYLOADS = { + "lexical_slur": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + "severity": "all", + "languages": ["en", "hi"], + }, + "pii_remover_input": { + "type": "pii_remover", + "stage": "input", + "on_fail_action": "fix", + }, + "pii_remover_output": { + "type": "pii_remover", + "stage": "output", + "on_fail_action": "fix", + }, + "minimal": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + }, +} + + +@pytest.fixture +def clear_database(): + """Clear ValidatorConfig table before and after each test.""" + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + yield + + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + +class BaseValidatorTest: + """Base class with helper methods for validator tests.""" + + def create_validator(self, client, payload_key="minimal", **kwargs): + """Helper to create a validator.""" + payload = {**VALIDATOR_PAYLOADS[payload_key], **kwargs} + return client.post(f"{BASE_URL}{DEFAULT_QUERY_PARAMS}", json=payload) + + def get_validator(self, client, validator_id): + """Helper to get a specific validator.""" + return client.get(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}") + + def list_validators(self, client, **query_params): + """Helper to list validators with optional filters.""" + params_str = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" + if query_params: + params_str += "&" + "&".join(f"{k}={v}" for k, v in query_params.items()) + return client.get(f"{BASE_URL}{params_str}") + + def update_validator(self, client, validator_id, payload): + """Helper to update a validator.""" + return client.patch(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}", json=payload) + + def delete_validator(self, client, validator_id): + """Helper to delete a validator.""" + return client.delete(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}") + + +class TestCreateValidator(BaseValidatorTest): + """Tests for POST /guardrails/validators/configs endpoint.""" + + def test_create_validator_success(self, integration_client, clear_database): + """Test successful validator creation.""" + response = self.create_validator(integration_client, "lexical_slur") + + assert response.status_code == 200 + data = response.json() + assert data["type"] == "uli_slur_match" + assert data["stage"] == "input" + assert data["severity"] == "all" + assert data["languages"] == ["en", "hi"] + assert "id" in data + + def test_create_validator_duplicate_raises_400(self, integration_client, clear_database): + """Test that creating duplicate validator raises 400.""" + # First request should succeed + response1 = self.create_validator(integration_client, "minimal") + assert response1.status_code == 200 + + # Second request with same unique keys should fail + response2 = self.create_validator(integration_client, "minimal") + assert response2.status_code == 400 + + def test_create_validator_missing_required_fields(self, integration_client, clear_database): + """Test that missing required fields returns validation error.""" + response = integration_client.post( + f"{BASE_URL}{DEFAULT_QUERY_PARAMS}", + json={"type": "uli_slur_match"}, + ) + + assert response.status_code == 422 + + +class TestListValidators(BaseValidatorTest): + """Tests for GET /guardrails/validators/configs endpoint.""" + + def test_list_validators_success(self, integration_client, clear_database): + """Test successful validator listing.""" + # Create validators first + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_input") + + response = self.list_validators(integration_client) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + def test_list_validators_filter_by_stage(self, integration_client, clear_database): + """Test filtering validators by stage.""" + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_output") + + response = self.list_validators(integration_client, stage="input") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["stage"] == "input" + + def test_list_validators_filter_by_type(self, integration_client, clear_database): + """Test filtering validators by type.""" + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_input") + + response = self.list_validators(integration_client, type="pii_remover") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["type"] == "pii_remover" + + def test_list_validators_empty(self, integration_client, clear_database): + """Test listing validators when none exist.""" + response = integration_client.get( + f"{BASE_URL}?org_id=999&project_id=999", + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 0 + + +class TestGetValidator(BaseValidatorTest): + """Tests for GET /guardrails/validators/configs/{id} endpoint.""" + + def test_get_validator_success(self, integration_client, clear_database): + """Test successful validator retrieval.""" + # Create a validator + create_response = self.create_validator( + integration_client, "lexical_slur", severity="all" + ) + validator_id = create_response.json()["id"] + + # Retrieve it + response = self.get_validator(integration_client, validator_id) + + assert response.status_code == 200 + data = response.json() + assert data["id"] == validator_id + assert data["severity"] == "all" + + def test_get_validator_not_found(self, integration_client, clear_database): + """Test retrieving non-existent validator returns 404.""" + fake_id = uuid.uuid4() + response = self.get_validator(integration_client, fake_id) + + assert response.status_code == 404 + + def test_get_validator_wrong_org(self, integration_client, clear_database): + """Test that accessing validator from different org returns 404.""" + # Create a validator for org 1 + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Try to access it as different org + response = integration_client.get( + f"{BASE_URL}{validator_id}/?org_id=2&project_id=1", + ) + + assert response.status_code == 404 + + +class TestUpdateValidator(BaseValidatorTest): + """Tests for PATCH /guardrails/validators/configs/{id} endpoint.""" + + def test_update_validator_success(self, integration_client, clear_database): + """Test successful validator update.""" + # Create a validator + create_response = self.create_validator( + integration_client, "lexical_slur", severity="all" + ) + validator_id = create_response.json()["id"] + + # Update it + update_payload = {"on_fail_action": "exception", "severity": "high"} + response = self.update_validator(integration_client, validator_id, update_payload) + + assert response.status_code == 200 + data = response.json() + assert data["on_fail_action"] == "exception" + assert data["severity"] == "high" + + def test_update_validator_partial(self, integration_client, clear_database): + """Test partial update preserves original fields.""" + # Create a validator + create_response = self.create_validator( + integration_client, + "lexical_slur", + severity="all", + languages=["en", "hi"], + ) + validator_id = create_response.json()["id"] + + # Update only one field + update_payload = {"severity": "low"} + response = self.update_validator(integration_client, validator_id, update_payload) + + assert response.status_code == 200 + data = response.json() + assert data["severity"] == "low" + assert data["languages"] == ["en", "hi"] # Original preserved + + def test_update_validator_not_found(self, integration_client, clear_database): + """Test updating non-existent validator returns 404.""" + fake_id = uuid.uuid4() + update_payload = {"severity": "low"} + + response = self.update_validator(integration_client, fake_id, update_payload) + + assert response.status_code == 404 + + +class TestDeleteValidator(BaseValidatorTest): + """Tests for DELETE /guardrails/validators/configs/{id} endpoint.""" + + def test_delete_validator_success(self, integration_client, clear_database): + """Test successful validator deletion.""" + # Create a validator + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Delete it + response = self.delete_validator(integration_client, validator_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + + # Verify it's deleted + get_response = self.get_validator(integration_client, validator_id) + assert get_response.status_code == 404 + + def test_delete_validator_not_found(self, integration_client, clear_database): + """Test deleting non-existent validator returns 404.""" + fake_id = uuid.uuid4() + response = self.delete_validator(integration_client, fake_id) + + assert response.status_code == 404 + + def test_delete_validator_wrong_org(self, integration_client, clear_database): + """Test that deleting validator from different org returns 404.""" + # Create a validator for org 1 + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Try to delete it as different org + response = integration_client.delete( + f"{BASE_URL}{validator_id}/?org_id=2&project_id=1", + ) + + assert response.status_code == 404 + + # Verify original is still there + get_response = self.get_validator(integration_client, validator_id) + assert get_response.status_code == 200 diff --git a/backend/app/tests/utils/constants.py b/backend/app/tests/utils/constants.py index 568bc19..e642c67 100644 --- a/backend/app/tests/utils/constants.py +++ b/backend/app/tests/utils/constants.py @@ -1,2 +1,2 @@ -VALIDATE_API_PATH = "/api/v1/guardrails/validate/" +VALIDATE_API_PATH = "/api/v1/guardrails/" SAFE_TEXT_FIELD = "safe_text" From 3a0fa811e83fed86305927d7d0539b8de4e84bde Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 16:47:54 +0530 Subject: [PATCH 05/11] resolved comments --- backend/app/api/routes/validator_configs.py | 118 +++-------------- backend/app/crud/validator_config_crud.py | 109 ++++++++++++++++ backend/app/tests/test_validator_configs.py | 132 ++++++++------------ 3 files changed, 180 insertions(+), 179 deletions(-) create mode 100644 backend/app/crud/validator_config_crud.py diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index e35ce1c..2829a01 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -1,20 +1,21 @@ from typing import List, Optional from uuid import UUID -from fastapi import APIRouter, HTTPException -from sqlalchemy.exc import IntegrityError -from sqlmodel import select +from fastapi import APIRouter from app.api.deps import AuthDep, SessionDep -from app.models.config.validator_config_table import ValidatorConfig from app.schemas.validator_config import * -from app.utils import split_validator_payload +from app.crud.validator_config_crud import validator_config_crud -router = APIRouter(prefix="/guardrails/validators/configs", tags=["validator configs"]) + +router = APIRouter( + prefix="/guardrails/validators/configs", + tags=["validator configs"], +) @router.post( - "/", + "/", response_model=ValidatorResponse ) async def create_validator( @@ -24,28 +25,7 @@ async def create_validator( project_id: int, _: AuthDep, ): - data = payload.model_dump() - base, config = split_validator_payload(data) - obj = ValidatorConfig( - org_id=org_id, - project_id=project_id, - config=config, - **base, - ) - - session.add(obj) - - try: - session.commit() - except IntegrityError: - session.rollback() - raise HTTPException( - status_code=400, - detail="Validator already exists for this type and stage", - ) - - session.refresh(obj) - return flatten_validator(obj) + return validator_config_crud.create(session, org_id, project_id, payload) @router.get( "/", @@ -59,19 +39,7 @@ async def list_validators( stage: Optional[Stage] = None, type: Optional[ValidatorType] = None, ): - query = select(ValidatorConfig).where( - ValidatorConfig.org_id == org_id, - ValidatorConfig.project_id == project_id - ) - - if stage: - query = query.where(ValidatorConfig.stage == stage) - - if type: - query = query.where(ValidatorConfig.type == type) - - rows = session.exec(query).all() - return [flatten_validator(r) for r in rows] + return validator_config_crud.list(session, org_id, project_id, stage, type) @router.get( @@ -85,8 +53,8 @@ async def get_validator( session: SessionDep, _: AuthDep, ): - obj = get_validator_or_404(id, org_id, project_id, session) - return flatten_validator(obj) + obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + return validator_config_crud._flatten(obj) @router.patch( @@ -101,13 +69,12 @@ async def update_validator( session: SessionDep, _: AuthDep, ): - obj = get_validator_or_404(id, org_id, project_id, session) - updated_obj = update_validator_config( + obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + return validator_config_crud.update( + session, obj, payload.model_dump(exclude_unset=True), - session ) - return flatten_validator(updated_obj) @router.delete("/{id}") @@ -118,57 +85,6 @@ async def delete_validator( session: SessionDep, _: AuthDep, ): - obj = get_validator_or_404(id, org_id, project_id, session) - session.delete(obj) - session.commit() + obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + validator_config_crud.delete(session, obj) return {"success": True} - -def flatten_validator(row: ValidatorConfig) -> dict: - """ - Flatten validator config: combines base fields with config dict. - Returns a dict with all fields including config extras. - """ - base = row.model_dump(exclude={"config"}) - flattened = {**base, **(row.config or {})} - print("FLATTENED:", flattened) - return flattened - - -def get_validator_or_404( - id: UUID, - org_id: int, - project_id: int, - session: SessionDep, -) -> ValidatorConfig: - """Fetch validator by id, org_id, and project_id, or raise 404.""" - obj = session.query(ValidatorConfig).filter( - ValidatorConfig.id == id, - ValidatorConfig.org_id == org_id, - ValidatorConfig.project_id == project_id - ).first() - - if not obj: - raise HTTPException(status_code=404, detail="Validator not found") - - return obj - - -def update_validator_config( - obj: ValidatorConfig, - update_data: dict, - session: SessionDep, -) -> ValidatorConfig: - """Update validator config fields and return the updated object.""" - base, config = split_validator_payload(update_data) - - for k, v in base.items(): - setattr(obj, k, v) - - if config: - obj.config = {**(obj.config or {}), **config} - - session.add(obj) - session.commit() - session.refresh(obj) - - return obj diff --git a/backend/app/crud/validator_config_crud.py b/backend/app/crud/validator_config_crud.py new file mode 100644 index 0000000..803dcf3 --- /dev/null +++ b/backend/app/crud/validator_config_crud.py @@ -0,0 +1,109 @@ +from typing import List, Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.models.config.validator_config_table import ValidatorConfig +from app.schemas.validator_config import Stage, ValidatorType +from app.utils import split_validator_payload + + +class ValidatorConfigCrud: + def create( + self, + session: Session, + org_id: int, + project_id: int, + payload + ): + data = payload.model_dump() + base, config = split_validator_payload(data) + + obj = ValidatorConfig( + org_id=org_id, + project_id=project_id, + config=config, + **base, + ) + + session.add(obj) + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Validator already exists for this type and stage", + ) + + session.refresh(obj) + return self._flatten(obj) + + def list( + self, + session: Session, + org_id: int, + project_id: int, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, + ) -> List[dict]: + query = select(ValidatorConfig).where( + ValidatorConfig.org_id == org_id, + ValidatorConfig.project_id == project_id, + ) + + if stage: + query = query.where(ValidatorConfig.stage == stage) + + if type: + query = query.where(ValidatorConfig.type == type) + + rows = session.exec(query).all() + return [self._flatten(r) for r in rows] + + def get_or_404( + self, + session: Session, + id: UUID, + org_id: int, + project_id: int, + ) -> ValidatorConfig: + obj = session.get(ValidatorConfig, id) + + if not obj or obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(404, "Validator not found") + + return obj + + def update( + self, + session: Session, + obj: ValidatorConfig, + update_data: dict + ): + base, config = split_validator_payload(update_data) + + for k, v in base.items(): + setattr(obj, k, v) + + if config: + obj.config = {**(obj.config or {}), **config} + + session.commit() + session.refresh(obj) + + return self._flatten(obj) + + def delete(self, session: Session, obj: ValidatorConfig): + session.delete(obj) + session.commit() + + def _flatten(self, row: ValidatorConfig) -> dict: + base = row.model_dump(exclude={"config"}) + return {**base, **(row.config or {})} + + +validator_config_crud = ValidatorConfigCrud() diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index 08c82ac..51a0739 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -1,14 +1,10 @@ import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from sqlmodel import Session, delete -from app.api.routes.validator_configs import ( - flatten_validator, - get_validator_or_404, - update_validator_config, -) +from app.crud.validator_config_crud import validator_config_crud from app.core.enum import GuardrailOnFail, Stage, ValidatorType from app.core.db import engine from app.models.config.validator_config_table import ValidatorConfig @@ -55,42 +51,15 @@ def sample_validator(): ) -class TestFlattenValidator: - """Tests for flatten_validator helper function.""" +class TestFlatten: + def test_flatten_includes_config_fields(self, sample_validator): + result = validator_config_crud._flatten(sample_validator) - def test_flatten_validator_includes_config_fields(self, sample_validator): - """Test that flatten_validator includes config fields in output.""" - result = flatten_validator(sample_validator) - - assert result["id"] == TEST_VALIDATOR_ID - assert result["org_id"] == TEST_ORG_ID - assert result["project_id"] == TEST_PROJECT_ID - assert result["type"] == TEST_TYPE assert result["severity"] == "all" assert result["languages"] == ["en", "hi"] - - def test_flatten_validator_with_empty_config(self): - """Test flatten_validator with empty config dict.""" - validator = ValidatorConfig( - id=TEST_VALIDATOR_ID, - org_id=TEST_ORG_ID, - project_id=TEST_PROJECT_ID, - type=TEST_TYPE, - stage=TEST_STAGE, - on_fail_action=TEST_ON_FAIL, - is_enabled=True, - config={}, - ) - - result = flatten_validator(validator) - assert result["id"] == TEST_VALIDATOR_ID - assert "severity" not in result - # Base fields: id, org_id, project_id, type, stage, on_fail_action, is_enabled, created_at, updated_at - assert len(result) == 9 - def test_flatten_validator_with_none_config(self): - """Test flatten_validator with None config.""" + def test_flatten_empty_config(self): validator = ValidatorConfig( id=TEST_VALIDATOR_ID, org_id=TEST_ORG_ID, @@ -99,75 +68,82 @@ def test_flatten_validator_with_none_config(self): stage=TEST_STAGE, on_fail_action=TEST_ON_FAIL, is_enabled=True, - config=None, + config={}, ) - result = flatten_validator(validator) + result = validator_config_crud._flatten(validator) - assert result["id"] == TEST_VALIDATOR_ID assert "severity" not in result -class TestGetValidatorOr404: - """Tests for get_validator_or_404 helper function.""" +class TestGetOr404: + def test_success(self, sample_validator, mock_session): + mock_session.get.return_value = sample_validator - def test_get_validator_success(self, sample_validator, mock_session): - """Test successful validator retrieval.""" - mock_session.query.return_value.filter.return_value.first.return_value = ( - sample_validator - ) - - result = get_validator_or_404( - TEST_VALIDATOR_ID, TEST_ORG_ID, TEST_PROJECT_ID, mock_session + result = validator_config_crud.get_or_404( + mock_session, + TEST_VALIDATOR_ID, + TEST_ORG_ID, + TEST_PROJECT_ID, ) assert result == sample_validator - mock_session.query.assert_called_once_with(ValidatorConfig) - - def test_get_validator_not_found(self, mock_session): - """Test validator not found raises 404.""" - mock_session.query.return_value.filter.return_value.first.return_value = None + mock_session.get.assert_called_once() - with pytest.raises(Exception) as exc_info: - get_validator_or_404(TEST_VALIDATOR_ID, TEST_ORG_ID, TEST_PROJECT_ID, mock_session) + def test_not_found(self, mock_session): + mock_session.get.return_value = None - assert "404" in str(exc_info.value) + with pytest.raises(Exception) as exc: + validator_config_crud.get_or_404( + mock_session, + TEST_VALIDATOR_ID, + TEST_ORG_ID, + TEST_PROJECT_ID, + ) + assert "Validator not found" in str(exc.value) -class TestUpdateValidatorConfig: - """Tests for update_validator_config helper function.""" - def test_update_validator_config_base_fields(self, sample_validator, mock_session): - """Test updating base validator fields.""" +class TestUpdate: + def test_update_base_fields(self, sample_validator, mock_session): update_data = { "type": ValidatorType.PIIRemover, "on_fail_action": GuardrailOnFail.Exception, } - result = update_validator_config(sample_validator, update_data, mock_session) + result = validator_config_crud.update( + mock_session, + sample_validator, + update_data, + ) + + assert result["type"] == ValidatorType.PIIRemover + assert result["on_fail_action"] == GuardrailOnFail.Exception - assert result.type == ValidatorType.PIIRemover - assert result.on_fail_action == GuardrailOnFail.Exception mock_session.commit.assert_called_once() mock_session.refresh.assert_called_once() - def test_update_validator_config_extra_fields(self, sample_validator, mock_session): - """Test updating extra config fields.""" + def test_update_extra_fields(self, sample_validator, mock_session): update_data = {"severity": "high", "new_field": "new_value"} - result = update_validator_config(sample_validator, update_data, mock_session) + result = validator_config_crud.update( + mock_session, + sample_validator, + update_data, + ) - assert result.config["severity"] == "high" - assert result.config["new_field"] == "new_value" - assert result.config["languages"] == ["en", "hi"] # Original values preserved + assert result["severity"] == "high" + assert result["new_field"] == "new_value" + assert result["languages"] == ["en", "hi"] - def test_update_validator_merges_config(self, sample_validator, mock_session): - """Test that updating config merges with existing config.""" + def test_merge_config(self, sample_validator, mock_session): sample_validator.config = {"severity": "all", "languages": ["en"]} - update_data = {"languages": ["en", "hi", "mr"]} - - result = update_validator_config(sample_validator, update_data, mock_session) - assert result.config["languages"] == ["en", "hi", "mr"] - assert result.config["severity"] == "all" + result = validator_config_crud.update( + mock_session, + sample_validator, + {"languages": ["en", "hi"]}, + ) + assert result["languages"] == ["en", "hi"] + assert result["severity"] == "all" From 0d2d609efb2a92a1512bbcfcc88ff065c958fb5e Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 17:05:07 +0530 Subject: [PATCH 06/11] removed .env.test --- .env.test | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 .env.test diff --git a/.env.test b/.env.test deleted file mode 100644 index 40e9ad8..0000000 --- a/.env.test +++ /dev/null @@ -1,31 +0,0 @@ -DOMAIN=localhost - -ENVIRONMENT=testing - -PROJECT_NAME="Kaapi-Guardrails" -STACK_NAME=Kaapi-Guardrails - -# API Base URL for cron scripts (defaults to http://localhost:8000 if not set) -API_BASE_URL=http://localhost:8000 - -# Postgres -POSTGRES_SERVER=localhost -POSTGRES_PORT=5432 -POSTGRES_DB=kaapi-guardrails -POSTGRES_USER=postgres -POSTGRES_PASSWORD=postgres - -SENTRY_DSN= - -# Configure these with your own Docker registry images - -DOCKER_IMAGE_BACKEND=kaapi-guardrails-backend - -# Callback Timeouts (in seconds) -CALLBACK_CONNECT_TIMEOUT=3 -CALLBACK_READ_TIMEOUT=10 - -# require as a env if you want to use doc transformation -OPENAI_API_KEY="" -GUARDRAILS_HUB_API_KEY="" -AUTH_TOKEN="" \ No newline at end of file From d445ad96d06be21fae3bc30268c97d3c9c8a09e3 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 17:17:32 +0530 Subject: [PATCH 07/11] resolved comments --- .../versions/003_added_validator_config.py | 2 +- backend/app/api/routes/validator_configs.py | 9 ++++---- backend/app/crud/validator_config_crud.py | 21 ++++++++++--------- .../models/config/validator_config_table.py | 7 +++++-- backend/app/tests/conftest.py | 2 +- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py index 2ee47d3..29a5a06 100644 --- a/backend/app/alembic/versions/003_added_validator_config.py +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -1,7 +1,7 @@ """Added validator_config table Revision ID: 003 -Revises: 001 +Revises: 002 Create Date: 2026-02-05 09:42:54.128852 """ diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index 2829a01..755cf82 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -1,10 +1,11 @@ -from typing import List, Optional +from typing import Optional from uuid import UUID from fastapi import APIRouter from app.api.deps import AuthDep, SessionDep -from app.schemas.validator_config import * +from app.core.enum import Stage, ValidatorType +from app.schemas.validator_config import ValidatorCreate, ValidatorResponse, ValidatorUpdate from app.crud.validator_config_crud import validator_config_crud @@ -29,7 +30,7 @@ async def create_validator( @router.get( "/", - response_model=List[ValidatorResponse] + response_model=list[ValidatorResponse] ) async def list_validators( org_id: int, @@ -54,7 +55,7 @@ async def get_validator( _: AuthDep, ): obj = validator_config_crud.get_or_404(session, id, org_id, project_id) - return validator_config_crud._flatten(obj) + return validator_config_crud.flatten(obj) @router.patch( diff --git a/backend/app/crud/validator_config_crud.py b/backend/app/crud/validator_config_crud.py index 803dcf3..73afd33 100644 --- a/backend/app/crud/validator_config_crud.py +++ b/backend/app/crud/validator_config_crud.py @@ -1,13 +1,14 @@ -from typing import List, Optional +from typing import Optional from uuid import UUID from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select +from app.core.enum import Stage, ValidatorType from app.models.config.validator_config_table import ValidatorConfig -from app.schemas.validator_config import Stage, ValidatorType -from app.utils import split_validator_payload +from app.schemas.validator_config import ValidatorCreate +from app.utils import now, split_validator_payload class ValidatorConfigCrud: @@ -16,7 +17,7 @@ def create( session: Session, org_id: int, project_id: int, - payload + payload: ValidatorCreate ): data = payload.model_dump() base, config = split_validator_payload(data) @@ -40,7 +41,7 @@ def create( ) session.refresh(obj) - return self._flatten(obj) + return self.flatten(obj) def list( self, @@ -49,7 +50,7 @@ def list( project_id: int, stage: Optional[Stage] = None, type: Optional[ValidatorType] = None, - ) -> List[dict]: + ) -> list[dict]: query = select(ValidatorConfig).where( ValidatorConfig.org_id == org_id, ValidatorConfig.project_id == project_id, @@ -62,7 +63,7 @@ def list( query = query.where(ValidatorConfig.type == type) rows = session.exec(query).all() - return [self._flatten(r) for r in rows] + return [self.flatten(r) for r in rows] def get_or_404( self, @@ -91,17 +92,17 @@ def update( if config: obj.config = {**(obj.config or {}), **config} - + obj.updated_at = now() session.commit() session.refresh(obj) - return self._flatten(obj) + return self.flatten(obj) def delete(self, session: Session, obj: ValidatorConfig): session.delete(obj) session.commit() - def _flatten(self, row: ValidatorConfig) -> dict: + def flatten(self, row: ValidatorConfig) -> dict: base = row.model_dump(exclude={"config"}) return {**base, **(row.config or {})} diff --git a/backend/app/models/config/validator_config_table.py b/backend/app/models/config/validator_config_table.py index 68d57cf..0f33590 100644 --- a/backend/app/models/config/validator_config_table.py +++ b/backend/app/models/config/validator_config_table.py @@ -71,7 +71,10 @@ class ValidatorConfig(SQLModel, table=True): updated_at: datetime = Field( default_factory=now, nullable=False, - sa_column_kwargs={"comment": "Timestamp when the validator config was last updated"}, + sa_column_kwargs={ + "comment": "Timestamp when the validator config was last updated", + "onupdate": now, + }, ) __table_args__ = ( @@ -79,4 +82,4 @@ class ValidatorConfig(SQLModel, table=True): "org_id", "project_id", "type", "stage", name="uq_validator_identity" ), - ) \ No newline at end of file + ) diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 2ff34d0..d595b97 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -30,7 +30,7 @@ def setup_test_db(): def clean_db(): with Session(test_engine) as session: for table in reversed(SQLModel.metadata.sorted_tables): - session.exec(table.delete()) + session.execute(table.delete()) session.commit() @pytest.fixture(scope="function", autouse=True) From 59deffe0dc11cdb46f11a4392035d0feaba84eee Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 17:50:10 +0530 Subject: [PATCH 08/11] resolved comments --- backend/app/tests/test_validator_configs.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index 51a0739..e014b5c 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock import pytest -from sqlmodel import Session, delete +from sqlmodel import Session from app.crud.validator_config_crud import validator_config_crud from app.core.enum import GuardrailOnFail, Stage, ValidatorType @@ -17,19 +17,6 @@ TEST_STAGE = Stage.Input TEST_ON_FAIL = GuardrailOnFail.Fix - -@pytest.fixture -def clear_database(): - """Clear ValidatorConfig table before and after each test.""" - with Session(engine) as session: - session.exec(delete(ValidatorConfig)) - session.commit() - yield - with Session(engine) as session: - session.exec(delete(ValidatorConfig)) - session.commit() - - @pytest.fixture def mock_session(): """Create a mock session for database operations.""" @@ -53,7 +40,7 @@ def sample_validator(): class TestFlatten: def test_flatten_includes_config_fields(self, sample_validator): - result = validator_config_crud._flatten(sample_validator) + result = validator_config_crud.flatten(sample_validator) assert result["severity"] == "all" assert result["languages"] == ["en", "hi"] @@ -71,7 +58,7 @@ def test_flatten_empty_config(self): config={}, ) - result = validator_config_crud._flatten(validator) + result = validator_config_crud.flatten(validator) assert "severity" not in result From 97dae6c09d5ae75823711de641f673270eb38395 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 17:57:11 +0530 Subject: [PATCH 09/11] resolved comments --- .env.test.example | 2 +- backend/app/crud/validator_config_crud.py | 3 ++- backend/app/tests/test_validator_configs.py | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.env.test.example b/.env.test.example index 40e9ad8..b7e4927 100644 --- a/.env.test.example +++ b/.env.test.example @@ -11,7 +11,7 @@ API_BASE_URL=http://localhost:8000 # Postgres POSTGRES_SERVER=localhost POSTGRES_PORT=5432 -POSTGRES_DB=kaapi-guardrails +POSTGRES_DB=kaapi_guardrails_testing POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres diff --git a/backend/app/crud/validator_config_crud.py b/backend/app/crud/validator_config_crud.py index 73afd33..8663fdd 100644 --- a/backend/app/crud/validator_config_crud.py +++ b/backend/app/crud/validator_config_crud.py @@ -92,7 +92,8 @@ def update( if config: obj.config = {**(obj.config or {}), **config} - obj.updated_at = now() + + obj.updated_at = now() session.commit() session.refresh(obj) diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index e014b5c..0d26ec8 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -6,7 +6,6 @@ from app.crud.validator_config_crud import validator_config_crud from app.core.enum import GuardrailOnFail, Stage, ValidatorType -from app.core.db import engine from app.models.config.validator_config_table import ValidatorConfig # Test data constants From a8f4727e34ebb1dcddaa842177b9e7ed3ccc5be3 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 19:56:16 +0530 Subject: [PATCH 10/11] added ban list api changes --- .../versions/004_added_ban_list_config.py | 39 ++++ backend/app/api/main.py | 3 +- backend/app/api/routes/ban_list_configs.py | 118 ++++++++++ backend/app/crud/ban_list_crud.py | 79 +++++++ backend/app/models/config/ban_list_table.py | 75 ++++++ backend/app/schemas/ban_list_config.py | 33 +++ backend/app/tests/test_ban_list_configs.py | 200 ++++++++++++++++ .../app/tests/test_ban_lists_integration.py | 221 ++++++++++++++++++ 8 files changed, 767 insertions(+), 1 deletion(-) create mode 100644 backend/app/alembic/versions/004_added_ban_list_config.py create mode 100644 backend/app/api/routes/ban_list_configs.py create mode 100644 backend/app/crud/ban_list_crud.py create mode 100644 backend/app/models/config/ban_list_table.py create mode 100644 backend/app/schemas/ban_list_config.py create mode 100644 backend/app/tests/test_ban_list_configs.py create mode 100644 backend/app/tests/test_ban_lists_integration.py diff --git a/backend/app/alembic/versions/004_added_ban_list_config.py b/backend/app/alembic/versions/004_added_ban_list_config.py new file mode 100644 index 0000000..c024911 --- /dev/null +++ b/backend/app/alembic/versions/004_added_ban_list_config.py @@ -0,0 +1,39 @@ +"""Added ban_list table + +Revision ID: 004 +Revises: 003 +Create Date: 2026-02-05 09:42:54.128852 + +""" +from typing import Sequence, Union + +from alembic import op +from sqlalchemy.dialects import postgresql +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '004' +down_revision: Union[str, Sequence[str], None] = "003" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('ban_list', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=False), + sa.Column('org_id', sa.Integer(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('domain', sa.String(), nullable=False), + sa.Column('is_public', sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("banned_words", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + + sa.PrimaryKeyConstraint('id'), + ) + + +def downgrade() -> None: + op.drop_table('validator_config') diff --git a/backend/app/api/main.py b/backend/app/api/main.py index bf78ade..cac0cb1 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,11 +1,12 @@ from fastapi import APIRouter -from app.api.routes import utils, guardrails, validator_configs +from app.api.routes import ban_list_configs, guardrails, utils, validator_configs api_router = APIRouter() api_router.include_router(utils.router) api_router.include_router(guardrails.router) api_router.include_router(validator_configs.router) +api_router.include_router(ban_list_configs.router) # if settings.ENVIRONMENT == "local": # api_router.include_router(private.router) diff --git a/backend/app/api/routes/ban_list_configs.py b/backend/app/api/routes/ban_list_configs.py new file mode 100644 index 0000000..9d8b0a0 --- /dev/null +++ b/backend/app/api/routes/ban_list_configs.py @@ -0,0 +1,118 @@ +from typing import List, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import Session + +from app.api.deps import AuthDep, SessionDep +from app.crud.ban_list_crud import ban_list_crud +from app.schemas.ban_list_config import ( + BanListCreate, + BanListUpdate, + BanListResponse +) + +router = APIRouter( + prefix="/guardrails/ban-lists", + tags=["Ban Lists"] +) + + +def check_owner(obj, org_id, project_id): + if obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(status_code=403, detail="Not owner") + + +@router.post( + "/", + response_model=BanListResponse + ) +def create_ban_list( + payload: BanListCreate, + session: SessionDep, + org_id: int, + project_id: int, + _: AuthDep, +): + return ban_list_crud.create( + session, + data=payload, + org_id=org_id, + project_id=project_id, + ) + + +@router.get( + "/", + response_model=list[BanListResponse] + ) +def list_ban_lists( + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, + domain: Optional[str] = None, +): + return ban_list_crud.list( + session, + org_id=org_id, + project_id=project_id, + domain=domain, + ) + + +@router.get( + "/{id}", + response_model=BanListResponse + ) +def get_ban_list( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = ban_list_crud.get(session, id) + if not obj: + raise HTTPException(404) + + if not obj.is_public: + check_owner(obj, org_id, project_id) + return obj + + +@router.patch( + "/{id}", + response_model=BanListResponse + ) +def update_ban_list( + id: UUID, + org_id: int, + project_id: int, + payload: BanListUpdate, + session: SessionDep, + _: AuthDep, +): + obj = ban_list_crud.get(session, id) + if not obj: + raise HTTPException(404) + + check_owner(obj, org_id, project_id) + return ban_list_crud.update(session, obj=obj, data=payload) + + +@router.delete("/{id}") +def delete_ban_list( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = ban_list_crud.get(session, id) + if not obj: + raise HTTPException(404) + + check_owner(obj, org_id, project_id) + ban_list_crud.delete(session, obj) + return {"success": True} diff --git a/backend/app/crud/ban_list_crud.py b/backend/app/crud/ban_list_crud.py new file mode 100644 index 0000000..215012d --- /dev/null +++ b/backend/app/crud/ban_list_crud.py @@ -0,0 +1,79 @@ +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from sqlmodel import Session, select + +from app.models.config.ban_list_table import BanList +from app.schemas.ban_list_config import BanListCreate, BanListUpdate +from app.utils import now + +class BanListCrud: + + def create( + self, + db: Session, + *, + data: BanListCreate, + org_id: int, + project_id: int, + ) -> BanList: + obj = BanList( + **data.model_dump(), + org_id=org_id, + project_id=project_id, + ) + db.add(obj) + db.commit() + db.refresh(obj) + return obj + + def get(self, db: Session, id: UUID) -> Optional[BanList]: + return db.get(BanList, id) + + def list( + self, + db: Session, + *, + org_id: int, + project_id: int, + domain: Optional[str] = None, + ) -> List[BanList]: + stmt = select(BanList).where( + ( + (BanList.org_id == org_id) & + (BanList.project_id == project_id) + ) | + (BanList.is_public == True) + ) + + if domain: + stmt = stmt.where(BanList.domain == domain) + + return list(db.exec(stmt)) + + def update( + self, + db: Session, + *, + obj: BanList, + data: BanListUpdate, + ) -> BanList: + update_data = data.model_dump(exclude_unset=True) + + for k, v in update_data.items(): + setattr(obj, k, v) + + obj.updated_at = now() + + db.add(obj) + db.commit() + db.refresh(obj) + return obj + + def delete(self, db: Session, obj: BanList): + db.delete(obj) + db.commit() + + +ban_list_crud = BanListCrud() diff --git a/backend/app/models/config/ban_list_table.py b/backend/app/models/config/ban_list_table.py new file mode 100644 index 0000000..b34b351 --- /dev/null +++ b/backend/app/models/config/ban_list_table.py @@ -0,0 +1,75 @@ +from datetime import datetime +from typing import List, Optional +from uuid import UUID, uuid4 + +from sqlalchemy import Column, String +from sqlalchemy.dialects.postgresql import ARRAY +from sqlmodel import Field, SQLModel + +from app.utils import now + +class BanList(SQLModel, table=True): + __tablename__ = "ban_list" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + index=True, + sa_column_kwargs={"comment": "Unique identifier for the ban list entry"} + ) + + name: str = Field( + nullable=False, + sa_column_kwargs={"comment": "Name of the ban list entry"} + ) + + description: Optional[str] = Field( + nullable=False, + sa_column_kwargs={"comment": "Description of the ban list entry"} + ) + + banned_words: list[str] = Field( + default_factory=list, + sa_column=Column( + ARRAY(String), + nullable=False, + comment="List of banned words", + ), + description=("List of banned words") + ) + + org_id: int = Field( + index=True, + nullable=False, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + index=True, + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + domain: str = Field( + default=None, + index=False, + nullable=False, + sa_column_kwargs={"comment": "Domain or context for the ban list entry"} + ) + + is_public: bool = Field( + default=False, + sa_column_kwargs={"comment": "Whether the ban list entry is public or private"} + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the ban list entry was created"} + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the ban list entry was last updated"} + ) diff --git a/backend/app/schemas/ban_list_config.py b/backend/app/schemas/ban_list_config.py new file mode 100644 index 0000000..8b39d44 --- /dev/null +++ b/backend/app/schemas/ban_list_config.py @@ -0,0 +1,33 @@ +from uuid import UUID +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel, Field +from sqlmodel import SQLModel + +class BanListBase(SQLModel): + name: str + description: str + banned_words: list[str] + domain: str + is_public: bool = False + + +class BanListCreate(BanListBase): + pass + + +class BanListUpdate(SQLModel): + name: Optional[str] = None + description: Optional[str] = None + banned_words: Optional[list[str]] = None + domain: Optional[str] = None + is_public: Optional[bool] = None + + +class BanListResponse(BanListBase): + id: UUID + org_id: int + project_id: int + created_at: datetime + updated_at: datetime diff --git a/backend/app/tests/test_ban_list_configs.py b/backend/app/tests/test_ban_list_configs.py new file mode 100644 index 0000000..da4b5ef --- /dev/null +++ b/backend/app/tests/test_ban_list_configs.py @@ -0,0 +1,200 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException +from sqlmodel import Session + +from app.api.routes.ban_list_configs import ( + create_ban_list, + list_ban_lists, + get_ban_list, + update_ban_list, + delete_ban_list, +) +from app.schemas.ban_list_config import ( + BanListCreate, + BanListUpdate, +) + +TEST_ID = uuid.uuid4() +TEST_ORG_ID = 1 +TEST_PROJECT_ID = 10 + +@pytest.fixture +def mock_session(): + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_ban_list(): + obj = MagicMock() + obj.id = TEST_ID + obj.name = "test" + obj.description = "desc" + obj.banned_words = ["bad"] + obj.org_id = TEST_ORG_ID + obj.project_id = TEST_PROJECT_ID + obj.domain = "health" + obj.is_public = False + return obj + + +@pytest.fixture +def create_payload(): + return BanListCreate( + name="test", + description="desc", + banned_words=["bad"], + domain="health", + is_public=False, + ) + + +def test_create_calls_crud(mock_session, create_payload, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.create.return_value = sample_ban_list + + result = create_ban_list( + payload=create_payload, + session=mock_session, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + _=None, + ) + + crud.create.assert_called_once() + assert result == sample_ban_list + + +def test_list_returns_data(mock_session, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.list.return_value = [sample_ban_list] + + result = list_ban_lists( + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert len(result) == 1 + crud.list.assert_called_once() + + +def test_get_success(mock_session, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + result = get_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert result == sample_ban_list + + +def test_get_not_found(mock_session): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = None + + with pytest.raises(HTTPException) as exc: + get_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert exc.value.status_code == 404 + + +def test_get_forbidden(mock_session, sample_ban_list): + sample_ban_list.org_id = 999 # different owner + + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + with pytest.raises(HTTPException) as exc: + get_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert exc.value.status_code == 403 + + +def test_update_success(mock_session, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + crud.update.return_value = sample_ban_list + + result = update_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + payload=BanListUpdate(name="new"), + session=mock_session, + _=None, + ) + + crud.update.assert_called_once() + assert result == sample_ban_list + + +def test_update_forbidden(mock_session, sample_ban_list): + sample_ban_list.org_id = 999 + + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + with pytest.raises(HTTPException) as exc: + update_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + payload=BanListUpdate(name="new"), + session=mock_session, + _=None, + ) + + assert exc.value.status_code == 403 + + +def test_delete_success(mock_session, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + result = delete_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + crud.delete.assert_called_once() + assert result["success"] is True + + +def test_delete_forbidden(mock_session, sample_ban_list): + sample_ban_list.org_id = 999 + + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + with pytest.raises(HTTPException): + delete_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) diff --git a/backend/app/tests/test_ban_lists_integration.py b/backend/app/tests/test_ban_lists_integration.py new file mode 100644 index 0000000..956e932 --- /dev/null +++ b/backend/app/tests/test_ban_lists_integration.py @@ -0,0 +1,221 @@ +import uuid +import pytest +from sqlmodel import Session, delete + +from app.core.db import engine +from app.models.config.ban_list_table import BanList + +pytestmark = pytest.mark.integration + + +# Test data constants +TEST_ORG_ID = 1 +TEST_PROJECT_ID = 1 +BASE_URL = "/api/v1/guardrails/ban-lists/" +DEFAULT_QUERY = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" + + +BAN_LIST_PAYLOADS = { + "minimal": { + "name": "default", + "description": "basic list", + "banned_words": ["bad"], + "domain": "general", + }, + "health": { + "name": "health-list", + "description": "healthcare words", + "banned_words": ["gender detection", "sonography"], + "domain": "health", + }, + "edu": { + "name": "edu-list", + "description": "education words", + "banned_words": ["cheating"], + "domain": "edu", + }, + "public": { + "name": "public-list", + "description": "shared", + "banned_words": ["shared"], + "is_public": True, + "domain": "general", + }, +} + + +@pytest.fixture +def clear_database(): + with Session(engine) as session: + session.exec(delete(BanList)) + session.commit() + + yield + + with Session(engine) as session: + session.exec(delete(BanList)) + session.commit() + + +class BaseBanListTest: + + def create(self, client, payload_key="minimal", **kwargs): + payload = {**BAN_LIST_PAYLOADS[payload_key], **kwargs} + return client.post(f"{BASE_URL}{DEFAULT_QUERY}", json=payload) + + def list(self, client, **filters): + params = DEFAULT_QUERY + if filters: + params += "&" + "&".join(f"{k}={v}" for k, v in filters.items()) + return client.get(f"{BASE_URL}{params}") + + def get(self, client, id, org=TEST_ORG_ID, project=TEST_PROJECT_ID): + return client.get(f"{BASE_URL}{id}/?org_id={org}&project_id={project}") + + def update(self, client, id, payload): + return client.patch(f"{BASE_URL}{id}/{DEFAULT_QUERY}", json=payload) + + def delete(self, client, id): + return client.delete(f"{BASE_URL}{id}/{DEFAULT_QUERY}") + + +class TestCreateBanList(BaseBanListTest): + + def test_create_success(self, integration_client, clear_database): + response = self.create(integration_client, "minimal") + + assert response.status_code == 200 + data = response.json() + + assert data["name"] == "default" + assert data["banned_words"] == ["bad"] + assert "id" in data + + def test_create_validation_error(self, integration_client, clear_database): + response = integration_client.post( + f"{BASE_URL}{DEFAULT_QUERY}", + json={"name": "missing words"}, + ) + + assert response.status_code == 422 + + +class TestListBanLists(BaseBanListTest): + + def test_list_success(self, integration_client, clear_database): + self.create(integration_client, "minimal") + self.create(integration_client, "health") + + response = self.list(integration_client) + + assert response.status_code == 200 + assert len(response.json()) == 2 + + def test_filter_by_domain(self, integration_client, clear_database): + self.create(integration_client, "health") + self.create(integration_client, "edu") + + response = self.list(integration_client, domain="health") + + data = response.json() + assert len(data) == 1 + assert data[0]["domain"] == "health" + + def test_list_empty(self, integration_client, clear_database): + response = self.list(integration_client) + assert response.json() == [] + + +class TestPublicAccess(BaseBanListTest): + + def test_public_visible_to_other_org(self, integration_client, clear_database): + create_resp = self.create(integration_client, "public") + ban_id = create_resp.json()["id"] + + response = self.get(integration_client, ban_id, org=999, project=999) + + # public lists should still be readable + assert response.status_code == 200 + + +class TestGetBanList(BaseBanListTest): + + def test_get_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.get(integration_client, ban_id) + + assert response.status_code == 200 + + def test_get_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + response = self.get(integration_client, fake) + + assert response.status_code == 404 + + def test_get_wrong_owner_private(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.get(integration_client, ban_id, org=2, project=2) + + assert response.status_code in (403, 404) + + +class TestUpdateBanList(BaseBanListTest): + + def test_update_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.update( + integration_client, + ban_id, + {"banned_words": ["bad", "worse"]}, + ) + + assert response.status_code == 200 + assert response.json()["banned_words"] == ["bad", "worse"] + + def test_partial_update(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.update(integration_client, ban_id, {"name": "updated"}) + + assert response.json()["name"] == "updated" + + def test_update_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + + response = self.update(integration_client, fake, {"name": "x"}) + assert response.status_code == 404 + + +class TestDeleteBanList(BaseBanListTest): + + def test_delete_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.delete(integration_client, ban_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + + def test_delete_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + + response = self.delete(integration_client, fake) + assert response.status_code == 404 + + def test_delete_wrong_owner(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = integration_client.delete( + f"{BASE_URL}{ban_id}/?org_id=999&project_id=999" + ) + + assert response.status_code in (403, 404) From 3b5528d44eff638726f43d2e675f8d3b1a471e81 Mon Sep 17 00:00:00 2001 From: Kritika Rupauliha Date: Mon, 9 Feb 2026 21:28:33 +0530 Subject: [PATCH 11/11] Guardrails: Config Management (#30) --- .env.test.example | 2 +- .../alembic/versions/001_added_request_log.py | 6 +- .../versions/002_added_validator_log.py | 6 +- .../versions/003_added_validator_config.py | 12 +- backend/app/api/routes/guardrails.py | 8 +- backend/app/api/routes/validator_configs.py | 59 +++++---- backend/app/core/constants.py | 2 +- backend/app/crud/__init__.py | 2 +- backend/app/crud/request_log_repo.py | 2 +- backend/app/crud/validator_config.py | 125 ++++++++++++++++++ backend/app/crud/validator_log_repo.py | 2 +- backend/app/models/__init__.py | 4 +- backend/app/models/config/validator_config.py | 88 ++++++++++++ backend/app/models/logging/request_log.py | 68 ++++++++++ backend/app/models/logging/validator_log.py | 63 +++++++++ backend/app/schemas/validator_config.py | 11 +- backend/app/tests/test_validator_configs.py | 18 +-- .../test_validator_configs_integration.py | 57 ++++---- backend/app/utils.py | 17 ++- 19 files changed, 450 insertions(+), 102 deletions(-) create mode 100644 backend/app/crud/validator_config.py create mode 100644 backend/app/models/config/validator_config.py create mode 100644 backend/app/models/logging/request_log.py create mode 100644 backend/app/models/logging/validator_log.py diff --git a/.env.test.example b/.env.test.example index b7e4927..368b7d7 100644 --- a/.env.test.example +++ b/.env.test.example @@ -10,8 +10,8 @@ API_BASE_URL=http://localhost:8000 # Postgres POSTGRES_SERVER=localhost -POSTGRES_PORT=5432 POSTGRES_DB=kaapi_guardrails_testing +POSTGRES_PORT=5432 POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres diff --git a/backend/app/alembic/versions/001_added_request_log.py b/backend/app/alembic/versions/001_added_request_log.py index 706d504..d83577a 100644 --- a/backend/app/alembic/versions/001_added_request_log.py +++ b/backend/app/alembic/versions/001_added_request_log.py @@ -14,9 +14,9 @@ # revision identifiers, used by Alembic. revision: str = '001' -down_revision: Union[str, Sequence[str], None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str | None = None +branch_labels = None +depends_on = None def upgrade() -> None: diff --git a/backend/app/alembic/versions/002_added_validator_log.py b/backend/app/alembic/versions/002_added_validator_log.py index d46513f..e0b8115 100644 --- a/backend/app/alembic/versions/002_added_validator_log.py +++ b/backend/app/alembic/versions/002_added_validator_log.py @@ -14,9 +14,9 @@ # revision identifiers, used by Alembic. revision: str = '002' -down_revision: Union[str, Sequence[str], None] = '001' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str = '001' +branch_labels = None +depends_on = None def upgrade() -> None: diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py index 29a5a06..b50af01 100644 --- a/backend/app/alembic/versions/003_added_validator_config.py +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -13,15 +13,15 @@ # revision identifiers, used by Alembic. revision: str = '003' -down_revision: Union[str, Sequence[str], None] = "002" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str = '002' +branch_labels = None +depends_on = None def upgrade() -> None: op.create_table('validator_config', sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('org_id', sa.Integer(), nullable=False), + sa.Column('organization_id', sa.Integer(), nullable=False), sa.Column('project_id', sa.Integer(), nullable=False), sa.Column('type', sa.String(), nullable=False), sa.Column('stage', sa.String(), nullable=False), @@ -37,10 +37,10 @@ def upgrade() -> None: sa.Column('updated_at', sa.DateTime(), nullable=False), sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('org_id', 'project_id', 'type', 'stage', name='uq_validator_identity') + sa.UniqueConstraint('organization_id', 'project_id', 'type', 'stage', name='uq_validator_identity') ) - op.create_index("idx_validator_org", "validator_config", ["org_id"]) + op.create_index("idx_validator_organization", "validator_config", ["organization_id"]) op.create_index("idx_validator_project", "validator_config", ["project_id"]) op.create_index("idx_validator_type", "validator_config", ["type"]) op.create_index("idx_validator_stage", "validator_config", ["stage"]) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 90f7b3c..4f422e5 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -8,10 +8,10 @@ from app.api.deps import AuthDep, SessionDep from app.core.constants import REPHRASE_ON_FAIL_PREFIX from app.core.guardrail_controller import build_guard, get_validator_config_models -from app.crud.request_log_repo import RequestLogCrud -from app.crud.validator_log_repo import ValidatorLogCrud -from app.models.logging.request_log_table import RequestLogUpdate, RequestStatus -from app.models.logging.validator_log_table import ValidatorLog, ValidatorOutcome +from app.crud.request_log import RequestLogCrud +from app.crud.validator_log import ValidatorLogCrud +from app.models.logging.request_log import RequestLogUpdate, RequestStatus +from app.models.logging.validator_log import ValidatorLog, ValidatorOutcome from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse from app.utils import APIResponse diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index 755cf82..701e6a8 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -6,7 +6,8 @@ from app.api.deps import AuthDep, SessionDep from app.core.enum import Stage, ValidatorType from app.schemas.validator_config import ValidatorCreate, ValidatorResponse, ValidatorUpdate -from app.crud.validator_config_crud import validator_config_crud +from app.crud.validator_config import validator_config_crud +from app.utils import APIResponse router = APIRouter( @@ -17,75 +18,77 @@ @router.post( "/", - response_model=ValidatorResponse + response_model=APIResponse[ValidatorResponse] ) -async def create_validator( +def create_validator( payload: ValidatorCreate, session: SessionDep, - org_id: int, + organization_id: int, project_id: int, _: AuthDep, ): - return validator_config_crud.create(session, org_id, project_id, payload) + response_model = validator_config_crud.create(session, organization_id, project_id, payload) + return APIResponse.success_response(data=response_model) @router.get( "/", - response_model=list[ValidatorResponse] + response_model=APIResponse[list[ValidatorResponse]] ) -async def list_validators( - org_id: int, +def list_validators( + organization_id: int, project_id: int, session: SessionDep, _: AuthDep, stage: Optional[Stage] = None, type: Optional[ValidatorType] = None, ): - return validator_config_crud.list(session, org_id, project_id, stage, type) + response_model = validator_config_crud.list(session, organization_id, project_id, stage, type) + return APIResponse.success_response(data=response_model) @router.get( "/{id}", - response_model=ValidatorResponse + response_model=APIResponse[ValidatorResponse] ) -async def get_validator( +def get_validator( id: UUID, - org_id: int, + organization_id: int, project_id: int, session: SessionDep, _: AuthDep, ): - obj = validator_config_crud.get_or_404(session, id, org_id, project_id) - return validator_config_crud.flatten(obj) + obj = validator_config_crud.get(session, id, organization_id, project_id) + return APIResponse.success_response(data=validator_config_crud.flatten(obj)) @router.patch( "/{id}", - response_model=ValidatorResponse + response_model=APIResponse[ValidatorResponse] ) -async def update_validator( +def update_validator( id: UUID, - org_id: int, + organization_id: int, project_id: int, payload: ValidatorUpdate, session: SessionDep, _: AuthDep, ): - obj = validator_config_crud.get_or_404(session, id, org_id, project_id) - return validator_config_crud.update( - session, - obj, - payload.model_dump(exclude_unset=True), - ) + obj = validator_config_crud.get(session, id, organization_id, project_id) + response_model = validator_config_crud.update(session, obj, payload.model_dump(exclude_unset=True)) + return APIResponse.success_response(data=response_model) -@router.delete("/{id}") -async def delete_validator( +@router.delete( + "/{id}", + response_model=APIResponse[dict] + ) +def delete_validator( id: UUID, - org_id: int, + organization_id: int, project_id: int, session: SessionDep, _: AuthDep, ): - obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + obj = validator_config_crud.get(session, id, organization_id, project_id) validator_config_crud.delete(session, obj) - return {"success": True} + return APIResponse.success_response(data={"message": "Validator deleted successfully"}) diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index 115ad21..6c3825d 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -8,7 +8,7 @@ REPHRASE_ON_FAIL_PREFIX = "Please rephrase the query without unsafe content." VALIDATOR_CONFIG_SYSTEM_FIELDS = { - "org_id", + "organization_id", "project_id", "type", "stage", diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 37bc215..e58e7bd 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -1 +1 @@ -from app.crud.request_log_repo import RequestLogCrud \ No newline at end of file +from app.crud.request_log import RequestLogCrud diff --git a/backend/app/crud/request_log_repo.py b/backend/app/crud/request_log_repo.py index 9d3b1e5..1d75d87 100644 --- a/backend/app/crud/request_log_repo.py +++ b/backend/app/crud/request_log_repo.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.request_log_table import RequestLog, RequestLogUpdate, RequestStatus +from app.models.logging.request_log import RequestLog, RequestLogUpdate, RequestStatus from app.utils import now class RequestLogCrud: diff --git a/backend/app/crud/validator_config.py b/backend/app/crud/validator_config.py new file mode 100644 index 0000000..26dca48 --- /dev/null +++ b/backend/app/crud/validator_config.py @@ -0,0 +1,125 @@ +from typing import Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.core.enum import Stage, ValidatorType +from app.models.config.validator_config import ValidatorConfig +from app.schemas.validator_config import ValidatorCreate +from app.utils import now, split_validator_payload + + +class ValidatorConfigCrud: + def create( + self, + session: Session, + organization_id: int, + project_id: int, + payload: ValidatorCreate + ) -> dict: + data = payload.model_dump() + model_fields, config_fields = split_validator_payload(data) + + obj = ValidatorConfig( + organization_id=organization_id, + project_id=project_id, + config=config_fields, + **model_fields, + ) + + session.add(obj) + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Validator already exists for this type and stage", + ) + + session.refresh(obj) + return self.flatten(obj) + + def list( + self, + session: Session, + organization_id: int, + project_id: int, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, + ) -> list[dict]: + query = select(ValidatorConfig).where( + ValidatorConfig.organization_id == organization_id, + ValidatorConfig.project_id == project_id, + ) + + if stage: + query = query.where(ValidatorConfig.stage == stage) + + if type: + query = query.where(ValidatorConfig.type == type) + + rows = session.exec(query).all() + return [self.flatten(r) for r in rows] + + def get( + self, + session: Session, + id: UUID, + organization_id: int, + project_id: int, + ) -> ValidatorConfig: + obj = session.get(ValidatorConfig, id) + + if not obj or obj.organization_id != organization_id or obj.project_id != project_id: + raise HTTPException(404, "Validator not found") + + return obj + + def update( + self, + session: Session, + obj: ValidatorConfig, + update_data: dict + ) -> dict: + model_fields, config_fields = split_validator_payload(update_data) + + for k, v in model_fields.items(): + setattr(obj, k, v) + + if config_fields: + obj.config = {**(obj.config or {}), **config_fields} + + obj.updated_at = now() + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Validator already exists for this type and stage", + ) + except Exception: + session.rollback() + raise + + session.refresh(obj) + return self.flatten(obj) + + def delete(self, session: Session, obj: ValidatorConfig): + session.delete(obj) + try: + session.commit() + except Exception: + session.rollback() + raise + + def flatten(self, row: ValidatorConfig) -> dict: + base = row.model_dump(exclude={"config"}) + return {**base, **(row.config or {})} + + +validator_config_crud = ValidatorConfigCrud() diff --git a/backend/app/crud/validator_log_repo.py b/backend/app/crud/validator_log_repo.py index 649d6aa..3903129 100644 --- a/backend/app/crud/validator_log_repo.py +++ b/backend/app/crud/validator_log_repo.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.validator_log_table import ValidatorLog +from app.models.logging.validator_log import ValidatorLog from app.utils import now class ValidatorLogCrud: diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 2ba735f..d116b81 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,4 +1,4 @@ from sqlmodel import SQLModel -from app.models.logging.request_log_table import RequestLog -from app.models.logging.validator_log_table import ValidatorLog +from app.models.logging.request_log import RequestLog +from app.models.logging.validator_log import ValidatorLog diff --git a/backend/app/models/config/validator_config.py b/backend/app/models/config/validator_config.py new file mode 100644 index 0000000..021a41f --- /dev/null +++ b/backend/app/models/config/validator_config.py @@ -0,0 +1,88 @@ +from datetime import datetime +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy import Column, UniqueConstraint +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel, Field +import sqlalchemy as sa + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.utils import now + +class ValidatorConfig(SQLModel, table=True): + __tablename__ = "validator_config" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the validator configuration"}, + ) + + organization_id: int = Field( + index=True, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + type: ValidatorType = Field( + nullable=False, + sa_column_kwargs={"comment": "Type of the validator"}, + ) + + stage: Stage = Field( + nullable=False, + sa_column_kwargs={"comment": "Stage at which the validator is applied"}, + ) + + on_fail_action: GuardrailOnFail = Field( + default=GuardrailOnFail.Fix, + nullable=False, + sa_column_kwargs={"comment": "Action to take when the validator fails"}, + ) + + + config: dict[str, Any] = SQLField( + default_factory=dict, + sa_column=Column( + JSONB, + nullable=False, + server_default=sa.text("'{}'::jsonb"), + comment="Configuration for the validator", + ), + description=( + "Configuration for the validator" + ), + ) + + is_enabled: bool = Field( + default=True, + sa_column_kwargs={"comment": "Indicates if the validator is enabled"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the validator config was inserted"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={ + "comment": "Timestamp when the validator config was last updated", + "onupdate": now, + }, + ) + + __table_args__ = ( + UniqueConstraint( + "organization_id", "project_id", "type", "stage", + name="uq_validator_identity" + ), + ) diff --git a/backend/app/models/logging/request_log.py b/backend/app/models/logging/request_log.py new file mode 100644 index 0000000..f04d35b --- /dev/null +++ b/backend/app/models/logging/request_log.py @@ -0,0 +1,68 @@ +from datetime import datetime +from enum import Enum +from typing import Optional +from uuid import UUID, uuid4 + +from sqlmodel import SQLModel, Field + +from app.utils import now + +class RequestStatus(str, Enum): + PROCESSING = "processing" + SUCCESS = "success" + ERROR = "error" + WARNING = "warning" + + +class RequestLog(SQLModel, table=True): + __tablename__ = "request_log" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the request log entry"}, + ) + + request_id: UUID = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the request"}, + ) + + response_id: Optional[UUID] = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Identifier for the response"}, + ) + + status: RequestStatus = Field( + default=RequestStatus.PROCESSING, + sa_column_kwargs={"comment": "Status of the request processing"}, + ) + + request_text: str = Field( + nullable=False, + sa_column_kwargs={"comment": "Text of the request made"}, + ) + + response_text: Optional[str] = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Text of the response received"}, + ) + + inserted_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the entry was created"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the entry was last updated"}, + ) + + +class RequestLogUpdate(SQLModel): + response_text: str + response_id: UUID diff --git a/backend/app/models/logging/validator_log.py b/backend/app/models/logging/validator_log.py new file mode 100644 index 0000000..ca2166b --- /dev/null +++ b/backend/app/models/logging/validator_log.py @@ -0,0 +1,63 @@ +from datetime import datetime +from enum import Enum +from uuid import UUID, uuid4 + +from sqlmodel import SQLModel, Field + +from app.utils import now + +class ValidatorOutcome(str, Enum): + PASS = "PASS" + FAIL = "FAIL" + +class ValidatorLog(SQLModel, table=True): + __tablename__ = "validator_log" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the validator log entry"}, + ) + + request_id: UUID = Field( + foreign_key="request_log.id", + nullable=False, + sa_column_kwargs={"comment": "Foreign key to the associated request log entry"}, + ) + + name: str = Field( + nullable=False, + sa_column_kwargs={"comment": "Name of the validator used"}, + ) + + input: str = Field( + nullable=False, + sa_column_kwargs={"comment": "Input message for the validator to check"}, + ) + + output: str | None = Field( + nullable=True, + sa_column_kwargs={"comment": "Output message post validation"}, + ) + + error: str | None = Field( + nullable=True, + sa_column_kwargs={"comment": "Error message if the validator throws an exception"}, + ) + + outcome: ValidatorOutcome = Field( + nullable=False, + sa_column_kwargs={"comment": "Validator outcome (whether the validation failed or passed)"}, + ) + + inserted_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the entry was created"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the entry was last updated"}, + ) \ No newline at end of file diff --git a/backend/app/schemas/validator_config.py b/backend/app/schemas/validator_config.py index c8a2ae1..57c8852 100644 --- a/backend/app/schemas/validator_config.py +++ b/backend/app/schemas/validator_config.py @@ -1,13 +1,15 @@ +from datetime import datetime from typing import Optional from uuid import UUID +from pydantic import ConfigDict from sqlmodel import SQLModel from app.core.enum import GuardrailOnFail, Stage, ValidatorType class ValidatorBase(SQLModel): - model_config = {"extra": "allow"} + model_config = ConfigDict(extra="allow") type: ValidatorType stage: Stage @@ -20,8 +22,7 @@ class ValidatorCreate(ValidatorBase): class ValidatorUpdate(SQLModel): - # also allow extras for partial updates - model_config = {"extra": "allow"} + model_config = ConfigDict(extra="forbid") type: Optional[ValidatorType] = None stage: Optional[Stage] = None @@ -30,6 +31,4 @@ class ValidatorUpdate(SQLModel): class ValidatorResponse(ValidatorBase): - id: UUID - org_id: int - project_id: int + pass diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index 0d26ec8..017da20 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -4,12 +4,12 @@ import pytest from sqlmodel import Session -from app.crud.validator_config_crud import validator_config_crud +from app.crud.validator_config import validator_config_crud from app.core.enum import GuardrailOnFail, Stage, ValidatorType -from app.models.config.validator_config_table import ValidatorConfig +from app.models.config.validator_config import ValidatorConfig # Test data constants -TEST_ORG_ID = 1 +TEST_ORGANIZATION_ID = 1 TEST_PROJECT_ID = 1 TEST_VALIDATOR_ID = uuid.uuid4() TEST_TYPE = ValidatorType.LexicalSlur @@ -27,7 +27,7 @@ def sample_validator(): """Create a sample validator config for testing.""" return ValidatorConfig( id=TEST_VALIDATOR_ID, - org_id=TEST_ORG_ID, + organization_id=TEST_ORGANIZATION_ID, project_id=TEST_PROJECT_ID, type=TEST_TYPE, stage=TEST_STAGE, @@ -48,7 +48,7 @@ def test_flatten_includes_config_fields(self, sample_validator): def test_flatten_empty_config(self): validator = ValidatorConfig( id=TEST_VALIDATOR_ID, - org_id=TEST_ORG_ID, + organization_id=TEST_ORGANIZATION_ID, project_id=TEST_PROJECT_ID, type=TEST_TYPE, stage=TEST_STAGE, @@ -66,10 +66,10 @@ class TestGetOr404: def test_success(self, sample_validator, mock_session): mock_session.get.return_value = sample_validator - result = validator_config_crud.get_or_404( + result = validator_config_crud.get( mock_session, TEST_VALIDATOR_ID, - TEST_ORG_ID, + TEST_ORGANIZATION_ID, TEST_PROJECT_ID, ) @@ -80,10 +80,10 @@ def test_not_found(self, mock_session): mock_session.get.return_value = None with pytest.raises(Exception) as exc: - validator_config_crud.get_or_404( + validator_config_crud.get( mock_session, TEST_VALIDATOR_ID, - TEST_ORG_ID, + TEST_ORGANIZATION_ID, TEST_PROJECT_ID, ) diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py index effffd5..bb2de29 100644 --- a/backend/app/tests/test_validator_configs_integration.py +++ b/backend/app/tests/test_validator_configs_integration.py @@ -1,19 +1,18 @@ import uuid import pytest -from sqlalchemy.exc import OperationalError from sqlmodel import Session, delete from app.core.db import engine -from app.models.config.validator_config_table import ValidatorConfig +from app.models.config.validator_config import ValidatorConfig pytestmark = pytest.mark.integration # Test data constants -TEST_ORG_ID = 1 +TEST_ORGANIZATION_ID = 1 TEST_PROJECT_ID = 1 BASE_URL = "/api/v1/guardrails/validators/configs/" -DEFAULT_QUERY_PARAMS = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" +DEFAULT_QUERY_PARAMS = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" VALIDATOR_PAYLOADS = { "lexical_slur": { @@ -69,14 +68,14 @@ def get_validator(self, client, validator_id): def list_validators(self, client, **query_params): """Helper to list validators with optional filters.""" - params_str = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" + params_str = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" if query_params: params_str += "&" + "&".join(f"{k}={v}" for k, v in query_params.items()) return client.get(f"{BASE_URL}{params_str}") def update_validator(self, client, validator_id, payload): """Helper to update a validator.""" - return client.patch(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}", json=payload) + return client.patch(f"{BASE_URL}{validator_id}{DEFAULT_QUERY_PARAMS}", json=payload) def delete_validator(self, client, validator_id): """Helper to delete a validator.""" @@ -91,7 +90,7 @@ def test_create_validator_success(self, integration_client, clear_database): response = self.create_validator(integration_client, "lexical_slur") assert response.status_code == 200 - data = response.json() + data = response.json()["data"] assert data["type"] == "uli_slur_match" assert data["stage"] == "input" assert data["severity"] == "all" @@ -130,7 +129,7 @@ def test_list_validators_success(self, integration_client, clear_database): response = self.list_validators(integration_client) assert response.status_code == 200 - data = response.json() + data = response.json()["data"] assert len(data) == 2 def test_list_validators_filter_by_stage(self, integration_client, clear_database): @@ -141,7 +140,7 @@ def test_list_validators_filter_by_stage(self, integration_client, clear_databas response = self.list_validators(integration_client, stage="input") assert response.status_code == 200 - data = response.json() + data = response.json()["data"] assert len(data) == 1 assert data[0]["stage"] == "input" @@ -153,18 +152,18 @@ def test_list_validators_filter_by_type(self, integration_client, clear_database response = self.list_validators(integration_client, type="pii_remover") assert response.status_code == 200 - data = response.json() + data = response.json()["data"] assert len(data) == 1 assert data[0]["type"] == "pii_remover" def test_list_validators_empty(self, integration_client, clear_database): """Test listing validators when none exist.""" response = integration_client.get( - f"{BASE_URL}?org_id=999&project_id=999", + f"{BASE_URL}?organization_id=999&project_id=999", ) assert response.status_code == 200 - data = response.json() + data = response.json()["data"] assert len(data) == 0 @@ -177,13 +176,13 @@ def test_get_validator_success(self, integration_client, clear_database): create_response = self.create_validator( integration_client, "lexical_slur", severity="all" ) - validator_id = create_response.json()["id"] + validator_id = create_response.json()["data"]["id"] # Retrieve it response = self.get_validator(integration_client, validator_id) assert response.status_code == 200 - data = response.json() + data = response.json()["data"] assert data["id"] == validator_id assert data["severity"] == "all" @@ -198,11 +197,11 @@ def test_get_validator_wrong_org(self, integration_client, clear_database): """Test that accessing validator from different org returns 404.""" # Create a validator for org 1 create_response = self.create_validator(integration_client, "minimal") - validator_id = create_response.json()["id"] + validator_id = create_response.json()["data"]["id"] # Try to access it as different org response = integration_client.get( - f"{BASE_URL}{validator_id}/?org_id=2&project_id=1", + f"{BASE_URL}{validator_id}/?organization_id=2&project_id=1", ) assert response.status_code == 404 @@ -217,16 +216,16 @@ def test_update_validator_success(self, integration_client, clear_database): create_response = self.create_validator( integration_client, "lexical_slur", severity="all" ) - validator_id = create_response.json()["id"] + validator_id = create_response.json()["data"]["id"] # Update it - update_payload = {"on_fail_action": "exception", "severity": "high"} + update_payload = {"on_fail_action": "exception", "is_enabled": False} response = self.update_validator(integration_client, validator_id, update_payload) assert response.status_code == 200 - data = response.json() + data = response.json()["data"] assert data["on_fail_action"] == "exception" - assert data["severity"] == "high" + assert data["is_enabled"] is False def test_update_validator_partial(self, integration_client, clear_database): """Test partial update preserves original fields.""" @@ -237,21 +236,21 @@ def test_update_validator_partial(self, integration_client, clear_database): severity="all", languages=["en", "hi"], ) - validator_id = create_response.json()["id"] + validator_id = create_response.json()["data"]["id"] # Update only one field - update_payload = {"severity": "low"} + update_payload = {"is_enabled": False} response = self.update_validator(integration_client, validator_id, update_payload) assert response.status_code == 200 - data = response.json() - assert data["severity"] == "low" - assert data["languages"] == ["en", "hi"] # Original preserved + data = response.json()["data"] + assert data["is_enabled"] is False + assert data["on_fail_action"] == "fix" # Original preserved def test_update_validator_not_found(self, integration_client, clear_database): """Test updating non-existent validator returns 404.""" fake_id = uuid.uuid4() - update_payload = {"severity": "low"} + update_payload = {"is_enabled": False} response = self.update_validator(integration_client, fake_id, update_payload) @@ -265,7 +264,7 @@ def test_delete_validator_success(self, integration_client, clear_database): """Test successful validator deletion.""" # Create a validator create_response = self.create_validator(integration_client, "minimal") - validator_id = create_response.json()["id"] + validator_id = create_response.json()["data"]["id"] # Delete it response = self.delete_validator(integration_client, validator_id) @@ -288,11 +287,11 @@ def test_delete_validator_wrong_org(self, integration_client, clear_database): """Test that deleting validator from different org returns 404.""" # Create a validator for org 1 create_response = self.create_validator(integration_client, "minimal") - validator_id = create_response.json()["id"] + validator_id = create_response.json()["data"]["id"] # Try to delete it as different org response = integration_client.delete( - f"{BASE_URL}{validator_id}/?org_id=2&project_id=1", + f"{BASE_URL}{validator_id}/?organization_id=2&project_id=1", ) assert response.status_code == 404 diff --git a/backend/app/utils.py b/backend/app/utils.py index 30543d5..3a1def6 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -14,17 +14,20 @@ def now(): return datetime.now(timezone.utc).replace(tzinfo=None) def split_validator_payload(data: dict): - base = {} - config = {} + model_fields = {} + config_fields = {} - for k, v in data.items(): - if k in SYSTEM_FIELDS: - base[k] = v + for key, value in data.items(): + if key in SYSTEM_FIELDS: + model_fields[key] = value else: - config[k] = v + config_fields[key] = value - return base, config + overlap = set(model_fields) & set(config_fields) + if overlap: + raise ValueError(f"Config keys conflict with reserved field names: {overlap}") + return model_fields, config_fields class APIResponse(BaseModel, Generic[T]): success: bool