Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions agent_fox/knowledge/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,20 @@ def enrich_extraction_with_causal(
return base_prompt + addendum


def _is_valid_uuid(s: str) -> bool:
"""Return True if *s* is a well-formed UUID string."""
try:
uuid.UUID(s)
return True
except (ValueError, AttributeError):
return False


def parse_causal_links(extraction_response: str) -> list[tuple[str, str]]:
"""Parse causal link pairs from the extraction model's response.

Returns a list of (cause_id, effect_id) tuples. Silently skips
malformed entries.
Returns a list of (cause_id, effect_id) tuples. Skips entries where
either ID is missing, not a string, or not a valid UUID.
"""
data = extract_json_array(extraction_response, repair_truncated=True)
if data is None:
Expand All @@ -164,10 +173,17 @@ def parse_causal_links(extraction_response: str) -> list[tuple[str, str]]:
continue
cause_id = item.get("cause_id")
effect_id = item.get("effect_id")
if isinstance(cause_id, str) and isinstance(effect_id, str):
links.append((cause_id, effect_id))
else:
if not isinstance(cause_id, str) or not isinstance(effect_id, str):
logger.debug("Skipping malformed causal link entry: %s", item)
continue
if not _is_valid_uuid(cause_id) or not _is_valid_uuid(effect_id):
logger.warning(
"Skipping causal link with malformed UUID: cause_id=%s, effect_id=%s",
cause_id,
effect_id,
)
continue
links.append((cause_id, effect_id))
return links


Expand Down
131 changes: 102 additions & 29 deletions tests/unit/knowledge/test_extraction_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
parse_causal_links,
)

# Valid UUIDs for use in tests
_UUID_A = "11111111-1111-1111-1111-111111111111"
_UUID_B = "22222222-2222-2222-2222-222222222222"
_UUID_C = "33333333-3333-3333-3333-333333333333"
_UUID_D = "44444444-4444-4444-4444-444444444444"
_UUID_E = "55555555-5555-5555-5555-555555555555"
_UUID_X1 = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
_UUID_X2 = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"


class TestEnrichExtractionPrompt:
"""TS-13-14: Enrich extraction prompt includes prior facts.
Expand All @@ -20,27 +29,27 @@ class TestEnrichExtractionPrompt:

def test_enriched_prompt_contains_base(self) -> None:
"""The enriched prompt includes the original base prompt."""
prior = [{"id": "aaa", "content": "User.email nullable"}]
prior = [{"id": _UUID_A, "content": "User.email nullable"}]
result = enrich_extraction_with_causal("Extract facts:", prior)
assert "Extract facts:" in result

def test_enriched_prompt_contains_causal_section(self) -> None:
"""The enriched prompt includes the Causal Relationships section."""
prior = [{"id": "aaa", "content": "User.email nullable"}]
prior = [{"id": _UUID_A, "content": "User.email nullable"}]
result = enrich_extraction_with_causal("Extract facts:", prior)
assert "Causal Relationships" in result

def test_enriched_prompt_contains_prior_fact_content(self) -> None:
"""The enriched prompt includes prior fact content."""
prior = [{"id": "aaa", "content": "User.email nullable"}]
prior = [{"id": _UUID_A, "content": "User.email nullable"}]
result = enrich_extraction_with_causal("Extract facts:", prior)
assert "User.email nullable" in result

def test_enriched_prompt_with_multiple_prior_facts(self) -> None:
"""Multiple prior facts are all included in the enriched prompt."""
prior = [
{"id": "aaa", "content": "First fact"},
{"id": "bbb", "content": "Second fact"},
{"id": _UUID_A, "content": "First fact"},
{"id": _UUID_B, "content": "Second fact"},
]
result = enrich_extraction_with_causal("Base:", prior)
assert "First fact" in result
Expand All @@ -60,12 +69,15 @@ class TestParseCausalLinks:
"""

def test_parses_valid_links(self) -> None:
"""Valid JSON causal links are parsed correctly."""
response = '[{"cause_id": "aaa", "effect_id": "bbb"}, {"cause_id": "ccc", "effect_id": "ddd"}]'
"""Valid JSON causal links with proper UUIDs are parsed correctly."""
response = (
f'[{{"cause_id": "{_UUID_A}", "effect_id": "{_UUID_B}"}}, '
f'{{"cause_id": "{_UUID_C}", "effect_id": "{_UUID_D}"}}]'
)
links = parse_causal_links(response)
assert len(links) == 2
assert links[0] == ("aaa", "bbb")
assert links[1] == ("ccc", "ddd")
assert links[0] == (_UUID_A, _UUID_B)
assert links[1] == (_UUID_C, _UUID_D)


