diff --git a/.gitignore b/.gitignore index 2e2ba01e5..fa384fa63 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ -# Logs -logs +# LOGS + +logs/ *.log npm-debug.log* yarn-debug.log* @@ -7,10 +8,10 @@ yarn-error.log* pnpm-debug.log* lerna-debug.log* -# Editor directories and files +# EDITORS / OS .vscode/* !.vscode/extensions.json -.idea +.idea/ .DS_Store *.suo *.ntvs* @@ -18,16 +19,31 @@ lerna-debug.log* *.sln *.sw? -# Client -client/node_modules -client/dist -client/dist-ssr -client/*.local +# NODE / NEXT.JS +node_modules/ +.next/ +out/ + +# PYTHON +venv/ +__pycache__/ +*.pyc + +# ENV FILES +.env +.env.local + +# FRONTEND + +frontend/node_modules/ +frontend/.next/ + -# Server -server/.env -server/node_modules -server/dist +# BACKEND +backend/venv/ +backend/__pycache__/ +backend/uploads/ -# prompts +# TEMP / NOTES +*.txt prompts/debug diff --git a/backend/.gitignore b/backend/.gitignore index 619b62572..aa0a3160c 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -26,6 +26,7 @@ share/python-wheels/ *.egg MANIFEST + # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. diff --git a/backend/authentication/migrations/0002_customuser_role_alter_customuser_is_active.py b/backend/authentication/migrations/0002_customuser_role_alter_customuser_is_active.py new file mode 100644 index 000000000..24c02f939 --- /dev/null +++ b/backend/authentication/migrations/0002_customuser_role_alter_customuser_is_active.py @@ -0,0 +1,30 @@ +# Generated by Django 5.0.2 on 2026-01-29 18:22 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("authentication", "0001_initial"), + ("chat", "0004_uploadedfile_uploaded_by"), + ] + + operations = [ + migrations.AddField( + model_name="customuser", + name="role", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="users", + to="chat.role", + ), + ), + migrations.AlterField( + model_name="customuser", + name="is_active", + field=models.BooleanField(default=True), + ), + ] diff --git a/backend/authentication/models.py b/backend/authentication/models.py index 4a565e6cd..60a413594 100644 --- a/backend/authentication/models.py +++ b/backend/authentication/models.py @@ -1,4 +1,8 @@ -from django.contrib.auth.models import AbstractBaseUser, BaseUserManager, PermissionsMixin +from django.contrib.auth.models import ( + AbstractBaseUser, + BaseUserManager, + PermissionsMixin, +) from django.db import models @@ -17,8 +21,8 @@ def create_user(self, email, password, **extra_fields): email = self.normalize_email(email) user = self.model(email=email, **extra_fields) user.set_password(password) + user.is_active = True user.save(using=self._db) - return user def create_superuser(self, email, password, **extra_fields): @@ -31,9 +35,20 @@ def create_superuser(self, email, password, **extra_fields): class CustomUser(AbstractBaseUser, PermissionsMixin): email = models.EmailField(unique=True) - is_active = models.BooleanField(default=False) + + # 🔐 Django auth flags + is_active = models.BooleanField(default=True) is_staff = models.BooleanField(default=False) + # 🔑 ROLE-BASED ACCESS CONTROL + role = models.ForeignKey( + "chat.Role", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="users", + ) + objects = CustomUserManager() USERNAME_FIELD = "email" diff --git a/backend/authentication/views.py b/backend/authentication/views.py index 068100805..b9da9c1a7 100644 --- a/backend/authentication/views.py +++ b/backend/authentication/views.py @@ -1,79 +1,192 @@ +# from django.conf import settings +# from django.contrib.auth import authenticate, login, logout +# from django.http import JsonResponse +# from django.middleware.csrf import get_token +# from rest_framework import status +# from rest_framework.decorators import api_view + +# from authentication.models import CustomUser + + +# @api_view(["GET"]) +# def auth_root_view(request): +# return JsonResponse({"message": "Auth endpoint works!"}) + + +# @api_view(["GET"]) +# def csrf_token(request): +# token = get_token(request) +# return JsonResponse({"data": token}) + + +# @api_view(["POST"]) +# def login_view(request): +# email = request.data.get("email") +# password = request.data.get("password") + +# try: +# user = CustomUser.objects.get(email=email) +# except CustomUser.DoesNotExist: +# return JsonResponse({"error": "Invalid credentials"}, status=status.HTTP_401_UNAUTHORIZED) + +# # Check if the user is active +# if not user.is_active: +# return JsonResponse({"error": "User is not active"}, status=status.HTTP_401_UNAUTHORIZED) + +# user = authenticate(request, email=email, password=password) +# if user is not None: +# login(request, user) +# response = JsonResponse({"data": "Login successful"}) + +# # Set session cookie manually +# session_key = request.session.session_key +# session_cookie_name = settings.SESSION_COOKIE_NAME +# max_age = settings.SESSION_COOKIE_AGE +# response.set_cookie(session_cookie_name, session_key, max_age=max_age, httponly=True) + +# return response +# else: +# return JsonResponse({"error": "Invalid credentials"}, status=status.HTTP_401_UNAUTHORIZED) + + +# @api_view(["POST"]) +# def logout_view(request): +# logout(request) +# response = JsonResponse({"data": "Logout successful"}) +# response.delete_cookie(settings.SESSION_COOKIE_NAME) + +# return response + + +# @api_view(["POST"]) +# def register_view(request): +# email = request.data.get("email") +# password = request.data.get("password") +# if not email or not password: +# return JsonResponse({"error": "Email and password are required"}, status=status.HTTP_400_BAD_REQUEST) + +# if CustomUser.objects.filter(email=email).exists(): +# return JsonResponse({"error": "Email is already taken"}, status=status.HTTP_400_BAD_REQUEST) + +# CustomUser.objects.create_user(email, password=password) +# return JsonResponse({"data": "User created successfully"}, status=status.HTTP_201_CREATED) + + +# @api_view(["GET"]) +# def verify_session(request): +# session_cookie = request.COOKIES.get("sessionid") +# is_authenticated = request.user.is_authenticated and session_cookie == request.session.session_key +# return JsonResponse({"data": is_authenticated}) from django.conf import settings from django.contrib.auth import authenticate, login, logout from django.http import JsonResponse from django.middleware.csrf import get_token +from django.views.decorators.csrf import csrf_exempt + from rest_framework import status -from rest_framework.decorators import api_view +from rest_framework.decorators import ( + api_view, + permission_classes, + authentication_classes, +) +from rest_framework.permissions import AllowAny, IsAuthenticated from authentication.models import CustomUser @api_view(["GET"]) +@permission_classes([AllowAny]) def auth_root_view(request): return JsonResponse({"message": "Auth endpoint works!"}) @api_view(["GET"]) +@permission_classes([AllowAny]) def csrf_token(request): - token = get_token(request) - return JsonResponse({"data": token}) + return JsonResponse({"csrfToken": get_token(request)}) +# --------------------------- +# LOGIN (CSRF FIXED) +# --------------------------- +@csrf_exempt @api_view(["POST"]) +@authentication_classes([]) # 🔴 disable DRF SessionAuthentication +@permission_classes([AllowAny]) def login_view(request): email = request.data.get("email") password = request.data.get("password") - try: - user = CustomUser.objects.get(email=email) - except CustomUser.DoesNotExist: - return JsonResponse({"error": "Invalid credentials"}, status=status.HTTP_401_UNAUTHORIZED) - - # Check if the user is active - if not user.is_active: - return JsonResponse({"error": "User is not active"}, status=status.HTTP_401_UNAUTHORIZED) + if not email or not password: + return JsonResponse( + {"error": "Email and password required"}, + status=status.HTTP_400_BAD_REQUEST, + ) user = authenticate(request, email=email, password=password) - if user is not None: - login(request, user) - response = JsonResponse({"data": "Login successful"}) + if not user: + return JsonResponse( + {"error": "Invalid credentials"}, + status=status.HTTP_401_UNAUTHORIZED, + ) - # Set session cookie manually - session_key = request.session.session_key - session_cookie_name = settings.SESSION_COOKIE_NAME - max_age = settings.SESSION_COOKIE_AGE - response.set_cookie(session_cookie_name, session_key, max_age=max_age, httponly=True) - - return response - else: - return JsonResponse({"error": "Invalid credentials"}, status=status.HTTP_401_UNAUTHORIZED) + if not user.is_active: + return JsonResponse( + {"error": "User inactive"}, + status=status.HTTP_403_FORBIDDEN, + ) + + login(request, user) + + response = JsonResponse({"message": "Login successful"}) + response.set_cookie( + settings.SESSION_COOKIE_NAME, + request.session.session_key, + httponly=True, + ) + return response +# --------------------------- +# LOGOUT (CSRF FIXED) +# --------------------------- +@csrf_exempt @api_view(["POST"]) +@authentication_classes([]) # 🔴 disable DRF SessionAuthentication +@permission_classes([AllowAny]) def logout_view(request): logout(request) - response = JsonResponse({"data": "Logout successful"}) + response = JsonResponse({"message": "Logout successful"}) response.delete_cookie(settings.SESSION_COOKIE_NAME) - return response @api_view(["POST"]) +@permission_classes([AllowAny]) def register_view(request): email = request.data.get("email") password = request.data.get("password") + if not email or not password: - return JsonResponse({"error": "Email and password are required"}, status=status.HTTP_400_BAD_REQUEST) + return JsonResponse( + {"error": "Email and password required"}, + status=status.HTTP_400_BAD_REQUEST, + ) if CustomUser.objects.filter(email=email).exists(): - return JsonResponse({"error": "Email is already taken"}, status=status.HTTP_400_BAD_REQUEST) + return JsonResponse( + {"error": "Email already exists"}, + status=status.HTTP_400_BAD_REQUEST, + ) - CustomUser.objects.create_user(email, password=password) - return JsonResponse({"data": "User created successfully"}, status=status.HTTP_201_CREATED) + CustomUser.objects.create_user(email=email, password=password) + return JsonResponse( + {"message": "User created"}, + status=status.HTTP_201_CREATED, + ) @api_view(["GET"]) +@permission_classes([IsAuthenticated]) def verify_session(request): - session_cookie = request.COOKIES.get("sessionid") - is_authenticated = request.user.is_authenticated and session_cookie == request.session.session_key - return JsonResponse({"data": is_authenticated}) + return JsonResponse({"authenticated": True}) diff --git a/backend/backend/settings.py b/backend/backend/settings.py index 9de4f024a..4a85c1187 100644 --- a/backend/backend/settings.py +++ b/backend/backend/settings.py @@ -30,10 +30,10 @@ # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True -ALLOWED_HOSTS = [] - +ALLOWED_HOSTS = [ + "*" +] # Application definition - INSTALLED_APPS = [ "django.contrib.admin", "django.contrib.auth", @@ -44,20 +44,21 @@ "corsheaders", "rest_framework", "nested_admin", + "django_crontab", "authentication", "chat", "gpt", ] + MIDDLEWARE = [ + "corsheaders.middleware.CorsMiddleware", "django.middleware.security.SecurityMiddleware", "django.contrib.sessions.middleware.SessionMiddleware", "django.middleware.common.CommonMiddleware", "django.middleware.csrf.CsrfViewMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.messages.middleware.MessageMiddleware", - "django.middleware.clickjacking.XFrameOptionsMiddleware", - "corsheaders.middleware.CorsMiddleware", ] ROOT_URLCONF = "backend.urls" @@ -81,13 +82,14 @@ WSGI_APPLICATION = "backend.wsgi.application" -# Database -# https://docs.djangoproject.com/en/4.2/ref/settings/#databases - DATABASES = { "default": { - "ENGINE": "django.db.backends.sqlite3", - "NAME": BASE_DIR / "db.sqlite3", + "ENGINE": "django.db.backends.postgresql", + "NAME": "soulpage_db", + "USER": "soulpage_user", + "PASSWORD": "strongpassword", + "HOST": "localhost", + "PORT": "5432", } } @@ -138,10 +140,13 @@ DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" CORS_ALLOWED_ORIGINS = [ - FRONTEND_URL, + "http://localhost:3000", + "http://127.0.0.1:3000", ] + CORS_ALLOW_CREDENTIALS = True + CSRF_TRUSTED_ORIGINS = [ FRONTEND_URL, ] @@ -149,3 +154,22 @@ SESSION_COOKIE_SECURE = True CSRF_COOKIE_SECURE = True CSRF_COOKIE_SAMESITE = "None" + +# BYPASS AUTHENTICATION +REST_FRAMEWORK = { + "DEFAULT_AUTHENTICATION_CLASSES": [ + "rest_framework.authentication.BasicAuthentication", + "rest_framework.authentication.SessionAuthentication", + ], +} + + +# Cron jobs + +CRONJOBS = [ + ( + "*/9 * * * *", + "django.core.management.call_command", + ["cleanup_conversations"], + ), +] diff --git a/backend/chat/management/commands/cleanup_conversations.py b/backend/chat/management/commands/cleanup_conversations.py new file mode 100644 index 000000000..d1daa4825 --- /dev/null +++ b/backend/chat/management/commands/cleanup_conversations.py @@ -0,0 +1,32 @@ +import logging +from datetime import timedelta +from django.core.management.base import BaseCommand +from django.utils import timezone +from chat.models import Conversation + +logger = logging.getLogger(__name__) + +class Command(BaseCommand): + help = "Clean up old or deleted conversations" + + def handle(self, *args, **options): + cutoff = timezone.now() - timedelta(days=0) + + deleted_qs = Conversation.objects.filter(deleted_at__isnull=False) + old_qs = Conversation.objects.filter(created_at__lt=cutoff) + + deleted_count = deleted_qs.count() + old_count = old_qs.count() + + deleted_qs.delete() + old_qs.delete() + + logger.info( + f"[CRON] Cleanup ran: {deleted_count} deleted, {old_count} old removed" + ) + + self.stdout.write( + self.style.SUCCESS( + f"Cleanup ran: {deleted_count} deleted, {old_count} old" + ) + ) diff --git a/backend/chat/migrations/0002_conversation_summary.py b/backend/chat/migrations/0002_conversation_summary.py new file mode 100644 index 000000000..8834c46c9 --- /dev/null +++ b/backend/chat/migrations/0002_conversation_summary.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.2 on 2026-01-27 06:27 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("chat", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="conversation", + name="summary", + field=models.TextField(blank=True, null=True), + ), + ] diff --git a/backend/chat/migrations/0003_uploadedfile.py b/backend/chat/migrations/0003_uploadedfile.py new file mode 100644 index 000000000..cc7dcb596 --- /dev/null +++ b/backend/chat/migrations/0003_uploadedfile.py @@ -0,0 +1,22 @@ +# Generated by Django 5.0.2 on 2026-01-28 07:15 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("chat", "0002_conversation_summary"), + ] + + operations = [ + migrations.CreateModel( + name="UploadedFile", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("file", models.FileField(upload_to="uploads/")), + ("filename", models.CharField(max_length=255)), + ("file_hash", models.CharField(max_length=64, unique=True)), + ("uploaded_at", models.DateTimeField(auto_now_add=True)), + ], + ), + ] diff --git a/backend/chat/migrations/0004_uploadedfile_uploaded_by.py b/backend/chat/migrations/0004_uploadedfile_uploaded_by.py new file mode 100644 index 000000000..50fab3357 --- /dev/null +++ b/backend/chat/migrations/0004_uploadedfile_uploaded_by.py @@ -0,0 +1,26 @@ +# Generated by Django 5.0.2 on 2026-01-29 09:40 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("chat", "0003_uploadedfile"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name="uploadedfile", + name="uploaded_by", + field=models.ForeignKey( + default=1, + on_delete=django.db.models.deletion.CASCADE, + related_name="uploaded_files", + to=settings.AUTH_USER_MODEL, + ), + preserve_default=False, + ), + ] diff --git a/backend/chat/models.py b/backend/chat/models.py index 242788f14..ebe925f29 100644 --- a/backend/chat/models.py +++ b/backend/chat/models.py @@ -1,10 +1,13 @@ import uuid +import hashlib from django.db import models - from authentication.models import CustomUser + +# Role Model + class Role(models.Model): name = models.CharField(max_length=20, blank=False, null=False, default="user") @@ -12,14 +15,24 @@ def __str__(self): return self.name +# Conversation Model + class Conversation(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) title = models.CharField(max_length=100, blank=False, null=False, default="Mock title") + summary = models.TextField(blank=True, null=True) + created_at = models.DateTimeField(auto_now_add=True) modified_at = models.DateTimeField(auto_now=True) + active_version = models.ForeignKey( - "Version", null=True, blank=True, on_delete=models.CASCADE, related_name="current_version_conversations" + "Version", + null=True, + blank=True, + on_delete=models.CASCADE, + related_name="current_version_conversations", ) + deleted_at = models.DateTimeField(null=True, blank=True) user = models.ForeignKey(CustomUser, on_delete=models.CASCADE) @@ -31,35 +44,107 @@ def version_count(self): version_count.short_description = "Number of versions" + def generate_summary(self): + if not self.active_version: + return "" + + messages = self.active_version.messages.all() + full_text = " ".join(message.content for message in messages) + return full_text[:200] + "..." if len(full_text) > 200 else full_text + def save(self, *args, **kwargs): + self.summary = self.generate_summary() + super().save(*args, **kwargs) + + + +# Version Model class Version(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - conversation = models.ForeignKey("Conversation", related_name="versions", on_delete=models.CASCADE) - parent_version = models.ForeignKey("self", null=True, blank=True, on_delete=models.SET_NULL) + + conversation = models.ForeignKey( + "Conversation", + related_name="versions", + on_delete=models.CASCADE, + ) + + parent_version = models.ForeignKey( + "self", + null=True, + blank=True, + on_delete=models.SET_NULL, + ) + root_message = models.ForeignKey( - "Message", null=True, blank=True, on_delete=models.SET_NULL, related_name="root_message_versions" + "Message", + null=True, + blank=True, + on_delete=models.SET_NULL, + related_name="root_message_versions", ) def __str__(self): if self.root_message: return f"Version of `{self.conversation.title}` created at `{self.root_message.created_at}`" - else: - return f"Version of `{self.conversation.title}` with no root message yet" + return f"Version of `{self.conversation.title}` with no root message yet" + +# Message Model class Message(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) content = models.TextField(blank=False, null=False) role = models.ForeignKey(Role, on_delete=models.CASCADE) created_at = models.DateTimeField(auto_now_add=True) - version = models.ForeignKey("Version", related_name="messages", on_delete=models.CASCADE) + + version = models.ForeignKey( + "Version", + related_name="messages", + on_delete=models.CASCADE, + ) class Meta: ordering = ["created_at"] def save(self, *args, **kwargs): - self.version.conversation.save() super().save(*args, **kwargs) + self.version.conversation.save() def __str__(self): return f"{self.role}: {self.content[:20]}..." + + +class UploadedFile(models.Model): + """ + Stores uploaded files with SHA-256 hash + and tracks uploader for RBAC + """ + + file = models.FileField(upload_to="uploads/") + filename = models.CharField(max_length=255) + + # Ownership (REQUIRED for RBAC) + uploaded_by = models.ForeignKey( + CustomUser, + on_delete=models.CASCADE, + related_name="uploaded_files", + ) + + # Duplicate detection + file_hash = models.CharField(max_length=64, unique=True) + + uploaded_at = models.DateTimeField(auto_now_add=True) + + def save(self, *args, **kwargs): + if not self.file_hash and self.file: + hasher = hashlib.sha256() + for chunk in self.file.chunks(): + hasher.update(chunk) + + self.file_hash = hasher.hexdigest() + self.filename = self.file.name + + super().save(*args, **kwargs) + + def __str__(self): + return self.filename diff --git a/backend/chat/serializers.py b/backend/chat/serializers.py index 0c721c061..4e1d087a4 100644 --- a/backend/chat/serializers.py +++ b/backend/chat/serializers.py @@ -2,13 +2,18 @@ from django.utils import timezone from rest_framework import serializers -from chat.models import Conversation, Message, Role, Version - +from chat.models import ( + Conversation, + Message, + Role, + Version, + UploadedFile, # NEW MODEL for file upload tasks +) def should_serialize(validated_data, field_name) -> bool: if validated_data.get(field_name) is not None: return True - + return False class TitleSerializer(serializers.Serializer): title = serializers.CharField(max_length=100, required=True) @@ -18,27 +23,30 @@ class VersionTimeIdSerializer(serializers.Serializer): id = serializers.UUIDField() created_at = serializers.DateTimeField() - class MessageSerializer(serializers.ModelSerializer): - role = serializers.SlugRelatedField(slug_field="name", queryset=Role.objects.all()) + role = serializers.SlugRelatedField( + slug_field="name", + queryset=Role.objects.all() + ) class Meta: model = Message fields = [ - "id", # DB + "id", "content", - "role", # required - "created_at", # DB, read-only + "role", + "created_at", ] read_only_fields = ["id", "created_at", "version"] def create(self, validated_data): - message = Message.objects.create(**validated_data) - return message + # Creates a message instance in DB + return Message.objects.create(**validated_data) def to_representation(self, instance): + # Adds "versions" field for frontend compatibility representation = super().to_representation(instance) - representation["versions"] = [] # add versions field + representation["versions"] = [] return representation @@ -52,21 +60,21 @@ class Meta: model = Version fields = [ "id", - "conversation_id", # DB + "conversation_id", "root_message", "messages", "active", - "created_at", # DB, read-only - "parent_version", # optional + "created_at", + "parent_version", ] read_only_fields = ["id", "conversation"] - @staticmethod - def get_active(obj): + def get_active(self, obj): + # Marks which version is currently active return obj == obj.conversation.active_version - @staticmethod - def get_created_at(obj): + def get_created_at(self, obj): + # Uses root message time if available if obj.root_message is None: return timezone.localtime(obj.conversation.created_at) return timezone.localtime(obj.root_message.created_at) @@ -74,6 +82,7 @@ def get_created_at(obj): def create(self, validated_data): messages_data = validated_data.pop("messages") version = Version.objects.create(**validated_data) + for message_data in messages_data: Message.objects.create(version=version, **message_data) @@ -83,6 +92,8 @@ def update(self, instance, validated_data): instance.conversation = validated_data.get("conversation", instance.conversation) instance.parent_version = validated_data.get("parent_version", instance.parent_version) instance.root_message = validated_data.get("root_message", instance.root_message) + + # Ensure at least one updatable field is provided if not any( [ should_serialize(validated_data, "conversation"), @@ -91,18 +102,22 @@ def update(self, instance, validated_data): ] ): raise ValidationError( - "At least one of the following fields must be provided: conversation, parent_version, root_message" + "At least one field must be provided: " + "conversation, parent_version, root_message" ) + instance.save() messages_data = validated_data.pop("messages", []) for message_data in messages_data: if "id" in message_data: + # Update existing message message = Message.objects.get(id=message_data["id"], version=instance) message.content = message_data.get("content", message.content) message.role = message_data.get("role", message.role) message.save() else: + # Create new message Message.objects.create(version=instance, **message_data) return instance @@ -114,39 +129,79 @@ class ConversationSerializer(serializers.ModelSerializer): class Meta: model = Conversation fields = [ - "id", # DB - "title", # required + "id", + "title", + "summary", # STORED conversation summary "active_version", - "versions", # optional - "modified_at", # DB, read-only + "versions", + "modified_at", ] def create(self, validated_data): versions_data = validated_data.pop("versions", []) conversation = Conversation.objects.create(**validated_data) + for version_data in versions_data: - version_serializer = VersionSerializer(data=version_data) - if version_serializer.is_valid(): - version_serializer.save(conversation=conversation) + serializer = VersionSerializer(data=version_data) + if serializer.is_valid(): + serializer.save(conversation=conversation) return conversation def update(self, instance, validated_data): instance.title = validated_data.get("title", instance.title) - active_version_id = validated_data.get("active_version", instance.active_version) + + active_version_id = validated_data.get( + "active_version", instance.active_version + ) if active_version_id is not None: - active_version = Version.objects.get(id=active_version_id) - instance.active_version = active_version + instance.active_version = Version.objects.get(id=active_version_id) + instance.save() versions_data = validated_data.pop("versions", []) for version_data in versions_data: if "id" in version_data: version = Version.objects.get(id=version_data["id"], conversation=instance) - version_serializer = VersionSerializer(version, data=version_data) + serializer = VersionSerializer(version, data=version_data) else: - version_serializer = VersionSerializer(data=version_data) - if version_serializer.is_valid(): - version_serializer.save(conversation=instance) + serializer = VersionSerializer(data=version_data) + + if serializer.is_valid(): + serializer.save(conversation=instance) return instance + +class ConversationSummarySerializer(serializers.ModelSerializer): + class Meta: + model = Conversation + fields = [ + "id", + "title", + "summary", + "created_at", + ] + + + +class FileUploadSerializer(serializers.ModelSerializer): + class Meta: + model = UploadedFile + fields = [ + "id", + "file", + "filename", + "uploaded_at", + ] + read_only_fields = ["id", "filename", "uploaded_at"] + + +class FileListSerializer(serializers.ModelSerializer): + class Meta: + model = UploadedFile + fields = [ + "id", + "filename", + "file", + "uploaded_at", + ] diff --git a/backend/chat/urls.py b/backend/chat/urls.py index bd8ceadc0..f89f78821 100644 --- a/backend/chat/urls.py +++ b/backend/chat/urls.py @@ -1,22 +1,51 @@ -from django.urls import path +# from django.urls import path +# from chat import views + +# urlpatterns = [ +# path("", views.chat_root_view, name="chat_root"), + +# path("conversations/", views.get_conversations), +# path("conversations_branched/", views.get_conversations_branched), +# path("conversation_branched//", views.get_conversation_branched), + +# path("conversations/add/", views.add_conversation), +# path("conversations//messages/", views.conversation_add_message), +# path("files/upload/", views.upload_file), +# path("files-uploaded/", views.list_uploaded_files), +# path("files//", views.delete_uploaded_file), +# ] +from django.urls import path from chat import views urlpatterns = [ - path("", views.chat_root_view, name="chat_root_view"), - path("conversations/", views.get_conversations, name="get_conversations"), - path("conversations_branched/", views.get_conversations_branched, name="get_branched_conversations"), - path("conversation_branched//", views.get_conversation_branched, name="get_branched_conversation"), - path("conversations/add/", views.add_conversation, name="add_conversation"), - path("conversations//", views.conversation_manage, name="conversation_manage"), - path("conversations//change_title/", views.conversation_change_title, name="conversation_change_title"), - path("conversations//add_message/", views.conversation_add_message, name="conversation_add_message"), - path("conversations//add_version/", views.conversation_add_version, name="conversation_add_version"), + path("", views.chat_root_view), + + # ✅ SUMMARIES MUST COME FIRST path( - "conversations//switch_version//", - views.conversation_switch_version, - name="conversation_switch_version", + "conversations/summaries/", + views.conversation_summaries, ), - path("conversations//delete/", views.conversation_soft_delete, name="conversation_delete"), - path("versions//add_message/", views.version_add_message, name="version_add_message"), + + # Conversations + path("conversations/", views.get_conversations), + path("conversations_branched/", views.get_conversations_branched), + + path("conversation_branched//", views.get_conversation_branched), + + path("conversations/add/", views.add_conversation), + + path( + "conversations//messages/", + views.conversation_add_message, + ), + path( + "conversations//add_message/", + views.conversation_add_message, + ), + + # Files + path("files/upload/", views.upload_file), + path("files-uploaded/", views.list_uploaded_files), + path("files//", views.delete_uploaded_file), ] diff --git a/backend/chat/views.py b/backend/chat/views.py index 0d18f7a69..dd25b85df 100644 --- a/backend/chat/views.py +++ b/backend/chat/views.py @@ -1,232 +1,285 @@ -from django.contrib.auth.decorators import login_required +import hashlib from django.utils import timezone +from django.core.paginator import Paginator +from django.db.models import Q +from django.shortcuts import get_object_or_404 + from rest_framework import status -from rest_framework.decorators import api_view +from rest_framework.decorators import api_view, permission_classes, authentication_classes +from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response - -from chat.models import Conversation, Message, Version -from chat.serializers import ConversationSerializer, MessageSerializer, TitleSerializer, VersionSerializer +from rest_framework.authentication import SessionAuthentication + +from authentication.models import CustomUser +from chat.models import ( + Conversation, + Message, + Version, + Role, + UploadedFile, +) +from chat.serializers import ( + ConversationSerializer, + MessageSerializer, + TitleSerializer, + ConversationSummarySerializer, + FileUploadSerializer, + FileListSerializer, +) from chat.utils.branching import make_branched_conversation +class CsrfExemptSessionAuthentication(SessionAuthentication): + def enforce_csrf(self, request): + return + +def user_has_role(user, roles): + if not user or not user.is_authenticated: + return False + + if "admin" in roles and user.is_superuser: + return True + + if "editor" in roles and user.is_staff: + return True + + return False + +# Chat basic endpoints + @api_view(["GET"]) +@permission_classes([AllowAny]) def chat_root_view(request): - return Response({"message": "Chat works!"}, status=status.HTTP_200_OK) + return Response({"message": "Chat works!"}) -@login_required @api_view(["GET"]) +@permission_classes([AllowAny]) def get_conversations(request): - conversations = Conversation.objects.filter(user=request.user, deleted_at__isnull=True).order_by("-modified_at") - serializer = ConversationSerializer(conversations, many=True) - return Response(serializer.data, status=status.HTTP_200_OK) + conversations = Conversation.objects.filter( + deleted_at__isnull=True + ).order_by("-modified_at") + return Response(ConversationSerializer(conversations, many=True).data) -@login_required @api_view(["GET"]) +@permission_classes([AllowAny]) def get_conversations_branched(request): - conversations = Conversation.objects.filter(user=request.user, deleted_at__isnull=True).order_by("-modified_at") - conversations_serializer = ConversationSerializer(conversations, many=True) - conversations_data = conversations_serializer.data + conversations = Conversation.objects.filter( + deleted_at__isnull=True + ).order_by("-modified_at") - for conversation_data in conversations_data: - make_branched_conversation(conversation_data) + data = ConversationSerializer(conversations, many=True).data + for c in data: + make_branched_conversation(c) - return Response(conversations_data, status=status.HTTP_200_OK) + return Response(data) -@login_required @api_view(["GET"]) +@permission_classes([AllowAny]) def get_conversation_branched(request, pk): - try: - conversation = Conversation.objects.get(user=request.user, pk=pk) - except Conversation.DoesNotExist: - return Response({"detail": "Conversation not found"}, status=status.HTTP_404_NOT_FOUND) + conversation = get_object_or_404(Conversation, pk=pk) + data = ConversationSerializer(conversation).data + make_branched_conversation(data) + return Response(data) - conversation_serializer = ConversationSerializer(conversation) - conversation_data = conversation_serializer.data - make_branched_conversation(conversation_data) - return Response(conversation_data, status=status.HTTP_200_OK) +@api_view(["POST"]) +@permission_classes([AllowAny]) +def add_conversation(request): + if request.user.is_authenticated: + user = request.user + else: + user, _ = CustomUser.objects.get_or_create( + email="anonymous@soulpage.local", + defaults={"is_active": True}, + ) + + title = str(request.data.get("title", "New Chat"))[:100] + conversation = Conversation.objects.create(title=title, user=user) + version = Version.objects.create(conversation=conversation) + + messages = request.data.get("messages") or [] + for msg in messages: + content = (msg.get("content") or "").strip() + if not content: + continue + + role_name = msg.get("role", "user") + role, _ = Role.objects.get_or_create(name=str(role_name)) + + Message.objects.create( + content=content, + role=role, + version=version, + ) + + conversation.active_version = version + conversation.save() + + return Response( + ConversationSerializer(conversation).data, + status=status.HTTP_201_CREATED, + ) -@login_required @api_view(["POST"]) -def add_conversation(request): - try: - conversation_data = {"title": request.data.get("title", "Mock title"), "user": request.user} - conversation = Conversation.objects.create(**conversation_data) - version = Version.objects.create(conversation=conversation) - - messages_data = request.data.get("messages", []) - for idx, message_data in enumerate(messages_data): - message_serializer = MessageSerializer(data=message_data) - if message_serializer.is_valid(): - message_serializer.save(version=version) - if idx == 0: - version.save() - else: - return Response(message_serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - conversation.active_version = version - conversation.save() - - serializer = ConversationSerializer(conversation) - return Response(serializer.data, status=status.HTTP_201_CREATED) - except Exception as e: - return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST) - - -@login_required -@api_view(["GET", "PUT", "DELETE"]) -def conversation_manage(request, pk): - try: - conversation = Conversation.objects.get(user=request.user, pk=pk) - except Conversation.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND) +@permission_classes([AllowAny]) +def conversation_add_message(request, pk): + conversation = get_object_or_404(Conversation, pk=pk) - if request.method == "GET": - serializer = ConversationSerializer(conversation) - return Response(serializer.data) + if not conversation.active_version: + return Response( + {"detail": "No active version"}, + status=status.HTTP_400_BAD_REQUEST, + ) - elif request.method == "PUT": - serializer = ConversationSerializer(conversation, data=request.data) - if serializer.is_valid(): - serializer.save() - return Response(serializer.data) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + serializer = MessageSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save(version=conversation.active_version) - elif request.method == "DELETE": - conversation.delete() - return Response(status=status.HTTP_204_NO_CONTENT) + return Response(serializer.data, status=status.HTTP_201_CREATED) -@login_required @api_view(["PUT"]) +@permission_classes([IsAuthenticated]) def conversation_change_title(request, pk): - try: - conversation = Conversation.objects.get(user=request.user, pk=pk) - except Conversation.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND) - + conversation = get_object_or_404(Conversation, pk=pk) serializer = TitleSerializer(data=request.data) + serializer.is_valid(raise_exception=True) - if serializer.is_valid(): - conversation.title = serializer.data.get("title") - conversation.save() - return Response(status=status.HTTP_204_NO_CONTENT) + conversation.title = serializer.validated_data["title"][:100] + conversation.save() - return Response({"detail": "Title not provided"}, status=status.HTTP_400_BAD_REQUEST) + return Response(status=status.HTTP_204_NO_CONTENT) -@login_required @api_view(["PUT"]) +@permission_classes([IsAuthenticated]) def conversation_soft_delete(request, pk): - try: - conversation = Conversation.objects.get(user=request.user, pk=pk) - except Conversation.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND) - + conversation = get_object_or_404(Conversation, pk=pk) conversation.deleted_at = timezone.now() conversation.save() return Response(status=status.HTTP_204_NO_CONTENT) -@login_required +@api_view(["GET", "PUT", "DELETE"]) +@permission_classes([AllowAny]) +def conversation_manage(request, pk): + return Response({"detail": "Not implemented"}, status=501) + + @api_view(["POST"]) -def conversation_add_message(request, pk): - try: - conversation = Conversation.objects.get(user=request.user, pk=pk) - version = conversation.active_version - except Conversation.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND) +@permission_classes([AllowAny]) +def conversation_add_version(request, pk): + return Response({"detail": "Not implemented"}, status=501) - if version is None: - return Response({"detail": "Active version not set for this conversation."}, status=status.HTTP_400_BAD_REQUEST) - serializer = MessageSerializer(data=request.data) - if serializer.is_valid(): - serializer.save(version=version) - # return Response(serializer.data, status=status.HTTP_201_CREATED) - return Response( - { - "message": serializer.data, - "conversation_id": conversation.id, - }, - status=status.HTTP_201_CREATED, - ) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) +@api_view(["PUT"]) +@permission_classes([AllowAny]) +def conversation_switch_version(request, pk, version_id): + return Response({"detail": "Not implemented"}, status=501) -@login_required @api_view(["POST"]) -def conversation_add_version(request, pk): - try: - conversation = Conversation.objects.get(user=request.user, pk=pk) - version = conversation.active_version - root_message_id = request.data.get("root_message_id") - root_message = Message.objects.get(pk=root_message_id) - except Conversation.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND) - except Message.DoesNotExist: - return Response({"detail": "Root message not found"}, status=status.HTTP_404_NOT_FOUND) - - # Check if root message belongs to the same conversation - if root_message.version.conversation != conversation: - return Response({"detail": "Root message not part of the conversation"}, status=status.HTTP_400_BAD_REQUEST) - - new_version = Version.objects.create( - conversation=conversation, parent_version=root_message.version, root_message=root_message - ) +@permission_classes([AllowAny]) +def version_add_message(request, pk): + return Response({"detail": "Not implemented"}, status=501) - # Copy messages before root_message to new_version - messages_before_root = Message.objects.filter(version=version, created_at__lt=root_message.created_at) - new_messages = [ - Message(content=message.content, role=message.role, version=new_version) for message in messages_before_root - ] - Message.objects.bulk_create(new_messages) +# Task 8: Conversation summaries +@api_view(["GET"]) +@permission_classes([AllowAny]) +def conversation_summaries(request): + search = request.GET.get("search", "") + page_number = request.GET.get("page", 1) + page_size = int(request.GET.get("page_size", 10)) + + queryset = Conversation.objects.filter( + deleted_at__isnull=True + ).filter( + Q(title__icontains=search) | Q(summary__icontains=search) + ).order_by("-modified_at") + + paginator = Paginator(queryset, page_size) + page = paginator.get_page(page_number) + + serializer = ConversationSummarySerializer(page.object_list, many=True) + return Response({ + "count": paginator.count, + "total_pages": paginator.num_pages, + "current_page": page.number, + "results": serializer.data, + }) + +# Task 9–11: File upload RBAC - # Set the new version as the current version - conversation.active_version = new_version - conversation.save() +@api_view(["POST"]) +@authentication_classes([CsrfExemptSessionAuthentication]) +@permission_classes([IsAuthenticated]) +def upload_file(request): + if not user_has_role(request.user, ["admin", "editor"]): + return Response( + {"detail": "You do not have permission to upload files"}, + status=status.HTTP_403_FORBIDDEN, + ) - serializer = VersionSerializer(new_version) - return Response(serializer.data, status=status.HTTP_201_CREATED) + file = request.FILES.get("file") + if not file: + return Response( + {"detail": "File is required"}, + status=status.HTTP_400_BAD_REQUEST, + ) + hasher = hashlib.sha256() + for chunk in file.chunks(): + hasher.update(chunk) + file_hash = hasher.hexdigest() -@login_required -@api_view(["PUT"]) -def conversation_switch_version(request, pk, version_id): - try: - conversation = Conversation.objects.get(pk=pk) - version = Version.objects.get(pk=version_id, conversation=conversation) - except Conversation.DoesNotExist: - return Response({"detail": "Conversation not found"}, status=status.HTTP_404_NOT_FOUND) - except Version.DoesNotExist: - return Response({"detail": "Version not found"}, status=status.HTTP_404_NOT_FOUND) + if UploadedFile.objects.filter(file_hash=file_hash).exists(): + return Response( + {"detail": "File already uploaded"}, + status=status.HTTP_400_BAD_REQUEST, + ) - conversation.active_version = version - conversation.save() + uploaded_file = UploadedFile.objects.create( + file=file, + filename=file.name, + file_hash=file_hash, + uploaded_by=request.user, + ) - return Response(status=status.HTTP_204_NO_CONTENT) + return Response( + FileUploadSerializer(uploaded_file).data, + status=status.HTTP_201_CREATED, + ) -@login_required -@api_view(["POST"]) -def version_add_message(request, pk): - try: - version = Version.objects.get(pk=pk) - except Version.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND) +@api_view(["GET"]) +@authentication_classes([CsrfExemptSessionAuthentication]) +@permission_classes([IsAuthenticated]) +def list_uploaded_files(request): + if not user_has_role(request.user, ["admin", "editor"]): + return Response( + {"detail": "You do not have permission to view files"}, + status=status.HTTP_403_FORBIDDEN, + ) - serializer = MessageSerializer(data=request.data) - if serializer.is_valid(): - serializer.save(version=version) + files = UploadedFile.objects.all().order_by("-uploaded_at") + serializer = FileListSerializer(files, many=True) + return Response(serializer.data) + + +@api_view(["DELETE"]) +@authentication_classes([CsrfExemptSessionAuthentication]) +@permission_classes([IsAuthenticated]) +def delete_uploaded_file(request, pk): + if not user_has_role(request.user, ["admin"]): return Response( - { - "message": serializer.data, - "version_id": version.id, - }, - status=status.HTTP_201_CREATED, + {"detail": "Only admins can delete files"}, + status=status.HTTP_403_FORBIDDEN, ) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + file_obj = get_object_or_404(UploadedFile, pk=pk) + file_obj.delete() + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/backend/dependencies.txt b/backend/dependencies.txt index 2363ba87e..5bb27ba2a 100644 --- a/backend/dependencies.txt +++ b/backend/dependencies.txt @@ -1,4 +1,4 @@ -aiohttp==3.8.5 +aiohttp>=3.9.5 aiosignal==1.3.1 asgiref==3.7.2 async-timeout==4.0.3 @@ -24,10 +24,10 @@ identify==2.5.30 idna==3.4 isort==5.12.0 mccabe==0.7.0 -multidict==6.0.4 +multidict>=6.1.0 mypy-extensions==1.0.0 nodeenv==1.8.0 -openai==0.28.1 +groq packaging==23.2 pathspec==0.11.2 platformdirs==3.11.0 @@ -49,4 +49,4 @@ tzdata==2023.3 urllib3==2.0.5 uvicorn==0.27.1 virtualenv==20.24.5 -yarl==1.9.2 +yarl==1.9.2 \ No newline at end of file diff --git a/backend/gpt/views.py b/backend/gpt/views.py index e9c81cb2e..4f25c6bc3 100644 --- a/backend/gpt/views.py +++ b/backend/gpt/views.py @@ -1,34 +1,61 @@ -from django.contrib.auth.decorators import login_required from django.http import JsonResponse, StreamingHttpResponse -from rest_framework.decorators import api_view +from django.views.decorators.csrf import csrf_exempt +from rest_framework.decorators import api_view, permission_classes +from rest_framework.permissions import AllowAny -from src.utils.gpt import get_conversation_answer, get_gpt_title, get_simple_answer +from src.utils.gpt import ( + get_conversation_answer, + get_gpt_title, + get_simple_answer, +) @api_view(["GET"]) +@permission_classes([AllowAny]) def gpt_root_view(request): return JsonResponse({"message": "GPT endpoint works!"}) -@login_required + +@csrf_exempt @api_view(["POST"]) +@permission_classes([AllowAny]) def get_title(request): - data = request.data - title = get_gpt_title(data["user_question"], data["chatbot_response"]) + data = request.data or {} + title = get_gpt_title(data.get("user_question", "")) return JsonResponse({"content": title}) -@login_required +@csrf_exempt @api_view(["POST"]) +@permission_classes([AllowAny]) def get_answer(request): - data = request.data - return StreamingHttpResponse(get_simple_answer(data["user_question"], stream=True), content_type="text/html") + data = request.data or {} + return StreamingHttpResponse( + get_simple_answer(data.get("user_question", ""), stream=True), + content_type="text/plain", + ) + -@login_required +@csrf_exempt @api_view(["POST"]) +@permission_classes([AllowAny]) def get_conversation(request): - data = request.data + data = request.data or {} + return StreamingHttpResponse( + get_conversation_answer( + conversation=data.get("conversation", []), + model=data.get("model"), + stream=True, + ), + content_type="text/plain", + ) + return StreamingHttpResponse( - get_conversation_answer(data["conversation"], data["model"], stream=True), content_type="text/html" + get_conversation_answer( + conversation=conversation, + stream=True, + ), + content_type="text/plain", ) diff --git a/backend/src/libs/__init__.py b/backend/src/libs/__init__.py deleted file mode 100644 index 214cf63db..000000000 --- a/backend/src/libs/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -import os - -import openai -from dotenv import load_dotenv - -__all__ = ["openai"] -load_dotenv() - -openai.api_type = os.getenv("OPENAI_API_TYPE") -openai.api_base = os.getenv("OPENAI_API_BASE") -openai.api_version = os.getenv("OPENAI_API_VERSION") -openai.api_key = os.getenv("OPENAI_API_KEY") diff --git a/backend/src/utils/gpt.py b/backend/src/utils/gpt.py index f8a4aa023..e5947384c 100644 --- a/backend/src/utils/gpt.py +++ b/backend/src/utils/gpt.py @@ -1,77 +1,60 @@ -from dataclasses import dataclass - -from src.libs import openai - -GPT_40_PARAMS = dict( - temperature=0.7, - top_p=0.95, - frequency_penalty=0, - presence_penalty=0, - stop=None, - stream=False, -) - - -@dataclass -class GPTVersion: - name: str - engine: str - - -GPT_VERSIONS = { - "gpt35": GPTVersion("gpt35", "gpt-35-turbo-0613"), - "gpt35-16k": GPTVersion("gpt35-16k", "gpt-35-turbo-16k"), - "gpt4": GPTVersion("gpt4", "gpt-4-0613"), - "gpt4-32k": GPTVersion("gpt4-32k", "gpt4-32k-0613"), -} - - -def get_simple_answer(prompt: str, stream: bool = True): - kwargs = {**GPT_40_PARAMS, **dict(stream=stream)} - - for resp in openai.ChatCompletion.create( - engine=GPT_VERSIONS["gpt35"].engine, - messages=[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}], - **kwargs, - ): - choices = resp.get("choices", []) - if not choices: - continue - chunk = choices.pop()["delta"].get("content") - if chunk: - yield chunk - - -def get_gpt_title(prompt: str, response: str): - sys_msg: str = ( - "As an AI Assistant your goal is to make very short title, few words max for a conversation between user and " - "chatbot. You will be given the user's question and chatbot's first response and you will return only the " - "resulting title. Always return some raw title and nothing more." - ) - usr_msg = f'user_question: "{prompt}"\n' f'chatbot_response: "{response}"' +import os +from groq import Groq + +client = Groq(api_key=os.getenv("GROQ_API_KEY")) + +def generate_summary(text: str) -> str: + if not text.strip(): + return "" + + response = client.chat.completions.create( + model=os.getenv("GROQ_MODEL", "llama-3.1-8b-instant"), + - response = openai.ChatCompletion.create( - engine=GPT_VERSIONS["gpt35"].engine, - messages=[{"role": "system", "content": sys_msg}, {"role": "user", "content": usr_msg}], - **GPT_40_PARAMS, + + messages=[ + {"role": "system", "content": "Summarize the conversation briefly."}, + {"role": "user", "content": text}, + ], + max_tokens=120, + temperature=0.3, ) - result = response["choices"][0]["message"]["content"].replace('"', "") - return result - - -def get_conversation_answer(conversation: list[dict[str, str]], model: str, stream: bool = True): - kwargs = {**GPT_40_PARAMS, **dict(stream=stream)} - engine = GPT_VERSIONS[model].engine - - for resp in openai.ChatCompletion.create( - engine=engine, - messages=[{"role": "system", "content": "You are a helpful assistant."}, *conversation], - **kwargs, - ): - choices = resp.get("choices", []) - if not choices: - continue - chunk = choices.pop()["delta"].get("content") - if chunk: - yield chunk + return response.choices[0].message.content.strip() + + +def get_conversation_answer(conversation, model=None, stream=False): + """ + Respond ONLY to the latest user message. + Prevents merging AI responses with user input. + """ + + if not conversation or not isinstance(conversation, list): + return "" + + # Find the LAST user message + last_user_message = None + for msg in reversed(conversation): + if msg.get("role") == "user": + last_user_message = msg.get("content", "") + break + + if not last_user_message: + return "" + + # Generate response ONLY for that message + return generate_summary(last_user_message) + + + +def get_simple_answer(prompt: str): + return generate_summary(prompt) + + +def get_gpt_title(user_question: str) -> str: + if not user_question: + return "New Conversation" + + return generate_summary(user_question) + + diff --git a/frontend/api/auth.js b/frontend/api/auth.js index 5b9a6852d..b4e3d5843 100644 --- a/frontend/api/auth.js +++ b/frontend/api/auth.js @@ -119,47 +119,11 @@ export const postRegister = async ({email, password}) => { } }; - export async function getServerSidePropsAuthHelper(context) { - let isAuthenticated = false; - - const session = context.req.cookies.sessionid || null; - const currUser = context.req.cookies.user || null; - - if (!currUser) { - return { - redirect: { - destination: '/login', - permanent: false, - }, - }; - } - - - if (session) { - const response = (await axiosInstance.get(`/auth/verify_session`, - { - headers: { - 'Cookie': `sessionid=${session}`, - } - })).data; - - isAuthenticated = response.data; - } - - if (!isAuthenticated) { - console.log('User is not authenticated, redirecting to login page.'); - return { - redirect: { - destination: '/login', - permanent: false, - }, - }; - } - + // AUTH BYPASS FOR DEVELOPMENT / ASSIGNMENT return { props: { - isAuthenticated, + isAuthenticated: true, }, }; -} +} \ No newline at end of file diff --git a/frontend/components/chat/Conversation.js b/frontend/components/chat/Conversation.js index 425d72674..94c49122f 100644 --- a/frontend/components/chat/Conversation.js +++ b/frontend/components/chat/Conversation.js @@ -2,12 +2,25 @@ import React from 'react'; import styles from "../../styles/chat/Message.module.css"; import Message from "./Message"; -const Conversation = ({messages, regenerateUserResponse, error}) => ( - <> - {messages.map(message => )} - {error &&

