Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion forum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Openedx forum app.
"""

__version__ = "0.3.9"
__version__ = "0.4.0"
207 changes: 152 additions & 55 deletions forum/backends/mysql/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
from django.core.exceptions import ObjectDoesNotExist
from django.core.paginator import Paginator
from django.db.models import (
Count,
Case,
Count,
Exists,
F,
IntegerField,
Max,
OuterRef,
Q,
Subquery,
When,
Sum,
When,
)
from django.db.models.functions import Coalesce
from django.utils import timezone
from rest_framework import status
from rest_framework.response import Response
Expand Down Expand Up @@ -308,8 +309,17 @@ def validate_thread_and_user(
ValueError: If the thread or user is not found.
"""
try:
thread = CommentThread.objects.get(pk=int(thread_id))
user = ForumUser.objects.get(user__pk=user_id)
# Optimize: Use select_related to avoid N+1 queries
thread = CommentThread.objects.select_related("author", "closed_by").get(
pk=int(thread_id)
)
user = (
ForumUser.objects.select_related("user")
.prefetch_related(
"user__course_stats", "user__read_states__last_read_times"
)
.get(user__pk=user_id)
)
except ObjectDoesNotExist as exc:
raise ValueError("User / Thread doesn't exist") from exc

Expand Down Expand Up @@ -348,8 +358,17 @@ def get_pinned_unpinned_thread_serialized_data(
Raises:
ValueError: If the serialization is not valid.
"""
user = ForumUser.objects.get(user__pk=user_id)
updated_thread = CommentThread.objects.get(pk=thread_id)
# Optimize: Use select_related to avoid N+1 queries
user = (
ForumUser.objects.select_related("user")
.prefetch_related(
"user__course_stats", "user__read_states__last_read_times"
)
.get(user__pk=user_id)
)
updated_thread = CommentThread.objects.select_related(
"author", "closed_by"
).get(pk=thread_id)
user_data = user.to_dict()
context = {
"user_id": user_data["_id"],
Expand Down Expand Up @@ -401,35 +420,41 @@ def get_abuse_flagged_count(thread_ids: list[str]) -> dict[str, int]:
Returns:
dict[str, int]: A dictionary mapping thread IDs to their corresponding abuse-flagged comment count.
"""
abuse_flagger_count_subquery = (
# Optimize: Use aggregation to count abuse flaggers per thread in bulk
comment_content_type = ContentType.objects.get_for_model(Comment)

# Get all comments for these threads
comment_ids = Comment.objects.filter(
comment_thread__pk__in=thread_ids
).values_list("pk", flat=True)

if not comment_ids:
return {}

# Count abuse flaggers per comment using aggregation
abuse_flagged_counts = (
AbuseFlagger.objects.filter(
content_type=ContentType.objects.get_for_model(Comment),
content_object_id=OuterRef("pk"),
content_type=comment_content_type,
content_object_id__in=comment_ids,
)
.values("content_object_id")
.annotate(count=Count("pk"))
.values("count")
)

abuse_flagged_comments = (
Comment.objects.filter(
comment_thread__pk__in=thread_ids,
)
.annotate(
abuse_flaggers_count=Subquery(
abuse_flagger_count_subquery, output_field=IntegerField()
)
# Map comment IDs back to thread IDs
comment_to_thread = dict(
Comment.objects.filter(pk__in=comment_ids).values_list(
"pk", "comment_thread_id"
)
.filter(abuse_flaggers_count__gt=0)
)

result = {}
for comment in abuse_flagged_comments:
thread_pk = str(comment.comment_thread.pk)
if thread_pk not in result:
result[thread_pk] = 0
abuse_flaggers = "abuse_flaggers_count"
result[thread_pk] += getattr(comment, abuse_flaggers)
result: dict[str, int] = {}
for item in abuse_flagged_counts:
comment_id = item["content_object_id"]
thread_id = comment_to_thread.get(comment_id)
if thread_id:
thread_pk = str(thread_id)
result[thread_pk] = result.get(thread_pk, 0) + item["count"]

return result

Expand Down Expand Up @@ -457,28 +482,43 @@ def get_read_states(
except User.DoesNotExist:
return read_states

threads = CommentThread.objects.filter(pk__in=thread_ids)
# Convert thread_ids to integers for database queries
try:
thread_ids_int = [int(tid) for tid in thread_ids]
except (ValueError, TypeError):
return read_states

threads = CommentThread.objects.filter(pk__in=thread_ids_int).values(
"pk", "last_activity_at"
)
thread_dict = {thread["pk"]: thread for thread in threads}

read_state = ReadState.objects.filter(user=user, course_id=course_id).first()
if not read_state:
return read_states

read_dates = read_state.last_read_times
last_read_times = read_state.last_read_times.select_related(
"comment_thread"
).filter(comment_thread_id__in=thread_ids_int)

for thread in threads:
read_date = read_dates.filter(comment_thread=thread).first()
if not read_date:
for read_date in last_read_times:
thread_id = read_date.comment_thread.pk
thread = thread_dict.get(thread_id)
if not thread:
continue

last_activity_at = thread.last_activity_at
last_activity_at = thread["last_activity_at"]
is_read = read_date.timestamp >= last_activity_at

# Count unread comments for this thread
unread_comment_count = (
Comment.objects.filter(
comment_thread=thread, created_at__gte=read_date.timestamp
comment_thread_id=thread_id, created_at__gte=read_date.timestamp
)
.exclude(author__pk=user_id)
.count()
)
read_states[str(thread.pk)] = [is_read, unread_comment_count]
read_states[str(thread_id)] = [is_read, unread_comment_count]

return read_states

Expand Down Expand Up @@ -524,11 +564,14 @@ def get_endorsed(thread_ids: list[str]) -> dict[str, bool]:
Returns:
dict[str, bool]: A dictionary of thread IDs to their endorsed status (True if endorsed, False otherwise).
"""
endorsed_comments = Comment.objects.filter(
comment_thread__pk__in=thread_ids, endorsed=True
# Optimize: Use values_list to avoid loading full objects
endorsed_thread_ids = (
Comment.objects.filter(comment_thread__pk__in=thread_ids, endorsed=True)
.values_list("comment_thread_id", flat=True)
.distinct()
)

return {str(comment.comment_thread.pk): True for comment in endorsed_comments}
return {str(thread_id): True for thread_id in endorsed_thread_ids}

@staticmethod
def get_user_read_state_by_course_id(
Expand Down Expand Up @@ -729,24 +772,44 @@ def handle_threads_query(
base_query = base_query.filter(
commentable_id__in=commentable_ids,
)
# Annotate comments count
base_query = base_query.annotate(
votes_point=Sum("uservote__vote"),
comments_count=Count("comment", distinct=True),
)

base_query = base_query.annotate(
votes_point=Sum("uservote__vote", distinct=True),
comments_count=Count("comment", distinct=True),
)

sort_criteria = cls.get_sort_criteria(sort_key)

# Only annotate votes_point if sorting by votes to avoid performance issues
# Otherwise calculate votes separately in bulk
if sort_key == "votes":
comment_thread_content_type = ContentType.objects.get_for_model(
CommentThread
)
base_query = base_query.annotate(
votes_point=Coalesce(
Subquery(
UserVote.objects.filter(
content_type=comment_thread_content_type,
content_object_id=OuterRef("pk"),
)
.values("content_object_id")
.annotate(votes_sum=Sum("vote"))
.values("votes_sum")[:1],
output_field=IntegerField(),
),
0,
),
)

base_query = base_query.select_related("author", "closed_by")

comment_threads = (
base_query.order_by(*sort_criteria) if sort_criteria else base_query
)
thread_count = base_query.count()

if raw_query:
comment_threads = comment_threads.prefetch_related("comment_set")
return {
"result": [
comment_thread.to_dict() for comment_thread in comment_threads
Expand All @@ -762,6 +825,7 @@ def handle_threads_query(
to_skip = (page - 1) * per_page
has_more = False

# Note: iterator() doesn't support prefetch_related, so we don't use it here
for thread in comment_threads.iterator():
thread_key = str(thread.pk)
if (
Expand All @@ -777,6 +841,8 @@ def handle_threads_query(
skipped += 1
num_pages = page + 1 if has_more else page
else:
# Apply prefetch_related when not using iterator()
comment_threads = comment_threads.prefetch_related("comment_set")
threads = [thread.pk for thread in comment_threads]
page = max(1, page)
start = per_page * (page - 1)
Expand Down Expand Up @@ -820,7 +886,10 @@ def prepare_thread(
Returns:
dict[str, Any]: A dictionary representing the prepared thread data.
"""
thread = CommentThread.objects.get(pk=thread_id)
# Optimize: Use select_related to avoid N+1 queries
thread = CommentThread.objects.select_related("author", "closed_by").get(
pk=thread_id
)
return {
**thread.to_dict(),
"type": "thread",
Expand Down Expand Up @@ -850,7 +919,13 @@ def threads_presentor(
Returns:
list[dict[str, Any]]: A list of prepared thread data.
"""
threads = CommentThread.objects.filter(pk__in=thread_ids)

threads = CommentThread.objects.filter(pk__in=thread_ids).select_related(
"author", "closed_by"
)

threads_dict = {thread.pk: thread for thread in threads}

read_states = cls.get_read_states(thread_ids, user_id, course_id)
threads_endorsed = cls.get_endorsed(thread_ids)
threads_flagged = (
Expand All @@ -859,7 +934,9 @@ def threads_presentor(

presenters = []
for thread_id in thread_ids:
thread = threads.get(id=thread_id)
thread = threads_dict.get(int(thread_id))
if not thread:
continue
is_read, unread_count = read_states.get(
str(thread.pk), (False, thread.comment_count)
)
Expand Down Expand Up @@ -1693,7 +1770,10 @@ def update_comment(comment_id: str, **kwargs: Any) -> int:
@staticmethod
def get_thread_id_from_comment(comment_id: str) -> dict[str, Any] | None:
"""Return thread_id from comment_id."""
comment = Comment.objects.get(pk=comment_id)
# Optimize: Use select_related to avoid N+1 queries
comment = Comment.objects.select_related(
"comment_thread__author", "comment_thread__closed_by"
).get(pk=comment_id)
if comment.comment_thread:
return comment.comment_thread.to_dict()
raise ValueError("Comment doesn't have the thread.")
Expand Down Expand Up @@ -2114,20 +2194,37 @@ def get_paginated_user_stats(
cls, course_id: str, page: int, per_page: int, sort_criterion: dict[str, Any]
) -> dict[str, Any]:
"""Get paginated user stats."""
users = User.objects.filter(
Q(course_stats__course_id=course_id)
& Q(course_stats__course_id__isnull=False)
).order_by(
*[f"-{key}" for key, value in sort_criterion.items() if value == -1],
*[key for key, value in sort_criterion.items() if value == 1],

users = (
User.objects.filter(
Q(course_stats__course_id=course_id)
& Q(course_stats__course_id__isnull=False)
)
.select_related("forum")
.prefetch_related("course_stats", "read_states__last_read_times")
.order_by(
*[f"-{key}" for key, value in sort_criterion.items() if value == -1],
*[key for key, value in sort_criterion.items() if value == 1],
)
)

paginator = Paginator(users, per_page)
paginated_users = paginator.page(page)

user_ids = [user.pk for user in paginated_users.object_list]
forum_users_dict = {
fu.user.pk: fu
for fu in ForumUser.objects.filter(user__pk__in=user_ids)
.select_related("user")
.prefetch_related(
"user__course_stats", "user__read_states__last_read_times"
)
}

forum_users = [
ForumUser.objects.get(user_id=user_id)
for user_id in paginated_users.object_list
forum_users_dict[user_id]
for user_id in user_ids
if user_id in forum_users_dict
]
return {
"pagination": [{"total_count": paginator.count}],
Expand Down