From f862b1b3ea9ab5dde47fa5afee37e95c6c409a2a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:32:48 +0000 Subject: [PATCH 1/5] Initial plan From 8fc350a7dcdc52923e91622dce4624bfa69198dd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:39:41 +0000 Subject: [PATCH 2/5] fix: address code review comments - remove commented code, fix IP filtering, improve imports Co-authored-by: happylittle7 <7501374+happylittle7@users.noreply.github.com> --- submissions/views.py | 191 +++++++------------------------------------ 1 file changed, 30 insertions(+), 161 deletions(-) diff --git a/submissions/views.py b/submissions/views.py index 2ad1efa..07c3da5 100644 --- a/submissions/views.py +++ b/submissions/views.py @@ -13,6 +13,9 @@ import uuid from datetime import datetime import logging +import ipaddress + +logger = logging.getLogger(__name__) # 統一的 API 響應格式 def api_response(data=None, message="OK", status_code=200): @@ -110,130 +113,20 @@ def update_user_problem_stats(submission): stats.solve_status = 'fully_solved' elif stats.best_score > 0: stats.solve_status = 'partial_solved' - elif stats.total_submissions > 0: - stats.solve_status = 'attempted' else: - stats.solve_status = 'never_tried' + stats.solve_status = 'attempted' stats.save() - logger = logging.getLogger(__name__) logger.info(f'Updated solve status for user {submission.user.id} problem {submission.problem_id}: ' f'status={stats.solve_status}, best_score={stats.best_score}, ' f'total_submissions={stats.total_submissions}') except Exception as e: - logger = logging.getLogger(__name__) logger.error(f'Failed to update user problem solve status: {str(e)}', exc_info=True) -''' 這邊不再需要了,算成績的部分交給前端處理 -def update_user_assignment_stats(submission, assignment_id): - """ - 更新使用者作業題目統計(作業層級) - - 根據提交結果更新 UserProblemStats,包括: - - 總提交數 - - 最佳分數 - - 首次 AC 時間 - - 最後提交時間 - - 解題狀態(unsolved/partial/solved) - - 最佳執行時間和記憶體使用 - - 遲交處理 - """ - from django.utils import timezone - from django.db.models import F - from .models import UserProblemStats - from assignments.models import Assignments, Assignment_problems - - try: - # 檢查是否屬於作業 - if not assignment_id: - return - - # 檢查 assignment 和 problem 的關聯 - try: - assignment = Assignments.objects.get(id=assignment_id) - assignment_problem = Assignment_problems.objects.get( - assignment=assignment, - problem_id=submission.problem_id - ) - except (Assignments.DoesNotExist, Assignment_problems.DoesNotExist): - # 如果作業或題目不存在,跳過 - return - - # 取得或創建 UserProblemStats - stats, created = UserProblemStats.objects.get_or_create( - user=submission.user, - assignment_id=assignment_id, - problem_id=submission.problem_id, - defaults={ - 'total_submissions': 0, - 'best_score': 0, - 'max_possible_score': assignment_problem.weight * 100 if assignment_problem.weight else 100, - 'solve_status': 'unsolved', - } - ) - - # 更新總提交數 - stats.total_submissions = F('total_submissions') + 1 - - # 更新最後提交時間 - stats.last_submission_time = submission.created_at - - # 檢查是否遲交 - if assignment.due_time and submission.created_at > assignment.due_time: - stats.is_late = True - # 計算遲交罰分 - if assignment.late_penalty > 0: - stats.penalty_score = assignment.late_penalty - - # 先保存以計算 F() 表達式 - stats.save() - stats.refresh_from_db() - - # 更新最佳分數(考慮遲交罰分) - final_score = submission.score - if stats.is_late and stats.penalty_score > 0: - final_score = int(submission.score * (1 - float(stats.penalty_score) / 100)) - - if final_score > stats.best_score: - stats.best_score = final_score - stats.best_submission = submission - - # 如果是 AC (status='0') 且還沒有 first_ac_time - if submission.status == '0' and not stats.first_ac_time: - stats.first_ac_time = submission.judged_at or timezone.now() - - # 更新最佳執行時間 - if submission.execution_time > 0: - if stats.best_execution_time is None or submission.execution_time < stats.best_execution_time: - stats.best_execution_time = submission.execution_time - - # 更新最佳記憶體使用 - if submission.memory_usage > 0: - if stats.best_memory_usage is None or submission.memory_usage < stats.best_memory_usage: - stats.best_memory_usage = submission.memory_usage - - # 更新解題狀態 - if stats.best_score >= stats.max_possible_score: - stats.solve_status = 'solved' - elif stats.best_score > 0: - stats.solve_status = 'partial' - else: - stats.solve_status = 'unsolved' - - stats.save() - - logger = logging.getLogger(__name__) - logger.info(f'Updated assignment stats for user {submission.user.id} ' - f'assignment {assignment_id} problem {submission.problem_id}: ' - f'status={stats.solve_status}, best_score={stats.best_score}, ' - f'is_late={stats.is_late}') - - except Exception as e: - logger = logging.getLogger(__name__) - logger.error(f'Failed to update user assignment stats: {str(e)}', exc_info=True) -''' + + class BasePermissionMixin: """基礎權限檢查 Mixin - 提供通用權限檢查方法""" @@ -613,8 +506,6 @@ def editorial_like_toggle(request, problem_id, solution_id): except Exception as e: # 記錄完整錯誤但不回傳給客戶端 - import logging - logger = logging.getLogger(__name__) logger.error(f"Editorial update failed: {str(e)}", exc_info=True) return api_response( @@ -680,8 +571,6 @@ def get_queryset(self): problem_ids = Problems.objects.filter(course_id=course_id).values_list('id', flat=True) queryset = queryset.filter(problem_id__in=problem_ids) except (ValueError, TypeError) as e: - import logging - logger = logging.getLogger(__name__) logger.warning(f'Invalid course_id parameter: {course_id}, error: {e}') queryset = queryset.none() # 查詢失敗返回空結果 @@ -712,25 +601,34 @@ def get_queryset(self): try: # 支援 CIDR 格式 (例如 192.168.1.0/24) 或簡單前綴 (例如 192.168.) if '/' in ip_prefix: - # CIDR 格式:使用 ipaddress 模組 - import ipaddress + # CIDR 格式:使用 ipaddress 模組進行範圍檢查 network = ipaddress.ip_network(ip_prefix, strict=False) - # 獲取網段範圍 - start_ip = str(network.network_address) - end_ip = str(network.broadcast_address) - # 篩選 IP 在範圍內的提交 - queryset = queryset.filter( - ip_address__gte=start_ip, - ip_address__lte=end_ip - ) + + # 由於 SQLite 將 IP 存為字串,我們需要在 Python 層面進行過濾 + # 先獲取所有可能的 IP(基於前綴的簡單過濾以減少數據量) + network_prefix = str(network.network_address).rsplit('.', 2)[0] # 取得網段的前綴部分 + candidate_submissions = queryset.filter(ip_address__startswith=network_prefix) + + # 在 Python 中過濾出符合 CIDR 範圍的提交 + valid_submission_ids = [] + for submission in candidate_submissions: + try: + ip = ipaddress.ip_address(submission.ip_address) + if ip in network: + valid_submission_ids.append(submission.id) + except (ValueError, AttributeError): + # 忽略無效的 IP 地址 + pass + + queryset = queryset.filter(id__in=valid_submission_ids) else: # 簡單前綴匹配:例如 "192.168." 會匹配所有 192.168.x.x queryset = queryset.filter(ip_address__startswith=ip_prefix) except (ValueError, TypeError) as e: - import logging - logger = logging.getLogger(__name__) logger.warning(f'Invalid ip_prefix parameter: {ip_prefix}, error: {e}') - pass # 忽略無效的 IP 前綴 + # 返回錯誤而不是靜默忽略 + from rest_framework.exceptions import ValidationError + raise ValidationError({'ip_prefix': f'Invalid IP prefix format: {ip_prefix}'}) return self.get_viewable_submissions(self.request.user, queryset) @@ -760,8 +658,6 @@ def post(self, request, *args, **kwargs): try: # Debug logging - import logging - logger = logging.getLogger(__name__) logger.info(f'POST /submission/ data: {request.data}') serializer = self.get_serializer(data=request.data) @@ -912,8 +808,6 @@ def post(self, request, *args, **kwargs): except Exception as e: # 其他系統錯誤 - import logging - logger = logging.getLogger(__name__) logger.error(f"提交創建失敗: {str(e)}", exc_info=True) return api_response(data=None, message=f"系統錯誤: {str(e)[:200]}", status_code=status.HTTP_400_BAD_REQUEST) @@ -1289,12 +1183,8 @@ def submission_rejudge(request, id): from .tasks import submit_to_sandbox_task try: submit_to_sandbox_task.delay(str(submission.id)) - import logging - logger = logging.getLogger(__name__) logger.info(f'Rejudge queued for submission: {submission.id}') except Exception as e: - import logging - logger = logging.getLogger(__name__) logger.error(f'Failed to queue rejudge for {submission.id}: {str(e)}') # 即使 celery 失敗,也回傳成功(submission 已經重設為 pending) @@ -1485,8 +1375,6 @@ def get_diff_count(name): ) redis_client.ping() # 測試連接 except Exception as e: - import logging - logger = logging.getLogger(__name__) logger.error(f'Redis connection failed: {str(e)}') redis_client = None @@ -1668,8 +1556,6 @@ def submit_custom_test(request, problem_id): ) except Exception as e: - import logging - logger = logging.getLogger(__name__) logger.error(f'Custom test submission error: {str(e)}') return api_response( data=None, @@ -1804,8 +1690,6 @@ def get_custom_test_result(request, custom_test_id): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR ) except Exception as e: - import logging - logger = logging.getLogger(__name__) logger.error(f'Get custom test result error: {str(e)}') return api_response( data=None, @@ -1823,7 +1707,7 @@ class SubmissionCallbackAPIView(APIView): 接收 Sandbox 判題結果的 callback endpoint Sandbox 判題完成後會 POST 到這個 endpoint - URL: POST /{callback_url}/submissions/callback/ + URL: POST /api/submissions/callback/ """ permission_classes = [permissions.AllowAny] # Sandbox 不需要 JWT 認證 @@ -1853,12 +1737,9 @@ def post(self, request): ] } """ - import logging from .models import Submission, SubmissionResult from django.conf import settings - logger = logging.getLogger(__name__) - try: # 1. 驗證請求來源(API Key) api_key = request.headers.get('X-API-KEY') @@ -1939,16 +1820,7 @@ def post(self, request): # 5. 更新 UserProblemSolveStatus(全域層級) update_user_problem_stats(submission) - '''這邊也不需要了,算成績的部分交給前端處理 - - # 6. 更新 UserProblemStats(作業層級) - # 嘗試從 submission 找到對應的 assignment_id - # submission model 需要加上 assignment_id 欄位,或者從 context 傳入 - # 目前先跳過,等 submission model 更新後再啟用 - # if hasattr(submission, 'assignment_id') and submission.assignment_id: - # update_user_assignment_stats(submission, submission.assignment_id) - - ''' + return api_response( data={'submission_id': str(submission_id)}, message='Callback processed successfully', @@ -1989,12 +1861,9 @@ def post(self, request): "exit_code": 0 } """ - import logging from .models import CustomTest from django.conf import settings - logger = logging.getLogger(__name__) - try: # 1. 驗證請求來源(API Key) api_key = request.headers.get('X-API-KEY') From 0565ee77b9383b1028668b6307924b74faacb75a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:41:13 +0000 Subject: [PATCH 3/5] test: add comprehensive tests for IP filtering and user stats updates Co-authored-by: happylittle7 <7501374+happylittle7@users.noreply.github.com> --- .../test_file/test_ip_filtering_and_stats.py | 444 ++++++++++++++++++ 1 file changed, 444 insertions(+) create mode 100644 submissions/test_file/test_ip_filtering_and_stats.py diff --git a/submissions/test_file/test_ip_filtering_and_stats.py b/submissions/test_file/test_ip_filtering_and_stats.py new file mode 100644 index 0000000..2f761bc --- /dev/null +++ b/submissions/test_file/test_ip_filtering_and_stats.py @@ -0,0 +1,444 @@ +# submissions/test_file/test_ip_filtering_and_stats.py +""" +Tests for IP filtering functionality and user problem stats updates +""" + +import pytest +import uuid +from datetime import datetime, timezone as dt_timezone +from unittest.mock import patch, Mock + +from django.test import TestCase +from django.contrib.auth import get_user_model +from django.urls import reverse +from rest_framework.test import APITestCase, APIClient +from rest_framework import status + +from ..models import Submission, UserProblemSolveStatus +from ..views import update_user_problem_stats +from problems.models import Problems +from courses.models import Courses + +User = get_user_model() + + +@pytest.mark.django_db +class TestIPFilteringAPI(APITestCase): + """Test IP filtering functionality in submission list API""" + + @classmethod + def setUpTestData(cls): + """Set up test data for IP filtering tests""" + from user.models import UserProfile + + # Create test user + cls.user = User.objects.create_user( + username='ip_test_user', + email='ip_test@test.com', + password='testpass123' + ) + profile, _ = UserProfile.objects.get_or_create(user=cls.user) + profile.email_verified = True + profile.save() + + # Create test problem + cls.course = Courses.objects.create( + course_name='Test Course', + year=2024, + semester=1, + teacher_id=cls.user.id + ) + + cls.problem = Problems.objects.create( + problem_name='Test Problem', + time_limit=1000, + memory_limit=256000, + course_id=cls.course.id + ) + + # Create submissions with different IPs + cls.submission1 = Submission.objects.create( + user=cls.user, + problem_id=cls.problem.id, + language_type=1, + source_code='print("test")', + ip_address='192.168.1.10', + status='-1', + score=0 + ) + + cls.submission2 = Submission.objects.create( + user=cls.user, + problem_id=cls.problem.id, + language_type=1, + source_code='print("test2")', + ip_address='192.168.1.20', + status='-1', + score=0 + ) + + cls.submission3 = Submission.objects.create( + user=cls.user, + problem_id=cls.problem.id, + language_type=1, + source_code='print("test3")', + ip_address='192.168.2.10', + status='-1', + score=0 + ) + + cls.submission4 = Submission.objects.create( + user=cls.user, + problem_id=cls.problem.id, + language_type=1, + source_code='print("test4")', + ip_address='10.0.0.5', + status='-1', + score=0 + ) + + def setUp(self): + """Set up for each test""" + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + def test_simple_prefix_filtering(self): + """Test simple IP prefix filtering (e.g., '192.168.1.')""" + response = self.client.get('/submission/', {'ip_prefix': '192.168.1.'}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + + # Should only return submissions with IPs starting with 192.168.1. + submission_ids = [item['id'] for item in data['data']] + self.assertIn(str(self.submission1.id), submission_ids) + self.assertIn(str(self.submission2.id), submission_ids) + self.assertNotIn(str(self.submission3.id), submission_ids) + self.assertNotIn(str(self.submission4.id), submission_ids) + + def test_cidr_filtering(self): + """Test CIDR notation IP filtering (e.g., '192.168.1.0/24')""" + response = self.client.get('/submission/', {'ip_prefix': '192.168.1.0/24'}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + + # Should return submissions in the 192.168.1.0/24 range + submission_ids = [item['id'] for item in data['data']] + self.assertIn(str(self.submission1.id), submission_ids) + self.assertIn(str(self.submission2.id), submission_ids) + self.assertNotIn(str(self.submission3.id), submission_ids) + self.assertNotIn(str(self.submission4.id), submission_ids) + + def test_cidr_filtering_larger_network(self): + """Test CIDR filtering with larger network (e.g., '192.168.0.0/16')""" + response = self.client.get('/submission/', {'ip_prefix': '192.168.0.0/16'}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + + # Should return all submissions in the 192.168.0.0/16 range + submission_ids = [item['id'] for item in data['data']] + self.assertIn(str(self.submission1.id), submission_ids) + self.assertIn(str(self.submission2.id), submission_ids) + self.assertIn(str(self.submission3.id), submission_ids) + self.assertNotIn(str(self.submission4.id), submission_ids) + + def test_invalid_ip_prefix_returns_error(self): + """Test that invalid IP prefix returns validation error""" + response = self.client.get('/submission/', {'ip_prefix': 'invalid_ip'}) + + # Should return 400 Bad Request with validation error + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + data = response.json() + self.assertIn('ip_prefix', str(data)) + + def test_invalid_cidr_returns_error(self): + """Test that invalid CIDR notation returns validation error""" + response = self.client.get('/submission/', {'ip_prefix': '192.168.1.0/99'}) + + # Should return 400 Bad Request + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_no_ip_prefix_returns_all(self): + """Test that without ip_prefix parameter, all submissions are returned""" + response = self.client.get('/submission/') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + data = response.json() + + # Should return all submissions + submission_ids = [item['id'] for item in data['data']] + self.assertIn(str(self.submission1.id), submission_ids) + self.assertIn(str(self.submission2.id), submission_ids) + self.assertIn(str(self.submission3.id), submission_ids) + self.assertIn(str(self.submission4.id), submission_ids) + + +@pytest.mark.django_db +class TestUserProblemStatsUpdate(TestCase): + """Test update_user_problem_stats function""" + + def setUp(self): + """Set up test data""" + from user.models import UserProfile + + # Create test user + self.user = User.objects.create_user( + username='stats_test_user', + email='stats_test@test.com', + password='testpass123' + ) + profile, _ = UserProfile.objects.get_or_create(user=self.user) + profile.email_verified = True + profile.save() + + # Create test problem + self.course = Courses.objects.create( + course_name='Test Course', + year=2024, + semester=1, + teacher_id=self.user.id + ) + + self.problem = Problems.objects.create( + problem_name='Test Problem', + time_limit=1000, + memory_limit=256000, + course_id=self.course.id + ) + + def test_first_submission_creates_stats(self): + """Test that first submission creates UserProblemSolveStatus""" + submission = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test")', + ip_address='127.0.0.1', + status='1', # WA + score=0, + execution_time=100, + memory_usage=1024 + ) + + update_user_problem_stats(submission) + + stats = UserProblemSolveStatus.objects.get( + user=self.user, + problem_id=self.problem.id + ) + + self.assertEqual(stats.total_submissions, 1) + self.assertEqual(stats.ac_submissions, 0) + self.assertEqual(stats.best_score, 0) + self.assertEqual(stats.solve_status, 'attempted') + self.assertIsNone(stats.first_solve_time) + + def test_ac_submission_updates_stats(self): + """Test that AC submission updates stats correctly""" + submission = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test")', + ip_address='127.0.0.1', + status='0', # AC + score=100, + execution_time=100, + memory_usage=1024, + judged_at=datetime.now(dt_timezone.utc) + ) + + update_user_problem_stats(submission) + + stats = UserProblemSolveStatus.objects.get( + user=self.user, + problem_id=self.problem.id + ) + + self.assertEqual(stats.total_submissions, 1) + self.assertEqual(stats.ac_submissions, 1) + self.assertEqual(stats.best_score, 100) + self.assertEqual(stats.solve_status, 'fully_solved') + self.assertIsNotNone(stats.first_solve_time) + + def test_partial_score_updates_status(self): + """Test that partial score updates solve status to 'partial_solved'""" + submission = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test")', + ip_address='127.0.0.1', + status='1', # WA + score=50, + execution_time=100, + memory_usage=1024 + ) + + update_user_problem_stats(submission) + + stats = UserProblemSolveStatus.objects.get( + user=self.user, + problem_id=self.problem.id + ) + + self.assertEqual(stats.best_score, 50) + self.assertEqual(stats.solve_status, 'partial_solved') + + def test_multiple_submissions_update_correctly(self): + """Test that multiple submissions update stats correctly""" + # First submission - low score + submission1 = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test1")', + ip_address='127.0.0.1', + status='1', # WA + score=30, + execution_time=150, + memory_usage=2048 + ) + update_user_problem_stats(submission1) + + # Second submission - higher score + submission2 = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test2")', + ip_address='127.0.0.1', + status='1', # WA + score=70, + execution_time=100, + memory_usage=1024 + ) + update_user_problem_stats(submission2) + + stats = UserProblemSolveStatus.objects.get( + user=self.user, + problem_id=self.problem.id + ) + + self.assertEqual(stats.total_submissions, 2) + self.assertEqual(stats.best_score, 70) + self.assertEqual(stats.best_execution_time, 100) + self.assertEqual(stats.best_memory_usage, 1024) + self.assertEqual(stats.solve_status, 'partial_solved') + + def test_best_execution_time_updated(self): + """Test that best execution time is tracked correctly""" + submission1 = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test1")', + ip_address='127.0.0.1', + status='0', # AC + score=100, + execution_time=200, + memory_usage=1024 + ) + update_user_problem_stats(submission1) + + submission2 = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test2")', + ip_address='127.0.0.1', + status='0', # AC + score=100, + execution_time=100, + memory_usage=1024 + ) + update_user_problem_stats(submission2) + + stats = UserProblemSolveStatus.objects.get( + user=self.user, + problem_id=self.problem.id + ) + + self.assertEqual(stats.best_execution_time, 100) + + def test_invalid_execution_time_ignored(self): + """Test that invalid (zero or negative) execution times are ignored""" + submission = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test")', + ip_address='127.0.0.1', + status='0', # AC + score=100, + execution_time=0, # Invalid + memory_usage=1024 + ) + + update_user_problem_stats(submission) + + stats = UserProblemSolveStatus.objects.get( + user=self.user, + problem_id=self.problem.id + ) + + # execution_time should not be updated since it's 0 + self.assertIsNone(stats.best_execution_time) + + def test_solve_status_transitions(self): + """Test that solve status transitions correctly""" + # Start with attempted + submission1 = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test")', + ip_address='127.0.0.1', + status='1', # WA + score=0, + execution_time=100, + memory_usage=1024 + ) + update_user_problem_stats(submission1) + + stats = UserProblemSolveStatus.objects.get( + user=self.user, + problem_id=self.problem.id + ) + self.assertEqual(stats.solve_status, 'attempted') + + # Move to partial_solved + submission2 = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test2")', + ip_address='127.0.0.1', + status='1', # WA + score=50, + execution_time=100, + memory_usage=1024 + ) + update_user_problem_stats(submission2) + + stats.refresh_from_db() + self.assertEqual(stats.solve_status, 'partial_solved') + + # Move to fully_solved + submission3 = Submission.objects.create( + user=self.user, + problem_id=self.problem.id, + language_type=1, + source_code='print("test3")', + ip_address='127.0.0.1', + status='0', # AC + score=100, + execution_time=100, + memory_usage=1024 + ) + update_user_problem_stats(submission3) + + stats.refresh_from_db() + self.assertEqual(stats.solve_status, 'fully_solved') From 03fddbfda4bf4da7a890aac8c0f9bc58d5cc755b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:42:47 +0000 Subject: [PATCH 4/5] fix: improve network prefix calculation and add missing default field Co-authored-by: happylittle7 <7501374+happylittle7@users.noreply.github.com> --- submissions/views.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/submissions/views.py b/submissions/views.py index 07c3da5..fd28a11 100644 --- a/submissions/views.py +++ b/submissions/views.py @@ -72,6 +72,7 @@ def update_user_problem_stats(submission): 'ac_submissions': 0, 'best_score': 0, 'solve_status': 'never_tried', + 'total_execution_time': 0, } ) @@ -606,8 +607,20 @@ def get_queryset(self): # 由於 SQLite 將 IP 存為字串,我們需要在 Python 層面進行過濾 # 先獲取所有可能的 IP(基於前綴的簡單過濾以減少數據量) - network_prefix = str(network.network_address).rsplit('.', 2)[0] # 取得網段的前綴部分 - candidate_submissions = queryset.filter(ip_address__startswith=network_prefix) + # 取得網段的前 1-3 個八位元組作為前綴(例如:192.168.1.0/24 -> "192.168") + network_parts = str(network.network_address).split('.') + if network.prefixlen >= 16: + network_prefix = '.'.join(network_parts[:2]) # 取前兩個八位元組 + elif network.prefixlen >= 8: + network_prefix = network_parts[0] # 取第一個八位元組 + else: + network_prefix = '' # 不使用前綴過濾,檢查所有 IP + + if network_prefix: + candidate_submissions = queryset.filter(ip_address__startswith=network_prefix) + else: + candidate_submissions = queryset + # 在 Python 中過濾出符合 CIDR 範圍的提交 valid_submission_ids = [] From 4fea17c3a08db7c1d8e571b0b033da14ba721cf2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:44:11 +0000 Subject: [PATCH 5/5] perf: optimize CIDR filtering with better prefix matching and values_list Co-authored-by: happylittle7 <7501374+happylittle7@users.noreply.github.com> --- submissions/views.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/submissions/views.py b/submissions/views.py index fd28a11..fe9239b 100644 --- a/submissions/views.py +++ b/submissions/views.py @@ -607,28 +607,33 @@ def get_queryset(self): # 由於 SQLite 將 IP 存為字串,我們需要在 Python 層面進行過濾 # 先獲取所有可能的 IP(基於前綴的簡單過濾以減少數據量) - # 取得網段的前 1-3 個八位元組作為前綴(例如:192.168.1.0/24 -> "192.168") + # 取得網段的前綴(例如:192.168.1.0/24 -> "192.168.1") network_parts = str(network.network_address).split('.') - if network.prefixlen >= 16: - network_prefix = '.'.join(network_parts[:2]) # 取前兩個八位元組 + if network.prefixlen >= 24: + network_prefix = '.'.join(network_parts[:3]) + '.' # 取前三個八位元組 + elif network.prefixlen >= 16: + network_prefix = '.'.join(network_parts[:2]) + '.' # 取前兩個八位元組 elif network.prefixlen >= 8: - network_prefix = network_parts[0] # 取第一個八位元組 + network_prefix = network_parts[0] + '.' # 取第一個八位元組(加點避免誤匹配) else: network_prefix = '' # 不使用前綴過濾,檢查所有 IP if network_prefix: - candidate_submissions = queryset.filter(ip_address__startswith=network_prefix) + # 使用 values_list 只獲取需要的欄位以提高效能 + candidate_data = queryset.filter( + ip_address__startswith=network_prefix + ).values_list('id', 'ip_address') else: - candidate_submissions = queryset + candidate_data = queryset.values_list('id', 'ip_address') # 在 Python 中過濾出符合 CIDR 範圍的提交 valid_submission_ids = [] - for submission in candidate_submissions: + for sub_id, ip_addr in candidate_data: try: - ip = ipaddress.ip_address(submission.ip_address) + ip = ipaddress.ip_address(ip_addr) if ip in network: - valid_submission_ids.append(submission.id) + valid_submission_ids.append(sub_id) except (ValueError, AttributeError): # 忽略無效的 IP 地址 pass