class TestParseCausalLinksMalformed:
Expand All @@ -76,10 +88,13 @@ class TestParseCausalLinksMalformed:

def test_skips_malformed_entries(self) -> None:
"""Malformed entries are silently skipped, valid ones returned."""
response = '[{"cause_id": "aaa", "effect_id": "bbb"}, {"bad": "entry"}, "not json"]'
response = (
f'[{{"cause_id": "{_UUID_A}", "effect_id": "{_UUID_B}"}}, '
'{"bad": "entry"}, "not json"]'
)
links = parse_causal_links(response)
assert len(links) == 1
assert links[0] == ("aaa", "bbb")
assert links[0] == (_UUID_A, _UUID_B)


class TestParseCausalLinksEmpty:
Expand All @@ -99,17 +114,17 @@ class TestParseCausalLinksMarkdownFences:

def test_parses_json_inside_code_fence(self) -> None:
"""JSON wrapped in ```json fences is parsed correctly."""
response = '```json\n[{"cause_id": "aaa", "effect_id": "bbb"}]\n```'
response = f'```json\n[{{"cause_id": "{_UUID_A}", "effect_id": "{_UUID_B}"}}]\n```'
links = parse_causal_links(response)
assert len(links) == 1
assert links[0] == ("aaa", "bbb")
assert links[0] == (_UUID_A, _UUID_B)

def test_parses_json_inside_plain_fence(self) -> None:
"""JSON wrapped in ``` fences (no language tag) is parsed correctly."""
response = '```\n[{"cause_id": "x1", "effect_id": "x2"}]\n```'
response = f'```\n[{{"cause_id": "{_UUID_X1}", "effect_id": "{_UUID_X2}"}}]\n```'
links = parse_causal_links(response)
assert len(links) == 1
assert links[0] == ("x1", "x2")
assert links[0] == (_UUID_X1, _UUID_X2)


class TestParseCausalLinksWithEchoedRefs:
Expand All @@ -118,16 +133,16 @@ class TestParseCausalLinksWithEchoedRefs:
def test_parses_links_after_echoed_uuid_references(self) -> None:
"""JSON causal links are parsed when LLM echoes [uuid] refs in prose."""
response = (
"Looking at [aaa-111] and [bbb-222], I see a causal chain:\n\n"
'[{"cause_id": "aaa-111", "effect_id": "bbb-222"}]'
f"Looking at [{_UUID_A}] and [{_UUID_B}], I see a causal chain:\n\n"
f'[{{"cause_id": "{_UUID_A}", "effect_id": "{_UUID_B}"}}]'
)
links = parse_causal_links(response)
assert len(links) == 1
assert links[0] == ("aaa-111", "bbb-222")
assert links[0] == (_UUID_A, _UUID_B)

def test_parses_empty_array_after_echoed_refs(self) -> None:
"""Empty JSON array is parsed when LLM echoes [uuid] refs in prose."""
response = "Reviewing [fact-1] and [fact-2], no causal relationship found.\n\n[]"
response = f"Reviewing [{_UUID_A}] and [{_UUID_B}], no causal relationship found.\n\n[]"
links = parse_causal_links(response)
assert len(links) == 0

Expand All @@ -145,7 +160,7 @@ def test_invalid_json_returns_empty_list(self) -> None:

def test_partial_json_no_complete_entries_returns_empty(self) -> None:
"""Truncated JSON with no complete entries returns an empty list."""
links = parse_causal_links('[{"cause_id": "aaa", "effect_id":')
links = parse_causal_links(f'[{{"cause_id": "{_UUID_A}", "effect_id":')
assert len(links) == 0


