From 5f33e1198192a60eec6ba3ba2a72b6e2e88869cb Mon Sep 17 00:00:00 2001 From: Fiifi Botchway Date: Thu, 25 Jun 2026 12:31:49 +0100 Subject: [PATCH 1/2] feat(mapping): add PGE loop for entity/relationship mapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce agent_mapping_pge — a Planner→Generator→Evaluator mapping engine — additively. Plans a source-model, generates entity + relationship SQL per ontology item, and gates each with a deterministic evaluator + a semantic critic. Coverage is engine-enforced from the ontology (abstract-superclass UNION derivation + synthetic-endpoint fallback so one failed hub can't drop all relationships). Additive: agent_auto_assignment is retained and still reachable via AgentClient.run_auto_assignment; a new AgentClient.run_mapping_pge gateway exposes the PGE engine, so an orchestrator can choose between them. Upstream features preserved (ToolContext.warehouse_id, Mapping._canonicalize_imported_uris). - NEW agent_mapping_pge package + tools/{planner,evaluation}.py - context.py: +source_model/+semantic_eval_report fields (warehouse_id kept) - Mapping.py: run PGE + accumulate source_model/evaluations/run_log; save_mappings_to_session gains 3 optional params (legacy path unaffected) - Tests: 90 in tests/agents/agent_mapping_pge; 208 across units/{agents,mapping} Co-authored-by: Isaac --- changelogs/v0.5.2/FiifiB_2026-06-25.log | 57 + src/agents/agent_mapping_pge/__init__.py | 50 + src/agents/agent_mapping_pge/contracts.py | 344 ++++ src/agents/agent_mapping_pge/coverage.py | 353 ++++ src/agents/agent_mapping_pge/engine.py | 1491 +++++++++++++++++ .../agent_mapping_pge/evaluator/__init__.py | 18 + .../agent_mapping_pge/evaluator/critic.py | 747 +++++++++ .../evaluator/deterministic.py | 539 ++++++ .../agent_mapping_pge/evaluator/report.py | 37 + .../agent_mapping_pge/generators/__init__.py | 31 + .../agent_mapping_pge/generators/entity.py | 812 +++++++++ .../generators/relationship.py | 875 ++++++++++ src/agents/agent_mapping_pge/planner.py | 729 ++++++++ src/agents/tools/context.py | 15 +- src/agents/tools/evaluation.py | 205 +++ src/agents/tools/mapping.py | 78 +- src/agents/tools/planner.py | 707 ++++++++ src/back/core/agents/AgentClient.py | 62 + src/back/objects/mapping/Mapping.py | 62 +- tests/agents/agent_mapping_pge/__init__.py | 0 .../agent_mapping_pge/test_contracts.py | 126 ++ .../agents/agent_mapping_pge/test_coverage.py | 140 ++ tests/agents/agent_mapping_pge/test_critic.py | 684 ++++++++ .../test_deterministic_evaluator.py | 916 ++++++++++ tests/agents/agent_mapping_pge/test_engine.py | 1117 ++++++++++++ .../test_entity_generator.py | 662 ++++++++ .../agents/agent_mapping_pge/test_planner.py | 519 ++++++ .../test_relationship_generator.py | 736 ++++++++ 28 files changed, 12104 insertions(+), 8 deletions(-) create mode 100644 changelogs/v0.5.2/FiifiB_2026-06-25.log create mode 100644 src/agents/agent_mapping_pge/__init__.py create mode 100644 src/agents/agent_mapping_pge/contracts.py create mode 100644 src/agents/agent_mapping_pge/coverage.py create mode 100644 src/agents/agent_mapping_pge/engine.py create mode 100644 src/agents/agent_mapping_pge/evaluator/__init__.py create mode 100644 src/agents/agent_mapping_pge/evaluator/critic.py create mode 100644 src/agents/agent_mapping_pge/evaluator/deterministic.py create mode 100644 src/agents/agent_mapping_pge/evaluator/report.py create mode 100644 src/agents/agent_mapping_pge/generators/__init__.py create mode 100644 src/agents/agent_mapping_pge/generators/entity.py create mode 100644 src/agents/agent_mapping_pge/generators/relationship.py create mode 100644 src/agents/agent_mapping_pge/planner.py create mode 100644 src/agents/tools/evaluation.py create mode 100644 src/agents/tools/planner.py create mode 100644 tests/agents/agent_mapping_pge/__init__.py create mode 100644 tests/agents/agent_mapping_pge/test_contracts.py create mode 100644 tests/agents/agent_mapping_pge/test_coverage.py create mode 100644 tests/agents/agent_mapping_pge/test_critic.py create mode 100644 tests/agents/agent_mapping_pge/test_deterministic_evaluator.py create mode 100644 tests/agents/agent_mapping_pge/test_engine.py create mode 100644 tests/agents/agent_mapping_pge/test_entity_generator.py create mode 100644 tests/agents/agent_mapping_pge/test_planner.py create mode 100644 tests/agents/agent_mapping_pge/test_relationship_generator.py diff --git a/changelogs/v0.5.2/FiifiB_2026-06-25.log b/changelogs/v0.5.2/FiifiB_2026-06-25.log new file mode 100644 index 00000000..aa803cfa --- /dev/null +++ b/changelogs/v0.5.2/FiifiB_2026-06-25.log @@ -0,0 +1,57 @@ +# 2026-06-25 — feat(mapping): PGE loop for entity/relationship mapping + +## Context + +Entity/relationship mapping previously ran through `agent_auto_assignment` — +a single-agent "implementer marks its own homework" loop with no planning or +independent evaluation. This change introduces `agent_mapping_pge`, a +Planner→Generator→Evaluator (PGE) mapping engine, **additively**: the original +`agent_auto_assignment` engine is retained and still reachable via +`AgentClient.run_auto_assignment`, so a downstream orchestrator can choose which +engine to run. + +The PGE engine plans a source-model, generates entity and relationship SQL per +ontology item, and gates each with a deterministic evaluator + a semantic +critic. Coverage is engine-enforced (computed from the ontology, not left to LLM +discretion), with abstract-superclass UNION derivation and a synthetic-endpoint +fallback so a single failed hub never cascades to drop all relationships. + +## Changes + +1. NEW package `src/agents/agent_mapping_pge/` — Planner (`planner.py`), + generators (`generators/{entity,relationship}.py`), evaluator + (`evaluator/{deterministic,critic,report}.py`), engine orchestrator + (`engine.py`, bounded ThreadPool walk + monotonic progress), `contracts.py` + (SourceModel/EvalReport), and `coverage.py` (deterministic ontology-derived + coverage; `skip[]` is advisory and never removes an item). +2. NEW `src/agents/tools/planner.py` + `src/agents/tools/evaluation.py` — + planner/evaluation terminal tools (submit_source_model, submit_evaluation, + normalized_value_overlap) used by the PGE agents. +3. `src/agents/tools/context.py` — ADD `source_model` + `semantic_eval_report` + fields (forward-ref typed to avoid a circular import). `warehouse_id` and all + existing fields are preserved. +4. `src/agents/tools/mapping.py` — additive PGE tool-schema plumbing + (`unmapped_attributes`, `MAPPING_TOOL_DEFINITIONS_BY_NAME`). +5. `src/back/core/agents/AgentClient.py` — ADD `run_mapping_pge()` gateway + (→ `agent_mapping_pge`). `run_auto_assignment()` is unchanged and still + points at `agent_auto_assignment` (the simple engine is retained). +6. `src/back/objects/mapping/Mapping.py` — run the PGE engine in the auto-assign + flow and accumulate the PGE extras (`source_model`, `mapping_evaluations`, + `mapping_run_log`) across chunks and single-item runs; + `save_mappings_to_session` gains three OPTIONAL params (default `None`, so the + legacy path is unaffected). The upstream `_canonicalize_imported_uris` helper + is preserved. +7. Tests: `tests/agents/agent_mapping_pge/` — contracts, coverage, planner, + entity/relationship generators, deterministic evaluator, critic, engine. + +## Modified / added files + +27 files changed, 12047 insertions(+), 8 deletions(-). New `agent_mapping_pge` +package (12 modules) + 2 new tools + 9 test modules; 4 additive modifications +(`context.py`, `mapping.py`, `AgentClient.py`, `Mapping.py`). + +## Tests + +- `uv run pytest tests/agents/agent_mapping_pge -q` → **90 passed**. +- `uv run pytest tests/units/agents tests/units/mapping -q` → **208 passed**. +- Imports resolve on the upstream base (origin/master, v0.5.2). diff --git a/src/agents/agent_mapping_pge/__init__.py b/src/agents/agent_mapping_pge/__init__.py new file mode 100644 index 00000000..0da0713e --- /dev/null +++ b/src/agents/agent_mapping_pge/__init__.py @@ -0,0 +1,50 @@ +"""Planner -> Generator -> Evaluator (PGE) mapping agent. + +Three-stage mapping pipeline that replaces the prior single-loop ReAct +mapping agent: + +* **Planner** — proposes a :class:`SourceModel` (table roles, canonical IDs, + join keys, ordered mapping plan). +* **Generator** — produces individual entity/relationship mappings given the + plan. +* **Evaluator** — checks each submitted mapping; stage 1 is deterministic + (pure SQL counts), stage 2 is semantic. + +Sprint 1 lays the foundation: the typed contracts plus the deterministic +evaluator. Subsequent sprints add the LLM-backed Planner, Generator, +semantic Evaluator, and the orchestrating loop. +""" + +from agents.agent_mapping_pge.contracts import ( + CanonicalId, + EvalFailure, + EvalReport, + JoinKey, + MappingPlan, + RetryState, + SkipItem, + SourceModel, + TableRole, + TableRoleCandidate, +) +from agents.agent_mapping_pge.engine import ( + AgentResult, + AgentStep, + run_agent, +) + +__all__ = [ + "AgentResult", + "AgentStep", + "CanonicalId", + "EvalFailure", + "EvalReport", + "JoinKey", + "MappingPlan", + "RetryState", + "SkipItem", + "SourceModel", + "TableRole", + "TableRoleCandidate", + "run_agent", +] diff --git a/src/agents/agent_mapping_pge/contracts.py b/src/agents/agent_mapping_pge/contracts.py new file mode 100644 index 00000000..8172e431 --- /dev/null +++ b/src/agents/agent_mapping_pge/contracts.py @@ -0,0 +1,344 @@ +"""Typed contracts for the mapping PGE pipeline. + +These dataclasses are the load-bearing interface between Planner, Generator, +Evaluator, and the orchestrator (added in later sprints). All shapes here +are JSON round-trippable via ``to_dict()`` / ``from_dict()`` so they can be +persisted as artefacts, attached to MLflow traces, or shipped over the wire +to the UI. + +No LLM code lives here; this is a pure-data module. +""" + +from dataclasses import dataclass, field, fields, is_dataclass +from typing import Any, Dict, List, Optional + + +# ===================================================== +# SourceModel — Planner output +# ===================================================== + + +@dataclass +class TableRoleCandidate: + """A candidate ontology class for a given source table.""" + + uri: str + confidence: float # 0.0 .. 1.0 + reason: str = "" + + def to_dict(self) -> Dict[str, Any]: + return {"uri": self.uri, "confidence": self.confidence, "reason": self.reason} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TableRoleCandidate": + return cls( + uri=data["uri"], + confidence=float(data["confidence"]), + reason=data.get("reason", ""), + ) + + +@dataclass +class TableRole: + """A source table together with its ranked ontology-class candidates.""" + + table: str # full name catalog.schema.table + ontology_class_candidates: List[TableRoleCandidate] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "table": self.table, + "ontology_class_candidates": [ + c.to_dict() for c in self.ontology_class_candidates + ], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TableRole": + return cls( + table=data["table"], + ontology_class_candidates=[ + TableRoleCandidate.from_dict(c) + for c in data.get("ontology_class_candidates", []) + ], + ) + + +@dataclass +class CanonicalId: + """Identifier conventions for an ontology class across its source tables. + + ``canonical_column_per_table`` maps a full table name -> the column to + use as the canonical identifier in that table (e.g. NHS number rather + than the trust-local patient id). + """ + + ontology_class: str # class URI + canonical_column_per_table: Dict[str, str] = field(default_factory=dict) + format_note: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "ontology_class": self.ontology_class, + "canonical_column_per_table": dict(self.canonical_column_per_table), + "format_note": self.format_note, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CanonicalId": + return cls( + ontology_class=data["ontology_class"], + canonical_column_per_table=dict( + data.get("canonical_column_per_table", {}) + ), + format_note=data.get("format_note", ""), + ) + + +@dataclass +class JoinKey: + """A proposed join between two table.column references. + + ``kind`` distinguishes within-trust foreign keys from value-matched + cross-source joins (e.g. NHS-number-to-NHS-number across trusts). + """ + + from_ref: str # "table.col" + to_ref: str # "table.col" + confidence: float # 0..1 + overlap_pct: float # 0..1 + kind: str # "same_trust_fk" | "cross_source_value_match" + + def to_dict(self) -> Dict[str, Any]: + return { + "from_ref": self.from_ref, + "to_ref": self.to_ref, + "confidence": self.confidence, + "overlap_pct": self.overlap_pct, + "kind": self.kind, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "JoinKey": + return cls( + from_ref=data["from_ref"], + to_ref=data["to_ref"], + confidence=float(data["confidence"]), + overlap_pct=float(data["overlap_pct"]), + kind=data["kind"], + ) + + +@dataclass +class SkipItem: + """An ontology entity/relationship the planner has decided to skip.""" + + item: str # uri + reason: str = "" + + def to_dict(self) -> Dict[str, Any]: + return {"item": self.item, "reason": self.reason} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SkipItem": + return cls(item=data["item"], reason=data.get("reason", "")) + + +@dataclass +class MappingPlan: + """The order in which the Generator should attempt entity/relationship + mappings, plus any items the planner chose to drop.""" + + entity_order: List[str] = field(default_factory=list) + relationship_order: List[str] = field(default_factory=list) + skip: List[SkipItem] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "entity_order": list(self.entity_order), + "relationship_order": list(self.relationship_order), + "skip": [s.to_dict() for s in self.skip], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MappingPlan": + return cls( + entity_order=list(data.get("entity_order", [])), + relationship_order=list(data.get("relationship_order", [])), + skip=[SkipItem.from_dict(s) for s in data.get("skip", [])], + ) + + +@dataclass +class SourceModel: + """Output of the Planner stage; input to the Generator. + + Contains the planner's understanding of the source schema (table roles, + canonical ids, join keys) and the ordered plan of work for the + Generator. + """ + + table_roles: List[TableRole] = field(default_factory=list) + canonical_ids: List[CanonicalId] = field(default_factory=list) + join_keys: List[JoinKey] = field(default_factory=list) + mapping_plan: MappingPlan = field(default_factory=MappingPlan) + + def to_dict(self) -> Dict[str, Any]: + return { + "table_roles": [t.to_dict() for t in self.table_roles], + "canonical_ids": [c.to_dict() for c in self.canonical_ids], + "join_keys": [j.to_dict() for j in self.join_keys], + "mapping_plan": self.mapping_plan.to_dict(), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SourceModel": + return cls( + table_roles=[ + TableRole.from_dict(t) for t in data.get("table_roles", []) + ], + canonical_ids=[ + CanonicalId.from_dict(c) for c in data.get("canonical_ids", []) + ], + join_keys=[JoinKey.from_dict(j) for j in data.get("join_keys", [])], + mapping_plan=MappingPlan.from_dict(data.get("mapping_plan", {})), + ) + + +# ===================================================== +# EvalReport — Evaluator output +# ===================================================== + + +@dataclass +class EvalFailure: + """A single failed check inside an :class:`EvalReport`. + + ``hint`` is the actionable correction text fed back to the Generator on + retry; it should be concrete and template-y, not a free-form essay. + """ + + kind: str # "structural" | "semantic" + check: str # e.g. "dangling_source_pct" + expected: str # e.g. "< 0.05" + observed: str # e.g. "0.47" + hint: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "kind": self.kind, + "check": self.check, + "expected": self.expected, + "observed": self.observed, + "hint": self.hint, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EvalFailure": + return cls( + kind=data["kind"], + check=data["check"], + expected=data["expected"], + observed=data["observed"], + hint=data.get("hint", ""), + ) + + +@dataclass +class EvalReport: + """Outcome of evaluating a single submitted mapping. + + ``bubble_to_planner`` signals that the failure cannot reasonably be + fixed by the Generator alone and warrants re-planning (e.g. wrong + canonical id column, table assigned to wrong ontology class). + """ + + status: str # "PASS" | "FAIL" + stage: str # "deterministic" | "semantic" + metrics: Dict[str, Any] = field(default_factory=dict) + failures: List[EvalFailure] = field(default_factory=list) + bubble_to_planner: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "status": self.status, + "stage": self.stage, + "metrics": dict(self.metrics), + "failures": [f.to_dict() for f in self.failures], + "bubble_to_planner": self.bubble_to_planner, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EvalReport": + return cls( + status=data["status"], + stage=data["stage"], + metrics=dict(data.get("metrics", {})), + failures=[EvalFailure.from_dict(f) for f in data.get("failures", [])], + bubble_to_planner=bool(data.get("bubble_to_planner", False)), + ) + + +# ===================================================== +# RetryState — orchestrator bookkeeping (used in Sprint 7) +# ===================================================== + + +@dataclass +class RetryState: + """Per-item retry budget tracked by the orchestrator. + + The orchestrator caps the Generator at 3 attempts per item before + giving up, and bumps the Planner at most twice per item if the + evaluator keeps bubbling failures upstream. + """ + + item_uri: str + generator_attempts: int = 0 + planner_reinvocations: int = 0 + last_eval_report: Optional[EvalReport] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "item_uri": self.item_uri, + "generator_attempts": self.generator_attempts, + "planner_reinvocations": self.planner_reinvocations, + "last_eval_report": ( + self.last_eval_report.to_dict() + if self.last_eval_report is not None + else None + ), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RetryState": + last = data.get("last_eval_report") + return cls( + item_uri=data["item_uri"], + generator_attempts=int(data.get("generator_attempts", 0)), + planner_reinvocations=int(data.get("planner_reinvocations", 0)), + last_eval_report=EvalReport.from_dict(last) if last is not None else None, + ) + + +# ===================================================== +# Sanity check — keep dataclass discovery introspectable +# ===================================================== + +_ALL_CONTRACTS = ( + TableRoleCandidate, + TableRole, + CanonicalId, + JoinKey, + SkipItem, + MappingPlan, + SourceModel, + EvalFailure, + EvalReport, + RetryState, +) +for _cls in _ALL_CONTRACTS: + assert is_dataclass(_cls), f"{_cls.__name__} must be a dataclass" + # touch ``fields`` to ensure all defaults are well-formed at import time. + fields(_cls) +del _cls diff --git a/src/agents/agent_mapping_pge/coverage.py b/src/agents/agent_mapping_pge/coverage.py new file mode 100644 index 00000000..9277a3f2 --- /dev/null +++ b/src/agents/agent_mapping_pge/coverage.py @@ -0,0 +1,353 @@ +"""Deterministic coverage enforcement + derived-mapping construction. + +The Planner is good at the *how* (which column is the canonical id, which +tables join, in what order to attempt things) but it must NOT be trusted with +the *what*: the set of entities and relationships that need mapping is fixed — +it is exactly the input ontology. Leaving coverage to the LLM produced +non-deterministic partial runs (some classes silently dropped into +``mapping_plan.skip``) and, because relationships were skipped whenever an +endpoint entity was absent, zero relationship coverage. + +This module computes the FULL, dependency-ordered coverage set from the +ontology itself (using the Planner's order only as a hint), and builds the two +kinds of mapping the LLM Generators cannot/should not produce: + +* **Abstract-superclass mappings** — a class with subclasses but no source + table of its own (``Person``, ``Patient``, ``Clinicalencounter``, + ``Clinicalfinding``). Its instances are exactly the UNION of its concrete + leaf subclasses, so we derive it mechanically from the already-validated + subclass mappings. Reusing the subclasses' verbatim ``sql_query`` keeps the + abstract id-universe byte-identical to the union of its parts, which is what + makes relationships pointing at the abstract (e.g. ``managedby`` with domain + ``Clinicalencounter``) join with zero dangling. + +* **Synthetic endpoint id-universe** — when a relationship endpoint entity has + no full mapping yet, we synthesise a minimal ``{sql_query, id_column}`` from + the Planner's ``canonical_ids`` so the relationship can still be attempted + instead of silently skipped. + +All functions here are pure data transforms — no LLM, no I/O — so they are +fast and unit-testable. +""" + +from typing import Dict, List, Optional, Set, Tuple + +from back.core.logging import get_logger +from agents.agent_mapping_pge.contracts import SourceModel + +logger = get_logger(__name__) + + +# ===================================================== +# Ontology structure helpers +# ===================================================== + + +def _entities(ontology: dict) -> List[dict]: + return (ontology or {}).get("entities", []) or [] + + +def _relationships(ontology: dict) -> List[dict]: + return (ontology or {}).get("relationships", []) or [] + + +def _uri(entity: dict) -> str: + return entity.get("uri") or entity.get("name") or "" + + +def name_to_uri(ontology: dict) -> Dict[str, str]: + """Map both short name and label to URI for parent/domain/range resolution.""" + out: Dict[str, str] = {} + for e in _entities(ontology): + uri = _uri(e) + if not uri: + continue + out[uri] = uri + if e.get("name"): + out[e["name"]] = uri + if e.get("label"): + out[e["label"]] = uri + return out + + +def parent_uri(entity: dict, n2u: Dict[str, str]) -> Optional[str]: + """Resolve a class's parent (stored as a name/label/uri) to a URI.""" + p = (entity.get("parent") or "").strip() + if not p: + return None + return n2u.get(p, p if p.startswith("http") else None) + + +def _tables_for_class(source_model: Optional[SourceModel]) -> Set[str]: + """URIs that the Planner assigned at least one source table to.""" + if source_model is None: + return set() + out: Set[str] = set() + for role in source_model.table_roles: + for cand in role.ontology_class_candidates: + out.add(cand.uri) + return out + + +def classify( + ontology: dict, + source_model: Optional[SourceModel], + *, + synthesized_uris: Optional[Set[str]] = None, +) -> Tuple[Set[str], Set[str]]: + """Partition classes into (concrete, abstract). + + A class is **abstract/derived** when it has subclasses but no source table + of its own — its rows are the union of its concrete descendants. Every + other class is **concrete** (it has, or will have via synthesis, a source + table and gets a normal Generator mapping). + """ + n2u = name_to_uri(ontology) + has_children: Set[str] = set() + for e in _entities(ontology): + p = parent_uri(e, n2u) + if p: + has_children.add(p) + + has_table = _tables_for_class(source_model) | (synthesized_uris or set()) + + concrete: Set[str] = set() + abstract: Set[str] = set() + for e in _entities(ontology): + uri = _uri(e) + if not uri: + continue + if uri in has_children and uri not in has_table: + abstract.add(uri) + else: + concrete.add(uri) + return concrete, abstract + + +def concrete_leaf_descendants( + abstract_uri: str, ontology: dict, concrete: Set[str] +) -> List[str]: + """All concrete descendant class URIs beneath ``abstract_uri`` (transitive).""" + n2u = name_to_uri(ontology) + children_of: Dict[str, List[str]] = {} + for e in _entities(ontology): + p = parent_uri(e, n2u) + if p: + children_of.setdefault(p, []).append(_uri(e)) + + out: List[str] = [] + stack = list(children_of.get(abstract_uri, [])) + seen: Set[str] = set() + while stack: + cur = stack.pop() + if cur in seen: + continue + seen.add(cur) + if cur in concrete: + out.append(cur) + stack.extend(children_of.get(cur, [])) + return out + + +# ===================================================== +# Coverage ordering (engine-enforced — the "what") +# ===================================================== + + +def full_entity_order( + ontology: dict, + source_model: Optional[SourceModel], + *, + synthesized_uris: Optional[Set[str]] = None, +) -> List[str]: + """Complete entity order: every class, concrete-first, abstracts after + their descendants. + + Uses the Planner's ``entity_order`` only to order concrete classes (it + knows base-vs-referencer dependencies); abstracts are appended in + descendant-before-ancestor order so a derived union can read its parts. + """ + concrete, abstract = classify( + ontology, source_model, synthesized_uris=synthesized_uris + ) + planned = ( + list(source_model.mapping_plan.entity_order) if source_model else [] + ) + + ordered: List[str] = [] + seen: Set[str] = set() + + # 1. Concrete classes in the Planner's order first… + for uri in planned: + if uri in concrete and uri not in seen: + ordered.append(uri) + seen.add(uri) + # 2. …then any concrete class the Planner omitted (coverage guarantee). + for e in _entities(ontology): + uri = _uri(e) + if uri in concrete and uri not in seen: + ordered.append(uri) + seen.add(uri) + + # 3. Abstracts, ordered so each appears after all its descendants. + remaining = [u for u in abstract if u not in seen] + + def _depth(uri: str) -> int: + # number of concrete leaves — deeper subtrees (more leaves) last is + # fine; what matters is a class never precedes its own descendant. + return len(concrete_leaf_descendants(uri, ontology, concrete)) + + # Topologically: a class with fewer abstract-ancestors first. Simple stable + # approach: repeatedly emit abstracts whose abstract-children are all done. + n2u = name_to_uri(ontology) + abstract_children: Dict[str, List[str]] = {} + for e in _entities(ontology): + uri = _uri(e) + if uri in abstract: + p = parent_uri(e, n2u) + if p in abstract: + abstract_children.setdefault(p, []).append(uri) + + progress = True + while remaining and progress: + progress = False + for uri in list(remaining): + kids = abstract_children.get(uri, []) + if all(k in seen for k in kids): + ordered.append(uri) + seen.add(uri) + remaining.remove(uri) + progress = True + # Anything left (cycles — shouldn't happen) just append. + for uri in remaining: + ordered.append(uri) + seen.add(uri) + + return ordered + + +def full_relationship_order( + ontology: dict, entity_order: List[str], source_model: Optional[SourceModel] +) -> List[str]: + """Every object property, ordered so both endpoints precede it where + possible (falls back to Planner order / declaration order).""" + rels = _relationships(ontology) + rel_uris = [r.get("uri") or r.get("name") for r in rels] + rel_uris = [u for u in rel_uris if u] + + planned = ( + list(source_model.mapping_plan.relationship_order) if source_model else [] + ) + ordered: List[str] = [] + seen: Set[str] = set() + for uri in planned + rel_uris: + if uri and uri not in seen: + ordered.append(uri) + seen.add(uri) + return ordered + + +# ===================================================== +# Derived mappings (the "how" the LLM cannot produce) +# ===================================================== + +_ID = "ID" + + +def _attr_names(entity: dict) -> List[str]: + out: List[str] = [] + for a in entity.get("attributes", []) or []: + if isinstance(a, dict): + n = a.get("name") + if n: + out.append(str(n)) + elif a is not None: + out.append(str(a)) + return out + + +def build_abstract_union_mapping( + abstract_uri: str, + abstract_entity: dict, + subclass_mappings: List[dict], +) -> Optional[dict]: + """Build a derived entity mapping for an abstract class as the UNION ALL of + its concrete subclass mappings. + + Reuses each subclass's verbatim ``sql_query`` (wrapped in a subquery) so the + abstract id-universe equals the union of the parts exactly. Projects the + abstract class's own attributes by re-aliasing each subclass's + ``attribute_mappings`` value to the ontology attribute name; subclasses that + do not carry an attribute contribute ``NULL`` for it. + """ + subs = [m for m in subclass_mappings if m and m.get("sql_query")] + if not subs: + return None + + attrs = _attr_names(abstract_entity) + selects: List[str] = [] + for m in subs: + amap = m.get("attribute_mappings") or {} + cols = [_ID] + for attr in attrs: + src_alias = amap.get(attr) + if src_alias: + cols.append(f"{src_alias} AS {attr}") + else: + cols.append(f"CAST(NULL AS STRING) AS {attr}") + selects.append(f"SELECT {', '.join(cols)} FROM ({m['sql_query']}) ") + union = " UNION ALL ".join(selects) + # DISTINCT on the whole projection collapses any incidental duplicates while + # preserving distinct ids (subclass id spaces are disjoint by construction). + sql = f"SELECT DISTINCT * FROM ({union}) _abstract WHERE {_ID} IS NOT NULL" + + return { + "class_uri": abstract_uri, + "ontology_class": abstract_uri, + "class_name": abstract_entity.get("name", abstract_uri), + "sql_query": sql, + "id_column": _ID, + "label_column": _ID, + "attribute_mappings": {a: a for a in attrs}, + "unmapped_attributes": [], + "derived": "abstract_union", + } + + +def synthetic_endpoint_mapping( + source_model: Optional[SourceModel], class_uri: str +) -> Optional[dict]: + """Build a minimal id-universe-only entity mapping for a relationship + endpoint from the Planner's ``canonical_ids``. + + Used as a fallback so a relationship is never skipped just because its + endpoint entity lacks a full mapping. Produces a UNION ALL of + ``SELECT AS ID FROM `` over every table the + Planner recorded for the class. + """ + if source_model is None: + return None + cid = next( + (c for c in source_model.canonical_ids if c.ontology_class == class_uri), + None, + ) + if cid is None or not cid.canonical_column_per_table: + return None + selects = [ + f"SELECT {expr} AS {_ID} FROM {table}" + for table, expr in cid.canonical_column_per_table.items() + if expr and table + ] + if not selects: + return None + inner = " UNION ALL ".join(selects) + sql = f"SELECT DISTINCT {_ID} FROM ({inner}) _u WHERE {_ID} IS NOT NULL" + return { + "class_uri": class_uri, + "ontology_class": class_uri, + "sql_query": sql, + "id_column": _ID, + "attribute_mappings": {}, + "unmapped_attributes": [], + "derived": "synthetic_endpoint", + } diff --git a/src/agents/agent_mapping_pge/engine.py b/src/agents/agent_mapping_pge/engine.py new file mode 100644 index 00000000..d6245b2a --- /dev/null +++ b/src/agents/agent_mapping_pge/engine.py @@ -0,0 +1,1491 @@ +""" +OntoBricks Mapping-PGE Orchestrator. + +Wires the Planner, the Entity/Relationship Generators, and the two-stage +Evaluator (deterministic + semantic critic) into a single ``run_agent`` +entry point. + +The public ``run_agent`` signature and :class:`AgentResult` shape match the +prior in-house single-loop mapping agent so ``back/objects/mapping/Mapping.py`` +can call this engine without other changes. + +Control flow per item (entity or relationship) +============================================== + +1. Build a focused slice from the Planner's :class:`SourceModel`. +2. Run the appropriate Generator with ``retry_hint=None``. +3. Run the deterministic evaluator. On FAIL: + * if ``bubble_to_planner=True`` -> escalate to Planner (capped at 2 global + replans across the whole run); + * else retry the Generator with the first failure's hint. +4. On stage-1 PASS, run the semantic critic (unless ``skip_semantic_critic`` + is set). Same bubble / hint logic on FAIL. +5. After 3 unsuccessful attempts, the item is recorded as ``FAIL_BUDGET`` and + the orchestrator moves on to the next item. + +Step-log design +=============== + +``AgentResult.steps`` is a HIGH-LEVEL log — one entry per stage transition +(planner-start, generator-start, evaluator-result, critic-result, item-done). +The detailed per-tool steps emitted by each sub-agent stay on the sub-agent's +own result dataclass (``PlannerResult.steps``, ``EntityGenResult.steps``, …) +and are NOT merged into the orchestrator's ``steps`` field. This keeps the +top-level log readable in the UI; the persistence layer can attach sub-agent +step lists separately when needed. +""" + +import concurrent.futures +import threading +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple + +from back.core.logging import get_logger +from agents.agent_mapping_pge.contracts import EvalReport, SourceModel +from agents.agent_mapping_pge.coverage import ( + build_abstract_union_mapping, + classify, + concrete_leaf_descendants, + full_entity_order, + full_relationship_order, + synthetic_endpoint_mapping, +) +from agents.agent_mapping_pge.evaluator.critic import run_critic +from agents.agent_mapping_pge.evaluator.deterministic import ( + evaluate_entity_mapping, + evaluate_relationship_mapping, +) +from agents.agent_mapping_pge.generators.entity import run_entity_generator +from agents.agent_mapping_pge.generators.relationship import ( + run_relationship_generator, +) +from agents.agent_mapping_pge.planner import run_planner +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +# Per-item retry budget for the Generator->Evaluator inner loop. +_PER_ITEM_GENERATOR_ATTEMPTS = 4 +# Global cap on Planner re-invocations triggered by escalated failures. +_PLANNER_REINVOCATION_BUDGET = 3 + +# Bounded parallelism for per-item Generator->Evaluator work. Items are mutually +# independent and the SQL client is connection-pooled + thread-safe, so this is +# the single biggest wall-clock win. Kept modest to respect FM-endpoint rate +# limits (call_llm_with_retry handles any 429 backoff). +_MAX_CONCURRENCY = 4 + + +# ===================================================== +# Public dataclasses — mirror the prior mapping agent's shapes +# ===================================================== + + +@dataclass +class AgentStep: + """One observable step of the orchestrator's execution. + + Same shape as :class:`agents.engine_base.AgentStep` plus a few extra + ``step_type`` values used by the PGE orchestrator: + + * ``"planner"`` / ``"generator"`` / ``"evaluator"`` / ``"critic"`` for + stage transitions; the legacy ``"tool_call"`` / ``"tool_result"`` / + ``"output"`` types remain valid so this struct is fully drop-in- + compatible with the prior orchestrator. + """ + + step_type: str + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class AgentResult: + """Outcome of a full PGE orchestration run. + + The first eight fields mirror the prior in-house mapping-agent's result + dataclass exactly so callers can swap engines without touching their + downstream code. The last three are PGE-specific extras the caller + can choose to persist. + """ + + success: bool + entity_mappings: list = field(default_factory=list) + relationship_mappings: list = field(default_factory=list) + steps: List[AgentStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + stats: Dict[str, int] = field(default_factory=dict) + # PGE-specific extras + source_model: Optional[dict] = None + mapping_evaluations: Dict[str, dict] = field(default_factory=dict) + mapping_run_log: List[dict] = field(default_factory=list) + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _ontology_index(ontology: dict) -> Dict[str, dict]: + """Build ``uri -> entity dict`` for fast lookup by URI.""" + out: Dict[str, dict] = {} + for e in (ontology or {}).get("entities", []) or []: + uri = e.get("uri") or e.get("name") + if uri: + out[uri] = e + return out + + +def _relationship_index(ontology: dict) -> Dict[str, dict]: + """Build ``uri -> relationship dict`` for fast lookup by URI.""" + out: Dict[str, dict] = {} + for r in (ontology or {}).get("relationships", []) or []: + uri = r.get("uri") or r.get("name") + if uri: + out[uri] = r + return out + + +def _slice_for_entity(source_model: SourceModel, class_uri: str) -> dict: + """Render the SourceModel slice consumed by the EntityGenerator. + + The slice surfaces only what's relevant to one ontology class: + candidate tables, the canonical-ID per chosen table, and any joins + naming a candidate table on at least one side. + """ + candidate_tables: List[dict] = [] + candidate_table_names: set = set() + for role in source_model.table_roles: + for cand in role.ontology_class_candidates: + if cand.uri == class_uri: + candidate_tables.append( + { + "table": role.table, + "confidence": cand.confidence, + "reason": cand.reason, + } + ) + candidate_table_names.add(role.table) + break # one entry per role is enough + + canonical_id_obj: Dict[str, Any] = { + "ontology_class": class_uri, + "canonical_column_per_table": {}, + "format_note": "", + } + for c in source_model.canonical_ids: + if c.ontology_class == class_uri: + canonical_id_obj = c.to_dict() + break + + relevant_joins: List[dict] = [] + for j in source_model.join_keys: + from_table = j.from_ref.split(".")[0] if j.from_ref else "" + to_table = j.to_ref.split(".")[0] if j.to_ref else "" + if any( + ft == from_table or ft.endswith("." + from_table) + for ft in candidate_table_names + ) or any( + tt == to_table or tt.endswith("." + to_table) + for tt in candidate_table_names + ): + relevant_joins.append(j.to_dict()) + + return { + "candidate_tables": candidate_tables, + "canonical_id": canonical_id_obj, + "relevant_joins": relevant_joins, + } + + +def _slice_for_relationship( + source_model: SourceModel, + property_uri: str, + source_entity_mapping: dict, + target_entity_mapping: dict, +) -> dict: + """Render the SourceModel slice consumed by the RelationshipGenerator. + + The slice surfaces every join key the Planner produced (the Generator + picks among them), plus the candidate-table list filtered to the + source/target classes when those classes are known. + """ + src_class = (source_entity_mapping or {}).get("ontology_class") or ( + source_entity_mapping or {} + ).get("class_uri", "") + tgt_class = (target_entity_mapping or {}).get("ontology_class") or ( + target_entity_mapping or {} + ).get("class_uri", "") + endpoint_classes = {c for c in (src_class, tgt_class) if c} + + candidate_tables: List[dict] = [] + for role in source_model.table_roles: + for cand in role.ontology_class_candidates: + if not endpoint_classes or cand.uri in endpoint_classes: + candidate_tables.append( + { + "table": role.table, + "ontology_class": cand.uri, + "confidence": cand.confidence, + "reason": cand.reason, + } + ) + break + + relevant_joins = [j.to_dict() for j in source_model.join_keys] + + return { + "property_uri": property_uri, + "relevant_joins": relevant_joins, + "candidate_tables": candidate_tables, + } + + +def _wrap_execute_sql(client: Any) -> Callable[[str], dict]: + """Adapt ``client.execute_query`` to the evaluator's expected shape. + + The deterministic evaluator wants ``{"columns": [...], "rows": [{...}]}`` + with FULL rows. ``client.execute_query`` returns ``List[Dict[str, Any]]`` + — we promote that to the evaluator's shape and derive columns from the + first row. Calling the underlying client directly (rather than the + sampling ``tool_execute_sql``) is load-bearing: the deterministic + evaluator's count-based checks need real values, not stringified ones. + """ + + def _run(sql: str) -> dict: + rows = client.execute_query(sql) or [] + if isinstance(rows, dict) and "rows" in rows: + return rows # client already returns the evaluator's shape + columns: List[str] = [] + if rows and isinstance(rows[0], dict): + columns = list(rows[0].keys()) + return {"columns": columns, "rows": list(rows)} + + return _run + + +def _first_hint(report: EvalReport) -> Optional[str]: + """Return the first failure's hint (or ``None`` when the report has none).""" + for f in report.failures: + if f.hint: + return f.hint + return None + + +def _resolve_endpoint_em( + ref: str, + by_uri: Dict[str, dict], + entity_index: Dict[str, dict], +) -> Optional[dict]: + """Best-effort lookup of an endpoint entity mapping. + + The ontology's ``domain`` / ``range`` may use either the entity's full + URI or its short name. We try direct lookup, then a name-match scan. + """ + if not ref: + return None + if ref in by_uri: + return by_uri[ref] + for uri, ent in entity_index.items(): + if ent.get("name") == ref or ent.get("label") == ref: + if uri in by_uri: + return by_uri[uri] + return None + + +def _ref_to_uri(ref: str, entity_index: Dict[str, dict]) -> str: + """Resolve a domain/range ref (URI, name, or label) to a class URI.""" + if not ref or ref in entity_index: + return ref + for uri, ent in entity_index.items(): + if ent.get("name") == ref or ent.get("label") == ref: + return uri + return ref + + +def _endpoint_em(state: "_RunState", ref: str) -> Optional[dict]: + """Resolve a relationship endpoint to an entity mapping carrying an + id-universe SQL. + + Prefers the real (fully-mapped) entity mapping; falls back to a synthetic + id-universe built from the Planner's ``canonical_ids`` so a relationship is + never skipped merely because its endpoint entity's attribute mapping + failed. Returns ``None`` only when the class is entirely unknown to both + the mapped set and the source model. + """ + real = state.entity_mapping_by_uri.get(ref) or _resolve_endpoint_em( + ref, state.entity_mapping_by_uri, state.entity_index + ) + if real is not None: + return real + uri = _ref_to_uri(ref, state.entity_index) + return synthetic_endpoint_mapping(state.source_model, uri) + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_engine") +def run_agent( + host: str, + token: str, + endpoint_name: str, + client: Any, + metadata: dict, + ontology: dict, + entity_mappings: Optional[list] = None, + relationship_mappings: Optional[list] = None, + documents: Optional[list] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: Optional[int] = None, + *, + skip_semantic_critic: bool = False, +) -> AgentResult: + """Run the PGE mapping orchestrator. + + Drop-in replacement for the prior in-house single-loop mapping agent — + same positional/keyword signature, same :class:`AgentResult` shape. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client exposing ``execute_query(sql)``. + metadata: Imported table metadata to hand to the Planner. + ontology: Ontology dict with ``entities`` and ``relationships``. + entity_mappings: Pre-seeded entity mappings (URI matched -> skipped). + relationship_mappings: Pre-seeded relationship mappings (likewise). + documents: Optional pre-loaded domain documents. + on_step: Optional progress callback ``(msg, pct)``. + max_iterations: Per-item override for the Generator's iteration cap. + Kept for API parity with the legacy engine; ``None`` uses each + sub-agent's default. + skip_semantic_critic: When ``True``, the orchestrator skips the + stage-2 critic and accepts every stage-1 PASS as a final PASS. + Production callers leave this ``False``; tests flip it ``True`` + to avoid LLM calls in the orchestrator's unit tests. + + Returns: + An :class:`AgentResult` with the submitted mappings, a high-level + ``steps`` log, per-item ``mapping_run_log``, and PGE-specific + extras (``source_model``, ``mapping_evaluations``). + """ + # ------------------------------------------------------------------ + # Per-call state lives entirely on this RunState object — no module- + # level mutables, so concurrent calls (and tests) cannot collide. + # ------------------------------------------------------------------ + state = _RunState( + host=host, + token=token, + endpoint_name=endpoint_name, + client=client, + metadata=metadata or {}, + ontology=ontology or {}, + documents=list(documents or []), + on_step=on_step, + max_iterations=max_iterations, + skip_semantic_critic=skip_semantic_critic, + ) + + # Pre-seeded mappings carry over verbatim — we never overwrite a URI the + # caller already mapped. + pre_entity_list = list(entity_mappings or []) + pre_rel_list = list(relationship_mappings or []) + preseeded_entity_uris = { + m.get("ontology_class") or m.get("class_uri") or "" for m in pre_entity_list + } + preseeded_entity_uris.discard("") + preseeded_rel_uris = { + m.get("property") or m.get("property_uri") or "" for m in pre_rel_list + } + preseeded_rel_uris.discard("") + + state.entity_mappings.extend(pre_entity_list) + state.relationship_mappings.extend(pre_rel_list) + for m in pre_entity_list: + uri = m.get("ontology_class") or m.get("class_uri") + if uri: + state.entity_mapping_by_uri[uri] = m + + entities_in_scope = state.ontology.get("entities", []) or [] + relationships_in_scope = state.ontology.get("relationships", []) or [] + + logger.info( + "===== MAPPING-PGE ENGINE START ===== endpoint=%s, entities=%d, " + "relationships=%d, preseeded_entities=%d, preseeded_rels=%d, " + "skip_critic=%s", + endpoint_name, + len(entities_in_scope), + len(relationships_in_scope), + len(preseeded_entity_uris), + len(preseeded_rel_uris), + skip_semantic_critic, + ) + + # ------------------------------------------------------------------ + # 1. Planner + # ------------------------------------------------------------------ + state.notify("Planning…", pct=2) + state.add_step("planner", "planner-start") + + t0 = time.time() + try: + planner_result = run_planner( + host=host, + token=token, + endpoint_name=endpoint_name, + client=client, + metadata=state.metadata, + ontology=state.ontology, + documents=state.documents, + on_step=None, + ) + except Exception as exc: # noqa: BLE001 — surface any failure as run failure + logger.error("Planner raised an exception: %s", exc, exc_info=True) + return state.finalise(error=f"planner exception: {exc}") + + planner_ms = int((time.time() - t0) * 1000) + state.add_iterations(planner_result.iterations) + state.accumulate_usage(planner_result.usage) + + if not planner_result.success or planner_result.source_model is None: + state.add_step( + "planner", + f"planner-fail: {planner_result.error}", + duration_ms=planner_ms, + ) + logger.error("===== MAPPING-PGE ENGINE FAILED ===== planner failed") + state.notify("Planner failed — aborting.", pct=10) + return state.finalise( + error=f"planner failed: {planner_result.error or 'no source model'}" + ) + + state.source_model = planner_result.source_model + state.refresh_plan() + state.add_step( + "planner", + f"planner-done: entities={len(state.entity_order)}, " + f"relationships={len(state.relationship_order)}", + duration_ms=planner_ms, + ) + + # ------------------------------------------------------------------ + # 2. Walk the plan — entities first, then relationships. + # ------------------------------------------------------------------ + state.entity_index = _ontology_index(state.ontology) + state.rel_index = _relationship_index(state.ontology) + state.execute_sql_fn = _wrap_execute_sql(client) + state.total_items_planned = len(state.entity_order) + len( + state.relationship_order + ) + + # ------------------------------------------------------------------ + # Entity walk — three phases: + # 1. concrete classes — independent, run in a bounded thread pool; + # 2. abstract superclasses — derived from the concrete mappings, so they + # run AFTER phase 1 (cheap, no LLM, kept sequential); + # pre-seeded classes are recorded inline and never re-mapped. + # Per-item work is independent and the SQL client is connection-pooled and + # thread-safe, so parallelism is the single biggest wall-clock win. + # ------------------------------------------------------------------ + concrete_items: List[Tuple[str, dict]] = [] + for entity_uri in list(state.entity_order): + ontology_class = state.entity_index.get(entity_uri, {"uri": entity_uri}) + label = ontology_class.get("label") or ontology_class.get("name", entity_uri) + if entity_uri in preseeded_entity_uris: + state.mapping_run_log.append( + {"item": entity_uri, "kind": "entity", "attempts": [], + "final_status": "PRESEEDED"} + ) + state.notify(f"Skipping pre-seeded {label}") + state.items_done += 1 + elif entity_uri not in state.abstract_uris: + concrete_items.append((entity_uri, ontology_class)) + + def _entity_runner(item): + uri, oc = item + return _run_entity_item(state, oc) + + for (entity_uri, ontology_class), outcome in _run_items_concurrently( + concrete_items, _entity_runner + ): + _merge_entity_result(state, entity_uri, ontology_class, outcome) + + # Phase 2 — abstract superclasses (depend on concrete mappings being present). + for entity_uri in list(state.entity_order): + if entity_uri in state.abstract_uris and entity_uri not in preseeded_entity_uris: + ontology_class = state.entity_index.get(entity_uri, {"uri": entity_uri}) + outcome = _run_abstract_item(state, ontology_class) + _merge_entity_result(state, entity_uri, ontology_class, outcome) + + # ------------------------------------------------------------------ + # Relationship walk — every relationship is independent once all entity + # id-universes exist, so the whole set runs in the same bounded pool. + # ------------------------------------------------------------------ + rel_items: List[Tuple[str, dict, dict, dict]] = [] + for property_uri in list(state.relationship_order): + prop = state.rel_index.get(property_uri, {"uri": property_uri}) + label = prop.get("label") or prop.get("name", property_uri) + if property_uri in preseeded_rel_uris: + state.mapping_run_log.append( + {"item": property_uri, "kind": "relationship", "attempts": [], + "final_status": "PRESEEDED"} + ) + state.notify(f"Skipping pre-seeded {label}") + state.items_done += 1 + continue + # Coverage is engine-enforced: resolve each endpoint to a full mapping + # or a synthetic id-universe (from canonical_ids) so a relationship is + # never silently skipped for a missing endpoint. + source_em = _endpoint_em(state, prop.get("domain", "") or "") + target_em = _endpoint_em(state, prop.get("range", "") or "") + if source_em is None or target_em is None: + missing = "source" if source_em is None else "target" + state.mapping_run_log.append( + {"item": property_uri, "kind": "relationship", "attempts": [], + "final_status": "FAIL_NO_ENDPOINT"} + ) + state.add_step( + "evaluator", + f"relationship {property_uri}: no {missing} id-universe — cannot attempt", + ) + state.notify(f"Cannot map {label}: {missing} endpoint has no id universe") + state.items_done += 1 + continue + rel_items.append((property_uri, prop, source_em, target_em)) + + def _rel_runner(item): + _uri, prop, source_em, target_em = item + return _run_relationship_item(state, prop, source_em, target_em) + + for (property_uri, prop, _s, _t), outcome in _run_items_concurrently( + rel_items, _rel_runner + ): + _merge_relationship_result(state, property_uri, prop, outcome) + + state.notify("Agent completed!", pct=100) + return state.finalise() + + +# ===================================================== +# Run-scoped mutable state +# ===================================================== + + +@dataclass +class _RunState: + """Encapsulates per-call mutable state — keeps ``run_agent`` re-entrant. + + All counters, mapping lists, and accumulators that need to evolve as the + walk progresses live here so the orchestrator never relies on module- + level globals. This also keeps the per-item helpers (``_run_*_item``) + pure functions of state + item input. + """ + + host: str + token: str + endpoint_name: str + client: Any + metadata: dict + ontology: dict + documents: List[Any] + on_step: Optional[Callable[[str, int], None]] + max_iterations: Optional[int] + skip_semantic_critic: bool + + # Output accumulators + entity_mappings: List[dict] = field(default_factory=list) + relationship_mappings: List[dict] = field(default_factory=list) + entity_mapping_by_uri: Dict[str, dict] = field(default_factory=dict) + mapping_run_log: List[dict] = field(default_factory=list) + mapping_evaluations: Dict[str, dict] = field(default_factory=dict) + steps: List[AgentStep] = field(default_factory=list) + usage: Dict[str, int] = field( + default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0} + ) + iterations: int = 0 + submitted_any: bool = False + + # Plan-derived state — refreshed on (re)plan. + source_model: Optional[SourceModel] = None + entity_order: List[str] = field(default_factory=list) + relationship_order: List[str] = field(default_factory=list) + skip_reasons: Dict[str, str] = field(default_factory=dict) + abstract_uris: set = field(default_factory=set) + planner_reinvocations: int = 0 + + # Walk progress + items_done: int = 0 + total_items_planned: int = 0 + + # Per-run caches & lookups + id_universe_cache: Dict[str, set] = field(default_factory=dict) + entity_index: Dict[str, dict] = field(default_factory=dict) + rel_index: Dict[str, dict] = field(default_factory=dict) + execute_sql_fn: Optional[Callable[[str], dict]] = None + + # Guards the shared accumulators (steps/usage/iterations) that per-item + # runners touch while the entity/relationship walks run them in a pool. + _lock: Any = field(default_factory=threading.Lock, repr=False, compare=False) + _replan_lock: Any = field( + default_factory=threading.Lock, repr=False, compare=False + ) + _max_pct: int = 0 + + # -- helpers ---------------------------------------------------------- + + def add_step( + self, + step_type: str, + content: str, + *, + tool_name: str = "", + duration_ms: int = 0, + ) -> None: + with self._lock: + self.steps.append( + AgentStep( + step_type=step_type, + content=content, + tool_name=tool_name, + duration_ms=duration_ms, + ) + ) + + def pct(self) -> int: + total = max(self.total_items_planned, 1) + return min(5 + int((self.items_done / total) * 90), 95) + + def notify(self, msg: str, *, pct: Optional[int] = None) -> None: + actual_pct = pct if pct is not None else self.pct() + # Progress is reported from a thread pool, so clamp to a monotonic + # high-water mark — the bar never visually goes backwards. + with self._lock: + if actual_pct < self._max_pct: + actual_pct = self._max_pct + else: + self._max_pct = actual_pct + logger.info("PGE STEP [%d%%] %s", actual_pct, msg) + if self.on_step: + self.on_step(msg, actual_pct) + + def add_iterations(self, n: int) -> None: + with self._lock: + self.iterations += int(n or 0) + + def accumulate_usage(self, src: Dict[str, int]) -> None: + with self._lock: + for k in ("prompt_tokens", "completion_tokens"): + self.usage[k] = self.usage.get(k, 0) + int((src or {}).get(k, 0)) + + def refresh_plan(self) -> None: + sm = self.source_model + if sm is None: + return + # Coverage is engine-enforced, NOT LLM-discretionary: attempt EVERY + # ontology entity + relationship regardless of what the Planner put in + # mapping_plan.skip. The Planner's order is used only as a hint. + self.entity_order = full_entity_order(self.ontology, sm) + self.relationship_order = full_relationship_order( + self.ontology, self.entity_order, sm + ) + # The Planner may still flag items it judged unmappable; we keep the + # reasons for logging but DO NOT let them remove items from coverage. + self.skip_reasons = {s.item: s.reason for s in sm.mapping_plan.skip} + concrete, abstract = classify(self.ontology, sm) + self.abstract_uris = abstract + logger.info( + "refresh_plan: full coverage — entities=%d (abstract=%d), " + "relationships=%d; planner_skip(advisory)=%d", + len(self.entity_order), + len(abstract), + len(self.relationship_order), + len(self.skip_reasons), + ) + + def replan_once(self) -> bool: + """Re-invoke the Planner once (subject to the global budget). + + Returns ``True`` on success (and updates the plan in place), ``False`` + when the budget is exhausted or the new Planner run failed. + + Serialised by ``_replan_lock`` so concurrent bubbling items (the + entity/relationship walks run in a thread pool) cannot double-invoke + the Planner or corrupt the shared plan state mid-walk. + """ + with self._replan_lock: + return self._replan_once_locked() + + def _replan_once_locked(self) -> bool: + if self.planner_reinvocations >= _PLANNER_REINVOCATION_BUDGET: + return False + self.planner_reinvocations += 1 + self.notify("Re-planning (escalated)…", pct=self.pct()) + self.add_step( + "planner", + f"replan-start (reinvocation #{self.planner_reinvocations})", + ) + t_rp = time.time() + try: + new_result = run_planner( + host=self.host, + token=self.token, + endpoint_name=self.endpoint_name, + client=self.client, + metadata=self.metadata, + ontology=self.ontology, + documents=self.documents, + on_step=None, + ) + except Exception as exc: # noqa: BLE001 + logger.error("Replan raised an exception: %s", exc, exc_info=True) + self.add_step("planner", f"replan-exception: {exc}") + return False + replan_ms = int((time.time() - t_rp) * 1000) + self.add_iterations(new_result.iterations) + self.accumulate_usage(new_result.usage) + if not new_result.success or new_result.source_model is None: + self.add_step( + "planner", + f"replan-fail: {new_result.error}", + duration_ms=replan_ms, + ) + return False + self.source_model = new_result.source_model + self.refresh_plan() + self.add_step("planner", "replan-done", duration_ms=replan_ms) + return True + + def finalise(self, *, error: str = "") -> AgentResult: + """Build the final :class:`AgentResult`.""" + result = AgentResult(success=False) + result.entity_mappings = list(self.entity_mappings) + result.relationship_mappings = list(self.relationship_mappings) + result.steps = list(self.steps) + result.iterations = self.iterations + result.usage = dict(self.usage) + result.mapping_run_log = list(self.mapping_run_log) + result.mapping_evaluations = dict(self.mapping_evaluations) + result.source_model = ( + self.source_model.to_dict() if self.source_model is not None else None + ) + result.stats = { + "total": len(self.entity_order) + len(self.relationship_order), + "entities": len(self.entity_mappings), + "relationships": len(self.relationship_mappings), + "planner_reinvocations": self.planner_reinvocations, + } + if error: + result.error = error + result.success = False + return result + + # Success when at least one mapping was submitted, OR when there was + # nothing to map (legitimate empty run). + nothing_to_map = ( + not self.entity_order and not self.relationship_order + ) + result.success = self.submitted_any or nothing_to_map + if not result.success: + result.error = ( + "no mappings submitted (all items failed or were skipped)" + ) + logger.info( + "===== MAPPING-PGE ENGINE COMPLETE ===== success=%s, entities=%d, " + "relationships=%d, iterations=%d, replans=%d", + result.success, + len(self.entity_mappings), + len(self.relationship_mappings), + self.iterations, + self.planner_reinvocations, + ) + return result + + +# ===================================================== +# Per-item walk helpers +# ===================================================== + + +_Outcome = Tuple[str, List[dict], Optional[dict], Optional[EvalReport]] + + +def _run_items_concurrently(items: List[Any], runner: Callable[[Any], _Outcome]): + """Run ``runner(item)`` for each item in a bounded thread pool and yield + ``(item, outcome)`` pairs in the ORIGINAL item order. + + The per-item runners mutate only thread-safe parts of the shared state + (lock-guarded usage/iteration/step accumulators); all result MERGING into + the run accumulators happens in the caller's thread after each future + resolves, so there is no race on the mapping lists/dicts. + """ + if not items: + return + workers = min(_MAX_CONCURRENCY, len(items)) + if workers <= 1: + for item in items: + yield item, runner(item) + return + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as pool: + future_to_item = {pool.submit(runner, item): item for item in items} + results: Dict[int, _Outcome] = {} + index = {id(item): i for i, item in enumerate(items)} + for fut in concurrent.futures.as_completed(future_to_item): + item = future_to_item[fut] + try: + results[index[id(item)]] = fut.result() + except Exception as exc: # noqa: BLE001 — surface as a failed item + logger.error("Concurrent item runner raised: %s", exc, exc_info=True) + results[index[id(item)]] = ( + "FAIL_BUDGET", + [{"attempt": 1, "stage1_status": "skipped", + "critic_status": "skipped", "bubble": False, + "hint": None, "error": f"runner exception: {exc}"}], + None, + None, + ) + for i, item in enumerate(items): + yield item, results[i] + + +def _merge_entity_result( + state: "_RunState", entity_uri: str, ontology_class: dict, outcome: _Outcome +) -> None: + """Merge one entity item's outcome into the run accumulators (main thread).""" + final_status, attempts_log, last_mapping, last_report = outcome + label = ontology_class.get("label") or ontology_class.get("name", entity_uri) + state.mapping_run_log.append( + {"item": entity_uri, "kind": "entity", "attempts": attempts_log, + "final_status": final_status} + ) + if final_status == "PASS" and last_mapping is not None: + state.entity_mappings.append(last_mapping) + state.entity_mapping_by_uri[entity_uri] = last_mapping + state.submitted_any = True + if last_report is not None: + state.mapping_evaluations[entity_uri] = last_report.to_dict() + state.notify(f"Mapped {label}") + state.items_done += 1 + + +def _merge_relationship_result( + state: "_RunState", property_uri: str, prop: dict, outcome: _Outcome +) -> None: + """Merge one relationship item's outcome into the run accumulators.""" + final_status, attempts_log, last_mapping, last_report = outcome + label = prop.get("label") or prop.get("name", property_uri) + state.mapping_run_log.append( + {"item": property_uri, "kind": "relationship", "attempts": attempts_log, + "final_status": final_status} + ) + if final_status == "PASS" and last_mapping is not None: + state.relationship_mappings.append(last_mapping) + state.submitted_any = True + if last_report is not None: + state.mapping_evaluations[property_uri] = last_report.to_dict() + state.notify(f"Mapped {label}") + state.items_done += 1 + + +def _run_abstract_item( + state: "_RunState", + ontology_class: dict, +) -> Tuple[str, List[dict], Optional[dict], Optional[EvalReport]]: + """Derive an abstract superclass mapping as the UNION of its concrete + subclass mappings — no LLM call. + + The abstract class (e.g. ``Clinicalencounter``) has no source table of its + own; its instances are exactly the union of its concrete leaf subclasses, + which have already been mapped earlier in the entity walk. Reusing their + verbatim SQL makes the abstract id-universe identical to the union of the + parts, so relationships whose domain/range is the abstract join with zero + dangling. We still run the deterministic evaluator (cheap) to guarantee + unique, non-null ids; the semantic critic is skipped (a mechanical union + has no column-choice ambiguity to audit). + """ + class_uri = ontology_class.get("uri", "") + label = ontology_class.get("label") or ontology_class.get("name", class_uri) + + concrete, _abstract = classify(state.ontology, state.source_model) + leaf_uris = concrete_leaf_descendants(class_uri, state.ontology, concrete) + sub_ems = [ + state.entity_mapping_by_uri[u] + for u in leaf_uris + if u in state.entity_mapping_by_uri + ] + state.notify( + f"Deriving {label} as UNION of {len(sub_ems)}/{len(leaf_uris)} " + "concrete subclass(es)…", + pct=state.pct(), + ) + state.add_step( + "generator", + f"abstract-derive: {class_uri} from {len(sub_ems)} mapped subclasses", + ) + + mapping = build_abstract_union_mapping(class_uri, ontology_class, sub_ems) + if mapping is None: + attempts_log = [ + { + "attempt": 1, + "generator_ms": 0, + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": "no mapped concrete subclasses to union", + } + ] + return "FAIL_BUDGET", attempts_log, None, None + + t_e = time.time() + try: + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=ontology_class, + execute_sql_fn=state.execute_sql_fn, + ) + except Exception as exc: # noqa: BLE001 + logger.error("Abstract derive eval raised on %s: %s", class_uri, exc) + attempts_log = [ + { + "attempt": 1, + "generator_ms": 0, + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"eval exception: {exc}", + } + ] + return "FAIL_BUDGET", attempts_log, mapping, None + eval_ms = int((time.time() - t_e) * 1000) + attempts_log = [ + { + "attempt": 1, + "generator_ms": 0, + "stage1_status": report.status, + "critic_status": "skipped", + "bubble": False, + "hint": _first_hint(report), + } + ] + state.add_step( + "evaluator", + f"abstract-stage1: {class_uri} status={report.status}", + duration_ms=eval_ms, + ) + if report.status == "PASS": + return "PASS", attempts_log, mapping, report + return "FAIL_BUDGET", attempts_log, mapping, report + + +def _run_entity_item( + state: _RunState, + ontology_class: dict, +) -> Tuple[str, List[dict], Optional[dict], Optional[EvalReport]]: + """Run the G->E loop for one entity. + + Returns ``(final_status, attempts_log, last_mapping, last_report)``. + The outer ``while True`` lets a successful replan restart the inner + retry budget fresh, which is the intent of the bubble-to-planner path. + """ + class_uri = ontology_class.get("uri", "") + class_label = ontology_class.get("label") or ontology_class.get( + "name", class_uri + ) + attempts_log: List[dict] = [] + last_mapping: Optional[dict] = None + last_report: Optional[EvalReport] = None + + while True: + retry_hint: Optional[str] = None + bubble_requested = False + for attempt_idx in range(_PER_ITEM_GENERATOR_ATTEMPTS): + attempt_num = attempt_idx + 1 + slice_dict = _slice_for_entity(state.source_model, class_uri) + state.notify( + f"Mapping {class_label} (attempt {attempt_num}/{_PER_ITEM_GENERATOR_ATTEMPTS})…", + pct=state.pct(), + ) + state.add_step( + "generator", + f"entity-gen-start: {class_uri} attempt {attempt_num}", + ) + t_g = time.time() + try: + gen_result = run_entity_generator( + host=state.host, + token=state.token, + endpoint_name=state.endpoint_name, + client=state.client, + ontology_class=ontology_class, + source_model_slice=slice_dict, + retry_hint=retry_hint, + on_step=None, + **( + {"max_iterations": state.max_iterations} + if state.max_iterations is not None + else {} + ), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "EntityGenerator raised on %s attempt %d: %s", + class_uri, + attempt_num, + exc, + exc_info=True, + ) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": int((time.time() - t_g) * 1000), + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"generator exception: {exc}", + } + ) + continue + gen_ms = int((time.time() - t_g) * 1000) + state.add_iterations(gen_result.iterations) + state.accumulate_usage(gen_result.usage) + + if not gen_result.success or gen_result.mapping is None: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": gen_result.error or "generator failed", + } + ) + state.add_step( + "generator", + f"entity-gen-fail: {class_uri} attempt {attempt_num}: " + f"{gen_result.error}", + duration_ms=gen_ms, + ) + retry_hint = gen_result.error or retry_hint + continue + + mapping = gen_result.mapping + last_mapping = mapping + + state.notify(f"Evaluating {class_label}…", pct=state.pct()) + t_e = time.time() + stage1_report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=ontology_class, + execute_sql_fn=state.execute_sql_fn, + ) + eval_ms = int((time.time() - t_e) * 1000) + last_report = stage1_report + state.add_step( + "evaluator", + f"entity-stage1: {class_uri} status={stage1_report.status} " + f"bubble={stage1_report.bubble_to_planner}", + duration_ms=eval_ms, + ) + + if stage1_report.status == "FAIL": + hint = _first_hint(stage1_report) + bubble = bool(stage1_report.bubble_to_planner) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "FAIL", + "critic_status": "skipped", + "bubble": bubble, + "hint": hint, + } + ) + if bubble: + bubble_requested = True + break + retry_hint = hint or retry_hint + continue + + # Stage 1 PASS — optionally run the critic. + if state.skip_semantic_critic: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + } + ) + return "PASS", attempts_log, mapping, stage1_report + + state.notify(f"Critiquing {class_label}…", pct=state.pct()) + t_c = time.time() + try: + critic_result = run_critic( + host=state.host, + token=state.token, + endpoint_name=state.endpoint_name, + client=state.client, + item_kind="entity", + item_uri=class_uri, + item_definition=ontology_class, + submitted_mapping=mapping, + source_model_slice=slice_dict, + stage1_metrics=dict(stage1_report.metrics), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Critic raised on %s attempt %d: %s", + class_uri, + attempt_num, + exc, + exc_info=True, + ) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"critic exception: {exc}", + } + ) + return "PASS", attempts_log, mapping, stage1_report + critic_ms = int((time.time() - t_c) * 1000) + state.add_iterations(critic_result.iterations) + state.accumulate_usage(critic_result.usage) + + critic_report = critic_result.report + state.add_step( + "critic", + f"entity-critic: {class_uri} status=" + f"{critic_report.status if critic_report else '?'} " + f"bubble=" + f"{critic_report.bubble_to_planner if critic_report else '?'}", + duration_ms=critic_ms, + ) + + if not critic_result.success or critic_report is None: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": critic_result.error or "critic failed", + } + ) + return "PASS", attempts_log, mapping, stage1_report + + if critic_report.status == "PASS": + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "PASS", + "bubble": False, + "hint": None, + } + ) + return "PASS", attempts_log, mapping, critic_report + + hint = _first_hint(critic_report) + bubble = bool(critic_report.bubble_to_planner) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "FAIL", + "bubble": bubble, + "hint": hint, + } + ) + last_report = critic_report + if bubble: + bubble_requested = True + break + retry_hint = hint or retry_hint + continue + + if bubble_requested: + if state.replan_once(): + continue # restart the item with the new plan + return "FAIL_BUBBLE", attempts_log, last_mapping, last_report + return "FAIL_BUDGET", attempts_log, last_mapping, last_report + + +def _run_relationship_item( + state: _RunState, + ontology_property: dict, + source_em: dict, + target_em: dict, +) -> Tuple[str, List[dict], Optional[dict], Optional[EvalReport]]: + """Run the G->E loop for one relationship. + + Returns ``(final_status, attempts_log, last_mapping, last_report)``. + """ + property_uri = ontology_property.get("uri", "") + property_label = ontology_property.get("label") or ontology_property.get( + "name", property_uri + ) + attempts_log: List[dict] = [] + last_mapping: Optional[dict] = None + last_report: Optional[EvalReport] = None + + while True: + retry_hint: Optional[str] = None + bubble_requested = False + for attempt_idx in range(_PER_ITEM_GENERATOR_ATTEMPTS): + attempt_num = attempt_idx + 1 + slice_dict = _slice_for_relationship( + state.source_model, + property_uri, + source_em, + target_em, + ) + state.notify( + f"Mapping {property_label} (attempt {attempt_num}/" + f"{_PER_ITEM_GENERATOR_ATTEMPTS})…", + pct=state.pct(), + ) + state.add_step( + "generator", + f"rel-gen-start: {property_uri} attempt {attempt_num}", + ) + t_g = time.time() + try: + gen_result = run_relationship_generator( + host=state.host, + token=state.token, + endpoint_name=state.endpoint_name, + client=state.client, + ontology_property=ontology_property, + source_entity_mapping=source_em, + target_entity_mapping=target_em, + source_model_slice=slice_dict, + retry_hint=retry_hint, + on_step=None, + **( + {"max_iterations": state.max_iterations} + if state.max_iterations is not None + else {} + ), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "RelationshipGenerator raised on %s attempt %d: %s", + property_uri, + attempt_num, + exc, + exc_info=True, + ) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": int((time.time() - t_g) * 1000), + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"generator exception: {exc}", + } + ) + continue + gen_ms = int((time.time() - t_g) * 1000) + state.add_iterations(gen_result.iterations) + state.accumulate_usage(gen_result.usage) + + if not gen_result.success or gen_result.mapping is None: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "skipped", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": gen_result.error or "generator failed", + } + ) + state.add_step( + "generator", + f"rel-gen-fail: {property_uri} attempt {attempt_num}: " + f"{gen_result.error}", + duration_ms=gen_ms, + ) + retry_hint = gen_result.error or retry_hint + continue + + mapping = gen_result.mapping + last_mapping = mapping + + state.notify(f"Evaluating {property_label}…", pct=state.pct()) + t_e = time.time() + stage1_report = evaluate_relationship_mapping( + mapping=mapping, + source_entity_mapping=source_em, + target_entity_mapping=target_em, + execute_sql_fn=state.execute_sql_fn, + id_universe_cache=state.id_universe_cache, + ) + eval_ms = int((time.time() - t_e) * 1000) + last_report = stage1_report + state.add_step( + "evaluator", + f"rel-stage1: {property_uri} status={stage1_report.status} " + f"bubble={stage1_report.bubble_to_planner}", + duration_ms=eval_ms, + ) + + if stage1_report.status == "FAIL": + hint = _first_hint(stage1_report) + bubble = bool(stage1_report.bubble_to_planner) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "FAIL", + "critic_status": "skipped", + "bubble": bubble, + "hint": hint, + } + ) + if bubble: + bubble_requested = True + break + retry_hint = hint or retry_hint + continue + + if state.skip_semantic_critic: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + } + ) + return "PASS", attempts_log, mapping, stage1_report + + state.notify(f"Critiquing {property_label}…", pct=state.pct()) + t_c = time.time() + try: + critic_result = run_critic( + host=state.host, + token=state.token, + endpoint_name=state.endpoint_name, + client=state.client, + item_kind="relationship", + item_uri=property_uri, + item_definition=ontology_property, + submitted_mapping=mapping, + source_model_slice=slice_dict, + stage1_metrics=dict(stage1_report.metrics), + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Critic raised on %s attempt %d: %s", + property_uri, + attempt_num, + exc, + exc_info=True, + ) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": f"critic exception: {exc}", + } + ) + return "PASS", attempts_log, mapping, stage1_report + critic_ms = int((time.time() - t_c) * 1000) + state.add_iterations(critic_result.iterations) + state.accumulate_usage(critic_result.usage) + + critic_report = critic_result.report + state.add_step( + "critic", + f"rel-critic: {property_uri} status=" + f"{critic_report.status if critic_report else '?'} " + f"bubble=" + f"{critic_report.bubble_to_planner if critic_report else '?'}", + duration_ms=critic_ms, + ) + + if not critic_result.success or critic_report is None: + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "skipped", + "bubble": False, + "hint": None, + "error": critic_result.error or "critic failed", + } + ) + return "PASS", attempts_log, mapping, stage1_report + + if critic_report.status == "PASS": + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "PASS", + "bubble": False, + "hint": None, + } + ) + return "PASS", attempts_log, mapping, critic_report + + hint = _first_hint(critic_report) + bubble = bool(critic_report.bubble_to_planner) + attempts_log.append( + { + "attempt": attempt_num, + "generator_ms": gen_ms, + "stage1_status": "PASS", + "critic_status": "FAIL", + "bubble": bubble, + "hint": hint, + } + ) + last_report = critic_report + if bubble: + bubble_requested = True + break + retry_hint = hint or retry_hint + continue + + if bubble_requested: + if state.replan_once(): + continue + return "FAIL_BUBBLE", attempts_log, last_mapping, last_report + return "FAIL_BUDGET", attempts_log, last_mapping, last_report diff --git a/src/agents/agent_mapping_pge/evaluator/__init__.py b/src/agents/agent_mapping_pge/evaluator/__init__.py new file mode 100644 index 00000000..41e7ef5e --- /dev/null +++ b/src/agents/agent_mapping_pge/evaluator/__init__.py @@ -0,0 +1,18 @@ +"""Evaluator stage of the mapping PGE pipeline. + +Stage 1 (this module) is the *deterministic* evaluator — pure-Python checks +backed by SQL counts. Stage 2 (added in a later sprint) is the semantic +evaluator that uses an LLM to judge naming/semantic fidelity. + +The deterministic checks live in :mod:`agents.agent_mapping_pge.evaluator.deterministic`. +""" + +from agents.agent_mapping_pge.evaluator.deterministic import ( + evaluate_entity_mapping, + evaluate_relationship_mapping, +) + +__all__ = [ + "evaluate_entity_mapping", + "evaluate_relationship_mapping", +] diff --git a/src/agents/agent_mapping_pge/evaluator/critic.py b/src/agents/agent_mapping_pge/evaluator/critic.py new file mode 100644 index 00000000..4c7db9a2 --- /dev/null +++ b/src/agents/agent_mapping_pge/evaluator/critic.py @@ -0,0 +1,747 @@ +""" +OntoBricks Mapping-PGE Semantic Critic Agent. + +Sprint 6 of the Planner-Generator-Evaluator (PGE) redesign — stage 2 of the +Evaluator. Runs ONLY after the deterministic (stage-1) evaluator has passed. + +The Critic audits ONE submitted mapping for SEMANTIC correctness — issues that +pure structural checks cannot catch: + +* the WRONG TABLE was picked (e.g. ``antenatal_visits`` chosen to realise + the ``Delivery`` class), or +* the wrong COLUMN within the right table (e.g. ``appointment_date`` used + for ``deliveryDate``). + +The Critic's "bubble" signal is sharp: if the wrong TABLE was chosen, the +verdict bubbles to the Planner (which must revise the source model); if just +a wrong column inside the right table, the verdict stays with the Generator +which can retry against the same table. + +The loop shape mirrors :mod:`agents.agent_mapping_pge.generators.entity` — +same ``call_serving_endpoint`` + ``dispatch_tool`` ReAct cycle, same 3-second +inter-iteration delay, same MLflow trace decorator. Differences: + +* Smaller default budget (6) — auditing is bounded work; if the Critic can't + conclude in 6 iterations, it defers (PASS with a reasoning note) rather + than falsely escalates. +* Different tool set: only ``sample_table``, ``execute_sql``, + ``get_documents_context``, and the terminal ``submit_evaluation``. +* No single-shot fallback — the Critic produces structured output through + ``submit_evaluation`` only. +""" + +import json +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +import requests + +if TYPE_CHECKING: + from agents.agent_mapping_pge.contracts import EvalReport + +from back.core.logging import get_logger +from agents.engine_base import ( + accumulate_usage, + call_serving_endpoint, + dispatch_tool, +) +from agents.tools.context import ToolContext +from agents.tools.documents import ( + GET_DOCUMENTS_CONTEXT_DEF, + tool_get_documents_context, +) +from agents.tools.evaluation import ( + EVALUATION_TOOL_DEFINITIONS, + EVALUATION_TOOL_HANDLERS, +) +from agents.tools.planner import ( + SAMPLE_TABLE_DEF, + tool_sample_table, +) +from agents.tools.sql import ( + SQL_TOOL_DEFINITIONS, + SQL_TOOL_HANDLERS, +) +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +MAX_ITERATIONS = 6 +LLM_TIMEOUT = 180 +_ITERATION_DELAY_SEC = 1 +# See planner._MAX_TOKENS comment — same rationale for submit_evaluation. +_MAX_TOKENS = 50000 + +_TRACE_NAME = "mapping_pge_critic" + + +# ===================================================== +# Tool aggregation +# ===================================================== +# +# The Critic only needs: +# * sample_table – peek at actual values to verify the column +# picked really represents the ontology concept. +# * execute_sql – targeted probes for "is this column really +# what it claims" sanity checks. +# * get_documents_context – consult any imported domain glossary. +# * submit_evaluation – TERMINAL. +# +# We deliberately exclude: +# * get_metadata / get_ontology — the audit target is supplied in the user +# prompt; broad re-fetches just inflate context. +# * column_value_overlap / distinct_count — those are structural, already +# covered by the deterministic stage. +# * submit_source_model / submit_entity_mapping / submit_relationship_mapping +# — wrong stage. + +TOOL_DEFINITIONS: List[dict] = ( + [SAMPLE_TABLE_DEF, GET_DOCUMENTS_CONTEXT_DEF] + + SQL_TOOL_DEFINITIONS + + EVALUATION_TOOL_DEFINITIONS +) + +TOOL_HANDLERS: Dict[str, Callable] = { + "sample_table": tool_sample_table, + "get_documents_context": tool_get_documents_context, + **SQL_TOOL_HANDLERS, + **EVALUATION_TOOL_HANDLERS, +} + + +# ===================================================== +# Data classes +# ===================================================== + + +@dataclass +class CriticStep: + """One observable step of the Critic's execution. + + Mirrors :class:`agents.agent_mapping_pge.generators.entity.EntityGenStep` + so the orchestrator (Sprint 7) can render a per-audit timeline in the UI. + """ + + step_type: str # "tool_call" | "tool_result" | "output" + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class CriticResult: + """Outcome of a single Critic invocation. + + ``report`` is populated when the agent terminated by submitting a verdict + via ``submit_evaluation``. ``success`` here is the agent-level success + flag — it indicates a *clean termination*, NOT a PASS verdict. A FAIL + verdict with ``bubble_to_planner=True`` still has ``success=True``. + ``success=False`` is reserved for budget exhaustion, text-only output, + and LLM/transport errors. + """ + + success: bool + report: Optional["EvalReport"] = None + steps: List[CriticStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + + +# ===================================================== +# System prompt +# ===================================================== +# +# Kept under 3KB. Frames the Critic as a senior data engineer auditing ONE +# submitted mapping for SEMANTIC correctness — the structural checks have +# already passed. The decision rubric (PASS / FAIL+no-bubble / FAIL+bubble) +# is load-bearing: it determines whether the orchestrator retries the +# Generator or re-invokes the Planner. + +SYSTEM_PROMPT = """\ +You are a senior data engineer auditing ONE submitted mapping for SEMANTIC \ +correctness. The structural checks (row counts, distinct IDs, dangling FKs) \ +have ALREADY PASSED — your job is to catch wrong-concept errors that pure \ +structural checks cannot see. + +WHAT YOU AUDIT +• Did the mapping pick the RIGHT TABLE for the ontology class/property? +• Do sampled values in the chosen column(s) actually represent what the \ +ontology attribute means? (e.g. "delivery_date" should be a delivery date, \ +not a booking date.) +• Does the column's semantics match the ontology comment / label? + +TOOLS +You have these tools: + • sample_table – Up to N random rows from a table. Use to peek at \ +actual values and check they match the concept. + • execute_sql – Targeted SQL for "is this column really what it \ +claims" probes (e.g. value ranges, distinct categories, null patterns). + • get_documents_context – Imported domain glossaries / data dictionaries. \ +Check against these when the column's role is non-obvious. + • submit_evaluation – TERMINAL. Call EXACTLY ONCE when you have a \ +confident verdict. + +DECISION RUBRIC +• PASS — sampled values, column semantics, and domain context all support \ +the mapping. status="PASS", failures=[], bubble_to_planner=false. +• FAIL with bubble_to_planner=false — the WRONG COLUMN was picked within \ +the RIGHT TABLE. The Generator can fix this on retry. Populate failures[] \ +with the specific column-level issue and a concrete hint. +• FAIL with bubble_to_planner=true — the WRONG TABLE was chosen entirely. \ +The Planner must revise the source model. Populate failures[] and set the \ +bubble flag. + +HINT DISCIPLINE +• Hints must be CONCRETE, ACTIONABLE, single-sentence corrections. +• Good column-level hint: "Sampled rows show `appointment_date` is the \ +booking date, not delivery date. Use `delivery_dttm` instead." +• Good table-level hint: "This mapping uses `antenatal_visits`, but the \ +chosen class is Delivery. Switch to the `labour_delivery` table." +• Bad hint (vague): "consider using a different column" +• Bad hint (chatty): "I think there might be an issue here, you should look \ +into it more carefully" + +HARD RULES +• You are bounded by max_iterations=6. Keep your audit FOCUSED — pick the \ +one or two probes that would change your verdict, not exhaustive ones. +• Call submit_evaluation EXACTLY ONCE. +• If you cannot determine a verdict within 6 iterations, submit PASS with a \ +reasoning note explaining the uncertainty. Do NOT bubble — better to defer \ +than to falsely escalate. +• Do not call get_metadata, get_ontology, column_value_overlap, \ +distinct_count, submit_source_model, submit_entity_mapping, or \ +submit_relationship_mapping — they are not available to you. The audit \ +target and structural metrics are already in the user message. +""" + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _format_entity_definition(item_definition: dict) -> List[str]: + """Lines for an entity (ontology class) audit target.""" + parts: List[str] = [] + label = item_definition.get("label") or item_definition.get("name", "") + comment = item_definition.get("comment", "") or "" + attributes = item_definition.get("attributes", []) or [] + + parts.append(f" label: {label}") + if comment: + parts.append(f" comment: {comment}") + if attributes: + parts.append(f" attributes ({len(attributes)}):") + for attr in attributes: + if isinstance(attr, dict): + a_name = attr.get("name") or attr.get("label") or attr.get("uri", "?") + a_type = attr.get("type") or attr.get("range") or "" + parts.append(f" - {a_name}" + (f" ({a_type})" if a_type else "")) + else: + parts.append(f" - {attr}") + return parts + + +def _format_relationship_definition(item_definition: dict) -> List[str]: + """Lines for a relationship (ontology property) audit target. + + Always emits explicit ``domain`` and ``range`` lines — these are what + differentiate a relationship audit from an entity audit, and the tests + pin them. + """ + parts: List[str] = [] + label = item_definition.get("label") or item_definition.get("name", "") + comment = item_definition.get("comment", "") or "" + domain = item_definition.get("domain", "") or "" + range_class = item_definition.get("range", "") or "" + + parts.append(f" label: {label}") + if comment: + parts.append(f" comment: {comment}") + parts.append(f" domain: {domain}") + parts.append(f" range: {range_class}") + return parts + + +def _format_submitted_entity_mapping(submitted_mapping: dict) -> List[str]: + """Lines summarising an entity mapping under audit.""" + parts: List[str] = ["SUBMITTED MAPPING (entity)"] + parts.append(f" sql_query: {submitted_mapping.get('sql_query', '')}") + parts.append(f" id_column: {submitted_mapping.get('id_column', '')}") + parts.append(f" label_column: {submitted_mapping.get('label_column', '')}") + attr_map = submitted_mapping.get("attribute_mappings", {}) or {} + if attr_map: + parts.append(" attribute_mappings:") + for k, v in attr_map.items(): + parts.append(f" {k} -> {v}") + unmapped = submitted_mapping.get("unmapped_attributes", []) or [] + if unmapped: + parts.append(" unmapped_attributes:") + for u in unmapped: + if isinstance(u, dict): + parts.append( + f" - {u.get('name', '?')}: {u.get('reason', '')}" + ) + else: + parts.append(f" - {u}") + return parts + + +def _format_submitted_relationship_mapping(submitted_mapping: dict) -> List[str]: + """Lines summarising a relationship mapping under audit.""" + parts: List[str] = ["SUBMITTED MAPPING (relationship)"] + parts.append(f" sql_query: {submitted_mapping.get('sql_query', '')}") + parts.append( + f" source_id_column: {submitted_mapping.get('source_id_column', '')}" + ) + parts.append( + f" target_id_column: {submitted_mapping.get('target_id_column', '')}" + ) + parts.append( + f" source_class: {submitted_mapping.get('source_class', '') or submitted_mapping.get('domain', '')}" + ) + parts.append( + f" target_class: {submitted_mapping.get('target_class', '') or submitted_mapping.get('range_class', '')}" + ) + return parts + + +def _build_user_prompt( + item_kind: str, + item_uri: str, + item_definition: dict, + submitted_mapping: dict, + source_model_slice: dict, + stage1_metrics: dict, +) -> str: + """Render the audit user prompt. + + Structure: + 1. AUDIT TARGET — item_kind, URI, ontology metadata (label/comment, + attributes for entities; domain/range for relationships). + 2. SUBMITTED MAPPING — the actual mapping under audit. + 3. PLANNER'S PREDICTION — the slice the Planner curated for this item. + 4. STRUCTURAL CHECK METRICS (PASSED) — context from stage 1. + 5. YOUR TASK — short reminder of the rubric. + """ + parts: List[str] = [] + + parts.append("AUDIT TARGET") + parts.append(f" kind: {item_kind}") + parts.append(f" uri: {item_uri}") + if item_kind == "relationship": + parts.extend(_format_relationship_definition(item_definition or {})) + else: + parts.extend(_format_entity_definition(item_definition or {})) + + parts.append("") + if item_kind == "relationship": + parts.extend(_format_submitted_relationship_mapping(submitted_mapping or {})) + else: + parts.extend(_format_submitted_entity_mapping(submitted_mapping or {})) + + parts.append("") + parts.append("PLANNER'S PREDICTION") + parts.append(json.dumps(source_model_slice or {}, indent=2, default=str)) + + parts.append("") + parts.append("STRUCTURAL CHECK METRICS (PASSED)") + parts.append(json.dumps(stage1_metrics or {}, indent=2, default=str)) + + parts.append("") + parts.append("YOUR TASK") + parts.append( + "Audit the SEMANTIC correctness of the submitted mapping. Use " + "sample_table / execute_sql / get_documents_context as needed, then " + "call submit_evaluation EXACTLY ONCE with your verdict. Follow the " + "PASS / FAIL(no bubble) / FAIL(bubble) rubric in the system prompt." + ) + + prompt = "\n".join(parts) + logger.debug( + "_build_user_prompt for %s=%s (%d chars):\n%s", + item_kind, + item_uri, + len(prompt), + prompt, + ) + return prompt + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_critic") +def run_critic( + host: str, + token: str, + endpoint_name: str, + client: Any, + *, + item_kind: str, + item_uri: str, + item_definition: dict, + submitted_mapping: dict, + source_model_slice: dict, + stage1_metrics: dict, + documents: Optional[list] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: int = MAX_ITERATIONS, +) -> CriticResult: + """Run the Semantic Critic agent for one submitted mapping. + + The Critic autonomously audits ``submitted_mapping`` for semantic + correctness using ``sample_table`` / ``execute_sql`` / + ``get_documents_context``, then submits a verdict via the terminal + ``submit_evaluation`` tool. The resulting :class:`EvalReport` (stage + ``"semantic"``) is stored on ``ctx.semantic_eval_report`` and returned in + ``CriticResult.report``. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client (must expose ``execute_query(sql)``). + item_kind: ``"entity"`` or ``"relationship"``. + item_uri: The ontology class or property URI under audit. + item_definition: Full ontology dict for the item (label/comment, + plus attributes for entities or domain/range for relationships). + submitted_mapping: The mapping under audit (handler dict shape). + source_model_slice: The Planner's slice for this item. + stage1_metrics: Metrics from the deterministic evaluator, for + context. + documents: Optional pre-loaded domain documents — surfaced via + ``get_documents_context``. + on_step: Optional progress callback ``(msg, pct)`` for UI updates. + max_iterations: Upper bound on tool-call iterations (default 6 — + smaller than the Generators because auditing is bounded work). + + Returns: + A :class:`CriticResult`. ``success`` is True iff the Critic + terminated by submitting a verdict; in that case ``report`` holds + the resulting :class:`EvalReport`. On failure (budget exhaustion, + text-only output, transport error), ``error`` explains why. + """ + iteration_limit = max_iterations if max_iterations is not None else MAX_ITERATIONS + + logger.info( + "===== CRITIC START ===== endpoint=%s, kind=%s, uri=%s, max_iter=%d", + endpoint_name, + item_kind, + item_uri, + iteration_limit, + ) + + ctx = ToolContext( + host=host.rstrip("/"), + token=token, + client=client, + # The audit target is in the user prompt; metadata/ontology are not + # needed by the Critic's tools. + metadata={}, + ontology={}, + documents=list(documents or []), + ) + + result = CriticResult(success=False) + + user_prompt = _build_user_prompt( + item_kind=item_kind, + item_uri=item_uri, + item_definition=item_definition or {}, + submitted_mapping=submitted_mapping or {}, + source_model_slice=source_model_slice or {}, + stage1_metrics=stage1_metrics or {}, + ) + messages: List[dict] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + logger.info( + "Critic conversation initialized: system=%d chars, user=%d chars", + len(SYSTEM_PROMPT), + len(user_prompt), + ) + + total_usage: Dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} + + def _progress_pct(iteration_idx: int) -> int: + ratio = (iteration_idx + 1) / max(iteration_limit, 1) + return min(5 + int(ratio * 90), 95) + + def notify(msg: str, *, pct: Optional[int] = None) -> None: + actual_pct = pct if pct is not None else 5 + logger.info("CRITIC STEP [%d%%] %s", actual_pct, msg) + if on_step: + on_step(msg, actual_pct) + + notify(f"Auditing {item_kind} {item_uri}…", pct=1) + + # ------------------------------------------------------------------ + # Agent loop + # ------------------------------------------------------------------ + for iteration in range(iteration_limit): + if iteration > 0: + logger.debug( + "Iteration %d: waiting %ds before LLM call (rate limit mitigation)", + iteration + 1, + _ITERATION_DELAY_SEC, + ) + time.sleep(_ITERATION_DELAY_SEC) + + current_iteration = iteration + 1 + pct = _progress_pct(iteration) + logger.info( + "----- Critic iteration %d/%d — %d messages, report=%s -----", + current_iteration, + iteration_limit, + len(messages), + "set" if ctx.semantic_eval_report is not None else "unset", + ) + notify( + f"Critic iteration {current_iteration}/{iteration_limit}…", + pct=pct, + ) + + t0 = time.time() + try: + llm_response = call_serving_endpoint( + host, + token, + endpoint_name, + messages, + tools=TOOL_DEFINITIONS, + max_tokens=_MAX_TOKENS, + temperature=0.1, + timeout=LLM_TIMEOUT, + trace_name=_TRACE_NAME, + ) + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + logger.warning( + "Critic iteration %d: HTTPError status=%s", + current_iteration, + status, + ) + logger.debug( + "Critic iteration %d: HTTPError body: %.500s", + current_iteration, + exc.response.text if exc.response is not None else "N/A", + ) + if exc.response is not None and status in (400, 422): + result.error = "LLM endpoint does not support function calling" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Critic: endpoint refused tools — cannot produce an evaluation" + ) + return result + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Critic: LLM request failed at iteration %d: %s", + current_iteration, + exc, + ) + return result + except requests.exceptions.ReadTimeout: + result.error = f"LLM request timed out after {LLM_TIMEOUT}s" + result.iterations = current_iteration + result.usage = total_usage + logger.error("Critic: timeout at iteration %d", current_iteration) + return result + except requests.exceptions.RequestException as exc: + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Critic: request exception at iteration %d: %s", + current_iteration, + exc, + ) + return result + + elapsed_ms = int((time.time() - t0) * 1000) + logger.info( + "Critic iteration %d: LLM responded in %dms", + current_iteration, + elapsed_ms, + ) + + accumulate_usage(total_usage, llm_response.get("usage", {})) + + choice = llm_response.get("choices", [{}])[0] + finish_reason = choice.get("finish_reason", "?") + message = choice.get("message", {}) + tool_calls = message.get("tool_calls", []) + has_content = bool(message.get("content")) + logger.info( + "Critic iteration %d: finish_reason=%s, tool_calls=%d, has_content=%s", + current_iteration, + finish_reason, + len(tool_calls), + has_content, + ) + + if not tool_calls: + # The Critic must terminate via submit_evaluation, never via + # free text. Text-only output is a failure. + content = (message.get("content") or "")[:500] + logger.warning( + "Critic iteration %d: produced text without submitting evaluation — %d chars", + current_iteration, + len(message.get("content") or ""), + ) + result.steps.append( + CriticStep( + step_type="output", + content=content, + duration_ms=elapsed_ms, + ) + ) + result.error = "critic produced text without submitting evaluation" + result.iterations = current_iteration + result.usage = total_usage + notify( + "Critic produced text without submitting evaluation.", + pct=pct, + ) + return result + + logger.info( + "Critic iteration %d: processing %d tool call(s): [%s]", + current_iteration, + len(tool_calls), + ", ".join( + tc.get("function", {}).get("name", "?") for tc in tool_calls + ), + ) + messages.append(message) + + terminal_success = False + for tc_idx, tc in enumerate(tool_calls, 1): + func = tc.get("function", {}) + tool_name = func.get("name", "") + raw_args = func.get("arguments", "{}") + tool_id = tc.get("id", "") + + try: + arguments = json.loads(raw_args) + except json.JSONDecodeError: + arguments = {} + + logger.info( + "Critic iteration %d: calling tool '%s' (%d/%d)", + current_iteration, + tool_name, + tc_idx, + len(tool_calls), + ) + + if tool_name == "submit_evaluation": + notify( + f"Submitting evaluation for {item_uri}…", pct=pct + ) + elif tool_name == "sample_table": + fn = arguments.get("full_name", "?") + notify(f"Sampling {fn}…", pct=pct) + elif tool_name == "execute_sql": + sql_preview = arguments.get("sql", "")[:80] + notify(f"Running SQL: {sql_preview}…", pct=pct) + elif tool_name == "get_documents_context": + notify("Retrieving documents…", pct=pct) + else: + notify(f"Calling {tool_name}…", pct=pct) + + result.steps.append( + CriticStep( + step_type="tool_call", + content=json.dumps(arguments, default=str)[:500], + tool_name=tool_name, + ) + ) + + t1 = time.time() + tool_result = dispatch_tool( + TOOL_HANDLERS, ctx, tool_name, arguments, trace_name=_TRACE_NAME + ) + tool_ms = int((time.time() - t1) * 1000) + + logger.info( + "Critic iteration %d: tool '%s' returned %d chars in %dms", + current_iteration, + tool_name, + len(tool_result), + tool_ms, + ) + + result.steps.append( + CriticStep( + step_type="tool_result", + content=( + (tool_result[:500] + "…") + if len(tool_result) > 500 + else tool_result + ), + tool_name=tool_name, + duration_ms=tool_ms, + ) + ) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": tool_result, + } + ) + + # Detect terminal success: submit_evaluation returned success=True + # AND stamped an EvalReport onto the context. An invalid status + # (the handler returns success=False) does NOT terminate the + # loop — the agent continues so it can resubmit a valid verdict. + if tool_name == "submit_evaluation": + try: + parsed = json.loads(tool_result) + except json.JSONDecodeError: + parsed = {} + if ( + parsed.get("success") is True + and ctx.semantic_eval_report is not None + ): + terminal_success = True + logger.info( + "Critic iteration %d: submit_evaluation succeeded — terminating", + current_iteration, + ) + + if terminal_success: + result.success = True + result.report = ctx.semantic_eval_report + result.iterations = current_iteration + result.usage = total_usage + logger.info( + "===== CRITIC COMPLETE ===== uri=%s, status=%s, bubble=%s, " + "iterations=%d, prompt_tokens=%d, completion_tokens=%d", + item_uri, + result.report.status if result.report else "?", + result.report.bubble_to_planner if result.report else "?", + result.iterations, + total_usage["prompt_tokens"], + total_usage["completion_tokens"], + ) + notify(f"Critic verdict submitted for {item_uri}.", pct=100) + return result + + # Budget exhausted without a successful submit. + result.iterations = iteration_limit + result.usage = total_usage + result.error = "critic exhausted iteration budget" + logger.error("===== CRITIC FAILED ===== %s", result.error) + notify(result.error, pct=95) + return result diff --git a/src/agents/agent_mapping_pge/evaluator/deterministic.py b/src/agents/agent_mapping_pge/evaluator/deterministic.py new file mode 100644 index 00000000..b8e72951 --- /dev/null +++ b/src/agents/agent_mapping_pge/evaluator/deterministic.py @@ -0,0 +1,539 @@ +"""Deterministic (stage-1) evaluator for submitted mappings. + +This module is pure-Python and has no LLM dependency. It runs the +submitted mapping's SQL through a caller-supplied ``execute_sql_fn`` and +checks structural invariants (row count, distinct id count, dangling +foreign-key fractions, etc.). + +``execute_sql_fn`` contract:: + + def execute_sql_fn(sql: str) -> dict +returning ``{"columns": [...], "rows": [{col: value, ...}, ...]}``. + +Important: this is the *full* result set, not the 3-row sample emitted by +:func:`agents.tools.sql.tool_execute_sql`. The orchestrator (Sprint 7) is +responsible for plugging in a runner that returns full rows — typically a +thin wrapper around ``DatabricksClient.execute_query``. + +All checks compute every metric even when some fail; the resulting +:class:`~agents.agent_mapping_pge.contracts.EvalReport` lists every failure +so the Generator/Planner can address them in one shot. +""" + +from typing import Any, Callable, Dict, List, Optional, Tuple + +from back.core.logging import get_logger +from agents.agent_mapping_pge.contracts import EvalFailure, EvalReport +from agents.agent_mapping_pge.evaluator.report import build_report + +logger = get_logger(__name__) + +# Thresholds for stage-1 checks. These are intentionally lax — the +# semantic evaluator (stage 2) catches subtler issues. +_DANGLING_FK_FAIL_THRESHOLD = 0.05 +_DANGLING_FK_BUBBLE_THRESHOLD = 0.5 + + +SqlFn = Callable[[str], dict] + + +# ===================================================== +# Helpers +# ===================================================== + + +def _resolve_id_col(mapping: dict, fallback: str = "ID") -> str: + """Return the column name that holds the entity identifier in the row dicts.""" + return mapping.get("id_column") or fallback + + +def _extract_id_values(rows: List[dict], id_col: str) -> List[Any]: + """Pull the id_col value from each row; missing key -> ``None``.""" + return [r.get(id_col) for r in rows] + + +def _attribute_names(ontology_class: dict) -> List[str]: + """Ontology attributes can come in a few shapes; normalise to a list of names.""" + attrs = ontology_class.get("attributes") or [] + out: List[str] = [] + for a in attrs: + if isinstance(a, str): + out.append(a) + elif isinstance(a, dict): + name = a.get("name") or a.get("uri") or a.get("label") + if name: + out.append(name) + return out + + +def _fail( + *, + check: str, + expected: str, + observed: str, + hint: str, + kind: str = "structural", +) -> EvalFailure: + return EvalFailure( + kind=kind, check=check, expected=expected, observed=observed, hint=hint + ) + + +class _SqlExecError(Exception): + """A generated mapping's SQL parsed but failed at execution time. + + Wraps the underlying DB driver exception so the deterministic evaluator + can convert it into an actionable FAIL — never let it crash the run. + """ + + +def _exec(execute_sql_fn: SqlFn, sql: str) -> dict: + """Run SQL, normalising any driver-level failure into ``_SqlExecError``. + + Generated mappings routinely produce SQL that *parses* but fails at + execution (UNION column-type mismatch, invalid CAST, unknown column). + The PGE contract is that such errors become feedback for the generator, + so they must surface as a FAIL report rather than an unhandled exception + that aborts the whole agent run. + """ + try: + return execute_sql_fn(sql) or {} + except Exception as exc: # noqa: BLE001 — any driver error becomes feedback + raise _SqlExecError(str(exc)) from exc + + +def _sql_error_report(*, item: str, sql_error: str) -> EvalReport: + """Build a FAIL report for a mapping whose SQL failed to execute. + + ``bubble_to_planner`` stays False: a runtime SQL error is the + Generator's to fix (align types, correct columns), not a signal that the + Planner's source model is wrong. + """ + # Keep the hint compact — driver errors can be very long. + err = sql_error.strip().splitlines()[0][:300] if sql_error else "unknown error" + return build_report( + stage="deterministic", + metrics={"sql_error": err}, + failures=[ + _fail( + check="sql_execution", + expected="SQL executes without error", + observed="execution error", + hint=( + f"The mapping SQL for '{item}' failed to execute: {err}. " + "Fix the SQL — e.g. align UNION branch column types with " + "explicit CAST (a common cause is one branch typing a " + "column as BIGINT and another as STRING/NULL), correct " + "column names, or use try_cast for malformed values." + ), + ) + ], + bubble_to_planner=False, + ) + + +# ===================================================== +# Entity evaluator +# ===================================================== + + +def evaluate_entity_mapping( + *, + mapping: dict, + ontology_class: dict, + execute_sql_fn: SqlFn, +) -> EvalReport: + """Run the stage-1 deterministic checks on a submitted entity mapping. + + Args: + mapping: Submitted entity mapping in the shape produced by + ``tool_submit_entity_mapping``. + ontology_class: The ontology-class dict the mapping targets; must + expose an ``attributes`` list (each item being a name string or + a dict with a ``name`` key). + execute_sql_fn: Caller-supplied SQL runner — see module docstring. + + Returns: + An :class:`EvalReport` summarising the metrics and any failures. + ``bubble_to_planner`` is set when ``row_count == 0`` (typically + means the mapping is querying the wrong table altogether). + """ + class_name = mapping.get("class_name") or ontology_class.get("name") or "?" + sql = mapping.get("sql_query", "") + id_col = _resolve_id_col(mapping) + logger.info( + "evaluate_entity_mapping: class=%s, id_col=%s, sql_len=%d", + class_name, + id_col, + len(sql), + ) + + try: + result = _exec(execute_sql_fn, sql) + except _SqlExecError as exc: + logger.warning( + "evaluate_entity_mapping: class=%s SQL failed to execute: %s", + class_name, + exc, + ) + return _sql_error_report(item=class_name, sql_error=str(exc)) + rows = result.get("rows", []) or [] + row_count = len(rows) + + id_values = _extract_id_values(rows, id_col) + null_id_count = sum(1 for v in id_values if v is None) + distinct_id_count = len({v for v in id_values if v is not None}) + + raw_unmapped = mapping.get("unmapped_attributes") or [] + declared_unmapped: set = set() + for item in raw_unmapped: + if isinstance(item, dict): + name = item.get("name") + if name: + declared_unmapped.add(str(name)) + elif item is not None: + declared_unmapped.add(str(item)) + declared_mapped = set((mapping.get("attribute_mappings") or {}).keys()) + all_attrs = _attribute_names(ontology_class) + unmapped_attrs = [ + a for a in all_attrs if a not in declared_mapped and a not in declared_unmapped + ] + unmapped_pct = (len(unmapped_attrs) / len(all_attrs)) if all_attrs else 0.0 + + metrics: Dict[str, Any] = { + "row_count": row_count, + "distinct_id_count": distinct_id_count, + "null_id_count": null_id_count, + "unmapped_attribute_pct": unmapped_pct, + "unmapped_attributes": unmapped_attrs, + } + + failures: List[EvalFailure] = [] + bubble = False + + if row_count == 0: + failures.append( + _fail( + check="row_count", + expected="> 0", + observed="0", + hint=( + f"Entity '{class_name}' SQL returned 0 rows. Check the FROM " + "table is correct and the WHERE clause is not over-filtering." + ), + ) + ) + bubble = True + + if row_count > 0 and distinct_id_count != row_count: + dupes = row_count - distinct_id_count + failures.append( + _fail( + check="distinct_id_count", + expected=f"== row_count ({row_count})", + observed=str(distinct_id_count), + hint=( + f"{dupes} duplicate '{id_col}' value(s) in entity '{class_name}'. " + "Add DISTINCT or use a stricter id column." + ), + ) + ) + + if null_id_count > 0: + failures.append( + _fail( + check="null_id_count", + expected="== 0", + observed=str(null_id_count), + hint=( + f"{null_id_count} row(s) have NULL '{id_col}' in entity " + f"'{class_name}'. Add 'WHERE {id_col} IS NOT NULL' to the SQL." + ), + ) + ) + + if unmapped_pct > 0: + failures.append( + _fail( + check="unmapped_attribute_pct", + expected="== 0", + observed=f"{unmapped_pct:.3f}", + hint=( + f"{len(unmapped_attrs)} attribute(s) of '{class_name}' are " + f"neither in attribute_mappings nor declared in " + f"unmapped_attributes: {unmapped_attrs}. Map them, or list " + "them explicitly under 'unmapped_attributes'." + ), + ) + ) + + logger.info( + "evaluate_entity_mapping: class=%s -> %s (%d failure(s), bubble=%s)", + class_name, + "PASS" if not failures else "FAIL", + len(failures), + bubble, + ) + return build_report( + stage="deterministic", + metrics=metrics, + failures=failures, + bubble_to_planner=bubble, + ) + + +# ===================================================== +# Relationship evaluator +# ===================================================== + + +def _distinct_id_set( + entity_mapping: dict, + execute_sql_fn: SqlFn, + id_universe_cache: Optional[Dict[str, set]] = None, +) -> set: + """Materialise the set of valid ids for a given entity mapping. + + When ``id_universe_cache`` is provided it is consulted/populated keyed + by the entity mapping's SQL string, avoiding redundant SQL execution + across repeated calls that share endpoint entities. + """ + sql = entity_mapping.get("sql_query", "") + id_col = _resolve_id_col(entity_mapping) + if id_universe_cache is not None and sql in id_universe_cache: + return id_universe_cache[sql] + result = _exec(execute_sql_fn, sql) # may raise _SqlExecError + rows = result.get("rows", []) or [] + ids = {r.get(id_col) for r in rows if r.get(id_col) is not None} + if id_universe_cache is not None: + id_universe_cache[sql] = ids + return ids + + +def _resolve_edge_columns(mapping: dict) -> Tuple[str, str]: + """Return ``(source_col, target_col)`` for a relationship mapping.""" + return ( + mapping.get("source_id_column") or "source_id", + mapping.get("target_id_column") or "target_id", + ) + + +def evaluate_relationship_mapping( + *, + mapping: dict, + source_entity_mapping: dict, + target_entity_mapping: dict, + execute_sql_fn: SqlFn, + expected_cross_source_overlap_band: Optional[Tuple[float, float]] = None, + id_universe_cache: Optional[Dict[str, set]] = None, +) -> EvalReport: + """Run stage-1 deterministic checks on a relationship mapping. + + Checks: + + * ``total_edges > 0`` + * ``dangling_source_pct < 0.05`` — fraction of source ids that do not + exist in the source entity's id universe. + * ``dangling_target_pct < 0.05`` — same for targets. + * If ``expected_cross_source_overlap_band`` is supplied, the realised + ``overlap_pct`` (fraction of edges whose target id appears in the + target entity universe) must fall inside the band. + + ``bubble_to_planner`` is set when ``total_edges == 0``, when the source + dangling fraction exceeds ``0.5``, or when the target dangling fraction + exceeds ``0.5`` *and* the realised overlap is materially worse than the + Planner predicted (either no band was supplied, or the band check + itself failed). These cases typically indicate the relationship was + built off the wrong join key. + + Args: + id_universe_cache: Optional caller-managed dict mapping an entity + mapping's ``sql_query`` string to its materialised set of ids. + When provided, repeated calls across relationships that share + endpoint entities reuse cached id universes instead of + re-running the entity SQL via ``execute_sql_fn``. When + ``None`` (default) behaviour is unchanged — fetch fresh each + call. No module-level state is involved. + """ + name = mapping.get("property_name") or mapping.get("property") or "?" + sql = mapping.get("sql_query", "") + src_col, tgt_col = _resolve_edge_columns(mapping) + logger.info( + "evaluate_relationship_mapping: property=%s, src_col=%s, tgt_col=%s", + name, + src_col, + tgt_col, + ) + + try: + edges_result = _exec(execute_sql_fn, sql) + edge_rows = edges_result.get("rows", []) or [] + total_edges = len(edge_rows) + + source_universe = _distinct_id_set( + source_entity_mapping, execute_sql_fn, id_universe_cache + ) + target_universe = _distinct_id_set( + target_entity_mapping, execute_sql_fn, id_universe_cache + ) + except _SqlExecError as exc: + logger.warning( + "evaluate_relationship_mapping: property=%s SQL failed to execute: %s", + name, + exc, + ) + return _sql_error_report(item=name, sql_error=str(exc)) + + src_values = [r.get(src_col) for r in edge_rows] + tgt_values = [r.get(tgt_col) for r in edge_rows] + + if total_edges > 0: + dangling_src = sum( + 1 for v in src_values if v is None or v not in source_universe + ) + dangling_tgt = sum( + 1 for v in tgt_values if v is None or v not in target_universe + ) + dangling_src_pct = dangling_src / total_edges + dangling_tgt_pct = dangling_tgt / total_edges + overlap_pct = 1.0 - dangling_tgt_pct + else: + dangling_src_pct = 0.0 + dangling_tgt_pct = 0.0 + overlap_pct = 0.0 + + metrics: Dict[str, Any] = { + "total_edges": total_edges, + "dangling_source_pct": dangling_src_pct, + "dangling_target_pct": dangling_tgt_pct, + "cross_source_overlap_pct": overlap_pct, + "source_universe_size": len(source_universe), + "target_universe_size": len(target_universe), + } + + failures: List[EvalFailure] = [] + bubble = False + + if total_edges == 0: + failures.append( + _fail( + check="total_edges", + expected="> 0", + observed="0", + hint=( + f"Relationship '{name}' produced 0 edges. Confirm the join " + "predicate is on the right columns and rows are not being " + "filtered away." + ), + ) + ) + bubble = True + + if total_edges > 0 and dangling_src_pct >= _DANGLING_FK_FAIL_THRESHOLD: + failures.append( + _fail( + check="dangling_source_pct", + expected=f"< {_DANGLING_FK_FAIL_THRESHOLD}", + observed=f"{dangling_src_pct:.3f}", + hint=( + f"{dangling_src_pct:.1%} of source_id values in relationship " + f"'{name}' are absent from the mapped source entity. The " + "source entity's id_column is usually an ALIAS for a derived " + "expression (e.g. CONCAT(regexp_extract(,'...'),'-x')). " + "Reproduce that exact id expression from the source entity's " + "SQL for source_id — do not select a raw/trust-local column." + ), + ) + ) + if dangling_src_pct > _DANGLING_FK_BUBBLE_THRESHOLD: + bubble = True + + # When an explicit cross-source overlap band is provided the relationship + # is *expected* to be partial (e.g. trust_a-only IDs vs the cross-trust + # canonical universe). In that case we trust the band check and skip + # the standard ``dangling_target_pct`` strictness — the partiality is + # the point. The catastrophic-dangling bubble below still fires, but + # only when the band itself ALSO fails (i.e. the realised overlap is + # materially worse than the Planner predicted). + if ( + total_edges > 0 + and dangling_tgt_pct >= _DANGLING_FK_FAIL_THRESHOLD + and expected_cross_source_overlap_band is None + ): + failures.append( + _fail( + check="dangling_target_pct", + expected=f"< {_DANGLING_FK_FAIL_THRESHOLD}", + observed=f"{dangling_tgt_pct:.3f}", + hint=( + f"{dangling_tgt_pct:.1%} of target_id values in relationship " + f"'{name}' are absent from the mapped target entity. The " + "target entity's id_column is usually an ALIAS for a derived " + "expression; reproduce that exact id expression from the " + "target entity's SQL for target_id — not a raw join column." + ), + ) + ) + + band_failed = False + if expected_cross_source_overlap_band is not None: + lo, hi = expected_cross_source_overlap_band + if not (lo <= overlap_pct <= hi): + band_failed = True + failures.append( + _fail( + check="cross_source_overlap_pct", + expected=f"in [{lo:.3f}, {hi:.3f}]", + observed=f"{overlap_pct:.3f}", + hint=( + f"Cross-source overlap for '{name}' is {overlap_pct:.1%}, " + f"outside the expected band [{lo:.1%}, {hi:.1%}]. " + "Check the join key and the source/target trust assignments." + ), + ) + ) + + # Bubble-to-planner on catastrophic target-dangling, with a band-aware gate. + # + # * Band absent + dangling > 0.5: the strict dangling_target_pct failure + # above already fired; we just flip the bubble flag (no new row needed). + # * Band present + band PASSED: the Planner predicted this overlap and + # was right — do NOT bubble, even if dangling > 0.5 (the partiality + # was expected). + # * Band present + band FAILED + dangling > 0.5: the realised overlap + # is materially worse than predicted. Bubble, and emit a dedicated + # ``dangling_target_pct_catastrophic`` failure so the FAIL report has + # a concrete structural row alongside the band-check failure. + if total_edges > 0 and dangling_tgt_pct > _DANGLING_FK_BUBBLE_THRESHOLD: + if expected_cross_source_overlap_band is None: + bubble = True + elif band_failed: + bubble = True + failures.append( + _fail( + check="dangling_target_pct_catastrophic", + expected=f"<= {_DANGLING_FK_BUBBLE_THRESHOLD}", + observed=f"{dangling_tgt_pct:.3f}", + hint=( + f"{dangling_tgt_pct:.1%} of target_id values in " + f"relationship '{name}' are absent from the mapped " + "target entity AND the realised overlap is outside " + "the predicted band. Re-plan the join key and the " + "source/target trust assignments." + ), + ) + ) + + logger.info( + "evaluate_relationship_mapping: %s -> %s (%d failure(s), bubble=%s)", + name, + "PASS" if not failures else "FAIL", + len(failures), + bubble, + ) + return build_report( + stage="deterministic", + metrics=metrics, + failures=failures, + bubble_to_planner=bubble, + ) diff --git a/src/agents/agent_mapping_pge/evaluator/report.py b/src/agents/agent_mapping_pge/evaluator/report.py new file mode 100644 index 00000000..532f0a7d --- /dev/null +++ b/src/agents/agent_mapping_pge/evaluator/report.py @@ -0,0 +1,37 @@ +"""Small helpers for assembling :class:`EvalReport` objects. + +The dataclasses themselves live in +:mod:`agents.agent_mapping_pge.contracts`; this module just centralises the +"compose a report from a list of failures" boilerplate so the deterministic +and (future) semantic evaluators stay short. +""" + +from typing import Any, Dict, List + +from back.core.logging import get_logger +from agents.agent_mapping_pge.contracts import EvalFailure, EvalReport + +logger = get_logger(__name__) + + +def build_report( + *, + stage: str, + metrics: Dict[str, Any], + failures: List[EvalFailure], + bubble_to_planner: bool, +) -> EvalReport: + """Assemble an :class:`EvalReport`; status is derived from ``failures``.""" + status = "PASS" if not failures else "FAIL" + if bubble_to_planner and status == "PASS": + logger.warning( + "build_report: bubble_to_planner=True but no failures → demoted " + "to False; check caller logic" + ) + return EvalReport( + status=status, + stage=stage, + metrics=dict(metrics), + failures=list(failures), + bubble_to_planner=bool(bubble_to_planner) and status == "FAIL", + ) diff --git a/src/agents/agent_mapping_pge/generators/__init__.py b/src/agents/agent_mapping_pge/generators/__init__.py new file mode 100644 index 00000000..575f858c --- /dev/null +++ b/src/agents/agent_mapping_pge/generators/__init__.py @@ -0,0 +1,31 @@ +"""Generator agents for the mapping-PGE pipeline. + +Each Generator is a narrow tool-calling agent that maps ONE ontology item +(class or relationship) at a time. The orchestrator (Sprint 7) calls them +per-item with a filtered slice of the Planner's :class:`SourceModel` — the +Generators never see the full ontology or full metadata, keeping each +decision cheap and local. + +* Sprint 4 — :mod:`agents.agent_mapping_pge.generators.entity`. +* Sprint 5 — :mod:`agents.agent_mapping_pge.generators.relationship`. +""" + +from agents.agent_mapping_pge.generators.entity import ( + EntityGenResult, + EntityGenStep, + run_entity_generator, +) +from agents.agent_mapping_pge.generators.relationship import ( + RelationshipGenResult, + RelationshipGenStep, + run_relationship_generator, +) + +__all__ = [ + "EntityGenResult", + "EntityGenStep", + "run_entity_generator", + "RelationshipGenResult", + "RelationshipGenStep", + "run_relationship_generator", +] diff --git a/src/agents/agent_mapping_pge/generators/entity.py b/src/agents/agent_mapping_pge/generators/entity.py new file mode 100644 index 00000000..973e9599 --- /dev/null +++ b/src/agents/agent_mapping_pge/generators/entity.py @@ -0,0 +1,812 @@ +""" +OntoBricks Mapping-PGE EntityGenerator Agent. + +Sprint 4 of the Planner-Generator-Evaluator (PGE) redesign. + +The EntityGenerator is a narrow, focused LLM agent that maps **one** ontology +class at a time. The orchestrator (Sprint 7) calls it per item with a +filtered slice of the Planner's :class:`SourceModel`: + +* the single ontology class to map, with its full attribute list, and +* a small SourceModel slice — only the candidate tables / canonical IDs / + joins that are relevant to *this* class. + +The Generator does NOT see the full ontology or full metadata. That is the +core design contract: keep its context bounded and each decision cheap. + +The loop shape mirrors :mod:`agents.agent_mapping_pge.planner` — same +``call_serving_endpoint`` + ``dispatch_tool`` ReAct cycle, same 3-second +inter-iteration delay, same MLflow trace decorator — with these differences: + +* Smaller default budget (12 vs 25): mapping one class is bounded work. +* Different tool set: only ``execute_sql``, ``sample_table``, and the + terminal ``submit_entity_mapping``. The slice already carries every piece + of context the Generator needs. +* No single-shot fallback: if the endpoint refuses tools, the Generator + reports failure — it produces structured output through + ``submit_entity_mapping`` only. +* The "NO SILENT DROPS" invariant: every ontology attribute must be either + in ``attribute_mappings`` or in ``unmapped_attributes`` with a one-sentence + reason. The system prompt enforces this; the tool persists it. +""" + +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +import requests + +from back.core.logging import get_logger +from agents.engine_base import ( + call_serving_endpoint, + dispatch_tool, + accumulate_usage, +) +from agents.tools.context import ToolContext +from agents.tools.mapping import ( + MAPPING_TOOL_DEFINITIONS_BY_NAME, + MAPPING_TOOL_HANDLERS, +) +from agents.tools.planner import ( + SAMPLE_TABLE_DEF, + tool_sample_table, +) +from agents.tools.sql import ( + SQL_TOOL_DEFINITIONS, + SQL_TOOL_HANDLERS, +) +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +MAX_ITERATIONS = 12 +LLM_TIMEOUT = 180 +_ITERATION_DELAY_SEC = 1 +# See planner._MAX_TOKENS comment — same rationale for the Generator's +# submit_entity_mapping JSON (SQL + attribute_mappings can be large). +_MAX_TOKENS = 50000 + +_TRACE_NAME = "mapping_pge_entity_generator" + + +# ===================================================== +# Tool aggregation +# ===================================================== +# +# The EntityGenerator only needs: +# * execute_sql – validate the composed SELECT before submitting. +# * sample_table – disambiguate when two candidate tables are equally +# plausible (e.g. same confidence in the slice). +# * submit_entity_mapping – TERMINAL. +# +# We deliberately exclude: +# * get_ontology / get_metadata / get_documents_context — the Planner's +# view; the slice already has what's needed. +# * column_value_overlap / distinct_count — those validate join keys and +# canonical IDs, which the Planner already locked in. +# * submit_relationship_mapping / submit_source_model — wrong stage. + +# Filter MAPPING_TOOL_DEFINITIONS down to just submit_entity_mapping. We +# look up by name from the by-name index in ``mapping.py`` rather than +# scanning the list inline. Sprint 5 will reuse the same pattern for +# ``submit_relationship_mapping``. +_SUBMIT_ENTITY_DEF: dict = MAPPING_TOOL_DEFINITIONS_BY_NAME["submit_entity_mapping"] + +TOOL_DEFINITIONS: List[dict] = ( + SQL_TOOL_DEFINITIONS + + [SAMPLE_TABLE_DEF] + + [_SUBMIT_ENTITY_DEF] +) + +TOOL_HANDLERS: Dict[str, Callable] = { + **SQL_TOOL_HANDLERS, + "sample_table": tool_sample_table, + "submit_entity_mapping": MAPPING_TOOL_HANDLERS["submit_entity_mapping"], +} + + +# ===================================================== +# Data classes +# ===================================================== + + +@dataclass +class EntityGenStep: + """One observable step of the EntityGenerator's execution. + + Mirrors :class:`agents.agent_mapping_pge.planner.PlannerStep` but is + scoped to the Generator so the orchestrator (Sprint 7) can render a + per-class timeline in the UI. + """ + + step_type: str # "tool_call" | "tool_result" | "output" + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class EntityGenResult: + """Outcome of a single EntityGenerator invocation. + + ``mapping`` holds the submitted entity-mapping dict (the same shape the + handler appends to ``ctx.entity_mappings``) when ``success`` is True. + """ + + success: bool + mapping: Optional[dict] = None + steps: List[EntityGenStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + + +# ===================================================== +# System prompt +# ===================================================== +# +# The ENTITY SQL RULES section is lifted verbatim from the legacy in-house +# mapping agent (the section starting "SQL RULES FOR ENTITIES") because +# those rules are correct and load-bearing — every mapping query must +# follow them or downstream SPARQL translation breaks. +# +# The PGE-specific additions are the slice-consumption rules: pick the +# best candidate table from the slice, use the canonical ID exactly as +# the Planner specified it, and account for every ontology attribute. + +SYSTEM_PROMPT = """\ +You are a senior data engineer. Your job is to map ONE ontology class to a \ +single SQL SELECT query against a Databricks source table, validated against \ +real data via execute_sql, and submitted via submit_entity_mapping. + +YOU WILL BE GIVEN +• ontology_class: the class to map (uri, label, comment, attributes list). +• source_model_slice: a small JSON object the Planner already curated for \ +this class: + - candidate_tables[]: {table, confidence, reason} — the tables that could \ +realise this class. + - canonical_id.canonical_column_per_table[
]: the expression that \ +MUST be aliased AS ID for each table. THIS VALUE MAY BE A BARE COLUMN \ +NAME ("MOTHER_NHS_NO") OR A FULL SQL EXPRESSION \ +("regexp_extract(pregnancy_id, '([a-f0-9-]+-preg-[0-9]+)')"). Drop it \ +verbatim into the SELECT and alias it AS ID — do NOT rewrite it, do NOT \ +pick a different column, do NOT strip the function call. The Planner emits \ +SQL expressions when raw column values across trusts are in different \ +formats and need to be normalized to a common canonical key. + - canonical_id.format_note: a one-sentence note describing the canonical \ +key (may be empty). Read it to understand what each row's ID represents. + - relevant_joins[]: optional — any joins the Planner thinks may apply. + +SINGLE-SOURCE vs CROSS-SOURCE (CRITICAL — read carefully) +The number of entries in canonical_id.canonical_column_per_table is the \ +authoritative signal for how to shape your SELECT: + + • If canonical_column_per_table has EXACTLY ONE table → single-source \ +class. Write a flat SELECT from that one table. Pick it from the matching \ +candidate_tables entry. + + • If canonical_column_per_table has TWO OR MORE tables → CROSS-SOURCE \ +class (e.g. the same patient or pregnancy realised across multiple trusts). \ +You MUST emit a UNION ALL across ALL listed tables, NOT pick one. Each \ +branch uses that table's canonical-ID column AS ID. Picking just one would \ +produce a Mother (or Pregnancy, etc.) entity that's missing 60–70% of its \ +real instances, and every relationship pointing at it would then dangle. \ +This is the #1 failure mode the orchestrator catches — do not produce it. + + UNION shape (use exactly this pattern — substitute the canonical-ID \ +EXPRESSION exactly as the Planner specified it for that table; do NOT \ +rewrite it): + SELECT AS ID, AS Label, \ + FROM WHERE IS NOT NULL + UNION ALL + SELECT AS ID, AS Label, \ + FROM WHERE IS NOT NULL + UNION ALL + ... + + All branches must return the SAME columns in the SAME order, AND each \ +column must have the SAME TYPE in every branch. If a branch lacks a column \ +another branch has, project a NULL with a matching alias **cast to the same \ +type the real branch uses** (e.g. if branch A has ``BABY_NHS_NO`` typed \ +BIGINT, branch B must use ``CAST(NULL AS BIGINT) AS BABY_NHS_NO`` — not \ +``AS STRING``). When two branches hold the column with DIFFERENT types, cast \ +BOTH to a common type (``CAST(... AS STRING)`` is the safe default). A \ +``CAST_INVALID_INPUT`` / type-mismatch error from execute_sql always means a \ +column's types differ across branches — fix the casts, do not change the ID. + +TOOLS +You have three tools: + • execute_sql – Validate the composed SELECT before submitting. \ +The tool runs your query with a small LIMIT and returns columns + sample \ +rows; the persisted mapping has no LIMIT. + • sample_table – Up to N random rows from a table. Use only when \ +two candidate tables are equally plausible and you need to peek at real \ +values to disambiguate. + • submit_entity_mapping – TERMINAL. Call exactly once, after execute_sql \ +succeeds, with the full mapping payload. + +SQL RULES FOR ENTITIES (CRITICAL) +• Always use the full table name from the slice (catalog.schema.table). +• The FIRST column MUST be aliased AS ID — it MUST be the canonical-ID \ +column the slice specifies for the chosen table. +• The SECOND column MUST be aliased AS Label — pick the most human-readable \ +available column (typically ``name``, ``label``, ``display_name``, or \ +similar). If no human-readable column exists, fall back to the canonical \ +ID column itself aliased AS Label. +• Add one column per ontology data-property attribute you can satisfy from \ +the chosen table. Use the column's original name (no alias). +• If the same column serves as both an alias and an attribute, include it \ +twice: once with the alias (AS ID or AS Label) and once with its original \ +name so it appears in attribute_mappings. +• Add WHERE IS NOT NULL to filter null keys. When the ID is a \ +derived expression, also exclude empty extractions (e.g. \ +``WHERE regexp_extract(...) <> ''``). +• DEDUP COLLAPSED KEYS: when the canonical-ID is a derived EXPRESSION that \ +can repeat across rows (e.g. a ``-del`` key where several episode rows \ +share one pregnancy core), the same ID will appear on multiple rows and the \ +evaluator FAILs on "duplicate ID values". Make each node id unique: wrap the \ +UNION in ``SELECT ... FROM () GROUP BY ID`` (taking MAX() of each \ +attribute) or use ``SELECT DISTINCT`` when there are no attributes. The id \ +column must have exactly one row per distinct value. +• Do NOT add LIMIT — the persisted mapping query must return ALL rows. \ +execute_sql adds a small LIMIT internally for validation only. +• Do NOT use ORDER BY, CTEs, or subqueries unless absolutely necessary. +• Write simple, flat SELECT statements. + +ATTRIBUTE COVERAGE — NO SILENT DROPS (CRITICAL) +For EACH ontology attribute on the class, you must do ONE of: + (a) include a SQL column for it in the SELECT, AND add an entry to \ +attribute_mappings mapping the ontology attribute name to the SQL column \ +name (case-sensitive); OR + (b) add it to unmapped_attributes with a one-sentence reason, using the \ +shape {"name": "", "reason": ""}. + +You may NOT silently drop an attribute. The orchestrator will reject any \ +mapping where some ontology attributes appear in neither list. If a column \ +genuinely does not exist on the chosen table, that's an honest unmapped — \ +say so in the reason. + +WORKFLOW +1. Read the ontology class and the source_model_slice carefully. +2. COUNT the entries in canonical_id.canonical_column_per_table: + - one → single-source: pick that table, compose a flat SELECT. + - two or more → cross-source: compose a UNION ALL across ALL of them \ +(see the SINGLE-SOURCE vs CROSS-SOURCE block above). Do NOT pick one. +3. Compose the SELECT (or UNION ALL) following the SQL RULES above. For \ +each branch, the value of canonical_column_per_table[] is what \ +gets aliased AS ID — drop it in verbatim. It may already be a SQL \ +expression (e.g. ``regexp_extract(...)``); do not rewrite it. +4. Call execute_sql to validate the SELECT. If it fails, READ the error and \ +fix the SQL (typically a typo'd column name, mismatched column lists in a \ +UNION, or wrong full_name). Retry as needed. Never submit an un-validated \ +query. +5. Once execute_sql succeeds, call submit_entity_mapping EXACTLY ONCE with: + class_uri, class_name, sql_query (no LIMIT), id_column, label_column, \ +attribute_mappings, unmapped_attributes. +6. That's the terminal step. Do not emit any free text after submitting. + +GENERAL RULES +• Only ever pass row-returning queries (SELECT / WITH …) to execute_sql. +• Do not call get_metadata, get_ontology, or any other tool — they are not \ +available to you. The slice carries everything you need. +• If a retry_hint is present at the top of the user message, treat it as \ +authoritative — your previous attempt failed for the reason stated and you \ +should NOT repeat the same mistake. +""" + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _build_user_prompt( + ontology_class: dict, + source_model_slice: dict, + retry_hint: Optional[str] = None, +) -> str: + """Render the per-class user prompt. + + The orchestrator hands us `ontology_class` and a focused + `source_model_slice`. We emit a structured prompt that: + * surfaces the retry hint up top if one was provided, + * lists the class metadata and attribute list explicitly so the LLM + cannot forget any attribute, and + * embeds the slice as JSON so the LLM can refer to it precisely. + """ + parts: List[str] = [] + + if retry_hint: + parts.append(f"RETRY HINT (authoritative): {retry_hint}") + parts.append("") + + class_uri = ontology_class.get("uri", "") + class_label = ontology_class.get("label") or ontology_class.get("name", "") + class_comment = ontology_class.get("comment", "") or "" + attributes = ontology_class.get("attributes", []) or [] + + attr_summary_lines: List[str] = [] + for attr in attributes: + if isinstance(attr, dict): + attr_name = attr.get("name") or attr.get("label") or attr.get("uri", "?") + attr_type = attr.get("type") or attr.get("range") or "" + attr_summary_lines.append( + f" - {attr_name}" + (f" ({attr_type})" if attr_type else "") + ) + else: + attr_summary_lines.append(f" - {attr}") + + parts.append("ONTOLOGY CLASS") + parts.append(f" uri: {class_uri}") + parts.append(f" label: {class_label}") + if class_comment: + parts.append(f" comment: {class_comment}") + if attr_summary_lines: + parts.append(" attributes ({} total):".format(len(attributes))) + parts.extend(attr_summary_lines) + else: + parts.append(" attributes: (none — only ID and Label required)") + + parts.append("") + parts.append("SOURCE MODEL SLICE") + parts.append(json.dumps(source_model_slice, indent=2, default=str)) + + parts.append("") + parts.append( + "Pick the best candidate table from the slice, compose a flat SELECT " + "following the SQL RULES, validate with execute_sql, then call " + "submit_entity_mapping exactly once. Every ontology attribute must " + "appear in either attribute_mappings or unmapped_attributes — no " + "silent drops." + ) + + prompt = "\n".join(parts) + logger.debug( + "_build_user_prompt for class=%s (%d chars):\n%s", + class_uri, + len(prompt), + prompt, + ) + return prompt + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_entity_generator") +def run_entity_generator( + host: str, + token: str, + endpoint_name: str, + client: Any, + *, + ontology_class: dict, + source_model_slice: dict, + retry_hint: Optional[str] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: int = MAX_ITERATIONS, +) -> EntityGenResult: + """Run the EntityGenerator agent for a single ontology class. + + The agent autonomously composes a SQL SELECT for ``ontology_class`` + against the candidate table(s) in ``source_model_slice``, validates the + SQL with ``execute_sql``, and submits the validated mapping via the + terminal ``submit_entity_mapping`` tool. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client (must expose ``execute_query(sql)``). + ontology_class: Full dict for the SINGLE class to map (uri, label, + comment, attributes list). + source_model_slice: Filtered SourceModel slice with candidate_tables, + canonical_id, and optional relevant_joins. + retry_hint: Optional one-sentence hint from the orchestrator's + previous-attempt evaluation. When present, surfaced at the top of + the user prompt. + on_step: Optional progress callback ``(msg, pct)`` for UI updates. + max_iterations: Upper bound on tool-call iterations (default 12 — + smaller than the Planner because the scope is one class). + + Returns: + An :class:`EntityGenResult`. ``success`` is True iff a mapping was + successfully submitted; in that case ``mapping`` holds the submitted + dict. On failure, ``error`` explains why and ``mapping`` is None. + """ + iteration_limit = max_iterations if max_iterations is not None else MAX_ITERATIONS + + class_uri = (ontology_class or {}).get("uri", "") + class_label = ( + (ontology_class or {}).get("label") + or (ontology_class or {}).get("name", "") + ) + n_attrs = len(((ontology_class or {}).get("attributes") or [])) + n_candidates = len(((source_model_slice or {}).get("candidate_tables") or [])) + + logger.info( + "===== ENTITY GENERATOR START ===== endpoint=%s, class=%s (%s), " + "attributes=%d, candidate_tables=%d, retry_hint=%s, max_iter=%d", + endpoint_name, + class_label, + class_uri, + n_attrs, + n_candidates, + "yes" if retry_hint else "no", + iteration_limit, + ) + + ctx = ToolContext( + host=host.rstrip("/"), + token=token, + client=client, + # The slice subsumes metadata/ontology for this agent; the unified + # ToolContext still needs these fields, so we plant the slice into + # ``metadata`` for completeness even though no handler reads it. + metadata={}, + ontology={}, + documents=[], + ) + + result = EntityGenResult(success=False) + + user_prompt = _build_user_prompt( + ontology_class or {}, source_model_slice or {}, retry_hint=retry_hint + ) + messages: List[dict] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + logger.info( + "EntityGenerator conversation initialized: system=%d chars, user=%d chars", + len(SYSTEM_PROMPT), + len(user_prompt), + ) + + total_usage: Dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} + + def _progress_pct(iteration_idx: int) -> int: + # Linear ramp 5 → 95 across the iteration budget. submit hits 100. + ratio = (iteration_idx + 1) / max(iteration_limit, 1) + return min(5 + int(ratio * 90), 95) + + def notify(msg: str, *, pct: Optional[int] = None) -> None: + actual_pct = pct if pct is not None else 5 + logger.info("ENTITY GEN STEP [%d%%] %s", actual_pct, msg) + if on_step: + on_step(msg, actual_pct) + + notify(f"Generating mapping for {class_label or class_uri}…", pct=1) + + # Snapshot the pre-existing mapping count so we can detect "this run + # added a mapping" without relying on absolute counters. (The orchestrator + # in Sprint 7 may reuse a ToolContext across calls; today's `ctx` is + # fresh, but the assertion is cheap and future-proof.) + pre_run_mapping_count = len(ctx.entity_mappings) + + # ------------------------------------------------------------------ + # Agent loop + # ------------------------------------------------------------------ + for iteration in range(iteration_limit): + if iteration > 0: + logger.debug( + "Iteration %d: waiting %ds before LLM call (rate limit mitigation)", + iteration + 1, + _ITERATION_DELAY_SEC, + ) + time.sleep(_ITERATION_DELAY_SEC) + + current_iteration = iteration + 1 + pct = _progress_pct(iteration) + logger.info( + "----- EntityGenerator iteration %d/%d — %d messages, mapping=%s -----", + current_iteration, + iteration_limit, + len(messages), + "set" if len(ctx.entity_mappings) > pre_run_mapping_count else "unset", + ) + notify( + f"Mapping iteration {current_iteration}/{iteration_limit}…", + pct=pct, + ) + + t0 = time.time() + try: + llm_response = call_serving_endpoint( + host, + token, + endpoint_name, + messages, + tools=TOOL_DEFINITIONS, + max_tokens=_MAX_TOKENS, + temperature=0.1, + timeout=LLM_TIMEOUT, + trace_name=_TRACE_NAME, + ) + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + logger.warning( + "EntityGenerator iteration %d: HTTPError status=%s", + current_iteration, + status, + ) + logger.debug( + "EntityGenerator iteration %d: HTTPError body: %.500s", + current_iteration, + exc.response.text if exc.response is not None else "N/A", + ) + if exc.response is not None and status in (400, 422): + result.error = "LLM endpoint does not support function calling" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "EntityGenerator: endpoint refused tools — cannot produce a mapping" + ) + return result + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "EntityGenerator: LLM request failed at iteration %d: %s", + current_iteration, + exc, + ) + return result + except requests.exceptions.ReadTimeout: + result.error = f"LLM request timed out after {LLM_TIMEOUT}s" + result.iterations = current_iteration + result.usage = total_usage + logger.error("EntityGenerator: timeout at iteration %d", current_iteration) + return result + except requests.exceptions.RequestException as exc: + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "EntityGenerator: request exception at iteration %d: %s", + current_iteration, + exc, + ) + return result + + elapsed_ms = int((time.time() - t0) * 1000) + logger.info( + "EntityGenerator iteration %d: LLM responded in %dms", + current_iteration, + elapsed_ms, + ) + + accumulate_usage(total_usage, llm_response.get("usage", {})) + + choice = llm_response.get("choices", [{}])[0] + finish_reason = choice.get("finish_reason", "?") + message = choice.get("message", {}) + tool_calls = message.get("tool_calls", []) + has_content = bool(message.get("content")) + logger.info( + "EntityGenerator iteration %d: finish_reason=%s, tool_calls=%d, has_content=%s", + current_iteration, + finish_reason, + len(tool_calls), + has_content, + ) + + if not tool_calls: + # The Generator must terminate via submit_entity_mapping, never + # via free text. + content = (message.get("content") or "")[:500] + logger.warning( + "EntityGenerator iteration %d: produced text without submitting mapping — %d chars", + current_iteration, + len(message.get("content") or ""), + ) + result.steps.append( + EntityGenStep( + step_type="output", + content=content, + duration_ms=elapsed_ms, + ) + ) + result.error = "entity generator produced text without submitting mapping" + result.iterations = current_iteration + result.usage = total_usage + notify( + "Entity generator produced text without submitting mapping.", + pct=pct, + ) + return result + + logger.info( + "EntityGenerator iteration %d: processing %d tool call(s): [%s]", + current_iteration, + len(tool_calls), + ", ".join( + tc.get("function", {}).get("name", "?") for tc in tool_calls + ), + ) + messages.append(message) + + terminal_success = False + for tc_idx, tc in enumerate(tool_calls, 1): + func = tc.get("function", {}) + tool_name = func.get("name", "") + raw_args = func.get("arguments", "{}") + tool_id = tc.get("id", "") + + try: + arguments = json.loads(raw_args) + except json.JSONDecodeError: + arguments = {} + + logger.info( + "EntityGenerator iteration %d: calling tool '%s' (%d/%d)", + current_iteration, + tool_name, + tc_idx, + len(tool_calls), + ) + + # Human-readable progress messages per tool. + if tool_name == "submit_entity_mapping": + notify(f"Submitting mapping for {class_label or class_uri}…", pct=pct) + elif tool_name == "sample_table": + fn = arguments.get("full_name", "?") + notify(f"Sampling {fn}…", pct=pct) + elif tool_name == "execute_sql": + sql_preview = arguments.get("sql", "")[:80] + notify(f"Running SQL: {sql_preview}…", pct=pct) + else: + notify(f"Calling {tool_name}…", pct=pct) + + result.steps.append( + EntityGenStep( + step_type="tool_call", + content=json.dumps(arguments, default=str)[:500], + tool_name=tool_name, + ) + ) + + t1 = time.time() + tool_result = dispatch_tool( + TOOL_HANDLERS, ctx, tool_name, arguments, trace_name=_TRACE_NAME + ) + tool_ms = int((time.time() - t1) * 1000) + + logger.info( + "EntityGenerator iteration %d: tool '%s' returned %d chars in %dms", + current_iteration, + tool_name, + len(tool_result), + tool_ms, + ) + + result.steps.append( + EntityGenStep( + step_type="tool_result", + content=( + (tool_result[:500] + "…") + if len(tool_result) > 500 + else tool_result + ), + tool_name=tool_name, + duration_ms=tool_ms, + ) + ) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": tool_result, + } + ) + + # Detect terminal success: submit_entity_mapping returned + # success=True AND a mapping for THIS class_uri is present in + # ctx.entity_mappings. A submit with a mismatched class_uri (the + # LLM mapped a different class than requested) is NOT terminal — + # we coach the LLM via a corrective tool message and let the loop + # continue so it can resubmit with the right URI. + if tool_name == "submit_entity_mapping": + try: + parsed = json.loads(tool_result) + except json.JSONDecodeError: + parsed = {} + if parsed.get("success") is True: + matched = any( + m.get("ontology_class") == class_uri + for m in ctx.entity_mappings + ) + if matched: + terminal_success = True + logger.info( + "EntityGenerator iteration %d: submit_entity_mapping succeeded — terminating", + current_iteration, + ) + else: + submitted_uri = arguments.get("class_uri", "") + mismatch_msg = ( + f"submitted class_uri '{submitted_uri}' does not " + f"match requested class_uri '{class_uri}'; " + f"resubmit with class_uri='{class_uri}'" + ) + logger.warning( + "EntityGenerator iteration %d: submit_entity_mapping " + "class_uri mismatch — submitted=%s, requested=%s", + current_iteration, + submitted_uri, + class_uri, + ) + corrective_payload = json.dumps( + {"success": False, "error": mismatch_msg} + ) + # Replace the recorded tool_result step's content so + # the UI / trace reflects the corrective signal + # rather than the original (misleading) success + # response. + result.steps[-1] = EntityGenStep( + step_type="tool_result", + content=corrective_payload, + tool_name=tool_name, + duration_ms=result.steps[-1].duration_ms, + ) + # Replace the tool message just appended to + # ``messages`` so the LLM sees the corrective + # payload on the next turn (one tool message per + # tool_call_id — keep the protocol clean). + messages[-1] = { + "role": "tool", + "tool_call_id": tool_id, + "content": corrective_payload, + } + + if terminal_success: + # Pull the mapping for this class by strict URI match. The + # terminal-success guard above already verified an entry with + # this URI exists; if we somehow can't find one here that's an + # internal invariant violation, not a recoverable failure. + submitted = next( + ( + m + for m in reversed(ctx.entity_mappings) + if m.get("ontology_class") == class_uri + ), + None, + ) + if submitted is None: + result.error = ( + "internal: submit succeeded but mapping not found for class_uri" + ) + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "===== ENTITY GENERATOR FAILED ===== %s (class=%s)", + result.error, + class_uri, + ) + return result + result.success = True + result.mapping = submitted + result.iterations = current_iteration + result.usage = total_usage + logger.info( + "===== ENTITY GENERATOR COMPLETE ===== class=%s, iterations=%d, " + "prompt_tokens=%d, completion_tokens=%d", + class_uri, + result.iterations, + total_usage["prompt_tokens"], + total_usage["completion_tokens"], + ) + notify(f"Mapping for {class_label or class_uri} complete!", pct=100) + return result + + # Budget exhausted without a successful submit. + result.iterations = iteration_limit + result.usage = total_usage + result.error = "entity generator exhausted iteration budget" + logger.error("===== ENTITY GENERATOR FAILED ===== %s", result.error) + notify(result.error, pct=95) + return result diff --git a/src/agents/agent_mapping_pge/generators/relationship.py b/src/agents/agent_mapping_pge/generators/relationship.py new file mode 100644 index 00000000..af5347f0 --- /dev/null +++ b/src/agents/agent_mapping_pge/generators/relationship.py @@ -0,0 +1,875 @@ +""" +OntoBricks Mapping-PGE RelationshipGenerator Agent. + +Sprint 5 of the Planner-Generator-Evaluator (PGE) redesign. + +The RelationshipGenerator is the sibling of :mod:`.entity` — same ReAct +loop shape and tooling discipline, narrower scope. It maps **one** ontology +property (relationship) at a time, given: + +* the property to map (uri, label, comment, domain, range), +* the source and target **entity mappings already produced by the + EntityGenerator** — crucially, the ``id_column`` each side mapped on, and +* a small SourceModel slice that surfaces the relevant join-key subgraph. + +The system prompt FORBIDS picking endpoint columns that do not match the +already-mapped entity IDs: the source/target endpoint columns are GIVEN. +This keeps relationships consistent with the entities they connect — if a +relationship's ``source_id`` doesn't match the source entity's ``id_column``, +the resulting SPARQL graph cannot join. + +The loop semantics mirror :mod:`.entity`: + +* Same default budget (12). +* Same 3-second inter-iteration delay. +* Same MLflow trace decorator. +* No single-shot fallback (terminate via tool call only). +* Strict ``property_uri`` match on terminal detection — a submit with the + wrong URI is coached via a corrective tool message, not accepted. +""" + +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +import requests + +from back.core.logging import get_logger +from agents.engine_base import ( + call_serving_endpoint, + dispatch_tool, + accumulate_usage, +) +from agents.tools.context import ToolContext +from agents.tools.mapping import ( + MAPPING_TOOL_DEFINITIONS_BY_NAME, + MAPPING_TOOL_HANDLERS, +) +from agents.tools.planner import ( + SAMPLE_TABLE_DEF, + tool_sample_table, +) +from agents.tools.sql import ( + SQL_TOOL_DEFINITIONS, + SQL_TOOL_HANDLERS, +) +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +MAX_ITERATIONS = 12 +LLM_TIMEOUT = 180 +_ITERATION_DELAY_SEC = 1 +# See planner._MAX_TOKENS comment — large UNION ALL queries for cross-source +# relationships can exceed a small ceiling. +_MAX_TOKENS = 50000 + +_TRACE_NAME = "mapping_pge_relationship_generator" + + +# ===================================================== +# Tool aggregation +# ===================================================== +# +# The RelationshipGenerator only needs: +# * execute_sql – validate the composed two-column SELECT. +# * sample_table – peek at endpoint columns when the join is +# ambiguous (rare; usually unnecessary). +# * submit_relationship_mapping – TERMINAL. +# +# We deliberately exclude: +# * get_ontology / get_metadata / get_documents_context — wrong stage. +# * column_value_overlap / distinct_count — already locked by the Planner. +# * submit_source_model / submit_entity_mapping — wrong stage. + +_SUBMIT_RELATIONSHIP_DEF: dict = MAPPING_TOOL_DEFINITIONS_BY_NAME[ + "submit_relationship_mapping" +] + +TOOL_DEFINITIONS: List[dict] = ( + SQL_TOOL_DEFINITIONS + + [SAMPLE_TABLE_DEF] + + [_SUBMIT_RELATIONSHIP_DEF] +) + +TOOL_HANDLERS: Dict[str, Callable] = { + **SQL_TOOL_HANDLERS, + "sample_table": tool_sample_table, + "submit_relationship_mapping": MAPPING_TOOL_HANDLERS[ + "submit_relationship_mapping" + ], +} + + +# ===================================================== +# Data classes +# ===================================================== + + +@dataclass +class RelationshipGenStep: + """One observable step of the RelationshipGenerator's execution. + + Mirrors :class:`.entity.EntityGenStep` — scoped to the relationship + generator so the orchestrator (Sprint 7) can render a per-property + timeline in the UI. + """ + + step_type: str # "tool_call" | "tool_result" | "output" + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class RelationshipGenResult: + """Outcome of a single RelationshipGenerator invocation. + + ``mapping`` holds the submitted relationship-mapping dict (the same + shape the handler appends to ``ctx.relationships``) when ``success`` is + True. + """ + + success: bool + mapping: Optional[dict] = None + steps: List[RelationshipGenStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + + +# ===================================================== +# System prompt +# ===================================================== +# +# The RELATIONSHIP SQL RULES section is lifted verbatim from the legacy +# in-house mapping agent (the section starting "SQL RULES FOR +# RELATIONSHIPS"). To those rules we add the Sprint 5 constraints: the +# source and target ID columns are GIVEN by the already-produced entity +# mappings; the LLM may not pick different endpoint columns. + +SYSTEM_PROMPT = """\ +You are a senior data engineer. Your job is to map ONE ontology property \ +(relationship) to a single SQL SELECT query against Databricks source \ +table(s), validated against real data via execute_sql, and submitted via \ +submit_relationship_mapping. + +YOU WILL BE GIVEN +• ontology_property: the property to map (uri, label, comment, domain, range). +• source_entity_mapping / target_entity_mapping: the ALREADY-MAPPED endpoint \ +entities — each with its class_uri, id_column, and the exact SQL it ran. \ +READ BOTH SQLs: they are the source of truth for your endpoint values. +• source_model_slice: relevant_joins[] {from_ref, to_ref, confidence, \ +overlap_pct, kind} and candidate_tables[] the Planner curated. Prefer \ +high-overlap, high-confidence joins. + +THE EDGE MUST CONNECT EXISTING NODES +An edge row is (source_id, target_id). Each value MUST already exist as a \ +node id in the corresponding entity, or it "dangles" and the mapping is \ +rejected (the evaluator fails any mapping with >5% dangling on either side, \ +unless the Planner predicted a cross-source band). Three traps cause almost \ +all dangling — avoid all three: + +TRAP 1 — id_column is an ALIAS FOR A DERIVED EXPRESSION, not a real column. +Each entity mints its id with ``SELECT AS `` (the \ +id_column is usually just ``ID``). That expression is often a canonical-key \ +normalization, e.g.:: + + CONCAT(regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-baby') + +There is no ``ID`` column to select. You MUST **reproduce the entity's id \ +EXPRESSION verbatim** (copied from its SQL), applied to your table, for the \ +endpoint. A raw column (a ``*_id`` join key, a trust-local id) will NOT match \ +→ 100% dangling. + +TRAP 2 — building from a table only ONE endpoint entity covers. +The two entities may be sourced from different trusts (compare the FROM \ +tables in each entity's SQL). Their id universes overlap only on the trust(s) \ +BOTH cover. Build the edge from a table present in BOTH entities' FROM lists \ +(the shared-coverage table). Building from a table only the target covers \ +makes every source_id absent from the source → 100% source-dangling (and \ +vice-versa). + +TRAP 3 — column-name / alias mismatch on submit. +Your SELECT MUST alias the two columns exactly ``AS source_id`` and \ +``AS target_id``, and you MUST submit ``source_id_column="source_id"`` and \ +``target_id_column="target_id"``. These name the columns IN YOUR EDGE OUTPUT, \ +NOT the entity's id_column. If they disagree with your SELECT aliases the \ +evaluator reads nothing and every edge dangles. + +WORKED EXAMPLE — ``Baby --hasApgarScore--> Apgar Score`` +Baby is sourced from {trust_a.maternity_episode, trust_b.delivery}; Apgar \ +Score from {trust_a.maternity_episode, trust_c.maternity_event}. Shared \ +coverage = trust_a only (Trap 2). Both ids share the canonical pregnancy core \ +with role suffixes (Trap 1). So build from trust_a, reproducing both \ +expressions from one row:: + + SELECT CONCAT(regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-baby') AS source_id, + CONCAT(regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-apgar') AS target_id + FROM fiifi_cdm_demo_catalog.trust_a.maternity_episode + WHERE regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1) <> '' + +Building from trust_c (Apgar's natural home) would dangle 100% on the Baby \ +side, because Baby has no trust_c rows. + +TOOLS + • execute_sql – validate / probe your SELECT (runs with a \ +small LIMIT; the persisted mapping has none). + • sample_table – peek at real values when a column is \ +ambiguous. + • submit_relationship_mapping – TERMINAL. Call EXACTLY ONCE, only after a \ +clean dangling probe (see WORKFLOW step 4). + +SQL RULES +• SELECT exactly two columns: `` AS source_id, AS target_id`` (Trap 1 + Trap 3). +• Build FROM a table both entities cover (Trap 2). Same-trust FK joins: one \ +table, no join. Cross-source: a UNION ALL of per-source SELECTs (each source \ +that holds both cores), or a JOIN on the shared canonical key. +• No LIMIT, no ORDER BY. Always full table names (catalog.schema.table). + +WORKFLOW +1. Read BOTH entity SQLs. Extract each entity's id EXPRESSION (the \ +``SELECT AS ``) and its set of FROM tables. +2. Pick a shared-coverage table (Trap 2). Compose the two-column SELECT, \ +setting source_id to the source entity's id EXPRESSION and target_id to the \ +target entity's id EXPRESSION, reproduced verbatim and aliased ``AS \ +source_id`` / ``AS target_id``. +3. Call execute_sql to confirm the query parses and returns two columns of \ +rows. Read any error and fix it; never submit an un-validated query. +4. SELF-VERIFY THE VALUES BEFORE SUBMITTING (MANDATORY GATE). Run this probe \ +via execute_sql: + + WITH rel AS (), + src AS (), + tgt AS () + SELECT + (SELECT COUNT(*) FROM rel) AS edges, + (SELECT COUNT(*) FROM rel r WHERE r.source_id NOT IN (SELECT ID FROM src)) AS dangling_src, + (SELECT COUNT(*) FROM rel r WHERE r.target_id NOT IN (SELECT ID FROM tgt)) AS dangling_tgt + + You may submit ONLY when ``dangling_src`` AND ``dangling_tgt`` are both 0 \ +(or a tiny fraction of edges). If either is high you hit Trap 1 or Trap 2 — \ +fix the endpoint expression or switch to the shared-coverage table, then \ +re-run this probe. Do NOT submit on an unrun or failing probe. +5. submit_relationship_mapping EXACTLY ONCE: property_uri, property_name, \ +sql_query (no LIMIT), source_id_column="source_id", target_id_column=\ +"target_id", domain, range_class. +6. Terminal — emit no free text after submitting. + +GENERAL RULES +• Only ever pass row-returning queries (SELECT / WITH …) to execute_sql. +• Do not call get_metadata, get_ontology, column_value_overlap, \ +distinct_count, submit_entity_mapping, or submit_source_model — they are \ +not available to you. The slice plus the entity mappings carry everything \ +you need. +• If a retry_hint is present at the top of the user message, treat it as \ +authoritative — your previous attempt failed for the reason stated; do NOT \ +repeat the same mistake. +""" + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _summarise_entity_mapping(em: dict, side: str) -> List[str]: + """One-block textual summary of a previously-produced entity mapping. + + Surfaces exactly the fields the LLM needs to constrain its endpoint + choice: the class_uri, the id_column it locked in, and the SQL it ran. + Anything else (label_column, attribute_mappings, …) is irrelevant to the + relationship task and is intentionally omitted to keep the prompt tight. + """ + em = em or {} + class_uri = ( + em.get("ontology_class") or em.get("class_uri") or em.get("class") or "" + ) + id_column = em.get("id_column", "") + sql_query = em.get("sql_query", "") + return [ + f"{side.upper()} ENTITY MAPPING", + f" class_uri: {class_uri}", + f" id_column: {id_column}", + f" sql: {sql_query}", + ] + + +def _format_join(j: dict) -> str: + """Readable one-line rendering of a join entry from the slice. + + Defensive about missing fields — partial joins still render usefully so + a malformed slice doesn't blow up the prompt build. + """ + from_ref = j.get("from_ref", "?") + to_ref = j.get("to_ref", "?") + kind = j.get("kind", "?") + conf = j.get("confidence") + overlap = j.get("overlap_pct") + extras: List[str] = [] + if conf is not None: + try: + extras.append(f"confidence={float(conf):.2f}") + except (TypeError, ValueError): + extras.append(f"confidence={conf}") + if overlap is not None: + try: + extras.append(f"overlap_pct={float(overlap):.2f}") + except (TypeError, ValueError): + extras.append(f"overlap_pct={overlap}") + suffix = (" — " + ", ".join(extras)) if extras else "" + return f" - {from_ref} -> {to_ref} [{kind}]{suffix}" + + +def _build_user_prompt( + ontology_property: dict, + source_entity_mapping: dict, + target_entity_mapping: dict, + source_model_slice: dict, + retry_hint: Optional[str] = None, +) -> str: + """Render the per-property user prompt. + + Structure: + 1. retry_hint (if any) at the very top + 2. ontology property metadata + 3. source entity mapping summary (class_uri / id_column / sql) + 4. target entity mapping summary + 5. relevant joins (one line per join, readable) + 6. candidate_tables (raw JSON — small) + 7. a reminder block reiterating the two-column / endpoint-match rules + """ + parts: List[str] = [] + + if retry_hint: + parts.append("RETRY HINT (authoritative — your previous attempt FAILED):") + parts.append(retry_hint) + parts.append( + "DO NOT repeat the same column choice. If the hint mentions " + "'dangling' or 'canonical id': sample BOTH the candidate endpoint " + "column AND the entity's id_column, compare actual values, and " + "pick the column whose values overlap. Run the dangling-edge " + "probe (step 4 of WORKFLOW) BEFORE submitting this time.\n" + ) + + prop_uri = ontology_property.get("uri", "") + prop_label = ( + ontology_property.get("label") or ontology_property.get("name", "") + ) + prop_comment = ontology_property.get("comment", "") or "" + prop_domain = ontology_property.get("domain", "") or "" + prop_range = ontology_property.get("range", "") or "" + + parts.append("ONTOLOGY PROPERTY") + parts.append(f" uri: {prop_uri}") + parts.append(f" label: {prop_label}") + if prop_comment: + parts.append(f" comment: {prop_comment}") + parts.append(f" domain: {prop_domain}") + parts.append(f" range: {prop_range}") + + parts.append("") + parts.extend(_summarise_entity_mapping(source_entity_mapping, side="source")) + + parts.append("") + parts.extend(_summarise_entity_mapping(target_entity_mapping, side="target")) + + slice_obj = source_model_slice or {} + joins = slice_obj.get("relevant_joins") or [] + candidates = slice_obj.get("candidate_tables") or [] + + parts.append("") + parts.append("RELEVANT JOINS") + if joins: + for j in joins: + parts.append(_format_join(j)) + else: + parts.append(" (none surfaced by the Planner — fall back to a single-table SELECT if possible)") + + if candidates: + parts.append("") + parts.append("CANDIDATE TABLES") + parts.append(json.dumps(candidates, indent=2, default=str)) + + src_id = (source_entity_mapping or {}).get("id_column", "") + tgt_id = (target_entity_mapping or {}).get("id_column", "") + + parts.append("") + parts.append("REMINDERS (CRITICAL)") + parts.append( + " • The persisted SQL MUST return EXACTLY two columns aliased " + "AS source_id and AS target_id." + ) + parts.append( + f" • source_id values MUST come from the column '{src_id}' (the " + "source entity's id_column) — or be directly transformable into it " + "via a join key in the slice." + ) + parts.append( + f" • target_id values MUST come from the column '{tgt_id}' (the " + "target entity's id_column) — same constraint." + ) + parts.append( + " • Validate with execute_sql, then call submit_relationship_mapping " + "exactly once." + ) + + prompt = "\n".join(parts) + logger.debug( + "_build_user_prompt for property=%s (%d chars):\n%s", + prop_uri, + len(prompt), + prompt, + ) + return prompt + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_relationship_generator") +def run_relationship_generator( + host: str, + token: str, + endpoint_name: str, + client: Any, + *, + ontology_property: dict, + source_entity_mapping: dict, + target_entity_mapping: dict, + source_model_slice: dict, + retry_hint: Optional[str] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: int = MAX_ITERATIONS, +) -> RelationshipGenResult: + """Run the RelationshipGenerator agent for a single ontology property. + + The agent composes a two-column SQL SELECT (``source_id`` / ``target_id``) + that realises the relationship between the source and target entities + using the join-key subgraph in ``source_model_slice``, validates the + SQL via ``execute_sql``, and submits the validated mapping via the + terminal ``submit_relationship_mapping`` tool. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client (must expose ``execute_query(sql)``). + ontology_property: Full dict for the SINGLE property to map (uri, + label, comment, domain, range). + source_entity_mapping: The ALREADY-MAPPED source entity (carries the + ``id_column`` the source endpoint must align with). + target_entity_mapping: The ALREADY-MAPPED target entity (same). + source_model_slice: Filtered SourceModel slice with relevant_joins + and optional candidate_tables. + retry_hint: Optional one-sentence hint from the orchestrator's + previous-attempt evaluation. When present, surfaced at the top + of the user prompt. + on_step: Optional progress callback ``(msg, pct)`` for UI updates. + max_iterations: Upper bound on tool-call iterations (default 12 — + same as the EntityGenerator). + + Returns: + A :class:`RelationshipGenResult`. ``success`` is True iff a mapping + was successfully submitted with the requested ``property_uri``; in + that case ``mapping`` holds the submitted dict. On failure, ``error`` + explains why and ``mapping`` is None. + """ + iteration_limit = max_iterations if max_iterations is not None else MAX_ITERATIONS + + property_uri = (ontology_property or {}).get("uri", "") + property_label = ( + (ontology_property or {}).get("label") + or (ontology_property or {}).get("name", "") + ) + n_joins = len(((source_model_slice or {}).get("relevant_joins") or [])) + n_candidates = len(((source_model_slice or {}).get("candidate_tables") or [])) + + logger.info( + "===== RELATIONSHIP GENERATOR START ===== endpoint=%s, property=%s (%s), " + "joins=%d, candidate_tables=%d, retry_hint=%s, max_iter=%d", + endpoint_name, + property_label, + property_uri, + n_joins, + n_candidates, + "yes" if retry_hint else "no", + iteration_limit, + ) + + ctx = ToolContext( + host=host.rstrip("/"), + token=token, + client=client, + # The slice + entity mappings subsume metadata/ontology for this + # agent; the unified ToolContext still wants these fields, so we + # leave them empty. + metadata={}, + ontology={}, + documents=[], + ) + + result = RelationshipGenResult(success=False) + + user_prompt = _build_user_prompt( + ontology_property or {}, + source_entity_mapping or {}, + target_entity_mapping or {}, + source_model_slice or {}, + retry_hint=retry_hint, + ) + messages: List[dict] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + logger.info( + "RelationshipGenerator conversation initialized: system=%d chars, user=%d chars", + len(SYSTEM_PROMPT), + len(user_prompt), + ) + + total_usage: Dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} + + def _progress_pct(iteration_idx: int) -> int: + ratio = (iteration_idx + 1) / max(iteration_limit, 1) + return min(5 + int(ratio * 90), 95) + + def notify(msg: str, *, pct: Optional[int] = None) -> None: + actual_pct = pct if pct is not None else 5 + logger.info("RELATIONSHIP GEN STEP [%d%%] %s", actual_pct, msg) + if on_step: + on_step(msg, actual_pct) + + notify(f"Generating mapping for {property_label or property_uri}…", pct=1) + + # Snapshot the pre-existing relationship count so we can detect "this + # run added a mapping" without relying on absolute counters. Future-proof + # for an orchestrator that reuses a ToolContext across calls. + pre_run_count = len(ctx.relationships) + + # ------------------------------------------------------------------ + # Agent loop + # ------------------------------------------------------------------ + for iteration in range(iteration_limit): + if iteration > 0: + logger.debug( + "Iteration %d: waiting %ds before LLM call (rate limit mitigation)", + iteration + 1, + _ITERATION_DELAY_SEC, + ) + time.sleep(_ITERATION_DELAY_SEC) + + current_iteration = iteration + 1 + pct = _progress_pct(iteration) + logger.info( + "----- RelationshipGenerator iteration %d/%d — %d messages, mapping=%s -----", + current_iteration, + iteration_limit, + len(messages), + "set" if len(ctx.relationships) > pre_run_count else "unset", + ) + notify( + f"Mapping iteration {current_iteration}/{iteration_limit}…", + pct=pct, + ) + + t0 = time.time() + try: + llm_response = call_serving_endpoint( + host, + token, + endpoint_name, + messages, + tools=TOOL_DEFINITIONS, + max_tokens=_MAX_TOKENS, + temperature=0.1, + timeout=LLM_TIMEOUT, + trace_name=_TRACE_NAME, + ) + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + logger.warning( + "RelationshipGenerator iteration %d: HTTPError status=%s", + current_iteration, + status, + ) + logger.debug( + "RelationshipGenerator iteration %d: HTTPError body: %.500s", + current_iteration, + exc.response.text if exc.response is not None else "N/A", + ) + if exc.response is not None and status in (400, 422): + result.error = "LLM endpoint does not support function calling" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "RelationshipGenerator: endpoint refused tools — cannot produce a mapping" + ) + return result + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "RelationshipGenerator: LLM request failed at iteration %d: %s", + current_iteration, + exc, + ) + return result + except requests.exceptions.ReadTimeout: + result.error = f"LLM request timed out after {LLM_TIMEOUT}s" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "RelationshipGenerator: timeout at iteration %d", current_iteration + ) + return result + except requests.exceptions.RequestException as exc: + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "RelationshipGenerator: request exception at iteration %d: %s", + current_iteration, + exc, + ) + return result + + elapsed_ms = int((time.time() - t0) * 1000) + logger.info( + "RelationshipGenerator iteration %d: LLM responded in %dms", + current_iteration, + elapsed_ms, + ) + + accumulate_usage(total_usage, llm_response.get("usage", {})) + + choice = llm_response.get("choices", [{}])[0] + finish_reason = choice.get("finish_reason", "?") + message = choice.get("message", {}) + tool_calls = message.get("tool_calls", []) + has_content = bool(message.get("content")) + logger.info( + "RelationshipGenerator iteration %d: finish_reason=%s, tool_calls=%d, has_content=%s", + current_iteration, + finish_reason, + len(tool_calls), + has_content, + ) + + if not tool_calls: + # The Generator must terminate via submit_relationship_mapping, + # never via free text. + content = (message.get("content") or "")[:500] + logger.warning( + "RelationshipGenerator iteration %d: produced text without submitting mapping — %d chars", + current_iteration, + len(message.get("content") or ""), + ) + result.steps.append( + RelationshipGenStep( + step_type="output", + content=content, + duration_ms=elapsed_ms, + ) + ) + result.error = "relationship generator produced text without submitting mapping" + result.iterations = current_iteration + result.usage = total_usage + notify( + "Relationship generator produced text without submitting mapping.", + pct=pct, + ) + return result + + logger.info( + "RelationshipGenerator iteration %d: processing %d tool call(s): [%s]", + current_iteration, + len(tool_calls), + ", ".join( + tc.get("function", {}).get("name", "?") for tc in tool_calls + ), + ) + messages.append(message) + + terminal_success = False + for tc_idx, tc in enumerate(tool_calls, 1): + func = tc.get("function", {}) + tool_name = func.get("name", "") + raw_args = func.get("arguments", "{}") + tool_id = tc.get("id", "") + + try: + arguments = json.loads(raw_args) + except json.JSONDecodeError: + arguments = {} + + logger.info( + "RelationshipGenerator iteration %d: calling tool '%s' (%d/%d)", + current_iteration, + tool_name, + tc_idx, + len(tool_calls), + ) + + if tool_name == "submit_relationship_mapping": + notify( + f"Submitting mapping for {property_label or property_uri}…", + pct=pct, + ) + elif tool_name == "sample_table": + fn = arguments.get("full_name", "?") + notify(f"Sampling {fn}…", pct=pct) + elif tool_name == "execute_sql": + sql_preview = arguments.get("sql", "")[:80] + notify(f"Running SQL: {sql_preview}…", pct=pct) + else: + notify(f"Calling {tool_name}…", pct=pct) + + result.steps.append( + RelationshipGenStep( + step_type="tool_call", + content=json.dumps(arguments, default=str)[:500], + tool_name=tool_name, + ) + ) + + t1 = time.time() + tool_result = dispatch_tool( + TOOL_HANDLERS, ctx, tool_name, arguments, trace_name=_TRACE_NAME + ) + tool_ms = int((time.time() - t1) * 1000) + + logger.info( + "RelationshipGenerator iteration %d: tool '%s' returned %d chars in %dms", + current_iteration, + tool_name, + len(tool_result), + tool_ms, + ) + + result.steps.append( + RelationshipGenStep( + step_type="tool_result", + content=( + (tool_result[:500] + "…") + if len(tool_result) > 500 + else tool_result + ), + tool_name=tool_name, + duration_ms=tool_ms, + ) + ) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": tool_result, + } + ) + + # Detect terminal success: submit_relationship_mapping returned + # success=True AND a mapping for THIS property_uri is present in + # ctx.relationships. A submit with a mismatched property_uri is + # NOT terminal — we coach the LLM via a corrective tool message + # and let the loop continue. + if tool_name == "submit_relationship_mapping": + try: + parsed = json.loads(tool_result) + except json.JSONDecodeError: + parsed = {} + if parsed.get("success") is True: + matched = any( + m.get("property") == property_uri + for m in ctx.relationships + ) + if matched: + terminal_success = True + logger.info( + "RelationshipGenerator iteration %d: submit_relationship_mapping succeeded — terminating", + current_iteration, + ) + else: + submitted_uri = arguments.get("property_uri", "") + mismatch_msg = ( + f"submitted property_uri '{submitted_uri}' does " + f"not match requested property_uri " + f"'{property_uri}'; resubmit with " + f"property_uri='{property_uri}'" + ) + logger.warning( + "RelationshipGenerator iteration %d: submit_relationship_mapping " + "property_uri mismatch — submitted=%s, requested=%s", + current_iteration, + submitted_uri, + property_uri, + ) + corrective_payload = json.dumps( + {"success": False, "error": mismatch_msg} + ) + # Replace the recorded tool_result step's content so + # the UI / trace shows the corrective signal. + result.steps[-1] = RelationshipGenStep( + step_type="tool_result", + content=corrective_payload, + tool_name=tool_name, + duration_ms=result.steps[-1].duration_ms, + ) + # Replace the tool message on the conversation so + # the LLM sees the corrective payload next turn. + messages[-1] = { + "role": "tool", + "tool_call_id": tool_id, + "content": corrective_payload, + } + + if terminal_success: + # Pull the mapping for this property by strict URI match. + submitted = next( + ( + m + for m in reversed(ctx.relationships) + if m.get("property") == property_uri + ), + None, + ) + if submitted is None: + result.error = ( + "internal: submit succeeded but mapping not found for property_uri" + ) + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "===== RELATIONSHIP GENERATOR FAILED ===== %s (property=%s)", + result.error, + property_uri, + ) + return result + result.success = True + result.mapping = submitted + result.iterations = current_iteration + result.usage = total_usage + logger.info( + "===== RELATIONSHIP GENERATOR COMPLETE ===== property=%s, iterations=%d, " + "prompt_tokens=%d, completion_tokens=%d", + property_uri, + result.iterations, + total_usage["prompt_tokens"], + total_usage["completion_tokens"], + ) + notify( + f"Mapping for {property_label or property_uri} complete!", pct=100 + ) + return result + + # Budget exhausted without a successful submit. + result.iterations = iteration_limit + result.usage = total_usage + result.error = "relationship generator exhausted iteration budget" + logger.error("===== RELATIONSHIP GENERATOR FAILED ===== %s", result.error) + notify(result.error, pct=95) + return result diff --git a/src/agents/agent_mapping_pge/planner.py b/src/agents/agent_mapping_pge/planner.py new file mode 100644 index 00000000..00403630 --- /dev/null +++ b/src/agents/agent_mapping_pge/planner.py @@ -0,0 +1,729 @@ +""" +OntoBricks Mapping-PGE Planner Agent. + +Sprint 3 of the Planner-Generator-Evaluator (PGE) redesign. + +The Planner is a single-invocation agent (no internal retry loop — re- +invocations come from the orchestrator on Evaluator escalation in Sprint 7). +It consumes the ontology, table metadata, and any imported domain documents, +probes the source data via the planner tools (sample_table, column_value_overlap, +distinct_count) plus the shared tools (get_metadata, get_ontology, +get_documents_context, execute_sql), and emits a validated +:class:`SourceModel` via the ``submit_source_model`` terminal tool. + +The loop semantics mirror the prior single-loop mapping agent — same +``call_serving_endpoint`` + ``dispatch_tool`` ReAct cycle, same 3-second +inter-iteration delay, same accumulated usage tracking, same MLflow trace +decorator — with two key differences: + +* No fallback to single-shot generation. If the endpoint refuses tools, the + Planner returns failure (the Planner *needs* tools — it produces structured + output through ``submit_source_model``). +* Smaller default iteration budget (25 instead of 60) — the Planner is more + focused than the auto-mapping agent. +""" + +import json +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +import requests + +if TYPE_CHECKING: + from agents.agent_mapping_pge.contracts import SourceModel + +from back.core.logging import get_logger +from agents.engine_base import ( + call_serving_endpoint, + dispatch_tool, + accumulate_usage, +) +from agents.tools.context import ToolContext +from agents.tools.documents import ( + GET_DOCUMENTS_CONTEXT_DEF, + tool_get_documents_context, +) +from agents.tools.metadata import ( + GET_METADATA_DEF, + tool_get_metadata, +) +from agents.tools.ontology import ( + ONTOLOGY_TOOL_DEFINITIONS, + ONTOLOGY_TOOL_HANDLERS, +) +from agents.tools.planner import ( + PLANNER_TOOL_DEFINITIONS, + PLANNER_TOOL_HANDLERS, +) +from agents.tools.sql import ( + SQL_TOOL_DEFINITIONS, + SQL_TOOL_HANDLERS, +) +from agents.tracing import trace_agent + +logger = get_logger(__name__) + +MAX_ITERATIONS = 50 +LLM_TIMEOUT = 180 +_ITERATION_DELAY_SEC = 1 + +# The submit_source_model JSON for a real-world ontology can run several KB +# (17+ classes × multiple candidates + canonical_ids + join_keys + plan). +# A small ceiling silently truncates the call (finish_reason=length) and the +# dataclass validation fails with no clue to the LLM as to why. 100k removes +# the practical ceiling for any ontology size; you only pay for tokens +# actually generated, so the cost stays bounded by output complexity. +_MAX_TOKENS = 50000 + +_TRACE_NAME = "mapping_pge_planner" + + +# ===================================================== +# Tool aggregation +# ===================================================== +# +# The Planner uses every read tool the auto-mapping agent has — ontology, +# metadata, documents, execute_sql — *plus* the four planner-specific tools. +# It deliberately does NOT receive ``submit_entity_mapping`` / +# ``submit_relationship_mapping``: those belong to the Generator (Sprints 4 +# and 5). The Planner's only terminal tool is ``submit_source_model``. + +TOOL_DEFINITIONS: List[dict] = ( + [GET_METADATA_DEF, GET_DOCUMENTS_CONTEXT_DEF] + + ONTOLOGY_TOOL_DEFINITIONS + + SQL_TOOL_DEFINITIONS + + PLANNER_TOOL_DEFINITIONS +) + +TOOL_HANDLERS: Dict[str, Callable] = { + "get_metadata": tool_get_metadata, + "get_documents_context": tool_get_documents_context, + **ONTOLOGY_TOOL_HANDLERS, + **SQL_TOOL_HANDLERS, + **PLANNER_TOOL_HANDLERS, +} + + +# ===================================================== +# Data classes +# ===================================================== + + +@dataclass +class PlannerStep: + """One observable step of the Planner's execution. + + Mirrors :class:`agents.engine_base.AgentStep` but is scoped to the Planner + so the orchestrator (Sprint 7) can present a stage-specific timeline in + the UI. + """ + + step_type: str # tool_call | tool_result | output + content: str + tool_name: str = "" + duration_ms: int = 0 + + +@dataclass +class PlannerResult: + """Outcome of a single Planner invocation. + + ``source_model`` is populated only when the LLM successfully called + ``submit_source_model`` with a structurally-valid payload. ``error`` is + the short reason string when ``success`` is ``False``. + """ + + success: bool + source_model: Optional["SourceModel"] = None + steps: List[PlannerStep] = field(default_factory=list) + iterations: int = 0 + error: str = "" + usage: Dict[str, int] = field(default_factory=dict) + + +# ===================================================== +# System prompt +# ===================================================== + +SYSTEM_PROMPT = """\ +You are a senior data architect. Your job is to build a SourceModel that \ +bridges a set of source tables to an OWL ontology, so a downstream Generator \ +agent can mechanically emit entity- and relationship-mapping SQL. + +TOOLS +You have these tools available: + • get_ontology – load classes (with attributes) and object \ +properties to be mapped. + • get_metadata – load imported table schemas (full names, \ +columns, types). + • get_documents_context – load any imported domain documents (glossaries, \ +schema docs). + • sample_table – return up to N random rows so you can see \ +actual values, not just column types. Use when a column's role is unclear \ +from its name/type alone. + • column_value_overlap – measure |distinct(from) ∩ distinct(to)| / \ +|distinct(from)| for two bare COLUMNS. Use to VALIDATE a candidate join key \ +with real data — never propose a join_key on the strength of name similarity \ +alone. + • normalized_value_overlap – the same overlap metric, but each side is a \ +scalar SQL EXPRESSION. This is how you PROVE a canonical-key normalization: \ +when two tables for the same class have 0% raw overlap, propose a \ +normalization expression per table and confirm overlap_pct > 0 here BEFORE \ +you submit. A still-zero result means your expression is wrong — fix it. + • distinct_count – row / distinct / null counts plus is_unique \ +and is_complete flags. Use to confirm a candidate canonical-ID column is \ +actually unique and complete. + • execute_sql – escape hatch for any check the four tools above \ +do not cover. Use sparingly — prefer the focused tools. + • submit_source_model – TERMINAL. Call exactly once, when the \ +SourceModel is complete and you are ready to hand off to the Generator. + +WORKFLOW +1. Call get_ontology AND get_metadata first to see what needs mapping and \ +what data is available. +2. Call get_documents_context to pick up any pre-loaded domain documents — \ +they often disambiguate column semantics. +3. For each table, decide which ontology class(es) it could realise — these \ +become table_roles[].ontology_class_candidates with a confidence and a one- \ +sentence reason. +4. For each ontology class, decide which column serves as its canonical \ +identifier in each table — record under canonical_ids[]. When you are \ +uncertain, run distinct_count to confirm uniqueness/completeness. +5. For each pair of tables that should join (intra-trust FK or cross-source \ +value match), run column_value_overlap and only record join_keys[] when the \ +realised overlap_pct supports it. Use kind="same_trust_fk" for FK joins and \ +kind="cross_source_value_match" for value-matched joins across sources. \ +For any class mapped to 2+ tables, follow CANONICAL-KEY NORMALIZATION below \ +and PROVE the chosen keys overlap with normalized_value_overlap. +6. Build mapping_plan.entity_order so that BASE classes come first \ +(i.e. classes that are referenced by other classes through object properties \ +should be mapped before their referencers). Build \ +mapping_plan.relationship_order so that, by the time each relationship is \ +attempted, BOTH its domain and range classes have already appeared in \ +entity_order. List anything you cannot reasonably map under \ +mapping_plan.skip[] with a short reason. +7. Finally, call submit_source_model exactly once with the full JSON. The \ +call returns success=true when the model is structurally valid; if it \ +returns success=false, fix the indicated problem and call it again. + +CANONICAL-KEY NORMALIZATION (CRITICAL — this is the #1 cause of relationship dangling) +For any class whose canonical_id lists MORE THAN ONE table, run \ +column_value_overlap on a representative column pair to see whether the raw \ +values already share a format: + + • If overlap_pct > 0 → values are in compatible formats. Record bare \ +column names in canonical_column_per_table (e.g. ``"MOTHER_NHS_NO"``). \ +A UNION across the tables produces a coherent ID universe. + + • If overlap_pct == 0 → DO NOT conclude these are "different" or \ +"trust-scoped" entities. When two tables both map to the SAME ontology \ +class, 0% overlap almost always means the SAME real-world key wrapped in \ +DIFFERENT trust-local encodings (prefixes, suffixes, embedded sub-IDs). \ +Leaving them disjoint makes every relationship pointing at this class 100% \ +dangle — that is a FAILURE, not an acceptable outcome. You MUST normalize: + + STEP 1 — sample_table BOTH columns and read the raw values. Look for a \ +shared embedded substring across the trusts — a stable inner identifier \ +(UUID, NHS number, ``...-preg-`` core) that appears in every trust's \ +value with only the surrounding prefix/suffix differing. + + STEP 2 — write ONE scalar SQL expression PER TABLE that strips the \ +trust-specific wrapping and exposes that shared core in an identical form. \ +Prefer extracting the shared core over stripping a single known prefix \ +(extraction is robust to multiple prefixes). When matching a hex/UUID core, \ +ALWAYS anchor the regex with a leading character class so a preceding dash \ +is not captured: + ✗ WRONG: regexp_extract(EPISODE_ID, '([a-f0-9-]+-preg-[0-9]+)', 1) + → returns "--preg-1" (leading dash) — will NOT match + ✓ RIGHT: regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1) + → returns "-preg-1" + + STEP 3 — for a DERIVED / child key (e.g. a Delivery, Baby or Apgar that \ +hangs off a pregnancy), DO NOT concatenate a suffix onto the RAW prefixed \ +local id — that re-introduces the trust prefix and the keys stay disjoint. \ +Extract the shared core FIRST, then append the role suffix, so every trust \ +yields the identical synthetic key: + ✗ WRONG: trust_a "CONCAT(EPISODE_ID, '-del')" (→ STA--preg-1-del) + trust_b "delivery_id" (→ BUH-DEL-BUH--preg-1) + ✓ RIGHT: trust_a "CONCAT(regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-del')" + trust_b "CONCAT(regexp_extract(delivery_id, '([a-f0-9][a-f0-9-]+-preg-[0-9]+)', 1), '-del')" + (both → -preg-1-del) + + STEP 4 — PROVE IT. Call normalized_value_overlap with your two \ +expressions. It MUST return overlap_pct > 0. If it is still 0, your \ +expressions land in different value spaces — go back to STEP 1 and fix them. \ +Do NOT call submit_source_model with an unverified normalization. + + (If, after sampling, a table genuinely cannot expose the shared core at \ +all, omit that table from canonical_column_per_table and note why — but this \ +is rare; exhaust STEP 1–4 first.) + + • Whatever expression you record, the EntityGenerator drops it verbatim \ +into the SELECT aliased AS ID. Bare column names and SQL expressions are \ +both valid here. + + • Always update format_note to one sentence describing what the canonical \ +key looks like (e.g. ``"-preg- core extracted from each \ +trust's local pregnancy id"``). Downstream agents read this. + +SOURCEMODEL JSON SCHEMA (these key names are LOAD-BEARING — do not improvise) +The `model` argument to submit_source_model has exactly this shape: + +{ + "table_roles": [ + { + "table": "", // STRING — required key is "table" + "ontology_class_candidates": [ + {"uri": "", "confidence": 0.0, "reason": ""} + ] + } + ], + "canonical_ids": [ + { + "ontology_class": "", // STRING — required key is "ontology_class" + // VALUES may be either a bare column name OR a SQL expression that + // produces the canonical key for that table. Use a SQL expression + // when raw column values across the listed tables are in different + // formats (see CANONICAL-KEY NORMALIZATION below). + "canonical_column_per_table": {"": ""}, + "format_note": "" + } + ], + "join_keys": [ + { + "from_ref": "
.", // STRING — required key is "from_ref" + "to_ref": "
.", // STRING — required key is "to_ref" + "confidence": 0.0, + "overlap_pct": 0.0, + "kind": "same_trust_fk" // or "cross_source_value_match" + } + ], + "mapping_plan": { + "entity_order": ["", "..."], + "relationship_order": ["", "..."], + "skip": [ + {"item": "", "reason": ""} // required keys: "item", "reason" + ] + } +} + +Key-name traps to avoid: +• Use "table" (not "name", "table_name", "uri") in each table_roles[] entry. +• Use "ontology_class" (not "class", "uri") in each canonical_ids[] entry. +• Use "from_ref" / "to_ref" (not "from" / "to" / "source" / "target") in each join_keys[] entry. +• Use "item" (not "uri", "property") in each mapping_plan.skip[] entry. + +INVARIANTS (the orchestrator will enforce these) +• Every URI in entity_order MUST exist in the ontology AND have at least one \ +candidate in table_roles[].ontology_class_candidates. +• Every URI in relationship_order MUST reference a property whose domain \ +class and range class both appear in entity_order at an EARLIER position. +• All confidence values are floats in [0.0, 1.0]. +• kind on each join_key is EXACTLY one of: "same_trust_fk", \ +"cross_source_value_match". +• Call submit_source_model EXACTLY ONCE, at the end. Do not emit a free-text \ +summary afterwards — submit_source_model is the terminal step. + +GENERAL RULES +• Prefer the focused tools (sample_table, column_value_overlap, \ +normalized_value_overlap, distinct_count) over execute_sql. +• Validate candidate join keys with column_value_overlap before adding them \ +to join_keys[]. +• You may batch multiple independent tool calls in a single response. +• Only ever pass row-returning queries (SELECT / WITH …) to execute_sql. +""" + + +# ===================================================== +# Internal helpers +# ===================================================== + + +def _build_user_prompt( + entities: List[dict], relationships: List[dict], n_tables: int +) -> str: + parts = [ + ( + f"Build a SourceModel for {n_tables} table(s), {len(entities)} ontology " + f"entity/entities, and {len(relationships)} relationship(s). " + "Start by calling get_ontology, get_metadata, and get_documents_context." + ) + ] + if entities: + names = ", ".join(e.get("name", "?") for e in entities) + parts.append(f"Entities in scope: {names}") + if relationships: + names = ", ".join(r.get("name", "?") for r in relationships) + parts.append(f"Relationships in scope: {names}") + prompt = "\n".join(parts) + logger.debug("_build_user_prompt (%d chars):\n%s", len(prompt), prompt) + return prompt + + +# ===================================================== +# Public entry point +# ===================================================== + + +@trace_agent(name="mapping_pge_planner") +def run_planner( + host: str, + token: str, + endpoint_name: str, + client: Any, + metadata: dict, + ontology: dict, + *, + documents: Optional[list] = None, + on_step: Optional[Callable[[str, int], None]] = None, + max_iterations: int = MAX_ITERATIONS, +) -> PlannerResult: + """Run the Planner agent. + + The Planner autonomously produces a :class:`SourceModel` by exploring the + ontology, metadata, documents, and source data via tool calls. It + terminates as soon as it submits a structurally-valid SourceModel via the + terminal ``submit_source_model`` tool. + + Args: + host: Databricks workspace URL. + token: Bearer token for the serving endpoint. + endpoint_name: Foundation Model serving endpoint name. + client: Databricks SQL client (must expose ``execute_query(sql)``). + metadata: Imported domain metadata (``{"tables": [...]}``). + ontology: Imported ontology (``{"entities": [...], "relationships": [...]}``). + documents: Optional pre-loaded domain documents. + on_step: Optional progress callback ``(msg, pct)`` for UI updates. + max_iterations: Upper bound on tool-call iterations (default 25). + + Returns: + A :class:`PlannerResult`. ``success`` is True iff a SourceModel was + successfully submitted; in that case ``source_model`` holds the + validated dataclass. On failure, ``error`` explains why and + ``source_model`` is None. + """ + iteration_limit = max_iterations if max_iterations is not None else MAX_ITERATIONS + + entities = (ontology or {}).get("entities", []) + relationships = (ontology or {}).get("relationships", []) + n_tables = len((metadata or {}).get("tables", [])) + + logger.info( + "===== PLANNER START ===== endpoint=%s, tables=%d, entities=%d, relationships=%d, max_iter=%d", + endpoint_name, + n_tables, + len(entities), + len(relationships), + iteration_limit, + ) + + ctx = ToolContext( + host=host.rstrip("/"), + token=token, + client=client, + metadata=metadata or {}, + ontology=ontology or {}, + documents=list(documents or []), + ) + + result = PlannerResult(success=False) + + user_prompt = _build_user_prompt(entities, relationships, n_tables) + messages: List[dict] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + logger.info( + "Planner conversation initialized: system=%d chars, user=%d chars", + len(SYSTEM_PROMPT), + len(user_prompt), + ) + + total_usage: Dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} + + def _progress_pct(iteration_idx: int) -> int: + # Linear ramp from 5 → 95 across the iteration budget. The terminal + # submit_source_model call is what sets 100. + ratio = (iteration_idx + 1) / max(iteration_limit, 1) + return min(5 + int(ratio * 90), 95) + + def notify(msg: str, *, pct: Optional[int] = None): + actual_pct = pct if pct is not None else 5 + logger.info("PLANNER STEP [%d%%] %s", actual_pct, msg) + if on_step: + on_step(msg, actual_pct) + + notify("Starting planner…", pct=1) + + # ------------------------------------------------------------------ + # Agent loop + # ------------------------------------------------------------------ + for iteration in range(iteration_limit): + # Rate-limit mitigation — same 3s delay as the legacy mapping agent. + if iteration > 0: + logger.debug( + "Iteration %d: waiting %ds before LLM call (rate limit mitigation)", + iteration + 1, + _ITERATION_DELAY_SEC, + ) + time.sleep(_ITERATION_DELAY_SEC) + + current_iteration = iteration + 1 + pct = _progress_pct(iteration) + logger.info( + "----- Planner iteration %d/%d — %d messages, source_model=%s -----", + current_iteration, + iteration_limit, + len(messages), + "set" if ctx.source_model is not None else "unset", + ) + notify(f"Planning iteration {current_iteration}/{iteration_limit}…", pct=pct) + + t0 = time.time() + try: + llm_response = call_serving_endpoint( + host, + token, + endpoint_name, + messages, + tools=TOOL_DEFINITIONS, + max_tokens=_MAX_TOKENS, + temperature=0.1, + timeout=LLM_TIMEOUT, + trace_name=_TRACE_NAME, + ) + except requests.exceptions.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else "?" + logger.warning( + "Planner iteration %d: HTTPError status=%s", current_iteration, status + ) + logger.debug( + "Planner iteration %d: HTTPError body: %.500s", + current_iteration, + exc.response.text if exc.response is not None else "N/A", + ) + # Tools are non-negotiable for the Planner — no single-shot fallback. + if exc.response is not None and status in (400, 422): + result.error = "LLM endpoint does not support function calling" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Planner: endpoint refused tools — cannot produce a SourceModel" + ) + return result + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Planner: LLM request failed at iteration %d: %s", + current_iteration, + exc, + ) + return result + except requests.exceptions.ReadTimeout: + result.error = f"LLM request timed out after {LLM_TIMEOUT}s" + result.iterations = current_iteration + result.usage = total_usage + logger.error("Planner: timeout at iteration %d", current_iteration) + return result + except requests.exceptions.RequestException as exc: + result.error = f"LLM request failed: {exc}" + result.iterations = current_iteration + result.usage = total_usage + logger.error( + "Planner: request exception at iteration %d: %s", + current_iteration, + exc, + ) + return result + + elapsed_ms = int((time.time() - t0) * 1000) + logger.info( + "Planner iteration %d: LLM responded in %dms", current_iteration, elapsed_ms + ) + + accumulate_usage(total_usage, llm_response.get("usage", {})) + + choice = llm_response.get("choices", [{}])[0] + finish_reason = choice.get("finish_reason", "?") + message = choice.get("message", {}) + tool_calls = message.get("tool_calls", []) + has_content = bool(message.get("content")) + logger.info( + "Planner iteration %d: finish_reason=%s, tool_calls=%d, has_content=%s", + current_iteration, + finish_reason, + len(tool_calls), + has_content, + ) + # A tool call truncated by the max_tokens ceiling produces malformed + # arguments and the tool can't recover. Flag it loudly so future runs + # don't silently waste iterations resubmitting the same broken JSON. + if finish_reason == "length" and tool_calls: + logger.error( + "Planner iteration %d: finish_reason=length on a tool call — " + "arguments were likely truncated. Consider bumping max_tokens.", + current_iteration, + ) + + if not tool_calls: + # The Planner must end with submit_source_model, not free text. + # If we see text without a terminal call, that's a failure. + content = (message.get("content") or "")[:500] + logger.warning( + "Planner iteration %d: produced text without submitting source model — %d chars", + current_iteration, + len(message.get("content") or ""), + ) + result.steps.append( + PlannerStep( + step_type="output", + content=content, + duration_ms=elapsed_ms, + ) + ) + result.error = "planner produced text without submitting source model" + result.iterations = current_iteration + result.usage = total_usage + notify("Planner produced text without submitting source model.", pct=pct) + return result + + # Tool-call branch — dispatch each call and accumulate steps. + logger.info( + "Planner iteration %d: processing %d tool call(s): [%s]", + current_iteration, + len(tool_calls), + ", ".join( + tc.get("function", {}).get("name", "?") for tc in tool_calls + ), + ) + messages.append(message) + + terminal_success = False + for tc_idx, tc in enumerate(tool_calls, 1): + func = tc.get("function", {}) + tool_name = func.get("name", "") + raw_args = func.get("arguments", "{}") + tool_id = tc.get("id", "") + + try: + arguments = json.loads(raw_args) + except json.JSONDecodeError: + arguments = {} + + logger.info( + "Planner iteration %d: calling tool '%s' (%d/%d)", + current_iteration, + tool_name, + tc_idx, + len(tool_calls), + ) + + # Human-readable progress messages per tool — same pattern as + # the legacy mapping agent for UI consistency. + if tool_name == "submit_source_model": + notify("Submitting source model…", pct=pct) + elif tool_name == "get_metadata": + notify("Retrieving table metadata…", pct=pct) + elif tool_name == "get_ontology": + notify("Retrieving ontology…", pct=pct) + elif tool_name == "get_documents_context": + notify("Retrieving documents…", pct=pct) + elif tool_name == "sample_table": + fn = arguments.get("full_name", "?") + notify(f"Sampling {fn}…", pct=pct) + elif tool_name == "column_value_overlap": + notify("Checking column overlap…", pct=pct) + elif tool_name == "normalized_value_overlap": + notify("Verifying canonical-key normalization…", pct=pct) + elif tool_name == "distinct_count": + notify("Checking distinct count…", pct=pct) + elif tool_name == "execute_sql": + sql_preview = arguments.get("sql", "")[:80] + notify(f"Running SQL: {sql_preview}…", pct=pct) + else: + notify(f"Calling {tool_name}…", pct=pct) + + result.steps.append( + PlannerStep( + step_type="tool_call", + content=json.dumps(arguments, default=str)[:500], + tool_name=tool_name, + ) + ) + + t1 = time.time() + tool_result = dispatch_tool( + TOOL_HANDLERS, ctx, tool_name, arguments, trace_name=_TRACE_NAME + ) + tool_ms = int((time.time() - t1) * 1000) + + logger.info( + "Planner iteration %d: tool '%s' returned %d chars in %dms", + current_iteration, + tool_name, + len(tool_result), + tool_ms, + ) + + result.steps.append( + PlannerStep( + step_type="tool_result", + content=( + (tool_result[:500] + "…") + if len(tool_result) > 500 + else tool_result + ), + tool_name=tool_name, + duration_ms=tool_ms, + ) + ) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "content": tool_result, + } + ) + + # Detect terminal success: submit_source_model returned success=True + # *and* stamped a SourceModel onto the context. We break *after* + # appending the tool result so the orchestrator sees a complete + # message trail in conversation/replay. + if tool_name == "submit_source_model": + try: + parsed = json.loads(tool_result) + except json.JSONDecodeError: + parsed = {} + if parsed.get("success") is True and ctx.source_model is not None: + terminal_success = True + logger.info( + "Planner iteration %d: submit_source_model succeeded — terminating", + current_iteration, + ) + + if terminal_success: + result.success = True + result.source_model = ctx.source_model + result.iterations = current_iteration + result.usage = total_usage + logger.info( + "===== PLANNER COMPLETE ===== iterations=%d, " + "prompt_tokens=%d, completion_tokens=%d", + result.iterations, + total_usage["prompt_tokens"], + total_usage["completion_tokens"], + ) + notify("Planner completed!", pct=100) + return result + + # Exhausted the iteration budget without ever calling submit_source_model + # successfully (or the LLM kept calling other tools forever). + result.iterations = iteration_limit + result.usage = total_usage + result.error = "planner exhausted iteration budget without submitting source model" + logger.error("===== PLANNER FAILED ===== %s", result.error) + notify(result.error, pct=95) + return result diff --git a/src/agents/tools/context.py b/src/agents/tools/context.py index 3b88df82..6edf3919 100644 --- a/src/agents/tools/context.py +++ b/src/agents/tools/context.py @@ -6,7 +6,10 @@ """ from dataclasses import dataclass, field -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from agents.agent_mapping_pge.contracts import EvalReport, SourceModel @dataclass @@ -53,3 +56,13 @@ class ToolContext: dtwin_registry_params: dict = field(default_factory=dict) dtwin_domain_name: str = "" dtwin_ontology_labels: dict = field(default_factory=dict) # uri/name → display label + + # Mapping PGE planner output (``agent_mapping_pge``) – populated by the + # ``submit_source_model`` terminal tool. Forward-ref string typing avoids a + # circular import between ``agents.tools`` and ``agents.agent_mapping_pge``. + source_model: Optional["SourceModel"] = None + + # Mapping PGE semantic critic output (``agent_mapping_pge``) – populated by + # the ``submit_evaluation`` terminal tool of the Sprint 6 Critic agent. + # Same forward-ref pattern as ``source_model`` to avoid a circular import. + semantic_eval_report: Optional["EvalReport"] = None diff --git a/src/agents/tools/evaluation.py b/src/agents/tools/evaluation.py new file mode 100644 index 00000000..e4f62f94 --- /dev/null +++ b/src/agents/tools/evaluation.py @@ -0,0 +1,205 @@ +"""Terminal tool for the mapping-PGE Semantic Critic (Sprint 6). + +The Critic audits ONE submitted mapping for semantic correctness after the +deterministic (stage-1) evaluator has already passed. It submits its verdict +through ``submit_evaluation`` — the terminal tool defined here — which +constructs an :class:`EvalReport` (stage="semantic") and stamps it onto +``ctx.semantic_eval_report``. + +This module deliberately mirrors the shape of the other terminal tools +(``submit_source_model``, ``submit_entity_mapping``, …) — pure-Python handler +with a JSON-schema definition for OpenAI function calling, exported via +``EVALUATION_TOOL_DEFINITIONS`` / ``EVALUATION_TOOL_HANDLERS`` aggregates. +""" + +import json +from typing import Callable, Dict, List, Optional + +from back.core.logging import get_logger +from agents.tools.context import ToolContext + +logger = get_logger(__name__) + + +# ===================================================== +# OpenAI function-calling definition +# ===================================================== + +SUBMIT_EVALUATION_DEF: dict = { + "type": "function", + "function": { + "name": "submit_evaluation", + "description": ( + "Submit the final semantic evaluation. Terminal tool — call exactly once " + "when you have a confident verdict. status MUST be 'PASS' or 'FAIL'. " + "If failing, populate failures[] with at least one entry. " + "Set bubble_to_planner=true ONLY when the wrong TABLE was chosen " + "(not just a wrong column within the right table)." + ), + "parameters": { + "type": "object", + "properties": { + "status": {"type": "string", "enum": ["PASS", "FAIL"]}, + "failures": { + "type": "array", + "items": { + "type": "object", + "properties": { + "check": {"type": "string"}, + "expected": {"type": "string"}, + "observed": {"type": "string"}, + "hint": {"type": "string"}, + }, + "required": ["check", "expected", "observed", "hint"], + }, + "description": "Empty when status is PASS.", + }, + "bubble_to_planner": {"type": "boolean"}, + "reasoning": { + "type": "string", + "description": "One-paragraph summary of the audit reasoning.", + }, + }, + "required": ["status"], + }, + }, +} + + +# ===================================================== +# Handler +# ===================================================== + + +def tool_submit_evaluation( + ctx: ToolContext, + *, + status: str = "", + failures: Optional[list] = None, + bubble_to_planner: bool = False, + reasoning: str = "", + **_kwargs, +) -> str: + """Construct an EvalReport from the critic's submission and store on ctx. + + Contract: + * ``status`` MUST be one of ``"PASS"`` or ``"FAIL"`` — anything else is + rejected as a JSON error so the agent loop can coach the LLM and + continue (it does NOT terminate the loop). + * On ``FAIL`` with an empty ``failures`` list, a generic + ``semantic_audit`` failure is synthesised so the resulting report is + coherent (status=FAIL <=> failures non-empty, matching + :func:`evaluator.report.build_report` semantics). + * ``bubble_to_planner=True`` is demoted to False when status is PASS — + same invariant the deterministic evaluator's :func:`build_report` + enforces (a passing evaluation should not escalate). + """ + logger.info( + "tool_submit_evaluation: status=%s, failures=%d, bubble=%s, reasoning=%d chars", + status, + len(failures or []), + bubble_to_planner, + len(reasoning or ""), + ) + + if status not in ("PASS", "FAIL"): + logger.warning("tool_submit_evaluation: invalid status=%r", status) + return json.dumps( + { + "success": False, + "error": f"invalid status: {status!r} (must be PASS or FAIL)", + } + ) + + # Lazy import — these contracts live in agent_mapping_pge and importing + # them at module load time would create a cycle through + # ``agents.tools.context``. + from agents.agent_mapping_pge.contracts import EvalFailure, EvalReport + + eval_failures: List[EvalFailure] = [] + for f in failures or []: + if not isinstance(f, dict): + continue + eval_failures.append( + EvalFailure( + kind="semantic", + check=str(f.get("check") or ""), + expected=str(f.get("expected") or ""), + observed=str(f.get("observed") or ""), + hint=str(f.get("hint") or ""), + ) + ) + + # status=PASS <=> failures empty. If the LLM submitted both, clamp the + # failures list and warn — keeping a passing report internally coherent. + if status == "PASS" and eval_failures: + logger.warning( + "tool_submit_evaluation: status=PASS with %d failures — clamping to []", + len(eval_failures), + ) + eval_failures = [] + + # If status=FAIL but no failures, synthesise a generic one so the report + # is coherent (status=FAIL <=> failures non-empty). + if status == "FAIL" and not eval_failures: + logger.debug( + "tool_submit_evaluation: synthesising semantic_audit failure for " + "FAIL with no failures[]" + ) + eval_failures.append( + EvalFailure( + kind="semantic", + check="semantic_audit", + expected="PASS", + observed="FAIL", + hint=reasoning or "critic returned FAIL without specific failures", + ) + ) + + # If status=PASS but bubble flag is True, demote — matches + # ``build_report``'s behaviour and the documented invariant: a passing + # evaluation does not escalate to the Planner. + if status == "PASS" and bubble_to_planner: + logger.warning( + "tool_submit_evaluation: bubble_to_planner=True with status=PASS — " + "demoting to False" + ) + bubble_to_planner = False + + metrics: Dict[str, str] = {"reasoning": reasoning} if reasoning else {} + + report = EvalReport( + status=status, + stage="semantic", + metrics=metrics, + failures=eval_failures, + bubble_to_planner=bool(bubble_to_planner), + ) + ctx.semantic_eval_report = report + + logger.info( + "tool_submit_evaluation: stored EvalReport status=%s, failures=%d, bubble=%s", + report.status, + len(report.failures), + report.bubble_to_planner, + ) + + return json.dumps( + { + "success": True, + "status": status, + "failures": len(eval_failures), + "bubble_to_planner": report.bubble_to_planner, + } + ) + + +# ===================================================== +# Aggregates +# ===================================================== + +EVALUATION_TOOL_DEFINITIONS: List[dict] = [SUBMIT_EVALUATION_DEF] + +EVALUATION_TOOL_HANDLERS: Dict[str, Callable] = { + "submit_evaluation": tool_submit_evaluation, +} diff --git a/src/agents/tools/mapping.py b/src/agents/tools/mapping.py index f2eaa19a..87dbecc0 100644 --- a/src/agents/tools/mapping.py +++ b/src/agents/tools/mapping.py @@ -28,9 +28,20 @@ def tool_submit_entity_mapping( id_column: str = "", label_column: str = "", attribute_mappings: Optional[dict] = None, + unmapped_attributes: Optional[list] = None, **_kwargs, ) -> str: - """Record a completed entity mapping.""" + """Record a completed entity mapping. + + ``unmapped_attributes`` lets the Generator stage declare ontology attributes + it intentionally did not map to a column, with a one-sentence ``reason``. + Items may be either bare strings (attribute name only) or dicts of shape + ``{"name": str, "reason": str}`` — the richer dict form is preferred for + downstream consumption but bare strings round-trip too. Anything else is + coerced to a string for safety. This enforces the PGE "no silent drops" + invariant: every ontology attribute is either in ``attribute_mappings`` or + in ``unmapped_attributes``. + """ logger.info("tool_submit_entity_mapping: '%s' (uri=%s)", class_name, class_uri) if not class_uri or not sql_query: logger.warning("tool_submit_entity_mapping: missing required fields") @@ -42,6 +53,22 @@ def tool_submit_entity_mapping( .rstrip(";") ) + # Normalise ``unmapped_attributes`` — accept either form, persist as-is for + # dicts, leave bare strings as strings (validation/coverage is downstream). + normalised_unmapped: List = [] + for item in unmapped_attributes or []: + if isinstance(item, dict) and "name" in item: + normalised_unmapped.append( + { + "name": str(item.get("name", "")), + "reason": str(item.get("reason", "")), + } + ) + elif isinstance(item, str): + normalised_unmapped.append(item) + else: + normalised_unmapped.append(str(item)) + mapping = { "ontology_class": class_uri, "class_name": class_name, @@ -49,6 +76,7 @@ def tool_submit_entity_mapping( "id_column": id_column, "label_column": label_column, "attribute_mappings": attribute_mappings or {}, + "unmapped_attributes": normalised_unmapped, } logger.debug( @@ -78,12 +106,14 @@ def tool_submit_entity_mapping( logger.debug("tool_submit_entity_mapping: appended new mapping") mapped_attrs = len(mapping["attribute_mappings"]) + unmapped_count = len(mapping["unmapped_attributes"]) logger.info( - "tool_submit_entity_mapping: '%s' recorded — ID=%s, Label=%s, %d attr(s) mapped", + "tool_submit_entity_mapping: '%s' recorded — ID=%s, Label=%s, %d attr(s) mapped, %d unmapped", class_name, id_column, label_column, mapped_attrs, + unmapped_count, ) return json.dumps( { @@ -92,6 +122,7 @@ def tool_submit_entity_mapping( "id_column": id_column, "label_column": label_column, "attributes_mapped": mapped_attrs, + "attributes_unmapped": unmapped_count, "total_entity_mappings": len(ctx.entity_mappings), } ) @@ -243,6 +274,30 @@ def _extract_label(value: str) -> str: ), "additionalProperties": {"type": "string"}, }, + "unmapped_attributes": { + "type": "array", + "description": ( + "Ontology attributes you intentionally did NOT map to a column, " + "each with a one-sentence reason. Use this to satisfy the " + 'no-silent-drops invariant. Preferred shape: ' + '[{"name": "apgarScore", "reason": "absent from source table"}]. ' + "Bare strings are also accepted but discouraged." + ), + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Ontology attribute name.", + }, + "reason": { + "type": "string", + "description": "Why this attribute was not mapped.", + }, + }, + "required": ["name", "reason"], + }, + }, }, "required": [ "class_uri", @@ -279,11 +334,19 @@ def _extract_label(value: str) -> str: }, "source_id_column": { "type": "string", - "description": "Column name for the source entity identifier.", + "description": ( + "The output-column alias in sql_query that holds the " + "source id — alias it AS source_id and pass " + '"source_id" here (NOT the entity\'s id_column).' + ), }, "target_id_column": { "type": "string", - "description": "Column name for the target entity identifier.", + "description": ( + "The output-column alias in sql_query that holds the " + "target id — alias it AS target_id and pass " + '"target_id" here (NOT the entity\'s id_column).' + ), }, "domain": { "type": "string", @@ -317,3 +380,10 @@ def _extract_label(value: str) -> str: "submit_entity_mapping": tool_submit_entity_mapping, "submit_relationship_mapping": tool_submit_relationship_mapping, } + +# Name-indexed view of MAPPING_TOOL_DEFINITIONS so callers needing a single +# definition (e.g. the EntityGenerator, which only exposes the entity submit +# tool) can look it up by name without re-scanning the list. +MAPPING_TOOL_DEFINITIONS_BY_NAME: Dict[str, dict] = { + d["function"]["name"]: d for d in MAPPING_TOOL_DEFINITIONS +} diff --git a/src/agents/tools/planner.py b/src/agents/tools/planner.py new file mode 100644 index 00000000..a1de9af9 --- /dev/null +++ b/src/agents/tools/planner.py @@ -0,0 +1,707 @@ +""" +Planner tools – used by the mapping-PGE Planner agent (Sprint 2+). + +Exposes the OpenAI function-calling tools that let the Planner LLM probe +source tables and submit a validated ``SourceModel`` artefact: + +* ``sample_table`` — N random rows from a table (n capped at 100). +* ``column_value_overlap`` — one-sided distinct-value overlap between two columns. +* ``normalized_value_overlap`` — same metric, but each side is a scalar SQL + expression, so canonical-key normalizations can be proven before commit. +* ``distinct_count`` — uniqueness / completeness of a candidate canonical id. +* ``submit_source_model`` — terminal tool: validates the candidate SourceModel + JSON against :class:`agents.agent_mapping_pge.contracts.SourceModel` and stores + the dataclass instance on :attr:`ToolContext.source_model`. + +All handlers return JSON strings (same convention as ``agents.tools.sql``) +and stringify scalar values for the LLM-facing surface. +""" + +import json +import re +from typing import Any, Callable, Dict, List, Optional, Tuple + +from back.core.logging import get_logger +from agents.tools.context import ToolContext + +logger = get_logger(__name__) + + +# Cap on ``n`` in ``sample_table`` to keep the LLM context bounded. +_SAMPLE_TABLE_MAX_N = 100 +_SAMPLE_TABLE_DEFAULT_N = 20 + + +# Permissive but injection-safe SQL identifier shape. We allow dots (for +# fully-qualified ``catalog.schema.table``) and backticks (for quoted +# identifiers), plus the usual alphanumerics + underscore. Anything else +# — semicolons, whitespace, quotes, comment markers — is rejected. +_IDENTIFIER_RE = re.compile(r"^[A-Za-z0-9_.`]+$") + + +# SQL keywords whose presence in a "normalization expression" indicates the +# string is no longer a scalar expression but a smuggled clause / subquery / +# DDL. A legitimate canonical-key expression (regexp_extract, regexp_replace, +# concat, substring, lower, upper, trim, coalesce, ||, string literals) needs +# none of these. Matched case-insensitively as whole words. +_EXPR_FORBIDDEN_WORDS = frozenset( + { + "select", + "from", + "where", + "join", + "union", + "intersect", + "except", + "insert", + "update", + "delete", + "drop", + "alter", + "create", + "grant", + "revoke", + "table", + "into", + "exec", + "execute", + "call", + "merge", + "values", + "having", + "group", + "order", + "limit", + } +) +_EXPR_WORD_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") + + +def _validate_safe_expression(expr: str, *, role: str) -> Optional[str]: + """Return None if ``expr`` is a safe scalar SQL expression; else an error. + + Unlike :func:`_validate_identifier`, this permits the parentheses, commas, + quotes and operators a canonical-key normalization needs (e.g. + ``regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-\\d+)', 1)`` or + ``concat(regexp_extract(delivery_id, '...', 1), '-del')``). It still gets + interpolated into SQL via an f-string, so it is gated against the obvious + injection vectors: statement terminators, comment markers, and any SQL + keyword that would turn the scalar into a clause/subquery/DDL. + """ + if not isinstance(expr, str) or not expr.strip(): + return f"invalid {role}: must be a non-empty string" + if ";" in expr or "--" in expr or "/*" in expr or "*/" in expr: + return ( + f"invalid {role}: must not contain ';' or SQL comment markers " + f"(got {expr!r})" + ) + bad = sorted( + { + w.lower() + for w in _EXPR_WORD_RE.findall(expr) + if w.lower() in _EXPR_FORBIDDEN_WORDS + } + ) + if bad: + return ( + f"invalid {role}: a canonical-key expression must be a single scalar " + f"expression, not a clause/subquery. Forbidden keyword(s): " + f"{', '.join(bad)} (got {expr!r})" + ) + return None + + +def _validate_identifier(name: str, *, role: str) -> Optional[str]: + """Return None if ``name`` is a valid SQL identifier; else an error message. + + Used to gate identifiers that get interpolated into SQL via f-strings. + Even though today's callers are LLMs (not untrusted users), a hallucinated + identifier like ``t; DROP TABLE x`` or ``nhs FROM secrets--`` would + otherwise execute. + """ + if not isinstance(name, str) or not _IDENTIFIER_RE.fullmatch(name): + return f"invalid {role}: {name!r}" + return None + + +def _run_query( + ctx: ToolContext, + sql: str, + *, + tool_name: str, +) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]: + """Execute the SQL via the client. Returns ``(rows, None)`` on success, + ``(None, error_str)`` on failure. On failure the SQL is logged at ERROR + level alongside the exception (previously only at DEBUG). + """ + try: + result = ctx.client.execute_query(sql) + return result, None + except Exception as exc: + logger.error( + "%s: query failed: %s\nSQL: %s", tool_name, exc, sql, exc_info=True + ) + return None, str(exc) + + +# ===================================================== +# Tool implementations +# ===================================================== + + +def tool_sample_table( + ctx: ToolContext, *, full_name: str = "", n: Any = _SAMPLE_TABLE_DEFAULT_N, **_kwargs +) -> str: + """Return N random sample rows from ``full_name`` so the agent can see + real values (not just column types). ``n`` is capped at 100. + """ + logger.info("tool_sample_table: full_name=%s, n=%s", full_name, n) + if not full_name: + return json.dumps({"success": False, "error": "full_name is required"}) + + err = _validate_identifier(full_name, role="full_name") + if err is not None: + return json.dumps({"success": False, "error": err}) + + # Strict ``n`` parsing: a malformed value is a tool-call error, not a + # silent fallback. The default (when ``n`` is omitted) is already the int + # ``_SAMPLE_TABLE_DEFAULT_N``, so ``int(n)`` is a no-op in that case. + try: + n_int = int(n) + except (TypeError, ValueError): + return json.dumps({"success": False, "error": f"invalid n: {n!r}"}) + capped_n = max(1, min(n_int, _SAMPLE_TABLE_MAX_N)) + + sql = f"SELECT * FROM {full_name} ORDER BY RAND() LIMIT {capped_n}" + logger.debug("tool_sample_table: SQL=%s", sql) + + rows, err = _run_query(ctx, sql, tool_name="tool_sample_table") + if err is not None: + return json.dumps({"success": False, "error": err}) + + rows = rows or [] + columns: List[str] = list(rows[0].keys()) if rows else [] + stringified_rows: List[List[Optional[str]]] = [] + for row in rows: + stringified_rows.append( + [str(row[c]) if row.get(c) is not None else None for c in columns] + ) + logger.info( + "tool_sample_table: %d row(s) × %d column(s)", + len(stringified_rows), + len(columns), + ) + return json.dumps( + { + "success": True, + "columns": columns, + "rows": stringified_rows, + "row_count": len(stringified_rows), + } + ) + + +def tool_column_value_overlap( + ctx: ToolContext, + *, + from_table: str = "", + from_column: str = "", + to_table: str = "", + to_column: str = "", + **_kwargs, +) -> str: + """Compute the one-sided overlap + ``|distinct(from) ∩ distinct(to)| / |distinct(from)|``. + + The numerator dedupes ``from`` before intersecting. Returns 0.0 (and a + note) when ``from_distinct_count`` is zero to avoid division by zero. + """ + logger.info( + "tool_column_value_overlap: %s.%s ↔ %s.%s", + from_table, + from_column, + to_table, + to_column, + ) + if not (from_table and from_column and to_table and to_column): + return json.dumps( + { + "success": False, + "error": "from_table, from_column, to_table, to_column are all required", + } + ) + + for value, role in ( + (from_table, "from_table"), + (from_column, "from_column"), + (to_table, "to_table"), + (to_column, "to_column"), + ): + err = _validate_identifier(value, role=role) + if err is not None: + return json.dumps({"success": False, "error": err}) + + sql = ( + "WITH from_distinct AS (" + f" SELECT DISTINCT {from_column} AS v FROM {from_table} " + f" WHERE {from_column} IS NOT NULL" + ")," + " to_distinct AS (" + f" SELECT DISTINCT {to_column} AS v FROM {to_table} " + f" WHERE {to_column} IS NOT NULL" + ")," + " inter AS (" + " SELECT v FROM from_distinct INTERSECT SELECT v FROM to_distinct" + ") " + "SELECT (SELECT COUNT(*) FROM from_distinct) AS from_distinct_count, " + " (SELECT COUNT(*) FROM to_distinct) AS to_distinct_count, " + " (SELECT COUNT(*) FROM inter) AS intersection_count" + ) + logger.debug("tool_column_value_overlap: SQL=%s", sql) + + rows, err = _run_query(ctx, sql, tool_name="tool_column_value_overlap") + if err is not None: + return json.dumps({"success": False, "error": err}) + if not rows: + return json.dumps( + {"success": False, "error": "overlap query returned no rows"} + ) + + row = rows[0] + from_distinct = int(row.get("from_distinct_count", 0) or 0) + to_distinct = int(row.get("to_distinct_count", 0) or 0) + intersection = int(row.get("intersection_count", 0) or 0) + + if from_distinct == 0: + result: Dict[str, Any] = { + "success": True, + "overlap_pct": 0.0, + "from_distinct_count": 0, + "to_distinct_count": to_distinct, + "intersection_count": 0, + "note": ( + f"{from_table}.{from_column} has zero distinct non-null values; " + "overlap_pct defaulted to 0.0 (no division by zero)." + ), + } + else: + result = { + "success": True, + "overlap_pct": intersection / from_distinct, + "from_distinct_count": from_distinct, + "to_distinct_count": to_distinct, + "intersection_count": intersection, + # Symmetric shape with the zero-denom branch: downstream consumers + # can read ``note`` unconditionally. + "note": "", + } + logger.info( + "tool_column_value_overlap: overlap_pct=%.4f (%d/%d)", + result["overlap_pct"], + intersection, + from_distinct, + ) + return json.dumps(result) + + +def tool_normalized_value_overlap( + ctx: ToolContext, + *, + from_table: str = "", + from_expr: str = "", + to_table: str = "", + to_expr: str = "", + **_kwargs, +) -> str: + """Like :func:`tool_column_value_overlap`, but each side is an arbitrary + scalar SQL *expression* rather than a bare column. + + This is the tool the Planner uses to PROVE a canonical-key normalization + works before committing it. When two tables that map to the same ontology + class have 0% raw-column overlap, the values are trust-local encodings of + the same key. The Planner proposes a normalization expression per table + (e.g. ``regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-\\d+)', 1)``) + and calls this tool to confirm the expressions land in a common value + space (overlap_pct > 0). A still-zero overlap means the normalization is + wrong — fix it before submitting. + """ + logger.info( + "tool_normalized_value_overlap: %s[%s] ↔ %s[%s]", + from_table, + from_expr, + to_table, + to_expr, + ) + if not (from_table and from_expr and to_table and to_expr): + return json.dumps( + { + "success": False, + "error": "from_table, from_expr, to_table, to_expr are all required", + } + ) + + for value, role in ((from_table, "from_table"), (to_table, "to_table")): + err = _validate_identifier(value, role=role) + if err is not None: + return json.dumps({"success": False, "error": err}) + for value, role in ((from_expr, "from_expr"), (to_expr, "to_expr")): + err = _validate_safe_expression(value, role=role) + if err is not None: + return json.dumps({"success": False, "error": err}) + + sql = ( + "WITH from_distinct AS (" + f" SELECT DISTINCT {from_expr} AS v FROM {from_table} " + f" WHERE {from_expr} IS NOT NULL AND {from_expr} <> ''" + ")," + " to_distinct AS (" + f" SELECT DISTINCT {to_expr} AS v FROM {to_table} " + f" WHERE {to_expr} IS NOT NULL AND {to_expr} <> ''" + ")," + " inter AS (" + " SELECT v FROM from_distinct INTERSECT SELECT v FROM to_distinct" + ") " + "SELECT (SELECT COUNT(*) FROM from_distinct) AS from_distinct_count, " + " (SELECT COUNT(*) FROM to_distinct) AS to_distinct_count, " + " (SELECT COUNT(*) FROM inter) AS intersection_count" + ) + logger.debug("tool_normalized_value_overlap: SQL=%s", sql) + + rows, err = _run_query(ctx, sql, tool_name="tool_normalized_value_overlap") + if err is not None: + return json.dumps({"success": False, "error": err}) + if not rows: + return json.dumps( + {"success": False, "error": "overlap query returned no rows"} + ) + + row = rows[0] + from_distinct = int(row.get("from_distinct_count", 0) or 0) + to_distinct = int(row.get("to_distinct_count", 0) or 0) + intersection = int(row.get("intersection_count", 0) or 0) + + if from_distinct == 0: + result: Dict[str, Any] = { + "success": True, + "overlap_pct": 0.0, + "from_distinct_count": 0, + "to_distinct_count": to_distinct, + "intersection_count": 0, + "note": ( + f"{from_expr} over {from_table} produced zero distinct non-empty " + "values; the expression likely does not match the data — revise it." + ), + } + else: + result = { + "success": True, + "overlap_pct": intersection / from_distinct, + "from_distinct_count": from_distinct, + "to_distinct_count": to_distinct, + "intersection_count": intersection, + "note": "", + } + logger.info( + "tool_normalized_value_overlap: overlap_pct=%.4f (%d/%d)", + result["overlap_pct"], + intersection, + from_distinct, + ) + return json.dumps(result) + + +def tool_distinct_count( + ctx: ToolContext, *, full_name: str = "", column: str = "", **_kwargs +) -> str: + """Report row / distinct / null counts for ``full_name.column`` and + derive ``is_unique`` and ``is_complete`` flags. + + * ``is_unique = distinct_count == row_count - null_count`` — i.e. the + non-null subset has no duplicates. + * ``is_complete = null_count == 0`` — no missing values. + """ + logger.info("tool_distinct_count: %s.%s", full_name, column) + if not (full_name and column): + return json.dumps( + {"success": False, "error": "full_name and column are required"} + ) + + for value, role in ((full_name, "full_name"), (column, "column")): + err = _validate_identifier(value, role=role) + if err is not None: + return json.dumps({"success": False, "error": err}) + + sql = ( + f"SELECT COUNT(*) AS row_count, " + f" COUNT(DISTINCT {column}) AS distinct_count, " + f" COUNT(*) - COUNT({column}) AS null_count " + f"FROM {full_name}" + ) + logger.debug("tool_distinct_count: SQL=%s", sql) + + rows, err = _run_query(ctx, sql, tool_name="tool_distinct_count") + if err is not None: + return json.dumps({"success": False, "error": err}) + if not rows: + return json.dumps( + {"success": False, "error": "distinct_count query returned no rows"} + ) + + row = rows[0] + row_count = int(row.get("row_count", 0) or 0) + distinct_count = int(row.get("distinct_count", 0) or 0) + null_count = int(row.get("null_count", 0) or 0) + non_null_rows = row_count - null_count + + result = { + "success": True, + "row_count": row_count, + "distinct_count": distinct_count, + "null_count": null_count, + "is_unique": distinct_count == non_null_rows, + "is_complete": null_count == 0, + } + logger.info( + "tool_distinct_count: rows=%d, distinct=%d, nulls=%d, unique=%s, complete=%s", + row_count, + distinct_count, + null_count, + result["is_unique"], + result["is_complete"], + ) + return json.dumps(result) + + +def tool_submit_source_model( + ctx: ToolContext, *, model: Optional[dict] = None, **_kwargs +) -> str: + """Terminal Planner tool: validate ``model`` against + :class:`SourceModel` and stash the dataclass on ``ctx.source_model``. + + Only structural validity is checked here (does ``SourceModel.from_dict`` + succeed?). Semantic checks — e.g. coverage against the live ontology — + are the orchestrator's responsibility. + """ + # Local import to keep ``agents.tools`` importable without + # ``agents.agent_mapping_pge`` (avoids circular imports during pkg init). + from agents.agent_mapping_pge.contracts import SourceModel + + logger.info("tool_submit_source_model: validating candidate model") + if model is None or not isinstance(model, dict): + return json.dumps( + {"success": False, "error": "model must be a JSON object"} + ) + + try: + source_model = SourceModel.from_dict(model) + except (KeyError, TypeError, ValueError) as exc: + # ``KeyError`` for missing required fields; ``TypeError`` / ``ValueError`` + # for bad coercions (e.g. confidence not float-parseable). + logger.warning( + "tool_submit_source_model: validation failed: %s: %s", + type(exc).__name__, + exc, + ) + return json.dumps( + { + "success": False, + "error": f"SourceModel validation failed: {type(exc).__name__}: {exc}", + } + ) + + ctx.source_model = source_model + summary = { + "table_roles": len(source_model.table_roles), + "canonical_ids": len(source_model.canonical_ids), + "join_keys": len(source_model.join_keys), + "entity_order_len": len(source_model.mapping_plan.entity_order), + "relationship_order_len": len(source_model.mapping_plan.relationship_order), + } + logger.info("tool_submit_source_model: stored — %s", summary) + return json.dumps({"success": True, "summary": summary}) + + +# ===================================================== +# OpenAI function-calling definitions +# ===================================================== + + +SAMPLE_TABLE_DEF: dict = { + "type": "function", + "function": { + "name": "sample_table", + "description": ( + "Return up to N random sample rows from a table so you can see actual values " + "(not just column types). n defaults to 20 and is capped at 100." + ), + "parameters": { + "type": "object", + "properties": { + "full_name": { + "type": "string", + "description": "Fully-qualified table name (catalog.schema.table).", + }, + "n": { + "type": "integer", + "description": "Sample size (default 20, max 100).", + }, + }, + "required": ["full_name"], + }, + }, +} + + +COLUMN_VALUE_OVERLAP_DEF: dict = { + "type": "function", + "function": { + "name": "column_value_overlap", + "description": ( + "Compute the one-sided overlap |distinct(from) ∩ distinct(to)| / |distinct(from)|. " + "Use this to validate a candidate join key before committing it to the SourceModel." + ), + "parameters": { + "type": "object", + "properties": { + "from_table": { + "type": "string", + "description": "Fully-qualified source table.", + }, + "from_column": { + "type": "string", + "description": "Column on the source side (numerator denominator).", + }, + "to_table": { + "type": "string", + "description": "Fully-qualified target table.", + }, + "to_column": { + "type": "string", + "description": "Column on the target side.", + }, + }, + "required": ["from_table", "from_column", "to_table", "to_column"], + }, + }, +} + + +NORMALIZED_VALUE_OVERLAP_DEF: dict = { + "type": "function", + "function": { + "name": "normalized_value_overlap", + "description": ( + "Same overlap metric as column_value_overlap, but each side is a " + "scalar SQL EXPRESSION instead of a bare column. Use this to PROVE a " + "canonical-key normalization before committing it: when two tables " + "that map to the same ontology class have 0% raw-column overlap, " + "propose a normalization expression per table (e.g. " + "regexp_extract(EPISODE_ID, '([a-f0-9][a-f0-9-]+-preg-\\d+)', 1)) and " + "call this to confirm overlap_pct > 0. A still-zero result means the " + "expression is wrong — fix it before submit_source_model. Expressions " + "must be a single scalar (functions/literals/operators only); " + "subqueries and SQL keywords are rejected." + ), + "parameters": { + "type": "object", + "properties": { + "from_table": { + "type": "string", + "description": "Fully-qualified source table.", + }, + "from_expr": { + "type": "string", + "description": ( + "Scalar SQL expression over the source table that " + "produces the canonical key (e.g. a regexp_extract / " + "concat). Bare column names are also accepted." + ), + }, + "to_table": { + "type": "string", + "description": "Fully-qualified target table.", + }, + "to_expr": { + "type": "string", + "description": "Scalar SQL expression over the target table.", + }, + }, + "required": ["from_table", "from_expr", "to_table", "to_expr"], + }, + }, +} + + +DISTINCT_COUNT_DEF: dict = { + "type": "function", + "function": { + "name": "distinct_count", + "description": ( + "Report row_count / distinct_count / null_count for a column, with is_unique " + "and is_complete flags. Use this to vet a candidate canonical-ID column." + ), + "parameters": { + "type": "object", + "properties": { + "full_name": { + "type": "string", + "description": "Fully-qualified table name (catalog.schema.table).", + }, + "column": { + "type": "string", + "description": "Column to characterise.", + }, + }, + "required": ["full_name", "column"], + }, + }, +} + + +SUBMIT_SOURCE_MODEL_DEF: dict = { + "type": "function", + "function": { + "name": "submit_source_model", + "description": ( + "Terminal Planner tool. Submit the final SourceModel JSON (matching " + "SourceModel.to_dict() shape). Validates the structure and stores the " + "dataclass on the ToolContext for the Generator stage to consume." + ), + "parameters": { + "type": "object", + "properties": { + "model": { + "type": "object", + "description": ( + "JSON-encoded SourceModel with table_roles, canonical_ids, " + "join_keys, and mapping_plan." + ), + } + }, + "required": ["model"], + }, + }, +} + + +# ===================================================== +# Aggregate exports +# ===================================================== + + +PLANNER_TOOL_DEFINITIONS: List[dict] = [ + SAMPLE_TABLE_DEF, + COLUMN_VALUE_OVERLAP_DEF, + NORMALIZED_VALUE_OVERLAP_DEF, + DISTINCT_COUNT_DEF, + SUBMIT_SOURCE_MODEL_DEF, +] + + +PLANNER_TOOL_HANDLERS: Dict[str, Callable] = { + "sample_table": tool_sample_table, + "column_value_overlap": tool_column_value_overlap, + "normalized_value_overlap": tool_normalized_value_overlap, + "distinct_count": tool_distinct_count, + "submit_source_model": tool_submit_source_model, +} diff --git a/src/back/core/agents/AgentClient.py b/src/back/core/agents/AgentClient.py index 5d7ab263..3162c7a3 100644 --- a/src/back/core/agents/AgentClient.py +++ b/src/back/core/agents/AgentClient.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from agents.agent_owl_generator.engine import AgentResult from agents.agent_auto_assignment.engine import AgentResult as AutoAssignAgentResult + from agents.agent_mapping_pge.engine import AgentResult as MappingPGEAgentResult from agents.agent_auto_icon_assign.engine import ( AgentResult as IconAssignAgentResult, ) @@ -131,6 +132,67 @@ def run_auto_assignment( max_iterations=max_iterations, ) + def run_mapping_pge( + self, + *, + host: str, + token: str, + endpoint_name: str, + client: Any, + metadata: Any, + ontology: Any, + entity_mappings: Any, + relationship_mappings: Any, + documents: Any = None, + on_step: Optional[Callable] = None, + max_iterations: int = 10, + ) -> "MappingPGEAgentResult": + """Propose entity and relationship SQL mappings using the mapping PGE agent. + + This is the Planner–Generator–Evaluator (PGE) mapping engine + (``agents.agent_mapping_pge``). It shares the same call signature as + :meth:`run_auto_assignment` but drives a deterministic, coverage-checked + loop with a semantic critic and a structured run log. + + Args: + host: Databricks workspace host (with or without ``https://``). + token: Bearer token for the workspace APIs. + endpoint_name: Model serving endpoint name for the agent. + client: SQL client (typically :class:`~back.core.databricks.DatabricksClient`) + used to validate or sample queries against the configured warehouse. + metadata: Schema context (for example UC table metadata) for the agent. + ontology: Ontology dict describing classes and properties to map. + entity_mappings: Existing or partial entity mapping list for the agent + to refine or extend. + relationship_mappings: Existing or partial relationship mapping list. + documents: Optional list of document dicts (``name``, ``content``) for + grounding. + on_step: Optional progress callback invoked by the agent loop. + max_iterations: Upper bound on agent refinement iterations. + + Returns: + Structured result from ``agents.agent_mapping_pge`` describing + proposed mappings, per-item status, and PGE diagnostics. + + Raises: + Exception: Propagates any failure raised by ``run_agent``. + """ + from agents.agent_mapping_pge import run_agent + + return run_agent( + host=host, + token=token, + endpoint_name=endpoint_name, + client=client, + metadata=metadata, + ontology=ontology, + entity_mappings=entity_mappings, + relationship_mappings=relationship_mappings, + documents=documents, + on_step=on_step, + max_iterations=max_iterations, + ) + def run_icon_assign( self, *, diff --git a/src/back/objects/mapping/Mapping.py b/src/back/objects/mapping/Mapping.py index 13ed00fe..7e40ea35 100644 --- a/src/back/objects/mapping/Mapping.py +++ b/src/back/objects/mapping/Mapping.py @@ -27,7 +27,7 @@ _MAX_DOC_CHARS = 50_000 if TYPE_CHECKING: - from agents.agent_auto_assignment.engine import AgentResult as AutoAssignAgentResult + from agents.agent_mapping_pge.engine import AgentResult as AutoAssignAgentResult SINGLE_ITEM_MAX_ITERATIONS = 15 @@ -78,13 +78,18 @@ def auto_assign_with_agent( on_step: Optional[Callable[[str, int], None]] = None, max_iterations: Optional[int] = None, ) -> "AutoAssignAgentResult": - """Run ``agent_auto_assignment`` (blocking). + """Run the mapping-PGE agent (``agent_mapping_pge``) — blocking. + + Returns an :class:`AgentResult` with the standard ``entity_mappings`` + and ``relationship_mappings`` plus three PGE-specific extras + (``source_model``, ``mapping_evaluations``, ``mapping_run_log``) that + the caller can persist on the session. ``client`` is typically a :class:`~back.core.databricks.DatabricksClient` built with the domain warehouse. Call from a background thread when started from HTTP. """ - from agents.agent_auto_assignment import run_agent + from agents.agent_mapping_pge import run_agent return run_agent( host=host, @@ -165,6 +170,12 @@ def run_auto_assign_task( total_iterations = 0 total_usage = {"prompt_tokens": 0, "completion_tokens": 0} chunk_errors: List[str] = [] + # PGE-specific extras accumulated across chunks. Each chunk + # re-plans, so ``last_source_model`` reflects the most recent + # plan; per-item evaluations / run logs concatenate cleanly. + last_source_model: Optional[Dict[str, Any]] = None + merged_mapping_evaluations: Dict[str, Any] = {} + merged_mapping_run_log: List[Any] = [] for chunk_idx, chunk in enumerate(chunks): chunk_num = chunk_idx + 1 @@ -261,6 +272,19 @@ def on_step(msg: str, progress_pct: int = 0) -> None: for k in total_usage: total_usage[k] += agent_result.usage.get(k, 0) + # PGE extras — accumulate. The new engine returns these as + # dicts/lists (drop-in compatible). The legacy engine omitted + # them; ``getattr`` with defaults keeps us tolerant. + chunk_source_model = getattr(agent_result, "source_model", None) + if chunk_source_model: + last_source_model = chunk_source_model + chunk_evals = getattr(agent_result, "mapping_evaluations", None) or {} + if chunk_evals: + merged_mapping_evaluations.update(chunk_evals) + chunk_run_log = getattr(agent_result, "mapping_run_log", None) or [] + if chunk_run_log: + merged_mapping_run_log.extend(chunk_run_log) + e_done = len(entity_mapping_by_uri) r_done = len(rel_mapping_by_uri) @@ -345,6 +369,9 @@ def on_step(msg: str, progress_pct: int = 0) -> None: all_relationship_mappings, existing_entity_mappings=entity_mappings, existing_relationship_mappings=relationship_mappings, + source_model=last_source_model, + mapping_evaluations=merged_mapping_evaluations or None, + mapping_run_log=merged_mapping_run_log or None, ) message = f"Completed: {e_count} entities, {r_count} relationships mapped" @@ -449,6 +476,11 @@ def on_step(msg: str, progress_pct: int = 0) -> None: tm.fail_task(task.id, "Agent completed but produced no mapping") return + # PGE extras from this single-item run — passed through verbatim. + single_source_model = getattr(agent_result, "source_model", None) + single_evals = getattr(agent_result, "mapping_evaluations", None) or None + single_run_log = getattr(agent_result, "mapping_run_log", None) or None + if item_type == "entity": Mapping.save_mappings_to_session( session_id, @@ -456,6 +488,9 @@ def on_step(msg: str, progress_pct: int = 0) -> None: agent_result.entity_mappings, None, existing_entity_mappings=existing_entity_mappings, + source_model=single_source_model, + mapping_evaluations=single_evals, + mapping_run_log=single_run_log, ) else: Mapping.save_mappings_to_session( @@ -464,6 +499,9 @@ def on_step(msg: str, progress_pct: int = 0) -> None: None, agent_result.relationship_mappings, existing_relationship_mappings=existing_relationship_mappings, + source_model=single_source_model, + mapping_evaluations=single_evals, + mapping_run_log=single_run_log, ) tm.complete_task( @@ -918,6 +956,9 @@ def save_mappings_to_session( *, existing_entity_mappings: Optional[list] = None, existing_relationship_mappings: Optional[list] = None, + source_model: Optional[Dict[str, Any]] = None, + mapping_evaluations: Optional[Dict[str, Any]] = None, + mapping_run_log: Optional[List[Any]] = None, ) -> None: if not session_id: logger.warning("save_mappings_to_session: no session_id — skipping") @@ -986,6 +1027,21 @@ def save_mappings_to_session( else: assignment["relationships"] = relationship_mappings + # Mapping-PGE extras — persisted alongside the assignment so the + # UI (future work) and downstream observability can surface + # planner state, per-item evaluation reports, and the per-item + # attempt log without re-running the agent. + if source_model is not None: + assignment["source_model"] = source_model + if mapping_evaluations is not None: + merged_evals = dict(assignment.get("mapping_evaluations") or {}) + merged_evals.update(mapping_evaluations) + assignment["mapping_evaluations"] = merged_evals + if mapping_run_log is not None: + existing_log = list(assignment.get("mapping_run_log") or []) + existing_log.extend(mapping_run_log) + assignment["mapping_run_log"] = existing_log + domain_node = bucket.setdefault("domain", {}) domain_node["assignment_changed"] = True diff --git a/tests/agents/agent_mapping_pge/__init__.py b/tests/agents/agent_mapping_pge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/agents/agent_mapping_pge/test_contracts.py b/tests/agents/agent_mapping_pge/test_contracts.py new file mode 100644 index 00000000..35498c15 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_contracts.py @@ -0,0 +1,126 @@ +"""Smoke tests for the PGE contracts. + +These are intentionally narrow — they only assert that every contract +dataclass round-trips cleanly through ``to_dict`` / ``from_dict`` so that +downstream sprints (Planner, Generator, orchestrator) can rely on +JSON-safe serialisation for MLflow artefacts and registry persistence. +""" + +import json + +from agents.agent_mapping_pge.contracts import ( + CanonicalId, + EvalFailure, + EvalReport, + JoinKey, + MappingPlan, + RetryState, + SkipItem, + SourceModel, + TableRole, + TableRoleCandidate, +) + + +def _roundtrip(obj): + """Serialise to dict -> JSON string -> dict -> reconstruct via from_dict.""" + cls = type(obj) + d = obj.to_dict() + encoded = json.dumps(d) + back = cls.from_dict(json.loads(encoded)) + return back, d + + +def test_source_model_roundtrip(): + sm = SourceModel( + table_roles=[ + TableRole( + table="cat.sch.mothers", + ontology_class_candidates=[ + TableRoleCandidate( + uri="http://ex.org#Mother", confidence=0.92, reason="row match" + ), + ], + ), + ], + canonical_ids=[ + CanonicalId( + ontology_class="http://ex.org#Mother", + canonical_column_per_table={"cat.sch.mothers": "nhs_number"}, + format_note="NHS number, 10 digits, no separators", + ) + ], + join_keys=[ + JoinKey( + from_ref="cat.sch.babies.mother_nhs", + to_ref="cat.sch.mothers.nhs_number", + confidence=0.88, + overlap_pct=0.97, + kind="same_trust_fk", + ) + ], + mapping_plan=MappingPlan( + entity_order=["http://ex.org#Mother", "http://ex.org#Baby"], + relationship_order=["http://ex.org#hasBaby"], + skip=[SkipItem(item="http://ex.org#Ghost", reason="no source table")], + ), + ) + back, d = _roundtrip(sm) + assert back.to_dict() == d + assert back.table_roles[0].table == "cat.sch.mothers" + assert back.canonical_ids[0].canonical_column_per_table["cat.sch.mothers"] == "nhs_number" + assert back.join_keys[0].kind == "same_trust_fk" + assert back.mapping_plan.skip[0].item == "http://ex.org#Ghost" + + +def test_eval_report_roundtrip(): + report = EvalReport( + status="FAIL", + stage="deterministic", + metrics={"row_count": 0}, + failures=[ + EvalFailure( + kind="structural", + check="row_count", + expected="> 0", + observed="0", + hint="fix the FROM clause", + ) + ], + bubble_to_planner=True, + ) + back, d = _roundtrip(report) + assert back.to_dict() == d + assert back.status == "FAIL" + assert back.failures[0].check == "row_count" + + +def test_retry_state_roundtrip_with_and_without_report(): + rs_empty = RetryState(item_uri="http://ex.org#Mother") + back, d = _roundtrip(rs_empty) + assert back.to_dict() == d + assert back.last_eval_report is None + + rs = RetryState( + item_uri="http://ex.org#Baby", + generator_attempts=2, + planner_reinvocations=1, + last_eval_report=EvalReport( + status="FAIL", + stage="deterministic", + failures=[ + EvalFailure( + kind="structural", + check="total_edges", + expected="> 0", + observed="0", + hint="fix join", + ) + ], + bubble_to_planner=True, + ), + ) + back, d = _roundtrip(rs) + assert back.to_dict() == d + assert back.last_eval_report is not None + assert back.last_eval_report.failures[0].check == "total_edges" diff --git a/tests/agents/agent_mapping_pge/test_coverage.py b/tests/agents/agent_mapping_pge/test_coverage.py new file mode 100644 index 00000000..1baa1b75 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_coverage.py @@ -0,0 +1,140 @@ +"""Tests for deterministic coverage enforcement + derived mappings. + +These lock in the invariant that broke the pipeline before: coverage is +computed from the ontology, NOT from the Planner's discretionary plan, so every +class and every relationship is always attempted. +""" + +from agents.agent_mapping_pge import coverage as cov +from agents.agent_mapping_pge.contracts import ( + CanonicalId, + MappingPlan, + SkipItem, + SourceModel, + TableRole, + TableRoleCandidate, +) + +# A miniature class hierarchy mirroring the maternity shape: +# Person (abstract) -> Patient (abstract) -> {Mother, Baby concrete} +# Encounter (abstract) -> {Visit concrete} +PERSON = "u#Person" +PATIENT = "u#Patient" +MOTHER = "u#Mother" +BABY = "u#Baby" +ENCOUNTER = "u#Encounter" +VISIT = "u#Visit" +ATTENDED = "u#attended" # Visit -> Mother + + +def _ontology() -> dict: + return { + "entities": [ + {"uri": PERSON, "name": "Person", "parent": "", "attributes": []}, + {"uri": PATIENT, "name": "Patient", "parent": "Person", + "attributes": [{"name": "nhsnumber"}]}, + {"uri": MOTHER, "name": "Mother", "parent": "Patient", + "attributes": [{"name": "nhsnumber"}, {"name": "postcode"}]}, + {"uri": BABY, "name": "Baby", "parent": "Patient", + "attributes": [{"name": "nhsnumber"}]}, + {"uri": ENCOUNTER, "name": "Encounter", "parent": "", "attributes": []}, + {"uri": VISIT, "name": "Visit", "parent": "Encounter", "attributes": []}, + ], + "relationships": [ + {"uri": ATTENDED, "name": "attended", "domain": VISIT, "range": MOTHER}, + ], + } + + +def _source_model(skip_some: bool = False) -> SourceModel: + # Planner only assigns tables to the concrete leaves + only plans a SUBSET. + roles = [ + TableRole(table="c.s.mother", ontology_class_candidates=[ + TableRoleCandidate(uri=MOTHER, confidence=0.9)]), + TableRole(table="c.s.baby", ontology_class_candidates=[ + TableRoleCandidate(uri=BABY, confidence=0.9)]), + TableRole(table="c.s.visit", ontology_class_candidates=[ + TableRoleCandidate(uri=VISIT, confidence=0.9)]), + ] + cids = [ + CanonicalId(ontology_class=MOTHER, canonical_column_per_table={"c.s.mother": "nhs"}), + CanonicalId(ontology_class=VISIT, canonical_column_per_table={"c.s.visit": "vid"}), + ] + plan = MappingPlan( + entity_order=[MOTHER], # deliberately incomplete + relationship_order=[], # deliberately empty + skip=[SkipItem(item=BABY, reason="planner skipped it")] if skip_some else [], + ) + return SourceModel(table_roles=roles, canonical_ids=cids, mapping_plan=plan) + + +def test_classify_abstract_vs_concrete(): + concrete, abstract = cov.classify(_ontology(), _source_model()) + assert abstract == {PERSON, PATIENT, ENCOUNTER} + assert {MOTHER, BABY, VISIT} <= concrete + + +def test_full_entity_order_covers_all_classes_abstracts_last(): + order = cov.full_entity_order(_ontology(), _source_model()) + assert set(order) == {PERSON, PATIENT, MOTHER, BABY, ENCOUNTER, VISIT} + # Every concrete leaf precedes its abstract ancestors. + assert order.index(MOTHER) < order.index(PATIENT) < order.index(PERSON) + assert order.index(VISIT) < order.index(ENCOUNTER) + + +def test_skip_list_does_not_reduce_coverage(): + # Even when the Planner skips Baby, full coverage still includes it. + order = cov.full_entity_order(_ontology(), _source_model(skip_some=True)) + assert BABY in order + + +def test_full_relationship_order_includes_all(): + order = cov.full_entity_order(_ontology(), _source_model()) + rels = cov.full_relationship_order(_ontology(), order, _source_model()) + assert rels == [ATTENDED] + + +def test_concrete_leaf_descendants(): + concrete, _ = cov.classify(_ontology(), _source_model()) + assert set(cov.concrete_leaf_descendants(PERSON, _ontology(), concrete)) == {MOTHER, BABY} + assert set(cov.concrete_leaf_descendants(PATIENT, _ontology(), concrete)) == {MOTHER, BABY} + assert set(cov.concrete_leaf_descendants(ENCOUNTER, _ontology(), concrete)) == {VISIT} + + +def test_build_abstract_union_mapping_reuses_subclass_sql(): + mother_em = { + "ontology_class": MOTHER, "id_column": "ID", + "sql_query": "SELECT nhs AS ID, nhs AS nhsnumber, pc AS postcode FROM c.s.mother", + "attribute_mappings": {"nhsnumber": "nhsnumber", "postcode": "postcode"}, + } + baby_em = { + "ontology_class": BABY, "id_column": "ID", + "sql_query": "SELECT CONCAT(nhs,'-baby') AS ID, nhs AS nhsnumber FROM c.s.baby", + "attribute_mappings": {"nhsnumber": "nhsnumber"}, + } + patient = next(e for e in _ontology()["entities"] if e["uri"] == PATIENT) + m = cov.build_abstract_union_mapping(PATIENT, patient, [mother_em, baby_em]) + assert m is not None + assert m["id_column"] == "ID" + assert m["derived"] == "abstract_union" + # Patient's own attribute (nhsnumber) is projected; both subclass SQLs reused. + assert "nhsnumber" in m["attribute_mappings"] + assert "UNION ALL" in m["sql_query"] + assert "c.s.mother" in m["sql_query"] and "c.s.baby" in m["sql_query"] + + +def test_build_abstract_union_mapping_none_when_no_subclasses(): + patient = next(e for e in _ontology()["entities"] if e["uri"] == PATIENT) + assert cov.build_abstract_union_mapping(PATIENT, patient, []) is None + + +def test_synthetic_endpoint_mapping_from_canonical_ids(): + em = cov.synthetic_endpoint_mapping(_source_model(), VISIT) + assert em is not None + assert em["id_column"] == "ID" + assert "c.s.visit" in em["sql_query"] + assert em["derived"] == "synthetic_endpoint" + + +def test_synthetic_endpoint_mapping_none_for_unknown_class(): + assert cov.synthetic_endpoint_mapping(_source_model(), "u#Nope") is None diff --git a/tests/agents/agent_mapping_pge/test_critic.py b/tests/agents/agent_mapping_pge/test_critic.py new file mode 100644 index 00000000..cad85d79 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_critic.py @@ -0,0 +1,684 @@ +"""Tests for the mapping-PGE Semantic Critic agent (Sprint 6). + +Mirrors the structure of ``test_relationship_generator.py``. The Critic is a +narrow tool-calling ReAct loop terminated by ``submit_evaluation``. These +tests exercise the loop's control flow with a *fake LLM* — a stub that +replaces ``call_serving_endpoint`` at module level and returns canned +responses on a per-call basis. + +No real HTTP, no real Databricks, no MLflow tracing. + +What we DO exercise: +* PASS verdict terminates immediately. +* FAIL with bubble_to_planner=False (column-level). +* FAIL with bubble_to_planner=True (table-level) — the bubble flag survives. +* PASS+bubble is demoted (matches build_report behaviour). +* FAIL with empty failures[] synthesises a generic semantic failure. +* Invalid status does NOT terminate — loop continues, accepts a valid retry. +* Text-only response → failure with "without submitting evaluation". +* Iteration-budget exhaustion → failure with "iteration budget". +* User prompt surfaces structural-stage metrics. +* User prompt for relationships includes domain/range sections. +""" + +import json +from typing import Any, Callable, Dict, List, Optional + +import pytest + +from agents.agent_mapping_pge.evaluator import critic as critic_mod +from agents.agent_mapping_pge.evaluator.critic import ( + CriticResult, + CriticStep, + run_critic, +) + + +# ===================================================== +# Fake LLM scaffolding +# ===================================================== + + +_ENTITY_URI = "http://ex.org/maternity#Mother" +_REL_URI = "http://ex.org/maternity#motherOf" + + +def _make_tool_call(name: str, arguments: dict, *, tc_id: str = "tc1") -> dict: + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(arguments)}, + } + + +def _llm_response( + *, + tool_calls: Optional[List[dict]] = None, + content: Optional[str] = None, + finish_reason: str = "tool_calls", + usage: Optional[Dict[str, int]] = None, +) -> dict: + message: Dict[str, Any] = {"role": "assistant"} + if tool_calls: + message["tool_calls"] = tool_calls + if content is not None: + message["content"] = content + return { + "choices": [{"finish_reason": finish_reason, "message": message}], + "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, + } + + +class FakeLLM: + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + self.last_messages: Optional[List[dict]] = None + self.first_messages: Optional[List[dict]] = None + + def __call__(self, *args, **kwargs) -> dict: + self.calls += 1 + msgs: Optional[List[dict]] = None + if len(args) >= 4 and isinstance(args[3], list): + msgs = args[3] + elif "messages" in kwargs: + msgs = kwargs["messages"] + if msgs is not None: + snapshot = [dict(m) for m in msgs] + if self.first_messages is None: + self.first_messages = snapshot + self.last_messages = snapshot + + if not self.responses: + raise AssertionError( + f"FakeLLM: ran out of canned responses on call #{self.calls}" + ) + return self.responses.pop(0) + + +class CyclingFakeLLM: + """Like FakeLLM but cycles through a fixed list forever.""" + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + resp = self.responses[self.calls % len(self.responses)] + self.calls += 1 + return resp + + +@pytest.fixture +def no_sleep(monkeypatch): + """Neutralise the 3-second inter-iteration delay so tests run fast.""" + monkeypatch.setattr(critic_mod.time, "sleep", lambda *_a, **_k: None) + + +def _patch_llm(monkeypatch, fake: Callable[..., dict]) -> None: + monkeypatch.setattr(critic_mod, "call_serving_endpoint", fake) + + +# ===================================================== +# Fixtures +# ===================================================== + + +def _entity_definition() -> dict: + return { + "uri": _ENTITY_URI, + "label": "Mother", + "name": "Mother", + "comment": "A pregnant woman in the maternity dataset.", + "attributes": [ + {"name": "nhsNumber", "type": "string"}, + {"name": "dateOfBirth", "type": "date"}, + ], + } + + +def _relationship_definition() -> dict: + return { + "uri": _REL_URI, + "label": "motherOf", + "name": "motherOf", + "comment": "Links a Mother to each of her babies.", + "domain": _ENTITY_URI, + "range": "http://ex.org/maternity#Baby", + } + + +def _entity_submitted_mapping() -> dict: + return { + "ontology_class": _ENTITY_URI, + "class_name": "Mother", + "sql_query": "SELECT nhs_number AS ID, nhs_number AS Label FROM cat.sch.mothers WHERE nhs_number IS NOT NULL", + "id_column": "nhs_number", + "label_column": "nhs_number", + "attribute_mappings": {"nhsNumber": "nhs_number"}, + "unmapped_attributes": [ + {"name": "dateOfBirth", "reason": "column absent from this table"} + ], + } + + +def _relationship_submitted_mapping() -> dict: + return { + "property": _REL_URI, + "property_name": "motherOf", + "sql_query": "SELECT mother_nhs_number AS source_id, baby_id AS target_id FROM cat.sch.babies", + "source_id_column": "nhs_number", + "target_id_column": "baby_id", + "domain": _ENTITY_URI, + "range_class": "http://ex.org/maternity#Baby", + } + + +def _source_model_slice() -> dict: + return { + "candidate_tables": [ + { + "table": "cat.sch.mothers", + "confidence": 0.9, + "reason": "row per mother, nhs_number as PK", + } + ], + "canonical_id": { + "canonical_column_per_table": {"cat.sch.mothers": "nhs_number"}, + "format_note": "10-digit NHS number", + }, + } + + +def _stage1_metrics(**overrides) -> dict: + base = { + "row_count": 100, + "distinct_ids": 100, + "null_ids": 0, + } + base.update(overrides) + return base + + +def _valid_pass_submit() -> dict: + return { + "status": "PASS", + "failures": [], + "bubble_to_planner": False, + "reasoning": "Sampled values match the Mother concept; column semantics OK.", + } + + +def _valid_fail_column_submit() -> dict: + return { + "status": "FAIL", + "failures": [ + { + "check": "column_semantics", + "expected": "delivery date", + "observed": "appointment_date is a booking date", + "hint": "Use `delivery_dttm` instead of `appointment_date`.", + } + ], + "bubble_to_planner": False, + "reasoning": "Wrong column within the right table.", + } + + +def _valid_fail_table_submit() -> dict: + return { + "status": "FAIL", + "failures": [ + { + "check": "table_selection", + "expected": "labour_delivery", + "observed": "antenatal_visits", + "hint": "Switch to `labour_delivery` table for the Delivery class.", + } + ], + "bubble_to_planner": True, + "reasoning": "Wrong table chosen — bubble to Planner.", + } + + +def _run_entity_critic( + fake: Callable[..., dict], + *, + max_iterations: int = 6, + item_kind: str = "entity", + item_uri: str = _ENTITY_URI, + item_definition: Optional[dict] = None, + submitted_mapping: Optional[dict] = None, + stage1_metrics: Optional[dict] = None, + client: Any = None, +) -> CriticResult: + return run_critic( + host="https://x", + token="t", + endpoint_name="ep", + client=client, + item_kind=item_kind, + item_uri=item_uri, + item_definition=item_definition + if item_definition is not None + else _entity_definition(), + submitted_mapping=submitted_mapping + if submitted_mapping is not None + else _entity_submitted_mapping(), + source_model_slice=_source_model_slice(), + stage1_metrics=stage1_metrics + if stage1_metrics is not None + else _stage1_metrics(), + max_iterations=max_iterations, + ) + + +# ===================================================== +# 1. PASS verdict terminates immediately +# ===================================================== + + +def test_pass_verdict(monkeypatch, no_sleep): + """First LLM turn submits PASS → success=True, status=PASS, iterations=1.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_pass_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert isinstance(result, CriticResult) + assert result.success is True + assert result.iterations == 1 + assert result.report is not None + assert result.report.status == "PASS" + assert result.report.stage == "semantic" + assert result.report.failures == [] + assert result.report.bubble_to_planner is False + assert result.error == "" + # Step recording: one tool_call + one tool_result. + assert [s.step_type for s in result.steps] == ["tool_call", "tool_result"] + assert result.steps[0].tool_name == "submit_evaluation" + + +# ===================================================== +# 2. FAIL with bubble_to_planner=False (column-level) +# ===================================================== + + +def test_fail_column_level(monkeypatch, no_sleep): + """status=FAIL, bubble_to_planner=False → column-level failure preserved.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_fail_column_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.report is not None + assert result.report.status == "FAIL" + assert result.report.bubble_to_planner is False + assert len(result.report.failures) == 1 + failure = result.report.failures[0] + assert failure.kind == "semantic" + assert failure.check == "column_semantics" + assert "delivery_dttm" in failure.hint + + +# ===================================================== +# 3. FAIL with bubble_to_planner=True (table-level) +# ===================================================== + + +def test_fail_table_level_bubbles(monkeypatch, no_sleep): + """status=FAIL, bubble_to_planner=True → bubble flag preserved on report.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_fail_table_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.report is not None + assert result.report.status == "FAIL" + assert result.report.bubble_to_planner is True + assert len(result.report.failures) == 1 + failure = result.report.failures[0] + assert failure.check == "table_selection" + assert "labour_delivery" in failure.hint + + +# ===================================================== +# 4. PASS with bubble_to_planner=True is demoted +# ===================================================== + + +def test_demotes_pass_with_bubble(monkeypatch, no_sleep): + """A PASS verdict that asks to bubble is demoted to bubble=False.""" + bad_pass = _valid_pass_submit() + bad_pass["bubble_to_planner"] = True + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("submit_evaluation", bad_pass)] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.report is not None + assert result.report.status == "PASS" + # The bubble flag must have been demoted. + assert result.report.bubble_to_planner is False + + +# ===================================================== +# 5. FAIL with empty failures[] synthesises one +# ===================================================== + + +def test_fail_without_failures_synthesises_one(monkeypatch, no_sleep): + """status=FAIL with empty failures[] gets a generic semantic failure + synthesised so the report stays coherent.""" + fail_no_failures = { + "status": "FAIL", + "failures": [], + "bubble_to_planner": False, + "reasoning": "Something is off but I can't pinpoint it.", + } + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", fail_no_failures) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.report is not None + assert result.report.status == "FAIL" + assert len(result.report.failures) == 1 + f = result.report.failures[0] + assert f.kind == "semantic" + assert f.check == "semantic_audit" + # The reasoning is folded into the synthetic failure's hint when present. + assert "Something is off" in f.hint + + +# ===================================================== +# 6. Invalid status does NOT terminate — agent retries +# ===================================================== + + +def test_invalid_status_rejected(monkeypatch, no_sleep): + """A submit with status='UNKNOWN' must NOT terminate the loop; the + Critic must keep going and a follow-up submit with a valid status + should succeed. + """ + fake = FakeLLM( + [ + # Turn 1: invalid status → handler returns success=False, loop continues. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_evaluation", + { + "status": "UNKNOWN", + "failures": [], + "bubble_to_planner": False, + "reasoning": "n/a", + }, + tc_id="bad", + ) + ] + ), + # Turn 2: valid PASS submit → terminates. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_evaluation", + _valid_pass_submit(), + tc_id="good", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is True + assert result.iterations == 2 + assert result.report is not None + assert result.report.status == "PASS" + # Both submit attempts left tool_call + tool_result steps (4 total). + assert len(result.steps) == 4 + + # The corrective tool message on the 2nd LLM call must contain the + # "invalid status" error so the LLM sees why its first attempt failed. + assert fake.last_messages is not None + tool_messages = [m for m in fake.last_messages if m.get("role") == "tool"] + assert tool_messages, "expected at least one tool message on the 2nd call" + first_tool_msg = tool_messages[0].get("content", "") + parsed = json.loads(first_tool_msg) + assert parsed.get("success") is False + assert "invalid status" in parsed.get("error", "") + + +# ===================================================== +# 7. Text without terminal call → failure +# ===================================================== + + +def test_text_without_terminal_fails(monkeypatch, no_sleep): + """A plain-text response is treated as failure — the Critic must + terminate via submit_evaluation. + """ + fake = FakeLLM( + [_llm_response(content="I am thinking…", finish_reason="stop")] + ) + _patch_llm(monkeypatch, fake) + + result = _run_entity_critic(fake) + + assert result.success is False + assert result.iterations == 1 + assert result.report is None + assert "without submitting evaluation" in result.error + assert any(s.step_type == "output" for s in result.steps) + + +# ===================================================== +# 8. Iteration-budget exhaustion → failure +# ===================================================== + + +def test_exhausts_budget(monkeypatch, no_sleep): + """Endless sample_table calls with max_iterations=3 → fail with + ``iteration budget`` and three iterations of steps recorded.""" + fake = CyclingFakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="probe", + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890"}] + + result = _run_entity_critic(fake, max_iterations=3, client=FakeClient()) + + assert result.success is False + assert result.iterations == 3 + assert result.report is None + assert "iteration budget" in result.error + # 3 iterations × (tool_call + tool_result) = 6 steps. + assert len(result.steps) == 6 + + +# ===================================================== +# 9. User prompt surfaces stage1 metrics +# ===================================================== + + +def test_user_prompt_includes_stage1_metrics(monkeypatch, no_sleep): + """The first LLM call's user message must contain stage1 metric values + so the Critic sees the structural context.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_pass_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + _run_entity_critic( + fake, + stage1_metrics={"row_count": 1234, "distinct_ids": 1234, "null_ids": 0}, + ) + + assert fake.first_messages is not None + assert fake.first_messages[0]["role"] == "system" + assert fake.first_messages[1]["role"] == "user" + user_content = fake.first_messages[1]["content"] + assert "1234" in user_content + assert "STRUCTURAL CHECK METRICS" in user_content + + +# ===================================================== +# 10. Relationship audit surfaces domain/range +# ===================================================== + + +def test_user_prompt_distinguishes_entity_vs_relationship(monkeypatch, no_sleep): + """When item_kind='relationship', the user prompt must include the + 'domain' and 'range' lines that an entity prompt would not have.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_evaluation", _valid_pass_submit()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_critic( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + item_kind="relationship", + item_uri=_REL_URI, + item_definition=_relationship_definition(), + submitted_mapping=_relationship_submitted_mapping(), + source_model_slice=_source_model_slice(), + stage1_metrics=_stage1_metrics(), + ) + + assert fake.first_messages is not None + user_content = fake.first_messages[1]["content"] + # The kind is surfaced explicitly. + assert "relationship" in user_content + # The relationship-specific domain/range sections appear. + assert "domain:" in user_content + assert "range:" in user_content + # And it is framed as a relationship submitted mapping. + assert "SUBMITTED MAPPING (relationship)" in user_content + # The relationship endpoint columns are surfaced too. + assert "source_id_column" in user_content + assert "target_id_column" in user_content + + +# ===================================================== +# 11. Step recording invariants +# ===================================================== + + +def test_records_steps(monkeypatch, no_sleep): + """Every tool-calling iteration produces one ``tool_call`` step + immediately followed by one ``tool_result`` step with the same tool_name.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_evaluation", + _valid_pass_submit(), + tc_id="b", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890"}] + + result = _run_entity_critic(fake, client=FakeClient()) + + assert result.success is True + assert len(result.steps) % 2 == 0 + for i in range(0, len(result.steps), 2): + call_step = result.steps[i] + result_step = result.steps[i + 1] + assert call_step.step_type == "tool_call" + assert result_step.step_type == "tool_result" + assert call_step.tool_name == result_step.tool_name + assert isinstance(call_step, CriticStep) + assert isinstance(result_step, CriticStep) diff --git a/tests/agents/agent_mapping_pge/test_deterministic_evaluator.py b/tests/agents/agent_mapping_pge/test_deterministic_evaluator.py new file mode 100644 index 00000000..94ec3f35 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_deterministic_evaluator.py @@ -0,0 +1,916 @@ +"""Tests for the deterministic (stage-1) evaluator of the mapping PGE pipeline. + +The evaluator is a pure function: it takes a submitted mapping plus an +injectable ``execute_sql_fn`` and returns an ``EvalReport`` summarising +structural failures. No LLM, no Databricks connection. + +``execute_sql_fn`` contract (for the evaluator): + def execute_sql_fn(sql: str) -> dict +returning:: + + {"columns": [...], "rows": [{col: value, ...}, ...]} + +This is the full result set — not the 3-row sample emitted by +``agents.tools.sql.tool_execute_sql``. The PGE orchestrator (Sprint 7) is +responsible for wiring a runner that yields full rows. +""" + +from typing import Dict, List + +import pytest + +from agents.agent_mapping_pge.contracts import EvalFailure, EvalReport +from agents.agent_mapping_pge.evaluator.deterministic import ( + evaluate_entity_mapping, + evaluate_relationship_mapping, +) +from agents.agent_mapping_pge.evaluator.report import build_report + + +# ===================================================== +# Fixtures +# ===================================================== + + +MOTHER_CLASS = { + "uri": "http://ex.org/maternity#Mother", + "name": "Mother", + "attributes": [ + {"name": "firstName"}, + {"name": "lastName"}, + {"name": "nhsNumber"}, + ], +} + +BABY_CLASS = { + "uri": "http://ex.org/maternity#Baby", + "name": "Baby", + "attributes": [ + {"name": "birthWeight"}, + ], +} + + +def _mother_mapping(*, attribute_mappings=None, unmapped_attributes=None): + mapping = { + "ontology_class": MOTHER_CLASS["uri"], + "class_name": "Mother", + "sql_query": "SELECT nhs_number AS ID, full_name AS Label, first_name, last_name, nhs_number FROM cat.sch.mothers", + "id_column": "ID", + "label_column": "Label", + "attribute_mappings": attribute_mappings + if attribute_mappings is not None + else { + "firstName": "first_name", + "lastName": "last_name", + "nhsNumber": "nhs_number", + }, + } + if unmapped_attributes is not None: + mapping["unmapped_attributes"] = unmapped_attributes + return mapping + + +def _baby_mapping(): + return { + "ontology_class": BABY_CLASS["uri"], + "class_name": "Baby", + "sql_query": "SELECT baby_id AS ID, baby_id AS Label, birth_weight FROM cat.sch.babies", + "id_column": "ID", + "label_column": "Label", + "attribute_mappings": {"birthWeight": "birth_weight"}, + } + + +def _mother_to_baby_relationship(): + return { + "property": "http://ex.org/maternity#hasBaby", + "property_name": "hasBaby", + "sql_query": ( + "SELECT mother_nhs AS source_id, baby_id AS target_id " + "FROM cat.sch.babies" + ), + "source_id_column": "source_id", + "target_id_column": "target_id", + "source_class": MOTHER_CLASS["uri"], + "target_class": BABY_CLASS["uri"], + } + + +def _make_sql_fn(table: dict): + """Return an execute_sql_fn closure that routes by SQL substring. + + ``table`` maps a unique substring -> {"columns": [...], "rows": [...]}. + """ + + def fn(sql: str) -> dict: + for needle, payload in table.items(): + if needle in sql: + return payload + raise AssertionError(f"unexpected SQL in test: {sql}") + + return fn + + +# ===================================================== +# Entity evaluator +# ===================================================== + + +class TestEvaluateEntityMapping: + def test_pass_happy_path(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice Smith", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + { + "ID": "NHS-002", + "Label": "Bob Jones", + "first_name": "Bob", + "last_name": "Jones", + "nhs_number": "NHS-002", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert isinstance(report, EvalReport) + assert report.status == "PASS" + assert report.stage == "deterministic" + assert report.failures == [] + assert report.bubble_to_planner is False + assert report.metrics["row_count"] == 2 + assert report.metrics["distinct_id_count"] == 2 + assert report.metrics["null_id_count"] == 0 + assert report.metrics["unmapped_attribute_pct"] == 0.0 + + def test_fail_row_count_zero_bubbles_to_planner(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + {"mothers": {"columns": ["ID", "Label"], "rows": []}} + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + check_names = [f.check for f in report.failures] + assert "row_count" in check_names + + def test_fail_duplicate_ids(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + { + "ID": "NHS-001", + "Label": "Alice dup", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is False + check_names = [f.check for f in report.failures] + assert "distinct_id_count" in check_names + + def test_fail_null_ids(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": None, + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": None, + }, + { + "ID": "NHS-002", + "Label": "Bob", + "first_name": "Bob", + "last_name": "Jones", + "nhs_number": "NHS-002", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + check_names = [f.check for f in report.failures] + assert "null_id_count" in check_names + + def test_fail_unmapped_attribute(self): + # Omit lastName from attribute_mappings, no unmapped_attributes list. + mapping = _mother_mapping( + attribute_mappings={ + "firstName": "first_name", + "nhsNumber": "nhs_number", + }, + ) + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + check_names = [f.check for f in report.failures] + assert "unmapped_attribute_pct" in check_names + # 1 of 3 attributes missing -> ~0.333 + assert report.metrics["unmapped_attribute_pct"] == pytest.approx(1 / 3) + + def test_pass_when_unmapped_attribute_is_declared(self): + mapping = _mother_mapping( + attribute_mappings={ + "firstName": "first_name", + "nhsNumber": "nhs_number", + }, + unmapped_attributes=["lastName"], + ) + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "PASS" + assert report.metrics["unmapped_attribute_pct"] == 0.0 + + def test_pass_when_unmapped_attribute_is_declared_as_dict(self): + """The Generator may emit unmapped_attributes as [{name, reason}, ...]. + Hashing dicts would crash the evaluator — names must be extracted.""" + mapping = _mother_mapping( + attribute_mappings={ + "firstName": "first_name", + "nhsNumber": "nhs_number", + }, + unmapped_attributes=[ + {"name": "lastName", "reason": "no source column"} + ], + ) + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + + assert report.status == "PASS" + assert report.metrics["unmapped_attribute_pct"] == 0.0 + + def test_report_is_json_serialisable(self): + mapping = _mother_mapping() + sql_fn = _make_sql_fn( + { + "mothers": { + "columns": ["ID", "Label", "first_name", "last_name", "nhs_number"], + "rows": [ + { + "ID": "NHS-001", + "Label": "Alice", + "first_name": "Alice", + "last_name": "Smith", + "nhs_number": "NHS-001", + }, + ], + } + } + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=sql_fn, + ) + d = report.to_dict() + assert d["status"] == "PASS" + assert d["stage"] == "deterministic" + assert isinstance(d["metrics"], dict) + assert isinstance(d["failures"], list) + + def test_sql_execution_error_becomes_fail_not_crash(self): + """A mapping whose SQL parses but fails at runtime (e.g. a UNION + type mismatch) must yield a FAIL report with the error as a hint — + never propagate and crash the agent run. + """ + mapping = _mother_mapping() + + def boom(sql: str) -> dict: + raise RuntimeError( + "[CAST_INVALID_INPUT] The value 'x-preg-1-baby' of the type " + '"STRING" cannot be cast to "BIGINT"' + ) + + report = evaluate_entity_mapping( + mapping=mapping, + ontology_class=MOTHER_CLASS, + execute_sql_fn=boom, + ) + + assert report.status == "FAIL" + # A runtime SQL error is the Generator's to fix, not a re-plan trigger. + assert report.bubble_to_planner is False + checks = [f.check for f in report.failures] + assert "sql_execution" in checks + # The underlying DB error is surfaced for the generator to act on. + assert "CAST_INVALID_INPUT" in report.metrics.get("sql_error", "") + # Report must remain JSON-serialisable. + import json as _json + + _json.dumps(report.to_dict()) + + +# ===================================================== +# Relationship evaluator +# ===================================================== + + +def _entity_rows(ids): + return { + "columns": ["ID", "Label"], + "rows": [{"ID": i, "Label": i} for i in ids], + } + + +class TestEvaluateRelationshipMapping: + def test_pass_happy_path(self): + rel = _mother_to_baby_relationship() + sql_fn = _make_sql_fn( + { + # Relationship edges + "source_id": { + "columns": ["source_id", "target_id"], + "rows": [ + {"source_id": "NHS-001", "target_id": "B-1"}, + {"source_id": "NHS-002", "target_id": "B-2"}, + ], + }, + # Source entity universe + "mothers": _entity_rows(["NHS-001", "NHS-002", "NHS-003"]), + # Target entity universe + "babies": _entity_rows(["B-1", "B-2", "B-3"]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "PASS" + assert report.bubble_to_planner is False + assert report.metrics["total_edges"] == 2 + assert report.metrics["dangling_source_pct"] == 0.0 + assert report.metrics["dangling_target_pct"] == 0.0 + + def test_sql_execution_error_becomes_fail_not_crash(self): + """A relationship (or its endpoint-universe) SQL that errors at + runtime must yield a FAIL report, not crash the agent run. + """ + rel = _mother_to_baby_relationship() + + def boom(sql: str) -> dict: + raise RuntimeError("[UNRESOLVED_COLUMN] cannot resolve `target_id`") + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=boom, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is False + assert "sql_execution" in [f.check for f in report.failures] + assert "UNRESOLVED_COLUMN" in report.metrics.get("sql_error", "") + + def test_fail_47_pct_dangling_source_bubbles(self): + rel = _mother_to_baby_relationship() + # 100 edges, 47 source_ids unknown to source universe. + edge_rows = [ + {"source_id": f"NHS-{i:03d}", "target_id": f"B-{i}"} + for i in range(1, 101) + ] + # Only NHS-001..NHS-053 exist as mothers. + mother_ids = [f"NHS-{i:03d}" for i in range(1, 54)] + baby_ids = [f"B-{i}" for i in range(1, 201)] + + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows(mother_ids), + "babies": _entity_rows(baby_ids), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is False # 0.47 < 0.5 threshold + check_names = [f.check for f in report.failures] + assert "dangling_source_pct" in check_names + assert report.metrics["dangling_source_pct"] == pytest.approx(0.47) + + def test_fail_above_50_pct_dangling_source_bubbles_to_planner(self): + rel = _mother_to_baby_relationship() + edge_rows = [ + {"source_id": f"NHS-{i:03d}", "target_id": f"B-{i}"} + for i in range(1, 101) + ] + # Only NHS-001..NHS-040 are known mothers -> 60% dangling + mother_ids = [f"NHS-{i:03d}" for i in range(1, 41)] + baby_ids = [f"B-{i}" for i in range(1, 201)] + + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows(mother_ids), + "babies": _entity_rows(baby_ids), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + + def test_pass_3_pct_dangling_source_under_threshold(self): + rel = _mother_to_baby_relationship() + # 100 edges, only 3 source ids not in mother universe -> 3%. + edge_rows = [ + {"source_id": f"NHS-{i:03d}", "target_id": f"B-{i}"} + for i in range(1, 101) + ] + mother_ids = [f"NHS-{i:03d}" for i in range(1, 98)] # 97 known, 3 dangling + baby_ids = [f"B-{i}" for i in range(1, 201)] + + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows(mother_ids), + "babies": _entity_rows(baby_ids), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "PASS" + assert report.bubble_to_planner is False + assert report.metrics["dangling_source_pct"] == pytest.approx(0.03) + + def test_fail_zero_edges_bubbles_to_planner(self): + rel = _mother_to_baby_relationship() + sql_fn = _make_sql_fn( + { + "source_id": {"columns": ["source_id", "target_id"], "rows": []}, + "mothers": _entity_rows(["NHS-001"]), + "babies": _entity_rows(["B-1"]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + check_names = [f.check for f in report.failures] + assert "total_edges" in check_names + + def test_cross_source_band_fail_when_outside(self): + rel = _mother_to_baby_relationship() + # 100 edges, all source ids in mother universe. + edge_rows = [ + {"source_id": f"NHS-{i:03d}", "target_id": f"B-{i}"} + for i in range(1, 101) + ] + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + expected_cross_source_overlap_band=(0.25, 0.4), + ) + # overlap_pct = 1.0 (every source row matches a target id); outside band. + assert report.status == "FAIL" + check_names = [f.check for f in report.failures] + assert "cross_source_overlap_pct" in check_names + + def test_cross_source_band_pass_when_inside(self): + rel = _mother_to_baby_relationship() + # Build edges where only ~30% of source ids match a target id (band 0.25..0.4 ). + edge_rows = [] + for i in range(1, 101): + edge_rows.append( + { + "source_id": f"NHS-{i:03d}", + "target_id": f"B-{i}" if i <= 30 else f"X-{i}", + } + ) + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + expected_cross_source_overlap_band=(0.25, 0.4), + ) + # overlap = 30/100 = 0.3, inside band. + assert report.status == "PASS" + assert report.metrics["cross_source_overlap_pct"] == pytest.approx(0.3) + + def test_band_present_overlap_outside_band_with_catastrophic_dangling_bubbles(self): + """Band FAILS (overlap 0.05 << lo=0.25) AND dangling > 0.5 → bubble. + + The realised overlap is materially worse than the Planner predicted, + so the catastrophic-dangling structural failure fires alongside the + band-check failure, and ``bubble_to_planner`` flips True. + """ + rel = _mother_to_baby_relationship() + # 100 edges, only the first 5 target_ids land in the babies universe + # → overlap = 0.05, dangling_target = 0.95. + edge_rows = [] + for i in range(1, 101): + edge_rows.append( + { + "source_id": f"NHS-{i:03d}", + "target_id": f"B-{i}" if i <= 5 else f"X-{i}", + } + ) + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + expected_cross_source_overlap_band=(0.25, 0.4), + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + assert report.metrics["dangling_target_pct"] == pytest.approx(0.95) + check_names = [f.check for f in report.failures] + # Both the band failure AND the catastrophic-dangling row must surface. + assert "cross_source_overlap_pct" in check_names + assert "dangling_target_pct_catastrophic" in check_names + # The strict 0.05 dangling_target_pct row is gated behind "band is None" + # — it must NOT appear here. + assert "dangling_target_pct" not in check_names + + def test_band_present_overlap_outside_band_with_mild_dangling_does_not_bubble(self): + """Band FAILS but dangling is exactly at the bubble threshold (not > 0.5) + → status FAIL on the band row but ``bubble_to_planner`` stays False. + """ + rel = _mother_to_baby_relationship() + # 100 edges, 50 land in target universe → overlap = 0.50, dangling = 0.50. + # Band is (0.6, 0.8) so band check fails (0.50 < 0.6); dangling NOT > 0.5. + edge_rows = [] + for i in range(1, 101): + edge_rows.append( + { + "source_id": f"NHS-{i:03d}", + "target_id": f"B-{i}" if i <= 50 else f"X-{i}", + } + ) + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + expected_cross_source_overlap_band=(0.6, 0.8), + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is False + assert report.metrics["dangling_target_pct"] == pytest.approx(0.5) + check_names = [f.check for f in report.failures] + assert "cross_source_overlap_pct" in check_names + # No catastrophic row because dangling is not strictly > 0.5. + assert "dangling_target_pct_catastrophic" not in check_names + + def test_relationship_evaluator_uses_id_universe_cache(self): + """Sharing a cache across calls avoids re-running the entity SQLs.""" + rel = _mother_to_baby_relationship() + base_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": [ + {"source_id": "NHS-001", "target_id": "B-1"}, + {"source_id": "NHS-002", "target_id": "B-2"}, + ], + }, + "mothers": _entity_rows(["NHS-001", "NHS-002", "NHS-003"]), + "babies": _entity_rows(["B-1", "B-2", "B-3"]), + } + ) + + calls: List[str] = [] + + def counting_fn(sql: str) -> dict: + calls.append(sql) + return base_fn(sql) + + cache: Dict[str, set] = {} + + # First call: source + target entity SQLs + relationship SQL = 3 calls. + evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=counting_fn, + id_universe_cache=cache, + ) + first_call_count = len(calls) + assert first_call_count == 3 + + mother_sql = _mother_mapping()["sql_query"] + baby_sql = _baby_mapping()["sql_query"] + assert mother_sql in cache + assert baby_sql in cache + + # Second call with same cache: only the relationship SQL should be + # re-executed; both entity universes are served from cache. + evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=counting_fn, + id_universe_cache=cache, + ) + + delta = calls[first_call_count:] + assert len(delta) == 1 + assert mother_sql not in delta + assert baby_sql not in delta + + def test_band_absent_catastrophic_target_dangling_bubbles(self): + """No band supplied + dangling_target > 0.5 → strict check fires and bubbles.""" + rel = _mother_to_baby_relationship() + # 100 edges, only 20 target_ids land in babies universe → dangling = 0.80. + edge_rows = [] + for i in range(1, 101): + edge_rows.append( + { + "source_id": f"NHS-{i:03d}", + "target_id": f"B-{i}" if i <= 20 else f"X-{i}", + } + ) + sql_fn = _make_sql_fn( + { + "source_id": { + "columns": ["source_id", "target_id"], + "rows": edge_rows, + }, + "mothers": _entity_rows([f"NHS-{i:03d}" for i in range(1, 101)]), + "babies": _entity_rows([f"B-{i}" for i in range(1, 101)]), + } + ) + + report = evaluate_relationship_mapping( + mapping=rel, + source_entity_mapping=_mother_mapping(), + target_entity_mapping=_baby_mapping(), + execute_sql_fn=sql_fn, + ) + + assert report.status == "FAIL" + assert report.bubble_to_planner is True + assert report.metrics["dangling_target_pct"] == pytest.approx(0.8) + check_names = [f.check for f in report.failures] + assert "dangling_target_pct" in check_names + + +# ===================================================== +# build_report — bubble demotion warning +# ===================================================== + + +def test_build_report_warns_when_bubble_demoted(caplog): + """``bubble_to_planner=True`` with no failures (status PASS) should + emit a warning, AND silently-PASSing reports should not warn. + """ + import logging + + # PASS + bubble_to_planner=True → warning expected, bubble demoted. + caplog.clear() + with caplog.at_level(logging.WARNING): + passing = build_report( + stage="deterministic", + metrics={"row_count": 1}, + failures=[], + bubble_to_planner=True, + ) + assert passing.status == "PASS" + assert passing.bubble_to_planner is False + assert any( + "bubble_to_planner=True" in rec.message and rec.levelname == "WARNING" + for rec in caplog.records + ) + + # PASS + bubble_to_planner=False → no warning. + caplog.clear() + with caplog.at_level(logging.WARNING): + build_report( + stage="deterministic", + metrics={"row_count": 1}, + failures=[], + bubble_to_planner=False, + ) + assert not any( + "bubble_to_planner=True" in rec.message for rec in caplog.records + ) + + # FAIL + bubble_to_planner=True → no demotion, no warning. + caplog.clear() + failure = EvalFailure( + kind="structural", + check="row_count", + expected="> 0", + observed="0", + hint="", + ) + with caplog.at_level(logging.WARNING): + failing = build_report( + stage="deterministic", + metrics={"row_count": 0}, + failures=[failure], + bubble_to_planner=True, + ) + assert failing.status == "FAIL" + assert failing.bubble_to_planner is True + assert not any( + "bubble_to_planner=True" in rec.message for rec in caplog.records + ) diff --git a/tests/agents/agent_mapping_pge/test_engine.py b/tests/agents/agent_mapping_pge/test_engine.py new file mode 100644 index 00000000..760683eb --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_engine.py @@ -0,0 +1,1117 @@ +"""Tests for the mapping-PGE orchestrator (Sprint 7). + +The orchestrator wires Planner -> Generator(s) -> Evaluator(s) into a single +``run_agent`` entry. These tests exercise the control flow with fake versions +of each sub-agent — no real LLM, no real Databricks. Each test patches the +module-level references in :mod:`engine` so the orchestrator calls the fakes +instead of the production functions. + +What we DO exercise: +* Happy path with both entities and relationships. +* Planner failure aborts cleanly. +* Generator failure records FAIL but continues. +* Evaluator FAIL (non-bubble) drives a retry with a hint. +* Bubble-to-planner triggers Planner re-invocation; budget is global. +* 3-attempt retry budget exhaustion records FAIL_BUDGET. +* Critic PASS / FAIL paths, and the ``skip_semantic_critic`` short-circuit. +* Pre-seeded entity mappings and Planner skip[] entries. +* on_step pct stays non-decreasing across the run. +* Id-universe cache shares entity universes across relationships. +""" + +from typing import Any, Dict, List, Optional, Tuple + +import pytest + +from agents.agent_mapping_pge import engine as engine_mod +from agents.agent_mapping_pge.contracts import ( + CanonicalId, + EvalFailure, + EvalReport, + JoinKey, + MappingPlan, + SkipItem, + SourceModel, + TableRole, + TableRoleCandidate, +) +from agents.agent_mapping_pge.engine import AgentResult, run_agent +from agents.agent_mapping_pge.evaluator.critic import CriticResult +from agents.agent_mapping_pge.generators.entity import EntityGenResult +from agents.agent_mapping_pge.generators.relationship import RelationshipGenResult +from agents.agent_mapping_pge.planner import PlannerResult + + +# ===================================================== +# Ontology + SourceModel fixtures +# ===================================================== + + +CUSTOMER_URI = "http://test.org/ontology#Customer" +ORDER_URI = "http://test.org/ontology#Order" +HAS_ORDER_URI = "http://test.org/ontology#hasOrder" +ITEM_URI = "http://test.org/ontology#Item" +CONTAINS_URI = "http://test.org/ontology#contains" + +T_CUSTOMERS = "cat.sch.customers" +T_ORDERS = "cat.sch.orders" +T_ITEMS = "cat.sch.items" + + +def _ontology() -> dict: + # Control-flow tests use this two-entity / one-relationship ontology so the + # engine-enforced coverage set equals exactly Customer + Order + hasOrder. + # (The id-universe-cache test below supplies its own 3-entity ontology.) + return { + "entities": [ + { + "uri": CUSTOMER_URI, + "name": "Customer", + "label": "Customer", + "attributes": [{"name": "firstName", "type": "xsd:string"}], + }, + { + "uri": ORDER_URI, + "name": "Order", + "label": "Order", + "attributes": [{"name": "orderDate", "type": "xsd:string"}], + }, + ], + "relationships": [ + { + "uri": HAS_ORDER_URI, + "name": "hasOrder", + "label": "hasOrder", + "domain": CUSTOMER_URI, + "range": ORDER_URI, + }, + ], + } + + +def _source_model(*, with_items: bool = False) -> SourceModel: + table_roles = [ + TableRole( + table=T_CUSTOMERS, + ontology_class_candidates=[ + TableRoleCandidate(uri=CUSTOMER_URI, confidence=0.9, reason="ok") + ], + ), + TableRole( + table=T_ORDERS, + ontology_class_candidates=[ + TableRoleCandidate(uri=ORDER_URI, confidence=0.9, reason="ok") + ], + ), + ] + if with_items: + table_roles.append( + TableRole( + table=T_ITEMS, + ontology_class_candidates=[ + TableRoleCandidate(uri=ITEM_URI, confidence=0.9, reason="ok") + ], + ) + ) + canonical_ids = [ + CanonicalId( + ontology_class=CUSTOMER_URI, + canonical_column_per_table={T_CUSTOMERS: "customer_id"}, + ), + CanonicalId( + ontology_class=ORDER_URI, + canonical_column_per_table={T_ORDERS: "order_id"}, + ), + ] + if with_items: + canonical_ids.append( + CanonicalId( + ontology_class=ITEM_URI, + canonical_column_per_table={T_ITEMS: "item_id"}, + ) + ) + join_keys = [ + JoinKey( + from_ref=f"{T_ORDERS}.customer_id", + to_ref=f"{T_CUSTOMERS}.customer_id", + confidence=0.9, + overlap_pct=0.95, + kind="same_trust_fk", + ), + ] + if with_items: + join_keys.append( + JoinKey( + from_ref=f"{T_ITEMS}.order_id", + to_ref=f"{T_ORDERS}.order_id", + confidence=0.9, + overlap_pct=0.95, + kind="same_trust_fk", + ) + ) + + entity_order = [CUSTOMER_URI, ORDER_URI] + relationship_order = [HAS_ORDER_URI] + if with_items: + entity_order.append(ITEM_URI) + relationship_order.append(CONTAINS_URI) + + return SourceModel( + table_roles=table_roles, + canonical_ids=canonical_ids, + join_keys=join_keys, + mapping_plan=MappingPlan( + entity_order=entity_order, + relationship_order=relationship_order, + skip=[], + ), + ) + + +def _entity_mapping(class_uri: str, id_col: str, sql: str) -> dict: + """Shape produced by the EntityGenerator's submit handler.""" + return { + "ontology_class": class_uri, + "class_name": class_uri.rsplit("#", 1)[-1], + "sql_query": sql, + "id_column": id_col, + "label_column": id_col, + "attribute_mappings": {}, + "unmapped_attributes": [], + } + + +def _relationship_mapping( + prop_uri: str, source_col: str, target_col: str, sql: str +) -> dict: + return { + "property": prop_uri, + "property_name": prop_uri.rsplit("#", 1)[-1], + "sql_query": sql, + "source_id_column": source_col, + "target_id_column": target_col, + "domain": CUSTOMER_URI, + "range_class": ORDER_URI, + } + + +# ===================================================== +# Fake sub-agent factories +# ===================================================== + + +class FakePlanner: + """Fake ``run_planner`` returning canned :class:`PlannerResult` values.""" + + def __init__(self, results: List[PlannerResult]): + self.results = list(results) + self.calls = 0 + + def __call__(self, *args: Any, **kwargs: Any) -> PlannerResult: + self.calls += 1 + if not self.results: + raise AssertionError( + f"FakePlanner ran out of canned results on call #{self.calls}" + ) + return self.results.pop(0) + + +class FakeEntityGenerator: + """Routes the call by ontology_class URI to a per-URI list of results.""" + + def __init__(self, results_by_uri: Dict[str, List[EntityGenResult]]): + self.results_by_uri = {k: list(v) for k, v in results_by_uri.items()} + self.calls: List[Tuple[str, Optional[str]]] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> EntityGenResult: + ontology_class = kwargs["ontology_class"] + uri = ontology_class.get("uri", "") + hint = kwargs.get("retry_hint") + self.calls.append((uri, hint)) + queue = self.results_by_uri.get(uri, []) + if not queue: + raise AssertionError( + f"FakeEntityGenerator: no canned result for {uri} (call " + f"#{len(self.calls)})" + ) + return queue.pop(0) + + +class FakeRelationshipGenerator: + """Routes the call by ontology_property URI.""" + + def __init__(self, results_by_uri: Dict[str, List[RelationshipGenResult]]): + self.results_by_uri = {k: list(v) for k, v in results_by_uri.items()} + self.calls: List[Tuple[str, Optional[str]]] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> RelationshipGenResult: + prop = kwargs["ontology_property"] + uri = prop.get("uri", "") + hint = kwargs.get("retry_hint") + self.calls.append((uri, hint)) + queue = self.results_by_uri.get(uri, []) + if not queue: + raise AssertionError( + f"FakeRelationshipGenerator: no canned result for {uri}" + ) + return queue.pop(0) + + +class FakeCritic: + """Routes by item_uri.""" + + def __init__(self, reports_by_uri: Dict[str, List[CriticResult]]): + self.reports_by_uri = {k: list(v) for k, v in reports_by_uri.items()} + self.calls: List[str] = [] + + def __call__(self, *args: Any, **kwargs: Any) -> CriticResult: + uri = kwargs["item_uri"] + self.calls.append(uri) + queue = self.reports_by_uri.get(uri, []) + if not queue: + # Default: PASS so tests that don't care about critic still work. + return CriticResult( + success=True, + report=EvalReport( + status="PASS", + stage="semantic", + metrics={}, + failures=[], + bubble_to_planner=False, + ), + ) + return queue.pop(0) + + +class FakeDeterministicEvaluator: + """Stage-1 evaluator stub keyed by mapping uri (class or property).""" + + def __init__(self, reports_by_uri: Dict[str, List[EvalReport]]): + self.reports_by_uri = {k: list(v) for k, v in reports_by_uri.items()} + self.calls: List[str] = [] + + def for_entity(self, *args: Any, **kwargs: Any) -> EvalReport: + mapping = kwargs["mapping"] + uri = mapping.get("ontology_class", "") + return self._next(uri) + + def for_relationship(self, *args: Any, **kwargs: Any) -> EvalReport: + mapping = kwargs["mapping"] + uri = mapping.get("property", "") + return self._next(uri) + + def _next(self, uri: str) -> EvalReport: + self.calls.append(uri) + queue = self.reports_by_uri.get(uri, []) + if not queue: + return EvalReport( + status="PASS", + stage="deterministic", + metrics={}, + failures=[], + bubble_to_planner=False, + ) + return queue.pop(0) + + +# ===================================================== +# Helpers — build typical canned results +# ===================================================== + + +def _ok_entity_gen(class_uri: str, sql: Optional[str] = None) -> EntityGenResult: + id_col = { + CUSTOMER_URI: "customer_id", + ORDER_URI: "order_id", + ITEM_URI: "item_id", + }.get(class_uri, "id") + sql = sql or f"SELECT {id_col} AS ID, {id_col} AS Label FROM tbl_for_{class_uri[-3:]}" + return EntityGenResult( + success=True, + mapping=_entity_mapping(class_uri, id_col, sql), + iterations=2, + usage={"prompt_tokens": 10, "completion_tokens": 5}, + ) + + +def _ok_rel_gen(prop_uri: str) -> RelationshipGenResult: + sql = "SELECT customer_id AS source_id, order_id AS target_id FROM orders" + return RelationshipGenResult( + success=True, + mapping=_relationship_mapping(prop_uri, "customer_id", "order_id", sql), + iterations=2, + usage={"prompt_tokens": 10, "completion_tokens": 5}, + ) + + +def _pass_report(stage: str = "deterministic") -> EvalReport: + return EvalReport( + status="PASS", + stage=stage, + metrics={"row_count": 100}, + failures=[], + bubble_to_planner=False, + ) + + +def _fail_report( + *, + stage: str = "deterministic", + hint: str = "wrong column", + bubble: bool = False, +) -> EvalReport: + return EvalReport( + status="FAIL", + stage=stage, + metrics={"row_count": 5}, + failures=[ + EvalFailure( + kind="structural" if stage == "deterministic" else "semantic", + check="some_check", + expected=">0", + observed="0", + hint=hint, + ) + ], + bubble_to_planner=bubble, + ) + + +# ===================================================== +# Common fixtures +# ===================================================== + + +@pytest.fixture +def fake_client() -> Any: + """Lightweight stub with the ``execute_query`` method the orchestrator wraps.""" + + class _Client: + def __init__(self): + self.calls: List[str] = [] + + def execute_query(self, sql: str): + self.calls.append(sql) + # Echo three rows so ``row_count > 0`` if the real evaluator is + # invoked. (Most tests stub the evaluator and never hit this.) + return [ + {"customer_id": 1, "order_id": 10}, + {"customer_id": 2, "order_id": 20}, + {"customer_id": 3, "order_id": 30}, + ] + + return _Client() + + +def _patch_sub_agents( + monkeypatch, + *, + planner: Any, + entity_gen: Any = None, + rel_gen: Any = None, + critic: Any = None, + det_eval: Optional[FakeDeterministicEvaluator] = None, +) -> None: + monkeypatch.setattr(engine_mod, "run_planner", planner) + if entity_gen is not None: + monkeypatch.setattr(engine_mod, "run_entity_generator", entity_gen) + if rel_gen is not None: + monkeypatch.setattr(engine_mod, "run_relationship_generator", rel_gen) + if critic is not None: + monkeypatch.setattr(engine_mod, "run_critic", critic) + if det_eval is not None: + monkeypatch.setattr( + engine_mod, "evaluate_entity_mapping", det_eval.for_entity + ) + monkeypatch.setattr( + engine_mod, + "evaluate_relationship_mapping", + det_eval.for_relationship, + ) + + +def _run(client: Any, **overrides) -> AgentResult: + kwargs = dict( + host="https://test", + token="t", + endpoint_name="ep", + client=client, + metadata={}, + ontology=_ontology(), + skip_semantic_critic=True, + ) + kwargs.update(overrides) + return run_agent(**kwargs) + + +# ===================================================== +# Tests +# ===================================================== + + +def test_happy_path_two_entities_one_relationship(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) # all default PASS + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + assert result.success is True + assert len(result.entity_mappings) == 2 + assert len(result.relationship_mappings) == 1 + assert {m["ontology_class"] for m in result.entity_mappings} == { + CUSTOMER_URI, + ORDER_URI, + } + # 3 mapping_run_log entries, all PASS. + assert len(result.mapping_run_log) == 3 + assert all(entry["final_status"] == "PASS" for entry in result.mapping_run_log) + # source_model serialised onto the result. + assert result.source_model is not None + assert "table_roles" in result.source_model + + +def test_planner_failure_aborts(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=False, source_model=None, error="LLM rejected tools")] + ) + _patch_sub_agents(monkeypatch, planner=planner) + + result = _run(fake_client) + + assert result.success is False + assert "LLM rejected tools" in result.error + assert result.entity_mappings == [] + assert result.relationship_mappings == [] + + +def test_generator_failure_records_item_failure_continues_run( + monkeypatch, fake_client +): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + # Customer generator fails 3 times; Order succeeds. + fail = EntityGenResult(success=False, mapping=None, error="generator crashed") + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [fail, fail, fail], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + # Order entity mapped despite Customer failing. + assert any(m["ontology_class"] == ORDER_URI for m in result.entity_mappings) + assert not any(m["ontology_class"] == CUSTOMER_URI for m in result.entity_mappings) + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "FAIL_BUDGET" + # The relationship endpoint for hasOrder requires Customer; with Customer + # missing the relationship is recorded but not mapped. + rel_log = next(e for e in result.mapping_run_log if e["item"] == HAS_ORDER_URI) + assert rel_log["final_status"] in {"FAIL_BUDGET", "PASS"} + + +def test_evaluator_fail_retry_with_hint(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [ + _ok_entity_gen(CUSTOMER_URI, "SELECT bad_col AS ID FROM x"), + _ok_entity_gen(CUSTOMER_URI), + ], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + # First attempt fails, second passes — non-bubble FAIL with a hint. + det = FakeDeterministicEvaluator( + { + CUSTOMER_URI: [ + _fail_report(hint="use customer_id, not bad_col", bubble=False), + _pass_report(), + ] + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + assert result.success is True + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "PASS" + assert len(customer_log["attempts"]) == 2 + assert customer_log["attempts"][0]["stage1_status"] == "FAIL" + assert customer_log["attempts"][0]["hint"] == "use customer_id, not bad_col" + # Second EntityGenerator call must have been given the hint. + customer_calls = [c for c in entity_gen.calls if c[0] == CUSTOMER_URI] + assert customer_calls[0][1] is None + assert customer_calls[1][1] == "use customer_id, not bad_col" + + +def test_bubble_to_planner_triggers_replanning(monkeypatch, fake_client): + planner = FakePlanner( + [ + PlannerResult(success=True, source_model=_source_model(), iterations=1), + PlannerResult(success=True, source_model=_source_model(), iterations=1), + ] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [ + _ok_entity_gen(CUSTOMER_URI), # attempt 1 (bubbles) + _ok_entity_gen(CUSTOMER_URI), # attempt 1 of replan iteration + ], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator( + { + CUSTOMER_URI: [ + _fail_report(hint="wrong table", bubble=True), + _pass_report(), + ] + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + assert result.success is True + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "PASS" + # Planner was invoked twice (initial + 1 replan). + assert planner.calls == 2 + assert result.stats["planner_reinvocations"] == 1 + + +def test_planner_reinvocation_budget_exhausted(monkeypatch, fake_client): + # Budget-agnostic: 1 initial planner call + exactly the replan budget. + budget = engine_mod._PLANNER_REINVOCATION_BUDGET + planner = FakePlanner( + [ + PlannerResult(success=True, source_model=_source_model(), iterations=1) + for _ in range(budget + 1) + ] + ) + bubble = _fail_report(hint="wrong table", bubble=True) + # Customer entity bubbles on every attempt forever. + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI) for _ in range(40)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator( + { + CUSTOMER_URI: [bubble for _ in range(40)], + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "FAIL_BUBBLE" + # 1 initial planner call + exactly `budget` replans. + assert planner.calls == budget + 1 + assert result.stats["planner_reinvocations"] == budget + # Other items still attempted; Order succeeded. + assert any(m["ontology_class"] == ORDER_URI for m in result.entity_mappings) + + +def test_retry_budget_exhausted_marks_item_fail_budget(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + attempts = engine_mod._PER_ITEM_GENERATOR_ATTEMPTS + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI) for _ in range(attempts + 2)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator( + { + CUSTOMER_URI: [ + _fail_report(hint=f"hint-{i}", bubble=False) + for i in range(attempts) + ], + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "FAIL_BUDGET" + assert len(customer_log["attempts"]) == attempts + assert all(a["stage1_status"] == "FAIL" for a in customer_log["attempts"]) + assert planner.calls == 1 + # Order still mapped. + assert any(m["ontology_class"] == ORDER_URI for m in result.entity_mappings) + + +def test_critic_pass_full_pipeline(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) # default PASS + critic = FakeCritic( + { + CUSTOMER_URI: [CriticResult(success=True, report=_pass_report("semantic"))], + ORDER_URI: [CriticResult(success=True, report=_pass_report("semantic"))], + HAS_ORDER_URI: [ + CriticResult(success=True, report=_pass_report("semantic")) + ], + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + critic=critic, + det_eval=det, + ) + + result = _run(fake_client, skip_semantic_critic=False) + + assert result.success is True + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + last_attempt = customer_log["attempts"][-1] + assert last_attempt["stage1_status"] == "PASS" + assert last_attempt["critic_status"] == "PASS" + # Critic was actually called. + assert CUSTOMER_URI in critic.calls + + +def test_critic_fail_with_bubble(monkeypatch, fake_client): + planner = FakePlanner( + [ + PlannerResult(success=True, source_model=_source_model(), iterations=1), + PlannerResult(success=True, source_model=_source_model(), iterations=1), + ] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [ + _ok_entity_gen(CUSTOMER_URI), # initial attempt — critic bubbles + _ok_entity_gen(CUSTOMER_URI), # post-replan attempt — passes + ], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) # default PASS on stage 1 + critic = FakeCritic( + { + CUSTOMER_URI: [ + CriticResult( + success=True, + report=_fail_report( + stage="semantic", hint="wrong table", bubble=True + ), + ), + CriticResult(success=True, report=_pass_report("semantic")), + ], + ORDER_URI: [CriticResult(success=True, report=_pass_report("semantic"))], + HAS_ORDER_URI: [ + CriticResult(success=True, report=_pass_report("semantic")) + ], + } + ) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + critic=critic, + det_eval=det, + ) + + result = _run(fake_client, skip_semantic_critic=False) + + assert result.success is True + assert planner.calls == 2 + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "PASS" + + +def test_skip_semantic_critic_true_skips_critic(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + critic = FakeCritic({}) # would default-PASS if called + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + critic=critic, + det_eval=det, + ) + + result = _run(fake_client, skip_semantic_critic=True) + + assert result.success is True + # Critic was never called. + assert critic.calls == [] + # Every attempt records critic_status="skipped". + for entry in result.mapping_run_log: + for attempt in entry["attempts"]: + assert attempt["critic_status"] == "skipped" + + +def test_preseeded_entity_skipped(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + # Customer must NOT be generated — it's pre-seeded. + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + pre = [ + _entity_mapping( + CUSTOMER_URI, + "customer_id", + "SELECT customer_id AS ID FROM cat.sch.customers", + ) + ] + + result = _run(fake_client, entity_mappings=pre) + + assert result.success is True + customer_log = next( + e for e in result.mapping_run_log if e["item"] == CUSTOMER_URI + ) + assert customer_log["final_status"] == "PRESEEDED" + assert customer_log["attempts"] == [] + # The pre-seeded mapping is still in the result list. + assert any(m["ontology_class"] == CUSTOMER_URI for m in result.entity_mappings) + # EntityGenerator never called for Customer. + assert not any(c[0] == CUSTOMER_URI for c in entity_gen.calls) + + +def test_skip_list_is_advisory_item_still_mapped(monkeypatch, fake_client): + # Coverage is engine-enforced: the Planner's skip[] is ADVISORY only and + # must NOT remove an ontology class from the mapping run. Even when the + # Planner asks to skip Order, the engine still maps it. + sm = _source_model() + sm.mapping_plan.skip.append(SkipItem(item=ORDER_URI, reason="planner unsure")) + sm.mapping_plan.entity_order = [CUSTOMER_URI, ORDER_URI] + + planner = FakePlanner([PlannerResult(success=True, source_model=sm, iterations=1)]) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], # MUST still be attempted + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + result = _run(fake_client) + + order_log = next(e for e in result.mapping_run_log if e["item"] == ORDER_URI) + assert order_log["final_status"] == "PASS" + assert any(c[0] == ORDER_URI for c in entity_gen.calls) + assert any(m["ontology_class"] == ORDER_URI for m in result.entity_mappings) + + +def test_on_step_pct_monotonic(monkeypatch, fake_client): + planner = FakePlanner( + [PlannerResult(success=True, source_model=_source_model(), iterations=1)] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [_ok_entity_gen(CUSTOMER_URI)], + ORDER_URI: [_ok_entity_gen(ORDER_URI)], + } + ) + rel_gen = FakeRelationshipGenerator({HAS_ORDER_URI: [_ok_rel_gen(HAS_ORDER_URI)]}) + det = FakeDeterministicEvaluator({}) + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + det_eval=det, + ) + + pcts: List[int] = [] + + def on_step(msg: str, pct: int) -> None: + pcts.append(pct) + + result = _run(fake_client, on_step=on_step) + + assert result.success is True + assert pcts, "expected at least one on_step call" + # Monotonic non-decreasing — captures the documented design contract + # (we only replan on bubble, which this test does not trigger). + for prev, curr in zip(pcts, pcts[1:]): + assert curr >= prev, f"pct went backwards: {prev} -> {curr}" + # First call planning at low pct, last call completion at 100. + assert pcts[0] <= 5 + assert pcts[-1] == 100 + + +def test_id_universe_cache_used_across_relationships(monkeypatch, fake_client): + # Bare ontology with no attributes so the real deterministic evaluator + # doesn't fire on unmapped_attribute_pct. + bare_ontology = { + "entities": [ + {"uri": CUSTOMER_URI, "name": "Customer", "label": "Customer", "attributes": []}, + {"uri": ORDER_URI, "name": "Order", "label": "Order", "attributes": []}, + {"uri": ITEM_URI, "name": "Item", "label": "Item", "attributes": []}, + ], + "relationships": [ + { + "uri": HAS_ORDER_URI, + "name": "hasOrder", + "label": "hasOrder", + "domain": CUSTOMER_URI, + "range": ORDER_URI, + }, + { + "uri": CONTAINS_URI, + "name": "contains", + "label": "contains", + "domain": ORDER_URI, + "range": ITEM_URI, + }, + ], + } + + # Distinct, recognisable SQL strings per entity — used both as cache keys + # and as a discriminator for the CountingClient routing below. + customer_sql = "SELECT customer_id AS ID, customer_id AS Label FROM cat.sch.customers" + order_sql = "SELECT order_id AS ID, order_id AS Label FROM cat.sch.orders" + item_sql = "SELECT item_id AS ID, item_id AS Label FROM cat.sch.items" + + planner = FakePlanner( + [ + PlannerResult( + success=True, source_model=_source_model(with_items=True), iterations=1 + ) + ] + ) + entity_gen = FakeEntityGenerator( + { + CUSTOMER_URI: [ + EntityGenResult( + success=True, + mapping=_entity_mapping(CUSTOMER_URI, "customer_id", customer_sql), + iterations=1, + ) + ], + ORDER_URI: [ + EntityGenResult( + success=True, + mapping=_entity_mapping(ORDER_URI, "order_id", order_sql), + iterations=1, + ) + ], + ITEM_URI: [ + EntityGenResult( + success=True, + mapping=_entity_mapping(ITEM_URI, "item_id", item_sql), + iterations=1, + ) + ], + } + ) + + # Relationship edges return rows whose source/target values fall inside + # the entity universes so the deterministic evaluator passes. + has_order_sql = "SELECT customer_id AS source_id, order_id AS target_id FROM has_order_edge" + contains_sql = "SELECT order_id AS source_id, item_id AS target_id FROM contains_edge" + + rel_gen = FakeRelationshipGenerator( + { + HAS_ORDER_URI: [ + RelationshipGenResult( + success=True, + mapping={ + "property": HAS_ORDER_URI, + "property_name": "hasOrder", + "sql_query": has_order_sql, + "source_id_column": "source_id", + "target_id_column": "target_id", + }, + iterations=1, + ) + ], + CONTAINS_URI: [ + RelationshipGenResult( + success=True, + mapping={ + "property": CONTAINS_URI, + "property_name": "contains", + "sql_query": contains_sql, + "source_id_column": "source_id", + "target_id_column": "target_id", + }, + iterations=1, + ) + ], + } + ) + # Use the REAL deterministic evaluators here so the cache codepath is + # actually exercised against execute_sql_fn. + _patch_sub_agents( + monkeypatch, + planner=planner, + entity_gen=entity_gen, + rel_gen=rel_gen, + # No det_eval override -> real evaluators used. + ) + + class CountingClient: + """Records every ``execute_query`` call so we can count cache hits.""" + + def __init__(self): + self.sql_calls: List[str] = [] + + def execute_query(self, sql: str): + self.sql_calls.append(sql) + # Entity-universe queries return rows keyed by the entity's id_column. + if sql == customer_sql: + return [{"customer_id": i, "ID": i} for i in range(1, 4)] + if sql == order_sql: + return [{"order_id": i, "ID": i} for i in range(1, 4)] + if sql == item_sql: + return [{"item_id": i, "ID": i} for i in range(1, 4)] + # Edge SQLs: values must overlap with the entity universes so + # dangling_*_pct stays low and the report PASSes. + if sql == has_order_sql: + return [ + {"source_id": i, "target_id": i, "customer_id": i, "order_id": i} + for i in range(1, 4) + ] + if sql == contains_sql: + return [ + {"source_id": i, "target_id": i, "order_id": i, "item_id": i} + for i in range(1, 4) + ] + return [] + + client = CountingClient() + result = _run(client, ontology=bare_ontology) + + assert len(result.entity_mappings) == 3, result.mapping_run_log + assert len(result.relationship_mappings) == 2, result.mapping_run_log + + # Each unique entity SQL is run by the entity evaluator (1) + at most + # ONCE more from the first relationship that references it (cached for + # subsequent relationships). Without the cache, each entity SQL would + # fire from EVERY relationship that touches it — order_sql in + # particular would run 1 (entity) + 2 (hasOrder + contains) = 3 times. + for sql in (customer_sql, order_sql, item_sql): + count = sum(1 for c in client.sql_calls if c == sql) + assert count <= 2, ( + f"entity SQL ran {count} times — id_universe_cache failed:\n{sql}" + ) diff --git a/tests/agents/agent_mapping_pge/test_entity_generator.py b/tests/agents/agent_mapping_pge/test_entity_generator.py new file mode 100644 index 00000000..ed09cb03 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_entity_generator.py @@ -0,0 +1,662 @@ +"""Tests for the mapping-PGE EntityGenerator agent (Sprint 4). + +The Generator is a narrow tool-calling ReAct loop terminated by +``submit_entity_mapping``. These tests exercise the loop's control flow with +a *fake LLM* — a stub that replaces ``call_serving_endpoint`` at module +level and returns canned tool-call responses on a per-call basis. + +No real HTTP, no real Databricks, no MLflow tracing. + +What we DO exercise: +* Termination on a single submit call. +* Multi-step trajectory (execute_sql → submit). +* ``unmapped_attributes`` round-trips through the tool to the result. +* Text-only output is treated as failure (no terminal call). +* Iteration-budget exhaustion is treated as failure. +* ``retry_hint`` surfaces inside the user message. +* Step recording: every tool call produces both tool_call and tool_result + steps in the right order. +""" + +import json +from typing import Any, Callable, Dict, List, Optional + +import pytest + +from agents.agent_mapping_pge.generators import entity as entity_mod +from agents.agent_mapping_pge.generators.entity import ( + EntityGenResult, + EntityGenStep, + run_entity_generator, +) + + +# ===================================================== +# Fake LLM scaffolding (mirrors test_planner.py) +# ===================================================== + + +_CLASS_URI = "http://ex.org/maternity#Mother" + + +def _make_tool_call(name: str, arguments: dict, *, tc_id: str = "tc1") -> dict: + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(arguments)}, + } + + +def _llm_response( + *, + tool_calls: Optional[List[dict]] = None, + content: Optional[str] = None, + finish_reason: str = "tool_calls", + usage: Optional[Dict[str, int]] = None, +) -> dict: + message: Dict[str, Any] = {"role": "assistant"} + if tool_calls: + message["tool_calls"] = tool_calls + if content is not None: + message["content"] = content + return { + "choices": [{"finish_reason": finish_reason, "message": message}], + "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, + } + + +class FakeLLM: + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + # Capture the messages list as observed on each call, so tests can + # introspect what the agent put into the prompt. + self.last_messages: Optional[List[dict]] = None + self.first_messages: Optional[List[dict]] = None + + def __call__(self, *args, **kwargs) -> dict: + self.calls += 1 + # ``call_serving_endpoint(host, token, endpoint, messages, ...)`` — + # the messages list is positional arg #3 (zero-indexed). Capture + # defensively in case the call site changes to kwargs. + msgs: Optional[List[dict]] = None + if len(args) >= 4 and isinstance(args[3], list): + msgs = args[3] + elif "messages" in kwargs: + msgs = kwargs["messages"] + if msgs is not None: + # snapshot so later mutations by the loop do not affect what we + # captured. + snapshot = [dict(m) for m in msgs] + if self.first_messages is None: + self.first_messages = snapshot + self.last_messages = snapshot + + if not self.responses: + raise AssertionError( + f"FakeLLM: ran out of canned responses on call #{self.calls}" + ) + return self.responses.pop(0) + + +class CyclingFakeLLM: + """Like FakeLLM but cycles through a fixed list forever.""" + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + resp = self.responses[self.calls % len(self.responses)] + self.calls += 1 + return resp + + +@pytest.fixture +def no_sleep(monkeypatch): + """Neutralise the 3-second inter-iteration delay so tests run fast.""" + monkeypatch.setattr(entity_mod.time, "sleep", lambda *_a, **_k: None) + + +def _patch_llm(monkeypatch, fake: Callable[..., dict]) -> None: + monkeypatch.setattr(entity_mod, "call_serving_endpoint", fake) + + +# ===================================================== +# Fixtures +# ===================================================== + + +def _ontology_class() -> dict: + return { + "uri": _CLASS_URI, + "label": "Mother", + "name": "Mother", + "comment": "A mother in the maternity trust dataset.", + "attributes": [ + {"name": "nhsNumber", "type": "string"}, + {"name": "dateOfBirth", "type": "date"}, + {"name": "ethnicity", "type": "string"}, + ], + } + + +def _source_model_slice() -> dict: + return { + "candidate_tables": [ + { + "table": "cat.sch.mothers", + "confidence": 0.92, + "reason": "row per NHS — mother demographics", + } + ], + "canonical_id": { + "canonical_column_per_table": {"cat.sch.mothers": "nhs_number"}, + "format_note": "10-digit NHS", + }, + "relevant_joins": [], + } + + +def _valid_submit_args( + *, + unmapped: Optional[list] = None, +) -> dict: + args: Dict[str, Any] = { + "class_uri": _CLASS_URI, + "class_name": "Mother", + "sql_query": ( + "SELECT nhs_number AS ID, nhs_number AS Label, nhs_number, dob, ethnicity " + "FROM cat.sch.mothers WHERE nhs_number IS NOT NULL" + ), + "id_column": "nhs_number", + "label_column": "nhs_number", + "attribute_mappings": { + "nhsNumber": "nhs_number", + "dateOfBirth": "dob", + "ethnicity": "ethnicity", + }, + } + if unmapped is not None: + args["unmapped_attributes"] = unmapped + return args + + +# ===================================================== +# 1. Single-shot submit terminates immediately +# ===================================================== + + +def test_terminates_on_submit(monkeypatch, no_sleep): + """First LLM turn submits a valid mapping → success, iterations=1.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_entity_mapping", _valid_submit_args()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert isinstance(result, EntityGenResult) + assert result.success is True + assert result.iterations == 1 + assert result.mapping is not None + assert result.mapping["ontology_class"] == _CLASS_URI + assert result.mapping["id_column"] == "nhs_number" + assert result.error == "" + step_kinds = [s.step_type for s in result.steps] + assert step_kinds == ["tool_call", "tool_result"] + assert result.steps[0].tool_name == "submit_entity_mapping" + + +# ===================================================== +# 2. execute_sql validation, then submit +# ===================================================== + + +def test_validates_sql_then_submits(monkeypatch, no_sleep): + """execute_sql → submit_entity_mapping → success, iterations=2.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "execute_sql", + { + "sql": ( + "SELECT nhs_number AS ID, nhs_number AS Label, " + "nhs_number, dob, ethnicity FROM cat.sch.mothers " + "WHERE nhs_number IS NOT NULL" + ) + }, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", _valid_submit_args(), tc_id="b" + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [ + { + "ID": "1234567890", + "Label": "1234567890", + "nhs_number": "1234567890", + "dob": "1990-01-01", + "ethnicity": "white", + } + ] + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.iterations == 2 + assert result.mapping is not None + # Sequence: tool_call(execute_sql), tool_result(execute_sql), + # tool_call(submit), tool_result(submit) — 4 steps. + assert len(result.steps) == 4 + assert [s.tool_name for s in result.steps] == [ + "execute_sql", + "execute_sql", + "submit_entity_mapping", + "submit_entity_mapping", + ] + + +# ===================================================== +# 3. unmapped_attributes round-trips +# ===================================================== + + +def test_unmapped_attributes_round_trip(monkeypatch, no_sleep): + """Submit with ``unmapped_attributes`` — the field must appear on the + resulting mapping dict in the same (normalised) shape.""" + unmapped_payload = [ + {"name": "ethnicity", "reason": "no ethnicity column in this table"} + ] + args = _valid_submit_args(unmapped=unmapped_payload) + # Strip ethnicity from attribute_mappings to make the example coherent. + args["attribute_mappings"].pop("ethnicity", None) + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("submit_entity_mapping", args)] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.mapping is not None + assert result.mapping["unmapped_attributes"] == [ + {"name": "ethnicity", "reason": "no ethnicity column in this table"} + ] + # Plain-string form is also documented; make sure it survives too. + fake2 = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", + _valid_submit_args(unmapped=["ethnicity"]), + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake2) + result2 = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + assert result2.success is True + assert result2.mapping["unmapped_attributes"] == ["ethnicity"] + + +# ===================================================== +# 4. Text without terminal call → failure +# ===================================================== + + +def test_text_without_terminal_fails(monkeypatch, no_sleep): + """A plain-text response is treated as failure — the Generator must + terminate via submit_entity_mapping. + """ + fake = FakeLLM( + [_llm_response(content="I am thinking…", finish_reason="stop")] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is False + assert result.iterations == 1 + assert result.mapping is None + assert "without submitting mapping" in result.error + assert any(s.step_type == "output" for s in result.steps) + + +# ===================================================== +# 5. Iteration-budget exhaustion → failure +# ===================================================== + + +def test_exhausts_iteration_budget(monkeypatch, no_sleep): + """Endless sample_table calls with max_iterations=3 → fail with + ``iteration budget`` and three iterations of steps recorded.""" + fake = CyclingFakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="a", + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890"}] + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + max_iterations=3, + ) + + assert result.success is False + assert result.iterations == 3 + assert result.mapping is None + assert "iteration budget" in result.error + # 3 iterations × (tool_call + tool_result) = 6 steps. + assert len(result.steps) == 6 + + +# ===================================================== +# 6. retry_hint surfaces in the user prompt +# ===================================================== + + +def test_retry_hint_surfaces_in_user_prompt(monkeypatch, no_sleep): + """If ``retry_hint`` is provided, the FIRST LLM call's user message must + contain the hint verbatim.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_entity_mapping", _valid_submit_args()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + retry_hint="Use NHS column, not patient_id.", + ) + + assert result.success is True + assert fake.first_messages is not None + # messages[0] is system, messages[1] is user. + assert fake.first_messages[0]["role"] == "system" + assert fake.first_messages[1]["role"] == "user" + user_content = fake.first_messages[1]["content"] + assert "Use NHS column, not patient_id." in user_content + # RETRY HINT label is present so the LLM understands its provenance. + assert "RETRY HINT" in user_content + + +def test_system_prompt_treats_canonical_value_as_sql_expression(monkeypatch, no_sleep): + """canonical_column_per_table values may be SQL expressions (e.g. + ``regexp_extract(...)``) not just bare columns. The Generator must drop + them verbatim. Live V1.1 smoke surfaced 100% dangling on cross-trust + entities whose canonical IDs needed regex normalization to a common + format.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_entity_mapping", _valid_submit_args()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "SQL EXPRESSION" in system_content + assert "regexp_extract" in system_content + assert "verbatim" in system_content + assert "do NOT rewrite" in system_content + + +def test_system_prompt_mandates_union_for_cross_source(monkeypatch, no_sleep): + """When canonical_id.canonical_column_per_table lists 2+ tables, the + Generator MUST UNION across all of them. Single-trust selection on a + cross-source class makes relationship dangling 100% — the failure mode + surfaced by the live V1.1 smoke.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_entity_mapping", _valid_submit_args()) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "SINGLE-SOURCE vs CROSS-SOURCE" in system_content + assert "UNION ALL" in system_content + assert "TWO OR MORE tables" in system_content + # Anti-pattern: picking one trust is called out by name. + assert "Picking just one" in system_content or "missing" in system_content + + +# ===================================================== +# 7. Step recording invariants +# ===================================================== + + +def test_wrong_class_uri_submission_does_not_terminate(monkeypatch, no_sleep): + """A submit_entity_mapping call with a class_uri that doesn't match the + requested one must NOT terminate the loop. The Generator must keep going + so a follow-up submit (with the correct URI) can succeed, and the LLM + must see a corrective tool message describing the mismatch. + """ + requested_uri = _CLASS_URI + other_uri = "http://ex.org/maternity#Baby" + + wrong_args = _valid_submit_args() + wrong_args["class_uri"] = other_uri + wrong_args["class_name"] = "Baby" + + fake = FakeLLM( + [ + # Turn 1: submit with the WRONG class_uri — must NOT terminate. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", wrong_args, tc_id="wrong" + ) + ] + ), + # Turn 2: submit with the correct class_uri — should terminate. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", + _valid_submit_args(), + tc_id="right", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.iterations == 2 + assert result.mapping is not None + assert result.mapping["ontology_class"] == requested_uri + + # The LLM's second call must have seen a corrective tool message + # describing the mismatch, surfaced through ``messages``. + assert fake.last_messages is not None + tool_messages = [m for m in fake.last_messages if m.get("role") == "tool"] + assert tool_messages, "expected at least one tool message on the 2nd call" + corrective = tool_messages[-1] + corrective_content = corrective.get("content", "") + assert other_uri in corrective_content + assert requested_uri in corrective_content + assert "does not match" in corrective_content + # Sanity: the corrective payload is a JSON error (not the original + # success=True response). + parsed = json.loads(corrective_content) + assert parsed.get("success") is False + assert "error" in parsed + + +def test_records_steps(monkeypatch, no_sleep): + """Every tool-calling iteration produces exactly one ``tool_call`` step + immediately followed by one ``tool_result`` step with the same tool_name. + """ + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_entity_mapping", _valid_submit_args(), tc_id="b" + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890"}] + + result = run_entity_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_class=_ontology_class(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert len(result.steps) % 2 == 0 + for i in range(0, len(result.steps), 2): + call_step = result.steps[i] + result_step = result.steps[i + 1] + assert call_step.step_type == "tool_call" + assert result_step.step_type == "tool_result" + assert call_step.tool_name == result_step.tool_name + assert call_step.content != "" + assert result_step.content != "" + assert isinstance(call_step, EntityGenStep) + assert isinstance(result_step, EntityGenStep) diff --git a/tests/agents/agent_mapping_pge/test_planner.py b/tests/agents/agent_mapping_pge/test_planner.py new file mode 100644 index 00000000..7af091b9 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_planner.py @@ -0,0 +1,519 @@ +"""Tests for the mapping-PGE Planner agent (Sprint 3). + +The Planner is a tool-calling ReAct loop terminated by ``submit_source_model``. +These tests exercise the loop's control flow with a *fake LLM* — a stub that +replaces ``call_serving_endpoint`` at module level and returns canned tool- +call responses on a per-call basis. + +No real HTTP, no real Databricks, no MLflow tracing. The tracing decorator +is a no-op when MLflow isn't configured (see ``_TRACING_READY`` in +``agents.tracing``), so it runs cleanly here. + +What we DO exercise: +* The four termination conditions + — terminal submit_source_model with success=True breaks the loop + — text content with no tool calls is treated as failure + — iteration budget exhaustion is treated as failure + — submit returning success=False is NOT terminal (allows retry) +* Step recording: every tool call produces both tool_call and tool_result + steps in the right order. +* Iteration counter accuracy. + +What we do NOT exercise (covered elsewhere or out of scope): +* The actual content of the SourceModel — that's Sprint 1's contracts tests. +* The four planner tool handlers — that's Sprint 2's test_planner_tools.py. +* MLflow tracing semantics — the decorator is wrapped in an ``if`` guard. +""" + +import json +from typing import Any, Callable, Dict, List, Optional + +import pytest + +from agents.agent_mapping_pge import planner as planner_mod +from agents.agent_mapping_pge.contracts import SourceModel +from agents.agent_mapping_pge.planner import ( + PlannerResult, + PlannerStep, + run_planner, +) + + +# ===================================================== +# Fake LLM scaffolding +# ===================================================== + + +def _make_tool_call(name: str, arguments: dict, *, tc_id: str = "tc1") -> dict: + """Build an OpenAI-style tool_calls entry.""" + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(arguments)}, + } + + +def _llm_response( + *, + tool_calls: Optional[List[dict]] = None, + content: Optional[str] = None, + finish_reason: str = "tool_calls", + usage: Optional[Dict[str, int]] = None, +) -> dict: + """Build a minimal OpenAI-style chat-completions response.""" + message: Dict[str, Any] = {"role": "assistant"} + if tool_calls: + message["tool_calls"] = tool_calls + if content is not None: + message["content"] = content + return { + "choices": [{"finish_reason": finish_reason, "message": message}], + "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, + } + + +class FakeLLM: + """A stub for ``call_serving_endpoint`` that returns canned responses. + + The list is consumed front-to-back, one response per call. If a test + exhausts the list, the stub raises — that's almost always a test bug + (the loop iterated more times than the test author expected). + """ + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + self.calls += 1 + if not self.responses: + raise AssertionError( + f"FakeLLM: ran out of canned responses on call #{self.calls}" + ) + return self.responses.pop(0) + + +class CyclingFakeLLM: + """Like FakeLLM but cycles through a fixed list forever. + + Used for the iteration-budget-exhaustion test, where the LLM is supposed + to be stuck in an infinite loop until the engine cuts it off. + """ + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + resp = self.responses[self.calls % len(self.responses)] + self.calls += 1 + return resp + + +@pytest.fixture +def no_sleep(monkeypatch): + """Neutralise the 3-second inter-iteration delay so tests run fast.""" + monkeypatch.setattr(planner_mod.time, "sleep", lambda *_a, **_k: None) + + +def _patch_llm(monkeypatch, fake: Callable[..., dict]) -> None: + """Replace the planner's reference to ``call_serving_endpoint``.""" + monkeypatch.setattr(planner_mod, "call_serving_endpoint", fake) + + +# ===================================================== +# Fixtures: a minimal valid SourceModel payload +# ===================================================== + + +def _valid_source_model_dict() -> Dict[str, Any]: + """Same shape as test_planner_tools._valid_source_model_dict — kept + independent here so the two test files don't coupling-leak.""" + return { + "table_roles": [ + { + "table": "cat.sch.mothers", + "ontology_class_candidates": [ + { + "uri": "http://ex.org/maternity#Mother", + "confidence": 0.9, + "reason": "row per NHS", + } + ], + } + ], + "canonical_ids": [ + { + "ontology_class": "http://ex.org/maternity#Mother", + "canonical_column_per_table": {"cat.sch.mothers": "nhs_number"}, + "format_note": "", + } + ], + "join_keys": [], + "mapping_plan": { + "entity_order": ["http://ex.org/maternity#Mother"], + "relationship_order": [], + "skip": [], + }, + } + + +def _minimal_metadata() -> dict: + return { + "tables": [ + { + "name": "mothers", + "full_name": "cat.sch.mothers", + "columns": [ + {"name": "nhs_number", "type": "STRING"}, + {"name": "dob", "type": "DATE"}, + ], + } + ] + } + + +def _minimal_ontology() -> dict: + return { + "entities": [{"name": "Mother", "uri": "http://ex.org/maternity#Mother"}], + "relationships": [], + } + + +# ===================================================== +# 1. Single-shot submit terminates immediately +# ===================================================== + + +def test_planner_terminates_on_submit_source_model(monkeypatch, no_sleep): + """First LLM turn calls submit_source_model with a valid model — Planner + must return success=True with iterations=1 and source_model populated. + """ + sm = _valid_source_model_dict() + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("submit_source_model", {"model": sm})] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, # not used in this scenario + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert isinstance(result, PlannerResult) + assert result.success is True + assert result.iterations == 1 + assert isinstance(result.source_model, SourceModel) + assert len(result.source_model.table_roles) == 1 + assert result.error == "" + assert result.usage["prompt_tokens"] >= 0 + # Exactly one tool_call + one tool_result step. + step_kinds = [s.step_type for s in result.steps] + assert step_kinds == ["tool_call", "tool_result"] + assert result.steps[0].tool_name == "submit_source_model" + assert result.steps[1].tool_name == "submit_source_model" + + +# ===================================================== +# 2. Multi-step ReAct trajectory followed by submit +# ===================================================== + + +def test_planner_multi_step_then_submit(monkeypatch, no_sleep): + """get_metadata → get_ontology → sample_table → submit_source_model.""" + sm = _valid_source_model_dict() + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("get_metadata", {}, tc_id="a")] + ), + _llm_response( + tool_calls=[_make_tool_call("get_ontology", {}, tc_id="b")] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.mothers"}, + tc_id="c", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_source_model", {"model": sm}, tc_id="d" + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + # sample_table needs a client — return one row. + class FakeClient: + def execute_query(self, sql): + return [{"nhs_number": "1234567890", "dob": "1990-01-01"}] + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert result.success is True + assert result.iterations == 4 + assert isinstance(result.source_model, SourceModel) + + # Every iteration produces both a tool_call and a tool_result step. + assert len(result.steps) == 8 + expected_tool_names = [ + "get_metadata", + "get_metadata", + "get_ontology", + "get_ontology", + "sample_table", + "sample_table", + "submit_source_model", + "submit_source_model", + ] + assert [s.tool_name for s in result.steps] == expected_tool_names + assert [s.step_type for s in result.steps] == [ + "tool_call", + "tool_result", + "tool_call", + "tool_result", + "tool_call", + "tool_result", + "tool_call", + "tool_result", + ] + + +# ===================================================== +# 3. submit returning success=False does NOT terminate +# ===================================================== + + +def test_planner_invalid_source_model_does_not_terminate(monkeypatch, no_sleep): + """First submit is malformed (missing 'table' on a table_role) — the + tool returns success=False and the Planner keeps going. Second submit + is valid and terminates the loop. + """ + bad = _valid_source_model_dict() + del bad["table_roles"][0]["table"] # break it + + good = _valid_source_model_dict() + # Make the good one visibly different so we can prove which one stuck. + good["mapping_plan"]["entity_order"] = ["http://ex.org/maternity#Mother"] + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call("submit_source_model", {"model": bad}, tc_id="x") + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_source_model", {"model": good}, tc_id="y" + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert result.success is True + assert result.iterations == 2 + assert isinstance(result.source_model, SourceModel) + # The valid one is what landed on ctx — pull a field from it. + assert result.source_model.mapping_plan.entity_order == [ + "http://ex.org/maternity#Mother" + ] + # Both submit attempts were recorded; the first tool_result must signal + # failure so the orchestrator can attribute the retry. + first_submit_result = result.steps[1] + assert first_submit_result.step_type == "tool_result" + assert first_submit_result.tool_name == "submit_source_model" + payload = json.loads(first_submit_result.content) + assert payload["success"] is False + + +# ===================================================== +# 4. Free-text output without a terminal tool call → failure +# ===================================================== + + +def test_planner_text_without_terminal_fails(monkeypatch, no_sleep): + """The Planner must terminate via submit_source_model. A plain-text + response is treated as failure. + """ + fake = FakeLLM( + [_llm_response(content="I think we are done.", finish_reason="stop")] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert result.success is False + assert result.iterations == 1 + assert result.source_model is None + assert "without submitting source model" in result.error + # The text was recorded as an output step for debuggability. + assert any(s.step_type == "output" for s in result.steps) + + +# ===================================================== +# 5. Iteration budget exhaustion → failure +# ===================================================== + + +def test_planner_exhausts_iteration_budget(monkeypatch, no_sleep): + """Fake LLM keeps calling get_metadata forever. With max_iterations=3 + the Planner must give up cleanly and report budget exhaustion. + """ + fake = CyclingFakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("get_metadata", {}, tc_id="a")] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + max_iterations=3, + ) + + assert result.success is False + assert result.iterations == 3 + assert result.source_model is None + assert "iteration budget" in result.error + # Three iterations × (tool_call + tool_result) = 6 steps. + assert len(result.steps) == 6 + + +# ===================================================== +# 6. Step recording invariants +# ===================================================== + + +def test_planner_records_steps(monkeypatch, no_sleep): + """For each tool-calling iteration, the Planner must record exactly one + ``tool_call`` step (with non-empty arguments-as-content) and one + ``tool_result`` step (with non-empty content) — in that order, paired by + ``tool_name``. + """ + sm = _valid_source_model_dict() + fake = FakeLLM( + [ + _llm_response( + tool_calls=[_make_tool_call("get_metadata", {}, tc_id="a")] + ), + _llm_response( + tool_calls=[ + _make_tool_call("submit_source_model", {"model": sm}, tc_id="b") + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_planner( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + metadata=_minimal_metadata(), + ontology=_minimal_ontology(), + ) + + assert result.success is True + # Verify the pairing: every odd-indexed step (tool_call) is immediately + # followed by an even-indexed step (tool_result) with the same tool_name. + assert len(result.steps) % 2 == 0 + for i in range(0, len(result.steps), 2): + call_step = result.steps[i] + result_step = result.steps[i + 1] + assert call_step.step_type == "tool_call" + assert result_step.step_type == "tool_result" + assert call_step.tool_name == result_step.tool_name + assert call_step.content != "" + assert result_step.content != "" + # PlannerStep is the right type. + assert isinstance(call_step, PlannerStep) + assert isinstance(result_step, PlannerStep) + + +# ===================================================== +# Prompt contract — canonical-key normalization guidance +# ===================================================== + + +class TestCanonicalKeyNormalizationPrompt: + """Pin the load-bearing canonical-key guidance in the system prompt. + + Issue 2 root cause: the Planner left cross-trust keys disjoint (0% + overlap rationalized as "trust-scoped"), and when it did normalize it + copied a non-anchored regex that returns a leading-dash key. These + assertions keep the corrective guidance from silently regressing. + """ + + def test_offers_expression_overlap_verification_tool(self): + assert "normalized_value_overlap" in planner_mod.SYSTEM_PROMPT + + def test_zero_overlap_is_not_a_terminal_state(self): + prompt = planner_mod.SYSTEM_PROMPT + # The prompt must steer the model AWAY from accepting disjoint keys. + assert "trust-scoped" in prompt # names the trap explicitly + assert "100%" in prompt and "dangle" in prompt + + def test_regex_example_is_anchored(self): + prompt = planner_mod.SYSTEM_PROMPT + # The correct, anchored pattern must be present... + assert "[a-f0-9][a-f0-9-]+-preg-" in prompt + # ...and it must be flagged as the RIGHT one (the WRONG/RIGHT contrast + # teaches the leading-dash pitfall). + assert "✓ RIGHT" in prompt and "✗ WRONG" in prompt + + def test_derived_key_extracts_core_before_suffix(self): + prompt = planner_mod.SYSTEM_PROMPT + # Derived child keys must extract the shared core, then append suffix — + # not concat onto the raw prefixed local id. + assert "regexp_extract" in prompt + assert "-del" in prompt # the worked Delivery example diff --git a/tests/agents/agent_mapping_pge/test_relationship_generator.py b/tests/agents/agent_mapping_pge/test_relationship_generator.py new file mode 100644 index 00000000..d18e4a71 --- /dev/null +++ b/tests/agents/agent_mapping_pge/test_relationship_generator.py @@ -0,0 +1,736 @@ +"""Tests for the mapping-PGE RelationshipGenerator agent (Sprint 5). + +Mirrors the structure of ``test_entity_generator.py``. The Generator is a +narrow tool-calling ReAct loop terminated by ``submit_relationship_mapping``. +These tests exercise the loop's control flow with a *fake LLM* — a stub that +replaces ``call_serving_endpoint`` at module level and returns canned +responses on a per-call basis. + +No real HTTP, no real Databricks, no MLflow tracing. + +What we DO exercise: +* Termination on a single submit call. +* Multi-step trajectory (execute_sql → submit). +* Text-only output is treated as failure (no terminal call). +* Iteration-budget exhaustion is treated as failure. +* ``retry_hint`` surfaces inside the user message. +* Strict ``property_uri`` match — submit with a wrong URI is coached, not + terminal. +* Step recording invariants. +* The user prompt surfaces the source/target id_columns verbatim — pins the + Sprint 5 contract that the LLM sees the endpoint columns. +""" + +import json +from typing import Any, Callable, Dict, List, Optional + +import pytest + +from agents.agent_mapping_pge.generators import relationship as rel_mod +from agents.agent_mapping_pge.generators.relationship import ( + RelationshipGenResult, + RelationshipGenStep, + run_relationship_generator, +) + + +# ===================================================== +# Fake LLM scaffolding (mirrors test_entity_generator.py) +# ===================================================== + + +_PROP_URI = "http://ex.org/maternity#motherOf" +_SOURCE_CLASS = "http://ex.org/maternity#Mother" +_TARGET_CLASS = "http://ex.org/maternity#Baby" + + +def _make_tool_call(name: str, arguments: dict, *, tc_id: str = "tc1") -> dict: + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": json.dumps(arguments)}, + } + + +def _llm_response( + *, + tool_calls: Optional[List[dict]] = None, + content: Optional[str] = None, + finish_reason: str = "tool_calls", + usage: Optional[Dict[str, int]] = None, +) -> dict: + message: Dict[str, Any] = {"role": "assistant"} + if tool_calls: + message["tool_calls"] = tool_calls + if content is not None: + message["content"] = content + return { + "choices": [{"finish_reason": finish_reason, "message": message}], + "usage": usage or {"prompt_tokens": 10, "completion_tokens": 5}, + } + + +class FakeLLM: + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + self.last_messages: Optional[List[dict]] = None + self.first_messages: Optional[List[dict]] = None + + def __call__(self, *args, **kwargs) -> dict: + self.calls += 1 + msgs: Optional[List[dict]] = None + if len(args) >= 4 and isinstance(args[3], list): + msgs = args[3] + elif "messages" in kwargs: + msgs = kwargs["messages"] + if msgs is not None: + snapshot = [dict(m) for m in msgs] + if self.first_messages is None: + self.first_messages = snapshot + self.last_messages = snapshot + + if not self.responses: + raise AssertionError( + f"FakeLLM: ran out of canned responses on call #{self.calls}" + ) + return self.responses.pop(0) + + +class CyclingFakeLLM: + """Like FakeLLM but cycles through a fixed list forever.""" + + def __init__(self, responses: List[dict]): + self.responses = list(responses) + self.calls = 0 + + def __call__(self, *args, **kwargs) -> dict: + resp = self.responses[self.calls % len(self.responses)] + self.calls += 1 + return resp + + +@pytest.fixture +def no_sleep(monkeypatch): + """Neutralise the 3-second inter-iteration delay so tests run fast.""" + monkeypatch.setattr(rel_mod.time, "sleep", lambda *_a, **_k: None) + + +def _patch_llm(monkeypatch, fake: Callable[..., dict]) -> None: + monkeypatch.setattr(rel_mod, "call_serving_endpoint", fake) + + +# ===================================================== +# Fixtures +# ===================================================== + + +def _ontology_property() -> dict: + return { + "uri": _PROP_URI, + "label": "motherOf", + "name": "motherOf", + "comment": "Links a Mother to each of her babies.", + "domain": _SOURCE_CLASS, + "range": _TARGET_CLASS, + } + + +def _source_entity_mapping() -> dict: + return { + "ontology_class": _SOURCE_CLASS, + "class_name": "Mother", + "id_column": "nhs_number", + "label_column": "nhs_number", + "sql_query": ( + "SELECT nhs_number AS ID, nhs_number AS Label FROM cat.sch.mothers " + "WHERE nhs_number IS NOT NULL" + ), + } + + +def _target_entity_mapping() -> dict: + return { + "ontology_class": _TARGET_CLASS, + "class_name": "Baby", + "id_column": "baby_id", + "label_column": "baby_id", + "sql_query": ( + "SELECT baby_id AS ID, baby_id AS Label FROM cat.sch.babies " + "WHERE baby_id IS NOT NULL" + ), + } + + +def _source_model_slice() -> dict: + return { + "relevant_joins": [ + { + "from_ref": "cat.sch.babies.mother_nhs_number", + "to_ref": "cat.sch.mothers.nhs_number", + "confidence": 0.95, + "overlap_pct": 0.98, + "kind": "same_trust_fk", + } + ], + "candidate_tables": [ + {"table": "cat.sch.babies", "reason": "row per baby, has mother FK"} + ], + } + + +def _valid_submit_args() -> dict: + return { + "property_uri": _PROP_URI, + "property_name": "motherOf", + "sql_query": ( + "SELECT mother_nhs_number AS source_id, baby_id AS target_id " + "FROM cat.sch.babies WHERE mother_nhs_number IS NOT NULL" + ), + "source_id_column": "nhs_number", + "target_id_column": "baby_id", + "domain": _SOURCE_CLASS, + "range_class": _TARGET_CLASS, + "direction": "forward", + } + + +# ===================================================== +# 1. Single-shot submit terminates immediately +# ===================================================== + + +def test_terminates_on_submit(monkeypatch, no_sleep): + """First LLM turn submits a valid mapping → success, iterations=1.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert isinstance(result, RelationshipGenResult) + assert result.success is True + assert result.iterations == 1 + assert result.mapping is not None + assert result.mapping["property"] == _PROP_URI + assert result.mapping["source_id_column"] == "nhs_number" + assert result.mapping["target_id_column"] == "baby_id" + assert result.error == "" + step_kinds = [s.step_type for s in result.steps] + assert step_kinds == ["tool_call", "tool_result"] + assert result.steps[0].tool_name == "submit_relationship_mapping" + + +# ===================================================== +# 2. execute_sql validation, then submit +# ===================================================== + + +def test_validates_sql_then_submits(monkeypatch, no_sleep): + """execute_sql → submit_relationship_mapping → success, iterations=2.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "execute_sql", + { + "sql": ( + "SELECT mother_nhs_number AS source_id, baby_id " + "AS target_id FROM cat.sch.babies " + "WHERE mother_nhs_number IS NOT NULL" + ) + }, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", + _valid_submit_args(), + tc_id="b", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"source_id": "1234567890", "target_id": "b-1"}] + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.iterations == 2 + assert result.mapping is not None + # Sequence: tool_call(execute_sql), tool_result(execute_sql), + # tool_call(submit), tool_result(submit) — 4 steps. + assert len(result.steps) == 4 + assert [s.tool_name for s in result.steps] == [ + "execute_sql", + "execute_sql", + "submit_relationship_mapping", + "submit_relationship_mapping", + ] + + +# ===================================================== +# 3. Text without terminal call → failure +# ===================================================== + + +def test_text_without_terminal_fails(monkeypatch, no_sleep): + """A plain-text response is treated as failure — the Generator must + terminate via submit_relationship_mapping. + """ + fake = FakeLLM( + [_llm_response(content="I am thinking…", finish_reason="stop")] + ) + _patch_llm(monkeypatch, fake) + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is False + assert result.iterations == 1 + assert result.mapping is None + assert "without submitting mapping" in result.error + assert any(s.step_type == "output" for s in result.steps) + + +# ===================================================== +# 4. Iteration-budget exhaustion → failure +# ===================================================== + + +def test_exhausts_iteration_budget(monkeypatch, no_sleep): + """Endless sample_table calls with max_iterations=3 → fail with + ``iteration budget`` and three iterations of steps recorded.""" + fake = CyclingFakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.babies"}, + tc_id="a", + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"mother_nhs_number": "1234567890", "baby_id": "b-1"}] + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + max_iterations=3, + ) + + assert result.success is False + assert result.iterations == 3 + assert result.mapping is None + assert "iteration budget" in result.error + # 3 iterations × (tool_call + tool_result) = 6 steps. + assert len(result.steps) == 6 + + +# ===================================================== +# 5. retry_hint surfaces in the user prompt +# ===================================================== + + +def test_retry_hint_surfaces_in_user_prompt(monkeypatch, no_sleep): + """If ``retry_hint`` is provided, the FIRST LLM call's user message must + contain the hint verbatim and the RETRY HINT label.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + retry_hint="Use mother_nhs_number, not patient_id.", + ) + + assert result.success is True + assert fake.first_messages is not None + assert fake.first_messages[0]["role"] == "system" + assert fake.first_messages[1]["role"] == "user" + user_content = fake.first_messages[1]["content"] + assert "Use mother_nhs_number, not patient_id." in user_content + assert "RETRY HINT" in user_content + # Retry-hint corrective workflow surfaces the dangling-edge probe. + assert "dangling-edge probe" in user_content + assert "DO NOT repeat the same column choice" in user_content + + +def test_system_prompt_mandates_dangling_edge_self_check(monkeypatch, no_sleep): + """The system prompt must instruct the model to run a dangling-edge + probe with execute_sql BEFORE submitting — name-similarity alone is + insufficient and was the root cause of the live smoke failure on + hasapgarscore.""" + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "SELF-VERIFY THE VALUES BEFORE SUBMITTING" in system_content + assert "dangling_src" in system_content + assert "dangling_tgt" in system_content + # The probe must reference both endpoint universes via the entity SQLs. + assert "source entity's SQL" in system_content + assert "target entity's SQL" in system_content + + +def test_system_prompt_teaches_reproducing_derived_id_expression(monkeypatch, no_sleep): + """The id_column is an alias for a derived canonical expression; the + prompt must instruct the model to REPRODUCE that expression for the + endpoints (not select a raw column) — the root cause of the 100% + source-dangling on hasapgarscore/deliveredbaby in the live smoke. + """ + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "ALIAS FOR A DERIVED EXPRESSION" in system_content + assert "regexp_extract" in system_content + # Must steer away from selecting a raw column for the endpoint. + assert "reproduce" in system_content.lower() + + +def test_system_prompt_teaches_shared_coverage_table_rule(monkeypatch, no_sleep): + """Cross-trust endpoint entities only overlap on shared trusts; building + the edge from a table only one entity covers yields 100% dangling on the + other side. The prompt must teach picking a BOTH-covered source table. + """ + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + system_content = fake.first_messages[0]["content"] + assert "BOTH" in system_content + assert "coverage" in system_content.lower() + assert "100% source-dangling" in system_content or "100%" in system_content + + +# ===================================================== +# 6. Wrong property_uri submission does NOT terminate +# ===================================================== + + +def test_wrong_property_uri_submission_does_not_terminate(monkeypatch, no_sleep): + """A submit_relationship_mapping call with a property_uri that doesn't + match the requested one must NOT terminate the loop. The Generator must + keep going so a follow-up submit (with the correct URI) can succeed, and + the LLM must see a corrective tool message describing the mismatch. + """ + requested_uri = _PROP_URI + other_uri = "http://ex.org/maternity#fatherOf" + + wrong_args = _valid_submit_args() + wrong_args["property_uri"] = other_uri + wrong_args["property_name"] = "fatherOf" + + fake = FakeLLM( + [ + # Turn 1: submit with the WRONG property_uri — must NOT terminate. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", + wrong_args, + tc_id="wrong", + ) + ] + ), + # Turn 2: submit with the correct property_uri — should terminate. + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", + _valid_submit_args(), + tc_id="right", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert result.iterations == 2 + assert result.mapping is not None + assert result.mapping["property"] == requested_uri + + # The LLM's second call must have seen a corrective tool message + # describing the mismatch, surfaced through ``messages``. + assert fake.last_messages is not None + tool_messages = [m for m in fake.last_messages if m.get("role") == "tool"] + assert tool_messages, "expected at least one tool message on the 2nd call" + corrective = tool_messages[-1] + corrective_content = corrective.get("content", "") + assert other_uri in corrective_content + assert requested_uri in corrective_content + assert "does not match" in corrective_content + # Sanity: the corrective payload is a JSON error (not the original + # success=True response). + parsed = json.loads(corrective_content) + assert parsed.get("success") is False + assert "error" in parsed + + +# ===================================================== +# 7. Step recording invariants +# ===================================================== + + +def test_records_steps(monkeypatch, no_sleep): + """Every tool-calling iteration produces exactly one ``tool_call`` step + immediately followed by one ``tool_result`` step with the same tool_name. + """ + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "sample_table", + {"full_name": "cat.sch.babies"}, + tc_id="a", + ) + ] + ), + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", + _valid_submit_args(), + tc_id="b", + ) + ] + ), + ] + ) + _patch_llm(monkeypatch, fake) + + class FakeClient: + def execute_query(self, sql): + return [{"mother_nhs_number": "1234567890", "baby_id": "b-1"}] + + result = run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=FakeClient(), + ontology_property=_ontology_property(), + source_entity_mapping=_source_entity_mapping(), + target_entity_mapping=_target_entity_mapping(), + source_model_slice=_source_model_slice(), + ) + + assert result.success is True + assert len(result.steps) % 2 == 0 + for i in range(0, len(result.steps), 2): + call_step = result.steps[i] + result_step = result.steps[i + 1] + assert call_step.step_type == "tool_call" + assert result_step.step_type == "tool_result" + assert call_step.tool_name == result_step.tool_name + assert call_step.content != "" + assert result_step.content != "" + assert isinstance(call_step, RelationshipGenStep) + assert isinstance(result_step, RelationshipGenStep) + + +# ===================================================== +# 8. User prompt surfaces source/target id_columns +# ===================================================== + + +def test_user_prompt_includes_source_and_target_id_columns(monkeypatch, no_sleep): + """The FIRST call's user message must contain both id_column names + verbatim. This pins the Sprint 5 contract that the Generator surfaces + the endpoint columns to the LLM, so the LLM cannot silently pick + different endpoints. + """ + # Use distinctive id_column names that won't appear anywhere else in + # the slice (mothers/babies join etc.), to make the assertion strict. + src_em = { + "ontology_class": _SOURCE_CLASS, + "class_name": "Mother", + "id_column": "weirdly_named_mother_pk", + "label_column": "weirdly_named_mother_pk", + "sql_query": "SELECT weirdly_named_mother_pk AS ID FROM cat.sch.mothers", + } + tgt_em = { + "ontology_class": _TARGET_CLASS, + "class_name": "Baby", + "id_column": "weirdly_named_baby_pk", + "label_column": "weirdly_named_baby_pk", + "sql_query": "SELECT weirdly_named_baby_pk AS ID FROM cat.sch.babies", + } + + fake = FakeLLM( + [ + _llm_response( + tool_calls=[ + _make_tool_call( + "submit_relationship_mapping", _valid_submit_args() + ) + ] + ) + ] + ) + _patch_llm(monkeypatch, fake) + + run_relationship_generator( + host="https://x", + token="t", + endpoint_name="ep", + client=None, + ontology_property=_ontology_property(), + source_entity_mapping=src_em, + target_entity_mapping=tgt_em, + source_model_slice=_source_model_slice(), + ) + + assert fake.first_messages is not None + user_content = fake.first_messages[1]["content"] + assert "weirdly_named_mother_pk" in user_content + assert "weirdly_named_baby_pk" in user_content From 1b47960328876b8d65817f618f94f6dc5cbf4ec3 Mon Sep 17 00:00:00 2001 From: Fiifi Botchway Date: Thu, 25 Jun 2026 12:48:38 +0100 Subject: [PATCH 2/2] docs(agents): add SPEC.md + eval dataset for agent_mapping_pge Satisfy the AI-feature lifecycle gate for the new agent: documented contract (purpose, tool surface, eval dimensions, failure modes) + a 20-example baseline eval dataset spanning single-source, multi-source cross-trust reconciliation, and degenerate inputs. Co-authored-by: Isaac --- .planning/agents/agent_mapping_pge/SPEC.md | 84 +++++++++++++++++++ .../datasets/agent_mapping_pge/baseline.jsonl | 20 +++++ 2 files changed, 104 insertions(+) create mode 100644 .planning/agents/agent_mapping_pge/SPEC.md create mode 100644 tests/eval/datasets/agent_mapping_pge/baseline.jsonl diff --git a/.planning/agents/agent_mapping_pge/SPEC.md b/.planning/agents/agent_mapping_pge/SPEC.md new file mode 100644 index 00000000..7b9acec4 --- /dev/null +++ b/.planning/agents/agent_mapping_pge/SPEC.md @@ -0,0 +1,84 @@ +# SPEC: agent_mapping_pge + +> Required by `.cursor/12-ai-feature-lifecycle.mdc`. + +## 1. Purpose + +`agent_mapping_pge` generates entity and relationship SQL mappings for a domain +via a Planner→Generator→Evaluator (PGE) loop. Given source metadata + an ontology +it plans a source model, generates SQL per ontology item, and gates each mapping +with a deterministic evaluator plus an independent semantic critic. It replaces +the single-agent `agent_auto_assignment` mapping flow with separation of creator +and critic, and enforces coverage from the ontology rather than LLM discretion. + +## 2. Identity + +| Field | Value | +|---|---| +| `agent_name` | `agent_mapping_pge` | +| `module_path` | `src/agents/agent_mapping_pge/` | +| `model_endpoint` | _configured per workspace_ | +| `temperature` | `0.0`–`0.2` | +| `mlflow_experiment` | `/Shared/ontobricks/agents/mapping_pge` | + +## 3. Tool surface + +| Tool name | Input | Output | Purpose | +|---|---|---|---| +| `submit_source_model` | planner source-model | `SourceModel` | Terminal planner tool | +| `submit_entity_mapping` | entity SQL + id expr | mapping dict | Record an entity mapping | +| `submit_relationship_mapping` | rel SQL + endpoints | mapping dict | Record a relationship mapping | +| `normalized_value_overlap` | two columns | overlap ratio | Verify join-key overlap | +| `submit_evaluation` | critic verdict | `EvalReport` | Terminal critic tool | + +## 4. Success criteria + +1. Every mappable ontology class/relationship is covered (engine-enforced, not + LLM-discretionary). +2. Relationship endpoints reproduce the entity's canonical id expression → + 0% dangling on a valid domain. +3. A failed hub entity does not cascade to drop all its relationships (synthetic + endpoint fallback). + +## 5. Eval dimensions + +| Dimension | Metric | Threshold | Weight | Judge | +|---|---|---|---|---| +| `entity_coverage` | mapped entities / mappable classes | `1.00` | `0.25` | rule-based (`coverage.py`) | +| `relationship_coverage` | mapped rels / ontology object-properties | `1.00` | `0.20` | rule-based | +| `dangling_rate` | proportion of relationship edges with a resolvable endpoint | `1.00` | `0.25` | rule-based (deterministic evaluator) | +| `sql_executes` | generated SQL parses + runs | `0.98` | `0.15` | rule-based | +| `semantic_correctness` | critic agreement that the mapping matches intent | `0.85` | `0.15` | LLM critic (`evaluator/critic.py`) | + +**Aggregate threshold:** ≥ `0.90`. + +## 6. Failure modes + +| Symptom | Detection | Mitigation | +|---|---|---| +| Class silently skipped | `entity_coverage` < 1.0 | coverage is computed from the ontology; `skip[]` is advisory and never removes an item | +| Relationship dangles | `dangling_rate` < 1.0 | relationship generator reproduces the endpoint's canonical id expression | +| One failed hub drops all rels | rel coverage collapse | synthetic-endpoint fallback from `canonical_ids` | +| Abstract superclass unmapped | missing union | abstract classes derived as UNION-ALL of concrete subclass SQL | + +## 7. Eval dataset + +- **Baseline:** `tests/eval/datasets/agent_mapping_pge/baseline.jsonl` — ≥ 20 examples + spanning single-source, multi-source cross-trust, and degenerate inputs. +- **Regression:** added on first production mis-mapping. + +## 8. MLflow tracing + +The engine traces planner / generator / evaluator / critic stages; per-item +`mapping_evaluations` + `mapping_run_log` are surfaced on the result. + +## 9. Plan reference + +PGE design notes tracked in session memory; loop pattern per Anthropic's +harness-design (planner/generator/evaluator separation). + +## 10. Sign-off + +- [x] Sections 4, 5, 6, 7 filled. +- [ ] Baseline eval run URI pasted into PR body. +- [x] Aggregate threshold declared in §5. diff --git a/tests/eval/datasets/agent_mapping_pge/baseline.jsonl b/tests/eval/datasets/agent_mapping_pge/baseline.jsonl new file mode 100644 index 00000000..d1ffc98e --- /dev/null +++ b/tests/eval/datasets/agent_mapping_pge/baseline.jsonl @@ -0,0 +1,20 @@ +{"id": "single-patient-001", "input": {"tables": [{"name": "patients", "columns": ["patient_id", "name", "dob"]}], "ontology": {"classes": [{"name": "Patient", "attributes": ["name", "dob"]}], "properties": []}}, "expected": {"entities_mapped": ["Patient"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-order-002", "input": {"tables": [{"name": "orders", "columns": ["order_id", "amount", "status"]}], "ontology": {"classes": [{"name": "Order", "attributes": ["amount", "status"]}], "properties": []}}, "expected": {"entities_mapped": ["Order"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-product-003", "input": {"tables": [{"name": "products", "columns": ["product_id", "title", "price"]}], "ontology": {"classes": [{"name": "Product", "attributes": ["title", "price"]}], "properties": []}}, "expected": {"entities_mapped": ["Product"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-employee-004", "input": {"tables": [{"name": "employees", "columns": ["emp_id", "first_name", "last_name"]}], "ontology": {"classes": [{"name": "Employee", "attributes": ["first_name", "last_name"]}], "properties": []}}, "expected": {"entities_mapped": ["Employee"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-invoice-005", "input": {"tables": [{"name": "invoices", "columns": ["invoice_id", "total", "issued_date"]}], "ontology": {"classes": [{"name": "Invoice", "attributes": ["total", "issued_date"]}], "properties": []}}, "expected": {"entities_mapped": ["Invoice"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-device-006", "input": {"tables": [{"name": "devices", "columns": ["device_id", "model", "serial"]}], "ontology": {"classes": [{"name": "Device", "attributes": ["model", "serial"]}], "properties": []}}, "expected": {"entities_mapped": ["Device"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-account-007", "input": {"tables": [{"name": "accounts", "columns": ["account_id", "balance", "opened_date"]}], "ontology": {"classes": [{"name": "Account", "attributes": ["balance", "opened_date"]}], "properties": []}}, "expected": {"entities_mapped": ["Account"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-ticket-008", "input": {"tables": [{"name": "tickets", "columns": ["ticket_id", "subject", "priority"]}], "ontology": {"classes": [{"name": "Ticket", "attributes": ["subject", "priority"]}], "properties": []}}, "expected": {"entities_mapped": ["Ticket"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-book-009", "input": {"tables": [{"name": "books", "columns": ["book_id", "title", "isbn"]}], "ontology": {"classes": [{"name": "Book", "attributes": ["title", "isbn"]}], "properties": []}}, "expected": {"entities_mapped": ["Book"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "single-vehicle-010", "input": {"tables": [{"name": "vehicles", "columns": ["vehicle_id", "vin", "color"]}], "ontology": {"classes": [{"name": "Vehicle", "attributes": ["vin", "color"]}], "properties": []}}, "expected": {"entities_mapped": ["Vehicle"], "constraints": [{"kind": "min_entities", "value": 1}, {"kind": "dangling_rate", "value": 0.0}]}, "tags": ["single-source", "happy"]} +{"id": "multi-mother-001", "input": {"tables": [{"name": "trust_a", "columns": ["EPISODE_ID", "MOTHER_NHS_NO", "DELIVERY_DATE"]}, {"name": "trust_b", "columns": ["pregnancy_id", "mother_nhs_no", "booking_date"]}, {"name": "trust_c", "columns": ["event_id", "mother_nhs_number", "event_type"]}], "ontology": {"classes": [{"name": "Mother"}, {"name": "Pregnancy"}, {"name": "Delivery"}], "properties": [{"name": "hasPregnancy", "domain": "Mother", "range": "Pregnancy"}, {"name": "hasDelivery", "domain": "Pregnancy", "range": "Delivery"}]}}, "expected": {"entities_mapped": ["Mother", "Pregnancy", "Delivery"], "relationships_mapped": ["hasPregnancy", "hasDelivery"], "constraints": [{"kind": "dangling_rate", "value": 0.0}, {"kind": "cross_source_reconciliation", "value": true}]}, "tags": ["multi-source", "cross-trust", "reconciliation"]} +{"id": "multi-customer-002", "input": {"tables": [{"name": "crm", "columns": ["customer_id", "full_name", "email"]}, {"name": "erp", "columns": ["CUSTOMER_ID", "NAME", "TAX_CODE"]}], "ontology": {"classes": [{"name": "Customer"}, {"name": "Invoice"}], "properties": [{"name": "hasInvoice", "domain": "Customer", "range": "Invoice"}]}}, "expected": {"entities_mapped": ["Customer", "Invoice"], "relationships_mapped": ["hasInvoice"], "constraints": [{"kind": "dangling_rate", "value": 0.0}, {"kind": "cross_source_reconciliation", "value": true}]}, "tags": ["multi-source", "cross-trust", "reconciliation"]} +{"id": "multi-patient-003", "input": {"tables": [{"name": "ehr", "columns": ["patient_id", "mrn", "dob"]}, {"name": "lab", "columns": ["patient_id", "test_code", "value"]}, {"name": "gp", "columns": ["PATIENT_ID", "NHS_NUMBER", "REG_DATE"]}], "ontology": {"classes": [{"name": "Patient"}, {"name": "LabResult"}, {"name": "Observation"}], "properties": [{"name": "hasLabResult", "domain": "Patient", "range": "LabResult"}]}}, "expected": {"entities_mapped": ["Patient", "LabResult", "Observation"], "relationships_mapped": ["hasLabResult"], "constraints": [{"kind": "dangling_rate", "value": 0.0}, {"kind": "cross_source_reconciliation", "value": true}]}, "tags": ["multi-source", "cross-trust", "reconciliation"]} +{"id": "multi-member-004", "input": {"tables": [{"name": "claims", "columns": ["member_id", "provider_id", "amount"]}, {"name": "members", "columns": ["member_id", "name", "plan_id"]}], "ontology": {"classes": [{"name": "Member"}, {"name": "Claim"}, {"name": "Provider"}], "properties": [{"name": "filedBy", "domain": "Claim", "range": "Member"}]}}, "expected": {"entities_mapped": ["Member", "Claim", "Provider"], "relationships_mapped": ["filedBy"], "constraints": [{"kind": "dangling_rate", "value": 0.0}, {"kind": "cross_source_reconciliation", "value": true}]}, "tags": ["multi-source", "cross-trust", "reconciliation"]} +{"id": "multi-account-005", "input": {"tables": [{"name": "trades", "columns": ["TRADE_ID", "ACCOUNT_NO", "SYMBOL"]}, {"name": "positions", "columns": ["account_no", "symbol", "quantity"]}], "ontology": {"classes": [{"name": "Account"}, {"name": "Trade"}, {"name": "Position"}], "properties": [{"name": "onAccount", "domain": "Trade", "range": "Account"}]}}, "expected": {"entities_mapped": ["Account", "Trade", "Position"], "relationships_mapped": ["onAccount"], "constraints": [{"kind": "dangling_rate", "value": 0.0}, {"kind": "cross_source_reconciliation", "value": true}]}, "tags": ["multi-source", "cross-trust", "reconciliation"]} +{"id": "multi-device-006", "input": {"tables": [{"name": "stream", "columns": ["device_id", "ts", "reading"]}, {"name": "registry", "columns": ["DEVICE_ID", "model", "install_date"]}], "ontology": {"classes": [{"name": "Device"}, {"name": "Reading"}], "properties": [{"name": "emits", "domain": "Device", "range": "Reading"}]}}, "expected": {"entities_mapped": ["Device", "Reading"], "relationships_mapped": ["emits"], "constraints": [{"kind": "dangling_rate", "value": 0.0}, {"kind": "cross_source_reconciliation", "value": true}]}, "tags": ["multi-source", "cross-trust", "reconciliation"]} +{"id": "multi-product-007", "input": {"tables": [{"name": "sales_us", "columns": ["product_code", "region", "units"]}, {"name": "catalog", "columns": ["product_code", "name", "category"]}], "ontology": {"classes": [{"name": "Product"}, {"name": "Sale"}], "properties": [{"name": "soldAs", "domain": "Product", "range": "Sale"}]}}, "expected": {"entities_mapped": ["Product", "Sale"], "relationships_mapped": ["soldAs"], "constraints": [{"kind": "dangling_rate", "value": 0.0}, {"kind": "cross_source_reconciliation", "value": true}]}, "tags": ["multi-source", "cross-trust", "reconciliation"]} +{"id": "multi-entity-008", "input": {"tables": [{"name": "t1", "columns": ["entity_id", "a"]}, {"name": "t2", "columns": ["entity_id", "b"]}, {"name": "t3", "columns": ["ENTITY_ID", "c"]}], "ontology": {"classes": [{"name": "Entity"}, {"name": "Event"}], "properties": [{"name": "hasEvent", "domain": "Entity", "range": "Event"}]}}, "expected": {"entities_mapped": ["Entity", "Event"], "relationships_mapped": ["hasEvent"], "constraints": [{"kind": "dangling_rate", "value": 0.0}, {"kind": "cross_source_reconciliation", "value": true}]}, "tags": ["multi-source", "cross-trust", "reconciliation"]} +{"id": "degenerate-empty-ontology-019", "input": {"tables": [{"name": "t", "columns": ["id", "x"]}], "ontology": {"classes": [], "properties": []}}, "expected": {"entities_mapped": [], "constraints": [{"kind": "min_entities", "value": 0}]}, "tags": ["degenerate", "empty-ontology"]} +{"id": "degenerate-no-id-column-020", "input": {"tables": [{"name": "freeform", "columns": ["note_text", "label"]}], "ontology": {"classes": [{"name": "Note", "attributes": ["note_text", "label"]}], "properties": []}}, "expected": {"entities_mapped": ["Note"], "constraints": [{"kind": "min_entities", "value": 1}]}, "tags": ["degenerate", "no-natural-key"]}