From e6aae67733d169154e434b112d216b8ad6d9db17 Mon Sep 17 00:00:00 2001 From: Taimoor Ahmed Date: Wed, 19 Nov 2025 00:31:05 +0500 Subject: [PATCH] feat: Optimize MySQL backend APIs to improve performance This commit introduces query optimizations to reduce database queries and improve response times: - Fixed N+1 queries in threads_presentor, get_paginated_user_stats, and other methods using select_related/prefetch_related - Optimized get_read_states to prefetch data in bulk instead of individual queries - Optimized get_abuse_flagged_count and get_endorsed with bulk aggregations - Removed duplicate annotations in handle_threads_query - Added query optimizations across prepare_thread, validate_thread_and_user, and other methods Performance impact: Reduced queries from O(n) to O(1)/O(k), eliminated N+1 patterns, improved bulk operations. All changes maintain backward compatibility. --- forum/__init__.py | 2 +- forum/backends/mysql/api.py | 207 ++++++++++++++++++++++++++---------- 2 files changed, 153 insertions(+), 56 deletions(-) diff --git a/forum/__init__.py b/forum/__init__.py index 44c360c..bc1fa6c 100644 --- a/forum/__init__.py +++ b/forum/__init__.py @@ -2,4 +2,4 @@ Openedx forum app. """ -__version__ = "0.3.9" +__version__ = "0.4.0" diff --git a/forum/backends/mysql/api.py b/forum/backends/mysql/api.py index 66aada4..b115119 100644 --- a/forum/backends/mysql/api.py +++ b/forum/backends/mysql/api.py @@ -10,8 +10,8 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.paginator import Paginator from django.db.models import ( - Count, Case, + Count, Exists, F, IntegerField, @@ -19,9 +19,10 @@ OuterRef, Q, Subquery, - When, Sum, + When, ) +from django.db.models.functions import Coalesce from django.utils import timezone from rest_framework import status from rest_framework.response import Response @@ -308,8 +309,17 @@ def validate_thread_and_user( ValueError: If the thread or user is not found. """ try: - thread = CommentThread.objects.get(pk=int(thread_id)) - user = ForumUser.objects.get(user__pk=user_id) + # Optimize: Use select_related to avoid N+1 queries + thread = CommentThread.objects.select_related("author", "closed_by").get( + pk=int(thread_id) + ) + user = ( + ForumUser.objects.select_related("user") + .prefetch_related( + "user__course_stats", "user__read_states__last_read_times" + ) + .get(user__pk=user_id) + ) except ObjectDoesNotExist as exc: raise ValueError("User / Thread doesn't exist") from exc @@ -348,8 +358,17 @@ def get_pinned_unpinned_thread_serialized_data( Raises: ValueError: If the serialization is not valid. """ - user = ForumUser.objects.get(user__pk=user_id) - updated_thread = CommentThread.objects.get(pk=thread_id) + # Optimize: Use select_related to avoid N+1 queries + user = ( + ForumUser.objects.select_related("user") + .prefetch_related( + "user__course_stats", "user__read_states__last_read_times" + ) + .get(user__pk=user_id) + ) + updated_thread = CommentThread.objects.select_related( + "author", "closed_by" + ).get(pk=thread_id) user_data = user.to_dict() context = { "user_id": user_data["_id"], @@ -401,35 +420,41 @@ def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]: Returns: dict[str, int]: A dictionary mapping thread IDs to their corresponding abuse-flagged comment count. """ - abuse_flagger_count_subquery = ( + # Optimize: Use aggregation to count abuse flaggers per thread in bulk + comment_content_type = ContentType.objects.get_for_model(Comment) + + # Get all comments for these threads + comment_ids = Comment.objects.filter( + comment_thread__pk__in=thread_ids + ).values_list("pk", flat=True) + + if not comment_ids: + return {} + + # Count abuse flaggers per comment using aggregation + abuse_flagged_counts = ( AbuseFlagger.objects.filter( - content_type=ContentType.objects.get_for_model(Comment), - content_object_id=OuterRef("pk"), + content_type=comment_content_type, + content_object_id__in=comment_ids, ) .values("content_object_id") .annotate(count=Count("pk")) - .values("count") ) - abuse_flagged_comments = ( - Comment.objects.filter( - comment_thread__pk__in=thread_ids, - ) - .annotate( - abuse_flaggers_count=Subquery( - abuse_flagger_count_subquery, output_field=IntegerField() - ) + # Map comment IDs back to thread IDs + comment_to_thread = dict( + Comment.objects.filter(pk__in=comment_ids).values_list( + "pk", "comment_thread_id" ) - .filter(abuse_flaggers_count__gt=0) ) - result = {} - for comment in abuse_flagged_comments: - thread_pk = str(comment.comment_thread.pk) - if thread_pk not in result: - result[thread_pk] = 0 - abuse_flaggers = "abuse_flaggers_count" - result[thread_pk] += getattr(comment, abuse_flaggers) + result: dict[str, int] = {} + for item in abuse_flagged_counts: + comment_id = item["content_object_id"] + thread_id = comment_to_thread.get(comment_id) + if thread_id: + thread_pk = str(thread_id) + result[thread_pk] = result.get(thread_pk, 0) + item["count"] return result @@ -457,28 +482,43 @@ def get_read_states( except User.DoesNotExist: return read_states - threads = CommentThread.objects.filter(pk__in=thread_ids) + # Convert thread_ids to integers for database queries + try: + thread_ids_int = [int(tid) for tid in thread_ids] + except (ValueError, TypeError): + return read_states + + threads = CommentThread.objects.filter(pk__in=thread_ids_int).values( + "pk", "last_activity_at" + ) + thread_dict = {thread["pk"]: thread for thread in threads} + read_state = ReadState.objects.filter(user=user, course_id=course_id).first() if not read_state: return read_states - read_dates = read_state.last_read_times + last_read_times = read_state.last_read_times.select_related( + "comment_thread" + ).filter(comment_thread_id__in=thread_ids_int) - for thread in threads: - read_date = read_dates.filter(comment_thread=thread).first() - if not read_date: + for read_date in last_read_times: + thread_id = read_date.comment_thread.pk + thread = thread_dict.get(thread_id) + if not thread: continue - last_activity_at = thread.last_activity_at + last_activity_at = thread["last_activity_at"] is_read = read_date.timestamp >= last_activity_at + + # Count unread comments for this thread unread_comment_count = ( Comment.objects.filter( - comment_thread=thread, created_at__gte=read_date.timestamp + comment_thread_id=thread_id, created_at__gte=read_date.timestamp ) .exclude(author__pk=user_id) .count() ) - read_states[str(thread.pk)] = [is_read, unread_comment_count] + read_states[str(thread_id)] = [is_read, unread_comment_count] return read_states @@ -524,11 +564,14 @@ def get_endorsed(thread_ids: list[str]) -> dict[str, bool]: Returns: dict[str, bool]: A dictionary of thread IDs to their endorsed status (True if endorsed, False otherwise). """ - endorsed_comments = Comment.objects.filter( - comment_thread__pk__in=thread_ids, endorsed=True + # Optimize: Use values_list to avoid loading full objects + endorsed_thread_ids = ( + Comment.objects.filter(comment_thread__pk__in=thread_ids, endorsed=True) + .values_list("comment_thread_id", flat=True) + .distinct() ) - return {str(comment.comment_thread.pk): True for comment in endorsed_comments} + return {str(thread_id): True for thread_id in endorsed_thread_ids} @staticmethod def get_user_read_state_by_course_id( @@ -729,24 +772,44 @@ def handle_threads_query( base_query = base_query.filter( commentable_id__in=commentable_ids, ) + # Annotate comments count base_query = base_query.annotate( - votes_point=Sum("uservote__vote"), - comments_count=Count("comment", distinct=True), - ) - - base_query = base_query.annotate( - votes_point=Sum("uservote__vote", distinct=True), comments_count=Count("comment", distinct=True), ) sort_criteria = cls.get_sort_criteria(sort_key) + # Only annotate votes_point if sorting by votes to avoid performance issues + # Otherwise calculate votes separately in bulk + if sort_key == "votes": + comment_thread_content_type = ContentType.objects.get_for_model( + CommentThread + ) + base_query = base_query.annotate( + votes_point=Coalesce( + Subquery( + UserVote.objects.filter( + content_type=comment_thread_content_type, + content_object_id=OuterRef("pk"), + ) + .values("content_object_id") + .annotate(votes_sum=Sum("vote")) + .values("votes_sum")[:1], + output_field=IntegerField(), + ), + 0, + ), + ) + + base_query = base_query.select_related("author", "closed_by") + comment_threads = ( base_query.order_by(*sort_criteria) if sort_criteria else base_query ) thread_count = base_query.count() if raw_query: + comment_threads = comment_threads.prefetch_related("comment_set") return { "result": [ comment_thread.to_dict() for comment_thread in comment_threads @@ -762,6 +825,7 @@ def handle_threads_query( to_skip = (page - 1) * per_page has_more = False + # Note: iterator() doesn't support prefetch_related, so we don't use it here for thread in comment_threads.iterator(): thread_key = str(thread.pk) if ( @@ -777,6 +841,8 @@ def handle_threads_query( skipped += 1 num_pages = page + 1 if has_more else page else: + # Apply prefetch_related when not using iterator() + comment_threads = comment_threads.prefetch_related("comment_set") threads = [thread.pk for thread in comment_threads] page = max(1, page) start = per_page * (page - 1) @@ -820,7 +886,10 @@ def prepare_thread( Returns: dict[str, Any]: A dictionary representing the prepared thread data. """ - thread = CommentThread.objects.get(pk=thread_id) + # Optimize: Use select_related to avoid N+1 queries + thread = CommentThread.objects.select_related("author", "closed_by").get( + pk=thread_id + ) return { **thread.to_dict(), "type": "thread", @@ -850,7 +919,13 @@ def threads_presentor( Returns: list[dict[str, Any]]: A list of prepared thread data. """ - threads = CommentThread.objects.filter(pk__in=thread_ids) + + threads = CommentThread.objects.filter(pk__in=thread_ids).select_related( + "author", "closed_by" + ) + + threads_dict = {thread.pk: thread for thread in threads} + read_states = cls.get_read_states(thread_ids, user_id, course_id) threads_endorsed = cls.get_endorsed(thread_ids) threads_flagged = ( @@ -859,7 +934,9 @@ def threads_presentor( presenters = [] for thread_id in thread_ids: - thread = threads.get(id=thread_id) + thread = threads_dict.get(int(thread_id)) + if not thread: + continue is_read, unread_count = read_states.get( str(thread.pk), (False, thread.comment_count) ) @@ -1693,7 +1770,10 @@ def update_comment(comment_id: str, **kwargs: Any) -> int: @staticmethod def get_thread_id_from_comment(comment_id: str) -> dict[str, Any] | None: """Return thread_id from comment_id.""" - comment = Comment.objects.get(pk=comment_id) + # Optimize: Use select_related to avoid N+1 queries + comment = Comment.objects.select_related( + "comment_thread__author", "comment_thread__closed_by" + ).get(pk=comment_id) if comment.comment_thread: return comment.comment_thread.to_dict() raise ValueError("Comment doesn't have the thread.") @@ -2114,20 +2194,37 @@ def get_paginated_user_stats( cls, course_id: str, page: int, per_page: int, sort_criterion: dict[str, Any] ) -> dict[str, Any]: """Get paginated user stats.""" - users = User.objects.filter( - Q(course_stats__course_id=course_id) - & Q(course_stats__course_id__isnull=False) - ).order_by( - *[f"-{key}" for key, value in sort_criterion.items() if value == -1], - *[key for key, value in sort_criterion.items() if value == 1], + + users = ( + User.objects.filter( + Q(course_stats__course_id=course_id) + & Q(course_stats__course_id__isnull=False) + ) + .select_related("forum") + .prefetch_related("course_stats", "read_states__last_read_times") + .order_by( + *[f"-{key}" for key, value in sort_criterion.items() if value == -1], + *[key for key, value in sort_criterion.items() if value == 1], + ) ) paginator = Paginator(users, per_page) paginated_users = paginator.page(page) + user_ids = [user.pk for user in paginated_users.object_list] + forum_users_dict = { + fu.user.pk: fu + for fu in ForumUser.objects.filter(user__pk__in=user_ids) + .select_related("user") + .prefetch_related( + "user__course_stats", "user__read_states__last_read_times" + ) + } + forum_users = [ - ForumUser.objects.get(user_id=user_id) - for user_id in paginated_users.object_list + forum_users_dict[user_id] + for user_id in user_ids + if user_id in forum_users_dict ] return { "pagination": [{"total_count": paginator.count}],