Expand All @@ -155,25 +170,83 @@ class TestParseCausalLinksTruncatedRecovery:
def test_recovers_complete_entries_from_truncated_array(self) -> None:
"""Complete entries before the truncation point are recovered."""
response = (
'[{"cause_id": "aaa", "effect_id": "bbb"}, '
'{"cause_id": "ccc", "effect_id": "ddd"}, '
'{"cause_id": "eee", "effect_'
f'[{{"cause_id": "{_UUID_A}", "effect_id": "{_UUID_B}"}}, '
f'{{"cause_id": "{_UUID_C}", "effect_id": "{_UUID_D}"}}, '
f'{{"cause_id": "{_UUID_E}", "effect_'
)
links = parse_causal_links(response)
assert len(links) == 2
assert links[0] == ("aaa", "bbb")
assert links[1] == ("ccc", "ddd")
assert links[0] == (_UUID_A, _UUID_B)
assert links[1] == (_UUID_C, _UUID_D)

def test_recovers_from_truncated_fenced_response(self) -> None:
"""Truncated ```json fenced response recovers valid entries."""
response = '```json\n[{"cause_id": "x1", "effect_id": "x2"}, {"cause_id": "y1"'
response = (
f'```json\n[{{"cause_id": "{_UUID_X1}", "effect_id": "{_UUID_X2}"}}, '
f'{{"cause_id": "{_UUID_A}"'
)
links = parse_causal_links(response)
assert len(links) == 1
assert links[0] == ("x1", "x2")
assert links[0] == (_UUID_X1, _UUID_X2)

def test_single_complete_entry_before_truncation(self) -> None:
"""A single complete entry followed by truncation is recovered."""
response = '[{"cause_id": "a", "effect_id": "b"}, {"cause_id":'
response = f'[{{"cause_id": "{_UUID_A}", "effect_id": "{_UUID_B}"}}, {{"cause_id":'
links = parse_causal_links(response)
assert len(links) == 1
assert links[0] == ("a", "b")
assert links[0] == (_UUID_A, _UUID_B)


# ---------------------------------------------------------------------------
# Regression: malformed UUID filtering (fixes #474)
# ---------------------------------------------------------------------------


class TestParseCausalLinksUUIDValidation:
"""Verify parse_causal_links filters out malformed UUIDs from LLM output."""

def test_truncated_uuid_filtered(self) -> None:
"""Truncated UUID (missing a segment) is filtered out."""
truncated = "bcdd143f-4363-a85f-77b6748add6c"
response = f'[{{"cause_id": "{_UUID_A}", "effect_id": "{truncated}"}}]'
links = parse_causal_links(response)
assert len(links) == 0

def test_git_sha_filtered(self) -> None:
"""40-char git SHA (not a UUID) is filtered out."""
git_sha = "b7f2ab9cf46b4552d505a3fac075a1935a653b22"
response = f'[{{"cause_id": "{git_sha}", "effect_id": "{_UUID_A}"}}]'
links = parse_causal_links(response)
assert len(links) == 0

def test_valid_uuid_passes(self) -> None:
"""Valid UUID v4 strings pass validation."""
response = f'[{{"cause_id": "{_UUID_A}", "effect_id": "{_UUID_B}"}}]'
links = parse_causal_links(response)
assert len(links) == 1
assert links[0] == (_UUID_A, _UUID_B)

def test_mixed_valid_and_invalid_keeps_only_valid(self) -> None:
"""Only links with both IDs as valid UUIDs are kept."""
git_sha = "b7f2ab9cf46b4552d505a3fac075a1935a653b22"
response = (
f'[{{"cause_id": "{_UUID_A}", "effect_id": "{_UUID_B}"}}, '
f'{{"cause_id": "{git_sha}", "effect_id": "{_UUID_C}"}}, '
f'{{"cause_id": "{_UUID_C}", "effect_id": "{_UUID_D}"}}]'
)
links = parse_causal_links(response)
assert len(links) == 2
assert links[0] == (_UUID_A, _UUID_B)
assert links[1] == (_UUID_C, _UUID_D)

def test_plain_string_filtered(self) -> None:
"""Plain non-hex strings are filtered out."""
response = f'[{{"cause_id": "not-a-uuid", "effect_id": "{_UUID_A}"}}]'
links = parse_causal_links(response)
assert len(links) == 0

def test_empty_string_filtered(self) -> None:
"""Empty string IDs are filtered out."""
response = f'[{{"cause_id": "", "effect_id": "{_UUID_A}"}}]'
links = parse_causal_links(response)
assert len(links) == 0