From 59a0ae56a578b7d7bbb5969fb8a9161e148100b2 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 10 Feb 2026 11:56:29 +0530 Subject: [PATCH 1/8] Added banlist management code --- .../versions/004_added_banlist_config.py | 43 +++++++ backend/app/api/main.py | 5 +- backend/app/api/routes/banlist_configs.py | 93 ++++++++++++++ backend/app/crud/banlist.py | 115 ++++++++++++++++++ backend/app/models/config/banlist.py | 75 ++++++++++++ backend/app/schemas/banlist.py | 28 +++++ backend/pyproject.toml | 2 +- 7 files changed, 358 insertions(+), 3 deletions(-) create mode 100644 backend/app/alembic/versions/004_added_banlist_config.py create mode 100644 backend/app/api/routes/banlist_configs.py create mode 100644 backend/app/crud/banlist.py create mode 100644 backend/app/models/config/banlist.py create mode 100644 backend/app/schemas/banlist.py diff --git a/backend/app/alembic/versions/004_added_banlist_config.py b/backend/app/alembic/versions/004_added_banlist_config.py new file mode 100644 index 0000000..b030425 --- /dev/null +++ b/backend/app/alembic/versions/004_added_banlist_config.py @@ -0,0 +1,43 @@ +"""Added ban_list table + +Revision ID: 004 +Revises: 003 +Create Date: 2026-02-05 09:42:54.128852 + +""" +from typing import Sequence, Union + +from alembic import op +from sqlalchemy.dialects import postgresql +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '004' +down_revision = '003' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table('ban_list', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=False), + sa.Column('organization_id', sa.Integer(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('domain', sa.String(), nullable=False), + sa.Column('is_public', sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("banned_words", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + + sa.PrimaryKeyConstraint('id'), + ) + + op.create_index("idx_banlist_organization", "ban_list", ["organization_id"]) + op.create_index("idx_banlist_project", "ban_list", ["project_id"]) + op.create_index("idx_banlist_domain", "ban_list", ["domain"]) + op.create_index("idx_banlist_is_public", "ban_list", ["is_public"]) + +def downgrade() -> None: + op.drop_table('ban_list') \ No newline at end of file diff --git a/backend/app/api/main.py b/backend/app/api/main.py index bf78ade..7e5f93e 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,11 +1,12 @@ from fastapi import APIRouter -from app.api.routes import utils, guardrails, validator_configs +from app.api.routes import banlist_configs, guardrails, validator_configs, utils api_router = APIRouter() -api_router.include_router(utils.router) +api_router.include_router(banlist_configs.router) api_router.include_router(guardrails.router) api_router.include_router(validator_configs.router) +api_router.include_router(utils.router) # if settings.ENVIRONMENT == "local": # api_router.include_router(private.router) diff --git a/backend/app/api/routes/banlist_configs.py b/backend/app/api/routes/banlist_configs.py new file mode 100644 index 0000000..28fc7a1 --- /dev/null +++ b/backend/app/api/routes/banlist_configs.py @@ -0,0 +1,93 @@ +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter + +from app.api.deps import AuthDep, SessionDep +from app.crud.banlist import banlist_crud +from app.schemas.banlist import ( + BanListCreate, + BanListUpdate, + BanListResponse +) +from app.utils import APIResponse + +router = APIRouter( + prefix="/guardrails/ban-lists", + tags=["Ban Lists"] +) + +@router.post( + "/", + response_model=APIResponse[BanListResponse] + ) +def create_banlist( + payload: BanListCreate, + session: SessionDep, + organization_id: int, + project_id: int, + _: AuthDep, +): + response_model = banlist_crud.create(session, payload, organization_id, project_id) + return APIResponse.success_response(data=response_model) + +@router.get( + "/", + response_model=APIResponse[list[BanListResponse]] + ) +def list_banlists( + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, + domain: Optional[str] = None, +): + response_model = banlist_crud.list(session, organization_id, project_id, domain) + return APIResponse.success_response(data=response_model) + + +@router.get( + "/{id}", + response_model=APIResponse[BanListResponse] + ) +def get_banlist( + id: UUID, + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = banlist_crud.get(session, id, organization_id, project_id) + return APIResponse.success_response(data=obj) + + +@router.patch( + "/{id}", + response_model=APIResponse[BanListResponse] + ) +def update_banlist( + id: UUID, + organization_id: int, + project_id: int, + payload: BanListUpdate, + session: SessionDep, + _: AuthDep, +): + obj = banlist_crud.get(session, id, organization_id, project_id) + response_model = banlist_crud.update(session, obj=obj, data=payload) + return APIResponse.success_response(data=response_model) + +@router.delete( + "/{id}", + response_model=APIResponse[dict] + ) +def delete_banlist( + id: UUID, + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = banlist_crud.get(session, id, organization_id, project_id) + banlist_crud.delete(session, obj) + return APIResponse.success_response(data={"message": "Banlist deleted successfully"}) diff --git a/backend/app/crud/banlist.py b/backend/app/crud/banlist.py new file mode 100644 index 0000000..5b9b0e9 --- /dev/null +++ b/backend/app/crud/banlist.py @@ -0,0 +1,115 @@ +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.models.config.banlist import BanList +from app.schemas.banlist import BanListCreate, BanListUpdate +from app.utils import now + +class BanListCrud: + def create( + self, + session: Session, + data: BanListCreate, + organization_id: int, + project_id: int, + ) -> BanList: + obj = BanList( + **data.model_dump(), + organization_id=organization_id, + project_id=project_id, + ) + session.add(obj) + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Banlist already exists for the given configuration" + ) + + session.refresh(obj) + return obj + + def get( + self, + session: Session, + id: UUID, + organization_id: int, + project_id: int + ) -> BanList: + obj = session.get(BanList, id) + + if not obj.is_public: + self.check_owner(obj, organization_id, project_id) + + return obj + + def list( + self, + session: Session, + organization_id: int, + project_id: int, + domain: Optional[str] = None, + ) -> List[BanList]: + stmt = select(BanList).where( + ( + (BanList.organization_id == organization_id) & + (BanList.project_id == project_id) + ) | + (BanList.is_public == True) + ) + + if domain: + stmt = stmt.where(BanList.domain == domain) + + return list(session.exec(stmt)) + + def update( + self, + session: Session, + obj: BanList, + data: BanListUpdate, + ) -> BanList: + update_data = data.model_dump(exclude_unset=True) + + for k, v in update_data.items(): + setattr(obj, k, v) + + obj.updated_at = now() + + session.add(obj) + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Banlist already exists for the given configuration" + ) + except Exception: + session.rollback() + raise + + session.refresh(obj) + return obj + + def delete(self, session: Session, obj: BanList): + session.delete(obj) + try: + session.commit() + except Exception: + session.rollback() + raise + + def check_owner(self, obj, organization_id, project_id): + if obj.organization_id != organization_id or obj.project_id != project_id: + raise HTTPException(status_code=403, detail="Not owner") + +banlist_crud = BanListCrud() \ No newline at end of file diff --git a/backend/app/models/config/banlist.py b/backend/app/models/config/banlist.py new file mode 100644 index 0000000..b6dae54 --- /dev/null +++ b/backend/app/models/config/banlist.py @@ -0,0 +1,75 @@ +from datetime import datetime +from typing import List, Optional +from uuid import UUID, uuid4 + +from sqlalchemy import Column, String +from sqlalchemy.dialects.postgresql import ARRAY +from sqlmodel import Field, SQLModel + +from app.utils import now + +class BanList(SQLModel, table=True): + __tablename__ = "ban_list" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + index=True, + sa_column_kwargs={"comment": "Unique identifier for the ban list entry"} + ) + + name: str = Field( + nullable=False, + sa_column_kwargs={"comment": "Name of the ban list entry"} + ) + + description: Optional[str] = Field( + nullable=False, + sa_column_kwargs={"comment": "Description of the ban list entry"} + ) + + banned_words: list[str] = Field( + default_factory=list, + sa_column=Column( + ARRAY(String), + nullable=False, + comment="List of banned words", + ), + description=("List of banned words") + ) + + organization_id: int = Field( + index=True, + nullable=False, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + index=True, + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + domain: str = Field( + default=None, + index=False, + nullable=False, + sa_column_kwargs={"comment": "Domain or context for the ban list entry"} + ) + + is_public: bool = Field( + default=False, + sa_column_kwargs={"comment": "Whether the ban list entry is public or private"} + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the ban list entry was created"} + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the ban list entry was last updated"} + ) \ No newline at end of file diff --git a/backend/app/schemas/banlist.py b/backend/app/schemas/banlist.py new file mode 100644 index 0000000..f4cb47b --- /dev/null +++ b/backend/app/schemas/banlist.py @@ -0,0 +1,28 @@ +from uuid import UUID +from datetime import datetime +from typing import List, Optional + +from sqlmodel import SQLModel + +class BanListBase(SQLModel): + name: str + description: str + banned_words: list[str] + domain: str + is_public: bool = False + + +class BanListCreate(BanListBase): + pass + + +class BanListUpdate(SQLModel): + name: Optional[str] = None + description: Optional[str] = None + banned_words: Optional[list[str]] = None + domain: Optional[str] = None + is_public: Optional[bool] = None + + +class BanListResponse(BanListBase): + pass \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 979075e..5d5e895 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "sentry-sdk[fastapi]<2.0.0,>=1.40.6", "pyjwt<3.0.0,>=2.8.0", "asgi-correlation-id>=4.3.4", - "guardrails-ai>=0.7.2", + "guardrails-ai[hub]>=0.8.0", "emoji", "ftfy", "presidio_analyzer>=2.2.360", From c1078caa718c485aaa9c8cd2100122e8d8885f0f Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 10 Feb 2026 12:16:12 +0530 Subject: [PATCH 2/8] Added tests --- backend/app/crud/banlist.py | 3 + backend/app/schemas/banlist.py | 5 +- backend/app/tests/test_banlist_configs.py | 125 ++++++++++ .../app/tests/test_banlists_integration.py | 225 ++++++++++++++++++ 4 files changed, 357 insertions(+), 1 deletion(-) create mode 100644 backend/app/tests/test_banlist_configs.py create mode 100644 backend/app/tests/test_banlists_integration.py diff --git a/backend/app/crud/banlist.py b/backend/app/crud/banlist.py index 5b9b0e9..0e76f11 100644 --- a/backend/app/crud/banlist.py +++ b/backend/app/crud/banlist.py @@ -46,6 +46,9 @@ def get( ) -> BanList: obj = session.get(BanList, id) + if obj is None: + raise HTTPException(status_code=404, detail="Banlist not found") + if not obj.is_public: self.check_owner(obj, organization_id, project_id) diff --git a/backend/app/schemas/banlist.py b/backend/app/schemas/banlist.py index f4cb47b..94b1b01 100644 --- a/backend/app/schemas/banlist.py +++ b/backend/app/schemas/banlist.py @@ -2,9 +2,12 @@ from datetime import datetime from typing import List, Optional +from pydantic import ConfigDict from sqlmodel import SQLModel class BanListBase(SQLModel): + model_config = ConfigDict(extra="allow") + name: str description: str banned_words: list[str] @@ -25,4 +28,4 @@ class BanListUpdate(SQLModel): class BanListResponse(BanListBase): - pass \ No newline at end of file + id: UUID \ No newline at end of file diff --git a/backend/app/tests/test_banlist_configs.py b/backend/app/tests/test_banlist_configs.py new file mode 100644 index 0000000..2004f3e --- /dev/null +++ b/backend/app/tests/test_banlist_configs.py @@ -0,0 +1,125 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException +from sqlmodel import Session + +from app.api.routes.banlist_configs import ( + create_banlist, + list_banlists, + get_banlist, + update_banlist, + delete_banlist, +) +from app.schemas.banlist import BanListCreate, BanListUpdate + + +TEST_ID = uuid.uuid4() +TEST_ORGANIZATION_ID = 1 +TEST_PROJECT_ID = 10 + + +@pytest.fixture +def mock_session(): + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_banlist(): + obj = MagicMock() + obj.id = TEST_ID + obj.name = "test" + obj.description = "desc" + obj.banned_words = ["bad"] + obj.organization_id = TEST_ORGANIZATION_ID + obj.project_id = TEST_PROJECT_ID + obj.domain = "health" + obj.is_public = False + return obj + + +@pytest.fixture +def create_payload(): + return BanListCreate( + name="test", + description="desc", + banned_words=["bad"], + domain="health", + is_public=False, + ) + + +def test_create_calls_crud(mock_session, create_payload, sample_banlist): + with patch("app.api.routes.banlist_configs.banlist_crud") as crud: + crud.create.return_value = sample_banlist + + result = create_banlist( + payload=create_payload, + session=mock_session, + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + _=None, + ) + + assert result.data == sample_banlist + + +def test_list_returns_data(mock_session, sample_banlist): + with patch("app.api.routes.banlist_configs.banlist_crud") as crud: + crud.list.return_value = [sample_banlist] + + result = list_banlists( + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert len(result.data) == 1 + + +def test_get_success(mock_session, sample_banlist): + with patch("app.api.routes.banlist_configs.banlist_crud") as crud: + crud.get.return_value = sample_banlist + + result = get_banlist( + id=TEST_ID, + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert result.data == sample_banlist + +def test_update_success(mock_session, sample_banlist): + with patch("app.api.routes.banlist_configs.banlist_crud") as crud: + crud.get.return_value = sample_banlist + crud.update.return_value = sample_banlist + + result = update_banlist( + id=TEST_ID, + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + payload=BanListUpdate(name="new"), + session=mock_session, + _=None, + ) + + assert result.data == sample_banlist + + +def test_delete_success(mock_session, sample_banlist): + with patch("app.api.routes.banlist_configs.banlist_crud") as crud: + crud.get.return_value = sample_banlist + + result = delete_banlist( + id=TEST_ID, + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert result.success is True diff --git a/backend/app/tests/test_banlists_integration.py b/backend/app/tests/test_banlists_integration.py new file mode 100644 index 0000000..b582842 --- /dev/null +++ b/backend/app/tests/test_banlists_integration.py @@ -0,0 +1,225 @@ +import uuid +import pytest +from sqlmodel import Session, delete + +from app.core.db import engine +from app.models.config.banlist import BanList + +pytestmark = pytest.mark.integration + + +TEST_ORGANIZATION_ID = 1 +TEST_PROJECT_ID = 1 +BASE_URL = "/api/v1/guardrails/ban-lists/" +DEFAULT_QUERY = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" + + +BAN_LIST_PAYLOADS = { + "minimal": { + "name": "default", + "description": "basic list", + "banned_words": ["bad"], + "domain": "general", + }, + "health": { + "name": "health-list", + "description": "healthcare words", + "banned_words": ["gender detection", "sonography"], + "domain": "health", + }, + "edu": { + "name": "edu-list", + "description": "education words", + "banned_words": ["cheating"], + "domain": "edu", + }, + "public": { + "name": "public-list", + "description": "shared", + "banned_words": ["shared"], + "is_public": True, + "domain": "general", + }, +} + + +@pytest.fixture +def clear_database(): + with Session(engine) as session: + session.exec(delete(BanList)) + session.commit() + + yield + + with Session(engine) as session: + session.exec(delete(BanList)) + session.commit() + + +class BaseBanListTest: + + def create(self, client, payload_key="minimal", **kwargs): + payload = {**BAN_LIST_PAYLOADS[payload_key], **kwargs} + return client.post(f"{BASE_URL}{DEFAULT_QUERY}", json=payload) + + def list(self, client, **filters): + params = DEFAULT_QUERY + if filters: + params += "&" + "&".join(f"{k}={v}" for k, v in filters.items()) + return client.get(f"{BASE_URL}{params}") + + def get(self, client, id, org=TEST_ORGANIZATION_ID, project=TEST_PROJECT_ID): + return client.get(f"{BASE_URL}{id}/?organization_id={org}&project_id={project}") + + def update(self, client, id, payload): + return client.patch(f"{BASE_URL}{id}/{DEFAULT_QUERY}", json=payload) + + def delete(self, client, id): + return client.delete(f"{BASE_URL}{id}/{DEFAULT_QUERY}") + + +class TestCreateBanList(BaseBanListTest): + + def test_create_success(self, integration_client, clear_database): + response = self.create(integration_client, "minimal") + + assert response.status_code == 200 + data = response.json()["data"] + + assert data["name"] == "default" + assert data["banned_words"] == ["bad"] + + def test_create_validation_error(self, integration_client, clear_database): + response = integration_client.post( + f"{BASE_URL}{DEFAULT_QUERY}", + json={"name": "missing words"}, + ) + + assert response.status_code == 422 + + +class TestListBanLists(BaseBanListTest): + + def test_list_success(self, integration_client, clear_database): + self.create(integration_client, "minimal") + self.create(integration_client, "health") + + response = self.list(integration_client) + + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + + def test_filter_by_domain(self, integration_client, clear_database): + self.create(integration_client, "health") + self.create(integration_client, "edu") + + response = self.list(integration_client, domain="health") + + data = response.json()["data"] + + assert len(data) == 1 + assert data[0]["domain"] == "health" + + def test_list_empty(self, integration_client, clear_database): + response = self.list(integration_client) + + assert response.json()["data"] == [] + + +class TestPublicAccess(BaseBanListTest): + + def test_public_visible_to_other_org(self, integration_client, clear_database): + create_resp = self.create(integration_client, "public") + ban_id = create_resp.json()["data"]["id"] + + response = self.get(integration_client, ban_id, org=999, project=999) + + assert response.status_code == 200 + + +class TestGetBanList(BaseBanListTest): + + def test_get_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["data"]["id"] + + response = self.get(integration_client, ban_id) + + assert response.status_code == 200 + + def test_get_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + response = self.get(integration_client, fake) + + assert response.status_code == 404 + + def test_get_wrong_owner_private(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["data"]["id"] + + response = self.get(integration_client, ban_id, org=2, project=2) + + assert response.status_code in (403, 404) + + +class TestUpdateBanList(BaseBanListTest): + + def test_update_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["data"]["id"] + + response = self.update( + integration_client, + ban_id, + {"banned_words": ["bad", "worse"]}, + ) + + assert response.status_code == 200 + + data = response.json()["data"] + assert data["banned_words"] == ["bad", "worse"] + + def test_partial_update(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["data"]["id"] + + response = self.update(integration_client, ban_id, {"name": "updated"}) + + assert response.json()["data"]["name"] == "updated" + + def test_update_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + + response = self.update(integration_client, fake, {"name": "x"}) + + assert response.status_code == 404 + + +class TestDeleteBanList(BaseBanListTest): + + def test_delete_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["data"]["id"] + + response = self.delete(integration_client, ban_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + + def test_delete_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + + response = self.delete(integration_client, fake) + + assert response.status_code == 404 + + def test_delete_wrong_owner(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["data"]["id"] + + response = integration_client.delete( + f"{BASE_URL}{ban_id}/?organization_id=999&project_id=999" + ) + + assert response.status_code in (403, 404) From 8c0b7de955f00ee9a874f74af1142f20fe377815 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 10 Feb 2026 13:09:14 +0530 Subject: [PATCH 3/8] resolved comments --- backend/app/alembic/versions/004_added_banlist_config.py | 1 + backend/app/crud/banlist.py | 5 ++++- backend/app/models/config/banlist.py | 3 +-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/app/alembic/versions/004_added_banlist_config.py b/backend/app/alembic/versions/004_added_banlist_config.py index b030425..2eda6f3 100644 --- a/backend/app/alembic/versions/004_added_banlist_config.py +++ b/backend/app/alembic/versions/004_added_banlist_config.py @@ -32,6 +32,7 @@ def upgrade() -> None: sa.Column('updated_at', sa.DateTime(), nullable=False), sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name', 'organization_id', 'project_id', name='uq_banlist_name_org_project'), ) op.create_index("idx_banlist_organization", "ban_list", ["organization_id"]) diff --git a/backend/app/crud/banlist.py b/backend/app/crud/banlist.py index 0e76f11..003842b 100644 --- a/backend/app/crud/banlist.py +++ b/backend/app/crud/banlist.py @@ -33,7 +33,10 @@ def create( 400, "Banlist already exists for the given configuration" ) - + except Exception: + session.rollback() + raise + session.refresh(obj) return obj diff --git a/backend/app/models/config/banlist.py b/backend/app/models/config/banlist.py index b6dae54..f3d0cf2 100644 --- a/backend/app/models/config/banlist.py +++ b/backend/app/models/config/banlist.py @@ -23,7 +23,7 @@ class BanList(SQLModel, table=True): sa_column_kwargs={"comment": "Name of the ban list entry"} ) - description: Optional[str] = Field( + description: str = Field( nullable=False, sa_column_kwargs={"comment": "Description of the ban list entry"} ) @@ -51,7 +51,6 @@ class BanList(SQLModel, table=True): ) domain: str = Field( - default=None, index=False, nullable=False, sa_column_kwargs={"comment": "Domain or context for the ban list entry"} From 8757facc58492c51d9d58dd733b9f1a3f9eae148 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 10 Feb 2026 16:41:27 +0530 Subject: [PATCH 4/8] resolved comment --- backend/app/models/config/banlist.py | 2 +- backend/app/schemas/banlist.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/app/models/config/banlist.py b/backend/app/models/config/banlist.py index f3d0cf2..e97b2b0 100644 --- a/backend/app/models/config/banlist.py +++ b/backend/app/models/config/banlist.py @@ -9,7 +9,7 @@ from app.utils import now class BanList(SQLModel, table=True): - __tablename__ = "ban_list" + __tablename__ = "banlist" id: UUID = Field( default_factory=uuid4, diff --git a/backend/app/schemas/banlist.py b/backend/app/schemas/banlist.py index 94b1b01..f81a3f0 100644 --- a/backend/app/schemas/banlist.py +++ b/backend/app/schemas/banlist.py @@ -6,8 +6,6 @@ from sqlmodel import SQLModel class BanListBase(SQLModel): - model_config = ConfigDict(extra="allow") - name: str description: str banned_words: list[str] From 42be28f34be105c43ebc94bde6f94c2a87fb4b0d Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 10 Feb 2026 17:05:05 +0530 Subject: [PATCH 5/8] Added seed database and updated tests --- .env.example | 4 - .env.test.example | 4 - backend/app/tests/conftest.py | 51 +++++++- backend/app/tests/seed_data.json | 90 ++++++++++++++ backend/app/tests/seed_data.py | 75 ++++++++++++ backend/app/tests/test_banlist_configs.py | 60 ++++------ .../app/tests/test_banlists_integration.py | 95 ++++++--------- backend/app/tests/test_validator_configs.py | 60 ++++------ .../test_validator_configs_integration.py | 113 ++++++------------ 9 files changed, 338 insertions(+), 214 deletions(-) create mode 100644 backend/app/tests/seed_data.json create mode 100644 backend/app/tests/seed_data.py diff --git a/.env.example b/.env.example index f869ad9..32afe10 100644 --- a/.env.example +++ b/.env.example @@ -21,10 +21,6 @@ SENTRY_DSN= DOCKER_IMAGE_BACKEND=kaapi-guardrails-backend -# Callback Timeouts (in seconds) -CALLBACK_CONNECT_TIMEOUT=3 -CALLBACK_READ_TIMEOUT=10 - # require as a env if you want to use doc transformation OPENAI_API_KEY="" GUARDRAILS_HUB_API_KEY="" diff --git a/.env.test.example b/.env.test.example index 368b7d7..b91edf9 100644 --- a/.env.test.example +++ b/.env.test.example @@ -21,10 +21,6 @@ SENTRY_DSN= DOCKER_IMAGE_BACKEND=kaapi-guardrails-backend -# Callback Timeouts (in seconds) -CALLBACK_CONNECT_TIMEOUT=3 -CALLBACK_READ_TIMEOUT=10 - # require as a env if you want to use doc transformation OPENAI_API_KEY="" GUARDRAILS_HUB_API_KEY="" diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index d595b97..93ce161 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -4,11 +4,23 @@ import pytest from fastapi.testclient import TestClient -from sqlmodel import Session, create_engine, SQLModel +from sqlmodel import Session, create_engine, SQLModel, delete from app.main import app from app.api.deps import SessionDep, verify_bearer_token from app.core.config import settings +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.models.config.banlist import BanList +from app.models.config.validator_config import ValidatorConfig +from app.tests.seed_data import ( + BANLIST_INTEGRATION_ORGANIZATION_ID, + BANLIST_INTEGRATION_PROJECT_ID, + BAN_LIST_PAYLOADS, + VALIDATOR_INTEGRATION_ORGANIZATION_ID, + VALIDATOR_INTEGRATION_PROJECT_ID, + VALIDATOR_PAYLOADS, +) +from app.utils import split_validator_payload test_engine = create_engine( str(settings.SQLALCHEMY_DATABASE_URI), @@ -20,6 +32,33 @@ def override_session(): with Session(test_engine) as session: yield session + +def seed_test_data(session: Session) -> None: + for payload in BAN_LIST_PAYLOADS.values(): + session.add( + BanList( + **payload, + organization_id=BANLIST_INTEGRATION_ORGANIZATION_ID, + project_id=BANLIST_INTEGRATION_PROJECT_ID, + ) + ) + + for payload in VALIDATOR_PAYLOADS.values(): + model_fields, config_fields = split_validator_payload(payload) + session.add( + ValidatorConfig( + organization_id=VALIDATOR_INTEGRATION_ORGANIZATION_ID, + project_id=VALIDATOR_INTEGRATION_PROJECT_ID, + type=ValidatorType(model_fields["type"]), + stage=Stage(model_fields["stage"]), + on_fail_action=GuardrailOnFail(model_fields["on_fail_action"]), + is_enabled=model_fields.get("is_enabled", True), + config=config_fields, + ) + ) + + session.commit() + @pytest.fixture(scope="session", autouse=True) def setup_test_db(): SQLModel.metadata.create_all(test_engine) @@ -43,6 +82,16 @@ def override_dependencies(): app.dependency_overrides.clear() + +@pytest.fixture(scope="function") +def seed_db(): + with Session(test_engine) as session: + session.exec(delete(BanList)) + session.exec(delete(ValidatorConfig)) + session.commit() + seed_test_data(session) + yield + @pytest.fixture(scope="function") def client(): with TestClient(app) as c: diff --git a/backend/app/tests/seed_data.json b/backend/app/tests/seed_data.json new file mode 100644 index 0000000..9ea1d22 --- /dev/null +++ b/backend/app/tests/seed_data.json @@ -0,0 +1,90 @@ +{ + "banlist": { + "unit": { + "test_id": "11111111-1111-1111-1111-111111111111", + "organization_id": 1, + "project_id": 10, + "sample": { + "name": "test", + "description": "desc", + "banned_words": ["bad"], + "domain": "health", + "is_public": false + } + }, + "integration": { + "organization_id": 1, + "project_id": 1, + "payloads": { + "minimal": { + "name": "default", + "description": "basic list", + "banned_words": ["bad"], + "domain": "general" + }, + "health": { + "name": "health-list", + "description": "healthcare words", + "banned_words": ["gender detection", "sonography"], + "domain": "health" + }, + "edu": { + "name": "edu-list", + "description": "education words", + "banned_words": ["cheating"], + "domain": "edu" + }, + "public": { + "name": "public-list", + "description": "shared", + "banned_words": ["shared"], + "is_public": true, + "domain": "general" + } + } + } + }, + "validator": { + "unit": { + "validator_id": "22222222-2222-2222-2222-222222222222", + "organization_id": 1, + "project_id": 1, + "type": "LexicalSlur", + "stage": "Input", + "on_fail_action": "Fix", + "is_enabled": true, + "config": { + "severity": "all", + "languages": ["en", "hi"] + } + }, + "integration": { + "organization_id": 1, + "project_id": 1, + "payloads": { + "lexical_slur": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + "severity": "all", + "languages": ["en", "hi"] + }, + "pii_remover_input": { + "type": "pii_remover", + "stage": "input", + "on_fail_action": "fix" + }, + "pii_remover_output": { + "type": "pii_remover", + "stage": "output", + "on_fail_action": "fix" + }, + "minimal": { + "type": "gender_assumption_bias", + "stage": "input", + "on_fail_action": "fix" + } + } + } + } +} diff --git a/backend/app/tests/seed_data.py b/backend/app/tests/seed_data.py new file mode 100644 index 0000000..50b9a16 --- /dev/null +++ b/backend/app/tests/seed_data.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path +import uuid +from unittest.mock import MagicMock + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.models.config.validator_config import ValidatorConfig +from app.schemas.banlist import BanListCreate + +SEED_DATA_PATH = Path(__file__).with_name("seed_data.json") + + +def _load_seed_data() -> dict: + with SEED_DATA_PATH.open("r", encoding="utf-8") as f: + return json.load(f) + + +DATA = _load_seed_data() + +BANLIST_UNIT = DATA["banlist"]["unit"] +BANLIST_INTEGRATION = DATA["banlist"]["integration"] + +VALIDATOR_UNIT = DATA["validator"]["unit"] +VALIDATOR_INTEGRATION = DATA["validator"]["integration"] + +BANLIST_TEST_ID = uuid.UUID(BANLIST_UNIT["test_id"]) +BANLIST_TEST_ORGANIZATION_ID = BANLIST_UNIT["organization_id"] +BANLIST_TEST_PROJECT_ID = BANLIST_UNIT["project_id"] + +BANLIST_INTEGRATION_ORGANIZATION_ID = BANLIST_INTEGRATION["organization_id"] +BANLIST_INTEGRATION_PROJECT_ID = BANLIST_INTEGRATION["project_id"] +BAN_LIST_PAYLOADS = BANLIST_INTEGRATION["payloads"] + +VALIDATOR_TEST_ID = uuid.UUID(VALIDATOR_UNIT["validator_id"]) +VALIDATOR_TEST_ORGANIZATION_ID = VALIDATOR_UNIT["organization_id"] +VALIDATOR_TEST_PROJECT_ID = VALIDATOR_UNIT["project_id"] +VALIDATOR_TEST_TYPE = ValidatorType[VALIDATOR_UNIT["type"]] +VALIDATOR_TEST_STAGE = Stage[VALIDATOR_UNIT["stage"]] +VALIDATOR_TEST_ON_FAIL = GuardrailOnFail[VALIDATOR_UNIT["on_fail_action"]] +VALIDATOR_TEST_CONFIG = VALIDATOR_UNIT["config"] +VALIDATOR_TEST_IS_ENABLED = VALIDATOR_UNIT["is_enabled"] + +VALIDATOR_INTEGRATION_ORGANIZATION_ID = VALIDATOR_INTEGRATION["organization_id"] +VALIDATOR_INTEGRATION_PROJECT_ID = VALIDATOR_INTEGRATION["project_id"] +VALIDATOR_PAYLOADS = VALIDATOR_INTEGRATION["payloads"] + + +def build_banlist_create_payload() -> BanListCreate: + return BanListCreate(**BANLIST_UNIT["sample"]) + + +def build_sample_banlist_mock() -> MagicMock: + obj = MagicMock() + obj.id = BANLIST_TEST_ID + obj.name = BANLIST_UNIT["sample"]["name"] + obj.description = BANLIST_UNIT["sample"]["description"] + obj.banned_words = BANLIST_UNIT["sample"]["banned_words"] + obj.organization_id = BANLIST_TEST_ORGANIZATION_ID + obj.project_id = BANLIST_TEST_PROJECT_ID + obj.domain = BANLIST_UNIT["sample"]["domain"] + obj.is_public = BANLIST_UNIT["sample"].get("is_public", False) + return obj + + +def build_sample_validator_config() -> ValidatorConfig: + return ValidatorConfig( + id=VALIDATOR_TEST_ID, + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + type=VALIDATOR_TEST_TYPE, + stage=VALIDATOR_TEST_STAGE, + on_fail_action=VALIDATOR_TEST_ON_FAIL, + is_enabled=VALIDATOR_TEST_IS_ENABLED, + config=VALIDATOR_TEST_CONFIG, + ) diff --git a/backend/app/tests/test_banlist_configs.py b/backend/app/tests/test_banlist_configs.py index 2004f3e..eda6261 100644 --- a/backend/app/tests/test_banlist_configs.py +++ b/backend/app/tests/test_banlist_configs.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import pytest -from fastapi import HTTPException from sqlmodel import Session from app.api.routes.banlist_configs import ( @@ -12,12 +11,14 @@ update_banlist, delete_banlist, ) -from app.schemas.banlist import BanListCreate, BanListUpdate - - -TEST_ID = uuid.uuid4() -TEST_ORGANIZATION_ID = 1 -TEST_PROJECT_ID = 10 +from app.schemas.banlist import BanListUpdate +from app.tests.seed_data import ( + BANLIST_TEST_ID, + BANLIST_TEST_ORGANIZATION_ID, + BANLIST_TEST_PROJECT_ID, + build_banlist_create_payload, + build_sample_banlist_mock, +) @pytest.fixture @@ -27,27 +28,12 @@ def mock_session(): @pytest.fixture def sample_banlist(): - obj = MagicMock() - obj.id = TEST_ID - obj.name = "test" - obj.description = "desc" - obj.banned_words = ["bad"] - obj.organization_id = TEST_ORGANIZATION_ID - obj.project_id = TEST_PROJECT_ID - obj.domain = "health" - obj.is_public = False - return obj + return build_sample_banlist_mock() @pytest.fixture def create_payload(): - return BanListCreate( - name="test", - description="desc", - banned_words=["bad"], - domain="health", - is_public=False, - ) + return build_banlist_create_payload() def test_create_calls_crud(mock_session, create_payload, sample_banlist): @@ -57,8 +43,8 @@ def test_create_calls_crud(mock_session, create_payload, sample_banlist): result = create_banlist( payload=create_payload, session=mock_session, - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, + organization_id=BANLIST_TEST_ORGANIZATION_ID, + project_id=BANLIST_TEST_PROJECT_ID, _=None, ) @@ -70,8 +56,8 @@ def test_list_returns_data(mock_session, sample_banlist): crud.list.return_value = [sample_banlist] result = list_banlists( - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, + organization_id=BANLIST_TEST_ORGANIZATION_ID, + project_id=BANLIST_TEST_PROJECT_ID, session=mock_session, _=None, ) @@ -84,9 +70,9 @@ def test_get_success(mock_session, sample_banlist): crud.get.return_value = sample_banlist result = get_banlist( - id=TEST_ID, - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, + id=BANLIST_TEST_ID, + organization_id=BANLIST_TEST_ORGANIZATION_ID, + project_id=BANLIST_TEST_PROJECT_ID, session=mock_session, _=None, ) @@ -99,9 +85,9 @@ def test_update_success(mock_session, sample_banlist): crud.update.return_value = sample_banlist result = update_banlist( - id=TEST_ID, - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, + id=BANLIST_TEST_ID, + organization_id=BANLIST_TEST_ORGANIZATION_ID, + project_id=BANLIST_TEST_PROJECT_ID, payload=BanListUpdate(name="new"), session=mock_session, _=None, @@ -115,9 +101,9 @@ def test_delete_success(mock_session, sample_banlist): crud.get.return_value = sample_banlist result = delete_banlist( - id=TEST_ID, - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, + id=BANLIST_TEST_ID, + organization_id=BANLIST_TEST_ORGANIZATION_ID, + project_id=BANLIST_TEST_PROJECT_ID, session=mock_session, _=None, ) diff --git a/backend/app/tests/test_banlists_integration.py b/backend/app/tests/test_banlists_integration.py index b582842..773ab78 100644 --- a/backend/app/tests/test_banlists_integration.py +++ b/backend/app/tests/test_banlists_integration.py @@ -4,43 +4,20 @@ from app.core.db import engine from app.models.config.banlist import BanList +from app.tests.seed_data import ( + BANLIST_INTEGRATION_ORGANIZATION_ID, + BANLIST_INTEGRATION_PROJECT_ID, + BAN_LIST_PAYLOADS, +) pytestmark = pytest.mark.integration -TEST_ORGANIZATION_ID = 1 -TEST_PROJECT_ID = 1 BASE_URL = "/api/v1/guardrails/ban-lists/" -DEFAULT_QUERY = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" - - -BAN_LIST_PAYLOADS = { - "minimal": { - "name": "default", - "description": "basic list", - "banned_words": ["bad"], - "domain": "general", - }, - "health": { - "name": "health-list", - "description": "healthcare words", - "banned_words": ["gender detection", "sonography"], - "domain": "health", - }, - "edu": { - "name": "edu-list", - "description": "education words", - "banned_words": ["cheating"], - "domain": "edu", - }, - "public": { - "name": "public-list", - "description": "shared", - "banned_words": ["shared"], - "is_public": True, - "domain": "general", - }, -} +DEFAULT_QUERY = ( + f"?organization_id={BANLIST_INTEGRATION_ORGANIZATION_ID}" + f"&project_id={BANLIST_INTEGRATION_PROJECT_ID}" +) @pytest.fixture @@ -68,7 +45,13 @@ def list(self, client, **filters): params += "&" + "&".join(f"{k}={v}" for k, v in filters.items()) return client.get(f"{BASE_URL}{params}") - def get(self, client, id, org=TEST_ORGANIZATION_ID, project=TEST_PROJECT_ID): + def get( + self, + client, + id, + org=BANLIST_INTEGRATION_ORGANIZATION_ID, + project=BANLIST_INTEGRATION_PROJECT_ID, + ): return client.get(f"{BASE_URL}{id}/?organization_id={org}&project_id={project}") def update(self, client, id, payload): @@ -100,19 +83,15 @@ def test_create_validation_error(self, integration_client, clear_database): class TestListBanLists(BaseBanListTest): - def test_list_success(self, integration_client, clear_database): - self.create(integration_client, "minimal") - self.create(integration_client, "health") + def test_list_success(self, integration_client, clear_database, seed_db): response = self.list(integration_client) assert response.status_code == 200 data = response.json()["data"] - assert len(data) == 2 + assert len(data) == 4 - def test_filter_by_domain(self, integration_client, clear_database): - self.create(integration_client, "health") - self.create(integration_client, "edu") + def test_filter_by_domain(self, integration_client, clear_database, seed_db): response = self.list(integration_client, domain="health") @@ -140,9 +119,9 @@ def test_public_visible_to_other_org(self, integration_client, clear_database): class TestGetBanList(BaseBanListTest): - def test_get_success(self, integration_client, clear_database): - create_resp = self.create(integration_client, "minimal") - ban_id = create_resp.json()["data"]["id"] + def test_get_success(self, integration_client, clear_database, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] response = self.get(integration_client, ban_id) @@ -154,9 +133,9 @@ def test_get_not_found(self, integration_client, clear_database): assert response.status_code == 404 - def test_get_wrong_owner_private(self, integration_client, clear_database): - create_resp = self.create(integration_client, "minimal") - ban_id = create_resp.json()["data"]["id"] + def test_get_wrong_owner_private(self, integration_client, clear_database, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] response = self.get(integration_client, ban_id, org=2, project=2) @@ -165,9 +144,9 @@ def test_get_wrong_owner_private(self, integration_client, clear_database): class TestUpdateBanList(BaseBanListTest): - def test_update_success(self, integration_client, clear_database): - create_resp = self.create(integration_client, "minimal") - ban_id = create_resp.json()["data"]["id"] + def test_update_success(self, integration_client, clear_database, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] response = self.update( integration_client, @@ -180,9 +159,9 @@ def test_update_success(self, integration_client, clear_database): data = response.json()["data"] assert data["banned_words"] == ["bad", "worse"] - def test_partial_update(self, integration_client, clear_database): - create_resp = self.create(integration_client, "minimal") - ban_id = create_resp.json()["data"]["id"] + def test_partial_update(self, integration_client, clear_database, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] response = self.update(integration_client, ban_id, {"name": "updated"}) @@ -198,9 +177,9 @@ def test_update_not_found(self, integration_client, clear_database): class TestDeleteBanList(BaseBanListTest): - def test_delete_success(self, integration_client, clear_database): - create_resp = self.create(integration_client, "minimal") - ban_id = create_resp.json()["data"]["id"] + def test_delete_success(self, integration_client, clear_database, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] response = self.delete(integration_client, ban_id) @@ -214,9 +193,9 @@ def test_delete_not_found(self, integration_client, clear_database): assert response.status_code == 404 - def test_delete_wrong_owner(self, integration_client, clear_database): - create_resp = self.create(integration_client, "minimal") - ban_id = create_resp.json()["data"]["id"] + def test_delete_wrong_owner(self, integration_client, clear_database, seed_db): + list_resp = self.list(integration_client) + ban_id = list_resp.json()["data"][0]["id"] response = integration_client.delete( f"{BASE_URL}{ban_id}/?organization_id=999&project_id=999" diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index 017da20..212bad9 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -1,20 +1,21 @@ -import uuid from unittest.mock import MagicMock import pytest from sqlmodel import Session from app.crud.validator_config import validator_config_crud -from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.core.enum import GuardrailOnFail, ValidatorType from app.models.config.validator_config import ValidatorConfig - -# Test data constants -TEST_ORGANIZATION_ID = 1 -TEST_PROJECT_ID = 1 -TEST_VALIDATOR_ID = uuid.uuid4() -TEST_TYPE = ValidatorType.LexicalSlur -TEST_STAGE = Stage.Input -TEST_ON_FAIL = GuardrailOnFail.Fix +from app.tests.seed_data import ( + VALIDATOR_TEST_CONFIG, + VALIDATOR_TEST_ID, + VALIDATOR_TEST_ON_FAIL, + VALIDATOR_TEST_ORGANIZATION_ID, + VALIDATOR_TEST_PROJECT_ID, + VALIDATOR_TEST_STAGE, + VALIDATOR_TEST_TYPE, + build_sample_validator_config, +) @pytest.fixture def mock_session(): @@ -25,16 +26,7 @@ def mock_session(): @pytest.fixture def sample_validator(): """Create a sample validator config for testing.""" - return ValidatorConfig( - id=TEST_VALIDATOR_ID, - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, - type=TEST_TYPE, - stage=TEST_STAGE, - on_fail_action=TEST_ON_FAIL, - is_enabled=True, - config={"severity": "all", "languages": ["en", "hi"]}, - ) + return build_sample_validator_config() class TestFlatten: @@ -43,16 +35,16 @@ def test_flatten_includes_config_fields(self, sample_validator): assert result["severity"] == "all" assert result["languages"] == ["en", "hi"] - assert result["id"] == TEST_VALIDATOR_ID + assert result["id"] == VALIDATOR_TEST_ID def test_flatten_empty_config(self): validator = ValidatorConfig( - id=TEST_VALIDATOR_ID, - organization_id=TEST_ORGANIZATION_ID, - project_id=TEST_PROJECT_ID, - type=TEST_TYPE, - stage=TEST_STAGE, - on_fail_action=TEST_ON_FAIL, + id=VALIDATOR_TEST_ID, + organization_id=VALIDATOR_TEST_ORGANIZATION_ID, + project_id=VALIDATOR_TEST_PROJECT_ID, + type=VALIDATOR_TEST_TYPE, + stage=VALIDATOR_TEST_STAGE, + on_fail_action=VALIDATOR_TEST_ON_FAIL, is_enabled=True, config={}, ) @@ -68,9 +60,9 @@ def test_success(self, sample_validator, mock_session): result = validator_config_crud.get( mock_session, - TEST_VALIDATOR_ID, - TEST_ORGANIZATION_ID, - TEST_PROJECT_ID, + VALIDATOR_TEST_ID, + VALIDATOR_TEST_ORGANIZATION_ID, + VALIDATOR_TEST_PROJECT_ID, ) assert result == sample_validator @@ -82,9 +74,9 @@ def test_not_found(self, mock_session): with pytest.raises(Exception) as exc: validator_config_crud.get( mock_session, - TEST_VALIDATOR_ID, - TEST_ORGANIZATION_ID, - TEST_PROJECT_ID, + VALIDATOR_TEST_ID, + VALIDATOR_TEST_ORGANIZATION_ID, + VALIDATOR_TEST_PROJECT_ID, ) assert "Validator not found" in str(exc.value) @@ -120,7 +112,7 @@ def test_update_extra_fields(self, sample_validator, mock_session): assert result["severity"] == "high" assert result["new_field"] == "new_value" - assert result["languages"] == ["en", "hi"] + assert result["languages"] == VALIDATOR_TEST_CONFIG["languages"] def test_merge_config(self, sample_validator, mock_session): sample_validator.config = {"severity": "all", "languages": ["en"]} diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py index bb2de29..2a95723 100644 --- a/backend/app/tests/test_validator_configs_integration.py +++ b/backend/app/tests/test_validator_configs_integration.py @@ -5,39 +5,19 @@ from app.core.db import engine from app.models.config.validator_config import ValidatorConfig +from app.tests.seed_data import ( + VALIDATOR_INTEGRATION_ORGANIZATION_ID, + VALIDATOR_INTEGRATION_PROJECT_ID, + VALIDATOR_PAYLOADS, +) pytestmark = pytest.mark.integration -# Test data constants -TEST_ORGANIZATION_ID = 1 -TEST_PROJECT_ID = 1 BASE_URL = "/api/v1/guardrails/validators/configs/" -DEFAULT_QUERY_PARAMS = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" - -VALIDATOR_PAYLOADS = { - "lexical_slur": { - "type": "uli_slur_match", - "stage": "input", - "on_fail_action": "fix", - "severity": "all", - "languages": ["en", "hi"], - }, - "pii_remover_input": { - "type": "pii_remover", - "stage": "input", - "on_fail_action": "fix", - }, - "pii_remover_output": { - "type": "pii_remover", - "stage": "output", - "on_fail_action": "fix", - }, - "minimal": { - "type": "uli_slur_match", - "stage": "input", - "on_fail_action": "fix", - }, -} +DEFAULT_QUERY_PARAMS = ( + f"?organization_id={VALIDATOR_INTEGRATION_ORGANIZATION_ID}" + f"&project_id={VALIDATOR_INTEGRATION_PROJECT_ID}" +) @pytest.fixture @@ -68,7 +48,10 @@ def get_validator(self, client, validator_id): def list_validators(self, client, **query_params): """Helper to list validators with optional filters.""" - params_str = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" + params_str = ( + f"?organization_id={VALIDATOR_INTEGRATION_ORGANIZATION_ID}" + f"&project_id={VALIDATOR_INTEGRATION_PROJECT_ID}" + ) if query_params: params_str += "&" + "&".join(f"{k}={v}" for k, v in query_params.items()) return client.get(f"{BASE_URL}{params_str}") @@ -120,40 +103,33 @@ def test_create_validator_missing_required_fields(self, integration_client, clea class TestListValidators(BaseValidatorTest): """Tests for GET /guardrails/validators/configs endpoint.""" - def test_list_validators_success(self, integration_client, clear_database): + def test_list_validators_success(self, integration_client, clear_database, seed_db): """Test successful validator listing.""" - # Create validators first - self.create_validator(integration_client, "lexical_slur") - self.create_validator(integration_client, "pii_remover_input") response = self.list_validators(integration_client) assert response.status_code == 200 data = response.json()["data"] - assert len(data) == 2 + assert len(data) == 4 - def test_list_validators_filter_by_stage(self, integration_client, clear_database): + def test_list_validators_filter_by_stage(self, integration_client, clear_database, seed_db): """Test filtering validators by stage.""" - self.create_validator(integration_client, "lexical_slur") - self.create_validator(integration_client, "pii_remover_output") response = self.list_validators(integration_client, stage="input") assert response.status_code == 200 data = response.json()["data"] - assert len(data) == 1 + assert len(data) == 3 assert data[0]["stage"] == "input" - def test_list_validators_filter_by_type(self, integration_client, clear_database): + def test_list_validators_filter_by_type(self, integration_client, clear_database, seed_db): """Test filtering validators by type.""" - self.create_validator(integration_client, "lexical_slur") - self.create_validator(integration_client, "pii_remover_input") response = self.list_validators(integration_client, type="pii_remover") assert response.status_code == 200 data = response.json()["data"] - assert len(data) == 1 + assert len(data) == 2 assert data[0]["type"] == "pii_remover" def test_list_validators_empty(self, integration_client, clear_database): @@ -170,13 +146,10 @@ def test_list_validators_empty(self, integration_client, clear_database): class TestGetValidator(BaseValidatorTest): """Tests for GET /guardrails/validators/configs/{id} endpoint.""" - def test_get_validator_success(self, integration_client, clear_database): + def test_get_validator_success(self, integration_client, clear_database, seed_db): """Test successful validator retrieval.""" - # Create a validator - create_response = self.create_validator( - integration_client, "lexical_slur", severity="all" - ) - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Retrieve it response = self.get_validator(integration_client, validator_id) @@ -193,11 +166,10 @@ def test_get_validator_not_found(self, integration_client, clear_database): assert response.status_code == 404 - def test_get_validator_wrong_org(self, integration_client, clear_database): + def test_get_validator_wrong_org(self, integration_client, clear_database, seed_db): """Test that accessing validator from different org returns 404.""" - # Create a validator for org 1 - create_response = self.create_validator(integration_client, "minimal") - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Try to access it as different org response = integration_client.get( @@ -210,13 +182,10 @@ def test_get_validator_wrong_org(self, integration_client, clear_database): class TestUpdateValidator(BaseValidatorTest): """Tests for PATCH /guardrails/validators/configs/{id} endpoint.""" - def test_update_validator_success(self, integration_client, clear_database): + def test_update_validator_success(self, integration_client, clear_database, seed_db): """Test successful validator update.""" - # Create a validator - create_response = self.create_validator( - integration_client, "lexical_slur", severity="all" - ) - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Update it update_payload = {"on_fail_action": "exception", "is_enabled": False} @@ -227,16 +196,10 @@ def test_update_validator_success(self, integration_client, clear_database): assert data["on_fail_action"] == "exception" assert data["is_enabled"] is False - def test_update_validator_partial(self, integration_client, clear_database): + def test_update_validator_partial(self, integration_client, clear_database, seed_db): """Test partial update preserves original fields.""" - # Create a validator - create_response = self.create_validator( - integration_client, - "lexical_slur", - severity="all", - languages=["en", "hi"], - ) - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Update only one field update_payload = {"is_enabled": False} @@ -260,11 +223,10 @@ def test_update_validator_not_found(self, integration_client, clear_database): class TestDeleteValidator(BaseValidatorTest): """Tests for DELETE /guardrails/validators/configs/{id} endpoint.""" - def test_delete_validator_success(self, integration_client, clear_database): + def test_delete_validator_success(self, integration_client, clear_database, seed_db): """Test successful validator deletion.""" - # Create a validator - create_response = self.create_validator(integration_client, "minimal") - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Delete it response = self.delete_validator(integration_client, validator_id) @@ -283,11 +245,10 @@ def test_delete_validator_not_found(self, integration_client, clear_database): assert response.status_code == 404 - def test_delete_validator_wrong_org(self, integration_client, clear_database): + def test_delete_validator_wrong_org(self, integration_client, clear_database, seed_db): """Test that deleting validator from different org returns 404.""" - # Create a validator for org 1 - create_response = self.create_validator(integration_client, "minimal") - validator_id = create_response.json()["data"]["id"] + list_response = self.list_validators(integration_client) + validator_id = list_response.json()["data"][0]["id"] # Try to delete it as different org response = integration_client.delete( From fe5e1bbf0b73bea5ba94758ed4969854083a36fd Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 10 Feb 2026 17:15:48 +0530 Subject: [PATCH 6/8] resolved comment --- backend/app/crud/banlist.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/backend/app/crud/banlist.py b/backend/app/crud/banlist.py index 003842b..dd3d092 100644 --- a/backend/app/crud/banlist.py +++ b/backend/app/crud/banlist.py @@ -114,8 +114,16 @@ def delete(self, session: Session, obj: BanList): session.rollback() raise - def check_owner(self, obj, organization_id, project_id): - if obj.organization_id != organization_id or obj.project_id != project_id: - raise HTTPException(status_code=403, detail="Not owner") +def check_owner(self, obj, organization_id, project_id): + is_owner = ( + obj.organization_id == organization_id + and obj.project_id == project_id + ) + + if not is_owner: + raise HTTPException( + status_code=403, + detail="You do not have permission to access this resource." + ) banlist_crud = BanListCrud() \ No newline at end of file From 840990b50b3234bbc4e09c6e12bfc7a368e151b0 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 10 Feb 2026 17:19:38 +0530 Subject: [PATCH 7/8] resolved comments --- backend/app/crud/banlist.py | 31 ++++++++++--------- backend/app/tests/conftest.py | 15 +++++++++ .../app/tests/test_banlists_integration.py | 17 ---------- .../test_validator_configs_integration.py | 18 ----------- 4 files changed, 31 insertions(+), 50 deletions(-) diff --git a/backend/app/crud/banlist.py b/backend/app/crud/banlist.py index dd3d092..2464e6c 100644 --- a/backend/app/crud/banlist.py +++ b/backend/app/crud/banlist.py @@ -1,4 +1,3 @@ -from datetime import datetime from typing import List, Optional from uuid import UUID @@ -10,8 +9,9 @@ from app.schemas.banlist import BanListCreate, BanListUpdate from app.utils import now + class BanListCrud: - def create( + def create( self, session: Session, data: BanListCreate, @@ -41,8 +41,8 @@ def create( return obj def get( - self, - session: Session, + self, + session: Session, id: UUID, organization_id: int, project_id: int @@ -114,16 +114,17 @@ def delete(self, session: Session, obj: BanList): session.rollback() raise -def check_owner(self, obj, organization_id, project_id): - is_owner = ( - obj.organization_id == organization_id - and obj.project_id == project_id - ) - - if not is_owner: - raise HTTPException( - status_code=403, - detail="You do not have permission to access this resource." + def check_owner(self, obj: BanList, organization_id: int, project_id: int) -> None: + is_owner = ( + obj.organization_id == organization_id + and obj.project_id == project_id ) -banlist_crud = BanListCrud() \ No newline at end of file + if not is_owner: + raise HTTPException( + status_code=403, + detail="You do not have permission to access this resource." + ) + + +banlist_crud = BanListCrud() diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 93ce161..51ef2e5 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -92,6 +92,21 @@ def seed_db(): seed_test_data(session) yield +@pytest.fixture +def clear_database(): + """Clear key config tables before and after each test.""" + with Session(test_engine) as session: + session.exec(delete(BanList)) + session.exec(delete(ValidatorConfig)) + session.commit() + + yield + + with Session(test_engine) as session: + session.exec(delete(BanList)) + session.exec(delete(ValidatorConfig)) + session.commit() + @pytest.fixture(scope="function") def client(): with TestClient(app) as c: diff --git a/backend/app/tests/test_banlists_integration.py b/backend/app/tests/test_banlists_integration.py index 773ab78..59c67a9 100644 --- a/backend/app/tests/test_banlists_integration.py +++ b/backend/app/tests/test_banlists_integration.py @@ -1,9 +1,5 @@ import uuid import pytest -from sqlmodel import Session, delete - -from app.core.db import engine -from app.models.config.banlist import BanList from app.tests.seed_data import ( BANLIST_INTEGRATION_ORGANIZATION_ID, BANLIST_INTEGRATION_PROJECT_ID, @@ -20,19 +16,6 @@ ) -@pytest.fixture -def clear_database(): - with Session(engine) as session: - session.exec(delete(BanList)) - session.commit() - - yield - - with Session(engine) as session: - session.exec(delete(BanList)) - session.commit() - - class BaseBanListTest: def create(self, client, payload_key="minimal", **kwargs): diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py index 2a95723..d554e01 100644 --- a/backend/app/tests/test_validator_configs_integration.py +++ b/backend/app/tests/test_validator_configs_integration.py @@ -1,10 +1,6 @@ import uuid import pytest -from sqlmodel import Session, delete - -from app.core.db import engine -from app.models.config.validator_config import ValidatorConfig from app.tests.seed_data import ( VALIDATOR_INTEGRATION_ORGANIZATION_ID, VALIDATOR_INTEGRATION_PROJECT_ID, @@ -20,20 +16,6 @@ ) -@pytest.fixture -def clear_database(): - """Clear ValidatorConfig table before and after each test.""" - with Session(engine) as session: - session.exec(delete(ValidatorConfig)) - session.commit() - - yield - - with Session(engine) as session: - session.exec(delete(ValidatorConfig)) - session.commit() - - class BaseValidatorTest: """Base class with helper methods for validator tests.""" From d3678558d1689cf9df5b9875e268d9790ea1375a Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Tue, 10 Feb 2026 17:25:40 +0530 Subject: [PATCH 8/8] resolved comment --- backend/app/api/routes/banlist_configs.py | 56 +++++++++++++++++------ backend/app/crud/banlist.py | 5 +- 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/backend/app/api/routes/banlist_configs.py b/backend/app/api/routes/banlist_configs.py index 28fc7a1..bca46cf 100644 --- a/backend/app/api/routes/banlist_configs.py +++ b/backend/app/api/routes/banlist_configs.py @@ -1,7 +1,7 @@ from typing import Optional from uuid import UUID -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.api.deps import AuthDep, SessionDep from app.crud.banlist import banlist_crud @@ -28,8 +28,13 @@ def create_banlist( project_id: int, _: AuthDep, ): - response_model = banlist_crud.create(session, payload, organization_id, project_id) - return APIResponse.success_response(data=response_model) + try: + response_model = banlist_crud.create(session, payload, organization_id, project_id) + return APIResponse.success_response(data=response_model) + except HTTPException as exc: + return APIResponse.failure_response(error=str(exc.detail)) + except Exception as exc: + return APIResponse.failure_response(error=str(exc)) @router.get( "/", @@ -42,8 +47,13 @@ def list_banlists( _: AuthDep, domain: Optional[str] = None, ): - response_model = banlist_crud.list(session, organization_id, project_id, domain) - return APIResponse.success_response(data=response_model) + try: + response_model = banlist_crud.list(session, organization_id, project_id, domain) + return APIResponse.success_response(data=response_model) + except HTTPException as exc: + return APIResponse.failure_response(error=str(exc.detail)) + except Exception as exc: + return APIResponse.failure_response(error=str(exc)) @router.get( @@ -57,8 +67,13 @@ def get_banlist( session: SessionDep, _: AuthDep, ): - obj = banlist_crud.get(session, id, organization_id, project_id) - return APIResponse.success_response(data=obj) + try: + obj = banlist_crud.get(session, id, organization_id, project_id) + return APIResponse.success_response(data=obj) + except HTTPException as exc: + return APIResponse.failure_response(error=str(exc.detail)) + except Exception as exc: + return APIResponse.failure_response(error=str(exc)) @router.patch( @@ -73,9 +88,19 @@ def update_banlist( session: SessionDep, _: AuthDep, ): - obj = banlist_crud.get(session, id, organization_id, project_id) - response_model = banlist_crud.update(session, obj=obj, data=payload) - return APIResponse.success_response(data=response_model) + try: + response_model = banlist_crud.update( + session, + id=id, + organization_id=organization_id, + project_id=project_id, + data=payload, + ) + return APIResponse.success_response(data=response_model) + except HTTPException as exc: + return APIResponse.failure_response(error=str(exc.detail)) + except Exception as exc: + return APIResponse.failure_response(error=str(exc)) @router.delete( "/{id}", @@ -88,6 +113,11 @@ def delete_banlist( session: SessionDep, _: AuthDep, ): - obj = banlist_crud.get(session, id, organization_id, project_id) - banlist_crud.delete(session, obj) - return APIResponse.success_response(data={"message": "Banlist deleted successfully"}) + try: + obj = banlist_crud.get(session, id, organization_id, project_id) + banlist_crud.delete(session, obj) + return APIResponse.success_response(data={"message": "Banlist deleted successfully"}) + except HTTPException as exc: + return APIResponse.failure_response(error=str(exc.detail)) + except Exception as exc: + return APIResponse.failure_response(error=str(exc)) diff --git a/backend/app/crud/banlist.py b/backend/app/crud/banlist.py index 2464e6c..e283da3 100644 --- a/backend/app/crud/banlist.py +++ b/backend/app/crud/banlist.py @@ -80,9 +80,12 @@ def list( def update( self, session: Session, - obj: BanList, + id: UUID, + organization_id: int, + project_id: int, data: BanListUpdate, ) -> BanList: + obj = self.get(session, id, organization_id, project_id) update_data = data.model_dump(exclude_unset=True) for k, v in update_data.items():