Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions comments/management/commands/sync_snapshot_comments_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from django.core.management.base import BaseCommand
from django.db.models import Count, Subquery, OuterRef
from django.db.models.functions import Coalesce

from comments.models import Comment
from posts.models import PostUserSnapshot

BATCH_SIZE = 5000


class Command(BaseCommand):
help = "Sync PostUserSnapshot.comments_count to match actual comment counts at viewed_at"

def add_arguments(self, parser):
parser.add_argument(
"--dry-run",
action="store_true",
help="Only report mismatches without updating",
)

def handle(self, *args, **options):
dry_run = options["dry_run"]

correct_count = Coalesce(
Subquery(
Comment.objects.filter(
on_post_id=OuterRef("post_id"),
is_private=False,
is_soft_deleted=False,
created_at__lte=OuterRef("viewed_at"),
)
.order_by()
.values("on_post_id")
.annotate(cnt=Count("id"))
.values("cnt")
),
0,
)

total = PostUserSnapshot.objects.count()
processed = 0
updated = 0
batch = []

for pk in PostUserSnapshot.objects.values_list("pk", flat=True).iterator(
chunk_size=BATCH_SIZE
):
batch.append(pk)
if len(batch) >= BATCH_SIZE:
if not dry_run:
updated += PostUserSnapshot.objects.filter(pk__in=batch).update(
comments_count=correct_count
)
processed += len(batch)
batch = []
if processed % 50000 == 0:
self.stdout.write(f" processed {processed}/{total}...")

if batch:
if not dry_run:
updated += PostUserSnapshot.objects.filter(pk__in=batch).update(
comments_count=correct_count
)
processed += len(batch)

self.stdout.write(
f"\nDone. Total: {total}, processed: {processed}, "
f"updated: {updated}{' (dry-run)' if dry_run else ''}"
)
1 change: 1 addition & 0 deletions comments/serializers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class CommentFilterSerializer(serializers.Serializer):
focus_comment_id = serializers.IntegerField(required=False, allow_null=True)
is_private = serializers.BooleanField(required=False, allow_null=True)
include_deleted = serializers.BooleanField(required=False, allow_null=True)
last_viewed_at = serializers.DateTimeField(required=False, allow_null=True)

def validate_post(self, value: int):
try:
Expand Down
38 changes: 37 additions & 1 deletion comments/services/feed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

from django.db.models import Q, Case, When, Value, IntegerField, Exists, OuterRef

from comments.models import Comment
Expand All @@ -13,6 +15,7 @@ def get_comments_feed(
is_private=None,
focus_comment_id: int = None,
include_deleted=False,
last_viewed_at: datetime = None,
):
user = user if user and user.is_authenticated else None
sort = sort or "-created_at"
Expand Down Expand Up @@ -40,6 +43,32 @@ def get_comments_feed(

order_by_args.append("-is_pinned_thread")

# Prioritize threads with unread comments for the current user
if last_viewed_at:
unread_comments = Comment.objects.filter(
on_post=post,
created_at__gt=last_viewed_at,
is_soft_deleted=False,
is_private=False,
)
unread_root_ids = {
root_id or comment_id
for root_id, comment_id in unread_comments.values_list("root_id", "id")
}

if unread_root_ids:
qs = qs.annotate(
has_unread_thread=Case(
When(
Q(pk__in=unread_root_ids) | Q(root_id__in=unread_root_ids),
then=Value(1),
),
default=Value(0),
output_field=IntegerField(),
),
)
order_by_args.append("-has_unread_thread")

if author is not None:
qs = qs.filter(author_id=author)

Expand Down Expand Up @@ -79,7 +108,14 @@ def get_comments_feed(
output_field=IntegerField(),
)
)
order_by_args.append("-is_focused_comment")
# Insert after pinned but before unread prioritization
# so focused comment always appears on the first page
pinned_idx = (
order_by_args.index("-is_pinned_thread") + 1
if "-is_pinned_thread" in order_by_args
else 0
)
order_by_args.insert(pinned_idx, "-is_focused_comment")

if sort:
if "vote_score" in sort:
Expand Down
17 changes: 9 additions & 8 deletions front_end/src/components/comment_feed/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,18 @@ const CommentFeed: FC<Props> = ({

if (user?.id && postId) {
fetchUserComments(user.id);
// Send BE request that user has read the post
const handler = setTimeout(() => {
markPostAsRead(postId).then();
}, 200);

return () => {
clearTimeout(handler);
};
}
}, [postId, user?.id]);

// Mark post as read after initial comments load completes
const hasMarkedAsRead = useRef(false);
useEffect(() => {
if (!isLoading && postId && user?.id && !hasMarkedAsRead.current) {
hasMarkedAsRead.current = true;
markPostAsRead(postId);
}
}, [isLoading, postId, user?.id]);