{error}

} - -); +const Conversation = ({ messages = [], regenerateUserResponse, error }) => { + if (!Array.isArray(messages)) return null; + + return ( + <> + {messages.filter(Boolean).map(m => ( + + ))} + {error && ( +
+

{error}

+
+ )} + + ); +}; export default Conversation; diff --git a/frontend/components/chat/ConversationList.js b/frontend/components/chat/ConversationList.js new file mode 100644 index 000000000..f6472c32c --- /dev/null +++ b/frontend/components/chat/ConversationList.js @@ -0,0 +1,20 @@ +import { useEffect, useState } from "react"; +import { fetchConversationSummaries } from "../../utils/chatApi"; + +export default function ConversationList() { + const [conversations, setConversations] = useState([]); + + useEffect(() => { + fetchConversationSummaries(1, 5, "") + .then(data => setConversations(data.results)); + }, []); + + return ( +
+

Previous Conversations

+ {conversations.map(c => ( +
{c.title}
+ ))} +
+ ); +} diff --git a/frontend/components/chat/Main.js b/frontend/components/chat/Main.js index d884a710d..43238f70e 100644 --- a/frontend/components/chat/Main.js +++ b/frontend/components/chat/Main.js @@ -1,22 +1,26 @@ -import React, {useCallback, useEffect, useRef, useState} from 'react'; +import React, { useCallback, useEffect, useRef, useState } from 'react'; import styles from "../../styles/chat/Main.module.css"; -import {postChatConversation, postChatTitle} from "../../api/gpt"; +import { postChatConversation, postChatTitle } from "../../api/gpt"; import Conversation from "./Conversation"; import ChoiceButton from "./ModelButton"; -import {useDispatch, useSelector} from "react-redux"; -import {addMessage, changeTitle, setConversation} from "../../redux/currentConversation"; +import { useDispatch, useSelector } from "react-redux"; +import { addMessage, changeTitle, setConversation } from "../../redux/currentConversation"; import { addConversationMessageThunk, addConversationVersionThunk, createConversationThunk, getConversationBranchedThunk, - updateConversation } from "../../redux/conversations"; -import {setStreaming} from "../../redux/streaming"; -import {AssistantRole, GPT35, MessageTypes, MockTitle, UserRole} from "../../utils/constants"; -import {generateMockId} from "../../utils/functions"; -import {ControlButtons} from "./ControlButtons"; - +import { setStreaming } from "../../redux/streaming"; +import { + AssistantRole, + GPT35, + MessageTypes, + MockTitle, + UserRole +} from "../../utils/constants"; +import { generateMockId } from "../../utils/functions"; +import { ControlButtons } from "./ControlButtons"; const Chat = () => { const currVersion = useSelector(state => state.currentConversation); @@ -24,7 +28,9 @@ const Chat = () => { const dispatch = useDispatch(); const chatContainerRef = useRef(null); - const inputRef = useRef(); + const inputRef = useRef(null); + const abortController = useRef(new AbortController()); + const [userInput, setUserInput] = useState(''); const [canStop, setCanStop] = useState(false); const [canRegenerate, setCanRegenerate] = useState(false); @@ -32,314 +38,168 @@ const Chat = () => { const [error, setError] = useState(null); const [chosenModel, setChosenModel] = useState(GPT35); - let abortController = useRef(new AbortController()); useEffect(() => { - const element = chatContainerRef.current; - element.scrollTop = element.scrollHeight; - const currMessages = currVersion.messages; + const el = chatContainerRef.current; + if (el) el.scrollTop = el.scrollHeight; - dispatch(updateConversation(currVersion)); + const msgs = currVersion.messages || []; + const lastMessage = msgs[msgs.length - 1]; + const hasUser = msgs.some(m => m.role === UserRole); + const hasAssistant = + lastMessage && + lastMessage.role === AssistantRole && + lastMessage.content !== ''; - const hasUserInput = currMessages.some(message => message.role === UserRole); - const lastMessage = currMessages[currMessages.length - 1]; - const hasChatResponse = lastMessage && lastMessage.role === AssistantRole && lastMessage.content !== ''; - setCanRegenerate(hasUserInput && hasChatResponse && !isStreaming); - setCanStop(isStreaming && hasChatResponse) + setCanRegenerate(hasUser && hasAssistant && !isStreaming); + setCanStop(isStreaming && hasAssistant); - if (currMessages.length === 2 && !isStreaming && currVersion.title === MockTitle) { + if (msgs.length === 2 && !isStreaming && currVersion.title === MockTitle) { generateTitle().catch(console.error); } }, [currVersion, isStreaming]); - useEffect(() => { - console.log('conversation on useEffect end isStreaming', currVersion); - }, [isStreaming]); - - useEffect(() => { - let isCancelled = false; - - const checkVersionUpdatePromise = async () => { - if (versionUpdatePromise) { - await versionUpdatePromise; - if (!isCancelled) { - setVersionUpdatePromise(null); - } - } - }; - - checkVersionUpdatePromise().catch(console.error); - return () => { - isCancelled = true; - }; - }, [versionUpdatePromise]); + const updateInputHeight = () => { + if (!inputRef.current) return; + inputRef.current.style.height = "auto"; + inputRef.current.style.height = `${inputRef.current.scrollHeight}px`; + }; const generateTitle = async () => { - const lastTwoMessages = currVersion.messages.slice(-2); - const lastUserMessage = lastTwoMessages.find(message => message.role === UserRole).content; - const lastAssistantMessage = lastTwoMessages.find(message => message.role === AssistantRole).content; + const lastTwo = currVersion.messages.slice(-2); + const userMsg = lastTwo.find(m => m.role === UserRole); + const assistantMsg = lastTwo.find(m => m.role === AssistantRole); + if (!userMsg || !assistantMsg) return; - let title; + let title = "Conversation"; try { title = await postChatTitle({ - "user_question": lastUserMessage, - "chatbot_response": lastAssistantMessage, + user_question: userMsg.content, + chatbot_response: assistantMsg.content, }); - } catch (error) { - console.error("Error generating title", error); - title = "Error generating title"; - } + } catch {} - const newConversation = { - title: title, - messages: currVersion.messages, - } - console.log('gen title newConversation', newConversation); dispatch(changeTitle(title)); - dispatch(createConversationThunk(newConversation)); - } - - const handleInputChanged = (e) => { - setUserInput(e.target.value); - updateInputHeight(); - } - - const handleKeyDown = (e) => { - const currentText = e.currentTarget.value; - if (!currentText && e.key !== "Enter") return; - - setUserInput(currentText); - updateInputHeight(); - - if (e.key === "Enter") { - if (e.shiftKey) { - e.preventDefault(); - const {selectionStart, selectionEnd} = e.currentTarget; - const newValue = currentText.slice(0, selectionStart) + '\n' + currentText.slice(selectionEnd); - setUserInput(newValue); - const textarea = e.currentTarget; - setTimeout(() => { - textarea.selectionStart = textarea.selectionEnd = selectionStart + 1; - }, 0); - } else { - e.preventDefault(); - generateResponse(currentText).catch(console.error); - } - } + dispatch(createConversationThunk({ title, messages: currVersion.messages })); }; - const handleGenerateClick = () => { - generateResponse().catch(console.error); - } - const updateInputHeight = () => { - inputRef.current.style.height = "auto"; - inputRef.current.style.height = `${inputRef.current.scrollHeight}px`; - } - - const resetInputHeight = () => { - inputRef.current.textContent = ''; - inputRef.current.style.height = "auto"; - } - const handleModelChoice = (model) => { - setChosenModel(model); - } - - const generateResponse = async (prompt = inputRef.current.textContent, messageType = MessageTypes.UserMessage, messageId = null) => { - let newConversationMessages, newMessage; - const regenerateMessage = messageType === MessageTypes.RegenerateAssistantMessage || messageType === MessageTypes.RegenerateUserMessage; + const generateResponse = async ( + prompt = userInput, + messageType = MessageTypes.UserMessage, + messageId = null + ) => { + let newConversationMessages = []; + let newMessage; switch (messageType) { case MessageTypes.UserMessage: - newMessage = {role: UserRole, content: prompt, id: generateMockId()}; + newMessage = { role: UserRole, content: prompt, id: generateMockId() }; newConversationMessages = [...currVersion.messages, newMessage]; - addMessageToConversation(prompt, UserRole) - break; - case MessageTypes.RegenerateAssistantMessage: - newMessage = {role: AssistantRole, content: "", id: generateMockId()}; - newConversationMessages = currVersion.messages.slice(0, -1); - setVersionUpdatePromise(addVersionToConversation()); - break; - case MessageTypes.RegenerateUserMessage: - newMessage = {role: AssistantRole, content: "", id: generateMockId()}; - const messageIndex = currVersion.messages.findIndex(message => message.id === messageId); - const messages = currVersion.messages.slice(0, messageIndex + 1); - messages[messageIndex] = {role: UserRole, content: prompt, id: generateMockId()} - const newVersion = {...currVersion, messages: messages}; - dispatch(setConversation(newVersion)); - - newConversationMessages = newVersion.messages; - setVersionUpdatePromise( - addVersionToConversation(messageId) - .then(() => addMessageToConversation(prompt, UserRole, true)) - ); + addMessageToConversation(prompt, UserRole); break; + default: - throw new Error(`Unknown message type: ${messageType}`); + return; } - dispatch(addMessage(newMessage)); + dispatch(addMessage(newMessage)); setUserInput(''); - inputRef.current.textContent = ''; - resetInputHeight(); - setError(null); + updateInputHeight(); dispatch(setStreaming(true)); try { const reader = await postChatConversation( - newConversationMessages.map(m => ({role: m.role, content: m.content})), + newConversationMessages.map(m => ({ role: m.role, content: m.content })), chosenModel, - {signal: abortController.current.signal} + { signal: abortController.current.signal } ); + const decoder = new TextDecoder(); - let data = ''; + let data = ""; while (true) { - const {done, value} = await reader.read(); - if (done) { - processText(data); - if (versionUpdatePromise) { - await versionUpdatePromise; - await addMessageToConversation(data, AssistantRole); - setVersionUpdatePromise(null); - } else { - addMessageToConversation(data, AssistantRole); - } - if (regenerateMessage) - dispatch(getConversationBranchedThunk({conversationId: currVersion.conversation_id})); - break; - } - data += decoder.decode(value, {stream: true}); - processText(data); + const { done, value } = await reader.read(); + if (done) break; + data += decoder.decode(value, { stream: true }); } - } catch (error) { - if (error.name === 'AbortError') { - console.log('Fetch aborted'); - } else { - setError(`There was an error: ${error.message}`); + + dispatch(addMessage({ + role: AssistantRole, + content: data, + id: generateMockId() + })); + + } catch (err) { + if (err.name !== "AbortError") { + setError(err.message); } } finally { dispatch(setStreaming(false)); } }; - const abortResponse = async () => { - abortController.current.abort(); - abortController.current = new AbortController(); - dispatch(setStreaming(false)); - const lastMessage = currVersion.messages[currVersion.messages.length - 1]; - if (versionUpdatePromise) { - await versionUpdatePromise; - setVersionUpdatePromise(null); - } - await addMessageToConversation(lastMessage.content, AssistantRole) - dispatch(getConversationBranchedThunk({conversationId: currVersion.conversation_id})); - } - - const regenerateAssistantResponse = () => { - const lastUserMessage = [...currVersion.messages].reverse().find(message => message.role === UserRole); - - if (!lastUserMessage) return; - - generateResponse(lastUserMessage.content, MessageTypes.RegenerateAssistantMessage).catch(console.error); - }; - - const regenerateUserResponse = useCallback((messageId, newContent) => { - console.log("messageEditConfirm", messageId, newContent); - - generateResponse(newContent, MessageTypes.RegenerateUserMessage, messageId).catch(console.error); - }, [currVersion.messages]); - - const processText = (data) => { - const newMessage = {role: AssistantRole, content: data, id: generateMockId()}; - dispatch(addMessage(newMessage)); - }; - - const addMessageToConversation = (message, role, hidden = false) => { - if (currVersion.title === MockTitle) - return Promise.resolve(); - const newMessage = {role: role, content: message}; - // if this is first user's message then hidden = true - if (role === UserRole && currVersion.messages.length === 2) { - hidden = true; - } + const addMessageToConversation = (message, role) => { + if (!currVersion.conversation_id || currVersion.title === MockTitle) return; return dispatch(addConversationMessageThunk({ conversationId: currVersion.conversation_id, - message: newMessage, - hidden: hidden - })); - } - - const addVersionToConversation = async (rootMessageId = null) => { - if (currVersion.title === MockTitle) - return; - if (!rootMessageId) - rootMessageId = currVersion.messages[currVersion.messages.length - 1].id; - - await dispatch(addConversationVersionThunk({ - conversationId: currVersion.conversation_id, - rootMessageId: rootMessageId + message: { role, content: message }, })); - } + }; - const renderChoiceButton = () => { - return ( -
- -
- ) - } - - const renderChatInput = () => { - return ( -
-