From b3f8239709803f96fee339376eee3a9cb4006fc7 Mon Sep 17 00:00:00 2001 From: Krrish Ghimire Date: Fri, 27 Feb 2026 21:05:14 +0545 Subject: [PATCH 01/65] make account approval configurable --- backend/config/settings.py | 2 ++ backend/core/middleware.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/backend/config/settings.py b/backend/config/settings.py index 3dba965..02fc472 100644 --- a/backend/config/settings.py +++ b/backend/config/settings.py @@ -164,3 +164,5 @@ API_BASE_URL = os.getenv("API_BASE_URL") FRONTEND_URL = os.getenv("FRONTEND_URL") DISCORD_SIGNUP_WEBHOOK_URL = os.getenv("DISCORD_SIGNUP_WEBHOOK_URL") + +REQUIRE_ACCOUNT_APPROVAL = os.getenv("REQUIRE_ACCOUNT_APPROVAL", "False").lower() == "true" diff --git a/backend/core/middleware.py b/backend/core/middleware.py index fcb90bf..fb68c61 100644 --- a/backend/core/middleware.py +++ b/backend/core/middleware.py @@ -2,6 +2,7 @@ from core.models import AccountStatus from rest_framework.authtoken.models import Token from django.http import JsonResponse +from django.conf import settings class AccountStatusMiddleware(MiddlewareMixin): @@ -16,6 +17,9 @@ class AccountStatusMiddleware(MiddlewareMixin): ] def process_request(self, request): + if not getattr(settings, 'REQUIRE_ACCOUNT_APPROVAL', False): + return None + if any(request.path.startswith(path) for path in self.EXCLUDED_PATHS): return None From 412cc2ed77a32300fee822fce01758a5409ccec0 Mon Sep 17 00:00:00 2001 From: Krrish Ghimire Date: Fri, 27 Feb 2026 21:06:45 +0545 Subject: [PATCH 02/65] make account approval configurable --- backend/.env.example | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/.env.example b/backend/.env.example index d598391..630ce67 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -29,3 +29,4 @@ API_BASE_URL=http://localhost:8000/api FRONTEND_URL=http://localhost:3000 CLOSED_ALPHA_SIGN_UPS='["test@example.com"]' +REQUIRE_ACCOUNT_APPROVAL=False From 3ad0a69ffdbe1dd33223afe54158852bd240c81f Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Fri, 27 Feb 2026 21:13:50 +0545 Subject: [PATCH 03/65] chore: upgrade dependencies --- backend/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/requirements.txt b/backend/requirements.txt index 71dc9c8..629b592 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -107,7 +107,7 @@ python-docx==1.2.0 python-dotenv==1.1.1 pytz==2025.2 PyYAML==6.0.2 -qdrant-client==1.14.3 +qdrant-client==1.17.0 redis==6.2.0 regex==2024.11.6 requests==2.32.4 @@ -128,7 +128,7 @@ tenacity==8.5.0 threadpoolctl==3.6.0 tokenizers==0.21.2 tomli==2.2.1 -torch==2.7.1 +torch==2.2.2 tqdm==4.67.1 transformers==4.53.0 Twisted==25.5.0 From 38e553ea8118f609edaea27ad227a1987301d52a Mon Sep 17 00:00:00 2001 From: Krrish Ghimire Date: Fri, 27 Feb 2026 22:48:51 +0545 Subject: [PATCH 04/65] change password --- backend/core/serializers/change_password.py | 60 ++++ backend/core/urls.py | 2 + backend/core/views/change_password.py | 32 ++ frontend/components/AppSidebar.vue | 19 ++ frontend/middleware/init.global.ts | 6 + frontend/pages/settings/change-password.vue | 306 ++++++++++++++++++++ 6 files changed, 425 insertions(+) create mode 100644 backend/core/serializers/change_password.py create mode 100644 backend/core/views/change_password.py create mode 100644 frontend/pages/settings/change-password.vue diff --git a/backend/core/serializers/change_password.py b/backend/core/serializers/change_password.py new file mode 100644 index 0000000..d168ae5 --- /dev/null +++ b/backend/core/serializers/change_password.py @@ -0,0 +1,60 @@ +from rest_framework import serializers +from django.contrib.auth.password_validation import validate_password +from django.core.exceptions import ValidationError +import re + + +class ChangePasswordSerializer(serializers.Serializer): + current_password = serializers.CharField(write_only=True, required=True) + new_password = serializers.CharField(write_only=True, required=True, min_length=8) + confirm_password = serializers.CharField(write_only=True, required=True) + + def validate_current_password(self, value): + user = self.context['request'].user + if not user.check_password(value): + raise serializers.ValidationError("Current password is incorrect") + return value + + def validate_new_password(self, value): + errors = [] + + if len(value) < 8: + errors.append("Password must be at least 8 characters long") + + if not re.search(r'[a-z]', value) or not re.search(r'[A-Z]', value): + errors.append("Password must contain both uppercase and lowercase letters") + + if not re.search(r'[0-9]', value): + errors.append("Password must contain at least one number") + + if not re.search(r'[^a-zA-Z0-9]', value): + errors.append("Password must contain at least one special character") + + if errors: + raise serializers.ValidationError(errors) + + try: + validate_password(value) + except ValidationError as django_error: + if isinstance(django_error.messages, list): + errors.extend(django_error.messages) + else: + errors.append(str(django_error)) + raise serializers.ValidationError(errors) + + return value + + def validate(self, attrs): + if attrs['new_password'] != attrs['confirm_password']: + raise serializers.ValidationError("New passwords don't match") + + if attrs['current_password'] == attrs['new_password']: + raise serializers.ValidationError("New password must be different from current password") + + return attrs + + def save(self): + user = self.context['request'].user + user.set_password(self.validated_data['new_password']) + user.save() + return user diff --git a/backend/core/urls.py b/backend/core/urls.py index 04012d8..effd1ab 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -8,6 +8,7 @@ MeView, GenerateAPIKeyView, IntegrationViewSet, WidgetView, LoadAvailableConfigurationView, AppNotificationUpdateView ) +from core.views.change_password import ChangePasswordView from core.views.resend_verification import ResendVerificationView from core.views.custom_auth import CustomAuthToken from core.views.application import ApplicationViewSet, ApplicationChatRoomsPreviewView @@ -94,4 +95,5 @@ path('forgot-password/', ForgotPasswordView.as_view(), name='forgot-password'), path('reset-password/', ResetPasswordView.as_view(), name='reset-password'), # remove path('reset-password//', ResetPasswordView.as_view(), name='reset-password'), + path('change-password/', ChangePasswordView.as_view(), name='change-password'), ] diff --git a/backend/core/views/change_password.py b/backend/core/views/change_password.py new file mode 100644 index 0000000..19d5356 --- /dev/null +++ b/backend/core/views/change_password.py @@ -0,0 +1,32 @@ +from rest_framework import status, permissions +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.views import APIView +from django.contrib.auth import update_session_auth_hash +from rest_framework.authtoken.models import Token + +from core.serializers.change_password import ChangePasswordSerializer + + +class ChangePasswordView(APIView): + permission_classes = [IsAuthenticated] + + def post(self, request): + serializer = ChangePasswordSerializer(data=request.data, context={'request': request}) + + if serializer.is_valid(): + user = serializer.save() + + update_session_auth_hash(request, user) + + try: + token = Token.objects.get(user=user) + except Token.DoesNotExist: + token = Token.objects.create(user=user) + + return Response({ + 'message': 'Password changed successfully', + 'token': token.key + }, status=status.HTTP_200_OK) + + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/frontend/components/AppSidebar.vue b/frontend/components/AppSidebar.vue index e9ecc7c..9fd1313 100644 --- a/frontend/components/AppSidebar.vue +++ b/frontend/components/AppSidebar.vue @@ -14,6 +14,7 @@ import { ChevronDown, ChevronUp, Puzzle, + Lock, } from 'lucide-vue-next' import SlideOver from '~/components/SlideOver.vue' import { ref, computed, onMounted, watch } from 'vue' @@ -258,6 +259,24 @@ async function initNewChat() { Notification Profile + + + + + Change Password + + +import { ref, computed } from 'vue' +import { Eye, EyeOff, Lock, CheckCircle } from 'lucide-vue-next' +import { toast } from 'vue-sonner' +import { Button } from '@/components/ui/button' +import { Input } from '@/components/ui/input' +import { Label } from '@/components/ui/label' +import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card' +import { Alert, AlertDescription } from '@/components/ui/alert' +import { useHttpClient } from '~/composables/useHttpClient' + +definePageMeta({ + layout: 'default', +}) + +const currentPassword = ref('') +const newPassword = ref('') +const confirmPassword = ref('') +const showCurrentPassword = ref(false) +const showNewPassword = ref(false) +const showConfirmPassword = ref(false) +const loading = ref(false) +const success = ref(false) + +const { httpPost } = useHttpClient() + +const validatePasswords = () => { + const errors = [] + + if (newPassword.value.length < 8) { + errors.push('New password must be at least 8 characters long') + } + + if (!/[a-z]/.test(newPassword.value) || !/[A-Z]/.test(newPassword.value)) { + errors.push('Password must contain both uppercase and lowercase letters') + } + + if (!/[0-9]/.test(newPassword.value)) { + errors.push('Password must contain at least one number') + } + + if (!/[^a-zA-Z0-9]/.test(newPassword.value)) { + errors.push('Password must contain at least one special character') + } + + if (newPassword.value !== confirmPassword.value) { + errors.push('New passwords do not match') + } + + if (currentPassword.value === newPassword.value) { + errors.push('New password must be different from current password') + } + + if (errors.length > 0) { + errors.forEach(error => toast.error(error)) + return false + } + + return true +} + +const handlePasswordChange = async () => { + if (loading.value) return + + if (!validatePasswords()) return + + loading.value = true + + try { + const response = await httpPost<{ message: string; token: string }>( + '/change-password/', + { + current_password: currentPassword.value, + new_password: newPassword.value, + confirm_password: confirmPassword.value, + } + ) + + if (response?.message) { + success.value = true + toast.success('Password changed successfully!') + + if (response.token) { + const cookie = useCookie('auth_token') + cookie.value = response.token + } + + setTimeout(() => { + currentPassword.value = '' + newPassword.value = '' + confirmPassword.value = '' + success.value = false + }, 3000) + } + } catch (err: any) { + console.error('Password change error:', err) + + const errors = err?.errors || {} + + if (errors.current_password) { + toast.error(errors.current_password[0] || 'Current password is incorrect') + } else if (errors.new_password) { + toast.error(errors.new_password[0] || 'New password is not valid') + } else if (errors.non_field_errors) { + toast.error(errors.non_field_errors[0] || 'Password change failed') + } else if (errors.error) { + toast.error(errors.error) + } else { + toast.error('Failed to change password. Please try again.') + } + } finally { + loading.value = false + } +} + +const passwordStrength = computed(() => { + const password = newPassword.value + if (!password) return { score: 0, text: '', color: '' } + + let score = 0 + if (password.length >= 8) score++ + if (password.length >= 12) score++ + if (/[a-z]/.test(password) && /[A-Z]/.test(password)) score++ + if (/[0-9]/.test(password)) score++ + if (/[^a-zA-Z0-9]/.test(password)) score++ + + const strengthLevels = [ + { score: 0, text: 'Very Weak', color: 'text-red-500' }, + { score: 1, text: 'Weak', color: 'text-red-400' }, + { score: 2, text: 'Fair', color: 'text-yellow-500' }, + { score: 3, text: 'Good', color: 'text-blue-500' }, + { score: 4, text: 'Strong', color: 'text-green-500' }, + { score: 5, text: 'Very Strong', color: 'text-green-600' } + ] + + return strengthLevels[Math.min(score, 4)] +}) + + + From df6d76076a5c76161555250541e43bfc61dcb7ab Mon Sep 17 00:00:00 2001 From: Krrish Ghimire Date: Fri, 27 Feb 2026 22:58:33 +0545 Subject: [PATCH 05/65] update ui for login and register pages --- frontend/pages/login.vue | 166 ++++++++++++++++++----------------- frontend/pages/register.vue | 169 ++++++++++++++++++++---------------- 2 files changed, 180 insertions(+), 155 deletions(-) diff --git a/frontend/pages/login.vue b/frontend/pages/login.vue index 8a579e2..071e650 100644 --- a/frontend/pages/login.vue +++ b/frontend/pages/login.vue @@ -1,10 +1,11 @@ From eb473c8d6cbae415b8367d3779828acbaf984ace Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Sun, 1 Mar 2026 02:04:33 +0545 Subject: [PATCH 06/65] feat: ai providers api + test --- backend/.env.example | 6 + backend/config/settings.py | 7 + backend/config/test_settings.py | 28 +++ backend/core/fields.py | 16 ++ ...12_basemodel_remove_appintegration_uuid.py | 28 +++ .../0013_aiprovider_delete_basemodel.py | 37 ++++ ...ey_aiprovider_provider_api_key_and_more.py | 30 +++ .../0015_alter_aiprovider_provider_api_key.py | 19 ++ backend/core/models/__init__.py | 4 +- backend/core/models/ai_provider.py | 15 ++ backend/core/models/app_integration.py | 2 - backend/core/models/base_model.py | 13 ++ backend/core/qdrant.py | 46 +++-- backend/core/serializers/__init__.py | 3 +- backend/core/serializers/ai_provider.py | 36 ++++ backend/core/tests/__init__.py | 0 backend/core/tests/conftest.py | 78 ++++++++ backend/core/tests/factories.py | 35 ++++ backend/core/tests/test_ai_provider.py | 176 ++++++++++++++++++ backend/core/tests/test_basic.py | 22 +++ backend/core/urls.py | 2 + backend/core/views/__init__.py | 3 +- backend/core/views/ai_provider.py | 41 ++++ backend/pytest.ini | 20 ++ backend/requirements.txt | 5 + 25 files changed, 647 insertions(+), 25 deletions(-) create mode 100644 backend/config/test_settings.py create mode 100644 backend/core/fields.py create mode 100644 backend/core/migrations/0012_basemodel_remove_appintegration_uuid.py create mode 100644 backend/core/migrations/0013_aiprovider_delete_basemodel.py create mode 100644 backend/core/migrations/0014_rename_key_aiprovider_provider_api_key_and_more.py create mode 100644 backend/core/migrations/0015_alter_aiprovider_provider_api_key.py create mode 100644 backend/core/models/ai_provider.py create mode 100644 backend/core/models/base_model.py create mode 100644 backend/core/serializers/ai_provider.py create mode 100644 backend/core/tests/__init__.py create mode 100644 backend/core/tests/conftest.py create mode 100644 backend/core/tests/factories.py create mode 100644 backend/core/tests/test_ai_provider.py create mode 100644 backend/core/tests/test_basic.py create mode 100644 backend/core/views/ai_provider.py create mode 100644 backend/pytest.ini diff --git a/backend/.env.example b/backend/.env.example index 630ce67..cdd4ad4 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -30,3 +30,9 @@ FRONTEND_URL=http://localhost:3000 CLOSED_ALPHA_SIGN_UPS='["test@example.com"]' REQUIRE_ACCOUNT_APPROVAL=False + +TEST_DB_NAME=ch8rtests +TEST_DB_USER=anish +TEST_DB_PASSWORD=Anish@1996 +TEST_DB_HOST=localhost +TEST_DB_PORT=5432 diff --git a/backend/config/settings.py b/backend/config/settings.py index 02fc472..70f9dcb 100644 --- a/backend/config/settings.py +++ b/backend/config/settings.py @@ -71,6 +71,13 @@ 'PASSWORD': os.getenv('PASSWORD'), 'HOST': os.getenv('DB_HOST'), 'PORT': os.getenv('PORT'), + 'TEST': { + 'NAME': os.getenv('TEST_DB_NAME', 'test_db'), + 'USER': os.getenv('TEST_DB_USER'), + 'PASSWORD': os.getenv('TEST_DB_PASSWORD'), + 'HOST': os.getenv('TEST_DB_HOST', 'localhost'), + 'PORT': os.getenv('TEST_DB_PORT', '5432'), + } } } diff --git a/backend/config/test_settings.py b/backend/config/test_settings.py new file mode 100644 index 0000000..08fffb6 --- /dev/null +++ b/backend/config/test_settings.py @@ -0,0 +1,28 @@ +import os +from .settings import * + +DEBUG = False + +CONNECT_TO_LOCAL_VECTOR_DB = False + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.postgresql', + 'NAME': os.getenv('TEST_DB_NAME', 'test_db'), + 'USER': os.getenv('TEST_DB_USER'), + 'PASSWORD': os.getenv('TEST_DB_PASSWORD'), + 'HOST': os.getenv('TEST_DB_HOST', 'localhost'), + 'PORT': os.getenv('TEST_DB_PORT', '5432'), + } +} + +EMAIL_BACKEND = 'django.core.mail.backends.locmem.LocMemBackend' + +CELERY_TASK_ALWAYS_EAGER = True +CELERY_TASK_EAGER_PROPAGATES = True + +CACHES = { + 'default': { + 'BACKEND': 'django.core.cache.backends.dummy.DummyCache', + } +} diff --git a/backend/core/fields.py b/backend/core/fields.py new file mode 100644 index 0000000..2e517ac --- /dev/null +++ b/backend/core/fields.py @@ -0,0 +1,16 @@ +from django.db import models +from cryptography.fernet import Fernet +from config import settings + +fernet = Fernet(settings.SECRET_ENCRYPTION_KEY) + +class EncryptedCharField(models.CharField): + def get_prep_value(self, value): + if value: + return fernet.encrypt(value.encode()).decode() + return value + + def from_db_value(self, value, expression, connection): + if value: + return fernet.decrypt(value.encode()).decode() + return value diff --git a/backend/core/migrations/0012_basemodel_remove_appintegration_uuid.py b/backend/core/migrations/0012_basemodel_remove_appintegration_uuid.py new file mode 100644 index 0000000..4bd8929 --- /dev/null +++ b/backend/core/migrations/0012_basemodel_remove_appintegration_uuid.py @@ -0,0 +1,28 @@ +# Generated by Django 5.2.3 on 2026-02-27 17:34 + +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0011_alter_accountstatus_id'), + ] + + operations = [ + migrations.CreateModel( + name='BaseModel', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)), + ('metadata', models.JSONField(blank=True, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ], + ), + migrations.RemoveField( + model_name='appintegration', + name='uuid', + ), + ] diff --git a/backend/core/migrations/0013_aiprovider_delete_basemodel.py b/backend/core/migrations/0013_aiprovider_delete_basemodel.py new file mode 100644 index 0000000..848064d --- /dev/null +++ b/backend/core/migrations/0013_aiprovider_delete_basemodel.py @@ -0,0 +1,37 @@ +# Generated by Django 5.2.3 on 2026-02-27 17:39 + +import django.db.models.deletion +import uuid +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0012_basemodel_remove_appintegration_uuid'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='AIProvider', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)), + ('metadata', models.JSONField(blank=True, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('name', models.CharField(blank=True, max_length=255, null=True)), + ('provider', models.CharField(max_length=255)), + ('key', models.CharField(max_length=1000)), + ('creator', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'abstract': False, + }, + ), + migrations.DeleteModel( + name='BaseModel', + ), + ] diff --git a/backend/core/migrations/0014_rename_key_aiprovider_provider_api_key_and_more.py b/backend/core/migrations/0014_rename_key_aiprovider_provider_api_key_and_more.py new file mode 100644 index 0000000..290983b --- /dev/null +++ b/backend/core/migrations/0014_rename_key_aiprovider_provider_api_key_and_more.py @@ -0,0 +1,30 @@ +# Generated by Django 5.2.3 on 2026-02-27 17:48 + +import django.utils.timezone +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0013_aiprovider_delete_basemodel'), + ] + + operations = [ + migrations.RenameField( + model_name='aiprovider', + old_name='key', + new_name='provider_api_key', + ), + migrations.AddField( + model_name='aiprovider', + name='base_url', + field=models.CharField(default=django.utils.timezone.now, max_length=100), + preserve_default=False, + ), + migrations.AddField( + model_name='aiprovider', + name='is_builtin', + field=models.BooleanField(blank=True, default=False), + ), + ] diff --git a/backend/core/migrations/0015_alter_aiprovider_provider_api_key.py b/backend/core/migrations/0015_alter_aiprovider_provider_api_key.py new file mode 100644 index 0000000..4ee9047 --- /dev/null +++ b/backend/core/migrations/0015_alter_aiprovider_provider_api_key.py @@ -0,0 +1,19 @@ +# Generated by Django 5.2.3 on 2026-02-27 18:06 + +import core.fields +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0014_rename_key_aiprovider_provider_api_key_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='aiprovider', + name='provider_api_key', + field=core.fields.EncryptedCharField(max_length=1000), + ), + ] diff --git a/backend/core/models/__init__.py b/backend/core/models/__init__.py index 15bd2dc..1b81a66 100644 --- a/backend/core/models/__init__.py +++ b/backend/core/models/__init__.py @@ -1,3 +1,4 @@ +from .base_model import BaseModel from .application import Application from .chatroom import ChatRoom from .chatroom_participant import ChatroomParticipant @@ -12,4 +13,5 @@ from .llm_model import LLMModel from .app_model import AppModel from .app_integration import AppIntegration -from .account_status import AccountStatus \ No newline at end of file +from .account_status import AccountStatus +from .ai_provider import AIProvider diff --git a/backend/core/models/ai_provider.py b/backend/core/models/ai_provider.py new file mode 100644 index 0000000..3e7ebbe --- /dev/null +++ b/backend/core/models/ai_provider.py @@ -0,0 +1,15 @@ +from django.contrib.auth.models import User +from django.db import models + +from core.fields import EncryptedCharField + +from .base_model import BaseModel + +class AIProvider(BaseModel): + name = models.CharField(max_length=255, null=True, blank=True) + provider = models.CharField(max_length=255) + provider_api_key = EncryptedCharField(max_length=1000) + base_url = models.CharField(max_length=100) + is_builtin = models.BooleanField(default=False, blank=True) + + creator = models.ForeignKey(User, on_delete=models.CASCADE) diff --git a/backend/core/models/app_integration.py b/backend/core/models/app_integration.py index e09d4cb..f00cc60 100644 --- a/backend/core/models/app_integration.py +++ b/backend/core/models/app_integration.py @@ -1,10 +1,8 @@ from django.db import models from core.models import Application, Integration -import uuid class AppIntegration(models.Model): id = models.AutoField(primary_key=True) - uuid = models.UUIDField(default=uuid.uuid4, editable=False, unique=True) application = models.ForeignKey( Application, on_delete=models.CASCADE, related_name="app_integrations" ) diff --git a/backend/core/models/base_model.py b/backend/core/models/base_model.py new file mode 100644 index 0000000..f783951 --- /dev/null +++ b/backend/core/models/base_model.py @@ -0,0 +1,13 @@ +from django.db import models +import uuid + +class BaseModel(models.Model): + uuid = models.UUIDField(default=uuid.uuid4, editable=False, unique=True) + + metadata = models.JSONField(blank=True, null=True) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + abstract = True diff --git a/backend/core/qdrant.py b/backend/core/qdrant.py index f712dbc..3652daf 100644 --- a/backend/core/qdrant.py +++ b/backend/core/qdrant.py @@ -11,33 +11,39 @@ logger = logging.getLogger(__name__) load_dotenv() -connect_to_local = os.getenv("CONNECT_TO_LOCAL_VECTOR_DB", "false").lower() == "true" - -if connect_to_local: - print('Connecting to local vector db') - qdrant = QdrantClient( - host=os.getenv("QDRANT_LOCAL_HOST", "localhost"), - port=int(os.getenv("QDRANT_LOCAL_PORT", "6333")), - prefer_grpc=True, - ) +if os.getenv("DISABLE_VECTOR_DB", "false").lower() == "true": + qdrant = None else: - print('Connecting to remote vector db') + connect_to_local = os.getenv("CONNECT_TO_LOCAL_VECTOR_DB", "false").lower() == "true" + + if connect_to_local: + print('Connecting to local vector db') + qdrant = QdrantClient( + host=os.getenv("QDRANT_LOCAL_HOST", "localhost"), + port=int(os.getenv("QDRANT_LOCAL_PORT", "6333")), + prefer_grpc=True, + ) + else: + print('Connecting to remote vector db') - cloud_host = os.getenv("QDRANT_CLOUD_HOST") - cloud_port = os.getenv("QDRANT_CLOUD_PORT", "6333") - api_key = os.getenv("QDRANT_CLOUD_API_KEY") + cloud_host = os.getenv("QDRANT_CLOUD_HOST") + cloud_port = os.getenv("QDRANT_CLOUD_PORT", "6333") + api_key = os.getenv("QDRANT_CLOUD_API_KEY") - full_url = f"{cloud_host}:{cloud_port}" + full_url = f"{cloud_host}:{cloud_port}" - qdrant = QdrantClient( - url=full_url, - api_key=api_key, - prefer_grpc=False, - ) + qdrant = QdrantClient( + url=full_url, + api_key=api_key, + prefer_grpc=False, + ) COLLECTION_NAME = "advq" def init_qdrant(retries=3, delay=2): + if qdrant is None: + return + for attempt in range(retries): try: if not qdrant.collection_exists(COLLECTION_NAME): @@ -66,4 +72,4 @@ def ensure_payload_indexes(): field_schema=PayloadSchemaType.KEYWORD ) except Exception as e: - print(f"Payload index for '{field}' may already exist or failed:", e) \ No newline at end of file + print(f"Payload index for '{field}' may already exist or failed:", e) diff --git a/backend/core/serializers/__init__.py b/backend/core/serializers/__init__.py index 78469d8..0e36463 100644 --- a/backend/core/serializers/__init__.py +++ b/backend/core/serializers/__init__.py @@ -11,4 +11,5 @@ from .configure_app import LoadAppConfigurationSerializer, ConfigureAppIntegrationSerializer from .app_integration import AppIntegrationViewSerializer from .app_model import AppModelViewSerializer, ConfigureAppModelsSerializer -from .password import ForgotPasswordSerializer, ResetPasswordSerializer \ No newline at end of file +from .password import ForgotPasswordSerializer, ResetPasswordSerializer +from .ai_provider import AIProviderSerializer diff --git a/backend/core/serializers/ai_provider.py b/backend/core/serializers/ai_provider.py new file mode 100644 index 0000000..0669479 --- /dev/null +++ b/backend/core/serializers/ai_provider.py @@ -0,0 +1,36 @@ +from rest_framework import serializers +from core.models.ai_provider import AIProvider + +class AIProviderSerializer(serializers.ModelSerializer): + class Meta: + model = AIProvider + fields = ['id', 'uuid', 'name', 'provider', 'base_url', 'is_builtin', 'creator', 'created_at', 'updated_at'] + +class AIProviderCreateSerializer(serializers.ModelSerializer): + class Meta: + model = AIProvider + fields = ['name', 'provider', 'base_url', 'provider_api_key', 'creator'] + read_only_fields = ['creator'] + + def create(self, validated_data): + validated_data['creator'] = self.context['request'].user + return super().create(validated_data) + +class AIProviderUpdateSerializer(serializers.ModelSerializer): + provider_api_key = serializers.CharField(write_only=True, required=False, allow_blank=True) + + class Meta: + model = AIProvider + fields = ['name', 'base_url', 'provider_api_key'] + read_only_fields = ['creator'] + + def update(self, instance, validated_data): + api_key = validated_data.pop('provider_api_key', None) + if api_key: + instance.provider_api_key = api_key + + for attr, value in validated_data.items(): + setattr(instance, attr, value) + + instance.save() + return instance diff --git a/backend/core/tests/__init__.py b/backend/core/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/core/tests/conftest.py b/backend/core/tests/conftest.py new file mode 100644 index 0000000..32b387c --- /dev/null +++ b/backend/core/tests/conftest.py @@ -0,0 +1,78 @@ +import os +import django +from django.conf import settings +from pathlib import Path + +if not settings.configured: + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.test_settings') + + base_dir = Path(__file__).resolve().parent.parent.parent + env_path = base_dir / '.env' + if env_path.exists(): + from dotenv import load_dotenv + load_dotenv(env_path) + + django.setup() + +import pytest +from django.test import TestCase +from rest_framework.test import APITestCase, APIClient +from factory.django import DjangoModelFactory +import factory + +def pytest_configure(config): + config.addinivalue_line("markers", "unit: Unit tests") + config.addinivalue_line("markers", "integration: Integration tests") + config.addinivalue_line("markers", "api: API endpoint tests") + config.addinivalue_line("markers", "slow: Slow running tests") + + +@pytest.fixture(autouse=True) +def enable_db_access_for_all_tests(db): + pass + + +@pytest.fixture +def api_client(): + return APIClient() + + +@pytest.fixture +def authenticated_client(api_client, user_factory): + user = user_factory() + api_client.force_authenticate(user=user) + return api_client, user + + +class BaseTestCase(TestCase): + def setUp(self): + super().setUp() + + def tearDown(self): + super().tearDown() + + +class BaseAPITestCase(APITestCase): + def setUp(self): + super().setUp() + self.client = APIClient() + + def tearDown(self): + super().tearDown() + + +class BaseServiceTestCase(TestCase): + def setUp(self): + super().setUp() + + def tearDown(self): + super().tearDown() + + +class BaseFactory(DjangoModelFactory): + class Meta: + abstract = True + + @classmethod + def _create(cls, model_class, *args, **kwargs): + return super()._create(model_class, *args, **kwargs) diff --git a/backend/core/tests/factories.py b/backend/core/tests/factories.py new file mode 100644 index 0000000..e879d0e --- /dev/null +++ b/backend/core/tests/factories.py @@ -0,0 +1,35 @@ +import factory +from django.contrib.auth.models import User +from core.models import ( + Application, + AIProvider +) + +class UserFactory(factory.django.DjangoModelFactory): + class Meta: + model = User + + username = factory.Sequence(lambda n: f'testuser{n}') + email = factory.Sequence(lambda n: f'testuser{n}@example.com') + first_name = factory.Faker('first_name') + last_name = factory.Faker('last_name') + is_active = True + +class ApplicationFactory(factory.django.DjangoModelFactory): + class Meta: + model = Application + + owner = factory.SubFactory(UserFactory) + name = factory.Faker('company') + uuid = factory.Faker('uuid4') + +class AIProviderFactory(factory.django.DjangoModelFactory): + class Meta: + model = AIProvider + + name = factory.Faker('company') + provider = factory.Iterator(['openai', 'anthropic', 'google']) + provider_api_key = factory.Faker('password') + base_url = factory.Faker('url') + is_builtin = False + creator = factory.SubFactory(UserFactory) diff --git a/backend/core/tests/test_ai_provider.py b/backend/core/tests/test_ai_provider.py new file mode 100644 index 0000000..e44f645 --- /dev/null +++ b/backend/core/tests/test_ai_provider.py @@ -0,0 +1,176 @@ +import pytest +from rest_framework import status +from rest_framework.test import APIClient + +from core.models import AIProvider +from core.tests.conftest import BaseAPITestCase +from core.tests.factories import UserFactory, AIProviderFactory + + +@pytest.mark.api +class TestAIProviderAPI(BaseAPITestCase): + """Test suite for AI Provider API endpoints.""" + + def setUp(self): + super().setUp() + self.list_url = '/api/ai-providers/' + + def test_list_ai_providers_authenticated_user(self): + """Test that authenticated users can list their AI providers.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + provider1 = AIProviderFactory(creator=user, name="OpenAI Provider") + provider2 = AIProviderFactory(creator=user, name="Anthropic Provider") + + other_user = UserFactory() + other_provider = AIProviderFactory(creator=other_user, name="Other User Provider") + + response = self.client.get(self.list_url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(len(data), 2) + + provider_names = [provider['name'] for provider in data] + self.assertIn("OpenAI Provider", provider_names) + self.assertIn("Anthropic Provider", provider_names) + self.assertNotIn("Other User Provider", provider_names) + + provider_data = data[0] + expected_fields = ['id', 'uuid', 'name', 'provider', 'base_url', 'is_builtin', 'creator', 'created_at', 'updated_at'] + for field in expected_fields: + self.assertIn(field, provider_data) + + def test_list_ai_providers_unauthenticated(self): + """Test that unauthenticated users cannot list AI providers.""" + response = self.client.get(self.list_url) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_retrieve_other_users_provider(self): + """Test that user A cannot retrieve user B's AI provider by ID.""" + user_a = UserFactory() + user_b = UserFactory() + + provider_b = AIProviderFactory(creator=user_b, name="User B's Provider") + + self.client.force_authenticate(user=user_a) + + detail_url = f'/api/ai-providers/{provider_b.id}/' + response = self.client.get(detail_url) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_update_other_users_provider(self): + """Test that user A cannot update user B's AI provider.""" + user_a = UserFactory() + user_b = UserFactory() + + provider_b = AIProviderFactory(creator=user_b, name="User B's Provider") + + self.client.force_authenticate(user=user_a) + + detail_url = f'/api/ai-providers/{provider_b.id}/' + update_data = {'name': 'Hacked Name'} + response = self.client.patch(detail_url, update_data, format='json') + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_delete_other_users_provider(self): + """Test that user A cannot delete user B's AI provider.""" + user_a = UserFactory() + user_b = UserFactory() + + provider_b = AIProviderFactory(creator=user_b, name="User B's Provider") + + self.client.force_authenticate(user=user_a) + + detail_url = f'/api/ai-providers/{provider_b.id}/' + response = self.client.delete(detail_url) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_create_ai_provider(self): + """Test that authenticated user can create their own AI provider.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + create_data = { + 'name': 'My OpenAI Provider', + 'provider': 'openai', + 'base_url': 'https://api.openai.com/v1', + 'provider_api_key': 'sk-test123456789' + } + + response = self.client.post(self.list_url, create_data, format='json') + + print(response.data) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + data = response.json() + + self.assertEqual(data['name'], 'My OpenAI Provider') + self.assertEqual(data['provider'], 'openai') + self.assertEqual(data['base_url'], 'https://api.openai.com/v1') + self.assertEqual(data['creator'], user.id) + + provider = AIProvider.objects.get(id=data['id']) + self.assertEqual(provider.creator, user) + self.assertEqual(provider.name, 'My OpenAI Provider') + + def test_update_own_provider(self): + """Test that authenticated user can update their own AI provider.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + provider = AIProviderFactory(creator=user, name="Original Name") + + detail_url = f'/api/ai-providers/{provider.id}/' + update_data = {'name': 'Updated Name'} + response = self.client.patch(detail_url, update_data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + + print(response) + self.assertEqual(data['name'], 'Updated Name') + self.assertEqual(data['provider'], provider.provider) + self.assertEqual(data['creator'], user.id) + + provider.refresh_from_db() + self.assertEqual(provider.name, 'Updated Name') + + def test_cannot_update_provider_field(self): + """Test that authenticated user cannot update the provider field of their AI provider.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + provider = AIProviderFactory(creator=user, provider='openai', name="Original Name") + + detail_url = f'/api/ai-providers/{provider.id}/' + update_data = { + 'name': 'Updated Name', + 'provider': 'anthropic' + } + response = self.client.patch(detail_url, update_data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + provider.refresh_from_db() + self.assertEqual(provider.name, 'Updated Name') + self.assertEqual(provider.provider, 'openai') + + def test_delete_own_provider(self): + """Test that authenticated user can delete their own AI provider.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + provider = AIProviderFactory(creator=user, name="Provider to Delete") + + detail_url = f'/api/ai-providers/{provider.id}/' + response = self.client.delete(detail_url) + + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + with self.assertRaises(AIProvider.DoesNotExist): + AIProvider.objects.get(id=provider.id) diff --git a/backend/core/tests/test_basic.py b/backend/core/tests/test_basic.py new file mode 100644 index 0000000..cd59557 --- /dev/null +++ b/backend/core/tests/test_basic.py @@ -0,0 +1,22 @@ +import pytest + + +class TestBasic: + @pytest.mark.unit + def test_math(self): + assert 2 + 2 == 4 + assert 1 + 1 == 2 + assert 3 * 2 == 6 + + @pytest.mark.unit + def test_strings(self): + assert "hello" + "world" == "helloworld" + assert len("test") == 4 + assert "test".upper() == "TEST" + + @pytest.mark.unit + def test_lists(self): + my_list = [1, 2, 3, 4] + assert len(my_list) == 4 + assert my_list[0] == 1 + assert 5 not in my_list diff --git a/backend/core/urls.py b/backend/core/urls.py index effd1ab..334a000 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -22,9 +22,11 @@ from core.views.app_model import ConfigureAppModelsView from core.views.forgot_password import ForgotPasswordView from core.views.reset_password import ResetPasswordView, ResetPasswordVerifyView +from core.views.ai_provider import AIProviderViewSet router = DefaultRouter() router.register(r'applications', ApplicationViewSet, basename='applications') +router.register(r'ai-providers', AIProviderViewSet, basename='ai-provider') router.register(r'notification-profiles', NotificationProfileViewSet, basename='notificationprofile') router.register(r'models', LLMModelViewSet, basename='model'), router.register(r'integrations', IntegrationViewSet, basename='integration'), diff --git a/backend/core/views/__init__.py b/backend/core/views/__init__.py index 1cbdc71..b62b15d 100644 --- a/backend/core/views/__init__.py +++ b/backend/core/views/__init__.py @@ -14,4 +14,5 @@ from .custom_auth import CustomAuthToken from .app_model import ConfigureAppModelsView from .reset_password import ResetPasswordView, ResetPasswordVerifyView -from .forgot_password import ForgotPasswordView \ No newline at end of file +from .forgot_password import ForgotPasswordView +from .ai_provider import AIProviderViewSet diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py new file mode 100644 index 0000000..4e9403a --- /dev/null +++ b/backend/core/views/ai_provider.py @@ -0,0 +1,41 @@ +from rest_framework import status, viewsets, permissions +from rest_framework.response import Response +from core.serializers.ai_provider import AIProviderCreateSerializer, AIProviderSerializer, AIProviderUpdateSerializer +from core.models import AIProvider + +class AIProviderViewSet(viewsets.ModelViewSet): + permission_classes = [permissions.IsAuthenticated] + http_method_names = ['get', 'post', 'put','patch', 'delete'] + + queryset = AIProvider.objects.all() + + def get_serializer_class(self): + if self.action == 'create': + return AIProviderCreateSerializer + elif self.action in ['update', 'partial_update']: + return AIProviderUpdateSerializer + return AIProviderSerializer + + def get_queryset(self): + user = self.request.user + return AIProvider.objects.filter(creator=user) + + def perform_create(self, serializer): + serializer.save(creator=self.request.user) + + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + + read_serializer = AIProviderSerializer(serializer.instance, context={'request': request}) + return Response(read_serializer.data, status=status.HTTP_201_CREATED) + + def update(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=True) + serializer.is_valid(raise_exception=True) + serializer.save() + + read_serializer = AIProviderSerializer(instance, context={'request': request}) + return Response(read_serializer.data, status=status.HTTP_200_OK) diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 0000000..a211eea --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,20 @@ +[tool:pytest] +DJANGO_SETTINGS_MODULE = config.test_settings +python_files = tests.py test_*.py *_tests.py +python_classes = Test* +python_functions = test_* +addopts = + --strict-markers + --strict-config + --disable-warnings + --tb=short + --cov=core + --cov-report=html + --cov-report=term-missing + --cov-fail-under=80 +testpaths = core/tests +markers = + unit: Unit tests + integration: Integration tests + api: API endpoint tests + slow: Slow running tests diff --git a/backend/requirements.txt b/backend/requirements.txt index 629b592..7ef7856 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -144,3 +144,8 @@ wcwidth==0.2.13 websockets==15.0.1 zope.interface==7.2 zstandard==0.23.0 +pytest==8.3.4 +pytest-django==4.9.0 +pytest-cov==6.0.0 +factory-boy==3.3.1 +django-factory-boy==1.0.0 From 722267d1ed883c5e1b1faf4823b81e9ab657355a Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Sun, 1 Mar 2026 23:03:31 +0545 Subject: [PATCH 07/65] improve: refactor ai providers logic + add tests --- backend/core/serializers/ai_provider.py | 19 +++-- backend/core/tests/test_ai_provider.py | 101 ++++++++++++++++++++++++ backend/core/views/ai_provider.py | 6 +- 3 files changed, 112 insertions(+), 14 deletions(-) diff --git a/backend/core/serializers/ai_provider.py b/backend/core/serializers/ai_provider.py index 0669479..01aa563 100644 --- a/backend/core/serializers/ai_provider.py +++ b/backend/core/serializers/ai_provider.py @@ -4,7 +4,7 @@ class AIProviderSerializer(serializers.ModelSerializer): class Meta: model = AIProvider - fields = ['id', 'uuid', 'name', 'provider', 'base_url', 'is_builtin', 'creator', 'created_at', 'updated_at'] + exclude = ['provider_api_key'] class AIProviderCreateSerializer(serializers.ModelSerializer): class Meta: @@ -12,21 +12,20 @@ class Meta: fields = ['name', 'provider', 'base_url', 'provider_api_key', 'creator'] read_only_fields = ['creator'] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.instance is not None: + self.fields['provider_api_key'].required = False + self.fields['provider_api_key'].allow_blank = True + self.fields.pop('provider', None) + def create(self, validated_data): validated_data['creator'] = self.context['request'].user return super().create(validated_data) -class AIProviderUpdateSerializer(serializers.ModelSerializer): - provider_api_key = serializers.CharField(write_only=True, required=False, allow_blank=True) - - class Meta: - model = AIProvider - fields = ['name', 'base_url', 'provider_api_key'] - read_only_fields = ['creator'] - def update(self, instance, validated_data): api_key = validated_data.pop('provider_api_key', None) - if api_key: + if api_key and isinstance(api_key, str) and api_key.strip(): instance.provider_api_key = api_key for attr, value in validated_data.items(): diff --git a/backend/core/tests/test_ai_provider.py b/backend/core/tests/test_ai_provider.py index e44f645..d4f79ee 100644 --- a/backend/core/tests/test_ai_provider.py +++ b/backend/core/tests/test_ai_provider.py @@ -174,3 +174,104 @@ def test_delete_own_provider(self): with self.assertRaises(AIProvider.DoesNotExist): AIProvider.objects.get(id=provider.id) + + def test_update_without_api_key_does_not_change_api_key(self): + """Test that update request without specifying provider api key does not update provider api key.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + provider = AIProviderFactory(creator=user, name="Test Provider") + original_api_key = provider.provider_api_key + + detail_url = f'/api/ai-providers/{provider.id}/' + update_data = {'name': 'Updated Name'} + response = self.client.patch(detail_url, update_data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + provider.refresh_from_db() + self.assertEqual(provider.name, 'Updated Name') + self.assertEqual(provider.provider_api_key, original_api_key) + + def test_update_with_api_key_changes_api_key(self): + """Test that update request with provider api key updates the provider api key.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + provider = AIProviderFactory(creator=user, name="Test Provider") + original_api_key = provider.provider_api_key + new_api_key = 'new-api-key-12345' + + detail_url = f'/api/ai-providers/{provider.id}/' + update_data = {'name': 'Updated Name', 'provider_api_key': new_api_key} + response = self.client.patch(detail_url, update_data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + provider.refresh_from_db() + self.assertEqual(provider.name, 'Updated Name') + self.assertEqual(provider.provider_api_key, new_api_key) + self.assertNotEqual(provider.provider_api_key, original_api_key) + + def test_update_with_empty_api_key_does_not_change_api_key(self): + """Test that update request with empty string provider api key does not update the provider api key.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + provider = AIProviderFactory(creator=user, name="Test Provider") + original_api_key = provider.provider_api_key + + detail_url = f'/api/ai-providers/{provider.id}/' + update_data = {'name': 'Updated Name', 'provider_api_key': ''} + response = self.client.patch(detail_url, update_data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + provider.refresh_from_db() + self.assertEqual(provider.name, 'Updated Name') + self.assertEqual(provider.provider_api_key, original_api_key) + + def test_update_with_whitespace_api_key_does_not_change_api_key(self): + """Test that update request with whitespace-only provider api key does not update the provider api key.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + provider = AIProviderFactory(creator=user, name="Test Provider") + original_api_key = provider.provider_api_key + + detail_url = f'/api/ai-providers/{provider.id}/' + update_data = {'name': 'Updated Name', 'provider_api_key': ' '} + response = self.client.patch(detail_url, update_data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + provider.refresh_from_db() + self.assertEqual(provider.name, 'Updated Name') + self.assertEqual(provider.provider_api_key, original_api_key) + + def test_api_key_is_encrypted_in_database(self): + """Test that provider_api_key is encrypted when stored in the database.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + api_key = 'test-api-key-12345' + create_data = { + 'name': 'Test Provider', + 'provider': 'openai', + 'base_url': 'https://api.openai.com/v1', + 'provider_api_key': api_key + } + + response = self.client.post(self.list_url, create_data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + provider = AIProvider.objects.get(id=response.json()['id']) + + self.assertEqual(provider.provider_api_key, api_key) + + from django.db import connection + with connection.cursor() as cursor: + cursor.execute("SELECT provider_api_key FROM core_aiprovider WHERE id = %s", [provider.id]) + raw_db_value = cursor.fetchone()[0] + self.assertNotEqual(raw_db_value, api_key) + self.assertTrue(raw_db_value.startswith('gAAAAA')) # Fernet encrypted strings start with this diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py index 4e9403a..0d2d2a5 100644 --- a/backend/core/views/ai_provider.py +++ b/backend/core/views/ai_provider.py @@ -1,6 +1,6 @@ from rest_framework import status, viewsets, permissions from rest_framework.response import Response -from core.serializers.ai_provider import AIProviderCreateSerializer, AIProviderSerializer, AIProviderUpdateSerializer +from core.serializers.ai_provider import AIProviderCreateSerializer, AIProviderSerializer from core.models import AIProvider class AIProviderViewSet(viewsets.ModelViewSet): @@ -10,10 +10,8 @@ class AIProviderViewSet(viewsets.ModelViewSet): queryset = AIProvider.objects.all() def get_serializer_class(self): - if self.action == 'create': + if self.action in ['create', 'update', 'partial_update']: return AIProviderCreateSerializer - elif self.action in ['update', 'partial_update']: - return AIProviderUpdateSerializer return AIProviderSerializer def get_queryset(self): From 80b651addffa1344f1dcb68360ab8a81cae51a40 Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Sun, 1 Mar 2026 23:25:59 +0545 Subject: [PATCH 08/65] improve: refactor ai providers logic + add tests --- backend/core/serializers/ai_provider.py | 3 +++ backend/core/views/ai_provider.py | 19 +------------------ 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/backend/core/serializers/ai_provider.py b/backend/core/serializers/ai_provider.py index 01aa563..b48b110 100644 --- a/backend/core/serializers/ai_provider.py +++ b/backend/core/serializers/ai_provider.py @@ -19,6 +19,9 @@ def __init__(self, *args, **kwargs): self.fields['provider_api_key'].allow_blank = True self.fields.pop('provider', None) + def to_representation(self, instance): + return AIProviderSerializer(instance, context=self.context).data + def create(self, validated_data): validated_data['creator'] = self.context['request'].user return super().create(validated_data) diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py index 0d2d2a5..04fa943 100644 --- a/backend/core/views/ai_provider.py +++ b/backend/core/views/ai_provider.py @@ -19,21 +19,4 @@ def get_queryset(self): return AIProvider.objects.filter(creator=user) def perform_create(self, serializer): - serializer.save(creator=self.request.user) - - def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) - serializer.is_valid(raise_exception=True) - serializer.save() - - read_serializer = AIProviderSerializer(serializer.instance, context={'request': request}) - return Response(read_serializer.data, status=status.HTTP_201_CREATED) - - def update(self, request, *args, **kwargs): - instance = self.get_object() - serializer = self.get_serializer(instance, data=request.data, partial=True) - serializer.is_valid(raise_exception=True) - serializer.save() - - read_serializer = AIProviderSerializer(instance, context={'request': request}) - return Response(read_serializer.data, status=status.HTTP_200_OK) + serializer.save(creator=self.request.user) \ No newline at end of file From 5bbdbcba4c1828355e29fc36290820d334d14bcf Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Mon, 2 Mar 2026 00:02:03 +0545 Subject: [PATCH 09/65] feat: list user owned and builtin ai providers --- backend/core/tests/test_ai_provider.py | 38 ++++++++++++++++++++++++++ backend/core/views/ai_provider.py | 5 +++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/backend/core/tests/test_ai_provider.py b/backend/core/tests/test_ai_provider.py index d4f79ee..5976bee 100644 --- a/backend/core/tests/test_ai_provider.py +++ b/backend/core/tests/test_ai_provider.py @@ -43,6 +43,44 @@ def test_list_ai_providers_authenticated_user(self): for field in expected_fields: self.assertIn(field, provider_data) + def test_list_ai_providers_includes_builtin_and_user_owned(self): + """Test that authenticated users can list their AI providers plus builtin providers.""" + user = UserFactory() + self.client.force_authenticate(user=user) + + user_provider1 = AIProviderFactory(creator=user, name="User OpenAI Provider", is_builtin=False) + user_provider2 = AIProviderFactory(creator=user, name="User Anthropic Provider", is_builtin=False) + + builtin_provider1 = AIProviderFactory(creator=user, name="Builtin OpenAI", is_builtin=True) + builtin_provider2 = AIProviderFactory(creator=user, name="Builtin Claude", is_builtin=True) + + other_user = UserFactory() + other_user_provider = AIProviderFactory(creator=other_user, name="Other User Provider", is_builtin=False) + + response = self.client.get(self.list_url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(len(data), 4) + + provider_names = [provider['name'] for provider in data] + self.assertIn("User OpenAI Provider", provider_names) + self.assertIn("User Anthropic Provider", provider_names) + self.assertIn("Builtin OpenAI", provider_names) + self.assertIn("Builtin Claude", provider_names) + self.assertNotIn("Other User Provider", provider_names) + + builtin_providers = [p for p in data if p['is_builtin']] + user_owned_providers = [p for p in data if not p['is_builtin']] + + self.assertEqual(len(builtin_providers), 2) + self.assertEqual(len(user_owned_providers), 2) + + builtin_names = [p['name'] for p in builtin_providers] + self.assertIn("Builtin OpenAI", builtin_names) + self.assertIn("Builtin Claude", builtin_names) + def test_list_ai_providers_unauthenticated(self): """Test that unauthenticated users cannot list AI providers.""" response = self.client.get(self.list_url) diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py index 04fa943..6939cde 100644 --- a/backend/core/views/ai_provider.py +++ b/backend/core/views/ai_provider.py @@ -1,5 +1,6 @@ from rest_framework import status, viewsets, permissions from rest_framework.response import Response +from django.db import models from core.serializers.ai_provider import AIProviderCreateSerializer, AIProviderSerializer from core.models import AIProvider @@ -16,7 +17,9 @@ def get_serializer_class(self): def get_queryset(self): user = self.request.user - return AIProvider.objects.filter(creator=user) + return AIProvider.objects.filter( + models.Q(creator=user) | models.Q(is_builtin=True) + ) def perform_create(self, serializer): serializer.save(creator=self.request.user) \ No newline at end of file From e344ae1a93d4fc83169f1bd4f56b96d4a0a9ecd0 Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Mon, 2 Mar 2026 01:43:50 +0545 Subject: [PATCH 10/65] feat: api for app ai provider configuration --- backend/core/migrations/0016_appaiprovider.py | 34 +++ backend/core/models/__init__.py | 1 + backend/core/models/app_ai_provider.py | 49 ++++ backend/core/serializers/__init__.py | 1 + backend/core/serializers/app_ai_provider.py | 62 +++++ backend/core/tests/test_app_ai_provider.py | 214 ++++++++++++++++++ backend/core/urls.py | 2 + backend/core/views/app_ai_provider.py | 55 +++++ 8 files changed, 418 insertions(+) create mode 100644 backend/core/migrations/0016_appaiprovider.py create mode 100644 backend/core/models/app_ai_provider.py create mode 100644 backend/core/serializers/app_ai_provider.py create mode 100644 backend/core/tests/test_app_ai_provider.py create mode 100644 backend/core/views/app_ai_provider.py diff --git a/backend/core/migrations/0016_appaiprovider.py b/backend/core/migrations/0016_appaiprovider.py new file mode 100644 index 0000000..766fc49 --- /dev/null +++ b/backend/core/migrations/0016_appaiprovider.py @@ -0,0 +1,34 @@ +# Generated by Django 5.2.3 on 2026-03-01 18:41 + +import django.db.models.deletion +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0015_alter_aiprovider_provider_api_key'), + ] + + operations = [ + migrations.CreateModel( + name='AppAIProvider', + fields=[ + ('id', models.AutoField(primary_key=True, serialize=False)), + ('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)), + ('context', models.CharField(default='default', max_length=50)), + ('capability', models.CharField(default='text', max_length=50)), + ('priority', models.PositiveIntegerField(default=100)), + ('external_model_id', models.CharField(blank=True, max_length=255, null=True)), + ('is_active', models.BooleanField(default=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('ai_provider', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='application_configs', to='core.aiprovider')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='ai_provider_configs', to='core.application')), + ], + options={ + 'ordering': ['context', 'capability', 'priority'], + }, + ), + ] diff --git a/backend/core/models/__init__.py b/backend/core/models/__init__.py index 1b81a66..e1bc3f5 100644 --- a/backend/core/models/__init__.py +++ b/backend/core/models/__init__.py @@ -15,3 +15,4 @@ from .app_integration import AppIntegration from .account_status import AccountStatus from .ai_provider import AIProvider +from .app_ai_provider import AppAIProvider diff --git a/backend/core/models/app_ai_provider.py b/backend/core/models/app_ai_provider.py new file mode 100644 index 0000000..8c752a1 --- /dev/null +++ b/backend/core/models/app_ai_provider.py @@ -0,0 +1,49 @@ +import uuid +from django.db import models + + +class AppAIProvider(models.Model): + id = models.AutoField(primary_key=True) + uuid = models.UUIDField(default=uuid.uuid4, editable=False, unique=True) + + application = models.ForeignKey( + "Application", + on_delete=models.CASCADE, + related_name="ai_provider_configs" + ) + ai_provider = models.ForeignKey( + "AIProvider", + on_delete=models.CASCADE, + related_name="application_configs" + ) + + context = models.CharField(max_length=50) + capability = models.CharField(max_length=50, default='text') + priority = models.PositiveIntegerField(default=100) + external_model_id = models.CharField(max_length=255, blank=True, null=True) + + is_active = models.BooleanField(default=True) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ['context', 'capability', 'priority'] + + def __str__(self): + return f"{self.application.name} - {self.ai_provider.name} ({self.context}:{self.capability})" + + def save(self, *args, **kwargs): + if not self.priority or self.priority == 100: + existing_configs = AppAIProvider.objects.filter( + application=self.application, + context=self.context, + capability=self.capability + ).exclude(id=self.id).order_by('-priority') + + if existing_configs.exists(): + self.priority = existing_configs.first().priority + 100 + else: + self.priority = 100 + + super().save(*args, **kwargs) diff --git a/backend/core/serializers/__init__.py b/backend/core/serializers/__init__.py index 0e36463..7bfabaa 100644 --- a/backend/core/serializers/__init__.py +++ b/backend/core/serializers/__init__.py @@ -13,3 +13,4 @@ from .app_model import AppModelViewSerializer, ConfigureAppModelsSerializer from .password import ForgotPasswordSerializer, ResetPasswordSerializer from .ai_provider import AIProviderSerializer +from .app_ai_provider import AppAIProviderSerializer, AppAIProviderCreateSerializer, AppAIProviderUpdateSerializer \ No newline at end of file diff --git a/backend/core/serializers/app_ai_provider.py b/backend/core/serializers/app_ai_provider.py new file mode 100644 index 0000000..5110617 --- /dev/null +++ b/backend/core/serializers/app_ai_provider.py @@ -0,0 +1,62 @@ +from rest_framework import serializers +from core.models.app_ai_provider import AppAIProvider +from core.models.ai_provider import AIProvider +from .ai_provider import AIProviderSerializer + +class AppAIProviderSerializer(serializers.ModelSerializer): + ai_provider = AIProviderSerializer(read_only=True) + + class Meta: + model = AppAIProvider + fields = [ + 'id', 'uuid', 'ai_provider', 'context', 'capability', + 'priority', 'external_model_id', 'is_active', + 'created_at', 'updated_at' + ] + read_only_fields = ['id', 'uuid', 'priority', 'created_at', 'updated_at'] + + +class AppAIProviderCreateSerializer(serializers.ModelSerializer): + ai_provider_id = serializers.IntegerField(write_only=True) + context = serializers.CharField() + capability = serializers.CharField() + external_model_id = serializers.CharField(required=False, allow_blank=True) + + class Meta: + model = AppAIProvider + fields = ['ai_provider_id', 'context', 'capability', 'external_model_id'] + + def to_representation(self, instance): + return AppAIProviderSerializer(instance, context=self.context).data + + def validate_ai_provider_id(self, value): + try: + ai_provider = AIProvider.objects.get(id=value) + if ai_provider.creator != self.context['request'].user: + raise serializers.ValidationError("You don't own this AI provider") + return value + except AIProvider.DoesNotExist: + raise serializers.ValidationError("AI provider not found") + + def validate(self, data): + return data + + def create(self, validated_data): + ai_provider_id = validated_data.pop('ai_provider_id') + ai_provider = AIProvider.objects.get(id=ai_provider_id) + application = self.context['application'] + + return AppAIProvider.objects.create( + application=application, + ai_provider=ai_provider, + **validated_data + ) + + +class AppAIProviderUpdateSerializer(serializers.ModelSerializer): + class Meta: + model = AppAIProvider + fields = ['external_model_id'] + + def to_representation(self, instance): + return AppAIProviderSerializer(instance, context=self.context).data diff --git a/backend/core/tests/test_app_ai_provider.py b/backend/core/tests/test_app_ai_provider.py new file mode 100644 index 0000000..ebd705d --- /dev/null +++ b/backend/core/tests/test_app_ai_provider.py @@ -0,0 +1,214 @@ +import pytest +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APITestCase +from django.contrib.auth.models import User + +from core.models import Application, AIProvider, AppAIProvider + + +class AppAIProviderTest(APITestCase): + def setUp(self): + self.user = User.objects.create_user(username='test', password='test') + self.application = Application.objects.create(owner=self.user, name='Test App') + self.ai_provider = AIProvider.objects.create( + name='Test Provider', + provider='openai', + base_url='https://api.openai.com', + provider_api_key='test-key', + creator=self.user + ) + + def test_create_app_ai_provider(self): + self.client.force_authenticate(user=self.user) + url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) + data = { + 'ai_provider_id': self.ai_provider.id, + 'context': 'widget', + 'capability': 'text', + 'external_model_id': 'gpt-4' + } + response = self.client.post(url, data) + print(response.data) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['context'], 'widget') + self.assertEqual(response.data['capability'], 'text') + self.assertEqual(response.data['priority'], 100) + self.assertEqual(response.data['external_model_id'], 'gpt-4') + self.assertEqual(response.data['ai_provider']['id'], self.ai_provider.id) + + def test_list_app_ai_providers(self): + config = AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='widget', + capability='text', + external_model_id='gpt-4' + ) + + self.client.force_authenticate(user=self.user) + url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]['id'], config.id) + + def test_filter_by_context(self): + AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='widget', + capability='text' + ) + AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='dashboard', + capability='text' + ) + + self.client.force_authenticate(user=self.user) + url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) + response = self.client.get(url, {'context': 'widget'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]['context'], 'widget') + + def test_filter_by_capability(self): + AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='widget', + capability='text' + ) + AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='widget', + capability='image' + ) + + self.client.force_authenticate(user=self.user) + url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) + response = self.client.get(url, {'capability': 'text'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data[0]['capability'], 'text') + + def test_update_app_ai_provider(self): + config = AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='widget', + capability='text', + external_model_id='gpt-4' + ) + + self.client.force_authenticate(user=self.user) + url = reverse('application-ai-providers-detail', kwargs={ + 'application_uuid': self.application.uuid, + 'uuid': config.uuid + }) + data = {'external_model_id': 'gpt-3.5-turbo'} + response = self.client.put(url, data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + config.refresh_from_db() + self.assertEqual(config.external_model_id, 'gpt-3.5-turbo') + + def test_delete_app_ai_provider(self): + config = AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider + ) + + self.client.force_authenticate(user=self.user) + url = reverse('application-ai-providers-detail', kwargs={ + 'application_uuid': self.application.uuid, + 'uuid': config.uuid + }) + response = self.client.delete(url) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertFalse(AppAIProvider.objects.filter(id=config.id).exists()) + + def test_priority_auto_assignment(self): + config1 = AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='widget', + capability='text' + ) + self.assertEqual(config1.priority, 100) + + config2 = AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='widget', + capability='text' + ) + self.assertEqual(config2.priority, 200) + + config3 = AppAIProvider.objects.create( + application=self.application, + ai_provider=self.ai_provider, + context='dashboard', + capability='text' + ) + self.assertEqual(config3.priority, 100) + + def test_unauthorized_access(self): + other_user = User.objects.create_user(username='other', password='other') + other_app = Application.objects.create(owner=other_user, name='Other App') + other_ai_provider = AIProvider.objects.create( + name='Other AI Provider', + provider='openai', + base_url='https://api.openai.com', + provider_api_key='test', + creator=other_user + ) + other_config = AppAIProvider.objects.create( + application=other_app, + ai_provider=other_ai_provider, + context='widget', + capability='text', + external_model_id='gpt-4' + ) + + self.client.force_authenticate(user=self.user) + + url = reverse('application-ai-providers-list', kwargs={'application_uuid': other_app.uuid}) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + data = { + 'ai_provider_id': self.ai_provider.id, + 'context': 'widget', + 'capability': 'text', + 'external_model_id': 'gpt-4' + } + response = self.client.post(url, data) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + update_url = reverse('application-ai-providers-detail', kwargs={ + 'application_uuid': other_app.uuid, + 'uuid': other_config.uuid + }) + update_data = {'external_model_id': 'gpt-3.5-turbo'} + response = self.client.put(update_url, update_data) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + response = self.client.get(update_url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + response = self.client.delete(update_url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_invalid_ai_provider(self): + self.client.force_authenticate(user=self.user) + url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) + data = { + 'ai_provider_id': 999, + 'context': 'widget', + 'capability': 'text' + } + response = self.client.post(url, data) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/backend/core/urls.py b/backend/core/urls.py index 334a000..a4fef32 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -23,6 +23,7 @@ from core.views.forgot_password import ForgotPasswordView from core.views.reset_password import ResetPasswordView, ResetPasswordVerifyView from core.views.ai_provider import AIProviderViewSet +from core.views.app_ai_provider import AppAIProviderViewSet router = DefaultRouter() router.register(r'applications', ApplicationViewSet, basename='applications') @@ -33,6 +34,7 @@ nested_router = NestedDefaultRouter(router, r'applications', lookup='application') nested_router.register(r'knowledge-bases', KnowledgeBaseViewSet, basename='application-knowledge-bases') +nested_router.register(r'ai-providers', AppAIProviderViewSet, basename='application-ai-providers') urlpatterns = [ path('login/', CustomAuthToken.as_view(), name='api_login'), diff --git a/backend/core/views/app_ai_provider.py b/backend/core/views/app_ai_provider.py new file mode 100644 index 0000000..3cd350e --- /dev/null +++ b/backend/core/views/app_ai_provider.py @@ -0,0 +1,55 @@ +from rest_framework import viewsets, status, permissions +from rest_framework.response import Response +from django.shortcuts import get_object_or_404 + +from core.models.app_ai_provider import AppAIProvider +from core.models.application import Application +from core.serializers.app_ai_provider import ( + AppAIProviderSerializer, + AppAIProviderCreateSerializer, + AppAIProviderUpdateSerializer +) + +class AppAIProviderViewSet(viewsets.ModelViewSet): + lookup_field = 'uuid' + permission_classes = [permissions.IsAuthenticated] + http_method_names = ['get', 'post', 'put', 'patch', 'delete'] + + def get_queryset(self): + application = get_object_or_404( + Application, + uuid=self.kwargs['application_uuid'], + owner=self.request.user + ) + + queryset = AppAIProvider.objects.filter(application=application) + + context = self.request.query_params.get('context') + capability = self.request.query_params.get('capability') + + if context: + queryset = queryset.filter(context=context) + if capability: + queryset = queryset.filter(capability=capability) + + return queryset + + def get_serializer_class(self): + if self.action == 'create': + return AppAIProviderCreateSerializer + elif self.action in ['update', 'partial_update']: + return AppAIProviderUpdateSerializer + return AppAIProviderSerializer + + def get_serializer_context(self): + context = super().get_serializer_context() + application = get_object_or_404( + Application, + uuid=self.kwargs['application_uuid'], + owner=self.request.user + ) + context['application'] = application + return context + + def perform_create(self, serializer): + serializer.save() From c126bd659569848d1f3f2f73a4aecf440b2a784a Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Mon, 2 Mar 2026 01:50:16 +0545 Subject: [PATCH 11/65] refactor: use uuid lookup for ai provider apis and fix tests --- backend/core/tests/test_ai_provider.py | 20 ++++++++++---------- backend/core/views/ai_provider.py | 1 + 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/backend/core/tests/test_ai_provider.py b/backend/core/tests/test_ai_provider.py index 5976bee..ce072e4 100644 --- a/backend/core/tests/test_ai_provider.py +++ b/backend/core/tests/test_ai_provider.py @@ -95,7 +95,7 @@ def test_retrieve_other_users_provider(self): self.client.force_authenticate(user=user_a) - detail_url = f'/api/ai-providers/{provider_b.id}/' + detail_url = f'/api/ai-providers/{provider_b.uuid}/' response = self.client.get(detail_url) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @@ -109,7 +109,7 @@ def test_update_other_users_provider(self): self.client.force_authenticate(user=user_a) - detail_url = f'/api/ai-providers/{provider_b.id}/' + detail_url = f'/api/ai-providers/{provider_b.uuid}/' update_data = {'name': 'Hacked Name'} response = self.client.patch(detail_url, update_data, format='json') @@ -124,7 +124,7 @@ def test_delete_other_users_provider(self): self.client.force_authenticate(user=user_a) - detail_url = f'/api/ai-providers/{provider_b.id}/' + detail_url = f'/api/ai-providers/{provider_b.uuid}/' response = self.client.delete(detail_url) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @@ -163,7 +163,7 @@ def test_update_own_provider(self): provider = AIProviderFactory(creator=user, name="Original Name") - detail_url = f'/api/ai-providers/{provider.id}/' + detail_url = f'/api/ai-providers/{provider.uuid}/' update_data = {'name': 'Updated Name'} response = self.client.patch(detail_url, update_data, format='json') @@ -185,7 +185,7 @@ def test_cannot_update_provider_field(self): provider = AIProviderFactory(creator=user, provider='openai', name="Original Name") - detail_url = f'/api/ai-providers/{provider.id}/' + detail_url = f'/api/ai-providers/{provider.uuid}/' update_data = { 'name': 'Updated Name', 'provider': 'anthropic' @@ -205,7 +205,7 @@ def test_delete_own_provider(self): provider = AIProviderFactory(creator=user, name="Provider to Delete") - detail_url = f'/api/ai-providers/{provider.id}/' + detail_url = f'/api/ai-providers/{provider.uuid}/' response = self.client.delete(detail_url) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) @@ -221,7 +221,7 @@ def test_update_without_api_key_does_not_change_api_key(self): provider = AIProviderFactory(creator=user, name="Test Provider") original_api_key = provider.provider_api_key - detail_url = f'/api/ai-providers/{provider.id}/' + detail_url = f'/api/ai-providers/{provider.uuid}/' update_data = {'name': 'Updated Name'} response = self.client.patch(detail_url, update_data, format='json') @@ -240,7 +240,7 @@ def test_update_with_api_key_changes_api_key(self): original_api_key = provider.provider_api_key new_api_key = 'new-api-key-12345' - detail_url = f'/api/ai-providers/{provider.id}/' + detail_url = f'/api/ai-providers/{provider.uuid}/' update_data = {'name': 'Updated Name', 'provider_api_key': new_api_key} response = self.client.patch(detail_url, update_data, format='json') @@ -259,7 +259,7 @@ def test_update_with_empty_api_key_does_not_change_api_key(self): provider = AIProviderFactory(creator=user, name="Test Provider") original_api_key = provider.provider_api_key - detail_url = f'/api/ai-providers/{provider.id}/' + detail_url = f'/api/ai-providers/{provider.uuid}/' update_data = {'name': 'Updated Name', 'provider_api_key': ''} response = self.client.patch(detail_url, update_data, format='json') @@ -277,7 +277,7 @@ def test_update_with_whitespace_api_key_does_not_change_api_key(self): provider = AIProviderFactory(creator=user, name="Test Provider") original_api_key = provider.provider_api_key - detail_url = f'/api/ai-providers/{provider.id}/' + detail_url = f'/api/ai-providers/{provider.uuid}/' update_data = {'name': 'Updated Name', 'provider_api_key': ' '} response = self.client.patch(detail_url, update_data, format='json') diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py index 6939cde..f524f5d 100644 --- a/backend/core/views/ai_provider.py +++ b/backend/core/views/ai_provider.py @@ -6,6 +6,7 @@ class AIProviderViewSet(viewsets.ModelViewSet): permission_classes = [permissions.IsAuthenticated] + lookup_field = 'uuid' http_method_names = ['get', 'post', 'put','patch', 'delete'] queryset = AIProvider.objects.all() From e45cf86bd2a5e7181c71db7c2b38af83c7b702f3 Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Wed, 4 Mar 2026 14:31:33 +0545 Subject: [PATCH 12/65] improve: refactor and integrate AI Provider setting --- backend/config/settings.py | 4 +- backend/core/consts.py | 15 +- backend/core/serializers/ai_provider.py | 11 +- backend/core/views/ai_provider.py | 16 +- .../components/AIProvider/NewAIProvider.vue | 191 ++++++++++++++++++ .../UpdateAIProvider.vue} | 10 +- frontend/components/AppSidebar.vue | 2 +- frontend/components/C8Select.vue | 14 +- frontend/components/Model/NewModel.vue | 162 --------------- frontend/composables/useUniqueName.ts | 18 ++ .../settings/{models.vue => ai-providers.vue} | 22 +- frontend/stores/aiProvider.ts | 98 +++++++++ frontend/stores/model.ts | 84 -------- 13 files changed, 370 insertions(+), 277 deletions(-) create mode 100644 frontend/components/AIProvider/NewAIProvider.vue rename frontend/components/{Model/UpdateModel.vue => AIProvider/UpdateAIProvider.vue} (94%) delete mode 100644 frontend/components/Model/NewModel.vue create mode 100644 frontend/composables/useUniqueName.ts rename frontend/pages/settings/{models.vue => ai-providers.vue} (72%) create mode 100644 frontend/stores/aiProvider.ts delete mode 100644 frontend/stores/model.ts diff --git a/backend/config/settings.py b/backend/config/settings.py index 70f9dcb..a9fc06f 100644 --- a/backend/config/settings.py +++ b/backend/config/settings.py @@ -113,7 +113,9 @@ 'rest_framework.parsers.MultiPartParser', 'rest_framework.parsers.FormParser', 'rest_framework.parsers.JSONParser', - ] + ], + 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', + 'PAGE_SIZE': 20, } CORS_ALLOWED_ORIGINS = [ diff --git a/backend/core/consts.py b/backend/core/consts.py index 2d301bf..3f4a39a 100644 --- a/backend/core/consts.py +++ b/backend/core/consts.py @@ -5,4 +5,17 @@ AI_ROLE_HUMAN_AGENT="assistant" AI_ROLE_USER="user" AI_ROLE_SYSTEM="system" -AI_ROLE_UNKNOWN="unknown" \ No newline at end of file +AI_ROLE_UNKNOWN="unknown" + +SUPPORTED_AI_PROVIDERS = [ + { + 'id': 'gemini', + 'label': 'Google Gemini', + 'base_url': 'https://generativelanguage.googleapis.com' + }, + { + 'id': 'custom', + 'label': 'Custom Provider', + 'base_url': '' + } +] diff --git a/backend/core/serializers/ai_provider.py b/backend/core/serializers/ai_provider.py index b48b110..e0f1380 100644 --- a/backend/core/serializers/ai_provider.py +++ b/backend/core/serializers/ai_provider.py @@ -1,5 +1,6 @@ from rest_framework import serializers from core.models.ai_provider import AIProvider +from core.consts import SUPPORTED_AI_PROVIDERS class AIProviderSerializer(serializers.ModelSerializer): class Meta: @@ -19,8 +20,14 @@ def __init__(self, *args, **kwargs): self.fields['provider_api_key'].allow_blank = True self.fields.pop('provider', None) - def to_representation(self, instance): - return AIProviderSerializer(instance, context=self.context).data + def validate_provider(self, value): + supported_provider_ids = [p['id'] for p in SUPPORTED_AI_PROVIDERS] + if value not in supported_provider_ids: + supported_labels = [p['label'] for p in SUPPORTED_AI_PROVIDERS] + raise serializers.ValidationError( + f"Provider '{value}' is not supported. Supported providers are: {', '.join(supported_labels)}" + ) + return value def create(self, validated_data): validated_data['creator'] = self.context['request'].user diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py index f524f5d..4bf723e 100644 --- a/backend/core/views/ai_provider.py +++ b/backend/core/views/ai_provider.py @@ -1,13 +1,16 @@ from rest_framework import status, viewsets, permissions from rest_framework.response import Response +from rest_framework.pagination import PageNumberPagination from django.db import models from core.serializers.ai_provider import AIProviderCreateSerializer, AIProviderSerializer from core.models import AIProvider +from core.consts import SUPPORTED_AI_PROVIDERS class AIProviderViewSet(viewsets.ModelViewSet): permission_classes = [permissions.IsAuthenticated] lookup_field = 'uuid' http_method_names = ['get', 'post', 'put','patch', 'delete'] + pagination_class = PageNumberPagination queryset = AIProvider.objects.all() @@ -23,4 +26,15 @@ def get_queryset(self): ) def perform_create(self, serializer): - serializer.save(creator=self.request.user) \ No newline at end of file + serializer.save(creator=self.request.user) + + def list(self, request, *args, **kwargs): + response = super().list(request, *args, **kwargs) + if isinstance(response.data, dict): + response.data['supported_ai_providers'] = SUPPORTED_AI_PROVIDERS + else: + response.data = { + 'results': response.data, + 'supported_ai_providers': SUPPORTED_AI_PROVIDERS + } + return response diff --git a/frontend/components/AIProvider/NewAIProvider.vue b/frontend/components/AIProvider/NewAIProvider.vue new file mode 100644 index 0000000..d01daac --- /dev/null +++ b/frontend/components/AIProvider/NewAIProvider.vue @@ -0,0 +1,191 @@ + + diff --git a/frontend/components/Model/UpdateModel.vue b/frontend/components/AIProvider/UpdateAIProvider.vue similarity index 94% rename from frontend/components/Model/UpdateModel.vue rename to frontend/components/AIProvider/UpdateAIProvider.vue index 6b9545a..eaf0e6c 100644 --- a/frontend/components/Model/UpdateModel.vue +++ b/frontend/components/AIProvider/UpdateAIProvider.vue @@ -107,9 +107,9 @@ import { z } from 'zod' import { setBackendErrors } from '~/lib/utils' import { ModelTypes } from '~/lib/consts' -const updateModelSlide = ref | null>(null) +const updateAIProviderSlide = ref | null>(null) -const modelStore = useModelStore() +const AIProviderStore = useAIProviderStore() const selectedModelType: Ref<{ label: string, value: string } | null> = ref(null) @@ -153,7 +153,7 @@ function open(model: LLMModel) { }) selectedModelType.value = ModelTypes.find((m) => m.value === modelType) || null - updateModelSlide.value?.openSlide() + updateAIProviderSlide.value?.openSlide() } defineExpose({ @@ -162,8 +162,8 @@ defineExpose({ const updateModel = form.handleSubmit(async (values) => { try { - await modelStore.update(values) - updateModelSlide.value?.closeSlide() + await AIProviderStore.update(values) + updateAIProviderSlide.value?.closeSlide() toast.success('Model updated') } catch (e: unknown) { setBackendErrors(form, e.errors) diff --git a/frontend/components/AppSidebar.vue b/frontend/components/AppSidebar.vue index 9fd1313..37993ff 100644 --- a/frontend/components/AppSidebar.vue +++ b/frontend/components/AppSidebar.vue @@ -215,7 +215,7 @@ async function initNewChat() { ]" > diff --git a/frontend/components/C8Select.vue b/frontend/components/C8Select.vue index 1766493..ab6a930 100644 --- a/frontend/components/C8Select.vue +++ b/frontend/components/C8Select.vue @@ -31,10 +31,10 @@ import { import { computed } from 'vue' const props = defineProps<{ - modelValue: { label: string; value: string | number } | null + modelValue?: string | null label?: string placeholder?: string - options: { label: string; value: string | number }[] + options: { label: string; value: string | number; logo?: string }[] disabled?: boolean containerClass?: string triggerClass?: string @@ -43,10 +43,6 @@ const props = defineProps<{ const emit = defineEmits(['update:modelValue']) const internalValue = computed({ - get: () => props.modelValue?.value ?? null, - set: (val) => { - const option = props.options.find((opt) => opt.value === val) || null - emit('update:modelValue', option) - }, -}) - + get: () => props.modelValue ?? null, + set: (val) => emit('update:modelValue', val), +}) diff --git a/frontend/components/Model/NewModel.vue b/frontend/components/Model/NewModel.vue deleted file mode 100644 index 4bf8b92..0000000 --- a/frontend/components/Model/NewModel.vue +++ /dev/null @@ -1,162 +0,0 @@ - - diff --git a/frontend/composables/useUniqueName.ts b/frontend/composables/useUniqueName.ts new file mode 100644 index 0000000..196b68e --- /dev/null +++ b/frontend/composables/useUniqueName.ts @@ -0,0 +1,18 @@ +export function useUniqueName() { + const generateUniqueName = (baseName: string, separator: string = ' '): string => { + const timestamp = Date.now() + return `${baseName}${separator}${timestamp}` + } + + const generateShortUniqueName = (baseName: string, separator: string = ' '): string => { + const timestamp = Date.now() + const randomPart = Math.floor(Math.random() * 100) + const uniqueSuffix = `${(timestamp % 1000).toString().padStart(3, '0')}${randomPart.toString().padStart(2, '0')}` // 3 timestamp + 2 random = 5 digits + return `${baseName}${separator}${uniqueSuffix}` + } + + return { + generateUniqueName, + generateShortUniqueName + } +} diff --git a/frontend/pages/settings/models.vue b/frontend/pages/settings/ai-providers.vue similarity index 72% rename from frontend/pages/settings/models.vue rename to frontend/pages/settings/ai-providers.vue index f3d19c4..1b39226 100644 --- a/frontend/pages/settings/models.vue +++ b/frontend/pages/settings/ai-providers.vue @@ -3,7 +3,7 @@
- +
@@ -15,20 +15,20 @@ :update-fn="updateModel" /> - +
diff --git a/frontend/stores/aiProvider.ts b/frontend/stores/aiProvider.ts new file mode 100644 index 0000000..62e7ab2 --- /dev/null +++ b/frontend/stores/aiProvider.ts @@ -0,0 +1,98 @@ +import { defineStore } from 'pinia' +import { useHttpClient } from '@/composables/useHttpClient' + +export type AIProviderType = 'text' | 'embedding' | 'image' | 'rerank' | 'other' + +export interface AIProvider { + id: number + uuid: string + + owner: number + + name: string + provider_api_key?: string | null + base_url?: string | null + provider: string + + is_builtin: boolean + + created_at: string +} + +export interface SupportedAIProvider { + id: string + label: string + base_url: string +} + +export interface AIProvidersResponse { + count: number + next: string | null + previous: string | null + results: AIProvider[] + supported_ai_providers: SupportedAIProvider[] +} + +export const useAIProviderStore = defineStore('aiProvider', { + state: () => ({ + AIProviders: [] as AIProvider[], + supportedAIProviders: [] as SupportedAIProvider[], + }), + + actions: { + async load() { + const { httpGet } = useHttpClient() + const response = await httpGet(`/ai-providers/`) + this.AIProviders = response.results + this.supportedAIProviders = response.supported_ai_providers + return response + }, + + async create(values: Record) { + const { httpPost } = useHttpClient() + const response = await httpPost('/ai-providers/', { + name: values.name, + provider_api_key: values.provider_api_key, + base_url: values.base_url, + provider: values.provider, + }) + this.AIProviders = [...this.AIProviders, response] + return response + }, + + async update(values: Record) { + const { httpPatch } = useHttpClient() + + const body: Record = { + name: values.name, + base_url: values.base_url, + model_name: values.model_name, + } + + if (values.provider_api_key) { + body.provider_api_key = values.provider_api_key + } + + const response = await httpPatch(`/ai-providers/${values.uuid}/`, body) + + const index = this.AIProviders.findIndex(p => p.uuid === values.uuid) + if (index !== -1 && response?.name) { + this.AIProviders[index] = { ...this.AIProviders[index], ...response } + } + + return response + }, + + async delete(uuid: string) { + const { httpDelete } = useHttpClient() + + const response = await httpDelete<{detail: string}>(`/ai-providers/${uuid}/`) + + if (response?.detail === 'Deleted') { + this.AIProviders = this.AIProviders.filter((p) => p.uuid !== uuid) + } + + return response + }, + }, +}) diff --git a/frontend/stores/model.ts b/frontend/stores/model.ts deleted file mode 100644 index 2a03f1e..0000000 --- a/frontend/stores/model.ts +++ /dev/null @@ -1,84 +0,0 @@ -import { defineStore } from 'pinia' -import { useHttpClient } from '@/composables/useHttpClient' - -export type LLMModelType = 'text' | 'embedding' | 'image' | 'rerank' | 'other' - -export interface LLMModel { - id: number - uuid: string - - owner: number - - name: string - api_key?: string | null - base_url?: string | null - model_name: string - - model_type: LLMModelType - is_default: boolean - - created_at: string -} - -export const useModelStore = defineStore('model', { - state: () => ({ - models: [] as LLMModel[], - }), - - actions: { - async load() { - const { httpGet } = useHttpClient() - const response = await httpGet(`/models/`) - this.models = response - return response - }, - - async create(values: Record) { - const { httpPost } = useHttpClient() - const response = await httpPost('/models/', { - name: values.name, - api_key: values.api_key, - base_url: values.base_url, - model_name: values.model_name, - model_type: values.model_type, - }) - this.models = [...this.models, response] - return response - }, - - async update(values: Record) { - const { httpPatch } = useHttpClient() - - const body: Record = { - name: values.name, - base_url: values.base_url, - model_name: values.model_name, - } - - if (values.api_key) { - body.api_key = values.api_key - } - - const response = await httpPatch(`/models/${values.uuid}/`, body) - - const index = this.models.findIndex(m => m.uuid === values.uuid) - if (index !== -1 && response?.name) { - this.models[index] = { ...this.models[index], ...response } - } - - return response - }, - - async delete(uuid: string) { - const { httpDelete } = useHttpClient() - - const response = await httpDelete<{detail: string}>(`/models/${uuid}/`) - - if (response?.detail === 'Deleted') { - this.models = this.models.filter((m) => m.uuid !== uuid) - } - - return response - }, - }, -}) From aac6c4576b371990fba59100e4d32ede6658eb2d Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Wed, 4 Mar 2026 23:10:06 +0545 Subject: [PATCH 13/65] feat: validate ai providers --- backend/core/consts.py | 2 +- backend/core/models/ai_provider.py | 3 + backend/core/serializers/ai_provider.py | 4 +- backend/core/services/__init__.py | 7 +- backend/core/services/ai_service.py | 86 ++++++++++++++ backend/core/services/contracts/__init__.py | 1 + .../contracts/ai_provider_contract.py | 20 ++++ backend/core/services/factories/__init__.py | 1 + .../services/factories/ai_provider_factory.py | 49 ++++++++ backend/core/services/providers/__init__.py | 2 + .../core/services/providers/ai/__init__.py | 2 + .../services/providers/ai/custom_provider.py | 26 +++++ .../services/providers/ai/gemini_provider.py | 49 ++++++++ backend/core/tests.py | 3 - backend/core/tests/test_ai_provider.py | 105 ++++++++++++++---- backend/core/tests/test_app_ai_provider.py | 16 +-- backend/core/views/ai_provider.py | 49 +++++++- backend/core/views/app_ai_provider.py | 3 +- .../components/AIProvider/NewAIProvider.vue | 39 ++++++- 19 files changed, 422 insertions(+), 45 deletions(-) create mode 100644 backend/core/services/ai_service.py create mode 100644 backend/core/services/contracts/__init__.py create mode 100644 backend/core/services/contracts/ai_provider_contract.py create mode 100644 backend/core/services/factories/__init__.py create mode 100644 backend/core/services/factories/ai_provider_factory.py create mode 100644 backend/core/services/providers/__init__.py create mode 100644 backend/core/services/providers/ai/__init__.py create mode 100644 backend/core/services/providers/ai/custom_provider.py create mode 100644 backend/core/services/providers/ai/gemini_provider.py delete mode 100644 backend/core/tests.py diff --git a/backend/core/consts.py b/backend/core/consts.py index 3f4a39a..0b1c4e9 100644 --- a/backend/core/consts.py +++ b/backend/core/consts.py @@ -11,7 +11,7 @@ { 'id': 'gemini', 'label': 'Google Gemini', - 'base_url': 'https://generativelanguage.googleapis.com' + 'base_url': 'https://generativelanguage.googleapis.com/v1beta' }, { 'id': 'custom', diff --git a/backend/core/models/ai_provider.py b/backend/core/models/ai_provider.py index 3e7ebbe..c12feef 100644 --- a/backend/core/models/ai_provider.py +++ b/backend/core/models/ai_provider.py @@ -13,3 +13,6 @@ class AIProvider(BaseModel): is_builtin = models.BooleanField(default=False, blank=True) creator = models.ForeignKey(User, on_delete=models.CASCADE) + + class Meta: + ordering = ['created_at'] diff --git a/backend/core/serializers/ai_provider.py b/backend/core/serializers/ai_provider.py index e0f1380..fb4e8e6 100644 --- a/backend/core/serializers/ai_provider.py +++ b/backend/core/serializers/ai_provider.py @@ -10,8 +10,8 @@ class Meta: class AIProviderCreateSerializer(serializers.ModelSerializer): class Meta: model = AIProvider - fields = ['name', 'provider', 'base_url', 'provider_api_key', 'creator'] - read_only_fields = ['creator'] + fields = ['uuid', 'name', 'provider', 'base_url', 'provider_api_key', 'creator'] + read_only_fields = ['uuid', 'creator'] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/backend/core/services/__init__.py b/backend/core/services/__init__.py index 922f59c..6f4d711 100644 --- a/backend/core/services/__init__.py +++ b/backend/core/services/__init__.py @@ -2,4 +2,9 @@ from .file_extractors import extract_text_from_file from .notifications import notify_users from .encryption import encrypt, decrypt, generate_verification_token, verify_verification_token -from .private_key_encryption import decrypt_with_private_key \ No newline at end of file +from .private_key_encryption import decrypt_with_private_key +from .contracts.ai_provider_contract import AIProviderContract +from .factories.ai_provider_factory import AIProviderFactory +from .providers.ai.gemini_provider import GeminiProvider +from .providers.ai.custom_provider import CustomProvider +from .ai_service import AIService \ No newline at end of file diff --git a/backend/core/services/ai_service.py b/backend/core/services/ai_service.py new file mode 100644 index 0000000..0b69a5e --- /dev/null +++ b/backend/core/services/ai_service.py @@ -0,0 +1,86 @@ +from typing import Optional, Any, Dict +from django.db.models import Q + +from core.models import Application, AppAIProvider +from .factories.ai_provider_factory import AIProviderFactory + + +class AIService: + + def __init__(self): + self.provider_factory = AIProviderFactory() + + def get_provider_for_app(self, application: Application, context: str = 'widget', + capability: str = 'text') -> Optional[Any]: + try: + config = AppAIProvider.objects.filter( + application=application, + context=context, + capability=capability, + is_active=True, + ai_provider__is_builtin=True + ).select_related('ai_provider').first() + + if not config: + config = AppAIProvider.objects.filter( + application=application, + context=context, + capability=capability, + is_active=True + ).select_related('ai_provider').order_by('priority').first() + + if not config: + return None + + ai_provider = config.ai_provider + + provider_instance = self.provider_factory.create_provider( + provider_type=ai_provider.provider, + api_key=ai_provider.provider_api_key, + base_url=ai_provider.base_url + ) + + return provider_instance + + except Exception as e: + print(f"Error getting AI provider: {e}") + return None + + def generate_content(self, application: Application, contents: str, + model: Optional[str] = None, context: str = 'widget', + capability: str = 'text', **kwargs) -> Optional[str]: + provider = self.get_provider_for_app(application, context, capability) + if not provider: + return None + + if model is None: + config = AppAIProvider.objects.filter( + application=application, + context=context, + capability=capability, + is_active=True + ).select_related('ai_provider').order_by('priority').first() + + if config and config.external_model_id: + model = config.external_model_id + else: + supported_models = provider.get_models() + model = supported_models[0] if supported_models else 'default' + + try: + return provider.generate_content(model, contents, **kwargs) + except Exception as e: + print(f"Error generating content: {e}") + return None + + def validate_provider_connection(self, application: Application, + context: str = 'widget', + capability: str = 'text') -> tuple[bool, list[str]]: + provider = self.get_provider_for_app(application, context, capability) + if not provider: + return False, [] + + try: + return provider.validate_connection() + except Exception: + return False, [] diff --git a/backend/core/services/contracts/__init__.py b/backend/core/services/contracts/__init__.py new file mode 100644 index 0000000..8f0d13f --- /dev/null +++ b/backend/core/services/contracts/__init__.py @@ -0,0 +1 @@ +from .ai_provider_contract import AIProviderContract \ No newline at end of file diff --git a/backend/core/services/contracts/ai_provider_contract.py b/backend/core/services/contracts/ai_provider_contract.py new file mode 100644 index 0000000..50f188d --- /dev/null +++ b/backend/core/services/contracts/ai_provider_contract.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import Optional + + +class AIProviderContract(ABC): + def __init__(self, api_key: str, base_url: Optional[str] = None): + self.api_key = api_key + self.base_url = base_url + + @abstractmethod + def generate_content(self, model: str, contents: str, **kwargs) -> str: + pass + + @abstractmethod + def validate_connection(self) -> tuple[bool, list[str]]: + pass + + @abstractmethod + def get_models(self) -> list[str]: + pass diff --git a/backend/core/services/factories/__init__.py b/backend/core/services/factories/__init__.py new file mode 100644 index 0000000..d0a1138 --- /dev/null +++ b/backend/core/services/factories/__init__.py @@ -0,0 +1 @@ +from .ai_provider_factory import AIProviderFactory \ No newline at end of file diff --git a/backend/core/services/factories/ai_provider_factory.py b/backend/core/services/factories/ai_provider_factory.py new file mode 100644 index 0000000..5c1907b --- /dev/null +++ b/backend/core/services/factories/ai_provider_factory.py @@ -0,0 +1,49 @@ +from typing import Optional + +from ..contracts.ai_provider_contract import AIProviderContract +from ..providers.ai.custom_provider import CustomProvider +from ..providers.ai.gemini_provider import GeminiProvider + + +class AIProviderFactory: + PROVIDER_CLASSES = { + 'gemini': GeminiProvider, + 'custom': CustomProvider, + } + + @staticmethod + def create_provider(provider_type: str, api_key: str, base_url: Optional[str] = None) -> AIProviderContract: + provider_class = AIProviderFactory.PROVIDER_CLASSES.get(provider_type.lower()) + + if provider_class is None: + supported_providers = list(AIProviderFactory.PROVIDER_CLASSES.keys()) + raise ValueError( + f"Unsupported provider type: {provider_type}. " + f"Supported providers: {supported_providers}" + ) + + try: + return provider_class(api_key=api_key, base_url=base_url) + except Exception as e: + raise ValueError(f"Failed to create {provider_type} provider: {e}") + + @staticmethod + def validate_provider(provider_type: str, api_key: str, base_url: Optional[str] = None) -> tuple[bool, list[str]]: + provider_class = AIProviderFactory.PROVIDER_CLASSES.get(provider_type.lower()) + + if provider_class is None: + supported_providers = list(AIProviderFactory.PROVIDER_CLASSES.keys()) + raise ValueError( + f"Unsupported provider type: {provider_type}. " + f"Supported providers: {supported_providers}" + ) + + try: + provider = provider_class(api_key=api_key, base_url=base_url) + return provider.validate_connection() + except Exception as e: + return False, [] + + @staticmethod + def get_supported_providers() -> list[str]: + return list(AIProviderFactory.PROVIDER_CLASSES.keys()) diff --git a/backend/core/services/providers/__init__.py b/backend/core/services/providers/__init__.py new file mode 100644 index 0000000..dda9abe --- /dev/null +++ b/backend/core/services/providers/__init__.py @@ -0,0 +1,2 @@ +from .ai.gemini_provider import GeminiProvider +from .ai.custom_provider import CustomProvider \ No newline at end of file diff --git a/backend/core/services/providers/ai/__init__.py b/backend/core/services/providers/ai/__init__.py new file mode 100644 index 0000000..56e5361 --- /dev/null +++ b/backend/core/services/providers/ai/__init__.py @@ -0,0 +1,2 @@ +from .gemini_provider import GeminiProvider +from .custom_provider import CustomProvider diff --git a/backend/core/services/providers/ai/custom_provider.py b/backend/core/services/providers/ai/custom_provider.py new file mode 100644 index 0000000..cc5c5db --- /dev/null +++ b/backend/core/services/providers/ai/custom_provider.py @@ -0,0 +1,26 @@ +from typing import Optional + +from ...contracts.ai_provider_contract import AIProviderContract + + +class CustomProvider(AIProviderContract): + def __init__(self, api_key: str, base_url: Optional[str] = None): + super().__init__(api_key, base_url) + + def generate_content(self, model: str, contents: str, **kwargs) -> str: + raise NotImplementedError( + "Custom provider generate_content method not implemented. " + "Please implement this method in your custom provider class." + ) + + def validate_connection(self) -> tuple[bool, list[str]]: + raise NotImplementedError( + "Custom provider validate_connection method not implemented. " + "Please implement this method in your custom provider class." + ) + + def get_models(self) -> list[str]: + raise NotImplementedError( + "Custom provider get_supported_models method not implemented. " + "Please implement this method in your custom provider class." + ) diff --git a/backend/core/services/providers/ai/gemini_provider.py b/backend/core/services/providers/ai/gemini_provider.py new file mode 100644 index 0000000..cf90056 --- /dev/null +++ b/backend/core/services/providers/ai/gemini_provider.py @@ -0,0 +1,49 @@ +from typing import Optional +from google import genai +from ...contracts.ai_provider_contract import AIProviderContract + + +class GeminiProvider(AIProviderContract): + SUPPORTED_MODELS = [ + 'gemini-1.5-pro', + 'gemini-1.5-flash', + 'gemini-1.0-pro', + 'gemini-pro-vision' + ] + + def __init__(self, api_key: str, base_url: Optional[str] = None): + super().__init__(api_key, base_url) + + try: + self.client = genai.Client(api_key=api_key) + except Exception as e: + raise ValueError(f"Failed to initialize Gemini client: {e}") + + def generate_content(self, model: str, contents: str, **kwargs) -> str: + if model not in self.SUPPORTED_MODELS: + raise ValueError(f"Unsupported model: {model}. Supported models: {self.SUPPORTED_MODELS}") + + try: + response = self.client.models.generate_content( + model=model, + contents=contents + ) + + return response.text + + except Exception as e: + raise ValueError(f"Gemini API error: {e}") + + def validate_connection(self) -> tuple[bool, list[str]]: + try: + models = self.get_models() + return True, models + except Exception as e: + return False, [] + + def get_models(self) -> list[str]: + try: + models = list(self.client.models.list()) + return [model.name for model in models] + except Exception as e: + raise ValueError(f"Failed to retrieve models from Gemini API: {e}") diff --git a/backend/core/tests.py b/backend/core/tests.py deleted file mode 100644 index 7ce503c..0000000 --- a/backend/core/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/backend/core/tests/test_ai_provider.py b/backend/core/tests/test_ai_provider.py index ce072e4..38b8bc8 100644 --- a/backend/core/tests/test_ai_provider.py +++ b/backend/core/tests/test_ai_provider.py @@ -1,8 +1,11 @@ import pytest from rest_framework import status from rest_framework.test import APIClient +from django.contrib.auth.models import AnonymousUser from core.models import AIProvider +from core.serializers.ai_provider import AIProviderCreateSerializer +from core.consts import SUPPORTED_AI_PROVIDERS from core.tests.conftest import BaseAPITestCase from core.tests.factories import UserFactory, AIProviderFactory @@ -31,18 +34,21 @@ def test_list_ai_providers_authenticated_user(self): self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(len(data), 2) + self.assertEqual(len(data['results']), 2) - provider_names = [provider['name'] for provider in data] + provider_names = [provider['name'] for provider in data['results']] self.assertIn("OpenAI Provider", provider_names) self.assertIn("Anthropic Provider", provider_names) self.assertNotIn("Other User Provider", provider_names) - provider_data = data[0] + provider_data = data['results'][0] expected_fields = ['id', 'uuid', 'name', 'provider', 'base_url', 'is_builtin', 'creator', 'created_at', 'updated_at'] for field in expected_fields: self.assertIn(field, provider_data) + self.assertIn('supported_ai_providers', data) + self.assertIsInstance(data['supported_ai_providers'], list) + def test_list_ai_providers_includes_builtin_and_user_owned(self): """Test that authenticated users can list their AI providers plus builtin providers.""" user = UserFactory() @@ -62,17 +68,17 @@ def test_list_ai_providers_includes_builtin_and_user_owned(self): self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - self.assertEqual(len(data), 4) + self.assertEqual(len(data['results']), 4) - provider_names = [provider['name'] for provider in data] + provider_names = [provider['name'] for provider in data['results']] self.assertIn("User OpenAI Provider", provider_names) self.assertIn("User Anthropic Provider", provider_names) self.assertIn("Builtin OpenAI", provider_names) self.assertIn("Builtin Claude", provider_names) self.assertNotIn("Other User Provider", provider_names) - builtin_providers = [p for p in data if p['is_builtin']] - user_owned_providers = [p for p in data if not p['is_builtin']] + builtin_providers = [p for p in data['results'] if p['is_builtin']] + user_owned_providers = [p for p in data['results'] if not p['is_builtin']] self.assertEqual(len(builtin_providers), 2) self.assertEqual(len(user_owned_providers), 2) @@ -135,26 +141,25 @@ def test_create_ai_provider(self): self.client.force_authenticate(user=user) create_data = { - 'name': 'My OpenAI Provider', - 'provider': 'openai', - 'base_url': 'https://api.openai.com/v1', + 'name': 'My Gemini Provider', + 'provider': 'gemini', + 'base_url': 'https://generativelanguage.googleapis.com', 'provider_api_key': 'sk-test123456789' } response = self.client.post(self.list_url, create_data, format='json') - print(response.data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) data = response.json() - self.assertEqual(data['name'], 'My OpenAI Provider') - self.assertEqual(data['provider'], 'openai') - self.assertEqual(data['base_url'], 'https://api.openai.com/v1') + self.assertEqual(data['name'], 'My Gemini Provider') + self.assertEqual(data['provider'], 'gemini') + self.assertEqual(data['base_url'], 'https://generativelanguage.googleapis.com') self.assertEqual(data['creator'], user.id) - provider = AIProvider.objects.get(id=data['id']) + provider = AIProvider.objects.get(uuid=data['uuid']) self.assertEqual(provider.creator, user) - self.assertEqual(provider.name, 'My OpenAI Provider') + self.assertEqual(provider.name, 'My Gemini Provider') def test_update_own_provider(self): """Test that authenticated user can update their own AI provider.""" @@ -170,9 +175,7 @@ def test_update_own_provider(self): self.assertEqual(response.status_code, status.HTTP_200_OK) data = response.json() - print(response) self.assertEqual(data['name'], 'Updated Name') - self.assertEqual(data['provider'], provider.provider) self.assertEqual(data['creator'], user.id) provider.refresh_from_db() @@ -295,15 +298,15 @@ def test_api_key_is_encrypted_in_database(self): api_key = 'test-api-key-12345' create_data = { 'name': 'Test Provider', - 'provider': 'openai', - 'base_url': 'https://api.openai.com/v1', + 'provider': 'gemini', + 'base_url': 'https://generativelanguage.googleapis.com', 'provider_api_key': api_key } response = self.client.post(self.list_url, create_data, format='json') self.assertEqual(response.status_code, status.HTTP_201_CREATED) - provider = AIProvider.objects.get(id=response.json()['id']) + provider = AIProvider.objects.get(uuid=response.json()['uuid']) self.assertEqual(provider.provider_api_key, api_key) @@ -312,4 +315,62 @@ def test_api_key_is_encrypted_in_database(self): cursor.execute("SELECT provider_api_key FROM core_aiprovider WHERE id = %s", [provider.id]) raw_db_value = cursor.fetchone()[0] self.assertNotEqual(raw_db_value, api_key) - self.assertTrue(raw_db_value.startswith('gAAAAA')) # Fernet encrypted strings start with this + self.assertTrue(raw_db_value.startswith('gAAAAA')) + + + def test_create_with_supported_provider_gemini(self): + """Test that AI provider can be created with supported 'gemini' provider.""" + user = UserFactory() + data = { + 'name': 'My Google Gemini Provider', + 'provider': 'gemini', + 'base_url': 'https://generativelanguage.googleapis.com', + 'provider_api_key': 'test-api-key-12345' + } + + serializer = AIProviderCreateSerializer(data=data, context={'request': type('MockRequest', (), {'user': user})()}) + + assert serializer.is_valid(), f"Serializer should be valid but got errors: {serializer.errors}" + provider = serializer.save() + + assert provider.name == 'My Google Gemini Provider' + assert provider.provider == 'gemini' + assert provider.base_url == 'https://generativelanguage.googleapis.com' + assert provider.creator == user + + def test_create_with_supported_provider_custom(self): + """Test that AI provider can be created with supported 'custom' provider.""" + user = UserFactory() + data = { + 'name': 'My Custom Provider', + 'provider': 'custom', + 'base_url': 'https://my-custom-api.com', + 'provider_api_key': 'test-api-key-67890' + } + + serializer = AIProviderCreateSerializer(data=data, context={'request': type('MockRequest', (), {'user': user})()}) + + assert serializer.is_valid() + provider = serializer.save() + + assert provider.name == 'My Custom Provider' + assert provider.provider == 'custom' + assert provider.base_url == 'https://my-custom-api.com' + assert provider.creator == user + + def test_create_with_unsupported_provider_fails(self): + """Test that AI provider creation fails with unsupported provider.""" + user = UserFactory() + data = { + 'name': 'My Unsupported Provider', + 'provider': 'unsupported_provider', + 'base_url': 'https://unsupported-api.com', + 'provider_api_key': 'test-api-key-12345' + } + + serializer = AIProviderCreateSerializer(data=data, context={'request': type('MockRequest', (), {'user': user})()}) + + assert not serializer.is_valid() + assert 'provider' in serializer.errors + assert "not supported" in str(serializer.errors['provider'][0]) + assert "Google Gemini" in str(serializer.errors['provider'][0]) diff --git a/backend/core/tests/test_app_ai_provider.py b/backend/core/tests/test_app_ai_provider.py index ebd705d..6605fee 100644 --- a/backend/core/tests/test_app_ai_provider.py +++ b/backend/core/tests/test_app_ai_provider.py @@ -19,6 +19,9 @@ def setUp(self): creator=self.user ) + def tearDown(self): + AppAIProvider.objects.all().delete() + def test_create_app_ai_provider(self): self.client.force_authenticate(user=self.user) url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) @@ -29,7 +32,6 @@ def test_create_app_ai_provider(self): 'external_model_id': 'gpt-4' } response = self.client.post(url, data) - print(response.data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data['context'], 'widget') self.assertEqual(response.data['capability'], 'text') @@ -50,8 +52,8 @@ def test_list_app_ai_providers(self): url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - self.assertEqual(response.data[0]['id'], config.id) + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['id'], config.id) def test_filter_by_context(self): AppAIProvider.objects.create( @@ -71,8 +73,8 @@ def test_filter_by_context(self): url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) response = self.client.get(url, {'context': 'widget'}) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - self.assertEqual(response.data[0]['context'], 'widget') + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['context'], 'widget') def test_filter_by_capability(self): AppAIProvider.objects.create( @@ -92,8 +94,8 @@ def test_filter_by_capability(self): url = reverse('application-ai-providers-list', kwargs={'application_uuid': self.application.uuid}) response = self.client.get(url, {'capability': 'text'}) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - self.assertEqual(response.data[0]['capability'], 'text') + self.assertEqual(len(response.data['results']), 1) + self.assertEqual(response.data['results'][0]['capability'], 'text') def test_update_app_ai_provider(self): config = AppAIProvider.objects.create( diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py index 4bf723e..4d79852 100644 --- a/backend/core/views/ai_provider.py +++ b/backend/core/views/ai_provider.py @@ -25,8 +25,53 @@ def get_queryset(self): models.Q(creator=user) | models.Q(is_builtin=True) ) - def perform_create(self, serializer): - serializer.save(creator=self.request.user) + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + validated_data = serializer.validated_data + + from core.services.factories.ai_provider_factory import AIProviderFactory + + factory = AIProviderFactory() + try: + is_valid, models = factory.validate_provider( + provider_type=validated_data['provider'], + api_key=validated_data['provider_api_key'], + base_url=validated_data.get('base_url') + ) + + if not is_valid: + return Response( + { + 'error': 'Failed to validate AI provider connection', + 'details': 'Unable to connect to the AI provider with the provided credentials' + }, + status=status.HTTP_400_BAD_REQUEST + ) + + ai_provider = serializer.save() + + response_serializer = AIProviderSerializer(ai_provider) + return Response( + { + 'ai_provider': response_serializer.data, + 'validation': { + 'is_valid': True, + 'models': models + } + }, + status=status.HTTP_201_CREATED + ) + + except Exception as e: + return Response( + { + 'error': 'Failed to validate AI provider connection', + 'details': str(e) + }, + status=status.HTTP_400_BAD_REQUEST + ) def list(self, request, *args, **kwargs): response = super().list(request, *args, **kwargs) diff --git a/backend/core/views/app_ai_provider.py b/backend/core/views/app_ai_provider.py index 3cd350e..af99045 100644 --- a/backend/core/views/app_ai_provider.py +++ b/backend/core/views/app_ai_provider.py @@ -1,5 +1,4 @@ -from rest_framework import viewsets, status, permissions -from rest_framework.response import Response +from rest_framework import viewsets, permissions from django.shortcuts import get_object_or_404 from core.models.app_ai_provider import AppAIProvider diff --git a/frontend/components/AIProvider/NewAIProvider.vue b/frontend/components/AIProvider/NewAIProvider.vue index d01daac..141c5e5 100644 --- a/frontend/components/AIProvider/NewAIProvider.vue +++ b/frontend/components/AIProvider/NewAIProvider.vue @@ -6,7 +6,16 @@ +
+ + + {{ formError.error }} + +

{{ formError.details }}

+
+
+ @@ -111,12 +120,17 @@ import { z } from 'zod' import { useForm } from 'vee-validate' import { setBackendErrors } from '~/lib/utils' import { useUniqueName } from '~/composables/useUniqueName' -import { Sparkles } from 'lucide-vue-next' +import { Sparkles, AlertCircleIcon } from 'lucide-vue-next' const newAIProviderSlideOver = ref | null>(null) const AIProviderStore = useAIProviderStore() const { generateShortUniqueName } = useUniqueName() +const formError = ref<{ + error?: string + details?: string +} | null>(null) + const providerOptions = computed(() => AIProviderStore.supportedAIProviders.map(p => ({ label: p.label, value: p.id, baseUrl: p.base_url })) ) @@ -146,7 +160,6 @@ onMounted(async () => { const firstProvider = AIProviderStore.supportedAIProviders[0] setFieldValue('provider', firstProvider.id) setFieldValue('base_url', firstProvider.base_url) - // Auto-generate unique connection name const uniqueName = generateShortUniqueName('Connection') setFieldValue('name', uniqueName) } @@ -167,14 +180,30 @@ const generateUniqueConnectionName = () => { } const createNewAIProvider = form.handleSubmit(async (values) => { + formError.value = null + try { await AIProviderStore.create(values) newAIProviderSlideOver.value?.closeSlide() toast.success('AI provider created') } catch (error: unknown) { - const err = error as { errors?: Record } - if (err.errors) { - setBackendErrors(form, err.errors) + const err = error as { + errors?: Record | { error?: string; details?: string } + } + + if (err.errors && typeof err.errors === 'object' && 'error' in err.errors) { + const errorObj = err.errors as { error?: string; details?: string } + formError.value = { + error: errorObj.error, + details: errorObj.details + } + } else if (err.errors && typeof err.errors === 'object') { + setBackendErrors(form, err.errors as Record) + } else { + formError.value = { + error: 'Unexpected Error', + details: 'An unexpected error occurred while creating the AI provider' + } } } }) From 0c77eeb0ca5a5c136913a4b7eab9dde022ea67a5 Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Thu, 5 Mar 2026 00:42:46 +0545 Subject: [PATCH 14/65] feat: make ai providers config dynamic and configurable --- .../0017_migrate_base_url_to_metadata.py | 27 +++++++++++ .../0018_migrate_base_url_to_metadata_data.py | 41 ++++++++++++++++ .../0019_remove_aiprovider_base_url.py | 17 +++++++ .../migrations/0020_aiprovider_base_url.py | 20 ++++++++ .../0021_remove_aiprovider_base_url.py | 17 +++++++ backend/core/models/ai_provider.py | 4 +- backend/core/models/app_ai_provider.py | 3 +- backend/core/serializers/ai_provider.py | 48 ++++++++++++++++++- backend/core/services/ai_service.py | 5 +- .../contracts/ai_provider_contract.py | 6 +-- .../services/factories/ai_provider_factory.py | 10 ++-- .../services/providers/ai/custom_provider.py | 30 +++++------- .../services/providers/ai/gemini_provider.py | 7 ++- backend/core/views/ai_provider.py | 12 ++++- .../components/AIProvider/NewAIProvider.vue | 3 +- 15 files changed, 211 insertions(+), 39 deletions(-) create mode 100644 backend/core/migrations/0017_migrate_base_url_to_metadata.py create mode 100644 backend/core/migrations/0018_migrate_base_url_to_metadata_data.py create mode 100644 backend/core/migrations/0019_remove_aiprovider_base_url.py create mode 100644 backend/core/migrations/0020_aiprovider_base_url.py create mode 100644 backend/core/migrations/0021_remove_aiprovider_base_url.py diff --git a/backend/core/migrations/0017_migrate_base_url_to_metadata.py b/backend/core/migrations/0017_migrate_base_url_to_metadata.py new file mode 100644 index 0000000..38f763d --- /dev/null +++ b/backend/core/migrations/0017_migrate_base_url_to_metadata.py @@ -0,0 +1,27 @@ +# Generated by Django 5.2.3 on 2026-03-04 17:34 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0016_appaiprovider'), + ] + + operations = [ + migrations.AlterModelOptions( + name='aiprovider', + options={'ordering': ['created_at']}, + ), + migrations.AddField( + model_name='appaiprovider', + name='metadata', + field=models.JSONField(blank=True, null=True), + ), + migrations.AlterField( + model_name='appaiprovider', + name='context', + field=models.CharField(max_length=50), + ), + ] diff --git a/backend/core/migrations/0018_migrate_base_url_to_metadata_data.py b/backend/core/migrations/0018_migrate_base_url_to_metadata_data.py new file mode 100644 index 0000000..35308d0 --- /dev/null +++ b/backend/core/migrations/0018_migrate_base_url_to_metadata_data.py @@ -0,0 +1,41 @@ +# Generated manually for AI Provider Configurability Enhancement + +from django.db import migrations, models + + +def migrate_base_url_to_metadata(apps, schema_editor): + """ + Migrate existing base_url values to metadata field for backward compatibility. + """ + AIProvider = apps.get_model('core', 'AIProvider') + + for provider in AIProvider.objects.all(): + if provider.base_url and not provider.metadata: + provider.metadata = {'base_url': provider.base_url} + provider.save() + + +def reverse_migrate_base_url_to_metadata(apps, schema_editor): + """ + Reverse migration: extract base_url from metadata if present. + """ + AIProvider = apps.get_model('core', 'AIProvider') + + for provider in AIProvider.objects.all(): + if provider.metadata and isinstance(provider.metadata, dict) and 'base_url' in provider.metadata: + provider.base_url = provider.metadata['base_url'] + provider.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0017_migrate_base_url_to_metadata'), + ] + + operations = [ + migrations.RunPython( + migrate_base_url_to_metadata, + reverse_migrate_base_url_to_metadata, + ), + ] diff --git a/backend/core/migrations/0019_remove_aiprovider_base_url.py b/backend/core/migrations/0019_remove_aiprovider_base_url.py new file mode 100644 index 0000000..d10b5a3 --- /dev/null +++ b/backend/core/migrations/0019_remove_aiprovider_base_url.py @@ -0,0 +1,17 @@ +# Generated manually for AI Provider Configurability Enhancement - Remove base_url field + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0018_migrate_base_url_to_metadata_data'), + ] + + operations = [ + migrations.RemoveField( + model_name='aiprovider', + name='base_url', + ), + ] diff --git a/backend/core/migrations/0020_aiprovider_base_url.py b/backend/core/migrations/0020_aiprovider_base_url.py new file mode 100644 index 0000000..855f23b --- /dev/null +++ b/backend/core/migrations/0020_aiprovider_base_url.py @@ -0,0 +1,20 @@ +# Generated by Django 5.2.3 on 2026-03-04 17:49 + +import django.utils.timezone +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0019_remove_aiprovider_base_url'), + ] + + operations = [ + migrations.AddField( + model_name='aiprovider', + name='base_url', + field=models.CharField(default=django.utils.timezone.now, max_length=100), + preserve_default=False, + ), + ] diff --git a/backend/core/migrations/0021_remove_aiprovider_base_url.py b/backend/core/migrations/0021_remove_aiprovider_base_url.py new file mode 100644 index 0000000..df607be --- /dev/null +++ b/backend/core/migrations/0021_remove_aiprovider_base_url.py @@ -0,0 +1,17 @@ +# Generated by Django 5.2.3 on 2026-03-04 17:50 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0020_aiprovider_base_url'), + ] + + operations = [ + migrations.RemoveField( + model_name='aiprovider', + name='base_url', + ), + ] diff --git a/backend/core/models/ai_provider.py b/backend/core/models/ai_provider.py index c12feef..4dd5571 100644 --- a/backend/core/models/ai_provider.py +++ b/backend/core/models/ai_provider.py @@ -9,9 +9,9 @@ class AIProvider(BaseModel): name = models.CharField(max_length=255, null=True, blank=True) provider = models.CharField(max_length=255) provider_api_key = EncryptedCharField(max_length=1000) - base_url = models.CharField(max_length=100) is_builtin = models.BooleanField(default=False, blank=True) - + metadata = models.JSONField(blank=True, null=True) + creator = models.ForeignKey(User, on_delete=models.CASCADE) class Meta: diff --git a/backend/core/models/app_ai_provider.py b/backend/core/models/app_ai_provider.py index 8c752a1..ab3f51a 100644 --- a/backend/core/models/app_ai_provider.py +++ b/backend/core/models/app_ai_provider.py @@ -21,7 +21,8 @@ class AppAIProvider(models.Model): capability = models.CharField(max_length=50, default='text') priority = models.PositiveIntegerField(default=100) external_model_id = models.CharField(max_length=255, blank=True, null=True) - + metadata = models.JSONField(blank=True, null=True) + is_active = models.BooleanField(default=True) created_at = models.DateTimeField(auto_now_add=True) diff --git a/backend/core/serializers/ai_provider.py b/backend/core/serializers/ai_provider.py index fb4e8e6..a81a7ab 100644 --- a/backend/core/serializers/ai_provider.py +++ b/backend/core/serializers/ai_provider.py @@ -8,9 +8,15 @@ class Meta: exclude = ['provider_api_key'] class AIProviderCreateSerializer(serializers.ModelSerializer): + base_url = serializers.CharField(max_length=500, required=False, allow_blank=True) + name = serializers.CharField(max_length=500, required=True, allow_blank=False) + class Meta: model = AIProvider - fields = ['uuid', 'name', 'provider', 'base_url', 'provider_api_key', 'creator'] + fields = [ + 'uuid', 'name', 'provider', 'provider_api_key', + 'base_url', 'creator' + ] read_only_fields = ['uuid', 'creator'] def __init__(self, *args, **kwargs): @@ -29,8 +35,32 @@ def validate_provider(self, value): ) return value + def validate(self, attrs): + provider = attrs.get('provider') + base_url = attrs.get('base_url', '') + + if provider == 'custom' and not base_url.strip(): + raise serializers.ValidationError({ + 'base_url': 'Custom provider requires base_url' + }) + + return attrs + def create(self, validated_data): + main_fields = ['name', 'provider', 'provider_api_key'] + + metadata = {} + for field, value in validated_data.items(): + if field not in main_fields: + metadata[field] = str(value).strip() if value is not None else '' + + for field in list(validated_data.keys()): + if field not in main_fields: + validated_data.pop(field) + + validated_data['metadata'] = metadata validated_data['creator'] = self.context['request'].user + return super().create(validated_data) def update(self, instance, validated_data): @@ -38,6 +68,22 @@ def update(self, instance, validated_data): if api_key and isinstance(api_key, str) and api_key.strip(): instance.provider_api_key = api_key + main_fields = ['name', 'provider'] + + new_metadata = {} + for field, value in validated_data.items(): + if field not in main_fields: + new_metadata[field] = str(value).strip() if value is not None else '' + + for field in list(validated_data.keys()): + if field not in main_fields: + validated_data.pop(field) + + existing_metadata = instance.metadata or {} + for field, value in new_metadata.items(): + existing_metadata[field] = value + validated_data['metadata'] = existing_metadata + for attr, value in validated_data.items(): setattr(instance, attr, value) diff --git a/backend/core/services/ai_service.py b/backend/core/services/ai_service.py index 0b69a5e..4b5fde4 100644 --- a/backend/core/services/ai_service.py +++ b/backend/core/services/ai_service.py @@ -1,5 +1,4 @@ -from typing import Optional, Any, Dict -from django.db.models import Q +from typing import Optional, Any from core.models import Application, AppAIProvider from .factories.ai_provider_factory import AIProviderFactory @@ -37,7 +36,7 @@ def get_provider_for_app(self, application: Application, context: str = 'widget' provider_instance = self.provider_factory.create_provider( provider_type=ai_provider.provider, api_key=ai_provider.provider_api_key, - base_url=ai_provider.base_url + config=ai_provider.metadata or {} ) return provider_instance diff --git a/backend/core/services/contracts/ai_provider_contract.py b/backend/core/services/contracts/ai_provider_contract.py index 50f188d..fa65055 100644 --- a/backend/core/services/contracts/ai_provider_contract.py +++ b/backend/core/services/contracts/ai_provider_contract.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Dict, Any class AIProviderContract(ABC): - def __init__(self, api_key: str, base_url: Optional[str] = None): + def __init__(self, api_key: str, config: Optional[Dict[str, Any]] = None): self.api_key = api_key - self.base_url = base_url + self.config = config or {} @abstractmethod def generate_content(self, model: str, contents: str, **kwargs) -> str: diff --git a/backend/core/services/factories/ai_provider_factory.py b/backend/core/services/factories/ai_provider_factory.py index 5c1907b..7027e45 100644 --- a/backend/core/services/factories/ai_provider_factory.py +++ b/backend/core/services/factories/ai_provider_factory.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Dict, Any from ..contracts.ai_provider_contract import AIProviderContract from ..providers.ai.custom_provider import CustomProvider @@ -12,7 +12,7 @@ class AIProviderFactory: } @staticmethod - def create_provider(provider_type: str, api_key: str, base_url: Optional[str] = None) -> AIProviderContract: + def create_provider(provider_type: str, api_key: str, config: Optional[Dict[str, Any]] = None) -> AIProviderContract: provider_class = AIProviderFactory.PROVIDER_CLASSES.get(provider_type.lower()) if provider_class is None: @@ -23,12 +23,12 @@ def create_provider(provider_type: str, api_key: str, base_url: Optional[str] = ) try: - return provider_class(api_key=api_key, base_url=base_url) + return provider_class(api_key=api_key, config=config or {}) except Exception as e: raise ValueError(f"Failed to create {provider_type} provider: {e}") @staticmethod - def validate_provider(provider_type: str, api_key: str, base_url: Optional[str] = None) -> tuple[bool, list[str]]: + def validate_provider(provider_type: str, api_key: str, config: Optional[Dict[str, Any]] = None) -> tuple[bool, list[str]]: provider_class = AIProviderFactory.PROVIDER_CLASSES.get(provider_type.lower()) if provider_class is None: @@ -39,7 +39,7 @@ def validate_provider(provider_type: str, api_key: str, base_url: Optional[str] ) try: - provider = provider_class(api_key=api_key, base_url=base_url) + provider = provider_class(api_key=api_key, config=config or {}) return provider.validate_connection() except Exception as e: return False, [] diff --git a/backend/core/services/providers/ai/custom_provider.py b/backend/core/services/providers/ai/custom_provider.py index cc5c5db..6581047 100644 --- a/backend/core/services/providers/ai/custom_provider.py +++ b/backend/core/services/providers/ai/custom_provider.py @@ -1,26 +1,20 @@ -from typing import Optional - +import json +import requests +from typing import Optional, Dict, Any, List from ...contracts.ai_provider_contract import AIProviderContract class CustomProvider(AIProviderContract): - def __init__(self, api_key: str, base_url: Optional[str] = None): - super().__init__(api_key, base_url) + def __init__(self, api_key: str, config: Optional[Dict[str, Any]] = None): + super().__init__(api_key, config) + + raise NotImplementedError("Not implemented") def generate_content(self, model: str, contents: str, **kwargs) -> str: - raise NotImplementedError( - "Custom provider generate_content method not implemented. " - "Please implement this method in your custom provider class." - ) + raise NotImplementedError("Not implemented") - def validate_connection(self) -> tuple[bool, list[str]]: - raise NotImplementedError( - "Custom provider validate_connection method not implemented. " - "Please implement this method in your custom provider class." - ) + def validate_connection(self) -> tuple[bool, List[str]]: + raise NotImplementedError("Not implemented") - def get_models(self) -> list[str]: - raise NotImplementedError( - "Custom provider get_supported_models method not implemented. " - "Please implement this method in your custom provider class." - ) + def get_models(self) -> List[str]: + raise NotImplementedError("Not implemented") diff --git a/backend/core/services/providers/ai/gemini_provider.py b/backend/core/services/providers/ai/gemini_provider.py index cf90056..eef2112 100644 --- a/backend/core/services/providers/ai/gemini_provider.py +++ b/backend/core/services/providers/ai/gemini_provider.py @@ -1,8 +1,7 @@ -from typing import Optional +from typing import Optional, Dict, Any from google import genai from ...contracts.ai_provider_contract import AIProviderContract - class GeminiProvider(AIProviderContract): SUPPORTED_MODELS = [ 'gemini-1.5-pro', @@ -11,8 +10,8 @@ class GeminiProvider(AIProviderContract): 'gemini-pro-vision' ] - def __init__(self, api_key: str, base_url: Optional[str] = None): - super().__init__(api_key, base_url) + def __init__(self, api_key: str, config: Optional[Dict[str, Any]] = None): + super().__init__(api_key, config) try: self.client = genai.Client(api_key=api_key) diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py index 4d79852..879ffa3 100644 --- a/backend/core/views/ai_provider.py +++ b/backend/core/views/ai_provider.py @@ -35,10 +35,20 @@ def create(self, request, *args, **kwargs): factory = AIProviderFactory() try: + main_fields = ['name', 'provider', 'provider_api_key'] + config = {} + + for field, value in validated_data.items(): + if field not in main_fields: + if field == 'timeout': + config[field] = int(value) if value is not None else None + else: + config[field] = str(value).strip() if value is not None else '' + is_valid, models = factory.validate_provider( provider_type=validated_data['provider'], api_key=validated_data['provider_api_key'], - base_url=validated_data.get('base_url') + config=config ) if not is_valid: diff --git a/frontend/components/AIProvider/NewAIProvider.vue b/frontend/components/AIProvider/NewAIProvider.vue index 141c5e5..4fd6075 100644 --- a/frontend/components/AIProvider/NewAIProvider.vue +++ b/frontend/components/AIProvider/NewAIProvider.vue @@ -138,6 +138,7 @@ const schema = z.object({ name: z.string().nonempty({ message: 'Required' }).min(1).max(255), base_url: z .string() + .url() .nonempty({ message: 'Required' }), provider: z.string().nonempty({ message: 'Required' }), provider_api_key: z.string().nonempty({ message: 'Required' }), @@ -149,7 +150,7 @@ const form = useForm({ name: '', base_url: '', provider_api_key: '', - provider: '' + provider: '', } }) const { isSubmitting, setFieldValue } = form From 58814306b74383dc1c09de066023630e08902a69 Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Thu, 5 Mar 2026 01:13:53 +0545 Subject: [PATCH 15/65] feat: extract and merge additional fields to metadata --- backend/core/serializers/ai_provider.py | 30 ++++--------------------- backend/core/utils.py | 24 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/backend/core/serializers/ai_provider.py b/backend/core/serializers/ai_provider.py index a81a7ab..11edc3f 100644 --- a/backend/core/serializers/ai_provider.py +++ b/backend/core/serializers/ai_provider.py @@ -1,6 +1,7 @@ from rest_framework import serializers from core.models.ai_provider import AIProvider from core.consts import SUPPORTED_AI_PROVIDERS +from core.utils import extract_and_merge_fields class AIProviderSerializer(serializers.ModelSerializer): class Meta: @@ -47,17 +48,7 @@ def validate(self, attrs): return attrs def create(self, validated_data): - main_fields = ['name', 'provider', 'provider_api_key'] - - metadata = {} - for field, value in validated_data.items(): - if field not in main_fields: - metadata[field] = str(value).strip() if value is not None else '' - - for field in list(validated_data.keys()): - if field not in main_fields: - validated_data.pop(field) - + metadata = extract_and_merge_fields(validated_data, ['name', 'provider', 'provider_api_key']) validated_data['metadata'] = metadata validated_data['creator'] = self.context['request'].user @@ -68,21 +59,8 @@ def update(self, instance, validated_data): if api_key and isinstance(api_key, str) and api_key.strip(): instance.provider_api_key = api_key - main_fields = ['name', 'provider'] - - new_metadata = {} - for field, value in validated_data.items(): - if field not in main_fields: - new_metadata[field] = str(value).strip() if value is not None else '' - - for field in list(validated_data.keys()): - if field not in main_fields: - validated_data.pop(field) - - existing_metadata = instance.metadata or {} - for field, value in new_metadata.items(): - existing_metadata[field] = value - validated_data['metadata'] = existing_metadata + metadata = extract_and_merge_fields(validated_data, ['name', 'provider'], instance.metadata or {}) + validated_data['metadata'] = metadata for attr, value in validated_data.items(): setattr(instance, attr, value) diff --git a/backend/core/utils.py b/backend/core/utils.py index 15087c7..7e084b9 100644 --- a/backend/core/utils.py +++ b/backend/core/utils.py @@ -1,5 +1,6 @@ import json import re +from typing import Any, Callable, Dict, List, Optional, Union def parse_llm_response(content: str) -> dict: try: @@ -12,3 +13,26 @@ def parse_llm_response(content: str) -> dict: except json.JSONDecodeError as e: print("Failed to parse LLM response as JSON:", e) raise + + +def extract_and_merge_fields( + validated_data: Dict[str, Any], + field_selector: Union[List[str], Callable[[str], bool]], + existing_data: Optional[Dict[str, Any]] = None, + merge: bool = True +) -> Dict[str, Any]: + extracted_data = existing_data.copy() if existing_data and merge else {} + + if isinstance(field_selector, list): + main_fields = set(field_selector) + fields_to_extract = [field for field in validated_data.keys() if field not in main_fields] + elif callable(field_selector): + fields_to_extract = [field for field in validated_data.keys() if field_selector(field)] + else: + raise ValueError("field_selector must be a list of field names or a callable") + + for field in fields_to_extract: + value = validated_data.pop(field) + extracted_data[field] = str(value).strip() if value is not None else '' + + return extracted_data From 1dc85f645fb1da466d7abcfbb185d60fd4c0aeeb Mon Sep 17 00:00:00 2001 From: Anish Ghimire Date: Fri, 6 Mar 2026 00:11:06 +0545 Subject: [PATCH 16/65] feat: integrate AI providers UI and refactored backend as needed --- backend/core/serializers/ai_provider.py | 14 +- backend/core/tests/factories.py | 2 +- backend/core/tests/test_ai_provider.py | 57 ++++-- backend/core/tests/test_app_ai_provider.py | 6 +- backend/core/views/ai_provider.py | 106 ++++++++++-- backend/core/views/app_ai_provider.py | 11 +- .../components/AIProvider/NewAIProvider.vue | 94 +++++----- .../AIProvider/UpdateAIProvider.vue | 150 ++++++++++------ frontend/components/C8APIAlert.vue | 23 +++ frontend/components/C8Dialog.vue | 81 +++++++++ frontend/components/C8Item.vue | 69 ++++++++ frontend/components/C8Select.vue | 32 +++- frontend/components/icons/GeminiIcon.vue | 162 ++++++++++++++++++ .../ui/alert-dialog/AlertDialog.vue | 15 ++ .../ui/alert-dialog/AlertDialogAction.vue | 18 ++ .../ui/alert-dialog/AlertDialogCancel.vue | 25 +++ .../ui/alert-dialog/AlertDialogContent.vue | 38 ++++ .../alert-dialog/AlertDialogDescription.vue | 22 +++ .../ui/alert-dialog/AlertDialogFooter.vue | 21 +++ .../ui/alert-dialog/AlertDialogHeader.vue | 16 ++ .../ui/alert-dialog/AlertDialogTitle.vue | 20 +++ .../ui/alert-dialog/AlertDialogTrigger.vue | 12 ++ frontend/components/ui/alert-dialog/index.ts | 9 + frontend/components/ui/item/Item.vue | 27 +++ frontend/components/ui/item/ItemActions.vue | 17 ++ frontend/components/ui/item/ItemContent.vue | 17 ++ .../components/ui/item/ItemDescription.vue | 21 +++ frontend/components/ui/item/ItemFooter.vue | 17 ++ frontend/components/ui/item/ItemGroup.vue | 18 ++ frontend/components/ui/item/ItemHeader.vue | 17 ++ frontend/components/ui/item/ItemMedia.vue | 21 +++ frontend/components/ui/item/ItemSeparator.vue | 18 ++ frontend/components/ui/item/ItemTitle.vue | 17 ++ frontend/components/ui/item/index.ts | 54 ++++++ .../components/ui/separator/Separator.vue | 19 +- frontend/composables/useAIProviderIcon.ts | 22 +++ frontend/composables/useApiErrorHandling.ts | 47 +++++ frontend/package-lock.json | 50 +++--- frontend/package.json | 2 +- frontend/pages/settings/ai-providers.vue | 115 ++++++++----- frontend/stores/aiProvider.ts | 27 +-- 41 files changed, 1279 insertions(+), 250 deletions(-) create mode 100644 frontend/components/C8APIAlert.vue create mode 100644 frontend/components/C8Dialog.vue create mode 100644 frontend/components/C8Item.vue create mode 100644 frontend/components/icons/GeminiIcon.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialog.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialogAction.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialogCancel.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialogContent.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialogDescription.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialogFooter.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialogHeader.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialogTitle.vue create mode 100644 frontend/components/ui/alert-dialog/AlertDialogTrigger.vue create mode 100644 frontend/components/ui/alert-dialog/index.ts create mode 100644 frontend/components/ui/item/Item.vue create mode 100644 frontend/components/ui/item/ItemActions.vue create mode 100644 frontend/components/ui/item/ItemContent.vue create mode 100644 frontend/components/ui/item/ItemDescription.vue create mode 100644 frontend/components/ui/item/ItemFooter.vue create mode 100644 frontend/components/ui/item/ItemGroup.vue create mode 100644 frontend/components/ui/item/ItemHeader.vue create mode 100644 frontend/components/ui/item/ItemMedia.vue create mode 100644 frontend/components/ui/item/ItemSeparator.vue create mode 100644 frontend/components/ui/item/ItemTitle.vue create mode 100644 frontend/components/ui/item/index.ts create mode 100644 frontend/composables/useAIProviderIcon.ts create mode 100644 frontend/composables/useApiErrorHandling.ts diff --git a/backend/core/serializers/ai_provider.py b/backend/core/serializers/ai_provider.py index 11edc3f..4323653 100644 --- a/backend/core/serializers/ai_provider.py +++ b/backend/core/serializers/ai_provider.py @@ -37,12 +37,18 @@ def validate_provider(self, value): return value def validate(self, attrs): - provider = attrs.get('provider') - base_url = attrs.get('base_url', '') + provider = attrs.get('provider') or (self.instance.provider if self.instance else None) + base_url = attrs.get('base_url') - if provider == 'custom' and not base_url.strip(): + if self.instance is None and provider == 'custom': + if base_url is None or not base_url.strip(): + raise serializers.ValidationError({ + 'base_url': 'Custom provider requires base url' + }) + + if self.instance is not None and base_url is not None and provider == 'custom' and not base_url.strip(): raise serializers.ValidationError({ - 'base_url': 'Custom provider requires base_url' + 'base_url': 'Custom provider requires base url' }) return attrs diff --git a/backend/core/tests/factories.py b/backend/core/tests/factories.py index e879d0e..915b582 100644 --- a/backend/core/tests/factories.py +++ b/backend/core/tests/factories.py @@ -30,6 +30,6 @@ class Meta: name = factory.Faker('company') provider = factory.Iterator(['openai', 'anthropic', 'google']) provider_api_key = factory.Faker('password') - base_url = factory.Faker('url') + metadata = factory.LazyAttribute(lambda obj: {'base_url': 'https://example.com'}) is_builtin = False creator = factory.SubFactory(UserFactory) diff --git a/backend/core/tests/test_ai_provider.py b/backend/core/tests/test_ai_provider.py index 38b8bc8..d10de67 100644 --- a/backend/core/tests/test_ai_provider.py +++ b/backend/core/tests/test_ai_provider.py @@ -1,11 +1,9 @@ import pytest from rest_framework import status -from rest_framework.test import APIClient -from django.contrib.auth.models import AnonymousUser +from unittest.mock import patch from core.models import AIProvider from core.serializers.ai_provider import AIProviderCreateSerializer -from core.consts import SUPPORTED_AI_PROVIDERS from core.tests.conftest import BaseAPITestCase from core.tests.factories import UserFactory, AIProviderFactory @@ -42,7 +40,7 @@ def test_list_ai_providers_authenticated_user(self): self.assertNotIn("Other User Provider", provider_names) provider_data = data['results'][0] - expected_fields = ['id', 'uuid', 'name', 'provider', 'base_url', 'is_builtin', 'creator', 'created_at', 'updated_at'] + expected_fields = ['id', 'uuid', 'name', 'provider', 'is_builtin', 'creator', 'created_at', 'updated_at', 'metadata'] for field in expected_fields: self.assertIn(field, provider_data) @@ -135,8 +133,11 @@ def test_delete_other_users_provider(self): self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - def test_create_ai_provider(self): + @patch('core.services.factories.ai_provider_factory.AIProviderFactory.validate_provider') + def test_create_ai_provider(self, mock_validate): """Test that authenticated user can create their own AI provider.""" + mock_validate.return_value = (True, ['gemini-1.5-pro', 'gemini-1.5-flash']) + user = UserFactory() self.client.force_authenticate(user=user) @@ -152,17 +153,23 @@ def test_create_ai_provider(self): self.assertEqual(response.status_code, status.HTTP_201_CREATED) data = response.json() - self.assertEqual(data['name'], 'My Gemini Provider') - self.assertEqual(data['provider'], 'gemini') - self.assertEqual(data['base_url'], 'https://generativelanguage.googleapis.com') - self.assertEqual(data['creator'], user.id) + self.assertEqual(data['ai_provider']['name'], 'My Gemini Provider') + self.assertEqual(data['ai_provider']['provider'], 'gemini') + self.assertEqual(data['ai_provider']['metadata']['base_url'], 'https://generativelanguage.googleapis.com') + self.assertEqual(data['ai_provider']['creator'], user.id) + self.assertTrue(data['validation']['is_valid']) + self.assertEqual(data['validation']['models'], ['gemini-1.5-pro', 'gemini-1.5-flash']) - provider = AIProvider.objects.get(uuid=data['uuid']) + provider = AIProvider.objects.get(uuid=data['ai_provider']['uuid']) self.assertEqual(provider.creator, user) self.assertEqual(provider.name, 'My Gemini Provider') + self.assertEqual(provider.metadata['base_url'], 'https://generativelanguage.googleapis.com') - def test_update_own_provider(self): + @patch('core.services.factories.ai_provider_factory.AIProviderFactory.validate_provider') + def test_update_own_provider(self, mock_validate): """Test that authenticated user can update their own AI provider.""" + mock_validate.return_value = (True, ['gemini-1.5-pro', 'gemini-1.5-flash']) + user = UserFactory() self.client.force_authenticate(user=user) @@ -211,7 +218,7 @@ def test_delete_own_provider(self): detail_url = f'/api/ai-providers/{provider.uuid}/' response = self.client.delete(detail_url) - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(response.status_code, status.HTTP_200_OK) with self.assertRaises(AIProvider.DoesNotExist): AIProvider.objects.get(id=provider.id) @@ -234,8 +241,11 @@ def test_update_without_api_key_does_not_change_api_key(self): self.assertEqual(provider.name, 'Updated Name') self.assertEqual(provider.provider_api_key, original_api_key) - def test_update_with_api_key_changes_api_key(self): + @patch('core.services.factories.ai_provider_factory.AIProviderFactory.validate_provider') + def test_update_with_api_key_changes_api_key(self, mock_validate): """Test that update request with provider api key updates the provider api key.""" + mock_validate.return_value = (True, ['gemini-1.5-pro', 'gemini-1.5-flash']) + user = UserFactory() self.client.force_authenticate(user=user) @@ -290,8 +300,11 @@ def test_update_with_whitespace_api_key_does_not_change_api_key(self): self.assertEqual(provider.name, 'Updated Name') self.assertEqual(provider.provider_api_key, original_api_key) - def test_api_key_is_encrypted_in_database(self): + @patch('core.services.factories.ai_provider_factory.AIProviderFactory.validate_provider') + def test_api_key_is_encrypted_in_database(self, mock_validate): """Test that provider_api_key is encrypted when stored in the database.""" + mock_validate.return_value = (True, ['gemini-1.5-pro', 'gemini-1.5-flash']) + user = UserFactory() self.client.force_authenticate(user=user) @@ -306,7 +319,7 @@ def test_api_key_is_encrypted_in_database(self): response = self.client.post(self.list_url, create_data, format='json') self.assertEqual(response.status_code, status.HTTP_201_CREATED) - provider = AIProvider.objects.get(uuid=response.json()['uuid']) + provider = AIProvider.objects.get(uuid=response.json()['ai_provider']['uuid']) self.assertEqual(provider.provider_api_key, api_key) @@ -318,8 +331,11 @@ def test_api_key_is_encrypted_in_database(self): self.assertTrue(raw_db_value.startswith('gAAAAA')) - def test_create_with_supported_provider_gemini(self): + @patch('core.services.factories.ai_provider_factory.AIProviderFactory.validate_provider') + def test_create_with_supported_provider_gemini(self, mock_validate): """Test that AI provider can be created with supported 'gemini' provider.""" + mock_validate.return_value = (True, ['gemini-1.5-pro', 'gemini-1.5-flash']) + user = UserFactory() data = { 'name': 'My Google Gemini Provider', @@ -335,11 +351,14 @@ def test_create_with_supported_provider_gemini(self): assert provider.name == 'My Google Gemini Provider' assert provider.provider == 'gemini' - assert provider.base_url == 'https://generativelanguage.googleapis.com' + assert provider.metadata['base_url'] == 'https://generativelanguage.googleapis.com' assert provider.creator == user - def test_create_with_supported_provider_custom(self): + @patch('core.services.factories.ai_provider_factory.AIProviderFactory.validate_provider') + def test_create_with_supported_provider_custom(self, mock_validate): """Test that AI provider can be created with supported 'custom' provider.""" + mock_validate.return_value = (True, ['custom-model-1', 'custom-model-2']) + user = UserFactory() data = { 'name': 'My Custom Provider', @@ -355,7 +374,7 @@ def test_create_with_supported_provider_custom(self): assert provider.name == 'My Custom Provider' assert provider.provider == 'custom' - assert provider.base_url == 'https://my-custom-api.com' + assert provider.metadata['base_url'] == 'https://my-custom-api.com' assert provider.creator == user def test_create_with_unsupported_provider_fails(self): diff --git a/backend/core/tests/test_app_ai_provider.py b/backend/core/tests/test_app_ai_provider.py index 6605fee..a0ad0cd 100644 --- a/backend/core/tests/test_app_ai_provider.py +++ b/backend/core/tests/test_app_ai_provider.py @@ -14,8 +14,8 @@ def setUp(self): self.ai_provider = AIProvider.objects.create( name='Test Provider', provider='openai', - base_url='https://api.openai.com', provider_api_key='test-key', + metadata={'base_url': 'https://api.openai.com'}, creator=self.user ) @@ -129,7 +129,7 @@ def test_delete_app_ai_provider(self): 'uuid': config.uuid }) response = self.client.delete(url) - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertFalse(AppAIProvider.objects.filter(id=config.id).exists()) def test_priority_auto_assignment(self): @@ -163,8 +163,8 @@ def test_unauthorized_access(self): other_ai_provider = AIProvider.objects.create( name='Other AI Provider', provider='openai', - base_url='https://api.openai.com', provider_api_key='test', + metadata={'base_url': 'https://api.openai.com'}, creator=other_user ) other_config = AppAIProvider.objects.create( diff --git a/backend/core/views/ai_provider.py b/backend/core/views/ai_provider.py index 879ffa3..ffb8908 100644 --- a/backend/core/views/ai_provider.py +++ b/backend/core/views/ai_provider.py @@ -25,31 +25,48 @@ def get_queryset(self): models.Q(creator=user) | models.Q(is_builtin=True) ) + def _validate_ai_provider(self, validated_data, instance=None): + from core.services.factories.ai_provider_factory import AIProviderFactory + + factory = AIProviderFactory() + main_fields = ['name', 'provider', 'provider_api_key'] + config = {} + + if instance: + current_data = { + 'name': instance.name, + 'provider': instance.provider, + 'provider_api_key': instance.provider_api_key + } + if instance.metadata: + config.update(instance.metadata) + update_data = {**current_data, **validated_data} + if not update_data['provider_api_key']: + update_data['provider_api_key'] = instance.provider_api_key + validation_data = update_data + else: + validation_data = validated_data + + for field, value in validation_data.items(): + if field not in main_fields: + config[field] = str(value).strip() if value is not None else '' + + is_valid, models = factory.validate_provider( + provider_type=validation_data['provider'], + api_key=validation_data['provider_api_key'], + config=config + ) + + return is_valid, models + def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) validated_data = serializer.validated_data - from core.services.factories.ai_provider_factory import AIProviderFactory - - factory = AIProviderFactory() try: - main_fields = ['name', 'provider', 'provider_api_key'] - config = {} - - for field, value in validated_data.items(): - if field not in main_fields: - if field == 'timeout': - config[field] = int(value) if value is not None else None - else: - config[field] = str(value).strip() if value is not None else '' - - is_valid, models = factory.validate_provider( - provider_type=validated_data['provider'], - api_key=validated_data['provider_api_key'], - config=config - ) + is_valid, models = self._validate_ai_provider(validated_data) if not is_valid: return Response( @@ -93,3 +110,56 @@ def list(self, request, *args, **kwargs): 'supported_ai_providers': SUPPORTED_AI_PROVIDERS } return response + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + instance = self.get_object() + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + + validated_data = serializer.validated_data + + api_key_to_validate = validated_data.get('provider_api_key') or instance.provider_api_key + + if api_key_to_validate and api_key_to_validate.strip(): + try: + is_valid, models = self._validate_ai_provider(validated_data, instance) + + if not is_valid: + return Response( + { + 'error': 'Failed to validate AI provider connection', + 'details': 'Unable to connect to the AI provider with the provided credentials' + }, + status=status.HTTP_400_BAD_REQUEST + ) + + except Exception as e: + return Response( + { + 'error': 'Failed to validate AI provider connection', + 'details': str(e) + }, + status=status.HTTP_400_BAD_REQUEST + ) + else: + return Response( + { + 'error': 'API key is required', + 'details': 'An API key must be provided to validate the AI provider connection' + }, + status=status.HTTP_400_BAD_REQUEST + ) + + updated_instance = serializer.save() + + response_serializer = AIProviderSerializer(updated_instance) + return Response(response_serializer.data) + + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + return Response( + {"detail": "deleted"}, + status=status.HTTP_200_OK + ) diff --git a/backend/core/views/app_ai_provider.py b/backend/core/views/app_ai_provider.py index af99045..fb4cfab 100644 --- a/backend/core/views/app_ai_provider.py +++ b/backend/core/views/app_ai_provider.py @@ -1,4 +1,5 @@ -from rest_framework import viewsets, permissions +from rest_framework import viewsets, permissions, status +from rest_framework.response import Response from django.shortcuts import get_object_or_404 from core.models.app_ai_provider import AppAIProvider @@ -52,3 +53,11 @@ def get_serializer_context(self): def perform_create(self, serializer): serializer.save() + + def destroy(self, request, *args, **kwargs): + instance = self.get_object() + self.perform_destroy(instance) + return Response( + {"detail": "deleted"}, + status=status.HTTP_200_OK + ) diff --git a/frontend/components/AIProvider/NewAIProvider.vue b/frontend/components/AIProvider/NewAIProvider.vue index 4fd6075..da2a78a 100644 --- a/frontend/components/AIProvider/NewAIProvider.vue +++ b/frontend/components/AIProvider/NewAIProvider.vue @@ -1,6 +1,6 @@