From 37bb5537bec3e18a26485d3b54d97c26b287c8df Mon Sep 17 00:00:00 2001 From: Subhash Gupta Date: Sun, 22 Feb 2026 21:00:34 +0530 Subject: [PATCH] feat: add suggest_cre_mappings() for automatic CRE mapping via embeddings --- application/cmd/cre_main.py | 53 ++++++++++++++++++++++++++ application/tests/cre_main_test.py | 61 ++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/application/cmd/cre_main.py b/application/cmd/cre_main.py index ead5a4281..ae527b74f 100644 --- a/application/cmd/cre_main.py +++ b/application/cmd/cre_main.py @@ -25,6 +25,7 @@ from alive_progress import alive_bar from application.prompt_client import prompt_client as prompt_client from application.utils import gap_analysis +from application.prompt_client.prompt_client import SIMILARITY_THRESHOLD logging.basicConfig() logger = logging.getLogger(__name__) @@ -315,6 +316,58 @@ def register_standard( conn.set(standard_hash, value="") +def suggest_cre_mappings( + standard_entries: List[defs.Standard], + collection: db.Node_collection, + confidence_threshold: float = SIMILARITY_THRESHOLD, +) -> Dict[str, Any]: + """ + Given a list of Standard entries, suggest CRE mappings using + cosine similarity on existing embeddings. + + Returns high-confidence matches and flags low-confidence ones + for human review. + + Args: + standard_entries: list of Standard nodes to map + collection: database connection + confidence_threshold: minimum similarity score to auto-map + + Returns: + Dict with 'mapped' (high confidence) and 'needs_review' (low confidence) lists + """ + if not standard_entries: + logger.warning("suggest_cre_mappings() called with no standard_entries") + return {"mapped": [], "needs_review": []} + + ph = prompt_client.PromptHandler(database=collection) + results: Dict[str, Any] = {"mapped": [], "needs_review": []} + + for node in standard_entries: + text = " ".join(filter(None, [node.name, node.section, node.description])) + if not text.strip(): + continue + embedding = ph.get_text_embeddings(text) + cre_id, similarity = ph.get_id_of_most_similar_cre_paginated( + embedding, similarity_threshold=confidence_threshold + ) + entry = { + "standard": node.todict(), + "suggested_cre_id": cre_id, + "confidence": round(float(similarity), 4) if similarity else None, + } + if cre_id and similarity and similarity >= confidence_threshold: + results["mapped"].append(entry) + else: + results["needs_review"].append(entry) + + logger.info( + f"suggest_cre_mappings: {len(results['mapped'])} mapped, " + f"{len(results['needs_review'])} need review" + ) + return results + + def parse_standards_from_spreadsheeet( cre_file: List[Dict[str, Any]], cache_location: str, diff --git a/application/tests/cre_main_test.py b/application/tests/cre_main_test.py index af313c8a6..a3cb9a308 100644 --- a/application/tests/cre_main_test.py +++ b/application/tests/cre_main_test.py @@ -783,6 +783,67 @@ def test_add_from_disk( # main.export_to_osib(file_loc=f"{dir}/osib.yaml", cache=cache) # mocked_db_connect.assert_called_with(path=cache) # mocked_cre2osib.assert_called_with([defs.CRE(id="000-000", name="c0")]) + + @patch.object(prompt_client.PromptHandler, "get_text_embeddings") + @patch.object(prompt_client.PromptHandler, "get_id_of_most_similar_cre_paginated") + def test_suggest_cre_mappings( + self, + mock_get_similar_cre, + mock_get_embeddings, + ) -> None: + # Arrange + standard_entries = [ + defs.Standard( + name="PCI-DSS", + section="Use strong cryptography to protect data in transit", + description="All transmissions of cardholder data must be encrypted.", + ), + defs.Standard( + name="PCI-DSS", + section="Some vague control with no good match", + description="", + ), + ] + + fake_embedding = [0.1] * 768 + mock_get_embeddings.return_value = fake_embedding + + # First standard maps well, second does not + mock_get_similar_cre.side_effect = [ + ("cre-db-id-123", 0.85), # high confidence + (None, None), # low confidence / no match + ] + + # Act + result = main.suggest_cre_mappings( + standard_entries=standard_entries, + collection=self.collection, + ) + + # Assert + self.assertEqual(len(result["mapped"]), 1) + self.assertEqual(len(result["needs_review"]), 1) + + mapped = result["mapped"][0] + self.assertEqual(mapped["suggested_cre_id"], "cre-db-id-123") + self.assertEqual(mapped["confidence"], 0.85) + self.assertEqual(mapped["standard"]["name"], "PCI-DSS") + + review = result["needs_review"][0] + self.assertIsNone(review["suggested_cre_id"]) + self.assertIsNone(review["confidence"]) + + # Assert embeddings were called for each standard + self.assertEqual(mock_get_embeddings.call_count, 2) + + @patch.object(prompt_client.PromptHandler, "get_text_embeddings") + def test_suggest_cre_mappings_empty_input(self, mock_get_embeddings) -> None: + result = main.suggest_cre_mappings( + standard_entries=[], + collection=self.collection, + ) + self.assertEqual(result, {"mapped": [], "needs_review": []}) + mock_get_embeddings.assert_not_called() if __name__ == "__main__":