diff --git a/main.py b/main.py index 819c8733..f10c5b7f 100644 --- a/main.py +++ b/main.py @@ -489,7 +489,7 @@ def push_rules( log.info(f"Folder {sanitize_for_log(folder_name)} - no new rules to push after filtering duplicates") return True - total_batches = len(range(0, len(filtered_hostnames), BATCH_SIZE)) + total_batches = (len(filtered_hostnames) + BATCH_SIZE - 1) // BATCH_SIZE # Helper for processing a single batch def _push_batch(batch_idx: int, start_idx: int) -> bool: @@ -532,7 +532,20 @@ def _push_batch(batch_idx: int, start_idx: int) -> bool: } for future in concurrent.futures.as_completed(futures): - if future.result(): + batch_idx = futures[future] + try: + result = future.result() + except Exception as e: + log.error( + "Unexpected error while processing batch %d for folder %s: %s", + batch_idx, + sanitize_for_log(folder_name), + sanitize_for_log(e), + ) + log.debug("Unexpected exception details", exc_info=True) + continue + + if result: successful_batches += 1 if successful_batches == total_batches: diff --git a/tests/test_performance.py b/tests/test_performance.py index 24d2290d..4217799f 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock, patch import time import threading -import main from main import push_rules, BATCH_SIZE +import httpx class TestPushRulesPerformance(unittest.TestCase): def setUp(self): @@ -16,7 +16,7 @@ def setUp(self): self.existing_rules = set() @patch('main._api_post_form') - def test_push_rules_correctness_with_lock(self, mock_post): + def test_push_rules_parallel_with_lock(self, mock_post): # Create enough hostnames for 5 batches num_batches = 5 hostnames = [f"host-{i}.com" for i in range(BATCH_SIZE * num_batches)] @@ -44,11 +44,11 @@ def test_push_rules_correctness_with_lock(self, mock_post): self.assertEqual(mock_post.call_count, num_batches) self.assertEqual(len(self.existing_rules), len(hostnames)) - print(f"\n[Sequential Baseline (Lock)] Duration: {duration:.4f}s") + print(f"\n[Parallel with Lock] Duration: {duration:.4f}s") @patch('main._api_post_form') def test_push_rules_concurrency(self, mock_post): - # Create enough hostnames for 5 batches + # Create enough hostnames for 10 batches num_batches = 10 hostnames = [f"host-{i}.com" for i in range(BATCH_SIZE * num_batches)] @@ -77,5 +77,42 @@ def delayed_post(*args, **kwargs): print(f"\n[Performance Test] Duration for {num_batches} batches with 0.1s latency: {duration:.4f}s") + @patch('main._api_post_form') + def test_push_rules_partial_failure(self, mock_post): + # Create enough hostnames for 5 batches + num_batches = 5 + hostnames = [f"host-{i}.com" for i in range(BATCH_SIZE * num_batches)] + + # Mock failure for some batches + call_count = 0 + def partial_failure(*args, **kwargs): + nonlocal call_count + call_count += 1 + # Fail batches 2 and 4 + if call_count in [2, 4]: + raise httpx.HTTPError("Simulated API failure") + return MagicMock(status_code=200) + + mock_post.side_effect = partial_failure + + success = push_rules( + self.profile_id, + self.folder_name, + self.folder_id, + self.do, + self.status, + hostnames, + self.existing_rules, + self.client + ) + + # Should return False when some batches fail + self.assertFalse(success) + self.assertEqual(mock_post.call_count, num_batches) + # Only 3 batches should have succeeded and updated existing_rules + self.assertEqual(len(self.existing_rules), BATCH_SIZE * 3) + + print(f"\n[Partial Failure Test] {mock_post.call_count} batches attempted, 3 succeeded") + if __name__ == '__main__': unittest.main()