Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions application/cmd/cre_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from alive_progress import alive_bar
from application.prompt_client import prompt_client as prompt_client
from application.utils import gap_analysis
from application.prompt_client.prompt_client import SIMILARITY_THRESHOLD

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -315,6 +316,58 @@ def register_standard(
conn.set(standard_hash, value="")


def suggest_cre_mappings(
standard_entries: List[defs.Standard],
collection: db.Node_collection,
confidence_threshold: float = SIMILARITY_THRESHOLD,
) -> Dict[str, Any]:
"""
Given a list of Standard entries, suggest CRE mappings using
cosine similarity on existing embeddings.

Returns high-confidence matches and flags low-confidence ones
for human review.

Args:
standard_entries: list of Standard nodes to map
collection: database connection
confidence_threshold: minimum similarity score to auto-map

Returns:
Dict with 'mapped' (high confidence) and 'needs_review' (low confidence) lists
"""
if not standard_entries:
logger.warning("suggest_cre_mappings() called with no standard_entries")
return {"mapped": [], "needs_review": []}

ph = prompt_client.PromptHandler(database=collection)
results: Dict[str, Any] = {"mapped": [], "needs_review": []}

for node in standard_entries:
text = " ".join(filter(None, [node.name, node.section, node.description]))
if not text.strip():
continue
embedding = ph.get_text_embeddings(text)
cre_id, similarity = ph.get_id_of_most_similar_cre_paginated(
embedding, similarity_threshold=confidence_threshold
)
entry = {
"standard": node.todict(),
"suggested_cre_id": cre_id,
"confidence": round(float(similarity), 4) if similarity else None,
}
if cre_id and similarity and similarity >= confidence_threshold:
results["mapped"].append(entry)
else:
results["needs_review"].append(entry)

logger.info(
f"suggest_cre_mappings: {len(results['mapped'])} mapped, "
f"{len(results['needs_review'])} need review"
)
return results


def parse_standards_from_spreadsheeet(
cre_file: List[Dict[str, Any]],
cache_location: str,
Expand Down
61 changes: 61 additions & 0 deletions application/tests/cre_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,67 @@ def test_add_from_disk(
# main.export_to_osib(file_loc=f"{dir}/osib.yaml", cache=cache)
# mocked_db_connect.assert_called_with(path=cache)
# mocked_cre2osib.assert_called_with([defs.CRE(id="000-000", name="c0")])

@patch.object(prompt_client.PromptHandler, "get_text_embeddings")
@patch.object(prompt_client.PromptHandler, "get_id_of_most_similar_cre_paginated")
def test_suggest_cre_mappings(
self,
mock_get_similar_cre,
mock_get_embeddings,
) -> None:
# Arrange
standard_entries = [
defs.Standard(
name="PCI-DSS",
section="Use strong cryptography to protect data in transit",
description="All transmissions of cardholder data must be encrypted.",
),
defs.Standard(
name="PCI-DSS",
section="Some vague control with no good match",
description="",
),
]

fake_embedding = [0.1] * 768
mock_get_embeddings.return_value = fake_embedding

# First standard maps well, second does not
mock_get_similar_cre.side_effect = [
("cre-db-id-123", 0.85), # high confidence
(None, None), # low confidence / no match
]

# Act
result = main.suggest_cre_mappings(
standard_entries=standard_entries,
collection=self.collection,
)

# Assert
self.assertEqual(len(result["mapped"]), 1)
self.assertEqual(len(result["needs_review"]), 1)

mapped = result["mapped"][0]
self.assertEqual(mapped["suggested_cre_id"], "cre-db-id-123")
self.assertEqual(mapped["confidence"], 0.85)
self.assertEqual(mapped["standard"]["name"], "PCI-DSS")

review = result["needs_review"][0]
self.assertIsNone(review["suggested_cre_id"])
self.assertIsNone(review["confidence"])

# Assert embeddings were called for each standard
self.assertEqual(mock_get_embeddings.call_count, 2)

@patch.object(prompt_client.PromptHandler, "get_text_embeddings")
def test_suggest_cre_mappings_empty_input(self, mock_get_embeddings) -> None:
result = main.suggest_cre_mappings(
standard_entries=[],
collection=self.collection,
)
self.assertEqual(result, {"mapped": [], "needs_review": []})
mock_get_embeddings.assert_not_called()


if __name__ == "__main__":
Expand Down