diff --git a/application/cmd/cre_main.py b/application/cmd/cre_main.py index ead5a428..acc9d9b3 100644 --- a/application/cmd/cre_main.py +++ b/application/cmd/cre_main.py @@ -31,6 +31,9 @@ logger.setLevel(logging.INFO) app = None +DEFAULT_UPSTREAM_API_URL = "https://opencre.org/rest/v1" +UPSTREAM_SYNC_REQUEST_TIMEOUT_SECONDS = 30 +UPSTREAM_SYNC_MAP_ANALYSIS_MAX_PAIRS_ENV = "CRE_UPSTREAM_SYNC_MAX_MAP_ANALYSIS_PAIRS" def register_node(node: defs.Node, collection: db.Node_collection) -> db.Node: @@ -461,14 +464,244 @@ def review_from_spreadsheet(cache: str, spreadsheet_url: str, share_with: str) - # logger.info("A spreadsheet view is at %s" % sheet_url) +def _upstream_api_url() -> str: + return os.environ.get("CRE_UPSTREAM_API_URL", DEFAULT_UPSTREAM_API_URL).rstrip("/") + + +def _progressively_sync_weak_links_for_pair( + collection: db.Node_collection, + upstream_api_url: str, + base_standard: str, + compare_standard: str, + result_payload: Dict[str, Any], +) -> Tuple[int, int]: + weak_attempted = 0 + weak_synced = 0 + + for key, value in result_payload.items(): + if not isinstance(key, str) or not isinstance(value, dict): + continue + + extra = value.get("extra") + try: + extra = int(extra) if extra is not None else 0 + except (TypeError, ValueError): + extra = 0 + if extra <= 0: + continue + + weak_cache_key = gap_analysis.make_subresources_key( + standards=[base_standard, compare_standard], key=key + ) + if collection.gap_analysis_exists(weak_cache_key): + continue + + weak_attempted += 1 + try: + weak_response = requests.get( + f"{upstream_api_url}/map_analysis_weak_links", + params=[ + ("standard", base_standard), + ("standard", compare_standard), + ("key", key), + ], + timeout=UPSTREAM_SYNC_REQUEST_TIMEOUT_SECONDS, + ) + except requests.RequestException as exc: + logger.warning( + "Could not sync weak links for %s >> %s (key=%s): %s", + base_standard, + compare_standard, + key, + exc, + ) + continue + if weak_response.status_code != 200: + logger.info( + "Skipping weak links for %s >> %s (key=%s) from upstream (status=%s)", + base_standard, + compare_standard, + key, + weak_response.status_code, + ) + continue + + try: + weak_payload = weak_response.json() + except ValueError: + logger.warning( + "Skipping weak links for %s >> %s (key=%s) due to invalid JSON payload", + base_standard, + compare_standard, + key, + ) + continue + if not isinstance(weak_payload, dict) or weak_payload.get("result") is None: + continue + + collection.add_gap_analysis_result( + cache_key=weak_cache_key, + ga_object=json.dumps({"result": weak_payload.get("result")}), + ) + weak_synced += 1 + + return weak_attempted, weak_synced + + +def _progressively_sync_gap_analysis_from_upstream( + collection: db.Node_collection, upstream_api_url: str +) -> None: + max_pairs_raw = os.environ.get(UPSTREAM_SYNC_MAP_ANALYSIS_MAX_PAIRS_ENV, "0") + try: + max_pairs = int(max_pairs_raw) + except ValueError: + logger.warning( + "%s should be an integer, got '%s'. Falling back to full sync.", + UPSTREAM_SYNC_MAP_ANALYSIS_MAX_PAIRS_ENV, + max_pairs_raw, + ) + max_pairs = 0 + if max_pairs < 0: + max_pairs = 0 + + try: + standards_response = requests.get( + f"{upstream_api_url}/standards", + timeout=UPSTREAM_SYNC_REQUEST_TIMEOUT_SECONDS, + ) + except requests.RequestException as exc: + logger.warning( + "Failed to fetch standards from upstream map analysis API: %s", exc + ) + return + if standards_response.status_code != 200: + logger.warning( + "Could not fetch standards from upstream (status=%s), skipping map analysis sync", + standards_response.status_code, + ) + return + + try: + standards = standards_response.json() + except ValueError: + logger.warning("Upstream /standards response is not valid JSON, skipping") + return + if not isinstance(standards, list): + logger.warning("Upstream /standards response is not a list, skipping") + return + standards = [standard for standard in standards if isinstance(standard, str)] + standards = list(dict.fromkeys(standards)) + + total_pairs = len(standards) * (len(standards) - 1) + if total_pairs == 0: + logger.info("No standard pairs found for progressive map analysis sync") + return + + logger.info( + "Starting progressive map analysis sync for up to %s pair(s) out of %s total", + max_pairs if max_pairs else "all", + total_pairs, + ) + + attempted_pairs = 0 + synced_pairs = 0 + weak_links_attempted = 0 + weak_links_synced = 0 + + for standard_a in standards: + for standard_b in standards: + if standard_a == standard_b: + continue + + cache_key = gap_analysis.make_resources_key([standard_a, standard_b]) + if collection.gap_analysis_exists(cache_key): + continue + + if max_pairs and synced_pairs >= max_pairs: + logger.info( + "Reached %s=%s after syncing %s pair(s), stopping early", + UPSTREAM_SYNC_MAP_ANALYSIS_MAX_PAIRS_ENV, + max_pairs, + synced_pairs, + ) + return + + attempted_pairs += 1 + try: + response = requests.get( + f"{upstream_api_url}/map_analysis", + params=[("standard", standard_a), ("standard", standard_b)], + timeout=UPSTREAM_SYNC_REQUEST_TIMEOUT_SECONDS, + ) + except requests.RequestException as exc: + logger.warning( + "Could not sync map analysis for %s >> %s: %s", + standard_a, + standard_b, + exc, + ) + continue + if response.status_code != 200: + logger.info( + "Skipping map analysis %s >> %s from upstream (status=%s)", + standard_a, + standard_b, + response.status_code, + ) + continue + + try: + payload = response.json() + except ValueError: + logger.warning( + "Skipping map analysis %s >> %s due to invalid JSON payload", + standard_a, + standard_b, + ) + continue + if not isinstance(payload, dict) or payload.get("result") is None: + continue + + collection.add_gap_analysis_result( + cache_key=cache_key, + ga_object=json.dumps({"result": payload.get("result")}), + ) + synced_pairs += 1 + + weak_attempted, weak_synced = _progressively_sync_weak_links_for_pair( + collection=collection, + upstream_api_url=upstream_api_url, + base_standard=standard_a, + compare_standard=standard_b, + result_payload=payload.get("result"), + ) + weak_links_attempted += weak_attempted + weak_links_synced += weak_synced + + if synced_pairs % 25 == 0: + logger.info( + "Progressive map analysis sync: synced %s pair(s) so far", + synced_pairs, + ) + + logger.info( + "Progressive map analysis sync complete. Attempted %s missing pair(s), synced %s pair(s), attempted %s weak-link result(s), synced %s weak-link result(s)", + attempted_pairs, + synced_pairs, + weak_links_attempted, + weak_links_synced, + ) + + def download_graph_from_upstream(cache: str) -> None: imported_cres = {} collection = db_connect(path=cache).with_graph() + upstream_api_url = _upstream_api_url() def download_cre_from_upstream(creid: str): cre_response = requests.get( - os.environ.get("CRE_UPSTREAM_API_URL", "https://opencre.org/rest/v1") - + f"/id/{creid}" + f"{upstream_api_url}/id/{creid}", + timeout=UPSTREAM_SYNC_REQUEST_TIMEOUT_SECONDS, ) if cre_response.status_code != 200: raise RuntimeError( @@ -487,8 +720,8 @@ def download_cre_from_upstream(creid: str): download_cre_from_upstream(link.document.id) root_cres_response = requests.get( - os.environ.get("CRE_UPSTREAM_API_URL", "https://opencre.org/rest/v1") - + "/root_cres" + f"{upstream_api_url}/root_cres", + timeout=UPSTREAM_SYNC_REQUEST_TIMEOUT_SECONDS, ) if root_cres_response.status_code != 200: raise RuntimeError( @@ -503,6 +736,19 @@ def download_cre_from_upstream(creid: str): if link.document.doctype == defs.Credoctypes.CRE: download_cre_from_upstream(link.document.id) + if not os.environ.get("CRE_NO_NEO4J"): + try: + populate_neo4j_db(cache) + except Exception as exc: + logger.warning( + "Could not populate local neo4j DB during upstream sync: %s", exc + ) + + _progressively_sync_gap_analysis_from_upstream( + collection=collection, + upstream_api_url=upstream_api_url, + ) + # def review_from_disk(cache: str, cre_file_loc: str, share_with: str) -> None: # """--review --cre_loc diff --git a/application/tests/cre_main_test.py b/application/tests/cre_main_test.py index af313c8a..4f0ca5d2 100644 --- a/application/tests/cre_main_test.py +++ b/application/tests/cre_main_test.py @@ -1,3 +1,4 @@ +import json import logging import os import shutil @@ -8,6 +9,7 @@ from unittest.mock import Mock, patch from rq import Queue from application.utils import redis +from application.utils import gap_analysis from application.prompt_client import prompt_client as prompt_client from application.tests.utils import data_gen from application import create_app, sqla # type: ignore @@ -18,6 +20,15 @@ from application.defs.osib_defs import Osib_id, Osib_tree +class StubResponse: + def __init__(self, status_code: int, payload: Any) -> None: + self.status_code = status_code + self._payload = payload + + def json(self) -> Any: + return self._payload + + class TestMain(unittest.TestCase): def tearDown(self) -> None: for tmpdir in self.tmpdirs: @@ -467,6 +478,375 @@ def test_get_standards_files_from_disk(self) -> None: ymls.append(location) self.assertCountEqual(ymls, [x for x in main.get_cre_files_from_disk(loc)]) + @patch("application.cmd.cre_main.db_connect") + @patch("application.cmd.cre_main.register_cre") + @patch("application.cmd.cre_main.populate_neo4j_db") + @patch("application.cmd.cre_main.requests.get") + def test_download_graph_from_upstream_syncs_gap_analysis_progressively( + self, + mocked_requests_get: Mock, + mocked_populate_neo4j_db: Mock, + mocked_register_cre: Mock, + mocked_db_connect: Mock, + ) -> None: + collection = mock.Mock() + collection.with_graph.return_value = collection + cache_entries: Dict[str, str] = {} + collection.gap_analysis_exists.side_effect = lambda key: key in cache_entries + collection.add_gap_analysis_result.side_effect = ( + lambda cache_key, ga_object: cache_entries.__setitem__(cache_key, ga_object) + ) + mocked_db_connect.return_value = collection + + def fake_get(url: str, **kwargs) -> StubResponse: + if url.endswith("/root_cres"): + return StubResponse( + 200, + { + "data": [ + { + "doctype": "CRE", + "id": "111-111", + "name": "Root CRE", + "description": "", + "links": [], + } + ] + }, + ) + if url.endswith("/standards"): + return StubResponse(200, ["ASVS", "Top10"]) + if url.endswith("/map_analysis"): + standards = [ + value + for key, value in kwargs.get("params", []) + if key == "standard" + ] + if standards == ["ASVS", "Top10"]: + return StubResponse( + 200, + { + "result": { + "ASVS:1": { + "start": {"id": "ASVS:1"}, + "paths": {}, + "extra": 0, + } + } + }, + ) + if standards == ["Top10", "ASVS"]: + return StubResponse(200, {"job_id": "job-1"}) + self.fail(f"Unexpected map_analysis query: {standards}") + self.fail(f"Unexpected upstream URL: {url}") + + mocked_requests_get.side_effect = fake_get + + with patch.dict( + os.environ, + { + main.UPSTREAM_SYNC_MAP_ANALYSIS_MAX_PAIRS_ENV: "0", + "CRE_NO_NEO4J": "", + }, + clear=False, + ): + main.download_graph_from_upstream("/tmp/cache.sqlite") + + mocked_register_cre.assert_called_once() + mocked_populate_neo4j_db.assert_called_once_with("/tmp/cache.sqlite") + self.assertIn("ASVS >> Top10", cache_entries) + self.assertNotIn("Top10 >> ASVS", cache_entries) + self.assertEqual( + json.loads(cache_entries["ASVS >> Top10"]), + { + "result": { + "ASVS:1": { + "start": {"id": "ASVS:1"}, + "paths": {}, + "extra": 0, + } + } + }, + ) + + @patch("application.cmd.cre_main.db_connect") + @patch("application.cmd.cre_main.register_cre") + @patch("application.cmd.cre_main.populate_neo4j_db") + @patch("application.cmd.cre_main.requests.get") + def test_download_graph_from_upstream_skips_cached_gap_analysis_pairs( + self, + mocked_requests_get: Mock, + mocked_populate_neo4j_db: Mock, + mocked_register_cre: Mock, + mocked_db_connect: Mock, + ) -> None: + cached_key = "ASVS >> Top10" + cache_entries: Dict[str, str] = { + cached_key: json.dumps({"result": {"cached": 1}}) + } + + collection = mock.Mock() + collection.with_graph.return_value = collection + collection.gap_analysis_exists.side_effect = lambda key: key in cache_entries + collection.add_gap_analysis_result.side_effect = ( + lambda cache_key, ga_object: cache_entries.__setitem__(cache_key, ga_object) + ) + mocked_db_connect.return_value = collection + + def fake_get(url: str, **kwargs) -> StubResponse: + if url.endswith("/root_cres"): + return StubResponse( + 200, + { + "data": [ + { + "doctype": "CRE", + "id": "111-111", + "name": "Root CRE", + "description": "", + "links": [], + } + ] + }, + ) + if url.endswith("/standards"): + return StubResponse(200, ["ASVS", "Top10"]) + if url.endswith("/map_analysis"): + standards = [ + value + for key, value in kwargs.get("params", []) + if key == "standard" + ] + if standards == ["ASVS", "Top10"]: + self.fail("Cached pair should not be fetched from upstream") + if standards == ["Top10", "ASVS"]: + return StubResponse(200, {"result": {"Top10:1": {"paths": {}}}}) + self.fail(f"Unexpected map_analysis query: {standards}") + self.fail(f"Unexpected upstream URL: {url}") + + mocked_requests_get.side_effect = fake_get + + with patch.dict( + os.environ, + { + main.UPSTREAM_SYNC_MAP_ANALYSIS_MAX_PAIRS_ENV: "0", + "CRE_NO_NEO4J": "", + }, + clear=False, + ): + main.download_graph_from_upstream("/tmp/cache.sqlite") + + mocked_register_cre.assert_called_once() + mocked_populate_neo4j_db.assert_called_once_with("/tmp/cache.sqlite") + self.assertEqual( + cache_entries[cached_key], json.dumps({"result": {"cached": 1}}) + ) + self.assertIn("Top10 >> ASVS", cache_entries) + + map_analysis_calls = [ + call + for call in mocked_requests_get.call_args_list + if call.args and call.args[0].endswith("/map_analysis") + ] + self.assertEqual(len(map_analysis_calls), 1) + + @patch("application.cmd.cre_main.db_connect") + @patch("application.cmd.cre_main.register_cre") + @patch("application.cmd.cre_main.populate_neo4j_db") + @patch("application.cmd.cre_main.requests.get") + def test_download_graph_from_upstream_syncs_weak_links_results( + self, + mocked_requests_get: Mock, + mocked_populate_neo4j_db: Mock, + mocked_register_cre: Mock, + mocked_db_connect: Mock, + ) -> None: + cache_entries: Dict[str, str] = {} + collection = mock.Mock() + collection.with_graph.return_value = collection + collection.gap_analysis_exists.side_effect = lambda key: key in cache_entries + collection.add_gap_analysis_result.side_effect = ( + lambda cache_key, ga_object: cache_entries.__setitem__(cache_key, ga_object) + ) + mocked_db_connect.return_value = collection + + def fake_get(url: str, **kwargs) -> StubResponse: + if url.endswith("/root_cres"): + return StubResponse( + 200, + { + "data": [ + { + "doctype": "CRE", + "id": "111-111", + "name": "Root CRE", + "description": "", + "links": [], + } + ] + }, + ) + if url.endswith("/standards"): + return StubResponse(200, ["ASVS", "Top10"]) + if url.endswith("/map_analysis"): + standards = [ + value + for key, value in kwargs.get("params", []) + if key == "standard" + ] + if standards == ["ASVS", "Top10"]: + return StubResponse( + 200, + { + "result": { + "ASVS:1": { + "start": {"id": "ASVS:1"}, + "paths": {}, + "extra": 1, + } + } + }, + ) + if standards == ["Top10", "ASVS"]: + return StubResponse(200, {"job_id": "job-1"}) + self.fail(f"Unexpected map_analysis query: {standards}") + if url.endswith("/map_analysis_weak_links"): + standards = [ + value + for key, value in kwargs.get("params", []) + if key == "standard" + ] + key = [ + value for key, value in kwargs.get("params", []) if key == "key" + ][0] + if standards == ["ASVS", "Top10"] and key == "ASVS:1": + return StubResponse( + 200, + {"result": {"paths": {"Top10:1": {"score": 7, "path": []}}}}, + ) + self.fail( + f"Unexpected weak links query: standards={standards}, key={key}" + ) + self.fail(f"Unexpected upstream URL: {url}") + + mocked_requests_get.side_effect = fake_get + + with patch.dict( + os.environ, + { + main.UPSTREAM_SYNC_MAP_ANALYSIS_MAX_PAIRS_ENV: "0", + "CRE_NO_NEO4J": "", + }, + clear=False, + ): + main.download_graph_from_upstream("/tmp/cache.sqlite") + + main_cache_key = gap_analysis.make_resources_key(["ASVS", "Top10"]) + weak_cache_key = gap_analysis.make_subresources_key( + standards=["ASVS", "Top10"], key="ASVS:1" + ) + self.assertIn(main_cache_key, cache_entries) + self.assertIn(weak_cache_key, cache_entries) + self.assertEqual( + json.loads(cache_entries[weak_cache_key]), + {"result": {"paths": {"Top10:1": {"score": 7, "path": []}}}}, + ) + + weak_call_count = len( + [ + call + for call in mocked_requests_get.call_args_list + if call.args and call.args[0].endswith("/map_analysis_weak_links") + ] + ) + self.assertEqual(weak_call_count, 1) + + mocked_register_cre.assert_called_once() + mocked_populate_neo4j_db.assert_called_once_with("/tmp/cache.sqlite") + + @patch("application.cmd.cre_main.db_connect") + @patch("application.cmd.cre_main.register_cre") + @patch("application.cmd.cre_main.populate_neo4j_db") + @patch("application.cmd.cre_main.requests.get") + def test_download_graph_from_upstream_respects_max_pairs_limit( + self, + mocked_requests_get: Mock, + mocked_populate_neo4j_db: Mock, + mocked_register_cre: Mock, + mocked_db_connect: Mock, + ) -> None: + cache_entries: Dict[str, str] = {} + collection = mock.Mock() + collection.with_graph.return_value = collection + collection.gap_analysis_exists.side_effect = lambda key: key in cache_entries + collection.add_gap_analysis_result.side_effect = ( + lambda cache_key, ga_object: cache_entries.__setitem__(cache_key, ga_object) + ) + mocked_db_connect.return_value = collection + + def fake_get(url: str, **kwargs) -> StubResponse: + if url.endswith("/root_cres"): + return StubResponse( + 200, + { + "data": [ + { + "doctype": "CRE", + "id": "111-111", + "name": "Root CRE", + "description": "", + "links": [], + } + ] + }, + ) + if url.endswith("/standards"): + return StubResponse(200, ["ASVS", "Top10", "NIST"]) + if url.endswith("/map_analysis"): + standards = [ + value + for key, value in kwargs.get("params", []) + if key == "standard" + ] + return StubResponse( + 200, + { + "result": { + f"{standards[0]}:1": { + "start": {"id": f"{standards[0]}:1"}, + "paths": {}, + "extra": 0, + } + } + }, + ) + self.fail(f"Unexpected upstream URL: {url}") + + mocked_requests_get.side_effect = fake_get + + with patch.dict( + os.environ, + { + main.UPSTREAM_SYNC_MAP_ANALYSIS_MAX_PAIRS_ENV: "1", + "CRE_NO_NEO4J": "", + }, + clear=False, + ): + main.download_graph_from_upstream("/tmp/cache.sqlite") + + pair_cache_keys = [key for key in cache_entries.keys() if "->" not in key] + self.assertEqual(len(pair_cache_keys), 1) + + map_analysis_calls = [ + call + for call in mocked_requests_get.call_args_list + if call.args and call.args[0].endswith("/map_analysis") + ] + self.assertEqual(len(map_analysis_calls), 1) + + mocked_register_cre.assert_called_once() + mocked_populate_neo4j_db.assert_called_once_with("/tmp/cache.sqlite") + @patch("application.cmd.cre_main.ai_client_init") @patch("application.cmd.cre_main.db_connect") @patch("application.cmd.cre_main.parse_standards_from_spreadsheeet") diff --git a/application/tests/gap_analysis_db_test.py b/application/tests/gap_analysis_db_test.py index 3e101f36..fc674b53 100644 --- a/application/tests/gap_analysis_db_test.py +++ b/application/tests/gap_analysis_db_test.py @@ -35,8 +35,10 @@ def cypher_side_effect(query, params=None, resolve_objects=True): self.mock_cypher.side_effect = cypher_side_effect - # Call the function - db.NEO_DB.gap_analysis("StandardA", "StandardB") + # Enable optimized mode explicitly for this pruning assertion. + with patch("application.config.Config.GAP_ANALYSIS_OPTIMIZED", True): + # Call the function using the singleton instance + db.NEO_DB.instance().gap_analysis("StandardA", "StandardB") # ASSERTION: # We expect cypher_query to be called.