Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,15 @@ server/dist

# prompts
prompts/debug
# Virtual environment
venv/

# Environment variables
.env

# Python cache
__pycache__/
*.pyc

# DB
db.sqlite3
36 changes: 33 additions & 3 deletions backend/chat/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid

from django.db import models
from django.apps import apps

from authentication.models import CustomUser

Expand All @@ -15,6 +16,13 @@ def __str__(self):
class Conversation(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
title = models.CharField(max_length=100, blank=False, null=False, default="Mock title")

summary = models.TextField(
blank=True,
null=True,
help_text="Auto-generated summary of the conversation"
)

created_at = models.DateTimeField(auto_now_add=True)
modified_at = models.DateTimeField(auto_now=True)
active_version = models.ForeignKey(
Expand All @@ -31,6 +39,29 @@ def version_count(self):

version_count.short_description = "Number of versions"

def generate_summary(self):
"""
Generate a simple summary using the first few messages
"""
Message = apps.get_model("chat", "Message")

messages = (
Message.objects
.filter(version__conversation=self)
.order_by("created_at")
.values_list("content", flat=True)[:3]
)

if messages:
return " | ".join(messages)

return ""

def save(self, *args, **kwargs):
if not self.summary:
self.summary = self.generate_summary()
super().save(*args, **kwargs)


class Version(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
Expand All @@ -43,8 +74,7 @@ class Version(models.Model):
def __str__(self):
if self.root_message:
return f"Version of `{self.conversation.title}` created at `{self.root_message.created_at}`"
else:
return f"Version of `{self.conversation.title}` with no root message yet"
return f"Version of `{self.conversation.title}` with no root message yet"


class Message(models.Model):
Expand All @@ -58,8 +88,8 @@ class Meta:
ordering = ["created_at"]

def save(self, *args, **kwargs):
self.version.conversation.save()
super().save(*args, **kwargs)
self.version.conversation.save()

def __str__(self):
return f"{self.role}: {self.content[:20]}..."
93 changes: 47 additions & 46 deletions backend/chat/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@


def should_serialize(validated_data, field_name) -> bool:
if validated_data.get(field_name) is not None:
return True

return validated_data.get(field_name) is not None

class TitleSerializer(serializers.Serializer):
title = serializers.CharField(max_length=100, required=True)
Expand All @@ -18,62 +16,61 @@ class VersionTimeIdSerializer(serializers.Serializer):
id = serializers.UUIDField()
created_at = serializers.DateTimeField()


class MessageSerializer(serializers.ModelSerializer):
role = serializers.SlugRelatedField(slug_field="name", queryset=Role.objects.all())
role = serializers.SlugRelatedField(
slug_field="name",
queryset=Role.objects.all()
)

class Meta:
model = Message
fields = [
"id", # DB
"id",
"content",
"role", # required
"created_at", # DB, read-only
"role",
"created_at",
]
read_only_fields = ["id", "created_at", "version"]

def create(self, validated_data):
message = Message.objects.create(**validated_data)
return message
return Message.objects.create(**validated_data)

def to_representation(self, instance):
representation = super().to_representation(instance)
representation["versions"] = [] # add versions field
representation["versions"] = []
return representation


class VersionSerializer(serializers.ModelSerializer):
messages = MessageSerializer(many=True)
active = serializers.SerializerMethodField()
conversation_id = serializers.UUIDField(source="conversation.id")
conversation_id = serializers.UUIDField(source="conversation.id", read_only=True)
created_at = serializers.SerializerMethodField()

class Meta:
model = Version
fields = [
"id",
"conversation_id", # DB
"conversation_id",
"root_message",
"messages",
"active",
"created_at", # DB, read-only
"parent_version", # optional
"created_at",
"parent_version",
]
read_only_fields = ["id", "conversation"]

@staticmethod
def get_active(obj):
def get_active(self, obj):
return obj == obj.conversation.active_version

@staticmethod
def get_created_at(obj):
if obj.root_message is None:
return timezone.localtime(obj.conversation.created_at)
return timezone.localtime(obj.root_message.created_at)
def get_created_at(self, obj):
if obj.root_message:
return timezone.localtime(obj.root_message.created_at)
return timezone.localtime(obj.conversation.created_at)

def create(self, validated_data):
messages_data = validated_data.pop("messages")
messages_data = validated_data.pop("messages", [])
version = Version.objects.create(**validated_data)

for message_data in messages_data:
Message.objects.create(version=version, **message_data)

Expand All @@ -83,16 +80,15 @@ def update(self, instance, validated_data):
instance.conversation = validated_data.get("conversation", instance.conversation)
instance.parent_version = validated_data.get("parent_version", instance.parent_version)
instance.root_message = validated_data.get("root_message", instance.root_message)

if not any(
[
should_serialize(validated_data, "conversation"),
should_serialize(validated_data, "parent_version"),
should_serialize(validated_data, "root_message"),
]
should_serialize(validated_data, field)
for field in ["conversation", "parent_version", "root_message"]
):
raise ValidationError(
"At least one of the following fields must be provided: conversation, parent_version, root_message"
"At least one of conversation, parent_version, or root_message must be provided."
)

instance.save()

messages_data = validated_data.pop("messages", [])
Expand All @@ -107,46 +103,51 @@ def update(self, instance, validated_data):

return instance


class ConversationSerializer(serializers.ModelSerializer):
versions = VersionSerializer(many=True)
summary = serializers.CharField(read_only=True)

class Meta:
model = Conversation
fields = [
"id", # DB
"title", # required
"id",
"title",
"summary",
"active_version",
"versions", # optional
"modified_at", # DB, read-only
"versions",
"modified_at",
]
read_only_fields = ["id", "summary", "modified_at"]

def create(self, validated_data):
versions_data = validated_data.pop("versions", [])
conversation = Conversation.objects.create(**validated_data)

for version_data in versions_data:
version_serializer = VersionSerializer(data=version_data)
if version_serializer.is_valid():
version_serializer.save(conversation=conversation)
serializer = VersionSerializer(data=version_data)
serializer.is_valid(raise_exception=True)
serializer.save(conversation=conversation)

return conversation

def update(self, instance, validated_data):
instance.title = validated_data.get("title", instance.title)
active_version_id = validated_data.get("active_version", instance.active_version)
if active_version_id is not None:
active_version = Version.objects.get(id=active_version_id)
instance.active_version = active_version

active_version_id = validated_data.get("active_version")
if active_version_id:
instance.active_version = Version.objects.get(id=active_version_id)

instance.save()

versions_data = validated_data.pop("versions", [])
for version_data in versions_data:
if "id" in version_data:
version = Version.objects.get(id=version_data["id"], conversation=instance)
version_serializer = VersionSerializer(version, data=version_data)
serializer = VersionSerializer(version, data=version_data)
else:
version_serializer = VersionSerializer(data=version_data)
if version_serializer.is_valid():
version_serializer.save(conversation=instance)
serializer = VersionSerializer(data=version_data)

serializer.is_valid(raise_exception=True)
serializer.save(conversation=instance)

return instance