const feedOptions: GroupButton<FeedOptions>[] = [
{
value: "public",
Expand Down
8 changes: 7 additions & 1 deletion posts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,11 +1041,17 @@ class Meta:

@classmethod
def update_last_forecast_date(cls, post: Post, user: User):
now = timezone.now()
cls.objects.update_or_create(
user=user,
post=post,
defaults={
"last_forecast_date": timezone.now(),
"last_forecast_date": now,
},
create_defaults={
"last_forecast_date": now,
"comments_count": post.get_comment_count(),
"viewed_at": now,
},
)

Expand Down
115 changes: 115 additions & 0 deletions tests/unit/test_comments/test_views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import timedelta
from urllib.parse import urlencode

import pytest # noqa
from django.urls import reverse
Expand All @@ -14,6 +15,7 @@
KeyFactorNews,
)
from comments.services.feed import get_comments_feed
from posts.models import PostUserSnapshot
from questions.models import Forecast
from questions.services.forecasts import create_forecast
from tests.unit.test_comments.factories import factory_comment, factory_key_factor
Expand Down Expand Up @@ -842,3 +844,116 @@ def test_comment_edit_include_forecast_closed_question(
comment.refresh_from_db()
# Should attach forecast active at closure time
assert comment.included_forecast == forecast


class TestUnreadThreadPrioritization:
"""Test that threads with unread comments are prioritized on the first page
and that ordering remains consistent across paginated requests even after
markPostAsRead updates the snapshot."""

@pytest.fixture()
def setup(self, user1, user2):
post = factory_post(author=user1)
now = timezone.now()
viewed_at = now - timedelta(hours=1)

# Create 3 old root comments (before viewed_at) — all "read"
old_roots = []
for i in range(3):
with freeze_time(now - timedelta(hours=3, minutes=i)):
c = factory_comment(author=user2, on_post=post, text=f"old_root_{i}")
old_roots.append(c)

# Create 2 old root comments that have NEW replies (after viewed_at)
roots_with_new_replies = []
for i in range(2):
with freeze_time(now - timedelta(hours=3, minutes=10 + i)):
root = factory_comment(
author=user2, on_post=post, text=f"root_with_reply_{i}"
)
with freeze_time(now - timedelta(minutes=30 - i)):
factory_comment(
author=user2,
on_post=post,
parent=root,
text=f"new_reply_{i}",
)
roots_with_new_replies.append(root)

# Create 3 new root comments (after viewed_at) — all "unread"
new_roots = []
for i in range(3):
with freeze_time(now - timedelta(minutes=20 - i)):
c = factory_comment(author=user2, on_post=post, text=f"new_root_{i}")
new_roots.append(c)

return {
"post": post,
"viewed_at": viewed_at,
"old_roots": old_roots,
"roots_with_new_replies": roots_with_new_replies,
"new_roots": new_roots,
}

def test_unread_threads_first_across_pages(self, user1_client, setup, user1):
"""With 8 root comments (5 unread threads, 3 read), page size 3:
- Page 1 should contain 3 unread threads
- Page 2 should contain 2 unread + 1 read thread
- Page 3 should contain remaining 2 read threads
Even after markPostAsRead fires between page 1 and page 2,
the ordering should stay consistent because last_viewed_at is
passed as a query param."""
post = setup["post"]
viewed_at = setup["viewed_at"]
old_roots = setup["old_roots"]
roots_with_new_replies = setup["roots_with_new_replies"]
new_roots = setup["new_roots"]

unread_root_ids = {c.pk for c in new_roots} | {
c.pk for c in roots_with_new_replies
}
read_root_ids = {c.pk for c in old_roots}

params = urlencode(
{
"post": post.pk,
"limit": 3,
"sort": "-created_at",
"use_root_comments_pagination": "true",
"last_viewed_at": viewed_at.isoformat(),
}
)

# Page 1: all 3 root comments should be from unread threads
response = user1_client.get(f"/api/comments/?{params}")
assert response.status_code == 200
page1_root_ids = {
r["id"] for r in response.data["results"] if r["parent_id"] is None
}
assert page1_root_ids.issubset(unread_root_ids)
assert len(page1_root_ids) == 3

# Simulate markPostAsRead (updates snapshot to now)
PostUserSnapshot.update_viewed_at(post, user1)

# Page 2: uses same last_viewed_at, so ordering is stable
response = user1_client.get(response.data["next"])
assert response.status_code == 200
page2_root_ids = {
r["id"] for r in response.data["results"] if r["parent_id"] is None
}
# Should have remaining 2 unread + 1 read
assert (page2_root_ids & unread_root_ids) == unread_root_ids - page1_root_ids
assert len(page2_root_ids & read_root_ids) >= 1

# Page 3: remaining read threads
response = user1_client.get(response.data["next"])
assert response.status_code == 200
page3_root_ids = {
r["id"] for r in response.data["results"] if r["parent_id"] is None
}
assert page3_root_ids.issubset(read_root_ids)

# All roots accounted for, no duplicates
all_paged = page1_root_ids | page2_root_ids | page3_root_ids
assert all_paged == unread_root_ids | read_root_ids
13 changes: 9 additions & 4 deletions users/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from projects.models import ProjectUserPermission
from questions.models import Forecast
from users.models import User, UserCampaignRegistration, UserSpamActivity
from users.services.common import (
clean_user_data_delete,
mark_user_as_spam,
soft_delete_user,
)
from users.services.spam_detection import (
CONFIDENCE_THRESHOLD,
check_profile_data_for_spam,
Expand Down Expand Up @@ -306,15 +311,15 @@ def bio_length(self, obj):

def mark_selected_as_spam(self, request, queryset: QuerySet[User]):
for user in queryset:
user.mark_as_spam()
mark_user_as_spam(user)

def soft_delete_selected(self, request, queryset: QuerySet[User]):
for user in queryset:
user.soft_delete()
soft_delete_user(user)

def clean_user_data_deletion(self, request, queryset: QuerySet[User]):
for user in queryset:
user.clean_user_data_delete()
clean_user_data_delete(user)

clean_user_data_deletion.short_description = (
"One click Personal Data deletion (GDPR compliant)"
Expand All @@ -329,7 +334,7 @@ def run_profile_spam_detection_on_selected(self, request, queryset: QuerySet[Use
)

if is_spam:
user.mark_as_spam()
mark_user_as_spam(user)
send_deactivation_email(user.email)

def get_fields(self, request, obj=None):
Expand Down
Loading
Loading