diff --git a/scraper.py b/scraper.py index bf44ac30d..da08e6bb8 100644 --- a/scraper.py +++ b/scraper.py @@ -1916,10 +1916,18 @@ async def build_certificate_artifacts( follow_redirects=True, timeout=timeout, ) as client: - tasks = [] - for index, module in enumerate(modules): + pending_tasks: Set[asyncio.Task] = set() + next_index = 0 + task_window = max(CERT_FETCH_CONCURRENCY, PDF_FETCH_CONCURRENCY) + + def schedule_next_certificate() -> None: + nonlocal next_index + if next_index >= len(modules): + return + index = next_index + module = modules[index] cert_number = parse_certificate_number(module) - tasks.append( + pending_tasks.add( asyncio.create_task( process_certificate_record_with_timeout( index, @@ -1939,24 +1947,34 @@ async def build_certificate_artifacts( ) ) ) + next_index += 1 + + total = len(modules) + for _ in range(min(task_window, total)): + schedule_next_certificate() - total = len(tasks) completed = 0 - for task in asyncio.as_completed(tasks): - index, module_out, detail_payload, categories, task_stats = await task - completed += 1 - results[index] = module_out - cert_number = parse_certificate_number(module_out) - if cert_number is not None and detail_payload is not None: - payloads[cert_number] = detail_payload - if cert_number is not None and categories: - algorithms_map[cert_number] = categories - add_processing_stats(stats, task_stats) - if completed % 100 == 0 or completed == total: - print( - f" Progress: {completed}/{total} " - f"({stats['html_reused']} reused, {stats['html_refreshed']} refreshed, {stats['html_failed']} failed)" - ) + while pending_tasks: + done, pending_tasks = await asyncio.wait( + pending_tasks, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in done: + index, module_out, detail_payload, categories, task_stats = await task + completed += 1 + results[index] = module_out + cert_number = parse_certificate_number(module_out) + if cert_number is not None and detail_payload is not None: + payloads[cert_number] = detail_payload + if cert_number is not None and categories: + algorithms_map[cert_number] = categories + add_processing_stats(stats, task_stats) + schedule_next_certificate() + if completed % 100 == 0 or completed == total: + print( + f" Progress: {completed}/{total} " + f"({stats['html_reused']} reused, {stats['html_refreshed']} refreshed, {stats['html_failed']} failed)" + ) return [result or {} for result in results], payloads, algorithms_map, stats diff --git a/test_scraper.py b/test_scraper.py index 23b010a34..dde9245f6 100644 --- a/test_scraper.py +++ b/test_scraper.py @@ -15,6 +15,7 @@ ALGORITHM_CACHE_VERSION, ALGORITHM_EXTRACTION_SCHEMA_VERSION, build_algorithm_extraction_provenance, + build_certificate_artifacts, build_certificate_index_payload, build_certificate_fingerprint, build_extraction_metrics, @@ -944,6 +945,100 @@ async def slow_process(*args, **kwargs): print("✓ Certificate timeout fallback test passed") +def test_build_certificate_artifacts_bounds_active_tasks(): + """Certificate artifact scheduling should not start every module task at once.""" + modules = [ + { + "Certificate Number": str(5200 + index), + "Vendor Name": "Example Vendor", + "Module Name": f"Example Module {index}", + "Module Type": "Software", + "Validation Date": "04/10/2026", + "security_policy_url": f"https://csrc.nist.gov/example/{index}.pdf", + "certificate_detail_url": f"https://csrc.nist.gov/projects/cryptographic-module-validation-program/certificate/{5200 + index}", + } + for index in range(8) + ] + + running = 0 + max_running = 0 + started = [] + + async def fake_process( + index, + module, + dataset, + generated_at, + algorithm_source, + previous_module, + previous_detail, + previous_metadata, + client, + cert_semaphore, + pdf_semaphore, + pdf_cache, + pdf_cache_lock, + database_algorithms_map, + ): + nonlocal running, max_running + running += 1 + max_running = max(max_running, running) + started.append(index) + try: + await asyncio.sleep(0.01 * (len(modules) - index)) + module_out = dict(module) + module_out["detail_available"] = True + cert_number = module["Certificate Number"] + detail_payload = { + "certificate_number": cert_number, + "dataset": dataset, + "generated_at": generated_at, + "nist_page_url": module["certificate_detail_url"], + "certificate_detail_url": module["certificate_detail_url"], + "security_policy_url": module["security_policy_url"], + "vendor_name": module["Vendor Name"], + "module_name": module["Module Name"], + "standard": "FIPS 140-3", + "status": "Active", + "related_files": [], + "validation_history": [], + "vendor": {}, + } + return index, module_out, detail_payload, ["AES"], {"html_refreshed": 1} + finally: + running -= 1 + + original_process = scraper_module.process_certificate_record_with_timeout + original_cert_concurrency = scraper_module.CERT_FETCH_CONCURRENCY + original_pdf_concurrency = scraper_module.PDF_FETCH_CONCURRENCY + scraper_module.process_certificate_record_with_timeout = fake_process + scraper_module.CERT_FETCH_CONCURRENCY = 2 + scraper_module.PDF_FETCH_CONCURRENCY = 3 + try: + enriched, payloads, algorithms_map, stats = asyncio.run( + build_certificate_artifacts( + modules, + "active", + "2026-04-12T03:10:00.961597Z", + "crawl4ai", + {"metadata": {}, "modules": {"active": {}}, "details": {}}, + ) + ) + finally: + scraper_module.process_certificate_record_with_timeout = original_process + scraper_module.CERT_FETCH_CONCURRENCY = original_cert_concurrency + scraper_module.PDF_FETCH_CONCURRENCY = original_pdf_concurrency + + assert max_running <= 3, "Certificate scheduler should bound active tasks to the concurrency window" + assert sorted(started) == list(range(len(modules))), "Every module should be scheduled once" + assert [record["Certificate Number"] for record in enriched] == [module["Certificate Number"] for module in modules], "Output order should match input order" + assert len(payloads) == len(modules), "Each fake detail payload should be retained" + assert len(algorithms_map) == len(modules), "Each fake algorithm payload should be indexed" + assert stats["html_refreshed"] == len(modules), "Stats should accumulate from bounded tasks" + + print("✓ Bounded certificate scheduling test passed") + + def test_prune_orphan_certificate_details(): """Test that stale certificate detail files are removed only for missing certs.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -1267,6 +1362,7 @@ def main(): test_fetch_policy_pdf_bytes_reuses_in_run_cache() test_process_certificate_record_applies_cached_algorithm_provenance() test_process_certificate_record_timeout_preserves_cached_data() + test_build_certificate_artifacts_bounds_active_tasks() test_prune_orphan_certificate_details() test_validate_generated_api_artifacts() test_build_certificate_index_payload()